@@ -523,15 +523,14 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, []> {
523523 AffineMapArrayAttr:$indexing_maps,
524524 ArrayAttr:$iterator_types,
525525 OptionalAttr<StrAttr>:$doc,
526- OptionalAttr<FlatSymbolRefAttr>:$fun,
527526 OptionalAttr<StrAttr>:$library_call);
528527 let results = (outs Variadic<AnyRankedTensor>:$output_tensors);
529528 let regions = (region AnyRegion:$region);
530529 let extraClassDeclaration = [{
531530 SmallVector<StringRef, 8> linalgTraitAttrNames() {
532531 return SmallVector<StringRef, 8>{
533532 getArgsInAttrName(), getArgsOutAttrName(), getDocAttrName(),
534- getFunAttrName(), getIndexingMapsAttrName(), getLibraryCallAttrName(),
533+ getIndexingMapsAttrName(), getLibraryCallAttrName(),
535534 getIteratorTypesAttrName()
536535 };
537536 }
@@ -540,12 +539,6 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, []> {
540539
541540 unsigned getNumOutputs() { return args_out().getSExtValue(); }
542541
543- FuncOp getFunction() {
544- auto moduleOp = getParentOfType<ModuleOp>();
545- return fun().hasValue() ?
546- moduleOp.lookupSymbol<FuncOp>(fun().getValue()) : FuncOp();
547- }
548-
549542 StringRef getLibraryCallName() {
550543 return library_call().hasValue() ? library_call().getValue() : "";
551544 }
@@ -581,13 +574,6 @@ def GenericOp : GenericOpBase<"generic"> {
581574 - args_in: an I64Attr representing the number of input (readonly) views
582575 - args_out: an I64Attr representing the number of output (readwrite) views
583576 - doc [optional]: a documentation string
584- - fun: a FlatSymbolRefAttr that must resolve to an existing function
585- symbol. To support inplace updates in a generic fashion, the signature
586- of the function must be:
587- ```
588- fun([input views element types], [output views element types])
589- -> ([output views element types])
590- ```
591577 - indexing_maps: a list of AffineMapAttr, one AffineMapAttr per each input
592578 and output view. Such AffineMapAttr specifies the mapping between the
593579 loops and the indexing within each view.
@@ -604,19 +590,13 @@ def GenericOp : GenericOpBase<"generic"> {
604590 Example:
605591 Defining a #matmul_trait attribute in MLIR can be done as follows:
606592 ```mlir
607- func @fma(%a: f32, %b: f32, %c: f32) -> f32 {
608- %d = mulf %a, %b: f32
609- %e = addf %c, %d: f32
610- return %e: f32
611- }
612593 #matmul_accesses = [
613594 (m, n, k) -> (m, k),
614595 (m, n, k) -> (k, n),
615596 (m, n, k) -> (m, n)
616597 ]
617598 #matmul_trait = {
618599 doc = "C(m, n) += A(m, k) * B(k, n)",
619- fun = @fma,
620600 indexing_maps = #matmul_accesses,
621601 library_call = "linalg_matmul",
622602 n_views = [2, 1],
@@ -626,10 +606,14 @@ def GenericOp : GenericOpBase<"generic"> {
626606
627607 And can be reused in multiple places as:
628608 ```mlir
629- linalg.generic #matmul_trait %A, %B, %C [other-attributes] :
630- memref<?x?xf32, stride_specification>,
631- memref<?x?xf32, stride_specification>,
632- memref<?x?xf32, stride_specification>
609+ linalg.generic #matmul_trait %A, %B, %C [other-attributes] {
610+ (%a: f32, %b: f32, %c: f32) :
611+ %d = mulf %a, %b: f32
612+ %e = addf %c, %d: f32
613+ linalg_yield %e : f32
614+ } : memref<?x?xf32, stride_specification>,
615+ memref<?x?xf32, stride_specification>,
616+ memref<?x?xf32, stride_specification>
633617 ```
634618
635619 This may lower to either:
@@ -649,9 +633,9 @@ def GenericOp : GenericOpBase<"generic"> {
649633 %a = load %A[%m, %k] : memref<?x?xf32, stride_specification>
650634 %b = load %B[%k, %n] : memref<?x?xf32, stride_specification>
651635 %c = load %C[%m, %n] : memref<?x?xf32, stride_specification>
652- %d = call @func_of_elements( %a, %b, %c)
653- : ( f32, f32, f32) -> (f32)
654- store %d , %C[%m, %n] : memref<?x?x?xf32, stride_specification>
636+ %d = mulf %a, %b: f32
637+ %e = addf %c, %d: f32
638+ store %e , %C[%m, %n] : memref<?x?x?xf32, stride_specification>
655639 }
656640 }
657641 }
@@ -662,7 +646,7 @@ def GenericOp : GenericOpBase<"generic"> {
662646 mixing input and output ranked tensor values with input and output memrefs.
663647
664648 ```mlir
665- %C = linalg.generic #trait_attribute %A, %B {other-attributes} :
649+ %C = linalg.generic #trait_attribute %A, %B {other-attributes} {region} :
666650 tensor<?x?xf32>,
667651 memref<?x?xf32, stride_specification>
668652 -> (tensor<?x?xf32>)
@@ -708,13 +692,6 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
708692 - args_in: an I64Attr representing the number of input (readonly) views
709693 - args_out: an I64Attr representing the number of output (readwrite) views
710694 - doc [optional]: a documentation string
711- - fun: a FlatSymbolRefAttr that must resolve to an existing function
712- symbol. To support inplace updates in a generic fashion, the signature
713- of the function must be:
714- ```
715- fun([index types of induction variables], [input views element types],
716- [output views element types]) -> ([output views element types])
717- ```
718695 - indexing_maps: a list of AffineMapAttr, one AffineMapAttr per each input
719696 and output view. Such AffineMapAttr specifies the mapping between the
720697 loops and the indexing within each view.
@@ -732,23 +709,13 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
732709 Defining a #matmul_trait attribute in MLIR can be done as follows:
733710
734711 ```mlir
735- func @fma(%offset_m: index, %offset_n: index, %offset_k: index,
736- %a: f32, %b: f32, %c: f32)
737- -> f32
738- {
739- "some_optional_condition"(%offset_m, %offset_n, %offset_k)
740- %d = mulf %a, %b: f32
741- %e = addf %c, %d: f32
742- return %e: f32
743- }
744712 #matmul_accesses = [
745713 (m, n, k) -> (m, k),
746714 (m, n, k) -> (k, n),
747715 (m, n, k) -> (m, n)
748716 ]
749717 #matmul_trait = {
750718 doc = "C(m, n) += A(m, k) * B(k, n)",
751- fun = @fma,
752719 indexing_maps = #matmul_accesses,
753720 library_call = "linalg_matmul",
754721 n_views = [2, 1],
@@ -759,10 +726,16 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
759726 And can be reused in multiple places as:
760727
761728 ```mlir
762- linalg.indexed_generic #matmul_trait %A, %B, %C [other-attributes] :
763- memref<?x?xf32, stride_specification>,
764- memref<?x?xf32, stride_specification>,
765- memref<?x?xf32, stride_specification>
729+ linalg.indexed_generic #matmul_trait %A, %B, %C [other-attributes] {
730+ (%offset_m: index, %offset_n: index, %offset_k: index,
731+ %a: f32, %b: f32, %c: f32) :
732+ "some_optional_computation"(%offset_m, %offset_n, %offset_k)
733+ %d = mulf %a, %b: f32
734+ %e = addf %c, %d: f32
735+ linalg_yield %e : f32
736+ } : memref<?x?xf32, stride_specification>,
737+ memref<?x?xf32, stride_specification>,
738+ memref<?x?xf32, stride_specification>
766739 ```
767740
768741 This may lower to either:
@@ -784,8 +757,9 @@ def IndexedGenericOp : GenericOpBase<"indexed_generic"> {
784757 %a = load %A[%m, %k] : memref<?x?xf32, stride_specification>
785758 %b = load %B[%k, %n] : memref<?x?xf32, stride_specification>
786759 %c = load %C[%m, %n] : memref<?x?xf32, stride_specification>
787- %d = call @func_of_elements_and_indices(%m, %n, %k, %a, %b, %c)
788- : (index, index, index, f32, f32, f32) -> (f32)
760+ "some_optional_computation"(%m, %n, %k)
761+ %d = mulf %a, %b: f32
762+ %e = addf %c, %d: f32
789763 store %d, %C[%m, %n] : memref<?x?x?xf32, stride_specification>
790764 }
791765 }
0 commit comments