Skip to content

Commit efb8144

Browse files
refactor operator to explicit methods
1 parent 7cdfe53 commit efb8144

File tree

3 files changed

+201
-66
lines changed

3 files changed

+201
-66
lines changed

headers/wasm/concrete_rt.hpp

Lines changed: 108 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,30 +17,121 @@ struct Num {
1717
Num(int64_t value) : value(value) {}
1818
Num() : value(0) {}
1919
int64_t value;
20+
2021
int32_t toInt() const { return static_cast<int32_t>(value); }
22+
uint32_t toUInt() const { return static_cast<uint32_t>(value); }
2123

24+
// Helper to create a Wasm Boolean result (1 or 0 as Num)
25+
Num WasmBool(bool condition) const { return Num(condition ? 1 : 0); }
2226
// TODO: support different bit width operations, for now we just assume all
2327
// oprands are i32
24-
bool operator==(const Num &other) const { return toInt() == other.toInt(); }
25-
bool operator!=(const Num &other) const { return !(*this == other); }
26-
Num operator+(const Num &other) const { return Num(toInt() + other.toInt()); }
27-
Num operator-(const Num &other) const { return Num(toInt() - other.toInt()); }
28-
Num operator*(const Num &other) const { return Num(toInt() * other.toInt()); }
29-
Num operator/(const Num &other) const {
30-
if (other.toInt() == 0) {
31-
throw std::runtime_error("Division by zero");
28+
// i32.eq (Equals): *this == other
29+
inline Num i32_eq(const Num &other) const {
30+
return WasmBool(this->toUInt() == other.toUInt());
31+
}
32+
33+
// i32.ne (Not Equals): *this != other
34+
inline Num i32_ne(const Num &other) const {
35+
return WasmBool(this->toUInt() != other.toUInt());
36+
}
37+
38+
// i32.lt_s (Signed Less Than): *this < other
39+
inline Num i32_lt_s(const Num &other) const {
40+
return WasmBool(this->toInt() < other.toInt());
41+
}
42+
43+
// i32.lt_u (Unsigned Less Than): *this < other (unsigned)
44+
inline Num i32_lt_u(const Num &other) const {
45+
return WasmBool(this->toUInt() < other.toUInt());
46+
}
47+
48+
// i32.le_s (Signed Less Than or Equal): *this <= other
49+
inline Num i32_le_s(const Num &other) const {
50+
return WasmBool(this->toInt() <= other.toInt());
51+
}
52+
// i32.le_u (Unsigned Less Than or Equal): *this <= other (unsigned)
53+
inline Num i32_le_u(const Num &other) const {
54+
return WasmBool(this->toUInt() <= other.toUInt());
55+
}
56+
57+
// i32.gt_s (Signed Greater Than): *this > other
58+
inline Num i32_gt_s(const Num &other) const {
59+
return WasmBool(this->toInt() > other.toInt());
60+
}
61+
62+
// i32.gt_u (Unsigned Greater Than): *this > other (unsigned)
63+
inline Num i32_gt_u(const Num &other) const {
64+
return WasmBool(this->toUInt() > other.toUInt());
65+
}
66+
67+
// i32.ge_s (Signed Greater Than or Equal): *this >= other
68+
inline Num i32_ge_s(const Num &other) const {
69+
return WasmBool(this->toInt() >= other.toInt());
70+
}
71+
72+
// i32.ge_u (Unsigned Greater Than or Equal): *this >= other (unsigned)
73+
inline Num i32_ge_u(const Num &other) const {
74+
return WasmBool(this->toUInt() >= other.toUInt());
75+
}
76+
77+
// i32.add (Wrapping addition)
78+
inline Num i32_add(const Num &other) const {
79+
uint32_t result_u = this->toUInt() + other.toUInt();
80+
return Num(static_cast<int32_t>(result_u));
81+
}
82+
83+
// i32.sub (Wrapping subtraction)
84+
inline Num i32_sub(const Num &other) const {
85+
uint32_t result_u = this->toUInt() - other.toUInt();
86+
return Num(static_cast<int32_t>(result_u));
87+
}
88+
89+
// i32.mul (Wrapping multiplication)
90+
inline Num i32_mul(const Num &other) const {
91+
uint32_t result_u = this->toUInt() * other.toUInt();
92+
return Num(static_cast<int32_t>(result_u));
93+
}
94+
95+
// i32.div_s (Signed division with traps)
96+
inline Num i32_div_s(const Num &other) const {
97+
int32_t divisor = other.toInt();
98+
int32_t dividend = this->toInt();
99+
100+
if (divisor == 0) {
101+
throw std::runtime_error("i32.div_s: Division by zero");
32102
}
33-
return Num(toInt() / other.toInt());
103+
104+
return Num(dividend / divisor);
105+
}
106+
107+
// i32.shl (Shift Left): *this << other (shift count masked by 31)
108+
inline Num i32_shl(const Num &other) const {
109+
uint32_t shift_amount = other.toUInt() & 0x1F;
110+
uint32_t result_u = toUInt() << shift_amount;
111+
return Num(static_cast<int32_t>(result_u));
112+
}
113+
114+
// i32.shr_s (Signed Shift Right): *this >> other (Arithmetic shift)
115+
inline Num i32_shr_s(const Num &other) const {
116+
// Wasm masks the shift amount by 31 (0x1F)
117+
uint32_t shift_amount = other.toUInt() & 0x1F;
118+
int32_t result_s = toInt() >> shift_amount;
119+
return Num(result_s);
34120
}
35-
Num operator<(const Num &other) const { return Num(toInt() < other.toInt()); }
36-
Num operator<=(const Num &other) const {
37-
return Num(toInt() <= other.toInt());
121+
122+
// i32.shr_u (Unsigned Shift Right): *this >>> other (Logical shift)
123+
inline Num i32_shr_u(const Num &other) const {
124+
// Wasm masks the shift amount by 31 (0x1F)
125+
uint32_t shift_amount = other.toUInt() & 0x1F;
126+
uint32_t result_u = toUInt() >> shift_amount;
127+
return Num(static_cast<int32_t>(result_u));
38128
}
39-
Num operator>(const Num &other) const { return Num(toInt() > other.toInt()); }
40-
Num operator>=(const Num &other) const {
41-
return Num(toInt() >= other.toInt());
129+
130+
// i32.and (Bitwise AND)
131+
inline Num i32_and(const Num &other) const {
132+
uint32_t result_u = this->toUInt() & other.toUInt();
133+
return Num(static_cast<int32_t>(result_u));
42134
}
43-
Num operator&(const Num &other) const { return Num(toInt() & other.toInt()); }
44135
};
45136

46137
static Num I32V(int v) { return v; }
@@ -303,6 +394,7 @@ struct Memory_t {
303394
}
304395
};
305396

397+
306398
static Memory_t Memory(1); // 1 page memory
307399

308400
#endif // WASM_CONCRETE_RT_HPP

headers/wasm/symbolic_rt.hpp

Lines changed: 57 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1251,29 +1251,75 @@ static EvalRes eval_sym_expr(const SymVal &sym, SymEnv_t &sym_env) {
12511251
auto rhs_res = eval_sym_expr(operation->rhs, sym_env);
12521252
auto lhs = lhs_res.value;
12531253
auto rhs = rhs_res.value;
1254+
auto lhs_width = lhs_res.width;
1255+
auto rhs_width = rhs_res.width;
12541256
switch (operation->op) {
12551257
case ADD:
1256-
return EvalRes(lhs + rhs, 32);
1258+
if (lhs_width == 32 && rhs_width == 32) {
1259+
return EvalRes(lhs.i32_add(rhs), 32);
1260+
} else {
1261+
assert(false && "TODO");
1262+
}
12571263
case SUB:
1258-
return EvalRes(lhs - rhs, 32);
1264+
if (lhs_width == 32 && rhs_width == 32) {
1265+
return EvalRes(lhs.i32_sub(rhs), 32);
1266+
} else {
1267+
assert(false && "TODO");
1268+
}
12591269
case MUL:
1260-
return EvalRes(lhs * rhs, 32);
1270+
if (lhs_width == 32 && rhs_width == 32) {
1271+
return EvalRes(lhs.i32_mul(rhs), 32);
1272+
} else {
1273+
assert(false && "TODO");
1274+
}
12611275
case DIV:
1262-
return EvalRes(lhs / rhs, 32);
1276+
if (lhs_width == 32 && rhs_width == 32) {
1277+
return EvalRes(lhs.i32_div_s(rhs), 32);
1278+
} else {
1279+
assert(false && "TODO");
1280+
}
12631281
case LT:
1264-
return EvalRes(lhs < rhs, 32);
1282+
if (lhs_width == 32 && rhs_width == 32) {
1283+
return EvalRes(lhs.i32_lt_s(rhs), 32);
1284+
} else {
1285+
assert(false && "TODO");
1286+
}
12651287
case LEQ:
1266-
return EvalRes(lhs <= rhs, 32);
1288+
if (lhs_width == 32 && rhs_width == 32) {
1289+
return EvalRes(lhs.i32_le_s(rhs), 32);
1290+
} else {
1291+
assert(false && "TODO");
1292+
}
12671293
case GT:
1268-
return EvalRes(lhs > rhs, 32);
1294+
if (lhs_width == 32 && rhs_width == 32) {
1295+
return EvalRes(lhs.i32_gt_s(rhs), 32);
1296+
} else {
1297+
assert(false && "TODO");
1298+
}
12691299
case GEQ:
1270-
return EvalRes(lhs >= rhs, 32);
1300+
if (lhs_width == 32 && rhs_width == 32) {
1301+
return EvalRes(lhs.i32_ge_s(rhs), 32);
1302+
} else {
1303+
assert(false && "TODO");
1304+
}
12711305
case NEQ:
1272-
return EvalRes(lhs != rhs, 32);
1306+
if (lhs_width == 32 && rhs_width == 32) {
1307+
return EvalRes(lhs.i32_ne(rhs), 32);
1308+
} else {
1309+
assert(false && "TODO");
1310+
}
12731311
case EQ:
1274-
return EvalRes(lhs == rhs, 32);
1312+
if (lhs_width == 32 && rhs_width == 32) {
1313+
return EvalRes(lhs.i32_eq(rhs), 32);
1314+
} else {
1315+
assert(false && "TODO");
1316+
}
12751317
case B_AND:
1276-
return EvalRes(Num(I64V(lhs.value & rhs.value)), 32);
1318+
if (lhs_width == 32 && rhs_width == 32) {
1319+
return EvalRes(lhs.i32_and(rhs), 32);
1320+
} else {
1321+
assert(false && "TODO");
1322+
}
12771323
case CONCAT: {
12781324
auto lhs_width = lhs_res.width;
12791325
auto rhs_width = rhs_res.width;

0 commit comments

Comments
 (0)