Skip to content
This repository was archived by the owner on Oct 9, 2024. It is now read-only.

Commit 25e19f5

Browse files
authored
add model_class argument (#29)
1 parent 223481c commit 25e19f5

File tree

10 files changed

+30
-13
lines changed

10 files changed

+30
-13
lines changed

Makefile

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ gen-proto:
1515
bloom-176b:
1616
TOKENIZERS_PARALLELISM=false \
1717
MODEL_NAME=microsoft/bloom-deepspeed-inference-fp16 \
18+
MODEL_CLASS=AutoModelForCausalLM \
1819
DEPLOYMENT_FRAMEWORK=ds_inference \
1920
DTYPE=fp16 \
2021
MAX_INPUT_LENGTH=2048 \
@@ -25,16 +26,18 @@ bloom-176b:
2526
bloomz-176b:
2627
TOKENIZERS_PARALLELISM=false \
2728
MODEL_NAME=bigscience/bloomz \
29+
MODEL_CLASS=AutoModelForCausalLM \
2830
DEPLOYMENT_FRAMEWORK=ds_inference \
2931
DTYPE=fp16 \
3032
MAX_INPUT_LENGTH=2048 \
3133
MAX_BATCH_SIZE=4 \
3234
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
3335
gunicorn -t 0 -w 1 -b 127.0.0.1:5000 inference_server.server:app --access-logfile - --access-logformat '%(h)s %(t)s "%(r)s" %(s)s %(b)s'
3436

35-
bloomz-560m:
37+
bloom-560m:
3638
TOKENIZERS_PARALLELISM=false \
3739
MODEL_NAME=bigscience/bloom-560m \
40+
MODEL_CLASS=AutoModelForCausalLM \
3841
DEPLOYMENT_FRAMEWORK=hf_accelerate \
3942
DTYPE=bf16 \
4043
MAX_INPUT_LENGTH=2048 \
@@ -45,6 +48,7 @@ bloomz-560m:
4548
flan-t5-xxl:
4649
TOKENIZERS_PARALLELISM=false \
4750
MODEL_NAME=google/flan-t5-xxl \
51+
MODEL_CLASS=AutoModelForSeq2SeqLM
4852
DEPLOYMENT_FRAMEWORK=hf_accelerate \
4953
DTYPE=fp \
5054
MAX_INPUT_LENGTH=2048 \
@@ -55,6 +59,7 @@ flan-t5-xxl:
5559
ul2:
5660
TOKENIZERS_PARALLELISM=false \
5761
MODEL_NAME=google/ul2 \
62+
MODEL_CLASS=AutoModelForSeq2SeqLM \
5863
DEPLOYMENT_FRAMEWORK=hf_accelerate \
5964
DTYPE=fp16 \
6065
MAX_INPUT_LENGTH=2048 \
@@ -65,6 +70,7 @@ ul2:
6570
codegen-mono:
6671
TOKENIZERS_PARALLELISM=false \
6772
MODEL_NAME=Salesforce/codegen-16B-mono \
73+
MODEL_CLASS=AutoModelForSeq2SeqLM \
6874
DEPLOYMENT_FRAMEWORK=hf_accelerate \
6975
DTYPE=fp16 \
7076
MAX_INPUT_LENGTH=2048 \

inference_server/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ CFLAGS="-I$CONDA_PREFIX/include/" LDFLAGS="-L$CONDA_PREFIX/lib/" TORCH_CUDA_ARCH
1717

1818
All the provided scripts are tested on 8 A100 80GB GPUs for BLOOM 176B (fp16/bf16) and 4 A100 80GB GPUs for BLOOM 176B (int8). These scripts might not work for other models or a different number of GPUs.
1919

20-
DS inference is deployed using the DeepSpeed MII library which requires the resharded checkpoints for 8 x Tensor Parallel.
20+
DS inference is deployed using logic borrowed from DeepSpeed MII library.
2121

2222
Note: Sometimes GPU memory is not freed when DS inference deployment crashes. You can free this memory by running `killall python` in terminal.
2323

inference_server/model_handler/deployment.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
"""
2+
Copyright 2022 The Microsoft DeepSpeed Team
3+
"""
14
import argparse
25
import asyncio
36
import os

inference_server/model_handler/launch.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
"""
2+
Copyright 2022 The Microsoft DeepSpeed Team
3+
"""
14
import argparse
25

36
import torch.distributed as dist

inference_server/models/ds_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def __init__(self, args: Namespace) -> None:
3333
# the actual weights while calling deepspeed.init_inference in the
3434
# following code
3535
with deepspeed.OnDevice(dtype=torch.float16, device="meta"):
36-
self.model = get_hf_model_class(args.model_name).from_config(
36+
self.model = get_hf_model_class(args.model_class).from_config(
3737
AutoConfig.from_pretrained(downloaded_model_path), torch_dtype=torch.bfloat16
3838
)
3939
self.model = self.model.eval()

inference_server/models/ds_zero.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,9 @@ def __init__(self, args: Namespace) -> None:
5757
self.tokenizer = AutoTokenizer.from_pretrained(downloaded_model_path)
5858
self.pad = self.tokenizer.pad_token_id
5959

60-
self.model = get_hf_model_class(args.model_name).from_pretrained(downloaded_model_path, torch_dtype=args.dtype)
60+
self.model = get_hf_model_class(args.model_class).from_pretrained(
61+
downloaded_model_path, torch_dtype=args.dtype
62+
)
6163
self.model = self.model.eval()
6264

6365
# convert model to a fully sharded model using ZeRO

inference_server/models/hf_accelerate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __init__(self, args: Namespace) -> None:
3232

3333
# this is the CUDA device for the current process. This will be used
3434
# later to identify the GPU on which to transfer tensors
35-
self.model = get_hf_model_class(args.model_name).from_pretrained(**kwargs)
35+
self.model = get_hf_model_class(args.model_class).from_pretrained(**kwargs)
3636

3737
self.model.requires_grad_(False)
3838
self.model.eval()

inference_server/models/model.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import torch
77

8+
import transformers
89
from huggingface_hub import snapshot_download
910
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM
1011
from transformers.utils import is_offline_mode
@@ -112,10 +113,5 @@ def check_batch_size(batch_size: int, max_batch_size: int) -> None:
112113

113114

114115
# this is a hack for now
115-
def get_hf_model_class(model_name: str) -> Union[AutoModelForCausalLM, AutoModelForSeq2SeqLM]:
116-
if "bloom" in model_name:
117-
return AutoModelForCausalLM
118-
elif "t5" in model_name:
119-
return AutoModelForSeq2SeqLM
120-
elif "ul2" in model_name:
121-
return AutoModelForSeq2SeqLM
116+
def get_hf_model_class(model_class: str) -> Union[AutoModelForCausalLM, AutoModelForSeq2SeqLM]:
117+
return getattr(transformers, model_class)

inference_server/server.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class Args:
2929
def __init__(self) -> None:
3030
self.deployment_framework = os.getenv("DEPLOYMENT_FRAMEWORK", HF_ACCELERATE)
3131
self.model_name = os.getenv("MODEL_NAME")
32+
self.model_class = os.getenv("MODEL_CLASS")
3233
self.dtype = get_torch_dtype(os.getenv("DTYPE"))
3334
self.allowed_max_new_tokens = int(os.getenv("ALLOWED_MAX_NEW_TOKENS", 100))
3435
self.max_input_length = int(os.getenv("MAX_INPUT_LENGTH", 512))

inference_server/utils/utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,13 @@ def get_argument_parser() -> argparse.ArgumentParser:
3838
"--model_name",
3939
type=str,
4040
required=True,
41-
help="model to use",
41+
help="model name to use",
42+
)
43+
group.add_argument(
44+
"--model_class",
45+
type=str,
46+
required=True,
47+
help="model class to use",
4248
)
4349
group.add_argument("--dtype", type=str, required=True, choices=["bf16", "fp16", "int8"], help="dtype for model")
4450
group.add_argument(

0 commit comments

Comments
 (0)