77// ===----------------------------------------------------------------------===//
88
99#include " mlir/Dialect/Tensor/IR/Tensor.h"
10+ #include " mlir/IR/DialectImplementation.h"
1011#include " mlir/Transforms/InliningUtils.h"
12+ #include " llvm/ADT/TypeSwitch.h"
1113
1214using namespace mlir ;
1315using namespace mlir ::tensor;
1416
17+ // ===----------------------------------------------------------------------===//
18+ // TableGen'd Attributes Methods
19+ // ===----------------------------------------------------------------------===//
20+
21+ #define GET_ATTRDEF_CLASSES
22+ #include " mlir/Dialect/Tensor/IR/TensorAttrDefs.cpp.inc"
23+
24+ // Dictionary keys.
25+ static constexpr StringRef getSparseDimLevelTypeAttrName () {
26+ return " sparseDimLevelType" ;
27+ }
28+ static constexpr StringRef getSparseDimOrderingAttrName () {
29+ return " sparseDimOrdering" ;
30+ }
31+ static constexpr StringRef getSparsePointerBitWidthAttrName () {
32+ return " sparsePointerBitWidth" ;
33+ }
34+ static constexpr StringRef getSparseIndexBitWidthAttrName () {
35+ return " sparseIndexBitWidth" ;
36+ }
37+
38+ // Dictionary values.
39+ static constexpr StringRef getDenseDimLevelTypeVal () { return " dense" ; }
40+ static constexpr StringRef getCompressedDimLevelTypeVal () {
41+ return " compressed" ;
42+ }
43+ static constexpr StringRef getSingletonDimLevelTypeVal () { return " singleton" ; }
44+
45+ Attribute SparseTensorEncodingAttr::parse (MLIRContext *context,
46+ DialectAsmParser &parser, Type type) {
47+ if (failed (parser.parseLess ()))
48+ return {};
49+ DictionaryAttr dict;
50+ if (failed (parser.parseAttribute (dict)))
51+ return {};
52+ if (failed (parser.parseGreater ()))
53+ return {};
54+ return SparseTensorEncodingAttr::get (context, dict);
55+ }
56+
57+ void SparseTensorEncodingAttr::print (DialectAsmPrinter &printer) const {
58+ printer << " sparse<" << getDict () << " >" ;
59+ }
60+
61+ LogicalResult SparseTensorEncodingAttr::verifyEncoding (
62+ llvm::ArrayRef<int64_t > shape, Type elementType,
63+ llvm::function_ref<mlir::InFlightDiagnostic()> emitError) const {
64+ unsigned size = shape.size ();
65+ for (const NamedAttribute &attr : getDict ()) {
66+ if (attr.first == getSparseDimLevelTypeAttrName ()) {
67+ // Dimension level type verification.
68+ auto arrayAttr = attr.second .dyn_cast <ArrayAttr>();
69+ if (!arrayAttr || size != static_cast <int64_t >(arrayAttr.size ()))
70+ return emitError () << " expected an array of size " << size
71+ << " for dimension level types" ;
72+ for (unsigned i = 0 ; i < size; i++) {
73+ auto strAttr = arrayAttr[i].dyn_cast <StringAttr>();
74+ if (!strAttr)
75+ return emitError ()
76+ << " expected string value in dimension level types" ;
77+ auto strVal = strAttr.getValue ();
78+ if (strVal != getDenseDimLevelTypeVal () &&
79+ strVal != getCompressedDimLevelTypeVal () &&
80+ strVal != getSingletonDimLevelTypeVal ())
81+ return emitError () << " unexpected dimension level type: " << strAttr;
82+ }
83+ } else if (attr.first == getSparseDimOrderingAttrName ()) {
84+ // Dimension order verification.
85+ auto affineAttr = attr.second .dyn_cast <AffineMapAttr>();
86+ if (!affineAttr)
87+ return emitError () << " expected an affine map for dimension ordering" ;
88+ AffineMap map = affineAttr.getValue ();
89+ if (size != map.getNumResults () || !map.isPermutation ())
90+ return emitError () << " expected a permutation affine map of size "
91+ << size << " for dimension ordering" ;
92+ } else if (attr.first == getSparsePointerBitWidthAttrName () ||
93+ attr.first == getSparseIndexBitWidthAttrName ()) {
94+ // Pointer or index bitwidth verification.
95+ auto intAttr = attr.second .dyn_cast <IntegerAttr>();
96+ if (!intAttr)
97+ return emitError () << " expected an integral bitwidth" ;
98+ switch (intAttr.getInt ()) {
99+ case 0 :
100+ case 8 :
101+ case 16 :
102+ case 32 :
103+ case 64 :
104+ continue ;
105+ default :
106+ return emitError () << " unexpected bitwidth: " << intAttr.getInt ();
107+ }
108+ } else {
109+ return emitError () << " unexpected key: " << attr.first .str ();
110+ }
111+ }
112+ return success ();
113+ }
114+
115+ SparseTensorEncodingAttr::DimLevelType
116+ SparseTensorEncodingAttr::getDimLevelType (unsigned dim) const {
117+ if (auto value = getDict ().get (getSparseDimLevelTypeAttrName ())) {
118+ auto strVal =
119+ value.dyn_cast <ArrayAttr>()[dim].cast <StringAttr>().getValue ();
120+ if (strVal == getCompressedDimLevelTypeVal ())
121+ return DimLevelType::Compressed;
122+ if (strVal == getSingletonDimLevelTypeVal ())
123+ return DimLevelType::Singleton;
124+ }
125+ return DimLevelType::Dense;
126+ }
127+
128+ AffineMap SparseTensorEncodingAttr::getDimOrdering () const {
129+ if (auto value = getDict ().get (getSparseDimOrderingAttrName ()))
130+ return value.cast <AffineMapAttr>().getValue ();
131+ return {};
132+ }
133+
134+ unsigned SparseTensorEncodingAttr::getPointerBitWidth () const {
135+ if (auto value = getDict ().get (getSparsePointerBitWidthAttrName ()))
136+ return value.cast <IntegerAttr>().getInt ();
137+ return 0 ;
138+ }
139+
140+ unsigned SparseTensorEncodingAttr::getIndexBitWidth () const {
141+ if (auto value = getDict ().get (getSparseIndexBitWidthAttrName ()))
142+ return value.cast <IntegerAttr>().getInt ();
143+ return 0 ;
144+ }
145+
15146// ===----------------------------------------------------------------------===//
16147// TensorDialect Dialect Interfaces
17148// ===----------------------------------------------------------------------===//
@@ -30,10 +161,38 @@ struct TensorInlinerInterface : public DialectInlinerInterface {
30161};
31162} // end anonymous namespace
32163
164+ // ===----------------------------------------------------------------------===//
165+ // TensorDialect Methods
166+ // ===----------------------------------------------------------------------===//
167+
33168void TensorDialect::initialize () {
169+ addAttributes<
170+ #define GET_ATTRDEF_LIST
171+ #include " mlir/Dialect/Tensor/IR/TensorAttrDefs.cpp.inc"
172+ >();
34173 addOperations<
35174#define GET_OP_LIST
36175#include " mlir/Dialect/Tensor/IR/TensorOps.cpp.inc"
37176 >();
38177 addInterfaces<TensorInlinerInterface>();
39178}
179+
180+ Attribute TensorDialect::parseAttribute (DialectAsmParser &parser,
181+ Type type) const {
182+ StringRef attrTag;
183+ if (failed (parser.parseKeyword (&attrTag)))
184+ return Attribute ();
185+ Attribute attr;
186+ auto parseResult =
187+ generatedAttributeParser (getContext (), parser, attrTag, type, attr);
188+ if (parseResult.hasValue ())
189+ return attr;
190+ parser.emitError (parser.getNameLoc (), " unknown tensor attribute" );
191+ return Attribute ();
192+ }
193+
194+ void TensorDialect::printAttribute (::mlir::Attribute attr,
195+ ::mlir::DialectAsmPrinter &printer) const {
196+ if (succeeded (generatedAttributePrinter (attr, printer)))
197+ return ;
198+ }
0 commit comments