@@ -27,7 +27,7 @@ use arrow::pyarrow::FromPyArrow;
2727use datafusion:: arrow:: datatypes:: { DataType , Schema , SchemaRef } ;
2828use datafusion:: arrow:: pyarrow:: PyArrowType ;
2929use datafusion:: arrow:: record_batch:: RecordBatch ;
30- use datafusion:: catalog:: { CatalogProvider , CatalogProviderList } ;
30+ use datafusion:: catalog:: { CatalogProvider , CatalogProviderList , TableProviderFactory } ;
3131use datafusion:: common:: { ScalarValue , TableReference , exec_err} ;
3232use datafusion:: datasource:: file_format:: file_compression_type:: FileCompressionType ;
3333use datafusion:: datasource:: file_format:: parquet:: ParquetFormat ;
@@ -51,6 +51,7 @@ use datafusion_ffi::catalog_provider::FFI_CatalogProvider;
5151use datafusion_ffi:: catalog_provider_list:: FFI_CatalogProviderList ;
5252use datafusion_ffi:: execution:: FFI_TaskContextProvider ;
5353use datafusion_ffi:: proto:: logical_extension_codec:: FFI_LogicalExtensionCodec ;
54+ use datafusion_ffi:: table_provider_factory:: FFI_TableProviderFactory ;
5455use datafusion_proto:: logical_plan:: DefaultLogicalExtensionCodec ;
5556use datafusion_python_util:: {
5657 create_logical_extension_capsule, ffi_logical_codec_from_pycapsule, get_global_ctx,
@@ -81,7 +82,7 @@ use crate::record_batch::PyRecordBatchStream;
8182use crate :: sql:: logical:: PyLogicalPlan ;
8283use crate :: sql:: util:: replace_placeholders_with_strings;
8384use crate :: store:: StorageContexts ;
84- use crate :: table:: PyTable ;
85+ use crate :: table:: { PyTable , RustWrappedPyTableProviderFactory } ;
8586use crate :: udaf:: PyAggregateUDF ;
8687use crate :: udf:: PyScalarUDF ;
8788use crate :: udtf:: PyTableFunction ;
@@ -659,6 +660,43 @@ impl PySessionContext {
659660 Ok ( ( ) )
660661 }
661662
663+ pub fn register_table_factory (
664+ & self ,
665+ format : & str ,
666+ mut factory : Bound < ' _ , PyAny > ,
667+ ) -> PyDataFusionResult < ( ) > {
668+ if factory. hasattr ( "__datafusion_table_provider_factory__" ) ? {
669+ let py = factory. py ( ) ;
670+ let codec_capsule = create_logical_extension_capsule ( py, self . logical_codec . as_ref ( ) ) ?;
671+ factory = factory
672+ . getattr ( "__datafusion_table_provider_factory__" ) ?
673+ . call1 ( ( codec_capsule, ) ) ?;
674+ }
675+
676+ let factory: Arc < dyn TableProviderFactory > =
677+ if let Ok ( capsule) = factory. cast :: < PyCapsule > ( ) . map_err ( py_datafusion_err) {
678+ validate_pycapsule ( capsule, "datafusion_table_provider_factory" ) ?;
679+
680+ let data: NonNull < FFI_TableProviderFactory > = capsule
681+ . pointer_checked ( Some ( c_str ! ( "datafusion_table_provider_factory" ) ) ) ?
682+ . cast ( ) ;
683+ let factory = unsafe { data. as_ref ( ) } ;
684+ factory. into ( )
685+ } else {
686+ Arc :: new ( RustWrappedPyTableProviderFactory :: new (
687+ factory. into ( ) ,
688+ self . logical_codec . clone ( ) ,
689+ ) )
690+ } ;
691+
692+ let st = self . ctx . state_ref ( ) ;
693+ let mut lock = st. write ( ) ;
694+ lock. table_factories_mut ( )
695+ . insert ( format. to_owned ( ) , factory) ;
696+
697+ Ok ( ( ) )
698+ }
699+
662700 pub fn register_catalog_provider_list (
663701 & self ,
664702 mut provider : Bound < PyAny > ,
0 commit comments