Skip to content

Latest commit

ย 

History

History
146 lines (114 loc) ยท 11 KB

File metadata and controls

146 lines (114 loc) ยท 11 KB

๋ถ„์‚ฐ ๋ฐ์ดํ„ฐ ๋ณ‘๋ ฌ(DDP)๊ณผ ๋ถ„์‚ฐ RPC ํ”„๋ ˆ์ž„์›Œํฌ ๊ฒฐํ•ฉ

์ €์ž: Pritam Damania and Yi Wang

๋ฒˆ์—ญ: ๋ฐ•๋‹ค์ •

Note

|edit| ์ด ํŠœํ† ๋ฆฌ์–ผ์˜ ์†Œ์Šค ์ฝ”๋“œ๋Š” GitHub ์—์„œ ํ™•์ธํ•˜๊ณ  ๋ณ€๊ฒฝํ•ด ๋ณผ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์ด ํŠœํ† ๋ฆฌ์–ผ์€ ๊ฐ„๋‹จํ•œ ์˜ˆ์ œ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋ถ„์‚ฐ ๋ฐ์ดํ„ฐ ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ(distributed data parallelism)์™€ ๋ถ„์‚ฐ ๋ชจ๋ธ ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ(distributed model parallelism)๋ฅผ ๊ฒฐํ•ฉํ•˜์—ฌ ๊ฐ„๋‹จํ•œ ๋ชจ๋ธ ํ•™์Šต์‹œํ‚ฌ ๋•Œ ๋ถ„์‚ฐ ๋ฐ์ดํ„ฐ ๋ณ‘๋ ฌ(DistributedDataParallel) (DDP)๊ณผ ๋ถ„์‚ฐ RPC ํ”„๋ ˆ์ž„์›Œํฌ(Distributed RPC framework) ๋ฅผ ๊ฒฐํ•ฉํ•˜๋Š” ๋ฐฉ๋ฒ•์— ๋Œ€ํ•ด ์„ค๋ช…ํ•ฉ๋‹ˆ๋‹ค. ์˜ˆ์ œ์˜ ์†Œ์Šค ์ฝ”๋“œ๋Š” ์—ฌ๊ธฐ ์—์„œ ํ™•์ธํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์ด์ „ ํŠœํ† ๋ฆฌ์–ผ ๋‚ด์šฉ์ด์—ˆ๋˜ ๋ถ„์‚ฐ ๋ฐ์ดํ„ฐ ๋ณ‘๋ ฌ ์‹œ์ž‘ํ•˜๊ธฐ ์™€ ๋ถ„์‚ฐ RPC ํ”„๋ ˆ์ž„์›Œํฌ ์‹œ์ž‘ํ•˜๊ธฐ ๋Š” ๋ถ„์‚ฐ ๋ฐ์ดํ„ฐ ๋ณ‘๋ ฌ ๋ฐ ๋ถ„์‚ฐ ๋ชจ๋ธ ๋ณ‘๋ ฌ ํ•™์Šต์„ ๊ฐ๊ฐ ์ˆ˜ํ–‰ํ•˜๋Š” ๋ฐฉ๋ฒ•์— ๋Œ€ํ•ด ์„ค๋ช…ํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋Ÿฌ๋‚˜ ์ด ๋‘ ๊ฐ€์ง€ ๊ธฐ์ˆ ์„ ๊ฒฐํ•ฉํ•  ์ˆ˜ ์žˆ๋Š” ๋ช‡ ๊ฐ€์ง€ ํ•™์Šต ํŒจ๋Ÿฌ๋‹ค์ž„์ด ์žˆ์Šต๋‹ˆ๋‹ค. ์˜ˆ๋ฅผ ๋“ค์–ด:

  1. ํฌ์†Œ ๋ถ€๋ถ„(ํฐ ์ž„๋ฒ ๋”ฉ ํ…Œ์ด๋ธ”)๊ณผ ๋ฐ€์ง‘ ๋ถ€๋ถ„(FC ๋ ˆ์ด์–ด)์ด ์žˆ๋Š” ๋ชจ๋ธ์ด ์žˆ๋Š” ๊ฒฝ์šฐ, ๋งค๊ฐœ๋ณ€์ˆ˜ ์„œ๋ฒ„(parameter server)์— ์ž„๋ฒ ๋”ฉ ํ…Œ์ด๋ธ”(embedding table)์„ ๋†“๊ณ  ๋ถ„์‚ฐ ๋ฐ์ดํ„ฐ ๋ณ‘๋ ฌ ์„ ์‚ฌ์šฉํ•˜์—ฌ ์—ฌ๋Ÿฌ ํŠธ๋ ˆ์ด๋„ˆ์— ๊ฑธ์ณ FC ๋ ˆ์ด์–ด๋ฅผ ๋ณต์ œํ•˜๋Š” ๊ฒƒ์„ ์›ํ•  ์ˆ˜๋„ ์žˆ์Šต๋‹ˆ๋‹ค. ์ด๋•Œ ๋ถ„์‚ฐ RPC ํ”„๋ ˆ์ž„์›Œํฌ ๋Š” ๋งค๊ฐœ๋ณ€์ˆ˜ ์„œ๋ฒ„์—์„œ ์ž„๋ฒ ๋”ฉ ์ฐพ๊ธฐ ์ž‘์—…(embedding lookup)์„ ์ˆ˜ํ–‰ํ•˜๋Š” ๋ฐ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.
  2. ๋‹ค์Œ์€ PipeDream ๋ฌธ์„œ์—์„œ ์„ค๋ช…๋œ ํ•˜์ด๋ธŒ๋ฆฌ๋“œ ๋ณ‘๋ ฌ ์ฒ˜๋ฆฌ ํ™œ์„ฑํ™”ํ•˜๊ธฐ ์ž…๋‹ˆ๋‹ค. ๋ถ„์‚ฐ RPC ํ”„๋ ˆ์ž„์›Œํฌ ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์—ฌ๋Ÿฌ worker์— ๊ฑธ์ณ ๋ชจ๋ธ์˜ ๋‹จ๊ณ„๋ฅผ ํŒŒ์ดํ”„๋ผ์ธ(pipeline)ํ•  ์ˆ˜ ์žˆ๊ณ  (ํ•„์š”์— ๋”ฐ๋ผ) ๋ถ„์‚ฐ ๋ฐ์ดํ„ฐ ๋ณ‘๋ ฌ ์„ ์ด์šฉํ•ด์„œ ๊ฐ ๋‹จ๊ณ„๋ฅผ ๋ณต์ œํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.

