Skip to content

Commit c130706

Browse files
pcuencaFL33TW00D
andauthored
Stateful cache, MLTensor (#257)
* feat: preview * Remove Random (#115) * Throwing error when the configs fail JSON serialization (#114) * Added error for JSON serialization errors * Fix merge commit --------- Co-authored-by: Pedro Cuenca <pedro@huggingface.co> * Allow archiving for Mac (#121) SPM dependencies are always compiled for the standard architectures, but Float16 is not available for `x86_64`. Thanks @joshnewnham for the workaround 🙌 * chore: strategic deletes avoid OOM * Remove RepetitionPenaltyWarper, fix build * Remove GenerationTests * Restore TokenizerError * Fix deprecation warnings in tests * Move transformers-cli to an example * Format * Relax requirements for main package But keep iOS 18 / macOS 15 for Core ML * Revert platform requirements * Relative package location plus comment * Mistral example: uv-ify and unpin * Remove obsolete GenerationTests again --------- Co-authored-by: FL33TW00D <FL33TW00D@users.noreply.github.com>
1 parent 105b915 commit c130706

28 files changed

+945
-1174
lines changed

Examples/Mistral7B/README.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
### Export Mistral 7B Instruct v0.3
2+
3+
```shell
4+
✗ uv run export.py
5+
6+
Loading checkpoint shards: 100%|███████████████████████████| 3/3 [00:12<00:00, 4.11s/it]
7+
Converting PyTorch Frontend ==> MIL Ops: 100%|███| 5575/5575 [00:02<00:00, 2440.66 ops/s]
8+
Running MIL frontend_pytorch pipeline: 100%|██████████| 5/5 [00:00<00:00, 7.12 passes/s]
9+
Running MIL default pipeline: 100%|█████████████████| 79/79 [02:36<00:00, 1.98s/ passes]
10+
Running MIL backend_mlprogram pipeline: 100%|███████| 12/12 [00:00<00:00, 22.90 passes/s]
11+
Running compression: 100%|███████████████████████████| 296/296 [03:04<00:00, 1.60 ops/s]
12+
...
13+
```
14+
15+
### Generate Text
16+
17+
```shell
18+
✗ swift run transformers-cli "Best recommendations for a place to visit in Paris in August 2024:" --max-length 128 StatefulMistral7BInstructInt4.mlpackage
19+
20+
Best recommendations for a place to visit in Paris in August 2024:
21+
22+
1. Palace of Versailles: This iconic palace is a must-visit. It's a short train ride from Paris and offers a glimpse into the opulence of the French monarchy.
23+
24+
2. Eiffel Tower: No trip to Paris is complete without a visit to the Eiffel Tower. You can take an elevator ride to the top for a stunning view of the city.
25+
26+
3. Louvre Museum: Home to thousands of works of art, including the Mona Lisa and the Winged Victory of Samothrace, the Louvre is a cultural treasure.
27+
```

Examples/Mistral7B/export.py

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
# /// script
2+
# requires-python = ">=3.10"
3+
# dependencies = [
4+
# "coremltools",
5+
# "numpy",
6+
# "sentencepiece",
7+
# "torch",
8+
# "tqdm",
9+
# "transformers",
10+
# ]
11+
# ///
12+
import logging
13+
import os
14+
import warnings
15+
from typing import List, Optional, Tuple
16+
17+
import coremltools as ct
18+
import numpy as np
19+
import torch
20+
from transformers.cache_utils import Cache
21+
from transformers.models.mistral.modeling_mistral import (
22+
MISTRAL_ATTENTION_CLASSES,
23+
MistralAttention,
24+
MistralConfig,
25+
MistralForCausalLM,
26+
apply_rotary_pos_emb,
27+
repeat_kv,
28+
)
29+
30+
warnings.filterwarnings("ignore")
31+
logging.getLogger("coremltools").setLevel(logging.ERROR)
32+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
33+
34+
# https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3
35+
MODEL_ID: str = "mistralai/Mistral-7B-Instruct-v0.3"
36+
METADATA_TOKENIZER: str = "co.huggingface.exporters.name"
37+
38+
39+
class SliceUpdateKeyValueCache(Cache):
40+
def __init__(
41+
self,
42+
shape: Tuple[int, ...],
43+
device="cpu",
44+
dtype=torch.float32,
45+
) -> None:
46+
"""KV cache of shape (#layers, batch_size, #kv_heads, context_size, head_dim)."""
47+
super().__init__()
48+
self.past_seen_tokens: int = 0
49+
self.k_cache: torch.Tensor = torch.zeros(shape, dtype=dtype, device=device)
50+
self.v_cache: torch.Tensor = torch.zeros(shape, dtype=dtype, device=device)
51+
52+
def update(
53+
self,
54+
k_state: torch.Tensor,
55+
v_state: torch.Tensor,
56+
layer_idx: int,
57+
slice_indices: torch.LongTensor,
58+
) -> Tuple[torch.Tensor, torch.Tensor]:
59+
"""
60+
Update key/value cache tensors for slice [slice_indices[0], slice_indices[1]).
61+
Return slice of key/value cache tensors from [0, slice_indices[1]).
62+
"""
63+
if len(slice_indices) != 2:
64+
raise ValueError(f"Expect tuple of integers [start, end), got {slice_indices=}.")
65+
begin, end = slice_indices
66+
self.k_cache[layer_idx, :, : k_state.shape[1], begin:end, :] = k_state
67+
self.v_cache[layer_idx, :, : v_state.shape[1], begin:end, :] = v_state
68+
k_cache: torch.Tensor = self.k_cache[layer_idx, :, :, :end, :]
69+
v_cache: torch.Tensor = self.v_cache[layer_idx, :, :, :end, :]
70+
return k_cache, v_cache
71+
72+
def get_seq_length(self, _: int | None = 0) -> int:
73+
"""Get the sequence length of the cache."""
74+
return self.past_seen_tokens
75+
76+
77+
class SliceUpdateMistralAttention(MistralAttention):
78+
def __init__(self, config: MistralConfig, layer_idx: Optional[int] = None):
79+
super().__init__(config=config, layer_idx=layer_idx)
80+
81+
@torch.no_grad()
82+
def forward(
83+
self,
84+
hidden_states: torch.Tensor,
85+
attention_mask: torch.Tensor,
86+
position_ids: Optional[torch.LongTensor] = None,
87+
past_key_value: Optional[Cache] = None,
88+
**kwargs,
89+
) -> Tuple[torch.Tensor | None, ...]:
90+
bsz, q_len, _ = hidden_states.size()
91+
92+
query_states = self.q_proj(hidden_states)
93+
key_states = self.k_proj(hidden_states)
94+
value_states = self.v_proj(hidden_states)
95+
96+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
97+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(
98+
1, 2
99+
)
100+
value_states = value_states.view(
101+
bsz, q_len, self.num_key_value_heads, self.head_dim
102+
).transpose(1, 2)
103+
104+
cos, sin = self.rotary_emb(value_states, position_ids)
105+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
106+
107+
# Slice update key/value cache
108+
end_step = attention_mask.shape[-1]
109+
key_states, value_states = past_key_value.update(
110+
key_states,
111+
value_states,
112+
self.layer_idx,
113+
slice_indices=(end_step - q_len, end_step),
114+
)
115+
116+
key_states = repeat_kv(key_states, self.num_key_value_groups)
117+
value_states = repeat_kv(value_states, self.num_key_value_groups)
118+
119+
attn_output = torch.nn.functional.scaled_dot_product_attention(
120+
query_states,
121+
key_states,
122+
value_states,
123+
attn_mask=attention_mask,
124+
)
125+
126+
attn_output = attn_output.transpose(1, 2).contiguous()
127+
attn_output = attn_output.view(bsz, q_len, self.hidden_size)
128+
attn_output = self.o_proj(attn_output)
129+
return attn_output, None, None
130+
131+
132+
class StatefulMistralForCausalLM(torch.nn.Module):
133+
def __init__(self, model_path: str, max_context_size: int = 2048, batch_size: int = 1) -> None:
134+
super().__init__()
135+
136+
# Custom attention implementation for stateful slice update key/value cache, override
137+
# "sdpa" to compliance with transformers.modeling_utils._autoset_attn_implementation
138+
MISTRAL_ATTENTION_CLASSES["sdpa"] = SliceUpdateMistralAttention
139+
self.model = MistralForCausalLM.from_pretrained(model_path)
140+
141+
# Register KV cache buffers to be recognized as Core ML states
142+
config: MistralConfig = self.model.config
143+
self.kv_cache_shape: Tuple[int, ...] = (
144+
config.num_hidden_layers,
145+
batch_size,
146+
config.num_key_value_heads,
147+
max_context_size,
148+
config.hidden_size // config.num_attention_heads,
149+
)
150+
self.kv_cache = SliceUpdateKeyValueCache(shape=self.kv_cache_shape)
151+
self.register_buffer("keyCache", self.kv_cache.k_cache)
152+
self.register_buffer("valueCache", self.kv_cache.v_cache)
153+
154+
@torch.no_grad()
155+
def forward(
156+
self,
157+
input_ids: torch.LongTensor,
158+
causal_mask: torch.Tensor,
159+
) -> torch.Tensor:
160+
# Compute past seen tokens used for updating key/value cache slices
161+
self.kv_cache.past_seen_tokens = causal_mask.shape[-1] - input_ids.shape[-1]
162+
return self.model(
163+
input_ids,
164+
attention_mask=causal_mask,
165+
past_key_values=self.kv_cache,
166+
use_cache=True,
167+
).logits
168+
169+
170+
def export() -> None:
171+
# Construct model from transformers and trace to TorchScript
172+
max_context_size: int = 2048
173+
torch_model = StatefulMistralForCausalLM(MODEL_ID, max_context_size=max_context_size)
174+
torch_model.eval()
175+
input_ids: torch.Tensor = torch.zeros((1, 2), dtype=torch.int32)
176+
causal_mask: torch.Tensor = torch.zeros((1, 1, 2, 5), dtype=torch.float32)
177+
traced_model = torch.jit.trace(torch_model, [input_ids, causal_mask])
178+
kv_cache_shape = torch_model.kv_cache_shape
179+
del torch_model
180+
181+
# Convert traced TorchScript to Core ML format
182+
query_length = ct.RangeDim(lower_bound=1, upper_bound=max_context_size, default=1)
183+
end_step_dim = ct.RangeDim(lower_bound=1, upper_bound=max_context_size, default=1)
184+
inputs: List[ct.TensorType] = [
185+
ct.TensorType(shape=(1, query_length), dtype=np.int32, name="inputIds"),
186+
ct.TensorType(
187+
shape=(1, 1, query_length, end_step_dim),
188+
dtype=np.float16,
189+
name="causalMask",
190+
),
191+
]
192+
outputs: List[ct.TensorType] = [ct.TensorType(dtype=np.float16, name="logits")]
193+
states: List[ct.StateType] = [
194+
ct.StateType(
195+
wrapped_type=ct.TensorType(shape=kv_cache_shape, dtype=np.float16),
196+
name="keyCache",
197+
),
198+
ct.StateType(
199+
wrapped_type=ct.TensorType(shape=kv_cache_shape, dtype=np.float16),
200+
name="valueCache",
201+
),
202+
]
203+
204+
# Convert model with FP16 precision
205+
mlmodel_fp16: ct.MLModel = ct.convert(
206+
traced_model,
207+
inputs=inputs,
208+
outputs=outputs,
209+
states=states,
210+
minimum_deployment_target=ct.target.iOS18,
211+
skip_model_load=True,
212+
)
213+
del traced_model
214+
215+
# Block-wise quantize model weights to int4
216+
op_config = ct.optimize.coreml.OpLinearQuantizerConfig(
217+
mode="linear_symmetric",
218+
dtype="int4",
219+
granularity="per_block",
220+
block_size=32,
221+
)
222+
config = ct.optimize.coreml.OptimizationConfig(global_config=op_config)
223+
mlmodel_int4 = ct.optimize.coreml.linear_quantize_weights(mlmodel_fp16, config=config)
224+
mlmodel_int4._spec.description.metadata.userDefined.update({METADATA_TOKENIZER: MODEL_ID})
225+
del mlmodel_fp16
226+
mlmodel_int4.save("StatefulMistral7BInstructInt4.mlpackage")
227+
228+
229+
if __name__ == "__main__":
230+
export()

Examples/Mistral7B/generate.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# /// script
2+
# requires-python = ">=3.10"
3+
# dependencies = [
4+
# "coremltools",
5+
# "numpy",
6+
# "sentencepiece",
7+
# "torch",
8+
# "tqdm",
9+
# "transformers",
10+
# ]
11+
# ///
12+
import argparse
13+
from typing import Dict, Generator, List, Tuple
14+
15+
import numpy as np
16+
from coremltools.models import MLModel
17+
from transformers import AutoTokenizer
18+
19+
from export import METADATA_TOKENIZER
20+
21+
22+
def load(model_path: str) -> Tuple[MLModel, AutoTokenizer]:
23+
"""Load a Core ML model and corresponding tokenizer."""
24+
model: MLModel = MLModel(model_path)
25+
description = model.get_spec().description
26+
if METADATA_TOKENIZER not in description.metadata.userDefined:
27+
raise ValueError("Model metadata does not contain tokenizer path.")
28+
tokenizer_path: str = description.metadata.userDefined[METADATA_TOKENIZER]
29+
tokenizer: AutoTokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
30+
return model, tokenizer
31+
32+
33+
def get_next_token(model: MLModel, prompt_tokens: np.ndarray) -> Generator[int, None, None]:
34+
"""Generate a sequence of tokens with naive greedy decoding."""
35+
36+
def sample(logits: np.ndarray) -> int:
37+
"""Perform greedy decoding on the logits array to get the next token."""
38+
return int(np.argmax(logits[0][-1], axis=-1))
39+
40+
def inference(model: MLModel, input_ids: np.ndarray, num_past_tokens: int) -> np.ndarray:
41+
"""Perform inference with the given model and input data."""
42+
causal_mask: np.ndarray = np.triu(
43+
np.full(
44+
(1, 1, input_ids.shape[-1], num_past_tokens + input_ids.shape[-1]),
45+
fill_value=-np.inf if num_past_tokens == 0 else 0,
46+
),
47+
k=1,
48+
).astype(np.float16)
49+
outputs: Dict[str, np.ndarray] = model.predict(
50+
data={"inputIds": input_ids, "causalMask": causal_mask},
51+
state=kv_cache_state,
52+
)
53+
return outputs["logits"]
54+
55+
kv_cache_state = model.make_state()
56+
logits: np.ndarray = inference(model, input_ids=prompt_tokens, num_past_tokens=0)
57+
token: int = sample(logits=logits)
58+
num_past_tokens: int = prompt_tokens.shape[-1]
59+
60+
while True:
61+
yield token
62+
logits: np.ndarray = inference(
63+
model,
64+
input_ids=np.array([[token]], dtype=np.int32),
65+
num_past_tokens=num_past_tokens,
66+
)
67+
token: int = sample(logits=logits)
68+
num_past_tokens += 1
69+
70+
71+
def generate(
72+
model: MLModel,
73+
prompt: str,
74+
tokenizer: AutoTokenizer,
75+
max_new_tokens: int,
76+
) -> str:
77+
prompt_tokens: np.ndarray = tokenizer(prompt, return_tensors="np").input_ids
78+
extend_tokens: List[int] = []
79+
for i, token in enumerate(get_next_token(model, prompt_tokens=prompt_tokens.astype(np.int32))):
80+
if token == tokenizer.eos_token_id or i == max_new_tokens:
81+
break
82+
extend_tokens.append(token)
83+
return tokenizer.decode(prompt_tokens[0].tolist() + extend_tokens)
84+
85+
86+
if __name__ == "__main__":
87+
parser = argparse.ArgumentParser()
88+
parser.add_argument("model_path", type=str)
89+
parser.add_argument("--prompt", type=str, default="Hello")
90+
parser.add_argument("--max_new_tokens", type=int, default=128)
91+
args = parser.parse_args()
92+
model, tokenizer = load(args.model_path)
93+
extend_text: str = generate(
94+
model,
95+
prompt=args.prompt,
96+
tokenizer=tokenizer,
97+
max_new_tokens=args.max_new_tokens,
98+
)
99+
print(extend_text)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
coremltools
2+
numpy
3+
torch
4+
tqdm
5+
transformers
6+
sentencepiece
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
// swift-tools-version: 6.2
2+
// The swift-tools-version declares the minimum version of Swift required to build this package.
3+
4+
import PackageDescription
5+
6+
let package = Package(
7+
name: "transformers-cli",
8+
platforms: [.iOS(.v18), .macOS(.v15)],
9+
dependencies: [
10+
.package(path: "../.."),
11+
// If you copy this manifest as a template, use the following line instead
12+
//.package(url: "https://github.com/huggingface/swift-transformers", from: "1.0.0"),
13+
.package(url: "https://github.com/apple/swift-argument-parser", from: "1.3.0"),
14+
],
15+
targets: [
16+
.executableTarget(
17+
name: "transformers-cli",
18+
dependencies: [
19+
.product(name: "Transformers", package: "swift-transformers"),
20+
.product(name: "ArgumentParser", package: "swift-argument-parser"),
21+
]
22+
)
23+
]
24+
)

0 commit comments

Comments
 (0)