Skip to content

Commit 7ca347c

Browse files
ZiemsNoah Ziems
andauthored
Param Passthrough and Consistent Tutorial Script (stanfordnlp#3)
* Add param passthrough and default banking77 tutorial * Add more threads * Update banking tutorial --------- Co-authored-by: Noah Ziems <nziems2@nziems2@nd.edu>
1 parent 8220c8f commit 7ca347c

File tree

2 files changed

+94
-1
lines changed

2 files changed

+94
-1
lines changed

banking77_tutorial.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
import dspy
2+
import random
3+
from dspy.datasets import DataLoader
4+
from datasets import load_dataset
5+
6+
# Load the Banking77 dataset.
7+
CLASSES = load_dataset("PolyAI/banking77", split="train", trust_remote_code=True).features['label'].names
8+
kwargs = dict(fields=("text", "label"), input_keys=("text",), split="train", trust_remote_code=True)
9+
10+
# Load the first 2000 examples from the dataset, and assign a hint to each *training* example.
11+
raw_data = [
12+
dspy.Example(x, label=CLASSES[x.label]).with_inputs("text")
13+
for x in DataLoader().from_huggingface(dataset_name="PolyAI/banking77", **kwargs)[:2000]
14+
]
15+
16+
random.Random(0).shuffle(raw_data)
17+
print(len(CLASSES), CLASSES[:10])
18+
19+
trainset = raw_data[:1500] # 1500 examples for training
20+
valset = raw_data[1500:1600] # 100 examples for validation
21+
print(trainset[0])
22+
23+
classify = dspy.ChainOfThought(f"text -> label: Literal{CLASSES}")
24+
25+
from dspy.clients.lm_local_arbor import ArborProvider
26+
port = 7453
27+
arbor_api_base = f"http://localhost:{port}/v1/"
28+
api_key = "arbor"
29+
provider = ArborProvider()
30+
31+
# student_lm_name = "meta-llama/Llama-3.2-3B-Instruct"
32+
# student_lm_name = "Qwen/Qwen2.5-1.5B-Instruct"
33+
student_lm_name = "Qwen/Qwen3-0.6B"
34+
student_lm = dspy.LM(model=f"openai/arbor:{student_lm_name}", provider=provider, temperature=0.5, api_base=arbor_api_base, api_key=api_key, max_tokens=3000)
35+
36+
student_classify = classify.deepcopy()
37+
student_classify.set_lm(student_lm)
38+
39+
metric = (lambda x, y, trace=None: x.label == y.label)
40+
41+
42+
from dspy.teleprompt.grpo import GRPO
43+
train_kwargs = {
44+
"update_interval": 10,
45+
"per_device_train_batch_size": 4,
46+
"temperature": 0.5,
47+
"beta": 0.04,
48+
"learning_rate": 1e-5,
49+
"gradient_accumulation_steps": 1,
50+
"gradient_checkpointing": True,
51+
"bf16": True,
52+
"lr_scheduler_type": "constant_with_warmup",
53+
}
54+
55+
compiler = GRPO(
56+
metric=metric,
57+
multitask=True,
58+
num_dspy_examples_per_grpo_step=4,
59+
num_samples_per_input=4,
60+
exclude_demos=True,
61+
num_train_steps=500,
62+
num_threads=4,
63+
use_train_as_val=False,
64+
num_steps_for_val=100,
65+
train_kwargs=train_kwargs,
66+
)
67+
68+
classify_ft = compiler.compile(
69+
student=student_classify,
70+
trainset=trainset,
71+
valset=valset,
72+
)
73+
74+
# evaluate = dspy.Evaluate(devset=valset, metric=metric, display_progress=True, display_table=5, num_threads=16)
75+
# print(evaluate(classify_ft))

dspy/clients/lm_local_arbor.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,15 @@ def status(self) -> TrainingStatus:
4444

4545
class ArborReinforceJob(ReinforceJob):
4646
DEFAULT_TRAIN_KWARGS = {
47-
"update_interval": 5,
47+
"update_interval": 10,
4848
"temperature": 0.9,
4949
"beta": 0.04,
50+
"per_device_train_batch_size": 8,
51+
"learning_rate": 1e-5,
52+
"gradient_accumulation_steps": 1,
53+
"gradient_checkpointing": True,
54+
"bf16": True,
55+
"lr_scheduler_type": "constant_with_warmup",
5056
}
5157

5258
def __init__(self, lm: "LM", train_kwargs: Dict[str, Any]):
@@ -66,6 +72,12 @@ def initialize(self):
6672
update_interval = self.train_kwargs.get("update_interval", self.DEFAULT_TRAIN_KWARGS["update_interval"])
6773
temperature = self.train_kwargs.get("temperature", self.DEFAULT_TRAIN_KWARGS["temperature"])
6874
beta = self.train_kwargs.get("beta", self.DEFAULT_TRAIN_KWARGS["beta"])
75+
per_device_train_batch_size = self.train_kwargs.get("per_device_train_batch_size", self.DEFAULT_TRAIN_KWARGS["per_device_train_batch_size"])
76+
learning_rate = self.train_kwargs.get("learning_rate", self.DEFAULT_TRAIN_KWARGS["learning_rate"])
77+
gradient_accumulation_steps = self.train_kwargs.get("gradient_accumulation_steps", self.DEFAULT_TRAIN_KWARGS["gradient_accumulation_steps"])
78+
gradient_checkpointing = self.train_kwargs.get("gradient_checkpointing", self.DEFAULT_TRAIN_KWARGS["gradient_checkpointing"])
79+
bf16 = self.train_kwargs.get("bf16", self.DEFAULT_TRAIN_KWARGS["bf16"])
80+
lr_scheduler_type = self.train_kwargs.get("lr_scheduler_type", self.DEFAULT_TRAIN_KWARGS["lr_scheduler_type"])
6981

7082
api_base = self.lm.kwargs["api_base"]
7183
# api_key = self.lm.kwargs["api_key"]
@@ -79,6 +91,12 @@ def initialize(self):
7991
'update_interval': update_interval,
8092
'temperature': temperature,
8193
'beta': beta,
94+
'per_device_train_batch_size': per_device_train_batch_size,
95+
'learning_rate': learning_rate,
96+
'gradient_accumulation_steps': gradient_accumulation_steps,
97+
'gradient_checkpointing': gradient_checkpointing,
98+
'bf16': bf16,
99+
'lr_scheduler_type': lr_scheduler_type,
82100
}
83101
url = f"{api_base}fine_tuning/grpo/initialize"
84102
headers = {'Content-Type': 'application/json'}

0 commit comments

Comments
 (0)