Skip to content

Commit f3f5bad

Browse files
yuvaltassacopybara-github
authored andcommitted
Refactor MJMODEL_POINTERS into category-specific macros.
PiperOrigin-RevId: 802089588 Change-Id: I95f6d3f64452328169b2a05d5c6bfd525d7128b1
1 parent 05f6475 commit f3f5bad

File tree

4 files changed

+18
-4
lines changed

4 files changed

+18
-4
lines changed

dm_control/autowrap/binding_generator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,8 @@ def parse_hints(self, xmacro_src):
160160
for tokens, _, _ in parser.scanString(xmacro_src):
161161
for xmacro in tokens:
162162
for member in xmacro.members:
163+
if not hasattr(member, "name") or not member.name:
164+
continue
163165
# "Squeeze out" singleton dimensions.
164166
shape = self.get_shape_tuple(member.dims, squeeze=True)
165167
self.hints_dict.update({member.name: shape})

dm_control/autowrap/codegen_util.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,14 +78,18 @@ def __setitem__(self, k, v):
7878

7979
def macro_struct_name(name, suffix=None):
8080
"""Converts mjxmacro struct names, e.g. "MJDATA_POINTERS" to "mjdata"."""
81+
if name.startswith("MJMODEL_POINTERS"):
82+
return "mjmodel"
8183
if suffix is None:
8284
suffix = _MJXMACRO_SUFFIX
83-
return name[:-len(suffix)].lower()
85+
if name.endswith(suffix):
86+
return name[:-len(suffix)].lower()
87+
return name.lower()
8488

8589

8690
def is_macro_pointer(name):
8791
"""Returns True if the mjxmacro struct name contains pointer sizes."""
88-
return name.endswith(_MJXMACRO_SUFFIX)
92+
return name.endswith(_MJXMACRO_SUFFIX) or name.startswith("MJMODEL_POINTERS")
8993

9094

9195
def try_coerce_to_num(s, try_types=(int, float)):

dm_control/autowrap/header_parsing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,12 +173,13 @@ def _nested_ifn_else(ifn_, pred, else_, endif, match_if_true, match_if_false):
173173
pp.delimitedList(XDIM, delim=COMMA)("dims") +
174174
RPAREN)
175175

176+
XMACRO_LINE = XMEMBER | NAME
176177
XMACRO = pp.Group(
177178
pp.Optional(COMMENT("comment")) +
178179
DEFINE +
179180
NAME("name") +
180181
CONTINUATION +
181-
pp.delimitedList(XMEMBER, delim=CONTINUATION)("members"))
182+
pp.delimitedList(XMACRO_LINE, delim=CONTINUATION)("members"))
182183

183184

184185
# Type/variable declarations.

dm_control/mujoco/index.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,14 @@ def _is_name_pointer(field_name):
300300

301301
def _get_size_name(field_name, struct_name='mjmodel'):
302302
# Look up size name in metadata.
303-
return sizes.array_sizes[struct_name][field_name][0]
303+
try:
304+
return sizes.array_sizes[struct_name][field_name][0]
305+
except KeyError:
306+
# Special handling required for name pointers in mjModel.
307+
if _is_name_pointer(field_name):
308+
return 'n' + field_name.split('_')[1][:-3]
309+
else:
310+
raise
304311

305312

306313
def _validate_key_item(key_item):

0 commit comments

Comments
 (0)