Skip to content

Commit 481a91c

Browse files
Regenerate MLIR Bindings (#1862)
Co-authored-by: enzyme-ci-bot[bot] <78882869+enzyme-ci-bot[bot]@users.noreply.github.com>
1 parent 9d45b84 commit 481a91c

File tree

3 files changed

+323
-96
lines changed

3 files changed

+323
-96
lines changed

src/mlir/Dialects/MemRef.jl

Lines changed: 37 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -461,9 +461,10 @@ region. To return a value, one should use `memref.alloca_scope.return`
461461
operation:
462462
463463
```mlir
464-
%result = memref.alloca_scope {
464+
%result = memref.alloca_scope -> f32 {
465+
%value = arith.constant 1.0 : f32
465466
...
466-
memref.alloca_scope.return %value
467+
memref.alloca_scope.return %value : f32
467468
}
468469
```
469470
@@ -498,7 +499,7 @@ the return operation may be omitted. Otherwise, it has to be present
498499
to indicate which values are going to be returned. For example:
499500
500501
```mlir
501-
memref.alloca_scope.return %value
502+
memref.alloca_scope.return %value : f32
502503
```
503504
"""
504505
function alloca_scope_return(results::Vector{Value}; location=Location())
@@ -563,11 +564,11 @@ address space.
563564
# Example
564565
565566
```mlir
566-
Cast to concrete shape.
567-
%4 = memref.cast %1 : memref<*xf32> to memref<4x?xf32>
567+
// Cast to concrete shape.
568+
%4 = memref.cast %1 : memref<*xf32> to memref<4x?xf32>
568569
569-
Erase rank information.
570-
%5 = memref.cast %1 : memref<4x?xf32> to memref<*xf32>
570+
// Erase rank information.
571+
%5 = memref.cast %1 : memref<4x?xf32> to memref<*xf32>
571572
```
572573
"""
573574
function cast(source::Value; dest::IR.Type, location=Location())
@@ -662,8 +663,8 @@ alloc\'d memref (e.g. memrefs returned by `view` operations).
662663
# Example
663664
664665
```mlir
665-
%0 = memref.alloc() : memref<8x64xf32, affine_map<(d0, d1) -> (d0, d1), 1>>
666-
memref.dealloc %0 : memref<8x64xf32, affine_map<(d0, d1) -> (d0, d1), 1>>
666+
%0 = memref.alloc() : memref<8x64xf32, affine_map<(d0, d1) -> (d0, d1)>, 1>
667+
memref.dealloc %0 : memref<8x64xf32, affine_map<(d0, d1) -> (d0, d1)>, 1>
667668
```
668669
"""
669670
function dealloc(memref::Value; location=Location())
@@ -766,22 +767,22 @@ For example, a DmaStartOp operation that transfers 256 elements of a memref
766767
space 1 at indices [%k, %l], would be specified as follows:
767768
768769
```mlir
769-
%num_elements = arith.constant 256
770+
%num_elements = arith.constant 256 : index
770771
%idx = arith.constant 0 : index
771-
%tag = memref.alloc() : memref<1 x i32, affine_map<(d0) -> (d0)>, 4>
772-
dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx] :
773-
memref<40 x 128 x f32>, affine_map<(d0) -> (d0)>, 0>,
774-
memref<2 x 1024 x f32>, affine_map<(d0) -> (d0)>, 1>,
775-
memref<1 x i32>, affine_map<(d0) -> (d0)>, 2>
772+
%tag = memref.alloc() : memref<1 x i32, affine_map<(d0) -> (d0)>, 2>
773+
memref.dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx] :
774+
memref<40 x 128 x f32, affine_map<(d0, d1) -> (d0, d1)>, 0>,
775+
memref<2 x 1024 x f32, affine_map<(d0, d1) -> (d0, d1)>, 1>,
776+
memref<1 x i32, affine_map<(d0) -> (d0)>, 2>
776777
```
777778
778779
If %stride and %num_elt_per_stride are specified, the DMA is expected to
779780
transfer %num_elt_per_stride elements every %stride elements apart from
780781
memory space 0 until %num_elements are transferred.
781782
782783
```mlir
783-
dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx], %stride,
784-
%num_elt_per_stride :
784+
memref.dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%idx], %stride,
785+
%num_elt_per_stride :
785786
```
786787
787788
* TODO: add additional operands to allow source and destination striding, and
@@ -818,10 +819,10 @@ number of elements associated with the DMA operation.
818819
# Example
819820
820821
```mlir
821-
dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%index] :
822-
memref<2048 x f32>, affine_map<(d0) -> (d0)>, 0>,
823-
memref<256 x f32>, affine_map<(d0) -> (d0)>, 1>
824-
memref<1 x i32>, affine_map<(d0) -> (d0)>, 2>
822+
memref.dma_start %src[%i, %j], %dst[%k, %l], %num_elements, %tag[%index] :
823+
memref<2048 x f32, affine_map<(d0) -> (d0)>, 0>,
824+
memref<256 x f32, affine_map<(d0) -> (d0)>, 1>,
825+
memref<1 x i32, affine_map<(d0) -> (d0)>, 2>
825826
...
826827
...
827828
dma_wait %tag[%index], %num_elements : memref<1 x i32, affine_map<(d0) -> (d0)>, 2>
@@ -998,16 +999,16 @@ This makes lowering more progressive and brings the following benefits:
998999
9991000
```mlir
10001001
%base, %offset, %sizes:2, %strides:2 =
1001-
memref.extract_strided_metadata %memref :
1002-
memref<10x?xf32>, index, index, index, index, index
1002+
memref.extract_strided_metadata %memref : memref<10x?xf32>
1003+
-> memref<f32>, index, index, index, index, index
10031004
10041005
// After folding, the type of %m2 can be memref<10x?xf32> and further
10051006
// folded to %memref.
10061007
%m2 = memref.reinterpret_cast %base to
10071008
offset: [%offset],
10081009
sizes: [%sizes#0, %sizes#1],
10091010
strides: [%strides#0, %strides#1]
1010-
: memref<f32> to memref<?x?xf32, offset: ?, strides: [?, ?]>
1011+
: memref<f32> to memref<?x?xf32, strided<[?, ?], offset:?>>
10111012
```
10121013
"""
10131014
function extract_strided_metadata(
@@ -1096,10 +1097,10 @@ given global variable will always return the same memref descriptor).
10961097
10971098
```mlir
10981099
// Private variable with an initial value.
1099-
memref.global \"private\" @x : memref<2xf32> = dense<0.0,2.0>
1100+
memref.global \"private\" @x : memref<2xf32> = dense<[0.0, 2.0]>
11001101
11011102
// Private variable with an initial value and an alignment (power of 2).
1102-
memref.global \"private\" @x : memref<2xf32> = dense<0.0,2.0> {alignment = 64}
1103+
memref.global \"private\" @x : memref<2xf32> = dense<[0.0, 2.0]> {alignment = 64}
11031104
11041105
// Declaration of an external variable.
11051106
memref.global \"private\" @y : memref<4xi32>
@@ -1108,7 +1109,7 @@ memref.global \"private\" @y : memref<4xi32>
11081109
memref.global @z : memref<3xf16> = uninitialized
11091110
11101111
// Externally visible constant variable.
1111-
memref.global constant @c : memref<2xi32> = dense<1, 4>
1112+
memref.global constant @c : memref<2xi32> = dense<[1, 4]>
11121113
```
11131114
"""
11141115
function global_(;
@@ -1328,8 +1329,8 @@ behavior.
13281329
13291330
```mlir
13301331
%new = memref.realloc %old : memref<64xf32> to memref<124xf32>
1331-
%4 = memref.load %new[%index] // ok
1332-
%5 = memref.load %old[%index] // undefined behavior
1332+
%4 = memref.load %new[%index] : memref<124xf32> // ok
1333+
%5 = memref.load %old[%index] : memref<64xf32> // undefined behavior
13331334
```
13341335
"""
13351336
function realloc(
@@ -1458,7 +1459,8 @@ In other words:
14581459
%dst = memref.reinterpret_cast %src to
14591460
offset: [%offset],
14601461
sizes: [%sizes],
1461-
strides: [%strides]
1462+
strides: [%strides] :
1463+
memref<*xf32> to memref<?x?xf32, strided<[?, ?], offset: ?>>
14621464
```
14631465
means that `%dst`\'s descriptor will be:
14641466
```mlir
@@ -1521,12 +1523,12 @@ Result type is ranked.
15211523
```mlir
15221524
// Reshape statically-shaped memref.
15231525
%dst = memref.reshape %src(%shape)
1524-
: (memref<4x1xf32>, memref<1xi32>) to memref<4xf32>
1526+
: (memref<4x1xf32>, memref<1xi32>) -> memref<4xf32>
15251527
%dst0 = memref.reshape %src(%shape0)
1526-
: (memref<4x1xf32>, memref<2xi32>) to memref<2x2xf32>
1528+
: (memref<4x1xf32>, memref<2xi32>) -> memref<2x2xf32>
15271529
// Flatten unranked memref.
15281530
%dst = memref.reshape %src(%shape)
1529-
: (memref<*xf32>, memref<1xi32>) to memref<?xf32>
1531+
: (memref<*xf32>, memref<1xi32>) -> memref<?xf32>
15301532
```
15311533
15321534
b. Source type is ranked or unranked. Shape argument has dynamic size.
@@ -1535,10 +1537,10 @@ Result type is unranked.
15351537
```mlir
15361538
// Reshape dynamically-shaped 1D memref.
15371539
%dst = memref.reshape %src(%shape)
1538-
: (memref<?xf32>, memref<?xi32>) to memref<*xf32>
1540+
: (memref<?xf32>, memref<?xi32>) -> memref<*xf32>
15391541
// Reshape unranked memref.
15401542
%dst = memref.reshape %src(%shape)
1541-
: (memref<*xf32>, memref<?xi32>) to memref<*xf32>
1543+
: (memref<*xf32>, memref<?xi32>) -> memref<*xf32>
15421544
```
15431545
"""
15441546
function reshape(source::Value, shape::Value; result::IR.Type, location=Location())

src/mlir/Dialects/MosaicGPU.jl

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,25 @@ function custom_primitive(
322322
)
323323
end
324324

325+
function debug_print(value::Value; format, location=Location())
326+
op_ty_results = IR.Type[]
327+
operands = Value[value,]
328+
owned_regions = Region[]
329+
successors = Block[]
330+
attributes = NamedAttribute[namedattribute("format", format),]
331+
332+
return create_operation(
333+
"mosaic_gpu.debug_print",
334+
location;
335+
operands,
336+
owned_regions,
337+
successors,
338+
attributes,
339+
results=op_ty_results,
340+
result_inference=false,
341+
)
342+
end
343+
325344
"""
326345
`initialize_barrier`
327346
@@ -669,6 +688,74 @@ function tmem_relinquish_alloc_permit(; collective=nothing, location=Location())
669688
)
670689
end
671690

691+
"""
692+
`vector_load`
693+
694+
Similar to `vector.load` (vector dialect) but supports loading from
695+
non-contiguous memory.
696+
697+
If `optimized` is true, raises an error if we cannot generate an optimised
698+
transfer. If unset, fall back to a non-optimized transfer if unable to
699+
generate an optimized transfer.
700+
"""
701+
function vector_load(
702+
source::Value;
703+
result_0=nothing::Union{Nothing,IR.Type},
704+
optimized=nothing,
705+
location=Location(),
706+
)
707+
op_ty_results = IR.Type[]
708+
operands = Value[source,]
709+
owned_regions = Region[]
710+
successors = Block[]
711+
attributes = NamedAttribute[]
712+
!isnothing(result_0) && push!(op_ty_results, result_0)
713+
!isnothing(optimized) && push!(attributes, namedattribute("optimized", optimized))
714+
715+
return create_operation(
716+
"mosaic_gpu.vector_load",
717+
location;
718+
operands,
719+
owned_regions,
720+
successors,
721+
attributes,
722+
results=(length(op_ty_results) == 0 ? nothing : op_ty_results),
723+
result_inference=(length(op_ty_results) == 0 ? true : false),
724+
)
725+
end
726+
727+
"""
728+
`vector_store`
729+
730+
Similar to `vector.store` (vector dialect) but supports storing to
731+
non-contiguous memory.
732+
733+
If `optimized` is true, raises an error if we cannot generate an optimised
734+
transfer. If unset, fall back to a non-optimized transfer if unable to
735+
generate an optimized transfer.
736+
"""
737+
function vector_store(
738+
valueToStore::Value, destination::Value; optimized=nothing, location=Location()
739+
)
740+
op_ty_results = IR.Type[]
741+
operands = Value[valueToStore, destination]
742+
owned_regions = Region[]
743+
successors = Block[]
744+
attributes = NamedAttribute[]
745+
!isnothing(optimized) && push!(attributes, namedattribute("optimized", optimized))
746+
747+
return create_operation(
748+
"mosaic_gpu.vector_store",
749+
location;
750+
operands,
751+
owned_regions,
752+
successors,
753+
attributes,
754+
results=op_ty_results,
755+
result_inference=false,
756+
)
757+
end
758+
672759
"""
673760
`wgmma`
674761

0 commit comments

Comments
 (0)