Skip to content

Commit a08bcc5

Browse files
committed
address review suggestions
Signed-off-by: dchigarev <dmitry.chigarev@intel.com>
1 parent 499db16 commit a08bcc5

File tree

1 file changed

+3
-5
lines changed
  • ingress/Torch-MLIR/py_src/export_lib

1 file changed

+3
-5
lines changed

ingress/Torch-MLIR/py_src/export_lib/utils.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -63,15 +63,13 @@ def parse_shape_str(shape: str) -> tuple[tuple[int], torch.dtype]:
6363
shape : str
6464
A string representing the shape and dtype, e.g. '1,3,224,224,float32'.
6565
"""
66-
components = shape.split(",")
67-
shapes = components[:-1]
68-
dtype = components[-1]
66+
*shapes, dtype = shape.split(",")
6967
tdtype = getattr(torch, dtype)
7068
if tdtype is None:
7169
raise ValueError(f"Unsupported dtype: {dtype}")
72-
if any(dim == "?" for dim in shapes):
70+
if "?" in shapes:
7371
raise ValueError(f"Dynamic shapes are not supported yet: {shape}")
74-
return (tuple(int(dim) for dim in shapes if dim), tdtype)
72+
return (tuple(map(int, shapes)), tdtype)
7573

7674

7775
def generate_fake_tensor(shape: tuple[int], dtype: torch.dtype) -> torch.Tensor:

0 commit comments

Comments
 (0)