TPU에서 모델 확장
이 문서에서는 언어 모델을 확장하는 방법, 즉 TPU가 작동하는 방식과 서로 통신하는 방식, 실제 하드웨어에서 LLM이 실행되는 방식, 대규모로 효율적으로 실행되도록 학습 및 추론 중에 모델을 병렬화하는 방법을 간략하게 설명합니다. LLM 학습에 드는 비용, 모델을 제공하는 데 필요한 메모리, 여러 TPU에 모델을 효과적으로 샤딩하는 방법을 평가하는 데 도움이 되는 정보를 제공합니다.
딥 러닝은 복잡한 부분이 많지만, 모델의 성능을 최적화하는 것은 대규모로도 어렵지 않습니다. 기본 원칙은 단일 액셀러레이터부터 수만 개에 이르기까지 어디에나 적용되며, 이를 이해하면 다음과 같은 유용한 작업을 많이 할 수 있습니다.
- 모델의 일부가 이론적 최적에 얼마나 가까운지 추정합니다.
- 다양한 규모에서 다양한 병렬 처리 방식 (여러 기기에 걸쳐 계산을 분할하는 방법)에 대해 정보에 입각한 선택을 합니다.
- 대규모 트랜스포머 모델을 학습시키고 실행하는 데 필요한 비용과 시간을 추정합니다.
- TPU 아키텍처를 활용하는 알고리즘을 설계합니다.
- 알고리즘 성능을 제한하는 요소를 명시적으로 이해하여 모델 아키텍처를 설계합니다.
기본 요건
LLM과 트랜스포머 아키텍처에 대한 기본적인 이해가 있어야 하지만 대규모로 작동하는 방식은 몰라도 됩니다. LLM 학습의 기본사항을 이해하고 JAX에 대한 기본적인 지식이 있으면 좋습니다. 트랜스포머 아키텍처에 관한 유용한 배경 자료는 다음과 같습니다.
- 그림으로 설명된 트랜스포머: 트랜스포머 아키텍처에 관한 블로그 게시물
- Attention is all you need: Transformer 원본 논문
이러한 기본 요건에 익숙해지면 특정 TPU 플랫폼에서 트랜스포머 모델에 가장 적합한 병렬 처리 방식을 쉽게 추정할 수 있습니다. 학습 및 추론에 걸리는 시간도 추정할 수 있습니다.
모델 확장 중요성
오늘날 LLM과 대부분의 소형 모델은 하드웨어 한계에 매우 근접하게 실행되므로 모델을 개발할 때는 대규모 효율성을 고려해야 합니다. 벤치마크에서 20% 승리하더라도 루프라인 효율성이 20% 저하된다면 의미가 없습니다. 유망한 모델 아키텍처는 대규모로 효율적으로 실행할 수 없거나 그렇게 하기 위한 최적화 노력이 부족하여 정기적으로 실패합니다.
모델 확장 목표는 처리량의 비례적이고 선형적인 증가를 달성하면서 학습 또는 추론에 사용되는 칩 수를 늘릴 수 있는 것입니다. 이를 강력한 확장이라고 합니다. 일반적으로 칩을 추가(병렬 처리)하면 계산 시간이 줄어들지만 칩 간 통신이 추가되는 비용도 발생합니다. 통신이 계산보다 오래 걸리면 모델이 통신에 바인딩되어 확장할 수 없습니다. 이러한 병목 현상이 발생할 위치를 예측할 수 있을 만큼 하드웨어를 잘 이해하면 이러한 병목 현상을 방지하도록 모델을 설계하거나 재구성할 수 있습니다.
다음 섹션에서는 TPU 하드웨어를 확장하는 방법과 트랜스포머 아키텍처가 어떻게 발전해 왔는지 간략하게 설명합니다. 이 정보는 새로운 아키텍처를 설계하는 연구자와 현재 세대의 LLM을 빠르게 실행하기 위해 노력하는 엔지니어 모두에게 유용합니다.
1부: 개념
이 부분에서는 루프라인 분석과 모델의 확장 능력을 제한하는 요소 (통신, 컴퓨팅, 메모리)를 설명합니다. 다음으로 TPU가 개별 칩으로, 그리고 제한된 대역폭과 지연 시간이 있는 칩 간 링크로 상호 연결된 시스템으로 작동하는 방식을 설명합니다.
- 루프라인 분석 소개: 이 섹션에서는 컴퓨팅, 통신, 메모리 한도를 기반으로 알고리즘이 얼마나 빠르게 실행되는지 근사하는 방법을 설명합니다.
- TPU 아키텍처의 작업: 이 섹션에서는 TPU의 아키텍처, TPU의 다양한 하드웨어 모듈이 작동하는 방식, 모델 학습 및 서빙에 미치는 영향을 설명합니다.
- 다중 TPU 병렬 처리를 위한 모델 샤딩: 이 섹션에서는 샤딩된 행렬 곱셈을 설명하여 모델 샤딩과 다중 TPU 병렬 처리를 자세히 살펴봅니다.
2부: 트랜스포머 확장
Transformer 아키텍처의 모든 부분을 이해하는 것이 중요합니다. 모든 행렬의 정확한 크기, 정규화가 발생하는 위치, 각 부분에 있는 매개변수와 FLOP 수 등을 알아야 합니다. 이 파트에서는 Transformer 수학을 자세히 살펴보고 학습과 추론 모두에 대해 파라미터와 FLOP를 계산하는 방법을 보여줍니다. 이를 통해 모델이 사용하는 메모리 양, 컴퓨팅 또는 통신에 소요되는 시간, 피드 포워드 블록과 관련하여 어텐션이 중요해지는 시점을 알 수 있습니다.
마지막으로 이 부분에서는 특정 크기의 모델이 주어지고 특정 수의 칩이 제공될 때 강력한 확장 조건에 머물도록 모델을 병렬화하는 방법에 관한 기본적인 질문에 답을 얻을 수 있습니다. 이 질문에 답하기 위해 이 부분에서는 여러 칩에 모델을 분할하는 데 사용되는 네 가지 기본 병렬 처리 기법인 데이터, 텐서, 파이프라인, 전문가를 설명합니다. 또한 재구체화, ZeRO 기반 모델 샤딩, 호스트 오프로드, 그라데이션 누적과 같은 메모리 요구사항을 줄이는 다른 기술도 설명합니다.
- 트랜스포머 수학 연산 소개: 이 섹션에서는 정방향 및 역방향 패스 중에 트랜스포머에서 사용되는 FLOP 수, 매개변수 수를 계산하는 계산, KV 캐시 크기에 관한 질문에 답하기 위해 수학을 살펴봅니다.
- 학습을 위한 트랜스포머 병렬화: 이 섹션에서는 FSDP, Megatron 샤딩, 파이프라인 병렬화를 조정하여 학습 효율성을 극대화하는 프로세스를 자세히 설명합니다. 최고 처리량을 달성하기 위해 고정된 칩 수에 걸쳐 특정 모델 크기와 배치 크기에 대한 최적의 분산을 결정하는 방법을 설명합니다.
- TPU에서 Llama 3 학습: 이 하위 섹션에서는 TPU에서 Llama 3를 학습하는 방법, 소요 시간, 비용을 설명합니다.
- 추론을 위한 트랜스포머 확장: 모델을 학습시킨 후에는 모델을 제공해야 합니다. 추론은 새로운 고려사항인 지연 시간을 추가하고 메모리 환경을 변경합니다. 이 섹션에서는 분리된 서빙의 작동 방식과 KV 캐시에 대해 생각하는 방법을 설명합니다.
- TPU에서 Llama 3 제공: 이 하위 섹션에서는 TPU에서 Llama 3를 제공하는 방법, 비용, 지연 시간 및 처리량 절충에 대해 설명합니다.
3부: 실제 구현
이 부분에서는 JAX를 사용하여 확장 개념을 구현하는 방법과 문제가 발생할 때 코드를 프로파일링하고 디버깅하는 방법을 설명합니다.
- TPU 프로그램 프로파일링: 실제 LLM은 복잡하며 개발, 최적화, 디버그가 어렵습니다. 이 섹션에서는 JAX + XLA 스택과 JAX/TensorBoard 프로파일러를 사용하여 실제 문제를 디버그하고 수정하는 방법을 설명합니다.
- JAX에서 TPU 프로그래밍: 이 섹션에서는 JAX API를 사용하여 계산을 병렬화하는 방법을 설명합니다.