|
10 | 10 | """ |
11 | 11 |
|
12 | 12 | from abc import ABC, abstractmethod |
13 | | -from dataclasses import dataclass |
14 | 13 | from typing import Dict, List, Literal, Optional, Protocol, Sequence, Set, Tuple, Type, Union |
15 | 14 |
|
16 | 15 | import torch |
| 16 | +from pydantic import BaseModel, ConfigDict, Field, field_validator |
17 | 17 | from torch._ops import OpOverloadPacket |
18 | 18 | from torch.fx import Node |
19 | 19 | from torch.types import Number |
|
24 | 24 | Constant = Union[int, float, str, None] |
25 | 25 |
|
26 | 26 |
|
27 | | -@dataclass |
28 | | -class CacheConfig: |
29 | | - """A dataclass to hold information how to configure the cache.""" |
| 27 | +class CacheConfig(BaseModel): |
| 28 | + """Cache configuration for attention-related dtypes.""" |
30 | 29 |
|
31 | | - dtype: Optional[torch.dtype] = None |
| 30 | + model_config = ConfigDict( |
| 31 | + arbitrary_types_allowed=True, |
| 32 | + extra="forbid", |
| 33 | + ) |
| 34 | + |
| 35 | + dtype: Optional[torch.dtype] = Field(default=None, description="KV cache dtype.") |
| 36 | + mamba_dtype: Optional[torch.dtype] = Field(default=None, description="Mamba cache dtype.") |
| 37 | + |
| 38 | + @field_validator("dtype", "mamba_dtype", mode="before") |
| 39 | + @classmethod |
| 40 | + def _coerce_dtype(cls, value): |
| 41 | + if value is None or isinstance(value, torch.dtype): |
| 42 | + return value |
| 43 | + if isinstance(value, str): |
| 44 | + dtype = getattr(torch, value, None) |
| 45 | + assert isinstance(dtype, torch.dtype), f"Invalid {dtype=}" |
| 46 | + return dtype |
| 47 | + return value |
| 48 | + |
| 49 | + def __or__(self, other: "CacheConfig") -> "CacheConfig": |
| 50 | + """Combine two CacheConfig objects field-wise using Python's `or` semantics. |
| 51 | +
|
| 52 | + For each field, selects the first non-None value between `self` and `other`. |
| 53 | + """ |
| 54 | + if not isinstance(other, CacheConfig): |
| 55 | + raise NotImplementedError(f"Cannot combine CacheConfig with {type(other)}") |
| 56 | + merged_kwargs = {} |
| 57 | + for field_name in type(self).model_fields.keys(): |
| 58 | + merged_kwargs[field_name] = getattr(self, field_name) or getattr(other, field_name) |
| 59 | + return CacheConfig(**merged_kwargs) |
32 | 60 |
|
33 | 61 |
|
34 | 62 | class SequenceInfo: |
|
0 commit comments