Skip to content

Commit c5be394

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 5c1525f commit c5be394

File tree

2 files changed

+94
-47
lines changed

2 files changed

+94
-47
lines changed

mesa_frames/concrete/datacollector.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def __init__(
113113
f"Agent reporter for '{key}' must be a string (the column name), "
114114
f"not a {type(value)}. Callable reporters are not supported for agents."
115115
)
116-
116+
117117
super().__init__(
118118
model=model,
119119
model_reporters=model_reporters,
@@ -223,7 +223,7 @@ def _collect_agent_reporters(self, current_model_step: int, batch_id: int):
223223
if source_col in available_cols:
224224
## Add the column, aliasing it if the key is different
225225
cols_to_select.append(pl.col(source_col).alias(final_name))
226-
226+
227227
## Only proceed if we have more than just unique_id
228228
if len(cols_to_select) > 1:
229229
set_frame = agent_df.select(cols_to_select)
@@ -516,7 +516,7 @@ def _validate_reporter_table_columns(
516516
If any expected columns are missing from the table.
517517
"""
518518
expected_columns = set()
519-
519+
520520
## Add columns required for the new long agent format
521521
if table_name == "agent_data":
522522
expected_columns.add("unique_id")
@@ -546,11 +546,7 @@ def _validate_reporter_table_columns(
546546

547547
existing_columns = {row[0] for row in result}
548548
missing_columns = expected_columns - existing_columns
549-
required_columns = {
550-
"step": "Integer",
551-
"seed": "Varchar",
552-
"batch": "Integer"
553-
}
549+
required_columns = {"step": "Integer", "seed": "Varchar", "batch": "Integer"}
554550

555551
missing_required = {
556552
col: col_type
@@ -596,4 +592,4 @@ def _execute_query_with_result(self, conn: connection, query: str) -> list[tuple
596592
"""
597593
with conn.cursor() as cur:
598594
cur.execute(query)
599-
return cur.fetchall()
595+
return cur.fetchall()

tests/test_datacollector.py

Lines changed: 89 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -113,17 +113,17 @@ def postgres_uri():
113113

114114
@pytest.fixture
115115
def fix1_AgentSet() -> ExampleAgentSet1:
116-
return ExampleAgentSet1(Model(seed = 1))
116+
return ExampleAgentSet1(Model(seed=1))
117117

118118

119119
@pytest.fixture
120120
def fix2_AgentSet() -> ExampleAgentSet2:
121-
return ExampleAgentSet2(Model(seed = 1))
121+
return ExampleAgentSet2(Model(seed=1))
122122

123123

124124
@pytest.fixture
125125
def fix3_AgentSet() -> ExampleAgentSet3:
126-
return ExampleAgentSet3(Model(seed = 1))
126+
return ExampleAgentSet3(Model(seed=1))
127127

128128

129129
@pytest.fixture
@@ -176,8 +176,8 @@ def test__init__(self, fix1_model, postgres_uri):
176176
match="Agent reporter for 'wealth' must be a string",
177177
):
178178
model.test_dc = DataCollector(
179-
model=model,
180-
agent_reporters={"wealth": lambda m: 1} # This is now illegal
179+
model=model,
180+
agent_reporters={"wealth": lambda m: 1}, # This is now illegal
181181
)
182182

183183
with pytest.raises(
@@ -225,7 +225,7 @@ def test_collect(self, fix1_model):
225225
collected_data["model"]["max_wealth"]
226226

227227
agent_df = collected_data["agent"]
228-
228+
229229
## 3 agent sets * 4 agents/set = 12 rows
230230
assert agent_df.shape == (12, 7)
231231
assert set(agent_df.columns) == {
@@ -237,17 +237,26 @@ def test_collect(self, fix1_model):
237237
"seed",
238238
"batch",
239239
}
240-
240+
241241
expected_wealth = [1, 2, 3, 4] + [10, 20, 30, 40] + [1, 2, 3, 4]
242242
expected_age = [10, 20, 30, 40] + [11, 22, 33, 44] + [1, 2, 3, 4]
243-
243+
244244
assert sorted(agent_df["wealth"].to_list()) == sorted(expected_wealth)
245245
assert sorted(agent_df["age"].to_list()) == sorted(expected_age)
246246

247247
type_counts = agent_df["agent_type"].value_counts(sort=True)
248-
assert type_counts.filter(pl.col("agent_type") == "ExampleAgentSet1")["count"][0] == 4
249-
assert type_counts.filter(pl.col("agent_type") == "ExampleAgentSet2")["count"][0] == 4
250-
assert type_counts.filter(pl.col("agent_type") == "ExampleAgentSet3")["count"][0] == 4
248+
assert (
249+
type_counts.filter(pl.col("agent_type") == "ExampleAgentSet1")["count"][0]
250+
== 4
251+
)
252+
assert (
253+
type_counts.filter(pl.col("agent_type") == "ExampleAgentSet2")["count"][0]
254+
== 4
255+
)
256+
assert (
257+
type_counts.filter(pl.col("agent_type") == "ExampleAgentSet3")["count"][0]
258+
== 4
259+
)
251260
assert agent_df["step"].to_list() == [0] * 12
252261
with pytest.raises(pl.exceptions.ColumnNotFoundError, match="max_wealth"):
253262
collected_data["agent"]["max_wealth"]
@@ -284,12 +293,18 @@ def test_collect_step(self, fix1_model):
284293
agent_df = collected_data["agent"]
285294
assert agent_df.shape == (12, 7)
286295
assert set(agent_df.columns) == {
287-
"unique_id", "agent_type", "wealth", "age", "step", "seed", "batch"
296+
"unique_id",
297+
"agent_type",
298+
"wealth",
299+
"age",
300+
"step",
301+
"seed",
302+
"batch",
288303
}
289304

290305
expected_wealth = [6, 7, 8, 9] + [20, 30, 40, 50] + [1, 2, 3, 4]
291306
expected_age = [10, 20, 30, 40] + [11, 22, 33, 44] + [6, 7, 8, 9]
292-
307+
293308
assert sorted(agent_df["wealth"].to_list()) == sorted(expected_wealth)
294309
assert sorted(agent_df["age"].to_list()) == sorted(expected_age)
295310
assert agent_df["step"].to_list() == [5] * 12
@@ -324,17 +339,23 @@ def test_conditional_collect(self, fix1_model):
324339
assert collected_data["model"]["total_agents"].to_list() == [12, 12]
325340

326341
agent_df = collected_data["agent"]
327-
342+
328343
# 12 agents * 2 steps = 24 rows
329344
assert agent_df.shape == (24, 7)
330345
assert set(agent_df.columns) == {
331-
"unique_id", "agent_type", "wealth", "age", "step", "seed", "batch"
346+
"unique_id",
347+
"agent_type",
348+
"wealth",
349+
"age",
350+
"step",
351+
"seed",
352+
"batch",
332353
}
333354

334355
df_step_2 = agent_df.filter(pl.col("step") == 2)
335356
expected_wealth_s2 = [3, 4, 5, 6] + [14, 24, 34, 44] + [1, 2, 3, 4]
336357
expected_age_s2 = [10, 20, 30, 40] + [11, 22, 33, 44] + [3, 4, 5, 6]
337-
358+
338359
assert df_step_2.shape == (12, 7)
339360
assert sorted(df_step_2["wealth"].to_list()) == sorted(expected_wealth_s2)
340361
assert sorted(df_step_2["age"].to_list()) == sorted(expected_age_s2)
@@ -398,12 +419,18 @@ def test_flush_local_csv(self, fix1_model):
398419
)
399420
assert agent_df.shape == (12, 7)
400421
assert set(agent_df.columns) == {
401-
"unique_id", "agent_type", "wealth", "age", "step", "seed", "batch"
422+
"unique_id",
423+
"agent_type",
424+
"wealth",
425+
"age",
426+
"step",
427+
"seed",
428+
"batch",
402429
}
403-
430+
404431
expected_wealth_s2 = [3, 4, 5, 6] + [14, 24, 34, 44] + [1, 2, 3, 4]
405432
expected_age_s2 = [10, 20, 30, 40] + [11, 22, 33, 44] + [3, 4, 5, 6]
406-
433+
407434
assert agent_df["step"].to_list() == [2] * 12
408435
assert sorted(agent_df["wealth"].to_list()) == sorted(expected_wealth_s2)
409436
assert sorted(agent_df["age"].to_list()) == sorted(expected_age_s2)
@@ -460,9 +487,14 @@ def test_flush_local_parquet(self, fix1_model):
460487
# 12 rows. 6 cols: unique_id, agent_type, wealth, step, seed, batch
461488
assert agent_df.shape == (12, 6)
462489
assert set(agent_df.columns) == {
463-
"unique_id", "agent_type", "wealth", "step", "seed", "batch"
490+
"unique_id",
491+
"agent_type",
492+
"wealth",
493+
"step",
494+
"seed",
495+
"batch",
464496
}
465-
497+
466498
expected_wealth = [1, 2, 3, 4] + [10, 20, 30, 40] + [1, 2, 3, 4]
467499
assert agent_df["step"].to_list() == [0] * 12
468500
assert sorted(agent_df["wealth"].to_list()) == sorted(expected_wealth)
@@ -476,14 +508,14 @@ def test_postgress(self, fix1_model, postgres_uri):
476508

477509
# Connect directly and validate data
478510
import psycopg2
479-
511+
480512
try:
481513
conn = psycopg2.connect(postgres_uri)
482514
except psycopg2.OperationalError as e:
483515
pytest.skip(f"Could not connect to PostgreSQL: {e}")
484-
516+
485517
cur = conn.cursor()
486-
518+
487519
## Cleaning up tables first
488520
cur.execute("DROP TABLE IF EXISTS public.model_data, public.agent_data;")
489521
conn.commit()
@@ -529,7 +561,7 @@ def test_postgress(self, fix1_model, postgres_uri):
529561
storage_uri=postgres_uri,
530562
)
531563

