|
| 1 | +from typing import Any, Iterable, Optional, Union |
| 2 | + |
| 3 | +import numpy as np |
| 4 | +import jax |
| 5 | +import ray |
| 6 | +from ray.util.accelerators import tpu |
| 7 | + |
| 8 | +from jetstream.engine import engine_api, tokenizer_pb2 |
| 9 | +from jetstream_pt.ray_worker import PyTorchRayWorker |
| 10 | + |
| 11 | +Params = Any |
| 12 | +Prefix = Any |
| 13 | +DecodeState = Any |
| 14 | + |
| 15 | + |
| 16 | +class PyTorchRayEngine(engine_api.Engine): |
| 17 | + """Ray PyTorch Engine Implementation for Multi-Host Inference Serving. |
| 18 | + Key Features: |
| 19 | + 1. Manages all Ray workers. |
| 20 | + 2. Initializes model parameters for each Ray worker. |
| 21 | + 3. Routes incoming inference requests to Ray workers. |
| 22 | + 4. Collects token responses from the Ray workers. |
| 23 | + """ |
| 24 | + |
| 25 | + def __init__( |
| 26 | + self, |
| 27 | + engine_workers: Iterable[PyTorchRayWorker], |
| 28 | + tokenizer_path: str, |
| 29 | + context_length: int, |
| 30 | + batch_size: int, |
| 31 | + ): |
| 32 | + self.engine_workers = engine_workers |
| 33 | + self.tokenizer_path = tokenizer_path |
| 34 | + self.context_length = context_length |
| 35 | + self.batch_size = batch_size |
| 36 | + |
| 37 | + # pylint: disable-next=all |
| 38 | + def load_params(self) -> Params: |
| 39 | + all_outputs = [] |
| 40 | + for worker in self.engine_workers: |
| 41 | + output = worker.load_params_ray.remote() |
| 42 | + all_outputs.append(output) |
| 43 | + _ = ray.get(all_outputs) |
| 44 | + return None |
| 45 | + |
| 46 | + # pylint: disable-next=all |
| 47 | + def init_decode_state( |
| 48 | + self, |
| 49 | + ) -> DecodeState: |
| 50 | + all_outputs = [] |
| 51 | + for worker in self.engine_workers: |
| 52 | + output = worker.init_decode_state_ray.remote() |
| 53 | + all_outputs.append(output) |
| 54 | + _ = ray.get(all_outputs) |
| 55 | + return None |
| 56 | + |
| 57 | + def prefill( |
| 58 | + self, |
| 59 | + *, |
| 60 | + params: Any, # Weights |
| 61 | + existing_prefix: Optional[Prefix] = None, |
| 62 | + padded_tokens: np.ndarray, # PrefillInputs[np.ndarray], |
| 63 | + true_length: int, |
| 64 | + ) -> Prefix: |
| 65 | + all_outputs = [] |
| 66 | + for worker in self.engine_workers: |
| 67 | + output = worker.prefill_ray.remote( |
| 68 | + params=params, |
| 69 | + existing_prefix=existing_prefix, |
| 70 | + padded_tokens=padded_tokens, |
| 71 | + true_length=true_length, |
| 72 | + ) |
| 73 | + all_outputs.append(output) |
| 74 | + _ = ray.get(all_outputs) |
| 75 | + # The prefill function does not return any values; |
| 76 | + # the worker itself manages and maintains the prefill states. |
| 77 | + return None |
| 78 | + |
| 79 | + def insert( |
| 80 | + self, |
| 81 | + prefix: Prefix, |
| 82 | + decode_state: DecodeState, |
| 83 | + slot: int, |
| 84 | + ) -> DecodeState: |
| 85 | + all_outputs = [] |
| 86 | + for worker in self.engine_workers: |
| 87 | + output = worker.insert_ray.remote( |
| 88 | + prefix=prefix, decode_state=decode_state, slot=slot |
| 89 | + ) |
| 90 | + all_outputs.append(output) |
| 91 | + _ = ray.get(all_outputs) |
| 92 | + # The insert function does not return any values; |
| 93 | + # the worker itself manages and maintains the DecodeState. |
| 94 | + return None |
| 95 | + |
| 96 | + def generate( |
| 97 | + self, params: Any, decode_state: DecodeState |
| 98 | + ) -> tuple[None, engine_api.ResultTokens]: |
| 99 | + all_outputs = [] |
| 100 | + for worker in self.engine_workers: |
| 101 | + output = worker.generate_ray.remote( |
| 102 | + params=params, decode_state=decode_state |
| 103 | + ) |
| 104 | + all_outputs.append(output) |
| 105 | + # All workers performed an all_gather operation. Since the results are |
| 106 | + # identical across all workers, the result from worker 0 is returned. |
| 107 | + state, result_tokens = ray.get(all_outputs)[0] |
| 108 | + return state, result_tokens |
| 109 | + |
| 110 | + # pylint: disable-next=all |
| 111 | + def get_tokenizer(self) -> tokenizer_pb2.TokenizerParameters: |
| 112 | + # pylint: disable-next=all |
| 113 | + return tokenizer_pb2.TokenizerParameters(path=self.tokenizer_path) |
| 114 | + |
| 115 | + @property |
| 116 | + def max_concurrent_decodes(self) -> int: |
| 117 | + return self.batch_size |
| 118 | + |
| 119 | + @property |
| 120 | + def samples_per_slot(self) -> int: |
| 121 | + return 1 |
| 122 | + |
| 123 | + @property |
| 124 | + def max_prefill_length(self) -> int: |
| 125 | + return self.context_length |
| 126 | + |
| 127 | + @property |
| 128 | + def colocated_cpus(self) -> Union[list[engine_api.CpuDevices], None]: |
| 129 | + return jax.devices("cpu")[0] |
| 130 | + |
| 131 | + def get_prefix_destination_sharding(self) -> Prefix: |
| 132 | + "No implementation" |
| 133 | + return None |
| 134 | + |
| 135 | + @property |
| 136 | + def mesh(self): |
| 137 | + "No implementation" |
| 138 | + return None |
| 139 | + |
| 140 | + |
| 141 | +# pylint: disable-next=all |
| 142 | +def create_pytorch_ray_engine( |
| 143 | + tokenizer_path: str, |
| 144 | + ckpt_path: Optional[str] = None, |
| 145 | + samples_per_slot: int = 1, |
| 146 | + bf16_enable: bool = False, |
| 147 | + param_size: str = "7b", |
| 148 | + context_length: int = 1024, |
| 149 | + batch_size: int = 1, |
| 150 | + max_decode_length: int = 4096, |
| 151 | + model_name="llama", |
| 152 | + quantize_weights=False, |
| 153 | + quantize_kv=False, |
| 154 | + max_cache_length=1024, |
| 155 | +) -> PyTorchRayEngine: |
| 156 | + |
| 157 | + ray.init(ignore_reinit_error=True) |
| 158 | + pod_name = tpu.get_current_pod_name() |
| 159 | + num_hosts = tpu.get_current_pod_worker_count() |
| 160 | + print(f"pod_name:{pod_name}, number of host: {num_hosts}") |
| 161 | + assert ( |
| 162 | + pod_name is not None |
| 163 | + ), f"TPU pod name (current value:{pod_name}) can not be None" |
| 164 | + assert ( |
| 165 | + num_hosts > 0 |
| 166 | + ), f"num_hosts (current value {num_hosts}) should be a positive number" |
| 167 | + # pylint: disable-next=all |
| 168 | + engine_worker_with_tpu_resource = PyTorchRayWorker.options( |
| 169 | + resources={"TPU": 4} |
| 170 | + ) |
| 171 | + engine_workers = [] |
| 172 | + for _ in range(num_hosts): |
| 173 | + engine_worker = engine_worker_with_tpu_resource.remote( |
| 174 | + tokenizer_path=tokenizer_path, |
| 175 | + ckpt_path=ckpt_path, |
| 176 | + samples_per_slot=samples_per_slot, |
| 177 | + bf16_enable=bf16_enable, |
| 178 | + param_size=param_size, |
| 179 | + context_length=context_length, |
| 180 | + batch_size=batch_size, |
| 181 | + max_decode_length=max_decode_length, |
| 182 | + model_name=model_name, |
| 183 | + quantize_weights=quantize_weights, |
| 184 | + quantize_kv=quantize_kv, |
| 185 | + max_cache_length=max_cache_length, |
| 186 | + ) |
| 187 | + engine_workers.append(engine_worker) |
| 188 | + engine_master = PyTorchRayEngine( |
| 189 | + engine_workers=engine_workers, |
| 190 | + tokenizer_path=tokenizer_path, |
| 191 | + context_length=context_length, |
| 192 | + batch_size=batch_size, |
| 193 | + ) |
| 194 | + return engine_master |
0 commit comments