์ ์: Pritam Damania and Yi Wang
๋ฒ์ญ: ๋ฐ๋ค์
Note
|edit| ์ด ํํ ๋ฆฌ์ผ์ ์์ค ์ฝ๋๋ GitHub ์์ ํ์ธํ๊ณ ๋ณ๊ฒฝํด ๋ณผ ์ ์์ต๋๋ค.
์ด ํํ ๋ฆฌ์ผ์ ๊ฐ๋จํ ์์ ๋ฅผ ์ฌ์ฉํ์ฌ ๋ถ์ฐ ๋ฐ์ดํฐ ๋ณ๋ ฌ ์ฒ๋ฆฌ(distributed data parallelism)์ ๋ถ์ฐ ๋ชจ๋ธ ๋ณ๋ ฌ ์ฒ๋ฆฌ(distributed model parallelism)๋ฅผ ๊ฒฐํฉํ์ฌ ๊ฐ๋จํ ๋ชจ๋ธ ํ์ต์ํฌ ๋ ๋ถ์ฐ ๋ฐ์ดํฐ ๋ณ๋ ฌ(DistributedDataParallel) (DDP)๊ณผ ๋ถ์ฐ RPC ํ๋ ์์ํฌ(Distributed RPC framework) ๋ฅผ ๊ฒฐํฉํ๋ ๋ฐฉ๋ฒ์ ๋ํด ์ค๋ช ํฉ๋๋ค. ์์ ์ ์์ค ์ฝ๋๋ ์ฌ๊ธฐ ์์ ํ์ธํ ์ ์์ต๋๋ค.
์ด์ ํํ ๋ฆฌ์ผ ๋ด์ฉ์ด์๋ ๋ถ์ฐ ๋ฐ์ดํฐ ๋ณ๋ ฌ ์์ํ๊ธฐ ์ ๋ถ์ฐ RPC ํ๋ ์์ํฌ ์์ํ๊ธฐ ๋ ๋ถ์ฐ ๋ฐ์ดํฐ ๋ณ๋ ฌ ๋ฐ ๋ถ์ฐ ๋ชจ๋ธ ๋ณ๋ ฌ ํ์ต์ ๊ฐ๊ฐ ์ํํ๋ ๋ฐฉ๋ฒ์ ๋ํด ์ค๋ช ํฉ๋๋ค. ๊ทธ๋ฌ๋ ์ด ๋ ๊ฐ์ง ๊ธฐ์ ์ ๊ฒฐํฉํ ์ ์๋ ๋ช ๊ฐ์ง ํ์ต ํจ๋ฌ๋ค์์ด ์์ต๋๋ค. ์๋ฅผ ๋ค์ด:
- ํฌ์ ๋ถ๋ถ(ํฐ ์๋ฒ ๋ฉ ํ ์ด๋ธ)๊ณผ ๋ฐ์ง ๋ถ๋ถ(FC ๋ ์ด์ด)์ด ์๋ ๋ชจ๋ธ์ด ์๋ ๊ฒฝ์ฐ, ๋งค๊ฐ๋ณ์ ์๋ฒ(parameter server)์ ์๋ฒ ๋ฉ ํ ์ด๋ธ(embedding table)์ ๋๊ณ ๋ถ์ฐ ๋ฐ์ดํฐ ๋ณ๋ ฌ ์ ์ฌ์ฉํ์ฌ ์ฌ๋ฌ ํธ๋ ์ด๋์ ๊ฑธ์ณ FC ๋ ์ด์ด๋ฅผ ๋ณต์ ํ๋ ๊ฒ์ ์ํ ์๋ ์์ต๋๋ค. ์ด๋ ๋ถ์ฐ RPC ํ๋ ์์ํฌ ๋ ๋งค๊ฐ๋ณ์ ์๋ฒ์์ ์๋ฒ ๋ฉ ์ฐพ๊ธฐ ์์ (embedding lookup)์ ์ํํ๋ ๋ฐ ์ฌ์ฉํ ์ ์์ต๋๋ค.
- ๋ค์์ PipeDream ๋ฌธ์์์ ์ค๋ช ๋ ํ์ด๋ธ๋ฆฌ๋ ๋ณ๋ ฌ ์ฒ๋ฆฌ ํ์ฑํํ๊ธฐ ์ ๋๋ค. ๋ถ์ฐ RPC ํ๋ ์์ํฌ ๋ฅผ ์ฌ์ฉํ์ฌ ์ฌ๋ฌ worker์ ๊ฑธ์ณ ๋ชจ๋ธ์ ๋จ๊ณ๋ฅผ ํ์ดํ๋ผ์ธ(pipeline)ํ ์ ์๊ณ (ํ์์ ๋ฐ๋ผ) ๋ถ์ฐ ๋ฐ์ดํฐ ๋ณ๋ ฌ ์ ์ด์ฉํด์ ๊ฐ ๋จ๊ณ๋ฅผ ๋ณต์ ํ ์ ์์ต๋๋ค.
์ด ํํ ๋ฆฌ์ผ์์๋ ์์์ ์ธ๊ธํ ์ฒซ ๋ฒ์งธ ๊ฒฝ์ฐ๋ฅผ ๋ค๋ฃฐ ๊ฒ์ ๋๋ค. ๋ค์๊ณผ ๊ฐ์ด ์ด 4๊ฐ์ worker๊ฐ ์์ต๋๋ค:
- 1๊ฐ์ ๋ง์คํฐ๋ ๋งค๊ฐ๋ณ์ ์๋ฒ์ ์๋ฒ ๋ฉ ํ ์ด๋ธ(nn.EmbeddingBag) ์์ฑ์ ๋ด๋นํฉ๋๋ค. ๋ํ ๋ง์คํฐ๋ ๋ ํธ๋ ์ด๋์ ํ์ต ๋ฃจํ๋ฅผ ์ํํฉ๋๋ค.
- 1๊ฐ์ ๋งค๊ฐ๋ณ์ ์๋ฒ๋ ๊ธฐ๋ณธ์ ์ผ๋ก ๋ฉ๋ชจ๋ฆฌ์ ์๋ฒ ๋ฉ ํ ์ด๋ธ์ ๋ณด์ ํ๊ณ ๋ง์คํฐ ๋ฐ ํธ๋ ์ด๋์ RPC์ ์๋ตํฉ๋๋ค.
- 2๊ฐ์ ํธ๋ ์ด๋๋ ๋ถ์ฐ ๋ฐ์ดํฐ ๋ณ๋ ฌ ์ ์ฌ์ฉํ์ฌ ์์ฒด์ ์ผ๋ก ๋ณต์ ๋๋ FC ๋ ์ด์ด(nn.Linear)๋ฅผ ์ ์ฅํฉ๋๋ค. ํธ๋ ์ด๋๋ ๋ํ ์๋ฐฉํฅ ์ ๋ฌ(forward pass), ์ญ๋ฐฉํฅ ์ ๋ฌ(backward pass) ๋ฐ ์ต์ ํ ๋จ๊ณ๋ฅผ ์คํํด์ผ ํฉ๋๋ค.
์ ์ฒด์ ์ธ ํ์ต๊ณผ์ ์ ๋ค์๊ณผ ๊ฐ์ด ์คํ๋ฉ๋๋ค:
- ๋ง์คํฐ๋ ๋งค๊ฐ๋ณ์ ์๋ฒ์ ์๋ฒ ๋ฉ ํ ์ด๋ธ์ ๋ด๊ณ ์๋ ์๊ฒฉ ๋ชจ๋(RemoteModule) ์ ์์ฑํฉ๋๋ค.
- ๊ทธ๋ฐ ๋ค์ ๋ง์คํฐ๋ ํธ๋ ์ด๋์ ํ์ต ๋ฃจํ๋ฅผ ์์ํ๊ณ ์๊ฒฉ ๋ชจ๋(remote module)์ ํธ๋ ์ด๋์๊ฒ ์ ๋ฌํฉ๋๋ค.
- ํธ๋ ์ด๋๋ ๋จผ์ ๋ง์คํฐ์์ ์ ๊ณตํ๋ ์๊ฒฉ ๋ชจ๋์ ์ฌ์ฉํ์ฌ
์๋ฒ ๋ฉ ์ฐพ๊ธฐ ์์
(embedding lookup)์ ์ํํ ๋ค์ DDP ๋ด๋ถ์ ๊ฐ์ธ์ง FC ๋ ์ด์ด๋ฅผ ์คํํ๋
HybridModel์ ์์ฑํฉ๋๋ค. - ํธ๋ ์ด๋๋ ๋ชจ๋ธ์ ์๋ฐฉํฅ ์ ๋ฌ์ ์คํํ๊ณ ์์ค์ ์ฌ์ฉํ์ฌ ๋ถ์ฐ Autograd ๋ฅผ ์ฌ์ฉํ์ฌ ์ญ๋ฐฉํฅ ์ ๋ฌ์ ์คํํฉ๋๋ค.
- ์ญ๋ฐฉํฅ ์ ๋ฌ์ ์ผ๋ถ๋ก FC ๋ ์ด์ด์ ๋ณํ๋๊ฐ ๋จผ์ ๊ณ์ฐ๋๊ณ DDP์ allreduce๋ฅผ ํตํด ๋ชจ๋ ํธ๋ ์ด๋์ ๋๊ธฐํ๋ฉ๋๋ค.
- ๋ค์์ผ๋ก, ๋ถ์ฐ Autograd๋ ๋งค๊ฐ๋ณ์ ์๋ฒ๋ก ๋ณํ๋๋ฅผ ์ ํํ๊ณ ๊ทธ๊ณณ์์ ์๋ฒ ๋ฉ ํ ์ด๋ธ์ ๋ณํ๋๊ฐ ์ ๋ฐ์ดํธ๋ฉ๋๋ค.
- ๋ง์ง๋ง์ผ๋ก, ๋ถ์ฐ ์ตํฐ๋ง์ด์ (DistributedOptimizer) ๋ ๋ชจ๋ ๋งค๊ฐ๋ณ์๋ฅผ ์ ๋ฐ์ดํธํ๋ ๋ฐ ์ฌ์ฉ๋ฉ๋๋ค.
Warning
DDP์ RPC๋ฅผ ๊ฒฐํฉํ ๋, ์ญ๋ฐฉํฅ ์ ๋ฌ์ ๋ํด ํญ์ ๋ถ์ฐ Autograd ๋ฅผ ์ฌ์ฉํด์ผ ํฉ๋๋ค.
์ด์ ๊ฐ ๋ถ๋ถ์ ์์ธํ ์ดํด๋ณด๊ฒ ์ต๋๋ค. ๋จผ์ ํ์ต์ ์ํํ๊ธฐ ์ ์ ๋ชจ๋ ์์ ์๋ฅผ ์ค์ ํด์ผ ํฉ๋๋ค. ์์ 0๊ณผ 1์ ํธ๋ ์ด๋, ์์ 2๋ ๋ง์คํฐ, ์์ 3์ ๋งค๊ฐ๋ณ์ ์๋ฒ์ธ 4๊ฐ์ ํ๋ก์ธ์ค๋ฅผ ๋ง๋ญ๋๋ค.
TCP init_method๋ฅผ ์ฌ์ฉํ์ฌ 4๊ฐ์ ๋ชจ๋ worker์์ RPC ํ๋ ์์ํฌ๋ฅผ ์ด๊ธฐํํฉ๋๋ค.
RPC ์ด๊ธฐํ๊ฐ ๋๋๋ฉด, ๋ง์คํฐ๋ EmbeddingBag ๋ ์ด์ด๋ฅผ
์๊ฒฉ ๋ชจ๋(RemoteModule) ์ ์ฌ์ฉํ์ฌ
๋งค๊ฐ๋ณ์ ์๋ฒ์ ๋ด๊ณ ์๋ ์๊ฒฉ ๋ชจ๋ ํ๋๋ฅผ ์์ฑํฉ๋๋ค.
๊ทธ๋ฐ ๋ค์ ๋ง์คํฐ๋ ๊ฐ ํธ๋ ์ด๋๋ฅผ ๋ฐ๋ณตํ๊ณ rpc_async ๋ฅผ
์ฌ์ฉํ์ฌ ๊ฐ ํธ๋ ์ด๋์์ _run_trainer ๋ฅผ ํธ์ถํ์ฌ ๋ฐ๋ณต ํ์ต์ ์์ํฉ๋๋ค.
๋ง์ง๋ง์ผ๋ก ๋ง์คํฐ๋ ์ข
๋ฃํ๊ธฐ ์ ์ ๋ชจ๋ ํ์ต์ด ์๋ฃ๋ ๋๊น์ง ๊ธฐ๋ค๋ฆฝ๋๋ค.
ํธ๋ ์ด๋๋ init_process_group ์ ์ฌ์ฉํ์ฌ
(2๊ฐ์ ํธ๋ ์ด๋) world_size=2๋ก DDP๋ฅผ ์ํด ProcessGroup ์ ์ด๊ธฐํํฉ๋๋ค.
๋ค์์ผ๋ก TCP init_method๋ฅผ ์ฌ์ฉํ์ฌ RPC ํ๋ ์์ํฌ๋ฅผ ์ด๊ธฐํํฉ๋๋ค.
์ฌ๊ธฐ์ ์ฃผ์ ํ ์ ์ RPC ์ด๊ธฐํ์ ProgressGroup ์ด๊ธฐํ์์ ์ฐ์ด๋ ํฌํธ(port)๊ฐ ๋ค๋ฅด๋ค๋ ๊ฒ์
๋๋ค.
์ด๋ ๋ ํ๋ ์์ํฌ์ ์ด๊ธฐํ ๊ฐ์ ํฌํธ ์ถฉ๋์ ํผํ๊ธฐ ์ํด์ ์
๋๋ค.
์ด๊ธฐํ๊ฐ ์๋ฃ๋๋ฉด ํธ๋ ์ด๋๋ ๋ง์คํฐ์ _run_trainer RPC๋ฅผ ๊ธฐ๋ค๋ฆฌ๊ธฐ๋ง ํ๋ฉด ๋ฉ๋๋ค.
ํ๋ผํผํฐ ์๋ฒ๋ RPC ํ๋ ์์ํฌ๋ฅผ ์ด๊ธฐํํ๊ณ ํธ๋ ์ด๋์ ๋ง์คํฐ์ RPC๋ฅผ ๊ธฐ๋ค๋ฆฝ๋๋ค.
.. literalinclude:: ../advanced_source/rpc_ddp_tutorial/main.py :language: py :start-after: BEGIN run_worker :end-before: END run_worker
ํธ๋ ์ด๋์ ๋ํ ์์ธํ ์ค๋ช
์ ์์, ํธ๋ ์ด๋๊ฐ ์ฌ์ฉํ๋ HybridModel ์ ๋ํด ์ค๋ช
๋๋ฆฌ๊ฒ ์ต๋๋ค.
์๋์ ์ค๋ช
๋ ๋๋ก HybridModel ์ ๋งค๊ฐ๋ณ์ ์๋ฒ์ ์๋ฒ ๋ฉ ํ
์ด๋ธ(remote_emb_module)๊ณผ DDP์ ์ฌ์ฉํ device ๋ฅผ ๋ณด์ ํ๋ ์๊ฒฉ ๋ชจ๋์ ์ฌ์ฉํ์ฌ ์ด๊ธฐํ๋ฉ๋๋ค.
๋ชจ๋ธ ์ด๊ธฐํ๋ DDP ๋ด๋ถ์ nn.Linear ๋ ์ด์ด๋ฅผ
๊ฐ์ธ ๋ชจ๋ ํธ๋ ์ด๋์์ ์ด ๋ ์ด์ด๋ฅผ ๋ณต์ ํ๊ณ ๋๊ธฐํํฉ๋๋ค.
๋ชจ๋ธ์ ์๋ฐฉํฅ(forward) ํจ์๋ ๊ฝค ๊ฐ๋จํฉ๋๋ค.
RemoteModule์ forward ๋ฅผ ์ฌ์ฉํ์ฌ ๋งค๊ฐ๋ณ์ ์๋ฒ์์ ์๋ฒ ๋ฉ ์ฐพ๊ธฐ ์์
(embedding lookup)์ ์ํํ๊ณ ๊ทธ ์ถ๋ ฅ์ FC ๋ ์ด์ด์ ์ ๋ฌํฉ๋๋ค.
.. literalinclude:: ../advanced_source/rpc_ddp_tutorial/main.py :language: py :start-after: BEGIN hybrid_model :end-before: END hybrid_model
๋ค์์ผ๋ก ํธ๋ ์ด๋์ ์ค์ ์ ์ดํด๋ณด๊ฒ ์ต๋๋ค.
ํธ๋ ์ด๋๋ ๋จผ์ ๋งค๊ฐ๋ณ์ ์๋ฒ์ ์๋ฒ ๋ฉ ํ
์ด๋ธ๊ณผ ์์ฒด ์์๋ฅผ ๋ณด์ ํ๋ ์๊ฒฉ ๋ชจ๋์ ์ฌ์ฉํ์ฌ
์์์ ์ค๋ช
ํ HybridModel ์ ์์ฑํฉ๋๋ค.
์ด์ ๋ถ์ฐ ์ตํฐ๋ง์ด์ (DistributedOptimizer) ๋ก
์ต์ ํํ๋ ค๋ ๋ชจ๋ ๋งค๊ฐ๋ณ์์ ๋ํ RRef ๋ชฉ๋ก์ ๊ฒ์ํด์ผ ํฉ๋๋ค.
๋งค๊ฐ๋ณ์ ์๋ฒ์์ ์๋ฒ ๋ฉ ํ
์ด๋ธ์ ๋งค๊ฐ๋ณ์๋ฅผ ๊ฒ์ํ๊ธฐ ์ํด
RemoteModule์ remote_parameters ๋ฅผ ํธ์ถํ ์ ์์ต๋๋ค.
๊ทธ๋ฆฌ๊ณ ์ด๊ฒ์ ๊ธฐ๋ณธ์ ์ผ๋ก ์๋ฒ ๋ฉ ํ
์ด๋ธ์ ๋ชจ๋ ๋งค๊ฐ๋ณ์๋ฅผ ์ดํด๋ณด๊ณ RRef ๋ชฉ๋ก์ ๋ฐํํฉ๋๋ค.
ํธ๋ ์ด๋๋ RPC๋ฅผ ํตํด ๋งค๊ฐ๋ณ์ ์๋ฒ์์ ์ด ๋ฉ์๋๋ฅผ ํธ์ถํ์ฌ ์ํ๋ ๋งค๊ฐ๋ณ์์ ๋ํ RRef ๋ชฉ๋ก์ ์์ ํฉ๋๋ค.
DistributedOptimizer๋ ํญ์ ์ต์ ํํด์ผ ํ๋ ๋งค๊ฐ๋ณ์์ ๋ํ RRef ๋ชฉ๋ก์ ๊ฐ์ ธ์ค๊ธฐ ๋๋ฌธ์ FC ๋ ์ด์ด์ ์ ์ญ ๋งค๊ฐ๋ณ์์ ๋ํด์๋ RRef๋ฅผ ์์ฑํด์ผ ํฉ๋๋ค.
์ด๊ฒ์ model.fc.parameters() ๋ฅผ ํ์ํ๊ณ ๊ฐ ๋งค๊ฐ๋ณ์์ ๋ํ RRef๋ฅผ ์์ฑํ๊ณ
remote_parameters() ์์ ๋ฐํ๋ ๋ชฉ๋ก์ ์ถ๊ฐํจ์ผ๋ก์จ ์ํ๋ฉ๋๋ค.
์ฐธ๊ณ ๋ก model.parameters() ๋ ์ฌ์ฉํ ์ ์์ต๋๋ค. RemoteModule ์์ ์ง์ํ์ง ์๋ model.remote_emb_module.parameters() ๋ฅผ ์ฌ๊ท์ ์ผ๋ก ํธ์ถํ๊ธฐ ๋๋ฌธ์
๋๋ค.
๋ง์ง๋ง์ผ๋ก ๋ชจ๋ RRef๋ฅผ ์ฌ์ฉํ์ฌ DistributedOptimizer๋ฅผ ๋ง๋ค๊ณ CrossEntropyLoss ํจ์๋ฅผ ์ ์ํฉ๋๋ค.
.. literalinclude:: ../advanced_source/rpc_ddp_tutorial/main.py :language: py :start-after: BEGIN setup_trainer :end-before: END setup_trainer
์ด์ ๊ฐ ํธ๋ ์ด๋์์ ์คํ๋๋ ๊ธฐ๋ณธ ํ์ต ๋ฃจํ๋ฅผ ์๊ฐํ๊ฒ ์ต๋๋ค.
get_next_batch ๋ ํ์ต์ ์ํ ์์์ ์
๋ ฅ๊ณผ ๋์์ ์์ฑํ๋ ๊ฒ์ ๋์์ฃผ๋ ํจ์์ผ ๋ฟ์
๋๋ค.
์ฌ๋ฌ ์ํญ(epoch)๊ณผ ๊ฐ ๋ฐฐ์น(batch)์ ๋ํด ํ์ต ๋ฃจํ๋ฅผ ์คํํฉ๋๋ค:
- ๋จผ์ ๋ถ์ฐ Autograd์ ๋ํด ๋ถ์ฐ Autograd Context ๋ฅผ ์ค์ ํฉ๋๋ค.
- ๋ชจ๋ธ์ ์๋ฐฉํฅ ์ ๋ฌ์ ์คํํ๊ณ ํด๋น ์ถ๋ ฅ์ ๊ฒ์(retrieve)ํฉ๋๋ค.
- ์์ค ํจ์๋ฅผ ์ฌ์ฉํ์ฌ ์ถ๋ ฅ๊ณผ ๋ชฉํ๋ฅผ ๊ธฐ๋ฐ์ผ๋ก ์์ค์ ๊ณ์ฐํฉ๋๋ค.
- ๋ถ์ฐ Autograd๋ฅผ ์ฌ์ฉํ์ฌ ์์ค์ ์ฌ์ฉํ์ฌ ๋ถ์ฐ ์ญ๋ฐฉํฅ ์ ๋ฌ์ ์คํํฉ๋๋ค.
- ๋ง์ง๋ง์ผ๋ก ๋ถ์ฐ ์ตํฐ๋ง์ด์ ๋จ๊ณ๋ฅผ ์คํํ์ฌ ๋ชจ๋ ๋งค๊ฐ๋ณ์๋ฅผ ์ต์ ํํฉ๋๋ค.
.. literalinclude:: ../advanced_source/rpc_ddp_tutorial/main.py :language: py :start-after: BEGIN run_trainer :end-before: END run_trainer
์ ์ฒด ์์ ์ ์์ค ์ฝ๋๋ ์ฌ๊ธฐ ์์ ์ฐพ์ ์ ์์ต๋๋ค.