From 83a8f4357eb660f72bb720f899223b17b612fa76 Mon Sep 17 00:00:00 2001 From: simon Date: Thu, 30 Apr 2026 10:29:47 +0200 Subject: [PATCH 01/15] Extract Lakebase target resolver into shared libs/lakebase/target Move the Postgres autoscaling and provisioned target-resolution helpers out of cmd/psql/ into a shared package so a second consumer (the new experimental postgres query command, in a follow-up commit) can reuse the same SDK shapes. cmd/psql keeps its interactive UX by wrapping the shared AutoSelect* helpers with errors.As fallbacks on AmbiguousError. No behavior change for cmd/psql; existing acceptance tests pass. Co-authored-by: Isaac --- cmd/psql/psql.go | 61 +++--------- cmd/psql/psql_autoscaling.go | 121 +++++++---------------- cmd/psql/psql_provisioned.go | 46 +++------ cmd/psql/psql_test.go | 83 ---------------- libs/lakebase/target/autoscaling.go | 122 +++++++++++++++++++++++ libs/lakebase/target/provisioned.go | 64 ++++++++++++ libs/lakebase/target/target.go | 145 ++++++++++++++++++++++++++++ libs/lakebase/target/target_test.go | 136 ++++++++++++++++++++++++++ 8 files changed, 523 insertions(+), 255 deletions(-) delete mode 100644 cmd/psql/psql_test.go create mode 100644 libs/lakebase/target/autoscaling.go create mode 100644 libs/lakebase/target/provisioned.go create mode 100644 libs/lakebase/target/target.go create mode 100644 libs/lakebase/target/target_test.go diff --git a/cmd/psql/psql.go b/cmd/psql/psql.go index e7f3a65f8b3..e5cfaff5cff 100644 --- a/cmd/psql/psql.go +++ b/cmd/psql/psql.go @@ -11,6 +11,7 @@ import ( "github.com/databricks/cli/libs/cmdctx" "github.com/databricks/cli/libs/cmdgroup" "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/lakebase/target" libpsql "github.com/databricks/cli/libs/psql" "github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go/service/database" @@ -86,9 +87,9 @@ For more information, see: https://docs.databricks.com/aws/en/oltp/ if argsLenAtDash < 0 { argsLenAtDash = len(args) } - target := "" + targetArg := "" if argsLenAtDash == 1 { - target = args[0] + targetArg = args[0] } else if argsLenAtDash > 1 { return errors.New("expected at most one positional argument for target") } @@ -109,16 +110,17 @@ For more information, see: https://docs.databricks.com/aws/en/oltp/ } // Positional argument takes precedence - if target != "" { - if strings.HasPrefix(target, "projects/") { + if targetArg != "" { + if target.IsAutoscalingPath(targetArg) { if provisionedFlag { return errors.New("cannot use --provisioned flag with an autoscaling resource path") } - projectID, branchID, endpointID, err := parseResourcePath(target) + spec, err := target.ParseAutoscalingPath(targetArg) if err != nil { return err } + projectID, branchID, endpointID := spec.ProjectID, spec.BranchID, spec.EndpointID // Check for conflicts between path and flags if projectFlag != "" && projectFlag != projectID { @@ -149,7 +151,7 @@ For more information, see: https://docs.databricks.com/aws/en/oltp/ if autoscalingFlag { return errors.New("cannot use --autoscaling flag with a provisioned instance name") } - return connectProvisioned(ctx, target, retryConfig, extraArgs) + return connectProvisioned(ctx, targetArg, retryConfig, extraArgs) } // No positional argument - use flags only @@ -197,45 +199,6 @@ For more information, see: https://docs.databricks.com/aws/en/oltp/ return cmd } -// parseResourcePath extracts project, branch, and endpoint IDs from a resource path. -// Returns an error for malformed paths. -func parseResourcePath(input string) (project, branch, endpoint string, err error) { - parts := strings.Split(input, "/") - - // Must start with projects/{project_id} - if len(parts) < 2 || parts[0] != "projects" { - return "", "", "", fmt.Errorf("invalid resource path: %s", input) - } - if parts[1] == "" { - return "", "", "", errors.New("invalid resource path: missing project ID") - } - project = parts[1] - - // Optional: branches/{branch_id} - if len(parts) > 2 { - if len(parts) < 4 || parts[2] != "branches" { - return "", "", "", errors.New("invalid resource path: expected 'branches' after project") - } - if parts[3] == "" { - return "", "", "", errors.New("invalid resource path: missing branch ID") - } - branch = parts[3] - } - - // Optional: endpoints/{endpoint_id} - if len(parts) > 4 { - if len(parts) < 6 || parts[4] != "endpoints" { - return "", "", "", errors.New("invalid resource path: expected 'endpoints' after branch") - } - if parts[5] == "" { - return "", "", "", errors.New("invalid resource path: missing endpoint ID") - } - endpoint = parts[5] - } - - return project, branch, endpoint, nil -} - // listAllDatabases fetches all database instances and projects in parallel. // Errors are silently ignored; callers should check for empty results. func listAllDatabases(ctx context.Context, w *databricks.WorkspaceClient) ([]database.DatabaseInstance, []postgres.Project) { @@ -248,12 +211,12 @@ func listAllDatabases(ctx context.Context, w *databricks.WorkspaceClient) ([]dat projectsCh := make(chan result[postgres.Project], 1) go func() { - instances, err := w.Database.ListDatabaseInstancesAll(ctx, database.ListDatabaseInstancesRequest{}) + instances, err := target.ListProvisionedInstances(ctx, w) instancesCh <- result[database.DatabaseInstance]{instances, err} }() go func() { - projects, err := w.Postgres.ListProjectsAll(ctx, postgres.ListProjectsRequest{}) + projects, err := target.ListProjects(ctx, w) projectsCh <- result[postgres.Project]{projects, err} }() @@ -294,7 +257,7 @@ func showSelectionAndConnect(ctx context.Context, retryConfig libpsql.RetryConfi }) } for _, proj := range projects { - displayName := extractIDFromName(proj.Name, "projects") + displayName := target.ExtractID(proj.Name, target.PathSegmentProjects) if proj.Status != nil && proj.Status.DisplayName != "" { displayName = proj.Status.DisplayName } @@ -315,7 +278,7 @@ func showSelectionAndConnect(ctx context.Context, retryConfig libpsql.RetryConfi } if after, ok := strings.CutPrefix(selected, "autoscaling:"); ok { projectName := after - projectID := extractIDFromName(projectName, "projects") + projectID := target.ExtractID(projectName, target.PathSegmentProjects) return connectAutoscaling(ctx, projectID, "", "", retryConfig, extraArgs) } diff --git a/cmd/psql/psql_autoscaling.go b/cmd/psql/psql_autoscaling.go index 00c555e4c12..4273dad3b50 100644 --- a/cmd/psql/psql_autoscaling.go +++ b/cmd/psql/psql_autoscaling.go @@ -4,10 +4,10 @@ import ( "context" "errors" "fmt" - "strings" "github.com/databricks/cli/libs/cmdctx" "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/lakebase/target" libpsql "github.com/databricks/cli/libs/psql" "github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go/service/postgres" @@ -16,18 +16,6 @@ import ( // autoscalingDefaultDatabase is the default database for Lakebase Autoscaling projects. const autoscalingDefaultDatabase = "databricks_postgres" -// extractIDFromName extracts the ID component from a resource name. -// For example, extractIDFromName("projects/foo/branches/bar", "branches") returns "bar". -func extractIDFromName(name, component string) string { - parts := strings.Split(name, "/") - for i := range len(parts) - 1 { - if parts[i] == component { - return parts[i+1] - } - } - return name -} - // connectAutoscaling connects to a Lakebase Autoscaling endpoint. func connectAutoscaling(ctx context.Context, projectID, branchID, endpointID string, retryConfig libpsql.RetryConfig, extraArgs []string) error { w := cmdctx.WorkspaceClient(ctx) @@ -50,11 +38,9 @@ func connectAutoscaling(ctx context.Context, projectID, branchID, endpointID str return errors.New("endpoint host information is not available") } - cred, err := w.Postgres.GenerateDatabaseCredential(ctx, postgres.GenerateDatabaseCredentialRequest{ - Endpoint: endpoint.Name, - }) + token, err := target.AutoscalingCredential(ctx, w, endpoint.Name) if err != nil { - return fmt.Errorf("failed to get database credentials: %w", err) + return err } var endpointType string @@ -83,7 +69,7 @@ func connectAutoscaling(ctx context.Context, projectID, branchID, endpointID str return libpsql.Connect(ctx, libpsql.ConnectOptions{ Host: endpoint.Status.Hosts.Host, Username: user.UserName, - Password: cred.Token, + Password: token, DefaultDatabase: autoscalingDefaultDatabase, ExtraArgs: extraArgs, }, retryConfig) @@ -102,7 +88,7 @@ func resolveEndpoint(ctx context.Context, w *databricks.WorkspaceClient, project } // Get project to display its name - project, err := w.Postgres.GetProject(ctx, postgres.GetProjectRequest{Name: "projects/" + projectID}) + project, err := target.GetProject(ctx, w, projectID) if err != nil { return nil, fmt.Errorf("failed to get project: %w", err) } @@ -136,7 +122,7 @@ func resolveEndpoint(ctx context.Context, w *databricks.WorkspaceClient, project } // Get endpoint to validate and return it - endpoint, err := w.Postgres.GetEndpoint(ctx, postgres.GetEndpointRequest{Name: branch.Name + "/endpoints/" + endpointID}) + endpoint, err := target.GetEndpoint(ctx, w, projectID, branchID, endpointID) if err != nil { return nil, fmt.Errorf("failed to get endpoint: %w", err) } @@ -145,38 +131,31 @@ func resolveEndpoint(ctx context.Context, w *databricks.WorkspaceClient, project return endpoint, nil } +// selectAmbiguous prompts the user to pick one of the choices in an +// AmbiguousError. Caller is expected to have logged a header (e.g. via the +// spinner) before invoking. Used to keep psql's interactive UX while letting +// the shared lib do the actual list+filter work. +func selectAmbiguous(ctx context.Context, amb *target.AmbiguousError, prompt string) (string, error) { + items := make([]cmdio.Tuple, 0, len(amb.Choices)) + for _, c := range amb.Choices { + items = append(items, cmdio.Tuple{Name: c.DisplayName, Id: c.ID}) + } + return cmdio.SelectOrdered(ctx, items, prompt) +} + // selectProjectID auto-selects if there's only one project, otherwise prompts user to select. // Returns the project ID (not the full project object). func selectProjectID(ctx context.Context, w *databricks.WorkspaceClient) (string, error) { sp := cmdio.NewSpinner(ctx) sp.Update("Loading projects...") - projects, err := w.Postgres.ListProjectsAll(ctx, postgres.ListProjectsRequest{}) + id, err := target.AutoSelectProject(ctx, w) sp.Close() - if err != nil { - return "", err - } - - if len(projects) == 0 { - return "", errors.New("no Lakebase Autoscaling projects found in workspace") - } - // Auto-select if there's only one project - if len(projects) == 1 { - return extractIDFromName(projects[0].Name, "projects"), nil - } - - // Multiple projects, prompt user to select - var items []cmdio.Tuple - for _, project := range projects { - projectID := extractIDFromName(project.Name, "projects") - displayName := projectID - if project.Status != nil && project.Status.DisplayName != "" { - displayName = project.Status.DisplayName - } - items = append(items, cmdio.Tuple{Name: displayName, Id: projectID}) + var amb *target.AmbiguousError + if !errors.As(err, &amb) { + return id, err } - - return cmdio.SelectOrdered(ctx, items, "Select project") + return selectAmbiguous(ctx, amb, "Select project") } // selectBranchID auto-selects if there's only one branch, otherwise prompts user to select. @@ -184,31 +163,14 @@ func selectProjectID(ctx context.Context, w *databricks.WorkspaceClient) (string func selectBranchID(ctx context.Context, w *databricks.WorkspaceClient, projectName string) (string, error) { sp := cmdio.NewSpinner(ctx) sp.Update("Loading branches...") - branches, err := w.Postgres.ListBranchesAll(ctx, postgres.ListBranchesRequest{ - Parent: projectName, - }) + id, err := target.AutoSelectBranch(ctx, w, projectName) sp.Close() - if err != nil { - return "", err - } - - if len(branches) == 0 { - return "", errors.New("no branches found in project") - } - - // Auto-select if there's only one branch - if len(branches) == 1 { - return extractIDFromName(branches[0].Name, "branches"), nil - } - // Multiple branches, prompt user to select - var items []cmdio.Tuple - for _, branch := range branches { - branchID := extractIDFromName(branch.Name, "branches") - items = append(items, cmdio.Tuple{Name: branchID, Id: branchID}) + var amb *target.AmbiguousError + if !errors.As(err, &amb) { + return id, err } - - return cmdio.SelectOrdered(ctx, items, "Select branch") + return selectAmbiguous(ctx, amb, "Select branch") } // selectEndpointID auto-selects if there's only one endpoint, otherwise prompts user to select. @@ -216,29 +178,12 @@ func selectBranchID(ctx context.Context, w *databricks.WorkspaceClient, projectN func selectEndpointID(ctx context.Context, w *databricks.WorkspaceClient, branchName string) (string, error) { sp := cmdio.NewSpinner(ctx) sp.Update("Loading endpoints...") - endpoints, err := w.Postgres.ListEndpointsAll(ctx, postgres.ListEndpointsRequest{ - Parent: branchName, - }) + id, err := target.AutoSelectEndpoint(ctx, w, branchName) sp.Close() - if err != nil { - return "", err - } - - if len(endpoints) == 0 { - return "", errors.New("no endpoints found in branch") - } - // Auto-select if there's only one endpoint - if len(endpoints) == 1 { - return extractIDFromName(endpoints[0].Name, "endpoints"), nil + var amb *target.AmbiguousError + if !errors.As(err, &amb) { + return id, err } - - // Multiple endpoints, prompt user to select - var items []cmdio.Tuple - for _, endpoint := range endpoints { - endpointID := extractIDFromName(endpoint.Name, "endpoints") - items = append(items, cmdio.Tuple{Name: endpointID, Id: endpointID}) - } - - return cmdio.SelectOrdered(ctx, items, "Select endpoint") + return selectAmbiguous(ctx, amb, "Select endpoint") } diff --git a/cmd/psql/psql_provisioned.go b/cmd/psql/psql_provisioned.go index 88ca1bb9181..9ea88def5ce 100644 --- a/cmd/psql/psql_provisioned.go +++ b/cmd/psql/psql_provisioned.go @@ -7,10 +7,10 @@ import ( "github.com/databricks/cli/libs/cmdctx" "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/lakebase/target" libpsql "github.com/databricks/cli/libs/psql" "github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go/service/database" - "github.com/google/uuid" ) // provisionedDefaultDatabase is the default database for Lakebase Provisioned instances. @@ -39,12 +39,9 @@ func connectProvisioned(ctx context.Context, instanceName string, retryConfig li return errors.New("database instance is not ready for accepting connections") } - cred, err := w.Database.GenerateDatabaseCredential(ctx, database.GenerateDatabaseCredentialRequest{ - InstanceNames: []string{instance.Name}, - RequestId: uuid.NewString(), - }) + token, err := target.ProvisionedCredential(ctx, w, instance.Name) if err != nil { - return fmt.Errorf("failed to get database credentials: %w", err) + return err } cmdio.LogString(ctx, "Connecting to database instance...") @@ -52,7 +49,7 @@ func connectProvisioned(ctx context.Context, instanceName string, retryConfig li return libpsql.Connect(ctx, libpsql.ConnectOptions{ Host: instance.ReadWriteDns, Username: user.UserName, - Password: cred.Token, + Password: token, DefaultDatabase: provisionedDefaultDatabase, ExtraArgs: extraArgs, }, retryConfig) @@ -61,7 +58,6 @@ func connectProvisioned(ctx context.Context, instanceName string, retryConfig li // resolveInstance resolves an instance name to a full instance object. // If instanceName is empty, prompts the user to select one. func resolveInstance(ctx context.Context, w *databricks.WorkspaceClient, instanceName string) (*database.DatabaseInstance, error) { - // If instance not specified, select one if instanceName == "" { var err error instanceName, err = selectInstanceID(ctx, w) @@ -70,15 +66,9 @@ func resolveInstance(ctx context.Context, w *databricks.WorkspaceClient, instanc } } - instance, err := w.Database.GetDatabaseInstance(ctx, database.GetDatabaseInstanceRequest{ - Name: instanceName, - }) + instance, err := target.GetProvisioned(ctx, w, instanceName) if err != nil { - return nil, fmt.Errorf("failed to get database instance: %w", err) - } - // Ensure Name is set (API response may not include it) - if instance.Name == "" { - instance.Name = instanceName + return nil, err } cmdio.LogString(ctx, "Instance: "+instance.Name) @@ -90,26 +80,12 @@ func resolveInstance(ctx context.Context, w *databricks.WorkspaceClient, instanc func selectInstanceID(ctx context.Context, w *databricks.WorkspaceClient) (string, error) { sp := cmdio.NewSpinner(ctx) sp.Update("Loading instances...") - instances, err := w.Database.ListDatabaseInstancesAll(ctx, database.ListDatabaseInstancesRequest{}) + id, err := target.AutoSelectProvisioned(ctx, w) sp.Close() - if err != nil { - return "", err - } - if len(instances) == 0 { - return "", errors.New("no Lakebase Provisioned instances found in workspace") + var amb *target.AmbiguousError + if !errors.As(err, &amb) { + return id, err } - - // Auto-select if there's only one instance - if len(instances) == 1 { - return instances[0].Name, nil - } - - // Multiple instances, prompt user to select - var items []cmdio.Tuple - for _, inst := range instances { - items = append(items, cmdio.Tuple{Name: inst.Name, Id: inst.Name}) - } - - return cmdio.SelectOrdered(ctx, items, "Select instance") + return selectAmbiguous(ctx, amb, "Select instance") } diff --git a/cmd/psql/psql_test.go b/cmd/psql/psql_test.go deleted file mode 100644 index fc8a7e53cba..00000000000 --- a/cmd/psql/psql_test.go +++ /dev/null @@ -1,83 +0,0 @@ -package psql - -import ( - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestParseResourcePath(t *testing.T) { - tests := []struct { - name string - input string - project string - branch string - endpoint string - wantErr string - }{ - { - name: "project only", - input: "projects/my-project", - project: "my-project", - }, - { - name: "project and branch", - input: "projects/my-project/branches/main", - project: "my-project", - branch: "main", - }, - { - name: "full path", - input: "projects/my-project/branches/main/endpoints/primary", - project: "my-project", - branch: "main", - endpoint: "primary", - }, - { - name: "missing project ID", - input: "projects/", - wantErr: "missing project ID", - }, - { - name: "missing branch ID", - input: "projects/my-project/branches/", - wantErr: "missing branch ID", - }, - { - name: "missing endpoint ID", - input: "projects/my-project/branches/main/endpoints/", - wantErr: "missing endpoint ID", - }, - { - name: "invalid segment after project", - input: "projects/my-project/invalid/foo", - wantErr: "expected 'branches' after project", - }, - { - name: "invalid segment after branch", - input: "projects/my-project/branches/main/invalid/foo", - wantErr: "expected 'endpoints' after branch", - }, - { - name: "not a projects path", - input: "something/else", - wantErr: "invalid resource path", - }, - } - - for _, tc := range tests { - t.Run(tc.name, func(t *testing.T) { - project, branch, endpoint, err := parseResourcePath(tc.input) - if tc.wantErr != "" { - require.Error(t, err) - assert.Contains(t, err.Error(), tc.wantErr) - return - } - require.NoError(t, err) - assert.Equal(t, tc.project, project) - assert.Equal(t, tc.branch, branch) - assert.Equal(t, tc.endpoint, endpoint) - }) - } -} diff --git a/libs/lakebase/target/autoscaling.go b/libs/lakebase/target/autoscaling.go new file mode 100644 index 00000000000..f1edef216d4 --- /dev/null +++ b/libs/lakebase/target/autoscaling.go @@ -0,0 +1,122 @@ +package target + +import ( + "context" + "errors" + "fmt" + + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/service/postgres" +) + +// ListProjects returns all autoscaling projects in the workspace. +func ListProjects(ctx context.Context, w *databricks.WorkspaceClient) ([]postgres.Project, error) { + return w.Postgres.ListProjectsAll(ctx, postgres.ListProjectsRequest{}) +} + +// ListBranches returns all branches under the given project. +// projectName is the SDK resource name like "projects/foo". +func ListBranches(ctx context.Context, w *databricks.WorkspaceClient, projectName string) ([]postgres.Branch, error) { + return w.Postgres.ListBranchesAll(ctx, postgres.ListBranchesRequest{Parent: projectName}) +} + +// ListEndpoints returns all endpoints under the given branch. +// branchName is the SDK resource name like "projects/foo/branches/bar". +func ListEndpoints(ctx context.Context, w *databricks.WorkspaceClient, branchName string) ([]postgres.Endpoint, error) { + return w.Postgres.ListEndpointsAll(ctx, postgres.ListEndpointsRequest{Parent: branchName}) +} + +// GetProject fetches a single project by ID. +func GetProject(ctx context.Context, w *databricks.WorkspaceClient, projectID string) (*postgres.Project, error) { + return w.Postgres.GetProject(ctx, postgres.GetProjectRequest{Name: PathSegmentProjects + "/" + projectID}) +} + +// GetEndpoint fetches a single endpoint by ID, given its parent IDs. +func GetEndpoint(ctx context.Context, w *databricks.WorkspaceClient, projectID, branchID, endpointID string) (*postgres.Endpoint, error) { + name := fmt.Sprintf("projects/%s/branches/%s/endpoints/%s", projectID, branchID, endpointID) + return w.Postgres.GetEndpoint(ctx, postgres.GetEndpointRequest{Name: name}) +} + +// AutoSelectProject returns the only project in the workspace, or an +// AmbiguousError carrying the choices if there are multiple. Returns a plain +// error if there are no projects. +func AutoSelectProject(ctx context.Context, w *databricks.WorkspaceClient) (string, error) { + projects, err := ListProjects(ctx, w) + if err != nil { + return "", err + } + if len(projects) == 0 { + return "", errors.New("no Lakebase Autoscaling projects found in workspace") + } + if len(projects) == 1 { + return ExtractID(projects[0].Name, PathSegmentProjects), nil + } + + choices := make([]Choice, 0, len(projects)) + for _, p := range projects { + id := ExtractID(p.Name, PathSegmentProjects) + display := id + if p.Status != nil && p.Status.DisplayName != "" { + display = p.Status.DisplayName + } + choices = append(choices, Choice{ID: id, DisplayName: display}) + } + return "", &AmbiguousError{Kind: "project", FlagHint: "--project", Choices: choices} +} + +// AutoSelectBranch returns the only branch under projectName, or an +// AmbiguousError if there are multiple. +func AutoSelectBranch(ctx context.Context, w *databricks.WorkspaceClient, projectName string) (string, error) { + branches, err := ListBranches(ctx, w, projectName) + if err != nil { + return "", err + } + if len(branches) == 0 { + return "", errors.New("no branches found in project") + } + if len(branches) == 1 { + return ExtractID(branches[0].Name, pathSegmentBranches), nil + } + + choices := make([]Choice, 0, len(branches)) + for _, b := range branches { + id := ExtractID(b.Name, pathSegmentBranches) + choices = append(choices, Choice{ID: id, DisplayName: id}) + } + return "", &AmbiguousError{Kind: "branch", Parent: projectName, FlagHint: "--branch", Choices: choices} +} + +// AutoSelectEndpoint returns the only endpoint under branchName, or an +// AmbiguousError if there are multiple. +func AutoSelectEndpoint(ctx context.Context, w *databricks.WorkspaceClient, branchName string) (string, error) { + endpoints, err := ListEndpoints(ctx, w, branchName) + if err != nil { + return "", err + } + if len(endpoints) == 0 { + return "", errors.New("no endpoints found in branch") + } + if len(endpoints) == 1 { + return ExtractID(endpoints[0].Name, pathSegmentEndpoints), nil + } + + choices := make([]Choice, 0, len(endpoints)) + for _, e := range endpoints { + id := ExtractID(e.Name, pathSegmentEndpoints) + choices = append(choices, Choice{ID: id, DisplayName: id}) + } + return "", &AmbiguousError{Kind: "endpoint", Parent: branchName, FlagHint: "--endpoint", Choices: choices} +} + +// AutoscalingCredential issues a short-lived OAuth token that can be used to +// authenticate to the given autoscaling endpoint. endpointName is the SDK +// resource name (e.g. "projects/foo/branches/bar/endpoints/baz"). +func AutoscalingCredential(ctx context.Context, w *databricks.WorkspaceClient, endpointName string) (string, error) { + cred, err := w.Postgres.GenerateDatabaseCredential(ctx, postgres.GenerateDatabaseCredentialRequest{ + Endpoint: endpointName, + }) + if err != nil { + return "", fmt.Errorf("failed to get database credentials: %w", err) + } + return cred.Token, nil +} diff --git a/libs/lakebase/target/provisioned.go b/libs/lakebase/target/provisioned.go new file mode 100644 index 00000000000..773cc867ce0 --- /dev/null +++ b/libs/lakebase/target/provisioned.go @@ -0,0 +1,64 @@ +package target + +import ( + "context" + "errors" + "fmt" + + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/service/database" + "github.com/google/uuid" +) + +// ListProvisionedInstances returns all provisioned database instances in the workspace. +func ListProvisionedInstances(ctx context.Context, w *databricks.WorkspaceClient) ([]database.DatabaseInstance, error) { + return w.Database.ListDatabaseInstancesAll(ctx, database.ListDatabaseInstancesRequest{}) +} + +// GetProvisioned fetches a single provisioned instance by name. +// The Name field on the response can be empty; this function ensures it is +// populated from the input so downstream callers do not have to re-set it. +func GetProvisioned(ctx context.Context, w *databricks.WorkspaceClient, name string) (*database.DatabaseInstance, error) { + instance, err := w.Database.GetDatabaseInstance(ctx, database.GetDatabaseInstanceRequest{Name: name}) + if err != nil { + return nil, fmt.Errorf("failed to get database instance: %w", err) + } + if instance.Name == "" { + instance.Name = name + } + return instance, nil +} + +// AutoSelectProvisioned returns the only provisioned instance in the workspace, +// or an AmbiguousError if there are multiple. Returns a plain error if none. +func AutoSelectProvisioned(ctx context.Context, w *databricks.WorkspaceClient) (string, error) { + instances, err := ListProvisionedInstances(ctx, w) + if err != nil { + return "", err + } + if len(instances) == 0 { + return "", errors.New("no Lakebase Provisioned instances found in workspace") + } + if len(instances) == 1 { + return instances[0].Name, nil + } + + choices := make([]Choice, 0, len(instances)) + for _, inst := range instances { + choices = append(choices, Choice{ID: inst.Name, DisplayName: inst.Name}) + } + return "", &AmbiguousError{Kind: "instance", FlagHint: "--target", Choices: choices} +} + +// ProvisionedCredential issues a short-lived OAuth token for the provisioned +// instance with the given name. +func ProvisionedCredential(ctx context.Context, w *databricks.WorkspaceClient, instanceName string) (string, error) { + cred, err := w.Database.GenerateDatabaseCredential(ctx, database.GenerateDatabaseCredentialRequest{ + InstanceNames: []string{instanceName}, + RequestId: uuid.NewString(), + }) + if err != nil { + return "", fmt.Errorf("failed to get database credentials: %w", err) + } + return cred.Token, nil +} diff --git a/libs/lakebase/target/target.go b/libs/lakebase/target/target.go new file mode 100644 index 00000000000..d02c95903ce --- /dev/null +++ b/libs/lakebase/target/target.go @@ -0,0 +1,145 @@ +// Package target resolves Lakebase Postgres targets (provisioned instances and +// autoscaling endpoints) into the host, credential, and SDK metadata that +// callers need to open a connection. It is shared by `cmd/psql` and the +// `experimental postgres query` command so that both speak the same SDK. +package target + +import ( + "errors" + "fmt" + "strings" +) + +const ( + // PathSegmentProjects is the leading path segment that identifies an + // autoscaling resource path. Provisioned instance names never start with it. + PathSegmentProjects = "projects" + pathSegmentBranches = "branches" + pathSegmentEndpoints = "endpoints" +) + +// AutoscalingSpec is a partial or full specification for an autoscaling endpoint. +// Empty fields signal "auto-select if exactly one exists, otherwise error". +type AutoscalingSpec struct { + ProjectID string + BranchID string + EndpointID string +} + +// Choice is a single candidate returned alongside an AmbiguousError so callers +// can either render the list to the user or prompt interactively. +type Choice struct { + ID string + DisplayName string +} + +// AmbiguousError is returned by AutoSelect* helpers when the SDK returns more +// than one candidate and the caller did not specify which one to pick. +// +// Callers that have a TTY (e.g. `databricks psql`) can use errors.As to detect +// this and prompt interactively. Callers that are non-interactive (e.g. the +// scriptable `postgres query` command) propagate it as a plain error: the +// formatted message already enumerates the choices. +type AmbiguousError struct { + // Kind identifies what was ambiguous: "project", "branch", or "endpoint". + Kind string + // Parent is the SDK resource name that contained the ambiguity (e.g. + // "projects/foo" when listing branches), or empty when listing projects. + Parent string + // FlagHint is the flag a user would set to disambiguate (e.g. "--branch"). + FlagHint string + // Choices enumerates the candidates returned by the SDK. + Choices []Choice +} + +func (e *AmbiguousError) Error() string { + plural := map[string]string{ + "project": "projects", + "branch": "branches", + "endpoint": "endpoints", + "instance": "instances", + }[e.Kind] + if plural == "" { + plural = e.Kind + } + + var sb strings.Builder + if e.Parent == "" { + fmt.Fprintf(&sb, "multiple %s found; specify %s:", plural, e.FlagHint) + } else { + fmt.Fprintf(&sb, "multiple %s found in %s; specify %s:", plural, e.Parent, e.FlagHint) + } + for _, c := range e.Choices { + sb.WriteString("\n - ") + sb.WriteString(c.ID) + if c.DisplayName != "" && c.DisplayName != c.ID { + fmt.Fprintf(&sb, " (%s)", c.DisplayName) + } + } + return sb.String() +} + +// ParseAutoscalingPath extracts project, branch, and endpoint IDs from a +// resource path. Accepts partial paths: +// +// projects/foo +// projects/foo/branches/bar +// projects/foo/branches/bar/endpoints/baz +// +// Returns an error if the path is malformed or does not start with "projects/". +func ParseAutoscalingPath(input string) (AutoscalingSpec, error) { + parts := strings.Split(input, "/") + + if len(parts) < 2 || parts[0] != PathSegmentProjects { + return AutoscalingSpec{}, fmt.Errorf("invalid resource path: %s", input) + } + if parts[1] == "" { + return AutoscalingSpec{}, errors.New("invalid resource path: missing project ID") + } + spec := AutoscalingSpec{ProjectID: parts[1]} + + if len(parts) > 2 { + if len(parts) < 4 || parts[2] != pathSegmentBranches { + return AutoscalingSpec{}, errors.New("invalid resource path: expected 'branches' after project") + } + if parts[3] == "" { + return AutoscalingSpec{}, errors.New("invalid resource path: missing branch ID") + } + spec.BranchID = parts[3] + } + + if len(parts) > 4 { + if len(parts) < 6 || parts[4] != pathSegmentEndpoints { + return AutoscalingSpec{}, errors.New("invalid resource path: expected 'endpoints' after branch") + } + if parts[5] == "" { + return AutoscalingSpec{}, errors.New("invalid resource path: missing endpoint ID") + } + spec.EndpointID = parts[5] + } + + if len(parts) > 6 { + return AutoscalingSpec{}, fmt.Errorf("invalid resource path: trailing components after endpoint: %s", input) + } + + return spec, nil +} + +// ExtractID returns the value following component in a resource name. +// ExtractID("projects/foo/branches/bar", "branches") returns "bar". +// Returns the original name unchanged if component is not found. +func ExtractID(name, component string) string { + parts := strings.Split(name, "/") + for i := range len(parts) - 1 { + if parts[i] == component { + return parts[i+1] + } + } + return name +} + +// IsAutoscalingPath reports whether s is an autoscaling resource path +// (i.e. starts with "projects/"). Provisioned instance names never do. +func IsAutoscalingPath(s string) bool { + return strings.HasPrefix(s, PathSegmentProjects+"/") +} diff --git a/libs/lakebase/target/target_test.go b/libs/lakebase/target/target_test.go new file mode 100644 index 00000000000..4b4a763c122 --- /dev/null +++ b/libs/lakebase/target/target_test.go @@ -0,0 +1,136 @@ +package target + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseAutoscalingPath(t *testing.T) { + tests := []struct { + name string + input string + want AutoscalingSpec + wantErr string + }{ + { + name: "project only", + input: "projects/my-project", + want: AutoscalingSpec{ProjectID: "my-project"}, + }, + { + name: "project and branch", + input: "projects/my-project/branches/main", + want: AutoscalingSpec{ProjectID: "my-project", BranchID: "main"}, + }, + { + name: "full path", + input: "projects/my-project/branches/main/endpoints/primary", + want: AutoscalingSpec{ProjectID: "my-project", BranchID: "main", EndpointID: "primary"}, + }, + { + name: "missing project ID", + input: "projects/", + wantErr: "missing project ID", + }, + { + name: "missing branch ID", + input: "projects/my-project/branches/", + wantErr: "missing branch ID", + }, + { + name: "missing endpoint ID", + input: "projects/my-project/branches/main/endpoints/", + wantErr: "missing endpoint ID", + }, + { + name: "invalid segment after project", + input: "projects/my-project/invalid/foo", + wantErr: "expected 'branches' after project", + }, + { + name: "invalid segment after branch", + input: "projects/my-project/branches/main/invalid/foo", + wantErr: "expected 'endpoints' after branch", + }, + { + name: "not a projects path", + input: "something/else", + wantErr: "invalid resource path", + }, + { + name: "trailing components after endpoint", + input: "projects/foo/branches/bar/endpoints/baz/extra", + wantErr: "trailing components after endpoint", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + got, err := ParseAutoscalingPath(tc.input) + if tc.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.wantErr) + return + } + require.NoError(t, err) + assert.Equal(t, tc.want, got) + }) + } +} + +func TestExtractID(t *testing.T) { + assert.Equal(t, "bar", ExtractID("projects/foo/branches/bar", "branches")) + assert.Equal(t, "foo", ExtractID("projects/foo", "projects")) + assert.Equal(t, "baz", ExtractID("projects/foo/branches/bar/endpoints/baz", "endpoints")) + assert.Equal(t, "no-component", ExtractID("no-component", "missing")) +} + +func TestIsAutoscalingPath(t *testing.T) { + assert.True(t, IsAutoscalingPath("projects/foo")) + assert.True(t, IsAutoscalingPath("projects/foo/branches/bar")) + assert.False(t, IsAutoscalingPath("my-instance")) + assert.False(t, IsAutoscalingPath("")) + assert.False(t, IsAutoscalingPath("projects")) +} + +func TestAmbiguousErrorMessage(t *testing.T) { + t.Run("with parent", func(t *testing.T) { + err := &AmbiguousError{ + Kind: "branch", + Parent: "projects/foo", + FlagHint: "--branch", + Choices: []Choice{ + {ID: "main", DisplayName: "main"}, + {ID: "feature-x", DisplayName: "feature-x"}, + }, + } + assert.Equal(t, + "multiple branches found in projects/foo; specify --branch:\n - main\n - feature-x", + err.Error(), + ) + }) + + t.Run("without parent", func(t *testing.T) { + err := &AmbiguousError{ + Kind: "project", + FlagHint: "--project", + Choices: []Choice{ + {ID: "alpha", DisplayName: "Alpha Project"}, + {ID: "beta", DisplayName: "beta"}, + }, + } + assert.Equal(t, + "multiple projects found; specify --project:\n - alpha (Alpha Project)\n - beta", + err.Error(), + ) + }) + + t.Run("errors.As", func(t *testing.T) { + var amb *AmbiguousError + err := error(&AmbiguousError{Kind: "endpoint", FlagHint: "--endpoint"}) + assert.ErrorAs(t, err, &amb) + assert.Equal(t, "endpoint", amb.Kind) + }) +} From bfb632090cd6e4e5ec75e69720de781a2eba2be8 Mon Sep 17 00:00:00 2001 From: simon Date: Thu, 30 Apr 2026 10:30:02 +0200 Subject: [PATCH 02/15] Add experimental postgres query command (autoscaling, text output) Scaffolds 'databricks experimental postgres query', a scriptable SQL runner against a Lakebase Postgres autoscaling endpoint that does not require a system psql binary. This PR ships the smallest useful slice: - Single positional SQL statement. - --target (autoscaling resource path), --project, --branch, --endpoint targeting forms; provisioned-shaped targets return a pointer at 'databricks psql' for now. - Connect retry on idle/waking endpoints (08xxx SQLSTATE family, dial errors). - Text output (static table for rows-producing statements, command tag for command-only). Provisioned support, JSON/CSV streaming output, multi-statement input, cancellation, and integration tests come in follow-up PRs. Driver: github.com/jackc/pgx/v5 v5.9.1 (MIT). Already a direct dep of the universe monorepo's Lakebase services; aligning here keeps the SDK surface consistent. Co-authored-by: Isaac --- NEXT_CHANGELOG.md | 2 + NOTICE | 4 + .../query/ambiguous-targeting/out.test.toml | 8 + .../query/ambiguous-targeting/output.txt | 18 ++ .../postgres/query/ambiguous-targeting/script | 8 + .../query/ambiguous-targeting/test.toml | 62 +++++++ .../query/argument-errors/out.test.toml | 8 + .../postgres/query/argument-errors/output.txt | 40 ++++ .../postgres/query/argument-errors/script | 29 +++ .../postgres/query/argument-errors/test.toml | 3 + cmd/experimental/experimental.go | 2 + experimental/postgres/cmd/cmd.go | 25 +++ experimental/postgres/cmd/connect.go | 147 +++++++++++++++ experimental/postgres/cmd/connect_test.go | 149 +++++++++++++++ experimental/postgres/cmd/execute.go | 62 +++++++ experimental/postgres/cmd/query.go | 133 ++++++++++++++ experimental/postgres/cmd/render.go | 74 ++++++++ experimental/postgres/cmd/render_test.go | 67 +++++++ experimental/postgres/cmd/targeting.go | 173 ++++++++++++++++++ experimental/postgres/cmd/targeting_test.go | 81 ++++++++ go.mod | 3 + go.sum | 10 + 22 files changed, 1108 insertions(+) create mode 100644 acceptance/cmd/experimental/postgres/query/ambiguous-targeting/out.test.toml create mode 100644 acceptance/cmd/experimental/postgres/query/ambiguous-targeting/output.txt create mode 100644 acceptance/cmd/experimental/postgres/query/ambiguous-targeting/script create mode 100644 acceptance/cmd/experimental/postgres/query/ambiguous-targeting/test.toml create mode 100644 acceptance/cmd/experimental/postgres/query/argument-errors/out.test.toml create mode 100644 acceptance/cmd/experimental/postgres/query/argument-errors/output.txt create mode 100644 acceptance/cmd/experimental/postgres/query/argument-errors/script create mode 100644 acceptance/cmd/experimental/postgres/query/argument-errors/test.toml create mode 100644 experimental/postgres/cmd/cmd.go create mode 100644 experimental/postgres/cmd/connect.go create mode 100644 experimental/postgres/cmd/connect_test.go create mode 100644 experimental/postgres/cmd/execute.go create mode 100644 experimental/postgres/cmd/query.go create mode 100644 experimental/postgres/cmd/render.go create mode 100644 experimental/postgres/cmd/render_test.go create mode 100644 experimental/postgres/cmd/targeting.go create mode 100644 experimental/postgres/cmd/targeting_test.go diff --git a/NEXT_CHANGELOG.md b/NEXT_CHANGELOG.md index 00152d550ea..be66fe3964b 100644 --- a/NEXT_CHANGELOG.md +++ b/NEXT_CHANGELOG.md @@ -7,3 +7,5 @@ ### Bundles ### Dependency updates + +* Added `github.com/jackc/pgx/v5` v5.9.1 (MIT) as a new dependency. Used by an experimental Postgres command added in this release; the package is dormant for users who do not invoke that command. diff --git a/NOTICE b/NOTICE index 1e286df6f91..7077be46928 100644 --- a/NOTICE +++ b/NOTICE @@ -127,6 +127,10 @@ google/jsonschema-go - https://github.com/google/jsonschema-go Copyright 2025 Google LLC License - https://github.com/google/jsonschema-go/blob/main/LICENSE +jackc/pgx - https://github.com/jackc/pgx +Copyright (c) 2013-2021 Jack Christensen +License - https://github.com/jackc/pgx/blob/master/LICENSE + charmbracelet/bubbles - https://github.com/charmbracelet/bubbles Copyright (c) 2020-2025 Charmbracelet, Inc License - https://github.com/charmbracelet/bubbles/blob/master/LICENSE diff --git a/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/out.test.toml b/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/out.test.toml new file mode 100644 index 00000000000..40bb0d10471 --- /dev/null +++ b/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/out.test.toml @@ -0,0 +1,8 @@ +Local = true +Cloud = false + +[GOOS] + windows = false + +[EnvMatrix] + DATABRICKS_BUNDLE_ENGINE = ["terraform", "direct"] diff --git a/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/output.txt b/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/output.txt new file mode 100644 index 00000000000..e95a7b3613d --- /dev/null +++ b/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/output.txt @@ -0,0 +1,18 @@ + +=== Project with multiple branches and no --branch should error with choices: +>>> musterr [CLI] experimental postgres query --project foo SELECT 1 +Error: multiple branches found in projects/foo; specify --branch: + - main + - dev + +=== Project with multiple endpoints in only branch should error with choices: +>>> musterr [CLI] experimental postgres query --project bar SELECT 1 +Error: multiple endpoints found in projects/bar/branches/only; specify --endpoint: + - read-write + - read-only + +=== Partial path with multiple branches should error with choices: +>>> musterr [CLI] experimental postgres query --target projects/foo SELECT 1 +Error: multiple branches found in projects/foo; specify --branch: + - main + - dev diff --git a/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/script b/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/script new file mode 100644 index 00000000000..6143fd96f02 --- /dev/null +++ b/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/script @@ -0,0 +1,8 @@ +title "Project with multiple branches and no --branch should error with choices:" +trace musterr $CLI experimental postgres query --project foo "SELECT 1" + +title "Project with multiple endpoints in only branch should error with choices:" +trace musterr $CLI experimental postgres query --project bar "SELECT 1" + +title "Partial path with multiple branches should error with choices:" +trace musterr $CLI experimental postgres query --target projects/foo "SELECT 1" diff --git a/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/test.toml b/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/test.toml new file mode 100644 index 00000000000..2a61e7e8e25 --- /dev/null +++ b/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/test.toml @@ -0,0 +1,62 @@ +GOOS.windows = false + +[[Server]] +Pattern = "GET /api/2.0/postgres/projects" +Response.Body = ''' +{ + "projects": [ + {"name": "projects/alpha", "status": {"display_name": "Alpha"}}, + {"name": "projects/beta", "status": {"display_name": "Beta"}} + ] +} +''' + +[[Server]] +Pattern = "GET /api/2.0/postgres/projects/foo" +Response.Body = ''' +{ + "name": "projects/foo", + "status": {"display_name": "Foo Project"} +} +''' + +[[Server]] +Pattern = "GET /api/2.0/postgres/projects/foo/branches" +Response.Body = ''' +{ + "branches": [ + {"name": "projects/foo/branches/main"}, + {"name": "projects/foo/branches/dev"} + ] +} +''' + +[[Server]] +Pattern = "GET /api/2.0/postgres/projects/bar" +Response.Body = ''' +{ + "name": "projects/bar", + "status": {"display_name": "Bar Project"} +} +''' + +[[Server]] +Pattern = "GET /api/2.0/postgres/projects/bar/branches" +Response.Body = ''' +{ + "branches": [ + {"name": "projects/bar/branches/only"} + ] +} +''' + +[[Server]] +Pattern = "GET /api/2.0/postgres/projects/bar/branches/only/endpoints" +Response.Body = ''' +{ + "endpoints": [ + {"name": "projects/bar/branches/only/endpoints/read-write"}, + {"name": "projects/bar/branches/only/endpoints/read-only"} + ] +} +''' diff --git a/acceptance/cmd/experimental/postgres/query/argument-errors/out.test.toml b/acceptance/cmd/experimental/postgres/query/argument-errors/out.test.toml new file mode 100644 index 00000000000..40bb0d10471 --- /dev/null +++ b/acceptance/cmd/experimental/postgres/query/argument-errors/out.test.toml @@ -0,0 +1,8 @@ +Local = true +Cloud = false + +[GOOS] + windows = false + +[EnvMatrix] + DATABRICKS_BUNDLE_ENGINE = ["terraform", "direct"] diff --git a/acceptance/cmd/experimental/postgres/query/argument-errors/output.txt b/acceptance/cmd/experimental/postgres/query/argument-errors/output.txt new file mode 100644 index 00000000000..59ddbfedc6e --- /dev/null +++ b/acceptance/cmd/experimental/postgres/query/argument-errors/output.txt @@ -0,0 +1,40 @@ + +=== No SQL argument should error: +>>> musterr [CLI] experimental postgres query --target projects/foo +Error: accepts 1 arg(s), received 0 + +=== Empty SQL should error: +>>> musterr [CLI] experimental postgres query --target projects/foo +Error: no SQL provided + +=== Neither targeting form should error: +>>> musterr [CLI] experimental postgres query SELECT 1 +Error: must specify --target or --project + +=== Both --target and --project should error: +>>> musterr [CLI] experimental postgres query --target projects/foo --project foo SELECT 1 +Error: if any flags in the group [target project] are set none of the others can be; [project target] were all set + +=== Both --target and --branch should error: +>>> musterr [CLI] experimental postgres query --target projects/foo --branch main SELECT 1 +Error: if any flags in the group [target branch] are set none of the others can be; [branch target] were all set + +=== Branch without project should error: +>>> musterr [CLI] experimental postgres query --branch main SELECT 1 +Error: --project is required when using --branch or --endpoint + +=== Endpoint without project should error: +>>> musterr [CLI] experimental postgres query --endpoint primary SELECT 1 +Error: --project is required when using --branch or --endpoint + +=== Provisioned-shaped target should error pointing at psql: +>>> musterr [CLI] experimental postgres query --target my-instance SELECT 1 +Error: provisioned instances are not yet supported by this experimental command; use 'databricks psql ' for now + +=== Malformed autoscaling path should error: +>>> musterr [CLI] experimental postgres query --target projects/ SELECT 1 +Error: invalid resource path: missing project ID + +=== Trailing components after endpoint should error: +>>> musterr [CLI] experimental postgres query --target projects/foo/branches/bar/endpoints/baz/extra SELECT 1 +Error: invalid resource path: trailing components after endpoint: projects/foo/branches/bar/endpoints/baz/extra diff --git a/acceptance/cmd/experimental/postgres/query/argument-errors/script b/acceptance/cmd/experimental/postgres/query/argument-errors/script new file mode 100644 index 00000000000..5874c843a03 --- /dev/null +++ b/acceptance/cmd/experimental/postgres/query/argument-errors/script @@ -0,0 +1,29 @@ +title "No SQL argument should error:" +trace musterr $CLI experimental postgres query --target projects/foo + +title "Empty SQL should error:" +trace musterr $CLI experimental postgres query --target projects/foo " " + +title "Neither targeting form should error:" +trace musterr $CLI experimental postgres query "SELECT 1" + +title "Both --target and --project should error:" +trace musterr $CLI experimental postgres query --target projects/foo --project foo "SELECT 1" + +title "Both --target and --branch should error:" +trace musterr $CLI experimental postgres query --target projects/foo --branch main "SELECT 1" + +title "Branch without project should error:" +trace musterr $CLI experimental postgres query --branch main "SELECT 1" + +title "Endpoint without project should error:" +trace musterr $CLI experimental postgres query --endpoint primary "SELECT 1" + +title "Provisioned-shaped target should error pointing at psql:" +trace musterr $CLI experimental postgres query --target my-instance "SELECT 1" + +title "Malformed autoscaling path should error:" +trace musterr $CLI experimental postgres query --target projects/ "SELECT 1" + +title "Trailing components after endpoint should error:" +trace musterr $CLI experimental postgres query --target projects/foo/branches/bar/endpoints/baz/extra "SELECT 1" diff --git a/acceptance/cmd/experimental/postgres/query/argument-errors/test.toml b/acceptance/cmd/experimental/postgres/query/argument-errors/test.toml new file mode 100644 index 00000000000..3371f08de12 --- /dev/null +++ b/acceptance/cmd/experimental/postgres/query/argument-errors/test.toml @@ -0,0 +1,3 @@ +# Argument validation runs before any SDK call. No mocked HTTP responses are +# needed; CLI either errors at flag-parse time or at our own validate function. +GOOS.windows = false diff --git a/cmd/experimental/experimental.go b/cmd/experimental/experimental.go index 36ad8765898..52c6bac79b0 100644 --- a/cmd/experimental/experimental.go +++ b/cmd/experimental/experimental.go @@ -2,6 +2,7 @@ package experimental import ( aitoolscmd "github.com/databricks/cli/experimental/aitools/cmd" + postgrescmd "github.com/databricks/cli/experimental/postgres/cmd" "github.com/spf13/cobra" ) @@ -21,6 +22,7 @@ development. They may change or be removed in future versions without notice.`, } cmd.AddCommand(aitoolscmd.NewAitoolsCmd()) + cmd.AddCommand(postgrescmd.New()) cmd.AddCommand(newWorkspaceOpenCommand()) return cmd diff --git a/experimental/postgres/cmd/cmd.go b/experimental/postgres/cmd/cmd.go new file mode 100644 index 00000000000..8db7b46be86 --- /dev/null +++ b/experimental/postgres/cmd/cmd.go @@ -0,0 +1,25 @@ +// Package postgrescmd registers the `databricks experimental postgres ...` +// command tree. The current sub-tree provides `query`, a scriptable SQL +// runner against any Lakebase Postgres endpoint that does not require a +// system `psql` binary. +package postgrescmd + +import ( + "github.com/spf13/cobra" +) + +// New returns the root `postgres` experimental command. It is hidden by its +// experimental parent; the command itself is always visible once one of its +// subcommands is reached. +func New() *cobra.Command { + cmd := &cobra.Command{ + Use: "postgres", + Short: "Experimental Lakebase Postgres commands", + Long: `Experimental commands for interacting with Lakebase Postgres endpoints. + +These commands are still under development and may change without notice.`, + } + + cmd.AddCommand(newQueryCmd()) + return cmd +} diff --git a/experimental/postgres/cmd/connect.go b/experimental/postgres/cmd/connect.go new file mode 100644 index 00000000000..a0674b81ead --- /dev/null +++ b/experimental/postgres/cmd/connect.go @@ -0,0 +1,147 @@ +package postgrescmd + +import ( + "context" + "errors" + "fmt" + "net" + "time" + + "github.com/databricks/cli/libs/cmdio" + "github.com/databricks/cli/libs/log" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" +) + +// defaultConnectTimeout is the dial timeout for a single connect attempt. +// Lakebase autoscaling endpoints can be cold-starting; Postgres' own dial +// keeps trying within this window before giving up. +const defaultConnectTimeout = 120 * time.Second + +// connectConfig collects everything pgx needs to dial Postgres. Kept as a +// struct rather than passed through positional args because the pgx config +// has many fields and the call sites differ between code paths (production +// vs unit tests stubbing connectFunc). +type connectConfig struct { + Host string + Port int + Username string + Password string + Database string + ConnectTimeout time.Duration +} + +// retryConfig controls connect retry on idle/waking endpoints. MaxAttempts is +// the total number of attempts: 1 means no retry, 3 means up to two retries +// with backoff between. We use the count-of-attempts reading rather than +// count-of-retries to match libs/psql.RetryConfig.MaxRetries semantics, so +// behavior stays consistent across the two commands sharing a flag name. +type retryConfig struct { + MaxAttempts int + InitialDelay time.Duration + MaxDelay time.Duration +} + +// connectFunc is a seam for unit tests: production wires pgx.ConnectConfig, +// tests inject failures (DNS, auth, ctx-cancel mid-connect). We deliberately +// do not wrap *pgx.Conn behind an interface for query execution; that surface +// is exercised by integration tests against real Lakebase endpoints. +type connectFunc func(ctx context.Context, cfg *pgx.ConnConfig) (*pgx.Conn, error) + +// buildPgxConfig parses a base DSN to inherit pgx's TLS shape, then patches +// in the resolved values. The DSN-then-patch pattern is the recommended way +// to configure pgx for `sslmode=require` because building a pgx.ConnConfig +// by hand omits internal fields that the parser sets. +func buildPgxConfig(c connectConfig) (*pgx.ConnConfig, error) { + cfg, err := pgx.ParseConfig("postgresql:///?sslmode=require") + if err != nil { + return nil, fmt.Errorf("parse pgx config: %w", err) + } + cfg.Host = c.Host + cfg.Port = uint16(c.Port) + cfg.User = c.Username + cfg.Password = c.Password + cfg.Database = c.Database + cfg.ConnectTimeout = c.ConnectTimeout + return cfg, nil +} + +// connectWithRetry dials Postgres, retrying on connect-time errors that +// indicate the endpoint is asleep or in the middle of a wake-up. Errors that +// cannot be improved by retrying (auth failures, permission errors, +// post-query errors) are returned immediately. +func connectWithRetry(ctx context.Context, cfg *pgx.ConnConfig, rc retryConfig, dial connectFunc) (*pgx.Conn, error) { + if rc.MaxAttempts < 1 { + rc.MaxAttempts = 1 + } + + delay := rc.InitialDelay + var lastErr error + + for attempt := 1; attempt <= rc.MaxAttempts; attempt++ { + if attempt > 1 { + cmdio.LogString(ctx, fmt.Sprintf("Connection attempt %d/%d failed, retrying in %v...", attempt-1, rc.MaxAttempts, delay)) + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-time.After(delay): + } + if rc.MaxDelay > 0 { + delay = min(delay*2, rc.MaxDelay) + } + } + + conn, err := dial(ctx, cfg) + if err == nil { + return conn, nil + } + lastErr = err + + if !isRetryableConnectError(err) { + return nil, err + } + log.Debugf(ctx, "retryable connect error on attempt %d: %v", attempt, err) + } + + return nil, fmt.Errorf("failed to connect after %d attempts: %w", rc.MaxAttempts, lastErr) +} + +// isRetryableConnectError classifies whether an error from the connect path +// is a transient "endpoint asleep / cold-starting" failure. +// +// Retryable: +// - net.OpError with Op == "dial" (DNS resolution, TCP connect refused, +// host unreachable). The "endpoint asleep" cases. +// - pgconn.ConnectError that wraps a retryable network error. +// - Postgres connection-establishment SQLSTATE codes (08xxx). Lakebase +// emits these during cold-start. +// +// Not retryable: auth errors (28xxx), permission errors (42501), +// context cancellation/deadlines, anything after Query has been issued +// (caller never passes that to us; we only run before Query). +func isRetryableConnectError(err error) bool { + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return false + } + + var pgErr *pgconn.PgError + if errors.As(err, &pgErr) { + // 08xxx is the connection_exception class. + if len(pgErr.Code) == 5 && pgErr.Code[:2] == "08" { + return true + } + return false + } + + var connectErr *pgconn.ConnectError + if errors.As(err, &connectErr) { + return isRetryableConnectError(connectErr.Unwrap()) + } + + var opErr *net.OpError + if errors.As(err, &opErr) { + return opErr.Op == "dial" + } + + return false +} diff --git a/experimental/postgres/cmd/connect_test.go b/experimental/postgres/cmd/connect_test.go new file mode 100644 index 00000000000..0f7614b1f31 --- /dev/null +++ b/experimental/postgres/cmd/connect_test.go @@ -0,0 +1,149 @@ +package postgrescmd + +import ( + "context" + "errors" + "net" + "testing" + "time" + + "github.com/databricks/cli/libs/cmdio" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func testCtx(t *testing.T) context.Context { + return cmdio.MockDiscard(t.Context()) +} + +func TestIsRetryableConnectError(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + { + name: "dial error", + err: &net.OpError{Op: "dial", Err: errors.New("connection refused")}, + want: true, + }, + { + name: "non-dial net.OpError", + err: &net.OpError{Op: "read", Err: errors.New("oops")}, + want: false, + }, + { + name: "08006 connection failure", + err: &pgconn.PgError{Code: "08006", Message: "connection failure"}, + want: true, + }, + { + name: "08001 cannot establish", + err: &pgconn.PgError{Code: "08001", Message: "sqlclient unable to establish sqlconnection"}, + want: true, + }, + { + name: "28000 invalid auth", + err: &pgconn.PgError{Code: "28000", Message: "invalid authorization specification"}, + want: false, + }, + { + name: "28P01 invalid password", + err: &pgconn.PgError{Code: "28P01", Message: "invalid password"}, + want: false, + }, + { + name: "42501 insufficient privilege", + err: &pgconn.PgError{Code: "42501", Message: "permission denied"}, + want: false, + }, + { + name: "context cancelled", + err: context.Canceled, + want: false, + }, + { + name: "context deadline exceeded", + err: context.DeadlineExceeded, + want: false, + }, + { + name: "nil error never retryable", + err: nil, + want: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + assert.Equal(t, tc.want, isRetryableConnectError(tc.err)) + }) + } +} + +func TestConnectWithRetry_RespectsMaxAttempts(t *testing.T) { + ctx := testCtx(t) + calls := 0 + dialErr := &pgconn.PgError{Code: "08006"} + dial := func(ctx context.Context, cfg *pgx.ConnConfig) (*pgx.Conn, error) { + calls++ + return nil, dialErr + } + cfg := &pgx.ConnConfig{} + rc := retryConfig{MaxAttempts: 3, InitialDelay: 0, MaxDelay: 0} + + _, err := connectWithRetry(ctx, cfg, rc, dial) + require.Error(t, err) + assert.Equal(t, 3, calls, "expected 3 attempts (1 initial + 2 retries)") +} + +func TestConnectWithRetry_StopsOnNonRetryable(t *testing.T) { + ctx := testCtx(t) + calls := 0 + authErr := &pgconn.PgError{Code: "28P01"} + dial := func(ctx context.Context, cfg *pgx.ConnConfig) (*pgx.Conn, error) { + calls++ + return nil, authErr + } + cfg := &pgx.ConnConfig{} + rc := retryConfig{MaxAttempts: 3, InitialDelay: 0} + + _, err := connectWithRetry(ctx, cfg, rc, dial) + require.Error(t, err) + assert.Equal(t, 1, calls, "auth errors should not retry") +} + +func TestConnectWithRetry_ZeroMaxAttemptsTreatedAsOne(t *testing.T) { + ctx := testCtx(t) + calls := 0 + dial := func(ctx context.Context, cfg *pgx.ConnConfig) (*pgx.Conn, error) { + calls++ + return nil, errors.New("nope") + } + cfg := &pgx.ConnConfig{} + rc := retryConfig{MaxAttempts: 0, InitialDelay: time.Millisecond} + + _, err := connectWithRetry(ctx, cfg, rc, dial) + require.Error(t, err) + assert.Equal(t, 1, calls) +} + +func TestBuildPgxConfig(t *testing.T) { + cfg, err := buildPgxConfig(connectConfig{ + Host: "host.example.com", + Port: 5432, + Username: "user", + Password: "secret", + Database: "db", + ConnectTimeout: 30 * time.Second, + }) + require.NoError(t, err) + assert.Equal(t, "host.example.com", cfg.Host) + assert.Equal(t, uint16(5432), cfg.Port) + assert.Equal(t, "user", cfg.User) + assert.Equal(t, "secret", cfg.Password) + assert.Equal(t, "db", cfg.Database) + assert.Equal(t, 30*time.Second, cfg.ConnectTimeout) +} diff --git a/experimental/postgres/cmd/execute.go b/experimental/postgres/cmd/execute.go new file mode 100644 index 00000000000..c29f7ce59d6 --- /dev/null +++ b/experimental/postgres/cmd/execute.go @@ -0,0 +1,62 @@ +package postgrescmd + +import ( + "context" + "fmt" + + "github.com/jackc/pgx/v5" +) + +// executeOne runs a single SQL statement against an open connection and +// captures the result in a queryResult. +// +// We pass QueryExecModeExec explicitly (not the pgx default +// QueryExecModeCacheStatement) for two reasons: +// +// 1. Statement caching has no benefit for a one-shot CLI: the connection is +// closed at the end of the command, so the cached prepared statement +// never gets reused. +// 2. Exec mode uses Postgres' extended-protocol "exec" path with text-format +// result columns. That makes rendering canonical-Postgres-text output +// (PR 1) and CSV (later PR) straightforward; the cache mode defaults to +// binary and we'd be reformatting back to text. +// +// QueryExecModeExec still uses extended protocol with a single statement and +// no implicit transaction wrap, so transaction-disallowed DDL like +// `CREATE DATABASE` works. +func executeOne(ctx context.Context, conn *pgx.Conn, sql string) (*queryResult, error) { + rows, err := conn.Query(ctx, sql, pgx.QueryExecModeExec) + if err != nil { + return nil, fmt.Errorf("query failed: %w", err) + } + defer rows.Close() + + result := &queryResult{SQL: sql} + + fields := rows.FieldDescriptions() + if len(fields) > 0 { + result.Columns = make([]string, len(fields)) + for i, f := range fields { + result.Columns[i] = f.Name + } + } + + for rows.Next() { + raw := rows.RawValues() + row := make([]string, len(raw)) + for i, b := range raw { + if b == nil { + row[i] = "NULL" + continue + } + row[i] = string(b) + } + result.Rows = append(result.Rows, row) + } + if err := rows.Err(); err != nil { + return nil, fmt.Errorf("query failed: %w", err) + } + + result.CommandTag = rows.CommandTag().String() + return result, nil +} diff --git a/experimental/postgres/cmd/query.go b/experimental/postgres/cmd/query.go new file mode 100644 index 00000000000..643aa496e84 --- /dev/null +++ b/experimental/postgres/cmd/query.go @@ -0,0 +1,133 @@ +package postgrescmd + +import ( + "context" + "errors" + "fmt" + "strings" + "time" + + "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/libs/cmdio" + "github.com/jackc/pgx/v5" + "github.com/spf13/cobra" +) + +// defaultDatabase is the database name used when --database is not set. +// Lakebase Autoscaling and Provisioned both use this name as their default. +const defaultDatabase = "databricks_postgres" + +// queryFlags is the union of every flag the query command exposes. Lifted +// out of newQueryCmd so unit-tested helpers (resolveTarget, etc.) can take +// it directly without poking at cobra internals. +type queryFlags struct { + targetingFlags + database string + connectTimeout time.Duration + maxRetries int +} + +func newQueryCmd() *cobra.Command { + var f queryFlags + + cmd := &cobra.Command{ + Use: "query [SQL]", + Short: "Run a SQL statement against a Lakebase Postgres endpoint", + GroupID: "", + Long: `Execute a single SQL statement against a Lakebase Postgres endpoint and +render the result as text. + +Targeting (exactly one form required): + --target STRING Autoscaling resource path + (e.g. projects/foo/branches/main/endpoints/primary) + --project ID Autoscaling project ID + --branch ID Autoscaling branch ID (default: auto-select if exactly one) + --endpoint ID Autoscaling endpoint ID (default: auto-select if exactly one) + +This is an experimental command. The flag set, output shape, and supported +target kinds will expand in subsequent releases. + +Limitations (this release): + + - Single SQL statement per invocation (multi-statement support comes later). + - Text output only. JSON and CSV output come in a follow-up release. + - Only Lakebase Autoscaling endpoints are supported. Provisioned instance + support comes in a follow-up release; use 'databricks psql ' as a + workaround for now. + - No interactive REPL. 'databricks psql' continues to own that surface. + - Multi-statement strings (e.g. "SELECT 1; SELECT 2") are not supported. + - The OAuth token is generated once per invocation and is valid for 1h. + Queries longer than that fail with an auth error. +`, + Args: cobra.ExactArgs(1), + PreRunE: root.MustWorkspaceClient, + RunE: func(cmd *cobra.Command, args []string) error { + return runQuery(cmd.Context(), cmd, args[0], f) + }, + } + + cmd.Flags().StringVar(&f.target, "target", "", "Autoscaling resource path (e.g. projects/foo/branches/main/endpoints/primary)") + cmd.Flags().StringVar(&f.project, "project", "", "Autoscaling project ID") + cmd.Flags().StringVar(&f.branch, "branch", "", "Autoscaling branch ID (default: auto-select if exactly one)") + cmd.Flags().StringVar(&f.endpoint, "endpoint", "", "Autoscaling endpoint ID (default: auto-select if exactly one)") + cmd.Flags().StringVarP(&f.database, "database", "d", defaultDatabase, "Database name") + cmd.Flags().DurationVar(&f.connectTimeout, "connect-timeout", defaultConnectTimeout, "Connect timeout") + cmd.Flags().IntVar(&f.maxRetries, "max-retries", 3, "Total connect attempts on idle/waking endpoint (1 disables retry)") + + cmd.MarkFlagsMutuallyExclusive("target", "project") + cmd.MarkFlagsMutuallyExclusive("target", "branch") + cmd.MarkFlagsMutuallyExclusive("target", "endpoint") + + return cmd +} + +// runQuery is the production entry point. It is split out from RunE so unit +// tests can call it directly with a stubbed connectFunc once we add seam-based +// tests in a later PR. +func runQuery(ctx context.Context, cmd *cobra.Command, sql string, f queryFlags) error { + sql = strings.TrimSpace(sql) + if sql == "" { + return errors.New("no SQL provided") + } + if err := validateTargeting(f.targetingFlags); err != nil { + return err + } + + resolved, err := resolveTarget(ctx, f.targetingFlags) + if err != nil { + return err + } + + cmdio.LogString(ctx, fmt.Sprintf("Connecting to %s...", resolved.DisplayName)) + + pgxCfg, err := buildPgxConfig(connectConfig{ + Host: resolved.Host, + Port: 5432, + Username: resolved.Username, + Password: resolved.Token, + Database: f.database, + ConnectTimeout: f.connectTimeout, + }) + if err != nil { + return err + } + + rc := retryConfig{ + MaxAttempts: max(1, f.maxRetries), + InitialDelay: time.Second, + MaxDelay: 10 * time.Second, + } + + conn, err := connectWithRetry(ctx, pgxCfg, rc, pgx.ConnectConfig) + if err != nil { + return err + } + defer conn.Close(context.WithoutCancel(ctx)) + + result, err := executeOne(ctx, conn, sql) + if err != nil { + return err + } + + return renderText(cmd.OutOrStdout(), result) +} diff --git a/experimental/postgres/cmd/render.go b/experimental/postgres/cmd/render.go new file mode 100644 index 00000000000..ff923c4a92e --- /dev/null +++ b/experimental/postgres/cmd/render.go @@ -0,0 +1,74 @@ +package postgrescmd + +import ( + "fmt" + "io" + "strings" + "text/tabwriter" +) + +// queryResult is the rendered shape of a single SQL execution. PR 1 only +// renders text; later PRs add JSON and CSV against the same struct. +// +// columns is empty for command-only statements (INSERT, CREATE DATABASE, ...); +// rows is empty when no rows were returned (or for command-only statements). +type queryResult struct { + SQL string + // CommandTag is the Postgres command tag for the statement (e.g. "INSERT 0 5", + // "CREATE DATABASE"). Always set; used for command-only statements and as a + // trailer for rows-producing ones. + CommandTag string + Columns []string + Rows [][]string +} + +// IsRowsProducing reports whether the statement returned a row description. +// Determined at runtime via FieldDescriptions() rather than by parsing the +// leading SQL keyword: `INSERT ... RETURNING` and CTEs ending in a SELECT are +// rows-producing despite their leading keywords. +func (r *queryResult) IsRowsProducing() bool { + return len(r.Columns) > 0 +} + +// renderText writes a result in plain text. +// +// For rows-producing statements we use a tabwriter-aligned table followed by +// a `(N rows)` footer, mimicking psql's compact text shape. For command-only +// statements we just print the command tag. +// +// PR 1 always uses the static (buffered) shape. The interactive table viewer +// for >30 rows lands in a later PR alongside the multi-input output shapes. +func renderText(out io.Writer, r *queryResult) error { + if !r.IsRowsProducing() { + _, err := fmt.Fprintln(out, r.CommandTag) + return err + } + + tw := tabwriter.NewWriter(out, 0, 0, 2, ' ', 0) + fmt.Fprintln(tw, strings.Join(r.Columns, "\t")) + fmt.Fprintln(tw, strings.Join(headerSeparator(r.Columns), "\t")) + for _, row := range r.Rows { + fmt.Fprintln(tw, strings.Join(row, "\t")) + } + if err := tw.Flush(); err != nil { + return err + } + + _, err := fmt.Fprintf(out, "(%d %s)\n", len(r.Rows), pluralize(len(r.Rows), "row", "rows")) + return err +} + +func headerSeparator(cols []string) []string { + out := make([]string, len(cols)) + for i, c := range cols { + out[i] = strings.Repeat("-", max(len(c), 3)) + } + return out +} + +func pluralize(n int, singular, plural string) string { + if n == 1 { + return singular + } + return plural +} diff --git a/experimental/postgres/cmd/render_test.go b/experimental/postgres/cmd/render_test.go new file mode 100644 index 00000000000..29aeb3c36fc --- /dev/null +++ b/experimental/postgres/cmd/render_test.go @@ -0,0 +1,67 @@ +package postgrescmd + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestRenderText_RowsProducing(t *testing.T) { + r := &queryResult{ + Columns: []string{"id", "name"}, + Rows: [][]string{ + {"1", "alice"}, + {"2", "bob"}, + }, + CommandTag: "SELECT 2", + } + var buf bytes.Buffer + require.NoError(t, renderText(&buf, r)) + + assert.Equal(t, + "id name\n"+ + "--- ----\n"+ + "1 alice\n"+ + "2 bob\n"+ + "(2 rows)\n", + buf.String(), + ) +} + +func TestRenderText_SingleRow(t *testing.T) { + r := &queryResult{ + Columns: []string{"id"}, + Rows: [][]string{{"42"}}, + CommandTag: "SELECT 1", + } + var buf bytes.Buffer + require.NoError(t, renderText(&buf, r)) + assert.Contains(t, buf.String(), "(1 row)\n") +} + +func TestRenderText_Empty(t *testing.T) { + r := &queryResult{ + Columns: []string{"id", "name"}, + CommandTag: "SELECT 0", + } + var buf bytes.Buffer + require.NoError(t, renderText(&buf, r)) + assert.Contains(t, buf.String(), "(0 rows)\n") +} + +func TestRenderText_CommandOnly(t *testing.T) { + r := &queryResult{ + CommandTag: "INSERT 0 5", + } + var buf bytes.Buffer + require.NoError(t, renderText(&buf, r)) + assert.Equal(t, "INSERT 0 5\n", buf.String()) +} + +func TestQueryResultIsRowsProducing(t *testing.T) { + assert.False(t, (&queryResult{}).IsRowsProducing()) + assert.False(t, (&queryResult{CommandTag: "INSERT 0 1"}).IsRowsProducing()) + assert.True(t, (&queryResult{Columns: []string{"a"}}).IsRowsProducing()) +} diff --git a/experimental/postgres/cmd/targeting.go b/experimental/postgres/cmd/targeting.go new file mode 100644 index 00000000000..e8a17fadfce --- /dev/null +++ b/experimental/postgres/cmd/targeting.go @@ -0,0 +1,173 @@ +package postgrescmd + +import ( + "context" + "errors" + "fmt" + + "github.com/databricks/cli/libs/cmdctx" + "github.com/databricks/cli/libs/lakebase/target" + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/service/postgres" +) + +// resolvedTarget carries everything the query command needs to dial Postgres: +// the endpoint host (resolved through the SDK) and a short-lived OAuth token. +// `kind` records whether we resolved an autoscaling endpoint or a provisioned +// instance, so the caller can pick the right default database name and emit +// kind-appropriate logging. +type resolvedTarget struct { + Kind targetKind + Host string + Username string + Token string + // Display strings used only for human-readable logs / errors. + DisplayName string +} + +type targetKind int + +const ( + kindAutoscaling targetKind = iota + kindProvisioned +) + +// targetingFlags is the user-supplied targeting input. Exactly one of: +// - target (full path or instance name) +// - project (with optional branch and endpoint) +// +// must be set. Validated by validateTargeting before any SDK call. +type targetingFlags struct { + target string + project string + branch string + endpoint string +} + +func (f targetingFlags) hasGranular() bool { + return f.project != "" || f.branch != "" || f.endpoint != "" +} + +// validateTargeting enforces "exactly one targeting form" before any SDK call. +// Returns a typed error so the JSON envelope renderer (added in a later PR) +// can surface a structured error. +func validateTargeting(f targetingFlags) error { + switch { + case f.target == "" && !f.hasGranular(): + return errors.New("must specify --target or --project") + case f.target != "" && f.hasGranular(): + return errors.New("--target is mutually exclusive with --project, --branch, --endpoint") + case f.target == "" && f.project == "" && (f.branch != "" || f.endpoint != ""): + return errors.New("--project is required when using --branch or --endpoint") + } + return nil +} + +// resolveTarget translates the validated flags into a resolvedTarget. +// PR 1 supports autoscaling targeting only; provisioned support is added in +// the next PR. A provisioned-shaped --target returns a clear error pointing at +// the experimental status. +func resolveTarget(ctx context.Context, f targetingFlags) (*resolvedTarget, error) { + w := cmdctx.WorkspaceClient(ctx) + + switch { + case f.target != "" && target.IsAutoscalingPath(f.target): + spec, err := target.ParseAutoscalingPath(f.target) + if err != nil { + return nil, err + } + return resolveAutoscaling(ctx, w, spec) + + case f.target != "": + // Provisioned-shaped target. Out of scope for this PR; will be wired in + // the follow-up PR alongside JSON/CSV output. + return nil, errors.New("provisioned instances are not yet supported by this experimental command; use 'databricks psql ' for now") + + default: + spec := target.AutoscalingSpec{ + ProjectID: f.project, + BranchID: f.branch, + EndpointID: f.endpoint, + } + return resolveAutoscaling(ctx, w, spec) + } +} + +// resolveAutoscaling expands a partial spec into a fully-resolved endpoint and +// issues a short-lived OAuth token. Missing branch/endpoint IDs are +// auto-selected when exactly one candidate exists; ambiguity propagates as an +// AmbiguousError with the list of choices. +func resolveAutoscaling(ctx context.Context, w *databricks.WorkspaceClient, spec target.AutoscalingSpec) (*resolvedTarget, error) { + if spec.ProjectID == "" { + var err error + spec.ProjectID, err = target.AutoSelectProject(ctx, w) + if err != nil { + return nil, err + } + } + + project, err := target.GetProject(ctx, w, spec.ProjectID) + if err != nil { + return nil, fmt.Errorf("failed to get project: %w", err) + } + + if spec.BranchID == "" { + spec.BranchID, err = target.AutoSelectBranch(ctx, w, project.Name) + if err != nil { + return nil, err + } + } + + if spec.EndpointID == "" { + branchName := project.Name + "/branches/" + spec.BranchID + spec.EndpointID, err = target.AutoSelectEndpoint(ctx, w, branchName) + if err != nil { + return nil, err + } + } + + endpoint, err := target.GetEndpoint(ctx, w, spec.ProjectID, spec.BranchID, spec.EndpointID) + if err != nil { + return nil, fmt.Errorf("failed to get endpoint: %w", err) + } + + if err := checkEndpointReady(endpoint); err != nil { + return nil, err + } + + user, err := w.CurrentUser.Me(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get current user: %w", err) + } + + token, err := target.AutoscalingCredential(ctx, w, endpoint.Name) + if err != nil { + return nil, err + } + + return &resolvedTarget{ + Kind: kindAutoscaling, + Host: endpoint.Status.Hosts.Host, + Username: user.UserName, + Token: token, + DisplayName: endpoint.Name, + }, nil +} + +// checkEndpointReady returns an error if the endpoint is not in a connectable +// state. Idle endpoints are considered connectable (Lakebase wakes them on +// dial); the connect retry loop handles the wake-up window. +func checkEndpointReady(endpoint *postgres.Endpoint) error { + if endpoint.Status == nil { + return errors.New("endpoint status is not available") + } + if endpoint.Status.Hosts == nil || endpoint.Status.Hosts.Host == "" { + return errors.New("endpoint host information is not available") + } + switch endpoint.Status.CurrentState { + case postgres.EndpointStatusStateActive, postgres.EndpointStatusStateIdle: + return nil + default: + return fmt.Errorf("endpoint is not ready for accepting connections (state: %s)", endpoint.Status.CurrentState) + } +} diff --git a/experimental/postgres/cmd/targeting_test.go b/experimental/postgres/cmd/targeting_test.go new file mode 100644 index 00000000000..dfdab0d405c --- /dev/null +++ b/experimental/postgres/cmd/targeting_test.go @@ -0,0 +1,81 @@ +package postgrescmd + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestValidateTargeting(t *testing.T) { + tests := []struct { + name string + flags targetingFlags + wantErr string + }{ + { + name: "neither form", + flags: targetingFlags{}, + wantErr: "must specify --target or --project", + }, + { + name: "only target", + flags: targetingFlags{ + target: "projects/foo", + }, + }, + { + name: "only project", + flags: targetingFlags{ + project: "foo", + }, + }, + { + name: "project and branch", + flags: targetingFlags{ + project: "foo", + branch: "main", + }, + }, + { + name: "project, branch, endpoint", + flags: targetingFlags{ + project: "foo", + branch: "main", + endpoint: "primary", + }, + }, + { + name: "target and project both set", + flags: targetingFlags{ + target: "projects/foo", + project: "foo", + }, + wantErr: "mutually exclusive", + }, + { + name: "branch without project", + flags: targetingFlags{ + branch: "main", + }, + wantErr: "--project is required when using --branch or --endpoint", + }, + { + name: "endpoint without project", + flags: targetingFlags{ + endpoint: "primary", + }, + wantErr: "--project is required when using --branch or --endpoint", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := validateTargeting(tc.flags) + if tc.wantErr != "" { + assert.ErrorContains(t, err, tc.wantErr) + return + } + assert.NoError(t, err) + }) + } +} diff --git a/go.mod b/go.mod index f376aa0a98d..170414de39f 100644 --- a/go.mod +++ b/go.mod @@ -23,6 +23,7 @@ require ( github.com/hashicorp/terraform-exec v0.25.0 // MPL-2.0 github.com/hashicorp/terraform-json v0.27.2 // MPL-2.0 github.com/hexops/gotextdiff v1.0.3 // BSD-3-Clause + github.com/jackc/pgx/v5 v5.9.1 // MIT github.com/manifoldco/promptui v0.9.0 // BSD-3-Clause github.com/mattn/go-isatty v0.0.20 // MIT github.com/nwidger/jsoncolor v0.3.2 // MIT @@ -80,6 +81,8 @@ require ( github.com/hashicorp/go-cleanhttp v0.5.2 // indirect github.com/hashicorp/go-retryablehttp v0.7.8 // indirect github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/lucasb-eyer/go-colorful v1.3.0 // indirect github.com/mattn/go-colorable v0.1.14 // indirect github.com/mattn/go-localereader v0.0.1 // indirect diff --git a/go.sum b/go.sum index f9181b898a2..715807887cd 100644 --- a/go.sum +++ b/go.sum @@ -147,6 +147,14 @@ github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUq github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo= +github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.9.1 h1:uwrxJXBnx76nyISkhr33kQLlUqjv7et7b9FjCen/tdc= +github.com/jackc/pgx/v5 v5.9.1/go.mod h1:mal1tBGAFfLHvZzaYh77YS/eC6IX9OWbRV1QIIM0Jn4= +github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= +github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 h1:BQSFePA1RWJOlocH6Fxy8MmwDt+yVQYULKfN0RoTN8A= github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99/go.mod h1:1lJo3i6rXxKeerYnT8Nvf0QmHCRC1n8sfWVwXF2Frvo= github.com/kevinburke/ssh_config v1.2.0 h1:x584FjTGwHzMwvHx18PXxbBVzfnxogHaAReU4gf13a4= @@ -213,7 +221,9 @@ github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSS github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= From edbd6bef6bacee89210398b51c3e37a8582ec3fd Mon Sep 17 00:00:00 2001 From: simon Date: Thu, 30 Apr 2026 10:45:56 +0200 Subject: [PATCH 03/15] Address review feedback on PR 1 - Replace exported PathSegmentProjects/ExtractID with focused helpers (ProjectIDFromName/BranchIDFromName/EndpointIDFromName); keeps SDK literals out of call sites. - Type AmbiguousError.Kind as a typed enum (KindProject/Branch/Endpoint/Instance) so producers and the pluralisation switch stay in sync. - Stop setting Choice.DisplayName when it equals the ID; Error() relies on empty-suppression rather than mixed empty/equal-to-ID checks. - Add 57P03 (cannot_connect_now) to the connect-retry allow-list. Postgres emits this during server startup and Lakebase autoscaling can plausibly return it during the wake-up handshake. Tests exercise 57P03/57P01/57014 to lock the boundary. - Require --branch when --endpoint is set. The auto-select-then-look-up flow produces confusing errors when the auto-selected branch does not contain the requested endpoint, and this command is non-interactive so asking the user to be explicit is friendlier. - Reject --max-retries < 1 explicitly instead of silently clamping. Help text already advertised the constraint; matching it at validation time is consistent with the repo's "reject incompatible inputs early" rule. - Harmonise the "endpoint is not ready" error in cmd/psql to include the state, matching the experimental command and giving operators something to act on. - Restore comments removed during the cmd/psql refactor and add a breadcrumb at the GetProvisioned call site about the Name patch. - Add doc comments to AutoSelect* helpers documenting the returned string shape (trailing ID for autoscaling vs full name for provisioned). - Reject trailing components after endpoint in ParseAutoscalingPath; new acceptance test in cmd/psql exercises this. - Drop dead GroupID: "" assignment. Co-authored-by: Isaac --- .../postgres/query/argument-errors/output.txt | 8 +++ .../postgres/query/argument-errors/script | 6 ++ .../cmd/psql/argument-errors/output.txt | 4 ++ acceptance/cmd/psql/argument-errors/script | 3 + acceptance/cmd/psql/postgres/output.txt | 2 +- cmd/psql/psql.go | 4 +- cmd/psql/psql_autoscaling.go | 2 +- cmd/psql/psql_provisioned.go | 3 + experimental/postgres/cmd/connect.go | 14 ++-- experimental/postgres/cmd/connect_test.go | 30 ++++---- experimental/postgres/cmd/query.go | 12 ++-- experimental/postgres/cmd/targeting.go | 7 ++ experimental/postgres/cmd/targeting_test.go | 8 +++ libs/lakebase/target/autoscaling.go | 51 +++++++------ libs/lakebase/target/provisioned.go | 10 +-- libs/lakebase/target/target.go | 72 ++++++++++++++----- libs/lakebase/target/target_test.go | 39 ++++++---- 17 files changed, 188 insertions(+), 87 deletions(-) diff --git a/acceptance/cmd/experimental/postgres/query/argument-errors/output.txt b/acceptance/cmd/experimental/postgres/query/argument-errors/output.txt index 59ddbfedc6e..c071466a1e3 100644 --- a/acceptance/cmd/experimental/postgres/query/argument-errors/output.txt +++ b/acceptance/cmd/experimental/postgres/query/argument-errors/output.txt @@ -27,6 +27,14 @@ Error: --project is required when using --branch or --endpoint >>> musterr [CLI] experimental postgres query --endpoint primary SELECT 1 Error: --project is required when using --branch or --endpoint +=== Endpoint without branch should error: +>>> musterr [CLI] experimental postgres query --project foo --endpoint primary SELECT 1 +Error: --branch is required when using --endpoint + +=== Max-retries 0 should error: +>>> musterr [CLI] experimental postgres query --project foo --branch main --max-retries 0 SELECT 1 +Error: --max-retries must be at least 1; got 0 + === Provisioned-shaped target should error pointing at psql: >>> musterr [CLI] experimental postgres query --target my-instance SELECT 1 Error: provisioned instances are not yet supported by this experimental command; use 'databricks psql ' for now diff --git a/acceptance/cmd/experimental/postgres/query/argument-errors/script b/acceptance/cmd/experimental/postgres/query/argument-errors/script index 5874c843a03..8d64bf307ed 100644 --- a/acceptance/cmd/experimental/postgres/query/argument-errors/script +++ b/acceptance/cmd/experimental/postgres/query/argument-errors/script @@ -19,6 +19,12 @@ trace musterr $CLI experimental postgres query --branch main "SELECT 1" title "Endpoint without project should error:" trace musterr $CLI experimental postgres query --endpoint primary "SELECT 1" +title "Endpoint without branch should error:" +trace musterr $CLI experimental postgres query --project foo --endpoint primary "SELECT 1" + +title "Max-retries 0 should error:" +trace musterr $CLI experimental postgres query --project foo --branch main --max-retries 0 "SELECT 1" + title "Provisioned-shaped target should error pointing at psql:" trace musterr $CLI experimental postgres query --target my-instance "SELECT 1" diff --git a/acceptance/cmd/psql/argument-errors/output.txt b/acceptance/cmd/psql/argument-errors/output.txt index 35da5961dec..cbf6c093b21 100644 --- a/acceptance/cmd/psql/argument-errors/output.txt +++ b/acceptance/cmd/psql/argument-errors/output.txt @@ -59,6 +59,10 @@ Error: invalid resource path: missing branch ID >>> musterr [CLI] psql projects/my-project/branches/main/endpoints/ Error: invalid resource path: missing endpoint ID +=== Trailing components after endpoint should error: +>>> musterr [CLI] psql projects/my-project/branches/main/endpoints/primary/extra +Error: invalid resource path: trailing components after endpoint: projects/my-project/branches/main/endpoints/primary/extra + === Provisioned flag with --project should error: >>> musterr [CLI] psql --provisioned --project foo Error: cannot use --project, --branch, or --endpoint flags with --provisioned diff --git a/acceptance/cmd/psql/argument-errors/script b/acceptance/cmd/psql/argument-errors/script index 7806efb0744..7db1cdbd271 100644 --- a/acceptance/cmd/psql/argument-errors/script +++ b/acceptance/cmd/psql/argument-errors/script @@ -38,6 +38,9 @@ trace musterr $CLI psql projects/my-project/branches/ title "Invalid path with missing endpoint ID should error:" trace musterr $CLI psql projects/my-project/branches/main/endpoints/ +title "Trailing components after endpoint should error:" +trace musterr $CLI psql projects/my-project/branches/main/endpoints/primary/extra + title "Provisioned flag with --project should error:" trace musterr $CLI psql --provisioned --project foo diff --git a/acceptance/cmd/psql/postgres/output.txt b/acceptance/cmd/psql/postgres/output.txt index 5269553a0ce..8df91c6321c 100644 --- a/acceptance/cmd/psql/postgres/output.txt +++ b/acceptance/cmd/psql/postgres/output.txt @@ -50,7 +50,7 @@ PGSSLMODE=require Project: Init Project Branch: main Endpoint: init-ep -Error: endpoint is not ready for accepting connections +Error: endpoint is not ready for accepting connections (state: INIT) === Branch flag without project should fail: >>> musterr [CLI] psql --branch some-branch diff --git a/cmd/psql/psql.go b/cmd/psql/psql.go index e5cfaff5cff..9be7fb5c5df 100644 --- a/cmd/psql/psql.go +++ b/cmd/psql/psql.go @@ -257,7 +257,7 @@ func showSelectionAndConnect(ctx context.Context, retryConfig libpsql.RetryConfi }) } for _, proj := range projects { - displayName := target.ExtractID(proj.Name, target.PathSegmentProjects) + displayName := target.ProjectIDFromName(proj.Name) if proj.Status != nil && proj.Status.DisplayName != "" { displayName = proj.Status.DisplayName } @@ -278,7 +278,7 @@ func showSelectionAndConnect(ctx context.Context, retryConfig libpsql.RetryConfi } if after, ok := strings.CutPrefix(selected, "autoscaling:"); ok { projectName := after - projectID := target.ExtractID(projectName, target.PathSegmentProjects) + projectID := target.ProjectIDFromName(projectName) return connectAutoscaling(ctx, projectID, "", "", retryConfig, extraArgs) } diff --git a/cmd/psql/psql_autoscaling.go b/cmd/psql/psql_autoscaling.go index 4273dad3b50..a4c3293cc18 100644 --- a/cmd/psql/psql_autoscaling.go +++ b/cmd/psql/psql_autoscaling.go @@ -61,7 +61,7 @@ func connectAutoscaling(ctx context.Context, projectID, branchID, endpointID str case postgres.EndpointStatusStateIdle: suffix = " (idle, waking up)" default: - return errors.New("endpoint is not ready for accepting connections") + return fmt.Errorf("endpoint is not ready for accepting connections (state: %s)", state) } cmdio.LogString(ctx, fmt.Sprintf("Connecting to %s endpoint%s...", endpointType, suffix)) diff --git a/cmd/psql/psql_provisioned.go b/cmd/psql/psql_provisioned.go index 9ea88def5ce..c7208906aa8 100644 --- a/cmd/psql/psql_provisioned.go +++ b/cmd/psql/psql_provisioned.go @@ -58,6 +58,7 @@ func connectProvisioned(ctx context.Context, instanceName string, retryConfig li // resolveInstance resolves an instance name to a full instance object. // If instanceName is empty, prompts the user to select one. func resolveInstance(ctx context.Context, w *databricks.WorkspaceClient, instanceName string) (*database.DatabaseInstance, error) { + // If instance not specified, select one if instanceName == "" { var err error instanceName, err = selectInstanceID(ctx, w) @@ -66,6 +67,8 @@ func resolveInstance(ctx context.Context, w *databricks.WorkspaceClient, instanc } } + // target.GetProvisioned patches Name on the response; the SDK's + // GetDatabaseInstance does not always populate it. instance, err := target.GetProvisioned(ctx, w, instanceName) if err != nil { return nil, err diff --git a/experimental/postgres/cmd/connect.go b/experimental/postgres/cmd/connect.go index a0674b81ead..920f02e932f 100644 --- a/experimental/postgres/cmd/connect.go +++ b/experimental/postgres/cmd/connect.go @@ -70,11 +70,10 @@ func buildPgxConfig(c connectConfig) (*pgx.ConnConfig, error) { // indicate the endpoint is asleep or in the middle of a wake-up. Errors that // cannot be improved by retrying (auth failures, permission errors, // post-query errors) are returned immediately. +// +// MaxAttempts must be >= 1 (caller validates). 1 means a single attempt +// with no retries. func connectWithRetry(ctx context.Context, cfg *pgx.ConnConfig, rc retryConfig, dial connectFunc) (*pgx.Conn, error) { - if rc.MaxAttempts < 1 { - rc.MaxAttempts = 1 - } - delay := rc.InitialDelay var lastErr error @@ -115,6 +114,10 @@ func connectWithRetry(ctx context.Context, cfg *pgx.ConnConfig, rc retryConfig, // - pgconn.ConnectError that wraps a retryable network error. // - Postgres connection-establishment SQLSTATE codes (08xxx). Lakebase // emits these during cold-start. +// - Postgres "cannot_connect_now" (57P03), which Postgres returns during +// server startup ("the database system is starting up"). Plausibly emitted +// during the wake-up handshake window. We do NOT broaden to all of class 57: +// 57P01/57P02 are admin shutdowns (debatable) and 57014 is query_canceled. // // Not retryable: auth errors (28xxx), permission errors (42501), // context cancellation/deadlines, anything after Query has been issued @@ -130,6 +133,9 @@ func isRetryableConnectError(err error) bool { if len(pgErr.Code) == 5 && pgErr.Code[:2] == "08" { return true } + if pgErr.Code == "57P03" { + return true + } return false } diff --git a/experimental/postgres/cmd/connect_test.go b/experimental/postgres/cmd/connect_test.go index 0f7614b1f31..d58fc52cc74 100644 --- a/experimental/postgres/cmd/connect_test.go +++ b/experimental/postgres/cmd/connect_test.go @@ -44,6 +44,21 @@ func TestIsRetryableConnectError(t *testing.T) { err: &pgconn.PgError{Code: "08001", Message: "sqlclient unable to establish sqlconnection"}, want: true, }, + { + name: "57P03 cannot_connect_now", + err: &pgconn.PgError{Code: "57P03", Message: "the database system is starting up"}, + want: true, + }, + { + name: "57P01 admin shutdown not retryable", + err: &pgconn.PgError{Code: "57P01"}, + want: false, + }, + { + name: "57014 query_canceled not retryable", + err: &pgconn.PgError{Code: "57014"}, + want: false, + }, { name: "28000 invalid auth", err: &pgconn.PgError{Code: "28000", Message: "invalid authorization specification"}, @@ -115,21 +130,6 @@ func TestConnectWithRetry_StopsOnNonRetryable(t *testing.T) { assert.Equal(t, 1, calls, "auth errors should not retry") } -func TestConnectWithRetry_ZeroMaxAttemptsTreatedAsOne(t *testing.T) { - ctx := testCtx(t) - calls := 0 - dial := func(ctx context.Context, cfg *pgx.ConnConfig) (*pgx.Conn, error) { - calls++ - return nil, errors.New("nope") - } - cfg := &pgx.ConnConfig{} - rc := retryConfig{MaxAttempts: 0, InitialDelay: time.Millisecond} - - _, err := connectWithRetry(ctx, cfg, rc, dial) - require.Error(t, err) - assert.Equal(t, 1, calls) -} - func TestBuildPgxConfig(t *testing.T) { cfg, err := buildPgxConfig(connectConfig{ Host: "host.example.com", diff --git a/experimental/postgres/cmd/query.go b/experimental/postgres/cmd/query.go index 643aa496e84..fe5cc528ea7 100644 --- a/experimental/postgres/cmd/query.go +++ b/experimental/postgres/cmd/query.go @@ -31,9 +31,8 @@ func newQueryCmd() *cobra.Command { var f queryFlags cmd := &cobra.Command{ - Use: "query [SQL]", - Short: "Run a SQL statement against a Lakebase Postgres endpoint", - GroupID: "", + Use: "query [SQL]", + Short: "Run a SQL statement against a Lakebase Postgres endpoint", Long: `Execute a single SQL statement against a Lakebase Postgres endpoint and render the result as text. @@ -72,7 +71,7 @@ Limitations (this release): cmd.Flags().StringVar(&f.endpoint, "endpoint", "", "Autoscaling endpoint ID (default: auto-select if exactly one)") cmd.Flags().StringVarP(&f.database, "database", "d", defaultDatabase, "Database name") cmd.Flags().DurationVar(&f.connectTimeout, "connect-timeout", defaultConnectTimeout, "Connect timeout") - cmd.Flags().IntVar(&f.maxRetries, "max-retries", 3, "Total connect attempts on idle/waking endpoint (1 disables retry)") + cmd.Flags().IntVar(&f.maxRetries, "max-retries", 3, "Total connect attempts on idle/waking endpoint (must be >= 1; 1 disables retry)") cmd.MarkFlagsMutuallyExclusive("target", "project") cmd.MarkFlagsMutuallyExclusive("target", "branch") @@ -89,6 +88,9 @@ func runQuery(ctx context.Context, cmd *cobra.Command, sql string, f queryFlags) if sql == "" { return errors.New("no SQL provided") } + if f.maxRetries < 1 { + return fmt.Errorf("--max-retries must be at least 1; got %d", f.maxRetries) + } if err := validateTargeting(f.targetingFlags); err != nil { return err } @@ -113,7 +115,7 @@ func runQuery(ctx context.Context, cmd *cobra.Command, sql string, f queryFlags) } rc := retryConfig{ - MaxAttempts: max(1, f.maxRetries), + MaxAttempts: f.maxRetries, InitialDelay: time.Second, MaxDelay: 10 * time.Second, } diff --git a/experimental/postgres/cmd/targeting.go b/experimental/postgres/cmd/targeting.go index e8a17fadfce..5e72840f952 100644 --- a/experimental/postgres/cmd/targeting.go +++ b/experimental/postgres/cmd/targeting.go @@ -51,6 +51,11 @@ func (f targetingFlags) hasGranular() bool { // validateTargeting enforces "exactly one targeting form" before any SDK call. // Returns a typed error so the JSON envelope renderer (added in a later PR) // can surface a structured error. +// +// We require --branch when --endpoint is set: this command is non-interactive +// and scriptable, and the alternative (auto-select-then-look-up-endpoint) +// produces confusing errors when the resolved branch does not contain the +// requested endpoint. Asking the user to be explicit is friendlier. func validateTargeting(f targetingFlags) error { switch { case f.target == "" && !f.hasGranular(): @@ -59,6 +64,8 @@ func validateTargeting(f targetingFlags) error { return errors.New("--target is mutually exclusive with --project, --branch, --endpoint") case f.target == "" && f.project == "" && (f.branch != "" || f.endpoint != ""): return errors.New("--project is required when using --branch or --endpoint") + case f.endpoint != "" && f.branch == "": + return errors.New("--branch is required when using --endpoint") } return nil } diff --git a/experimental/postgres/cmd/targeting_test.go b/experimental/postgres/cmd/targeting_test.go index dfdab0d405c..62f43d22496 100644 --- a/experimental/postgres/cmd/targeting_test.go +++ b/experimental/postgres/cmd/targeting_test.go @@ -66,6 +66,14 @@ func TestValidateTargeting(t *testing.T) { }, wantErr: "--project is required when using --branch or --endpoint", }, + { + name: "endpoint with project but no branch", + flags: targetingFlags{ + project: "foo", + endpoint: "primary", + }, + wantErr: "--branch is required when using --endpoint", + }, } for _, tc := range tests { diff --git a/libs/lakebase/target/autoscaling.go b/libs/lakebase/target/autoscaling.go index f1edef216d4..3e496611d6b 100644 --- a/libs/lakebase/target/autoscaling.go +++ b/libs/lakebase/target/autoscaling.go @@ -26,20 +26,23 @@ func ListEndpoints(ctx context.Context, w *databricks.WorkspaceClient, branchNam return w.Postgres.ListEndpointsAll(ctx, postgres.ListEndpointsRequest{Parent: branchName}) } -// GetProject fetches a single project by ID. +// GetProject fetches a single project by ID. Unlike GetProvisioned, the +// Postgres autoscaling API populates the Name field on the response so we do +// not need to patch it. func GetProject(ctx context.Context, w *databricks.WorkspaceClient, projectID string) (*postgres.Project, error) { - return w.Postgres.GetProject(ctx, postgres.GetProjectRequest{Name: PathSegmentProjects + "/" + projectID}) + return w.Postgres.GetProject(ctx, postgres.GetProjectRequest{Name: pathSegmentProjects + "/" + projectID}) } -// GetEndpoint fetches a single endpoint by ID, given its parent IDs. +// GetEndpoint fetches a single endpoint by ID, given its parent IDs. Unlike +// GetProvisioned, the Postgres autoscaling API populates the Name field. func GetEndpoint(ctx context.Context, w *databricks.WorkspaceClient, projectID, branchID, endpointID string) (*postgres.Endpoint, error) { name := fmt.Sprintf("projects/%s/branches/%s/endpoints/%s", projectID, branchID, endpointID) return w.Postgres.GetEndpoint(ctx, postgres.GetEndpointRequest{Name: name}) } -// AutoSelectProject returns the only project in the workspace, or an -// AmbiguousError carrying the choices if there are multiple. Returns a plain -// error if there are no projects. +// AutoSelectProject returns the trailing project ID (e.g. "foo", not +// "projects/foo") if exactly one project exists. Returns an *AmbiguousError +// carrying the choices if there are multiple, or a plain error if there are none. func AutoSelectProject(ctx context.Context, w *databricks.WorkspaceClient) (string, error) { projects, err := ListProjects(ctx, w) if err != nil { @@ -49,23 +52,24 @@ func AutoSelectProject(ctx context.Context, w *databricks.WorkspaceClient) (stri return "", errors.New("no Lakebase Autoscaling projects found in workspace") } if len(projects) == 1 { - return ExtractID(projects[0].Name, PathSegmentProjects), nil + return extractID(projects[0].Name, pathSegmentProjects), nil } choices := make([]Choice, 0, len(projects)) for _, p := range projects { - id := ExtractID(p.Name, PathSegmentProjects) - display := id - if p.Status != nil && p.Status.DisplayName != "" { + id := extractID(p.Name, pathSegmentProjects) + var display string + if p.Status != nil && p.Status.DisplayName != "" && p.Status.DisplayName != id { display = p.Status.DisplayName } choices = append(choices, Choice{ID: id, DisplayName: display}) } - return "", &AmbiguousError{Kind: "project", FlagHint: "--project", Choices: choices} + return "", &AmbiguousError{Kind: KindProject, FlagHint: "--project", Choices: choices} } -// AutoSelectBranch returns the only branch under projectName, or an -// AmbiguousError if there are multiple. +// AutoSelectBranch returns the trailing branch ID under projectName if +// exactly one branch exists. Returns an *AmbiguousError if there are multiple. +// projectName is the SDK resource name (e.g. "projects/foo"). func AutoSelectBranch(ctx context.Context, w *databricks.WorkspaceClient, projectName string) (string, error) { branches, err := ListBranches(ctx, w, projectName) if err != nil { @@ -75,19 +79,20 @@ func AutoSelectBranch(ctx context.Context, w *databricks.WorkspaceClient, projec return "", errors.New("no branches found in project") } if len(branches) == 1 { - return ExtractID(branches[0].Name, pathSegmentBranches), nil + return extractID(branches[0].Name, pathSegmentBranches), nil } choices := make([]Choice, 0, len(branches)) for _, b := range branches { - id := ExtractID(b.Name, pathSegmentBranches) - choices = append(choices, Choice{ID: id, DisplayName: id}) + id := extractID(b.Name, pathSegmentBranches) + choices = append(choices, Choice{ID: id}) } - return "", &AmbiguousError{Kind: "branch", Parent: projectName, FlagHint: "--branch", Choices: choices} + return "", &AmbiguousError{Kind: KindBranch, Parent: projectName, FlagHint: "--branch", Choices: choices} } -// AutoSelectEndpoint returns the only endpoint under branchName, or an -// AmbiguousError if there are multiple. +// AutoSelectEndpoint returns the trailing endpoint ID under branchName if +// exactly one endpoint exists. Returns an *AmbiguousError if there are multiple. +// branchName is the SDK resource name (e.g. "projects/foo/branches/bar"). func AutoSelectEndpoint(ctx context.Context, w *databricks.WorkspaceClient, branchName string) (string, error) { endpoints, err := ListEndpoints(ctx, w, branchName) if err != nil { @@ -97,15 +102,15 @@ func AutoSelectEndpoint(ctx context.Context, w *databricks.WorkspaceClient, bran return "", errors.New("no endpoints found in branch") } if len(endpoints) == 1 { - return ExtractID(endpoints[0].Name, pathSegmentEndpoints), nil + return extractID(endpoints[0].Name, pathSegmentEndpoints), nil } choices := make([]Choice, 0, len(endpoints)) for _, e := range endpoints { - id := ExtractID(e.Name, pathSegmentEndpoints) - choices = append(choices, Choice{ID: id, DisplayName: id}) + id := extractID(e.Name, pathSegmentEndpoints) + choices = append(choices, Choice{ID: id}) } - return "", &AmbiguousError{Kind: "endpoint", Parent: branchName, FlagHint: "--endpoint", Choices: choices} + return "", &AmbiguousError{Kind: KindEndpoint, Parent: branchName, FlagHint: "--endpoint", Choices: choices} } // AutoscalingCredential issues a short-lived OAuth token that can be used to diff --git a/libs/lakebase/target/provisioned.go b/libs/lakebase/target/provisioned.go index 773cc867ce0..261ef37a6a8 100644 --- a/libs/lakebase/target/provisioned.go +++ b/libs/lakebase/target/provisioned.go @@ -29,8 +29,10 @@ func GetProvisioned(ctx context.Context, w *databricks.WorkspaceClient, name str return instance, nil } -// AutoSelectProvisioned returns the only provisioned instance in the workspace, -// or an AmbiguousError if there are multiple. Returns a plain error if none. +// AutoSelectProvisioned returns the only provisioned instance's name (e.g. +// "my-instance"; the database SDK uses flat names, not the "projects/..." +// path shape used by autoscaling). Returns an *AmbiguousError if there are +// multiple, or a plain error if none. func AutoSelectProvisioned(ctx context.Context, w *databricks.WorkspaceClient) (string, error) { instances, err := ListProvisionedInstances(ctx, w) if err != nil { @@ -45,9 +47,9 @@ func AutoSelectProvisioned(ctx context.Context, w *databricks.WorkspaceClient) ( choices := make([]Choice, 0, len(instances)) for _, inst := range instances { - choices = append(choices, Choice{ID: inst.Name, DisplayName: inst.Name}) + choices = append(choices, Choice{ID: inst.Name}) } - return "", &AmbiguousError{Kind: "instance", FlagHint: "--target", Choices: choices} + return "", &AmbiguousError{Kind: KindInstance, FlagHint: "--target", Choices: choices} } // ProvisionedCredential issues a short-lived OAuth token for the provisioned diff --git a/libs/lakebase/target/target.go b/libs/lakebase/target/target.go index d02c95903ce..f0fd2c069e3 100644 --- a/libs/lakebase/target/target.go +++ b/libs/lakebase/target/target.go @@ -11,9 +11,11 @@ import ( ) const ( - // PathSegmentProjects is the leading path segment that identifies an - // autoscaling resource path. Provisioned instance names never start with it. - PathSegmentProjects = "projects" + // pathSegmentProjects is the leading path segment that identifies an + // autoscaling resource path. Provisioned instance names never start with + // it. Use IsAutoscalingPath / ProjectIDFromName from outside this package + // instead of comparing the literal. + pathSegmentProjects = "projects" pathSegmentBranches = "branches" pathSegmentEndpoints = "endpoints" ) @@ -28,11 +30,27 @@ type AutoscalingSpec struct { // Choice is a single candidate returned alongside an AmbiguousError so callers // can either render the list to the user or prompt interactively. +// +// DisplayName is the optional friendlier label for the choice. Producers +// should leave it empty when no friendlier label exists; callers that prompt +// interactively can fall back to the ID. type Choice struct { ID string DisplayName string } +// AmbiguousKind is the typed enum for what an AmbiguousError refers to. A +// typed enum (vs raw string) keeps producers and the pluralisation switch in +// AmbiguousError.Error in sync. +type AmbiguousKind string + +const ( + KindProject AmbiguousKind = "project" + KindBranch AmbiguousKind = "branch" + KindEndpoint AmbiguousKind = "endpoint" + KindInstance AmbiguousKind = "instance" +) + // AmbiguousError is returned by AutoSelect* helpers when the SDK returns more // than one candidate and the caller did not specify which one to pick. // @@ -41,26 +59,27 @@ type Choice struct { // scriptable `postgres query` command) propagate it as a plain error: the // formatted message already enumerates the choices. type AmbiguousError struct { - // Kind identifies what was ambiguous: "project", "branch", or "endpoint". - Kind string + Kind AmbiguousKind // Parent is the SDK resource name that contained the ambiguity (e.g. // "projects/foo" when listing branches), or empty when listing projects. Parent string // FlagHint is the flag a user would set to disambiguate (e.g. "--branch"). FlagHint string - // Choices enumerates the candidates returned by the SDK. + // Choices enumerates the candidates returned by the SDK. DisplayName is + // only set when it carries information beyond ID; an empty DisplayName + // suppresses the parenthetical suffix in Error(). Choices []Choice } func (e *AmbiguousError) Error() string { - plural := map[string]string{ - "project": "projects", - "branch": "branches", - "endpoint": "endpoints", - "instance": "instances", + plural := map[AmbiguousKind]string{ + KindProject: "projects", + KindBranch: "branches", + KindEndpoint: "endpoints", + KindInstance: "instances", }[e.Kind] if plural == "" { - plural = e.Kind + plural = string(e.Kind) } var sb strings.Builder @@ -72,7 +91,7 @@ func (e *AmbiguousError) Error() string { for _, c := range e.Choices { sb.WriteString("\n - ") sb.WriteString(c.ID) - if c.DisplayName != "" && c.DisplayName != c.ID { + if c.DisplayName != "" { fmt.Fprintf(&sb, " (%s)", c.DisplayName) } } @@ -90,7 +109,7 @@ func (e *AmbiguousError) Error() string { func ParseAutoscalingPath(input string) (AutoscalingSpec, error) { parts := strings.Split(input, "/") - if len(parts) < 2 || parts[0] != PathSegmentProjects { + if len(parts) < 2 || parts[0] != pathSegmentProjects { return AutoscalingSpec{}, fmt.Errorf("invalid resource path: %s", input) } if parts[1] == "" { @@ -125,10 +144,10 @@ func ParseAutoscalingPath(input string) (AutoscalingSpec, error) { return spec, nil } -// ExtractID returns the value following component in a resource name. -// ExtractID("projects/foo/branches/bar", "branches") returns "bar". +// extractID returns the value following component in a resource name. +// extractID("projects/foo/branches/bar", "branches") returns "bar". // Returns the original name unchanged if component is not found. -func ExtractID(name, component string) string { +func extractID(name, component string) string { parts := strings.Split(name, "/") for i := range len(parts) - 1 { if parts[i] == component { @@ -138,8 +157,25 @@ func ExtractID(name, component string) string { return name } +// ProjectIDFromName extracts the project ID from a fully-qualified +// SDK resource name like "projects/foo" or "projects/foo/branches/bar". +// Returns the input unchanged if the name does not contain a "projects/" segment. +func ProjectIDFromName(name string) string { + return extractID(name, pathSegmentProjects) +} + +// BranchIDFromName extracts the branch ID from an SDK resource name. +func BranchIDFromName(name string) string { + return extractID(name, pathSegmentBranches) +} + +// EndpointIDFromName extracts the endpoint ID from an SDK resource name. +func EndpointIDFromName(name string) string { + return extractID(name, pathSegmentEndpoints) +} + // IsAutoscalingPath reports whether s is an autoscaling resource path // (i.e. starts with "projects/"). Provisioned instance names never do. func IsAutoscalingPath(s string) bool { - return strings.HasPrefix(s, PathSegmentProjects+"/") + return strings.HasPrefix(s, pathSegmentProjects+"/") } diff --git a/libs/lakebase/target/target_test.go b/libs/lakebase/target/target_test.go index 4b4a763c122..f502cf6e70c 100644 --- a/libs/lakebase/target/target_test.go +++ b/libs/lakebase/target/target_test.go @@ -64,6 +64,16 @@ func TestParseAutoscalingPath(t *testing.T) { input: "projects/foo/branches/bar/endpoints/baz/extra", wantErr: "trailing components after endpoint", }, + { + name: "empty input", + input: "", + wantErr: "invalid resource path", + }, + { + name: "single slash", + input: "/", + wantErr: "invalid resource path", + }, } for _, tc := range tests { @@ -80,11 +90,12 @@ func TestParseAutoscalingPath(t *testing.T) { } } -func TestExtractID(t *testing.T) { - assert.Equal(t, "bar", ExtractID("projects/foo/branches/bar", "branches")) - assert.Equal(t, "foo", ExtractID("projects/foo", "projects")) - assert.Equal(t, "baz", ExtractID("projects/foo/branches/bar/endpoints/baz", "endpoints")) - assert.Equal(t, "no-component", ExtractID("no-component", "missing")) +func TestIDFromName(t *testing.T) { + assert.Equal(t, "foo", ProjectIDFromName("projects/foo")) + assert.Equal(t, "foo", ProjectIDFromName("projects/foo/branches/bar")) + assert.Equal(t, "bar", BranchIDFromName("projects/foo/branches/bar")) + assert.Equal(t, "bar", BranchIDFromName("projects/foo/branches/bar/endpoints/baz")) + assert.Equal(t, "baz", EndpointIDFromName("projects/foo/branches/bar/endpoints/baz")) } func TestIsAutoscalingPath(t *testing.T) { @@ -96,14 +107,14 @@ func TestIsAutoscalingPath(t *testing.T) { } func TestAmbiguousErrorMessage(t *testing.T) { - t.Run("with parent", func(t *testing.T) { + t.Run("with parent, no display names", func(t *testing.T) { err := &AmbiguousError{ - Kind: "branch", + Kind: KindBranch, Parent: "projects/foo", FlagHint: "--branch", Choices: []Choice{ - {ID: "main", DisplayName: "main"}, - {ID: "feature-x", DisplayName: "feature-x"}, + {ID: "main"}, + {ID: "feature-x"}, }, } assert.Equal(t, @@ -112,13 +123,13 @@ func TestAmbiguousErrorMessage(t *testing.T) { ) }) - t.Run("without parent", func(t *testing.T) { + t.Run("without parent, mixed display names", func(t *testing.T) { err := &AmbiguousError{ - Kind: "project", + Kind: KindProject, FlagHint: "--project", Choices: []Choice{ {ID: "alpha", DisplayName: "Alpha Project"}, - {ID: "beta", DisplayName: "beta"}, + {ID: "beta"}, }, } assert.Equal(t, @@ -129,8 +140,8 @@ func TestAmbiguousErrorMessage(t *testing.T) { t.Run("errors.As", func(t *testing.T) { var amb *AmbiguousError - err := error(&AmbiguousError{Kind: "endpoint", FlagHint: "--endpoint"}) + err := error(&AmbiguousError{Kind: KindEndpoint, FlagHint: "--endpoint"}) assert.ErrorAs(t, err, &amb) - assert.Equal(t, "endpoint", amb.Kind) + assert.Equal(t, KindEndpoint, amb.Kind) }) } From 030a2790dd31086bbfc47eaeb21140987eed6246 Mon Sep 17 00:00:00 2001 From: simon Date: Thu, 30 Apr 2026 10:51:59 +0200 Subject: [PATCH 04/15] Address review feedback round 2 - Fix selectAmbiguous: fall back to ID when DisplayName is empty. Round-1 fix to Choice semantics left producers emitting empty DisplayName for branches/endpoints/instances; the psql interactive selector passed that straight to cmdio.Tuple.Name and rendered blank rows. Add the documented fallback. - Drop unused BranchIDFromName / EndpointIDFromName exports; only ProjectIDFromName has callers in this PR. Re-add when first consumed. - Convert chained ifs in isRetryableConnectError to a switch. Co-authored-by: Isaac --- cmd/psql/psql_autoscaling.go | 11 ++++++++++- experimental/postgres/cmd/connect.go | 9 +++++---- libs/lakebase/target/target.go | 10 ---------- libs/lakebase/target/target_test.go | 6 ++---- 4 files changed, 17 insertions(+), 19 deletions(-) diff --git a/cmd/psql/psql_autoscaling.go b/cmd/psql/psql_autoscaling.go index a4c3293cc18..04ccd4bef6b 100644 --- a/cmd/psql/psql_autoscaling.go +++ b/cmd/psql/psql_autoscaling.go @@ -135,10 +135,19 @@ func resolveEndpoint(ctx context.Context, w *databricks.WorkspaceClient, project // AmbiguousError. Caller is expected to have logged a header (e.g. via the // spinner) before invoking. Used to keep psql's interactive UX while letting // the shared lib do the actual list+filter work. +// +// Choice.DisplayName is empty when the producer has no friendlier label than +// the ID (e.g. branches and endpoints, where the ID is the human label). +// The promptui template renders an empty Name as a blank row, so we fall back +// to the ID before handing off to cmdio.SelectOrdered. func selectAmbiguous(ctx context.Context, amb *target.AmbiguousError, prompt string) (string, error) { items := make([]cmdio.Tuple, 0, len(amb.Choices)) for _, c := range amb.Choices { - items = append(items, cmdio.Tuple{Name: c.DisplayName, Id: c.ID}) + name := c.DisplayName + if name == "" { + name = c.ID + } + items = append(items, cmdio.Tuple{Name: name, Id: c.ID}) } return cmdio.SelectOrdered(ctx, items, prompt) } diff --git a/experimental/postgres/cmd/connect.go b/experimental/postgres/cmd/connect.go index 920f02e932f..2eefc681868 100644 --- a/experimental/postgres/cmd/connect.go +++ b/experimental/postgres/cmd/connect.go @@ -129,14 +129,15 @@ func isRetryableConnectError(err error) bool { var pgErr *pgconn.PgError if errors.As(err, &pgErr) { + switch { // 08xxx is the connection_exception class. - if len(pgErr.Code) == 5 && pgErr.Code[:2] == "08" { + case len(pgErr.Code) == 5 && pgErr.Code[:2] == "08": return true - } - if pgErr.Code == "57P03" { + case pgErr.Code == "57P03": return true + default: + return false } - return false } var connectErr *pgconn.ConnectError diff --git a/libs/lakebase/target/target.go b/libs/lakebase/target/target.go index f0fd2c069e3..1874829acce 100644 --- a/libs/lakebase/target/target.go +++ b/libs/lakebase/target/target.go @@ -164,16 +164,6 @@ func ProjectIDFromName(name string) string { return extractID(name, pathSegmentProjects) } -// BranchIDFromName extracts the branch ID from an SDK resource name. -func BranchIDFromName(name string) string { - return extractID(name, pathSegmentBranches) -} - -// EndpointIDFromName extracts the endpoint ID from an SDK resource name. -func EndpointIDFromName(name string) string { - return extractID(name, pathSegmentEndpoints) -} - // IsAutoscalingPath reports whether s is an autoscaling resource path // (i.e. starts with "projects/"). Provisioned instance names never do. func IsAutoscalingPath(s string) bool { diff --git a/libs/lakebase/target/target_test.go b/libs/lakebase/target/target_test.go index f502cf6e70c..f1726890330 100644 --- a/libs/lakebase/target/target_test.go +++ b/libs/lakebase/target/target_test.go @@ -90,12 +90,10 @@ func TestParseAutoscalingPath(t *testing.T) { } } -func TestIDFromName(t *testing.T) { +func TestProjectIDFromName(t *testing.T) { assert.Equal(t, "foo", ProjectIDFromName("projects/foo")) assert.Equal(t, "foo", ProjectIDFromName("projects/foo/branches/bar")) - assert.Equal(t, "bar", BranchIDFromName("projects/foo/branches/bar")) - assert.Equal(t, "bar", BranchIDFromName("projects/foo/branches/bar/endpoints/baz")) - assert.Equal(t, "baz", EndpointIDFromName("projects/foo/branches/bar/endpoints/baz")) + assert.Equal(t, "no-projects", ProjectIDFromName("no-projects")) } func TestIsAutoscalingPath(t *testing.T) { From 5e0e3dd2157ea324c819795b80c8ebf0dd79cd4d Mon Sep 17 00:00:00 2001 From: simon Date: Thu, 30 Apr 2026 11:01:07 +0200 Subject: [PATCH 05/15] Provisioned targeting + JSON/CSV streaming + typed values This is PR 2 of the experimental postgres query stack. Builds on PR 1's scaffold to fill in the rest of the single-input output story. Provisioned support: --target accepts both autoscaling resource paths (starts with "projects/") and provisioned instance names (everything else). Granular --project/--branch/--endpoint targeting stays autoscaling-only. resolveProvisioned validates the instance is in the AVAILABLE state and has read/write DNS before issuing a token. Output renderers are now sinks fed by executeOne row-by-row instead of buffering. textSink keeps buffering (tabwriter needs the widest cell to align); jsonSink and csvSink stream. jsonSink uses separator-before-element writing throughout so a mid-stream error can close the array cleanly via OnError, leaving stdout as parseable JSON with a partial result. JSON value rendering follows the typed mapping: numbers stay numeric inside +- 2^53, become strings outside; NaN/Inf become "NaN"/"Infinity"/ "-Infinity"; timestamps render in RFC3339; jsonb passes through as json.RawMessage so e.g. {"id": 9007199254740993} keeps its digits; bytea base64-encodes; everything else falls back to canonical Postgres text. CSV and text use Postgres' canonical text representation, with NULL rendered as the literal "NULL" in text and as empty in CSV (matches psql --csv). Output mode auto-selection mirrors aitools query: --output text on a non-TTY stdout falls back to JSON. DATABRICKS_OUTPUT_FORMAT is honoured when --output is not explicitly set; invalid env values are silently ignored. Duplicate column names are deterministically renamed (id, id__2, id__3) with a stderr warning. Acceptance: argument-errors loses the now-obsolete "provisioned not yet supported" case; new provisioned-targeting test exercises not-AVAILABLE / no-DNS / 404 paths via the SDK testserver mock. Co-authored-by: Isaac --- .../postgres/query/argument-errors/output.txt | 4 - .../postgres/query/argument-errors/script | 3 - .../query/provisioned-targeting/out.test.toml | 8 + .../query/provisioned-targeting/output.txt | 12 ++ .../query/provisioned-targeting/script | 8 + .../query/provisioned-targeting/test.toml | 30 +++ experimental/postgres/cmd/execute.go | 67 ++++--- experimental/postgres/cmd/output.go | 71 +++++++ experimental/postgres/cmd/output_test.go | 79 ++++++++ experimental/postgres/cmd/query.go | 70 +++++-- experimental/postgres/cmd/render.go | 77 ++++---- experimental/postgres/cmd/render_csv.go | 80 ++++++++ experimental/postgres/cmd/render_csv_test.go | 49 +++++ experimental/postgres/cmd/render_json.go | 173 ++++++++++++++++++ experimental/postgres/cmd/render_json_test.go | 118 ++++++++++++ experimental/postgres/cmd/render_test.go | 68 +++---- experimental/postgres/cmd/targeting.go | 47 ++++- experimental/postgres/cmd/value.go | 152 +++++++++++++++ experimental/postgres/cmd/value_test.go | 84 +++++++++ 19 files changed, 1079 insertions(+), 121 deletions(-) create mode 100644 acceptance/cmd/experimental/postgres/query/provisioned-targeting/out.test.toml create mode 100644 acceptance/cmd/experimental/postgres/query/provisioned-targeting/output.txt create mode 100644 acceptance/cmd/experimental/postgres/query/provisioned-targeting/script create mode 100644 acceptance/cmd/experimental/postgres/query/provisioned-targeting/test.toml create mode 100644 experimental/postgres/cmd/output.go create mode 100644 experimental/postgres/cmd/output_test.go create mode 100644 experimental/postgres/cmd/render_csv.go create mode 100644 experimental/postgres/cmd/render_csv_test.go create mode 100644 experimental/postgres/cmd/render_json.go create mode 100644 experimental/postgres/cmd/render_json_test.go create mode 100644 experimental/postgres/cmd/value.go create mode 100644 experimental/postgres/cmd/value_test.go diff --git a/acceptance/cmd/experimental/postgres/query/argument-errors/output.txt b/acceptance/cmd/experimental/postgres/query/argument-errors/output.txt index c071466a1e3..238e099299c 100644 --- a/acceptance/cmd/experimental/postgres/query/argument-errors/output.txt +++ b/acceptance/cmd/experimental/postgres/query/argument-errors/output.txt @@ -35,10 +35,6 @@ Error: --branch is required when using --endpoint >>> musterr [CLI] experimental postgres query --project foo --branch main --max-retries 0 SELECT 1 Error: --max-retries must be at least 1; got 0 -=== Provisioned-shaped target should error pointing at psql: ->>> musterr [CLI] experimental postgres query --target my-instance SELECT 1 -Error: provisioned instances are not yet supported by this experimental command; use 'databricks psql ' for now - === Malformed autoscaling path should error: >>> musterr [CLI] experimental postgres query --target projects/ SELECT 1 Error: invalid resource path: missing project ID diff --git a/acceptance/cmd/experimental/postgres/query/argument-errors/script b/acceptance/cmd/experimental/postgres/query/argument-errors/script index 8d64bf307ed..ac6ac42746e 100644 --- a/acceptance/cmd/experimental/postgres/query/argument-errors/script +++ b/acceptance/cmd/experimental/postgres/query/argument-errors/script @@ -25,9 +25,6 @@ trace musterr $CLI experimental postgres query --project foo --endpoint primary title "Max-retries 0 should error:" trace musterr $CLI experimental postgres query --project foo --branch main --max-retries 0 "SELECT 1" -title "Provisioned-shaped target should error pointing at psql:" -trace musterr $CLI experimental postgres query --target my-instance "SELECT 1" - title "Malformed autoscaling path should error:" trace musterr $CLI experimental postgres query --target projects/ "SELECT 1" diff --git a/acceptance/cmd/experimental/postgres/query/provisioned-targeting/out.test.toml b/acceptance/cmd/experimental/postgres/query/provisioned-targeting/out.test.toml new file mode 100644 index 00000000000..40bb0d10471 --- /dev/null +++ b/acceptance/cmd/experimental/postgres/query/provisioned-targeting/out.test.toml @@ -0,0 +1,8 @@ +Local = true +Cloud = false + +[GOOS] + windows = false + +[EnvMatrix] + DATABRICKS_BUNDLE_ENGINE = ["terraform", "direct"] diff --git a/acceptance/cmd/experimental/postgres/query/provisioned-targeting/output.txt b/acceptance/cmd/experimental/postgres/query/provisioned-targeting/output.txt new file mode 100644 index 00000000000..0f00f8b3e44 --- /dev/null +++ b/acceptance/cmd/experimental/postgres/query/provisioned-targeting/output.txt @@ -0,0 +1,12 @@ + +=== Provisioned target in non-AVAILABLE state should error: +>>> musterr [CLI] experimental postgres query --target starting-instance SELECT 1 +Error: database instance "starting-instance" is not ready for accepting connections (state: STARTING) + +=== Provisioned target with no DNS should error: +>>> musterr [CLI] experimental postgres query --target no-dns-instance SELECT 1 +Error: database instance "no-dns-instance" has no read/write DNS yet + +=== Provisioned target not found should surface SDK 404: +>>> musterr [CLI] experimental postgres query --target missing-instance SELECT 1 +Error: failed to get database instance: instance not found diff --git a/acceptance/cmd/experimental/postgres/query/provisioned-targeting/script b/acceptance/cmd/experimental/postgres/query/provisioned-targeting/script new file mode 100644 index 00000000000..d8995c62a6c --- /dev/null +++ b/acceptance/cmd/experimental/postgres/query/provisioned-targeting/script @@ -0,0 +1,8 @@ +title "Provisioned target in non-AVAILABLE state should error:" +trace musterr $CLI experimental postgres query --target starting-instance "SELECT 1" + +title "Provisioned target with no DNS should error:" +trace musterr $CLI experimental postgres query --target no-dns-instance "SELECT 1" + +title "Provisioned target not found should surface SDK 404:" +trace musterr $CLI experimental postgres query --target missing-instance "SELECT 1" diff --git a/acceptance/cmd/experimental/postgres/query/provisioned-targeting/test.toml b/acceptance/cmd/experimental/postgres/query/provisioned-targeting/test.toml new file mode 100644 index 00000000000..4821dab5741 --- /dev/null +++ b/acceptance/cmd/experimental/postgres/query/provisioned-targeting/test.toml @@ -0,0 +1,30 @@ +GOOS.windows = false + +[[Server]] +Pattern = "GET /api/2.0/database/instances/starting-instance" +Response.Body = ''' +{ + "name": "starting-instance", + "state": "STARTING", + "read_write_dns": "starting.example.com" +} +''' + +[[Server]] +Pattern = "GET /api/2.0/database/instances/no-dns-instance" +Response.Body = ''' +{ + "name": "no-dns-instance", + "state": "AVAILABLE" +} +''' + +[[Server]] +Pattern = "GET /api/2.0/database/instances/missing-instance" +Response.StatusCode = 404 +Response.Body = ''' +{ + "error_code": "NOT_FOUND", + "message": "instance not found" +} +''' diff --git a/experimental/postgres/cmd/execute.go b/experimental/postgres/cmd/execute.go index c29f7ce59d6..61d93bd7bc2 100644 --- a/experimental/postgres/cmd/execute.go +++ b/experimental/postgres/cmd/execute.go @@ -5,10 +5,29 @@ import ( "fmt" "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" ) -// executeOne runs a single SQL statement against an open connection and -// captures the result in a queryResult. +// rowSink consumes a query result one row at a time. Sinks that maintain open +// output structures (e.g. a streaming JSON array) implement OnError so they +// can close cleanly when the iteration terminates with a partial result. +type rowSink interface { + // Begin is called once with the column descriptions before any Row. + // For command-only statements (no rows), Begin is still called with an + // empty slice so the sink can lock in its rows-vs-command shape. + Begin(fields []pgconn.FieldDescription) error + // Row delivers one decoded row. Values aligns with the fields passed to + // Begin and uses pgx's Go type mapping (int64, float64, time.Time, ...). + Row(values []any) error + // End is called once on successful completion. + End(commandTag string) error + // OnError is called if iteration errors after Begin returned. The sink + // is expected to flush any in-progress output structures so stdout + // remains well-formed. The caller still surfaces err to its caller. + OnError(err error) +} + +// executeOne runs a single SQL statement and streams the result through sink. // // We pass QueryExecModeExec explicitly (not the pgx default // QueryExecModeCacheStatement) for two reasons: @@ -17,46 +36,38 @@ import ( // closed at the end of the command, so the cached prepared statement // never gets reused. // 2. Exec mode uses Postgres' extended-protocol "exec" path with text-format -// result columns. That makes rendering canonical-Postgres-text output -// (PR 1) and CSV (later PR) straightforward; the cache mode defaults to -// binary and we'd be reformatting back to text. +// result columns, which keeps the canonical-Postgres-text rendering for +// --output text and --output csv straightforward. // // QueryExecModeExec still uses extended protocol with a single statement and // no implicit transaction wrap, so transaction-disallowed DDL like -// `CREATE DATABASE` works. -func executeOne(ctx context.Context, conn *pgx.Conn, sql string) (*queryResult, error) { +// CREATE DATABASE works. +func executeOne(ctx context.Context, conn *pgx.Conn, sql string, sink rowSink) error { rows, err := conn.Query(ctx, sql, pgx.QueryExecModeExec) if err != nil { - return nil, fmt.Errorf("query failed: %w", err) + return fmt.Errorf("query failed: %w", err) } defer rows.Close() - result := &queryResult{SQL: sql} - - fields := rows.FieldDescriptions() - if len(fields) > 0 { - result.Columns = make([]string, len(fields)) - for i, f := range fields { - result.Columns[i] = f.Name - } + if err := sink.Begin(rows.FieldDescriptions()); err != nil { + return err } for rows.Next() { - raw := rows.RawValues() - row := make([]string, len(raw)) - for i, b := range raw { - if b == nil { - row[i] = "NULL" - continue - } - row[i] = string(b) + values, err := rows.Values() + if err != nil { + sink.OnError(err) + return fmt.Errorf("decode row: %w", err) + } + if err := sink.Row(values); err != nil { + sink.OnError(err) + return err } - result.Rows = append(result.Rows, row) } if err := rows.Err(); err != nil { - return nil, fmt.Errorf("query failed: %w", err) + sink.OnError(err) + return fmt.Errorf("query failed: %w", err) } - result.CommandTag = rows.CommandTag().String() - return result, nil + return sink.End(rows.CommandTag().String()) } diff --git a/experimental/postgres/cmd/output.go b/experimental/postgres/cmd/output.go new file mode 100644 index 00000000000..c293b424b73 --- /dev/null +++ b/experimental/postgres/cmd/output.go @@ -0,0 +1,71 @@ +package postgrescmd + +import ( + "context" + "fmt" + "strings" + + "github.com/databricks/cli/libs/env" +) + +// outputFormat is the user-selectable output shape. Using a string typedef +// instead of an int enum keeps the help text and DATABRICKS_OUTPUT_FORMAT env +// var values self-describing. +type outputFormat string + +const ( + outputText outputFormat = "text" + outputJSON outputFormat = "json" + outputCSV outputFormat = "csv" + + // envOutputFormat matches the env var name in cmd/root/io.go. Reading it + // here lets pipelines set DATABRICKS_OUTPUT_FORMAT once for all + // commands. See aitools query for a parallel pattern. + envOutputFormat = "DATABRICKS_OUTPUT_FORMAT" +) + +// allOutputFormats is the canonical order shown in completions / help. +var allOutputFormats = []outputFormat{outputText, outputJSON, outputCSV} + +// resolveOutputFormat picks the effective output format. Precedence: +// +// 1. The local --output flag if it was explicitly set. +// 2. DATABRICKS_OUTPUT_FORMAT env var if set to a known value (invalid +// values are silently ignored, matching cmd/root/io.go and aitools). +// 3. The flag default ("text"). +// +// Then the auto-selection rule applies: text on a non-TTY stdout falls back +// to JSON. This matches the aitools query command and means scripts piping +// stdout get machine-readable output by default. +// +// flagSet is true if the user explicitly passed --output. stdoutTTY is true +// if stdout is a terminal. +func resolveOutputFormat(ctx context.Context, flagValue string, flagSet, stdoutTTY bool) (outputFormat, error) { + chosen := outputFormat(strings.ToLower(flagValue)) + + if !flagSet { + if v, ok := env.Lookup(ctx, envOutputFormat); ok { + candidate := outputFormat(strings.ToLower(v)) + if isKnownOutputFormat(candidate) { + chosen = candidate + } + } + } + + if !isKnownOutputFormat(chosen) { + return "", fmt.Errorf("unsupported output format %q; expected one of: text, json, csv", flagValue) + } + + if chosen == outputText && !stdoutTTY { + return outputJSON, nil + } + return chosen, nil +} + +func isKnownOutputFormat(f outputFormat) bool { + switch f { + case outputText, outputJSON, outputCSV: + return true + } + return false +} diff --git a/experimental/postgres/cmd/output_test.go b/experimental/postgres/cmd/output_test.go new file mode 100644 index 00000000000..79289a43e56 --- /dev/null +++ b/experimental/postgres/cmd/output_test.go @@ -0,0 +1,79 @@ +package postgrescmd + +import ( + "testing" + + "github.com/databricks/cli/libs/env" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestResolveOutputFormat_Defaults(t *testing.T) { + ctx := t.Context() + + got, err := resolveOutputFormat(ctx, "text", false, true) + require.NoError(t, err) + assert.Equal(t, outputText, got) +} + +func TestResolveOutputFormat_TextOnPipeFallsBackToJSON(t *testing.T) { + ctx := t.Context() + got, err := resolveOutputFormat(ctx, "text", false, false) + require.NoError(t, err) + assert.Equal(t, outputJSON, got) +} + +func TestResolveOutputFormat_ExplicitTextOnPipeAlsoFallsBackToJSON(t *testing.T) { + ctx := t.Context() + got, err := resolveOutputFormat(ctx, "text", true, false) + require.NoError(t, err) + assert.Equal(t, outputJSON, got) +} + +func TestResolveOutputFormat_ExplicitJSON(t *testing.T) { + ctx := t.Context() + got, err := resolveOutputFormat(ctx, "json", true, true) + require.NoError(t, err) + assert.Equal(t, outputJSON, got) +} + +func TestResolveOutputFormat_ExplicitCSV(t *testing.T) { + ctx := t.Context() + got, err := resolveOutputFormat(ctx, "csv", true, true) + require.NoError(t, err) + assert.Equal(t, outputCSV, got) +} + +func TestResolveOutputFormat_EnvVarHonoredWhenFlagNotSet(t *testing.T) { + ctx := env.Set(t.Context(), envOutputFormat, "csv") + got, err := resolveOutputFormat(ctx, "text", false, true) + require.NoError(t, err) + assert.Equal(t, outputCSV, got) +} + +func TestResolveOutputFormat_FlagOverridesEnvVar(t *testing.T) { + ctx := env.Set(t.Context(), envOutputFormat, "csv") + got, err := resolveOutputFormat(ctx, "json", true, true) + require.NoError(t, err) + assert.Equal(t, outputJSON, got) +} + +func TestResolveOutputFormat_InvalidEnvVarIgnored(t *testing.T) { + ctx := env.Set(t.Context(), envOutputFormat, "yaml") + got, err := resolveOutputFormat(ctx, "text", false, true) + require.NoError(t, err) + assert.Equal(t, outputText, got) +} + +func TestResolveOutputFormat_InvalidFlagErrors(t *testing.T) { + ctx := t.Context() + _, err := resolveOutputFormat(ctx, "yaml", true, true) + assert.ErrorContains(t, err, "unsupported output format") +} + +func TestResolveOutputFormat_CaseInsensitive(t *testing.T) { + ctx := t.Context() + got, err := resolveOutputFormat(ctx, "JSON", true, true) + require.NoError(t, err) + assert.Equal(t, outputJSON, got) +} diff --git a/experimental/postgres/cmd/query.go b/experimental/postgres/cmd/query.go index fe5cc528ea7..c3078f24d82 100644 --- a/experimental/postgres/cmd/query.go +++ b/experimental/postgres/cmd/query.go @@ -4,6 +4,7 @@ import ( "context" "errors" "fmt" + "io" "strings" "time" @@ -25,6 +26,11 @@ type queryFlags struct { database string connectTimeout time.Duration maxRetries int + + // outputFormat is the raw flag value. resolveOutputFormat turns it into + // the effective format (which may differ when stdout is piped). + outputFormat string + outputFormatSet bool } func newQueryCmd() *cobra.Command { @@ -33,15 +39,29 @@ func newQueryCmd() *cobra.Command { cmd := &cobra.Command{ Use: "query [SQL]", Short: "Run a SQL statement against a Lakebase Postgres endpoint", - Long: `Execute a single SQL statement against a Lakebase Postgres endpoint and -render the result as text. + Long: `Execute a single SQL statement against a Lakebase Postgres endpoint. Targeting (exactly one form required): - --target STRING Autoscaling resource path - (e.g. projects/foo/branches/main/endpoints/primary) + --target STRING Provisioned instance name OR autoscaling resource path + (e.g. my-instance, projects/foo/branches/main/endpoints/primary) --project ID Autoscaling project ID --branch ID Autoscaling branch ID (default: auto-select if exactly one) - --endpoint ID Autoscaling endpoint ID (default: auto-select if exactly one) + --endpoint ID Autoscaling endpoint ID + +Output: + --output text Aligned table for rows-producing statements (default). + Falls back to JSON when stdout is not a terminal so + scripts piping the output get machine-readable results. + --output json Top-level array of row objects, streamed for + rows-producing statements. Command-only statements + emit a single {"command": "...", "rows_affected": N} + object. Numbers, booleans, NULL, jsonb, timestamps + render with their JSON-native types. + --output csv Header row + one CSV row per result row, streamed. + Command-only statements write the command tag to + stderr. + +DATABRICKS_OUTPUT_FORMAT is honoured when --output is not explicitly set. This is an experimental command. The flag set, output shape, and supported target kinds will expand in subsequent releases. @@ -49,10 +69,6 @@ target kinds will expand in subsequent releases. Limitations (this release): - Single SQL statement per invocation (multi-statement support comes later). - - Text output only. JSON and CSV output come in a follow-up release. - - Only Lakebase Autoscaling endpoints are supported. Provisioned instance - support comes in a follow-up release; use 'databricks psql ' as a - workaround for now. - No interactive REPL. 'databricks psql' continues to own that surface. - Multi-statement strings (e.g. "SELECT 1; SELECT 2") are not supported. - The OAuth token is generated once per invocation and is valid for 1h. @@ -61,17 +77,26 @@ Limitations (this release): Args: cobra.ExactArgs(1), PreRunE: root.MustWorkspaceClient, RunE: func(cmd *cobra.Command, args []string) error { + f.outputFormatSet = cmd.Flag("output").Changed return runQuery(cmd.Context(), cmd, args[0], f) }, } - cmd.Flags().StringVar(&f.target, "target", "", "Autoscaling resource path (e.g. projects/foo/branches/main/endpoints/primary)") + cmd.Flags().StringVar(&f.target, "target", "", "Provisioned instance name OR autoscaling resource path") cmd.Flags().StringVar(&f.project, "project", "", "Autoscaling project ID") cmd.Flags().StringVar(&f.branch, "branch", "", "Autoscaling branch ID (default: auto-select if exactly one)") cmd.Flags().StringVar(&f.endpoint, "endpoint", "", "Autoscaling endpoint ID (default: auto-select if exactly one)") cmd.Flags().StringVarP(&f.database, "database", "d", defaultDatabase, "Database name") cmd.Flags().DurationVar(&f.connectTimeout, "connect-timeout", defaultConnectTimeout, "Connect timeout") cmd.Flags().IntVar(&f.maxRetries, "max-retries", 3, "Total connect attempts on idle/waking endpoint (must be >= 1; 1 disables retry)") + cmd.Flags().StringVarP(&f.outputFormat, "output", "o", string(outputText), "Output format: text, json, or csv") + cmd.RegisterFlagCompletionFunc("output", func(*cobra.Command, []string, string) ([]string, cobra.ShellCompDirective) { + out := make([]string, len(allOutputFormats)) + for i, f := range allOutputFormats { + out[i] = string(f) + } + return out, cobra.ShellCompDirectiveNoFileComp + }) cmd.MarkFlagsMutuallyExclusive("target", "project") cmd.MarkFlagsMutuallyExclusive("target", "branch") @@ -95,6 +120,12 @@ func runQuery(ctx context.Context, cmd *cobra.Command, sql string, f queryFlags) return err } + stdoutTTY := cmdio.SupportsColor(ctx, cmd.OutOrStdout()) + format, err := resolveOutputFormat(ctx, f.outputFormat, f.outputFormatSet, stdoutTTY) + if err != nil { + return err + } + resolved, err := resolveTarget(ctx, f.targetingFlags) if err != nil { return err @@ -126,10 +157,19 @@ func runQuery(ctx context.Context, cmd *cobra.Command, sql string, f queryFlags) } defer conn.Close(context.WithoutCancel(ctx)) - result, err := executeOne(ctx, conn, sql) - if err != nil { - return err - } + sink := newSink(format, cmd.OutOrStdout(), cmd.ErrOrStderr()) + return executeOne(ctx, conn, sql, sink) +} - return renderText(cmd.OutOrStdout(), result) +// newSink returns the rowSink for the chosen output format. Kept separate +// from runQuery so tests can build sinks without going through pgx. +func newSink(format outputFormat, out, stderr io.Writer) rowSink { + switch format { + case outputJSON: + return newJSONSink(out, stderr) + case outputCSV: + return newCSVSink(out, stderr) + default: + return newTextSink(out) + } } diff --git a/experimental/postgres/cmd/render.go b/experimental/postgres/cmd/render.go index ff923c4a92e..bc45c89e0d0 100644 --- a/experimental/postgres/cmd/render.go +++ b/experimental/postgres/cmd/render.go @@ -5,59 +5,68 @@ import ( "io" "strings" "text/tabwriter" + + "github.com/jackc/pgx/v5/pgconn" ) -// queryResult is the rendered shape of a single SQL execution. PR 1 only -// renders text; later PRs add JSON and CSV against the same struct. +// textSink renders results as plain text: a tabwriter-aligned table for +// rows-producing statements, the command tag for command-only ones. // -// columns is empty for command-only statements (INSERT, CREATE DATABASE, ...); -// rows is empty when no rows were returned (or for command-only statements). -type queryResult struct { - SQL string - // CommandTag is the Postgres command tag for the statement (e.g. "INSERT 0 5", - // "CREATE DATABASE"). Always set; used for command-only statements and as a - // trailer for rows-producing ones. - CommandTag string - Columns []string - Rows [][]string +// Text output buffers all rows because tabwriter needs the widest cell in each +// column before it can align. Streaming output is provided by the JSON and CSV +// sinks; users with huge result sets should pick those. +type textSink struct { + out io.Writer + columns []string + rows [][]string } -// IsRowsProducing reports whether the statement returned a row description. -// Determined at runtime via FieldDescriptions() rather than by parsing the -// leading SQL keyword: `INSERT ... RETURNING` and CTEs ending in a SELECT are -// rows-producing despite their leading keywords. -func (r *queryResult) IsRowsProducing() bool { - return len(r.Columns) > 0 +func newTextSink(out io.Writer) *textSink { + return &textSink{out: out} } -// renderText writes a result in plain text. -// -// For rows-producing statements we use a tabwriter-aligned table followed by -// a `(N rows)` footer, mimicking psql's compact text shape. For command-only -// statements we just print the command tag. -// -// PR 1 always uses the static (buffered) shape. The interactive table viewer -// for >30 rows lands in a later PR alongside the multi-input output shapes. -func renderText(out io.Writer, r *queryResult) error { - if !r.IsRowsProducing() { - _, err := fmt.Fprintln(out, r.CommandTag) +func (s *textSink) Begin(fields []pgconn.FieldDescription) error { + s.columns = make([]string, len(fields)) + for i, f := range fields { + s.columns[i] = f.Name + } + return nil +} + +func (s *textSink) Row(values []any) error { + row := make([]string, len(values)) + for i, v := range values { + row[i] = textValue(v) + } + s.rows = append(s.rows, row) + return nil +} + +func (s *textSink) End(commandTag string) error { + if len(s.columns) == 0 { + _, err := fmt.Fprintln(s.out, commandTag) return err } - tw := tabwriter.NewWriter(out, 0, 0, 2, ' ', 0) - fmt.Fprintln(tw, strings.Join(r.Columns, "\t")) - fmt.Fprintln(tw, strings.Join(headerSeparator(r.Columns), "\t")) - for _, row := range r.Rows { + tw := tabwriter.NewWriter(s.out, 0, 0, 2, ' ', 0) + fmt.Fprintln(tw, strings.Join(s.columns, "\t")) + fmt.Fprintln(tw, strings.Join(headerSeparator(s.columns), "\t")) + for _, row := range s.rows { fmt.Fprintln(tw, strings.Join(row, "\t")) } if err := tw.Flush(); err != nil { return err } - _, err := fmt.Fprintf(out, "(%d %s)\n", len(r.Rows), pluralize(len(r.Rows), "row", "rows")) + _, err := fmt.Fprintf(s.out, "(%d %s)\n", len(s.rows), pluralize(len(s.rows), "row", "rows")) return err } +// OnError for text sinks is a no-op: text output prints whatever rows have +// already been collected, with no open structure to close. The caller +// surfaces the error separately (cobra's default error rendering). +func (s *textSink) OnError(err error) {} + func headerSeparator(cols []string) []string { out := make([]string, len(cols)) for i, c := range cols { diff --git a/experimental/postgres/cmd/render_csv.go b/experimental/postgres/cmd/render_csv.go new file mode 100644 index 00000000000..940e11324f5 --- /dev/null +++ b/experimental/postgres/cmd/render_csv.go @@ -0,0 +1,80 @@ +package postgrescmd + +import ( + "encoding/csv" + "fmt" + "io" + + "github.com/jackc/pgx/v5/pgconn" +) + +// csvSink streams query results as CSV. Header row is written on Begin, each +// data row is written and flushed individually so large exports do not buffer +// in memory. +// +// For command-only statements CSV has nothing meaningful to emit (no header, +// no rows): we write the command tag to stderr so machine consumers reading +// stdout still receive an empty document, while humans get a confirmation. +type csvSink struct { + out io.Writer + stderr io.Writer + w *csv.Writer + + // rowsProducing is true once Begin saw a non-empty fields slice. End + // uses it to decide whether to write the command-tag stderr line. + rowsProducing bool +} + +func newCSVSink(out, stderr io.Writer) *csvSink { + return &csvSink{out: out, stderr: stderr, w: csv.NewWriter(out)} +} + +func (s *csvSink) Begin(fields []pgconn.FieldDescription) error { + if len(fields) == 0 { + return nil + } + s.rowsProducing = true + + header := make([]string, len(fields)) + for i, f := range fields { + header[i] = f.Name + } + if err := s.w.Write(header); err != nil { + return fmt.Errorf("write CSV header: %w", err) + } + s.w.Flush() + return s.w.Error() +} + +func (s *csvSink) Row(values []any) error { + row := make([]string, len(values)) + for i, v := range values { + // CSV represents NULL as an empty field, matching `psql --csv`. + if v == nil { + row[i] = "" + continue + } + row[i] = textValue(v) + } + if err := s.w.Write(row); err != nil { + return fmt.Errorf("write CSV row: %w", err) + } + s.w.Flush() + return s.w.Error() +} + +func (s *csvSink) End(commandTag string) error { + if !s.rowsProducing { + _, err := fmt.Fprintln(s.stderr, commandTag) + return err + } + s.w.Flush() + return s.w.Error() +} + +// OnError flushes whatever is buffered in the csv.Writer so the partial result +// is visible to the consumer. csv.Writer has no concept of "open structure", +// so there is nothing more to do. +func (s *csvSink) OnError(err error) { + s.w.Flush() +} diff --git a/experimental/postgres/cmd/render_csv_test.go b/experimental/postgres/cmd/render_csv_test.go new file mode 100644 index 00000000000..35d1c3596f1 --- /dev/null +++ b/experimental/postgres/cmd/render_csv_test.go @@ -0,0 +1,49 @@ +package postgrescmd + +import ( + "bytes" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestCSVSink_TwoRows(t *testing.T) { + var stdout, stderr bytes.Buffer + s := newCSVSink(&stdout, &stderr) + require.NoError(t, s.Begin(fields("id", "name"))) + require.NoError(t, s.Row([]any{int64(1), "alice"})) + require.NoError(t, s.Row([]any{int64(2), "bob"})) + require.NoError(t, s.End("SELECT 2")) + + assert.Equal(t, "id,name\n1,alice\n2,bob\n", stdout.String()) + assert.Empty(t, stderr.String()) +} + +func TestCSVSink_NULLEmptyField(t *testing.T) { + var stdout, stderr bytes.Buffer + s := newCSVSink(&stdout, &stderr) + require.NoError(t, s.Begin(fields("id", "note"))) + require.NoError(t, s.Row([]any{int64(1), nil})) + require.NoError(t, s.End("SELECT 1")) + + assert.Equal(t, "id,note\n1,\n", stdout.String()) +} + +func TestCSVSink_CommandOnly(t *testing.T) { + var stdout, stderr bytes.Buffer + s := newCSVSink(&stdout, &stderr) + require.NoError(t, s.Begin(nil)) + require.NoError(t, s.End("CREATE DATABASE")) + assert.Empty(t, stdout.String()) + assert.Equal(t, "CREATE DATABASE\n", stderr.String()) +} + +func TestCSVSink_QuotesFieldsWithCommas(t *testing.T) { + var stdout, stderr bytes.Buffer + s := newCSVSink(&stdout, &stderr) + require.NoError(t, s.Begin(fields("note"))) + require.NoError(t, s.Row([]any{"a,b"})) + require.NoError(t, s.End("SELECT 1")) + assert.Contains(t, stdout.String(), `"a,b"`) +} diff --git a/experimental/postgres/cmd/render_json.go b/experimental/postgres/cmd/render_json.go new file mode 100644 index 00000000000..1d9a53a8e8d --- /dev/null +++ b/experimental/postgres/cmd/render_json.go @@ -0,0 +1,173 @@ +package postgrescmd + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "strconv" + + "github.com/jackc/pgx/v5/pgconn" +) + +// jsonSink streams query results as JSON. +// +// For rows-producing statements the output is a top-level array of row +// objects. We use the separator-before-element pattern to avoid the +// "rewrite the trailing comma" trick and keep the JSON parseable even when +// iteration ends with a partial result (caller closes the array on OnError). +// +// For command-only statements the output is a single object describing the +// command tag. +type jsonSink struct { + out io.Writer + stderr io.Writer + + // columns are the disambiguated column names: duplicates beyond the first + // occurrence are renamed to "__2", "__3", etc. Postgres + // allows duplicate output names (`SELECT 1, 1`, joins with two unaliased + // `id` columns) but JSON consumers usually want unique keys; we rename + // deterministically and warn once on stderr. + columns []string + oids []uint32 + + // hasOpenedArray is true once the leading `[\n` has been written. Used + // by OnError to decide whether to emit the closing `]\n` to keep stdout + // well-formed. + hasOpenedArray bool + // rowsWritten counts emitted rows so the separator decision is trivial: + // 0 means "first row, no separator", anything else means "separator first". + rowsWritten int +} + +func newJSONSink(out, stderr io.Writer) *jsonSink { + return &jsonSink{out: out, stderr: stderr} +} + +func (s *jsonSink) Begin(fields []pgconn.FieldDescription) error { + if len(fields) == 0 { + // Command-only; we wait until End to emit the {"command": ...} object. + return nil + } + + s.columns = make([]string, len(fields)) + s.oids = make([]uint32, len(fields)) + seen := make(map[string]int, len(fields)) + dupes := false + for i, f := range fields { + s.oids[i] = f.DataTypeOID + name := f.Name + seen[name]++ + if seen[name] > 1 { + dupes = true + name = fmt.Sprintf("%s__%d", f.Name, seen[name]) + } + s.columns[i] = name + } + if dupes { + fmt.Fprintln(s.stderr, "Warning: query returned duplicate column names; renamed duplicates to __N. Use AS aliases for stable names.") + } + + if _, err := io.WriteString(s.out, "[\n"); err != nil { + return err + } + s.hasOpenedArray = true + return nil +} + +func (s *jsonSink) Row(values []any) error { + if s.rowsWritten > 0 { + if _, err := io.WriteString(s.out, ",\n"); err != nil { + return err + } + } + + // Build the row object as a *map* of column to converted value, then let + // json.Marshal handle the encoding. We don't preserve key insertion order + // (json package sorts map keys), which is fine for machine consumers; the + // columns slice is the canonical order. + // + // Using ordered emission would require a manual writer. Worth the cost + // only if a downstream consumer needs schema-positional output, which + // none do today. + obj := make(map[string]any, len(s.columns)) + for i, name := range s.columns { + obj[name] = jsonValueWithOID(values[i], s.oids[i]) + } + + var buf bytes.Buffer + enc := json.NewEncoder(&buf) + enc.SetEscapeHTML(false) + if err := enc.Encode(obj); err != nil { + return fmt.Errorf("encode row: %w", err) + } + // json.Encoder always writes a trailing newline; trim it so our outer + // formatting controls the layout. + out := bytes.TrimRight(buf.Bytes(), "\n") + if _, err := s.out.Write(out); err != nil { + return err + } + s.rowsWritten++ + return nil +} + +func (s *jsonSink) End(commandTag string) error { + if s.hasOpenedArray { + _, err := io.WriteString(s.out, "\n]\n") + return err + } + // Command-only path: emit a single object. + obj := map[string]any{"command": commandTagVerb(commandTag)} + if rows, ok := commandTagRowCount(commandTag); ok { + obj["rows_affected"] = rows + } + + var buf bytes.Buffer + enc := json.NewEncoder(&buf) + enc.SetEscapeHTML(false) + if err := enc.Encode(obj); err != nil { + return fmt.Errorf("encode command tag: %w", err) + } + _, err := s.out.Write(buf.Bytes()) + return err +} + +// OnError closes the array cleanly so stdout remains parseable JSON. The +// caller still propagates the original error, which the command writes to +// stderr. +func (s *jsonSink) OnError(err error) { + if !s.hasOpenedArray { + return + } + // Best-effort; if this Write fails the stream is already corrupted + // and there is nothing more we can do. + _, _ = io.WriteString(s.out, "\n]\n") +} + +// commandTagVerb extracts the leading verb from a Postgres command tag (e.g. +// "INSERT 0 5" -> "INSERT"). Returns the input unchanged if there is no space. +func commandTagVerb(tag string) string { + for i, r := range tag { + if r == ' ' { + return tag[:i] + } + } + return tag +} + +// commandTagRowCount extracts the trailing row count from a Postgres command +// tag. INSERT tags have the shape "INSERT "; UPDATE/DELETE/SELECT +// have "VERB ". Returns ok=false for tags without a trailing integer +// (e.g. "CREATE DATABASE", "SET"). +func commandTagRowCount(tag string) (int64, bool) { + for i := len(tag) - 1; i >= 0; i-- { + if tag[i] == ' ' { + n, err := strconv.ParseInt(tag[i+1:], 10, 64) + if err != nil { + return 0, false + } + return n, true + } + } + return 0, false +} diff --git a/experimental/postgres/cmd/render_json_test.go b/experimental/postgres/cmd/render_json_test.go new file mode 100644 index 00000000000..a2617b27bc6 --- /dev/null +++ b/experimental/postgres/cmd/render_json_test.go @@ -0,0 +1,118 @@ +package postgrescmd + +import ( + "bytes" + "testing" + + "github.com/jackc/pgx/v5/pgconn" + "github.com/jackc/pgx/v5/pgtype" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func fieldsWithOIDs(names []string, oids []uint32) []pgconn.FieldDescription { + out := make([]pgconn.FieldDescription, len(names)) + for i, n := range names { + out[i] = pgconn.FieldDescription{Name: n, DataTypeOID: oids[i]} + } + return out +} + +func TestJSONSink_TwoRows(t *testing.T) { + var stdout, stderr bytes.Buffer + s := newJSONSink(&stdout, &stderr) + + require.NoError(t, s.Begin(fieldsWithOIDs([]string{"id", "name"}, []uint32{pgtype.Int8OID, pgtype.TextOID}))) + require.NoError(t, s.Row([]any{int64(1), "alice"})) + require.NoError(t, s.Row([]any{int64(2), "bob"})) + require.NoError(t, s.End("SELECT 2")) + + assert.Equal(t, + "[\n"+ + `{"id":1,"name":"alice"}`+",\n"+ + `{"id":2,"name":"bob"}`+ + "\n]\n", + stdout.String(), + ) + assert.Empty(t, stderr.String()) +} + +func TestJSONSink_EmptyRowsProducing(t *testing.T) { + var stdout, stderr bytes.Buffer + s := newJSONSink(&stdout, &stderr) + require.NoError(t, s.Begin(fieldsWithOIDs([]string{"id"}, []uint32{pgtype.Int8OID}))) + require.NoError(t, s.End("SELECT 0")) + assert.Equal(t, "[\n\n]\n", stdout.String()) +} + +func TestJSONSink_CommandOnly_WithRowCount(t *testing.T) { + var stdout, stderr bytes.Buffer + s := newJSONSink(&stdout, &stderr) + require.NoError(t, s.Begin(nil)) + require.NoError(t, s.End("INSERT 0 5")) + assert.JSONEq(t, `{"command":"INSERT","rows_affected":5}`, stdout.String()) +} + +func TestJSONSink_CommandOnly_NoRowCount(t *testing.T) { + var stdout, stderr bytes.Buffer + s := newJSONSink(&stdout, &stderr) + require.NoError(t, s.Begin(nil)) + require.NoError(t, s.End("CREATE DATABASE")) + assert.JSONEq(t, `{"command":"CREATE"}`, stdout.String()) +} + +func TestJSONSink_DuplicateColumns(t *testing.T) { + var stdout, stderr bytes.Buffer + s := newJSONSink(&stdout, &stderr) + require.NoError(t, s.Begin(fieldsWithOIDs([]string{"id", "id", "id"}, []uint32{pgtype.Int8OID, pgtype.Int8OID, pgtype.Int8OID}))) + require.NoError(t, s.Row([]any{int64(1), int64(2), int64(3)})) + require.NoError(t, s.End("SELECT 1")) + + assert.Contains(t, stdout.String(), `"id":1`) + assert.Contains(t, stdout.String(), `"id__2":2`) + assert.Contains(t, stdout.String(), `"id__3":3`) + assert.Contains(t, stderr.String(), "duplicate column names") +} + +func TestJSONSink_OnError_AfterRows(t *testing.T) { + var stdout, stderr bytes.Buffer + s := newJSONSink(&stdout, &stderr) + require.NoError(t, s.Begin(fieldsWithOIDs([]string{"id"}, []uint32{pgtype.Int8OID}))) + require.NoError(t, s.Row([]any{int64(1)})) + s.OnError(assert.AnError) + + assert.Contains(t, stdout.String(), "[\n") + assert.Contains(t, stdout.String(), `{"id":1}`) + assert.Contains(t, stdout.String(), "\n]\n") +} + +func TestJSONSink_OnError_BeforeBegin(t *testing.T) { + var stdout, stderr bytes.Buffer + s := newJSONSink(&stdout, &stderr) + s.OnError(assert.AnError) + assert.Empty(t, stdout.String()) +} + +func TestCommandTagParse(t *testing.T) { + tests := []struct { + tag string + verb string + rows int64 + hasCount bool + }{ + {"INSERT 0 5", "INSERT", 5, true}, + {"UPDATE 3", "UPDATE", 3, true}, + {"DELETE 0", "DELETE", 0, true}, + {"SELECT 100", "SELECT", 100, true}, + {"CREATE DATABASE", "CREATE", 0, false}, + {"SET", "SET", 0, false}, + } + for _, tc := range tests { + assert.Equal(t, tc.verb, commandTagVerb(tc.tag), "verb for %q", tc.tag) + count, ok := commandTagRowCount(tc.tag) + assert.Equal(t, tc.hasCount, ok, "hasCount for %q", tc.tag) + if tc.hasCount { + assert.Equal(t, tc.rows, count, "rows for %q", tc.tag) + } + } +} diff --git a/experimental/postgres/cmd/render_test.go b/experimental/postgres/cmd/render_test.go index 29aeb3c36fc..06190323e43 100644 --- a/experimental/postgres/cmd/render_test.go +++ b/experimental/postgres/cmd/render_test.go @@ -4,21 +4,29 @@ import ( "bytes" "testing" + "github.com/jackc/pgx/v5/pgconn" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestRenderText_RowsProducing(t *testing.T) { - r := &queryResult{ - Columns: []string{"id", "name"}, - Rows: [][]string{ - {"1", "alice"}, - {"2", "bob"}, - }, - CommandTag: "SELECT 2", +// fields is a small helper to build []pgconn.FieldDescription with just names +// (no OIDs), so renderer tests don't need to know about Postgres OIDs. +func fields(names ...string) []pgconn.FieldDescription { + out := make([]pgconn.FieldDescription, len(names)) + for i, n := range names { + out[i] = pgconn.FieldDescription{Name: n} } + return out +} + +func TestTextSink_RowsProducing(t *testing.T) { var buf bytes.Buffer - require.NoError(t, renderText(&buf, r)) + s := newTextSink(&buf) + + require.NoError(t, s.Begin(fields("id", "name"))) + require.NoError(t, s.Row([]any{int64(1), "alice"})) + require.NoError(t, s.Row([]any{int64(2), "bob"})) + require.NoError(t, s.End("SELECT 2")) assert.Equal(t, "id name\n"+ @@ -30,38 +38,36 @@ func TestRenderText_RowsProducing(t *testing.T) { ) } -func TestRenderText_SingleRow(t *testing.T) { - r := &queryResult{ - Columns: []string{"id"}, - Rows: [][]string{{"42"}}, - CommandTag: "SELECT 1", - } +func TestTextSink_SingleRow(t *testing.T) { var buf bytes.Buffer - require.NoError(t, renderText(&buf, r)) + s := newTextSink(&buf) + require.NoError(t, s.Begin(fields("id"))) + require.NoError(t, s.Row([]any{int64(42)})) + require.NoError(t, s.End("SELECT 1")) assert.Contains(t, buf.String(), "(1 row)\n") } -func TestRenderText_Empty(t *testing.T) { - r := &queryResult{ - Columns: []string{"id", "name"}, - CommandTag: "SELECT 0", - } +func TestTextSink_Empty(t *testing.T) { var buf bytes.Buffer - require.NoError(t, renderText(&buf, r)) + s := newTextSink(&buf) + require.NoError(t, s.Begin(fields("id", "name"))) + require.NoError(t, s.End("SELECT 0")) assert.Contains(t, buf.String(), "(0 rows)\n") } -func TestRenderText_CommandOnly(t *testing.T) { - r := &queryResult{ - CommandTag: "INSERT 0 5", - } +func TestTextSink_CommandOnly(t *testing.T) { var buf bytes.Buffer - require.NoError(t, renderText(&buf, r)) + s := newTextSink(&buf) + require.NoError(t, s.Begin(nil)) + require.NoError(t, s.End("INSERT 0 5")) assert.Equal(t, "INSERT 0 5\n", buf.String()) } -func TestQueryResultIsRowsProducing(t *testing.T) { - assert.False(t, (&queryResult{}).IsRowsProducing()) - assert.False(t, (&queryResult{CommandTag: "INSERT 0 1"}).IsRowsProducing()) - assert.True(t, (&queryResult{Columns: []string{"a"}}).IsRowsProducing()) +func TestTextSink_NULLRendersAsNULL(t *testing.T) { + var buf bytes.Buffer + s := newTextSink(&buf) + require.NoError(t, s.Begin(fields("id"))) + require.NoError(t, s.Row([]any{nil})) + require.NoError(t, s.End("SELECT 1")) + assert.Contains(t, buf.String(), "NULL") } diff --git a/experimental/postgres/cmd/targeting.go b/experimental/postgres/cmd/targeting.go index 5e72840f952..78e230adaac 100644 --- a/experimental/postgres/cmd/targeting.go +++ b/experimental/postgres/cmd/targeting.go @@ -8,6 +8,7 @@ import ( "github.com/databricks/cli/libs/cmdctx" "github.com/databricks/cli/libs/lakebase/target" "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/service/database" "github.com/databricks/databricks-sdk-go/service/postgres" ) @@ -71,9 +72,10 @@ func validateTargeting(f targetingFlags) error { } // resolveTarget translates the validated flags into a resolvedTarget. -// PR 1 supports autoscaling targeting only; provisioned support is added in -// the next PR. A provisioned-shaped --target returns a clear error pointing at -// the experimental status. +// +// --target accepts either an autoscaling resource path (starts with "projects/") +// or a provisioned instance name (everything else). Granular flags +// (--project, --branch, --endpoint) target autoscaling only. func resolveTarget(ctx context.Context, f targetingFlags) (*resolvedTarget, error) { w := cmdctx.WorkspaceClient(ctx) @@ -86,9 +88,7 @@ func resolveTarget(ctx context.Context, f targetingFlags) (*resolvedTarget, erro return resolveAutoscaling(ctx, w, spec) case f.target != "": - // Provisioned-shaped target. Out of scope for this PR; will be wired in - // the follow-up PR alongside JSON/CSV output. - return nil, errors.New("provisioned instances are not yet supported by this experimental command; use 'databricks psql ' for now") + return resolveProvisioned(ctx, w, f.target) default: spec := target.AutoscalingSpec{ @@ -100,6 +100,41 @@ func resolveTarget(ctx context.Context, f targetingFlags) (*resolvedTarget, erro } } +// resolveProvisioned looks up a provisioned instance and issues a token. The +// instance must be in the AVAILABLE state; transitional states return an +// error pointing the user at the lifecycle they are waiting on. +func resolveProvisioned(ctx context.Context, w *databricks.WorkspaceClient, instanceName string) (*resolvedTarget, error) { + instance, err := target.GetProvisioned(ctx, w, instanceName) + if err != nil { + return nil, err + } + + if instance.State != database.DatabaseInstanceStateAvailable { + return nil, fmt.Errorf("database instance %q is not ready for accepting connections (state: %s)", instance.Name, instance.State) + } + if instance.ReadWriteDns == "" { + return nil, fmt.Errorf("database instance %q has no read/write DNS yet", instance.Name) + } + + user, err := w.CurrentUser.Me(ctx) + if err != nil { + return nil, fmt.Errorf("failed to get current user: %w", err) + } + + token, err := target.ProvisionedCredential(ctx, w, instance.Name) + if err != nil { + return nil, err + } + + return &resolvedTarget{ + Kind: kindProvisioned, + Host: instance.ReadWriteDns, + Username: user.UserName, + Token: token, + DisplayName: instance.Name, + }, nil +} + // resolveAutoscaling expands a partial spec into a fully-resolved endpoint and // issues a short-lived OAuth token. Missing branch/endpoint IDs are // auto-selected when exactly one candidate exists; ambiguity propagates as an diff --git a/experimental/postgres/cmd/value.go b/experimental/postgres/cmd/value.go new file mode 100644 index 00000000000..3049b44a82a --- /dev/null +++ b/experimental/postgres/cmd/value.go @@ -0,0 +1,152 @@ +package postgrescmd + +import ( + "encoding/base64" + "encoding/hex" + "encoding/json" + "fmt" + "math" + "math/big" + "strconv" + "time" + + "github.com/jackc/pgx/v5/pgtype" +) + +// safeIntegerBound is the largest absolute integer value that can be +// represented exactly in IEEE 754 double precision. Beyond this, encoding an +// int64 as a JSON number silently loses precision in JavaScript-style +// consumers. We render those as JSON strings to preserve the original digits. +const safeIntegerBound = 1<<53 - 1 + +// textValue renders a Go value (as decoded by pgx) to its canonical Postgres +// text representation. Used by --output text and --output csv. +// +// NULL renders as the literal "NULL" so it lines up with the column rather +// than appearing as an empty cell. CSV converts that back to an empty field +// at write time (matches `psql --csv`). +func textValue(v any) string { + if v == nil { + return "NULL" + } + + switch x := v.(type) { + case string: + return x + case []byte: + return `\x` + hex.EncodeToString(x) + case bool: + if x { + return "t" + } + return "f" + case time.Time: + return x.Format(time.RFC3339Nano) + case fmt.Stringer: + return x.String() + } + + return fmt.Sprintf("%v", v) +} + +// jsonValue renders a Go value (as decoded by pgx) to a JSON-encodable +// representation. Returns a value the standard json.Marshal can handle +// directly and the JSON shape we want; never returns Go values that would +// silently lose information (e.g. NaN, oversized integers). +// +// The mapping intentionally favours machine-friendly output: +// - jsonb / json bytes round-trip as raw JSON (preserves bigint precision +// inside JSON values, e.g. {"id": 9007199254740993}). +// - bytea encodes as base64. +// - timestamps render in RFC3339 with subsecond precision. +// - Postgres NaN / +Inf / -Inf become JSON strings (JSON has no IEEE-special). +// - Integers outside ±2^53 become JSON strings to preserve precision. +// - Numerics, intervals, geometric types, and unknown types fall back to +// the canonical Postgres text representation as a JSON string. +func jsonValue(v any) any { + if v == nil { + return nil + } + + switch x := v.(type) { + case bool: + return x + case string: + return x + case int8, int16, int32, int, uint8, uint16, uint32: + return x + case int64: + if x > safeIntegerBound || x < -safeIntegerBound { + return strconv.FormatInt(x, 10) + } + return x + case uint64: + if x > safeIntegerBound { + return strconv.FormatUint(x, 10) + } + return x + case float32: + return jsonFloat(float64(x)) + case float64: + return jsonFloat(x) + case []byte: + // Postgres jsonb / json arrive as []byte holding raw JSON. Anything + // else we'd like to base64-encode. We can't tell them apart from the + // Go type alone; the sink calls jsonValueWithOID for oid-aware + // disambiguation. This bare path is the conservative fallback and + // treats unknown bytes as base64 (lossless and correct for bytea). + return base64.StdEncoding.EncodeToString(x) + case time.Time: + return x.UTC().Format(time.RFC3339Nano) + case *big.Int: + // numeric without scale; preserve as string to keep precision. + return x.String() + case fmt.Stringer: + return x.String() + } + + return fmt.Sprintf("%v", v) +} + +// jsonFloat handles the IEEE-special cases that JSON cannot represent. +// Finite values pass through unchanged. +func jsonFloat(f float64) any { + switch { + case math.IsNaN(f): + return "NaN" + case math.IsInf(f, 1): + return "Infinity" + case math.IsInf(f, -1): + return "-Infinity" + } + return f +} + +// jsonValueWithOID applies oid-aware overrides on top of jsonValue. The two +// places this matters today are JSON/JSONB and bytea: both arrive from pgx as +// []byte but want different JSON shapes (raw JSON passthrough vs base64). +func jsonValueWithOID(v any, oid uint32) any { + if v == nil { + return nil + } + + switch oid { + case pgtype.JSONOID, pgtype.JSONBOID: + // pgx returns json/jsonb as already-decoded Go values when no codec + // is registered; with the default codec, they decode to map/slice/etc. + // In QueryExecModeExec text-mode, pgx returns the raw JSON bytes as + // string (since the wire is text). We accept both shapes. + switch x := v.(type) { + case []byte: + return json.RawMessage(x) + case string: + return json.RawMessage(x) + } + case pgtype.ByteaOID: + if b, ok := v.([]byte); ok { + return base64.StdEncoding.EncodeToString(b) + } + } + + return jsonValue(v) +} diff --git a/experimental/postgres/cmd/value_test.go b/experimental/postgres/cmd/value_test.go new file mode 100644 index 00000000000..092fc6f7284 --- /dev/null +++ b/experimental/postgres/cmd/value_test.go @@ -0,0 +1,84 @@ +package postgrescmd + +import ( + "encoding/json" + "math" + "testing" + "time" + + "github.com/jackc/pgx/v5/pgtype" + "github.com/stretchr/testify/assert" +) + +func TestJSONValue_PrimitiveTypes(t *testing.T) { + assert.Equal(t, true, jsonValue(true)) + assert.Equal(t, "hello", jsonValue("hello")) + assert.Equal(t, int64(42), jsonValue(int64(42))) + assert.InDelta(t, 3.14, jsonValue(float64(3.14)), 1e-9) +} + +func TestJSONValue_NULL(t *testing.T) { + assert.Nil(t, jsonValue(nil)) +} + +func TestJSONValue_FloatSpecials(t *testing.T) { + assert.Equal(t, "NaN", jsonValue(math.NaN())) + assert.Equal(t, "Infinity", jsonValue(math.Inf(1))) + assert.Equal(t, "-Infinity", jsonValue(math.Inf(-1))) +} + +func TestJSONValue_LargeIntPreservedAsString(t *testing.T) { + big := int64(1<<53 + 1) + assert.Equal(t, "9007199254740993", jsonValue(big)) + + negBig := -int64(1<<53 + 1) + assert.Equal(t, "-9007199254740993", jsonValue(negBig)) +} + +func TestJSONValue_SafeIntPreservedAsNumber(t *testing.T) { + safe := int64(1<<53 - 1) + assert.Equal(t, safe, jsonValue(safe)) +} + +func TestJSONValue_TimestampToRFC3339(t *testing.T) { + tm := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC) + v := jsonValue(tm) + assert.Equal(t, "2024-01-15T10:30:00Z", v) +} + +func TestJSONValueWithOID_JSONBPassthrough(t *testing.T) { + raw := []byte(`{"id":9007199254740993,"name":"alice"}`) + v := jsonValueWithOID(raw, pgtype.JSONBOID) + + encoded, err := json.Marshal(v) + assert.NoError(t, err) + assert.JSONEq(t, string(raw), string(encoded)) +} + +func TestJSONValueWithOID_ByteaToBase64(t *testing.T) { + v := jsonValueWithOID([]byte{0xde, 0xad, 0xbe, 0xef}, pgtype.ByteaOID) + assert.Equal(t, "3q2+7w==", v) +} + +func TestJSONValueWithOID_FallsBackToJSONValue(t *testing.T) { + assert.Equal(t, int64(42), jsonValueWithOID(int64(42), pgtype.Int8OID)) + assert.Nil(t, jsonValueWithOID(nil, pgtype.TextOID)) +} + +func TestTextValue_NULL(t *testing.T) { + assert.Equal(t, "NULL", textValue(nil)) +} + +func TestTextValue_Bool(t *testing.T) { + assert.Equal(t, "t", textValue(true)) + assert.Equal(t, "f", textValue(false)) +} + +func TestTextValue_BytesAsHex(t *testing.T) { + assert.Equal(t, `\xdeadbeef`, textValue([]byte{0xde, 0xad, 0xbe, 0xef})) +} + +func TestTextValue_Time(t *testing.T) { + tm := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC) + assert.Equal(t, "2024-01-15T10:30:00Z", textValue(tm)) +} From fdd0e1b8cd8778281d18aa022c0ee2a6567b06f1 Mon Sep 17 00:00:00 2001 From: simon Date: Thu, 30 Apr 2026 11:15:34 +0200 Subject: [PATCH 06/15] Address PR 2 review feedback round 1 - Fix non-finite float text: textValue had no float branch and fell through to fmt.Sprintf, which emits +Inf/-Inf instead of Postgres' Infinity/-Infinity. Added the explicit float case + tests. - Emit JSON object keys in column order, not alphabetical. The map approach inadvertently sorted keys; switched to manual ordered emission (write '{', encode key:value pairs in column order, write '}'). Added a regression test with non-alphabetical column names. - Honor --output text on a pipe instead of silently rewriting to JSON. Repo rule says "reject incompatible inputs early; never silently ignore a flag the current mode can't honor". Auto-fallback now only fires when the flag was not explicitly set (or not pinned by env). - Trim impossible Go types from jsonValue (pgx never decodes int8 / uint8/16/32 / uint64 from PG columns). - Drop the redundant ReadWriteDns guard in resolveProvisioned; an AVAILABLE Lakebase instance is documented to have DNS, and cmd/psql doesn't carry the same guard. - Build the unsupported-format error from allOutputFormats so the message stays in sync if a fourth format is added. - Update execute.go's QueryExecModeExec doc to acknowledge that we now call rows.Values() (not RawValues), so all sinks see Go-typed input. - Collapse empty rows-producing JSON to "[\n]\n" and matching OnError. - Add stderr warning helper (commandTagRowCount now covered for MERGE/COPY/FETCH/MOVE). - Test gaps: text +Inf, text finite float, JSON column order, OnError for csv/text sinks, CSV with embedded newline + quote. Co-authored-by: Isaac --- .../query/provisioned-targeting/output.txt | 4 - .../query/provisioned-targeting/script | 3 - .../query/provisioned-targeting/test.toml | 9 -- experimental/postgres/cmd/execute.go | 7 +- experimental/postgres/cmd/output.go | 21 +++- experimental/postgres/cmd/output_test.go | 18 +++- experimental/postgres/cmd/query.go | 5 + experimental/postgres/cmd/render_csv_test.go | 20 ++++ experimental/postgres/cmd/render_json.go | 97 +++++++++++++------ experimental/postgres/cmd/render_json_test.go | 15 ++- experimental/postgres/cmd/render_test.go | 12 +++ experimental/postgres/cmd/targeting.go | 3 - experimental/postgres/cmd/value.go | 33 +++++-- experimental/postgres/cmd/value_test.go | 11 +++ 14 files changed, 192 insertions(+), 66 deletions(-) diff --git a/acceptance/cmd/experimental/postgres/query/provisioned-targeting/output.txt b/acceptance/cmd/experimental/postgres/query/provisioned-targeting/output.txt index 0f00f8b3e44..bb7ebe1ee69 100644 --- a/acceptance/cmd/experimental/postgres/query/provisioned-targeting/output.txt +++ b/acceptance/cmd/experimental/postgres/query/provisioned-targeting/output.txt @@ -3,10 +3,6 @@ >>> musterr [CLI] experimental postgres query --target starting-instance SELECT 1 Error: database instance "starting-instance" is not ready for accepting connections (state: STARTING) -=== Provisioned target with no DNS should error: ->>> musterr [CLI] experimental postgres query --target no-dns-instance SELECT 1 -Error: database instance "no-dns-instance" has no read/write DNS yet - === Provisioned target not found should surface SDK 404: >>> musterr [CLI] experimental postgres query --target missing-instance SELECT 1 Error: failed to get database instance: instance not found diff --git a/acceptance/cmd/experimental/postgres/query/provisioned-targeting/script b/acceptance/cmd/experimental/postgres/query/provisioned-targeting/script index d8995c62a6c..5459e01dfcc 100644 --- a/acceptance/cmd/experimental/postgres/query/provisioned-targeting/script +++ b/acceptance/cmd/experimental/postgres/query/provisioned-targeting/script @@ -1,8 +1,5 @@ title "Provisioned target in non-AVAILABLE state should error:" trace musterr $CLI experimental postgres query --target starting-instance "SELECT 1" -title "Provisioned target with no DNS should error:" -trace musterr $CLI experimental postgres query --target no-dns-instance "SELECT 1" - title "Provisioned target not found should surface SDK 404:" trace musterr $CLI experimental postgres query --target missing-instance "SELECT 1" diff --git a/acceptance/cmd/experimental/postgres/query/provisioned-targeting/test.toml b/acceptance/cmd/experimental/postgres/query/provisioned-targeting/test.toml index 4821dab5741..25513a7a975 100644 --- a/acceptance/cmd/experimental/postgres/query/provisioned-targeting/test.toml +++ b/acceptance/cmd/experimental/postgres/query/provisioned-targeting/test.toml @@ -10,15 +10,6 @@ Response.Body = ''' } ''' -[[Server]] -Pattern = "GET /api/2.0/database/instances/no-dns-instance" -Response.Body = ''' -{ - "name": "no-dns-instance", - "state": "AVAILABLE" -} -''' - [[Server]] Pattern = "GET /api/2.0/database/instances/missing-instance" Response.StatusCode = 404 diff --git a/experimental/postgres/cmd/execute.go b/experimental/postgres/cmd/execute.go index 61d93bd7bc2..51f70f836a9 100644 --- a/experimental/postgres/cmd/execute.go +++ b/experimental/postgres/cmd/execute.go @@ -36,8 +36,11 @@ type rowSink interface { // closed at the end of the command, so the cached prepared statement // never gets reused. // 2. Exec mode uses Postgres' extended-protocol "exec" path with text-format -// result columns, which keeps the canonical-Postgres-text rendering for -// --output text and --output csv straightforward. +// result columns. We still call rows.Values() (not RawValues) so all +// three sinks see uniform Go-typed input; jsonValue/textValue then map +// those types back to canonical strings for text/CSV and to JSON-typed +// values for JSON. The wire format being text means pgx's decode is +// cheap (text -> Go) rather than binary -> Go. // // QueryExecModeExec still uses extended protocol with a single statement and // no implicit transaction wrap, so transaction-disallowed DDL like diff --git a/experimental/postgres/cmd/output.go b/experimental/postgres/cmd/output.go index c293b424b73..9976cd0d548 100644 --- a/experimental/postgres/cmd/output.go +++ b/experimental/postgres/cmd/output.go @@ -34,34 +34,45 @@ var allOutputFormats = []outputFormat{outputText, outputJSON, outputCSV} // values are silently ignored, matching cmd/root/io.go and aitools). // 3. The flag default ("text"). // -// Then the auto-selection rule applies: text on a non-TTY stdout falls back -// to JSON. This matches the aitools query command and means scripts piping -// stdout get machine-readable output by default. +// Then the auto-selection rule applies: a *defaulted* text mode on a non-TTY +// stdout falls back to JSON, so scripts piping the output get machine- +// readable output by default. An *explicit* --output text is honoured even +// on a pipe; per CLAUDE.md we don't silently override flags the user set. // // flagSet is true if the user explicitly passed --output. stdoutTTY is true // if stdout is a terminal. func resolveOutputFormat(ctx context.Context, flagValue string, flagSet, stdoutTTY bool) (outputFormat, error) { chosen := outputFormat(strings.ToLower(flagValue)) + chosenExplicit := flagSet if !flagSet { if v, ok := env.Lookup(ctx, envOutputFormat); ok { candidate := outputFormat(strings.ToLower(v)) if isKnownOutputFormat(candidate) { chosen = candidate + chosenExplicit = true } } } if !isKnownOutputFormat(chosen) { - return "", fmt.Errorf("unsupported output format %q; expected one of: text, json, csv", flagValue) + return "", fmt.Errorf("unsupported output format %q; expected one of: %s", flagValue, joinOutputFormats(allOutputFormats)) } - if chosen == outputText && !stdoutTTY { + if chosen == outputText && !stdoutTTY && !chosenExplicit { return outputJSON, nil } return chosen, nil } +func joinOutputFormats(formats []outputFormat) string { + parts := make([]string, len(formats)) + for i, f := range formats { + parts[i] = string(f) + } + return strings.Join(parts, ", ") +} + func isKnownOutputFormat(f outputFormat) bool { switch f { case outputText, outputJSON, outputCSV: diff --git a/experimental/postgres/cmd/output_test.go b/experimental/postgres/cmd/output_test.go index 79289a43e56..4598085805a 100644 --- a/experimental/postgres/cmd/output_test.go +++ b/experimental/postgres/cmd/output_test.go @@ -23,11 +23,25 @@ func TestResolveOutputFormat_TextOnPipeFallsBackToJSON(t *testing.T) { assert.Equal(t, outputJSON, got) } -func TestResolveOutputFormat_ExplicitTextOnPipeAlsoFallsBackToJSON(t *testing.T) { +func TestResolveOutputFormat_ExplicitTextOnPipeIsHonoured(t *testing.T) { ctx := t.Context() got, err := resolveOutputFormat(ctx, "text", true, false) require.NoError(t, err) - assert.Equal(t, outputJSON, got) + assert.Equal(t, outputText, got) +} + +func TestResolveOutputFormat_EnvVarTextOnPipeIsHonoured(t *testing.T) { + ctx := env.Set(t.Context(), envOutputFormat, "text") + got, err := resolveOutputFormat(ctx, "text", false, false) + require.NoError(t, err) + assert.Equal(t, outputText, got) +} + +func TestResolveOutputFormat_EnvVarCSVOnPipe(t *testing.T) { + ctx := env.Set(t.Context(), envOutputFormat, "csv") + got, err := resolveOutputFormat(ctx, "text", false, false) + require.NoError(t, err) + assert.Equal(t, outputCSV, got) } func TestResolveOutputFormat_ExplicitJSON(t *testing.T) { diff --git a/experimental/postgres/cmd/query.go b/experimental/postgres/cmd/query.go index c3078f24d82..2b4f12694f9 100644 --- a/experimental/postgres/cmd/query.go +++ b/experimental/postgres/cmd/query.go @@ -120,6 +120,11 @@ func runQuery(ctx context.Context, cmd *cobra.Command, sql string, f queryFlags) return err } + // SupportsColor is the public TTY-ish signal libs/cmdio exposes today; it + // also folds in NO_COLOR / TERM=dumb, which strictly speaking are colour + // preferences rather than TTY signals. Users who hit that edge case can + // pass --output text explicitly; that path is honoured (see + // resolveOutputFormat). Mirrors the aitools query command. stdoutTTY := cmdio.SupportsColor(ctx, cmd.OutOrStdout()) format, err := resolveOutputFormat(ctx, f.outputFormat, f.outputFormatSet, stdoutTTY) if err != nil { diff --git a/experimental/postgres/cmd/render_csv_test.go b/experimental/postgres/cmd/render_csv_test.go index 35d1c3596f1..5a3ee277e2c 100644 --- a/experimental/postgres/cmd/render_csv_test.go +++ b/experimental/postgres/cmd/render_csv_test.go @@ -47,3 +47,23 @@ func TestCSVSink_QuotesFieldsWithCommas(t *testing.T) { require.NoError(t, s.End("SELECT 1")) assert.Contains(t, stdout.String(), `"a,b"`) } + +func TestCSVSink_EmbeddedNewlineAndQuote(t *testing.T) { + var stdout, stderr bytes.Buffer + s := newCSVSink(&stdout, &stderr) + require.NoError(t, s.Begin(fields("note"))) + require.NoError(t, s.Row([]any{"line1\nline2 \"quoted\""})) + require.NoError(t, s.End("SELECT 1")) + assert.Contains(t, stdout.String(), "\"line1\nline2 \"\"quoted\"\"\"") +} + +func TestCSVSink_OnError_NoOp(t *testing.T) { + var stdout, stderr bytes.Buffer + s := newCSVSink(&stdout, &stderr) + require.NoError(t, s.Begin(fields("id"))) + require.NoError(t, s.Row([]any{int64(1)})) + s.OnError(assert.AnError) + // CSV has no open structure to close; partial row count plus header is + // what the consumer sees. The sink must not panic on OnError. + assert.Contains(t, stdout.String(), "id\n1\n") +} diff --git a/experimental/postgres/cmd/render_json.go b/experimental/postgres/cmd/render_json.go index 1d9a53a8e8d..dc713b6b786 100644 --- a/experimental/postgres/cmd/render_json.go +++ b/experimental/postgres/cmd/render_json.go @@ -82,53 +82,84 @@ func (s *jsonSink) Row(values []any) error { } } - // Build the row object as a *map* of column to converted value, then let - // json.Marshal handle the encoding. We don't preserve key insertion order - // (json package sorts map keys), which is fine for machine consumers; the - // columns slice is the canonical order. - // - // Using ordered emission would require a manual writer. Worth the cost - // only if a downstream consumer needs schema-positional output, which - // none do today. - obj := make(map[string]any, len(s.columns)) + // Emit keys in column order. json.Marshal on a map sorts keys + // alphabetically; SELECT order is what consumers expect, so we write + // `{`, walk columns, encode key:value pairs ourselves, then `}`. + if _, err := io.WriteString(s.out, "{"); err != nil { + return err + } for i, name := range s.columns { - obj[name] = jsonValueWithOID(values[i], s.oids[i]) + if i > 0 { + if _, err := io.WriteString(s.out, ","); err != nil { + return err + } + } + key, err := marshalJSON(name) + if err != nil { + return fmt.Errorf("encode column name %q: %w", name, err) + } + if _, err := s.out.Write(key); err != nil { + return err + } + if _, err := io.WriteString(s.out, ":"); err != nil { + return err + } + val, err := marshalJSON(jsonValueWithOID(values[i], s.oids[i])) + if err != nil { + return fmt.Errorf("encode value for %q: %w", name, err) + } + if _, err := s.out.Write(val); err != nil { + return err + } } + if _, err := io.WriteString(s.out, "}"); err != nil { + return err + } + s.rowsWritten++ + return nil +} +// marshalJSON encodes v with HTML escaping disabled (so jsonb values like +// {"url":""} round-trip without `<` rewrites). encoding/json's Encoder +// is the only path that exposes SetEscapeHTML, so we route through it and +// strip the trailing newline it always appends. +func marshalJSON(v any) ([]byte, error) { var buf bytes.Buffer enc := json.NewEncoder(&buf) enc.SetEscapeHTML(false) - if err := enc.Encode(obj); err != nil { - return fmt.Errorf("encode row: %w", err) - } - // json.Encoder always writes a trailing newline; trim it so our outer - // formatting controls the layout. - out := bytes.TrimRight(buf.Bytes(), "\n") - if _, err := s.out.Write(out); err != nil { - return err + if err := enc.Encode(v); err != nil { + return nil, err } - s.rowsWritten++ - return nil + return bytes.TrimRight(buf.Bytes(), "\n"), nil } func (s *jsonSink) End(commandTag string) error { if s.hasOpenedArray { + if s.rowsWritten == 0 { + // Empty result: collapse to "[]\n" rather than "[\n\n]\n". + _, err := io.WriteString(s.out, "]\n") + return err + } _, err := io.WriteString(s.out, "\n]\n") return err } - // Command-only path: emit a single object. - obj := map[string]any{"command": commandTagVerb(commandTag)} - if rows, ok := commandTagRowCount(commandTag); ok { - obj["rows_affected"] = rows + // Command-only path: emit a single ordered object. + if _, err := io.WriteString(s.out, `{"command":`); err != nil { + return err } - - var buf bytes.Buffer - enc := json.NewEncoder(&buf) - enc.SetEscapeHTML(false) - if err := enc.Encode(obj); err != nil { - return fmt.Errorf("encode command tag: %w", err) + verbBytes, err := marshalJSON(commandTagVerb(commandTag)) + if err != nil { + return fmt.Errorf("encode command tag verb: %w", err) + } + if _, err := s.out.Write(verbBytes); err != nil { + return err } - _, err := s.out.Write(buf.Bytes()) + if rows, ok := commandTagRowCount(commandTag); ok { + if _, err := fmt.Fprintf(s.out, `,"rows_affected":%d`, rows); err != nil { + return err + } + } + _, err = io.WriteString(s.out, "}\n") return err } @@ -141,6 +172,10 @@ func (s *jsonSink) OnError(err error) { } // Best-effort; if this Write fails the stream is already corrupted // and there is nothing more we can do. + if s.rowsWritten == 0 { + _, _ = io.WriteString(s.out, "]\n") + return + } _, _ = io.WriteString(s.out, "\n]\n") } diff --git a/experimental/postgres/cmd/render_json_test.go b/experimental/postgres/cmd/render_json_test.go index a2617b27bc6..26aa79cc832 100644 --- a/experimental/postgres/cmd/render_json_test.go +++ b/experimental/postgres/cmd/render_json_test.go @@ -42,7 +42,16 @@ func TestJSONSink_EmptyRowsProducing(t *testing.T) { s := newJSONSink(&stdout, &stderr) require.NoError(t, s.Begin(fieldsWithOIDs([]string{"id"}, []uint32{pgtype.Int8OID}))) require.NoError(t, s.End("SELECT 0")) - assert.Equal(t, "[\n\n]\n", stdout.String()) + assert.Equal(t, "[\n]\n", stdout.String()) +} + +func TestJSONSink_KeysInColumnOrder(t *testing.T) { + var stdout, stderr bytes.Buffer + s := newJSONSink(&stdout, &stderr) + require.NoError(t, s.Begin(fieldsWithOIDs([]string{"b", "a"}, []uint32{pgtype.Int8OID, pgtype.Int8OID}))) + require.NoError(t, s.Row([]any{int64(1), int64(2)})) + require.NoError(t, s.End("SELECT 1")) + assert.Equal(t, "[\n"+`{"b":1,"a":2}`+"\n]\n", stdout.String()) } func TestJSONSink_CommandOnly_WithRowCount(t *testing.T) { @@ -104,6 +113,10 @@ func TestCommandTagParse(t *testing.T) { {"UPDATE 3", "UPDATE", 3, true}, {"DELETE 0", "DELETE", 0, true}, {"SELECT 100", "SELECT", 100, true}, + {"MERGE 5", "MERGE", 5, true}, + {"COPY 1000", "COPY", 1000, true}, + {"FETCH 7", "FETCH", 7, true}, + {"MOVE 3", "MOVE", 3, true}, {"CREATE DATABASE", "CREATE", 0, false}, {"SET", "SET", 0, false}, } diff --git a/experimental/postgres/cmd/render_test.go b/experimental/postgres/cmd/render_test.go index 06190323e43..d451febb191 100644 --- a/experimental/postgres/cmd/render_test.go +++ b/experimental/postgres/cmd/render_test.go @@ -71,3 +71,15 @@ func TestTextSink_NULLRendersAsNULL(t *testing.T) { require.NoError(t, s.End("SELECT 1")) assert.Contains(t, buf.String(), "NULL") } + +func TestTextSink_OnError_NoOp(t *testing.T) { + var buf bytes.Buffer + s := newTextSink(&buf) + require.NoError(t, s.Begin(fields("id"))) + require.NoError(t, s.Row([]any{int64(1)})) + s.OnError(assert.AnError) + // Text sink has no open structure to close. OnError must not panic and + // must not emit a partial table; the partial result lives in s.rows but + // is never flushed. + assert.Empty(t, buf.String()) +} diff --git a/experimental/postgres/cmd/targeting.go b/experimental/postgres/cmd/targeting.go index 78e230adaac..4c46bee02dc 100644 --- a/experimental/postgres/cmd/targeting.go +++ b/experimental/postgres/cmd/targeting.go @@ -112,9 +112,6 @@ func resolveProvisioned(ctx context.Context, w *databricks.WorkspaceClient, inst if instance.State != database.DatabaseInstanceStateAvailable { return nil, fmt.Errorf("database instance %q is not ready for accepting connections (state: %s)", instance.Name, instance.State) } - if instance.ReadWriteDns == "" { - return nil, fmt.Errorf("database instance %q has no read/write DNS yet", instance.Name) - } user, err := w.CurrentUser.Me(ctx) if err != nil { diff --git a/experimental/postgres/cmd/value.go b/experimental/postgres/cmd/value.go index 3049b44a82a..21beedd04f0 100644 --- a/experimental/postgres/cmd/value.go +++ b/experimental/postgres/cmd/value.go @@ -25,6 +25,11 @@ const safeIntegerBound = 1<<53 - 1 // NULL renders as the literal "NULL" so it lines up with the column rather // than appearing as an empty cell. CSV converts that back to an empty field // at write time (matches `psql --csv`). +// +// Floats are rendered with Postgres' canonical wording for the IEEE specials +// ("NaN" / "Infinity" / "-Infinity"), not Go's `fmt.Sprintf("%v")` defaults +// (which would emit "+Inf"/"-Inf"). This keeps text/CSV consistent with what +// `psql` would print. func textValue(v any) string { if v == nil { return "NULL" @@ -40,6 +45,10 @@ func textValue(v any) string { return "t" } return "f" + case float64: + return floatTextForm(x) + case float32: + return floatTextForm(float64(x)) case time.Time: return x.Format(time.RFC3339Nano) case fmt.Stringer: @@ -49,6 +58,20 @@ func textValue(v any) string { return fmt.Sprintf("%v", v) } +// floatTextForm formats a float using Postgres' canonical text wording for +// the IEEE specials and Go's shortest-round-trip 'g' format otherwise. +func floatTextForm(f float64) string { + switch { + case math.IsNaN(f): + return "NaN" + case math.IsInf(f, 1): + return "Infinity" + case math.IsInf(f, -1): + return "-Infinity" + } + return strconv.FormatFloat(f, 'g', -1, 64) +} + // jsonValue renders a Go value (as decoded by pgx) to a JSON-encodable // representation. Returns a value the standard json.Marshal can handle // directly and the JSON shape we want; never returns Go values that would @@ -73,18 +96,16 @@ func jsonValue(v any) any { return x case string: return x - case int8, int16, int32, int, uint8, uint16, uint32: + case int16, int32: return x case int64: + // pgx decodes Postgres int8 to Go int64. Outside the IEEE-754 safe + // integer range we render as a string so JavaScript-style consumers + // don't silently lose precision. if x > safeIntegerBound || x < -safeIntegerBound { return strconv.FormatInt(x, 10) } return x - case uint64: - if x > safeIntegerBound { - return strconv.FormatUint(x, 10) - } - return x case float32: return jsonFloat(float64(x)) case float64: diff --git a/experimental/postgres/cmd/value_test.go b/experimental/postgres/cmd/value_test.go index 092fc6f7284..d52edae90bc 100644 --- a/experimental/postgres/cmd/value_test.go +++ b/experimental/postgres/cmd/value_test.go @@ -82,3 +82,14 @@ func TestTextValue_Time(t *testing.T) { tm := time.Date(2024, 1, 15, 10, 30, 0, 0, time.UTC) assert.Equal(t, "2024-01-15T10:30:00Z", textValue(tm)) } + +func TestTextValue_FloatSpecials(t *testing.T) { + assert.Equal(t, "NaN", textValue(math.NaN())) + assert.Equal(t, "Infinity", textValue(math.Inf(1))) + assert.Equal(t, "-Infinity", textValue(math.Inf(-1))) +} + +func TestTextValue_FiniteFloat(t *testing.T) { + assert.Equal(t, "3.14", textValue(float64(3.14))) + assert.Equal(t, "0", textValue(float64(0))) +} From 287dd62aab56a80597288b85c30f9068e09b1a11 Mon Sep 17 00:00:00 2001 From: simon Date: Thu, 30 Apr 2026 11:21:18 +0200 Subject: [PATCH 07/15] Address PR 2 review feedback round 2 - Doc fix: textSink.OnError doc said "prints whatever rows have been collected" but text mode buffers everything to End. New doc states the buffered partial result is discarded on iteration error. - Doc fix: textValue float comment overstated psql parity. Tightened to acknowledge Go's 'g' format may differ from psql in exponential vs fixed notation around the boundary. - Tighten OnError contract: explicitly states it is NOT called when Begin itself errors. - Replace switch-by-format in isKnownOutputFormat with slices.Contains on allOutputFormats so adding a fourth format is one edit. - Tighten command-only JSON tests from JSONEq (key-order ignored) to byte-equal so a future field addition is caught. - Tighten JSONSink_OnError tests to byte-equal; add the Begin-but-no-rows case which exercises the rowsWritten==0 branch. Co-authored-by: Isaac --- experimental/postgres/cmd/execute.go | 8 +++++--- experimental/postgres/cmd/output.go | 7 ++----- experimental/postgres/cmd/render.go | 7 ++++--- experimental/postgres/cmd/render_json_test.go | 17 ++++++++++++----- experimental/postgres/cmd/value.go | 10 ++++++---- 5 files changed, 29 insertions(+), 20 deletions(-) diff --git a/experimental/postgres/cmd/execute.go b/experimental/postgres/cmd/execute.go index 51f70f836a9..8d0b896031c 100644 --- a/experimental/postgres/cmd/execute.go +++ b/experimental/postgres/cmd/execute.go @@ -21,9 +21,11 @@ type rowSink interface { Row(values []any) error // End is called once on successful completion. End(commandTag string) error - // OnError is called if iteration errors after Begin returned. The sink - // is expected to flush any in-progress output structures so stdout - // remains well-formed. The caller still surfaces err to its caller. + // OnError is called if iteration errors after Begin returned successfully. + // The sink is expected to flush any in-progress output structures so + // stdout remains well-formed. The caller still surfaces err to its caller. + // If Begin itself errors, OnError is NOT called: sinks must not write any + // framing before Begin succeeds. OnError(err error) } diff --git a/experimental/postgres/cmd/output.go b/experimental/postgres/cmd/output.go index 9976cd0d548..e5b59fec96f 100644 --- a/experimental/postgres/cmd/output.go +++ b/experimental/postgres/cmd/output.go @@ -3,6 +3,7 @@ package postgrescmd import ( "context" "fmt" + "slices" "strings" "github.com/databricks/cli/libs/env" @@ -74,9 +75,5 @@ func joinOutputFormats(formats []outputFormat) string { } func isKnownOutputFormat(f outputFormat) bool { - switch f { - case outputText, outputJSON, outputCSV: - return true - } - return false + return slices.Contains(allOutputFormats, f) } diff --git a/experimental/postgres/cmd/render.go b/experimental/postgres/cmd/render.go index bc45c89e0d0..2e1daf6376b 100644 --- a/experimental/postgres/cmd/render.go +++ b/experimental/postgres/cmd/render.go @@ -62,9 +62,10 @@ func (s *textSink) End(commandTag string) error { return err } -// OnError for text sinks is a no-op: text output prints whatever rows have -// already been collected, with no open structure to close. The caller -// surfaces the error separately (cobra's default error rendering). +// OnError for text sinks is a no-op. Text mode buffers all rows for +// tabwriter alignment, so a partial result is discarded on iteration error; +// only cobra's error message reaches the user. The streaming sinks (json, +// csv) handle the partial-result case themselves. func (s *textSink) OnError(err error) {} func headerSeparator(cols []string) []string { diff --git a/experimental/postgres/cmd/render_json_test.go b/experimental/postgres/cmd/render_json_test.go index 26aa79cc832..4e6f474d257 100644 --- a/experimental/postgres/cmd/render_json_test.go +++ b/experimental/postgres/cmd/render_json_test.go @@ -59,7 +59,9 @@ func TestJSONSink_CommandOnly_WithRowCount(t *testing.T) { s := newJSONSink(&stdout, &stderr) require.NoError(t, s.Begin(nil)) require.NoError(t, s.End("INSERT 0 5")) - assert.JSONEq(t, `{"command":"INSERT","rows_affected":5}`, stdout.String()) + // Byte-equal: pins the field order so adding a future field (e.g. last_oid) + // must update the test rather than silently drift. + assert.Equal(t, `{"command":"INSERT","rows_affected":5}`+"\n", stdout.String()) } func TestJSONSink_CommandOnly_NoRowCount(t *testing.T) { @@ -67,7 +69,7 @@ func TestJSONSink_CommandOnly_NoRowCount(t *testing.T) { s := newJSONSink(&stdout, &stderr) require.NoError(t, s.Begin(nil)) require.NoError(t, s.End("CREATE DATABASE")) - assert.JSONEq(t, `{"command":"CREATE"}`, stdout.String()) + assert.Equal(t, `{"command":"CREATE"}`+"\n", stdout.String()) } func TestJSONSink_DuplicateColumns(t *testing.T) { @@ -89,10 +91,15 @@ func TestJSONSink_OnError_AfterRows(t *testing.T) { require.NoError(t, s.Begin(fieldsWithOIDs([]string{"id"}, []uint32{pgtype.Int8OID}))) require.NoError(t, s.Row([]any{int64(1)})) s.OnError(assert.AnError) + assert.Equal(t, "[\n"+`{"id":1}`+"\n]\n", stdout.String()) +} - assert.Contains(t, stdout.String(), "[\n") - assert.Contains(t, stdout.String(), `{"id":1}`) - assert.Contains(t, stdout.String(), "\n]\n") +func TestJSONSink_OnError_AfterBeginNoRows(t *testing.T) { + var stdout, stderr bytes.Buffer + s := newJSONSink(&stdout, &stderr) + require.NoError(t, s.Begin(fieldsWithOIDs([]string{"id"}, []uint32{pgtype.Int8OID}))) + s.OnError(assert.AnError) + assert.Equal(t, "[\n]\n", stdout.String()) } func TestJSONSink_OnError_BeforeBegin(t *testing.T) { diff --git a/experimental/postgres/cmd/value.go b/experimental/postgres/cmd/value.go index 21beedd04f0..1578c7efecf 100644 --- a/experimental/postgres/cmd/value.go +++ b/experimental/postgres/cmd/value.go @@ -26,10 +26,12 @@ const safeIntegerBound = 1<<53 - 1 // than appearing as an empty cell. CSV converts that back to an empty field // at write time (matches `psql --csv`). // -// Floats are rendered with Postgres' canonical wording for the IEEE specials -// ("NaN" / "Infinity" / "-Infinity"), not Go's `fmt.Sprintf("%v")` defaults -// (which would emit "+Inf"/"-Inf"). This keeps text/CSV consistent with what -// `psql` would print. +// IEEE special floats use Postgres' canonical wording ("NaN" / "Infinity" +// / "-Infinity"), not Go's `fmt.Sprintf("%v")` defaults (which would emit +// "+Inf"/"-Inf"). Finite floats use Go's shortest-round-trip 'g' format, +// which may differ from psql in exponential vs fixed notation around the +// 'g' boundary (e.g. Go prints `1e+10`; psql prints `10000000000`). Full +// psql parity is not worth a custom formatter. func textValue(v any) string { if v == nil { return "NULL" From 5a27bf0034e479666ae6a7aa28105721a0c2fe6a Mon Sep 17 00:00:00 2001 From: simon Date: Thu, 30 Apr 2026 12:58:53 +0200 Subject: [PATCH 08/15] Cut blast radius: keep target package + acceptance tests inside experimental Two reductions for an experimental command, per a maintainer comment: - Move libs/lakebase/target into experimental/postgres/cmd/internal/target so the experiment is self-contained. cmd/psql is no longer touched (no refactor, no behavior change). When/if this command graduates from experimental, that's the right time to extract the shared package. - Drop acceptance tests for the new command. Aitools (the other experimental command) has none either; locking down user-visible wording for an experimental surface is overinvestment. Unit tests still cover argument validation, retry classification, and rendering. Acceptance tests can be added when the command graduates. Net diff on cmd/psql is now zero. The experiment lives entirely under experimental/postgres/cmd/. Co-authored-by: Isaac --- .../query/ambiguous-targeting/out.test.toml | 8 -- .../query/ambiguous-targeting/output.txt | 18 --- .../postgres/query/ambiguous-targeting/script | 8 -- .../query/ambiguous-targeting/test.toml | 62 -------- .../query/argument-errors/out.test.toml | 8 -- .../postgres/query/argument-errors/output.txt | 48 ------- .../postgres/query/argument-errors/script | 35 ----- .../postgres/query/argument-errors/test.toml | 3 - .../cmd/psql/argument-errors/output.txt | 4 - acceptance/cmd/psql/argument-errors/script | 3 - acceptance/cmd/psql/postgres/output.txt | 2 +- cmd/psql/psql.go | 61 ++++++-- cmd/psql/psql_autoscaling.go | 132 ++++++++++++------ cmd/psql/psql_provisioned.go | 47 +++++-- cmd/psql/psql_test.go | 83 +++++++++++ .../cmd/internal}/target/autoscaling.go | 0 .../cmd/internal}/target/provisioned.go | 0 .../postgres/cmd/internal}/target/target.go | 0 .../cmd/internal}/target/target_test.go | 0 experimental/postgres/cmd/targeting.go | 2 +- 20 files changed, 257 insertions(+), 267 deletions(-) delete mode 100644 acceptance/cmd/experimental/postgres/query/ambiguous-targeting/out.test.toml delete mode 100644 acceptance/cmd/experimental/postgres/query/ambiguous-targeting/output.txt delete mode 100644 acceptance/cmd/experimental/postgres/query/ambiguous-targeting/script delete mode 100644 acceptance/cmd/experimental/postgres/query/ambiguous-targeting/test.toml delete mode 100644 acceptance/cmd/experimental/postgres/query/argument-errors/out.test.toml delete mode 100644 acceptance/cmd/experimental/postgres/query/argument-errors/output.txt delete mode 100644 acceptance/cmd/experimental/postgres/query/argument-errors/script delete mode 100644 acceptance/cmd/experimental/postgres/query/argument-errors/test.toml create mode 100644 cmd/psql/psql_test.go rename {libs/lakebase => experimental/postgres/cmd/internal}/target/autoscaling.go (100%) rename {libs/lakebase => experimental/postgres/cmd/internal}/target/provisioned.go (100%) rename {libs/lakebase => experimental/postgres/cmd/internal}/target/target.go (100%) rename {libs/lakebase => experimental/postgres/cmd/internal}/target/target_test.go (100%) diff --git a/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/out.test.toml b/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/out.test.toml deleted file mode 100644 index 40bb0d10471..00000000000 --- a/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/out.test.toml +++ /dev/null @@ -1,8 +0,0 @@ -Local = true -Cloud = false - -[GOOS] - windows = false - -[EnvMatrix] - DATABRICKS_BUNDLE_ENGINE = ["terraform", "direct"] diff --git a/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/output.txt b/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/output.txt deleted file mode 100644 index e95a7b3613d..00000000000 --- a/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/output.txt +++ /dev/null @@ -1,18 +0,0 @@ - -=== Project with multiple branches and no --branch should error with choices: ->>> musterr [CLI] experimental postgres query --project foo SELECT 1 -Error: multiple branches found in projects/foo; specify --branch: - - main - - dev - -=== Project with multiple endpoints in only branch should error with choices: ->>> musterr [CLI] experimental postgres query --project bar SELECT 1 -Error: multiple endpoints found in projects/bar/branches/only; specify --endpoint: - - read-write - - read-only - -=== Partial path with multiple branches should error with choices: ->>> musterr [CLI] experimental postgres query --target projects/foo SELECT 1 -Error: multiple branches found in projects/foo; specify --branch: - - main - - dev diff --git a/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/script b/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/script deleted file mode 100644 index 6143fd96f02..00000000000 --- a/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/script +++ /dev/null @@ -1,8 +0,0 @@ -title "Project with multiple branches and no --branch should error with choices:" -trace musterr $CLI experimental postgres query --project foo "SELECT 1" - -title "Project with multiple endpoints in only branch should error with choices:" -trace musterr $CLI experimental postgres query --project bar "SELECT 1" - -title "Partial path with multiple branches should error with choices:" -trace musterr $CLI experimental postgres query --target projects/foo "SELECT 1" diff --git a/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/test.toml b/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/test.toml deleted file mode 100644 index 2a61e7e8e25..00000000000 --- a/acceptance/cmd/experimental/postgres/query/ambiguous-targeting/test.toml +++ /dev/null @@ -1,62 +0,0 @@ -GOOS.windows = false - -[[Server]] -Pattern = "GET /api/2.0/postgres/projects" -Response.Body = ''' -{ - "projects": [ - {"name": "projects/alpha", "status": {"display_name": "Alpha"}}, - {"name": "projects/beta", "status": {"display_name": "Beta"}} - ] -} -''' - -[[Server]] -Pattern = "GET /api/2.0/postgres/projects/foo" -Response.Body = ''' -{ - "name": "projects/foo", - "status": {"display_name": "Foo Project"} -} -''' - -[[Server]] -Pattern = "GET /api/2.0/postgres/projects/foo/branches" -Response.Body = ''' -{ - "branches": [ - {"name": "projects/foo/branches/main"}, - {"name": "projects/foo/branches/dev"} - ] -} -''' - -[[Server]] -Pattern = "GET /api/2.0/postgres/projects/bar" -Response.Body = ''' -{ - "name": "projects/bar", - "status": {"display_name": "Bar Project"} -} -''' - -[[Server]] -Pattern = "GET /api/2.0/postgres/projects/bar/branches" -Response.Body = ''' -{ - "branches": [ - {"name": "projects/bar/branches/only"} - ] -} -''' - -[[Server]] -Pattern = "GET /api/2.0/postgres/projects/bar/branches/only/endpoints" -Response.Body = ''' -{ - "endpoints": [ - {"name": "projects/bar/branches/only/endpoints/read-write"}, - {"name": "projects/bar/branches/only/endpoints/read-only"} - ] -} -''' diff --git a/acceptance/cmd/experimental/postgres/query/argument-errors/out.test.toml b/acceptance/cmd/experimental/postgres/query/argument-errors/out.test.toml deleted file mode 100644 index 40bb0d10471..00000000000 --- a/acceptance/cmd/experimental/postgres/query/argument-errors/out.test.toml +++ /dev/null @@ -1,8 +0,0 @@ -Local = true -Cloud = false - -[GOOS] - windows = false - -[EnvMatrix] - DATABRICKS_BUNDLE_ENGINE = ["terraform", "direct"] diff --git a/acceptance/cmd/experimental/postgres/query/argument-errors/output.txt b/acceptance/cmd/experimental/postgres/query/argument-errors/output.txt deleted file mode 100644 index c071466a1e3..00000000000 --- a/acceptance/cmd/experimental/postgres/query/argument-errors/output.txt +++ /dev/null @@ -1,48 +0,0 @@ - -=== No SQL argument should error: ->>> musterr [CLI] experimental postgres query --target projects/foo -Error: accepts 1 arg(s), received 0 - -=== Empty SQL should error: ->>> musterr [CLI] experimental postgres query --target projects/foo -Error: no SQL provided - -=== Neither targeting form should error: ->>> musterr [CLI] experimental postgres query SELECT 1 -Error: must specify --target or --project - -=== Both --target and --project should error: ->>> musterr [CLI] experimental postgres query --target projects/foo --project foo SELECT 1 -Error: if any flags in the group [target project] are set none of the others can be; [project target] were all set - -=== Both --target and --branch should error: ->>> musterr [CLI] experimental postgres query --target projects/foo --branch main SELECT 1 -Error: if any flags in the group [target branch] are set none of the others can be; [branch target] were all set - -=== Branch without project should error: ->>> musterr [CLI] experimental postgres query --branch main SELECT 1 -Error: --project is required when using --branch or --endpoint - -=== Endpoint without project should error: ->>> musterr [CLI] experimental postgres query --endpoint primary SELECT 1 -Error: --project is required when using --branch or --endpoint - -=== Endpoint without branch should error: ->>> musterr [CLI] experimental postgres query --project foo --endpoint primary SELECT 1 -Error: --branch is required when using --endpoint - -=== Max-retries 0 should error: ->>> musterr [CLI] experimental postgres query --project foo --branch main --max-retries 0 SELECT 1 -Error: --max-retries must be at least 1; got 0 - -=== Provisioned-shaped target should error pointing at psql: ->>> musterr [CLI] experimental postgres query --target my-instance SELECT 1 -Error: provisioned instances are not yet supported by this experimental command; use 'databricks psql ' for now - -=== Malformed autoscaling path should error: ->>> musterr [CLI] experimental postgres query --target projects/ SELECT 1 -Error: invalid resource path: missing project ID - -=== Trailing components after endpoint should error: ->>> musterr [CLI] experimental postgres query --target projects/foo/branches/bar/endpoints/baz/extra SELECT 1 -Error: invalid resource path: trailing components after endpoint: projects/foo/branches/bar/endpoints/baz/extra diff --git a/acceptance/cmd/experimental/postgres/query/argument-errors/script b/acceptance/cmd/experimental/postgres/query/argument-errors/script deleted file mode 100644 index 8d64bf307ed..00000000000 --- a/acceptance/cmd/experimental/postgres/query/argument-errors/script +++ /dev/null @@ -1,35 +0,0 @@ -title "No SQL argument should error:" -trace musterr $CLI experimental postgres query --target projects/foo - -title "Empty SQL should error:" -trace musterr $CLI experimental postgres query --target projects/foo " " - -title "Neither targeting form should error:" -trace musterr $CLI experimental postgres query "SELECT 1" - -title "Both --target and --project should error:" -trace musterr $CLI experimental postgres query --target projects/foo --project foo "SELECT 1" - -title "Both --target and --branch should error:" -trace musterr $CLI experimental postgres query --target projects/foo --branch main "SELECT 1" - -title "Branch without project should error:" -trace musterr $CLI experimental postgres query --branch main "SELECT 1" - -title "Endpoint without project should error:" -trace musterr $CLI experimental postgres query --endpoint primary "SELECT 1" - -title "Endpoint without branch should error:" -trace musterr $CLI experimental postgres query --project foo --endpoint primary "SELECT 1" - -title "Max-retries 0 should error:" -trace musterr $CLI experimental postgres query --project foo --branch main --max-retries 0 "SELECT 1" - -title "Provisioned-shaped target should error pointing at psql:" -trace musterr $CLI experimental postgres query --target my-instance "SELECT 1" - -title "Malformed autoscaling path should error:" -trace musterr $CLI experimental postgres query --target projects/ "SELECT 1" - -title "Trailing components after endpoint should error:" -trace musterr $CLI experimental postgres query --target projects/foo/branches/bar/endpoints/baz/extra "SELECT 1" diff --git a/acceptance/cmd/experimental/postgres/query/argument-errors/test.toml b/acceptance/cmd/experimental/postgres/query/argument-errors/test.toml deleted file mode 100644 index 3371f08de12..00000000000 --- a/acceptance/cmd/experimental/postgres/query/argument-errors/test.toml +++ /dev/null @@ -1,3 +0,0 @@ -# Argument validation runs before any SDK call. No mocked HTTP responses are -# needed; CLI either errors at flag-parse time or at our own validate function. -GOOS.windows = false diff --git a/acceptance/cmd/psql/argument-errors/output.txt b/acceptance/cmd/psql/argument-errors/output.txt index cbf6c093b21..35da5961dec 100644 --- a/acceptance/cmd/psql/argument-errors/output.txt +++ b/acceptance/cmd/psql/argument-errors/output.txt @@ -59,10 +59,6 @@ Error: invalid resource path: missing branch ID >>> musterr [CLI] psql projects/my-project/branches/main/endpoints/ Error: invalid resource path: missing endpoint ID -=== Trailing components after endpoint should error: ->>> musterr [CLI] psql projects/my-project/branches/main/endpoints/primary/extra -Error: invalid resource path: trailing components after endpoint: projects/my-project/branches/main/endpoints/primary/extra - === Provisioned flag with --project should error: >>> musterr [CLI] psql --provisioned --project foo Error: cannot use --project, --branch, or --endpoint flags with --provisioned diff --git a/acceptance/cmd/psql/argument-errors/script b/acceptance/cmd/psql/argument-errors/script index 7db1cdbd271..7806efb0744 100644 --- a/acceptance/cmd/psql/argument-errors/script +++ b/acceptance/cmd/psql/argument-errors/script @@ -38,9 +38,6 @@ trace musterr $CLI psql projects/my-project/branches/ title "Invalid path with missing endpoint ID should error:" trace musterr $CLI psql projects/my-project/branches/main/endpoints/ -title "Trailing components after endpoint should error:" -trace musterr $CLI psql projects/my-project/branches/main/endpoints/primary/extra - title "Provisioned flag with --project should error:" trace musterr $CLI psql --provisioned --project foo diff --git a/acceptance/cmd/psql/postgres/output.txt b/acceptance/cmd/psql/postgres/output.txt index 8df91c6321c..5269553a0ce 100644 --- a/acceptance/cmd/psql/postgres/output.txt +++ b/acceptance/cmd/psql/postgres/output.txt @@ -50,7 +50,7 @@ PGSSLMODE=require Project: Init Project Branch: main Endpoint: init-ep -Error: endpoint is not ready for accepting connections (state: INIT) +Error: endpoint is not ready for accepting connections === Branch flag without project should fail: >>> musterr [CLI] psql --branch some-branch diff --git a/cmd/psql/psql.go b/cmd/psql/psql.go index 9be7fb5c5df..e7f3a65f8b3 100644 --- a/cmd/psql/psql.go +++ b/cmd/psql/psql.go @@ -11,7 +11,6 @@ import ( "github.com/databricks/cli/libs/cmdctx" "github.com/databricks/cli/libs/cmdgroup" "github.com/databricks/cli/libs/cmdio" - "github.com/databricks/cli/libs/lakebase/target" libpsql "github.com/databricks/cli/libs/psql" "github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go/service/database" @@ -87,9 +86,9 @@ For more information, see: https://docs.databricks.com/aws/en/oltp/ if argsLenAtDash < 0 { argsLenAtDash = len(args) } - targetArg := "" + target := "" if argsLenAtDash == 1 { - targetArg = args[0] + target = args[0] } else if argsLenAtDash > 1 { return errors.New("expected at most one positional argument for target") } @@ -110,17 +109,16 @@ For more information, see: https://docs.databricks.com/aws/en/oltp/ } // Positional argument takes precedence - if targetArg != "" { - if target.IsAutoscalingPath(targetArg) { + if target != "" { + if strings.HasPrefix(target, "projects/") { if provisionedFlag { return errors.New("cannot use --provisioned flag with an autoscaling resource path") } - spec, err := target.ParseAutoscalingPath(targetArg) + projectID, branchID, endpointID, err := parseResourcePath(target) if err != nil { return err } - projectID, branchID, endpointID := spec.ProjectID, spec.BranchID, spec.EndpointID // Check for conflicts between path and flags if projectFlag != "" && projectFlag != projectID { @@ -151,7 +149,7 @@ For more information, see: https://docs.databricks.com/aws/en/oltp/ if autoscalingFlag { return errors.New("cannot use --autoscaling flag with a provisioned instance name") } - return connectProvisioned(ctx, targetArg, retryConfig, extraArgs) + return connectProvisioned(ctx, target, retryConfig, extraArgs) } // No positional argument - use flags only @@ -199,6 +197,45 @@ For more information, see: https://docs.databricks.com/aws/en/oltp/ return cmd } +// parseResourcePath extracts project, branch, and endpoint IDs from a resource path. +// Returns an error for malformed paths. +func parseResourcePath(input string) (project, branch, endpoint string, err error) { + parts := strings.Split(input, "/") + + // Must start with projects/{project_id} + if len(parts) < 2 || parts[0] != "projects" { + return "", "", "", fmt.Errorf("invalid resource path: %s", input) + } + if parts[1] == "" { + return "", "", "", errors.New("invalid resource path: missing project ID") + } + project = parts[1] + + // Optional: branches/{branch_id} + if len(parts) > 2 { + if len(parts) < 4 || parts[2] != "branches" { + return "", "", "", errors.New("invalid resource path: expected 'branches' after project") + } + if parts[3] == "" { + return "", "", "", errors.New("invalid resource path: missing branch ID") + } + branch = parts[3] + } + + // Optional: endpoints/{endpoint_id} + if len(parts) > 4 { + if len(parts) < 6 || parts[4] != "endpoints" { + return "", "", "", errors.New("invalid resource path: expected 'endpoints' after branch") + } + if parts[5] == "" { + return "", "", "", errors.New("invalid resource path: missing endpoint ID") + } + endpoint = parts[5] + } + + return project, branch, endpoint, nil +} + // listAllDatabases fetches all database instances and projects in parallel. // Errors are silently ignored; callers should check for empty results. func listAllDatabases(ctx context.Context, w *databricks.WorkspaceClient) ([]database.DatabaseInstance, []postgres.Project) { @@ -211,12 +248,12 @@ func listAllDatabases(ctx context.Context, w *databricks.WorkspaceClient) ([]dat projectsCh := make(chan result[postgres.Project], 1) go func() { - instances, err := target.ListProvisionedInstances(ctx, w) + instances, err := w.Database.ListDatabaseInstancesAll(ctx, database.ListDatabaseInstancesRequest{}) instancesCh <- result[database.DatabaseInstance]{instances, err} }() go func() { - projects, err := target.ListProjects(ctx, w) + projects, err := w.Postgres.ListProjectsAll(ctx, postgres.ListProjectsRequest{}) projectsCh <- result[postgres.Project]{projects, err} }() @@ -257,7 +294,7 @@ func showSelectionAndConnect(ctx context.Context, retryConfig libpsql.RetryConfi }) } for _, proj := range projects { - displayName := target.ProjectIDFromName(proj.Name) + displayName := extractIDFromName(proj.Name, "projects") if proj.Status != nil && proj.Status.DisplayName != "" { displayName = proj.Status.DisplayName } @@ -278,7 +315,7 @@ func showSelectionAndConnect(ctx context.Context, retryConfig libpsql.RetryConfi } if after, ok := strings.CutPrefix(selected, "autoscaling:"); ok { projectName := after - projectID := target.ProjectIDFromName(projectName) + projectID := extractIDFromName(projectName, "projects") return connectAutoscaling(ctx, projectID, "", "", retryConfig, extraArgs) } diff --git a/cmd/psql/psql_autoscaling.go b/cmd/psql/psql_autoscaling.go index 04ccd4bef6b..00c555e4c12 100644 --- a/cmd/psql/psql_autoscaling.go +++ b/cmd/psql/psql_autoscaling.go @@ -4,10 +4,10 @@ import ( "context" "errors" "fmt" + "strings" "github.com/databricks/cli/libs/cmdctx" "github.com/databricks/cli/libs/cmdio" - "github.com/databricks/cli/libs/lakebase/target" libpsql "github.com/databricks/cli/libs/psql" "github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go/service/postgres" @@ -16,6 +16,18 @@ import ( // autoscalingDefaultDatabase is the default database for Lakebase Autoscaling projects. const autoscalingDefaultDatabase = "databricks_postgres" +// extractIDFromName extracts the ID component from a resource name. +// For example, extractIDFromName("projects/foo/branches/bar", "branches") returns "bar". +func extractIDFromName(name, component string) string { + parts := strings.Split(name, "/") + for i := range len(parts) - 1 { + if parts[i] == component { + return parts[i+1] + } + } + return name +} + // connectAutoscaling connects to a Lakebase Autoscaling endpoint. func connectAutoscaling(ctx context.Context, projectID, branchID, endpointID string, retryConfig libpsql.RetryConfig, extraArgs []string) error { w := cmdctx.WorkspaceClient(ctx) @@ -38,9 +50,11 @@ func connectAutoscaling(ctx context.Context, projectID, branchID, endpointID str return errors.New("endpoint host information is not available") } - token, err := target.AutoscalingCredential(ctx, w, endpoint.Name) + cred, err := w.Postgres.GenerateDatabaseCredential(ctx, postgres.GenerateDatabaseCredentialRequest{ + Endpoint: endpoint.Name, + }) if err != nil { - return err + return fmt.Errorf("failed to get database credentials: %w", err) } var endpointType string @@ -61,7 +75,7 @@ func connectAutoscaling(ctx context.Context, projectID, branchID, endpointID str case postgres.EndpointStatusStateIdle: suffix = " (idle, waking up)" default: - return fmt.Errorf("endpoint is not ready for accepting connections (state: %s)", state) + return errors.New("endpoint is not ready for accepting connections") } cmdio.LogString(ctx, fmt.Sprintf("Connecting to %s endpoint%s...", endpointType, suffix)) @@ -69,7 +83,7 @@ func connectAutoscaling(ctx context.Context, projectID, branchID, endpointID str return libpsql.Connect(ctx, libpsql.ConnectOptions{ Host: endpoint.Status.Hosts.Host, Username: user.UserName, - Password: token, + Password: cred.Token, DefaultDatabase: autoscalingDefaultDatabase, ExtraArgs: extraArgs, }, retryConfig) @@ -88,7 +102,7 @@ func resolveEndpoint(ctx context.Context, w *databricks.WorkspaceClient, project } // Get project to display its name - project, err := target.GetProject(ctx, w, projectID) + project, err := w.Postgres.GetProject(ctx, postgres.GetProjectRequest{Name: "projects/" + projectID}) if err != nil { return nil, fmt.Errorf("failed to get project: %w", err) } @@ -122,7 +136,7 @@ func resolveEndpoint(ctx context.Context, w *databricks.WorkspaceClient, project } // Get endpoint to validate and return it - endpoint, err := target.GetEndpoint(ctx, w, projectID, branchID, endpointID) + endpoint, err := w.Postgres.GetEndpoint(ctx, postgres.GetEndpointRequest{Name: branch.Name + "/endpoints/" + endpointID}) if err != nil { return nil, fmt.Errorf("failed to get endpoint: %w", err) } @@ -131,40 +145,38 @@ func resolveEndpoint(ctx context.Context, w *databricks.WorkspaceClient, project return endpoint, nil } -// selectAmbiguous prompts the user to pick one of the choices in an -// AmbiguousError. Caller is expected to have logged a header (e.g. via the -// spinner) before invoking. Used to keep psql's interactive UX while letting -// the shared lib do the actual list+filter work. -// -// Choice.DisplayName is empty when the producer has no friendlier label than -// the ID (e.g. branches and endpoints, where the ID is the human label). -// The promptui template renders an empty Name as a blank row, so we fall back -// to the ID before handing off to cmdio.SelectOrdered. -func selectAmbiguous(ctx context.Context, amb *target.AmbiguousError, prompt string) (string, error) { - items := make([]cmdio.Tuple, 0, len(amb.Choices)) - for _, c := range amb.Choices { - name := c.DisplayName - if name == "" { - name = c.ID - } - items = append(items, cmdio.Tuple{Name: name, Id: c.ID}) - } - return cmdio.SelectOrdered(ctx, items, prompt) -} - // selectProjectID auto-selects if there's only one project, otherwise prompts user to select. // Returns the project ID (not the full project object). func selectProjectID(ctx context.Context, w *databricks.WorkspaceClient) (string, error) { sp := cmdio.NewSpinner(ctx) sp.Update("Loading projects...") - id, err := target.AutoSelectProject(ctx, w) + projects, err := w.Postgres.ListProjectsAll(ctx, postgres.ListProjectsRequest{}) sp.Close() + if err != nil { + return "", err + } + + if len(projects) == 0 { + return "", errors.New("no Lakebase Autoscaling projects found in workspace") + } + + // Auto-select if there's only one project + if len(projects) == 1 { + return extractIDFromName(projects[0].Name, "projects"), nil + } - var amb *target.AmbiguousError - if !errors.As(err, &amb) { - return id, err + // Multiple projects, prompt user to select + var items []cmdio.Tuple + for _, project := range projects { + projectID := extractIDFromName(project.Name, "projects") + displayName := projectID + if project.Status != nil && project.Status.DisplayName != "" { + displayName = project.Status.DisplayName + } + items = append(items, cmdio.Tuple{Name: displayName, Id: projectID}) } - return selectAmbiguous(ctx, amb, "Select project") + + return cmdio.SelectOrdered(ctx, items, "Select project") } // selectBranchID auto-selects if there's only one branch, otherwise prompts user to select. @@ -172,14 +184,31 @@ func selectProjectID(ctx context.Context, w *databricks.WorkspaceClient) (string func selectBranchID(ctx context.Context, w *databricks.WorkspaceClient, projectName string) (string, error) { sp := cmdio.NewSpinner(ctx) sp.Update("Loading branches...") - id, err := target.AutoSelectBranch(ctx, w, projectName) + branches, err := w.Postgres.ListBranchesAll(ctx, postgres.ListBranchesRequest{ + Parent: projectName, + }) sp.Close() + if err != nil { + return "", err + } + + if len(branches) == 0 { + return "", errors.New("no branches found in project") + } + + // Auto-select if there's only one branch + if len(branches) == 1 { + return extractIDFromName(branches[0].Name, "branches"), nil + } - var amb *target.AmbiguousError - if !errors.As(err, &amb) { - return id, err + // Multiple branches, prompt user to select + var items []cmdio.Tuple + for _, branch := range branches { + branchID := extractIDFromName(branch.Name, "branches") + items = append(items, cmdio.Tuple{Name: branchID, Id: branchID}) } - return selectAmbiguous(ctx, amb, "Select branch") + + return cmdio.SelectOrdered(ctx, items, "Select branch") } // selectEndpointID auto-selects if there's only one endpoint, otherwise prompts user to select. @@ -187,12 +216,29 @@ func selectBranchID(ctx context.Context, w *databricks.WorkspaceClient, projectN func selectEndpointID(ctx context.Context, w *databricks.WorkspaceClient, branchName string) (string, error) { sp := cmdio.NewSpinner(ctx) sp.Update("Loading endpoints...") - id, err := target.AutoSelectEndpoint(ctx, w, branchName) + endpoints, err := w.Postgres.ListEndpointsAll(ctx, postgres.ListEndpointsRequest{ + Parent: branchName, + }) sp.Close() + if err != nil { + return "", err + } + + if len(endpoints) == 0 { + return "", errors.New("no endpoints found in branch") + } - var amb *target.AmbiguousError - if !errors.As(err, &amb) { - return id, err + // Auto-select if there's only one endpoint + if len(endpoints) == 1 { + return extractIDFromName(endpoints[0].Name, "endpoints"), nil } - return selectAmbiguous(ctx, amb, "Select endpoint") + + // Multiple endpoints, prompt user to select + var items []cmdio.Tuple + for _, endpoint := range endpoints { + endpointID := extractIDFromName(endpoint.Name, "endpoints") + items = append(items, cmdio.Tuple{Name: endpointID, Id: endpointID}) + } + + return cmdio.SelectOrdered(ctx, items, "Select endpoint") } diff --git a/cmd/psql/psql_provisioned.go b/cmd/psql/psql_provisioned.go index c7208906aa8..88ca1bb9181 100644 --- a/cmd/psql/psql_provisioned.go +++ b/cmd/psql/psql_provisioned.go @@ -7,10 +7,10 @@ import ( "github.com/databricks/cli/libs/cmdctx" "github.com/databricks/cli/libs/cmdio" - "github.com/databricks/cli/libs/lakebase/target" libpsql "github.com/databricks/cli/libs/psql" "github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go/service/database" + "github.com/google/uuid" ) // provisionedDefaultDatabase is the default database for Lakebase Provisioned instances. @@ -39,9 +39,12 @@ func connectProvisioned(ctx context.Context, instanceName string, retryConfig li return errors.New("database instance is not ready for accepting connections") } - token, err := target.ProvisionedCredential(ctx, w, instance.Name) + cred, err := w.Database.GenerateDatabaseCredential(ctx, database.GenerateDatabaseCredentialRequest{ + InstanceNames: []string{instance.Name}, + RequestId: uuid.NewString(), + }) if err != nil { - return err + return fmt.Errorf("failed to get database credentials: %w", err) } cmdio.LogString(ctx, "Connecting to database instance...") @@ -49,7 +52,7 @@ func connectProvisioned(ctx context.Context, instanceName string, retryConfig li return libpsql.Connect(ctx, libpsql.ConnectOptions{ Host: instance.ReadWriteDns, Username: user.UserName, - Password: token, + Password: cred.Token, DefaultDatabase: provisionedDefaultDatabase, ExtraArgs: extraArgs, }, retryConfig) @@ -67,11 +70,15 @@ func resolveInstance(ctx context.Context, w *databricks.WorkspaceClient, instanc } } - // target.GetProvisioned patches Name on the response; the SDK's - // GetDatabaseInstance does not always populate it. - instance, err := target.GetProvisioned(ctx, w, instanceName) + instance, err := w.Database.GetDatabaseInstance(ctx, database.GetDatabaseInstanceRequest{ + Name: instanceName, + }) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to get database instance: %w", err) + } + // Ensure Name is set (API response may not include it) + if instance.Name == "" { + instance.Name = instanceName } cmdio.LogString(ctx, "Instance: "+instance.Name) @@ -83,12 +90,26 @@ func resolveInstance(ctx context.Context, w *databricks.WorkspaceClient, instanc func selectInstanceID(ctx context.Context, w *databricks.WorkspaceClient) (string, error) { sp := cmdio.NewSpinner(ctx) sp.Update("Loading instances...") - id, err := target.AutoSelectProvisioned(ctx, w) + instances, err := w.Database.ListDatabaseInstancesAll(ctx, database.ListDatabaseInstancesRequest{}) sp.Close() + if err != nil { + return "", err + } - var amb *target.AmbiguousError - if !errors.As(err, &amb) { - return id, err + if len(instances) == 0 { + return "", errors.New("no Lakebase Provisioned instances found in workspace") } - return selectAmbiguous(ctx, amb, "Select instance") + + // Auto-select if there's only one instance + if len(instances) == 1 { + return instances[0].Name, nil + } + + // Multiple instances, prompt user to select + var items []cmdio.Tuple + for _, inst := range instances { + items = append(items, cmdio.Tuple{Name: inst.Name, Id: inst.Name}) + } + + return cmdio.SelectOrdered(ctx, items, "Select instance") } diff --git a/cmd/psql/psql_test.go b/cmd/psql/psql_test.go new file mode 100644 index 00000000000..fc8a7e53cba --- /dev/null +++ b/cmd/psql/psql_test.go @@ -0,0 +1,83 @@ +package psql + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestParseResourcePath(t *testing.T) { + tests := []struct { + name string + input string + project string + branch string + endpoint string + wantErr string + }{ + { + name: "project only", + input: "projects/my-project", + project: "my-project", + }, + { + name: "project and branch", + input: "projects/my-project/branches/main", + project: "my-project", + branch: "main", + }, + { + name: "full path", + input: "projects/my-project/branches/main/endpoints/primary", + project: "my-project", + branch: "main", + endpoint: "primary", + }, + { + name: "missing project ID", + input: "projects/", + wantErr: "missing project ID", + }, + { + name: "missing branch ID", + input: "projects/my-project/branches/", + wantErr: "missing branch ID", + }, + { + name: "missing endpoint ID", + input: "projects/my-project/branches/main/endpoints/", + wantErr: "missing endpoint ID", + }, + { + name: "invalid segment after project", + input: "projects/my-project/invalid/foo", + wantErr: "expected 'branches' after project", + }, + { + name: "invalid segment after branch", + input: "projects/my-project/branches/main/invalid/foo", + wantErr: "expected 'endpoints' after branch", + }, + { + name: "not a projects path", + input: "something/else", + wantErr: "invalid resource path", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + project, branch, endpoint, err := parseResourcePath(tc.input) + if tc.wantErr != "" { + require.Error(t, err) + assert.Contains(t, err.Error(), tc.wantErr) + return + } + require.NoError(t, err) + assert.Equal(t, tc.project, project) + assert.Equal(t, tc.branch, branch) + assert.Equal(t, tc.endpoint, endpoint) + }) + } +} diff --git a/libs/lakebase/target/autoscaling.go b/experimental/postgres/cmd/internal/target/autoscaling.go similarity index 100% rename from libs/lakebase/target/autoscaling.go rename to experimental/postgres/cmd/internal/target/autoscaling.go diff --git a/libs/lakebase/target/provisioned.go b/experimental/postgres/cmd/internal/target/provisioned.go similarity index 100% rename from libs/lakebase/target/provisioned.go rename to experimental/postgres/cmd/internal/target/provisioned.go diff --git a/libs/lakebase/target/target.go b/experimental/postgres/cmd/internal/target/target.go similarity index 100% rename from libs/lakebase/target/target.go rename to experimental/postgres/cmd/internal/target/target.go diff --git a/libs/lakebase/target/target_test.go b/experimental/postgres/cmd/internal/target/target_test.go similarity index 100% rename from libs/lakebase/target/target_test.go rename to experimental/postgres/cmd/internal/target/target_test.go diff --git a/experimental/postgres/cmd/targeting.go b/experimental/postgres/cmd/targeting.go index 5e72840f952..7f6a6830daa 100644 --- a/experimental/postgres/cmd/targeting.go +++ b/experimental/postgres/cmd/targeting.go @@ -5,8 +5,8 @@ import ( "errors" "fmt" + "github.com/databricks/cli/experimental/postgres/cmd/internal/target" "github.com/databricks/cli/libs/cmdctx" - "github.com/databricks/cli/libs/lakebase/target" "github.com/databricks/databricks-sdk-go" "github.com/databricks/databricks-sdk-go/service/postgres" ) From 1a0798825ce9be527e1f7f796dc640bcd9e797a4 Mon Sep 17 00:00:00 2001 From: simon Date: Thu, 30 Apr 2026 13:11:53 +0200 Subject: [PATCH 09/15] Extract output-mode handling into experimental/libs/sqlcli Both aitools query and postgres query had near-identical output-mode selection: same DATABRICKS_OUTPUT_FORMAT env var, same flag-vs-env precedence, same staticTableThreshold=30, same Format type with text/json/csv values. Promote the shared bits to experimental/libs/sqlcli: - sqlcli.EnvOutputFormat, sqlcli.StaticTableThreshold consts - sqlcli.Format typedef + sqlcli.OutputText/JSON/CSV consts - sqlcli.AllFormats slice (canonical order for completions) - sqlcli.ResolveFormat: handles flag > env > default precedence with the explicit-text-on-pipe-is-honoured rule Both consumers now import sqlcli. The package lives under experimental/libs/ rather than libs/ so it inherits the experimental- stability guarantee of its consumers; when both commands graduate, the package can be promoted alongside. The aitools migration is a pure refactor (no behavior change). The postgres command's output.go and output_test.go are deleted; tests moved to experimental/libs/sqlcli. Co-authored-by: Isaac --- experimental/aitools/cmd/query.go | 65 ++++++--------- experimental/aitools/cmd/query_test.go | 22 ++--- experimental/libs/sqlcli/output.go | 93 +++++++++++++++++++++ experimental/libs/sqlcli/output_test.go | 100 +++++++++++++++++++++++ experimental/postgres/cmd/output.go | 79 ------------------ experimental/postgres/cmd/output_test.go | 93 --------------------- experimental/postgres/cmd/query.go | 15 ++-- 7 files changed, 239 insertions(+), 228 deletions(-) create mode 100644 experimental/libs/sqlcli/output.go create mode 100644 experimental/libs/sqlcli/output_test.go delete mode 100644 experimental/postgres/cmd/output.go delete mode 100644 experimental/postgres/cmd/output_test.go diff --git a/experimental/aitools/cmd/query.go b/experimental/aitools/cmd/query.go index 7e9ae1d030d..45c5669c699 100644 --- a/experimental/aitools/cmd/query.go +++ b/experimental/aitools/cmd/query.go @@ -14,10 +14,9 @@ import ( "github.com/databricks/cli/cmd/root" "github.com/databricks/cli/experimental/aitools/lib/middlewares" "github.com/databricks/cli/experimental/aitools/lib/session" + "github.com/databricks/cli/experimental/libs/sqlcli" "github.com/databricks/cli/libs/cmdctx" "github.com/databricks/cli/libs/cmdio" - "github.com/databricks/cli/libs/env" - "github.com/databricks/cli/libs/flags" "github.com/databricks/cli/libs/log" "github.com/databricks/databricks-sdk-go/service/sql" "github.com/spf13/cobra" @@ -35,16 +34,6 @@ const ( // cancelTimeout is how long to wait for server-side cancellation. cancelTimeout = 10 * time.Second - - // staticTableThreshold is the maximum number of rows rendered as a static table. - // Beyond this, an interactive scrollable table is used. - staticTableThreshold = 30 - - // outputCSV is the csv output format, supported only by the query command. - outputCSV = "csv" - - // envOutputFormat matches the env var name in cmd/root/io.go. - envOutputFormat = "DATABRICKS_OUTPUT_FORMAT" ) type queryOutputMode int @@ -55,8 +44,13 @@ const ( queryOutputModeInteractiveTable ) -func selectQueryOutputMode(outputType flags.Output, stdoutInteractive, promptSupported bool, rowCount int) queryOutputMode { - if outputType == flags.OutputJSON { +// selectQueryOutputMode picks the rendering mode for a single-query result. +// JSON is the only machine-readable option; static and interactive are +// table variants chosen by row count and TTY capabilities. Sharing only +// the threshold with sqlcli; the three-way decision is aitools-specific +// because the postgres command's renderers have a different shape. +func selectQueryOutputMode(format sqlcli.Format, stdoutInteractive, promptSupported bool, rowCount int) queryOutputMode { + if format == sqlcli.OutputJSON { return queryOutputModeJSON } if !stdoutInteractive { @@ -67,7 +61,7 @@ func selectQueryOutputMode(outputType flags.Output, stdoutInteractive, promptSup if !promptSupported { return queryOutputModeStaticTable } - if rowCount <= staticTableThreshold { + if rowCount <= sqlcli.StaticTableThreshold { return queryOutputModeStaticTable } return queryOutputModeInteractiveTable @@ -119,24 +113,15 @@ interactive table browser. Use --output csv to export results as CSV.`, RunE: func(cmd *cobra.Command, args []string) error { ctx := cmd.Context() - // Normalize case to match root --output behavior (flags.Output.Set lowercases). - outputFormat = strings.ToLower(outputFormat) - - // If --output wasn't explicitly passed, check the env var. - // Invalid env values are silently ignored, matching cmd/root/io.go. - if !cmd.Flag("output").Changed { - if v, ok := env.Lookup(ctx, envOutputFormat); ok { - switch flags.Output(strings.ToLower(v)) { - case flags.OutputText, flags.OutputJSON, outputCSV: - outputFormat = strings.ToLower(v) - } - } - } - - switch flags.Output(outputFormat) { - case flags.OutputText, flags.OutputJSON, outputCSV: - default: - return fmt.Errorf("unsupported output format %q, accepted values: text, json, csv", outputFormat) + // Resolve the effective format via sqlcli so the env-var + // precedence and explicit-text-on-pipe handling stays in sync + // across commands. We pass stdoutTTY=true to keep the original + // aitools behavior of not auto-falling-back to JSON here; the + // per-result render mode further down already handles the pipe + // case via selectQueryOutputMode. + format, err := sqlcli.ResolveFormat(ctx, outputFormat, cmd.Flag("output").Changed, true) + if err != nil { + return err } sqls, err := resolveSQLs(ctx, cmd, args, filePaths) @@ -146,7 +131,7 @@ interactive table browser. Use --output csv to export results as CSV.`, // Reject incompatible flag combinations before any API call so the // user sees the real error instead of an auth/warehouse failure. - if len(sqls) > 1 && flags.Output(outputFormat) != flags.OutputJSON { + if len(sqls) > 1 && format != sqlcli.OutputJSON { return fmt.Errorf("multiple queries require --output json (got %q); pass --output json to receive a JSON array of per-statement results", outputFormat) } @@ -173,7 +158,7 @@ interactive table browser. Use --output csv to export results as CSV.`, } // CSV bypasses the normal output mode selection. - if flags.Output(outputFormat) == outputCSV { + if format == sqlcli.OutputCSV { if len(columns) == 0 && len(rows) == 0 { return nil } @@ -190,7 +175,7 @@ interactive table browser. Use --output csv to export results as CSV.`, stdoutInteractive := cmdio.SupportsColor(ctx, cmd.OutOrStdout()) promptSupported := cmdio.IsPromptSupported(ctx) - switch selectQueryOutputMode(flags.Output(outputFormat), stdoutInteractive, promptSupported, len(rows)) { + switch selectQueryOutputMode(format, stdoutInteractive, promptSupported, len(rows)) { case queryOutputModeJSON: return renderJSON(cmd.OutOrStdout(), columns, rows) case queryOutputModeStaticTable: @@ -206,9 +191,13 @@ interactive table browser. Use --output csv to export results as CSV.`, cmd.Flags().IntVar(&concurrency, "concurrency", defaultBatchConcurrency, "Maximum in-flight statements when running a batch of queries") // Local --output flag shadows the root command's persistent --output flag, // adding csv support for this command only. - cmd.Flags().StringVarP(&outputFormat, "output", "o", string(flags.OutputText), "Output format: text, json, or csv") + cmd.Flags().StringVarP(&outputFormat, "output", "o", string(sqlcli.OutputText), "Output format: text, json, or csv") cmd.RegisterFlagCompletionFunc("output", func(*cobra.Command, []string, string) ([]string, cobra.ShellCompDirective) { - return []string{string(flags.OutputText), string(flags.OutputJSON), string(outputCSV)}, cobra.ShellCompDirectiveNoFileComp + out := make([]string, len(sqlcli.AllFormats)) + for i, f := range sqlcli.AllFormats { + out[i] = string(f) + } + return out, cobra.ShellCompDirectiveNoFileComp }) return cmd diff --git a/experimental/aitools/cmd/query_test.go b/experimental/aitools/cmd/query_test.go index 59de11d578a..c85edc64722 100644 --- a/experimental/aitools/cmd/query_test.go +++ b/experimental/aitools/cmd/query_test.go @@ -10,9 +10,9 @@ import ( "time" "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/experimental/libs/sqlcli" "github.com/databricks/cli/libs/cmdio" "github.com/databricks/cli/libs/env" - "github.com/databricks/cli/libs/flags" mocksql "github.com/databricks/databricks-sdk-go/experimental/mocks/service/sql" "github.com/databricks/databricks-sdk-go/service/sql" "github.com/spf13/cobra" @@ -271,7 +271,7 @@ func TestResolveWarehouseIDWithFlag(t *testing.T) { func TestSelectQueryOutputMode(t *testing.T) { tests := []struct { name string - outputType flags.Output + format sqlcli.Format stdoutInteractive bool promptSupported bool rowCount int @@ -279,7 +279,7 @@ func TestSelectQueryOutputMode(t *testing.T) { }{ { name: "json flag always returns json", - outputType: flags.OutputJSON, + format: sqlcli.OutputJSON, stdoutInteractive: true, promptSupported: true, rowCount: 999, @@ -287,7 +287,7 @@ func TestSelectQueryOutputMode(t *testing.T) { }, { name: "non interactive stdout returns json", - outputType: flags.OutputText, + format: sqlcli.OutputText, stdoutInteractive: false, promptSupported: true, rowCount: 5, @@ -295,33 +295,33 @@ func TestSelectQueryOutputMode(t *testing.T) { }, { name: "missing stdin interactivity falls back to static table", - outputType: flags.OutputText, + format: sqlcli.OutputText, stdoutInteractive: true, promptSupported: false, - rowCount: staticTableThreshold + 10, + rowCount: sqlcli.StaticTableThreshold + 10, want: queryOutputModeStaticTable, }, { name: "small results use static table", - outputType: flags.OutputText, + format: sqlcli.OutputText, stdoutInteractive: true, promptSupported: true, - rowCount: staticTableThreshold, + rowCount: sqlcli.StaticTableThreshold, want: queryOutputModeStaticTable, }, { name: "large results use interactive table", - outputType: flags.OutputText, + format: sqlcli.OutputText, stdoutInteractive: true, promptSupported: true, - rowCount: staticTableThreshold + 1, + rowCount: sqlcli.StaticTableThreshold + 1, want: queryOutputModeInteractiveTable, }, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { - got := selectQueryOutputMode(tc.outputType, tc.stdoutInteractive, tc.promptSupported, tc.rowCount) + got := selectQueryOutputMode(tc.format, tc.stdoutInteractive, tc.promptSupported, tc.rowCount) assert.Equal(t, tc.want, got) }) } diff --git a/experimental/libs/sqlcli/output.go b/experimental/libs/sqlcli/output.go new file mode 100644 index 00000000000..4643303cd23 --- /dev/null +++ b/experimental/libs/sqlcli/output.go @@ -0,0 +1,93 @@ +// Package sqlcli holds patterns shared by experimental SQL-running commands +// (currently `experimental aitools tools query` and `experimental postgres +// query`). The package lives under experimental/libs/ rather than libs/ so +// the commands depending on it inherit experimental-stability guarantees: +// when both consumers graduate, this package can be promoted alongside +// (or its API stabilised first). +package sqlcli + +import ( + "context" + "fmt" + "slices" + "strings" + + "github.com/databricks/cli/libs/env" +) + +// EnvOutputFormat matches the env var name in cmd/root/io.go. +// Reading it lets pipelines set DATABRICKS_OUTPUT_FORMAT once for all +// commands. +const EnvOutputFormat = "DATABRICKS_OUTPUT_FORMAT" + +// StaticTableThreshold is the row count above which interactive callers may +// hand off to libs/tableview's scrollable viewer. Smaller results stay in a +// static tabwriter table so they pipe to scripts unchanged. +const StaticTableThreshold = 30 + +// Format is the user-selectable output shape. Using a string typedef instead +// of an int enum keeps the help text and DATABRICKS_OUTPUT_FORMAT env var +// values self-describing. +type Format string + +const ( + OutputText Format = "text" + OutputJSON Format = "json" + OutputCSV Format = "csv" +) + +// AllFormats is the canonical order shown in completions / help. Sharing +// the slice avoids drift between consumers when a new format is added. +var AllFormats = []Format{OutputText, OutputJSON, OutputCSV} + +// ResolveFormat picks the effective output format. Precedence: +// +// 1. The local --output flag if it was explicitly set. +// 2. DATABRICKS_OUTPUT_FORMAT env var if set to a known value (invalid +// values are silently ignored, matching cmd/root/io.go and aitools). +// 3. The flag default (whatever the caller passes as flagValue). +// +// Then the auto-selection rule applies: a *defaulted* text mode on a non-TTY +// stdout falls back to JSON, so scripts piping the output get machine- +// readable output by default. An *explicit* --output text (flag or env) is +// honoured even on a pipe; per AGENTS.md we don't silently override flags +// the user set. +// +// flagSet is true if the user explicitly passed --output on the CLI. +// stdoutTTY is true if stdout is a terminal. +func ResolveFormat(ctx context.Context, flagValue string, flagSet, stdoutTTY bool) (Format, error) { + chosen := Format(strings.ToLower(flagValue)) + chosenExplicit := flagSet + + if !flagSet { + if v, ok := env.Lookup(ctx, EnvOutputFormat); ok { + candidate := Format(strings.ToLower(v)) + if IsKnown(candidate) { + chosen = candidate + chosenExplicit = true + } + } + } + + if !IsKnown(chosen) { + return "", fmt.Errorf("unsupported output format %q; expected one of: %s", flagValue, joinFormats(AllFormats)) + } + + if chosen == OutputText && !stdoutTTY && !chosenExplicit { + return OutputJSON, nil + } + return chosen, nil +} + +// IsKnown reports whether f is one of the formats in AllFormats. +func IsKnown(f Format) bool { + return slices.Contains(AllFormats, f) +} + +func joinFormats(formats []Format) string { + parts := make([]string, len(formats)) + for i, f := range formats { + parts[i] = string(f) + } + return strings.Join(parts, ", ") +} diff --git a/experimental/libs/sqlcli/output_test.go b/experimental/libs/sqlcli/output_test.go new file mode 100644 index 00000000000..1e91bd9cf3d --- /dev/null +++ b/experimental/libs/sqlcli/output_test.go @@ -0,0 +1,100 @@ +package sqlcli + +import ( + "testing" + + "github.com/databricks/cli/libs/env" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestResolveFormat_Defaults(t *testing.T) { + ctx := t.Context() + got, err := ResolveFormat(ctx, "text", false, true) + require.NoError(t, err) + assert.Equal(t, OutputText, got) +} + +func TestResolveFormat_TextOnPipeFallsBackToJSON(t *testing.T) { + ctx := t.Context() + got, err := ResolveFormat(ctx, "text", false, false) + require.NoError(t, err) + assert.Equal(t, OutputJSON, got) +} + +func TestResolveFormat_ExplicitTextOnPipeIsHonoured(t *testing.T) { + ctx := t.Context() + got, err := ResolveFormat(ctx, "text", true, false) + require.NoError(t, err) + assert.Equal(t, OutputText, got) +} + +func TestResolveFormat_EnvVarTextOnPipeIsHonoured(t *testing.T) { + ctx := env.Set(t.Context(), EnvOutputFormat, "text") + got, err := ResolveFormat(ctx, "text", false, false) + require.NoError(t, err) + assert.Equal(t, OutputText, got) +} + +func TestResolveFormat_EnvVarCSVOnPipe(t *testing.T) { + ctx := env.Set(t.Context(), EnvOutputFormat, "csv") + got, err := ResolveFormat(ctx, "text", false, false) + require.NoError(t, err) + assert.Equal(t, OutputCSV, got) +} + +func TestResolveFormat_ExplicitJSON(t *testing.T) { + ctx := t.Context() + got, err := ResolveFormat(ctx, "json", true, true) + require.NoError(t, err) + assert.Equal(t, OutputJSON, got) +} + +func TestResolveFormat_ExplicitCSV(t *testing.T) { + ctx := t.Context() + got, err := ResolveFormat(ctx, "csv", true, true) + require.NoError(t, err) + assert.Equal(t, OutputCSV, got) +} + +func TestResolveFormat_EnvVarHonoredWhenFlagNotSet(t *testing.T) { + ctx := env.Set(t.Context(), EnvOutputFormat, "csv") + got, err := ResolveFormat(ctx, "text", false, true) + require.NoError(t, err) + assert.Equal(t, OutputCSV, got) +} + +func TestResolveFormat_FlagOverridesEnvVar(t *testing.T) { + ctx := env.Set(t.Context(), EnvOutputFormat, "csv") + got, err := ResolveFormat(ctx, "json", true, true) + require.NoError(t, err) + assert.Equal(t, OutputJSON, got) +} + +func TestResolveFormat_InvalidEnvVarIgnored(t *testing.T) { + ctx := env.Set(t.Context(), EnvOutputFormat, "yaml") + got, err := ResolveFormat(ctx, "text", false, true) + require.NoError(t, err) + assert.Equal(t, OutputText, got) +} + +func TestResolveFormat_InvalidFlagErrors(t *testing.T) { + ctx := t.Context() + _, err := ResolveFormat(ctx, "yaml", true, true) + assert.ErrorContains(t, err, "unsupported output format") +} + +func TestResolveFormat_CaseInsensitive(t *testing.T) { + ctx := t.Context() + got, err := ResolveFormat(ctx, "JSON", true, true) + require.NoError(t, err) + assert.Equal(t, OutputJSON, got) +} + +func TestIsKnown(t *testing.T) { + assert.True(t, IsKnown(OutputText)) + assert.True(t, IsKnown(OutputJSON)) + assert.True(t, IsKnown(OutputCSV)) + assert.False(t, IsKnown(Format("yaml"))) + assert.False(t, IsKnown(Format(""))) +} diff --git a/experimental/postgres/cmd/output.go b/experimental/postgres/cmd/output.go deleted file mode 100644 index e5b59fec96f..00000000000 --- a/experimental/postgres/cmd/output.go +++ /dev/null @@ -1,79 +0,0 @@ -package postgrescmd - -import ( - "context" - "fmt" - "slices" - "strings" - - "github.com/databricks/cli/libs/env" -) - -// outputFormat is the user-selectable output shape. Using a string typedef -// instead of an int enum keeps the help text and DATABRICKS_OUTPUT_FORMAT env -// var values self-describing. -type outputFormat string - -const ( - outputText outputFormat = "text" - outputJSON outputFormat = "json" - outputCSV outputFormat = "csv" - - // envOutputFormat matches the env var name in cmd/root/io.go. Reading it - // here lets pipelines set DATABRICKS_OUTPUT_FORMAT once for all - // commands. See aitools query for a parallel pattern. - envOutputFormat = "DATABRICKS_OUTPUT_FORMAT" -) - -// allOutputFormats is the canonical order shown in completions / help. -var allOutputFormats = []outputFormat{outputText, outputJSON, outputCSV} - -// resolveOutputFormat picks the effective output format. Precedence: -// -// 1. The local --output flag if it was explicitly set. -// 2. DATABRICKS_OUTPUT_FORMAT env var if set to a known value (invalid -// values are silently ignored, matching cmd/root/io.go and aitools). -// 3. The flag default ("text"). -// -// Then the auto-selection rule applies: a *defaulted* text mode on a non-TTY -// stdout falls back to JSON, so scripts piping the output get machine- -// readable output by default. An *explicit* --output text is honoured even -// on a pipe; per CLAUDE.md we don't silently override flags the user set. -// -// flagSet is true if the user explicitly passed --output. stdoutTTY is true -// if stdout is a terminal. -func resolveOutputFormat(ctx context.Context, flagValue string, flagSet, stdoutTTY bool) (outputFormat, error) { - chosen := outputFormat(strings.ToLower(flagValue)) - chosenExplicit := flagSet - - if !flagSet { - if v, ok := env.Lookup(ctx, envOutputFormat); ok { - candidate := outputFormat(strings.ToLower(v)) - if isKnownOutputFormat(candidate) { - chosen = candidate - chosenExplicit = true - } - } - } - - if !isKnownOutputFormat(chosen) { - return "", fmt.Errorf("unsupported output format %q; expected one of: %s", flagValue, joinOutputFormats(allOutputFormats)) - } - - if chosen == outputText && !stdoutTTY && !chosenExplicit { - return outputJSON, nil - } - return chosen, nil -} - -func joinOutputFormats(formats []outputFormat) string { - parts := make([]string, len(formats)) - for i, f := range formats { - parts[i] = string(f) - } - return strings.Join(parts, ", ") -} - -func isKnownOutputFormat(f outputFormat) bool { - return slices.Contains(allOutputFormats, f) -} diff --git a/experimental/postgres/cmd/output_test.go b/experimental/postgres/cmd/output_test.go deleted file mode 100644 index 4598085805a..00000000000 --- a/experimental/postgres/cmd/output_test.go +++ /dev/null @@ -1,93 +0,0 @@ -package postgrescmd - -import ( - "testing" - - "github.com/databricks/cli/libs/env" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" -) - -func TestResolveOutputFormat_Defaults(t *testing.T) { - ctx := t.Context() - - got, err := resolveOutputFormat(ctx, "text", false, true) - require.NoError(t, err) - assert.Equal(t, outputText, got) -} - -func TestResolveOutputFormat_TextOnPipeFallsBackToJSON(t *testing.T) { - ctx := t.Context() - got, err := resolveOutputFormat(ctx, "text", false, false) - require.NoError(t, err) - assert.Equal(t, outputJSON, got) -} - -func TestResolveOutputFormat_ExplicitTextOnPipeIsHonoured(t *testing.T) { - ctx := t.Context() - got, err := resolveOutputFormat(ctx, "text", true, false) - require.NoError(t, err) - assert.Equal(t, outputText, got) -} - -func TestResolveOutputFormat_EnvVarTextOnPipeIsHonoured(t *testing.T) { - ctx := env.Set(t.Context(), envOutputFormat, "text") - got, err := resolveOutputFormat(ctx, "text", false, false) - require.NoError(t, err) - assert.Equal(t, outputText, got) -} - -func TestResolveOutputFormat_EnvVarCSVOnPipe(t *testing.T) { - ctx := env.Set(t.Context(), envOutputFormat, "csv") - got, err := resolveOutputFormat(ctx, "text", false, false) - require.NoError(t, err) - assert.Equal(t, outputCSV, got) -} - -func TestResolveOutputFormat_ExplicitJSON(t *testing.T) { - ctx := t.Context() - got, err := resolveOutputFormat(ctx, "json", true, true) - require.NoError(t, err) - assert.Equal(t, outputJSON, got) -} - -func TestResolveOutputFormat_ExplicitCSV(t *testing.T) { - ctx := t.Context() - got, err := resolveOutputFormat(ctx, "csv", true, true) - require.NoError(t, err) - assert.Equal(t, outputCSV, got) -} - -func TestResolveOutputFormat_EnvVarHonoredWhenFlagNotSet(t *testing.T) { - ctx := env.Set(t.Context(), envOutputFormat, "csv") - got, err := resolveOutputFormat(ctx, "text", false, true) - require.NoError(t, err) - assert.Equal(t, outputCSV, got) -} - -func TestResolveOutputFormat_FlagOverridesEnvVar(t *testing.T) { - ctx := env.Set(t.Context(), envOutputFormat, "csv") - got, err := resolveOutputFormat(ctx, "json", true, true) - require.NoError(t, err) - assert.Equal(t, outputJSON, got) -} - -func TestResolveOutputFormat_InvalidEnvVarIgnored(t *testing.T) { - ctx := env.Set(t.Context(), envOutputFormat, "yaml") - got, err := resolveOutputFormat(ctx, "text", false, true) - require.NoError(t, err) - assert.Equal(t, outputText, got) -} - -func TestResolveOutputFormat_InvalidFlagErrors(t *testing.T) { - ctx := t.Context() - _, err := resolveOutputFormat(ctx, "yaml", true, true) - assert.ErrorContains(t, err, "unsupported output format") -} - -func TestResolveOutputFormat_CaseInsensitive(t *testing.T) { - ctx := t.Context() - got, err := resolveOutputFormat(ctx, "JSON", true, true) - require.NoError(t, err) - assert.Equal(t, outputJSON, got) -} diff --git a/experimental/postgres/cmd/query.go b/experimental/postgres/cmd/query.go index 2b4f12694f9..5a7f3e577cc 100644 --- a/experimental/postgres/cmd/query.go +++ b/experimental/postgres/cmd/query.go @@ -9,6 +9,7 @@ import ( "time" "github.com/databricks/cli/cmd/root" + "github.com/databricks/cli/experimental/libs/sqlcli" "github.com/databricks/cli/libs/cmdio" "github.com/jackc/pgx/v5" "github.com/spf13/cobra" @@ -89,10 +90,10 @@ Limitations (this release): cmd.Flags().StringVarP(&f.database, "database", "d", defaultDatabase, "Database name") cmd.Flags().DurationVar(&f.connectTimeout, "connect-timeout", defaultConnectTimeout, "Connect timeout") cmd.Flags().IntVar(&f.maxRetries, "max-retries", 3, "Total connect attempts on idle/waking endpoint (must be >= 1; 1 disables retry)") - cmd.Flags().StringVarP(&f.outputFormat, "output", "o", string(outputText), "Output format: text, json, or csv") + cmd.Flags().StringVarP(&f.outputFormat, "output", "o", string(sqlcli.OutputText), "Output format: text, json, or csv") cmd.RegisterFlagCompletionFunc("output", func(*cobra.Command, []string, string) ([]string, cobra.ShellCompDirective) { - out := make([]string, len(allOutputFormats)) - for i, f := range allOutputFormats { + out := make([]string, len(sqlcli.AllFormats)) + for i, f := range sqlcli.AllFormats { out[i] = string(f) } return out, cobra.ShellCompDirectiveNoFileComp @@ -126,7 +127,7 @@ func runQuery(ctx context.Context, cmd *cobra.Command, sql string, f queryFlags) // pass --output text explicitly; that path is honoured (see // resolveOutputFormat). Mirrors the aitools query command. stdoutTTY := cmdio.SupportsColor(ctx, cmd.OutOrStdout()) - format, err := resolveOutputFormat(ctx, f.outputFormat, f.outputFormatSet, stdoutTTY) + format, err := sqlcli.ResolveFormat(ctx, f.outputFormat, f.outputFormatSet, stdoutTTY) if err != nil { return err } @@ -168,11 +169,11 @@ func runQuery(ctx context.Context, cmd *cobra.Command, sql string, f queryFlags) // newSink returns the rowSink for the chosen output format. Kept separate // from runQuery so tests can build sinks without going through pgx. -func newSink(format outputFormat, out, stderr io.Writer) rowSink { +func newSink(format sqlcli.Format, out, stderr io.Writer) rowSink { switch format { - case outputJSON: + case sqlcli.OutputJSON: return newJSONSink(out, stderr) - case outputCSV: + case sqlcli.OutputCSV: return newCSVSink(out, stderr) default: return newTextSink(out) From a5dff81d25b4ac21d390935ee5061e4972dc72e3 Mon Sep 17 00:00:00 2001 From: simon Date: Thu, 30 Apr 2026 15:16:57 +0200 Subject: [PATCH 10/15] Address nitpicker findings: NO_COLOR-safe TTY check, dup-column collision, control-char escape Three P2 findings from the nitpicker bot, all in code introduced or strengthened in this PR: - stdoutTTY now uses cmdio.IsOutputTTY (a new tiny public helper that wraps the existing private isTTY) instead of cmdio.SupportsColor. SupportsColor folds in NO_COLOR / TERM=dumb, which are colour preferences and have nothing to do with whether stdout is a pipe; using it for the auto-fall-back-to-JSON decision silently demoted interactive text output to JSON for users with NO_COLOR set on a real terminal. IsOutputTTY is the right primitive for this. - jsonSink dup-column rename: the previous logic generated id__2 for the second `id` without checking whether id__2 was already taken by the original column list. A query returning ["id", "id__2", "id"] produced two id__2 keys. Now we keep bumping the suffix until unique. - textSink escapes \t, \n, \r in cell values before tabwriter sees them. tabwriter uses \t as a column boundary and \n as a row boundary, so an embedded tab silently shifted subsequent columns and an embedded newline split a logical row across multiple output lines. psql does the same backslash-letter escape. Co-authored-by: Isaac --- experimental/postgres/cmd/query.go | 12 +++++----- experimental/postgres/cmd/render.go | 15 +++++++++++- experimental/postgres/cmd/render_json.go | 20 ++++++++++++---- experimental/postgres/cmd/render_json_test.go | 23 +++++++++++++++++++ experimental/postgres/cmd/render_test.go | 13 +++++++++++ libs/cmdio/tty.go | 10 ++++++++ 6 files changed, 82 insertions(+), 11 deletions(-) diff --git a/experimental/postgres/cmd/query.go b/experimental/postgres/cmd/query.go index 5a7f3e577cc..f05b3e01503 100644 --- a/experimental/postgres/cmd/query.go +++ b/experimental/postgres/cmd/query.go @@ -121,12 +121,12 @@ func runQuery(ctx context.Context, cmd *cobra.Command, sql string, f queryFlags) return err } - // SupportsColor is the public TTY-ish signal libs/cmdio exposes today; it - // also folds in NO_COLOR / TERM=dumb, which strictly speaking are colour - // preferences rather than TTY signals. Users who hit that edge case can - // pass --output text explicitly; that path is honoured (see - // resolveOutputFormat). Mirrors the aitools query command. - stdoutTTY := cmdio.SupportsColor(ctx, cmd.OutOrStdout()) + // IsOutputTTY checks the file-descriptor only. SupportsColor would also + // AND in NO_COLOR / TERM=dumb, which are colour preferences and have + // nothing to do with whether stdout is a pipe; folding them in here + // would silently demote interactive text output to JSON for users who + // have NO_COLOR set on a real terminal. + stdoutTTY := cmdio.IsOutputTTY(cmd.OutOrStdout()) format, err := sqlcli.ResolveFormat(ctx, f.outputFormat, f.outputFormatSet, stdoutTTY) if err != nil { return err diff --git a/experimental/postgres/cmd/render.go b/experimental/postgres/cmd/render.go index 2e1daf6376b..a3c6aa53344 100644 --- a/experimental/postgres/cmd/render.go +++ b/experimental/postgres/cmd/render.go @@ -36,12 +36,25 @@ func (s *textSink) Begin(fields []pgconn.FieldDescription) error { func (s *textSink) Row(values []any) error { row := make([]string, len(values)) for i, v := range values { - row[i] = textValue(v) + row[i] = escapeControlForTabwriter(textValue(v)) } s.rows = append(s.rows, row) return nil } +// escapeControlForTabwriter replaces tabs, newlines, and carriage returns in +// a cell value with the two-character backslash-letter sequence. tabwriter +// uses '\t' as a column boundary and '\n' as a row boundary, so an embedded +// tab silently shifts subsequent columns and an embedded newline splits one +// logical row into two. psql's text mode applies the same escapes. +func escapeControlForTabwriter(s string) string { + if !strings.ContainsAny(s, "\t\n\r") { + return s + } + r := strings.NewReplacer("\t", `\t`, "\n", `\n`, "\r", `\r`) + return r.Replace(s) +} + func (s *textSink) End(commandTag string) error { if len(s.columns) == 0 { _, err := fmt.Fprintln(s.out, commandTag) diff --git a/experimental/postgres/cmd/render_json.go b/experimental/postgres/cmd/render_json.go index dc713b6b786..c50739e2d1f 100644 --- a/experimental/postgres/cmd/render_json.go +++ b/experimental/postgres/cmd/render_json.go @@ -52,16 +52,28 @@ func (s *jsonSink) Begin(fields []pgconn.FieldDescription) error { s.columns = make([]string, len(fields)) s.oids = make([]uint32, len(fields)) - seen := make(map[string]int, len(fields)) + // assigned tracks every name we have committed to s.columns so far. This + // must include both first-occurrence names and __N suffixed renames, so a + // query whose original column list contains the same suffix we'd generate + // (e.g. ["id", "id__2", "id"]) does not produce two id__2 keys. + assigned := make(map[string]struct{}, len(fields)) dupes := false for i, f := range fields { s.oids[i] = f.DataTypeOID name := f.Name - seen[name]++ - if seen[name] > 1 { + if _, taken := assigned[name]; taken { dupes = true - name = fmt.Sprintf("%s__%d", f.Name, seen[name]) + suffix := 2 + for { + candidate := fmt.Sprintf("%s__%d", f.Name, suffix) + if _, taken := assigned[candidate]; !taken { + name = candidate + break + } + suffix++ + } } + assigned[name] = struct{}{} s.columns[i] = name } if dupes { diff --git a/experimental/postgres/cmd/render_json_test.go b/experimental/postgres/cmd/render_json_test.go index 4e6f474d257..9cf386cb14d 100644 --- a/experimental/postgres/cmd/render_json_test.go +++ b/experimental/postgres/cmd/render_json_test.go @@ -2,6 +2,7 @@ package postgrescmd import ( "bytes" + "strings" "testing" "github.com/jackc/pgx/v5/pgconn" @@ -136,3 +137,25 @@ func TestCommandTagParse(t *testing.T) { } } } + +func TestJSONSink_DuplicateColumns_DoesNotCollideWithExistingSuffix(t *testing.T) { + // Source columns ["id", "id__2", "id"]: the second `id` would naively + // rename to id__2, colliding with the existing id__2 from the source. + // Verify the dedup logic bumps the suffix until unique. + var stdout, stderr bytes.Buffer + s := newJSONSink(&stdout, &stderr) + require.NoError(t, s.Begin(fieldsWithOIDs( + []string{"id", "id__2", "id"}, + []uint32{pgtype.Int8OID, pgtype.Int8OID, pgtype.Int8OID}, + ))) + require.NoError(t, s.Row([]any{int64(1), int64(2), int64(3)})) + require.NoError(t, s.End("SELECT 1")) + + // All three keys present with no duplicates. + out := stdout.String() + assert.Contains(t, out, `"id":1`) + assert.Contains(t, out, `"id__2":2`) + assert.Contains(t, out, `"id__3":3`) + // And NOT two id__2 keys. + assert.Equal(t, 1, strings.Count(out, `"id__2"`)) +} diff --git a/experimental/postgres/cmd/render_test.go b/experimental/postgres/cmd/render_test.go index d451febb191..bdd2bddd4f6 100644 --- a/experimental/postgres/cmd/render_test.go +++ b/experimental/postgres/cmd/render_test.go @@ -83,3 +83,16 @@ func TestTextSink_OnError_NoOp(t *testing.T) { // is never flushed. assert.Empty(t, buf.String()) } + +func TestTextSink_EscapesTabAndNewlineInCells(t *testing.T) { + var buf bytes.Buffer + s := newTextSink(&buf) + require.NoError(t, s.Begin(fields("note"))) + require.NoError(t, s.Row([]any{"a\tb\nc\rd"})) + require.NoError(t, s.End("SELECT 1")) + // The escape replaces tabs/newlines/CR with their backslash-letter forms + // so the tabwriter doesn't treat them as column or row boundaries. + assert.Contains(t, buf.String(), `a\tb\nc\rd`) + assert.NotContains(t, buf.String(), "a\tb") + assert.NotContains(t, buf.String(), "c\rd") +} diff --git a/libs/cmdio/tty.go b/libs/cmdio/tty.go index 40148bb0895..c2607b8909f 100644 --- a/libs/cmdio/tty.go +++ b/libs/cmdio/tty.go @@ -7,6 +7,16 @@ import ( "github.com/mattn/go-isatty" ) +// IsOutputTTY reports whether w is connected to a terminal. Unlike +// SupportsColor this does NOT consult NO_COLOR or TERM=dumb, which are +// colour preferences and not TTY signals. Use this when a command needs +// to decide "should I default to interactive output" or "should I +// auto-fall-back to machine-readable output on a pipe", and use +// SupportsColor only for the colour-rendering decision itself. +func IsOutputTTY(w io.Writer) bool { + return isTTY(w) +} + // isTTY detects if the given reader or writer is a terminal. func isTTY(v any) bool { // Check if it's a fakeTTY first. From e81ab278c761501fb2b2058b658563bddbd3bcce Mon Sep 17 00:00:00 2001 From: simon Date: Fri, 1 May 2026 09:01:00 +0200 Subject: [PATCH 11/15] PR 1 lint fix: drop unused provisioned helpers from internal/target This PR only uses autoscaling targeting; provisioned helpers in internal/target/provisioned.go have no caller in PR 1's net diff, which the task deadcode check (run by CI's lint job, not lint-q) correctly flags. Provisioned support lands in PR 2; the necessary subset of helpers (GetProvisioned, ProvisionedCredential) is added there alongside the first caller. Co-authored-by: Isaac --- .../cmd/internal/target/provisioned.go | 66 ------------------- 1 file changed, 66 deletions(-) delete mode 100644 experimental/postgres/cmd/internal/target/provisioned.go diff --git a/experimental/postgres/cmd/internal/target/provisioned.go b/experimental/postgres/cmd/internal/target/provisioned.go deleted file mode 100644 index 261ef37a6a8..00000000000 --- a/experimental/postgres/cmd/internal/target/provisioned.go +++ /dev/null @@ -1,66 +0,0 @@ -package target - -import ( - "context" - "errors" - "fmt" - - "github.com/databricks/databricks-sdk-go" - "github.com/databricks/databricks-sdk-go/service/database" - "github.com/google/uuid" -) - -// ListProvisionedInstances returns all provisioned database instances in the workspace. -func ListProvisionedInstances(ctx context.Context, w *databricks.WorkspaceClient) ([]database.DatabaseInstance, error) { - return w.Database.ListDatabaseInstancesAll(ctx, database.ListDatabaseInstancesRequest{}) -} - -// GetProvisioned fetches a single provisioned instance by name. -// The Name field on the response can be empty; this function ensures it is -// populated from the input so downstream callers do not have to re-set it. -func GetProvisioned(ctx context.Context, w *databricks.WorkspaceClient, name string) (*database.DatabaseInstance, error) { - instance, err := w.Database.GetDatabaseInstance(ctx, database.GetDatabaseInstanceRequest{Name: name}) - if err != nil { - return nil, fmt.Errorf("failed to get database instance: %w", err) - } - if instance.Name == "" { - instance.Name = name - } - return instance, nil -} - -// AutoSelectProvisioned returns the only provisioned instance's name (e.g. -// "my-instance"; the database SDK uses flat names, not the "projects/..." -// path shape used by autoscaling). Returns an *AmbiguousError if there are -// multiple, or a plain error if none. -func AutoSelectProvisioned(ctx context.Context, w *databricks.WorkspaceClient) (string, error) { - instances, err := ListProvisionedInstances(ctx, w) - if err != nil { - return "", err - } - if len(instances) == 0 { - return "", errors.New("no Lakebase Provisioned instances found in workspace") - } - if len(instances) == 1 { - return instances[0].Name, nil - } - - choices := make([]Choice, 0, len(instances)) - for _, inst := range instances { - choices = append(choices, Choice{ID: inst.Name}) - } - return "", &AmbiguousError{Kind: KindInstance, FlagHint: "--target", Choices: choices} -} - -// ProvisionedCredential issues a short-lived OAuth token for the provisioned -// instance with the given name. -func ProvisionedCredential(ctx context.Context, w *databricks.WorkspaceClient, instanceName string) (string, error) { - cred, err := w.Database.GenerateDatabaseCredential(ctx, database.GenerateDatabaseCredentialRequest{ - InstanceNames: []string{instanceName}, - RequestId: uuid.NewString(), - }) - if err != nil { - return "", fmt.Errorf("failed to get database credentials: %w", err) - } - return cred.Token, nil -} From f714c237979c3f2ddb551d52a6b991e8067d617a Mon Sep 17 00:00:00 2001 From: simon Date: Fri, 1 May 2026 09:01:55 +0200 Subject: [PATCH 12/15] PR 2 lint fix: re-add provisioned.go with only the helpers used here PR 1's lint fix dropped the entire provisioned.go because PR 1 had no caller. Re-add a slim version with just GetProvisioned and ProvisionedCredential (the two functions resolveProvisioned actually calls). Drop ListProvisionedInstances and AutoSelectProvisioned: they were originally intended for cmd/psql interactive selection, but the cmd/psql refactor was reverted, so they have no caller anywhere. Co-authored-by: Isaac --- .../cmd/internal/target/provisioned.go | 37 +++++++++++++++++++ 1 file changed, 37 insertions(+) create mode 100644 experimental/postgres/cmd/internal/target/provisioned.go diff --git a/experimental/postgres/cmd/internal/target/provisioned.go b/experimental/postgres/cmd/internal/target/provisioned.go new file mode 100644 index 00000000000..786e86d2886 --- /dev/null +++ b/experimental/postgres/cmd/internal/target/provisioned.go @@ -0,0 +1,37 @@ +package target + +import ( + "context" + "fmt" + + "github.com/databricks/databricks-sdk-go" + "github.com/databricks/databricks-sdk-go/service/database" + "github.com/google/uuid" +) + +// GetProvisioned fetches a single provisioned instance by name. +// The Name field on the response can be empty; this function ensures it is +// populated from the input so downstream callers do not have to re-set it. +func GetProvisioned(ctx context.Context, w *databricks.WorkspaceClient, name string) (*database.DatabaseInstance, error) { + instance, err := w.Database.GetDatabaseInstance(ctx, database.GetDatabaseInstanceRequest{Name: name}) + if err != nil { + return nil, fmt.Errorf("failed to get database instance: %w", err) + } + if instance.Name == "" { + instance.Name = name + } + return instance, nil +} + +// ProvisionedCredential issues a short-lived OAuth token for the provisioned +// instance with the given name. +func ProvisionedCredential(ctx context.Context, w *databricks.WorkspaceClient, instanceName string) (string, error) { + cred, err := w.Database.GenerateDatabaseCredential(ctx, database.GenerateDatabaseCredentialRequest{ + InstanceNames: []string{instanceName}, + RequestId: uuid.NewString(), + }) + if err != nil { + return "", fmt.Errorf("failed to get database credentials: %w", err) + } + return cred.Token, nil +} From 4c4433420d30c4f8faea6347e974de6009aa8a5c Mon Sep 17 00:00:00 2001 From: simon Date: Tue, 5 May 2026 11:08:11 +0200 Subject: [PATCH 13/15] Fix TLS missing in postgres query connect pgx.ParseConfig with an empty host falls back to a unix-socket path and sets TLSConfig=nil. Patching Host after the parse leaves TLSConfig nil, so the connection goes plaintext and Lakebase rejects the pgwire startup ("Invalid protocol version: 196608"). Build the DSN with the real host so pgx derives TLSConfig correctly, and keep user/password/connect-timeout as field patches. Co-authored-by: Isaac --- experimental/postgres/cmd/connect.go | 21 +++++++++++++-------- experimental/postgres/cmd/connect_test.go | 6 ++++++ 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/experimental/postgres/cmd/connect.go b/experimental/postgres/cmd/connect.go index 2eefc681868..a211e19b1ce 100644 --- a/experimental/postgres/cmd/connect.go +++ b/experimental/postgres/cmd/connect.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net" + "net/url" "time" "github.com/databricks/cli/libs/cmdio" @@ -48,20 +49,24 @@ type retryConfig struct { // is exercised by integration tests against real Lakebase endpoints. type connectFunc func(ctx context.Context, cfg *pgx.ConnConfig) (*pgx.Conn, error) -// buildPgxConfig parses a base DSN to inherit pgx's TLS shape, then patches -// in the resolved values. The DSN-then-patch pattern is the recommended way -// to configure pgx for `sslmode=require` because building a pgx.ConnConfig -// by hand omits internal fields that the parser sets. +// buildPgxConfig parses a DSN that includes the real host so pgx derives the +// right TLSConfig and Fallbacks for sslmode=require. An empty host in the DSN +// makes pgx fall back to defaultHost(), which resolves to a unix-socket path. +// pgconn classifies that as a unix socket and assigns TLSConfig=nil; patching +// cfg.Host after the parse does not re-derive TLSConfig, so the connection +// goes out in plaintext and Lakebase rejects the pgwire startup with +// "Invalid protocol version: 196608". User, password, and connect timeout are +// patched as fields because tokens can contain characters that would need +// URL-escaping in userinfo. func buildPgxConfig(c connectConfig) (*pgx.ConnConfig, error) { - cfg, err := pgx.ParseConfig("postgresql:///?sslmode=require") + dsn := fmt.Sprintf("postgresql://%s:%d/%s?sslmode=require", + c.Host, c.Port, url.PathEscape(c.Database)) + cfg, err := pgx.ParseConfig(dsn) if err != nil { return nil, fmt.Errorf("parse pgx config: %w", err) } - cfg.Host = c.Host - cfg.Port = uint16(c.Port) cfg.User = c.Username cfg.Password = c.Password - cfg.Database = c.Database cfg.ConnectTimeout = c.ConnectTimeout return cfg, nil } diff --git a/experimental/postgres/cmd/connect_test.go b/experimental/postgres/cmd/connect_test.go index d58fc52cc74..fd294ef2765 100644 --- a/experimental/postgres/cmd/connect_test.go +++ b/experimental/postgres/cmd/connect_test.go @@ -146,4 +146,10 @@ func TestBuildPgxConfig(t *testing.T) { assert.Equal(t, "secret", cfg.Password) assert.Equal(t, "db", cfg.Database) assert.Equal(t, 30*time.Second, cfg.ConnectTimeout) + + // sslmode=require must produce a non-nil TLSConfig for the real host. + // Connecting in plaintext makes Lakebase reject the pgwire startup with + // "Invalid protocol version: 196608". + require.NotNil(t, cfg.TLSConfig, "TLSConfig must be set for sslmode=require") + assert.Equal(t, "host.example.com", cfg.TLSConfig.ServerName) } From a51dc831131b920d5d2d19d73fb8c9dbb4ad593c Mon Sep 17 00:00:00 2001 From: simon Date: Tue, 5 May 2026 13:13:27 +0200 Subject: [PATCH 14/15] Use net.JoinHostPort in pgx DSN to satisfy nosprintfhostport The golangci-lint nosprintfhostport check flags fmt.Sprintf with %s:%d for host:port in URLs. Switch to net.JoinHostPort. Co-authored-by: Isaac --- experimental/postgres/cmd/connect.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/experimental/postgres/cmd/connect.go b/experimental/postgres/cmd/connect.go index a211e19b1ce..b2038efac45 100644 --- a/experimental/postgres/cmd/connect.go +++ b/experimental/postgres/cmd/connect.go @@ -6,6 +6,7 @@ import ( "fmt" "net" "net/url" + "strconv" "time" "github.com/databricks/cli/libs/cmdio" @@ -59,8 +60,9 @@ type connectFunc func(ctx context.Context, cfg *pgx.ConnConfig) (*pgx.Conn, erro // patched as fields because tokens can contain characters that would need // URL-escaping in userinfo. func buildPgxConfig(c connectConfig) (*pgx.ConnConfig, error) { - dsn := fmt.Sprintf("postgresql://%s:%d/%s?sslmode=require", - c.Host, c.Port, url.PathEscape(c.Database)) + dsn := fmt.Sprintf("postgresql://%s/%s?sslmode=require", + net.JoinHostPort(c.Host, strconv.Itoa(c.Port)), + url.PathEscape(c.Database)) cfg, err := pgx.ParseConfig(dsn) if err != nil { return nil, fmt.Errorf("parse pgx config: %w", err) From 6757d4e26f01e3592927005a81115f434687f0c9 Mon Sep 17 00:00:00 2001 From: simon Date: Tue, 5 May 2026 14:37:07 +0200 Subject: [PATCH 15/15] Show connecting status as a spinner that clears on success The previous "Connecting to ..." line went to stderr but stayed in the terminal forever, even after results arrived. Use cmdio.NewSpinner so the status disappears once the connection succeeds. Co-authored-by: Isaac --- experimental/postgres/cmd/query.go | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/experimental/postgres/cmd/query.go b/experimental/postgres/cmd/query.go index fe5cc528ea7..47b3a00755f 100644 --- a/experimental/postgres/cmd/query.go +++ b/experimental/postgres/cmd/query.go @@ -100,8 +100,6 @@ func runQuery(ctx context.Context, cmd *cobra.Command, sql string, f queryFlags) return err } - cmdio.LogString(ctx, fmt.Sprintf("Connecting to %s...", resolved.DisplayName)) - pgxCfg, err := buildPgxConfig(connectConfig{ Host: resolved.Host, Port: 5432, @@ -120,7 +118,13 @@ func runQuery(ctx context.Context, cmd *cobra.Command, sql string, f queryFlags) MaxDelay: 10 * time.Second, } + // Spinner clears its line on Close, so the "Connecting to ..." status + // disappears once the connection is up. cmdio.NewSpinner already writes + // to stderr and degrades to a no-op in non-interactive terminals. + sp := cmdio.NewSpinner(ctx) + sp.Update("Connecting to " + resolved.DisplayName) conn, err := connectWithRetry(ctx, pgxCfg, rc, pgx.ConnectConfig) + sp.Close() if err != nil { return err }