Skip to content

Commit 2805e6f

Browse files
allanrenucciGoogle-ML-Automation
authored andcommitted
[Mosaic GPU][NFC] Make Equation a type of Constraint.
This change renames the `Equation` class to `Equals` and includes it in the `Constraint` type alias. The `EquationSystem` now stores all equality constraints within its `constraints` list, removing the separate `equations` list. PiperOrigin-RevId: 836228575
1 parent 0b0073c commit 2805e6f

File tree

4 files changed

+123
-158
lines changed

4 files changed

+123
-158
lines changed

jax/experimental/mosaic/gpu/equations.py

Lines changed: 51 additions & 99 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,21 @@ def reduce_expression(
340340
case _:
341341
assert_never(expr)
342342

343+
@dataclasses.dataclass(frozen=True)
344+
class Equals:
345+
"""States that `lhs` and `rhs` are equal."""
346+
lhs: Expression
347+
rhs: Expression
348+
349+
def holds(self) -> bool | None:
350+
if self.lhs == self.rhs:
351+
return True
352+
if isinstance(self.lhs, Constant) and isinstance(self.rhs, Constant):
353+
return False
354+
return None
355+
356+
def __str__(self):
357+
return f"Equals({self.lhs} == {self.rhs})"
343358

344359
_SUPPORTED_TILED_RELAYOUTS = frozenset([
345360
# Transposed layouts.
@@ -548,7 +563,7 @@ def __str__(self):
548563
return f"{self.tiling_multiple} % {self.expr} == 0"
549564

550565

551-
Constraint = Relayout | NotOfType | IsTransferable | Divides
566+
Constraint = Equals | Relayout | NotOfType | IsTransferable | Divides
552567

553568

554569
def reduce_constraint(
@@ -558,6 +573,14 @@ def reduce_constraint(
558573

559574
new_constraint: Constraint
560575
match constraint:
576+
case Equals(lhs=lhs, rhs=rhs):
577+
lhs_red = reduce_expression(lhs, assignments)
578+
if isinstance(lhs_red, Unsatisfiable):
579+
return Unsatisfiable()
580+
rhs_red = reduce_expression(rhs, assignments)
581+
if isinstance(rhs_red, Unsatisfiable):
582+
return Unsatisfiable()
583+
new_constraint = Equals(lhs_red, rhs_red)
561584
case Relayout(source=source, target=target):
562585
source_red = reduce_expression(source, assignments)
563586
target_red = reduce_expression(target, assignments)
@@ -591,51 +614,6 @@ def reduce_constraint(
591614
return Tautological() if constraint_holds else Unsatisfiable()
592615

593616

594-
@dataclasses.dataclass(frozen=True)
595-
class Equation:
596-
lhs: Expression
597-
rhs: Expression
598-
599-
def __str__(self):
600-
return f"{self.lhs} == {self.rhs}"
601-
602-
603-
def reduce_equation(
604-
eq: Equation, assignments: dict[Variable, Constant]
605-
) -> Solution:
606-
"""Reduces an equation.
607-
608-
Args:
609-
eq: the equation to reduce.
610-
assignments: a set of known variable assignments.
611-
612-
Returns:
613-
A Solution object representing the result of the evaluation. That is:
614-
- Unsatisfiable(): if the equation is unsatisfiable.
615-
- Tautological(): if the equation is tautological.
616-
- Satisfiable(): if the equation is satisfiable by assigning a value to
617-
a variable.
618-
- Unknown(): if the equation contains remaining unknown variables.
619-
"""
620-
lhs = reduce_expression(eq.lhs, assignments)
621-
rhs = reduce_expression(eq.rhs, assignments)
622-
match (lhs, rhs):
623-
case (Variable(), Constant()):
624-
return SatisfiedBy((lhs, rhs))
625-
case (Constant(), Variable()):
626-
return SatisfiedBy((rhs, lhs))
627-
case (Constant(), Constant()) if lhs != rhs:
628-
return Unsatisfiable()
629-
case _ if isinstance(lhs, Unsatisfiable) or isinstance(rhs, Unsatisfiable):
630-
return Unsatisfiable()
631-
case _ if lhs == rhs:
632-
return Tautological()
633-
case _:
634-
# This is covered above. Add a check here to appease the type checker.
635-
assert not isinstance(lhs, Unsatisfiable) and not isinstance(rhs, Unsatisfiable)
636-
return Unknown(Equation(lhs, rhs))
637-
638-
639617
@dataclasses.dataclass
640618
class EquationSystem:
641619
"""An equation system contains a set of equations and assignments.
@@ -650,7 +628,6 @@ class EquationSystem:
650628
assignments: dict[Variable, Constant] = dataclasses.field(
651629
default_factory=dict
652630
)
653-
equations: list[Equation] = dataclasses.field(default_factory=list)
654631
constraints: Sequence[Constraint] = dataclasses.field(default_factory=list)
655632

656633
def unknowns(self) -> list[Variable]:
@@ -681,11 +658,11 @@ def extract_variables(expr: Expression) -> None:
681658
extract_variables(e)
682659
case _:
683660
assert_never(expr)
684-
for equation in self.equations:
685-
extract_variables(equation.lhs)
686-
extract_variables(equation.rhs)
687661
for constraint in self.constraints:
688662
match constraint:
663+
case Equals(lhs=lhs, rhs=rhs):
664+
extract_variables(lhs)
665+
extract_variables(rhs)
689666
case Relayout(source=source, target=target):
690667
extract_variables(source)
691668
extract_variables(target)
@@ -706,7 +683,6 @@ def __and__(self, other: EquationSystem) -> EquationSystem | Unsatisfiable:
706683
return Unsatisfiable()
707684
return EquationSystem(
708685
assignments=self.assignments | other.assignments,
709-
equations=self.equations + other.equations,
710686
constraints=[*self.constraints, *other.constraints],
711687
)
712688

@@ -715,9 +691,6 @@ def __str__(self):
715691
r += " assignments:\n"
716692
for assignment, constant in self.assignments.items():
717693
r += f" {assignment}{constant}\n"
718-
r += " equations:\n"
719-
for equation in self.equations:
720-
r += f" {equation}\n"
721694
r += " constraints:\n"
722695
for constraint in self.constraints:
723696
r += f" {constraint}\n"
@@ -729,16 +702,6 @@ def __and__(self, other: EquationSystem | Unsatisfiable) -> Unsatisfiable:
729702
return self
730703

731704

732-
@dataclasses.dataclass(frozen=True)
733-
class SatisfiedBy:
734-
assignment: tuple[Variable, Constant]
735-
736-
737-
@dataclasses.dataclass(frozen=True)
738-
class Unknown:
739-
equation: Equation
740-
741-
742705
class Tautological:
743706
...
744707

@@ -756,14 +719,6 @@ def non_splat_variables(
756719
return vars
757720

758721

759-
# The result of reducing an equation---and by extension, a system of
760-
# equations. An equation can either be unsatisfiable (i.e. there exists no
761-
# assignment for which it holds), satisfied by an assignment, unknown (i.e.
762-
# still undetermined), or tautological (i.e. the equation is guaranteed to
763-
# hold for any assignment).
764-
Solution = Unsatisfiable | SatisfiedBy | Unknown | Tautological
765-
766-
767722
def _has_relayout_of_non_splat_to_splat(constraints: Sequence[Constraint]) -> bool:
768723
"""Returns whether the constraints imply a non-splat to splat relayout.
769724
@@ -854,11 +809,14 @@ def union(v1: Variable, v2: Variable):
854809
parent[root2] = root1
855810

856811
all_vars: set[Variable] = set()
857-
for eq in system.equations:
858-
if isinstance(eq.lhs, Variable) and isinstance(eq.rhs, Variable):
859-
all_vars.add(eq.lhs)
860-
all_vars.add(eq.rhs)
861-
union(eq.lhs, eq.rhs)
812+
for constraint in system.constraints:
813+
match constraint:
814+
case Equals(lhs=Variable() as lhs, rhs=Variable() as rhs):
815+
assert isinstance(lhs, Variable) # make pytype happy
816+
assert isinstance(rhs, Variable) # make pytype happy
817+
all_vars.add(lhs)
818+
all_vars.add(rhs)
819+
union(lhs, rhs)
862820

863821
# Group variables by their component representative.
864822
components: dict[Variable, list[Variable]] = {}
@@ -934,33 +892,28 @@ def _reduce_system_once(
934892
- None: if the equation system is not known unsatisfiable, but hasn't been
935893
reduced.
936894
"""
895+
assignments = equation_system.assignments
896+
constraints: list[Constraint] = []
937897
changed = False
938-
assignments: dict[Variable, Constant] = {}
939-
equations: list[Equation] = []
940-
for equation in equation_system.equations:
941-
match reduce_equation(equation, equation_system.assignments):
942-
case Unsatisfiable():
943-
return Unsatisfiable()
944-
case Tautological():
945-
changed = True
946-
case SatisfiedBy() as result:
947-
variable, expression = result.assignment
948-
if variable in assignments and assignments[variable] != expression:
949-
return Unsatisfiable()
950-
assignments[variable] = expression
951-
changed = True
952-
case Unknown(equation=reduced_equation):
953-
equations.append(reduced_equation)
954-
changed |= reduced_equation != equation
955-
case _ as never:
956-
assert_never(never)
957898

958-
assignments |= equation_system.assignments
959-
constraints: list[Constraint] = []
899+
def try_assign(var: Variable, cst: Constant) -> bool:
900+
if var in assignments and assignments[var] != cst:
901+
return False
902+
assignments[var] = cst
903+
return True
904+
960905
for constraint in equation_system.constraints:
961906
match reduce_constraint(constraint, assignments):
962907
case Unsatisfiable():
963908
return Unsatisfiable()
909+
case Equals(lhs=Variable() as var, rhs=Constant() as cst):
910+
if not try_assign(var, cst):
911+
return Unsatisfiable()
912+
changed = True
913+
case Equals(lhs=Constant() as cst, rhs=Variable() as var):
914+
if not try_assign(var, cst):
915+
return Unsatisfiable()
916+
changed = True
964917
case Tautological():
965918
changed = True
966919
case _ as new_constraint:
@@ -979,7 +932,6 @@ def _reduce_system_once(
979932
if changed:
980933
return EquationSystem(
981934
assignments=assignments | equation_system.assignments,
982-
equations=equations,
983935
constraints=constraints,
984936
)
985937
return None

jax/experimental/mosaic/gpu/layout_inference.py

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1052,12 +1052,12 @@ def _vector_reduction_equation_system(
10521052
return eqns.EquationSystem(), {in_variable: [in_variable.key]}, []
10531053

10541054

1055-
def _reduction_equation_and_hint(
1055+
def _reduction_constraint_and_hint(
10561056
larger: eqns.Variable,
10571057
smaller: eqns.Variable,
10581058
larger_shape: tuple[int, ...],
1059-
reduction_dims: tuple[int, ...]
1060-
) -> tuple[eqns.Equation, Hint]:
1059+
reduction_dims: tuple[int, ...],
1060+
) -> tuple[eqns.Constraint, Hint]:
10611061
reduce_expr = eqns.Reduce(larger, reduction_dims)
10621062
# There are always many options for broadcasting a layout, so we can only
10631063
# derive a broadcast hint in the out_variable -> source_variable direction.
@@ -1066,7 +1066,7 @@ def _reduction_equation_and_hint(
10661066
)
10671067
broadcast_expr = eqns.BroadcastInDim(smaller, broadcast_dims, larger_shape)
10681068
broadcast_hint = Hint(variable=larger, expression=broadcast_expr)
1069-
return eqns.Equation(lhs=smaller, rhs=reduce_expr), broadcast_hint
1069+
return eqns.Equals(lhs=smaller, rhs=reduce_expr), broadcast_hint
10701070

10711071

10721072
@_add_equation_system_derivation_rule(vector.MultiDimReductionOp)
@@ -1081,15 +1081,17 @@ def _multi_dim_reduction_equation_system(
10811081
source_variable = eqns.Variable(source)
10821082
out_variable = eqns.Variable(out)
10831083

1084-
reduction_equation, broadcast_hint = _reduction_equation_and_hint(
1085-
source_variable, out_variable,
1086-
tuple(ir.ShapedType(op.source.type).shape), tuple(op.reduction_dims)
1084+
reduction_constraint, broadcast_hint = _reduction_constraint_and_hint(
1085+
source_variable,
1086+
out_variable,
1087+
tuple(ir.ShapedType(op.source.type).shape),
1088+
tuple(op.reduction_dims),
10871089
)
10881090
# TODO(bchetioui): in the future, we may need to add rules that prevent
10891091
# strided layouts from being chosen---since trying to reduce a strided layout
10901092
# may cause us to raise an Exception at the moment.
10911093
return (
1092-
eqns.EquationSystem(equations=[reduction_equation]),
1094+
eqns.EquationSystem(constraints=[reduction_constraint]),
10931095
{source_variable: [source], out_variable: [acc, out]},
10941096
[broadcast_hint],
10951097
)
@@ -1108,13 +1110,16 @@ def _broadcast_in_dim_equation_system(
11081110
i for i in range(len(out_shape)) if i not in op.broadcast_dimensions
11091111
)
11101112

1111-
reduction_equation, broadcast_hint = _reduction_equation_and_hint(
1113+
reduction_constraint, broadcast_hint = _reduction_constraint_and_hint(
11121114
out_variable, source_variable, out_shape, reduction_dims
11131115
)
11141116

11151117
return (
1116-
eqns.EquationSystem(equations=[reduction_equation]),
1117-
{source_variable: [source_variable.key], out_variable: [out_variable.key]},
1118+
eqns.EquationSystem(constraints=[reduction_constraint]),
1119+
{
1120+
source_variable: [source_variable.key],
1121+
out_variable: [out_variable.key],
1122+
},
11181123
[broadcast_hint],
11191124
)
11201125

@@ -1153,9 +1158,9 @@ def _shape_cast_equation_system(
11531158

11541159
return (
11551160
eqns.EquationSystem(
1156-
equations=[
1157-
eqns.Equation(lhs=out_variable, rhs=in_to_out),
1158-
eqns.Equation(lhs=in_variable, rhs=out_to_in),
1161+
constraints=[
1162+
eqns.Equals(lhs=out_variable, rhs=in_to_out),
1163+
eqns.Equals(lhs=in_variable, rhs=out_to_in),
11591164
],
11601165
),
11611166
{in_variable: [in_variable.key], out_variable: [out_variable.key]},
@@ -1196,7 +1201,7 @@ def _custom_primitive_equation_system(
11961201
op: mgpu.CustomPrimitiveOp,
11971202
) -> tuple[eqns.EquationSystem, ValueSitesForVariable, list[Hint]]:
11981203
assignments: dict[eqns.Variable, eqns.Constant] = {}
1199-
equations: list[eqns.Equation] = []
1204+
constraints: list[eqns.Constraint] = []
12001205
in_layouts = iter(op.in_layouts)
12011206
in_transforms = iter(op.in_transforms)
12021207
variables: list[eqns.Variable] = []
@@ -1220,7 +1225,7 @@ def _custom_primitive_equation_system(
12201225
value_site = ValueSite(op, VariableType.OPERAND, i)
12211226
source_var = ctx.producer_ref(value_site)
12221227
v = eqns.Variable(value_site)
1223-
equations.append(eqns.Equation(lhs=source_var, rhs=v))
1228+
constraints.append(eqns.Equals(lhs=source_var, rhs=v))
12241229
variables.append(v)
12251230
transforms = next(in_transforms)
12261231
ref_ty = value_site.value.type
@@ -1236,7 +1241,7 @@ def _custom_primitive_equation_system(
12361241
layouts_lib.from_layout_attr(next(out_layouts))
12371242
)
12381243
return (
1239-
eqns.EquationSystem(equations=equations, assignments=assignments),
1244+
eqns.EquationSystem(assignments, constraints),
12401245
{v: [v.key] for v in variables},
12411246
[],
12421247
)
@@ -1512,11 +1517,11 @@ def _memref_transpose_op_equation_system(
15121517
return (eqns.EquationSystem(), {source_var: [source, dest]}, [])
15131518

15141519
dest_var = eqns.Variable(dest)
1515-
equations = [
1516-
eqns.Equation(source_var, eqns.Transpose(dest_var)),
1517-
eqns.Equation(eqns.Transpose(source_var), dest_var),
1520+
constraints = [
1521+
eqns.Equals(eqns.Transpose(source_var), dest_var),
1522+
eqns.Equals(source_var, eqns.Transpose(dest_var)),
15181523
]
1519-
system = eqns.EquationSystem(equations=equations)
1524+
system = eqns.EquationSystem(constraints=constraints)
15201525
return system, {source_var: [source], dest_var: [dest]}, []
15211526

15221527

0 commit comments

Comments
 (0)