Skip to content

Commit 2a979f8

Browse files
Regenerate MLIR Bindings (#1833)
Co-authored-by: enzyme-ci-bot[bot] <78882869+enzyme-ci-bot[bot]@users.noreply.github.com>
1 parent 8d245f9 commit 2a979f8

File tree

2 files changed

+67
-2
lines changed

2 files changed

+67
-2
lines changed

src/mlir/Dialects/Enzyme.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -520,8 +520,8 @@ function getFlattenedSamplesFromTrace(
520520
)
521521
end
522522

523-
function get(gradient::Value; result_0::IR.Type, location=Location())
524-
op_ty_results = IR.Type[result_0,]
523+
function get(gradient::Value; result::IR.Type, location=Location())
524+
op_ty_results = IR.Type[result,]
525525
operands = Value[gradient,]
526526
owned_regions = Region[]
527527
successors = Block[]

src/mlir/Dialects/TPU.jl

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -910,6 +910,38 @@ function prng_set_seed_32(seeds::Vector{Value}; location=Location())
910910
)
911911
end
912912

913+
"""
914+
`pack_elementwise`
915+
916+
Packs multiple `sources` elementwise into a single vector of a narrower `target_type`.
917+
918+
The number of `sources` must equal the packing factor, which is the ratio of
919+
the element bitwidth of the `sources` to the element bitwidth of the
920+
`target_type`. Elements from the `sources` are interleaved and packed into
921+
each word of the `output`, ordered from lowest to highest bits,
922+
corresponding to their order in the `sources`.
923+
"""
924+
function pack_elementwise(
925+
sources::Vector{Value}; output::IR.Type, target_type, location=Location()
926+
)
927+
op_ty_results = IR.Type[output,]
928+
operands = Value[sources...,]
929+
owned_regions = Region[]
930+
successors = Block[]
931+
attributes = NamedAttribute[namedattribute("target_type", target_type),]
932+
933+
return create_operation(
934+
"tpu.pack_elementwise",
935+
location;
936+
operands,
937+
owned_regions,
938+
successors,
939+
attributes,
940+
results=op_ty_results,
941+
result_inference=false,
942+
)
943+
end
944+
913945
function pack_vmsk(low::Value, high::Value; output::IR.Type, location=Location())
914946
op_ty_results = IR.Type[output,]
915947
operands = Value[low, high]
@@ -1620,6 +1652,39 @@ function truncf(in::Value; out::IR.Type, rounding_mode, location=Location())
16201652
)
16211653
end
16221654

1655+
"""
1656+
`unpack_elementwise`
1657+
1658+
Unpacks a single vector from `source`, which contains multiple `source_type`
1659+
vectors packed elementwise.
1660+
1661+
The `index` selects which packed value to extract from each word of `source`.
1662+
An `index` of 0 corresponds to the lowest bits. The extracted values are
1663+
cast to the output element type.
1664+
"""
1665+
function unpack_elementwise(
1666+
source::Value; output::IR.Type, source_type, index, location=Location()
1667+
)
1668+
op_ty_results = IR.Type[output,]
1669+
operands = Value[source,]
1670+
owned_regions = Region[]
1671+
successors = Block[]
1672+
attributes = NamedAttribute[
1673+
namedattribute("source_type", source_type), namedattribute("index", index)
1674+
]
1675+
1676+
return create_operation(
1677+
"tpu.unpack_elementwise",
1678+
location;
1679+
operands,
1680+
owned_regions,
1681+
successors,
1682+
attributes,
1683+
results=op_ty_results,
1684+
result_inference=false,
1685+
)
1686+
end
1687+
16231688
function unpack_subelements(
16241689
source::Value;
16251690
output::IR.Type,

0 commit comments

Comments
 (0)