Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -4542,6 +4542,17 @@ func (c *clientConn) handleExecute(body []byte) {
rowCount++
}

if err := rows.Err(); err != nil {
if isQueryCancelled(err) {
c.sendError("ERROR", "57014", "canceling statement due to user request")
} else {
slog.Error("Row iteration error.", "user", c.username, "error", err)
c.sendError("ERROR", "42000", err.Error())
}
c.setTxError()
return
}

c.updateTxStatus(cmdType)
tag := buildCommandTagFromRowCount(cmdType, int64(rowCount))
_ = writeCommandComplete(c.writer, tag)
Expand Down
44 changes: 22 additions & 22 deletions server/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -156,65 +156,65 @@ func mapDuckDBType(typeName string) TypeInfo {
return TypeInfo{OID: arrayOID, Size: -1, Typmod: elemInfo.Typmod}
}
// Unknown element type — fall through to text
return TypeInfo{OID: OidText, Size: -1}
return TypeInfo{OID: OidText, Size: -1, Typmod: -1}
}

switch {
case upper == "BOOLEAN" || upper == "BOOL":
return TypeInfo{OID: OidBool, Size: 1}
return TypeInfo{OID: OidBool, Size: 1, Typmod: -1}
case upper == "TINYINT" || upper == "INT1":
return TypeInfo{OID: OidInt2, Size: 2} // PostgreSQL doesn't have int1
return TypeInfo{OID: OidInt2, Size: 2, Typmod: -1} // PostgreSQL doesn't have int1
case upper == "SMALLINT" || upper == "INT2":
return TypeInfo{OID: OidInt2, Size: 2}
return TypeInfo{OID: OidInt2, Size: 2, Typmod: -1}
case upper == "INTEGER" || upper == "INT4" || upper == "INT":
return TypeInfo{OID: OidInt4, Size: 4}
return TypeInfo{OID: OidInt4, Size: 4, Typmod: -1}
case upper == "BIGINT" || upper == "INT8":
return TypeInfo{OID: OidInt8, Size: 8}
return TypeInfo{OID: OidInt8, Size: 8, Typmod: -1}
case upper == "HUGEINT" || upper == "INT128":
// Map to NUMERIC(38,0) so postgres_scanner reads it as DECIMAL(38,0) → INT128,
// matching the HUGEINT physical type. Typmod = ((38 << 16) | 0) + 4 = 2490372.
return TypeInfo{OID: OidNumeric, Size: -1, Typmod: 2490372}
case upper == "UTINYINT" || upper == "USMALLINT":
return TypeInfo{OID: OidInt4, Size: 4}
return TypeInfo{OID: OidInt4, Size: 4, Typmod: -1}
case upper == "UINTEGER":
return TypeInfo{OID: OidOid, Size: 4} // PostgreSQL oid type for pg_catalog columns
return TypeInfo{OID: OidOid, Size: 4, Typmod: -1} // PostgreSQL oid type for pg_catalog columns
case upper == "UBIGINT":
// Map to NUMERIC(20,0) so postgres_scanner reads it as DECIMAL(20,0) → INT128.
// UBIGINT max (2^64-1 = 18446744073709551615) is 20 digits.
// Without typmod, the extension can't determine buffer size and fails with
// "out of buffer in ReadInteger". Typmod = ((20 << 16) | 0) + 4 = 1310724.
return TypeInfo{OID: OidNumeric, Size: -1, Typmod: 1310724}
case upper == "REAL" || upper == "FLOAT4" || upper == "FLOAT":
return TypeInfo{OID: OidFloat4, Size: 4}
return TypeInfo{OID: OidFloat4, Size: 4, Typmod: -1}
case upper == "DOUBLE" || upper == "FLOAT8":
return TypeInfo{OID: OidFloat8, Size: 8}
return TypeInfo{OID: OidFloat8, Size: 8, Typmod: -1}
case strings.HasPrefix(upper, "DECIMAL") || strings.HasPrefix(upper, "NUMERIC"):
return TypeInfo{OID: OidNumeric, Size: -1, Typmod: parseNumericTypmod(typeName)}
case upper == "VARCHAR" || strings.HasPrefix(upper, "VARCHAR("):
return TypeInfo{OID: OidVarchar, Size: -1}
return TypeInfo{OID: OidVarchar, Size: -1, Typmod: -1}
case upper == "TEXT" || upper == "STRING":
return TypeInfo{OID: OidText, Size: -1}
return TypeInfo{OID: OidText, Size: -1, Typmod: -1}
case upper == "BLOB" || upper == "BYTEA":
return TypeInfo{OID: OidBytea, Size: -1}
return TypeInfo{OID: OidBytea, Size: -1, Typmod: -1}
case upper == "DATE":
return TypeInfo{OID: OidDate, Size: 4}
return TypeInfo{OID: OidDate, Size: 4, Typmod: -1}
case upper == "TIME":
return TypeInfo{OID: OidTime, Size: 8}
return TypeInfo{OID: OidTime, Size: 8, Typmod: -1}
case upper == "TIME WITH TIME ZONE" || upper == "TIMETZ":
return TypeInfo{OID: OidTimetz, Size: 12}
return TypeInfo{OID: OidTimetz, Size: 12, Typmod: -1}
case upper == "TIMESTAMP":
return TypeInfo{OID: OidTimestamp, Size: 8}
return TypeInfo{OID: OidTimestamp, Size: 8, Typmod: -1}
case upper == "TIMESTAMP WITH TIME ZONE" || upper == "TIMESTAMPTZ":
return TypeInfo{OID: OidTimestamptz, Size: 8}
return TypeInfo{OID: OidTimestamptz, Size: 8, Typmod: -1}
case upper == "INTERVAL":
return TypeInfo{OID: OidInterval, Size: 16}
return TypeInfo{OID: OidInterval, Size: 16, Typmod: -1}
case upper == "UUID":
return TypeInfo{OID: OidUUID, Size: 16}
return TypeInfo{OID: OidUUID, Size: 16, Typmod: -1}
case upper == "JSON":
return TypeInfo{OID: OidJSON, Size: -1}
return TypeInfo{OID: OidJSON, Size: -1, Typmod: -1}
default:
// Default to text for unknown types
return TypeInfo{OID: OidText, Size: -1}
return TypeInfo{OID: OidText, Size: -1, Typmod: -1}
}
}

