@@ -921,6 +921,7 @@ def get_schema(
921921 spec : Any ,
922922 overrides : Optional [List [ParamSpec ]] = None ,
923923 mode : str = 'parameter' ,
924+ masks : Optional [List [bool ]] = None ,
924925) -> Tuple [List [ParamSpec ], str , str ]:
925926 """
926927 Expand a return type annotation into a list of types and field names.
@@ -933,6 +934,8 @@ def get_schema(
933934 List of SQL type specifications for the return type
934935 mode : str
935936 The mode of the function, either 'parameter' or 'return'
937+ is_masked : bool
938+ Whether the type is wrapped in a Masked type
936939
937940 Returns
938941 -------
@@ -994,7 +997,13 @@ def get_schema(
994997 'dataclass, TypedDict, or pydantic model' ,
995998 )
996999 spec = typing .get_args (unpacked_spec [0 ])[0 ]
997- data_format = 'list'
1000+ # Lists as output from TVFs are considered scalar outputs
1001+ # since they correspond to individual Python objects, not
1002+ # a true vector type.
1003+ if function_type == 'tvf' :
1004+ data_format = 'scalar'
1005+ else :
1006+ data_format = 'list'
9981007
9991008 elif all ([utils .is_vector (x , include_masks = True ) for x in unpacked_spec ]):
10001009 pass
@@ -1111,7 +1120,11 @@ def get_schema(
11111120 _ , inner_apply_meta = unpack_annotated (typing .get_args (spec )[0 ])
11121121 if inner_apply_meta .sql_type :
11131122 udf_attrs = inner_apply_meta
1114- colspec = get_schema (typing .get_args (spec )[0 ], mode = mode )[0 ]
1123+ colspec = get_schema (
1124+ typing .get_args (spec )[0 ],
1125+ mode = mode ,
1126+ masks = [masks [0 ]] if masks else None ,
1127+ )[0 ]
11151128 else :
11161129 colspec = [
11171130 ParamSpec (
@@ -1142,6 +1155,7 @@ def get_schema(
11421155 overrides = [overrides [i ]] if overrides else [],
11431156 # Always pass UDF mode for individual items
11441157 mode = mode ,
1158+ masks = [masks [i ]] if masks else None ,
11451159 )
11461160
11471161 # Use the name from the overrides if specified
@@ -1183,7 +1197,7 @@ def get_schema(
11831197 out = []
11841198
11851199 # Normalize colspec data types
1186- for c in colspec :
1200+ for i , c in enumerate ( colspec ) :
11871201
11881202 # if the dtype is a string, it is resolved already
11891203 if isinstance (c .dtype , str ):
@@ -1201,13 +1215,27 @@ def get_schema(
12011215 include_null = c .is_optional ,
12021216 )
12031217
1218+ sql_type = c .sql_type if isinstance (c .sql_type , str ) else udf_attrs .sql_type
1219+
1220+ is_optional = (
1221+ c .is_optional
1222+ or bool (dtype and dtype .endswith ('?' ))
1223+ or bool (masks and masks [i ])
1224+ )
1225+
1226+ if is_optional :
1227+ if dtype and not dtype .endswith ('?' ):
1228+ dtype += '?'
1229+ if sql_type and re .search (r' NOT NULL\b' , sql_type ):
1230+ sql_type = re .sub (r' NOT NULL\b' , r' NULL' , sql_type )
1231+
12041232 p = ParamSpec (
12051233 name = c .name ,
12061234 dtype = dtype ,
1207- sql_type = c . sql_type if isinstance ( c . sql_type , str ) else udf_attrs . sql_type ,
1208- is_optional = c . is_optional or bool ( dtype and dtype . endswith ( '?' )) ,
1209- transformer = udf_attrs .input_transformer
1210- if mode == 'parameter' else udf_attrs .output_transformer ,
1235+ sql_type = sql_type ,
1236+ is_optional = is_optional ,
1237+ transformer = udf_attrs .args_transformer
1238+ if mode == 'parameter' else udf_attrs .returns_transformer ,
12111239 )
12121240
12131241 out .append (p )
@@ -1345,6 +1373,7 @@ def get_signature(
13451373 unpack_masked_type (param .annotation ),
13461374 overrides = [args_colspec [i ]] if args_colspec else [],
13471375 mode = 'parameter' ,
1376+ masks = [args_masks [i ]] if args_masks else [],
13481377 )
13491378 args_data_formats .append (args_data_format )
13501379
@@ -1404,6 +1433,7 @@ def get_signature(
14041433 unpack_masked_type (signature .return_annotation ),
14051434 overrides = returns_colspec if returns_colspec else None ,
14061435 mode = 'return' ,
1436+ masks = ret_masks or [],
14071437 )
14081438
14091439 rdf = out ['returns_data_format' ] = out ['returns_data_format' ] or 'scalar'
@@ -1419,6 +1449,12 @@ def get_signature(
14191449 'scalar or vector types.' ,
14201450 )
14211451
1452+ # If we hava function parameters and the function is a TVF, then
1453+ # the return type should just match the parameter vector types. This ensures
1454+ # the output producers for scalars and vectors are consistent.
1455+ elif function_type == 'tvf' and rdf == 'scalar' and args_schema :
1456+ out ['returns_data_format' ] = out ['args_data_format' ]
1457+
14221458 # All functions have to return a value, so if none was specified try to
14231459 # insert a reasonable default that includes NULLs.
14241460 if not ret_schema :
0 commit comments