TPU7x (Ironwood) 성능 최적화
이 가이드에서는 다중 계층 메모리 시스템 간의 데이터 이동을 효율적으로 관리하여 TPU7x(Ironwood)로 성능을 최적화하는 여러 방법을 설명합니다. 여기에는 저정밀도 학습, 샤딩, 통신 최적화, 활성화 재구체화, 범위 지정 가상 메모리 조정, 맞춤 액셀러레이터 커널과 같은 기법이 포함됩니다.
TPU7x로 성능을 최적화하려면 먼저 Ironwood 아키텍처, 특히 메모리 계층 구조와 상호 연결 토폴로지를 잘 알고 있어야 합니다. 자세한 내용은 TPU7x (Ironwood)를 참고하세요.
FP8을 사용한 저정밀도 학습
FP8 (8비트 부동 소수점)은 주로 모델 학습 및 추론을 가속화하는 데 사용되는 효율적인 수치 데이터 형식입니다. 표준 16비트 형식 (FP16 또는 BF16) 및 32비트 (FP32) 대신 8비트를 사용하여 숫자를 표현하면 TPU가 데이터를 훨씬 빠르게 처리하고 메모리를 적게 사용할 수 있습니다.
TPU7x는 FP8 데이터 유형의 내장 하드웨어 가속을 지원하여 칩당 이론상 최대 성능이 4, 614TFLOPS입니다. 이 기능을 사용하면 엔드 투 엔드 학습 시간을 크게 단축할 수 있습니다. 호환되는 작업, 특히 AI 워크로드에 일반적인 밀도 높은 행렬 곱셈의 경우 FP8을 사용하면 표준 BF16 학습에 비해 성능이 1.3배 향상될 수 있습니다. BF16과 비교할 때 FP8은 가중치 및 활성화의 최대 FLOP가 두 배로 늘어나고 메모리 사용량이 절반으로 줄어듭니다. FP8은 컴퓨팅 바운드 워크로드와 메모리 용량 또는 대역폭에 의해 제한되는 시나리오 모두의 기본 조정 레버여야 합니다.
FP8을 사용하면 다음과 같은 성능 이점이 있습니다.
- 고대역폭 메모리 (HBM) 압력 감소: 메모리 사용 공간이 작아지면 추론 중에 더 큰 모델이나 KV 캐시가 더 큰 모델이 192GB의 HBM 내에 완전히 맞출 수 있습니다. 이렇게 하면 느린 호스트 메모리로의 오프로딩을 방지할 수 있습니다.
- 유효 배치 크기 증가: 활성화에 필요한 메모리를 줄임으로써 FP8을 사용하면 더 큰 배치 크기를 사용할 수 있습니다. 이렇게 하면 데이터 병렬 처리가 개선되고 처리량이 증가하며 컴퓨팅 단위 활용도가 높아질 수 있습니다.
- 메모리 대역폭 요구사항 감소: 각 작업에 대해 데이터 양을 절반으로 이동하면 HBM-MXU 데이터 경로의 요구사항이 줄어듭니다. 데이터 이동이 일반적인 병목 현상인 시스템에서 이는 MXU가 작업으로 포화 상태를 유지하는 데 도움이 됩니다.
성능 저하가 없거나 제한적인 FP8을 사용하려면 양자화 기법을 신중하게 선택해야 합니다. FP8 학습 시 고려해야 할 몇 가지 권장사항은 다음과 같습니다.
- 확장 세부사항: 텐서별 스케일링을 기준선으로 시작합니다. 품질 또는 성능 문제가 있는 경우 축별 스케일링으로 전환합니다. 하위 채널 확장이 필요하지 않을 수 있습니다.
- 스케일링 모드: 런타임에 스케일링 요소를 계산하는 동적 스케일링은 품질을 유지하는 데 적합한 기본값입니다. 정적 스케일링은 계산을 제거하여 성능을 크게 향상할 수 있지만 올바른 스케일링 요소를 결정하려면 신중한 프로파일링이 필요하며, 특히 모델 구성이 변경되는 경우에는 모든 사용 사례에 적합하지 않을 수 있습니다. 반대로 일부 강력한 모델과 구성은 가중치 또는 활성화의 스케일을 FP8 한도로 고정하여 정확도를 유지하고 성능을 개선하면서 양자화 오버헤드를 줄일 수 있습니다.
- FP8 형식 (E4M3 및 E5M2): 일반적이고 효과적인 방법은 FP8 형식을 혼합하여 사용하는 것입니다. 예를 들어 순방향 패스에서 가중치와 활성화에 E4M3을 사용하여 E4M3의 높은 정밀도를 활용하고 역방향 패스에서 그라데이션에 E5M2를 사용하여 그라데이션의 더 넓은 동적 범위를 수용합니다.
- 반올림: 그라데이션에 확률적 반올림 대신 '가장 가까운 짝수로 반올림' (RNE)을 사용하면 품질을 유지하면서 성능과 재현성을 개선할 수 있습니다.
- MaxText에서 FP8 사용 설정: MaxText는 QWIX 양자화 라이브러리를 통해 FP8 학습을 지원합니다. 양자화를 활성화하려면 구성에서
use_qwix_quantization=true플래그를 설정합니다.
샤딩 및 동시 로드
샤딩은 대규모 모델 또는 학습 데이터를 더 작은 조각으로 나누어 여러 TPU 칩 또는 코어에 분산하는 프로세스입니다. TPU7x에서 높은 성능을 달성하려면 적절한 샤딩 전략을 선택하는 것이 중요합니다.
병렬 처리 정도를 순수하게 최대화하는 단순한 접근 방식은 통신에 바인딩되어 성능이 저하되는 경우가 많습니다. 가장 좋은 방법은 메모리 제약 조건을 충족하는 가장 간단한 샤딩 전략을 선택하는 것입니다. 이렇게 하면 통신 오버헤드가 최소화되고 컴퓨팅 단위를 효율적으로 사용할 수 있습니다.
샤딩 전략을 선택하기 전에 성능 조정 노력의 첫 번째 단계는 산술 강도 분석이어야 합니다. 이 분석은 특정 계산이 컴퓨팅, 메모리 대역폭 또는 상호 연결 대역폭에 의해 제한되는지 여부를 확인합니다. 이 값은 이동해야 하는 데이터의 바이트 수에 대한 부동 소수점 연산의 비율로 계산됩니다.
산술 강도가 높으면 컴퓨팅 바운드 워크로드를 나타냅니다. 산술 강도가 낮으면 메모리 또는 통신 바운드 워크로드를 나타내며, 여기서는 HBM에서 또는 ICI 네트워크를 통해 데이터를 이동할 수 있는 속도에 따라 성능이 제한됩니다. 이 분석을 통해 이상적인 배치 크기와 샤딩 전략을 알 수 있습니다. 예를 들어 통신에 바인딩된 워크로드는 높은 수준의 텐서 병렬 처리와 같이 통신을 더 많이 도입하는 샤딩 전략의 이점을 누릴 수 없습니다.
샤딩 전략 결정 프레임워크
MaxText는 다양한 샤딩 전략을 제공합니다. 최적의 선택은 모델 아키텍처, 시퀀스 길이, 통신 오버헤드에 대한 컴퓨팅 부하의 균형 필요성에 따라 달라집니다.
- 완전 샤딩된 데이터 동시 로드 (FSDP): 데이터 동시 로드에 권장되는 기본 전략입니다. FSDP는 데이터 병렬 처리 기기 간에 모델 가중치, 그라데이션, 옵티마이저 상태를 샤딩합니다. 계산 중에 각 기기는 All-Gather 작업을 실행하여 로컬 마이크로 배치에 필요한 전체 가중치를 가져옵니다. FSDP는 기기당 배치 크기가 이 All-Gather 통신의 지연 시간을 숨길 만큼 충분히 큰 경우 매우 효과적입니다. Mixture-of-Experts (MoE) 모델의 경우 산술 강도 계산에서 희소성을 고려해야 합니다.
- 텐서 병렬 처리 (TP): TP는 기기 간에 개별 텐서를 샤딩합니다. 일반적으로 텐서는 다층 퍼셉트론 (MLP) 및 어텐션 블록의 가중치 행렬입니다. 하드웨어의 높은 산술 강도 (11.5k)로 인해 ICI보다 TP를 실행하려면 모델의 차원에 매우 높은 요구사항이 적용되며, TP를 사용하려고 하면 시스템이 통신에 바인딩될 수 있습니다.
- Expert Parallelism (EP): MoE 모델을 학습하는 데 필요한 표준 전략입니다. EP는 여러 기기에 '전문가' 레이어를 샤딩하고, All-to-All 통신 집합을 사용하여 토큰을 지정된 전문가 기기로 라우팅합니다. 모델의 MLP 차원이 루프라인에 근접할 만큼 크면 EP가 효율적일 수 있습니다.
- 컨텍스트 병렬 처리 (CP): CP는 시퀀스 길이가 매우 긴 모델을 학습하는 데 필수적인 특화된 전략입니다. 기본 기능은 시퀀스 길이에 따라 2차 함수로 증가하고 HBM 용량을 초과할 수 있는 활성화의 메모리 소비를 관리하는 것입니다. CP는 활성화 텐서의 시퀀스 차원을 샤딩하므로 기기별 배치 크기의 일부를 사용할 수 있습니다. CP는 FSDP보다 더 많은 통신을 도입하므로 일반적인 규칙은 메모리 제약 조건을 충족하고 배치 축 샤드가 정수로 유지되도록 하는 데 필요한 최소한의 CP를 사용하는 것입니다.
다음 표에서는 일반적인 워크로드 유형과 최적의 샤딩 전략을 매핑합니다.
| 워크로드 유형 | 권장 기본 샤딩 | 보조 샤딩 | 주요 병목 현상 | 근거 |
|---|---|---|---|---|
| 밀도 모델 - 짧은 시퀀스 | FSDP | 해당 사항 없음 | 재구체화, FF Matmuls | FSDP가 가장 균형이 잘 잡혀 있습니다. 짧은 시퀀스의 경우 활성화 메모리가 주요 문제가 아닐 수 있습니다. 키는 FSDP의 가중치 All-Gather를 숨길 수 있을 만큼 큰 전역 배치입니다. 배치 크기가 증가하면 활성화 크기가 증가하므로 이 구성에 메모리가 부족하지 않도록 적절한 재구체화 정책이 필요합니다. |
| 밀도 모델 - 긴 시퀀스 | FSDP | CP | 플래시 어텐션, 활성화 메모리 | 활성화 메모리가 기본 제약 조건이 됩니다. CP는 기기별 배치 크기를 사용 설정하고 메모리 부족 (OOM) 문제를 방지하는 데 필요합니다. 플래시 어텐션이 컴퓨팅 및 낭비되는 시간의 주요 원인입니다. |
| MoE 모델 - 짧은 시퀀스 | FSDP + EP | 해당 사항 없음 | All-to-All (전문가 라우팅), rematerialization | MoE 모델은 전문가를 샤딩하기 위해 EP가 필요합니다. 토큰 라우팅을 위한 All-to-All 통신은 오버랩되어야 하는 주요 병목 현상입니다. 재료화도 상당한 낭비의 원인입니다. |
| MoE 모델 - 매우 큰 규모 | FSDP + EP + PP | 모델 병렬 처리 (MP) | 이전에 언급된 모든 병목 현상과 파이프라인 버블 | 단일 포드의 메모리를 초과하는 모델의 경우 포드 간에 레이어를 샤딩하려면 PP가 필요합니다. 이로 인해 DCN 통신 및 파이프라인 버블 오버헤드가 발생합니다. 이는 신중한 조정이 필요한 매우 복잡한 구성입니다. |
커뮤니케이션 최적화
TPU7x에서 통신과 계산을 중첩하는 기본 메커니즘을 SparseCore 집단 오프로드라고 합니다. Ironwood 아키텍처에는 ICI 패브릭을 통한 데이터 이동을 관리할 수 있는 독립적인 제어 스레드 역할을 하는 전용 SparseCore 단위가 포함되어 있습니다. 이를 통해 TensorCore에서 발생하는 기본 계산과 병렬로 집단 통신 작업 (예: All-Gather 또는 Reduce-Scatter)을 실행할 수 있습니다. TPU7x에서 비동기 집합에 권장되는 방법입니다. 권장 플래그를 사용하여 가장 일반적인 집합의 오프로드를 사용 설정합니다.
활성화 재구체화
활성화 재구체화(그라데이션 체크포인트라고도 함)는 모델의 HBM 공간을 줄이는 기본 기법입니다. 역방향 전달 중에 사용할 순방향 전달의 모든 중간 활성화를 HBM에 저장하는 대신 몇 가지 주요 활성화 (체크포인트)만 저장하고 역방향 전달 중에 필요에 따라 다른 활성화를 다시 계산합니다. 이렇게 하면 계산이 증가 (표준 트랜스포머 블록의 경우 약 25~30% 추가 FLOP)하는 대신 상당한 양의 메모리가 절약됩니다.
리매터리얼라이제이션을 얼마나 적극적으로 적용할지 결정하는 것은 기본 병목 현상에 전적으로 의존하는 중요한 조정 매개변수이며, 이는 시퀀스 길이에 따라 달라지는 경우가 많습니다.
긴 시퀀스 워크로드 (예: 128k): 이러한 경우 활성화 텐서의 크기가 HBM의 주요 소비자입니다. 워크로드는 일반적으로 메모리 바운드입니다. 따라서 적극적인 재구체화 정책을 적용하는 것이 매우 유용합니다. 메모리 절약으로 인해 메모리 부족 오류 없이 학습을 진행할 수 있으며 배치 크기도 더 커질 수 있습니다. 재계산의 컴퓨팅 오버헤드는 가치 있는 절충안입니다.
짧은 시퀀스 워크로드 (예: 8k): 이러한 경우 활성화 메모리는 훨씬 덜 문제가 되며 워크로드가 컴퓨팅에 더 많이 바인딩될 가능성이 높습니다. 재구체화의 계산 오버헤드는 비효율성의 가장 큰 단일 원인이 될 수 있습니다.
MaxText에서 재구체화 정책 조정
MaxText는 remat_policy 플래그를 사용하여 구성된 사전 설정 및 맞춤 정책을 통해 리매터리얼라이제이션을 세부적으로 제어할 수 있습니다.
사전 설정된 정책
MaxText는 다음과 같은 기본 제공 정책을 제공합니다.
full: 가장 적극적인 정책으로, 거의 모든 것을 다시 구체화합니다. 이렇게 하면 HBM 사용량이 최소화되지만 재계산 오버헤드가 최대화됩니다. 메모리 제약이 매우 심한 긴 시퀀스 시나리오에 적합합니다.minimal: 가장 소극적인 정책으로, 대부분의 활성화를 저장합니다. 이렇게 하면 HBM 사용량이 최대화되지만 재계산은 최소화됩니다. 메모리가 문제가 되지 않는 짧은 시퀀스, 컴퓨팅 바운드 워크로드에 가장 적합합니다.- 중간 정책:
save_dot_with_context_except_mlp,save_qkv_proj,save_out_proj과 같은 옵션은 비용이 많이 드는 점곱 연산의 출력을 선택적으로 검사점 지정하는 동시에 저렴한 요소별 연산을 다시 구체화하여 다양한 절충안을 제공합니다.
맞춤 정책
더 세밀하게 제어하려면 remat_policy를 custom로 설정하면 됩니다. 이를 통해 모델의 디코딩 모듈 내에서 개별 레이어의 동작을 지정할 수 있습니다. 각 레이어에는 다음 세 가지 동작 중 하나를 할당할 수 있습니다.
device: 활성화는 TPU 기기의 HBM에 저장됩니다.remat: 활성화가 삭제되고 역방향 전달 중에 다시 구체화됩니다.offload: 활성화가 HBM에서 CPU 호스트의 메모리로 이동하여 PCIe 전송 지연 시간의 비용으로 HBM이 확보됩니다.
범위가 지정된 VMEM 조정
플래시 어텐션과 같은 커널 성능은 커널에서 선택한 타일 크기에 따라 달라지며, 이 크기는 사용 가능한 벡터 메모리 (VMEM)에 의해 제한됩니다. TPU7x 칩에는 64MB의 VMEM이 있으며, 이는 현재 범위 (범위가 지정된 VMEM)와 향후 가중치 프리패치 간에 분할될 수 있습니다. 범위가 지정된 VMEM을 늘리면 커널에서 타일 크기를 늘릴 수 있어 메모리 정체가 줄어들고 커널 성능이 향상될 수 있습니다. xla_tpu_scoped_vmem_limit_kib (LIBTPU_INIT_ARGS)를 설정하여 범위가 지정된 VMEM 크기를 변경할 수 있습니다. 이를 통해 커널 성능과 엔드 투 엔드 성능 한계를 살펴볼 수 있습니다.
범위가 지정된 VMEM 크기를 최적화하면 범위가 지정된 VMEM이 커널 내 타일 크기에 더 큰 초매개변수 검색 공간을 제공하므로 맞춤 Pallas 커널 성능에 간접적으로 영향을 줄 수 있습니다.
Tokamax 커널
Tokamax는 고도로 최적화된 TPU 커널이 많은 고성능 JAX 커널 라이브러리로 다음과 같은 여러 일반적인 하드웨어 관련 병목 현상을 해결합니다.
- 스플래시 어텐션: 스플래시 어텐션은 표준 어텐션의 HBM 병목 현상을 제거하는 기본 어텐션 구현으로 사용되며 TPU에서 가장 효율적인 어텐션 구현을 사용합니다.
- Megablox 그룹화된 행렬 곱셈 (GMM): MoE 워크로드의 경우 Megablox는 불규칙한 활성화 표현식을 통해 계산하여 그룹화된 행렬 곱셈을 효율적으로 처리합니다. 불규칙한 차원을 효율적으로 매핑하여 LHS의 불규칙한 행 그룹과 해당 전문가 행렬 간의 행렬 곱셈을 계산하므로 배치를 고정된 크기로 패딩할 필요가 없습니다.
tune-jax을 사용한 경험적 조정:tune-jax라이브러리에는 최적의 블록 크기를 경험적으로 검색하는 유틸리티가 있습니다. 기본 커널 크기는 최적이 아닌 경우가 많습니다. 튜닝을 통해 하드웨어 친화적인 VMEM 타일 크기를 선택하여 하드웨어 활용도를 극대화할 수 있습니다.- 최대 로짓 추정치:
max_logit_const값을 설정하여 Tokamax Splash 어텐션 커널을 추가로 최적화할 수 있습니다. 설정된 경우 어텐션의 소프트맥스 작업 (softmax(Q * KT)) 중에 최대 로짓의 감소 계산을 대체하여 일부 계산 및 동기화 오버헤드를 줄입니다. MaxText에서는use_max_logits_estimate구성으로 구현되며, 이 구성은None(사용 중지됨) 또는 부동 소수점 값으로 설정할 수 있습니다. 특정 모델의 로짓 범위가 추정치와 호환되는지 확인하여 숫자 오버플로를 방지합니다. 이 값이 설정된 경우 수렴 테스트를 권장합니다.