Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
113 changes: 113 additions & 0 deletions crates/corro-pg/src/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@ use pgwire::{
api::{self, ClientInfo},
error::PgWireError,
messages as msg,
types::FromSqlText,
};
use postgres_types::Type;
use std::{collections::HashMap, io};
use tokio_util::codec;
use tracing::debug;

pub struct Client {
pub socket_addr: std::net::SocketAddr,
Expand Down Expand Up @@ -169,3 +172,113 @@ where
}
}
}

pub trait VecFromSqlText: Sized {
fn from_vec_sql_text(
ty: &Type,
input: &[u8],
) -> Result<Self, Box<dyn std::error::Error + Sync + Send>>;
}

// Re-implementation of the ToSqlText trait from pg_wire to make it generic over different types.
// Implemented as a macro in pgwire
// https://github.com/sunng87/pgwire/blob/6cbce9d444cc86a01d992f6b35f84c024f10ceda/src/types/from_sql_text.rs#L402
impl<T: FromSqlText> VecFromSqlText for Vec<T> {
fn from_vec_sql_text(
ty: &Type,
input: &[u8],
) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
// PostgreSQL array text format: {elem1,elem2,elem3}
// Remove the outer braces
let input_str = std::str::from_utf8(input)?;

if input_str.is_empty() {
return Ok(Vec::new());
}

// Check if it's an array format
if !input_str.starts_with('{') || !input_str.ends_with('}') {
return Err(format!(
"Invalid array format: must start with '{{' and end with '}}', input: {input_str}"
)
.into());
}

let inner = &input_str[1..input_str.len() - 1];

if inner.is_empty() {
return Ok(Vec::new());
}

let elements = extract_array_elements(inner)?;
let mut result = Vec::new();

for element_str in elements {
let element = T::from_sql_text(ty, element_str.as_bytes())?;
result.push(element);
}

Ok(result)
}
}

// Helper function to extract array elements
// https://github.com/sunng87/pgwire/blob/6cbce9d444cc86a01d992f6b35f84c024f10ceda/src/types/from_sql_text.rs#L402
fn extract_array_elements(
input: &str,
) -> Result<Vec<String>, Box<dyn std::error::Error + Sync + Send>> {
if input.is_empty() {
return Ok(Vec::new());
}

let mut elements = Vec::new();
let mut current = String::new();
let mut in_quotes = false;
let mut escape_next = false;
let mut seen_content = false; // helpful for tracking when the last element is an empty string

for ch in input.chars() {
match ch {
'\\' if !escape_next => {
escape_next = true;
}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possible bug in pgwire's implementation where we add escape baskslashes to the string was fixed here so it works fine with go's postgres client. I have opened an issue in pgwire

'"' if !escape_next => {
in_quotes = !in_quotes;
// we have seen a new element surrounded by quotes
if !in_quotes {
seen_content = true;
}
// Don't include the quotes in the output
}
'{' if !in_quotes && !escape_next => {
return Err("Nested arrays are not supported".into());
}
'}' if !in_quotes && !escape_next => {
return Err("Nested arrays are not supported".into());
}
',' if !in_quotes && !escape_next => {
// End of current element
if !current.trim().eq_ignore_ascii_case("NULL") {
elements.push(std::mem::take(&mut current));
seen_content = false;
}
}
_ => {
current.push(ch);
escape_next = false;
seen_content = true;
}
}
}

// Process the last element
if seen_content && !current.trim().eq_ignore_ascii_case("NULL") {
elements.push(current);
}

debug!(
"extracted elements: {elements:?} from input: {input}, lenght: {}",
elements.len()
);
Ok(elements)
}
27 changes: 20 additions & 7 deletions crates/corro-pg/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mod ssl;
pub mod utils;
mod vtab;

use codec::VecFromSqlText;
use eyre::WrapErr;
use std::{
collections::{BTreeSet, HashMap, VecDeque},
Expand Down Expand Up @@ -42,6 +43,7 @@ use pgwire::{
startup::ParameterStatus,
PgWireBackendMessage, PgWireFrontendMessage,
},
types::FromSqlText,
};
use postgres_types::{FromSql, Type};
use rusqlite::{
Expand Down Expand Up @@ -2647,13 +2649,13 @@ fn from_type_and_format<'a, E, T: FromSql<'a> + FromStr<Err = E>>(
})
}

fn from_array_type_and_format<'a, T: FromSql<'a>>(
fn from_array_type_and_format<'a, T: FromSql<'a> + FromSqlText>(
t: &Type,
b: &'a [u8],
format_code: FormatCode,
) -> Result<Vec<T>, ToParamError<String>> {
Ok(match format_code {
FormatCode::Text => panic!("Impossible - arrays are only sent in binary format"),
FormatCode::Text => Vec::<T>::from_vec_sql_text(t, b).map_err(ToParamError::FromSql)?,
FormatCode::Binary => Vec::<T>::from_sql(t, b).map_err(ToParamError::FromSql)?,
})
}
Expand All @@ -2675,7 +2677,7 @@ impl From<UnsupportedSqliteToPostgresType> for ErrorResponse {
}

#[derive(Debug, thiserror::Error)]
#[error("Untyped array argument for unnest(), please use CAST($N AS T) where T is one of: TEXT[] BLOB[] INT[] INTEGER[] BIGINT[] REAL[] FLOAT[] DOUBLE[]")]
#[error("Untyped array argument for unnest() (or corro_unnest()), please use CAST($N AS T) where T is one of: TEXT[] BLOB[] INT[] INTEGER[] BIGINT[] REAL[] FLOAT[] DOUBLE[]")]
struct UntypedUnnestParameter;

impl From<UntypedUnnestParameter> for PgWireBackendMessage {
Expand Down Expand Up @@ -2902,10 +2904,16 @@ fn extract_params<'schema, 'stmt>(
Expr::FunctionCall {
name: _,
distinctness: _,
args: _,
args,
filter_over: _,
order_by: _,
} => {}
} => {
if let Some(args) = args {
for expr in args.iter() {
extract_params(schema, expr, tables, params)?
}
}
}

Expr::FunctionCallStar {
name: _,
Expand Down Expand Up @@ -3105,7 +3113,8 @@ fn handle_table_call_params<'schema, 'stmt>(
params: &mut ParamsList<'stmt, 'schema>,
) -> Result<(), UntypedUnnestParameter> {
if let Some(exprs) = args {
let is_unnest = qname.name.0.eq_ignore_ascii_case("UNNEST");
let is_unnest = qname.name.0.eq_ignore_ascii_case("CORRO_UNNEST")
|| qname.name.0.eq_ignore_ascii_case("UNNEST");

for expr in exprs.iter() {
// If not unnest, just extract params
Expand Down Expand Up @@ -3352,7 +3361,11 @@ fn parameter_types<'schema, 'stmt>(

let mut tables = HashMap::new();
if let Some(tbl) = schema.tables.get(&tbl_name.name.0) {
tables.insert(tbl_name.name.0.clone(), tbl);
if let Some(alias) = &tbl_name.alias {
tables.insert(alias.0.clone(), tbl);
} else {
tables.insert(tbl_name.name.0.clone(), tbl);
}
}
if let Some(where_clause) = where_clause {
extract_params(schema, where_clause, &tables, &mut params)?;
Expand Down
Loading
Loading