|
| 1 | +from typing import Union |
| 2 | + |
| 3 | +from mlir import ir |
| 4 | +from mlir.dialects import linalg, arith, tensor, math |
| 5 | + |
| 6 | +from .utils import ( |
| 7 | + affine_map, |
| 8 | + get_bias, |
| 9 | + get_outputs, |
| 10 | + get_weights, |
| 11 | + parallel, |
| 12 | + reduction, |
| 13 | +) |
| 14 | + |
| 15 | + |
| 16 | +def affine_maps_and_iter_types(rank: int): |
| 17 | + M, N, K = [ir.AffineDimExpr.get(i) for i in range(3)] |
| 18 | + |
| 19 | + if rank == 2: # plain 2D weights |
| 20 | + affine_maps = [ |
| 21 | + affine_map(3, [M, K]), |
| 22 | + affine_map(3, [K, N]), |
| 23 | + affine_map(3, [M, N]), |
| 24 | + ] |
| 25 | + iterator_types = [parallel, parallel, reduction] |
| 26 | + elif rank == 4: # tiled weights, no vnni blocking |
| 27 | + mb, nb, kb = [ir.AffineDimExpr.get(i) for i in range(3, 6)] |
| 28 | + affine_maps = [ |
| 29 | + affine_map(6, [M, K, mb, kb]), |
| 30 | + affine_map(6, [N, K, kb, nb]), # transposed K and N on B |
| 31 | + affine_map(6, [M, N, mb, nb]), |
| 32 | + ] |
| 33 | + iterator_types = [parallel, parallel, reduction] * 2 |
| 34 | + elif rank == 5: # tiled weights with vnni blocking |
| 35 | + # FIXME: due to replicating C++ code, vnni dim is in middle instead of at end. |
| 36 | + k_vnni, mb, nb, kb = [ir.AffineDimExpr.get(i) for i in range(3, 7)] |
| 37 | + |
| 38 | + affine_maps = [ |
| 39 | + affine_map(7, [M, K, mb, kb, k_vnni]), |
| 40 | + # TODO(RM): check if kb and (k_)vnni _not_ being contiguous makes sense. |
| 41 | + affine_map(7, [N, K, kb, nb, k_vnni]), # transposed K and N on B |
| 42 | + affine_map(7, [M, N, mb, nb]), |
| 43 | + ] |
| 44 | + iterator_types = [ |
| 45 | + parallel, # M |
| 46 | + parallel, # N |
| 47 | + reduction, # K |
| 48 | + reduction, # vnni |
| 49 | + parallel, # mb |
| 50 | + parallel, # nb |
| 51 | + reduction, # kb |
| 52 | + ] |
| 53 | + else: |
| 54 | + assert False |
| 55 | + |
| 56 | + return affine_maps, iterator_types |
| 57 | + |
| 58 | + |
| 59 | +def times_weights( |
| 60 | + inputs: ir.Value, |
| 61 | + weights_or_weights_type: Union[ir.Value, ir.RankedTensorType], |
| 62 | + outputs_or_outputs_type: Union[ir.Value, ir.RankedTensorType], |
| 63 | +) -> ir.Value: |
| 64 | + weights: ir.Value = get_weights(weights_or_weights_type) |
| 65 | + outputs: ir.Value = get_outputs(outputs_or_outputs_type) |
| 66 | + |
| 67 | + if weights.type.rank == 5: # tiled weights with vnni blocking |
| 68 | + vnni_block = weights.type.get_dim_size(4) |
| 69 | + assert inputs.type.shape[-1] % vnni_block == 0 |
| 70 | + |
| 71 | + expanded_shape = ( |
| 72 | + inputs.type.shape[:-1] |
| 73 | + + [inputs.type.shape[-1] // vnni_block] |
| 74 | + + [vnni_block] |
| 75 | + ) |
| 76 | + inputs = tensor.expand_shape( |
| 77 | + ir.RankedTensorType.get(expanded_shape, inputs.type.element_type), |
| 78 | + inputs, |
| 79 | + reassociation=[[0], [1], [2], [3, 4]], |
| 80 | + output_shape=[], |
| 81 | + static_output_shape=expanded_shape, |
| 82 | + ) |
| 83 | + |
| 84 | + affine_maps, iterator_types = affine_maps_and_iter_types(weights.type.rank) |
| 85 | + |
| 86 | + @linalg.generic([inputs, weights], [outputs], affine_maps, iterator_types) |
| 87 | + def inputs_times_weights(a, b, c): |
| 88 | + prod = arith.MulFOp(a, b) |
| 89 | + return arith.AddFOp(prod.result, c) |
| 90 | + |
| 91 | + return inputs_times_weights |
| 92 | + |
| 93 | + |
| 94 | +def add_bias(inputs: ir.Value, bias_or_bias_type: Union[ir.Value, ir.Type] = None): |
| 95 | + bias: ir.Value = get_bias(bias_or_bias_type) |
| 96 | + |
| 97 | + M, N, mb, nb = [ir.AffineDimExpr.get(i) for i in range(4)] |
| 98 | + affine_maps, iterator_types = { |
| 99 | + 2: ([affine_map(2, [N]), affine_map(2, [M, N])], [parallel] * 2), |
| 100 | + 4: ([affine_map(4, [N, nb]), affine_map(4, [M, N, mb, nb])], [parallel] * 4), |
| 101 | + }[inputs.type.rank] |
| 102 | + |
| 103 | + @linalg.generic([bias], [inputs], affine_maps, iterator_types) |
| 104 | + def biased(a, b): |
| 105 | + return arith.AddFOp(a, b) |
| 106 | + |
| 107 | + return biased |
| 108 | + |
| 109 | + |
| 110 | +def relu(inputs: ir.Value): |
| 111 | + zero = arith.constant(inputs.type.element_type, 0.0) |
| 112 | + |
| 113 | + M, N, mb, nb = [ir.AffineDimExpr.get(i) for i in range(4)] |
| 114 | + affine_maps, iterator_types = { |
| 115 | + 2: ([affine_map(2, [M, N])], [parallel, parallel]), |
| 116 | + 4: ([affine_map(4, [M, N, mb, nb])], [parallel, parallel] * 2), |
| 117 | + }[inputs.type.rank] |
| 118 | + |
| 119 | + @linalg.generic([], [inputs], affine_maps, iterator_types) |
| 120 | + def relu_ed(a): |
| 121 | + return arith.MaximumFOp(a, zero) |
| 122 | + |
| 123 | + return relu_ed |
| 124 | + |
| 125 | + |
| 126 | +def softmax( |
| 127 | + inputs: ir.Value, softmax_buf_or_softmax_buf_type: Union[ir.Value, ir.Type] |
| 128 | +) -> ir.Value: |
| 129 | + softmax_buf = get_outputs(softmax_buf_or_softmax_buf_type) |
| 130 | + |
| 131 | + shape, elem_type = inputs.type.shape, inputs.type.element_type |
| 132 | + exp_out_uninit = tensor.EmptyOp(shape, elem_type) |
| 133 | + |
| 134 | + dims = [ir.AffineDimExpr.get(i) for i in range(inputs.type.rank)] |
| 135 | + par_affine_map = affine_map(inputs.type.rank, dims) |
| 136 | + par_affine_maps = [par_affine_map] * inputs.type.rank |
| 137 | + par_iter_types = [parallel] * inputs.type.rank |
| 138 | + red_affine_map = affine_map( |
| 139 | + inputs.type.rank, [dims[0], ir.AffineConstantExpr.get(0)] |
| 140 | + ) |
| 141 | + red_iter_types = [parallel, reduction] * (inputs.type.rank // 2) |
| 142 | + |
| 143 | + @linalg.generic([inputs], [exp_out_uninit.result], par_affine_maps, par_iter_types) |
| 144 | + def exped(input, _output): |
| 145 | + return math.exp(input) |
| 146 | + |
| 147 | + zero = arith.constant(elem_type, 0.0) |
| 148 | + reduction_out_uninit = tensor.EmptyOp((shape[0], 1), elem_type) |
| 149 | + reduction_out = linalg.fill(zero, outs=reduction_out_uninit) |
| 150 | + |
| 151 | + @linalg.generic( |
| 152 | + [exped], [reduction_out], [par_affine_map, red_affine_map], red_iter_types |
| 153 | + ) |
| 154 | + def summed_exped(exped_input, redex): |
| 155 | + return arith.AddFOp(exped_input, redex) |
| 156 | + |
| 157 | + bcast_out_uninit = tensor.EmptyOp(shape, elem_type) |
| 158 | + |
| 159 | + @linalg.generic( |
| 160 | + [summed_exped], |
| 161 | + [bcast_out_uninit.result], |
| 162 | + [red_affine_map, par_affine_map], |
| 163 | + par_iter_types, |
| 164 | + ) |
| 165 | + def bcasted_summed_exped(input, _output): |
| 166 | + return input |
| 167 | + |
| 168 | + @linalg.generic( |
| 169 | + [exped, bcasted_summed_exped], |
| 170 | + [softmax_buf], |
| 171 | + [par_affine_map] * 3, |
| 172 | + par_iter_types, |
| 173 | + ) |
| 174 | + def dived_bcasted_summed_exped(exped_input, normalizing_term, _output): |
| 175 | + return arith.DivFOp(exped_input, normalizing_term) |
| 176 | + |
| 177 | + return dived_bcasted_summed_exped |
0 commit comments