diff --git a/modules/agents/updet_agent.py b/modules/agents/updet_agent.py index 097cc84..d63391e 100644 --- a/modules/agents/updet_agent.py +++ b/modules/agents/updet_agent.py @@ -16,6 +16,9 @@ def init_hidden(self): return torch.zeros(1, self.args.emb).cpu() def forward(self, inputs, hidden_state, task_enemy_num, task_ally_num): + + # Call the mask directly from the env.. As did in DGNs. + outputs, _ = self.transformer.forward(inputs, hidden_state, None) # first output for 6 action (no_op stop up down left right) q_basic_actions = self.q_basic(outputs[:, 0, :])