@@ -112,46 +112,47 @@ def clone_from_template(stage: Usd.Stage, num_clones: int, template_clone_cfg: T
112112 cfg .template_root ,
113113 predicate = lambda prim : str (prim .GetPath ()).split ("/" )[- 1 ].startswith (prototype_id ),
114114 )
115- prototype_root_set = {"/" .join (str (prototype .GetPath ()).split ("/" )[:- 1 ]) for prototype in prototypes }
116- for prototype_root in prototype_root_set :
117- protos = sim_utils .find_matching_prim_paths (f"{ prototype_root } /.*" )
118- protos = [proto for proto in protos if proto .split ("/" )[- 1 ].startswith (prototype_id )]
119- m = torch .zeros ((len (protos ), num_clones ), dtype = torch .bool , device = cfg .device )
120- # Optionally select prototypes randomly per environment; else round-robin by modulo
121- if cfg .random_heterogenous_cloning :
122- rand_idx = torch .randint (len (protos ), (num_clones ,), device = cfg .device )
123- m [rand_idx , world_indices ] = True
115+ if len (prototypes ) > 0 :
116+ prototype_root_set = {"/" .join (str (prototype .GetPath ()).split ("/" )[:- 1 ]) for prototype in prototypes }
117+ for prototype_root in prototype_root_set :
118+ protos = sim_utils .find_matching_prim_paths (f"{ prototype_root } /.*" )
119+ protos = [proto for proto in protos if proto .split ("/" )[- 1 ].startswith (prototype_id )]
120+ m = torch .zeros ((len (protos ), num_clones ), dtype = torch .bool , device = cfg .device )
121+ # Optionally select prototypes randomly per environment; else round-robin by modulo
122+ if cfg .random_heterogenous_cloning :
123+ rand_idx = torch .randint (len (protos ), (num_clones ,), device = cfg .device )
124+ m [rand_idx , world_indices ] = True
125+ else :
126+ m [world_indices % len (protos ), world_indices ] = True
127+
128+ clone_plan ["src" ].extend (protos )
129+ clone_plan ["dest" ].extend ([prototype_root .replace (cfg .template_root , clone_path_fmt )] * len (protos ))
130+ clone_plan ["mapping" ] = torch .cat ((clone_plan ["mapping" ].reshape (- 1 , m .size (1 )), m ), dim = 0 )
131+
132+ proto_idx = clone_plan ["mapping" ].to (torch .int32 ).argmax (dim = 1 )
133+ proto_mask = torch .zeros_like (clone_plan ["mapping" ])
134+ proto_mask .scatter_ (1 , proto_idx .view (- 1 , 1 ).to (torch .long ), clone_plan ["mapping" ].any (dim = 1 , keepdim = True ))
135+ usd_replicate (stage , clone_plan ["src" ], clone_plan ["dest" ], world_indices , proto_mask )
136+ stage .GetPrimAtPath (cfg .template_root ).SetActive (False )
137+
138+ # If all prototypes map to env_0, clone whole env_0 to all envs; else clone per-object
139+ if torch .all (proto_idx == 0 ):
140+ replicate_args = [clone_path_fmt .format (0 )], [clone_path_fmt ], world_indices , clone_plan ["mapping" ]
141+ if cfg .clone_usd :
142+ # parse env_origins directly from clone_path
143+ get_translate = (
144+ lambda prim_path : stage .GetPrimAtPath (prim_path ).GetAttribute ("xformOp:translate" ).Get ()
145+ ) # noqa: E731
146+ positions = torch .tensor ([get_translate (clone_path_fmt .format (i )) for i in world_indices ])
147+ usd_replicate (stage , * replicate_args , positions = positions )
124148 else :
125- m [world_indices % len (protos ), world_indices ] = True
126-
127- clone_plan ["src" ].extend (protos )
128- clone_plan ["dest" ].extend ([prototype_root .replace (cfg .template_root , clone_path_fmt )] * len (protos ))
129- clone_plan ["mapping" ] = torch .cat ((clone_plan ["mapping" ].reshape (- 1 , m .size (1 )), m ), dim = 0 )
130-
131- proto_idx = clone_plan ["mapping" ].to (torch .int32 ).argmax (dim = 1 )
132- proto_mask = torch .zeros_like (clone_plan ["mapping" ])
133- proto_mask .scatter_ (1 , proto_idx .view (- 1 , 1 ).to (torch .long ), clone_plan ["mapping" ].any (dim = 1 , keepdim = True ))
134- usd_replicate (stage , clone_plan ["src" ], clone_plan ["dest" ], world_indices , proto_mask )
135- stage .GetPrimAtPath (cfg .template_root ).SetActive (False )
136-
137- # If all prototypes map to env_0, clone whole env_0 to all envs; else clone per-object
138- if torch .all (proto_idx == 0 ):
139- replicate_args = [clone_path_fmt .format (0 )], [clone_path_fmt ], world_indices , clone_plan ["mapping" ]
140- if cfg .clone_usd :
141- # parse env_origins directly from clone_path
142- get_translate = (
143- lambda prim_path : stage .GetPrimAtPath (prim_path ).GetAttribute ("xformOp:translate" ).Get ()
144- ) # noqa: E731
145- positions = torch .tensor ([get_translate (clone_path_fmt .format (i )) for i in world_indices ])
146- usd_replicate (stage , * replicate_args , positions = positions )
147- else :
148- src = [tpl .format (int (idx )) for tpl , idx in zip (clone_plan ["dest" ], proto_idx .tolist ())]
149- replicate_args = src , clone_plan ["dest" ], world_indices , clone_plan ["mapping" ]
150- if cfg .clone_usd :
151- usd_replicate (stage , * replicate_args )
152-
153- if cfg .clone_physx :
154- physx_replicate (stage , * replicate_args , use_fabric = cfg .clone_in_fabric )
149+ src = [tpl .format (int (idx )) for tpl , idx in zip (clone_plan ["dest" ], proto_idx .tolist ())]
150+ replicate_args = src , clone_plan ["dest" ], world_indices , clone_plan ["mapping" ]
151+ if cfg .clone_usd :
152+ usd_replicate (stage , * replicate_args )
153+
154+ if cfg .clone_physx :
155+ physx_replicate (stage , * replicate_args , use_fabric = cfg .clone_in_fabric )
155156
156157
157158def usd_replicate (
0 commit comments