@@ -340,6 +340,21 @@ def reduce_expression(
340340 case _:
341341 assert_never (expr )
342342
343+ @dataclasses .dataclass (frozen = True )
344+ class Equals :
345+ """States that `lhs` and `rhs` are equal."""
346+ lhs : Expression
347+ rhs : Expression
348+
349+ def holds (self ) -> bool | None :
350+ if self .lhs == self .rhs :
351+ return True
352+ if isinstance (self .lhs , Constant ) and isinstance (self .rhs , Constant ):
353+ return False
354+ return None
355+
356+ def __str__ (self ):
357+ return f"Equals({ self .lhs } == { self .rhs } )"
343358
344359_SUPPORTED_TILED_RELAYOUTS = frozenset ([
345360 # Transposed layouts.
@@ -548,7 +563,7 @@ def __str__(self):
548563 return f"{ self .tiling_multiple } % { self .expr } == 0"
549564
550565
551- Constraint = Relayout | NotOfType | IsTransferable | Divides
566+ Constraint = Equals | Relayout | NotOfType | IsTransferable | Divides
552567
553568
554569def reduce_constraint (
@@ -558,6 +573,14 @@ def reduce_constraint(
558573
559574 new_constraint : Constraint
560575 match constraint :
576+ case Equals (lhs = lhs , rhs = rhs ):
577+ lhs_red = reduce_expression (lhs , assignments )
578+ if isinstance (lhs_red , Unsatisfiable ):
579+ return Unsatisfiable ()
580+ rhs_red = reduce_expression (rhs , assignments )
581+ if isinstance (rhs_red , Unsatisfiable ):
582+ return Unsatisfiable ()
583+ new_constraint = Equals (lhs_red , rhs_red )
561584 case Relayout (source = source , target = target ):
562585 source_red = reduce_expression (source , assignments )
563586 target_red = reduce_expression (target , assignments )
@@ -591,51 +614,6 @@ def reduce_constraint(
591614 return Tautological () if constraint_holds else Unsatisfiable ()
592615
593616
594- @dataclasses .dataclass (frozen = True )
595- class Equation :
596- lhs : Expression
597- rhs : Expression
598-
599- def __str__ (self ):
600- return f"{ self .lhs } == { self .rhs } "
601-
602-
603- def reduce_equation (
604- eq : Equation , assignments : dict [Variable , Constant ]
605- ) -> Solution :
606- """Reduces an equation.
607-
608- Args:
609- eq: the equation to reduce.
610- assignments: a set of known variable assignments.
611-
612- Returns:
613- A Solution object representing the result of the evaluation. That is:
614- - Unsatisfiable(): if the equation is unsatisfiable.
615- - Tautological(): if the equation is tautological.
616- - Satisfiable(): if the equation is satisfiable by assigning a value to
617- a variable.
618- - Unknown(): if the equation contains remaining unknown variables.
619- """
620- lhs = reduce_expression (eq .lhs , assignments )
621- rhs = reduce_expression (eq .rhs , assignments )
622- match (lhs , rhs ):
623- case (Variable (), Constant ()):
624- return SatisfiedBy ((lhs , rhs ))
625- case (Constant (), Variable ()):
626- return SatisfiedBy ((rhs , lhs ))
627- case (Constant (), Constant ()) if lhs != rhs :
628- return Unsatisfiable ()
629- case _ if isinstance (lhs , Unsatisfiable ) or isinstance (rhs , Unsatisfiable ):
630- return Unsatisfiable ()
631- case _ if lhs == rhs :
632- return Tautological ()
633- case _:
634- # This is covered above. Add a check here to appease the type checker.
635- assert not isinstance (lhs , Unsatisfiable ) and not isinstance (rhs , Unsatisfiable )
636- return Unknown (Equation (lhs , rhs ))
637-
638-
639617@dataclasses .dataclass
640618class EquationSystem :
641619 """An equation system contains a set of equations and assignments.
@@ -650,7 +628,6 @@ class EquationSystem:
650628 assignments : dict [Variable , Constant ] = dataclasses .field (
651629 default_factory = dict
652630 )
653- equations : list [Equation ] = dataclasses .field (default_factory = list )
654631 constraints : Sequence [Constraint ] = dataclasses .field (default_factory = list )
655632
656633 def unknowns (self ) -> list [Variable ]:
@@ -681,11 +658,11 @@ def extract_variables(expr: Expression) -> None:
681658 extract_variables (e )
682659 case _:
683660 assert_never (expr )
684- for equation in self .equations :
685- extract_variables (equation .lhs )
686- extract_variables (equation .rhs )
687661 for constraint in self .constraints :
688662 match constraint :
663+ case Equals (lhs = lhs , rhs = rhs ):
664+ extract_variables (lhs )
665+ extract_variables (rhs )
689666 case Relayout (source = source , target = target ):
690667 extract_variables (source )
691668 extract_variables (target )
@@ -706,7 +683,6 @@ def __and__(self, other: EquationSystem) -> EquationSystem | Unsatisfiable:
706683 return Unsatisfiable ()
707684 return EquationSystem (
708685 assignments = self .assignments | other .assignments ,
709- equations = self .equations + other .equations ,
710686 constraints = [* self .constraints , * other .constraints ],
711687 )
712688
@@ -715,9 +691,6 @@ def __str__(self):
715691 r += " assignments:\n "
716692 for assignment , constant in self .assignments .items ():
717693 r += f" { assignment } ⟵ { constant } \n "
718- r += " equations:\n "
719- for equation in self .equations :
720- r += f" { equation } \n "
721694 r += " constraints:\n "
722695 for constraint in self .constraints :
723696 r += f" { constraint } \n "
@@ -729,16 +702,6 @@ def __and__(self, other: EquationSystem | Unsatisfiable) -> Unsatisfiable:
729702 return self
730703
731704
732- @dataclasses .dataclass (frozen = True )
733- class SatisfiedBy :
734- assignment : tuple [Variable , Constant ]
735-
736-
737- @dataclasses .dataclass (frozen = True )
738- class Unknown :
739- equation : Equation
740-
741-
742705class Tautological :
743706 ...
744707
@@ -756,14 +719,6 @@ def non_splat_variables(
756719 return vars
757720
758721
759- # The result of reducing an equation---and by extension, a system of
760- # equations. An equation can either be unsatisfiable (i.e. there exists no
761- # assignment for which it holds), satisfied by an assignment, unknown (i.e.
762- # still undetermined), or tautological (i.e. the equation is guaranteed to
763- # hold for any assignment).
764- Solution = Unsatisfiable | SatisfiedBy | Unknown | Tautological
765-
766-
767722def _has_relayout_of_non_splat_to_splat (constraints : Sequence [Constraint ]) -> bool :
768723 """Returns whether the constraints imply a non-splat to splat relayout.
769724
@@ -854,11 +809,14 @@ def union(v1: Variable, v2: Variable):
854809 parent [root2 ] = root1
855810
856811 all_vars : set [Variable ] = set ()
857- for eq in system .equations :
858- if isinstance (eq .lhs , Variable ) and isinstance (eq .rhs , Variable ):
859- all_vars .add (eq .lhs )
860- all_vars .add (eq .rhs )
861- union (eq .lhs , eq .rhs )
812+ for constraint in system .constraints :
813+ match constraint :
814+ case Equals (lhs = Variable () as lhs , rhs = Variable () as rhs ):
815+ assert isinstance (lhs , Variable ) # make pytype happy
816+ assert isinstance (rhs , Variable ) # make pytype happy
817+ all_vars .add (lhs )
818+ all_vars .add (rhs )
819+ union (lhs , rhs )
862820
863821 # Group variables by their component representative.
864822 components : dict [Variable , list [Variable ]] = {}
@@ -934,33 +892,28 @@ def _reduce_system_once(
934892 - None: if the equation system is not known unsatisfiable, but hasn't been
935893 reduced.
936894 """
895+ assignments = equation_system .assignments
896+ constraints : list [Constraint ] = []
937897 changed = False
938- assignments : dict [Variable , Constant ] = {}
939- equations : list [Equation ] = []
940- for equation in equation_system .equations :
941- match reduce_equation (equation , equation_system .assignments ):
942- case Unsatisfiable ():
943- return Unsatisfiable ()
944- case Tautological ():
945- changed = True
946- case SatisfiedBy () as result :
947- variable , expression = result .assignment
948- if variable in assignments and assignments [variable ] != expression :
949- return Unsatisfiable ()
950- assignments [variable ] = expression
951- changed = True
952- case Unknown (equation = reduced_equation ):
953- equations .append (reduced_equation )
954- changed |= reduced_equation != equation
955- case _ as never :
956- assert_never (never )
957898
958- assignments |= equation_system .assignments
959- constraints : list [Constraint ] = []
899+ def try_assign (var : Variable , cst : Constant ) -> bool :
900+ if var in assignments and assignments [var ] != cst :
901+ return False
902+ assignments [var ] = cst
903+ return True
904+
960905 for constraint in equation_system .constraints :
961906 match reduce_constraint (constraint , assignments ):
962907 case Unsatisfiable ():
963908 return Unsatisfiable ()
909+ case Equals (lhs = Variable () as var , rhs = Constant () as cst ):
910+ if not try_assign (var , cst ):
911+ return Unsatisfiable ()
912+ changed = True
913+ case Equals (lhs = Constant () as cst , rhs = Variable () as var ):
914+ if not try_assign (var , cst ):
915+ return Unsatisfiable ()
916+ changed = True
964917 case Tautological ():
965918 changed = True
966919 case _ as new_constraint :
@@ -979,7 +932,6 @@ def _reduce_system_once(
979932 if changed :
980933 return EquationSystem (
981934 assignments = assignments | equation_system .assignments ,
982- equations = equations ,
983935 constraints = constraints ,
984936 )
985937 return None
0 commit comments