Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
96 commits
Select commit Hold shift + click to select a range
757f02b
base semantic gen
Oct 15, 2025
95d93d2
base requirements
Oct 15, 2025
8b7bf3c
simple package
Oct 15, 2025
5e5276c
base method added
Oct 15, 2025
b8ec696
remove missing libraru inserts
Oct 15, 2025
113f0df
clean up variables
Oct 15, 2025
aa2de29
clean up input variables
Oct 15, 2025
02d6087
clean up prompt logs
Oct 15, 2025
9faf81d
fix device count
Oct 15, 2025
8e407a0
add semantic_similarity_args
Oct 15, 2025
82eaa9a
update model input vars
Oct 16, 2025
70d0ad4
added more log
Oct 16, 2025
bf7d817
initialize vllm
Oct 16, 2025
1d85fa6
download model beforehand
Oct 16, 2025
6be969d
template score
Oct 16, 2025
0798f61
update task name
Oct 16, 2025
ca9ff84
rouge score array
Oct 16, 2025
21f54d0
base scoring script
Oct 16, 2025
37b82d9
remove snapshot downlad
Oct 16, 2025
7e13853
test vllm server
Oct 16, 2025
1f30150
add requests query
Oct 16, 2025
a153996
clean libs
Oct 16, 2025
92fe4c9
fix parse issue
Oct 16, 2025
a8751c2
use start_vllm_server
Oct 16, 2025
5b9539e
test llm generate
Oct 16, 2025
2a73701
updated vllm server
Oct 16, 2025
a8547f5
add debug logging level
Oct 16, 2025
274948b
base LLM
Oct 16, 2025
051afae
try except for vllm
Oct 16, 2025
d51d3cf
use vllmm server
Oct 16, 2025
e9cd534
retry snapshot download
Oct 17, 2025
ab7fe4a
snapshot down
Oct 17, 2025
bc2897b
snapshot with download_dir
Oct 17, 2025
78b005a
add model dir
Oct 17, 2025
aac6d69
add dtype
Oct 17, 2025
73abff7
model dir
Oct 17, 2025
41bc34a
add trust remote code
Oct 17, 2025
d37541a
download safetensors
Oct 17, 2025
67fee2c
move vllm server up
Oct 17, 2025
e59e1df
use the same dir
Oct 17, 2025
bef036e
redo snapshot download
Oct 17, 2025
044cef7
trigger
Oct 17, 2025
77bc96d
combined
Oct 17, 2025
67fca56
use vllm server
Oct 17, 2025
bf80fd4
add process tree import
Oct 17, 2025
f5a21f5
add clearml conditional
Oct 17, 2025
b471178
add task import
Oct 17, 2025
08914e5
retrieve current task
Oct 17, 2025
2c3a299
output server logs
Oct 17, 2025
c1a0b3c
print vllm command
Oct 17, 2025
f83b044
output as json
Oct 20, 2025
d9b447a
output artifact
Oct 20, 2025
8ebd724
retry with python llm interface
Oct 22, 2025
b9ae4c1
reference the downloaded model
Oct 23, 2025
ecf9f4b
add results directory creation
Oct 23, 2025
05b8f0f
fix download and read
Oct 23, 2025
389a5d8
clean up repo
Oct 23, 2025
5e84115
clean up scoring and remove hardcoding
Oct 23, 2025
64c5369
add low score indices
Oct 23, 2025
bed4991
add f1 score to enum
Oct 23, 2025
29b650c
simplify output path
Oct 23, 2025
b84a102
add examples and clean up
Oct 23, 2025
c4e1aea
clean up example
Oct 23, 2025
7cd5a3a
add scoring args dict
Oct 23, 2025
d5e4210
add support for variable score limits
Oct 23, 2025
9d349e8
clearml get model_id
Oct 28, 2025
98609a2
add clearml model import
Oct 28, 2025
391ecc5
check for clearml model
Oct 29, 2025
23a5f95
reference huggingface dir
Oct 29, 2025
5feeff7
implement semantic feedback
Oct 31, 2025
4768cf8
add db path debug
Oct 31, 2025
2b3a4f7
more debug
Oct 31, 2025
6dcfd74
debug dataset_args
Oct 31, 2025
11ba9fc
hardcode dataset args
Oct 31, 2025
0da978e
update examples
Oct 31, 2025
44f0c62
moved from utils
Oct 31, 2025
3b09a3f
dataset args through parse
Oct 31, 2025
640ae0a
add more dataset arg prints
Oct 31, 2025
ae9d928
add dict flattening
Oct 31, 2025
e74e0d6
added dictionary flattening
Nov 3, 2025
0ee0288
update prompt to chat
Nov 11, 2025
8c71fac
string output prompts
Nov 11, 2025
1ae047a
moved from prompt to conversation
Nov 12, 2025
e382646
retry with tqdm
Nov 12, 2025
a3ecdb5
re-add messages list
Nov 12, 2025
9f56b0b
clean up convos
Nov 12, 2025
57a2596
add debug to know which model is being initialised
Nov 13, 2025
b36b998
add mistral exception
Nov 13, 2025
91009b9
allow existing results dir
Nov 13, 2025
3c400ec
snapshot download only
Nov 13, 2025
e21f100
add tokenizer mode for mistral
Nov 13, 2025
96e33be
llm direct from model id
Nov 13, 2025
ff351b9
add print
Nov 13, 2025
48c0e6d
fix format
Nov 14, 2025
ce27ec4
add list dir
Nov 14, 2025
aeaca8f
add params download
Nov 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions examples/semantic_similarity_generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from automation.tasks import SemanticSimilarityGenerateTask

