diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index 87170f595f413..8f9861c9487df 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -1453,6 +1453,9 @@ impl SessionContext { RegisterFunction::Window(f) => { self.state.write().register_udwf(f)?; } + RegisterFunction::HigherOrder(f) => { + self.state.write().register_higher_order_function(f)?; + } RegisterFunction::Table(name, f) => self.register_udtf(&name, f), }; @@ -1467,6 +1470,11 @@ impl SessionContext { dropped |= self.state.write().deregister_udaf(&stmt.name)?.is_some(); dropped |= self.state.write().deregister_udwf(&stmt.name)?.is_some(); dropped |= self.state.write().deregister_udtf(&stmt.name)?.is_some(); + dropped |= self + .state + .write() + .deregister_higher_order_function(&stmt.name)? + .is_some(); // DROP FUNCTION IF EXISTS drops the specified function only if that // function exists and in this way, it avoids error. While the DROP FUNCTION @@ -1566,6 +1574,20 @@ impl SessionContext { state.register_udf(Arc::new(f)).ok(); } + /// Registers a higher-order function within this context. + /// + /// Note in SQL queries, function names are looked up using + /// lowercase unless the query uses quotes. For example, + /// + /// - `SELECT MY_HIGHER_ORDER_FUNC(x)...` will look for a function named `"my_higher_order_func"` + /// - `SELECT "my_HIGHER_ORDER_FUNC"(x)` will look for a function named `"my_HIGHER_ORDER_FUNC"` + /// + /// Any functions registered with the function name or its aliases will be overwritten with this new function + pub fn register_higher_order_function(&self, f: Arc) { + let mut state = self.state.write(); + state.register_higher_order_function(f).ok(); + } + /// Registers an aggregate UDF within this context. /// /// Note in SQL queries, aggregate names are looked up using @@ -1605,6 +1627,14 @@ impl SessionContext { self.state.write().deregister_udf(name).ok(); } + /// Deregisters a higher-order function within this context. + pub fn deregister_higher_order_function(&self, name: &str) { + self.state + .write() + .deregister_higher_order_function(name) + .ok(); + } + /// Deregisters a UDAF within this context. pub fn deregister_udaf(&self, name: &str) { self.state.write().deregister_udaf(name).ok(); @@ -2135,6 +2165,8 @@ pub enum RegisterFunction { Aggregate(Arc), /// Window user defined function Window(Arc), + /// Higher-order user defined function + HigherOrder(Arc), /// Table user defined function Table(String, Arc), }