Skip to content

Commit 7b61352

Browse files
committed
[naga wgsl-in] Short-circuiting of && and || operators
Fixes #4394 Mostly fixes #6302, but #8440 remains an issue
1 parent 13a9c1b commit 7b61352

File tree

7 files changed

+1084
-561
lines changed

7 files changed

+1084
-561
lines changed

naga/src/front/wgsl/lower/mod.rs

Lines changed: 250 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,13 @@ impl TypeContext for ExpressionContext<'_, '_, '_> {
427427
}
428428

429429
impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
430+
const fn is_runtime(&self) -> bool {
431+
match self.expr_type {
432+
ExpressionContextType::Runtime(_) => true,
433+
ExpressionContextType::Constant(_) | ExpressionContextType::Override => false,
434+
}
435+
}
436+
430437
#[allow(dead_code)]
431438
fn as_const(&mut self) -> ExpressionContext<'source, '_, '_> {
432439
ExpressionContext {
@@ -589,6 +596,16 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
589596
}
590597
}
591598

599+
fn get(&self, handle: Handle<crate::Expression>) -> &crate::Expression {
600+
match self.expr_type {
601+
ExpressionContextType::Runtime(ref ctx)
602+
| ExpressionContextType::Constant(Some(ref ctx)) => &ctx.function.expressions[handle],
603+
ExpressionContextType::Constant(None) | ExpressionContextType::Override => {
604+
&self.module.global_expressions[handle]
605+
}
606+
}
607+
}
608+
592609
fn local(
593610
&mut self,
594611
local: &Handle<ast::Local>,
@@ -615,6 +632,52 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
615632
}
616633
}
617634

635+
fn with_nested_runtime_expression_ctx<'a, F, T>(
636+
&mut self,
637+
span: Span,
638+
f: F,
639+
) -> Result<'source, (T, crate::Block)>
640+
where
641+
for<'t> F: FnOnce(&mut ExpressionContext<'source, 't, 't>) -> Result<'source, T>,
642+
{
643+
let mut block = crate::Block::new();
644+
let rctx = match self.expr_type {
645+
ExpressionContextType::Runtime(ref mut rctx) => Ok(rctx),
646+
ExpressionContextType::Constant(_) | ExpressionContextType::Override => {
647+
Err(Error::UnexpectedOperationInConstContext(span))
648+
}
649+
}?;
650+
651+
rctx.block
652+
.extend(rctx.emitter.finish(&rctx.function.expressions));
653+
rctx.emitter.start(&rctx.function.expressions);
654+
655+
let nested_rctx = LocalExpressionContext {
656+
local_table: rctx.local_table,
657+
function: rctx.function,
658+
block: &mut block,
659+
emitter: rctx.emitter,
660+
typifier: rctx.typifier,
661+
local_expression_kind_tracker: rctx.local_expression_kind_tracker,
662+
};
663+
let mut nested_ctx = ExpressionContext {
664+
enable_extensions: self.enable_extensions,
665+
expr_type: ExpressionContextType::Runtime(nested_rctx),
666+
ast_expressions: self.ast_expressions,
667+
globals: self.globals,
668+
module: self.module,
669+
const_typifier: self.const_typifier,
670+
layouter: self.layouter,
671+
global_expression_kind_tracker: self.global_expression_kind_tracker,
672+
};
673+
let ret = f(&mut nested_ctx)?;
674+
675+
block.extend(rctx.emitter.finish(&rctx.function.expressions));
676+
rctx.emitter.start(&rctx.function.expressions);
677+
678+
Ok((ret, block))
679+
}
680+
618681
fn gather_component(
619682
&mut self,
620683
expr: Handle<ir::Expression>,
@@ -2534,6 +2597,130 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
25342597
Ok(ty)
25352598
}
25362599