task = SemanticSimilarityGenerateTask(
project_name="semantic_similarity_debug",
task_name="semantic_generation_qwen3_14b_w4a16_feedback",
#task_name="semantic_generation_qwen3_14b_feedback",
branch="semantic_similarity",
packages = ["huggingface-hub==0.34.3", "triton==3.3.1", "vllm==0.10.1.1"],
dataset_args = {"tatsu-lab/alpaca" : 300 , "garage-bAInd/Open-Platypus": "310", "allenai/tulu-3-sft-mixture": 320},
model_id="Qwen/Qwen3-14B",
max_new_tokens=1024,
max_model_len=4096,
semantic_similarity_args={"enable-chunked-prefill": True, "enforce_eager": True, "dtype" :"auto", "device_map": "auto", "temperature": 0.0},
)

task.execute_remotely("oneshot-a100x1")

20 changes: 20 additions & 0 deletions examples/semantic_similarity_score.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from automation.tasks import SemanticSimilarityScoreTask

task = SemanticSimilarityScoreTask(
project_name="semantic_similarity_debug",
#task_name="semantic_scoring_14b",
task_name="semantic_scoring_4b",
branch="semantic_similarity",
packages = ["huggingface-hub==0.34.3", "networkx==3.4.2", "datasets==4.2.0", "rouge_score==0.1.2", "bert-score==0.3.13", "sentence-transformers==5.1.1", "matplotlib"],
reference_model_project_name="semantic_similarity_debug",
candidate_model_project_name="semantic_similarity_debug",
reference_model_task_name="semantic_generation_qwen3_14b_feedback",
#reference_model_task_name="semantic_generation_qwen3_14b_base",
candidate_model_task_name="semantic_generation_qwen3_14b_w4a16_feedback",
#candidate_model_task_name="semantic_generation_qwen3_14b_w4a16",
sts_model_id="all-MiniLM-L6-v2",
rouge_scores=["rouge1", "rougeL"],
low_score_threshold_args={"f1": 0.79, "rouge1": 0.65, "sts": 0.71},
)

task.execute_remotely("oneshot-a100x1")
10 changes: 9 additions & 1 deletion src/automation/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
from automation.datasets.openthoughts import DATASET_PATH as OPENTHOUGHTSDATASET
from automation.datasets.utils import load_llm_messages, load_vlm_messages
from automation.datasets.fleurs import load_fleurs_dataset
from automation.datasets.tulu import make_tulu_prompt
from automation.datasets.openplatypus import make_openplatypus_prompt
from automation.datasets.alpaca import make_alpaca_prompt
from automation.datasets.defaults import make_default_prompt

