|
19 | 19 | # Import necessary libraries |
20 | 20 | import numpy as np |
21 | 21 | import seaborn as sns |
22 | | -from mesa_models.boltzmann_wealth_model.model import ( |
23 | | - BoltzmannWealthModel, |
24 | | - MoneyAgent, |
25 | | - compute_gini, |
26 | | -) |
| 22 | +from mesa.examples.basic.boltzmann_wealth_model.agents import MoneyAgent |
| 23 | +from mesa.examples.basic.boltzmann_wealth_model.model import BoltzmannWealth |
27 | 24 |
|
28 | 25 | NUM_AGENTS = 10 |
29 | 26 |
|
30 | 27 |
|
31 | 28 | # Define the agent class |
32 | 29 | class MoneyAgentRL(MoneyAgent): |
33 | | - def __init__(self, unique_id, model): |
34 | | - super().__init__(unique_id, model) |
| 30 | + def __init__(self, model): |
| 31 | + super().__init__(model) |
35 | 32 | self.wealth = np.random.randint(1, NUM_AGENTS) |
36 | 33 |
|
37 | 34 | def move(self, action): |
@@ -74,45 +71,46 @@ def take_money(self): |
74 | 71 |
|
75 | 72 | def step(self): |
76 | 73 | # Get the action for the agent |
77 | | - action = self.model.action_dict[self.unique_id] |
| 74 | + # TODO: figure out why agents are being made twice |
| 75 | + action = self.model.action_dict[self.unique_id - 11] |
78 | 76 | # Move the agent based on the action |
79 | 77 | self.move(action) |
80 | 78 | # Take money from other agents in the same cell |
81 | 79 | self.take_money() |
82 | 80 |
|
83 | 81 |
|
84 | 82 | # Define the model class |
85 | | -class BoltzmannWealthModelRL(BoltzmannWealthModel, gymnasium.Env): |
86 | | - def __init__(self, N, width, height): |
87 | | - super().__init__(N, width, height) |
| 83 | +class BoltzmannWealthModelRL(BoltzmannWealth, gymnasium.Env): |
| 84 | + def __init__(self, n, width, height): |
| 85 | + super().__init__(n, width, height) |
88 | 86 | # Define the observation and action space for the RL model |
89 | 87 | # The observation space is the wealth of each agent and their position |
90 | | - self.observation_space = gymnasium.spaces.Box(low=0, high=10 * N, shape=(N, 3)) |
| 88 | + self.observation_space = gymnasium.spaces.Box(low=0, high=10 * n, shape=(n, 3)) |
91 | 89 | # The action space is a MultiDiscrete space with 5 possible actions for each agent |
92 | | - self.action_space = gymnasium.spaces.MultiDiscrete([5] * N) |
| 90 | + self.action_space = gymnasium.spaces.MultiDiscrete([5] * n) |
93 | 91 | self.is_visualize = False |
94 | 92 |
|
95 | 93 | def step(self, action): |
96 | 94 | self.action_dict = action |
97 | 95 | # Perform one step of the model |
98 | | - self.schedule.step() |
| 96 | + self.agents.shuffle_do("step") |
99 | 97 | # Collect data for visualization |
100 | 98 | self.datacollector.collect(self) |
101 | 99 | # Compute the new Gini coefficient |
102 | | - new_gini = compute_gini(self) |
| 100 | + new_gini = self.compute_gini() |
103 | 101 | # Compute the reward based on the change in Gini coefficient |
104 | 102 | reward = self.calculate_reward(new_gini) |
105 | 103 | self.prev_gini = new_gini |
106 | 104 | # Get the observation for the RL model |
107 | 105 | obs = self._get_obs() |
108 | | - if self.schedule.time > 5 * NUM_AGENTS: |
| 106 | + if self.time > 5 * NUM_AGENTS: |
109 | 107 | # Terminate the episode if the model has run for a certain number of timesteps |
110 | 108 | done = True |
111 | 109 | reward = -1 |
112 | 110 | elif new_gini < 0.1: |
113 | 111 | # Terminate the episode if the Gini coefficient is below a certain threshold |
114 | 112 | done = True |
115 | | - reward = 50 / self.schedule.time |
| 113 | + reward = 50 / self.time |
116 | 114 | else: |
117 | 115 | done = False |
118 | 116 | info = {} |
@@ -142,20 +140,18 @@ def reset(self, *, seed=None, options=None): |
142 | 140 | self.visualize() |
143 | 141 | super().reset() |
144 | 142 | self.grid = mesa.space.MultiGrid(self.grid.width, self.grid.height, True) |
145 | | - self.schedule = mesa.time.RandomActivation(self) |
| 143 | + self.remove_all_agents() |
146 | 144 | for i in range(self.num_agents): |
147 | 145 | # Create MoneyAgentRL instances and add them to the schedule |
148 | | - a = MoneyAgentRL(i, self) |
149 | | - self.schedule.add(a) |
| 146 | + a = MoneyAgentRL(self) |
150 | 147 | x = self.random.randrange(self.grid.width) |
151 | 148 | y = self.random.randrange(self.grid.height) |
152 | 149 | self.grid.place_agent(a, (x, y)) |
153 | | - self.prev_gini = compute_gini(self) |
| 150 | + self.prev_gini = self.compute_gini() |
154 | 151 | return self._get_obs(), {} |
155 | 152 |
|
156 | 153 | def _get_obs(self): |
157 | 154 | # The observation is the wealth of each agent and their position |
158 | | - obs = [] |
159 | | - for a in self.schedule.agents: |
160 | | - obs.append([a.wealth, *list(a.pos)]) |
| 155 | + obs = [[a.wealth, *a.pos] for a in self.agents] |
| 156 | + obs = np.array(obs) |
161 | 157 | return np.array(obs) |
0 commit comments