diff --git a/proxy/server/target.go b/proxy/server/target.go index b85e2f8a..6f97ca5d 100644 --- a/proxy/server/target.go +++ b/proxy/server/target.go @@ -178,12 +178,19 @@ func (s *TargetStream) CloseWith(err error) { // Send the supplied request to the target stream, returning // an error if the context has already been cancelled. func (s *TargetStream) Send(req proto.Message) error { - ctx := s.getStream().Context() select { case s.reqChan <- req: return nil - case <-ctx.Done(): - return ctx.Err() + // s.ctx is cancelled by our cancelFunc — e.g. when the target dies + // mid-stream (SendMsg returns EOF) or when NewStream fails. Without + // this, Send would keep enqueueing data into a dead reqChan. + case <-s.ctx.Done(): + return s.ctx.Err() + // getStream().Context() is cancelled by gRPC on transport-level + // failures. After setStream() it is a different object from s.ctx, + // so we need both to react to whichever fires first. + case <-s.getStream().Context().Done(): + return s.getStream().Context().Err() } } @@ -221,6 +228,7 @@ func (s *TargetStream) Run(nonce uint32, replyChan chan *pb.ProxyReply) { if err != nil { // We cannot create a new stream to the target. So we need to cancel this stream. s.logger.Info("unable to create stream", "status", err) + s.cancelFunc() return fmt.Errorf("could not connect to target from the proxy: %w", err) } @@ -345,6 +353,7 @@ func (s *TargetStream) Run(nonce uint32, replyChan chan *pb.ProxyReply) { // can return nil, and the error will be picked // up by the receiving goroutine if err == io.EOF { + s.cancelFunc() return nil } // Otherwise, this is the 'final' error. The underlying diff --git a/proxy/server/target_test.go b/proxy/server/target_test.go index 61cc7805..5e6332d3 100644 --- a/proxy/server/target_test.go +++ b/proxy/server/target_test.go @@ -20,16 +20,20 @@ import ( "context" "errors" "fmt" + "io" + "sync" + "sync/atomic" "testing" "time" "google.golang.org/grpc" "google.golang.org/grpc/codes" + "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/anypb" pb "github.com/Snowflake-Labs/sansshell/proxy" - _ "github.com/Snowflake-Labs/sansshell/proxy/testdata" + td "github.com/Snowflake-Labs/sansshell/proxy/testdata" "github.com/Snowflake-Labs/sansshell/testing/testutil" ) @@ -188,6 +192,200 @@ func TestTargetStreamAddNonBlocking(t *testing.T) { } } +// TestSendUnblocksWhenTargetUnreachable verifies that TargetStream.Send does +// not hang when the target is unreachable and the reqChan buffer is full. +// When NewStream fails, Run must call cancelFunc so that Send unblocks +// instead of waiting forever on the full reqChan. +func TestSendUnblocksWhenTargetUnreachable(t *testing.T) { + // Arrange: create a stream set with a dialer whose NewStream blocks + // until the context is cancelled (simulating an unreachable target). + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + serviceMap := LoadGlobalServiceMap() + ss := NewTargetStreamSet(serviceMap, blockingClientDialer{}, nil) + + replyChan := make(chan *pb.ProxyReply, 100) + doneChan := make(chan uint64, 1) + + req := &pb.StartStream{ + Target: "unreachable:9500", + Nonce: 42, + MethodName: "/Testdata.TestService/TestClientStream", + } + if err := ss.Add(ctx, req, replyChan, doneChan); err != nil { + t.Fatalf("Add: %v", err) + } + + var streamID uint64 + select { + case msg := <-replyChan: + sid := msg.GetStartStreamReply().GetStreamId() + if sid == 0 { + t.Fatalf("expected stream ID, got: %+v", msg) + } + streamID = sid + case <-time.After(2 * time.Second): + t.Fatal("no reply from Add") + } + + payload, err := anypb.New(&td.TestRequest{Input: "chunk"}) + if err != nil { + t.Fatal(err) + } + data := &pb.StreamData{ + StreamIds: []uint64{streamID}, + Payload: payload, + } + + for i := 0; i < ReqBufferSize; i++ { + if err := ss.Send(ctx, data); err != nil { + t.Fatalf("Send[%d]: %v", i, err) + } + } + + // Act: attempt one more Send (reqChan is full, so it blocks), + // then cancel the context to simulate connection failure. + sendDone := make(chan error, 1) + go func() { + sendDone <- ss.Send(ctx, data) + }() + + select { + case err := <-sendDone: + t.Fatalf("Send returned immediately (want block): %v", err) + case <-time.After(100 * time.Millisecond): + } + + cancel() + + // Assert: Send must unblock promptly with a context error. + select { + case err := <-sendDone: + if err == nil { + t.Fatal("Send returned nil after cancel, want context error") + } + case <-time.After(2 * time.Second): + t.Fatal("Send still blocked 2s after context cancel — stall not fixed") + } +} + +// dyingClientStream simulates a gRPC stream to a target that dies +// after accepting sendLimit messages. SendMsg returns io.EOF once the +// limit is reached, and RecvMsg returns an Unavailable error (as a real +// broken connection would). +type dyingClientStream struct { + ctx context.Context + mu sync.Mutex + sent int + sendLimit int +} + +func (d *dyingClientStream) Header() (metadata.MD, error) { return nil, nil } +func (d *dyingClientStream) Trailer() metadata.MD { return nil } +func (d *dyingClientStream) CloseSend() error { return nil } +func (d *dyingClientStream) Context() context.Context { return d.ctx } +func (d *dyingClientStream) SendMsg(interface{}) error { + d.mu.Lock() + defer d.mu.Unlock() + if d.sent >= d.sendLimit { + return io.EOF + } + d.sent++ + return nil +} +func (d *dyingClientStream) RecvMsg(interface{}) error { + return status.Error(codes.Unavailable, "connection closed") +} + +// dyingClientConn returns a dyingClientStream from NewStream. +type dyingClientConn struct { + stream *dyingClientStream +} + +func (c *dyingClientConn) Invoke(ctx context.Context, method string, args, reply interface{}, opts ...grpc.CallOption) error { + return status.Error(codes.Unimplemented, "not supported") +} +func (c *dyingClientConn) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) { + return c.stream, nil +} +func (c *dyingClientConn) Close() error { return nil } + +type dyingClientDialer struct { + stream *dyingClientStream +} + +func (d *dyingClientDialer) DialContext(ctx context.Context, target string, dialOpts ...grpc.DialOption) (ClientConnCloser, error) { + return &dyingClientConn{stream: d.stream}, nil +} + +// TestSendFailsAfterTargetDiesMidStream verifies that when the target dies +// mid-stream (SendMsg returns io.EOF), subsequent calls to +// TargetStream.Send return an error instead of silently enqueueing data. +func TestSendFailsAfterTargetDiesMidStream(t *testing.T) { + // Arrange: a stream that accepts 2 messages then returns io.EOF on SendMsg. + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + stream := &dyingClientStream{ctx: ctx, sendLimit: 2} + dialer := &dyingClientDialer{stream: stream} + + serviceMap := LoadGlobalServiceMap() + ss := NewTargetStreamSet(serviceMap, dialer, nil) + + replyChan := make(chan *pb.ProxyReply, 100) + doneChan := make(chan uint64, 1) + + req := &pb.StartStream{ + Target: "dying-target:9500", + Nonce: 99, + MethodName: "/Testdata.TestService/TestClientStream", + } + if err := ss.Add(ctx, req, replyChan, doneChan); err != nil { + t.Fatalf("Add: %v", err) + } + + var streamID uint64 + select { + case msg := <-replyChan: + sid := msg.GetStartStreamReply().GetStreamId() + if sid == 0 { + t.Fatalf("expected stream ID, got: %+v", msg) + } + streamID = sid + case <-time.After(2 * time.Second): + t.Fatal("no reply from Add") + } + + payload, err := anypb.New(&td.TestRequest{Input: "chunk"}) + if err != nil { + t.Fatal(err) + } + data := &pb.StreamData{ + StreamIds: []uint64{streamID}, + Payload: payload, + } + + // Act: send messages until Send fails. The stream accepts 2 via + // SendMsg, then returns io.EOF. The send loop should call cancelFunc, + // causing subsequent Send() calls to fail with a context error. + var sendErr atomic.Value + for i := 0; i < 20; i++ { + if err := ss.Send(ctx, data); err != nil { + sendErr.Store(err) + break + } + time.Sleep(10 * time.Millisecond) + } + + // Assert: Send must have returned an error. + stored := sendErr.Load() + if stored == nil { + t.Fatal("Send never returned error after target died — stall not fixed") + } + t.Logf("Send failed with: %v (good)", stored) +} + func TestIsCardinalityViolation(t *testing.T) { for _, tc := range []struct { name string diff --git a/services/localfile/client/client.go b/services/localfile/client/client.go index 3131c42f..017a65b1 100644 --- a/services/localfile/client/client.go +++ b/services/localfile/client/client.go @@ -1232,7 +1232,14 @@ func (p *cpCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{ } } progress := progressbar.DefaultBytes(fileSize) - defer progress.Close() + sendFailed := false + defer func() { + if sendFailed { + _ = progress.Clear() + return + } + _ = progress.Close() + }() buf := make([]byte, util.StreamingChunkSize) for { @@ -1248,6 +1255,7 @@ func (p *cpCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{ }, } if err := stream.Send(req); err != nil { + sendFailed = true // Emit this to every error file as it's not specific to a given target. for _, e := range state.Err { fmt.Fprintf(e, "All targets - error sending on stream - %v\n", err) @@ -1258,6 +1266,7 @@ func (p *cpCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{ } resp, err := stream.CloseAndRecv() if err != nil && err != io.EOF { + sendFailed = true // Emit this to every error file as it's not specific to a given target. for _, e := range state.Err { fmt.Fprintf(e, "All targets - error closing stream - %v\n", err) @@ -1274,6 +1283,7 @@ func (p *cpCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{ retCode = subcommands.ExitFailure } } + sendFailed = retCode != subcommands.ExitSuccess return retCode }