Skip to content

Commit 0e1044c

Browse files
committed
refactor: streamline key management methods in AgentSetRegistry and enforce model consistency
1 parent 2935894 commit 0e1044c

File tree

2 files changed

+33
-75
lines changed

2 files changed

+33
-75
lines changed

mesa_frames/abstract/agentsetregistry.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -538,6 +538,13 @@ def __setitem__(
538538
- For name keys, the key is authoritative for the assigned set's name
539539
- For index keys, collisions on a different entry's name must raise
540540
"""
541+
if value.model is not self.model:
542+
raise TypeError("Assigned AgentSet must belong to the same model")
543+
if isinstance(key, (int, str)):
544+
# Delegate to replace() so subclasses centralize invariant handling.
545+
self.replace({key: value}, inplace=True, atomic=True)
546+
return
547+
raise TypeError("Key must be int index or str name")
541548

542549
@abstractmethod
543550
def __getattr__(self, name: str) -> Any | dict[str, Any]:
@@ -568,14 +575,23 @@ def __str__(self) -> str:
568575
"""Get a string representation of the AgentSets in the registry."""
569576
...
570577

571-
@abstractmethod
572578
def keys(
573579
self, *, key_by: KeyBy = "name"
574580
) -> Iterable[str | int | type[mesa_frames.abstract.agentset.AbstractAgentSet]]:
575581
"""Iterate keys for contained AgentSets (by name|index|type)."""
576-
...
582+
if key_by == "index":
583+
yield from range(len(self))
584+
return
585+
if key_by == "type":
586+
for agentset in self:
587+
yield type(agentset)
588+
return
589+
if key_by != "name":
590+
raise ValueError("key_by must be 'name'|'index'|'type'")
591+
for agentset in self:
592+
if agentset.name is not None:
593+
yield agentset.name
577594

578-
@abstractmethod
579595
def items(
580596
self, *, key_by: KeyBy = "name"
581597
) -> Iterable[
@@ -585,12 +601,23 @@ def items(
585601
]
586602
]:
587603
"""Iterate (key, AgentSet) pairs for contained sets."""
588-
...
604+
if key_by == "index":
605+
for idx, agentset in enumerate(self):
606+
yield idx, agentset
607+
return
608+
if key_by == "type":
609+
for agentset in self:
610+
yield type(agentset), agentset
611+
return
612+
if key_by != "name":
613+
raise ValueError("key_by must be 'name'|'index'|'type'")
614+
for agentset in self:
615+
if agentset.name is not None:
616+
yield agentset.name, agentset
589617

590-
@abstractmethod
591618
def values(self) -> Iterable[mesa_frames.abstract.agentset.AbstractAgentSet]:
592619
"""Iterate contained AgentSets (values view)."""
593-
...
620+
yield from self
594621

595622
@property
596623
def model(self) -> mesa_frames.concrete.model.Model:

mesa_frames/concrete/agentsetregistry.py

Lines changed: 0 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -669,78 +669,9 @@ def __repr__(self) -> str:
669669
def __reversed__(self) -> Iterator[AgentSet]:
670670
return reversed(self._agentsets)
671671

672-
def __setitem__(self, key: int | str, value: AgentSet) -> None:
673-
"""Assign/replace a single AgentSet at an index or name.
674-
675-
Enforces name uniqueness and model consistency.
676-
"""
677-
if value.model is not self.model:
678-
raise TypeError("Assigned AgentSet must belong to the same model")
679-
if isinstance(key, int):
680-
if value.name is not None:
681-
for i, s in enumerate(self._agentsets):
682-
if i != key and s.name == value.name:
683-
raise ValueError(
684-
f"Duplicate agent set name disallowed: {value.name}"
685-
)
686-
self._agentsets[key] = value
687-
elif isinstance(key, str):
688-
try:
689-
value.rename(key)
690-
except Exception:
691-
if hasattr(value, "_name"):
692-
setattr(value, "_name", key)
693-
idx = None
694-
for i, s in enumerate(self._agentsets):
695-
if s.name == key:
696-
idx = i
697-
break
698-
if idx is None:
699-
self._agentsets.append(value)
700-
else:
701-
self._agentsets[idx] = value
702-
else:
703-
raise TypeError("Key must be int index or str name")
704-
# Recompute ids cache
705-
self._recompute_ids()
706-
707672
def __str__(self) -> str:
708673
return "\n".join([str(agentset) for agentset in self._agentsets])
709674

710-
def keys(self, *, key_by: KeyBy = "name") -> Iterable[Any]:
711-
if key_by not in ("name", "index", "type"):
712-
raise ValueError("key_by must be 'name'|'index'|'type'")
713-
if key_by == "index":
714-
yield from range(len(self._agentsets))
715-
return
716-
if key_by == "type":
717-
for s in self._agentsets:
718-
yield type(s)
719-
return
720-
# name
721-
for s in self._agentsets:
722-
if s.name is not None:
723-
yield s.name
724-
725-
def items(self, *, key_by: KeyBy = "name") -> Iterable[tuple[Any, AgentSet]]:
726-
if key_by not in ("name", "index", "type"):
727-
raise ValueError("key_by must be 'name'|'index'|'type'")
728-
if key_by == "index":
729-
for i, s in enumerate(self._agentsets):
730-
yield i, s
731-
return
732-
if key_by == "type":
733-
for s in self._agentsets:
734-
yield type(s), s
735-
return
736-
# name
737-
for s in self._agentsets:
738-
if s.name is not None:
739-
yield s.name, s
740-
741-
def values(self) -> Iterable[AgentSet]:
742-
return iter(self._agentsets)
743-
744675
@property
745676
def ids(self) -> pl.Series:
746677
"""Public view of all agent unique_id values across contained sets."""

0 commit comments

Comments
 (0)