Skip to content

Conversation

@mdazz
Copy link
Contributor

@mdazz mdazz commented Oct 30, 2025

This commit teaches the folding methods of AtenFloor, AtenCeil, AtenRound, and AtenTruc to constant-fold roundings when the operand is a splat DenseElementsAttr.

This commit teaches the folding methods of `AtenFloor`, `AtenCeil`, `AtenRound`, and `AtenTruc`
to constant-fold roundings when the operand is a splat `DenseElementsAttr`.
@mdazz mdazz force-pushed the mdazz/ad-rounding-folder branch from 58550c3 to 175d2a2 Compare October 30, 2025 10:14
@mdazz
Copy link
Contributor Author

mdazz commented Oct 30, 2025

@sahas3 @zjgarvey @vivekkhandelwal1 can you please take a look at this?

Copy link
Member

@sahas3 sahas3 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice change, looks good to me for most part.

}

// Common helper for splat-only rounding-based folders.
static OpFoldResult foldSplatRounding(ValueTensorType resultType,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think naming it foldFloatSplatWithRounding will be more appropriate since it only handles float data.

return {};

auto outShaped = resultType.toBuiltinTensor();
if (!outShaped.hasStaticShape())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is static shape a requirement?

@export
@annotate_args([None])
def forward(self):
return torch.ops.aten.round(self.const)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am guessing the other ops are already covered in e2e tests?

// CHECK: %[[C:.*]] = torch.vtensor.literal(dense<-1.000000e+00> : tensor<2x2xf32>)
// CHECK: return %[[C]]
func.func @torch.aten.ceil$fold() -> !torch.vtensor<[2,2],f32> {
%cst = torch.vtensor.literal(dense<-1.100000e+00> : tensor<2x2xf32>)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a +ve value too for completeness?

// CHECK: %[[C:.*]] = torch.vtensor.literal(dense<1.000000e+00> : tensor<3x4xf32>)
// CHECK: return %[[C]]
func.func @torch.aten.floor$fold() -> !torch.vtensor<[3,4],f32> {
%cst = torch.vtensor.literal(dense<1.900000e+00> : tensor<3x4xf32>)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly adding a negative value will give full coverage

Comment on lines +253 to +254
// NaNs and infs are dealt with consistently with torch, so side-effects
// can be discarded.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Clarification: Can you elaborate on what you meant by "NaNs and infs are dealt with consistently with torch" ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants