@@ -64,23 +64,23 @@ DEFINE_uint32(Q, 2, "W4_h");
6464// float(E1, D) LUT1, int32(B, L1) I1,
6565// float(E2, D) LUT2, int32(B, L2) I2) -> (O1, O2)
6666// {
67- // O1(i, j ) +=! LUT1(I1(i, k ), j )
68- // O2(i, j ) +=! LUT2(I2(i, k ), j )
67+ // O1(b, d ) +=! LUT1(I1(b, r_l1 ), d )
68+ // O2(b, d ) +=! LUT2(I2(b, r_l2 ), d )
6969// }
7070// def _3FCRELU(
7171// float(B,M) I, float(O,N) W2, float(O) B2,
7272// float(P,O) W3, float(P) B3, float(Q,P) W4,
7373// float(Q) B4) -> (O1, O2, O3, O4)
7474// {
75- // O2(b, o) = B2(o)
76- // O2(b, o) += O1(b, n) * W2(o, n)
77- // O2(b, o) = fmax(O2(b, o), 0)
78- // O3(b, p) = B3(p)
79- // O3(b, p) += O2(b, o) * W3(p, o)
80- // O3(b, p) = fmax(O3(b, p), 0)
81- // O4(b, q) = B4(q)
82- // O4(b, q) += O3(b, p) * W4(q, p)
83- // O4(b, q) = fmax(O4(b, q), 0)
75+ // O2(b, o) = B2(o)
76+ // O2(b, o) += O1(b, n) * W2(o, n)
77+ // O2(b, o) = fmax(O2(b, o), 0)
78+ // O3(b, p) = B3(p)
79+ // O3(b, p) += O2(b, o) * W3(p, o)
80+ // O3(b, p) = fmax(O3(b, p), 0)
81+ // O4(b, q) = B4(q)
82+ // O4(b, q) += O3(b, p) * W4(q, p)
83+ // O4(b, q) = fmax(O4(b, q), 0)
8484// }
8585// def prod_model(float(E1, D) LUT1, int32(B, L1) I1,
8686// float(E2, D) LUT2, int32(B, L2) I2,
@@ -91,15 +91,15 @@ DEFINE_uint32(Q, 2, "W4_h");
9191// float(Q,P) W4, float(Q) B4)
9292// -> (C1, C2, C3, I, O1, O2, O3, O4)
9393// {
94- // (C1, C2) = _2LUT(LUT1, I1, LUT2, I2)
95- // C3(b, wy) += I3(b, wxx ) * W(wy, wxx )
96- // I(b, m) = Concat(C1, C2, C3) // not in TC atm
97- // O1(b, n) = B1(n)
98- // O1(b, n) += I(b, m) * W1(n, m)
99- // O1(b, n) = fmax(O1(b, n), 0)
94+ // (C1, C2) = _2LUT(LUT1, I1, LUT2, I2)
95+ // C3(b, wy) +=! I3(b, r_wx ) * W(wy, r_wx )
96+ // I(b, m) = Concat(C1, C2, C3) // not in TC atm
97+ // O1(b, n) = B1(n)
98+ // O1(b, n) +=! I(b, m) * W1(n, m)
99+ // O1(b, n) = fmax(O1(b, n), 0)
100100// (O2, O3, O4) =
101- // _3FCRELU(I, W1, B1, W2, B2, W3, B3, W4, B4)
102- // # O4 goes out to binary classifier, omitted here
101+ // _3FCRELU(I, W1, B1, W2, B2, W3, B3, W4, B4)
102+ // # O4 goes out to binary classifier, omitted here
103103// }
104104
105105class ProductionModel : public Benchmark {
@@ -191,9 +191,9 @@ void ProductionModel::run1LUT(
191191
192192 std::vector<at::Tensor> inputs = {LUT1, IDX1};
193193 std::string tc = R"(
194- def _1LUT(float(E1, D) LUT1, int32(B, L1) I1) -> (O1) {
195- O1(i, j ) +=! LUT1(I1(i, k ), j )
196- }
194+ def _1LUT(float(E1, D) LUT1, int32(B, L1) I1) -> (O1) {
195+ O1(b, d ) +=! LUT1(I1(b, r_l1 ), d )
196+ }
197197 )" ;
198198
199199 std::string suffix = std::string (" _B_" ) + std::to_string (FLAGS_B) +
@@ -294,10 +294,10 @@ void ProductionModel::run2LUT(
294294
295295 std::vector<at::Tensor> inputs = {LUT1, IDX1, LUT2, IDX2};
296296 std::string tc = R"(
297- def _2LUT(float(E1, D) LUT1, int32(B, L1) I1, float(E2, D) LUT2, int32(B, L2) I2) -> (O1, O2) {
298- O1(i, j ) +=! LUT1(I1(i, k ), j )
299- O2(i, j ) +=! LUT2(I2(i, k ), j )
300- }
297+ def _2LUT(float(E1, D) LUT1, int32(B, L1) I1, float(E2, D) LUT2, int32(B, L2) I2) -> (O1, O2) {
298+ O1(b, d ) +=! LUT1(I1(b, r_l1 ), d )
299+ O2(b, d ) +=! LUT2(I2(b, r_l2 ), d )
300+ }
301301 )" ;
302302
303303 std::string suffix = std::string (" _B_" ) + std::to_string (FLAGS_B) +
@@ -353,9 +353,9 @@ void ProductionModel::runC3(
353353
354354 std::vector<at::Tensor> inputs = {I, W};
355355 std::string tc = R"TC(
356- def _C3(float(B,WX) I, float(WY, WX) W) -> (C3) {
357- C3(b, wy) +=! I(b, wxx ) * W(wy, wxx )
358- }
356+ def _C3(float(B,WX) I, float(WY, WX) W) -> (C3) {
357+ C3(b, wy) +=! I(b, r_wx ) * W(wy, r_wx )
358+ }
359359)TC" ;
360360
361361 std::string suffix = std::string (" _B_" ) + std::to_string (FLAGS_B) +
@@ -408,11 +408,11 @@ void ProductionModel::runMLP1(
408408
409409 std::vector<at::Tensor> inputs = {I, W1, B1};
410410 std::string tc = R"TC(
411- def mlp1(float(B,M) I, float(M, N) W1, float(N) B1) -> (O1) {
412- O1(b, n) +=! I(b, mm ) * W1(mm , n)
413- O1(b, n) = O1(b, n) + B1(n)
414- O1(b, n) = fmax(O1(b, n), 0)
415- }
411+ def mlp1(float(B,M) I, float(M, N) W1, float(N) B1) -> (O1) {
412+ O1(b, n) +=! I(b, r_m ) * W1(r_m , n)
413+ O1(b, n) = O1(b, n) + B1(n)
414+ O1(b, n) = fmax(O1(b, n), 0)
415+ }
416416)TC" ;
417417
418418 std::string suffix = std::string (" _B_" ) + std::to_string (FLAGS_B) +
@@ -474,17 +474,17 @@ void ProductionModel::runMLP3(
474474
475475 std::vector<at::Tensor> inputs = {I, W2, B2, W3, B3, W4, B4};
476476 std::string tc = R"TC(
477- def mlp3(float(B,N) I, float(O,N) W2, float(O) B2, float(P,O) W3, float(P) B3, float(Q,P) W4, float(Q) B4) -> (O2, O3, O4) {
478- O2(b, o) +=! I(b, n) * W2(o, n)
479- O2(b, o) = O2(b, o) + B2(o)
480- O2(b, o) = fmax(O2(b, o), 0)
477+ def mlp3(float(B,N) I, float(O,N) W2, float(O) B2, float(P,O) W3, float(P) B3, float(Q,P) W4, float(Q) B4) -> (O2, O3, O4) {
478+ O2(b, o) +=! I(b, n) * W2(o, n)
479+ O2(b, o) = O2(b, o) + B2(o)
480+ O2(b, o) = fmax(O2(b, o), 0)
481481 O3(b, p) +=! O2(b, o) * W3(p, o)
482- O3(b, p) = O3(b, p) + B3(p)
483- O3(b, p) = fmax(O3(b, p), 0)
482+ O3(b, p) = O3(b, p) + B3(p)
483+ O3(b, p) = fmax(O3(b, p), 0)
484484 O4(b, q) +=! O3(b, p) * W4(q, p)
485- O4(b, q) = O4(b, q) + B4(q)
486- O4(b, q) = fmax(O4(b, q), 0)
487- }
485+ O4(b, q) = O4(b, q) + B4(q)
486+ O4(b, q) = fmax(O4(b, q), 0)
487+ }
488488)TC" ;
489489
490490 std::string suffix = std::string (" _B_" ) + std::to_string (FLAGS_B) +
0 commit comments