-
Notifications
You must be signed in to change notification settings - Fork 614
Open
Description
-
According the define of interface name aten.conv_transpose3d.input with input suffix, its naming style differs significantly from other conv operators, such as aten.conv_transpose1d, aten.conv1d.
so Is there any reason about it ? -
The interface name can also be get from the torch.export.export (torch 2.6.0),for example:
(py311-source-172) root@998ee80b761b: /torch-mlir# cat conv3dTran_new.py
import torch
import torch_mlir.compiler_utils
from torch_mlir import fx
import torch_mlir
class SimpleModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.conv = torch.nn.ConvTranspose3d(
in_channels=1,
out_channels=8,
kernel_size=(2, 3, 3),
stride=(2, 2, 2),
padding=(1, 1, 1)
)
def forward(self, x):
return self.conv(x)
model = SimpleModel()
model.eval()
# 1. N=1, C_in=1, D=5, H=5, W=5)
example_input = torch.randn(1, 1, 5, 5, 5, dtype=torch.float32)
fx_graph = torch.export.export(model, (example_input, ))
torch_mlir.compiler_utils
module_fx = fx.export_and_import(fx_graph,
func_name="forward",
output_type=fx.OutputType.RAW)
print("\n=== FX Graph to MLIR ===")
print(module_fx)
Metadata
Metadata
Assignees
Labels
No labels