2600+
/// Generate IR for the short-circuiting operators `&&` and `||`.
2601+
///
2602+
/// `binary` has already lowered the LHS expression and resolved its type.
2603+
fn logical(
2604+
&mut self,
2605+
op: crate::BinaryOperator,
2606+
left: Handle<crate::Expression>,
2607+
right: Handle<ast::Expression<'source>>,
2608+
span: Span,
2609+
ctx: &mut ExpressionContext<'source, '_, '_>,
2610+
) -> Result<'source, Typed<crate::Expression>> {
2611+
debug_assert!(
2612+
op == crate::BinaryOperator::LogicalAnd || op == crate::BinaryOperator::LogicalOr
2613+
);
2614+
2615+
if ctx.is_runtime() {
2616+
// To simulate short-circuiting behavior, we want to generate IR
2617+
// like the following for `&&`. For `||`, the condition is `!_lhs`
2618+
// and the else value is `true`.
2619+
//
2620+
// var _e0: bool;
2621+
// if _lhs {
2622+
// _e0 = _rhs;
2623+
// } else {
2624+
// _e0 = false;
2625+
// }
2626+
2627+
let (condition, else_val) = if op == crate::BinaryOperator::LogicalAnd {
2628+
let condition = left;
2629+
let else_val = ctx.append_expression(
2630+
crate::Expression::Literal(crate::Literal::Bool(false)),
2631+
span,
2632+
)?;
2633+
(condition, else_val)
2634+
} else {
2635+
let condition = ctx.append_expression(
2636+
crate::Expression::Unary {
2637+
op: crate::UnaryOperator::LogicalNot,
2638+
expr: left,
2639+
},
2640+
span,
2641+
)?;
2642+
let else_val = ctx.append_expression(
2643+
crate::Expression::Literal(crate::Literal::Bool(true)),
2644+
span,
2645+
)?;
2646+
(condition, else_val)
2647+
};
2648+
2649+
let bool_ty = ctx.ensure_type_exists(crate::TypeInner::Scalar(crate::Scalar::BOOL));
2650+
2651+
let rctx = ctx.runtime_expression_ctx(span)?;
2652+
let result_var = rctx.function.local_variables.append(
2653+
crate::LocalVariable {
2654+
name: None,
2655+
ty: bool_ty,
2656+
init: None,
2657+
},
2658+
span,
2659+
);
2660+
let pointer =
2661+
ctx.append_expression(crate::Expression::LocalVariable(result_var), span)?;
2662+
2663+
let (right, mut accept) = ctx.with_nested_runtime_expression_ctx(span, |ctx| {
2664+
let right = self.expression_for_abstract(right, ctx)?;
2665+
ctx.grow_types(right)?;
2666+
Ok(right)
2667+
})?;
2668+
2669+
accept.push(
2670+
crate::Statement::Store {
2671+
pointer,
2672+
value: right,
2673+
},
2674+
span,
2675+
);
2676+
2677+
let mut reject = crate::Block::with_capacity(1);
2678+
reject.push(
2679+
crate::Statement::Store {
2680+
pointer,
2681+
value: else_val,
2682+
},
2683+
span,
2684+
);
2685+
2686+
let rctx = ctx.runtime_expression_ctx(span)?;
2687+
rctx.block.push(
2688+
crate::Statement::If {
2689+
condition,
2690+
accept,
2691+
reject,
2692+
},
2693+
span,
2694+
);
2695+
2696+
Ok(Typed::Reference(crate::Expression::LocalVariable(
2697+
result_var,
2698+
)))
2699+
} else {
2700+
let left_expr = ctx.get(left);
2701+
// Constant or override context in either function or module scope
2702+
let &crate::Expression::Literal(crate::Literal::Bool(left_val)) = left_expr else {
2703+
return Err(Box::new(Error::NotBool(span)));
2704+
};
2705+
2706+
if op == crate::BinaryOperator::LogicalAnd && !left_val
2707+
|| op == crate::BinaryOperator::LogicalOr && left_val
2708+
{
2709+
// Short-circuit behavior: don't evaluate the RHS. Ideally we
2710+
// would do _some_ validity checks of the RHS here, but that's
2711+
// tricky, because the RHS is allowed to have things that aren't
2712+
// legal in const contexts.
2713+
2714+
Ok(Typed::Plain(left_expr.clone()))
2715+
} else {
2716+
let right = self.expression_for_abstract(right, ctx)?;
2717+
ctx.grow_types(right)?;
2718+
2719+
Ok(Typed::Plain(crate::Expression::Binary { op, left, right }))
2720+
}
2721+
}
2722+
}
2723+
25372724
fn binary(
25382725
&mut self,
25392726
op: ir::BinaryOperator,
@@ -2542,57 +2729,74 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
25422729
span: Span,
25432730
ctx: &mut ExpressionContext<'source, '_, '_>,
25442731
) -> Result<'source, Typed<ir::Expression>> {
2545-
// Load both operands.
2546-
let mut left = self.expression_for_abstract(left, ctx)?;
2547-
let mut right = self.expression_for_abstract(right, ctx)?;
2548-
2549-
// Convert `scalar op vector` to `vector op vector` by introducing
2550-
// `Splat` expressions.
2551-
ctx.binary_op_splat(op, &mut left, &mut right)?;
2552-
2553-
// Apply automatic conversions.
2554-
match op {
2555-
ir::BinaryOperator::ShiftLeft | ir::BinaryOperator::ShiftRight => {
2556-
// Shift operators require the right operand to be `u32` or
2557-
// `vecN<u32>`. We can let the validator sort out vector length
2558-
// issues, but the right operand must be, or convert to, a u32 leaf
2559-
// scalar.
2560-
right =
2561-
ctx.try_automatic_conversion_for_leaf_scalar(right, ir::Scalar::U32, span)?;
2562-
2563-
// Additionally, we must concretize the left operand if the right operand
2564-
// is not a const-expression.
2565-
// See https://www.w3.org/TR/WGSL/#overload-resolution-section.
2566-
//
2567-
// 2. Eliminate any candidate where one of its subexpressions resolves to
2568-
// an abstract type after feasible automatic conversions, but another of
2569-
// the candidate’s subexpressions is not a const-expression.
2570-
//
2571-
// We only have to explicitly do so for shifts as their operands may be
2572-
// of different types - for other binary ops this is achieved by finding
2573-
// the conversion consensus for both operands.
2574-
if !ctx.is_const(right) {
2575-
left = ctx.concretize(left)?;
2576-
}
2732+
if op == ir::BinaryOperator::LogicalAnd || op == ir::BinaryOperator::LogicalOr {
2733+
let left = self.expression_for_abstract(left, ctx)?;
2734+
ctx.grow_types(left)?;
2735+
2736+
if !matches!(
2737+
resolve_inner!(ctx, left),
2738+
&ir::TypeInner::Scalar(ir::Scalar::BOOL)
2739+
) {
2740+
// Pass it through as-is, will fail validation
2741+
let right = self.expression_for_abstract(right, ctx)?;
2742+
ctx.grow_types(right)?;
2743+
Ok(Typed::Plain(crate::Expression::Binary { op, left, right }))
2744+
} else {
2745+
self.logical(op, left, right, span, ctx)
25772746
}
2747+
} else {
2748+
// Load both operands.
2749+
let mut left = self.expression_for_abstract(left, ctx)?;
2750+
let mut right = self.expression_for_abstract(right, ctx)?;
2751+
2752+
// Convert `scalar op vector` to `vector op vector` by introducing
2753+
// `Splat` expressions.
2754+
ctx.binary_op_splat(op, &mut left, &mut right)?;
2755+
2756+
// Apply automatic conversions.
2757+
match op {
2758+
ir::BinaryOperator::ShiftLeft | ir::BinaryOperator::ShiftRight => {
2759+
// Shift operators require the right operand to be `u32` or
2760+
// `vecN<u32>`. We can let the validator sort out vector length
2761+
// issues, but the right operand must be, or convert to, a u32 leaf
2762+
// scalar.
2763+
right =
2764+
ctx.try_automatic_conversion_for_leaf_scalar(right, ir::Scalar::U32, span)?;
2765+
2766+
// Additionally, we must concretize the left operand if the right operand
2767+
// is not a const-expression.
2768+
// See https://www.w3.org/TR/WGSL/#overload-resolution-section.
2769+
//
2770+
// 2. Eliminate any candidate where one of its subexpressions resolves to
2771+
// an abstract type after feasible automatic conversions, but another of
2772+
// the candidate’s subexpressions is not a const-expression.
2773+
//
2774+
// We only have to explicitly do so for shifts as their operands may be
2775+
// of different types - for other binary ops this is achieved by finding
2776+
// the conversion consensus for both operands.
2777+
if !ctx.is_const(right) {
2778+
left = ctx.concretize(left)?;
2779+
}
2780+
}
25782781

2579-
// All other operators follow the same pattern: reconcile the
2580-
// scalar leaf types. If there's no reconciliation possible,
2581-
// leave the expressions as they are: validation will report the
2582-
// problem.
2583-
_ => {
2584-
ctx.grow_types(left)?;
2585-
ctx.grow_types(right)?;
2586-
if let Ok(consensus_scalar) =
2587-
ctx.automatic_conversion_consensus([left, right].iter())
2588-
{
2589-
ctx.convert_to_leaf_scalar(&mut left, consensus_scalar)?;
2590-
ctx.convert_to_leaf_scalar(&mut right, consensus_scalar)?;
2782+
// All other operators follow the same pattern: reconcile the
2783+
// scalar leaf types. If there's no reconciliation possible,
2784+
// leave the expressions as they are: validation will report the
2785+
// problem.
2786+
_ => {
2787+
ctx.grow_types(left)?;
2788+
ctx.grow_types(right)?;
2789+
if let Ok(consensus_scalar) =
2790+
ctx.automatic_conversion_consensus([left, right].iter())
2791+
{
2792+
ctx.convert_to_leaf_scalar(&mut left, consensus_scalar)?;
2793+
ctx.convert_to_leaf_scalar(&mut right, consensus_scalar)?;
2794+
}
25912795
}
25922796
}
2593-
}
25942797

2595-
Ok(Typed::Plain(ir::Expression::Binary { op, left, right }))
2798+
Ok(Typed::Plain(ir::Expression::Binary { op, left, right }))
2799+
}
25962800
}
25972801

25982802
/// Generate Naga IR for call expressions and statements, and type

naga/tests/in/wgsl/operators.wgsl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@ fn bool_cast(x: vec3<f32>) -> vec3<f32> {
4040
return vec3<f32>(y);
4141
}
4242

43+
fn p() -> bool { return true; }
44+
fn q() -> bool { return false; }
45+
fn r() -> bool { return true; }
46+
fn s() -> bool { return false; }
47+
4348
fn logical() {
4449
let t = true;
4550
let f = false;
@@ -55,6 +60,7 @@ fn logical() {
5560
let bitwise_or1 = vec3(t) | vec3(f);
5661
let bitwise_and0 = t & f;
5762
let bitwise_and1 = vec4(t) & vec4(f);
63+
let short_circuit = (p() || q()) && (r() || s());
5864
}
5965

6066
fn arithmetic() {

0 commit comments

Comments
 (0)