11# Copyright 2018-present Kensho Technologies, LLC.
22"""Transform a SqlNode tree into an executable SQLAlchemy query."""
33from dataclasses import dataclass
4- from typing import Dict , Iterator , List , NamedTuple , Optional , Set , Tuple , Union
4+ from typing import AbstractSet , Dict , Iterator , List , NamedTuple , Optional , Set , Tuple , Union
55
66import six
77import sqlalchemy
2121from . import blocks
2222from ..global_utils import VertexPath
2323from ..schema import COUNT_META_FIELD_NAME
24- from ..schema .schema_info import DirectJoinDescriptor , SQLAlchemySchemaInfo
24+ from ..schema .schema_info import (
25+ CompositeJoinDescriptor ,
26+ DirectJoinDescriptor ,
27+ JoinDescriptor ,
28+ SQLAlchemySchemaInfo ,
29+ )
2530from .compiler_entities import BasicBlock
2631from .compiler_frontend import IrAndMetadata
2732from .expressions import ContextField , Expression
@@ -143,21 +148,30 @@ def _find_used_columns(
143148 )
144149 vertex_field_name = f"{ edge_direction } _{ edge_name } "
145150 edge = sql_schema_info .join_descriptors [location_info .type .name ][vertex_field_name ]
146- used_columns .setdefault (get_vertex_path (location ), set ()).add (edge .from_column )
147- used_columns .setdefault (get_vertex_path (child_location ), set ()).add (edge .to_column )
151+ if isinstance (edge , DirectJoinDescriptor ):
152+ columns_at_location = {edge .from_column }
153+ columns_at_child = {edge .to_column }
154+ elif isinstance (edge , CompositeJoinDescriptor ):
155+ columns_at_location = {column_pair [0 ] for column_pair in edge .column_pairs }
156+ columns_at_child = {column_pair [1 ] for column_pair in edge .column_pairs }
157+ else :
158+ raise AssertionError (f"Unknown join descriptor type { edge } : { type (edge )} " )
159+
160+ used_columns .setdefault (get_vertex_path (location ), set ()).update (columns_at_location )
161+ used_columns .setdefault (get_vertex_path (child_location ), set ()).update (columns_at_child )
148162
149163 # Check if the edge is recursive
150164 child_location_info = ir .query_metadata_table .get_location_info (child_location )
151165 if child_location_info .recursive_scopes_depth > location_info .recursive_scopes_depth :
152166 # The primary key may be used if the recursive cte base semijoins to
153167 # the pre-recurse cte by primary key.
154168 alias = sql_schema_info .vertex_name_to_table [location_info .type .name ].alias ()
155- primary_key_name = _get_primary_key_name ( alias , location_info . type . name , "@recurse" )
156- used_columns .setdefault (get_vertex_path (location ), set ()).add ( primary_key_name )
169+ primary_keys = { column . name for column in alias . primary_key }
170+ used_columns .setdefault (get_vertex_path (location ), set ()).update ( primary_keys )
157171
158172 # The from_column is used at the destination as well, inside the recursive step
159- used_columns .setdefault (get_vertex_path (child_location ), set ()).add (
160- edge . from_column
173+ used_columns .setdefault (get_vertex_path (child_location ), set ()).update (
174+ columns_at_location
161175 )
162176
163177 # Find outputs used
@@ -780,7 +794,9 @@ def __init__(self, sql_schema_info: SQLAlchemySchemaInfo, ir: IrAndMetadata):
780794 # Move to the beginning location of the query.
781795 self ._relocate (ir .query_metadata_table .root_location )
782796
783- # Mapping aliases to the column used to join into them.
797+ # Mapping aliases to one of the column used to join into them. We use this column
798+ # to check for LEFT JOIN misses, since it helps us distinguish actuall NULL values
799+ # from values that are NULL because of a LEFT JOIN miss.
784800 self ._came_from : Dict [Union [Alias , ColumnRouter ], Column ] = {}
785801
786802 self ._recurse_needs_cte : bool = False
@@ -840,9 +856,8 @@ def _relocate(self, new_location: BaseLocation):
840856 self ._current_alias , self ._current_location , output_fields
841857 )
842858
843- # TODO merge from_column and to_column into a joindescriptor
844859 def _join_to_parent_location (
845- self , parent_alias : Alias , from_column : str , to_column : str , optional : bool
860+ self , parent_alias : Alias , join_descriptor : JoinDescriptor , optional : bool
846861 ):
847862 """Join the current location to the parent location using the column names specified."""
848863 if self ._current_alias is None :
@@ -851,7 +866,25 @@ def _join_to_parent_location(
851866 f"during fold { self } ."
852867 )
853868
854- self ._came_from [self ._current_alias ] = self ._current_alias .c [to_column ]
869+ # construct on clause for join
870+ if isinstance (join_descriptor , DirectJoinDescriptor ):
871+ matching_column_pairs : AbstractSet [Tuple [str , str ]] = {
872+ (join_descriptor .from_column , join_descriptor .to_column ),
873+ }
874+ elif isinstance (join_descriptor , CompositeJoinDescriptor ):
875+ matching_column_pairs = join_descriptor .column_pairs
876+ else :
877+ raise AssertionError (
878+ f"Unknown join descriptor type { join_descriptor } : { type (join_descriptor )} "
879+ )
880+
881+ if not matching_column_pairs :
882+ raise AssertionError (
883+ f"Invalid join descriptor { join_descriptor } , produced no matching column pairs."
884+ )
885+
886+ _ , non_null_column = sorted (matching_column_pairs )[0 ]
887+ self ._came_from [self ._current_alias ] = self ._current_alias .c [non_null_column ]
855888
856889 if self ._is_in_optional_scope () and not optional :
857890 # For mandatory edges in optional scope, we emit LEFT OUTER JOIN and enforce the
@@ -879,10 +912,17 @@ def _join_to_parent_location(
879912 )
880913 )
881914
915+ on_clause = sqlalchemy .and_ (
916+ * (
917+ parent_alias .c [from_column ] == self ._current_alias .c [to_column ]
918+ for from_column , to_column in sorted (matching_column_pairs )
919+ )
920+ )
921+
882922 # Join to where we came from.
883923 self ._from_clause = self ._from_clause .join (
884924 self ._current_alias ,
885- onclause = ( parent_alias . c [ from_column ] == self . _current_alias . c [ to_column ]) ,
925+ onclause = on_clause ,
886926 isouter = self ._is_in_optional_scope (),
887927 )
888928
@@ -932,11 +972,14 @@ def traverse(self, vertex_field: str, optional: bool) -> None:
932972 "Attempting to traverse inside a fold while the _current_location was not a "
933973 f"FoldScopeLocation. _current_location was set to { self ._current_location } ."
934974 )
975+ if not isinstance (edge , DirectJoinDescriptor ):
976+ raise NotImplementedError (
977+ f"Edge { vertex_field } is backed by a CompositeJoinDescriptor, "
978+ "so it can't be used inside a @fold scope."
979+ )
935980 self ._current_fold .add_traversal (edge , previous_alias , self ._current_alias )
936981 else :
937- self ._join_to_parent_location (
938- previous_alias , edge .from_column , edge .to_column , optional
939- )
982+ self ._join_to_parent_location (previous_alias , edge , optional )
940983
941984 def _wrap_into_cte (self ) -> None :
942985 """Wrap the current query into a cte."""
@@ -1017,6 +1060,11 @@ def recurse(self, vertex_field: str, depth: int) -> None:
10171060 )
10181061
10191062 edge = self ._sql_schema_info .join_descriptors [self ._current_classname ][vertex_field ]
1063+ if not isinstance (edge , DirectJoinDescriptor ):
1064+ raise NotImplementedError (
1065+ f"Edge { vertex_field } requires a JOIN across a composite key, which is currently "
1066+ f"not supported for use with @recurse."
1067+ )
10201068 primary_key = self ._get_current_primary_key_name ("@recurse" )
10211069
10221070 # Wrap the query so far into a CTE if it would speed up the recursive query.
@@ -1074,8 +1122,8 @@ def recurse(self, vertex_field: str, depth: int) -> None:
10741122 .where (base .c [CTE_DEPTH_NAME ] < literal_depth )
10751123 )
10761124
1077- # TODO(bojanserafimov): This creates an unused alias if there's no tags or outputs so far
1078- self ._join_to_parent_location (previous_alias , primary_key , CTE_KEY_NAME , False )
1125+ join_descriptor = DirectJoinDescriptor ( primary_key , CTE_KEY_NAME )
1126+ self ._join_to_parent_location (previous_alias , join_descriptor , False )
10791127
10801128 def start_global_operations (self ) -> None :
10811129 """Execute a GlobalOperationsStart block."""
@@ -1131,6 +1179,11 @@ def fold(self, fold_scope_location: FoldScopeLocation) -> None:
11311179 join_descriptor = self ._sql_schema_info .join_descriptors [self ._current_classname ][
11321180 full_edge_name
11331181 ]
1182+ if not isinstance (join_descriptor , DirectJoinDescriptor ):
1183+ raise NotImplementedError (
1184+ f"Edge { full_edge_name } requires a JOIN across a composite key, which is currently "
1185+ "not supported for use with @fold."
1186+ )
11341187
11351188 # 3. Initialize fold object.
11361189 self ._current_fold = FoldSubqueryBuilder (
0 commit comments