Skip to content

Commit c1e3052

Browse files
committed
add cloning length check
1 parent 8ebcad8 commit c1e3052

File tree

2 files changed

+44
-41
lines changed

2 files changed

+44
-41
lines changed

source/isaaclab/isaaclab/scene/cloner.py

Lines changed: 40 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -112,46 +112,47 @@ def clone_from_template(stage: Usd.Stage, num_clones: int, template_clone_cfg: T
112112
cfg.template_root,
113113
predicate=lambda prim: str(prim.GetPath()).split("/")[-1].startswith(prototype_id),
114114
)
115-
prototype_root_set = {"/".join(str(prototype.GetPath()).split("/")[:-1]) for prototype in prototypes}
116-
for prototype_root in prototype_root_set:
117-
protos = sim_utils.find_matching_prim_paths(f"{prototype_root}/.*")
118-
protos = [proto for proto in protos if proto.split("/")[-1].startswith(prototype_id)]
119-
m = torch.zeros((len(protos), num_clones), dtype=torch.bool, device=cfg.device)
120-
# Optionally select prototypes randomly per environment; else round-robin by modulo
121-
if cfg.random_heterogenous_cloning:
122-
rand_idx = torch.randint(len(protos), (num_clones,), device=cfg.device)
123-
m[rand_idx, world_indices] = True
115+
if len(prototypes) > 0:
116+
prototype_root_set = {"/".join(str(prototype.GetPath()).split("/")[:-1]) for prototype in prototypes}
117+
for prototype_root in prototype_root_set:
118+
protos = sim_utils.find_matching_prim_paths(f"{prototype_root}/.*")
119+
protos = [proto for proto in protos if proto.split("/")[-1].startswith(prototype_id)]
120+
m = torch.zeros((len(protos), num_clones), dtype=torch.bool, device=cfg.device)
121+
# Optionally select prototypes randomly per environment; else round-robin by modulo
122+
if cfg.random_heterogenous_cloning:
123+
rand_idx = torch.randint(len(protos), (num_clones,), device=cfg.device)
124+
m[rand_idx, world_indices] = True
125+
else:
126+
m[world_indices % len(protos), world_indices] = True
127+
128+
clone_plan["src"].extend(protos)
129+
clone_plan["dest"].extend([prototype_root.replace(cfg.template_root, clone_path_fmt)] * len(protos))
130+
clone_plan["mapping"] = torch.cat((clone_plan["mapping"].reshape(-1, m.size(1)), m), dim=0)
131+
132+
proto_idx = clone_plan["mapping"].to(torch.int32).argmax(dim=1)
133+
proto_mask = torch.zeros_like(clone_plan["mapping"])
134+
proto_mask.scatter_(1, proto_idx.view(-1, 1).to(torch.long), clone_plan["mapping"].any(dim=1, keepdim=True))
135+
usd_replicate(stage, clone_plan["src"], clone_plan["dest"], world_indices, proto_mask)
136+
stage.GetPrimAtPath(cfg.template_root).SetActive(False)
137+
138+
# If all prototypes map to env_0, clone whole env_0 to all envs; else clone per-object
139+
if torch.all(proto_idx == 0):
140+
replicate_args = [clone_path_fmt.format(0)], [clone_path_fmt], world_indices, clone_plan["mapping"]
141+
if cfg.clone_usd:
142+
# parse env_origins directly from clone_path
143+
get_translate = (
144+
lambda prim_path: stage.GetPrimAtPath(prim_path).GetAttribute("xformOp:translate").Get()
145+
) # noqa: E731
146+
positions = torch.tensor([get_translate(clone_path_fmt.format(i)) for i in world_indices])
147+
usd_replicate(stage, *replicate_args, positions=positions)
124148
else:
125-
m[world_indices % len(protos), world_indices] = True
126-
127-
clone_plan["src"].extend(protos)
128-
clone_plan["dest"].extend([prototype_root.replace(cfg.template_root, clone_path_fmt)] * len(protos))
129-
clone_plan["mapping"] = torch.cat((clone_plan["mapping"].reshape(-1, m.size(1)), m), dim=0)
130-
131-
proto_idx = clone_plan["mapping"].to(torch.int32).argmax(dim=1)
132-
proto_mask = torch.zeros_like(clone_plan["mapping"])
133-
proto_mask.scatter_(1, proto_idx.view(-1, 1).to(torch.long), clone_plan["mapping"].any(dim=1, keepdim=True))
134-
usd_replicate(stage, clone_plan["src"], clone_plan["dest"], world_indices, proto_mask)
135-
stage.GetPrimAtPath(cfg.template_root).SetActive(False)
136-
137-
# If all prototypes map to env_0, clone whole env_0 to all envs; else clone per-object
138-
if torch.all(proto_idx == 0):
139-
replicate_args = [clone_path_fmt.format(0)], [clone_path_fmt], world_indices, clone_plan["mapping"]
140-
if cfg.clone_usd:
141-
# parse env_origins directly from clone_path
142-
get_translate = (
143-
lambda prim_path: stage.GetPrimAtPath(prim_path).GetAttribute("xformOp:translate").Get()
144-
) # noqa: E731
145-
positions = torch.tensor([get_translate(clone_path_fmt.format(i)) for i in world_indices])
146-
usd_replicate(stage, *replicate_args, positions=positions)
147-
else:
148-
src = [tpl.format(int(idx)) for tpl, idx in zip(clone_plan["dest"], proto_idx.tolist())]
149-
replicate_args = src, clone_plan["dest"], world_indices, clone_plan["mapping"]
150-
if cfg.clone_usd:
151-
usd_replicate(stage, *replicate_args)
152-
153-
if cfg.clone_physx:
154-
physx_replicate(stage, *replicate_args, use_fabric=cfg.clone_in_fabric)
149+
src = [tpl.format(int(idx)) for tpl, idx in zip(clone_plan["dest"], proto_idx.tolist())]
150+
replicate_args = src, clone_plan["dest"], world_indices, clone_plan["mapping"]
151+
if cfg.clone_usd:
152+
usd_replicate(stage, *replicate_args)
153+
154+
if cfg.clone_physx:
155+
physx_replicate(stage, *replicate_args, use_fabric=cfg.clone_in_fabric)
155156

156157

157158
def usd_replicate(

source/isaaclab/isaaclab/scene/interactive_scene.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -136,15 +136,17 @@ def __init__(self, cfg: InteractiveSceneCfg):
136136
self._physics_scene_path = None
137137
# prepare cloner for environment replication
138138
self.env_prim_paths = [f"{self.env_ns}/env_{i}" for i in range(self.cfg.num_envs)]
139-
# create source prim
140-
self.stage.DefinePrim(self.env_prim_paths[0], "Xform")
141139

142140
self.cloner_cfg = cloner.TemplateCloneCfg(
143141
clone_regex=self.env_regex_ns,
144142
random_heterogenous_cloning=self.cfg.random_heterogenous_cloning,
145143
clone_in_fabric=self.cfg.clone_in_fabric,
146144
device=self.device,
147145
)
146+
147+
# create source prim
148+
self.stage.DefinePrim(self.env_prim_paths[0], "Xform")
149+
self.stage.DefinePrim(self.cloner_cfg.template_root, "Xform")
148150
self.env_fmt = self.env_regex_ns.replace(".*", "{}")
149151
# allocate env indices
150152
self._ALL_INDICES = torch.arange(self.cfg.num_envs, dtype=torch.long, device=self.device)

0 commit comments

Comments
 (0)