Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions proxy/server/target.go
Original file line number Diff line number Diff line change
Expand Up @@ -178,12 +178,19 @@ func (s *TargetStream) CloseWith(err error) {
// Send the supplied request to the target stream, returning
// an error if the context has already been cancelled.
func (s *TargetStream) Send(req proto.Message) error {
ctx := s.getStream().Context()
select {
case s.reqChan <- req:
return nil
case <-ctx.Done():
return ctx.Err()
// s.ctx is cancelled by our cancelFunc — e.g. when the target dies
// mid-stream (SendMsg returns EOF) or when NewStream fails. Without
// this, Send would keep enqueueing data into a dead reqChan.
case <-s.ctx.Done():
return s.ctx.Err()
// getStream().Context() is cancelled by gRPC on transport-level
// failures. After setStream() it is a different object from s.ctx,
// so we need both to react to whichever fires first.
case <-s.getStream().Context().Done():
return s.getStream().Context().Err()
}
}

Expand Down Expand Up @@ -221,6 +228,7 @@ func (s *TargetStream) Run(nonce uint32, replyChan chan *pb.ProxyReply) {
if err != nil {
// We cannot create a new stream to the target. So we need to cancel this stream.
s.logger.Info("unable to create stream", "status", err)
s.cancelFunc()
return fmt.Errorf("could not connect to target from the proxy: %w", err)
}

Expand Down Expand Up @@ -345,6 +353,7 @@ func (s *TargetStream) Run(nonce uint32, replyChan chan *pb.ProxyReply) {
// can return nil, and the error will be picked
// up by the receiving goroutine
if err == io.EOF {
s.cancelFunc()
return nil
}
// Otherwise, this is the 'final' error. The underlying
Expand Down
200 changes: 199 additions & 1 deletion proxy/server/target_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,16 +20,20 @@ import (
"context"
"errors"
"fmt"
"io"
"sync"
"sync/atomic"
"testing"
"time"

"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/anypb"

pb "github.com/Snowflake-Labs/sansshell/proxy"
_ "github.com/Snowflake-Labs/sansshell/proxy/testdata"
td "github.com/Snowflake-Labs/sansshell/proxy/testdata"
"github.com/Snowflake-Labs/sansshell/testing/testutil"
)

Expand Down Expand Up @@ -188,6 +192,200 @@ func TestTargetStreamAddNonBlocking(t *testing.T) {
}
}

// TestSendUnblocksWhenTargetUnreachable verifies that TargetStream.Send does
// not hang when the target is unreachable and the reqChan buffer is full.
// When NewStream fails, Run must call cancelFunc so that Send unblocks
// instead of waiting forever on the full reqChan.
func TestSendUnblocksWhenTargetUnreachable(t *testing.T) {
// Arrange: create a stream set with a dialer whose NewStream blocks
// until the context is cancelled (simulating an unreachable target).
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

serviceMap := LoadGlobalServiceMap()
ss := NewTargetStreamSet(serviceMap, blockingClientDialer{}, nil)

replyChan := make(chan *pb.ProxyReply, 100)
doneChan := make(chan uint64, 1)

req := &pb.StartStream{
Target: "unreachable:9500",
Nonce: 42,
MethodName: "/Testdata.TestService/TestClientStream",
}
if err := ss.Add(ctx, req, replyChan, doneChan); err != nil {
t.Fatalf("Add: %v", err)
}

var streamID uint64
select {
case msg := <-replyChan:
sid := msg.GetStartStreamReply().GetStreamId()
if sid == 0 {
t.Fatalf("expected stream ID, got: %+v", msg)
}
streamID = sid
case <-time.After(2 * time.Second):
t.Fatal("no reply from Add")
}

payload, err := anypb.New(&td.TestRequest{Input: "chunk"})
if err != nil {
t.Fatal(err)
}
data := &pb.StreamData{
StreamIds: []uint64{streamID},
Payload: payload,
}

for i := 0; i < ReqBufferSize; i++ {
if err := ss.Send(ctx, data); err != nil {
t.Fatalf("Send[%d]: %v", i, err)
}
}

// Act: attempt one more Send (reqChan is full, so it blocks),
// then cancel the context to simulate connection failure.
sendDone := make(chan error, 1)
go func() {
sendDone <- ss.Send(ctx, data)
}()

select {
case err := <-sendDone:
t.Fatalf("Send returned immediately (want block): %v", err)
case <-time.After(100 * time.Millisecond):
}

cancel()

// Assert: Send must unblock promptly with a context error.
select {
case err := <-sendDone:
if err == nil {
t.Fatal("Send returned nil after cancel, want context error")
}
case <-time.After(2 * time.Second):
t.Fatal("Send still blocked 2s after context cancel — stall not fixed")
}
}

// dyingClientStream simulates a gRPC stream to a target that dies
// after accepting sendLimit messages. SendMsg returns io.EOF once the
// limit is reached, and RecvMsg returns an Unavailable error (as a real
// broken connection would).
type dyingClientStream struct {
ctx context.Context
mu sync.Mutex
sent int
sendLimit int
}

func (d *dyingClientStream) Header() (metadata.MD, error) { return nil, nil }
func (d *dyingClientStream) Trailer() metadata.MD { return nil }
func (d *dyingClientStream) CloseSend() error { return nil }
func (d *dyingClientStream) Context() context.Context { return d.ctx }
func (d *dyingClientStream) SendMsg(interface{}) error {
d.mu.Lock()
defer d.mu.Unlock()
if d.sent >= d.sendLimit {
return io.EOF
}
d.sent++
return nil
}
func (d *dyingClientStream) RecvMsg(interface{}) error {
return status.Error(codes.Unavailable, "connection closed")
}

// dyingClientConn returns a dyingClientStream from NewStream.
type dyingClientConn struct {
stream *dyingClientStream
}

func (c *dyingClientConn) Invoke(ctx context.Context, method string, args, reply interface{}, opts ...grpc.CallOption) error {
return status.Error(codes.Unimplemented, "not supported")
}
func (c *dyingClientConn) NewStream(ctx context.Context, desc *grpc.StreamDesc, method string, opts ...grpc.CallOption) (grpc.ClientStream, error) {
return c.stream, nil
}
func (c *dyingClientConn) Close() error { return nil }

type dyingClientDialer struct {
stream *dyingClientStream
}

func (d *dyingClientDialer) DialContext(ctx context.Context, target string, dialOpts ...grpc.DialOption) (ClientConnCloser, error) {
return &dyingClientConn{stream: d.stream}, nil
}

// TestSendFailsAfterTargetDiesMidStream verifies that when the target dies
// mid-stream (SendMsg returns io.EOF), subsequent calls to
// TargetStream.Send return an error instead of silently enqueueing data.
func TestSendFailsAfterTargetDiesMidStream(t *testing.T) {
// Arrange: a stream that accepts 2 messages then returns io.EOF on SendMsg.
ctx, cancel := context.WithCancel(context.Background())
defer cancel()

stream := &dyingClientStream{ctx: ctx, sendLimit: 2}
dialer := &dyingClientDialer{stream: stream}

serviceMap := LoadGlobalServiceMap()
ss := NewTargetStreamSet(serviceMap, dialer, nil)

replyChan := make(chan *pb.ProxyReply, 100)
doneChan := make(chan uint64, 1)

req := &pb.StartStream{
Target: "dying-target:9500",
Nonce: 99,
MethodName: "/Testdata.TestService/TestClientStream",
}
if err := ss.Add(ctx, req, replyChan, doneChan); err != nil {
t.Fatalf("Add: %v", err)
}

var streamID uint64
select {
case msg := <-replyChan:
sid := msg.GetStartStreamReply().GetStreamId()
if sid == 0 {
t.Fatalf("expected stream ID, got: %+v", msg)
}
streamID = sid
case <-time.After(2 * time.Second):
t.Fatal("no reply from Add")
}

payload, err := anypb.New(&td.TestRequest{Input: "chunk"})
if err != nil {
t.Fatal(err)
}
data := &pb.StreamData{
StreamIds: []uint64{streamID},
Payload: payload,
}

// Act: send messages until Send fails. The stream accepts 2 via
// SendMsg, then returns io.EOF. The send loop should call cancelFunc,
// causing subsequent Send() calls to fail with a context error.
var sendErr atomic.Value
for i := 0; i < 20; i++ {
if err := ss.Send(ctx, data); err != nil {
sendErr.Store(err)
break
}
time.Sleep(10 * time.Millisecond)
}

// Assert: Send must have returned an error.
stored := sendErr.Load()
if stored == nil {
t.Fatal("Send never returned error after target died — stall not fixed")
}
t.Logf("Send failed with: %v (good)", stored)
}

func TestIsCardinalityViolation(t *testing.T) {
for _, tc := range []struct {
name string
Expand Down
12 changes: 11 additions & 1 deletion services/localfile/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -1232,7 +1232,14 @@ func (p *cpCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{
}
}
progress := progressbar.DefaultBytes(fileSize)
defer progress.Close()
sendFailed := false
defer func() {
if sendFailed {
_ = progress.Clear()
return
}
_ = progress.Close()
}()

buf := make([]byte, util.StreamingChunkSize)
for {
Expand All @@ -1248,6 +1255,7 @@ func (p *cpCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{
},
}
if err := stream.Send(req); err != nil {
sendFailed = true
// Emit this to every error file as it's not specific to a given target.
for _, e := range state.Err {
fmt.Fprintf(e, "All targets - error sending on stream - %v\n", err)
Expand All @@ -1258,6 +1266,7 @@ func (p *cpCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{
}
resp, err := stream.CloseAndRecv()
if err != nil && err != io.EOF {
sendFailed = true
// Emit this to every error file as it's not specific to a given target.
for _, e := range state.Err {
fmt.Fprintf(e, "All targets - error closing stream - %v\n", err)
Expand All @@ -1274,6 +1283,7 @@ func (p *cpCmd) Execute(ctx context.Context, f *flag.FlagSet, args ...interface{
retCode = subcommands.ExitFailure
}
}
sendFailed = retCode != subcommands.ExitSuccess
return retCode
}

Expand Down
Loading