Skip to content

Commit 32ecb16

Browse files
[QEff. Finetune] Updated handling of custom dataset in FT. Updated finetune.md readme file. (#520)
- Introduced the handling of custom dataset via --dataset_config argument. This argument expects a json file which has parameters to enable custom preprocessing for any dataset. - Updated the docs to reflect the changes in the interface of custom dataset usage. --------- Signed-off-by: meetkuma <meetkuma@qti.qualcomm.com>
1 parent 7e0ad94 commit 32ecb16

File tree

11 files changed

+268
-61
lines changed

11 files changed

+268
-61
lines changed

QEfficient/cloud/finetune.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -288,11 +288,10 @@ def main(**kwargs) -> None:
288288
--model_name "meta-llama/Llama-3.2-1B" \\
289289
--lr 5e-4
290290
"""
291-
# TODO:Remove TrainConfig() and update_config() as all params are passed in kwargs by parser
292291
train_config = TrainConfig()
293292
update_config(train_config, **kwargs)
294-
dataset_config = generate_dataset_config(train_config.dataset)
295-
update_config(dataset_config, **kwargs)
293+
custom_dataset_config_file = kwargs.pop("custom_dataset_config", None)
294+
dataset_config = generate_dataset_config(train_config.dataset, custom_dataset_config_file)
296295

297296
logger.prepare_for_logs(train_config.output_dir, train_config.dump_logs, train_config.log_level)
298297

QEfficient/finetune/configs/dataset_config.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,5 @@ class imdb_dataset:
4141
@dataclass
4242
class custom_dataset:
4343
dataset: str = "custom_dataset"
44-
file: str = "dataset/custom_dataset.py"
4544
train_split: str = "train"
4645
test_split: str = "validation"
47-
data_path: str = ""
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
{
2+
"r": 32,
3+
"lora_alpha": 64,
4+
"target_modules": [
5+
"q_proj",
6+
"k_proj",
7+
"v_proj",
8+
"o_proj",
9+
"up_proj",
10+
"down_proj",
11+
"gate_proj"
12+
],
13+
"bias": "none",
14+
"task_type": "CAUSAL_LM",
15+
"lora_dropout": 0.05,
16+
"inference_mode": false
17+
}

QEfficient/finetune/dataset/custom_dataset.py

Lines changed: 45 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
# -----------------------------------------------------------------------------
77

88
import importlib
9+
import logging
910
from pathlib import Path
1011

1112
from QEfficient.finetune.utils.logging_utils import logger
@@ -26,51 +27,81 @@ def load_module_from_py_file(py_file: str) -> object:
2627

2728

2829
def get_custom_dataset(dataset_config, tokenizer, split: str, context_length=None):
29-
if ":" in dataset_config.file:
30-
module_path, func_name = dataset_config.file.split(":")
31-
else:
32-
module_path, func_name = dataset_config.file, "get_custom_dataset"
30+
if not hasattr(dataset_config, "preproc_file"):
31+
logger.raise_error("Can not find preproc_file key in dataset_config file.", RuntimeError)
32+
33+
if ":" not in dataset_config.preproc_file:
34+
logger.raise_error(
35+
"The 'preproc_file' key in dataset_config file should follow the format: python_file_path:function_name",
36+
RuntimeError,
37+
)
38+
39+
module_path, func_name = dataset_config.preproc_file.split(":")
40+
logger.log_rank_zero(
41+
f"Using '{func_name}' function from {module_path} as preprocessing function in dataset preprocessing.",
42+
logging.DEBUG,
43+
)
3344

3445
if not module_path.endswith(".py"):
35-
logger.raise_error(f"Dataset file {module_path} is not a .py file.", ValueError)
46+
logger.raise_error(f"Custom dataset preprocessing file {module_path} is not a .py file.", ValueError)
3647

3748
module_path = Path(module_path)
3849
if not module_path.is_file():
3950
logger.raise_error(
40-
f"Dataset py file {module_path.as_posix()} does not exist or is not a file.", FileNotFoundError
51+
f"Custom dataset file {module_path.as_posix()} does not exist or is not a file.", FileNotFoundError
4152
)
4253

4354
module = load_module_from_py_file(module_path.as_posix())
4455
try:
4556
return getattr(module, func_name)(dataset_config, tokenizer, split, context_length)
4657
except AttributeError:
4758
logger.raise_error(
48-
f"It seems like the given method name ({func_name}) is not present in the dataset .py file ({module_path.as_posix()}).",
59+
f"For custom dataset preprocessing, the method ({func_name}) is not "
60+
f"present in the file ({module_path.as_posix()}).",
4961
AttributeError,
5062
)
5163

5264

5365
def get_data_collator(dataset_processer, dataset_config):
54-
if ":" in dataset_config.file:
55-
module_path, func_name = dataset_config.file.split(":")
66+
if not hasattr(dataset_config, "collate_file"):
67+
logger.log_rank_zero(
68+
"Can not find collate_file key in dataset_config file. Using the default data collator function instead.",
69+
logging.WARNING,
70+
)
71+
return None
72+
73+
if ":" not in dataset_config.collate_file:
74+
logger.log_rank_zero(
75+
"Can not find function name in 'collate_file' key in dataset_config "
76+
"file. Using the default data collator function instead. If this is "
77+
"not intended then change the format of the 'collate_file' key in "
78+
"dataset_config file to follow the format: python_file_path:function_name",
79+
logging.WARNING,
80+
)
81+
return None
5682
else:
57-
module_path, func_name = dataset_config.file, "get_data_collator"
83+
module_path, func_name = dataset_config.collate_file.split(":")
84+
logger.log_rank_zero(
85+
f"Using '{func_name}' function from {module_path} as collate_fn in dataset preprocessing.",
86+
logging.DEBUG,
87+
)
5888

5989
if not module_path.endswith(".py"):
60-
logger.raise_error(f"Dataset file {module_path} is not a .py file.", ValueError)
90+
logger.raise_error(f"Custom dataset collate file {module_path} is not a .py file.", ValueError)
6191

6292
module_path = Path(module_path)
6393
if not module_path.is_file():
6494
logger.raise_error(
65-
f"Dataset py file {module_path.as_posix()} does not exist or is not a file.", FileNotFoundError
95+
f"Custom dataset collate file {module_path.as_posix()} does not exist or is not a file.", FileNotFoundError
6696
)
6797

6898
module = load_module_from_py_file(module_path.as_posix())
6999
try:
70100
return getattr(module, func_name)(dataset_processer)
71101
except AttributeError:
72102
logger.log_rank_zero(
73-
f"Can not find the custom data_collator in the dataset.py file ({module_path.as_posix()})."
103+
f"Can not find the function {func_name} in file "
104+
f"({module_path.as_posix()}). Using the default data collator "
105+
"function instead."
74106
)
75-
logger.log_rank_zero("Using the default data_collator instead.")
76107
return None
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"train_split": "train",
3+
"test_split": "test",
4+
"test_split_ratio": 0.15,
5+
"preproc_file": "./QEfficient/finetune/dataset/custom_dataset/disc_preproc.py:get_preprocessed_disc",
6+
"disc_style": "sarcasm_more"
7+
}
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
# -----------------------------------------------------------------------------
2+
#
3+
# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries.
4+
# SPDX-License-Identifier: BSD-3-Clause
5+
#
6+
# -----------------------------------------------------------------------------
7+
8+
9+
import datasets
10+
from transformers.data import DataCollatorForSeq2Seq
11+
12+
13+
def get_data_collator(tokenizer):
14+
return DataCollatorForSeq2Seq(tokenizer)
15+
16+
17+
def get_preprocessed_disc(dataset_config, tokenizer, split, context_length=None):
18+
dataset = datasets.load_dataset("hallisky/DiSC")
19+
20+
# Considering 'train' split as this dataset has only one split.
21+
dataset = dataset["train"]
22+
23+
test_split_ratio = dataset_config.test_split_ratio
24+
disc_style = dataset_config.disc_style
25+
26+
# Only collect the samples for a given style.
27+
available_styles = set(dataset["category"])
28+
if disc_style not in available_styles:
29+
raise RuntimeError(f"For DiSC dataset the provided disc_style '{disc_style}' is not supported.")
30+
31+
dataset = dataset.filter(lambda example: example["category"] == disc_style)
32+
33+
# Shuffle the dataset before splitting
34+
dataset = dataset.shuffle(seed=42)
35+
36+
# Split the data in train and test split.
37+
total_samples = len(dataset)
38+
test_size = int(total_samples * test_split_ratio)
39+
train_size = total_samples - test_size
40+
41+
if split == "test":
42+
indices = range(train_size, total_samples)
43+
else:
44+
indices = range(0, train_size)
45+
46+
dataset = dataset.select(indices)
47+
48+
if tokenizer.pad_token is None:
49+
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
50+
51+
# Below is the template of the DiSC dataset.
52+
# <bos>### Original:{original} \n ### Rewrite: {rewrite} <eos>
53+
template = "### Original:{original} \n ### Rewrite: "
54+
55+
def apply_prompt_template(sample):
56+
return {
57+
"input": template.format(original=sample["original"]),
58+
"label": sample["generation"],
59+
}
60+
61+
dataset = dataset.map(apply_prompt_template, remove_columns=list(dataset.features))
62+
63+
def tokenize_add_label(sample):
64+
input = tokenizer.encode(
65+
tokenizer.bos_token + sample["input"],
66+
add_special_tokens=False,
67+
max_length=context_length,
68+
pad_to_max_length=True,
69+
)
70+
label = tokenizer.encode(
71+
sample["label"] + tokenizer.pad_token + tokenizer.eos_token,
72+
add_special_tokens=False,
73+
max_length=context_length,
74+
pad_to_max_length=True,
75+
)
76+
77+
sample = {
78+
"input_ids": (input + label),
79+
"attention_mask": [1] * (len(input) + len(label)),
80+
"labels": [-100] * len(input) + label,
81+
}
82+
83+
return sample
84+
85+
dataset = dataset.map(tokenize_add_label, remove_columns=list(dataset.features))
86+
87+
return dataset

QEfficient/finetune/dataset/dataset_config.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
#
66
# -----------------------------------------------------------------------------
77

8-
from functools import partial
98

109
from QEfficient.finetune.dataset.alpaca_dataset import (
1110
InstructionDataset as get_alpaca_dataset,
@@ -23,7 +22,7 @@
2322
)
2423

2524
DATASET_PREPROC = {
26-
"alpaca_dataset": partial(get_alpaca_dataset),
25+
"alpaca_dataset": get_alpaca_dataset,
2726
"grammar_dataset": get_grammar_dataset,
2827
"gsm8k_dataset": get_gsm8k_dataset,
2928
"custom_dataset": get_custom_dataset,

QEfficient/finetune/utils/config_utils.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@
88
import inspect
99
import json
1010
import os
11+
from collections import namedtuple
1112
from dataclasses import asdict
12-
from typing import Any, Dict
13+
from typing import Any, Dict, Optional
1314

1415
import yaml
1516
from peft import LoraConfig as PeftLoraConfig
1617

17-
import QEfficient.finetune.configs.dataset_config as datasets
18+
import QEfficient.finetune.configs.dataset_config as qeff_datasets
1819
from QEfficient.finetune.configs.peft_config import LoraConfig
1920
from QEfficient.finetune.configs.training import TrainConfig
2021
from QEfficient.finetune.dataset.dataset_config import DATASET_PREPROC
@@ -86,11 +87,14 @@ def generate_peft_config(train_config: TrainConfig, **kwargs) -> Any:
8687
return peft_config
8788

8889

89-
def generate_dataset_config(dataset_name: str) -> Any:
90+
def generate_dataset_config(dataset_name: str, custom_dataset_config: Optional[str] = None) -> Any:
9091
"""Generate a dataset configuration based on the specified dataset.
9192
9293
Args:
9394
dataset_name (str): Name of the dataset to be used for finetuning.
95+
custom_dataset_config (str): Dataset config json file for custom datset.
96+
This file contains dataset specific arguments to be used in dataset
97+
preprocessing step.
9498
9599
Returns:
96100
Any: A dataset configuration object.
@@ -101,7 +105,20 @@ def generate_dataset_config(dataset_name: str) -> Any:
101105
supported_datasets = DATASET_PREPROC.keys()
102106
assert dataset_name in supported_datasets, f"Given dataset '{dataset_name}' is not supported."
103107
# FIXME (Meet): Replace below logic by creating using auto registry of datasets.
104-
dataset_config = {k: v for k, v in inspect.getmembers(datasets)}[dataset_name]()
108+
dataset_config = {k: v for k, v in inspect.getmembers(qeff_datasets)}[dataset_name]()
109+
if dataset_name == "custom_dataset":
110+
if custom_dataset_config is None:
111+
logger.raise_error(
112+
"For 'custom_dataset', please provide dataset config file via 'custom_dataset_config' flag.",
113+
RuntimeError,
114+
)
115+
custom_dataset_dict = asdict(dataset_config)
116+
custom_dataset_dict_override = load_config_file(custom_dataset_config)
117+
# Override existing and add new params to dataset_config.
118+
custom_dataset_dict.update(custom_dataset_dict_override)
119+
120+
custom_dataset_class = namedtuple("custom_dataset", custom_dataset_dict.keys())
121+
dataset_config = custom_dataset_class(**custom_dataset_dict)
105122
return dataset_config
106123

107124

QEfficient/finetune/utils/dataset_utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,9 @@ def get_dataloader_kwargs(train_config, dataset, dataset_processer, split):
6464
kwargs["drop_last"] = False
6565
else:
6666
kwargs["batch_size"] = batch_size
67-
kwargs["drop_last"] = False
68-
kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer)
67+
kwargs["drop_last"] = True
68+
# todo: -100 should be changed to a variable. or tokenizer.pad_token_id
69+
kwargs["collate_fn"] = DataCollatorForSeq2Seq(dataset_processer, label_pad_token_id=-100)
6970
return kwargs
7071

7172

QEfficient/finetune/utils/parser.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,13 @@ def get_finetune_parser():
4343
default=None,
4444
help="Name of the tokenizer,if not passed as an argument, it uses the value of model_name",
4545
)
46+
parser.add_argument(
47+
"--custom_dataset_config",
48+
"--custom-dataset-config",
49+
type=str,
50+
default=None,
51+
help="Path of custom dataset config json file to override the custom dataset params such as test_split_ratio, test_split etc.",
52+
)
4653
parser.add_argument(
4754
"--run_validation",
4855
"--run-validation",

0 commit comments

Comments
 (0)