Skip to content

Commit 6806158

Browse files
committed
fix: Changed test file for datacollector
1 parent eb8d56c commit 6806158

File tree

1 file changed

+16
-4
lines changed

1 file changed

+16
-4
lines changed

tests/test_datacollector.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,9 @@ def test_collect(self, fix1_model):
185185
with pytest.raises(pl.exceptions.ColumnNotFoundError, match="max_wealth"):
186186
collected_data["model"]["max_wealth"]
187187

188-
assert collected_data["agent"].shape == (4, 7)
188+
assert collected_data["agent"].shape == (4, 8)
189189
assert set(collected_data["agent"].columns) == {
190+
"agent_id",
190191
"wealth",
191192
"age_ExampleAgentSet1",
192193
"age_ExampleAgentSet2",
@@ -195,6 +196,7 @@ def test_collect(self, fix1_model):
195196
"seed",
196197
"batch",
197198
}
199+
assert collected_data["agent"]["agent_id"].to_list() == [0, 1, 2, 3]
198200
assert collected_data["agent"]["wealth"].to_list() == [1, 2, 3, 4]
199201
assert collected_data["agent"]["age_ExampleAgentSet1"].to_list() == [
200202
10,
@@ -242,8 +244,9 @@ def test_collect_step(self, fix1_model):
242244
assert collected_data["model"]["step"].to_list() == [5]
243245
assert collected_data["model"]["total_agents"].to_list() == [12]
244246

245-
assert collected_data["agent"].shape == (4, 7)
247+
assert collected_data["agent"].shape == (4, 8)
246248
assert set(collected_data["agent"].columns) == {
249+
"agent_id",
247250
"wealth",
248251
"age_ExampleAgentSet1",
249252
"age_ExampleAgentSet2",
@@ -252,6 +255,7 @@ def test_collect_step(self, fix1_model):
252255
"seed",
253256
"batch",
254257
}
258+
assert collected_data["agent"]["agent_id"].to_list() == [0, 1, 2, 3]
255259
assert collected_data["agent"]["wealth"].to_list() == [6, 7, 8, 9]
256260
assert collected_data["agent"]["age_ExampleAgentSet1"].to_list() == [
257261
10,
@@ -297,8 +301,9 @@ def test_conditional_collect(self, fix1_model):
297301
assert collected_data["model"]["step"].to_list() == [2, 4]
298302
assert collected_data["model"]["total_agents"].to_list() == [12, 12]
299303

300-
assert collected_data["agent"].shape == (8, 7)
304+
assert collected_data["agent"].shape == (8, 8)
301305
assert set(collected_data["agent"].columns) == {
306+
"agent_id",
302307
"wealth",
303308
"age_ExampleAgentSet1",
304309
"age_ExampleAgentSet2",
@@ -308,6 +313,7 @@ def test_conditional_collect(self, fix1_model):
308313
"batch",
309314
}
310315
assert set(collected_data["agent"].columns) == {
316+
"agent_id",
311317
"wealth",
312318
"age_ExampleAgentSet1",
313319
"age_ExampleAgentSet2",
@@ -399,6 +405,7 @@ def test_flush_local_csv(self, fix1_model):
399405
schema_overrides={"seed": pl.Utf8},
400406
)
401407
assert set(agent_df.columns) == {
408+
"agent_id",
402409
"wealth",
403410
"age_ExampleAgentSet1",
404411
"age_ExampleAgentSet2",
@@ -580,8 +587,9 @@ def test_batch_memory(self, fix2_model):
580587
assert collected_data["model"]["batch"].to_list() == [0, 1, 0, 1]
581588
assert collected_data["model"]["total_agents"].to_list() == [12, 12, 12, 12]
582589

583-
assert collected_data["agent"].shape == (16, 7)
590+
assert collected_data["agent"].shape == (16, 8)
584591
assert set(collected_data["agent"].columns) == {
592+
"agent_id",
585593
"wealth",
586594
"age_ExampleAgentSet1",
587595
"age_ExampleAgentSet2",
@@ -592,6 +600,7 @@ def test_batch_memory(self, fix2_model):
592600
}
593601

594602
assert set(collected_data["agent"].columns) == {
603+
"agent_id",
595604
"wealth",
596605
"age_ExampleAgentSet1",
597606
"age_ExampleAgentSet2",
@@ -779,6 +788,7 @@ def test_batch_save(self, fix2_model):
779788
schema_overrides={"seed": pl.Utf8},
780789
)
781790
assert set(agent_df_step2_batch0.columns) == {
791+
"agent_id",
782792
"wealth",
783793
"age_ExampleAgentSet1",
784794
"age_ExampleAgentSet2",
@@ -813,6 +823,7 @@ def test_batch_save(self, fix2_model):
813823
schema_overrides={"seed": pl.Utf8},
814824
)
815825
assert set(agent_df_step2_batch1.columns) == {
826+
"agent_id",
816827
"wealth",
817828
"age_ExampleAgentSet1",
818829
"age_ExampleAgentSet2",
@@ -847,6 +858,7 @@ def test_batch_save(self, fix2_model):
847858
schema_overrides={"seed": pl.Utf8},
848859
)
849860
assert set(agent_df_step4_batch0.columns) == {
861+
"agent_id",
850862
"wealth",
851863
"age_ExampleAgentSet1",
852864
"age_ExampleAgentSet2",

0 commit comments

Comments
 (0)