@@ -57,15 +57,15 @@ More concretely, key-value cache acts as a memory bank for these generative mode
5757 >> > from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
5858
5959 >> > model_id = " TinyLlama/TinyLlama-1.1B-Chat-v1.0"
60- >> > model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype = torch.bfloat16, device_map = " cuda:0 " )
60+ >> > model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype = torch.bfloat16, device_map = " auto " )
6161 >> > tokenizer = AutoTokenizer.from_pretrained(model_id)
6262
6363 >> > past_key_values = DynamicCache()
6464 >> > messages = [{" role" : " user" , " content" : " Hello, what's your name." }]
65- >> > inputs = tokenizer.apply_chat_template(messages, add_generation_prompt = True , return_tensors = " pt" , return_dict = True ).to(" cuda:0 " )
65+ >> > inputs = tokenizer.apply_chat_template(messages, add_generation_prompt = True , return_tensors = " pt" , return_dict = True ).to(model.device )
6666
6767 >> > generated_ids = inputs.input_ids
68- >> > cache_position = torch.arange(inputs.input_ids.shape[1 ], dtype = torch.int64, device = " cuda:0 " )
68+ >> > cache_position = torch.arange(inputs.input_ids.shape[1 ], dtype = torch.int64, device = model.device )
6969 >> > max_new_tokens = 10
7070
7171 >> > for _ in range (max_new_tokens):
@@ -139,7 +139,7 @@ Cache quantization can be detrimental in terms of latency if the context length
139139>> > from transformers import AutoTokenizer, AutoModelForCausalLM
140140
141141>> > tokenizer = AutoTokenizer.from_pretrained(" TinyLlama/TinyLlama-1.1B-Chat-v1.0" )
142- >> > model = AutoModelForCausalLM.from_pretrained(" TinyLlama/TinyLlama-1.1B-Chat-v1.0" , torch_dtype = torch.float16).to( " cuda:0 " )
142+ >> > model = AutoModelForCausalLM.from_pretrained(" TinyLlama/TinyLlama-1.1B-Chat-v1.0" , torch_dtype = torch.float16, device_map = " auto " )
143143>> > inputs = tokenizer(" I like rock music because" , return_tensors = " pt" ).to(model.device)
144144
145145>> > out = model.generate(** inputs, do_sample = False , max_new_tokens = 20 , cache_implementation = " quantized" , cache_config = {" nbits" : 4 , " backend" : " quanto" })
@@ -168,7 +168,7 @@ Use `cache_implementation="offloaded_static"` for an offloaded static cache (see
168168>> > ckpt = " microsoft/Phi-3-mini-4k-instruct"
169169
170170>> > tokenizer = AutoTokenizer.from_pretrained(ckpt)
171- >> > model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype = torch.float16).to( " cuda:0 " )
171+ >> > model = AutoModelForCausalLM.from_pretrained(ckpt, torch_dtype = torch.float16, device_map = " auto " )
172172>> > inputs = tokenizer(" Fun fact: The shortest" , return_tensors = " pt" ).to(model.device)
173173
174174>> > out = model.generate(** inputs, do_sample = False , max_new_tokens = 23 , cache_implementation = " offloaded" )
@@ -278,7 +278,7 @@ Note that you can use this cache only for models that support sliding window, e.
278278>> > from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache
279279
280280>> > tokenizer = AutoTokenizer.from_pretrained(" teknium/OpenHermes-2.5-Mistral-7B" )
281- >> > model = AutoModelForCausalLM.from_pretrained(" teknium/OpenHermes-2.5-Mistral-7B" , torch_dtype = torch.float16).to( " cuda:0 " )
281+ >> > model = AutoModelForCausalLM.from_pretrained(" teknium/OpenHermes-2.5-Mistral-7B" , torch_dtype = torch.float16, device_map = " auto " )
282282>> > inputs = tokenizer(" Yesterday I was on a rock concert and." , return_tensors = " pt" ).to(model.device)
283283
284284>> > # can be used by passing in cache implementation
@@ -298,7 +298,7 @@ Unlike other cache classes, this one can't be used directly by indicating a `cac
298298>> > from transformers import AutoTokenizer, AutoModelForCausalLM, SinkCache
299299
300300>> > tokenizer = AutoTokenizer.from_pretrained(" TinyLlama/TinyLlama-1.1B-Chat-v1.0" )
301- >> > model = AutoModelForCausalLM.from_pretrained(" TinyLlama/TinyLlama-1.1B-Chat-v1.0" , torch_dtype = torch.float16).to( " cuda:0 " )
301+ >> > model = AutoModelForCausalLM.from_pretrained(" TinyLlama/TinyLlama-1.1B-Chat-v1.0" , torch_dtype = torch.float16, device_map = " auto " )
302302>> > inputs = tokenizer(" This is a long story about unicorns, fairies and magic." , return_tensors = " pt" ).to(model.device)
303303
304304>> > # get our cache, specify number of sink tokens and window size
@@ -377,25 +377,27 @@ Sometimes you would want to first fill-in cache object with key/values for certa
377377>> > import copy
378378>> > import torch
379379>> > from transformers import AutoModelForCausalLM, AutoTokenizer, DynamicCache, StaticCache
380+ >> > from accelerate.test_utils.testing import get_backend
380381
382+ >> > DEVICE , _, _ = get_backend() # automatically detects the underlying device type (CUDA, CPU, XPU, MPS, etc.)
381383>> > model_id = " TinyLlama/TinyLlama-1.1B-Chat-v1.0"
382- >> > model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype = torch.bfloat16, device_map = " cuda " )
384+ >> > model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype = torch.bfloat16, device_map = DEVICE )
383385>> > tokenizer = AutoTokenizer.from_pretrained(model_id)
384386
385387>> > # Init StaticCache with big enough max-length (1024 tokens for the below example)
386388>> > # You can also init a DynamicCache, if that suits you better
387- >> > prompt_cache = StaticCache(config = model.config, max_batch_size = 1 , max_cache_len = 1024 , device = " cuda " , dtype = torch.bfloat16)
389+ >> > prompt_cache = StaticCache(config = model.config, max_batch_size = 1 , max_cache_len = 1024 , device = DEVICE , dtype = torch.bfloat16)
388390
389391>> > INITIAL_PROMPT = " You are a helpful assistant. "
390- >> > inputs_initial_prompt = tokenizer(INITIAL_PROMPT , return_tensors = " pt" ).to(" cuda " )
392+ >> > inputs_initial_prompt = tokenizer(INITIAL_PROMPT , return_tensors = " pt" ).to(DEVICE )
391393>> > # This is the common prompt cached, we need to run forward without grad to be abel to copy
392394>> > with torch.no_grad():
393395... prompt_cache = model(** inputs_initial_prompt, past_key_values = prompt_cache).past_key_values
394396
395397>> > prompts = [" Help me to write a blogpost about travelling." , " What is the capital of France?" ]
396398>> > responses = []
397399>> > for prompt in prompts:
398- ... new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors = " pt" ).to(" cuda " )
400+ ... new_inputs = tokenizer(INITIAL_PROMPT + prompt, return_tensors = " pt" ).to(DEVICE )
399401... past_key_values = copy.deepcopy(prompt_cache)
400402... outputs = model.generate(** new_inputs, past_key_values = past_key_values,max_new_tokens = 20 )
401403... response = tokenizer.batch_decode(outputs)[0 ]
0 commit comments