@@ -113,17 +113,17 @@ def postgres_uri():
113113
114114@pytest .fixture
115115def fix1_AgentSet () -> ExampleAgentSet1 :
116- return ExampleAgentSet1 (Model (seed = 1 ))
116+ return ExampleAgentSet1 (Model (seed = 1 ))
117117
118118
119119@pytest .fixture
120120def fix2_AgentSet () -> ExampleAgentSet2 :
121- return ExampleAgentSet2 (Model (seed = 1 ))
121+ return ExampleAgentSet2 (Model (seed = 1 ))
122122
123123
124124@pytest .fixture
125125def fix3_AgentSet () -> ExampleAgentSet3 :
126- return ExampleAgentSet3 (Model (seed = 1 ))
126+ return ExampleAgentSet3 (Model (seed = 1 ))
127127
128128
129129@pytest .fixture
@@ -176,8 +176,8 @@ def test__init__(self, fix1_model, postgres_uri):
176176 match = "Agent reporter for 'wealth' must be a string" ,
177177 ):
178178 model .test_dc = DataCollector (
179- model = model ,
180- agent_reporters = {"wealth" : lambda m : 1 } # This is now illegal
179+ model = model ,
180+ agent_reporters = {"wealth" : lambda m : 1 }, # This is now illegal
181181 )
182182
183183 with pytest .raises (
@@ -225,7 +225,7 @@ def test_collect(self, fix1_model):
225225 collected_data ["model" ]["max_wealth" ]
226226
227227 agent_df = collected_data ["agent" ]
228-
228+
229229 ## 3 agent sets * 4 agents/set = 12 rows
230230 assert agent_df .shape == (12 , 7 )
231231 assert set (agent_df .columns ) == {
@@ -237,17 +237,26 @@ def test_collect(self, fix1_model):
237237 "seed" ,
238238 "batch" ,
239239 }
240-
240+
241241 expected_wealth = [1 , 2 , 3 , 4 ] + [10 , 20 , 30 , 40 ] + [1 , 2 , 3 , 4 ]
242242 expected_age = [10 , 20 , 30 , 40 ] + [11 , 22 , 33 , 44 ] + [1 , 2 , 3 , 4 ]
243-
243+
244244 assert sorted (agent_df ["wealth" ].to_list ()) == sorted (expected_wealth )
245245 assert sorted (agent_df ["age" ].to_list ()) == sorted (expected_age )
246246
247247 type_counts = agent_df ["agent_type" ].value_counts (sort = True )
248- assert type_counts .filter (pl .col ("agent_type" ) == "ExampleAgentSet1" )["count" ][0 ] == 4
249- assert type_counts .filter (pl .col ("agent_type" ) == "ExampleAgentSet2" )["count" ][0 ] == 4
250- assert type_counts .filter (pl .col ("agent_type" ) == "ExampleAgentSet3" )["count" ][0 ] == 4
248+ assert (
249+ type_counts .filter (pl .col ("agent_type" ) == "ExampleAgentSet1" )["count" ][0 ]
250+ == 4
251+ )
252+ assert (
253+ type_counts .filter (pl .col ("agent_type" ) == "ExampleAgentSet2" )["count" ][0 ]
254+ == 4
255+ )
256+ assert (
257+ type_counts .filter (pl .col ("agent_type" ) == "ExampleAgentSet3" )["count" ][0 ]
258+ == 4
259+ )
251260 assert agent_df ["step" ].to_list () == [0 ] * 12
252261 with pytest .raises (pl .exceptions .ColumnNotFoundError , match = "max_wealth" ):
253262 collected_data ["agent" ]["max_wealth" ]
@@ -284,12 +293,18 @@ def test_collect_step(self, fix1_model):
284293 agent_df = collected_data ["agent" ]
285294 assert agent_df .shape == (12 , 7 )
286295 assert set (agent_df .columns ) == {
287- "unique_id" , "agent_type" , "wealth" , "age" , "step" , "seed" , "batch"
296+ "unique_id" ,
297+ "agent_type" ,
298+ "wealth" ,
299+ "age" ,
300+ "step" ,
301+ "seed" ,
302+ "batch" ,
288303 }
289304
290305 expected_wealth = [6 , 7 , 8 , 9 ] + [20 , 30 , 40 , 50 ] + [1 , 2 , 3 , 4 ]
291306 expected_age = [10 , 20 , 30 , 40 ] + [11 , 22 , 33 , 44 ] + [6 , 7 , 8 , 9 ]
292-
307+
293308 assert sorted (agent_df ["wealth" ].to_list ()) == sorted (expected_wealth )
294309 assert sorted (agent_df ["age" ].to_list ()) == sorted (expected_age )
295310 assert agent_df ["step" ].to_list () == [5 ] * 12
@@ -324,17 +339,23 @@ def test_conditional_collect(self, fix1_model):
324339 assert collected_data ["model" ]["total_agents" ].to_list () == [12 , 12 ]
325340
326341 agent_df = collected_data ["agent" ]
327-
342+
328343 # 12 agents * 2 steps = 24 rows
329344 assert agent_df .shape == (24 , 7 )
330345 assert set (agent_df .columns ) == {
331- "unique_id" , "agent_type" , "wealth" , "age" , "step" , "seed" , "batch"
346+ "unique_id" ,
347+ "agent_type" ,
348+ "wealth" ,
349+ "age" ,
350+ "step" ,
351+ "seed" ,
352+ "batch" ,
332353 }
333354
334355 df_step_2 = agent_df .filter (pl .col ("step" ) == 2 )
335356 expected_wealth_s2 = [3 , 4 , 5 , 6 ] + [14 , 24 , 34 , 44 ] + [1 , 2 , 3 , 4 ]
336357 expected_age_s2 = [10 , 20 , 30 , 40 ] + [11 , 22 , 33 , 44 ] + [3 , 4 , 5 , 6 ]
337-
358+
338359 assert df_step_2 .shape == (12 , 7 )
339360 assert sorted (df_step_2 ["wealth" ].to_list ()) == sorted (expected_wealth_s2 )
340361 assert sorted (df_step_2 ["age" ].to_list ()) == sorted (expected_age_s2 )
@@ -398,12 +419,18 @@ def test_flush_local_csv(self, fix1_model):
398419 )
399420 assert agent_df .shape == (12 , 7 )
400421 assert set (agent_df .columns ) == {
401- "unique_id" , "agent_type" , "wealth" , "age" , "step" , "seed" , "batch"
422+ "unique_id" ,
423+ "agent_type" ,
424+ "wealth" ,
425+ "age" ,
426+ "step" ,
427+ "seed" ,
428+ "batch" ,
402429 }
403-
430+
404431 expected_wealth_s2 = [3 , 4 , 5 , 6 ] + [14 , 24 , 34 , 44 ] + [1 , 2 , 3 , 4 ]
405432 expected_age_s2 = [10 , 20 , 30 , 40 ] + [11 , 22 , 33 , 44 ] + [3 , 4 , 5 , 6 ]
406-
433+
407434 assert agent_df ["step" ].to_list () == [2 ] * 12
408435 assert sorted (agent_df ["wealth" ].to_list ()) == sorted (expected_wealth_s2 )
409436 assert sorted (agent_df ["age" ].to_list ()) == sorted (expected_age_s2 )
@@ -460,9 +487,14 @@ def test_flush_local_parquet(self, fix1_model):
460487 # 12 rows. 6 cols: unique_id, agent_type, wealth, step, seed, batch
461488 assert agent_df .shape == (12 , 6 )
462489 assert set (agent_df .columns ) == {
463- "unique_id" , "agent_type" , "wealth" , "step" , "seed" , "batch"
490+ "unique_id" ,
491+ "agent_type" ,
492+ "wealth" ,
493+ "step" ,
494+ "seed" ,
495+ "batch" ,
464496 }
465-
497+
466498 expected_wealth = [1 , 2 , 3 , 4 ] + [10 , 20 , 30 , 40 ] + [1 , 2 , 3 , 4 ]
467499 assert agent_df ["step" ].to_list () == [0 ] * 12
468500 assert sorted (agent_df ["wealth" ].to_list ()) == sorted (expected_wealth )
@@ -476,14 +508,14 @@ def test_postgress(self, fix1_model, postgres_uri):
476508
477509 # Connect directly and validate data
478510 import psycopg2
479-
511+
480512 try :
481513 conn = psycopg2 .connect (postgres_uri )
482514 except psycopg2 .OperationalError as e :
483515 pytest .skip (f"Could not connect to PostgreSQL: { e } " )
484-
516+
485517 cur = conn .cursor ()
486-
518+
487519 ## Cleaning up tables first
488520 cur .execute ("DROP TABLE IF EXISTS public.model_data, public.agent_data;" )
489521 conn .commit ()
@@ -529,7 +561,7 @@ def test_postgress(self, fix1_model, postgres_uri):
529561 storage_uri = postgres_uri ,
530562 )
531563
532- model .run_model_with_conditional_collect (4 ) ## Runs 1,2,3,4. Collects at 2, 4.
564+ model .run_model_with_conditional_collect (4 ) ## Runs 1,2,3,4. Collects at 2, 4.
533565 model .dc .flush ()
534566
535567 # Connect directly and validate data
@@ -545,17 +577,27 @@ def test_postgress(self, fix1_model, postgres_uri):
545577 model_rows = cur .fetchall ()
546578 assert model_rows == [(2 , 12 ), (4 , 12 )]
547579
548- # MODIFIED: Check agent data
580+ # MODIFIED: Check agent data
549581 cur .execute (
550582 "SELECT wealth, age FROM agent_data WHERE step=2 ORDER BY wealth, age"
551583 )
552584 agent_rows = cur .fetchall ()
553585
554586 expected_rows_s2 = [
555- (1 , 3 ), (2 , 4 ), (3 , 5 ), (3 , 10 ), (4 , 6 ), (4 , 20 ),
556- (5 , 30 ), (6 , 40 ), (14 , 11 ), (24 , 22 ), (34 , 33 ), (44 , 44 )
587+ (1 , 3 ),
588+ (2 , 4 ),
589+ (3 , 5 ),
590+ (3 , 10 ),
591+ (4 , 6 ),
592+ (4 , 20 ),
593+ (5 , 30 ),
594+ (6 , 40 ),
595+ (14 , 11 ),
596+ (24 , 22 ),
597+ (34 , 33 ),
598+ (44 , 44 ),
557599 ]
558-
600+
559601 assert sorted (agent_rows ) == sorted (expected_rows_s2 )
560602
561603 cur .execute ("DROP TABLE public.model_data, public.agent_data;" )
@@ -573,7 +615,7 @@ def test_batch_memory(self, fix2_model):
573615 len (agentset ) for agentset in model .sets ._agentsets
574616 )
575617 },
576- agent_reporters = { "wealth" : "wealth" , "age" : "age" },
618+ agent_reporters = {"wealth" : "wealth" , "age" : "age" },
577619 )
578620
579621 model .run_model_with_conditional_collect_multiple_batch (5 )
@@ -586,7 +628,13 @@ def test_batch_memory(self, fix2_model):
586628 agent_df = collected_data ["agent" ]
587629 assert agent_df .shape == (48 , 7 )
588630 assert set (agent_df .columns ) == {
589- "unique_id" , "agent_type" , "wealth" , "age" , "step" , "seed" , "batch"
631+ "unique_id" ,
632+ "agent_type" ,
633+ "wealth" ,
634+ "age" ,
635+ "step" ,
636+ "seed" ,
637+ "batch" ,
590638 }
591639
592640 df_s2_b0 = agent_df .filter ((pl .col ("step" ) == 2 ) & (pl .col ("batch" ) == 0 ))
@@ -622,10 +670,10 @@ def test_batch_save(self, fix2_model):
622670
623671 model .run_model_with_conditional_collect_multiple_batch (5 )
624672 model .dc .flush ()
625-
673+
626674 for _ in range (20 ): # wait up to ~2 seconds
627675 created_files = os .listdir (tmpdir )
628- if len (created_files ) >= 8 : # 4 collects * 2 files/collect = 8 files
676+ if len (created_files ) >= 8 : # 4 collects * 2 files/collect = 8 files
629677 break
630678 time .sleep (0.1 )
631679
@@ -654,27 +702,30 @@ def test_batch_save(self, fix2_model):
654702 )
655703 assert model_df_step2_batch1 ["step" ].to_list () == [2 ]
656704 assert model_df_step2_batch1 ["total_agents" ].to_list () == [12 ]
657-
705+
658706 model_df_step4_batch0 = pl .read_csv (
659707 os .path .join (tmpdir , "model_step4_batch0.csv" ),
660708 schema_overrides = {"seed" : pl .Utf8 },
661709 )
662710 assert model_df_step4_batch0 ["step" ].to_list () == [4 ]
663711 assert model_df_step4_batch0 ["total_agents" ].to_list () == [12 ]
664-
712+
665713 # test agent batch reset
666714 agent_df_step2_batch0 = pl .read_csv (
667715 os .path .join (tmpdir , "agent_step2_batch0.csv" ),
668716 schema_overrides = {"seed" : pl .Utf8 , "unique_id" : pl .UInt64 },
669717 )
670718
671719 expected_wealth_s2b0 = [2 , 3 , 4 , 5 ] + [12 , 22 , 32 , 42 ] + [1 , 2 , 3 , 4 ]
672- assert sorted (agent_df_step2_batch0 ["wealth" ].to_list ()) == sorted (expected_wealth_s2b0 )
720+ assert sorted (agent_df_step2_batch0 ["wealth" ].to_list ()) == sorted (
721+ expected_wealth_s2b0
722+ )
673723
674724 agent_df_step2_batch1 = pl .read_csv (
675725 os .path .join (tmpdir , "agent_step2_batch1.csv" ),
676726 schema_overrides = {"seed" : pl .Utf8 , "unique_id" : pl .UInt64 },
677727 )
678728 expected_wealth_s2b1 = [3 , 4 , 5 , 6 ] + [14 , 24 , 34 , 44 ] + [1 , 2 , 3 , 4 ]
679- assert sorted (agent_df_step2_batch1 ["wealth" ].to_list ()) == sorted (expected_wealth_s2b1 )
680-
729+ assert sorted (agent_df_step2_batch1 ["wealth" ].to_list ()) == sorted (
730+ expected_wealth_s2b1
731+ )
0 commit comments