SUPPORTED_DATASETS = {
"calibration": load_calibration_dataset,
Expand All @@ -17,6 +21,10 @@
"load_openthoughts_dataset",
"load_llm_messages",
"load_vlm_messages",
"make_tulu_prompt",
"make_openplatypus_prompt",
"make_alpaca_prompt",
"make_default_prompt",
"load_fleurs_dataset",
"SUPPORTED_DATASETS",
]
]
22 changes: 22 additions & 0 deletions src/automation/datasets/alpaca.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
def make_alpaca_prompt(sample):
instruction = sample["instruction"].strip()
input_text = sample.get("input", "").strip()

if input_text == "":
messages = [
{
"role": "user",
"content": f"{instruction}",
}
]


else:
messages = [
{
"role": "user",
"content": f"{instruction}\n{input_text}",
}
]

return messages
12 changes: 12 additions & 0 deletions src/automation/datasets/defaults.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
def make_default_prompt(sample):
messages = [
{
"role": "user",
"content": f"{json.dumps(sample)}",
}
]

prompt = "\n".join([f"{msg['role']}: {msg['content']}" for msg in messages])

return prompt

21 changes: 21 additions & 0 deletions src/automation/datasets/openplatypus.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
def make_openplatypus_prompt(sample):
instruction = sample["instruction"].strip()
input_text = sample.get("input", "").strip()

if input_text == "":
messages = [
{
"role": "user",
"content": f"{instruction}",
}
]

else:
messages = [
{
"role": "user",
"content": f"{instruction}\n{input_text}",
}
]

return messages
3 changes: 3 additions & 0 deletions src/automation/datasets/tulu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@

def make_tulu_prompt(sample):
return sample["messages"]
6 changes: 5 additions & 1 deletion src/automation/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from automation.tasks.base_task import BaseTask
from automation.tasks.semantic_similarity_generate import SemanticSimilarityGenerateTask
from automation.tasks.semantic_similarity_score import SemanticSimilarityScoreTask
from automation.tasks.llmcompressor import LLMCompressorTask
from automation.tasks.lmeval import LMEvalTask
from automation.tasks.lighteval import LightEvalTask
Expand All @@ -7,9 +9,11 @@

__all__ = [
"BaseTask",
"SemanticSimilarityGenerateTask",
"SemanticSimilarityScoreTask",
"LLMCompressorTask",
"LMEvalTask",
"LightEvalTask",
"GuideLLMTask",
"DebugTask",
]
]
150 changes: 150 additions & 0 deletions src/automation/tasks/scripts/semantic_similarity_generate_script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import json
import os
import requests
from torch.cuda import device_count
from tqdm import tqdm
from datasets import load_dataset
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer

from automation.utils import kill_process_tree, parse_argument, flatten_nested_dict
from automation.datasets.tulu import make_tulu_prompt
from automation.datasets.openplatypus import make_openplatypus_prompt
from automation.datasets.alpaca import make_alpaca_prompt
from automation.datasets.defaults import make_default_prompt

try:
from clearml import OutputModel, Task, Model
clearml_available = True
except ImportError:
clearml_available = False

RESULTS_DIR = os.path.join(os.getcwd(), "results")
os.makedirs(RESULTS_DIR, exist_ok=True)

def semantic_similarity_generate_main(
model_id,
trust_remote_code,
dataset_args,
semantic_similarity_args,
max_model_len,
max_new_tokens,
clearml_model,
):
from collections import defaultdict
from huggingface_hub import snapshot_download

all_conversations = []
all_samples_dict = defaultdict(list)

print(">>> Loading dataset...")
for dataset_path, num_samples_per_dataset in dataset_args.items():
dataset_name = dataset_path.split("/")[1].lower()
print(f">>> Loading dataset {dataset_name}...")
dataset = load_dataset(dataset_path, split=f"train[:{int(num_samples_per_dataset)}]")
all_samples_dict[dataset_name].extend(dataset)

sorted_all_samples_dict = dict(sorted(all_samples_dict.items()))