์ด ํŠœํ† ๋ฆฌ์–ผ์—์„œ๋Š” ์œ„์—์„œ ์–ธ๊ธ‰ํ•œ ์ฒซ ๋ฒˆ์งธ ๊ฒฝ์šฐ๋ฅผ ๋‹ค๋ฃฐ ๊ฒƒ์ž…๋‹ˆ๋‹ค. ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์ด 4๊ฐœ์˜ worker๊ฐ€ ์žˆ์Šต๋‹ˆ๋‹ค:

  1. 1๊ฐœ์˜ ๋งˆ์Šคํ„ฐ๋Š” ๋งค๊ฐœ๋ณ€์ˆ˜ ์„œ๋ฒ„์— ์ž„๋ฒ ๋”ฉ ํ…Œ์ด๋ธ”(nn.EmbeddingBag) ์ƒ์„ฑ์„ ๋‹ด๋‹นํ•ฉ๋‹ˆ๋‹ค. ๋˜ํ•œ ๋งˆ์Šคํ„ฐ๋Š” ๋‘ ํŠธ๋ ˆ์ด๋„ˆ์˜ ํ•™์Šต ๋ฃจํ”„๋ฅผ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.
  2. 1๊ฐœ์˜ ๋งค๊ฐœ๋ณ€์ˆ˜ ์„œ๋ฒ„๋Š” ๊ธฐ๋ณธ์ ์œผ๋กœ ๋ฉ”๋ชจ๋ฆฌ์— ์ž„๋ฒ ๋”ฉ ํ…Œ์ด๋ธ”์„ ๋ณด์œ ํ•˜๊ณ  ๋งˆ์Šคํ„ฐ ๋ฐ ํŠธ๋ ˆ์ด๋„ˆ์˜ RPC์— ์‘๋‹ตํ•ฉ๋‹ˆ๋‹ค.
  3. 2๊ฐœ์˜ ํŠธ๋ ˆ์ด๋„ˆ๋Š” ๋ถ„์‚ฐ ๋ฐ์ดํ„ฐ ๋ณ‘๋ ฌ ์„ ์‚ฌ์šฉํ•˜์—ฌ ์ž์ฒด์ ์œผ๋กœ ๋ณต์ œ๋˜๋Š” FC ๋ ˆ์ด์–ด(nn.Linear)๋ฅผ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค. ํŠธ๋ ˆ์ด๋„ˆ๋Š” ๋˜ํ•œ ์ˆœ๋ฐฉํ–ฅ ์ „๋‹ฌ(forward pass), ์—ญ๋ฐฉํ–ฅ ์ „๋‹ฌ(backward pass) ๋ฐ ์ตœ์ ํ™” ๋‹จ๊ณ„๋ฅผ ์‹คํ–‰ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

