Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 87 additions & 0 deletions include/lbug_rs.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ struct TypeListBuilder {
};

std::unique_ptr<TypeListBuilder> create_type_list();
inline void type_list_insert(TypeListBuilder& list, std::unique_ptr<lbug::common::LogicalType> type) {
list.insert(std::move(type));
}

struct QueryParams {
std::unordered_map<std::string, std::unique_ptr<lbug::common::Value>> inputParams;
Expand All @@ -42,6 +45,10 @@ struct QueryParams {
};

std::unique_ptr<QueryParams> new_params();
inline void query_params_insert(QueryParams& params, const rust::Str key,
std::unique_ptr<lbug::common::Value> value) {
params.insert(key, std::move(value));
}

std::unique_ptr<lbug::common::LogicalType> create_logical_type(lbug::common::LogicalTypeID id);
std::unique_ptr<lbug::common::LogicalType> create_logical_type_list(
Expand Down Expand Up @@ -93,6 +100,10 @@ inline uint32_t logical_type_get_decimal_precision(const lbug::common::LogicalTy
inline uint32_t logical_type_get_decimal_scale(const lbug::common::LogicalType& logicalType) {
return lbug::common::DecimalType::getScale(logicalType);
}
inline lbug::common::LogicalTypeID logical_type_get_logical_type_id(
const lbug::common::LogicalType& logicalType) {
return logicalType.getLogicalTypeID();
}

/* Database */
std::unique_ptr<lbug::main::Database> new_database(std::string_view databasePath,
Expand All @@ -110,20 +121,56 @@ inline std::unique_ptr<lbug::main::QueryResult> connection_query(lbug::main::Con
std::string_view query) {
return connection.query(query);
}
inline std::unique_ptr<lbug::main::PreparedStatement> connection_prepare(
lbug::main::Connection& connection, std::string_view query) {
return connection.prepare(query);
}
inline uint64_t connection_get_max_num_thread_for_exec(lbug::main::Connection& connection) {
return connection.getMaxNumThreadForExec();
}
inline void connection_set_max_num_thread_for_exec(lbug::main::Connection& connection,
uint64_t numThreads) {
connection.setMaxNumThreadForExec(numThreads);
}
inline void connection_interrupt(lbug::main::Connection& connection) {
connection.interrupt();
}
inline void connection_set_query_timeout(lbug::main::Connection& connection, uint64_t timeoutMs) {
connection.setQueryTimeOut(timeoutMs);
}

/* PreparedStatement */
rust::String prepared_statement_error_message(const lbug::main::PreparedStatement& statement);
inline lbug::common::StatementType prepared_statement_get_statement_type(
const lbug::main::PreparedStatement& statement) {
return statement.getStatementType();
}
inline bool prepared_statement_is_success(const lbug::main::PreparedStatement& statement) {
return statement.isSuccess();
}

/* QueryResult */
rust::String query_result_to_string(const lbug::main::QueryResult& result);
rust::String query_result_get_error_message(const lbug::main::QueryResult& result);
inline bool query_result_is_success(const lbug::main::QueryResult& result) {
return result.isSuccess();
}
inline bool query_result_has_next(const lbug::main::QueryResult& result) {
return result.hasNext();
}
inline std::shared_ptr<lbug::processor::FlatTuple> query_result_get_next(
lbug::main::QueryResult& result) {
return result.getNext();
}

double query_result_get_compiling_time(const lbug::main::QueryResult& result);
double query_result_get_execution_time(const lbug::main::QueryResult& result);
inline size_t query_result_get_num_columns(const lbug::main::QueryResult& result) {
return result.getNumColumns();
}
inline uint64_t query_result_get_num_tuples(const lbug::main::QueryResult& result) {
return result.getNumTuples();
}

std::unique_ptr<std::vector<lbug::common::LogicalType>> query_result_column_data_types(
const lbug::main::QueryResult& query_result);
Expand Down Expand Up @@ -156,6 +203,7 @@ const lbug::common::Value& recursive_rel_get_nodes(const lbug::common::Value& va
const lbug::common::Value& recursive_rel_get_rels(const lbug::common::Value& val);

/* FlatTuple */
uint32_t flat_tuple_len(const lbug::processor::FlatTuple& flatTuple);
const lbug::common::Value& flat_tuple_get_value(const lbug::processor::FlatTuple& flatTuple,
uint32_t index);

Expand Down Expand Up @@ -185,6 +233,42 @@ inline lbug::common::PhysicalTypeID value_get_physical_type(const lbug::common::
return value.getDataType().getPhysicalType();
}
rust::String value_to_string(const lbug::common::Value& val);
inline bool value_get_bool(const lbug::common::Value& value) {
return value.getValue<bool>();
}
inline int8_t value_get_i8(const lbug::common::Value& value) {
return value.getValue<int8_t>();
}
inline int16_t value_get_i16(const lbug::common::Value& value) {
return value.getValue<int16_t>();
}
inline int32_t value_get_i32(const lbug::common::Value& value) {
return value.getValue<int32_t>();
}
inline int64_t value_get_i64(const lbug::common::Value& value) {
return value.getValue<int64_t>();
}
inline uint8_t value_get_u8(const lbug::common::Value& value) {
return value.getValue<uint8_t>();
}
inline uint16_t value_get_u16(const lbug::common::Value& value) {
return value.getValue<uint16_t>();
}
inline uint32_t value_get_u32(const lbug::common::Value& value) {
return value.getValue<uint32_t>();
}
inline uint64_t value_get_u64(const lbug::common::Value& value) {
return value.getValue<uint64_t>();
}
inline float value_get_float(const lbug::common::Value& value) {
return value.getValue<float>();
}
inline double value_get_double(const lbug::common::Value& value) {
return value.getValue<double>();
}
inline bool value_is_null(const lbug::common::Value& value) {
return value.isNull();
}

std::unique_ptr<lbug::common::Value> create_value_string(lbug::common::LogicalTypeID typ,
const rust::Slice<const unsigned char> value);
Expand Down Expand Up @@ -237,6 +321,9 @@ struct ValueListBuilder {
std::unique_ptr<lbug::common::Value> get_list_value(std::unique_ptr<lbug::common::LogicalType> typ,
std::unique_ptr<ValueListBuilder> value);
std::unique_ptr<ValueListBuilder> create_list();
inline void value_list_insert(ValueListBuilder& list, std::unique_ptr<lbug::common::Value> value) {
list.insert(std::move(value));
}

inline std::string_view string_view_from_str(rust::Str s) {
return {s.data(), s.size()};
Expand Down
25 changes: 12 additions & 13 deletions src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,12 @@ impl<'a> Connection<'a> {
/// # Arguments
/// * `num_threads`: The maximum number of threads to use for execution in the current connection
pub fn set_max_num_threads_for_exec(&mut self, num_threads: u64) {
self.conn
.get_mut()
.pin_mut()
.setMaxNumThreadForExec(num_threads);
ffi::connection_set_max_num_thread_for_exec(self.conn.get_mut().pin_mut(), num_threads);
}

/// Returns the maximum number of threads used for execution in the current connection
pub fn get_max_num_threads_for_exec(&self) -> u64 {
unsafe { (*self.conn.get()).pin_mut().getMaxNumThreadForExec() }
ffi::connection_get_max_num_thread_for_exec(unsafe { (*self.conn.get()).pin_mut() })
}

/// Prepares the given query and returns the prepared statement. [`PreparedStatement`]s can be run
Expand All @@ -116,9 +113,11 @@ impl<'a> Connection<'a> {
/// * `query`: The query to prepare. See <https://ladybugdb.com/docs/cypher> for details on the
/// query format.
pub fn prepare(&self, query: &str) -> Result<PreparedStatement, Error> {
let statement =
unsafe { (*self.conn.get()).pin_mut() }.prepare(ffi::StringView::new(query))?;
if statement.isSuccess() {
let statement = ffi::connection_prepare(
unsafe { (*self.conn.get()).pin_mut() },
ffi::StringView::new(query),
)?;
if ffi::prepared_statement_is_success(&statement) {
Ok(PreparedStatement { statement })
} else {
Err(Error::FailedPreparedStatement(
Expand All @@ -143,7 +142,7 @@ impl<'a> Connection<'a> {
pub fn query(&self, query: &str) -> Result<QueryResult<'a>, Error> {
let conn = unsafe { (*self.conn.get()).pin_mut() };
let result = ffi::connection_query(conn, ffi::StringView::new(query))?;
if result.isSuccess() {
if ffi::query_result_is_success(&result) {
Ok(QueryResult { result })
} else {
Err(Error::FailedQuery(ffi::query_result_get_error_message(
Expand Down Expand Up @@ -186,12 +185,12 @@ impl<'a> Connection<'a> {
let mut cxx_params = ffi::new_params();
for (key, value) in params {
let ffi_value: cxx::UniquePtr<ffi::Value> = value.try_into()?;
cxx_params.pin_mut().insert(key, ffi_value);
ffi::query_params_insert(cxx_params.pin_mut(), key, ffi_value);
}
let conn = unsafe { (*self.conn.get()).pin_mut() };
let result =
ffi::connection_execute(conn, prepared_statement.statement.pin_mut(), cxx_params)?;
if result.isSuccess() {
if ffi::query_result_is_success(&result) {
Ok(QueryResult { result })
} else {
Err(Error::FailedQuery(ffi::query_result_get_error_message(
Expand All @@ -203,15 +202,15 @@ impl<'a> Connection<'a> {
/// Interrupts all queries currently executing within this connection
pub fn interrupt(&self) -> Result<(), Error> {
let conn = unsafe { (*self.conn.get()).pin_mut() };
Ok(conn.interrupt()?)
Ok(ffi::connection_interrupt(conn)?)
}

/// Sets the query timeout value of the current connection
///
/// A value of zero (the default) disables the timeout.
pub fn set_query_timeout(&self, timeout_ms: u64) {
let conn = unsafe { (*self.conn.get()).pin_mut() };
conn.setQueryTimeOut(timeout_ms);
ffi::connection_set_query_timeout(conn, timeout_ms);
}
}

Expand Down
Loading
Loading