Skip to content

Commit aa1b470

Browse files
yanboliangArtyom17
authored andcommitted
Merge pull request #118 from yanboliang/cleanup
Clean up mixtral-moe
2 parents b262949 + a3e825f commit aa1b470

File tree

12 files changed

+119
-804
lines changed

12 files changed

+119
-804
lines changed

README.md

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@ Please check the rest of this page about benchmark of LLaMA family models.
2222
### Mixtral 8x7B
2323
We also supported [Mixtral 8x7B](https://mistral.ai/news/mixtral-of-experts/) which is a high-quality sparse mixture of experts (MoE) model, the average token generation rates are:
2424

25-
| | 1 GPU | 2 GPU | 4 GPU | 8 GPU |
25+
| | 1 GPU | 2 GPU | 4 GPU | 8 GPU |
2626
|------------------|---------|-----------|--------|------------|
27-
|baseline(bfloat16)| OOM | 78.75 | 118.23 | 203.69 |
28-
| int8 | 56.04 | 99.91 | 149.53 | 218.48 |
27+
|baseline(bfloat16)| OOM | 96.67 | 155.35 | 227.82 |
28+
| int8 | 97.92 | 155.03 | 216.87 | 279.35 |
2929

3030
Note that the benchmarks run on an 8xA100-80GB, power limited to 330W with a hybrid cube mesh topology. Note that all benchmarks are run at *batch size=1*, making the reported tokens/s numbers equivalent to "tokens/s/user". In addition, they are run with a very small prompt length (just 5 tokens).
3131

@@ -59,6 +59,9 @@ meta-llama/Llama-2-13b-chat-hf
5959
meta-llama/Llama-2-70b-chat-hf
6060
codellama/CodeLlama-7b-Python-hf
6161
codellama/CodeLlama-34b-Python-hf
62+
mistralai/Mistral-7B-v0.1
63+
mistralai/Mistral-7B-Instruct-v0.1
64+
mistralai/Mistral-7B-Instruct-v0.2
6265
```
6366

6467
For example, to convert Llama-2-7b-chat-hf
@@ -120,6 +123,11 @@ python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model.pth
120123
To squeeze out a little bit more performance, you can also compile the prefill with `--compile_prefill`. This will increase compilation times though.
121124

122125
## Quantization
126+
Choose device to use by
127+
```bash
128+
# The current support devices: cuda, cpu
129+
export DEVICE=cuda
130+
```
123131
### Int8 Weight-Only Quantization
124132
To generate this version of the model
125133
```bash
@@ -128,19 +136,19 @@ python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode in
128136
```
129137
To run with int8, just pass the int8 checkpoint to generate.py.
130138
```bash
131-
python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model_int8.pth
139+
python generate.py --compile --checkpoint_path checkpoints/$MODEL_REPO/model_int8.pth --device $DEVICE
132140
```
133141

134142
### Int4 Weight-Only Quantization
135143
To generate int4 version of model
136144
```bash
137-
# Spits out model at checkpoints/$MODEL_REPO/model_int4.g32.pth
138-
python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4 --groupsize 32
145+
# Spits out model at checkpoints/$MODEL_REPO/model_int4.g32.$DEVICE.pth
146+
python quantize.py --checkpoint_path checkpoints/$MODEL_REPO/model.pth --mode int4 --groupsize 32 --device $DEVICE
139147
```
140148

141149
To run with int4, just pass the int4 checkpoint to generate.py.
142150
```bash
143-
python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.pth --compile
151+
python generate.py --checkpoint_path checkpoints/$MODEL_REPO/model_int4.g32.$DEVICE.pth --compile --device $DEVICE
144152
```
145153

146154
## Speculative Sampling

generate.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
def device_sync(device):
1717
if "cuda" in device:
1818
torch.cuda.synchronize(device)
19-
elif "cpu" in device:
19+
elif ("cpu" in device) or ("mps" in device):
2020
pass
2121
else:
2222
print(f"device={device} is not yet suppported")
@@ -26,6 +26,7 @@ def device_sync(device):
2626
torch._inductor.config.triton.unique_kernel_names = True
2727
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
2828

29+
default_device = 'cuda' if torch.cuda.is_available() else 'cpu'
2930

3031
# support running without installing as a package
3132
wd = Path(__file__).parent.parent.resolve()
@@ -206,13 +207,14 @@ def generate(
206207
}
207208
return seq, generate_stats
208209

209-
def encode_tokens(tokenizer, string, bos=True, device='cuda'):
210+
def encode_tokens(tokenizer, string, bos=True, device=default_device):
210211
tokens = tokenizer.encode(string)
211212
if bos:
212213
tokens = [tokenizer.bos_id()] + tokens
213214
return torch.tensor(tokens, dtype=torch.int, device=device)
214215

215216
def _load_model(checkpoint_path, device, precision, use_tp):
217+
use_cuda = 'cuda' in device
216218
with torch.device('meta'):
217219
model = Transformer.from_name(checkpoint_path.parent.name)
218220

@@ -223,15 +225,18 @@ def _load_model(checkpoint_path, device, precision, use_tp):
223225
model = simple_quantizer.convert_for_runtime()
224226

225227
if "int4" in str(checkpoint_path):
226-
print("Using int4 quantization!")
228+
print("Using int4 weight-only quantization!")
227229
path_comps = checkpoint_path.name.split(".")
228-
assert path_comps[-2].startswith("g")
229-
groupsize = int(path_comps[-2][1:])
230+
assert path_comps[-3].startswith("g")
231+
assert path_comps[-2] in device, "weight packed format mismatch, please rerun quantize.py!"
232+
groupsize = int(path_comps[-3][1:])
230233
from quantize import WeightOnlyInt4QuantHandler
231234
simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
232-
model = simple_quantizer.convert_for_runtime()
235+
model = simple_quantizer.convert_for_runtime(use_cuda)
233236

234237
checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
238+
if "model" in checkpoint and "stories" in str(checkpoint_path):
239+
checkpoint = checkpoint["model"]
235240
model.load_state_dict(checkpoint, assign=True)
236241

237242
if use_tp:
@@ -257,7 +262,7 @@ def main(
257262
profile: Optional[Path] = None,
258263
draft_checkpoint_path: Optional[Path] = None,
259264
speculate_k: int = 5,
260-
device='cuda',
265+
device=default_device,
261266
) -> None:
262267
"""Generates text samples based on a pre-trained Transformer model and tokenizer.
263268
"""
@@ -310,7 +315,7 @@ def main(
310315
decode_one_token = torch.compile(decode_one_token, mode="reduce-overhead", fullgraph=True)
311316

312317
# Uncomment to squeeze more perf out of prefill
313-
if args.compile_prefill:
318+
if compile_prefill:
314319
prefill = torch.compile(prefill, fullgraph=True, dynamic=True)
315320

316321

@@ -412,7 +417,7 @@ def callback(x):
412417
parser.add_argument('--profile', type=Path, default=None, help='Profile path.')
413418
parser.add_argument('--speculate_k', type=int, default=5, help='Speculative execution depth.')
414419
parser.add_argument('--draft_checkpoint_path', type=Path, default=None, help='Draft checkpoint path.')
415-
parser.add_argument('--device', type=str, default="cuda", help='device to use')
420+
parser.add_argument('--device', type=str, default=default_device, help='Device to use')
416421

417422
args = parser.parse_args()
418423
main(

mixtral-moe/README.md

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,10 @@ python scripts/convert_hf_checkpoint.py --checkpoint_dir checkpoints/$MODEL_REPO
1212
## Benchmarks
1313
Benchmarks run on an 8xA100-80GB, power limited to 330W with a hybrid cube mesh topology. Note that all benchmarks are run at *batch size=1*, making the reported tokens/s numbers equivalent to "tokens/s/user". In addition, they are run with a very small prompt length (just 5 tokens).
1414

15-
| | 1 GPU | 2 GPU | 4 GPU | 8 GPU |
15+
| | 1 GPU | 2 GPU | 4 GPU | 8 GPU |
1616
|------------------|---------|-----------|--------|------------|
17-
|baseline(bfloat16)| OOM | 78.75 | 118.23 | 203.69 |
18-
| int8 | 56.04 | 99.91 | 149.53 | 218.48 |
19-
17+
|baseline(bfloat16)| OOM | 96.67 | 155.35 | 227.82 |
18+
| int8 | 97.92 | 155.03 | 216.87 | 279.35 |
2019

2120

2221
## Generate Text

0 commit comments

Comments
 (0)