for dataset_name,dataset_samples in sorted_all_samples_dict.items():
print(f">>> Loading values for {dataset_name}...")
for sample in dataset_samples:
if dataset_name == "alpaca":
prompt = make_alpaca_prompt(sample)
elif dataset_name == "open-platypus":
prompt = make_openplatypus_prompt(sample)
elif dataset_name == "tulu-3-sft-mixture":
prompt = make_tulu_prompt(sample)
else:
print("Using default prompt")
prompt = make_default_prompt(sample)
all_conversations.append(prompt)

print("Define sampling parameters")
sampling_params = SamplingParams(
temperature=semantic_similarity_args.get("temperature", 0.0),
max_tokens=max_new_tokens
)

HUGGINGFACE_DIR = "/home"
if clearml_model:
HUGGINGFACE_DIR = Model(model_id).get_local_copy()
else:
print("Download snapshot")
snapshot_download(repo_id=model_id, local_dir=HUGGINGFACE_DIR)
print(os.listdir(HUGGINGFACE_DIR))
if "mistral" in model_id.lower():
from huggingface_hub import hf_hub_download
hf_hub_download(repo_id="mistralai/Mistral-Small-3.1-24B-Instruct-2503", filename="params.json", local_dir=HUGGINGFACE_DIR)

try:
print(f"Initializing vLLM: {model_id}...")
llm = LLM(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are we using the LLM class instead of vllm serve?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The main branch has an old src/automation/vllm/server.py file the class VLLMServer but other branches use start_vllm_server`.
Also shouldn't the output of the LLM class be identical to the vllm serve api endpoint?

model= HUGGINGFACE_DIR,
#model= model_id if "mistral" in model_id.lower() else HUGGINGFACE_DIR,
dtype=semantic_similarity_args.get("dtype", "auto"),
trust_remote_code=trust_remote_code,
tensor_parallel_size=device_count(),
enforce_eager=semantic_similarity_args.get("enforce_eager", True),
enable_chunked_prefill=semantic_similarity_args.get("enable_chunked_prefill", True),
max_model_len=max_model_len,
load_format="mistral",
config_format="mistral",
tokenizer_mode="mistral" if "mistral" in model_id.lower() else "auto"
)
print("Completed the model initialization ")
print(">>> Running vLLM generation...")
outputs = llm.chat(messages=all_conversations, sampling_params=sampling_params)
except Exception as e:
print(f"Error initializing LLM: {e}")

return all_conversations, outputs


def main(configurations=None, args=None):
if clearml_available:
task = Task.current_task()
args = task.get_parameters_as_dict(cast=True)["Args"]
clearml_model = parse_argument(args["clearml_model"], bool)
else:
args = args["Args"]
clearml_model = False

# Parse arguments
force_download = parse_argument(args["force_download"], bool)
trust_remote_code = parse_argument(args["trust_remote_code"], bool)
model_id = parse_argument(args["model_id"], str)
max_model_len = parse_argument(args["max_model_len"], int)
max_new_tokens = parse_argument(args["max_new_tokens"], int)
dataset_args = flatten_nested_dict(parse_argument(args["dataset_args"], dict))
semantic_similarity_args= args.get("semantic_similarity_args", None)
tags = args.get("tags", None)

all_conversations, outputs = semantic_similarity_generate_main(
model_id,
trust_remote_code,
dataset_args,
semantic_similarity_args,
max_model_len,
max_new_tokens,
clearml_model,
)

OUTPUT_FILE = os.path.join(RESULTS_DIR,f"{model_id.replace('/', '_')}.jsonl")
print(">>> Writing outputs to file...")
with open(OUTPUT_FILE, "w") as fout:
for idx, (prompt, output) in enumerate(zip(all_conversations, outputs)):
response = output.outputs[0].text.strip()
fout.write(json.dumps({
"index": idx,
"prompt": prompt,
"response": response
}) + "\n")

print(f">>> Completed. Saved {len(outputs)} outputs to {OUTPUT_FILE}")

if clearml_available:
task.upload_artifact("jsonl_output", OUTPUT_FILE)

if __name__ == '__main__':
main()
Loading