Skip to content

Commit f65410b

Browse files
committed
add notebook for reinforce-learning
1 parent 334bb09 commit f65410b

File tree

9 files changed

+1933
-0
lines changed

9 files changed

+1933
-0
lines changed

sdk/python/foundation-models/system/reinforcement-learning/reinforcement-learning.ipynb

Lines changed: 783 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
azure-ai-ml
2+
azure-identity
3+
huggingface_hub
4+
matplotlib
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
import os
2+
import json
3+
import subprocess
4+
import pandas as pd
5+
from tempfile import TemporaryDirectory
6+
from azure.ai.ml import MLClient
7+
from azure.ai.ml.entities import Data
8+
from azure.ai.ml.constants import AssetTypes
9+
10+
11+
def register_dataset(ml_client: MLClient, dataset_name: str, file_path: str):
12+
"""Register a dataset in Azure ML."""
13+
data_asset = Data(
14+
name=dataset_name,
15+
path=file_path,
16+
type=AssetTypes.URI_FILE,
17+
description="FinQA dataset",
18+
tags={"source": "https://github.com/czyssrs/FinQA"},
19+
version="1",
20+
)
21+
22+
registered_data = ml_client.data.create_or_update(data_asset)
23+
print(f"Registered dataset {registered_data.name}.")
24+
return registered_data
25+
26+
27+
def download_finqa_dataset(src: str, target_dir: str = "data/raw"):
28+
"""Prepare the FinQA dataset for training and evaluation."""
29+
with TemporaryDirectory() as tmpdir:
30+
print(f"Cloning raw FinQA dataset to {tmpdir} ...")
31+
subprocess.run(["git", "clone", src, tmpdir], check=True)
32+
print("Converting FinQA dataset to jsonl format ...")
33+
dataset_dir = os.path.join(tmpdir, "dataset")
34+
for file_name in os.listdir(dataset_dir):
35+
target_file_name = file_name.split(".")[0] + ".jsonl"
36+
os.makedirs(target_dir, exist_ok=True)
37+
convert_to_jsonl(current_path=os.path.join(dataset_dir, file_name), target_path=os.path.join(target_dir, target_file_name))
38+
39+
40+
def convert_to_jsonl(current_path: str, target_path: str):
41+
"""Convert FinQA dataset file to jsonl format."""
42+
with open(current_path, "r") as rf, open(target_path, "w") as wf:
43+
lines = json.loads(rf.read())
44+
for item in lines:
45+
wf.write(json.dumps(item) + "\n")
46+
print(f"Converted {current_path} to {target_path}.")
47+
48+
49+
def prepare_finqa_dataset(ml_client: MLClient, data_dir: str = "data", register_datasets: bool = False) -> tuple[str, str, str]:
50+
"""Prepare the FinQA dataset for training and evaluation."""
51+
# VERL finetuning relies on acceptable data sources for reward modeling and evaluation
52+
data_source = "openai/gsm8k"
53+
54+
# download and convert dataset
55+
raw_data_dir = os.path.join(data_dir, "raw")
56+
FINQA_GIT_REPO = "https://github.com/czyssrs/FinQA"
57+
download_finqa_dataset(src=FINQA_GIT_REPO, target_dir=raw_data_dir)
58+
train_dataset_path = os.path.join(raw_data_dir, "train.jsonl")
59+
test_dataset_path = os.path.join(raw_data_dir, "test.jsonl")
60+
valid_dataset_path = os.path.join(raw_data_dir, "dev.jsonl")
61+
62+
def format_list_to_string(data_list: list):
63+
"""Convert list to string with newline separation"""
64+
if not data_list:
65+
return ""
66+
if isinstance(data_list, str):
67+
return data_list
68+
return "\n".join(str(item) for item in data_list)
69+
70+
def format_table(table_list: list):
71+
"""Format table data as string"""
72+
if not table_list:
73+
return ""
74+
table_str = "\nTable:\n"
75+
for row in table_list:
76+
if isinstance(row, list):
77+
table_str += " | ".join(str(cell) for cell in row) + "\n"
78+
else:
79+
table_str += str(row) + "\n"
80+
return table_str
81+
82+
def map_fn(example: pd.Series, idx: int, split: str):
83+
"""Map function to transform each example into desired format."""
84+
pre_instruction = "Please answer the following financial question based on the context provided."
85+
post_instruction = 'Let\'s think step by step and output the final answer after "####".'
86+
qa = example.get("qa", {})
87+
question = qa.get("question", "")
88+
answer = qa.get('answer', qa.get('exe_ans', ''))
89+
gold_evidence = "\n".join(qa.get('gold_inds', {}).values())
90+
pre_text = format_list_to_string(example.get("pre_text", []))
91+
post_text = format_list_to_string(example.get("post_text", []))
92+
table = format_table(example.get('table', [])).strip()
93+
# Build prompt content according to specified schema
94+
prompt_content = "\n\n".join([pre_instruction, "Context: " + pre_text, gold_evidence, post_text, table, "Question: " + question, post_instruction])
95+
data = {
96+
"data_source": data_source,
97+
"prompt": [
98+
{
99+
"role": "user",
100+
"content": prompt_content,
101+
}
102+
],
103+
"ability": "financial_reasoning",
104+
"reward_model": {"style": "rule", "ground_truth": answer},
105+
"extra_info": {
106+
"index": idx,
107+
"answer": answer,
108+
"question": question,
109+
"split": split,
110+
},
111+
}
112+
return data
113+
114+
# load datasets
115+
train_dataset = pd.read_json(train_dataset_path, lines=True)
116+
test_dataset = pd.read_json(test_dataset_path, lines=True)
117+
valid_dataset = pd.read_json(valid_dataset_path, lines=True)
118+
119+
# map datasets
120+
train_dataset = train_dataset.apply(lambda x: map_fn(x, x.name, split="train"), axis=1)
121+
test_dataset = test_dataset.apply(lambda x: map_fn(x, x.name, split="test"), axis=1)
122+
valid_dataset = valid_dataset.apply(lambda x: map_fn(x, x.name, split="valid"), axis=1)
123+
124+
# save locally as jsonl
125+
train_dataset_path = os.path.join(data_dir, "train.jsonl")
126+
test_dataset_path = os.path.join(data_dir, "test.jsonl")
127+
valid_dataset_path = os.path.join(data_dir, "valid.jsonl")
128+
train_dataset.to_json(train_dataset_path, orient="records", lines=True)
129+
test_dataset.to_json(test_dataset_path, orient="records", lines=True)
130+
valid_dataset.to_json(valid_dataset_path, orient="records", lines=True)
131+
132+
# register datasets
133+
if register_datasets:
134+
train_data = register_dataset(ml_client, "finqa_train", train_dataset_path)
135+
test_data = register_dataset(ml_client, "finqa_test", test_dataset_path)
136+
valid_data = register_dataset(ml_client, "finqa_valid", valid_dataset_path)
137+
if (train_data and train_data.id) and (test_data and test_data.id) and (valid_data and valid_data.id):
138+
return train_data.id, test_data.id, valid_data.id
139+
140+
return train_dataset_path, test_dataset_path, valid_dataset_path
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
import uuid
2+
import requests
3+
from typing import Optional
4+
from azure.ai.ml import MLClient
5+
from azure.ai.ml.entities import (
6+
EndpointAuthKeys,
7+
ManagedOnlineEndpoint,
8+
ManagedOnlineDeployment,
9+
KubernetesOnlineEndpoint,
10+
KubernetesOnlineDeployment,
11+
ProbeSettings,
12+
OnlineRequestSettings,
13+
)
14+
15+
16+
def get_default_probe_settings() -> ProbeSettings:
17+
"""Get default probe settings for deployments."""
18+
return ProbeSettings( # Probes are APIs exposed by the deployment which informs the frameworktraffic
19+
initial_delay=1400, # if the deployment is healthy and ready to receive
20+
period=30,
21+
timeout=2,
22+
success_threshold=1,
23+
failure_threshold=30
24+
)
25+
26+
27+
def get_default_request_settings() -> OnlineRequestSettings:
28+
"""Get default request settings for deployments."""
29+
return OnlineRequestSettings( # Online request setting which controls timeout and concurrent request per instance
30+
request_timeout_ms=90000,
31+
max_concurrent_requests_per_instance=4,
32+
)
33+
34+
35+
def create_managed_deployment(
36+
ml_client: MLClient,
37+
model_asset_id: str, # Asset ID of the model to deploy
38+
instance_type: str, # Supported instance type for managed deployment
39+
environment_asset_id: Optional[str] = None, # Asset ID of the serving engine to use
40+
endpoint_name: Optional[str] = None,
41+
endpoint_description: str = "Sample endpoint",
42+
endpoint_tags: dict = {},
43+
deployment_name: Optional[str] = None,
44+
deployment_env_vars: dict = {},
45+
) -> str:
46+
"""Create a managed deployment."""
47+
guid = str(uuid.uuid4())[:8] # Unique suffix to avoid name collisions
48+
endpoint_name = endpoint_name or f"rl-endpoint"
49+
endpoint_name = f"{endpoint_name}-{guid}" # Unique names prevent collisions and allow parallel experiments
50+
deployment_name = deployment_name or "default"
51+
52+
endpoint = ManagedOnlineEndpoint( # Use AzureML endpoint abstraction for traffic management and auth
53+
name=endpoint_name,
54+
auth_mode="key",
55+
description=endpoint_description,
56+
tags=endpoint_tags,
57+
)
58+
59+
print(f"Creating endpoint: {endpoint_name}")
60+
ml_client.online_endpoints.begin_create_or_update(endpoint).wait() # Using there the endpoint object to trigger actual endpoint in AML workspace.
61+
62+
deployment = ManagedOnlineDeployment( # Use deployment abstraction for scaling, versioning, and isolation
63+
name=deployment_name,
64+
endpoint_name=endpoint_name,
65+
model=model_asset_id,
66+
instance_type=instance_type,
67+
instance_count=1,
68+
environment=environment_asset_id,
69+
environment_variables=deployment_env_vars,
70+
liveness_probe=get_default_probe_settings(),
71+
readiness_probe=get_default_probe_settings(),
72+
request_settings=get_default_request_settings(),
73+
)
74+
75+
print(f"Creating deployment (15-20 min)...") #
76+
ml_client.online_deployments.begin_create_or_update(deployment).wait()
77+
78+
# Route all traffic to new deployment for immediate use
79+
endpoint.traffic = {deployment_name: 100}
80+
ml_client.online_endpoints.begin_create_or_update(endpoint).result()
81+
82+
print(f"Endpoint ready: {endpoint_name}")
83+
84+
return endpoint_name
85+
86+
87+
def create_kubernetes_deployment(
88+
ml_client: MLClient,
89+
model_asset_id: str, # Asset ID of the model to deploy
90+
environment_asset_id: str, # Asset ID of the serving engine to use
91+
instance_type: str, # Kubernetes supports partial node usage granular upto the GPU level
92+
compute_name: str, # Name of the compute which will be use for endpoint creation
93+
endpoint_name: Optional[str] = None,
94+
endpoint_description: str = "Sample endpoint",
95+
endpoint_tags: dict = {},
96+
deployment_name: Optional[str] = None,
97+
deployment_env_vars: dict = {},
98+
model_mount_path: str = "/var/model-mount",
99+
) -> str:
100+
"""Create endpoint using Kubernetes."""
101+
102+
print("🌐 Creating endpoint...")
103+
104+
guid = str(uuid.uuid4())[:8] # Unique suffix to avoid name collisions
105+
endpoint_name = endpoint_name or f"rl-endpoint"
106+
endpoint_name = f"{endpoint_name}-{guid}" # Unique names prevent collisions and allow parallel experiments
107+
deployment_name = deployment_name or "default"
108+
109+
endpoint = KubernetesOnlineEndpoint( # Use AzureML endpoint abstraction for traffic management and auth
110+
name=endpoint_name,
111+
auth_mode="key",
112+
compute=compute_name,
113+
description=endpoint_description,
114+
tags=endpoint_tags,
115+
)
116+
117+
print(f"Creating endpoint: {endpoint_name}")
118+
ml_client.online_endpoints.begin_create_or_update(endpoint).wait() # Using there the endpoint object to trigger actual endpoint in AML workspace.
119+
120+
deployment = KubernetesOnlineDeployment( # Use deployment abstraction for scaling, versioning, and isolation
121+
name=deployment_name,
122+
endpoint_name=endpoint_name,
123+
model=model_asset_id,
124+
model_mount_path=model_mount_path,
125+
instance_type=instance_type,
126+
instance_count=1,
127+
environment=environment_asset_id,
128+
environment_variables=deployment_env_vars,
129+
liveness_probe=get_default_probe_settings(),
130+
readiness_probe=get_default_probe_settings(),
131+
request_settings=get_default_request_settings(),
132+
)
133+
134+
print(f"Creating deployment (15-20 min)...") #
135+
ml_client.online_deployments.begin_create_or_update(deployment).wait()
136+
137+
# Route all traffic to new deployment for immediate use
138+
endpoint.traffic = {deployment_name: 100}
139+
ml_client.online_endpoints.begin_create_or_update(endpoint).result()
140+
141+
print(f"Endpoint ready: {endpoint_name}")
142+
143+
return endpoint_name
144+
145+
146+
def test_deployment(ml_client, endpoint_name):
147+
"""Run a test request against a deployed endpoint and print the result."""
148+
print("Testing endpoint...")
149+
# Retrieve endpoint URI and API key to authenticate test request
150+
scoring_uri = ml_client.online_endpoints.get(endpoint_name).scoring_uri
151+
if not scoring_uri:
152+
raise ValueError("Scoring URI not found for endpoint.")
153+
154+
api_keys = ml_client.online_endpoints.get_keys(endpoint_name)
155+
if not isinstance(api_keys, EndpointAuthKeys) or not api_keys.primary_key:
156+
raise ValueError("API key not found for endpoint.")
157+
158+
# Use a realistic financial question to verify model reasoning and output format
159+
payload = {
160+
"messages": [
161+
{
162+
"role": "user",
163+
"content": """Please answer the following financial question:
164+
165+
Context: A company has revenue of $1,000,000 and expenses of $750,000.
166+
167+
Question: What is the profit margin as a percentage?
168+
Let's think step by step and put final answer after ####."""
169+
}
170+
],
171+
"max_tokens": 512,
172+
"temperature": 0.7,
173+
}
174+
175+
# Set headers for JSON content and bearer authentication
176+
headers = {
177+
"Content-Type": "application/json",
178+
"Authorization": f"Bearer {api_keys.primary_key}",
179+
}
180+
181+
response = requests.post(scoring_uri, json=payload, headers=headers)
182+
183+
if response.status_code == 200:
184+
result = response.json()
185+
# Extract the model response
186+
if "choices" in result and len(result["choices"]) > 0:
187+
answer = result["choices"][0]["message"]["content"]
188+
print(f"Response received")
189+
print(f"\n{'='*60}")
190+
print(answer)
191+
print(f"{'='*60}\n")
192+
return result
193+
else:
194+
print(f" ✗ Error: {response.status_code}")
195+
print(f" {response.text}")
196+
return None

0 commit comments

Comments
 (0)