diff --git a/examples/semantic_similarity_generate.py b/examples/semantic_similarity_generate.py new file mode 100644 index 00000000..22056365 --- /dev/null +++ b/examples/semantic_similarity_generate.py @@ -0,0 +1,17 @@ +from automation.tasks import SemanticSimilarityGenerateTask + +task = SemanticSimilarityGenerateTask( + project_name="semantic_similarity_debug", + task_name="semantic_generation_qwen3_14b_w4a16_feedback", + #task_name="semantic_generation_qwen3_14b_feedback", + branch="semantic_similarity", + packages = ["huggingface-hub==0.34.3", "triton==3.3.1", "vllm==0.10.1.1"], + dataset_args = {"tatsu-lab/alpaca" : 300 , "garage-bAInd/Open-Platypus": "310", "allenai/tulu-3-sft-mixture": 320}, + model_id="Qwen/Qwen3-14B", + max_new_tokens=1024, + max_model_len=4096, + semantic_similarity_args={"enable-chunked-prefill": True, "enforce_eager": True, "dtype" :"auto", "device_map": "auto", "temperature": 0.0}, +) + +task.execute_remotely("oneshot-a100x1") + diff --git a/examples/semantic_similarity_score.py b/examples/semantic_similarity_score.py new file mode 100644 index 00000000..c00417ce --- /dev/null +++ b/examples/semantic_similarity_score.py @@ -0,0 +1,20 @@ +from automation.tasks import SemanticSimilarityScoreTask + +task = SemanticSimilarityScoreTask( + project_name="semantic_similarity_debug", + #task_name="semantic_scoring_14b", + task_name="semantic_scoring_4b", + branch="semantic_similarity", + packages = ["huggingface-hub==0.34.3", "networkx==3.4.2", "datasets==4.2.0", "rouge_score==0.1.2", "bert-score==0.3.13", "sentence-transformers==5.1.1", "matplotlib"], + reference_model_project_name="semantic_similarity_debug", + candidate_model_project_name="semantic_similarity_debug", + reference_model_task_name="semantic_generation_qwen3_14b_feedback", + #reference_model_task_name="semantic_generation_qwen3_14b_base", + candidate_model_task_name="semantic_generation_qwen3_14b_w4a16_feedback", + #candidate_model_task_name="semantic_generation_qwen3_14b_w4a16", + sts_model_id="all-MiniLM-L6-v2", + rouge_scores=["rouge1", "rougeL"], + low_score_threshold_args={"f1": 0.79, "rouge1": 0.65, "sts": 0.71}, +) + +task.execute_remotely("oneshot-a100x1") diff --git a/src/automation/datasets/__init__.py b/src/automation/datasets/__init__.py index 629001aa..4f7c8d28 100644 --- a/src/automation/datasets/__init__.py +++ b/src/automation/datasets/__init__.py @@ -4,6 +4,10 @@ from automation.datasets.openthoughts import DATASET_PATH as OPENTHOUGHTSDATASET from automation.datasets.utils import load_llm_messages, load_vlm_messages from automation.datasets.fleurs import load_fleurs_dataset +from automation.datasets.tulu import make_tulu_prompt +from automation.datasets.openplatypus import make_openplatypus_prompt +from automation.datasets.alpaca import make_alpaca_prompt +from automation.datasets.defaults import make_default_prompt SUPPORTED_DATASETS = { "calibration": load_calibration_dataset, @@ -17,6 +21,10 @@ "load_openthoughts_dataset", "load_llm_messages", "load_vlm_messages", + "make_tulu_prompt", + "make_openplatypus_prompt", + "make_alpaca_prompt", + "make_default_prompt", "load_fleurs_dataset", "SUPPORTED_DATASETS", -] \ No newline at end of file +] diff --git a/src/automation/datasets/alpaca.py b/src/automation/datasets/alpaca.py new file mode 100644 index 00000000..7d07be4d --- /dev/null +++ b/src/automation/datasets/alpaca.py @@ -0,0 +1,22 @@ +def make_alpaca_prompt(sample): + instruction = sample["instruction"].strip() + input_text = sample.get("input", "").strip() + + if input_text == "": + messages = [ + { + "role": "user", + "content": f"{instruction}", + } + ] + + + else: + messages = [ + { + "role": "user", + "content": f"{instruction}\n{input_text}", + } + ] + + return messages diff --git a/src/automation/datasets/defaults.py b/src/automation/datasets/defaults.py new file mode 100644 index 00000000..36ab8953 --- /dev/null +++ b/src/automation/datasets/defaults.py @@ -0,0 +1,12 @@ +def make_default_prompt(sample): + messages = [ + { + "role": "user", + "content": f"{json.dumps(sample)}", + } + ] + + prompt = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) + + return prompt + diff --git a/src/automation/datasets/openplatypus.py b/src/automation/datasets/openplatypus.py new file mode 100644 index 00000000..6b6f2116 --- /dev/null +++ b/src/automation/datasets/openplatypus.py @@ -0,0 +1,21 @@ +def make_openplatypus_prompt(sample): + instruction = sample["instruction"].strip() + input_text = sample.get("input", "").strip() + + if input_text == "": + messages = [ + { + "role": "user", + "content": f"{instruction}", + } + ] + + else: + messages = [ + { + "role": "user", + "content": f"{instruction}\n{input_text}", + } + ] + + return messages diff --git a/src/automation/datasets/tulu.py b/src/automation/datasets/tulu.py new file mode 100644 index 00000000..6dcf2ac8 --- /dev/null +++ b/src/automation/datasets/tulu.py @@ -0,0 +1,3 @@ + +def make_tulu_prompt(sample): + return sample["messages"] diff --git a/src/automation/tasks/__init__.py b/src/automation/tasks/__init__.py index 62e70841..baae3580 100644 --- a/src/automation/tasks/__init__.py +++ b/src/automation/tasks/__init__.py @@ -1,4 +1,6 @@ from automation.tasks.base_task import BaseTask +from automation.tasks.semantic_similarity_generate import SemanticSimilarityGenerateTask +from automation.tasks.semantic_similarity_score import SemanticSimilarityScoreTask from automation.tasks.llmcompressor import LLMCompressorTask from automation.tasks.lmeval import LMEvalTask from automation.tasks.lighteval import LightEvalTask @@ -7,9 +9,11 @@ __all__ = [ "BaseTask", + "SemanticSimilarityGenerateTask", + "SemanticSimilarityScoreTask", "LLMCompressorTask", "LMEvalTask", "LightEvalTask", "GuideLLMTask", "DebugTask", -] \ No newline at end of file +] diff --git a/src/automation/tasks/scripts/semantic_similarity_generate_script.py b/src/automation/tasks/scripts/semantic_similarity_generate_script.py new file mode 100644 index 00000000..be1e2630 --- /dev/null +++ b/src/automation/tasks/scripts/semantic_similarity_generate_script.py @@ -0,0 +1,150 @@ +import json +import os +import requests +from torch.cuda import device_count +from tqdm import tqdm +from datasets import load_dataset +from vllm import LLM, SamplingParams +from transformers import AutoTokenizer + +from automation.utils import kill_process_tree, parse_argument, flatten_nested_dict +from automation.datasets.tulu import make_tulu_prompt +from automation.datasets.openplatypus import make_openplatypus_prompt +from automation.datasets.alpaca import make_alpaca_prompt +from automation.datasets.defaults import make_default_prompt + +try: + from clearml import OutputModel, Task, Model + clearml_available = True +except ImportError: + clearml_available = False + +RESULTS_DIR = os.path.join(os.getcwd(), "results") +os.makedirs(RESULTS_DIR, exist_ok=True) + +def semantic_similarity_generate_main( + model_id, + trust_remote_code, + dataset_args, + semantic_similarity_args, + max_model_len, + max_new_tokens, + clearml_model, +): + from collections import defaultdict + from huggingface_hub import snapshot_download + + all_conversations = [] + all_samples_dict = defaultdict(list) + + print(">>> Loading dataset...") + for dataset_path, num_samples_per_dataset in dataset_args.items(): + dataset_name = dataset_path.split("/")[1].lower() + print(f">>> Loading dataset {dataset_name}...") + dataset = load_dataset(dataset_path, split=f"train[:{int(num_samples_per_dataset)}]") + all_samples_dict[dataset_name].extend(dataset) + + sorted_all_samples_dict = dict(sorted(all_samples_dict.items())) + + for dataset_name,dataset_samples in sorted_all_samples_dict.items(): + print(f">>> Loading values for {dataset_name}...") + for sample in dataset_samples: + if dataset_name == "alpaca": + prompt = make_alpaca_prompt(sample) + elif dataset_name == "open-platypus": + prompt = make_openplatypus_prompt(sample) + elif dataset_name == "tulu-3-sft-mixture": + prompt = make_tulu_prompt(sample) + else: + print("Using default prompt") + prompt = make_default_prompt(sample) + all_conversations.append(prompt) + + print("Define sampling parameters") + sampling_params = SamplingParams( + temperature=semantic_similarity_args.get("temperature", 0.0), + max_tokens=max_new_tokens + ) + + HUGGINGFACE_DIR = "/home" + if clearml_model: + HUGGINGFACE_DIR = Model(model_id).get_local_copy() + else: + print("Download snapshot") + snapshot_download(repo_id=model_id, local_dir=HUGGINGFACE_DIR) + print(os.listdir(HUGGINGFACE_DIR)) + if "mistral" in model_id.lower(): + from huggingface_hub import hf_hub_download + hf_hub_download(repo_id="mistralai/Mistral-Small-3.1-24B-Instruct-2503", filename="params.json", local_dir=HUGGINGFACE_DIR) + + try: + print(f"Initializing vLLM: {model_id}...") + llm = LLM( + model= HUGGINGFACE_DIR, + #model= model_id if "mistral" in model_id.lower() else HUGGINGFACE_DIR, + dtype=semantic_similarity_args.get("dtype", "auto"), + trust_remote_code=trust_remote_code, + tensor_parallel_size=device_count(), + enforce_eager=semantic_similarity_args.get("enforce_eager", True), + enable_chunked_prefill=semantic_similarity_args.get("enable_chunked_prefill", True), + max_model_len=max_model_len, + load_format="mistral", + config_format="mistral", + tokenizer_mode="mistral" if "mistral" in model_id.lower() else "auto" + ) + print("Completed the model initialization ") + print(">>> Running vLLM generation...") + outputs = llm.chat(messages=all_conversations, sampling_params=sampling_params) + except Exception as e: + print(f"Error initializing LLM: {e}") + + return all_conversations, outputs + + +def main(configurations=None, args=None): + if clearml_available: + task = Task.current_task() + args = task.get_parameters_as_dict(cast=True)["Args"] + clearml_model = parse_argument(args["clearml_model"], bool) + else: + args = args["Args"] + clearml_model = False + + # Parse arguments + force_download = parse_argument(args["force_download"], bool) + trust_remote_code = parse_argument(args["trust_remote_code"], bool) + model_id = parse_argument(args["model_id"], str) + max_model_len = parse_argument(args["max_model_len"], int) + max_new_tokens = parse_argument(args["max_new_tokens"], int) + dataset_args = flatten_nested_dict(parse_argument(args["dataset_args"], dict)) + semantic_similarity_args= args.get("semantic_similarity_args", None) + tags = args.get("tags", None) + + all_conversations, outputs = semantic_similarity_generate_main( + model_id, + trust_remote_code, + dataset_args, + semantic_similarity_args, + max_model_len, + max_new_tokens, + clearml_model, + ) + + OUTPUT_FILE = os.path.join(RESULTS_DIR,f"{model_id.replace('/', '_')}.jsonl") + print(">>> Writing outputs to file...") + with open(OUTPUT_FILE, "w") as fout: + for idx, (prompt, output) in enumerate(zip(all_conversations, outputs)): + response = output.outputs[0].text.strip() + fout.write(json.dumps({ + "index": idx, + "prompt": prompt, + "response": response + }) + "\n") + + print(f">>> Completed. Saved {len(outputs)} outputs to {OUTPUT_FILE}") + + if clearml_available: + task.upload_artifact("jsonl_output", OUTPUT_FILE) + +if __name__ == '__main__': + main() diff --git a/src/automation/tasks/scripts/semantic_similarity_score_script.py b/src/automation/tasks/scripts/semantic_similarity_score_script.py new file mode 100644 index 00000000..1b447104 --- /dev/null +++ b/src/automation/tasks/scripts/semantic_similarity_score_script.py @@ -0,0 +1,159 @@ +import json +import os +from automation.utils import parse_argument + +try: + from clearml import OutputModel, Task + clearml_available = True +except ImportError: + clearml_available = False + +SCORING_DIR = os.path.join(os.getcwd(), "scoresdirectory") +os.makedirs(SCORING_DIR, exist_ok=False) + +def semantic_similarity_score_main( + reference_file, + candidate_file, + sts_model_id, + rouge_scores, + bert_score_limit, + rouge1_score_limit, + sts_score_limit, +): + from bert_score import score + from rouge_score import rouge_scorer + from sentence_transformers import SentenceTransformer, util + + # Load reference and candidate data + with open(reference_file, "r") as f_ref, open(candidate_file, "r") as f_cand: + reference_data = [json.loads(line) for line in f_ref] + candidate_data = [json.loads(line) for line in f_cand] + + assert len(reference_data) == len(candidate_data), "Mismatched number of entries!" + + # Extract answers + references = [ref.get("output") or ref["response"] for ref in reference_data] + candidates = [cand["response"] for cand in candidate_data] + + # Load models + sts_model = SentenceTransformer(sts_model_id) + rouge = rouge_scorer.RougeScorer(rouge_scores, use_stemmer=True) + + # Compute BERTScore + _, _, f1_scores = score(candidates, references, lang="en", verbose=False) + #all_bert_f1 = [ f1.item() for f1 in f1_scores ] + + # Evaluate metrics + all_rouge1_f1, all_rougeL_f1, all_sts, all_bert_f1 = [], [], [], [] + low_score_indices = [] + + for i, (ref, cand, f1) in enumerate(zip(references, candidates, f1_scores)): + emb_ref = sts_model.encode(ref, convert_to_tensor=True) + emb_cand = sts_model.encode(cand, convert_to_tensor=True) + raw_sts = util.cos_sim(emb_cand, emb_ref).item() + sts = (raw_sts + 1) / 2 # Normalize to [0, 1] + all_sts.append(sts) + + rouge_scores = rouge.score(ref, cand) + rouge1 = rouge_scores["rouge1"].fmeasure + rougeL = rouge_scores["rougeL"].fmeasure + all_rouge1_f1.append(rouge1) + all_rougeL_f1.append(rougeL) + + all_bert_f1.append(f1.item()) + + if f1 < bert_score_limit or rouge1 < rouge1_score_limit or sts < sts_score_limit: + low_score_indices.append(i) + + # Compute averages + num_samples = len(references) + avg_bert = sum(all_bert_f1) / num_samples + avg_rouge1 = sum(all_rouge1_f1) / num_samples + avg_rougeL = sum(all_rougeL_f1) / num_samples + avg_sts = sum(all_sts) / num_samples + return avg_bert, avg_rouge1, avg_rougeL, avg_sts, low_score_indices + +def main(configurations=None, args=None): + if clearml_available: + task = Task.current_task() + args = task.get_parameters_as_dict(cast=True)["Args"] + else: + args = args["Args"] + + # Parse arguments + clearml_model = parse_argument(args["clearml_model"], bool) + force_download = parse_argument(args["force_download"], bool) + trust_remote_code = parse_argument(args["trust_remote_code"], bool) + low_score_threshold_args = args.get("low_score_threshold_args", dict) + sts_model_id = args.get("sts_model_id", str) + rouge_scores= args.get("rouge_scores", list) + tags = args.get("tags", None) + + print(args) + print(low_score_threshold_args) + + if clearml_available: + reference_model_project_name = parse_argument(args["reference_model_project_name"], str) + candidate_model_project_name = parse_argument(args["candidate_model_project_name"], str) + candidate_model_task_name = parse_argument(args["candidate_model_task_name"], str) + reference_model_task_name = parse_argument(args["reference_model_task_name"], str) + reference_task = Task.query_tasks(project_name=reference_model_project_name,task_name= reference_model_task_name, task_filter={'order_by': ['-last_update'], 'status': ['completed'] }) + reference_task = Task.get_task(reference_task[0]) + reference_file = reference_task.artifacts['jsonl_output'].get_local_copy() + + candidate_task = Task.query_tasks(project_name=candidate_model_project_name,task_name= candidate_model_task_name, task_filter={'order_by': ['-last_update'], 'status': ['completed'] }) + candidate_task = Task.get_task(candidate_task[0]) + candidate_file = candidate_task.artifacts['jsonl_output'].get_local_copy() + else: + ref_model_jsonl = args.get("ref_model_jsonl", str) + cand_model_jsonl = args.get("cand_model_jsonl", str) + reference_file = os.path.join(SCORING_DIR, ref_model_jsonl) + candidate_file = os.path.join(SCORING_DIR, cand_model_jsonl) + + bert_score_limit = low_score_threshold_args.get("f1",0.75) + rouge1_score_limit = low_score_threshold_args.get("rouge1",0.6) + sts_score_limit = low_score_threshold_args.get("sts",0.75) + + avg_bert, avg_rouge1, avg_rougeL, avg_sts, low_score_indices = semantic_similarity_score_main( + reference_file, + candidate_file, + sts_model_id, + rouge_scores, + bert_score_limit, + rouge1_score_limit, + sts_score_limit, + ) + # Print summary + print("\n=== Averages (for Google Sheets) ===") + print("BERTScore F1 | ROUGE-1 F1 | ROUGE-L F1 | STS CosSim") + print(f"{avg_bert:.3f} | {avg_rouge1:.3f} | {avg_rougeL:.3f} | {avg_sts:.3f}") + + print(f"\n=== Low-score indices (BERT < {bert_score_limit}, ROUGE-1 < {rouge1_score_limit}, STS < {sts_score_limit}) ===") + print(low_score_indices) + + data = { + "BERTScore F1": f"{avg_bert:.3f}", + "ROUGE-1 F1": f"{avg_rouge1:.3f}", + "ROUGE-L F1": f"{avg_rougeL:.3f}", + "STS CosSim": f"{avg_sts:.3f}", + } + + from pathlib import Path + + reference_file = Path(reference_file).stem.lower() + candidate_file = Path(candidate_file).stem.lower() + out_filename = f"scores_{reference_file}__vs__{candidate_file}.txt" + out_filename = os.path.join(SCORING_DIR,out_filename) + + # Save results + with open(out_filename, "w") as file: + json.dump(data, file, indent=4) + + print(f"\nSaved results to {out_filename}") + if clearml_available: + task.upload_artifact("scores", data) + task.upload_artifact("outscores", out_filename) + print("Pushing clearml artifact") + +if __name__ == '__main__': + main() diff --git a/src/automation/tasks/semantic_similarity_generate.py b/src/automation/tasks/semantic_similarity_generate.py new file mode 100644 index 00000000..87ef543e --- /dev/null +++ b/src/automation/tasks/semantic_similarity_generate.py @@ -0,0 +1,121 @@ +from automation.tasks.base_task import BaseTask +from automation.configs import DEFAULT_DOCKER_IMAGE +from typing import Union, List, Optional, Sequence, Any, Callable +import os +import yaml + +class SemanticSimilarityGenerateTask(BaseTask): + task_packages = [ + "vllm", + "hf_xet", + "pyzmq", + #"vllm==0.10.1.1", + ] + + def __init__( + self, + project_name: str, + task_name: str, + model_id: str, + branch: str, + max_new_tokens: int, + max_model_len: int, + dataset_args: Optional[dict]=None, + semantic_similarity_args: Optional[dict]=None, + docker_image: str=DEFAULT_DOCKER_IMAGE, + packages: Optional[Sequence[str]]=None, + clearml_model: bool=False, + force_download: bool=False, + save_directory: str="output", + trust_remote_code: bool=False, + tags: Union[str, List[str]]=None, + task_type: str="training", + config: Optional[str]=None, + ): + + # Process config + config_kwargs = self.process_config(config) + + # Set packages, taking into account default packages + # for the LMEvalTask and packages set in the config + if packages is not None: + packages = list(set(packages + self.task_packages)) + else: + packages = self.task_packages + + if "packages" in config_kwargs: + packages = list(set(packages + config_kwargs.pop("packages"))) + + # Initialize base parameters + super().__init__( + project_name=project_name, + task_name=task_name, + branch=branch, + docker_image=docker_image, + packages=packages, + task_type=task_type, + ) + + for key in config_kwargs: + if key in kwargs: + raise ValueError(f"{key} already defined in config's args. It can't be defined again in task instantiation.") + + if dataset_args is None: + self.dataset_args = config_kwargs.pop("dataset_args", None) + else: + config_dataset_args = config_kwargs.pop("dataset_args", {}) + config_dataset_args.update(dataset_args) + self.dataset_args = config_dataset_args + + if semantic_similarity_args is None: + self.semantic_similarity_args = config_kwargs.pop("semantic_similarity_args", None) + else: + config_semantic_similarity_args = config_kwargs.pop("semantic_similarity_args", {}) + config_semantic_similarity_args.update(semantic_similarity_args) + self.semantic_similarity_args = config_semantic_similarity_args + + self.max_new_tokens = config_kwargs.pop("max_new_tokens", max_new_tokens) + self.max_model_len = config_kwargs.pop("max_model_len", max_model_len) + self.trust_remote_code = config_kwargs.pop("trust_remote_code", trust_remote_code) + + if tags is not None: + tags = list(set(config_kwargs.pop("tags", []).extend(tags))) + else: + tags = config_kwargs.pop("tags", None) + self.tags = tags + + # Store class attributes + self.model_id = model_id + self.clearml_model = clearml_model + self.force_download = force_download + self.save_directory = save_directory + self.script_path = os.path.join(".", "src", "automation", "tasks", "scripts", "semantic_similarity_generate_script.py") + + + def script(self, configurations, args): + from automation.tasks.scripts.semantic_similarity_generate_script import main + main(configurations, args) + + + def get_configurations(self): + configs = {} + return configs + + + def get_arguments(self): + return { + "Args": { + "model_id": self.model_id, + "dataset_args": self.dataset_args, + "semantic_similarity_args": self.semantic_similarity_args, + "clearml_model": self.clearml_model, + "force_download": self.force_download, + "save_directory": self.save_directory, + "max_new_tokens": self.max_new_tokens, + "max_model_len": self.max_model_len, + "trust_remote_code": self.trust_remote_code, + "tags": self.tags, + }, + } + + diff --git a/src/automation/tasks/semantic_similarity_score.py b/src/automation/tasks/semantic_similarity_score.py new file mode 100644 index 00000000..0194b6b9 --- /dev/null +++ b/src/automation/tasks/semantic_similarity_score.py @@ -0,0 +1,122 @@ +from automation.tasks.base_task import BaseTask +from automation.configs import DEFAULT_DOCKER_IMAGE +from typing import Union, List, Optional, Sequence, Any, Callable +import os +import yaml + +class SemanticSimilarityScoreTask(BaseTask): + task_packages = [ + "hf_xet", + "pyzmq", + ] + + def __init__( + self, + project_name: str, + task_name: str, + reference_model_project_name: str, + candidate_model_project_name: str, + reference_model_task_name: str, + candidate_model_task_name: str, + sts_model_id: str, + branch: str, + rouge_scores: Optional[len]=None, + low_score_threshold_args: Optional[dict]=None, + docker_image: str=DEFAULT_DOCKER_IMAGE, + packages: Optional[Sequence[str]]=None, + clearml_model: bool=False, + force_download: bool=False, + save_directory: str="output", + trust_remote_code: bool=False, + tags: Union[str, List[str]]=None, + task_type: str="training", + config: Optional[str]=None, + ): + + # Process config + config_kwargs = self.process_config(config) + + # Set packages, taking into account default packages + # for the LMEvalTask and packages set in the config + if packages is not None: + packages = list(set(packages + self.task_packages)) + else: + packages = self.task_packages + + if "packages" in config_kwargs: + packages = list(set(packages + config_kwargs.pop("packages"))) + + # Initialize base parameters + super().__init__( + project_name=project_name, + task_name=task_name, + branch=branch, + docker_image=docker_image, + packages=packages, + task_type=task_type, + ) + + + if rouge_scores is None: + self.rouge_scores = config_kwargs.pop("rouge_scores", None) + else: + config_rouge_scores = config_kwargs.pop("rouge_scores", []) + config_rouge_scores+= rouge_scores + self.rouge_scores = config_rouge_scores + + if low_score_threshold_args is None: + self.low_score_threshold_args = config_kwargs.pop("low_score_threshold_args", None) + else: + config_low_score_threshold_args = config_kwargs.pop("low_score_threshold_args", {}) + config_low_score_threshold_args.update(low_score_threshold_args) + self.low_score_threshold_args = config_low_score_threshold_args + + self.trust_remote_code = config_kwargs.pop("trust_remote_code", trust_remote_code) + + if tags is not None: + tags = list(set(config_kwargs.pop("tags", []).extend(tags))) + else: + tags = config_kwargs.pop("tags", None) + self.tags = tags + + # Store class attributes + self.reference_model_project_name = reference_model_project_name + self.candidate_model_project_name = candidate_model_project_name + self.reference_model_task_name = reference_model_task_name + self.candidate_model_task_name = candidate_model_task_name + self.sts_model_id = sts_model_id + self.clearml_model = clearml_model + self.force_download = force_download + self.save_directory = save_directory + self.script_path = os.path.join(".", "src", "automation", "tasks", "scripts", "semantic_similarity_score_script.py") + + + def script(self, configurations, args): + from automation.tasks.scripts.semantic_similarity_score_script import main + main(configurations, args) + + + def get_configurations(self): + configs = {} + return configs + + + def get_arguments(self): + return { + "Args": { + "reference_model_project_name": self.reference_model_project_name, + "candidate_model_project_name": self.candidate_model_project_name, + "reference_model_task_name": self.reference_model_task_name, + "candidate_model_task_name": self.candidate_model_task_name, + "sts_model_id": self.sts_model_id, + "rouge_scores": self.rouge_scores, + "low_score_threshold_args": self.low_score_threshold_args, + "clearml_model": self.clearml_model, + "force_download": self.force_download, + "save_directory": self.save_directory, + "trust_remote_code": self.trust_remote_code, + "tags": self.tags, + }, + } + + diff --git a/src/automation/utils.py b/src/automation/utils.py index 7b4263f9..c90d3ff9 100644 --- a/src/automation/utils.py +++ b/src/automation/utils.py @@ -177,4 +177,11 @@ def merge_dicts(d1, d2): else: raise ValueError(f"{key} already defined. It can't be defined again.") d1.update(d2) - return d1 \ No newline at end of file + return d1 + +def flatten_nested_dict(nested_dataset_args): + flattened_dict = {} + for org, datasets in nested_dataset_args.items(): + for dataset, count in datasets.items(): + flattened_dict[f"{org}/{dataset}"] = count + return flattened_dict