diff --git a/apps/ingest/internal/db/queries/terminal_sessions.sql.go b/apps/ingest/internal/db/queries/terminal_sessions.sql.go index 1451feb5..10acd363 100644 --- a/apps/ingest/internal/db/queries/terminal_sessions.sql.go +++ b/apps/ingest/internal/db/queries/terminal_sessions.sql.go @@ -52,25 +52,55 @@ func ValidateAndActivateTerminalSession(ctx context.Context, pool *pgxpool.Pool, AND h.deleted_at IS NULL RETURNING ts.instance_id, ts.host_id, ts.user_id, - COALESCE(NULLIF(h.hostname, ''), h.ip_addresses->>0) AS host, + COALESCE(h.hostname, '') AS hostname, + COALESCE(h.ip_addresses, '[]'::jsonb)::text AS ip_addresses, COALESCE(ts.username, '') AS username, COALESCE((o.metadata->>'terminalLoggingEnabled')::boolean, false) AS logging_enabled ` var info TerminalSessionInfo + var hostname string + var rawIPAddresses string err := pool.QueryRow(ctx, q, sessionID, tokenHash).Scan( &info.InstanceID, &info.HostID, &info.UserID, - &info.Host, + &hostname, + &rawIPAddresses, &info.Username, &info.LoggingEnabled, ) if err != nil { return nil, fmt.Errorf("terminal session not found or expired: %w", err) } + info.Host = terminalSSHTarget(hostname, terminalIPAddressesFromJSON(rawIPAddresses)) return &info, nil } +func terminalIPAddressesFromJSON(raw string) []string { + var ips []string + if err := json.Unmarshal([]byte(raw), &ips); err != nil { + return nil + } + return ips +} + +func terminalSSHTarget(hostname string, ipAddresses []string) string { + if useful := FilterHostIdentityIPs(ipAddresses); len(useful) > 0 { + return useful[0] + } + + if hostname = strings.TrimSpace(hostname); hostname != "" { + return hostname + } + + for _, ip := range ipAddresses { + if ip = strings.TrimSpace(ip); ip != "" { + return ip + } + } + return "" +} + const ( terminalAuthWindow = 15 * time.Minute terminalAuthMaxFailures = 5 diff --git a/apps/ingest/internal/db/queries/terminal_sessions_target_test.go b/apps/ingest/internal/db/queries/terminal_sessions_target_test.go new file mode 100644 index 00000000..67b6a3b8 --- /dev/null +++ b/apps/ingest/internal/db/queries/terminal_sessions_target_test.go @@ -0,0 +1,24 @@ +package queries + +import "testing" + +func TestTerminalSSHTargetPrefersUsefulIPAddress(t *testing.T) { + got := terminalSSHTarget("ct-ops", []string{"172.17.0.1", "192.168.1.42"}) + if got != "192.168.1.42" { + t.Fatalf("terminalSSHTarget() = %q, want first useful IP", got) + } +} + +func TestTerminalSSHTargetFallsBackToHostname(t *testing.T) { + got := terminalSSHTarget("ct-ops", []string{"127.0.0.1", "172.17.0.1"}) + if got != "ct-ops" { + t.Fatalf("terminalSSHTarget() = %q, want hostname fallback", got) + } +} + +func TestTerminalSSHTargetUsesFirstIPWhenHostnameMissing(t *testing.T) { + got := terminalSSHTarget("", []string{"127.0.0.1", "172.17.0.1"}) + if got != "127.0.0.1" { + t.Fatalf("terminalSSHTarget() = %q, want first recorded IP fallback", got) + } +} diff --git a/apps/ingest/internal/handlers/terminal_ws.go b/apps/ingest/internal/handlers/terminal_ws.go index 44e28e13..48e5f6a5 100644 --- a/apps/ingest/internal/handlers/terminal_ws.go +++ b/apps/ingest/internal/handlers/terminal_ws.go @@ -16,6 +16,7 @@ import ( "strconv" "strings" "sync" + "syscall" "time" "github.com/coder/websocket" @@ -378,6 +379,26 @@ func terminalSSHFailureDetails(err error) (reason string, message string) { if isSSHAuthenticationFailure(err) { return "ssh authentication failed", "SSH authentication failed" } + + var dnsErr *net.DNSError + if errors.As(err, &dnsErr) || strings.Contains(strings.ToLower(err.Error()), "no such host") { + return "ssh connection failed", "SSH connection failed: host name could not be resolved" + } + + var netErr net.Error + if errors.Is(err, context.DeadlineExceeded) || + errors.Is(err, syscall.ETIMEDOUT) || + (errors.As(err, &netErr) && netErr.Timeout()) { + return "ssh connection failed", "SSH connection failed: connection timed out" + } + + if errors.Is(err, syscall.ECONNREFUSED) { + return "ssh connection failed", "SSH connection failed: connection refused" + } + if errors.Is(err, syscall.EHOSTUNREACH) || errors.Is(err, syscall.ENETUNREACH) { + return "ssh connection failed", "SSH connection failed: host is unreachable" + } + return "ssh connection failed", "SSH connection failed" } diff --git a/apps/ingest/internal/handlers/terminal_ws_security_test.go b/apps/ingest/internal/handlers/terminal_ws_security_test.go index a38ae749..341bfbb2 100644 --- a/apps/ingest/internal/handlers/terminal_ws_security_test.go +++ b/apps/ingest/internal/handlers/terminal_ws_security_test.go @@ -1,12 +1,14 @@ package handlers import ( + "context" "errors" "os" "path/filepath" "reflect" "runtime" "strings" + "syscall" "testing" "golang.org/x/crypto/ssh" @@ -97,8 +99,18 @@ func TestTerminalSSHFailureDetails(t *testing.T) { } reason, message = terminalSSHFailureDetails(errors.New("dial tcp: lookup host.example.test: no such host")) - if reason != "ssh connection failed" || message != "SSH connection failed" { - t.Fatalf("terminalSSHFailureDetails(network) = %q, %q", reason, message) + if reason != "ssh connection failed" || message != "SSH connection failed: host name could not be resolved" { + t.Fatalf("terminalSSHFailureDetails(dns) = %q, %q", reason, message) + } + + reason, message = terminalSSHFailureDetails(syscall.ECONNREFUSED) + if reason != "ssh connection failed" || message != "SSH connection failed: connection refused" { + t.Fatalf("terminalSSHFailureDetails(refused) = %q, %q", reason, message) + } + + reason, message = terminalSSHFailureDetails(context.DeadlineExceeded) + if reason != "ssh connection failed" || message != "SSH connection failed: connection timed out" { + t.Fatalf("terminalSSHFailureDetails(timeout) = %q, %q", reason, message) } }