@@ -560,6 +560,21 @@ def transform_output_spec(self, output_spec: Composite) -> Composite:
560560 output_spec ["full_done_spec" ] = self .transform_done_spec (
561561 output_spec ["full_done_spec" ]
562562 )
563+ output_spec_keys = [
564+ unravel_key (k [1 :]) for k in output_spec .keys (True ) if isinstance (k , tuple )
565+ ]
566+ out_keys = {unravel_key (k ) for k in self .out_keys }
567+ in_keys = {unravel_key (k ) for k in self .in_keys }
568+ for key in out_keys - in_keys :
569+ if unravel_key (key ) not in output_spec_keys :
570+ warnings .warn (
571+ f"The key '{ key } ' is unaccounted for by the transform (expected keys { output_spec_keys } ). "
572+ f"Every new entry in the tensordict resulting from a call to a transform must be "
573+ f"registered in the specs for torchrl rollouts to be consistently built. "
574+ f"Make sure transform_output_spec/transform_observation_spec/... is coded correctly. "
575+ "This warning will trigger a KeyError in v0.9, make sure to adapt your code accordingly." ,
576+ category = FutureWarning ,
577+ )
563578 return output_spec
564579
565580 def transform_input_spec (self , input_spec : TensorSpec ) -> TensorSpec :
@@ -1468,33 +1483,57 @@ def transform_input_spec(self, input_spec: TensorSpec) -> TensorSpec:
14681483 # the action spec from the env, map it using t0 then t1 (going from in to out).
14691484 for t in self .transforms :
14701485 input_spec = t .transform_input_spec (input_spec )
1486+ if not isinstance (input_spec , Composite ):
1487+ raise TypeError (
1488+ f"Expected Compose but got { type (input_spec )} with transform { t } "
1489+ )
14711490 return input_spec
14721491
14731492 def transform_action_spec (self , action_spec : TensorSpec ) -> TensorSpec :
14741493 # To understand why we don't invert, look up at transform_input_spec
14751494 for t in self .transforms :
14761495 action_spec = t .transform_action_spec (action_spec )
1496+ if not isinstance (action_spec , TensorSpec ):
1497+ raise TypeError (
1498+ f"Expected TensorSpec but got { type (action_spec )} with transform { t } "
1499+ )
14771500 return action_spec
14781501
14791502 def transform_state_spec (self , state_spec : TensorSpec ) -> TensorSpec :
14801503 # To understand why we don't invert, look up at transform_input_spec
14811504 for t in self .transforms :
14821505 state_spec = t .transform_state_spec (state_spec )
1506+ if not isinstance (state_spec , Composite ):
1507+ raise TypeError (
1508+ f"Expected Compose but got { type (state_spec )} with transform { t } "
1509+ )
14831510 return state_spec
14841511
14851512 def transform_observation_spec (self , observation_spec : TensorSpec ) -> TensorSpec :
14861513 for t in self .transforms :
14871514 observation_spec = t .transform_observation_spec (observation_spec )
1515+ if not isinstance (observation_spec , TensorSpec ):
1516+ raise TypeError (
1517+ f"Expected TensorSpec but got { type (observation_spec )} with transform { t } "
1518+ )
14881519 return observation_spec
14891520
14901521 def transform_output_spec (self , output_spec : TensorSpec ) -> TensorSpec :
14911522 for t in self .transforms :
14921523 output_spec = t .transform_output_spec (output_spec )
1524+ if not isinstance (output_spec , Composite ):
1525+ raise TypeError (
1526+ f"Expected Compose but got { type (output_spec )} with transform { t } "
1527+ )
14931528 return output_spec
14941529
14951530 def transform_reward_spec (self , reward_spec : TensorSpec ) -> TensorSpec :
14961531 for t in self .transforms :
14971532 reward_spec = t .transform_reward_spec (reward_spec )
1533+ if not isinstance (reward_spec , TensorSpec ):
1534+ raise TypeError (
1535+ f"Expected TensorSpec but got { type (reward_spec )} with transform { t } "
1536+ )
14981537 return reward_spec
14991538
15001539 def __getitem__ (self , item : Union [int , slice , List ]) -> Union :
0 commit comments