55Note: This file must be formatted using the Black Python formatter.
66"""
77
8- import os . path
8+ import pathlib
99import subprocess
1010import sys
1111from typing import Required , TypedDict , List , Callable , Optional
@@ -41,7 +41,7 @@ def missing_module(module_name: str) -> None:
4141 .decode ("utf-8" )
4242 .strip ()
4343)
44- build_dir = os . path . join (gitroot , "mad-generation-build" )
44+ build_dir = pathlib . Path (gitroot , "mad-generation-build" )
4545
4646
4747# A project to generate models for
@@ -86,10 +86,10 @@ def clone_project(project: Project) -> str:
8686 git_tag = project .get ("git-tag" )
8787
8888 # Determine target directory
89- target_dir = os . path . join ( build_dir , name )
89+ target_dir = build_dir / name
9090
9191 # Clone only if directory doesn't already exist
92- if not os . path . exists (target_dir ):
92+ if not target_dir . exists ():
9393 if git_tag :
9494 print (f"Cloning { name } from { repo_url } at tag { git_tag } " )
9595 else :
@@ -191,10 +191,10 @@ def build_database(
191191 name = project ["name" ]
192192
193193 # Create database directory path
194- database_dir = os . path . join ( build_dir , f"{ name } -db" )
194+ database_dir = build_dir / f"{ name } -db"
195195
196196 # Only build the database if it doesn't already exist
197- if not os . path . exists (database_dir ):
197+ if not database_dir . exists ():
198198 print (f"Building CodeQL database for { name } ..." )
199199 extractor_options = [option for x in extractor_options for option in ("-O" , x )]
200200 try :
@@ -236,13 +236,16 @@ def generate_models(config, args, project: Project, database_dir: str) -> None:
236236 language = config ["language" ]
237237
238238 generator = mad .Generator (language )
239- # Note: The argument parser converts with-sinks to with_sinks, etc.
240- generator .generateSinks = should_generate_sinks (project )
241- generator .generateSources = should_generate_sources (project )
242- generator .generateSummaries = should_generate_summaries (project )
243- generator .setenvironment (database = database_dir , folder = name )
239+ generator .with_sinks = should_generate_sinks (project )
240+ generator .with_sources = should_generate_sources (project )
241+ generator .with_summaries = should_generate_summaries (project )
244242 generator .threads = args .codeql_threads
245243 generator .ram = args .codeql_ram
244+ if config .get ("single-file" , False ):
245+ generator .single_file = name
246+ else :
247+ generator .folder = name
248+ generator .setenvironment (database = database_dir )
246249 generator .run ()
247250
248251
@@ -313,20 +316,14 @@ def download_artifact(url: str, artifact_name: str, pat: str) -> str:
313316 if response .status_code != 200 :
314317 print (f"Failed to download file. Status code: { response .status_code } " )
315318 sys .exit (1 )
316- target_zip = os . path . join ( build_dir , zipName )
319+ target_zip = build_dir / zipName
317320 with open (target_zip , "wb" ) as file :
318321 for chunk in response .iter_content (chunk_size = 8192 ):
319322 file .write (chunk )
320323 print (f"Download complete: { target_zip } " )
321324 return target_zip
322325
323326
324- def remove_extension (filename : str ) -> str :
325- while "." in filename :
326- filename , _ = os .path .splitext (filename )
327- return filename
328-
329-
330327def pretty_name_from_artifact_name (artifact_name : str ) -> str :
331328 return artifact_name .split ("___" )[1 ]
332329
@@ -348,7 +345,7 @@ def download_dca_databases(
348345 """
349346 print ("\n === Finding projects ===" )
350347 project_map = {project ["name" ]: project for project in projects }
351- analyzed_databases = {}
348+ analyzed_databases = {n : None for n in project_map }
352349 for experiment_name in experiment_names :
353350 response = get_json_from_github (
354351 f"https://raw.githubusercontent.com/github/codeql-dca-main/data/{ experiment_name } /reports/downloads.json" ,
@@ -361,17 +358,24 @@ def download_dca_databases(
361358 artifact_name = analyzed_database ["artifact_name" ]
362359 pretty_name = pretty_name_from_artifact_name (artifact_name )
363360
364- if not pretty_name in project_map :
361+ if not pretty_name in analyzed_databases :
365362 print (f"Skipping { pretty_name } as it is not in the list of projects" )
366363 continue
367364
368- if pretty_name in analyzed_databases :
365+ if analyzed_databases [ pretty_name ] is not None :
369366 print (
370367 f"Skipping previous database { analyzed_databases [pretty_name ]['artifact_name' ]} for { pretty_name } "
371368 )
372369
373370 analyzed_databases [pretty_name ] = analyzed_database
374371
372+ not_found = [name for name , db in analyzed_databases .items () if db is None ]
373+ if not_found :
374+ print (
375+ f"ERROR: The following projects were not found in the DCA experiments: { ', ' .join (not_found )} "
376+ )
377+ sys .exit (1 )
378+
375379 def download_and_decompress (analyzed_database : dict ) -> str :
376380 artifact_name = analyzed_database ["artifact_name" ]
377381 repository = analyzed_database ["repository" ]
@@ -393,19 +397,17 @@ def download_and_decompress(analyzed_database: dict) -> str:
393397 # The database is in a zip file, which contains a tar.gz file with the DB
394398 # First we open the zip file
395399 with zipfile .ZipFile (artifact_zip_location , "r" ) as zip_ref :
396- artifact_unzipped_location = os . path . join ( build_dir , artifact_name )
400+ artifact_unzipped_location = build_dir / artifact_name
397401 # clean up any remnants of previous runs
398402 shutil .rmtree (artifact_unzipped_location , ignore_errors = True )
399403 # And then we extract it to build_dir/artifact_name
400404 zip_ref .extractall (artifact_unzipped_location )
401405 # And then we extract the language tar.gz file inside it
402- artifact_tar_location = os .path .join (
403- artifact_unzipped_location , f"{ language } .tar.gz"
404- )
406+ artifact_tar_location = artifact_unzipped_location / f"{ language } .tar.gz"
405407 with tarfile .open (artifact_tar_location , "r:gz" ) as tar_ref :
406408 # And we just untar it to the same directory as the zip file
407409 tar_ref .extractall (artifact_unzipped_location )
408- ret = os . path . join ( artifact_unzipped_location , language )
410+ ret = artifact_unzipped_location / language
409411 print (f"Decompression complete: { ret } " )
410412 return ret
411413
@@ -425,8 +427,16 @@ def download_and_decompress(analyzed_database: dict) -> str:
425427 return [(project_map [n ], r ) for n , r in zip (analyzed_databases , results )]
426428
427429
428- def get_mad_destination_for_project (config , name : str ) -> str :
429- return os .path .join (config ["destination" ], name )
430+ def clean_up_mad_destination_for_project (config , name : str ):
431+ target = pathlib .Path (config ["destination" ], name )
432+ if config .get ("single-file" , False ):
433+ target = target .with_suffix (".model.yml" )
434+ if target .exists ():
435+ print (f"Deleting existing MaD file at { target } " )
436+ target .unlink ()
437+ elif target .exists ():
438+ print (f"Deleting existing MaD directory at { target } " )
439+ shutil .rmtree (target , ignore_errors = True )
430440
431441
432442def get_strategy (config ) -> str :
@@ -448,8 +458,7 @@ def main(config, args) -> None:
448458 language = config ["language" ]
449459
450460 # Create build directory if it doesn't exist
451- if not os .path .exists (build_dir ):
452- os .makedirs (build_dir )
461+ build_dir .mkdir (parents = True , exist_ok = True )
453462
454463 database_results = []
455464 match get_strategy (config ):
@@ -469,7 +478,7 @@ def main(config, args) -> None:
469478 if args .pat is None :
470479 print ("ERROR: --pat argument is required for DCA strategy" )
471480 sys .exit (1 )
472- if not os . path .exists (args . pat ):
481+ if not args . pat .exists ():
473482 print (f"ERROR: Personal Access Token file '{ pat } ' does not exist." )
474483 sys .exit (1 )
475484 with open (args .pat , "r" ) as f :
@@ -493,12 +502,9 @@ def main(config, args) -> None:
493502 )
494503 sys .exit (1 )
495504
496- # Delete the MaD directory for each project
497- for project , database_dir in database_results :
498- mad_dir = get_mad_destination_for_project (config , project ["name" ])
499- if os .path .exists (mad_dir ):
500- print (f"Deleting existing MaD directory at { mad_dir } " )
501- subprocess .check_call (["rm" , "-rf" , mad_dir ])
505+ # clean up existing MaD data for the projects
506+ for project , _ in database_results :
507+ clean_up_mad_destination_for_project (config , project ["name" ])
502508
503509 for project , database_dir in database_results :
504510 if database_dir is not None :
@@ -508,7 +514,10 @@ def main(config, args) -> None:
508514if __name__ == "__main__" :
509515 parser = argparse .ArgumentParser ()
510516 parser .add_argument (
511- "--config" , type = str , help = "Path to the configuration file." , required = True
517+ "--config" ,
518+ type = pathlib .Path ,
519+ help = "Path to the configuration file." ,
520+ required = True ,
512521 )
513522 parser .add_argument (
514523 "--dca" ,
@@ -519,13 +528,13 @@ def main(config, args) -> None:
519528 )
520529 parser .add_argument (
521530 "--pat" ,
522- type = str ,
531+ type = pathlib . Path ,
523532 help = "Path to a file containing the PAT token required to grab DCA databases (the same as the one you use for DCA)" ,
524533 )
525534 parser .add_argument (
526535 "--codeql-ram" ,
527536 type = int ,
528- help = "What `--ram` value to pass to `codeql` while generating models (by default the flag is not passed )" ,
537+ help = "What `--ram` value to pass to `codeql` while generating models (by default 2048 MB per thread )" ,
529538 default = None ,
530539 )
531540 parser .add_argument (
@@ -538,7 +547,7 @@ def main(config, args) -> None:
538547
539548 # Load config file
540549 config = {}
541- if not os . path .exists (args . config ):
550+ if not args . config .exists ():
542551 print (f"ERROR: Config file '{ args .config } ' does not exist." )
543552 sys .exit (1 )
544553 try :
0 commit comments