1+ #!/usr/bin/env python3
12"""
23Experimental script for bulk generation of MaD models based on a list of projects.
34
78import os .path
89import subprocess
910import sys
10- from typing import NotRequired , TypedDict , List
11+ from typing import Required , TypedDict , List , Callable , Optional
1112from concurrent .futures import ThreadPoolExecutor , as_completed
1213import time
1314import argparse
14- import json
15- import requests
1615import zipfile
1716import tarfile
18- from functools import cmp_to_key
17+ import shutil
18+
19+
20+ def missing_module (module_name : str ) -> None :
21+ print (
22+ f"ERROR: { module_name } is not installed. Please install it with 'pip install { module_name } '."
23+ )
24+ sys .exit (1 )
25+
26+
27+ try :
28+ import yaml
29+ except ImportError :
30+ missing_module ("pyyaml" )
31+
32+ try :
33+ import requests
34+ except ImportError :
35+ missing_module ("requests" )
1936
2037import generate_mad as mad
2138
2845
2946
3047# A project to generate models for
31- class Project (TypedDict ):
32- """
33- Type definition for projects (acquired via a GitHub repo) to model.
34-
35- Attributes:
36- name: The name of the project
37- git_repo: URL to the git repository
38- git_tag: Optional Git tag to check out
39- """
40-
41- name : str
42- git_repo : NotRequired [str ]
43- git_tag : NotRequired [str ]
44- with_sinks : NotRequired [bool ]
45- with_sinks : NotRequired [bool ]
46- with_summaries : NotRequired [bool ]
48+ Project = TypedDict (
49+ "Project" ,
50+ {
51+ "name" : Required [str ],
52+ "git-repo" : str ,
53+ "git-tag" : str ,
54+ "with-sinks" : bool ,
55+ "with-sources" : bool ,
56+ "with-summaries" : bool ,
57+ },
58+ total = False ,
59+ )
4760
4861
4962def should_generate_sinks (project : Project ) -> bool :
@@ -63,14 +76,14 @@ def clone_project(project: Project) -> str:
6376 Shallow clone a project into the build directory.
6477
6578 Args:
66- project: A dictionary containing project information with 'name', 'git_repo ', and optional 'git_tag ' keys.
79+ project: A dictionary containing project information with 'name', 'git-repo ', and optional 'git-tag ' keys.
6780
6881 Returns:
6982 The path to the cloned project directory.
7083 """
7184 name = project ["name" ]
72- repo_url = project ["git_repo " ]
73- git_tag = project .get ("git_tag " )
85+ repo_url = project ["git-repo " ]
86+ git_tag = project .get ("git-tag " )
7487
7588 # Determine target directory
7689 target_dir = os .path .join (build_dir , name )
@@ -103,6 +116,39 @@ def clone_project(project: Project) -> str:
103116 return target_dir
104117
105118
119+ def run_in_parallel [
120+ T , U
121+ ](
122+ func : Callable [[T ], U ],
123+ items : List [T ],
124+ * ,
125+ on_error = lambda item , exc : None ,
126+ error_summary = lambda failures : None ,
127+ max_workers = 8 ,
128+ ) -> List [Optional [U ]]:
129+ if not items :
130+ return []
131+ max_workers = min (max_workers , len (items ))
132+ results = [None for _ in range (len (items ))]
133+ with ThreadPoolExecutor (max_workers = max_workers ) as executor :
134+ # Start cloning tasks and keep track of them
135+ futures = {
136+ executor .submit (func , item ): index for index , item in enumerate (items )
137+ }
138+ # Process results as they complete
139+ for future in as_completed (futures ):
140+ index = futures [future ]
141+ try :
142+ results [index ] = future .result ()
143+ except Exception as e :
144+ on_error (items [index ], e )
145+ failed = [item for item , result in zip (items , results ) if result is None ]
146+ if failed :
147+ error_summary (failed )
148+ sys .exit (1 )
149+ return results
150+
151+
106152def clone_projects (projects : List [Project ]) -> List [tuple [Project , str ]]:
107153 """
108154 Clone all projects in parallel.
@@ -114,40 +160,19 @@ def clone_projects(projects: List[Project]) -> List[tuple[Project, str]]:
114160 List of (project, project_dir) pairs in the same order as the input projects
115161 """
116162 start_time = time .time ()
117- max_workers = min (8 , len (projects )) # Use at most 8 threads
118- project_dirs_map = {} # Map to store results by project name
119-
120- with ThreadPoolExecutor (max_workers = max_workers ) as executor :
121- # Start cloning tasks and keep track of them
122- future_to_project = {
123- executor .submit (clone_project , project ): project for project in projects
124- }
125-
126- # Process results as they complete
127- for future in as_completed (future_to_project ):
128- project = future_to_project [future ]
129- try :
130- project_dir = future .result ()
131- project_dirs_map [project ["name" ]] = (project , project_dir )
132- except Exception as e :
133- print (f"ERROR: Failed to clone { project ['name' ]} : { e } " )
134-
135- if len (project_dirs_map ) != len (projects ):
136- failed_projects = [
137- project ["name" ]
138- for project in projects
139- if project ["name" ] not in project_dirs_map
140- ]
141- print (
142- f"ERROR: Only { len (project_dirs_map )} out of { len (projects )} projects were cloned successfully. Failed projects: { ', ' .join (failed_projects )} "
143- )
144- sys .exit (1 )
145-
146- project_dirs = [project_dirs_map [project ["name" ]] for project in projects ]
147-
163+ dirs = run_in_parallel (
164+ clone_project ,
165+ projects ,
166+ on_error = lambda project , exc : print (
167+ f"ERROR: Failed to clone project { project ['name' ]} : { exc } "
168+ ),
169+ error_summary = lambda failures : print (
170+ f"ERROR: Failed to clone { len (failures )} projects: { ', ' .join (p ['name' ] for p in failures )} "
171+ ),
172+ )
148173 clone_time = time .time () - start_time
149174 print (f"Cloning completed in { clone_time :.2f} seconds" )
150- return project_dirs
175+ return list ( zip ( projects , dirs ))
151176
152177
153178def build_database (
@@ -159,7 +184,7 @@ def build_database(
159184 Args:
160185 language: The language for which to build the database (e.g., "rust").
161186 extractor_options: Additional options for the extractor.
162- project: A dictionary containing project information with 'name' and 'git_repo ' keys.
187+ project: A dictionary containing project information with 'name' and 'git-repo ' keys.
163188 project_dir: Path to the CodeQL database.
164189
165190 Returns:
@@ -307,7 +332,10 @@ def pretty_name_from_artifact_name(artifact_name: str) -> str:
307332
308333
309334def download_dca_databases (
310- experiment_name : str , pat : str , projects : List [Project ]
335+ language : str ,
336+ experiment_name : str ,
337+ pat : str ,
338+ projects : List [Project ],
311339) -> List [tuple [Project , str | None ]]:
312340 """
313341 Download databases from a DCA experiment.
@@ -318,14 +346,14 @@ def download_dca_databases(
318346 Returns:
319347 List of (project_name, database_dir) pairs, where database_dir is None if the download failed.
320348 """
321- database_results = {}
322349 print ("\n === Finding projects ===" )
323350 response = get_json_from_github (
324351 f"https://raw.githubusercontent.com/github/codeql-dca-main/data/{ experiment_name } /reports/downloads.json" ,
325352 pat ,
326353 )
327354 targets = response ["targets" ]
328355 project_map = {project ["name" ]: project for project in projects }
356+ analyzed_databases = {}
329357 for data in targets .values ():
330358 downloads = data ["downloads" ]
331359 analyzed_database = downloads ["analyzed_database" ]
@@ -336,6 +364,15 @@ def download_dca_databases(
336364 print (f"Skipping { pretty_name } as it is not in the list of projects" )
337365 continue
338366
367+ if pretty_name in analyzed_databases :
368+ print (
369+ f"Skipping previous database { analyzed_databases [pretty_name ]['artifact_name' ]} for { pretty_name } "
370+ )
371+
372+ analyzed_databases [pretty_name ] = analyzed_database
373+
374+ def download_and_decompress (analyzed_database : dict ) -> str :
375+ artifact_name = analyzed_database ["artifact_name" ]
339376 repository = analyzed_database ["repository" ]
340377 run_id = analyzed_database ["run_id" ]
341378 print (f"=== Finding artifact: { artifact_name } ===" )
@@ -351,27 +388,40 @@ def download_dca_databases(
351388 artifact_zip_location = download_artifact (
352389 archive_download_url , artifact_name , pat
353390 )
354- print (f"=== Extracting artifact: { artifact_name } ===" )
391+ print (f"=== Decompressing artifact: { artifact_name } ===" )
355392 # The database is in a zip file, which contains a tar.gz file with the DB
356393 # First we open the zip file
357394 with zipfile .ZipFile (artifact_zip_location , "r" ) as zip_ref :
358395 artifact_unzipped_location = os .path .join (build_dir , artifact_name )
396+ # clean up any remnants of previous runs
397+ shutil .rmtree (artifact_unzipped_location , ignore_errors = True )
359398 # And then we extract it to build_dir/artifact_name
360399 zip_ref .extractall (artifact_unzipped_location )
361- # And then we iterate over the contents of the extracted directory
362- # and extract the tar.gz files inside it
363- for entry in os .listdir (artifact_unzipped_location ):
364- artifact_tar_location = os .path .join (artifact_unzipped_location , entry )
365- with tarfile .open (artifact_tar_location , "r:gz" ) as tar_ref :
366- # And we just untar it to the same directory as the zip file
367- tar_ref .extractall (artifact_unzipped_location )
368- database_results [pretty_name ] = os .path .join (
369- artifact_unzipped_location , remove_extension (entry )
370- )
400+ # And then we extract the language tar.gz file inside it
401+ artifact_tar_location = os .path .join (
402+ artifact_unzipped_location , f"{ language } .tar.gz"
403+ )
404+ with tarfile .open (artifact_tar_location , "r:gz" ) as tar_ref :
405+ # And we just untar it to the same directory as the zip file
406+ tar_ref .extractall (artifact_unzipped_location )
407+ ret = os .path .join (artifact_unzipped_location , language )
408+ print (f"Decompression complete: { ret } " )
409+ return ret
410+
411+ results = run_in_parallel (
412+ download_and_decompress ,
413+ list (analyzed_databases .values ()),
414+ on_error = lambda db , exc : print (
415+ f"ERROR: Failed to download and decompress { db ["artifact_name" ]} : { exc } "
416+ ),
417+ error_summary = lambda failures : print (
418+ f"ERROR: Failed to download { len (failures )} databases: { ', ' .join (item [0 ] for item in failures )} "
419+ ),
420+ )
371421
372- print (f"\n === Extracted { len (database_results )} databases ===" )
422+ print (f"\n === Fetched { len (results )} databases ===" )
373423
374- return [(project , database_results [ project [ "name" ]] ) for project in projects ]
424+ return [(project_map [ n ], r ) for n , r in zip ( analyzed_databases , results ) ]
375425
376426
377427def get_mad_destination_for_project (config , name : str ) -> str :
@@ -422,7 +472,9 @@ def main(config, args) -> None:
422472 case "repo" :
423473 extractor_options = config .get ("extractor_options" , [])
424474 database_results = build_databases_from_projects (
425- language , extractor_options , projects
475+ language ,
476+ extractor_options ,
477+ projects ,
426478 )
427479 case "dca" :
428480 experiment_name = args .dca
@@ -439,7 +491,10 @@ def main(config, args) -> None:
439491 with open (args .pat , "r" ) as f :
440492 pat = f .read ().strip ()
441493 database_results = download_dca_databases (
442- experiment_name , pat , projects
494+ language ,
495+ experiment_name ,
496+ pat ,
497+ projects ,
443498 )
444499
445500 # Generate models for all projects
@@ -492,9 +547,9 @@ def main(config, args) -> None:
492547 sys .exit (1 )
493548 try :
494549 with open (args .config , "r" ) as f :
495- config = json . load (f )
496- except json . JSONDecodeError as e :
497- print (f"ERROR: Failed to parse JSON file { args .config } : { e } " )
550+ config = yaml . safe_load (f )
551+ except yaml . YAMLError as e :
552+ print (f"ERROR: Failed to parse YAML file { args .config } : { e } " )
498553 sys .exit (1 )
499554
500555 main (config , args )
0 commit comments