532-
model.run_model_with_conditional_collect(4) ## Runs 1,2,3,4. Collects at 2, 4.
564+
model.run_model_with_conditional_collect(4) ## Runs 1,2,3,4. Collects at 2, 4.
533565
model.dc.flush()
534566

535567
# Connect directly and validate data
@@ -545,17 +577,27 @@ def test_postgress(self, fix1_model, postgres_uri):
545577
model_rows = cur.fetchall()
546578
assert model_rows == [(2, 12), (4, 12)]
547579

548-
# MODIFIED: Check agent data
580+
# MODIFIED: Check agent data
549581
cur.execute(
550582
"SELECT wealth, age FROM agent_data WHERE step=2 ORDER BY wealth, age"
551583
)
552584
agent_rows = cur.fetchall()
553585

554586
expected_rows_s2 = [
555-
(1, 3), (2, 4), (3, 5), (3, 10), (4, 6), (4, 20),
556-
(5, 30), (6, 40), (14, 11), (24, 22), (34, 33), (44, 44)
587+
(1, 3),
588+
(2, 4),
589+
(3, 5),
590+
(3, 10),
591+
(4, 6),
592+
(4, 20),
593+
(5, 30),
594+
(6, 40),
595+
(14, 11),
596+
(24, 22),
597+
(34, 33),
598+
(44, 44),
557599
]
558-
600+
559601
assert sorted(agent_rows) == sorted(expected_rows_s2)
560602