Expand Down
61 changes: 61 additions & 0 deletions server/types_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,67 @@ func TestMapDuckDBType(t *testing.T) {
}
}

func TestMapDuckDBTypeTypmod(t *testing.T) {
// PostgreSQL uses typmod=-1 to mean "no modifier". JDBC clients (pgjdbc)
// interpret typmod=0 differently from -1. For example, INTERVAL with typmod=0
// means "second precision 0" (no fractional seconds), while typmod=-1 means
// default precision (microseconds). Sending the wrong typmod breaks JDBC
// metadata and can cause client-side errors.
tests := []struct {
typeName string
wantTypmod int32
}{
// Types without modifiers should have typmod=-1
{"BOOLEAN", -1},
{"TINYINT", -1},
{"SMALLINT", -1},
{"INTEGER", -1},
{"BIGINT", -1},
{"REAL", -1},
{"DOUBLE", -1},
{"VARCHAR", -1},
{"TEXT", -1},
{"BYTEA", -1},
{"DATE", -1},
{"TIME", -1},
{"TIMETZ", -1},
{"TIMESTAMP", -1},
{"TIMESTAMPTZ", -1},
{"INTERVAL", -1},
{"UUID", -1},
{"JSON", -1},
{"UTINYINT", -1},
{"USMALLINT", -1},
{"UINTEGER", -1},
// NUMERIC without precision should have typmod=-1
{"NUMERIC", -1},
{"DECIMAL", -1},
// NUMERIC with precision should have a positive typmod
{"DECIMAL(10,2)", int32((10<<16)|2) + 4},
// HUGEINT and UBIGINT have specific typmods for postgres_scanner
{"HUGEINT", int32(38<<16) + 4},
{"UBIGINT", int32(20<<16) + 4},
// Aliases and long-form type names
{"STRING", -1},
{"TIME WITH TIME ZONE", -1},
{"TIMESTAMP WITH TIME ZONE", -1},
// Array types inherit element typmod
{"INTEGER[]", -1},
{"INTERVAL[]", -1},
// Unknown types should default to typmod=-1
{"SOMECUSTOMTYPE", -1},
}

for _, tt := range tests {
t.Run(tt.typeName, func(t *testing.T) {
info := mapDuckDBType(tt.typeName)
if info.Typmod != tt.wantTypmod {
t.Errorf("mapDuckDBType(%q).Typmod = %d, want %d", tt.typeName, info.Typmod, tt.wantTypmod)
}
})
}
}

