@@ -160,64 +160,45 @@ def __assert_indexer_outputs(
160160 stats = json .loads ((output_path / "stats.json" ).read_bytes ().decode ("utf-8" ))
161161
162162 # Check all workflows run
163- expected_artifacts = 0
164163 expected_workflows = set (workflow_config .keys ())
165164 workflows = set (stats ["workflows" ].keys ())
166165 assert workflows == expected_workflows , (
167166 f"Workflows missing from stats.json: { expected_workflows - workflows } . Unexpected workflows in stats.json: { workflows - expected_workflows } "
168167 )
169168
170169 # [OPTIONAL] Check runtime
171- for workflow in expected_workflows :
170+ for workflow , config in workflow_config . items () :
172171 # Check expected artifacts
173- expected_artifacts = expected_artifacts + workflow_config [workflow ].get (
174- "expected_artifacts" , 1
175- )
172+ workflow_artifacts = config .get ("expected_artifacts" , [])
176173 # Check max runtime
177- max_runtime = workflow_config [ workflow ] .get ("max_runtime" , None )
174+ max_runtime = config .get ("max_runtime" , None )
178175 if max_runtime :
179176 assert stats ["workflows" ][workflow ]["overall" ] <= max_runtime , (
180177 f"Expected max runtime of { max_runtime } , found: { stats ['workflows' ][workflow ]['overall' ]} for workflow: { workflow } "
181178 )
182-
183- # Check artifacts
184- artifact_files = os .listdir (output_path )
185-
186- # check that the number of workflows matches the number of artifacts
187- assert len (artifact_files ) == (expected_artifacts + 3 ), (
188- f"Expected { expected_artifacts + 3 } artifacts, found: { len (artifact_files )} "
189- ) # Embeddings add to the count
190-
191- for artifact in artifact_files :
192- if artifact .endswith (".parquet" ):
193- output_df = pd .read_parquet (output_path / artifact )
194- artifact_name = artifact .split ("." )[0 ]
195-
196- try :
197- workflow = workflow_config [artifact_name ]
179+ # Check expected artifacts
180+ for artifact in workflow_artifacts :
181+ if artifact .endswith (".parquet" ):
182+ output_df = pd .read_parquet (output_path / artifact )
198183
199184 # Check number of rows between range
200185 assert (
201- workflow ["row_range" ][0 ]
186+ config ["row_range" ][0 ]
202187 <= len (output_df )
203- <= workflow ["row_range" ][1 ]
188+ <= config ["row_range" ][1 ]
204189 ), (
205- f"Expected between { workflow ['row_range' ][0 ]} and { workflow ['row_range' ][1 ]} , found: { len (output_df )} for file: { artifact } "
190+ f"Expected between { config ['row_range' ][0 ]} and { config ['row_range' ][1 ]} , found: { len (output_df )} for file: { artifact } "
206191 )
207192
208193 # Get non-nan rows
209194 nan_df = output_df .loc [
210195 :,
211- ~ output_df .columns .isin (
212- workflow .get ("nan_allowed_columns" , [])
213- ),
196+ ~ output_df .columns .isin (config .get ("nan_allowed_columns" , [])),
214197 ]
215198 nan_df = nan_df [nan_df .isna ().any (axis = 1 )]
216199 assert len (nan_df ) == 0 , (
217200 f"Found { len (nan_df )} rows with NaN values for file: { artifact } on columns: { nan_df .columns [nan_df .isna ().any ()].tolist ()} "
218201 )
219- except KeyError :
220- log .warning ("No workflow config found %s" , artifact_name )
221202
222203 def __run_query (self , root : Path , query_config : dict [str , str ]):
223204 command = [
0 commit comments