Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit a401619

Browse files
Merge pull request #215 from facebookresearch/pr/tc-coding-guide
TC coding guide
2 parents b74440d + 37bf9bc commit a401619

34 files changed

+698
-571
lines changed

CodingConventions.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -588,3 +588,10 @@ changes) while working on a feature and even in "WIP" pull requests,
588588
as long as the pieces are recombined (e.g., through an interactive rebase)
589589
into logical units when the feature is ready for merging.
590590
Force-pushing in PR branches is fine.
591+
592+
Coding Conventions for writing Tensor Comprehensions
593+
====================================================
594+
595+
Please see the following documentation
596+
[entry](https://facebookresearch.github.io/TensorComprehensions/coding_conventions.html)
597+
on how to write Tensor Comprehensions in a standard legible fashion.

benchmarks/MLP_model.cc

Lines changed: 43 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -64,23 +64,23 @@ DEFINE_uint32(Q, 2, "W4_h");
6464
// float(E1, D) LUT1, int32(B, L1) I1,
6565
// float(E2, D) LUT2, int32(B, L2) I2) -> (O1, O2)
6666
// {
67-
// O1(i, j) +=! LUT1(I1(i, k), j)
68-
// O2(i, j) +=! LUT2(I2(i, k), j)
67+
// O1(b, d) +=! LUT1(I1(b, r_l1), d)
68+
// O2(b, d) +=! LUT2(I2(b, r_l2), d)
6969
// }
7070
// def _3FCRELU(
7171
// float(B,M) I, float(O,N) W2, float(O) B2,
7272
// float(P,O) W3, float(P) B3, float(Q,P) W4,
7373
// float(Q) B4) -> (O1, O2, O3, O4)
7474
// {
75-
// O2(b, o) = B2(o)
76-
// O2(b, o) += O1(b, n) * W2(o, n)
77-
// O2(b, o) = fmax(O2(b, o), 0)
78-
// O3(b, p) = B3(p)
79-
// O3(b, p) += O2(b, o) * W3(p, o)
80-
// O3(b, p) = fmax(O3(b, p), 0)
81-
// O4(b, q) = B4(q)
82-
// O4(b, q) += O3(b, p) * W4(q, p)
83-
// O4(b, q) = fmax(O4(b, q), 0)
75+
// O2(b, o) = B2(o)
76+
// O2(b, o) += O1(b, n) * W2(o, n)
77+
// O2(b, o) = fmax(O2(b, o), 0)
78+
// O3(b, p) = B3(p)
79+
// O3(b, p) += O2(b, o) * W3(p, o)
80+
// O3(b, p) = fmax(O3(b, p), 0)
81+
// O4(b, q) = B4(q)
82+
// O4(b, q) += O3(b, p) * W4(q, p)
83+
// O4(b, q) = fmax(O4(b, q), 0)
8484
// }
8585
// def prod_model(float(E1, D) LUT1, int32(B, L1) I1,
8686
// float(E2, D) LUT2, int32(B, L2) I2,
@@ -91,15 +91,15 @@ DEFINE_uint32(Q, 2, "W4_h");
9191
// float(Q,P) W4, float(Q) B4)
9292
// -> (C1, C2, C3, I, O1, O2, O3, O4)
9393
// {
94-
// (C1, C2) = _2LUT(LUT1, I1, LUT2, I2)
95-
// C3(b, wy) += I3(b, wxx) * W(wy, wxx)
96-
// I(b, m) = Concat(C1, C2, C3) // not in TC atm
97-
// O1(b, n) = B1(n)
98-
// O1(b, n) += I(b, m) * W1(n, m)
99-
// O1(b, n) = fmax(O1(b, n), 0)
94+
// (C1, C2) = _2LUT(LUT1, I1, LUT2, I2)
95+
// C3(b, wy) +=! I3(b, r_wx) * W(wy, r_wx)
96+
// I(b, m) = Concat(C1, C2, C3) // not in TC atm
97+
// O1(b, n) = B1(n)
98+
// O1(b, n) +=! I(b, m) * W1(n, m)
99+
// O1(b, n) = fmax(O1(b, n), 0)
100100
// (O2, O3, O4) =
101-
// _3FCRELU(I, W1, B1, W2, B2, W3, B3, W4, B4)
102-
// # O4 goes out to binary classifier, omitted here
101+
// _3FCRELU(I, W1, B1, W2, B2, W3, B3, W4, B4)
102+
// # O4 goes out to binary classifier, omitted here
103103
// }
104104

105105
class ProductionModel : public Benchmark {
@@ -191,9 +191,9 @@ void ProductionModel::run1LUT(
191191

192192
std::vector<at::Tensor> inputs = {LUT1, IDX1};
193193
std::string tc = R"(
194-
def _1LUT(float(E1, D) LUT1, int32(B, L1) I1) -> (O1) {
195-
O1(i, j) +=! LUT1(I1(i, k), j)
196-
}
194+
def _1LUT(float(E1, D) LUT1, int32(B, L1) I1) -> (O1) {
195+
O1(b, d) +=! LUT1(I1(b, r_l1), d)
196+
}
197197
)";
198198

199199
std::string suffix = std::string("_B_") + std::to_string(FLAGS_B) +
@@ -294,10 +294,10 @@ void ProductionModel::run2LUT(
294294

295295
std::vector<at::Tensor> inputs = {LUT1, IDX1, LUT2, IDX2};
296296
std::string tc = R"(
297-
def _2LUT(float(E1, D) LUT1, int32(B, L1) I1, float(E2, D) LUT2, int32(B, L2) I2) -> (O1, O2) {
298-
O1(i, j) +=! LUT1(I1(i, k), j)
299-
O2(i, j) +=! LUT2(I2(i, k), j)
300-
}
297+
def _2LUT(float(E1, D) LUT1, int32(B, L1) I1, float(E2, D) LUT2, int32(B, L2) I2) -> (O1, O2) {
298+
O1(b, d) +=! LUT1(I1(b, r_l1), d)
299+
O2(b, d) +=! LUT2(I2(b, r_l2), d)
300+
}
301301
)";
302302

303303
std::string suffix = std::string("_B_") + std::to_string(FLAGS_B) +
@@ -353,9 +353,9 @@ void ProductionModel::runC3(
353353

354354
std::vector<at::Tensor> inputs = {I, W};
355355
std::string tc = R"TC(
356-
def _C3(float(B,WX) I, float(WY, WX) W) -> (C3) {
357-
C3(b, wy) +=! I(b, wxx) * W(wy, wxx)
358-
}
356+
def _C3(float(B,WX) I, float(WY, WX) W) -> (C3) {
357+
C3(b, wy) +=! I(b, r_wx) * W(wy, r_wx)
358+
}
359359
)TC";
360360

361361
std::string suffix = std::string("_B_") + std::to_string(FLAGS_B) +
@@ -408,11 +408,11 @@ void ProductionModel::runMLP1(
408408

409409
std::vector<at::Tensor> inputs = {I, W1, B1};
410410
std::string tc = R"TC(
411-
def mlp1(float(B,M) I, float(M, N) W1, float(N) B1) -> (O1) {
412-
O1(b, n) +=! I(b, mm) * W1(mm, n)
413-
O1(b, n) = O1(b, n) + B1(n)
414-
O1(b, n) = fmax(O1(b, n), 0)
415-
}
411+
def mlp1(float(B,M) I, float(M, N) W1, float(N) B1) -> (O1) {
412+
O1(b, n) +=! I(b, r_m) * W1(r_m, n)
413+
O1(b, n) = O1(b, n) + B1(n)
414+
O1(b, n) = fmax(O1(b, n), 0)
415+
}
416416
)TC";
417417

418418
std::string suffix = std::string("_B_") + std::to_string(FLAGS_B) +
@@ -474,17 +474,17 @@ void ProductionModel::runMLP3(
474474

475475
std::vector<at::Tensor> inputs = {I, W2, B2, W3, B3, W4, B4};
476476
std::string tc = R"TC(
477-
def mlp3(float(B,N) I, float(O,N) W2, float(O) B2, float(P,O) W3, float(P) B3, float(Q,P) W4, float(Q) B4) -> (O2, O3, O4) {
478-
O2(b, o) +=! I(b, n) * W2(o, n)
479-
O2(b, o) = O2(b, o) + B2(o)
480-
O2(b, o) = fmax(O2(b, o), 0)
477+
def mlp3(float(B,N) I, float(O,N) W2, float(O) B2, float(P,O) W3, float(P) B3, float(Q,P) W4, float(Q) B4) -> (O2, O3, O4) {
478+
O2(b, o) +=! I(b, n) * W2(o, n)
479+
O2(b, o) = O2(b, o) + B2(o)
480+
O2(b, o) = fmax(O2(b, o), 0)
481481
O3(b, p) +=! O2(b, o) * W3(p, o)
482-
O3(b, p) = O3(b, p) + B3(p)
483-
O3(b, p) = fmax(O3(b, p), 0)
482+
O3(b, p) = O3(b, p) + B3(p)
483+
O3(b, p) = fmax(O3(b, p), 0)
484484
O4(b, q) +=! O3(b, p) * W4(q, p)
485-
O4(b, q) = O4(b, q) + B4(q)
486-
O4(b, q) = fmax(O4(b, q), 0)
487-
}
485+
O4(b, q) = O4(b, q) + B4(q)
486+
O4(b, q) = fmax(O4(b, q), 0)
487+
}
488488
)TC";
489489

490490
std::string suffix = std::string("_B_") + std::to_string(FLAGS_B) +

benchmarks/batchmatmul.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,9 @@ void BatchMatMul::runBatchMatMul(
7676

7777
std::vector<at::Tensor> inputs = {X, Y};
7878
std::string tc = R"(
79-
def batch_matmul(float(B, N, M) X, float(B, M, K) Y) -> (Z) {
80-
Z(b, n, k) +=! X(b, n, mm) * Y(b, mm, k)
81-
}
79+
def batch_matmul(float(B, N, M) X, float(B, M, K) Y) -> (Z) {
80+
Z(b, n, k) +=! X(b, n, r_m) * Y(b, r_m, k)
81+
}
8282
)";
8383

8484
std::string suffix = std::string("_B_") + std::to_string(FLAGS_B) +

benchmarks/group_convolution.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -122,13 +122,13 @@ void GroupConvolution::runGroupConvolution(
122122
.resize_({G, F});
123123
std::vector<at::Tensor> inputs = {tI, tW, tB};
124124
std::string tc = R"(
125-
def group_convolution(float(N,G,C,H,W) I, float(G,F,C,KH,KW) W1, float(G,F) B)
126-
-> (O)
127-
{
125+
def group_convolution(float(N,G,C,H,W) I, float(G,F,C,KH,KW) W1, float(G,F) B)
126+
-> (O)
127+
{
128128
O(n, g, f, h, w) +=!
129-
I(n, g, c, h + kh, w + kw) * W1(g, f, c, kh, kw)
130-
O(n, g, f, h, w) = O(n, g, f, h, w) + B(g, f)
131-
}
129+
I(n, g, r_c, h + r_kh, w + r_kw) * W1(g, f, r_c, r_kh, r_kw)
130+
O(n, g, f, h, w) = O(n, g, f, h, w) + B(g, f)
131+
}
132132
)";
133133

134134
std::string suffix = std::string("_N_") + std::to_string(FLAGS_N) +

benchmarks/tmm.cc

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,9 @@ void TransposedMatMul::runTransposedMatMul(
7373

7474
std::vector<at::Tensor> inputs = {A, B};
7575
std::string tc = R"TC(
76-
def tmm(float(M,K) A, float(N,K) B) -> (C) {
77-
C(m, n) +=! A(m, kk) * B(n, kk)
78-
}
76+
def tmm(float(M,K) A, float(N,K) B) -> (C) {
77+
C(m, n) +=! A(m, r_k) * B(n, r_k)
78+
}
7979
)TC";
8080

8181
std::string suffix = std::string("_M_") + std::to_string(FLAGS_M) +

docs/doxygen/index.md

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,30 +13,30 @@ with a few basic functionalities.
1313

1414
Tensor Comprehension Notation
1515
-----------------------------
16-
TC borrow three ideas from Einstein notation that make expressions concise:
16+
TC borrows three ideas from Einstein notation that make expressions concise:
1717

1818
1. Loop index variables are defined implicitly by using them in an expression and their range is aggressively inferred based on what they index.
1919
2. Indices that appear on the right of an expression but not on the left are assumed to be reduction dimensions.
2020
3. The evaluation order of points in the iteration space does not affect the output.
2121

2222
Let's start with a simple example is a matrix vector product:
2323

24-
def mv(float(R,C) A, float(C) B) -> (o) {
25-
o(i) +=! A(i,j) * B(j)
24+
def mv(float(R,C) A, float(C) x) -> (o) {
25+
o(r) +=! A(r,r_c) * x(r_c)
2626
}
2727

2828
`A` and `x` are input tensors. `o` is an output tensor.
29-
The statement `o(i) += A(i,j) * b(j)` introduces two index variables `i` and `j`.
30-
Their range is inferred by their use indexing `A` and `B`. `i = [0,R)`, `j = [0,C)`.
31-
Because `j` only appears on the right side,
32-
stores into `o` will reduce over `j` with the reduction specified for the loop.
29+
The statement `o(r) +=! A(r,r_c) * x(r_c)` introduces two index variables `r` and `r_c`.
30+
Their range is inferred by their use indexing `A` and `x`. `r = [0,R)`, `r_c = [0,C)`.
31+
Because `r_c` only appears on the righthand side,
32+
stores into `o` will reduce over `r_c` with the reduction specified for the loop.
3333
Reductions can occur across multiple variables, but they all share the same kind of associative reduction (e.g. +=)
3434
to maintain invariant (3). `mv` computes the same thing as this C++ loop:
3535

3636
for(int i = 0; i < R; i++) {
3737
o(i) = 0.0f;
3838
for(int j = 0; j < C; j++) {
39-
o(i) += A(i,j) * B(j);
39+
o(i) += A(i,j) * x(j);
4040
}
4141
}
4242

@@ -50,7 +50,7 @@ We provide a few basic examples.
5050
**Simple matrix-vector**:
5151

5252
def mv(float(R,C) A, float(C) B) -> (o) {
53-
o(i) += A(i,j) * B(j)
53+
o(r) +=! A(r,r_c) * B(r_c)
5454
}
5555

5656
**Simple matrix-multiply:**
@@ -59,21 +59,20 @@ Note the layout for B is transposed and matches the
5959
traditional layout of the weight matrix in a linear layer):
6060

6161
def mm(float(X,Y) A, float(Y,Z) B) -> (R) {
62-
R(i,j) += A(i,j) * B(j,k)
62+
R(x,z) +=! A(x,r_y) * B(r_y,z)
6363
}
6464

6565
**Simple 2-D convolution (no stride, no padding):**
6666

6767
def conv(float(B,IP,H,W) input, float(OP,IP,KH,KW) weight) -> (output) {
68-
output(b, op, h, w) += input(b, ip, h + kh, w + kw) * weight(op, ip, kh, kw)
68+
output(b, op, h, w) +=! input(b, r_ip, h + r_kh, w + r_kw) * weight(op, r_ip, r_kh, r_kw)
6969
}
7070

7171
**Simple 2D max pooling:**
7272

73-
Note the similarity with a convolution with a
74-
"select"-style kernel):
73+
Note the similarity with a convolution with a "select"-style kernel:
7574

7675
def maxpool2x2(float(B,C,H,W) input) -> (output) {
77-
output(b,c,i,j) max= input(b,c,2*i + kw, 2*j + kh)
78-
where kw = [0, 2[, kh = [0, 2[
76+
output(b,c,h,w) max=! input(b,c,2*h + r_kw, 2*w + r_kh)
77+
where r_kw in 0:2, r_kh in 0..2
7978
}

docs/source/coding_conventions.rst

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
Coding Conventions
2+
==================
3+
4+
In order to increase readability across Tensor Comprehensions written by
5+
multiple authors and to reduce the amount of surprising behavior, the
6+
following conventions should be adopted when writing TC. Generally in TC, one
7+
should increment nesting by 4 whitespaces at each level and align tensor names
8+
and indices where appropriate to make memory access patterns emerge. Since
9+
these two goals can easily be conflicting, use your best judgement to tradeoff
10+
between the two goals. Such examples are provided below.
11+
12+
Use indices named after parameters
13+
----------------------------------
14+
15+
Use upper-case names for parameters and capital-case names for input/output tensors.
16+
Use lower-case names for indices to match the name of the parameter
17+
corresponding to the dimension upon which they iterate.
18+
In other words, prefer:
19+
20+
.. code::
21+
22+
def copy2d(float(M, N) I) -> (O) {
23+
O(m, n) = I(m, n)
24+
}
25+
26+
to:
27+
28+
.. code::
29+
30+
def copy2d(float(M, N) I) -> (O) {
31+
O(i, j) = I(i, j)
32+
}
33+
34+
Prefix reduction index names with :code:`r_`
35+
--------------------------------------------
36+
37+
By definition, reduction indices are the ones that appear on the RHS of a TC
38+
expression but not on the LHS. On larger expressions it can get challenging to easily
39+
detect the reduction variables by mentally parsing the set of indices on the
40+
RHS and subtracting the set of indices on the LHS from it. To alleviate such
41+
issues, name the reduction variables with a :code:`r_` prefix.
42+
In other words, prefer:
43+
44+
.. code::
45+
46+
def matmul(float(M, K) A, float(K, N) B) -> (C) {
47+
C(m, n) +=! A(m, r_k) * B(r_k, n)
48+
}
49+
50+
to:
51+
52+
.. code::
53+
54+
def matmul(float(M, K) A, float(K, N) B) -> (C) {
55+
C(m, n) +=! A(m, k) * B(k, n)
56+
}
57+
58+
Filter non-rectangular regions with data-dependencies
59+
-----------------------------------------------------
60+
61+
TC semantics are restricted to (hyper-)rectangular iteration spaces.
62+
This is a hard requirement to ensure range inference is non-ambiguous (see inference_).
63+
To simulate non-rectangular iteration spaces, one can use the following:
64+
65+
.. code::
66+
67+
def matmul(float(M, K) L, float(K, M) U) -> (LU) {
68+
LU(m1, m2) +=! (r_k >= m1 and r_k =< m2) ? L(m1, r_k) * U(r_k, m2) : 0
69+
}
70+
71+
However, non-(hyper)-rectangular iteration spaces (e.g. triangular) are
72+
incompatible with range inference and will fail the semantic checks in the TC
73+
compiler:
74+
75+
.. code::
76+
77+
def matmul(float(M, K) L, float(K, M) U) -> (LU) {
78+
LU(m1, m2) +=! L(m1, r_k) * U(r_k, m2) where r_k in m1:M, r_k in 0:m2+1
79+
}
80+
81+
The reader may remark that this is an inefficient way of writing
82+
matrix-multiplication of triangular matrices.
83+
Lowering such operations efficiently from TC is the subject of future work.
84+
85+
Prefix gradient tensors names with :code:`d_`
86+
---------------------------------------------
87+
88+
When implementing backward operations, pass the inputs to the backwards pass
89+
in the same order as the outputs of the forward pass and use the same tensor
90+
name prefixed by :code:`d_`. For instance:
91+
92+
.. code::
93+
94+
def conv(float(N,C,H,W) I, float(M,C,KH,KW) Wt) -> (O) {
95+
...
96+
}
97+
98+
def conv_bw(float(N,C,H,W) I, float(M,C,KH,KW) Wt, float(N,M,HO,WO) d_O) -> (d_I) {
99+
...
100+
}
101+
102+
A more complex example
103+
----------------------
104+
105+
The following shows a possible implementation for a more complex forward and
106+
backward example. Notice the proper alignment of indices in the backward pass
107+
and the emergence of an antidiagonal pattern in the reduction accesses:
108+
109+
.. code::
110+
111+
def matmul(float(M,K) A, float(K,N) B) -> (C) {
112+
C(m, n) +=! A(m, r_k) * B(r_k, n)
113+
}
114+
def matmul_bw(float(M,K) A, float(K,N) B, float(M,N) d_C) -> (d_A, d_B){
115+
d_A(m, k) +=! d_C( m, r_n) * B( k, r_n)
116+
d_B(k, n) +=! d_C(r_m, n) * A(r_m, k)
117+
}
118+
119+
Reasoning on such reduction patterns at the level of TC has already proven
120+
valuable in other circumstances.

0 commit comments

Comments
 (0)