์ „์ฒด์ ์ธ ํ•™์Šต๊ณผ์ •์€ ๋‹ค์Œ๊ณผ ๊ฐ™์ด ์‹คํ–‰๋ฉ๋‹ˆ๋‹ค:

  1. ๋งˆ์Šคํ„ฐ๋Š” ๋งค๊ฐœ๋ณ€์ˆ˜ ์„œ๋ฒ„์— ์ž„๋ฒ ๋”ฉ ํ…Œ์ด๋ธ”์„ ๋‹ด๊ณ  ์žˆ๋Š” ์›๊ฒฉ ๋ชจ๋“ˆ(RemoteModule) ์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
  2. ๊ทธ๋Ÿฐ ๋‹ค์Œ ๋งˆ์Šคํ„ฐ๋Š” ํŠธ๋ ˆ์ด๋„ˆ์˜ ํ•™์Šต ๋ฃจํ”„๋ฅผ ์‹œ์ž‘ํ•˜๊ณ  ์›๊ฒฉ ๋ชจ๋“ˆ(remote module)์„ ํŠธ๋ ˆ์ด๋„ˆ์—๊ฒŒ ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค.
  3. ํŠธ๋ ˆ์ด๋„ˆ๋Š” ๋จผ์ € ๋งˆ์Šคํ„ฐ์—์„œ ์ œ๊ณตํ•˜๋Š” ์›๊ฒฉ ๋ชจ๋“ˆ์„ ์‚ฌ์šฉํ•˜์—ฌ ์ž„๋ฒ ๋”ฉ ์ฐพ๊ธฐ ์ž‘์—…(embedding lookup)์„ ์ˆ˜ํ–‰ํ•œ ๋‹ค์Œ DDP ๋‚ด๋ถ€์— ๊ฐ์‹ธ์ง„ FC ๋ ˆ์ด์–ด๋ฅผ ์‹คํ–‰ํ•˜๋Š” HybridModel ์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
  4. ํŠธ๋ ˆ์ด๋„ˆ๋Š” ๋ชจ๋ธ์˜ ์ˆœ๋ฐฉํ–ฅ ์ „๋‹ฌ์„ ์‹คํ–‰ํ•˜๊ณ  ์†์‹ค์„ ์‚ฌ์šฉํ•˜์—ฌ ๋ถ„์‚ฐ Autograd ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์—ญ๋ฐฉํ–ฅ ์ „๋‹ฌ์„ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค.
  5. ์—ญ๋ฐฉํ–ฅ ์ „๋‹ฌ์˜ ์ผ๋ถ€๋กœ FC ๋ ˆ์ด์–ด์˜ ๋ณ€ํ™”๋„๊ฐ€ ๋จผ์ € ๊ณ„์‚ฐ๋˜๊ณ  DDP์˜ allreduce๋ฅผ ํ†ตํ•ด ๋ชจ๋“  ํŠธ๋ ˆ์ด๋„ˆ์™€ ๋™๊ธฐํ™”๋ฉ๋‹ˆ๋‹ค.
  6. ๋‹ค์Œ์œผ๋กœ, ๋ถ„์‚ฐ Autograd๋Š” ๋งค๊ฐœ๋ณ€์ˆ˜ ์„œ๋ฒ„๋กœ ๋ณ€ํ™”๋„๋ฅผ ์ „ํŒŒํ•˜๊ณ  ๊ทธ๊ณณ์—์„œ ์ž„๋ฒ ๋”ฉ ํ…Œ์ด๋ธ”์˜ ๋ณ€ํ™”๋„๊ฐ€ ์—…๋ฐ์ดํŠธ๋ฉ๋‹ˆ๋‹ค.
  7. ๋งˆ์ง€๋ง‰์œผ๋กœ, ๋ถ„์‚ฐ ์˜ตํ‹ฐ๋งˆ์ด์ €(DistributedOptimizer) ๋Š” ๋ชจ๋“  ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์—…๋ฐ์ดํŠธํ•˜๋Š” ๋ฐ ์‚ฌ์šฉ๋ฉ๋‹ˆ๋‹ค.

