3434}
3535
3636
37+ # We need the 'shadow' flag to avoid having tensordict complaining about 'type'/'size' etc. fields
38+ class ContentBase (TensorClass ["nocast" , "shadow" ]):
39+ """Base class for all message content types.
40+
41+ Attributes:
42+ type (str): The type of the content.
43+ text (str, optional): The text content.
44+ url (str, optional): The URL content.
45+ data (str, optional): The data content.
46+ mime_type (str, optional): The MIME type of the content.
47+ name (str, optional): The name of the content.
48+ size (int, optional): The size of the content.
49+ function_name (str, optional): The name of the function.
50+ function_args (dict, optional): The arguments of the function.
51+
52+ Examples:
53+ >>> from tensordict import lazy_stack
54+ >>> content1 = ContentBase(type="text", text="Hello, world!")
55+ >>> print(content1)
56+ ContentBase(
57+ text=NonTensorData(data=Hello, world!, batch_size=torch.Size([]), device=None),
58+ type=NonTensorData(data=text, batch_size=torch.Size([]), device=None),
59+ url=None,
60+ data=None,
61+ mime_type=None,
62+ name=None,
63+ size=None,
64+ function_name=None,
65+ function_args=None,
66+ batch_size=torch.Size([]),
67+ device=None,
68+ is_shared=False)
69+ >>> content2 = ContentBase(type="image", url="https://example.com/image.jpg")
70+ >>> print(content2)
71+ ContentBase(
72+ type=NonTensorData(data=image, batch_size=torch.Size([]), device=None),
73+ url=NonTensorData(data=https://example.com/image.jpg, batch_size=torch.Size([]), device=None),
74+ text=None,
75+ data=None,
76+ mime_type=None,
77+ name=None,
78+ size=None,
79+ function_name=None,
80+ function_args=None,
81+ batch_size=torch.Size([]),
82+ device=None,
83+ is_shared=False)
84+ >>> content = lazy_stack([content1, content2])
85+ >>> print(content)
86+ ContentBase(
87+ type=NonTensorStack(
88+ ['text', 'image'],
89+ batch_size=torch.Size([2]),
90+ device=None),
91+ url=None,
92+ data=None,
93+ mime_type=None,
94+ name=None,
95+ size=None,
96+ function_name=None,
97+ function_args=None,
98+ text=None,
99+ batch_size=torch.Size([2]),
100+ device=None,
101+ is_shared=False)
102+ >>> # A content is typically used in a History object. Usually, its batch dimension is
103+ >>> # one dimension greater than the History object.
104+ >>> history = History(role="user", content=content)
105+
106+ """
107+
108+ type : Literal [
109+ "text" , "image" , "audio" , "video" , "file" , "function_call"
110+ ] # Required: "text", "image", "audio", "video", "file", "function_call"
111+
112+ # Text content
113+ text : str | None = None
114+
115+ # Media/file content (either URL or data)
116+ url : str | None = None # HTTP URL to content
117+ data : str | None = None # Base64 encoded content
118+
119+ # Metadata
120+ mime_type : str | None = None # "image/jpeg", "audio/mp3", "application/pdf"
121+ name : str | None = None # Original filename or description
122+ size : int | None = None # File size in bytes
123+
124+ # Function calling (for AI agents)
125+ function_name : str | None = None
126+ function_args : dict | None = None
127+
128+
37129class History (TensorClass ["nocast" ]):
38130 """A class representing a structured history of messages in a conversation, designed for efficient manipulation and integration with language models.
39131
@@ -98,7 +190,7 @@ class History(TensorClass["nocast"]):
98190 """
99191
100192 role : str
101- content : str
193+ content : str | ContentBase
102194
103195 def __post_init__ (self ):
104196 if not list_to_stack ():
@@ -110,27 +202,29 @@ def __post_init__(self):
110202 def apply_chat_template (
111203 self ,
112204 * ,
113- tokenizer : transformers .AutoTokenizer , # noqa
205+ tokenizer : transformers .AutoTokenizer | transformers . AutoProcessor , # noqa
114206 add_generation_prompt : bool = True ,
115207 chat_template : str | None = None ,
116208 continue_final_message : bool = False ,
117209 tokenize : bool = False ,
118210 padding : bool | str = False ,
119211 truncation : bool | str = False ,
120212 return_tensors : str | None = "pt" ,
213+ return_dict : bool = False ,
121214 ** kwargs ,
122215 ):
123216 """Applies a chat template to the history.
124217
125218 Keyword Args:
126- tokenizer (transformers.PreTrainedTokenizer): The tokenizer to use.
127- add_generation_prompt (bool, optional): Whether to add a generation prompt. Defaults to True.
219+ tokenizer (transformers.PreTrainedTokenizer | transformers.AutoProcessor ): The tokenizer to use.
220+ add_generation_prompt (bool, optional): Whether to add a generation prompt. Defaults to ` True` .
128221 chat_template (str, optional): The chat template to use. Defaults to the tokenizer's default template.
129- continue_final_message (bool, optional): Whether to continue the final message. Defaults to False.
130- tokenize (bool, optional): Whether to tokenize the output. Defaults to False.
131- padding (bool | str, optional): The padding strategy to use. Defaults to False.
132- truncation (bool | str, optional): The truncation strategy to use. Defaults to False.
222+ continue_final_message (bool, optional): Whether to continue the final message. Defaults to ` False` .
223+ tokenize (bool, optional): Whether to tokenize the output. Defaults to ` False` .
224+ padding (bool | str, optional): The padding strategy to use. Defaults to ` False` .
225+ truncation (bool | str, optional): The truncation strategy to use. Defaults to ` False` .
133226 return_tensors (str | None, optional): The type of tensors to return. Defaults to "pt".
227+ return_dict (bool, optional): Whether to return a dictionary. Defaults to `False`.
134228 **kwargs: Additional keyword arguments to pass to the tokenizer `apply_chat_template` method.
135229
136230 Returns:
@@ -155,20 +249,24 @@ def apply_chat_template(
155249 truncation = truncation ,
156250 return_tensors = return_tensors ,
157251 continue_final_message = continue_final_message ,
252+ return_dict = return_dict ,
158253 ** kwargs ,
159254 )
160255 for i in range (self .batch_size [0 ])
161256 ]
162- self_flat = self .view (- 1 ).tolist ()
257+ self_flat = self .view (- 1 )
258+ # tolist_first=True is needed to avoid having a list of dict of dicts, but a list of dicts of lists of dicts
259+ self_flat = self_flat .tolist (tolist_first = True )
163260 return tokenizer .apply_chat_template (
164- self_flat ,
261+ conversation = self_flat ,
165262 add_generation_prompt = add_generation_prompt ,
166263 chat_template = chat_template ,
167264 tokenize = tokenize ,
168265 padding = padding ,
169266 truncation = truncation ,
170267 return_tensors = return_tensors ,
171268 continue_final_message = continue_final_message ,
269+ return_dict = return_dict ,
172270 )
173271
174272 @classmethod
@@ -275,7 +373,7 @@ def append(
275373
276374 Args:
277375 history (History): The new history to append.
278- inplace (bool, optional): Whether to perform the operation in-place. Defaults to True.
376+ inplace (bool, optional): Whether to perform the operation in-place. Defaults to ` True` .
279377 dim (int, optional): The dimension to append along. Defaults to -1.
280378
281379 Returns:
0 commit comments