|
3 | 3 | import argparse |
4 | 4 | import os |
5 | 5 | import sys |
| 6 | +import math |
6 | 7 |
|
7 | 8 | import torch |
8 | 9 | import torch.optim as optim |
|
13 | 14 | from model import ActorCritic |
14 | 15 | from train import train |
15 | 16 | from test import test |
| 17 | +from utils import logger |
16 | 18 | import my_optim |
17 | 19 |
|
| 20 | +logger = logger.getLogger('main') |
| 21 | + |
18 | 22 | # Based on |
19 | 23 | # https://github.com/pytorch/examples/tree/master/mnist_hogwild |
20 | 24 | # Training settings |
|
37 | 41 | help='environment to train on (default: PongDeterministic-v3)') |
38 | 42 | parser.add_argument('--no-shared', default=False, metavar='O', |
39 | 43 | help='use an optimizer without shared momentum.') |
| 44 | +parser.add_argument('--max-iters', type=int, default=math.inf, |
| 45 | + help='maximum iterations per process.') |
40 | 46 |
|
| 47 | +parser.add_argument('--debug', action='store_true', default=False, |
| 48 | + help='run in a way its easier to debug') |
41 | 49 |
|
42 | 50 | if __name__ == '__main__': |
43 | 51 | args = parser.parse_args() |
44 | 52 |
|
45 | 53 | torch.manual_seed(args.seed) |
46 | | - |
47 | 54 | env = create_atari_env(args.env_name) |
48 | 55 | shared_model = ActorCritic( |
49 | 56 | env.observation_space.shape[0], env.action_space) |
|
55 | 62 | optimizer = my_optim.SharedAdam(shared_model.parameters(), lr=args.lr) |
56 | 63 | optimizer.share_memory() |
57 | 64 |
|
58 | | - processes = [] |
59 | | - |
60 | | - p = mp.Process(target=test, args=(args.num_processes, args, shared_model)) |
61 | | - p.start() |
62 | | - processes.append(p) |
| 65 | + |
| 66 | + if not args.debug: |
| 67 | + processes = [] |
63 | 68 |
|
64 | | - for rank in range(0, args.num_processes): |
65 | | - p = mp.Process(target=train, args=(rank, args, shared_model, optimizer)) |
| 69 | + p = mp.Process(target=test, args=(args.num_processes, args, shared_model)) |
66 | 70 | p.start() |
67 | 71 | processes.append(p) |
68 | | - for p in processes: |
69 | | - p.join() |
| 72 | + for rank in range(0, args.num_processes): |
| 73 | + p = mp.Process(target=train, args=(rank, args, shared_model, optimizer)) |
| 74 | + p.start() |
| 75 | + processes.append(p) |
| 76 | + for p in processes: |
| 77 | + p.join() |
| 78 | + else: ## debug is enabled |
| 79 | + # run only one process in a main, easier to debug |
| 80 | + train(0, args, shared_model, optimizer) |
0 commit comments