@@ -795,6 +795,17 @@ def input_spec(self) -> TensorSpec:
795795 input_spec = self .__dict__ .get ("_input_spec" , None )
796796 return input_spec
797797
798+ def rand_action (self , tensordict : Optional [TensorDictBase ] = None ) -> TensorDict :
799+ if self .base_env .rand_action is not EnvBase .rand_action :
800+ # TODO: this will fail if the transform modifies the input.
801+ # For instance, if PendulumEnv overrides rand_action and we build a
802+ # env = PendulumEnv().append_transform(ActionDiscretizer(num_intervals=4))
803+ # env.rand_action will NOT have a discrete action!
804+ # Getting a discrete action would require coding the inverse transform of an action within
805+ # ActionDiscretizer (ie, float->int, not int->float).
806+ return self .base_env .rand_action (tensordict )
807+ return super ().rand_action (tensordict )
808+
798809 def _step (self , tensordict : TensorDictBase ) -> TensorDictBase :
799810 # No need to clone here because inv does it already
800811 # tensordict = tensordict.clone(False)
@@ -4415,10 +4426,12 @@ class UnaryTransform(Transform):
44154426 Args:
44164427 in_keys (sequence of NestedKey): the keys of inputs to the unary operation.
44174428 out_keys (sequence of NestedKey): the keys of the outputs of the unary operation.
4418- fn (Callable ): the function to use as the unary operation. If it accepts
4419- a non-tensor input, it must also accept ``None`` .
4429+ in_keys_inv (sequence of NestedKey, optional ): the keys of inputs to the unary operation during inverse call.
4430+ out_keys_inv (sequence of NestedKey, optional): the keys of the outputs of the unary operation durin inverse call .
44204431
44214432 Keyword Args:
4433+ fn (Callable): the function to use as the unary operation. If it accepts
4434+ a non-tensor input, it must also accept ``None``.
44224435 use_raw_nontensor (bool, optional): if ``False``, data is extracted from
44234436 :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` inputs before ``fn`` is called
44244437 on them. If ``True``, the raw :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack`
@@ -4489,11 +4502,18 @@ def __init__(
44894502 self ,
44904503 in_keys : Sequence [NestedKey ],
44914504 out_keys : Sequence [NestedKey ],
4492- fn : Callable ,
4505+ in_keys_inv : Sequence [NestedKey ] | None = None ,
4506+ out_keys_inv : Sequence [NestedKey ] | None = None ,
44934507 * ,
4508+ fn : Callable ,
44944509 use_raw_nontensor : bool = False ,
44954510 ):
4496- super ().__init__ (in_keys = in_keys , out_keys = out_keys )
4511+ super ().__init__ (
4512+ in_keys = in_keys ,
4513+ out_keys = out_keys ,
4514+ in_keys_inv = in_keys_inv ,
4515+ out_keys_inv = out_keys_inv ,
4516+ )
44974517 self ._fn = fn
44984518 self ._use_raw_nontensor = use_raw_nontensor
44994519
@@ -4508,13 +4528,49 @@ def _apply_transform(self, value):
45084528 value = value .tolist ()
45094529 return self ._fn (value )
45104530
4531+ def _inv_apply_transform (self , state : torch .Tensor ) -> torch .Tensor :
4532+ if not self ._use_raw_nontensor :
4533+ if isinstance (state , NonTensorData ):
4534+ if state .dim () == 0 :
4535+ state = state .get ("data" )
4536+ else :
4537+ state = state .tolist ()
4538+ elif isinstance (state , NonTensorStack ):
4539+ state = state .tolist ()
4540+ return self ._fn (state )
4541+
45114542 def _reset (
45124543 self , tensordict : TensorDictBase , tensordict_reset : TensorDictBase
45134544 ) -> TensorDictBase :
45144545 with _set_missing_tolerance (self , True ):
45154546 tensordict_reset = self ._call (tensordict_reset )
45164547 return tensordict_reset
45174548
4549+ def transform_input_spec (self , input_spec : Composite ) -> Composite :
4550+ input_spec = input_spec .clone ()
4551+
4552+ # Make a generic input from the spec, call the transform with that
4553+ # input, and then generate the output spec from the output.
4554+ zero_input_ = input_spec .zero ()
4555+ test_input = zero_input_ ["full_action_spec" ].update (
4556+ zero_input_ ["full_state_spec" ]
4557+ )
4558+ test_output = self .inv (test_input )
4559+ test_input_spec = make_composite_from_td (
4560+ test_output , unsqueeze_null_shapes = False
4561+ )
4562+
4563+ input_spec ["full_action_spec" ] = self .transform_action_spec (
4564+ input_spec ["full_action_spec" ],
4565+ test_input_spec ,
4566+ )
4567+ if "full_state_spec" in input_spec .keys ():
4568+ input_spec ["full_state_spec" ] = self .transform_state_spec (
4569+ input_spec ["full_state_spec" ],
4570+ test_input_spec ,
4571+ )
4572+ return input_spec
4573+
45184574 def transform_output_spec (self , output_spec : Composite ) -> Composite :
45194575 output_spec = output_spec .clone ()
45204576
@@ -4575,19 +4631,31 @@ def transform_done_spec(
45754631 ) -> TensorSpec :
45764632 return self ._transform_spec (done_spec , test_output_spec )
45774633
4634+ def transform_action_spec (
4635+ self , action_spec : TensorSpec , test_input_spec : TensorSpec
4636+ ) -> TensorSpec :
4637+ return self ._transform_spec (action_spec , test_input_spec )
4638+
4639+ def transform_state_spec (
4640+ self , state_spec : TensorSpec , test_input_spec : TensorSpec
4641+ ) -> TensorSpec :
4642+ return self ._transform_spec (state_spec , test_input_spec )
4643+
45784644
45794645class Hash (UnaryTransform ):
45804646 r"""Adds a hash value to a tensordict.
45814647
45824648 Args:
45834649 in_keys (sequence of NestedKey): the keys of the values to hash.
45844650 out_keys (sequence of NestedKey): the keys of the resulting hashes.
4651+ in_keys_inv (sequence of NestedKey, optional): the keys of the values to hash during inv call.
4652+ out_keys_inv (sequence of NestedKey, optional): the keys of the resulting hashes during inv call.
4653+
4654+ Keyword Args:
45854655 hash_fn (Callable, optional): the hash function to use. If ``seed`` is given,
45864656 the hash function must accept it as its second argument. Default is
45874657 ``Hash.reproducible_hash``.
45884658 seed (optional): seed to use for the hash function, if it requires one.
4589-
4590- Keyword Args:
45914659 use_raw_nontensor (bool, optional): if ``False``, data is extracted from
45924660 :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` inputs before ``fn`` is called
45934661 on them. If ``True``, the raw :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack`
@@ -4673,9 +4741,11 @@ def __init__(
46734741 self ,
46744742 in_keys : Sequence [NestedKey ],
46754743 out_keys : Sequence [NestedKey ],
4744+ in_keys_inv : Sequence [NestedKey ] | None = None ,
4745+ out_keys_inv : Sequence [NestedKey ] | None = None ,
4746+ * ,
46764747 hash_fn : Callable = None ,
46774748 seed : Any | None = None ,
4678- * ,
46794749 use_raw_nontensor : bool = False ,
46804750 ):
46814751 if hash_fn is None :
@@ -4686,6 +4756,8 @@ def __init__(
46864756 super ().__init__ (
46874757 in_keys = in_keys ,
46884758 out_keys = out_keys ,
4759+ in_keys_inv = in_keys_inv ,
4760+ out_keys_inv = out_keys_inv ,
46894761 fn = self .call_hash_fn ,
46904762 use_raw_nontensor = use_raw_nontensor ,
46914763 )
@@ -4714,7 +4786,7 @@ def reproducible_hash(cls, string, seed=None):
47144786 if seed is not None :
47154787 seeded_string = seed + string
47164788 else :
4717- seeded_string = string
4789+ seeded_string = str ( string )
47184790
47194791 # Create a new SHA-256 hash object
47204792 hash_object = hashlib .sha256 ()
@@ -4728,6 +4800,77 @@ def reproducible_hash(cls, string, seed=None):
47284800 return torch .frombuffer (hash_bytes , dtype = torch .uint8 )
47294801
47304802
4803+ class Tokenizer (UnaryTransform ):
4804+ r"""Applies a tokenization operation on the specified inputs.
4805+
4806+ Args:
4807+ in_keys (sequence of NestedKey): the keys of inputs to the tokenization operation.
4808+ out_keys (sequence of NestedKey): the keys of the outputs of the tokenization operation.
4809+ in_keys_inv (sequence of NestedKey, optional): the keys of inputs to the tokenization operation during inverse call.
4810+ out_keys_inv (sequence of NestedKey, optional): the keys of the outputs of the tokenization operation during inverse call.
4811+
4812+ Keyword Args:
4813+ tokenizer (transformers.PretrainedTokenizerBase or str, optional): the tokenizer to use. If ``None``,
4814+ "bert-base-uncased" will be used by default. If a string is provided, it should be the name of a
4815+ pre-trained tokenizer.
4816+ use_raw_nontensor (bool, optional): if ``False``, data is extracted from
4817+ :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` inputs before the tokenization
4818+ function is called on them. If ``True``, the raw :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack`
4819+ inputs are given directly to the tokenization function, which must support those inputs. Default is ``False``.
4820+ additional_tokens (List[str], optional): list of additional tokens to add to the tokenizer's vocabulary.
4821+ """
4822+
4823+ def __init__ (
4824+ self ,
4825+ in_keys : Sequence [NestedKey ],
4826+ out_keys : Sequence [NestedKey ],
4827+ in_keys_inv : Sequence [NestedKey ] | None = None ,
4828+ out_keys_inv : Sequence [NestedKey ] | None = None ,
4829+ * ,
4830+ tokenizer : "transformers.PretrainedTokenizerBase" = None , # noqa: F821
4831+ use_raw_nontensor : bool = False ,
4832+ additional_tokens : List [str ] | None = None ,
4833+ ):
4834+ if tokenizer is None :
4835+ from transformers import AutoTokenizer
4836+
4837+ tokenizer = AutoTokenizer .from_pretrained ("bert-base-uncased" )
4838+ elif isinstance (tokenizer , str ):
4839+ from transformers import AutoTokenizer
4840+
4841+ tokenizer = AutoTokenizer .from_pretrained (tokenizer )
4842+
4843+ self .tokenizer = tokenizer
4844+ if additional_tokens :
4845+ self .tokenizer .add_tokens (additional_tokens )
4846+ super ().__init__ (
4847+ in_keys = in_keys ,
4848+ out_keys = out_keys ,
4849+ in_keys_inv = in_keys_inv ,
4850+ out_keys_inv = out_keys_inv ,
4851+ fn = self .call_tokenizer_fn ,
4852+ use_raw_nontensor = use_raw_nontensor ,
4853+ )
4854+
4855+ @property
4856+ def device (self ):
4857+ if "_device" in self .__dict__ :
4858+ return self ._device
4859+ parent = self .parent
4860+ if parent is None :
4861+ return None
4862+ device = parent .device
4863+ self ._device = device
4864+ return device
4865+
4866+ def call_tokenizer_fn (self , value : str | List [str ]):
4867+ device = self .device
4868+ out = self .tokenizer .encode (value , return_tensors = "pt" )
4869+ if device is not None and out .device != device :
4870+ out = out .to (device )
4871+ return out
4872+
4873+
47314874class Stack (Transform ):
47324875 """Stacks tensors and tensordicts.
47334876
0 commit comments