Skip to content

Commit 18d1197

Browse files
authored
Merge pull request #8741 from tautschnig/case_exprt
Add case_exprt to std_expr.h and refactor code to use it
2 parents 8640fc9 + 18d5083 commit 18d1197

File tree

6 files changed

+223
-8
lines changed

6 files changed

+223
-8
lines changed

src/solvers/flattening/boolbv.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ bvt boolbvt::convert_bitvector(const exprt &expr)
123123
else if(expr.id() == ID_update_bit)
124124
return convert_update_bit(to_update_bit_expr(expr));
125125
else if(expr.id()==ID_case)
126-
return convert_case(expr);
126+
return convert_case(to_case_expr(expr));
127127
else if(expr.id()==ID_cond)
128128
return convert_cond(to_cond_expr(expr));
129129
else if(expr.id()==ID_if)
@@ -390,7 +390,7 @@ literalt boolbvt::convert_rest(const exprt &expr)
390390
}
391391
else if(expr.id()==ID_case)
392392
{
393-
bvt bv=convert_case(expr);
393+
bvt bv = convert_case(to_case_expr(expr));
394394
CHECK_RETURN(bv.size() == 1);
395395
return bv[0];
396396
}

src/solvers/flattening/boolbv.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ class boolbvt:public arrayst
185185
virtual bvt convert_update(const update_exprt &);
186186
virtual bvt convert_update_bit(const update_bit_exprt &);
187187
virtual bvt convert_update_bits(const update_bits_exprt &);
188-
virtual bvt convert_case(const exprt &expr);
188+
virtual bvt convert_case(const case_exprt &);
189189
virtual bvt convert_cond(const cond_exprt &);
190190
virtual bvt convert_shift(const binary_exprt &expr);
191191
virtual bvt convert_bitwise(const exprt &expr);

src/solvers/flattening/boolbv_case.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,14 +6,13 @@ Author: Daniel Kroening, kroening@kroening.com
66
77
\*******************************************************************/
88

9-
#include "boolbv.h"
10-
119
#include <util/invariant.h>
10+
#include <util/std_expr.h>
1211

