Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
230 changes: 230 additions & 0 deletions examples/llm_compress_eval_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
from typing import Literal
from clearml import Task

from automation.pipelines import Pipeline
from automation.tasks import LMEvalTask, LLMCompressorTask

PROJECT_NAME = "brian_transforms_v1"


def get_spinquant_modifier(
transform_block_size: int | None,
rotations: list[Literal["R1", "R2", "R4"]] = ["R1", "R2"],
):
from llmcompressor.modifiers.transform import SpinQuantModifier

return SpinQuantModifier(
transform_type="hadamard",
transform_block_size=transform_block_size,
rotations=rotations,
)


def get_quip_modifier(
transform_block_size: int | None, rotations: list[Literal["u", "v"]] = ["u", "v"]
):
from llmcompressor.modifiers.transform import QuIPModifier

return QuIPModifier(
transform_type="hadamard",
transform_block_size=transform_block_size,
rotations=rotations,
)


def get_w4a16_scheme(group_size: int = 128):
from compressed_tensors.quantization import (
QuantizationScheme,
QuantizationStrategy,
QuantizationType,
QuantizationArgs,
)

return QuantizationScheme(
targets=["Linear"],
weights=QuantizationArgs(
num_bits=4,
type=QuantizationType.INT,
strategy=QuantizationStrategy.GROUP,
group_size=group_size,
symmetric=True,
dynamic=False,
),
)


def get_rtn_modifier(group_size: int = 128):
from llmcompressor.modifiers.quantization import (
QuantizationModifier,
)

return QuantizationModifier(
config_groups={"group_0": get_w4a16_scheme(group_size)}, ignore=["lm_head"]
)


def get_gptq_modifier(group_size: int = 128):
from llmcompressor.modifiers.quantization import (
GPTQModifier,
)

return GPTQModifier(
config_groups={"group_0": get_w4a16_scheme(group_size)}, ignore=["lm_head"]
)


recipes = {
# "DENSE": [],
# "RTN_W4A16G128": get_rtn_modifier(128),
# "GPTQ_W4A16G128": get_gptq_modifier(128),
# "QUIPv_B128_RTN_W4A16G128": [get_quip_modifier(128, ["v"]), get_rtn_modifier(128)],
# "QUIPv_B128_GPTQ_W4A16G128": [
# get_quip_modifier(128, ["v"]),
# get_gptq_modifier(128),
# ],
# "QUIPuv_B128_RTN_W4A16G128": [
# get_quip_modifier(128, ["u", "v"]),
# get_rtn_modifier(128),
# ],
# "QUIPuv_B128_GPTQ_W4A16G128": [
# get_quip_modifier(128, ["u", "v"]),
# get_gptq_modifier(128),
# ],
"SpinQuantR1R2_B128_GPTQ_W4A16G128": [
get_spinquant_modifier(128, ["R1", "R2"]),
get_gptq_modifier(128),
],
"SpinQuantR1R2R4_B128_GPTQ_W4A16G128": [
get_spinquant_modifier(128, ["R1", "R2", "R4"]),
get_gptq_modifier(128),
],
# "RTN_W4A16G64": get_rtn_modifier(64),
# "GPTQ_W4A16G64": get_gptq_modifier(64),
# "QUIPv_B64_RTN_W4A16G64": [get_quip_modifier(64, ["v"]), get_rtn_modifier(64)],
# "QUIPv_B64_GPTQ_W4A16G64": [
# get_quip_modifier(64, ["v"]),
# get_gptq_modifier(64),
# ],
# "QUIPuv_B64_RTN_W4A16G64": [get_quip_modifier(64, ["u", "v"]), get_rtn_modifier(64)],
# "QUIPuv_B64_GPTQ_W4A16G64": [
# get_quip_modifier(64, ["u", "v"]),
# get_gptq_modifier(64),
# ],
"SpinQuantR1R2_B64_GPTQ_W4A16G64": [
get_spinquant_modifier(64, ["R1", "R2"]),
get_gptq_modifier(64),
],
"SpinQuantR1R2R4_B64_GPTQ_W4A16G64": [
get_spinquant_modifier(64, ["R1", "R2", "R4"]),
get_gptq_modifier(64),
],
# "RTN_W4A16G32": get_rtn_modifier(32),
# "GPTQ_W4A16G32": get_gptq_modifier(32),
# "QUIPv_B32_RTN_W4A16G32": [get_quip_modifier(32, ["v"]), get_rtn_modifier(32)],
# "QUIPv_B32_GPTQ_W4A16G32": [
# get_quip_modifier(32, ["v"]),
# get_gptq_modifier(32),
# ],
# "QUIPuv_B32_RTN_W4A16G32": [get_quip_modifier(32, ["u", "v"]), get_rtn_modifier(32)],
# "QUIPuv_B32_GPTQ_W4A16G32": [
# get_quip_modifier(32, ["u", "v"]),
# get_gptq_modifier(32),
# ],
"SpinQuantR1R2_B32_GPTQ_W4A16G32": [
get_spinquant_modifier(32, ["R1", "R2"]),
get_gptq_modifier(32),
],
"SpinQuantR1R2R4_B32_GPTQ_W4A16G32": [
get_spinquant_modifier(32, ["R1", "R2", "R4"]),
get_gptq_modifier(32),
],
}


