@@ -15,19 +15,19 @@ Tensor Comprehension Notation
1515-----------------------------
1616TC borrow three ideas from Einstein notation that make expressions concise:
1717
18- 1 . loop index variables are defined implicitly by using them in an expression and their range is aggressively inferred based on what they index,
19- 2 . indices that appear on the right of an expression but not on the left are assumed to be reduction dimensions,
20- 3 . the evaluation order of points in the iteration space does not affect the output.
18+ 1 . Loop index variables are defined implicitly by using them in an expression and their range is aggressively inferred based on what they index.
19+ 2 . Indices that appear on the right of an expression but not on the left are assumed to be reduction dimensions.
20+ 3 . The evaluation order of points in the iteration space does not affect the output.
2121
2222Let's start with a simple example is a matrix vector product:
2323
24- def mv(float(R,C) A, float(C) x ) -> (o) {
25- o(i) +=! A(i,j) * b (j)
24+ def mv(float(R,C) A, float(C) B ) -> (o) {
25+ o(i) +=! A(i,j) * B (j)
2626 }
2727
2828` A ` and ` x ` are input tensors. ` o ` is an output tensor.
29- The statement ` o(i) += A(i,j)* b(j) ` introduces two index variables ` i ` and ` j ` .
30- Their range is inferred by their use indexing ` A ` and ` b ` . ` i = [0,R) ` , ` j = [0,C) ` .
29+ The statement ` o(i) += A(i,j) * b(j) ` introduces two index variables ` i ` and ` j ` .
30+ Their range is inferred by their use indexing ` A ` and ` B ` . ` i = [0,R) ` , ` j = [0,C) ` .
3131Because ` j ` only appears on the right side,
3232stores into ` o ` will reduce over ` j ` with the reduction specified for the loop.
3333Reductions can occur across multiple variables, but they all share the same kind of associative reduction (e.g. +=)
@@ -36,7 +36,7 @@ to maintain invariant (3). `mv` computes the same thing as this C++ loop:
3636 for(int i = 0; i < R; i++) {
3737 o(i) = 0.0f;
3838 for(int j = 0; j < C; j++) {
39- o(i) += A(i,j) * b (j);
39+ o(i) += A(i,j) * B (j);
4040 }
4141 }
4242
@@ -47,30 +47,33 @@ Examples of TC
4747
4848We provide a few basic examples.
4949
50- Simple matrix-vector:
50+ ** Simple matrix-vector** :
5151
52- def mv(float(R,C) A, float(C) x ) -> (o) {
53- o(i) += A(i,j) * b (j)
52+ def mv(float(R,C) A, float(C) B ) -> (o) {
53+ o(i) += A(i,j) * B (j)
5454 }
5555
56- Simple matrix-multiply (note the layout for B is transposed and matches the
56+ ** Simple matrix-multiply:**
57+
58+ Note the layout for B is transposed and matches the
5759traditional layout of the weight matrix in a linear layer):
5860
5961 def mm(float(X,Y) A, float(Y,Z) B) -> (R) {
6062 R(i,j) += A(i,j) * B(j,k)
6163 }
6264
63- Simple 2-D convolution (no stride, no padding):
65+ ** Simple 2-D convolution (no stride, no padding):**
6466
6567 def conv(float(B,IP,H,W) input, float(OP,IP,KH,KW) weight) -> (output) {
6668 output(b, op, h, w) += input(b, ip, h + kh, w + kw) * weight(op, ip, kh, kw)
6769 }
6870
69- Simple 2D max pooling (note the similarity with a convolution with a
71+ ** Simple 2D max pooling:**
72+
73+ Note the similarity with a convolution with a
7074"select"-style kernel):
7175
7276 def maxpool2x2(float(B,C,H,W) input) -> (output) {
7377 output(b,c,i,j) max= input(b,c,2*i + kw, 2*j + kh)
7478 where kw = [0, 2[, kh = [0, 2[
7579 }
76-
0 commit comments