|
| 1 | +import dspy |
| 2 | +import random |
| 3 | +from dspy.datasets import DataLoader |
| 4 | +from datasets import load_dataset |
| 5 | + |
| 6 | +# Load the Banking77 dataset. |
| 7 | +CLASSES = load_dataset("PolyAI/banking77", split="train", trust_remote_code=True).features['label'].names |
| 8 | +kwargs = dict(fields=("text", "label"), input_keys=("text",), split="train", trust_remote_code=True) |
| 9 | + |
| 10 | +# Load the first 2000 examples from the dataset, and assign a hint to each *training* example. |
| 11 | +raw_data = [ |
| 12 | + dspy.Example(x, label=CLASSES[x.label]).with_inputs("text") |
| 13 | + for x in DataLoader().from_huggingface(dataset_name="PolyAI/banking77", **kwargs)[:2000] |
| 14 | +] |
| 15 | + |
| 16 | +random.Random(0).shuffle(raw_data) |
| 17 | +print(len(CLASSES), CLASSES[:10]) |
| 18 | + |
| 19 | +trainset = raw_data[:1500] # 1500 examples for training |
| 20 | +valset = raw_data[1500:1600] # 100 examples for validation |
| 21 | +print(trainset[0]) |
| 22 | + |
| 23 | +classify = dspy.ChainOfThought(f"text -> label: Literal{CLASSES}") |
| 24 | + |
| 25 | +from dspy.clients.lm_local_arbor import ArborProvider |
| 26 | +port = 7453 |
| 27 | +arbor_api_base = f"http://localhost:{port}/v1/" |
| 28 | +api_key = "arbor" |
| 29 | +provider = ArborProvider() |
| 30 | + |
| 31 | +# student_lm_name = "meta-llama/Llama-3.2-3B-Instruct" |
| 32 | +# student_lm_name = "Qwen/Qwen2.5-1.5B-Instruct" |
| 33 | +student_lm_name = "Qwen/Qwen3-0.6B" |
| 34 | +student_lm = dspy.LM(model=f"openai/arbor:{student_lm_name}", provider=provider, temperature=0.5, api_base=arbor_api_base, api_key=api_key, max_tokens=3000) |
| 35 | + |
| 36 | +student_classify = classify.deepcopy() |
| 37 | +student_classify.set_lm(student_lm) |
| 38 | + |
| 39 | +metric = (lambda x, y, trace=None: x.label == y.label) |
| 40 | + |
| 41 | + |
| 42 | +from dspy.teleprompt.grpo import GRPO |
| 43 | +train_kwargs = { |
| 44 | + "update_interval": 10, |
| 45 | + "per_device_train_batch_size": 4, |
| 46 | + "temperature": 0.5, |
| 47 | + "beta": 0.04, |
| 48 | + "learning_rate": 1e-5, |
| 49 | + "gradient_accumulation_steps": 1, |
| 50 | + "gradient_checkpointing": True, |
| 51 | + "bf16": True, |
| 52 | + "lr_scheduler_type": "constant_with_warmup", |
| 53 | +} |
| 54 | + |
| 55 | +compiler = GRPO( |
| 56 | + metric=metric, |
| 57 | + multitask=True, |
| 58 | + num_dspy_examples_per_grpo_step=4, |
| 59 | + num_samples_per_input=4, |
| 60 | + exclude_demos=True, |
| 61 | + num_train_steps=500, |
| 62 | + num_threads=4, |
| 63 | + use_train_as_val=False, |
| 64 | + num_steps_for_val=100, |
| 65 | + train_kwargs=train_kwargs, |
| 66 | +) |
| 67 | + |
| 68 | +classify_ft = compiler.compile( |
| 69 | + student=student_classify, |
| 70 | + trainset=trainset, |
| 71 | + valset=valset, |
| 72 | +) |
| 73 | + |
| 74 | +# evaluate = dspy.Evaluate(devset=valset, metric=metric, display_progress=True, display_table=5, num_threads=16) |
| 75 | +# print(evaluate(classify_ft)) |
0 commit comments