if __name__ == "__main__":
from llmcompressor.recipe import Recipe

pipeline = Pipeline(
project_name=PROJECT_NAME,
pipeline_name=f"{PROJECT_NAME}_pipeline",
)

for model_id in [
"meta-llama/Llama-3.2-3B-Instruct",
"meta-llama/Llama-3.1-8B-Instruct",
]:
model_name = model_id.split("/")[-1].replace(".", "_").replace("-", "_")
for recipe_id, recipe_modifiers in recipes.items():
# NOTE: passing recipe in as a list of modifiers results in parsing
# errors. Use `Recipe.from_modifiers(recipe).model_dump_json()` instead
recipe = Recipe.from_modifiers(recipe_modifiers)
compress_step_name = f"compress--{model_name}--{recipe_id}"
compress_step = LLMCompressorTask(
project_name=PROJECT_NAME,
task_name=compress_step_name,
model_id=model_id,
text_samples=512,
recipe=recipe.yaml(),
)
compress_step.create_task()

# NOTE: lm_eval settings set to match those found in
# src/automation/standards/evaluations/openllm.yaml
# apply_chat_template set to False
# anmarques: "We notice that apply_chat_template tends to mess up
# loglikelihood-based evals, which are most of the openllm benchmarks
# (the model tends to blab before predicting the answer)""
eval_step = LMEvalTask(
project_name=PROJECT_NAME,
task_name=f"eval--{model_name}--{recipe_id}",
model_id="dummuy", # overridden
clearml_model=True,
tasks=[
# openllm tasks + llama variants
"arc_challenge",
"gsm8k",
"hellaswag",
"mmlu",
"winogrande",
"truthfulqa_mc2",
"arc_challenge_llama",
"gsm8k_llama",
# TODO: PPL based metrics broken in lm_eval+vllm
# https://github.com/EleutherAI/lm-evaluation-harness/issues/3134
# "wikitext"
],
num_fewshot=5,
apply_chat_template=False,
model_args=(
"gpu_memory_utilization=0.4,dtype=auto,max_model_len=4096,"
"add_bos_token=True,enable_chunked_prefill=True"
),
batch_size="auto",
)
eval_step.create_task()

pipeline.add_step(
name=compress_step_name,
base_task_id=compress_step.id,
execution_queue="oneshot-a100x1",
monitor_models=[
compress_step.get_arguments()["Args"]["save_directory"]
],
monitor_artifacts=["recipe"],
)

pipeline.add_step(
name=f"eval-{model_name}-{recipe_id}",
base_task_id=eval_step.id,
parents=[compress_step_name],
execution_queue="oneshot-a100x1",
parameter_override={
"Args/model_id": "${" + compress_step_name + ".models.output.-1.id}"
},
monitor_metrics=[
("gsm8k", "exact_match,strict-match"),
("winogrande", "acc,none"),
],
)

