@@ -26,22 +26,26 @@ limitations under the License.
2626#include " llvm/ADT/SmallVector.h"
2727#include " mlir/IR/MLIRContext.h"
2828#include " xla/codegen/tiling/constraint_expression.h"
29- #include " xla/hlo/analysis/indexing_map .h"
29+ #include " xla/hlo/analysis/interval .h"
3030#include " xla/hlo/ir/hlo_instruction.h"
3131#include " xla/hlo/utils/hlo_traversal.h"
3232#include " xla/service/gpu/model/experimental/symbolic_tile.h"
3333#include " xla/shape.h"
3434
3535namespace xla ::gpu::experimental {
3636
37- // TilingSpace contains information about all parallel and sequential dimensions
38- // and runtime variables in a fusion.
39- // The parallel dimensions correspond to the dimensions of the outputs of the
40- // fusion.
41- // The sequential dimensions correspond to the contraction/reduction dimensions
42- // of the dots/reduces in the fusion.
43- // The runtime variables correspond to the offsets of the dynamic slices in the
44- // fusion.
37+ // TilingSpace holds information about all tiling parameters of a fusion.
38+ //
39+ // It defines symbolic tiles for the fusion roots as symbolic expressions and
40+ // constraints of possible tile "variables":
41+ // * parallel dimensions - output dimensions of the fusion;
42+ // * sequential dimensions - contraction/reduction dimensions of operations in
43+ // the fusion;
44+ // * runtime variables - for example, offsets of the dynamic slices.
45+ //
46+ // This information allows us later to explore the space of all possible tilings
47+ // and assign concrete tilings for every instruction of the fusion with
48+ // SymbolicTilePropagation.
4549class TilingSpace {
4650 public:
4751 TilingSpace () : constraints_(ConstraintExpression::GetAlwaysSatisfied()) {}
@@ -51,29 +55,46 @@ class TilingSpace {
5155
5256 enum class DimensionSemantics { kParallel , kSequential };
5357 struct DimensionInfo {
54- // Unique ID for the dimension.
58+ // Unique ID for the dimension within the tiling space .
5559 ID id;
5660 // Size of the dimension.
5761 int64_t dimension_size;
5862 // Type of the dimension.
5963 DimensionSemantics type;
60- // HLO instruction that defines the dimension.
64+ // HLO instruction that defines (introduces) the dimension. For example
65+ // fusion root instruction defines the parallel dimensions. Dot/reduce
66+ // defines the sequential (contraction) dimensions.
6167 const HloInstruction* hlo;
62- // Index into the ordered list of dimensions of the HLO instruction.
63- // All dimensions in the HLO instruction are described as
68+ // Index into the ordered list of dimensions of the HLO instruction `hlo`
69+ // that defines the dimension.
70+ // All dimensions in the HLO instruction are ordered as
6471 // [all parallel dims of the output, all reduction/contraction dims].
6572 //
66- // Example:
67- // [output_dims] dot(lhs, rhs, lhs_contracting_dims, rhs_contracting_dims)
68- // The dimensions are ordered as [output_dims, LHS[lhs_contracting_dims]].
73+ // Example, for `[a,b,c] = dot(lhs, rhs, lhs_contracting_dims={d,e}, ...)`.
74+ // The ordered list of dimensions is [a,b,c,d,e].
6975 int64_t dim_position;
7076 };
7177
78+ // Information about a runtime variable.
79+ // For example:
80+ //
81+ // off = s32[] parameter(0)
82+ // ds = dynamic-slice(tensor, off), ...
83+ //
84+ // `off = s32[] parameter(0)` instruction (`hlo`) defines the runtime
85+ // variable.
86+ // User's (dynamic-slice) semantics sets the `bounds` of possible values.
87+ //
88+ // If the same hlo is used as runtime variable multiple times, there will be
89+ // multiple entries in the `rt_vars_` with different IDs.
90+ //
91+ // RTVarInfo are accessed by (user_hlo, operand_id), in this case it is
92+ // (dynamic-slice, 1).
7293 struct RTVarInfo {
73- // Unique ID for the runtime variable.
94+ // Unique ID for the runtime variable within the tiling space .
7495 ID id;
75- // Feasible bounds of the runtime variable. The values outside of the bounds
76- // will be clamped.
96+ // Feasible bounds of the runtime variable.
97+ // The values outside of the bounds will be clamped.
7798 Interval bounds;
7899 // HLO instruction that defines the runtime variable.
79100 const HloInstruction* hlo;
@@ -90,9 +111,15 @@ class TilingSpace {
90111 sink.Append (space.ToString ());
91112 }
92113
114+ // Returns the dimension info for the given `hlo` and `dim_position`.
115+ // `dim_position` is the index into the ordered list of dimensions of the HLO
116+ // instruction `hlo` that defines the dimension. The dimension info must
117+ // exist.
93118 const DimensionInfo& GetDimensionInfo (const HloInstruction& hlo,
94119 int64_t dim_position) const ;
95120
121+ // Returns the runtime variable info for `hlo` that uses it and its
122+ // `operand_id`. This runtime variable info must exist.
96123 const RTVarInfo& GetRTVarInfo (const HloInstruction& hlo,
97124 int64_t operand_id) const ;
98125
@@ -115,6 +142,7 @@ class TilingSpace {
115142 void ProcessDot (const HloInstruction& hlo);
116143 void ProcessReduce (const HloInstruction& hlo);
117144 void ProcessDynamicSlice (const HloInstruction& hlo);
145+ void ProcessInstruction (const HloInstruction& hlo);
118146
119147 // Maps from (hlo, dim_position) to the dimension info.
120148 absl::flat_hash_map<std::pair<const HloInstruction*, int64_t >,
@@ -130,7 +158,9 @@ class TilingSpace {
130158 // The deque is used to guarantee the pointer stability.
131159 std::deque<RTVarInfo> rt_vars_;
132160
133- // Root instruction of the fusion.
161+ // Symbolic tiles for the fusion roots.
162+ // For tuple roots, there will be one tile per tuple element. Otherwise,
163+ // there will be only one symbolic tile.
134164 llvm::SmallVector<SymbolicTile, 2 > tiled_roots_;
135165
136166 // Constraint expression for the tiling space.
0 commit comments