@@ -232,9 +232,11 @@ def contains(
232232 agents : PolarsIdsLike ,
233233 ) -> bool | pl .Series :
234234 if isinstance (agents , pl .Series ):
235- return agents .is_in (self ._df ["unique_id" ])
235+ return agents .is_in (self ._df ["unique_id" ]. implode () )
236236 elif isinstance (agents , Collection ) and not isinstance (agents , str ):
237- return pl .Series (agents , dtype = pl .UInt64 ).is_in (self ._df ["unique_id" ])
237+ return pl .Series (agents , dtype = pl .UInt64 ).is_in (
238+ self ._df ["unique_id" ].implode ()
239+ )
238240 else :
239241 return agents in self ._df ["unique_id" ]
240242
@@ -322,7 +324,7 @@ def remove(self, agents: PolarsIdsLike | AgentMask, inplace: bool = True) -> Sel
322324 # Normalize to Series of unique_ids
323325 ids = obj ._df_index (obj ._get_masked_df (agents ), "unique_id" )
324326 # Validate presence
325- if not ids .is_in (obj ._df ["unique_id" ]).all ():
327+ if not ids .is_in (obj ._df ["unique_id" ]. implode () ).all ():
326328 raise KeyError ("Some 'unique_id' of mask are not present in this AgentSet." )
327329 # Remove by ids
328330 return obj ._discard (ids )
@@ -396,8 +398,8 @@ def select(
396398 if filter_func :
397399 mask = mask & filter_func (obj )
398400 if n is not None :
399- mask = ( obj ._df ["unique_id" ]) .is_in (
400- obj ._df .filter (mask ).sample (n )["unique_id" ]
401+ mask = obj ._df ["unique_id" ].is_in (
402+ obj ._df .filter (mask ).sample (n )["unique_id" ]. implode ()
401403 )
402404 if negate :
403405 mask = mask .not_ ()
@@ -456,7 +458,9 @@ def _concatenate_agentsets(
456458 for obj in iter (agentsets ):
457459 # Remove agents that are already in the final DataFrame
458460 final_dfs .append (
459- obj ._df .filter (pl .col ("unique_id" ).is_in (final_indices ).not_ ())
461+ obj ._df .filter (
462+ pl .col ("unique_id" ).is_in (final_indices .implode ()).not_ ()
463+ )
460464 )
461465 # Add the indices of the active agents of current AgentSet
462466 final_active_indices .append (obj ._df .filter (obj ._mask )["unique_id" ])
@@ -476,13 +480,13 @@ def _concatenate_agentsets(
476480 final_active_index = pl .concat (
477481 [obj ._df .filter (obj ._mask )["unique_id" ] for obj in agentsets ]
478482 )
479- final_mask = final_df ["unique_id" ].is_in (final_active_index )
483+ final_mask = final_df ["unique_id" ].is_in (final_active_index . implode () )
480484 self ._df = final_df
481485 self ._mask = final_mask
482486 # If some ids were removed in the do-method, we need to remove them also from final_df
483487 if not isinstance (original_masked_index , type (None )):
484488 ids_to_remove = original_masked_index .filter (
485- original_masked_index .is_in (self ._df ["unique_id" ]).not_ ()
489+ original_masked_index .is_in (self ._df ["unique_id" ]. implode () ).not_ ()
486490 )
487491 if not ids_to_remove .is_empty ():
488492 self .remove (ids_to_remove , inplace = True )
@@ -499,7 +503,7 @@ def bool_mask_from_series(mask: pl.Series) -> pl.Series:
499503 and len (mask ) == len (self ._df )
500504 ):
501505 return mask
502- return self ._df ["unique_id" ].is_in (mask )
506+ return self ._df ["unique_id" ].is_in (mask . implode () )
503507
504508 if isinstance (mask , pl .Expr ):
505509 return mask
@@ -532,13 +536,13 @@ def _get_masked_df(
532536 ):
533537 return self ._df .filter (mask )
534538 elif isinstance (mask , pl .DataFrame ):
535- if not mask ["unique_id" ].is_in (self ._df ["unique_id" ]).all ():
539+ if not mask ["unique_id" ].is_in (self ._df ["unique_id" ]. implode () ).all ():
536540 raise KeyError (
537541 "Some 'unique_id' of mask are not present in DataFrame 'unique_id'."
538542 )
539543 return mask .select ("unique_id" ).join (self ._df , on = "unique_id" , how = "left" )
540544 elif isinstance (mask , pl .Series ):
541- if not mask .is_in (self ._df ["unique_id" ]).all ():
545+ if not mask .is_in (self ._df ["unique_id" ]. implode () ).all ():
542546 raise KeyError (
543547 "Some 'unique_id' of mask are not present in DataFrame 'unique_id'."
544548 )
@@ -553,7 +557,7 @@ def _get_masked_df(
553557 mask_series = pl .Series (mask , dtype = pl .UInt64 )
554558 else :
555559 mask_series = pl .Series ([mask ], dtype = pl .UInt64 )
556- if not mask_series .is_in (self ._df ["unique_id" ]).all ():
560+ if not mask_series .is_in (self ._df ["unique_id" ]. implode () ).all ():
557561 raise KeyError (
558562 "Some 'unique_id' of mask are not present in DataFrame 'unique_id'."
559563 )
@@ -585,12 +589,13 @@ def _discard(self, ids: PolarsIdsLike) -> Self:
585589 def _update_mask (
586590 self , original_active_indices : pl .Series , new_indices : pl .Series | None = None
587591 ) -> None :
592+ original_active = original_active_indices .implode ()
588593 if new_indices is not None :
589- self ._mask = self ._df ["unique_id" ].is_in (
590- original_active_indices
591- ) | self . _df [ "unique_id" ].is_in (new_indices )
594+ self ._mask = self ._df ["unique_id" ].is_in (original_active ) | self . _df [
595+ "unique_id"
596+ ].is_in (new_indices . implode () )
592597 else :
593- self ._mask = self ._df ["unique_id" ].is_in (original_active_indices )
598+ self ._mask = self ._df ["unique_id" ].is_in (original_active )
594599
595600 def __getattr__ (self , key : str ) -> Any :
596601 if key == "name" :
0 commit comments