func TestEncodeBool(t *testing.T) {
tests := []struct {
name string
Expand Down
63 changes: 63 additions & 0 deletions tests/integration/clients/clients_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package clients

import (
"context"
"database/sql"
"fmt"
"os"
Expand Down Expand Up @@ -564,6 +565,68 @@ func TestPreparedStatements(t *testing.T) {
})
}

// TestExtendedQueryErrorHandling verifies that errors during the extended query
// protocol (prepared statements) are properly reported to the client and don't
// corrupt the connection state.
func TestExtendedQueryErrorHandling(t *testing.T) {
t.Run("error_during_prepared_execution", func(t *testing.T) {
// Use a single connection so we can verify the SAME connection survives errors.
conn, err := testDB.Conn(context.Background())
if err != nil {
t.Fatalf("Conn failed: %v", err)
}
defer func() { _ = conn.Close() }()

// CAST($1 AS INTEGER) with a non-numeric string uses the extended query
// protocol (Parse/Bind/Execute) and produces a conversion error at execution time.
var val int
err = conn.QueryRowContext(context.Background(), "SELECT CAST($1 AS INTEGER)", "not-a-number").Scan(&val)
if err == nil {
t.Fatal("Expected conversion error, got none")
}

// The SAME connection should still be usable after the error.
// If handleExecute didn't properly handle the error (e.g., protocol desync),
// this would fail or hang.
var result int
err = conn.QueryRowContext(context.Background(), "SELECT $1::int", 42).Scan(&result)
if err != nil {
t.Fatalf("Connection unusable after error: %v", err)
}
if result != 42 {
t.Errorf("Expected 42, got %d", result)
}
})

t.Run("error_recovery_multi_cycle", func(t *testing.T) {
// Run multiple error/success cycles on the same connection.
conn, err := testDB.Conn(context.Background())
if err != nil {
t.Fatalf("Conn failed: %v", err)
}
defer func() { _ = conn.Close() }()

for i := 0; i < 3; i++ {
// Failing query via extended protocol
var badVal int
err := conn.QueryRowContext(context.Background(), "SELECT CAST($1 AS INTEGER)", "bad").Scan(&badVal)
if err == nil {
t.Fatalf("Expected error on iteration %d", i)
}

// Successful query via extended protocol on the same connection
var val int
err = conn.QueryRowContext(context.Background(), "SELECT $1::int + 1", i).Scan(&val)
if err != nil {
t.Fatalf("Success query failed on iteration %d: %v", i, err)
}
if val != i+1 {
t.Errorf("Expected %d, got %d on iteration %d", i+1, val, i)
}
}
})
}

// TestTransactions tests transaction handling
func TestTransactions(t *testing.T) {
t.Run("commit", func(t *testing.T) {
Expand Down