@@ -967,6 +967,133 @@ def _make_node_map(self, source, dest):
967967 self .max_size = self .data_map .max_size
968968
969969 def extend (self , rollout , * , return_node : bool = False ):
970+ """Add a rollout to the forest.
971+
972+ Nodes are only added to a tree at points where rollouts diverge from
973+ each other and at the endpoints of rollouts.
974+
975+ If there is no existing tree that matches the first steps of the
976+ rollout, a new tree is added. Only one node is created, for the final
977+ step.
978+
979+ If there is an existing tree that matches, the rollout is added to that
980+ tree. If the rollout diverges from all other rollouts in the tree at
981+ some step, a new node is created before the step where the rollouts
982+ diverge, and a leaf node is created for the final step of the rollout.
983+ If all of the rollout's steps match with a previously added rollout,
984+ nothing changes. If the rollout matches up to a leaf node of a tree but
985+ continues beyond it, that node is extended to the end of the rollout,
986+ and no new nodes are created.
987+
988+ Args:
989+ rollout (TensorDict): The rollout to add to the forest.
990+ return_node (bool, optional): If ``True``, the method returns the
991+ added node. Default is ``False``.
992+
993+ Returns:
994+ Tree: The node that was added to the forest. This is only
995+ returned if ``return_node`` is True.
996+
997+ Examples:
998+ >>> from torchrl.data import MCTSForest
999+ >>> from tensordict import TensorDict
1000+ >>> import torch
1001+ >>> forest = MCTSForest()
1002+ >>> r0 = TensorDict({
1003+ ... 'action': torch.tensor([1, 2, 3, 4, 5]),
1004+ ... 'next': {'observation': torch.tensor([123, 392, 989, 809, 847])},
1005+ ... 'observation': torch.tensor([ 0, 123, 392, 989, 809])
1006+ ... }, [5])
1007+ >>> r1 = TensorDict({
1008+ ... 'action': torch.tensor([1, 2, 6, 7]),
1009+ ... 'next': {'observation': torch.tensor([123, 392, 235, 38])},
1010+ ... 'observation': torch.tensor([ 0, 123, 392, 235])
1011+ ... }, [4])
1012+ >>> td_root = r0[0].exclude("next")
1013+ >>> forest.extend(r0)
1014+ >>> forest.extend(r1)
1015+ >>> tree = forest.get_tree(td_root)
1016+ >>> print(tree)
1017+ Tree(
1018+ count=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int32, is_shared=False),
1019+ index=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False),
1020+ node_data=TensorDict(
1021+ fields={
1022+ observation: Tensor(shape=torch.Size([]), device=cpu, dtype=torch.int64, is_shared=False)},
1023+ batch_size=torch.Size([]),
1024+ device=cpu,
1025+ is_shared=False),
1026+ node_id=NonTensorData(data=0, batch_size=torch.Size([]), device=None),
1027+ rollout=TensorDict(
1028+ fields={
1029+ action: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False),
1030+ next: TensorDict(
1031+ fields={
1032+ observation: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)},
1033+ batch_size=torch.Size([2]),
1034+ device=cpu,
1035+ is_shared=False),
1036+ observation: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)},
1037+ batch_size=torch.Size([2]),
1038+ device=cpu,
1039+ is_shared=False),
1040+ subtree=Tree(
1041+ _parent=NonTensorStack(
1042+ [<weakref at 0x716eeb78fbf0; to 'TensorDict' at 0x...,
1043+ batch_size=torch.Size([2]),
1044+ device=None),
1045+ count=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int32, is_shared=False),
1046+ hash=NonTensorStack(
1047+ [4341220243998689835, 6745467818783115365],
1048+ batch_size=torch.Size([2]),
1049+ device=None),
1050+ node_data=LazyStackedTensorDict(
1051+ fields={
1052+ observation: Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.int64, is_shared=False)},
1053+ exclusive_fields={
1054+ },
1055+ batch_size=torch.Size([2]),
1056+ device=cpu,
1057+ is_shared=False,
1058+ stack_dim=0),
1059+ node_id=NonTensorStack(
1060+ [1, 2],
1061+ batch_size=torch.Size([2]),
1062+ device=None),
1063+ rollout=LazyStackedTensorDict(
1064+ fields={
1065+ action: Tensor(shape=torch.Size([2, -1]), device=cpu, dtype=torch.int64, is_shared=False),
1066+ next: LazyStackedTensorDict(
1067+ fields={
1068+ observation: Tensor(shape=torch.Size([2, -1]), device=cpu, dtype=torch.int64, is_shared=False)},
1069+ exclusive_fields={
1070+ },
1071+ batch_size=torch.Size([2, -1]),
1072+ device=cpu,
1073+ is_shared=False,
1074+ stack_dim=0),
1075+ observation: Tensor(shape=torch.Size([2, -1]), device=cpu, dtype=torch.int64, is_shared=False)},
1076+ exclusive_fields={
1077+ },
1078+ batch_size=torch.Size([2, -1]),
1079+ device=cpu,
1080+ is_shared=False,
1081+ stack_dim=0),
1082+ wins=Tensor(shape=torch.Size([2]), device=cpu, dtype=torch.float32, is_shared=False),
1083+ index=None,
1084+ subtree=None,
1085+ specs=None,
1086+ batch_size=torch.Size([2]),
1087+ device=None,
1088+ is_shared=False),
1089+ wins=Tensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False),
1090+ hash=None,
1091+ _parent=None,
1092+ specs=None,
1093+ batch_size=torch.Size([]),
1094+ device=None,
1095+ is_shared=False)
1096+ """
9701097 source , dest = (
9711098 rollout .exclude ("next" ).copy (),
9721099 rollout .select ("next" , * self .action_keys ).copy (),
0 commit comments