13-
bvt boolbvt::convert_case(const exprt &expr)
14-
{
15-
PRECONDITION(expr.id() == ID_case);
12+
#include "boolbv.h"
1613

14+
bvt boolbvt::convert_case(const case_exprt &expr)
15+
{
1716
const std::vector<exprt> &operands=expr.operands();
1817

1918
std::size_t width=boolbv_width(expr.type());

src/util/std_expr.h

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3578,6 +3578,124 @@ inline cond_exprt &to_cond_expr(exprt &expr)
35783578
return ret;
35793579
}
35803580

3581+
/// \brief Case expression: evaluates to the value corresponding to the first
3582+
/// matching case. The first operand is the value to compare against. Subsequent
3583+
/// operands alternate between compare values and result values. The syntax is:
3584+
/// case(select_value, case1_value, result1, case2_value, result2, ...)
3585+
class case_exprt : public multi_ary_exprt
3586+
{
3587+
public:
3588+
case_exprt(operandst _operands, typet _type)
3589+
: multi_ary_exprt(ID_case, std::move(_operands), std::move(_type))
3590+
{
3591+
}
3592+
3593+
/// Constructor with select value
3594+
case_exprt(exprt _select_value, typet _type)
3595+
: multi_ary_exprt(ID_case, {std::move(_select_value)}, std::move(_type))
3596+
{
3597+
}
3598+
3599+
/// Get the value that is being compared against
3600+
const exprt &select_value() const
3601+
{
3602+
PRECONDITION(!operands().empty());
3603+
return operands()[0];
3604+
}
3605+
3606+
/// Get the value that is being compared against
3607+
exprt &select_value()
3608+
{
3609+
PRECONDITION(!operands().empty());
3610+
return operands()[0];
3611+
}
3612+
3613+
/// Add a case: value to compare and corresponding result
3614+
/// \param case_value: the value to compare against select_value
3615+
/// \param result_value: the value to return if case_value matches
3616+
/// select_value
3617+
void add_case(const exprt &case_value, const exprt &result_value)
3618+
{
3619+
operands().reserve(operands().size() + 2);
3620+
operands().push_back(case_value);
3621+
operands().push_back(result_value);
3622+
}
3623+
3624+
/// Get the number of cases (excluding the select value)
3625+
std::size_t number_of_cases() const
3626+
{
3627+
PRECONDITION(operands().size() >= 1);
3628+
return (operands().size() - 1) / 2;
3629+
}
3630+
3631+
/// Get the case value for the i-th case
3632+
const exprt &case_value(std::size_t i) const
3633+
{
3634+
PRECONDITION(i < number_of_cases());
3635+
return operands()[1 + 2 * i];
3636+
}
3637+
3638+
/// Get the case value for the i-th case
3639+
exprt &case_value(std::size_t i)
3640+
{
3641+
PRECONDITION(i < number_of_cases());
3642+
return operands()[1 + 2 * i];
3643+
}
3644+
3645+
/// Get the result value for the i-th case
3646+
const exprt &result_value(std::size_t i) const
3647+
{
3648+
PRECONDITION(i < number_of_cases());
3649+
return operands()[1 + 2 * i + 1];
3650+
}
3651+
3652+
/// Get the result value for the i-th case
3653+
exprt &result_value(std::size_t i)
3654+
{
3655+
PRECONDITION(i < number_of_cases());
3656+
return operands()[1 + 2 * i + 1];
3657+
}
3658+
3659+
static void validate_expr(const case_exprt &value)
3660+
{
3661+
DATA_INVARIANT(
3662+
value.operands().size() >= 1,
3663+
"case expression must have at least one operand");
3664+
DATA_INVARIANT(
3665+
value.operands().size() % 2 == 1,
3666+
"case expression must have odd number of operands");
3667+
}
3668+
};
3669+
3670+
template <>
3671+
inline bool can_cast_expr<case_exprt>(const exprt &base)
3672+
{
3673+
return base.id() == ID_case;
3674+
}
3675+
3676+
/// \brief Cast an exprt to a \ref case_exprt
3677+
///
3678+
/// \a expr must be known to be \ref case_exprt.
3679+
///
3680+
/// \param expr: Source expression
3681+
/// \return Object of type \ref case_exprt
3682+
inline const case_exprt &to_case_expr(const exprt &expr)
3683+
{
3684+
PRECONDITION(expr.id() == ID_case);
3685+
const case_exprt &ret = static_cast<const case_exprt &>(expr);
3686+
case_exprt::validate_expr(ret);
3687+
return ret;
3688+
}
3689+
3690+
/// \copydoc to_case_expr(const exprt &)
3691+
inline case_exprt &to_case_expr(exprt &expr)
3692+
{
3693+
PRECONDITION(expr.id() == ID_case);
3694+
case_exprt &ret = static_cast<case_exprt &>(expr);
3695+
case_exprt::validate_expr(ret);
3696+
return ret;
3697+
}
3698+
35813699
/// \brief Expression to define a mapping from an argument (index) to elements.
35823700
/// This enables constructing an array via an anonymous function.
35833701
/// Not all kinds of array comprehension can be expressed, only those of the

unit/Makefile

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ SRC += analyses/ai/ai.cpp \
142142
solvers/strings/string_refinement/substitute_array_list.cpp \
143143
solvers/strings/string_refinement/union_find_replace.cpp \
144144
util/bitvector_expr.cpp \
145+
util/case_expr.cpp \
145146
util/cmdline.cpp \
146147
util/dense_integer_map.cpp \
147148
util/edit_distance.cpp \

unit/util/case_expr.cpp

Lines changed: 97 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
/*******************************************************************\
2+
3+
Module: Unit tests for case_exprt
4+
5+
Author: Unit test
6+
7+
\*******************************************************************/
8+
9+
#include <util/arith_tools.h>
10+
#include <util/bitvector_types.h>
11+
#include <util/std_expr.h>
12+
13+
#include <testing-utils/use_catch.h>
14+
15+
TEST_CASE("case_exprt construction and access", "[core][util][case_expr]")
16+
{
17+
const signedbv_typet int_type(32);
18+
const symbol_exprt select_value("x", int_type);
19+
20+
SECTION("Basic construction")
21+
{
22+
case_exprt case_expr(select_value, int_type);
23+
24+
REQUIRE(case_expr.id() == ID_case);
25+
REQUIRE(case_expr.select_value() == select_value);
26+
REQUIRE(case_expr.number_of_cases() == 0);
27+
}
28+
29+
SECTION("Adding cases")
30+
{
31+
case_exprt case_expr(select_value, int_type);
32+
33+
const constant_exprt case1_value = from_integer(1, int_type);
34+
const constant_exprt result1_value = from_integer(10, int_type);
35+
36+
const constant_exprt case2_value = from_integer(2, int_type);
37+
const constant_exprt result2_value = from_integer(20, int_type);
38+
39+
case_expr.add_case(case1_value, result1_value);
40+
REQUIRE(case_expr.number_of_cases() == 1);
41+
REQUIRE(case_expr.case_value(0) == case1_value);
42+
REQUIRE(case_expr.result_value(0) == result1_value);
43+
44+
case_expr.add_case(case2_value, result2_value);
45+
REQUIRE(case_expr.number_of_cases() == 2);
46+
REQUIRE(case_expr.case_value(1) == case2_value);
47+
REQUIRE(case_expr.result_value(1) == result2_value);
48+
49+
// Verify operands structure: 1 select + 2*2 case/result pairs = 5
50+
REQUIRE(case_expr.operands().size() == 5);
51+
// Verify odd number of operands
52+
REQUIRE(case_expr.operands().size() % 2 == 1);
53+
}
54+
55+
SECTION("to_case_expr conversion")
56+
{
57+
case_exprt case_expr(select_value, int_type);
58+
const constant_exprt case_value = from_integer(1, int_type);
59+
const constant_exprt result_value = from_integer(10, int_type);
60+
case_expr.add_case(case_value, result_value);
61+
62+
exprt &base = case_expr;
63+
case_exprt &converted = to_case_expr(base);
64+
65+
REQUIRE(&converted == &case_expr);
66+
REQUIRE(converted.number_of_cases() == 1);
67+
REQUIRE(converted.case_value(0) == case_value);
68+
}
69+
70+
SECTION("can_cast_expr")
71+
{
72+
case_exprt case_expr(select_value, int_type);
73+
exprt &base = case_expr;
74+
75+
REQUIRE(can_cast_expr<case_exprt>(base));
76+
REQUIRE_FALSE(can_cast_expr<if_exprt>(base));
77+
}
78+
79+
SECTION("Construction with operands")
80+
{
81+
const constant_exprt case_value = from_integer(1, int_type);
82+
const constant_exprt result_value = from_integer(10, int_type);
83+
84+
case_exprt::operandst ops;
85+
ops.push_back(select_value);
86+
ops.push_back(case_value);
87+
ops.push_back(result_value);
88+
89+
case_exprt case_expr(std::move(ops), int_type);
90+
91+
REQUIRE(case_expr.id() == ID_case);
92+
REQUIRE(case_expr.number_of_cases() == 1);
93+
REQUIRE(case_expr.select_value() == select_value);
94+
REQUIRE(case_expr.case_value(0) == case_value);
95+
REQUIRE(case_expr.result_value(0) == result_value);
96+
}
97+
}

0 commit comments

Comments
 (0)