Warning

DDP์™€ RPC๋ฅผ ๊ฒฐํ•ฉํ•  ๋•Œ, ์—ญ๋ฐฉํ–ฅ ์ „๋‹ฌ์— ๋Œ€ํ•ด ํ•ญ์ƒ ๋ถ„์‚ฐ Autograd ๋ฅผ ์‚ฌ์šฉํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค.

์ด์ œ ๊ฐ ๋ถ€๋ถ„์„ ์ž์„ธํžˆ ์‚ดํŽด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. ๋จผ์ € ํ•™์Šต์„ ์ˆ˜ํ–‰ํ•˜๊ธฐ ์ „์— ๋ชจ๋“  ์ž‘์—…์ž๋ฅผ ์„ค์ •ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ์ˆœ์œ„ 0๊ณผ 1์€ ํŠธ๋ ˆ์ด๋„ˆ, ์ˆœ์œ„ 2๋Š” ๋งˆ์Šคํ„ฐ, ์ˆœ์œ„ 3์€ ๋งค๊ฐœ๋ณ€์ˆ˜ ์„œ๋ฒ„์ธ 4๊ฐœ์˜ ํ”„๋กœ์„ธ์Šค๋ฅผ ๋งŒ๋“ญ๋‹ˆ๋‹ค.

TCP init_method๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ 4๊ฐœ์˜ ๋ชจ๋“  worker์—์„œ RPC ํ”„๋ ˆ์ž„์›Œํฌ๋ฅผ ์ดˆ๊ธฐํ™”ํ•ฉ๋‹ˆ๋‹ค. RPC ์ดˆ๊ธฐํ™”๊ฐ€ ๋๋‚˜๋ฉด, ๋งˆ์Šคํ„ฐ๋Š” EmbeddingBag ๋ ˆ์ด์–ด๋ฅผ ์›๊ฒฉ ๋ชจ๋“ˆ(RemoteModule) ์„ ์‚ฌ์šฉํ•˜์—ฌ ๋งค๊ฐœ๋ณ€์ˆ˜ ์„œ๋ฒ„์— ๋‹ด๊ณ  ์žˆ๋Š” ์›๊ฒฉ ๋ชจ๋“ˆ ํ•˜๋‚˜๋ฅผ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค. ๊ทธ๋Ÿฐ ๋‹ค์Œ ๋งˆ์Šคํ„ฐ๋Š” ๊ฐ ํŠธ๋ ˆ์ด๋„ˆ๋ฅผ ๋ฐ˜๋ณตํ•˜๊ณ  rpc_async ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๊ฐ ํŠธ๋ ˆ์ด๋„ˆ์—์„œ _run_trainer ๋ฅผ ํ˜ธ์ถœํ•˜์—ฌ ๋ฐ˜๋ณต ํ•™์Šต์„ ์‹œ์ž‘ํ•ฉ๋‹ˆ๋‹ค. ๋งˆ์ง€๋ง‰์œผ๋กœ ๋งˆ์Šคํ„ฐ๋Š” ์ข…๋ฃŒํ•˜๊ธฐ ์ „์— ๋ชจ๋“  ํ•™์Šต์ด ์™„๋ฃŒ๋  ๋•Œ๊นŒ์ง€ ๊ธฐ๋‹ค๋ฆฝ๋‹ˆ๋‹ค.

