Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions crates/corro-pg/src/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@ use pgwire::{
api::{self, ClientInfo},
error::PgWireError,
messages as msg,
types::FromSqlText,
};
use postgres_types::Type;
use std::{collections::HashMap, io};
use tokio_util::codec;

Expand Down Expand Up @@ -169,3 +171,102 @@ where
}
}
}

pub trait VecFromSqlText: Sized {
fn from_vec_sql_text(
ty: &Type,
input: &[u8],
) -> Result<Self, Box<dyn std::error::Error + Sync + Send>>;
}

// Re-implementation of the ToSqlText trait from pg_wire to make it generic over different types.
// Implemented as a macro in pgwire
// https://github.com/sunng87/pgwire/blob/6cbce9d444cc86a01d992f6b35f84c024f10ceda/src/types/from_sql_text.rs#L402
impl<T: FromSqlText> VecFromSqlText for Vec<T> {
fn from_vec_sql_text(
ty: &Type,
input: &[u8],
) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
// PostgreSQL array text format: {elem1,elem2,elem3}
// Remove the outer braces
let input_str = std::str::from_utf8(input)?;

if input_str.is_empty() {
return Ok(Vec::new());
}

// Check if it's an array format
if !input_str.starts_with('{') || !input_str.ends_with('}') {
return Err("Invalid array format: must start with '{' and end with '}'".into());
}

let inner = &input_str[1..input_str.len() - 1];

if inner.is_empty() {
return Ok(Vec::new());
}

let elements = extract_array_elements(inner)?;
let mut result = Vec::new();

for element_str in elements {
let element = T::from_sql_text(ty, element_str.as_bytes())?;
result.push(element);
}

Ok(result)
}
}

// Helper function to extract array elements
// https://github.com/sunng87/pgwire/blob/6cbce9d444cc86a01d992f6b35f84c024f10ceda/src/types/from_sql_text.rs#L402
fn extract_array_elements(
input: &str,
) -> Result<Vec<String>, Box<dyn std::error::Error + Sync + Send>> {
if input.is_empty() {
return Ok(Vec::new());
}

let mut elements = Vec::new();
let mut current = String::new();
let mut in_quotes = false;
let mut escape_next = false;
let mut depth = 0; // For nested arrays

for ch in input.chars() {
match ch {
'\\' if !escape_next => {
escape_next = true;
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possible bug in pgwire's implementation where we add escape baskslashes to the string was fixed here so it works fine with go's postgres client. I have opened an issue in pgwire

'"' if !escape_next => {
in_quotes = !in_quotes;
// Don't include the quotes in the output
}
'{' if !in_quotes && !escape_next => {
depth += 1;
current.push(ch);
}
'}' if !in_quotes && !escape_next => {
depth -= 1;
current.push(ch);
}
',' if !in_quotes && depth == 0 && !escape_next => {
// End of current element
if !current.trim().eq_ignore_ascii_case("NULL") {
elements.push(std::mem::take(&mut current));
}
}
_ => {
current.push(ch);
escape_next = false;
}
}
}

// Process the last element
if !current.is_empty() && !current.trim().eq_ignore_ascii_case("NULL") {
elements.push(current);
}

Ok(elements)
}
11 changes: 7 additions & 4 deletions crates/corro-pg/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mod ssl;
pub mod utils;
mod vtab;

use codec::VecFromSqlText;
use eyre::WrapErr;
use std::{
collections::{BTreeSet, HashMap, VecDeque},
Expand Down Expand Up @@ -42,6 +43,7 @@ use pgwire::{
startup::ParameterStatus,
PgWireBackendMessage, PgWireFrontendMessage,
},
types::FromSqlText,
};
use postgres_types::{FromSql, Type};
use rusqlite::{
Expand Down Expand Up @@ -2647,13 +2649,13 @@ fn from_type_and_format<'a, E, T: FromSql<'a> + FromStr<Err = E>>(
})
}

