Skip to content

Commit 3d79bc8

Browse files
allanrenucciGoogle-ML-Automation
authored andcommitted
[Mosaic GPU][NFC] Move result inference logic to .cc file.
PiperOrigin-RevId: 836231197
1 parent efdc83d commit 3d79bc8

File tree

2 files changed

+57
-77
lines changed

2 files changed

+57
-77
lines changed

jaxlib/mosaic/dialect/gpu/mosaic_gpu.cc

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616
#include "jaxlib/mosaic/dialect/gpu/mosaic_gpu.h"
1717

1818
#include <cstdint>
19+
#include <optional>
1920
#include <string_view>
2021
#include <vector>
2122

@@ -50,6 +51,7 @@ limitations under the License.
5051
#include "mlir/IR/MLIRContext.h"
5152
#include "mlir/IR/Operation.h"
5253
#include "mlir/IR/OperationSupport.h"
54+
#include "mlir/IR/Region.h"
5355
#include "mlir/IR/TypeRange.h"
5456
#include "mlir/IR/Types.h"
5557
#include "mlir/IR/Value.h"
@@ -334,6 +336,18 @@ llvm::LogicalResult AsyncStoreOp::verify() {
334336
getSliceLengths(), getIndices().size());
335337
}
336338

339+
llvm::LogicalResult WGMMAOp::inferReturnTypes(
340+
mlir::MLIRContext*, std::optional<mlir::Location> location,
341+
mlir::ValueRange operands, mlir::DictionaryAttr attributes,
342+
mlir::OpaqueProperties properties, mlir::RegionRange regions,
343+
llvm::SmallVectorImpl<mlir::Type>& inferredReturnTypes) {
344+
if (operands.empty()) {
345+
return mlir::emitOptionalError(location, "expected non-empty operands");
346+
}
347+
inferredReturnTypes.assign({operands[0].getType()});
348+
return mlir::success();
349+
}
350+
337351
llvm::LogicalResult WGMMAOp::verify() {
338352
auto error = [this](auto... params) {
339353
return getOperation()->emitOpError(llvm::formatv(params...));
@@ -644,6 +658,19 @@ llvm::LogicalResult TmemDeallocOp::verify() {
644658
return VerifyTmemRefType(getOperation(), getTmemRef().getType());
645659
}
646660

661+
llvm::LogicalResult AsyncLoadTmemOp::inferReturnTypes(
662+
mlir::MLIRContext*, std::optional<mlir::Location> location,
663+
mlir::ValueRange operands, mlir::DictionaryAttr attributes,
664+
mlir::OpaqueProperties properties, mlir::RegionRange regions,
665+
llvm::SmallVectorImpl<mlir::Type>& inferredReturnTypes) {
666+
mlir::MemRefType memref_type =
667+
mlir::cast<mlir::MemRefType>(operands[0].getType());
668+
auto vector_type = mlir::VectorType::get(memref_type.getShape(),
669+
memref_type.getElementType());
670+
inferredReturnTypes.assign({vector_type});
671+
return mlir::success();
672+
}
673+
647674
llvm::LogicalResult AsyncLoadTmemOp::verify() {
648675
if (getSource().getType().getElementType() !=
649676
getResult().getType().getElementType()) {
@@ -690,6 +717,18 @@ llvm::LogicalResult SliceTmemOp::verify() {
690717
return llvm::success();
691718
}
692719

720+
llvm::LogicalResult VectorLoadOp::inferReturnTypes(
721+
mlir::MLIRContext*, std::optional<mlir::Location>,
722+
mlir::ValueRange operands, mlir::DictionaryAttr, mlir::OpaqueProperties,
723+
mlir::RegionRange, llvm::SmallVectorImpl<mlir::Type>& inferredReturnTypes) {
724+
mlir::MemRefType memref_type =
725+
mlir::cast<mlir::MemRefType>(operands[0].getType());
726+
auto vector_type = mlir::VectorType::get(memref_type.getShape(),
727+
memref_type.getElementType());
728+
inferredReturnTypes.assign({vector_type});
729+
return mlir::success();
730+
}
731+
693732
llvm::LogicalResult VectorStoreOp::verify() {
694733
mlir::VectorType src_type = getValueToStore().getType();
695734
mlir::MemRefType dst_type = getDestination().getType();
@@ -726,6 +765,19 @@ llvm::LogicalResult PrintLayoutOp::verify() {
726765
return llvm::success();
727766
}
728767

768+
llvm::LogicalResult OptimizationBarrierOp::inferReturnTypes(
769+
mlir::MLIRContext*, std::optional<mlir::Location> location,
770+
mlir::ValueRange operands, mlir::DictionaryAttr attributes,
771+
mlir::OpaqueProperties properties, mlir::RegionRange regions,
772+
llvm::SmallVectorImpl<mlir::Type>& inferredReturnTypes) {
773+
if (operands.empty()) {
774+
return mlir::emitOptionalError(location, "expected non-empty operands");
775+
}
776+
mlir::TypeRange operand_types = operands.getTypes();
777+
inferredReturnTypes.assign(operand_types.begin(), operand_types.end());
778+
return mlir::success();
779+
}
780+
729781
void MosaicGPUDialect::initialize() {
730782
addTypes<
731783
#define GET_TYPEDEF_LIST

jaxlib/mosaic/dialect/gpu/mosaic_gpu.td

Lines changed: 5 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ def MosaicGPU_AsyncStoreOp : Op<MosaicGPU_Dialect, "async_store",
370370
}
371371

372372
def MosaicGPU_VectorLoadOp : Op<MosaicGPU_Dialect, "vector_load",
373-
[InferTypeOpInterface, MemoryEffects<[MemRead]>]> {
373+
[DeclareOpInterfaceMethods<InferTypeOpInterface>, MemoryEffects<[MemRead]>]> {
374374
let summary = "Reads an n-D slice of memory into an n-D vector.";
375375
let description = [{
376376
Similar to `vector.load` (vector dialect) but supports loading from
@@ -386,24 +386,6 @@ def MosaicGPU_VectorLoadOp : Op<MosaicGPU_Dialect, "vector_load",
386386
OptionalAttr<BoolAttr>:$optimized
387387
);
388388
let results = (outs AnyFixedVectorOfNonZeroRank);
389-
390-
let extraClassDeclaration = [{
391-
static llvm::LogicalResult inferReturnTypes(
392-
mlir::MLIRContext *,
393-
std::optional<mlir::Location> location,
394-
mlir::ValueRange operands,
395-
mlir::DictionaryAttr attributes,
396-
mlir::OpaqueProperties properties,
397-
mlir::RegionRange regions,
398-
llvm::SmallVectorImpl<mlir::Type> &inferredReturnTypes) {
399-
mlir::MemRefType memref_type =
400-
mlir::cast<mlir::MemRefType>(operands[0].getType());
401-
auto vector_type = mlir::VectorType::get(
402-
memref_type.getShape(), memref_type.getElementType());
403-
inferredReturnTypes.assign({vector_type});
404-
return ::mlir::success();
405-
}
406-
}];
407389
}
408390

409391
def MosaicGPU_VectorStoreOp : Op<MosaicGPU_Dialect, "vector_store",
@@ -502,7 +484,8 @@ def MosaicGPU_WGMMASupportedAccumulatorType : AnyTypeOf<[F16, F32, I32],
502484
"A type supported by the accumulator `wgmma.mma_async` instruction">;
503485

504486

505-
def MosaicGPU_WGMMAOp : Op<MosaicGPU_Dialect, "wgmma", [InferTypeOpInterface]> {
487+
def MosaicGPU_WGMMAOp : Op<MosaicGPU_Dialect, "wgmma",
488+
[DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
506489
let summary = "Multiply two matrices asynchronously using warpgroup level matrix multiply operations.";
507490
let description = [{
508491
Schedules WGMMA operations that perform the following matrix multiply and
@@ -554,24 +537,6 @@ def MosaicGPU_WGMMAOp : Op<MosaicGPU_Dialect, "wgmma", [InferTypeOpInterface]> {
554537
`->` type(results)
555538
}];
556539

557-
let extraClassDeclaration = [{
558-
static llvm::LogicalResult inferReturnTypes(
559-
mlir::MLIRContext *,
560-
std::optional<mlir::Location> location,
561-
mlir::ValueRange operands,
562-
mlir::DictionaryAttr attributes,
563-
mlir::OpaqueProperties properties,
564-
mlir::RegionRange regions,
565-
llvm::SmallVectorImpl<mlir::Type> &inferredReturnTypes) {
566-
if (operands.empty()) {
567-
return ::mlir::emitOptionalError(
568-
location, "expected non-empty operands");
569-
}
570-
inferredReturnTypes.assign({operands[0].getType()});
571-
return ::mlir::success();
572-
}
573-
}];
574-
575540
let hasVerifier = 1;
576541
}
577542

@@ -630,32 +595,13 @@ def MosaicGPU_TcGen05MMAOp : Op<MosaicGPU_Dialect, "tcgen05_mma", [AttrSizedOper
630595
}
631596

632597
def MosaicGPU_OptimizationBarrierOp : Op<MosaicGPU_Dialect, "optimization_barrier",
633-
[InferTypeOpInterface]> {
598+
[DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
634599
let summary = "Prevents MLIR from moving operations across the barrier.";
635600

636601
let arguments = (ins
637602
Variadic<AnyType>:$operands
638603
);
639604
let results = (outs Variadic<AnyType>);
640-
641-
let extraClassDeclaration = [{
642-
static llvm::LogicalResult inferReturnTypes(
643-
mlir::MLIRContext *,
644-
std::optional<mlir::Location> location,
645-
mlir::ValueRange operands,
646-
mlir::DictionaryAttr attributes,
647-
mlir::OpaqueProperties properties,
648-
mlir::RegionRange regions,
649-
llvm::SmallVectorImpl<mlir::Type> &inferredReturnTypes) {
650-
if (operands.empty()) {
651-
return ::mlir::emitOptionalError(
652-
location, "expected non-empty operands");
653-
}
654-
::mlir::TypeRange operand_types = operands.getTypes();
655-
inferredReturnTypes.assign(operand_types.begin(), operand_types.end());
656-
return ::mlir::success();
657-
}
658-
}];
659605
}
660606

661607
def MosaicGPU_ReturnOp : Op<MosaicGPU_Dialect, "return",
@@ -794,31 +740,13 @@ def MosaicGPU_TmemDeallocOp : Op<MosaicGPU_Dialect, "tmem_dealloc", []> {
794740
}
795741

796742
def MosaicGPU_AsyncLoadTmemOp : Op<MosaicGPU_Dialect, "async_load_tmem",
797-
[InferTypeOpInterface]> {
743+
[DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
798744
let summary = "Copies TMEM to registers asynchronously.";
799745

800746
let arguments = (ins MemRefRankOf<[AnyType], [2]>:$source);
801747
let results = (outs VectorOfRank<[2]>);
802748

803749
let hasVerifier = 1;
804-
805-
let extraClassDeclaration = [{
806-
static llvm::LogicalResult inferReturnTypes(
807-
mlir::MLIRContext *,
808-
std::optional<mlir::Location> location,
809-
mlir::ValueRange operands,
810-
mlir::DictionaryAttr attributes,
811-
mlir::OpaqueProperties properties,
812-
mlir::RegionRange regions,
813-
llvm::SmallVectorImpl<mlir::Type> &inferredReturnTypes) {
814-
mlir::MemRefType memref_type =
815-
mlir::cast<mlir::MemRefType>(operands[0].getType());
816-
auto vector_type = mlir::VectorType::get(
817-
memref_type.getShape(), memref_type.getElementType());
818-
inferredReturnTypes.assign({vector_type});
819-
return ::mlir::success();
820-
}
821-
}];
822750
}
823751

824752
def MosaicGPU_AsyncStoreTmemOp : Op<MosaicGPU_Dialect, "async_store_tmem", []> {

0 commit comments

Comments
 (0)