ํŠธ๋ ˆ์ด๋„ˆ๋Š” init_process_group ์„ ์‚ฌ์šฉํ•˜์—ฌ (2๊ฐœ์˜ ํŠธ๋ ˆ์ด๋„ˆ) world_size=2๋กœ DDP๋ฅผ ์œ„ํ•ด ProcessGroup ์„ ์ดˆ๊ธฐํ™”ํ•ฉ๋‹ˆ๋‹ค. ๋‹ค์Œ์œผ๋กœ TCP init_method๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ RPC ํ”„๋ ˆ์ž„์›Œํฌ๋ฅผ ์ดˆ๊ธฐํ™”ํ•ฉ๋‹ˆ๋‹ค. ์—ฌ๊ธฐ์„œ ์ฃผ์˜ ํ•  ์ ์€ RPC ์ดˆ๊ธฐํ™”์™€ ProgressGroup ์ดˆ๊ธฐํ™”์—์„œ ์“ฐ์ด๋Š” ํฌํŠธ(port)๊ฐ€ ๋‹ค๋ฅด๋‹ค๋Š” ๊ฒƒ์ž…๋‹ˆ๋‹ค. ์ด๋Š” ๋‘ ํ”„๋ ˆ์ž„์›Œํฌ์˜ ์ดˆ๊ธฐํ™” ๊ฐ„์— ํฌํŠธ ์ถฉ๋Œ์„ ํ”ผํ•˜๊ธฐ ์œ„ํ•ด์„œ ์ž…๋‹ˆ๋‹ค. ์ดˆ๊ธฐํ™”๊ฐ€ ์™„๋ฃŒ๋˜๋ฉด ํŠธ๋ ˆ์ด๋„ˆ๋Š” ๋งˆ์Šคํ„ฐ์˜ _run_trainer RPC๋ฅผ ๊ธฐ๋‹ค๋ฆฌ๊ธฐ๋งŒ ํ•˜๋ฉด ๋ฉ๋‹ˆ๋‹ค.

ํŒŒ๋ผํ”ผํ„ฐ ์„œ๋ฒ„๋Š” RPC ํ”„๋ ˆ์ž„์›Œํฌ๋ฅผ ์ดˆ๊ธฐํ™”ํ•˜๊ณ  ํŠธ๋ ˆ์ด๋„ˆ์™€ ๋งˆ์Šคํ„ฐ์˜ RPC๋ฅผ ๊ธฐ๋‹ค๋ฆฝ๋‹ˆ๋‹ค.

.. literalinclude:: ../advanced_source/rpc_ddp_tutorial/main.py
  :language: py
  :start-after: BEGIN run_worker
  :end-before: END run_worker

ํŠธ๋ ˆ์ด๋„ˆ์— ๋Œ€ํ•œ ์ž์„ธํ•œ ์„ค๋ช…์— ์•ž์„œ, ํŠธ๋ ˆ์ด๋„ˆ๊ฐ€ ์‚ฌ์šฉํ•˜๋Š” HybridModel ์— ๋Œ€ํ•ด ์„ค๋ช…๋“œ๋ฆฌ๊ฒ ์Šต๋‹ˆ๋‹ค. ์•„๋ž˜์— ์„ค๋ช…๋œ ๋Œ€๋กœ HybridModel ์€ ๋งค๊ฐœ๋ณ€์ˆ˜ ์„œ๋ฒ„์˜ ์ž„๋ฒ ๋”ฉ ํ…Œ์ด๋ธ”(remote_emb_module)๊ณผ DDP์— ์‚ฌ์šฉํ•  device ๋ฅผ ๋ณด์œ ํ•˜๋Š” ์›๊ฒฉ ๋ชจ๋“ˆ์„ ์‚ฌ์šฉํ•˜์—ฌ ์ดˆ๊ธฐํ™”๋ฉ๋‹ˆ๋‹ค. ๋ชจ๋ธ ์ดˆ๊ธฐํ™”๋Š” DDP ๋‚ด๋ถ€์˜ nn.Linear ๋ ˆ์ด์–ด๋ฅผ ๊ฐ์‹ธ ๋ชจ๋“  ํŠธ๋ ˆ์ด๋„ˆ์—์„œ ์ด ๋ ˆ์ด์–ด๋ฅผ ๋ณต์ œํ•˜๊ณ  ๋™๊ธฐํ™”ํ•ฉ๋‹ˆ๋‹ค.

๋ชจ๋ธ์˜ ์ˆœ๋ฐฉํ–ฅ(forward) ํ•จ์ˆ˜๋Š” ๊ฝค ๊ฐ„๋‹จํ•ฉ๋‹ˆ๋‹ค. RemoteModule์˜ forward ๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ๋งค๊ฐœ๋ณ€์ˆ˜ ์„œ๋ฒ„์—์„œ ์ž„๋ฒ ๋”ฉ ์ฐพ๊ธฐ ์ž‘์—…(embedding lookup)์„ ์ˆ˜ํ–‰ํ•˜๊ณ  ๊ทธ ์ถœ๋ ฅ์„ FC ๋ ˆ์ด์–ด์— ์ „๋‹ฌํ•ฉ๋‹ˆ๋‹ค.

