1111๋ถ์ฐ ๋ชจ๋ธ ๋ณ๋ ฌ ์ฒ๋ฆฌ(distributed model parallelism)๋ฅผ ๊ฒฐํฉํ์ฌ ๊ฐ๋จํ ๋ชจ๋ธ ํ์ต์ํฌ ๋
1212`๋ถ์ฐ ๋ฐ์ดํฐ ๋ณ๋ ฌ(DistributedDataParallel) <https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel >`__ (DDP)๊ณผ
1313`๋ถ์ฐ RPC ํ๋ ์์ํฌ(Distributed RPC framework) <https://pytorch.org/docs/master/rpc.html >`__ ๋ฅผ ๊ฒฐํฉํ๋ ๋ฐฉ๋ฒ์ ๋ํด ์ค๋ช
ํฉ๋๋ค.
14- ์์ ์ ์์ค ์ฝ๋๋ `์ฌ๊ธฐ <https://github.com/pytorch/examples/tree/master/distributed/rpc/ddp_rpc>`__์์ ํ์ธํ ์ ์์ต๋๋ค.
14+ ์์ ์ ์์ค ์ฝ๋๋ `์ฌ๊ธฐ <https://github.com/pytorch/examples/tree/master/distributed/rpc/ddp_rpc >`__ ์์ ํ์ธํ ์ ์์ต๋๋ค.
1515
1616์ด์ ํํ ๋ฆฌ์ผ ๋ด์ฉ์ด์๋
17- `๋ถ์ฐ ๋ฐ์ดํฐ ๋ณ๋ ฌ ์์ํ๊ธฐ <https://tutorials.pytorch.kr/intermediate/ddp_tutorial.html>`__์
18- `๋ถ์ฐ RPC ํ๋ ์์ํฌ ์์ํ๊ธฐ <https://tutorials.pytorch.kr/intermediate/rpc_tutorial.html>`__๋
17+ `๋ถ์ฐ ๋ฐ์ดํฐ ๋ณ๋ ฌ ์์ํ๊ธฐ <https://tutorials.pytorch.kr/intermediate/ddp_tutorial.html >`__ ์
18+ `๋ถ์ฐ RPC ํ๋ ์์ํฌ ์์ํ๊ธฐ <https://tutorials.pytorch.kr/intermediate/rpc_tutorial.html >`__ ๋
1919๋ถ์ฐ ๋ฐ์ดํฐ ๋ณ๋ ฌ ๋ฐ ๋ถ์ฐ ๋ชจ๋ธ ๋ณ๋ ฌ ํ์ต์ ๊ฐ๊ฐ ์ํํ๋ ๋ฐฉ๋ฒ์ ๋ํด ์ค๋ช
ํฉ๋๋ค.
2020๊ทธ๋ฌ๋ ์ด ๋ ๊ฐ์ง ๊ธฐ์ ์ ๊ฒฐํฉํ ์ ์๋ ๋ช ๊ฐ์ง ํ์ต ํจ๋ฌ๋ค์์ด ์์ต๋๋ค. ์๋ฅผ ๋ค์ด:
2121
22221) ํฌ์ ๋ถ๋ถ(ํฐ ์๋ฒ ๋ฉ ํ
์ด๋ธ)๊ณผ ๋ฐ์ง ๋ถ๋ถ(FC ๋ ์ด์ด)์ด ์๋ ๋ชจ๋ธ์ด ์๋ ๊ฒฝ์ฐ,
23- ๋งค๊ฐ๋ณ์ ์๋ฒ(parameter server)์ ์๋ฒ ๋ฉ ํ
์ด๋ธ(embedding table)์ ๋๊ณ `๋ถ์ฐ ๋ฐ์ดํฐ ๋ณ๋ ฌ <https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel>`__์ ์ฌ์ฉํ์ฌ
23+ ๋งค๊ฐ๋ณ์ ์๋ฒ(parameter server)์ ์๋ฒ ๋ฉ ํ
์ด๋ธ(embedding table)์ ๋๊ณ `๋ถ์ฐ ๋ฐ์ดํฐ ๋ณ๋ ฌ <https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel >`__ ์ ์ฌ์ฉํ์ฌ
2424 ์ฌ๋ฌ ํธ๋ ์ด๋์ ๊ฑธ์ณ FC ๋ ์ด์ด๋ฅผ ๋ณต์ ํ๋ ๊ฒ์ ์ํ ์๋ ์์ต๋๋ค.
25- ์ด๋ `๋ถ์ฐ RPC ํ๋ ์์ํฌ <https://pytorch.org/docs/master/rpc.html>`__๋
25+ ์ด๋ `๋ถ์ฐ RPC ํ๋ ์์ํฌ <https://pytorch.org/docs/master/rpc.html >`__ ๋
2626 ๋งค๊ฐ๋ณ์ ์๋ฒ์์ ์๋ฒ ๋ฉ ์ฐพ๊ธฐ ์์
(embedding lookup)์ ์ํํ๋ ๋ฐ ์ฌ์ฉํ ์ ์์ต๋๋ค.
27272) ๋ค์์ `PipeDream <https://arxiv.org/abs/1806.03377 >`__ ๋ฌธ์์์ ์ค๋ช
๋ ํ์ด๋ธ๋ฆฌ๋ ๋ณ๋ ฌ ์ฒ๋ฆฌ ํ์ฑํํ๊ธฐ ์
๋๋ค.
2828 `๋ถ์ฐ RPC ํ๋ ์์ํฌ <https://pytorch.org/docs/master/rpc.html >`__ ๋ฅผ ์ฌ์ฉํ์ฌ
2929 ์ฌ๋ฌ worker์ ๊ฑธ์ณ ๋ชจ๋ธ์ ๋จ๊ณ๋ฅผ ํ์ดํ๋ผ์ธ(pipeline)ํ ์ ์๊ณ
30- (ํ์์ ๋ฐ๋ผ) `๋ถ์ฐ ๋ฐ์ดํฐ ๋ณ๋ ฌ <https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel>`__์ ์ด์ฉํด์
30+ (ํ์์ ๋ฐ๋ผ) `๋ถ์ฐ ๋ฐ์ดํฐ ๋ณ๋ ฌ <https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel >`__ ์ ์ด์ฉํด์
3131 ๊ฐ ๋จ๊ณ๋ฅผ ๋ณต์ ํ ์ ์์ต๋๋ค.
3232
3333|
38381) 1๊ฐ์ ๋ง์คํฐ๋ ๋งค๊ฐ๋ณ์ ์๋ฒ์ ์๋ฒ ๋ฉ ํ
์ด๋ธ(nn.EmbeddingBag) ์์ฑ์ ๋ด๋นํฉ๋๋ค.
3939 ๋ํ ๋ง์คํฐ๋ ๋ ํธ๋ ์ด๋์ ํ์ต ๋ฃจํ๋ฅผ ์ํํฉ๋๋ค.
40402) 1๊ฐ์ ๋งค๊ฐ๋ณ์ ์๋ฒ๋ ๊ธฐ๋ณธ์ ์ผ๋ก ๋ฉ๋ชจ๋ฆฌ์ ์๋ฒ ๋ฉ ํ
์ด๋ธ์ ๋ณด์ ํ๊ณ ๋ง์คํฐ ๋ฐ ํธ๋ ์ด๋์ RPC์ ์๋ตํฉ๋๋ค.
41- 3) 2๊ฐ์ ํธ๋ ์ด๋๋ `๋ถ์ฐ ๋ฐ์ดํฐ ๋ณ๋ ฌ <https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel>`__์
41+ 3) 2๊ฐ์ ํธ๋ ์ด๋๋ `๋ถ์ฐ ๋ฐ์ดํฐ ๋ณ๋ ฌ <https://pytorch.org/docs/stable/nn.html#torch.nn.parallel.DistributedDataParallel >`__ ์
4242 ์ฌ์ฉํ์ฌ ์์ฒด์ ์ผ๋ก ๋ณต์ ๋๋ FC ๋ ์ด์ด(nn.Linear)๋ฅผ ์ ์ฅํฉ๋๋ค.
4343 ํธ๋ ์ด๋๋ ๋ํ ์๋ฐฉํฅ ์ ๋ฌ(forward pass), ์ญ๋ฐฉํฅ ์ ๋ฌ(backward pass) ๋ฐ ์ต์ ํ ๋จ๊ณ๋ฅผ ์คํํด์ผ ํฉ๋๋ค.
4444
4545|
4646 ์ ์ฒด์ ์ธ ํ์ต๊ณผ์ ์ ๋ค์๊ณผ ๊ฐ์ด ์คํ๋ฉ๋๋ค:
4747
48481) ๋ง์คํฐ๋ ๋งค๊ฐ๋ณ์ ์๋ฒ์ ์๋ฒ ๋ฉ ํ
์ด๋ธ์ ๋ด๊ณ ์๋
49- `์๊ฒฉ ๋ชจ๋(RemoteModule) <https://pytorch.org/docs/master/rpc.html#remotemodule>`__์ ์์ฑํฉ๋๋ค.
49+ `์๊ฒฉ ๋ชจ๋(RemoteModule) <https://pytorch.org/docs/master/rpc.html#remotemodule >`__ ์ ์์ฑํฉ๋๋ค.
50502) ๊ทธ๋ฐ ๋ค์ ๋ง์คํฐ๋ ํธ๋ ์ด๋์ ํ์ต ๋ฃจํ๋ฅผ ์์ํ๊ณ ์๊ฒฉ ๋ชจ๋(remote module)์ ํธ๋ ์ด๋์๊ฒ ์ ๋ฌํฉ๋๋ค.
51513) ํธ๋ ์ด๋๋ ๋จผ์ ๋ง์คํฐ์์ ์ ๊ณตํ๋ ์๊ฒฉ ๋ชจ๋์ ์ฌ์ฉํ์ฌ
52- ์๋ฒ ๋ฉ ์ฐพ๊ธฐ ์์
(embedding lookup)์ ์ํํ ๋ค์ DDP ๋ด๋ถ์ ๊ฐ์ธ์ง FC ๋ ์ด์ด๋ฅผ ์คํํ๋ ``HybridModel``์ ์์ฑํฉ๋๋ค.
52+ ์๋ฒ ๋ฉ ์ฐพ๊ธฐ ์์
(embedding lookup)์ ์ํํ ๋ค์ DDP ๋ด๋ถ์ ๊ฐ์ธ์ง FC ๋ ์ด์ด๋ฅผ ์คํํ๋ ``HybridModel `` ์ ์์ฑํฉ๋๋ค.
53534) ํธ๋ ์ด๋๋ ๋ชจ๋ธ์ ์๋ฐฉํฅ ์ ๋ฌ์ ์คํํ๊ณ ์์ค์ ์ฌ์ฉํ์ฌ `๋ถ์ฐ Autograd <https://pytorch.org/docs/master/rpc.html#distributed-autograd-framework >`__ ๋ฅผ
5454 ์ฌ์ฉํ์ฌ ์ญ๋ฐฉํฅ ์ ๋ฌ์ ์คํํฉ๋๋ค.
55555) ์ญ๋ฐฉํฅ ์ ๋ฌ์ ์ผ๋ถ๋ก FC ๋ ์ด์ด์ ๋ณํ๋๊ฐ ๋จผ์ ๊ณ์ฐ๋๊ณ DDP์ allreduce๋ฅผ ํตํด ๋ชจ๋ ํธ๋ ์ด๋์ ๋๊ธฐํ๋ฉ๋๋ค.
56566) ๋ค์์ผ๋ก, ๋ถ์ฐ Autograd๋ ๋งค๊ฐ๋ณ์ ์๋ฒ๋ก ๋ณํ๋๋ฅผ ์ ํํ๊ณ ๊ทธ๊ณณ์์ ์๋ฒ ๋ฉ ํ
์ด๋ธ์ ๋ณํ๋๊ฐ ์
๋ฐ์ดํธ๋ฉ๋๋ค.
57- 7) ๋ง์ง๋ง์ผ๋ก, `๋ถ์ฐ ์ตํฐ๋ง์ด์ (DistributedOptimizer) <https://pytorch.org/docs/master/rpc.html#module-torch.distributed.optim>`__๋ ๋ชจ๋ ๋งค๊ฐ๋ณ์๋ฅผ ์
๋ฐ์ดํธํ๋ ๋ฐ ์ฌ์ฉ๋ฉ๋๋ค.
57+ 7) ๋ง์ง๋ง์ผ๋ก, `๋ถ์ฐ ์ตํฐ๋ง์ด์ (DistributedOptimizer) <https://pytorch.org/docs/master/rpc.html#module-torch.distributed.optim >`__ ๋ ๋ชจ๋ ๋งค๊ฐ๋ณ์๋ฅผ ์
๋ฐ์ดํธํ๋ ๋ฐ ์ฌ์ฉ๋ฉ๋๋ค.
5858
5959.. warning ::
6060
6868
6969TCP init_method๋ฅผ ์ฌ์ฉํ์ฌ 4๊ฐ์ ๋ชจ๋ worker์์ RPC ํ๋ ์์ํฌ๋ฅผ ์ด๊ธฐํํฉ๋๋ค.
7070RPC ์ด๊ธฐํ๊ฐ ๋๋๋ฉด, ๋ง์คํฐ๋ `EmbeddingBag <https://pytorch.org/docs/master/generated/torch.nn.EmbeddingBag.html >`__ ๋ ์ด์ด๋ฅผ
71- `์๊ฒฉ ๋ชจ๋(RemoteModule) <https://pytorch.org/docs/master/rpc.html#remotemodule>`__์ ์ฌ์ฉํ์ฌ
71+ `์๊ฒฉ ๋ชจ๋(RemoteModule) <https://pytorch.org/docs/master/rpc.html#remotemodule >`__ ์ ์ฌ์ฉํ์ฌ
7272๋งค๊ฐ๋ณ์ ์๋ฒ์ ๋ด๊ณ ์๋ ์๊ฒฉ ๋ชจ๋ ํ๋๋ฅผ ์์ฑํฉ๋๋ค.
7373๊ทธ๋ฐ ๋ค์ ๋ง์คํฐ๋ ๊ฐ ํธ๋ ์ด๋๋ฅผ ๋ฐ๋ณตํ๊ณ `rpc_async <https://pytorch.org/docs/master/rpc.html#torch.distributed.rpc.rpc_async >`__ ๋ฅผ
74- ์ฌ์ฉํ์ฌ ๊ฐ ํธ๋ ์ด๋์์ ``_run_trainer``๋ฅผ ํธ์ถํ์ฌ ๋ฐ๋ณต ํ์ต์ ์์ํฉ๋๋ค.
74+ ์ฌ์ฉํ์ฌ ๊ฐ ํธ๋ ์ด๋์์ ``_run_trainer `` ๋ฅผ ํธ์ถํ์ฌ ๋ฐ๋ณต ํ์ต์ ์์ํฉ๋๋ค.
7575๋ง์ง๋ง์ผ๋ก ๋ง์คํฐ๋ ์ข
๋ฃํ๊ธฐ ์ ์ ๋ชจ๋ ํ์ต์ด ์๋ฃ๋ ๋๊น์ง ๊ธฐ๋ค๋ฆฝ๋๋ค.
7676
77- ํธ๋ ์ด๋๋ `init_process_group <https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group>`__์ ์ฌ์ฉํ์ฌ
78- (2๊ฐ์ ํธ๋ ์ด๋) world_size=2๋ก DDP๋ฅผ ์ํด ``ProcessGroup``์ ์ด๊ธฐํํฉ๋๋ค.
77+ ํธ๋ ์ด๋๋ `init_process_group <https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group >`__ ์ ์ฌ์ฉํ์ฌ
78+ (2๊ฐ์ ํธ๋ ์ด๋) world_size=2๋ก DDP๋ฅผ ์ํด ``ProcessGroup `` ์ ์ด๊ธฐํํฉ๋๋ค.
7979๋ค์์ผ๋ก TCP init_method๋ฅผ ์ฌ์ฉํ์ฌ RPC ํ๋ ์์ํฌ๋ฅผ ์ด๊ธฐํํฉ๋๋ค.
8080์ฌ๊ธฐ์ ์ฃผ์ ํ ์ ์ RPC ์ด๊ธฐํ์ ProgressGroup ์ด๊ธฐํ์์ ์ฐ์ด๋ ํฌํธ(port)๊ฐ ๋ค๋ฅด๋ค๋ ๊ฒ์
๋๋ค.
8181์ด๋ ๋ ํ๋ ์์ํฌ์ ์ด๊ธฐํ ๊ฐ์ ํฌํธ ์ถฉ๋์ ํผํ๊ธฐ ์ํด์ ์
๋๋ค.
82- ์ด๊ธฐํ๊ฐ ์๋ฃ๋๋ฉด ํธ๋ ์ด๋๋ ๋ง์คํฐ์ ``_run_trainer` RPC๋ฅผ ๊ธฐ๋ค๋ฆฌ๊ธฐ๋ง ํ๋ฉด ๋ฉ๋๋ค.
82+ ์ด๊ธฐํ๊ฐ ์๋ฃ๋๋ฉด ํธ๋ ์ด๋๋ ๋ง์คํฐ์ ``_run_trainer `` RPC๋ฅผ ๊ธฐ๋ค๋ฆฌ๊ธฐ๋ง ํ๋ฉด ๋ฉ๋๋ค.
8383
8484ํ๋ผํผํฐ ์๋ฒ๋ RPC ํ๋ ์์ํฌ๋ฅผ ์ด๊ธฐํํ๊ณ ํธ๋ ์ด๋์ ๋ง์คํฐ์ RPC๋ฅผ ๊ธฐ๋ค๋ฆฝ๋๋ค.
8585
@@ -89,14 +89,14 @@ RPC ์ด๊ธฐํ๊ฐ ๋๋๋ฉด, ๋ง์คํฐ๋ `EmbeddingBag <https://pytorch.org/docs
8989 :start-after: BEGIN run_worker
9090 :end-before: END run_worker
9191
92- ํธ๋ ์ด๋์ ๋ํ ์์ธํ ์ค๋ช
์ ์์, ํธ๋ ์ด๋๊ฐ ์ฌ์ฉํ๋ ``HybridModel``์ ๋ํด ์ค๋ช
๋๋ฆฌ๊ฒ ์ต๋๋ค.
93- ์๋์ ์ค๋ช
๋ ๋๋ก ``HybridModel``์ ๋งค๊ฐ๋ณ์ ์๋ฒ์ ์๋ฒ ๋ฉ ํ
์ด๋ธ(``remote_emb_module ``)๊ณผ DDP์ ์ฌ์ฉํ ``device``๋ฅผ ๋ณด์ ํ๋ ์๊ฒฉ ๋ชจ๋์ ์ฌ์ฉํ์ฌ ์ด๊ธฐํ๋ฉ๋๋ค.
92+ ํธ๋ ์ด๋์ ๋ํ ์์ธํ ์ค๋ช
์ ์์, ํธ๋ ์ด๋๊ฐ ์ฌ์ฉํ๋ ``HybridModel `` ์ ๋ํด ์ค๋ช
๋๋ฆฌ๊ฒ ์ต๋๋ค.
93+ ์๋์ ์ค๋ช
๋ ๋๋ก ``HybridModel `` ์ ๋งค๊ฐ๋ณ์ ์๋ฒ์ ์๋ฒ ๋ฉ ํ
์ด๋ธ(``remote_emb_module ``)๊ณผ DDP์ ์ฌ์ฉํ ``device `` ๋ฅผ ๋ณด์ ํ๋ ์๊ฒฉ ๋ชจ๋์ ์ฌ์ฉํ์ฌ ์ด๊ธฐํ๋ฉ๋๋ค.
9494๋ชจ๋ธ ์ด๊ธฐํ๋ DDP ๋ด๋ถ์ `nn.Linear <https://pytorch.org/docs/master/generated/torch.nn.Linear.html >`__ ๋ ์ด์ด๋ฅผ
9595๊ฐ์ธ ๋ชจ๋ ํธ๋ ์ด๋์์ ์ด ๋ ์ด์ด๋ฅผ ๋ณต์ ํ๊ณ ๋๊ธฐํํฉ๋๋ค.
9696
9797
9898๋ชจ๋ธ์ ์๋ฐฉํฅ(forward) ํจ์๋ ๊ฝค ๊ฐ๋จํฉ๋๋ค.
99- RemoteModule์ ``forward``๋ฅผ ์ฌ์ฉํ์ฌ ๋งค๊ฐ๋ณ์ ์๋ฒ์์ ์๋ฒ ๋ฉ ์ฐพ๊ธฐ ์์
(embedding lookup)์ ์ํํ๊ณ ๊ทธ ์ถ๋ ฅ์ FC ๋ ์ด์ด์ ์ ๋ฌํฉ๋๋ค.
99+ RemoteModule์ ``forward `` ๋ฅผ ์ฌ์ฉํ์ฌ ๋งค๊ฐ๋ณ์ ์๋ฒ์์ ์๋ฒ ๋ฉ ์ฐพ๊ธฐ ์์
(embedding lookup)์ ์ํํ๊ณ ๊ทธ ์ถ๋ ฅ์ FC ๋ ์ด์ด์ ์ ๋ฌํฉ๋๋ค.
100100
101101
102102.. literalinclude :: ../advanced_source/rpc_ddp_tutorial/main.py
@@ -106,18 +106,18 @@ RemoteModule์ ``forward``๋ฅผ ์ฌ์ฉํ์ฌ ๋งค๊ฐ๋ณ์ ์๋ฒ์์ ์๋ฒ ๋ฉ
106106
107107๋ค์์ผ๋ก ํธ๋ ์ด๋์ ์ค์ ์ ์ดํด๋ณด๊ฒ ์ต๋๋ค.
108108ํธ๋ ์ด๋๋ ๋จผ์ ๋งค๊ฐ๋ณ์ ์๋ฒ์ ์๋ฒ ๋ฉ ํ
์ด๋ธ๊ณผ ์์ฒด ์์๋ฅผ ๋ณด์ ํ๋ ์๊ฒฉ ๋ชจ๋์ ์ฌ์ฉํ์ฌ
109- ์์์ ์ค๋ช
ํ ``HybridModel``์ ์์ฑํฉ๋๋ค.
109+ ์์์ ์ค๋ช
ํ ``HybridModel `` ์ ์์ฑํฉ๋๋ค.
110110
111- ์ด์ `๋ถ์ฐ ์ตํฐ๋ง์ด์ (DistributedOptimizer) <https://pytorch.org/docs/master/rpc.html#module-torch.distributed.optim>`__๋ก
111+ ์ด์ `๋ถ์ฐ ์ตํฐ๋ง์ด์ (DistributedOptimizer) <https://pytorch.org/docs/master/rpc.html#module-torch.distributed.optim >`__ ๋ก
112112์ต์ ํํ๋ ค๋ ๋ชจ๋ ๋งค๊ฐ๋ณ์์ ๋ํ RRef ๋ชฉ๋ก์ ๊ฒ์ํด์ผ ํฉ๋๋ค.
113113๋งค๊ฐ๋ณ์ ์๋ฒ์์ ์๋ฒ ๋ฉ ํ
์ด๋ธ์ ๋งค๊ฐ๋ณ์๋ฅผ ๊ฒ์ํ๊ธฐ ์ํด
114114RemoteModule์ `remote_parameters <https://pytorch.org/docs/master/rpc.html#torch.distributed.nn.api.remote_module.RemoteModule.remote_parameters >`__ ๋ฅผ ํธ์ถํ ์ ์์ต๋๋ค.
115115๊ทธ๋ฆฌ๊ณ ์ด๊ฒ์ ๊ธฐ๋ณธ์ ์ผ๋ก ์๋ฒ ๋ฉ ํ
์ด๋ธ์ ๋ชจ๋ ๋งค๊ฐ๋ณ์๋ฅผ ์ดํด๋ณด๊ณ RRef ๋ชฉ๋ก์ ๋ฐํํฉ๋๋ค.
116116ํธ๋ ์ด๋๋ RPC๋ฅผ ํตํด ๋งค๊ฐ๋ณ์ ์๋ฒ์์ ์ด ๋ฉ์๋๋ฅผ ํธ์ถํ์ฌ ์ํ๋ ๋งค๊ฐ๋ณ์์ ๋ํ RRef ๋ชฉ๋ก์ ์์ ํฉ๋๋ค.
117117DistributedOptimizer๋ ํญ์ ์ต์ ํํด์ผ ํ๋ ๋งค๊ฐ๋ณ์์ ๋ํ RRef ๋ชฉ๋ก์ ๊ฐ์ ธ์ค๊ธฐ ๋๋ฌธ์ FC ๋ ์ด์ด์ ์ ์ญ ๋งค๊ฐ๋ณ์์ ๋ํด์๋ RRef๋ฅผ ์์ฑํด์ผ ํฉ๋๋ค.
118- ์ด๊ฒ์ ``model.fc.parameters()``๋ฅผ ํ์ํ๊ณ ๊ฐ ๋งค๊ฐ๋ณ์์ ๋ํ RRef๋ฅผ ์์ฑํ๊ณ
119- ``remote_parameters()``์์ ๋ฐํ๋ ๋ชฉ๋ก์ ์ถ๊ฐํจ์ผ๋ก์จ ์ํ๋ฉ๋๋ค.
120- ์ฐธ๊ณ ๋ก ``model.parameters()``๋ ์ฌ์ฉํ ์ ์์ต๋๋ค. ``RemoteModule``์์ ์ง์ํ์ง ์๋ ``model.remote_emb_module.parameters()``๋ฅผ ์ฌ๊ท์ ์ผ๋ก ํธ์ถํ๊ธฐ ๋๋ฌธ์
๋๋ค.
118+ ์ด๊ฒ์ ``model.fc.parameters() `` ๋ฅผ ํ์ํ๊ณ ๊ฐ ๋งค๊ฐ๋ณ์์ ๋ํ RRef๋ฅผ ์์ฑํ๊ณ
119+ ``remote_parameters() `` ์์ ๋ฐํ๋ ๋ชฉ๋ก์ ์ถ๊ฐํจ์ผ๋ก์จ ์ํ๋ฉ๋๋ค.
120+ ์ฐธ๊ณ ๋ก ``model.parameters() `` ๋ ์ฌ์ฉํ ์ ์์ต๋๋ค. ``RemoteModule `` ์์ ์ง์ํ์ง ์๋ ``model.remote_emb_module.parameters() `` ๋ฅผ ์ฌ๊ท์ ์ผ๋ก ํธ์ถํ๊ธฐ ๋๋ฌธ์
๋๋ค.
121121
122122๋ง์ง๋ง์ผ๋ก ๋ชจ๋ RRef๋ฅผ ์ฌ์ฉํ์ฌ DistributedOptimizer๋ฅผ ๋ง๋ค๊ณ CrossEntropyLoss ํจ์๋ฅผ ์ ์ํฉ๋๋ค.
123123
@@ -127,7 +127,7 @@ DistributedOptimizer๋ ํญ์ ์ต์ ํํด์ผ ํ๋ ๋งค๊ฐ๋ณ์์ ๋ํ RRe
127127 :end-before: END setup_trainer
128128
129129์ด์ ๊ฐ ํธ๋ ์ด๋์์ ์คํ๋๋ ๊ธฐ๋ณธ ํ์ต ๋ฃจํ๋ฅผ ์๊ฐํ๊ฒ ์ต๋๋ค.
130- ``get_next_batch``๋ ํ์ต์ ์ํ ์์์ ์
๋ ฅ๊ณผ ๋์์ ์์ฑํ๋ ๊ฒ์ ๋์์ฃผ๋ ํจ์์ผ ๋ฟ์
๋๋ค.
130+ ``get_next_batch `` ๋ ํ์ต์ ์ํ ์์์ ์
๋ ฅ๊ณผ ๋์์ ์์ฑํ๋ ๊ฒ์ ๋์์ฃผ๋ ํจ์์ผ ๋ฟ์
๋๋ค.
131131์ฌ๋ฌ ์ํญ(epoch)๊ณผ ๊ฐ ๋ฐฐ์น(batch)์ ๋ํด ํ์ต ๋ฃจํ๋ฅผ ์คํํฉ๋๋ค:
132132
1331331) ๋จผ์ ๋ถ์ฐ Autograd์ ๋ํด
@@ -143,4 +143,4 @@ DistributedOptimizer๋ ํญ์ ์ต์ ํํด์ผ ํ๋ ๋งค๊ฐ๋ณ์์ ๋ํ RRe
143143 :end-before: END run_trainer
144144.. code :: python
145145
146- ์ ์ฒด ์์ ์ ์์ค ์ฝ๋๋ `์ฌ๊ธฐ <https://github.com/pytorch/examples/tree/master/distributed/rpc/ddp_rpc>`__์์ ์ฐพ์ ์ ์์ต๋๋ค.
146+ ์ ์ฒด ์์ ์ ์์ค ์ฝ๋๋ `์ฌ๊ธฐ <https://github.com/pytorch/examples/tree/master/distributed/rpc/ddp_rpc >`__ ์์ ์ฐพ์ ์ ์์ต๋๋ค.
0 commit comments