@@ -4426,10 +4426,12 @@ class UnaryTransform(Transform):
44264426 Args:
44274427 in_keys (sequence of NestedKey): the keys of inputs to the unary operation.
44284428 out_keys (sequence of NestedKey): the keys of the outputs of the unary operation.
4429- fn (Callable ): the function to use as the unary operation. If it accepts
4430- a non-tensor input, it must also accept ``None`` .
4429+ in_keys_inv (sequence of NestedKey ): the keys of inputs to the unary operation during inverse call.
4430+ out_keys_inv (sequence of NestedKey): the keys of the outputs of the unary operation durin inverse call .
44314431
44324432 Keyword Args:
4433+ fn (Callable): the function to use as the unary operation. If it accepts
4434+ a non-tensor input, it must also accept ``None``.
44334435 use_raw_nontensor (bool, optional): if ``False``, data is extracted from
44344436 :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` inputs before ``fn`` is called
44354437 on them. If ``True``, the raw :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack`
@@ -4500,11 +4502,18 @@ def __init__(
45004502 self ,
45014503 in_keys : Sequence [NestedKey ],
45024504 out_keys : Sequence [NestedKey ],
4503- fn : Callable ,
4505+ in_keys_inv : Sequence [NestedKey ] | None = None ,
4506+ out_keys_inv : Sequence [NestedKey ] | None = None ,
45044507 * ,
4508+ fn : Callable ,
45054509 use_raw_nontensor : bool = False ,
45064510 ):
4507- super ().__init__ (in_keys = in_keys , out_keys = out_keys )
4511+ super ().__init__ (
4512+ in_keys = in_keys ,
4513+ out_keys = out_keys ,
4514+ in_keys_inv = in_keys_inv ,
4515+ out_keys_inv = out_keys_inv ,
4516+ )
45084517 self ._fn = fn
45094518 self ._use_raw_nontensor = use_raw_nontensor
45104519
@@ -4519,13 +4528,50 @@ def _apply_transform(self, value):
45194528 value = value .tolist ()
45204529 return self ._fn (value )
45214530
4531+ def _inv_apply_transform (self , state : torch .Tensor ) -> torch .Tensor :
4532+ if not self ._use_raw_nontensor :
4533+ if isinstance (state , NonTensorData ):
4534+ if state .dim () == 0 :
4535+ state = state .get ("data" )
4536+ else :
4537+ state = state .tolist ()
4538+ elif isinstance (state , NonTensorStack ):
4539+ state = state .tolist ()
4540+ return self ._fn (state )
4541+
45224542 def _reset (
45234543 self , tensordict : TensorDictBase , tensordict_reset : TensorDictBase
45244544 ) -> TensorDictBase :
45254545 with _set_missing_tolerance (self , True ):
45264546 tensordict_reset = self ._call (tensordict_reset )
45274547 return tensordict_reset
45284548
4549+ def transform_input_spec (self , input_spec : Composite ) -> Composite :
4550+ input_spec = input_spec .clone ()
4551+
4552+ # Make a generic input from the spec, call the transform with that
4553+ # input, and then generate the output spec from the output.
4554+ zero_input_ = input_spec .zero ()
4555+ test_input = zero_input_ ["full_action_spec" ].update (
4556+ zero_input_ ["full_state_spec" ]
4557+ )
4558+ test_output = self .inv (test_input )
4559+ test_input_spec = make_composite_from_td (
4560+ test_output , unsqueeze_null_shapes = False
4561+ )
4562+
4563+ input_spec ["full_action_spec" ] = self .transform_action_spec (
4564+ input_spec ["full_action_spec" ],
4565+ test_input_spec ,
4566+ )
4567+ if "full_state_spec" in input_spec .keys ():
4568+ input_spec ["full_state_spec" ] = self .transform_state_spec (
4569+ input_spec ["full_state_spec" ],
4570+ test_input_spec ,
4571+ )
4572+ print (input_spec )
4573+ return input_spec
4574+
45294575 def transform_output_spec (self , output_spec : Composite ) -> Composite :
45304576 output_spec = output_spec .clone ()
45314577
@@ -4586,19 +4632,31 @@ def transform_done_spec(
45864632 ) -> TensorSpec :
45874633 return self ._transform_spec (done_spec , test_output_spec )
45884634
4635+ def transform_action_spec (
4636+ self , action_spec : TensorSpec , test_input_spec : TensorSpec
4637+ ) -> TensorSpec :
4638+ return self ._transform_spec (action_spec , test_input_spec )
4639+
4640+ def transform_state_spec (
4641+ self , state_spec : TensorSpec , test_input_spec : TensorSpec
4642+ ) -> TensorSpec :
4643+ return self ._transform_spec (state_spec , test_input_spec )
4644+
45894645
45904646class Hash (UnaryTransform ):
45914647 r"""Adds a hash value to a tensordict.
45924648
45934649 Args:
45944650 in_keys (sequence of NestedKey): the keys of the values to hash.
45954651 out_keys (sequence of NestedKey): the keys of the resulting hashes.
4652+ in_keys_inv (sequence of NestedKey): the keys of the values to hash during inv call.
4653+ out_keys_inv (sequence of NestedKey): the keys of the resulting hashes during inv call.
4654+
4655+ Keyword Args:
45964656 hash_fn (Callable, optional): the hash function to use. If ``seed`` is given,
45974657 the hash function must accept it as its second argument. Default is
45984658 ``Hash.reproducible_hash``.
45994659 seed (optional): seed to use for the hash function, if it requires one.
4600-
4601- Keyword Args:
46024660 use_raw_nontensor (bool, optional): if ``False``, data is extracted from
46034661 :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack` inputs before ``fn`` is called
46044662 on them. If ``True``, the raw :class:`~tensordict.NonTensorData`/:class:`~tensordict.NonTensorStack`
@@ -4684,9 +4742,11 @@ def __init__(
46844742 self ,
46854743 in_keys : Sequence [NestedKey ],
46864744 out_keys : Sequence [NestedKey ],
4745+ in_keys_inv : Sequence [NestedKey ] | None = None ,
4746+ out_keys_inv : Sequence [NestedKey ] | None = None ,
4747+ * ,
46874748 hash_fn : Callable = None ,
46884749 seed : Any | None = None ,
4689- * ,
46904750 use_raw_nontensor : bool = False ,
46914751 ):
46924752 if hash_fn is None :
@@ -4697,6 +4757,8 @@ def __init__(
46974757 super ().__init__ (
46984758 in_keys = in_keys ,
46994759 out_keys = out_keys ,
4760+ in_keys_inv = in_keys_inv ,
4761+ out_keys_inv = out_keys_inv ,
47004762 fn = self .call_hash_fn ,
47014763 use_raw_nontensor = use_raw_nontensor ,
47024764 )
@@ -4725,7 +4787,7 @@ def reproducible_hash(cls, string, seed=None):
47254787 if seed is not None :
47264788 seeded_string = seed + string
47274789 else :
4728- seeded_string = string
4790+ seeded_string = str ( string )
47294791
47304792 # Create a new SHA-256 hash object
47314793 hash_object = hashlib .sha256 ()
0 commit comments