-
Notifications
You must be signed in to change notification settings - Fork 0
Semantic similarity #13
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
Chibukach
wants to merge
96
commits into
main
Choose a base branch
from
semantic_similarity
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
96 commits
Select commit
Hold shift + click to select a range
757f02b
base semantic gen
95d93d2
base requirements
8b7bf3c
simple package
5e5276c
base method added
b8ec696
remove missing libraru inserts
113f0df
clean up variables
aa2de29
clean up input variables
02d6087
clean up prompt logs
9faf81d
fix device count
8e407a0
add semantic_similarity_args
82eaa9a
update model input vars
70d0ad4
added more log
bf7d817
initialize vllm
1d85fa6
download model beforehand
6be969d
template score
0798f61
update task name
ca9ff84
rouge score array
21f54d0
base scoring script
37b82d9
remove snapshot downlad
7e13853
test vllm server
1f30150
add requests query
a153996
clean libs
92fe4c9
fix parse issue
a8751c2
use start_vllm_server
5b9539e
test llm generate
2a73701
updated vllm server
a8547f5
add debug logging level
274948b
base LLM
051afae
try except for vllm
d51d3cf
use vllmm server
e9cd534
retry snapshot download
ab7fe4a
snapshot down
bc2897b
snapshot with download_dir
78b005a
add model dir
aac6d69
add dtype
73abff7
model dir
41bc34a
add trust remote code
d37541a
download safetensors
67fee2c
move vllm server up
e59e1df
use the same dir
bef036e
redo snapshot download
044cef7
trigger
77bc96d
combined
67fca56
use vllm server
bf80fd4
add process tree import
f5a21f5
add clearml conditional
b471178
add task import
08914e5
retrieve current task
2c3a299
output server logs
c1a0b3c
print vllm command
f83b044
output as json
d9b447a
output artifact
8ebd724
retry with python llm interface
b9ae4c1
reference the downloaded model
ecf9f4b
add results directory creation
05b8f0f
fix download and read
389a5d8
clean up repo
5e84115
clean up scoring and remove hardcoding
64c5369
add low score indices
bed4991
add f1 score to enum
29b650c
simplify output path
b84a102
add examples and clean up
c4e1aea
clean up example
7cd5a3a
add scoring args dict
d5e4210
add support for variable score limits
9d349e8
clearml get model_id
98609a2
add clearml model import
391ecc5
check for clearml model
23a5f95
reference huggingface dir
5feeff7
implement semantic feedback
4768cf8
add db path debug
2b3a4f7
more debug
6dcfd74
debug dataset_args
11ba9fc
hardcode dataset args
0da978e
update examples
44f0c62
moved from utils
3b09a3f
dataset args through parse
640ae0a
add more dataset arg prints
ae9d928
add dict flattening
e74e0d6
added dictionary flattening
0ee0288
update prompt to chat
8c71fac
string output prompts
1ae047a
moved from prompt to conversation
e382646
retry with tqdm
a3ecdb5
re-add messages list
9f56b0b
clean up convos
57a2596
add debug to know which model is being initialised
b36b998
add mistral exception
91009b9
allow existing results dir
3c400ec
snapshot download only
e21f100
add tokenizer mode for mistral
96e33be
llm direct from model id
ff351b9
add print
48c0e6d
fix format
ce27ec4
add list dir
aeaca8f
add params download
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,17 @@ | ||
| from automation.tasks import SemanticSimilarityGenerateTask | ||
|
|
||
| task = SemanticSimilarityGenerateTask( | ||
| project_name="semantic_similarity_debug", | ||
| task_name="semantic_generation_qwen3_14b_w4a16_feedback", | ||
| #task_name="semantic_generation_qwen3_14b_feedback", | ||
| branch="semantic_similarity", | ||
| packages = ["huggingface-hub==0.34.3", "triton==3.3.1", "vllm==0.10.1.1"], | ||
| dataset_args = {"tatsu-lab/alpaca" : 300 , "garage-bAInd/Open-Platypus": "310", "allenai/tulu-3-sft-mixture": 320}, | ||
| model_id="Qwen/Qwen3-14B", | ||
| max_new_tokens=1024, | ||
| max_model_len=4096, | ||
| semantic_similarity_args={"enable-chunked-prefill": True, "enforce_eager": True, "dtype" :"auto", "device_map": "auto", "temperature": 0.0}, | ||
| ) | ||
|
|
||
| task.execute_remotely("oneshot-a100x1") | ||
|
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| from automation.tasks import SemanticSimilarityScoreTask | ||
|
|
||
| task = SemanticSimilarityScoreTask( | ||
| project_name="semantic_similarity_debug", | ||
| #task_name="semantic_scoring_14b", | ||
| task_name="semantic_scoring_4b", | ||
| branch="semantic_similarity", | ||
| packages = ["huggingface-hub==0.34.3", "networkx==3.4.2", "datasets==4.2.0", "rouge_score==0.1.2", "bert-score==0.3.13", "sentence-transformers==5.1.1", "matplotlib"], | ||
| reference_model_project_name="semantic_similarity_debug", | ||
| candidate_model_project_name="semantic_similarity_debug", | ||
| reference_model_task_name="semantic_generation_qwen3_14b_feedback", | ||
| #reference_model_task_name="semantic_generation_qwen3_14b_base", | ||
| candidate_model_task_name="semantic_generation_qwen3_14b_w4a16_feedback", | ||
| #candidate_model_task_name="semantic_generation_qwen3_14b_w4a16", | ||
| sts_model_id="all-MiniLM-L6-v2", | ||
| rouge_scores=["rouge1", "rougeL"], | ||
| low_score_threshold_args={"f1": 0.79, "rouge1": 0.65, "sts": 0.71}, | ||
| ) | ||
|
|
||
| task.execute_remotely("oneshot-a100x1") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,22 @@ | ||
| def make_alpaca_prompt(sample): | ||
| instruction = sample["instruction"].strip() | ||
| input_text = sample.get("input", "").strip() | ||
|
|
||
| if input_text == "": | ||
| messages = [ | ||
| { | ||
| "role": "user", | ||
| "content": f"{instruction}", | ||
| } | ||
| ] | ||
|
|
||
|
|
||
| else: | ||
| messages = [ | ||
| { | ||
| "role": "user", | ||
| "content": f"{instruction}\n{input_text}", | ||
| } | ||
| ] | ||
|
|
||
| return messages |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| def make_default_prompt(sample): | ||
| messages = [ | ||
| { | ||
| "role": "user", | ||
| "content": f"{json.dumps(sample)}", | ||
| } | ||
| ] | ||
|
|
||
| prompt = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages]) | ||
|
|
||
| return prompt | ||
|
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| def make_openplatypus_prompt(sample): | ||
| instruction = sample["instruction"].strip() | ||
| input_text = sample.get("input", "").strip() | ||
|
|
||
| if input_text == "": | ||
| messages = [ | ||
| { | ||
| "role": "user", | ||
| "content": f"{instruction}", | ||
| } | ||
| ] | ||
|
|
||
| else: | ||
| messages = [ | ||
| { | ||
| "role": "user", | ||
| "content": f"{instruction}\n{input_text}", | ||
| } | ||
| ] | ||
|
|
||
| return messages |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,3 @@ | ||
|
|
||
| def make_tulu_prompt(sample): | ||
| return sample["messages"] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
150 changes: 150 additions & 0 deletions
150
src/automation/tasks/scripts/semantic_similarity_generate_script.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,150 @@ | ||
| import json | ||
| import os | ||
| import requests | ||
| from torch.cuda import device_count | ||
| from tqdm import tqdm | ||
| from datasets import load_dataset | ||
| from vllm import LLM, SamplingParams | ||
| from transformers import AutoTokenizer | ||
|
|
||
| from automation.utils import kill_process_tree, parse_argument, flatten_nested_dict | ||
| from automation.datasets.tulu import make_tulu_prompt | ||
| from automation.datasets.openplatypus import make_openplatypus_prompt | ||
| from automation.datasets.alpaca import make_alpaca_prompt | ||
| from automation.datasets.defaults import make_default_prompt | ||
|
|
||
| try: | ||
| from clearml import OutputModel, Task, Model | ||
| clearml_available = True | ||
| except ImportError: | ||
| clearml_available = False | ||
|
|
||
| RESULTS_DIR = os.path.join(os.getcwd(), "results") | ||
| os.makedirs(RESULTS_DIR, exist_ok=True) | ||
|
|
||
| def semantic_similarity_generate_main( | ||
| model_id, | ||
| trust_remote_code, | ||
| dataset_args, | ||
| semantic_similarity_args, | ||
| max_model_len, | ||
| max_new_tokens, | ||
| clearml_model, | ||
| ): | ||
| from collections import defaultdict | ||
| from huggingface_hub import snapshot_download | ||
|
|
||
| all_conversations = [] | ||
| all_samples_dict = defaultdict(list) | ||
|
|
||
| print(">>> Loading dataset...") | ||
| for dataset_path, num_samples_per_dataset in dataset_args.items(): | ||
| dataset_name = dataset_path.split("/")[1].lower() | ||
| print(f">>> Loading dataset {dataset_name}...") | ||
| dataset = load_dataset(dataset_path, split=f"train[:{int(num_samples_per_dataset)}]") | ||
| all_samples_dict[dataset_name].extend(dataset) | ||
|
|
||
| sorted_all_samples_dict = dict(sorted(all_samples_dict.items())) | ||
|
|
||
| for dataset_name,dataset_samples in sorted_all_samples_dict.items(): | ||
| print(f">>> Loading values for {dataset_name}...") | ||
| for sample in dataset_samples: | ||
| if dataset_name == "alpaca": | ||
| prompt = make_alpaca_prompt(sample) | ||
| elif dataset_name == "open-platypus": | ||
| prompt = make_openplatypus_prompt(sample) | ||
| elif dataset_name == "tulu-3-sft-mixture": | ||
| prompt = make_tulu_prompt(sample) | ||
| else: | ||
| print("Using default prompt") | ||
| prompt = make_default_prompt(sample) | ||
| all_conversations.append(prompt) | ||
|
|
||
| print("Define sampling parameters") | ||
| sampling_params = SamplingParams( | ||
| temperature=semantic_similarity_args.get("temperature", 0.0), | ||
| max_tokens=max_new_tokens | ||
| ) | ||
|
|
||
| HUGGINGFACE_DIR = "/home" | ||
| if clearml_model: | ||
| HUGGINGFACE_DIR = Model(model_id).get_local_copy() | ||
| else: | ||
| print("Download snapshot") | ||
| snapshot_download(repo_id=model_id, local_dir=HUGGINGFACE_DIR) | ||
| print(os.listdir(HUGGINGFACE_DIR)) | ||
| if "mistral" in model_id.lower(): | ||
| from huggingface_hub import hf_hub_download | ||
| hf_hub_download(repo_id="mistralai/Mistral-Small-3.1-24B-Instruct-2503", filename="params.json", local_dir=HUGGINGFACE_DIR) | ||
|
|
||
| try: | ||
| print(f"Initializing vLLM: {model_id}...") | ||
| llm = LLM( | ||
| model= HUGGINGFACE_DIR, | ||
| #model= model_id if "mistral" in model_id.lower() else HUGGINGFACE_DIR, | ||
| dtype=semantic_similarity_args.get("dtype", "auto"), | ||
| trust_remote_code=trust_remote_code, | ||
| tensor_parallel_size=device_count(), | ||
| enforce_eager=semantic_similarity_args.get("enforce_eager", True), | ||
| enable_chunked_prefill=semantic_similarity_args.get("enable_chunked_prefill", True), | ||
| max_model_len=max_model_len, | ||
| load_format="mistral", | ||
| config_format="mistral", | ||
| tokenizer_mode="mistral" if "mistral" in model_id.lower() else "auto" | ||
| ) | ||
| print("Completed the model initialization ") | ||
| print(">>> Running vLLM generation...") | ||
| outputs = llm.chat(messages=all_conversations, sampling_params=sampling_params) | ||
| except Exception as e: | ||
| print(f"Error initializing LLM: {e}") | ||
|
|
||
| return all_conversations, outputs | ||
|
|
||
|
|
||
| def main(configurations=None, args=None): | ||
| if clearml_available: | ||
| task = Task.current_task() | ||
| args = task.get_parameters_as_dict(cast=True)["Args"] | ||
| clearml_model = parse_argument(args["clearml_model"], bool) | ||
| else: | ||
| args = args["Args"] | ||
| clearml_model = False | ||
|
|
||
| # Parse arguments | ||
| force_download = parse_argument(args["force_download"], bool) | ||
| trust_remote_code = parse_argument(args["trust_remote_code"], bool) | ||
| model_id = parse_argument(args["model_id"], str) | ||
| max_model_len = parse_argument(args["max_model_len"], int) | ||
| max_new_tokens = parse_argument(args["max_new_tokens"], int) | ||
| dataset_args = flatten_nested_dict(parse_argument(args["dataset_args"], dict)) | ||
| semantic_similarity_args= args.get("semantic_similarity_args", None) | ||
| tags = args.get("tags", None) | ||
|
|
||
| all_conversations, outputs = semantic_similarity_generate_main( | ||
| model_id, | ||
| trust_remote_code, | ||
| dataset_args, | ||
| semantic_similarity_args, | ||
| max_model_len, | ||
| max_new_tokens, | ||
| clearml_model, | ||
| ) | ||
|
|
||
| OUTPUT_FILE = os.path.join(RESULTS_DIR,f"{model_id.replace('/', '_')}.jsonl") | ||
| print(">>> Writing outputs to file...") | ||
| with open(OUTPUT_FILE, "w") as fout: | ||
| for idx, (prompt, output) in enumerate(zip(all_conversations, outputs)): | ||
| response = output.outputs[0].text.strip() | ||
| fout.write(json.dumps({ | ||
| "index": idx, | ||
| "prompt": prompt, | ||
| "response": response | ||
| }) + "\n") | ||
|
|
||
| print(f">>> Completed. Saved {len(outputs)} outputs to {OUTPUT_FILE}") | ||
|
|
||
| if clearml_available: | ||
| task.upload_artifact("jsonl_output", OUTPUT_FILE) | ||
|
|
||
| if __name__ == '__main__': | ||
| main() | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are we using the LLM class instead of vllm serve?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The main branch has an old
src/automation/vllm/server.pyfile theclass VLLMServerbut other branches use start_vllm_server`.Also shouldn't the output of the LLM class be identical to the vllm serve api endpoint?