.. literalinclude:: ../advanced_source/rpc_ddp_tutorial/main.py
  :language: py
  :start-after: BEGIN hybrid_model
  :end-before: END hybrid_model

๋‹ค์Œ์œผ๋กœ ํŠธ๋ ˆ์ด๋„ˆ์˜ ์„ค์ •์„ ์‚ดํŽด๋ณด๊ฒ ์Šต๋‹ˆ๋‹ค. ํŠธ๋ ˆ์ด๋„ˆ๋Š” ๋จผ์ € ๋งค๊ฐœ๋ณ€์ˆ˜ ์„œ๋ฒ„์˜ ์ž„๋ฒ ๋”ฉ ํ…Œ์ด๋ธ”๊ณผ ์ž์ฒด ์ˆœ์œ„๋ฅผ ๋ณด์œ ํ•˜๋Š” ์›๊ฒฉ ๋ชจ๋“ˆ์„ ์‚ฌ์šฉํ•˜์—ฌ ์œ„์—์„œ ์„ค๋ช…ํ•œ HybridModel ์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.

์ด์ œ ๋ถ„์‚ฐ ์˜ตํ‹ฐ๋งˆ์ด์ €(DistributedOptimizer) ๋กœ ์ตœ์ ํ™”ํ•˜๋ ค๋Š” ๋ชจ๋“  ๋งค๊ฐœ๋ณ€์ˆ˜์— ๋Œ€ํ•œ RRef ๋ชฉ๋ก์„ ๊ฒ€์ƒ‰ํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ๋งค๊ฐœ๋ณ€์ˆ˜ ์„œ๋ฒ„์—์„œ ์ž„๋ฒ ๋”ฉ ํ…Œ์ด๋ธ”์˜ ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ๊ฒ€์ƒ‰ํ•˜๊ธฐ ์œ„ํ•ด RemoteModule์˜ remote_parameters ๋ฅผ ํ˜ธ์ถœํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. ๊ทธ๋ฆฌ๊ณ  ์ด๊ฒƒ์€ ๊ธฐ๋ณธ์ ์œผ๋กœ ์ž„๋ฒ ๋”ฉ ํ…Œ์ด๋ธ”์˜ ๋ชจ๋“  ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์‚ดํŽด๋ณด๊ณ  RRef ๋ชฉ๋ก์„ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค. ํŠธ๋ ˆ์ด๋„ˆ๋Š” RPC๋ฅผ ํ†ตํ•ด ๋งค๊ฐœ๋ณ€์ˆ˜ ์„œ๋ฒ„์—์„œ ์ด ๋ฉ”์„œ๋“œ๋ฅผ ํ˜ธ์ถœํ•˜์—ฌ ์›ํ•˜๋Š” ๋งค๊ฐœ๋ณ€์ˆ˜์— ๋Œ€ํ•œ RRef ๋ชฉ๋ก์„ ์ˆ˜์‹ ํ•ฉ๋‹ˆ๋‹ค. DistributedOptimizer๋Š” ํ•ญ์ƒ ์ตœ์ ํ™”ํ•ด์•ผ ํ•˜๋Š” ๋งค๊ฐœ๋ณ€์ˆ˜์— ๋Œ€ํ•œ RRef ๋ชฉ๋ก์„ ๊ฐ€์ ธ์˜ค๊ธฐ ๋•Œ๋ฌธ์— FC ๋ ˆ์ด์–ด์˜ ์ „์—ญ ๋งค๊ฐœ๋ณ€์ˆ˜์— ๋Œ€ํ•ด์„œ๋„ RRef๋ฅผ ์ƒ์„ฑํ•ด์•ผ ํ•ฉ๋‹ˆ๋‹ค. ์ด๊ฒƒ์€ model.fc.parameters() ๋ฅผ ํƒ์ƒ‰ํ•˜๊ณ  ๊ฐ ๋งค๊ฐœ๋ณ€์ˆ˜์— ๋Œ€ํ•œ RRef๋ฅผ ์ƒ์„ฑํ•˜๊ณ  remote_parameters() ์—์„œ ๋ฐ˜ํ™˜๋œ ๋ชฉ๋ก์— ์ถ”๊ฐ€ํ•จ์œผ๋กœ์จ ์ˆ˜ํ–‰๋ฉ๋‹ˆ๋‹ค. ์ฐธ๊ณ ๋กœ model.parameters() ๋Š” ์‚ฌ์šฉํ•  ์ˆ˜ ์—†์Šต๋‹ˆ๋‹ค. RemoteModule ์—์„œ ์ง€์›ํ•˜์ง€ ์•Š๋Š” model.remote_emb_module.parameters() ๋ฅผ ์žฌ๊ท€์ ์œผ๋กœ ํ˜ธ์ถœํ•˜๊ธฐ ๋•Œ๋ฌธ์ž…๋‹ˆ๋‹ค.