561603
cur.execute("DROP TABLE public.model_data, public.agent_data;")
@@ -573,7 +615,7 @@ def test_batch_memory(self, fix2_model):
573615
len(agentset) for agentset in model.sets._agentsets
574616
)
575617
},
576-
agent_reporters={ "wealth": "wealth", "age": "age" },
618+
agent_reporters={"wealth": "wealth", "age": "age"},
577619
)
578620

579621
model.run_model_with_conditional_collect_multiple_batch(5)
@@ -586,7 +628,13 @@ def test_batch_memory(self, fix2_model):
586628
agent_df = collected_data["agent"]
587629
assert agent_df.shape == (48, 7)
588630
assert set(agent_df.columns) == {
589-
"unique_id", "agent_type", "wealth", "age", "step", "seed", "batch"
631+
"unique_id",
632+
"agent_type",
633+
"wealth",
634+
"age",
635+
"step",
636+
"seed",
637+
"batch",
590638
}
591639

592640
df_s2_b0 = agent_df.filter((pl.col("step") == 2) & (pl.col("batch") == 0))
@@ -622,10 +670,10 @@ def test_batch_save(self, fix2_model):
622670

623671
model.run_model_with_conditional_collect_multiple_batch(5)
624672
model.dc.flush()
625-
673+
626674
for _ in range(20): # wait up to ~2 seconds
627675
created_files = os.listdir(tmpdir)
628-
if len(created_files) >= 8: # 4 collects * 2 files/collect = 8 files
676+
if len(created_files) >= 8: # 4 collects * 2 files/collect = 8 files
629677
break
630678
time.sleep(0.1)
631679

@@ -654,27 +702,30 @@ def test_batch_save(self, fix2_model):
654702
)
655703
assert model_df_step2_batch1["step"].to_list() == [2]
656704
assert model_df_step2_batch1["total_agents"].to_list() == [12]
657-
705+
658706
model_df_step4_batch0 = pl.read_csv(
659707
os.path.join(tmpdir, "model_step4_batch0.csv"),
660708
schema_overrides={"seed": pl.Utf8},
661709
)
662710
assert model_df_step4_batch0["step"].to_list() == [4]
663711
assert model_df_step4_batch0["total_agents"].to_list() == [12]
664-
712+
665713
# test agent batch reset
666714
agent_df_step2_batch0 = pl.read_csv(
667715
os.path.join(tmpdir, "agent_step2_batch0.csv"),
668716
schema_overrides={"seed": pl.Utf8, "unique_id": pl.UInt64},
669717
)
670718

671719
expected_wealth_s2b0 = [2, 3, 4, 5] + [12, 22, 32, 42] + [1, 2, 3, 4]
672-
assert sorted(agent_df_step2_batch0["wealth"].to_list()) == sorted(expected_wealth_s2b0)
720+
assert sorted(agent_df_step2_batch0["wealth"].to_list()) == sorted(
721+
expected_wealth_s2b0
722+
)
673723

674724
agent_df_step2_batch1 = pl.read_csv(
675725
os.path.join(tmpdir, "agent_step2_batch1.csv"),
676726
schema_overrides={"seed": pl.Utf8, "unique_id": pl.UInt64},
677727
)
678728
expected_wealth_s2b1 = [3, 4, 5, 6] + [14, 24, 34, 44] + [1, 2, 3, 4]
679-
assert sorted(agent_df_step2_batch1["wealth"].to_list()) == sorted(expected_wealth_s2b1)
680-
729+
assert sorted(agent_df_step2_batch1["wealth"].to_list()) == sorted(
730+
expected_wealth_s2b1
731+
)

0 commit comments

Comments
 (0)