fn from_array_type_and_format<'a, T: FromSql<'a>>(
fn from_array_type_and_format<'a, T: FromSql<'a> + FromSqlText>(
t: &Type,
b: &'a [u8],
format_code: FormatCode,
) -> Result<Vec<T>, ToParamError<String>> {
Ok(match format_code {
FormatCode::Text => panic!("Impossible - arrays are only sent in binary format"),
FormatCode::Text => Vec::<T>::from_vec_sql_text(t, b).map_err(ToParamError::FromSql)?,
FormatCode::Binary => Vec::<T>::from_sql(t, b).map_err(ToParamError::FromSql)?,
})
}
Expand All @@ -2675,7 +2677,7 @@ impl From<UnsupportedSqliteToPostgresType> for ErrorResponse {
}

#[derive(Debug, thiserror::Error)]
#[error("Untyped array argument for unnest(), please use CAST($N AS T) where T is one of: TEXT[] BLOB[] INT[] INTEGER[] BIGINT[] REAL[] FLOAT[] DOUBLE[]")]
#[error("Untyped array argument for unnest() (or corro_unnest()), please use CAST($N AS T) where T is one of: TEXT[] BLOB[] INT[] INTEGER[] BIGINT[] REAL[] FLOAT[] DOUBLE[]")]
struct UntypedUnnestParameter;

impl From<UntypedUnnestParameter> for PgWireBackendMessage {
Expand Down Expand Up @@ -3105,7 +3107,8 @@ fn handle_table_call_params<'schema, 'stmt>(
params: &mut ParamsList<'stmt, 'schema>,
) -> Result<(), UntypedUnnestParameter> {
if let Some(exprs) = args {
let is_unnest = qname.name.0.eq_ignore_ascii_case("UNNEST");
let is_unnest = qname.name.0.eq_ignore_ascii_case("CORRO_UNNEST")
|| qname.name.0.eq_ignore_ascii_case("UNNEST");

for expr in exprs.iter() {
// If not unnest, just extract params
Expand Down
99 changes: 75 additions & 24 deletions crates/corro-pg/tests/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ use corro_types::{
config::{PgConfig, PgTlsConfig},
tls::{generate_ca, generate_client_cert, generate_server_cert},
};
use postgres_types::ToSql;
use pgwire::types::ToSqlText;
use postgres_types::{Format, IsNull, ToSql, Type};
use rcgen::Certificate;
use rustls::pki_types::pem::PemObject;
use spawn::wait_for_all_pending_handles;
Expand Down Expand Up @@ -805,6 +806,43 @@ async fn test_unnest_typing() {
wait_for_all_pending_handles().await;
}

// wrapper so we can easily switch between text and binary formats
#[derive(Debug)]
struct SqlVec<'a, T> {
inner: &'a Vec<T>,
format: Format,
}

// test text encoding/decoding
impl<'a, T: ToSqlText + ToSql> ToSql for SqlVec<'a, T> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Love this approach

fn to_sql(
&self,
ty: &Type,
out: &mut bytes::BytesMut,
) -> Result<IsNull, Box<dyn std::error::Error + Sync + Send>> {
match self.format {
Format::Text => self.inner.to_sql_text(ty, out),
Format::Binary => self.inner.to_sql(ty, out),
}
}

fn accepts(ty: &postgres_types::Type) -> bool
where
Self: Sized,
{
match ty.kind() {
postgres_types::Kind::Array(_) => true,
_ => false,
}
}

fn encode_format(&self, _ty: &Type) -> Format {
self.format
}

postgres_types::to_sql_checked!();
}

#[tokio::test(flavor = "multi_thread")]
async fn test_unnest_max_parameters() {
let (tripwire, tripwire_worker, tripwire_tx) = Tripwire::new_simple();
Expand Down Expand Up @@ -925,37 +963,50 @@ async fn test_unnest_vtab() {

// Test single array unnest with text type
{
let col1 = vec!["a", "b", "c", "d", "e", "f"];
let rows = client
.query(
"SELECT CAST(value0 AS text) FROM unnest(CAST($1 AS text[]))",
&[&col1],
)
.await
.unwrap();
for (i, row) in rows.iter().enumerate() {
let val: String = row.get(0);
assert_eq!(val, col1[i]);
for format in [Format::Text, Format::Binary] {
let col1 = vec!["a", "b", "c", "d", "e", "f"];
let sql_vec = SqlVec {
inner: &col1,
format,
};
let rows = client
.query(
"SELECT CAST(value0 AS text) FROM unnest(CAST($1 AS text[]))",
&[&sql_vec],
)
.await
.unwrap();
for (i, row) in rows.iter().enumerate() {
let val: String = row.get(0);
assert_eq!(val, col1[i]);
}
}
}

// Test single array unnest with float type
{
let col1 = vec![1.0, 2.0, 3.0, 4.0, 1337.0, 12312312312.0];
let rows = client
.query(
"SELECT CAST(value0 AS float) FROM unnest(CAST($1 AS float[]))",
&[&col1],
)
.await
.unwrap();
for (i, row) in rows.iter().enumerate() {
let val: f64 = row.get(0);
assert_eq!(val, col1[i]);
for format in [Format::Text, Format::Binary] {
let sql_vec = SqlVec {
inner: &col1,
format,
};
let rows = client
.query(
"SELECT CAST(value0 AS float) FROM unnest(CAST($1 AS float[]))",
&[&sql_vec],
)
.await
.unwrap();
for (i, row) in rows.iter().enumerate() {
let val: f64 = row.get(0);
assert_eq!(val, col1[i]);
}
}
}

// Test single array unnest with blob type
// TODO: pgwire's text encoding for blob[] is currently broken but we'd work for proper clients
{
let col1 = vec![b"a", b"b", b"c", b"d", b"e", b"f"];
let rows = client
Expand All @@ -971,7 +1022,7 @@ async fn test_unnest_vtab() {
}
}

// Now try all at once with different types
// Now try all at once with different types, use corro_unnest
{
let col1 = vec![1i64, 2, 3, 4, 1337, 12312312312];
let col2 = vec!["a", "b", "c", "d", "e", "f"];
Expand All @@ -980,7 +1031,7 @@ async fn test_unnest_vtab() {
let rows = client
.query(
"SELECT
CAST(value0 AS int), CAST(value1 AS text), CAST(value2 AS float), CAST(value3 AS blob) FROM unnest(CAST($1 AS int[]), CAST($2 AS text[]), CAST($3 AS float[]), CAST($4 AS blob[]))",
CAST(value0 AS int), CAST(value1 AS text), CAST(value2 AS float), CAST(value3 AS blob) FROM corro_unnest(CAST($1 AS int[]), CAST($2 AS text[]), CAST($3 AS float[]), CAST($4 AS blob[]))",
&[&col1, &col2, &col3, &col4],
)
.await
Expand Down
1 change: 1 addition & 0 deletions crates/corro-types/src/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ pub fn setup_conn(conn: &Connection) -> Result<(), rusqlite::Error> {

// Register unnest for PostgreSQL-style multi-array unnesting
conn.create_module("unnest", eponymous_only_module::<UnnestTab>(), None)?;
conn.create_module("corro_unnest", eponymous_only_module::<UnnestTab>(), None)?;

Ok(())
}
Expand Down
Loading