๋งˆ์ง€๋ง‰์œผ๋กœ ๋ชจ๋“  RRef๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ DistributedOptimizer๋ฅผ ๋งŒ๋“ค๊ณ  CrossEntropyLoss ํ•จ์ˆ˜๋ฅผ ์ •์˜ํ•ฉ๋‹ˆ๋‹ค.

.. literalinclude:: ../advanced_source/rpc_ddp_tutorial/main.py
  :language: py
  :start-after: BEGIN setup_trainer
  :end-before: END setup_trainer

์ด์ œ ๊ฐ ํŠธ๋ ˆ์ด๋„ˆ์—์„œ ์‹คํ–‰๋˜๋Š” ๊ธฐ๋ณธ ํ•™์Šต ๋ฃจํ”„๋ฅผ ์†Œ๊ฐœํ•˜๊ฒ ์Šต๋‹ˆ๋‹ค. get_next_batch ๋Š” ํ•™์Šต์„ ์œ„ํ•œ ์ž„์˜์˜ ์ž…๋ ฅ๊ณผ ๋Œ€์ƒ์„ ์ƒ์„ฑํ•˜๋Š” ๊ฒƒ์„ ๋„์™€์ฃผ๋Š” ํ•จ์ˆ˜์ผ ๋ฟ์ž…๋‹ˆ๋‹ค. ์—ฌ๋Ÿฌ ์—ํญ(epoch)๊ณผ ๊ฐ ๋ฐฐ์น˜(batch)์— ๋Œ€ํ•ด ํ•™์Šต ๋ฃจํ”„๋ฅผ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค:

  1. ๋จผ์ € ๋ถ„์‚ฐ Autograd์— ๋Œ€ํ•ด ๋ถ„์‚ฐ Autograd Context ๋ฅผ ์„ค์ •ํ•ฉ๋‹ˆ๋‹ค.
  2. ๋ชจ๋ธ์˜ ์ˆœ๋ฐฉํ–ฅ ์ „๋‹ฌ์„ ์‹คํ–‰ํ•˜๊ณ  ํ•ด๋‹น ์ถœ๋ ฅ์„ ๊ฒ€์ƒ‰(retrieve)ํ•ฉ๋‹ˆ๋‹ค.
  3. ์†์‹ค ํ•จ์ˆ˜๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์ถœ๋ ฅ๊ณผ ๋ชฉํ‘œ๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ ์†์‹ค์„ ๊ณ„์‚ฐํ•ฉ๋‹ˆ๋‹ค.
  4. ๋ถ„์‚ฐ Autograd๋ฅผ ์‚ฌ์šฉํ•˜์—ฌ ์†์‹ค์„ ์‚ฌ์šฉํ•˜์—ฌ ๋ถ„์‚ฐ ์—ญ๋ฐฉํ–ฅ ์ „๋‹ฌ์„ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค.
  5. ๋งˆ์ง€๋ง‰์œผ๋กœ ๋ถ„์‚ฐ ์˜ตํ‹ฐ๋งˆ์ด์ € ๋‹จ๊ณ„๋ฅผ ์‹คํ–‰ํ•˜์—ฌ ๋ชจ๋“  ๋งค๊ฐœ๋ณ€์ˆ˜๋ฅผ ์ตœ์ ํ™”ํ•ฉ๋‹ˆ๋‹ค.
.. literalinclude:: ../advanced_source/rpc_ddp_tutorial/main.py
  :language: py
  :start-after: BEGIN run_trainer
  :end-before: END run_trainer

์ „์ฒด ์˜ˆ์ œ์˜ ์†Œ์Šค ์ฝ”๋“œ๋Š” ์—ฌ๊ธฐ ์—์„œ ์ฐพ์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.