1717 "FragmentTransformer" ,
1818 "TransformedElaboratable" ,
1919 "DomainCollector" , "DomainRenamer" , "DomainLowerer" ,
20+ "LHSMaskCollector" ,
2021 "ResetInserter" , "EnableInserter" ]
2122
2223
@@ -575,6 +576,71 @@ def on_fragment(self, fragment):
575576 return super ().on_fragment (fragment )
576577
577578
579+ class LHSMaskCollector :
580+ def __init__ (self ):
581+ self .lhs = SignalDict ()
582+
583+ def visit_stmt (self , stmt ):
584+ if type (stmt ) is Assign :
585+ self .visit_value (stmt .lhs , ~ 0 )
586+ elif type (stmt ) is Switch :
587+ for (_ , substmt , _ ) in stmt .cases :
588+ self .visit_stmt (substmt )
589+ elif type (stmt ) in (Property , Print ):
590+ pass
591+ elif isinstance (stmt , Iterable ):
592+ for substmt in stmt :
593+ self .visit_stmt (substmt )
594+ else :
595+ assert False # :nocov:
596+
597+ def visit_value (self , value , mask ):
598+ if type (value ) in (Signal , ClockSignal , ResetSignal ):
599+ mask &= (1 << len (value )) - 1
600+ self .lhs .setdefault (value , 0 )
601+ self .lhs [value ] |= mask
602+ elif type (value ) is Operator :
603+ assert value .operator in ("s" , "u" )
604+ self .visit_value (value .operands [0 ], mask )
605+ elif type (value ) is Slice :
606+ slice_mask = (1 << value .stop ) - (1 << value .start )
607+ mask <<= value .start
608+ mask &= slice_mask
609+ self .visit_value (value .value , mask )
610+ elif type (value ) is Part :
611+ # Could be more accurate, but if you're relying on such details, you're not seeing
612+ # the Light of Heaven.
613+ self .visit_value (value .value , ~ 0 )
614+ elif type (value ) is Concat :
615+ for part in value .parts :
616+ self .visit_value (part , mask )
617+ mask >>= len (part )
618+ elif type (value ) is SwitchValue :
619+ for (_ , subvalue ) in value .cases :
620+ self .visit_value (subvalue , mask )
621+ else :
622+ assert False # :nocov:
623+
624+ def chunks (self ):
625+ for signal , mask in self .lhs .items ():
626+ if mask == (1 << len (signal )) - 1 :
627+ yield signal , 0 , None
628+ else :
629+ start = 0
630+ while start < len (signal ):
631+ if ((mask >> start ) & 1 ) == 0 :
632+ start += 1
633+ else :
634+ stop = start
635+ while stop < len (signal ) and ((mask >> stop ) & 1 ) == 1 :
636+ stop += 1
637+ yield (signal , start , stop )
638+ start = stop
639+
640+ def masks (self ):
641+ yield from self .lhs .items ()
642+
643+
578644class _ControlInserter (FragmentTransformer ):
579645 def __init__ (self , controls ):
580646 self .src_loc = None
@@ -589,10 +655,9 @@ def on_fragment(self, fragment):
589655 for domain , statements in fragment .statements .items ():
590656 if domain == "comb" or domain not in self .controls :
591657 continue
592- signals = SignalSet ()
593- for stmt in statements :
594- signals |= stmt ._lhs_signals ()
595- self ._insert_control (new_fragment , domain , signals )
658+ lhs_masks = LHSMaskCollector ()
659+ lhs_masks .visit_stmt (statements )
660+ self ._insert_control (new_fragment , domain , lhs_masks )
596661 return new_fragment
597662
598663 def _insert_control (self , fragment , domain , signals ):
@@ -604,13 +669,20 @@ def __call__(self, value, *, src_loc_at=0):
604669
605670
606671class ResetInserter (_ControlInserter ):
607- def _insert_control (self , fragment , domain , signals ):
608- stmts = [s .eq (Const (s .init , s .shape ())) for s in signals if not s .reset_less ]
672+ def _insert_control (self , fragment , domain , lhs_masks ):
673+ stmts = []
674+ for signal , start , stop in lhs_masks .chunks ():
675+ if signal .reset_less :
676+ continue
677+ if start == 0 and stop is None :
678+ stmts .append (signal .eq (Const (signal .init , signal .shape ())))
679+ else :
680+ stmts .append (signal [start :stop ].eq (Const (signal .init , signal .shape ())[start :stop ]))
609681 fragment .add_statements (domain , Switch (self .controls [domain ], [(1 , stmts , None )], src_loc = self .src_loc ))
610682
611683
612684class EnableInserter (_ControlInserter ):
613- def _insert_control (self , fragment , domain , signals ):
685+ def _insert_control (self , fragment , domain , _lhs_masks ):
614686 if domain in fragment .statements :
615687 fragment .statements [domain ] = _StatementList ([Switch (
616688 self .controls [domain ],
0 commit comments