pipeline.execute_remotely()
61 changes: 31 additions & 30 deletions src/automation/tasks/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,29 @@

try:
from clearml import Task

clearml_available = True
except ImportError:
print("ClearML not available. Will run tasks locally and not report to ClearML.")
clearml_available = False

class BaseTask():

class BaseTask:

def __init__(
self,
project_name: str,
task_name: str,
docker_image: str,
branch: Optional[str] = DEFAULT_RESEARCH_BRANCH,
packages: Optional[Sequence[str]]=None,
task_type: str="training",
packages: Optional[Sequence[str]] = None,
task_type: str = "training",
):
branch_name = branch or DEFAULT_RESEARCH_BRANCH
base_packages = [f"git+https://github.com/neuralmagic/research.git@{branch_name}"]

base_packages = [
f"git+https://github.com/neuralmagic/research.git@{branch_name}"
]

if packages is not None:
packages = list(set(packages + base_packages))
else:
Expand All @@ -45,7 +49,6 @@ def __init__(
self.branch = branch
self.script_path = None
self.callable_artifacts = None


@property
def id(self):
Expand All @@ -55,11 +58,10 @@ def id(self):
def name(self):
return self.task_name


def process_config(self, config):
if config is None:
return {}

if config in STANDARD_CONFIGS:
return yaml.safe_load(open(STANDARD_CONFIGS[config], "r"))
elif os.path.exists(config):
Expand All @@ -69,75 +71,75 @@ def process_config(self, config):
else:
return yaml.safe_load(config)


def get_arguments(self):
return {}


def set_arguments(self):
args = self.get_arguments()
if clearml_available:
for args_name, args_dict in args.items():
self.task.connect(args_dict, args_name)

return args

return args

def get_configurations(self):
return {}


def set_configurations(self):
configurations = self.get_configurations()
if clearml_available:
for name, config in configurations.items():
self.task.connect_configuration(config, name=name)

return configurations

return configurations

def script(self, configurations, args):
raise NotImplementedError


def create_task(self):
self.task = Task.create(
project_name=self.project_name,
task_name=self.task_name,
task_type=self.task_type,
docker=self.docker_image,
packages=self.packages,
self.task: Task = Task.create(
project_name=self.project_name,
task_name=self.task_name,
task_type=self.task_type,
docker=self.docker_image,
packages=self.packages,
add_task_init_call=True,
script=self.script_path,
repo="https://github.com/neuralmagic/research.git",
branch=self.branch,
)
# To avoid precompiling VLLM when installing from main, add env var
self.task.set_base_docker(
docker_image=self.docker_image,
docker_arguments="-e VLLM_USE_PRECOMPILED=1",
)
self.task.output_uri = DEFAULT_OUTPUT_URI
self.set_arguments()
self.set_configurations()


def get_task_id(self):
if self.task is not None:
return self.task.id
else:
raise ValueError("Task ID not available since ClearML task not yet created. Try task.create_task() firts.")

raise ValueError(
"Task ID not available since ClearML task not yet created. Try task.create_task() firts."
)

def execute_remotely(self, queue_name):
if self.task is None:
self.create_task()
self.task.execute_remotely(queue_name=queue_name, clone=False, exit_process=True)

self.task.execute_remotely(
queue_name=queue_name, clone=False, exit_process=True
)

def execute_locally(self):
if clearml_available:
if self.task is not None:
raise Exception("Can only execute locally if task is not yet created.")

self.task = Task.init(
project_name=self.project_name,
task_name=self.task_name,
project_name=self.project_name,
task_name=self.task_name,
task_type=self.task_type,
auto_connect_arg_parser=False,
)
Expand All @@ -149,4 +151,3 @@ def execute_locally(self):
args = self.set_arguments()
configurations = self.set_configurations()
self.script(configurations, args)

Loading