JAX를 사용하여 Cloud TPU에서 프로덕션 AI 빌드
JAX AI 스택은 Google에서 지원하는 구성 가능 라이브러리 모음을 사용하여 JAX 숫자 핵심을 확장해 극단적 규모의 머신러닝에 사용할 수 있는 강력한 엔드 투 엔드 오픈소스 플랫폼으로 발전시킵니다. 따라서 JAX AI 스택은 전체 ML 수명 주기를 처리하는 포괄적이면서 강력한 생태계로 구성됩니다.
산업 규모의 기반: JAX AI 스택은 대규모로 설계되었으며 ML Pathways를 활용하여 수만 개의 칩에서 학습을 조정하고 Orbax를 사용하여 복원력이 우수한 높은 처리량 비동기 체크포인트를 지원하므로 최첨단 모델의 프로덕션 등급 학습이 가능합니다.
프로덕션에 즉시 사용 가능한 완전 툴킷: JAX AI 스택은 유연한 모델 작성을 위한 Flax, 구성 가능 최적화 전략을 위한 Optax, 재현 가능한 대규모 실행에 필수적인 결정적 데이터 파이프라인을 위한 Grain 등 전체 개발 프로세스에 사용할 수 있는 포괄적인 라이브러리 세트를 제공합니다.
최고의 전문 성능: 하드웨어 활용도가 극대화되도록 JAX AI 스택은 최첨단 커스텀 커널을 위한 Tokamax, 학습 및 추론 속도를 향상시키는 비침입적 양자화를 위한 Qwix, 심층적인 하드웨어 통합 성능 프로파일링을 위한 XProf 등 전문 라이브러리를 제공합니다.
프로덕션 전체 경로: JAX AI 스택은 연구부터 배포에 이르기까지 원활한 전환을 제공합니다. 여기에는 파운데이션 모델 학습에 사용할 수 있는 확장 가능한 참조인 MaxText, 최첨단 강화 학습(RL) 및 정렬에 사용되는 Tunix, vLLM TPU 통합 및 JAX 서빙 런타임이 포함된 통합 추론 솔루션이 포함됩니다.
JAX AI 스택 철학은 느슨하게 결합된 한 가지 구성요소이며 각 구성요소는 한 가지 작업을 잘 수행합니다. JAX 자체는 범위를 좁혀 모놀리식 ML 프레임워크가 아닌 효율적인 배열 작업과 프로그램 변환에 중점을 둡니다. 이 생태계는 이 핵심 프레임워크를 기반으로 빌드되어 ML 모델 학습과 기타 유형의 워크로드(예: 과학 컴퓨팅)와 관련된 다양한 기능을 제공합니다.
느슨하게 결합된 구성요소로 구성된 이 시스템을 사용하면 요구사항에 가장 적합한 방식으로 라이브러리를 선택하고 결합할 수 있습니다. 소프트웨어 엔지니어링 관점에서 이 아키텍처를 사용하면 핵심 프레임워크 불안정화 및 출시 주기에 구애받지 않고 기존에 핵심 프레임워크 구성요소(예: 데이터 파이프라인 및 체크포인트)로 간주되던 기능을 반복적으로 업데이트할 수 있습니다. 대부분의 기능이 모놀리식 프레임워크 변경사항이 아닌 라이브러리에서 구현되므로 핵심 숫자 라이브러리는 기술 환경의 향후 변화에 더 안정적이고 적응할 수 있습니다.
다음 섹션에서는 JAX AI 스택, 주요 기능, 기능 설계 배경, 최신 ML 워크로드용 지속 가능한 플랫폼을 빌드하기 위해 이러한 요소가 결합되는 방식에 대한 기술을 간략하게 설명합니다.
JAX AI 스택 및 기타 생태계 구성요소
| 구성요소 | 기능/설명 |
|---|---|
| JAX AI 스택 코어 및 구성요소1 | |
| JAX | 가속기 지향 배열 계산 및 프로그램 변환(JIT, grad, vmap, pmap) |
| Flax | 직관적인 모델 생성 및 수정에 사용되는 유연한 신경망 작성 라이브러리입니다. |
| Optax | 구성 가능 경사 처리 및 최적화 변환 라이브러리입니다. |
| Orbax | 영웅 규모 학습 복원력을 위한 '모든 규모' 분산 체크포인트 라이브러리입니다. |
| Grain | 확장 가능하고 결정적이며 체크포인트가 가능한 입력 데이터 파이프라인 라이브러리입니다. |
| JAX AI 스택 - 인프라 | |
| XLA | TPU, CPU, GPU용 오픈소스 머신러닝 컴파일러입니다. |
| Pathways | 수만 개의 칩에서 계산을 조정하는 분산 런타임입니다. |
| JAX AI 스택 - 고급 개발 | |
| Pallas | Python으로 구현된 저급 고성능 커스텀 커널을 작성하기 위한 JAX 확장 프로그램입니다. |
| Tokamax | 최첨단 고성능 커스텀 커널(예: 어텐션)의 선별된 라이브러리입니다. |
| Qwix | 양자화(PTQ, QAT, QLoRA)를 위한 포괄적인 비침입적 라이브러리입니다. |
| JAX AI 스택 - 애플리케이션 | |
| MaxText/MaxDiffusion | 파운데이션 모델(예: LLM 및 확산) 학습에 사용되는 확장 가능한 대표 참조 프레임워크입니다. |
| Tunix | 최첨단 사후 학습 및 정렬(RLHF, DPO)을 위한 프레임워크입니다. |
| vLLM | vLLM 프레임워크의 기본 제공 통합을 사용하는 고성능 LLM 추론 솔루션입니다. |
| XProf | 시스템 전체 성능 분석을 위한 심층적인 하드웨어 통합 프로파일러입니다. |
1jax-ai-stack Python 패키지에 포함되어 있습니다.
그림 1: JAX AI 스택 및 생태계 구성요소

아키텍처 필수 요소: 프레임워크를 뛰어넘는 성능
모델 아키텍처가 예를 들어 멀티모달 전문가 조합(MoE) 트랜스포머에서 수렴됨에 따라 최고 성능을 추구하는 과정에서 메가커널이 등장하고 있습니다. 메가커널은 효과적인 특정 모델 하나의 전체 순방향 패스(또는 상당 부분)이며 NVIDIA GPU에서 CUDA SDK와 같은 저급 API를 통해 수동으로 코딩됩니다. 이 방식은 컴퓨팅, 메모리, 통신을 적극적으로 중첩하여 하드웨어 활용도를 극대화합니다. 연구 커뮤니티의 최근 연구에 따르면 이 방식을 사용하면 GPU에서 추론할 때 처리량이 크게 향상될 수 있습니다(경우에 따라 22% 이상). 이러한 추세는 추론에만 국한되지 않습니다. 상당한 효율성 향상을 달성하기 위해 일부에서 대규모로 학습할 때 낮은 수준의 하드웨어 제어를 활용하는 것으로 보입니다.
이러한 추세가 가속화되면 현재 존재하는 모든 고급 프레임워크에서 관련성이 저하될 위험이 있습니다. 성숙하고 안정적인 아키텍처 성능에 궁극적으로 중요한 것은 하드웨어에 대한 저급 액세스이기 때문입니다. 이는 모든 최신 ML 스택에 있어 공통된 과제입니다. 즉, 고급 프레임워크의 생산성과 유연성을 희생시키지 않고 전문가 수준의 하드웨어 제어를 어떻게 제공하느냐입니다.
TPU에서 이러한 수준의 성능을 제공하려면 이러한 고도로 전문화된 커널을 개발할 수 있도록 생태계에 하드웨어에 더 가까운 API 레이어가 노출되어야 합니다. JAX 스택은 XLA 컴파일러의 자동화된 고급 최적화부터 Pallas 커널 작성 라이브러리의 세분화된 수동 제어에 이르기까지 추상화 연속체를 제공하여 이 문제를 해결하도록 설계되었습니다(그림 2 참조).
그림 2: JAX 추상화 연속체

핵심 JAX AI 스택
핵심 JAX AI 스택은 모델 개발의 기반을 제공하는 5가지 주요 라이브러리로 구성됩니다.
JAX: 구성 가능 고성능 프로그램 변환을 위한 기반
JAX는 고성능 수치 계산과 대규모 머신러닝을 위해 설계된 가속기 지향 배열 계산과 프로그램 변환에 사용할 수 있는 Python 라이브러리입니다. 함수형 프로그래밍 모델과 NumPy와 유사한 API가 포함된 JAX는 고급 라이브러리를 위한 견고한 기반을 제공합니다.
컴파일러 우선 설계로 인해 JAX는 적극적인 전체 프로그램 분석, 최적화, 하드웨어 타겟팅을 위한 XLA(XLA 섹션 참조)를 활용하여 확장성을 본질적으로 강화합니다. 함수형 프로그래밍(예: 순수 함수)에 대한 JAX의 중점은 핵심 프로그램 변환을 더 쉽게 처리하고 무엇보다도 구성 가능하게 만드는 것입니다.
이러한 핵심 변환을 혼합하고 일치시켜 모델 크기, 클러스터 크기, 하드웨어 유형에 관계없이 워크로드의 고성능과 확장성을 달성할 수 있습니다.
- jit: Python 함수를 최적화된 융합 XLA 실행 파일로 적시에 컴파일합니다.
- grad: 자동 미분으로, 정방향 및 역방향 모드와 고차 미분을 지원합니다.
- vmap: 자동 벡터화로, 함수 로직을 수정하지 않고도 원활한 일괄 처리 및 데이터 병렬 구조를 지원합니다.
- pmap / shard_map: 여러 기기(예: TPU 코어)에서 병렬화를 자동화하여 분산 학습의 기반을 형성합니다.
XLA의 GSPMD(범용 SPMD) 모델과 원활하게 통합되므로 JAX는 최소한의 코드 변경으로 대규모 TPU 포드에서 계산을 자동으로 병렬화할 수 있습니다. 대부분의 경우 확장에는 상위 수준 샤딩 주석만 필요합니다.
Flax: 유연한 신경망 작성
Flax는 모델 빌드에 직관적인 객체 지향적 방식을 제공하여 JAX에서 신경망 생성, 디버깅, 분석을 간소화합니다. JAX의 함수형 API는 강력하며 성능 저하 없이 PyTorch와 같은 프레임워크에 익숙한 개발자에게 더욱 친숙한 레이어 기반 추상화를 제공합니다.
이러한 설계로 학습된 모델 구성요소를 더욱 쉽게 수정하거나 결합할 수 있습니다.
LoRA 및 양자화와 같은 기법을 사용하려면 조작 가능한 모델 정의가 필요하며 Flax의 NNX API는 Pythonic 인터페이스를 통해 이러한 정의를 제공합니다. NNX는 모델 상태를 캡슐화하여 사용자 인지 부하를 줄이고 모델 계층 구조를 프로그래매틱 방식으로 순회하고 수정할 수 있습니다.
주요 강점:
- 직관적인 객체 지향 API: 모델 구성을 간소화하고 하위 모듈 교체 및 부분 초기화와 같은 고급 사용 사례를 지원합니다.
- Core JAX와 일관성 유지: Flax는 JAX의 함수형 패러다임과 완전히 호환되는 리프트된 변환을 제공하여 더욱 개발자 친화적인 JAX의 전체 성능을 제공합니다.
Optax: 구성 가능 경사 처리 및 최적화 전략
Optax는 JAX용 경사 처리 및 최적화 라이브러리입니다. 모델 빌더에게 다른 애플리케이션 중에서 딥 러닝 모델을 학습시키기 위해 커스텀 방식으로 재조합될 수 있는 빌딩 블록을 제공하도록 설계되었습니다. 핵심 JAX 라이브러리의 기능을 기반으로 빌드되어 ML 모델을 학습시키는 데 사용할 수 있는 손실 함수와 옵티마이저 함수, 관련 기술의 잘 테스트된 고성능 라이브러리를 제공합니다.
동기
손실 계산 및 최소화는 ML 모델 학습을 지원하는 핵심 요소입니다. 자동 미분을 지원하는 핵심 JAX 라이브러리는 모델을 학습시키는 숫자 기능을 제공하지만 인기 있는 옵티마이저(예: RMSProp 또는 Adam) 또는 손실(예: CrossEntropy 또는 MSE)의 표준 구현을 제공하지 않습니다. 이러한 함수를 구현할 수 있지만(일부 고급 개발자는 이렇게 함) 옵티마이저 구현에 버그가 있으면 모델 품질 문제를 진단하기가 어려워집니다. 사용자가 이러한 중요한 부분을 구현하는 대신 Optax는 정확성과 성능이 테스트된 알고리즘을 구현합니다.
최적화 이론 분야는 연구 영역에 속하지만 학습에서 중심적인 역할을 하므로 프로덕션 ML 모델 학습에 필수적인 부분이기도 합니다. 이 역할을 수행하는 라이브러리는 빠른 연구 반복을 수용할 수 있을 만큼 충분히 유연해야 하며 프로덕션 모델 학습에 사용할 수 있을 만큼 충분히 강력하고 성능이 우수해야 합니다. 또한 표준 방정식과 일치하는 최신 알고리즘 구현을 테스트해야 합니다. Optax 라이브러리는 모듈식 구성 가능 아키텍처와 올바르고 읽기 쉬운 코드에 중점을 두어 이를 달성하도록 설계되었습니다.
디자인
Optax는 읽기 쉽고 테스트를 거친 효율적인 알고리즘 구현을 제공함으로써 연구 속도와 연구에서 프로덕션으로의 전환이 모두 향상되도록 설계되었습니다. Optax는 딥 러닝 컨텍스트 그 이상으로 사용되지만 이 컨텍스트에서는 JAX 철학에 따라 순수 함수형 방식으로 구현된 잘 알려진 손실 함수, 최적화 알고리즘, 경사 변환의 모음으로 보일 수 있습니다. 잘 알려진 손실 및 옵티마이저 모음을 통해 사용자는 쉽고 자신감 있게 시작할 수 있습니다.
Optax의 모듈식 방식을 사용하면 옵티마이저 여러 개를 연결한 후 다른 일반적인 변환(예: 경사 제한)을 수행하고 MultiStep 또는 Lookahead와 같은 일반적인 기법을 사용하여 옵티마이저를 래핑해 몇 줄의 코드로 강력한 최적화 전략을 얻을 수 있습니다. 유연한 인터페이스를 통해 새로운 최적화 알고리즘을 연구하고 샴푸나 뮤온과 같은 강력한 2차 최적화 기법을 사용할 수 있습니다.
# Optax implementation of a RMSProp optimizer with a custom learning rate
# schedule, gradient clipping and gradient accumulation.
optimizer = optax.chain(
optax.clip_by_global_norm(GRADIENT_CLIP_VALUE),
optax.rmsprop(learning_rate=optax.cosine_decay_schedule(init_value=lr,decay_steps=decay)),
optax.apply_every(k=ACCUMULATION_STEPS)
)
# The same thing, in PyTorch
optimizer = optim.RMSprop(model_params, lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=TOTAL_STEPS)
for i, (inputs, targets) in enumerate(data_loader):
# ... Training loop body ...
if (i + 1) % ACCUMULATION_STEPS == 0:
torch.nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIP_VALUE)
optimizer.step()
scheduler.step()
optimizer.zero_grad()
이전 코드 스니펫에서는 커스텀 학습률, 경사 제한, 경사 누적으로 옵티마이저를 설정하는 방법을 보여줍니다.
주요 강점
- 강력한 라이브러리: 정확성과 가독성에 중점을 둔 손실, 옵티마이저, 알고리즘의 포괄적인 라이브러리를 제공합니다.
- 모듈식 연결 가능한 변환: 이 유연한 API를 사용하면 학습 루프를 수정하지 않고도 강력하고 복잡한 최적화 전략을 선언적으로 만들 수 있습니다.
- 기능적 및 확장 가능: 순수 함수형 구현은 JAX의 병렬화 메커니즘(예: pmap)과 원활하게 통합되므로 같은 코드를 사용하여 단일 호스트에서 대규모 클러스터로 확장할 수 있습니다.
Orbax/TensorStore - 대규모 분산 체크포인트
Orbax는 단일 기기부터 대규모 분산 학습에 이르기까지 모든 규모에 맞게 설계된 JAX용 체크포인트 라이브러리입니다. 목표는 단편화된 체크포인트 구현을 통합하고 비동기 및 멀티 계층 체크포인트와 같은 중요한 성능 기능을 다양한 잠재고객에게 제공하는 것입니다. Orbax는 대규모 학습 작업에 필요한 복원력을 지원하고 체크포인트를 게시하는 데 사용할 수 있는 유연한 형식을 제공합니다.
전체 시스템 상태에 대한 스냅샷을 만드는 일반화된 체크포인트 및 복원 시스템과 달리 Orbax를 사용한 ML 체크포인트는 학습 모델 가중치, 옵티마이저 상태, 데이터 로더 상태를 재개하는 데 필수적인 정보만 선택적으로 유지합니다. 이 타겟팅된 방식은 액셀러레이터 다운타임을 최소화합니다. Orbax는 I/O 작업과 계산을 중첩하여 다운타임을 최소화하며 이는 대규모 워크로드에 중요한 기능입니다. 시간 액셀러레이터는 기기에서 데이터 전송을 호스팅하는 기간으로 줄어든 유휴 상태이며 이는 다음 학습 단계와 겹칠 수 있으므로 체크포인트가 성능에 영향을 미치지 않습니다.
핵심은 Orbax는 배열 데이터를 효율적으로 동시에 읽고 쓰는 데 TensorStore를 사용한다는 점입니다. Orbax API는 이러한 복잡성을 추상화하여 사용자 친화적인 인터페이스를 제공해 JAX의 표준 모델 표현인 PyTree를 처리합니다.
주요 강점:
- 광범위한 채택: 월별 다운로드 수가 수백만 건에 달하는 Orbax는 ML 아티팩트를 공유하는 일반적인 매개체 역할을 합니다.
- 복잡성 간소화: Orbax는 비동기 저장, 원자성, 파일 시스템 세부정보를 포함한 분산 체크포인트 복잡성을 추상화합니다.
- 유연성: Orbax를 사용하면 일반적인 사용 사례에 사용할 수 있는 API를 제공하는 동시에 전문적인 요구사항을 처리하도록 워크플로를 맞춤설정할 수 있습니다.
- 성능 및 확장 가능: 비동기 체크포인트, 효율적인 스토리지 형식(OCDBT), 지능형 데이터 로드 전략과 같은 기능을 통해 Orbax는 수만 개의 노드가 포함된 학습 실행으로 확장할 수 있습니다.
Grain: 결정적이고 확장 가능한 입력 데이터 파이프라인
Grain은 JAX 모델을 학습하고 평가할 수 있도록 사용 데이터를 읽고 처리하는 Python 라이브러리입니다. 유연하고 빠르며 결정적이며 대규모 워크로드를 성공적으로 학습하는 데 필수적인 체크포인트와 같은 고급 기능을 지원합니다. 이 도구는 많이 사용되는 데이터 형식과 스토리지 백엔드를 지원하며 기본적으로 지원되지 않는 사용자별 형식과 백엔드로 지원을 확장할 수 있는 유연한 API도 제공합니다. Grain은 기본적으로 JAX와 함께 작동하도록 설계되었지만 독립적인 프레임워크로, JAX를 실행할 필요가 없고 다른 프레임워크와 함께 사용할 수도 있습니다.
동기
데이터 파이프라인은 학습 인프라에서 중요한 부분을 형성합니다. 일반적인 변환을 효율적으로 표현할 수 있도록 유연해야 하며 항상 액셀러레이터를 계속 사용할 수 있을 만큼 충분한 성능이 확보되어야 합니다. 또한 여러 스토리지 형식과 백엔드를 수용할 수 있어야 합니다. 단계 시간이 길기 때문에 대규모 모델을 대규모로 학습하려면 결정론과 재현성에 중점을 둔 일반 학습 워크로드에 필요한 것 이상으로 데이터 파이프라인에 대한 추가 요구사항이 필요합니다2. Grain 라이브러리는 이러한 니즈를 충족하는 유연한 아키텍처로 설계되었습니다.
2PaLM 논문의 섹션 5.1에서 저자는 경사 제한을 사용 설정했음에도 불구하고 매우 큰 손실 급증이 관찰되었다고 언급합니다. 문제가 되는 데이터 배치를 삭제하고 손실이 급증하기 전에 체크포인트에서 학습을 다시 시작함으로써 이 문제를 해결했습니다. 이 해결책은 완전히 결정적이고 재현 가능한 학습 설정에서만 가능합니다.
디자인
최고 수준에서 입력 파이프라인을 구성하는 방법에는 두 가지가 있습니다. 하나는 데이터 작업자의 별도 클러스터로 구성하는 것이고 다른 하나는 액셀러레이터를 구동하는 호스트에 데이터 작업자를 공동 배치하는 것입니다. Grain은 다양한 이유로 후자를 선택합니다.
액셀러레이터는 일반적으로 학습 단계에서 유휴 상태인 강력한 호스트와 결합되므로 입력 데이터 파이프라인을 실행하는 데 적합합니다. 이 구현에는 추가적인 이점이 있습니다. 입력과 컴퓨팅 전반에서 샤딩의 일관된 뷰를 제공하므로 데이터 샤딩 뷰가 간소화됩니다. 데이터 작업자를 액셀러레이터 호스트에 배치하면 호스트 CPU가 포화될 수 있는 위험이 있다고 주장할 수 있지만 RPC를 사용하여 컴퓨팅 집약적 변환을 다른 클러스터로 오프로드할 수 있습니다3.
API 측면에서 여러 프로세스를 지원하는 순수 Python 구현과 유연한 API와 함께 Grain을 사용하면 잘 이해된 변환 패러다임을 기반으로 여러 파이프라인 단계를 함께 구성하여 임의로 복잡한 데이터를 변환할 수 있습니다.
Grain은 기본적으로 ArrayRecord 및 Bagz와 같은 효율적인 임의 액세스 데이터 형식과 Parquet 및 TFDS와 같은 인기 있는 다른 데이터 형식을 지원합니다. Grain에는 로컬 파일 시스템에서 읽기와 Cloud Storage에서 읽기가 기본적으로 포함되어 있습니다. 널리 사용되는 스토리지 형식과 백엔드 지원과 더불어 스토리지 레이어에 대한 명확한 추상화를 통해 기존 데이터 소스를 추가하거나 Grain 라이브러리와 호환되도록 래핑할 수 있습니다.
3이는 멀티모달 데이터 파이프라인이 작동하는 데 필요한 방식입니다. 예를 들어 이미지 및 오디오 토크나이저는 자체 액셀러레이터의 자체 클러스터에서 실행되는 모델 자체이며 입력 파이프라인은 RPC 호출을 통해 데이터 예시를 토큰 스트림으로 변환합니다.
주요 강점
- 결정적 데이터 피드: 데이터 작업자를 액셀러레이터와 공동 배치하고 안정적인 전역 셔플 및 체크포인트 가능 반복자와 결합하면 Orbax를 사용하여 일관된 스냅샷에서 모델 상태와 데이터 파이프라인 상태를 함께 체크포인트할 수 있으므로 학습 프로세스의 결정론이 향상됩니다.
- 강력한 데이터 변환을 지원하는 유연한 API: 유연한 순수 Python 변환 API를 사용하면 입력 처리 파이프라인 내에서 다양한 데이터를 변환할 수 있습니다.
- 여러 형식과 백엔드에 대한 확장 가능한 지원: 확장 가능한 데이터 소스 API는 인기 있는 스토리지 형식과 백엔드를 지원하며 새 형식과 백엔드에 대한 지원을 추가할 수 있습니다.
- 강력한 디버깅 인터페이스: 데이터 파이프라인 시각화 도구와 디버그 모드를 사용하면 데이터 파이프라인의 성능을 검사하고 디버그하고 최적화할 수 있습니다.
확장된 JAX AI 스택
풍부한 전문 라이브러리 생태계는 핵심 스택 외에도 엔드 투 엔드 ML 개발에 필요한 인프라, 고급 도구, 애플리케이션 레이어 솔루션을 제공합니다.
기본 인프라: 컴파일러 및 런타임
XLA: 하드웨어 독립적인 컴파일러 중심 엔진
동기
XLA(Accelerated Linear Algebra)는 Google의 도메인별 컴파일러로, JAX에 통합되어 있으며 TPU, CPU, GPU 하드웨어 기기를 지원합니다. XLA는 TPU, GPU, CPU를 타겟팅하는 하드웨어 독립 코드 생성기로 설계되었습니다.
XLA 컴파일러의 컴파일러 우선 설계는 빠르게 진화하는 연구 환경에서 지속적인 이점을 만드는 기본적인 아키텍처 선택입니다. 반면 다른 생태계의 일반적인 커널 중심 방식은 성능을 위해 수동으로 최적화된 라이브러리를 사용합니다. 이는 안정적이고 잘 정립된 모델 아키텍처에는 매우 효과적이지만 혁신에는 적합하지 않습니다. 새로운 연구에서 새로운 아키텍처를 도입하면 생태계는 새로운 커널이 작성되고 최적화될 때까지 기다려야 합니다. 하지만 Google의 컴파일러 중심 설계는 새로운 패턴으로 일반화되는 경우가 많아 최첨단 연구를 위한 고성능 경로를 처음부터 제공할 수 있습니다.
디자인
XLA는 추적 프로세스 중에 JAX에서 생성하는 계산 그래프를 적시(JIT)에 컴파일함으로써 작동합니다(예: 함수가 @jax.jit로 데코레이션된 경우).
이 컴파일은 다단계 파이프라인을 따릅니다.
- JAX 계산 그래프
- 고수준 옵티마이저(HLO)
- 저수준 옵티마이저(LLO)
- 하드웨어 코드
- JAX 그래프에서 HLO로: JAX 계산 그래프가 XLA의 HLO 표현으로 변환됩니다. 이 고수준에서는 연산자 병합 및 효율적인 메모리 관리와 같은 강력한 하드웨어 독립적 최적화가 적용됩니다. StableHLO 언어는 이 단계의 지속적인 버전 인터페이스 역할을 합니다.
- HLO에서 LLO로: 고수준 최적화 후 하드웨어별 백엔드가 인계되어 HLO 표현을 머신 중심 LLO로 낮춥니다.
- LLO에서 하드웨어 코드로: LLO는 최종적으로 매우 효율적인 기계어 코드로 컴파일됩니다. TPU의 경우 이 코드는 하드웨어로 직접 전송되는 긴 명령어 워드(VLIW) 패킷으로 함께 제공됩니다.
확장성을 위해 XLA의 설계는 병렬 구조를 중심으로 빌드됩니다. 칩의 행렬 곱셈 단위(MXU)를 최대한 사용하기 위해 알고리즘을 사용합니다. 칩 간에 XLA는 모든 기기에서 단일 프로그램을 사용하는 컴파일러 기반 병렬화 기법인 SPMD(Single Program Multiple Data)를 사용합니다. 이 강력한 모델은 JAX API를 통해 노출되므로 고수준 샤딩 주석을 사용하여 데이터, 모델 또는 파이프라인 병렬 구조를 관리할 수 있습니다.
더 복잡한 병렬 구조 패턴의 경우 MPMD(Multiple Program Multiple Data)도 사용할 수 있으며 PartIR:MPMD와 같은 라이브러리를 사용하면 JAX 사용자가 MPMD 주석을 제공할 수도 있습니다.
주요 강점
- 컴파일: 계산 그래프의 적시 컴파일을 통해 메모리 레이아웃, 버퍼 할당, 메모리 관리를 최적화할 수 있습니다. 커널 기반 방법론과 같은 대안은 개발자에게 부담이 됩니다. 대부분의 경우 XLA는 개발자 속도를 늦추지 않고도 우수한 성능을 달성할 수 있습니다.
- 병렬 구조: XLA는 SPMD를 사용하여 여러 형태의 병렬 구조를 구현하며 이는 JAX 수준에서 노출됩니다. 이를 통해 샤딩 전략을 표현하여 칩 수천 개에 걸쳐 모델을 실험하고 확장할 수 있습니다.
Pathways: 대규모 분산 컴퓨팅을 위한 통합 런타임
Pathways는 기본 제공 내결함성과 복구 기능을 통해 분산 학습과 추론을 위한 추상화를 제공하므로 ML 연구자는 강력한 단일 머신을 사용하는 것처럼 코딩할 수 있습니다.
동기
대규모 모델을 학습시키고 배포할 수 있으려면 칩이 수백 개에서 수천 개까지 필요합니다. 이러한 칩은 여러 랙과 호스트 머신에 분산되어 있습니다. 학습 작업은 이러한 모든 칩이 필요하고 각 호스트가 병렬화(샤딩)된 XLA 계산에서 함께 작동해야 하는 대규모 동기 프로그램입니다. 칩이 수만 개 이상 필요할 수 있는 대규모 언어 모델의 경우 이 서비스는 포드 내에서 ICI(interchip interconnect) 및 OCI(on-chip interconnect) 패브릭 사용 외에도 데이터 센터 패브릭에서 여러 포드에 걸쳐 있을 수 있어야 합니다.
디자인
ML Pathways는 호스트와 TPU 칩 간에 분산된 계산을 조정하는 데 사용되는 시스템입니다. 액셀러레이터 수십만 개에서 확장성과 효율성이 제공되도록 설계되었습니다. 대규모 학습의 경우 여러 포드 작업, Megascale XLA 통합, 컴파일 서비스, 원격 Python에 사용되는 단일 Python 클라이언트를 제공합니다. 또한 슬라이스 간 병렬 구조 및 선점 허용을 지원하므로 리소스 선점에서 자동 복구가 가능합니다.
Pathways에는 최적화된 교차 호스트 집합이 통합되어 있으므로 XLA 계산 그래프가 단일 TPU Pod 이상으로 확장될 수 있습니다. XLA 통신 기본 요소로 데이터 센터 네트워크(DCN) 통신을 관리하는 분산 런타임을 통합하여 데이터, 모델, 파이프라인 병렬 처리에 대한 XLA 지원이 DCN을 통해 TPU 슬라이스 경계 전반에서 작동하도록 확장됩니다.
주요 강점
JAX와 통합된 단일 컨트롤러 아키텍처가 핵심 추상화입니다. 이를 통해 연구자들은 학습 및 배포를 위한 다양한 샤딩 및 병렬 구조 전략을 탐색하면서 간편하게 칩을 수만 개까지 확장할 수 있습니다.
고급 개발: 성능, 데이터, 효율성
Pallas: JAX에서 고성능 커스텀 커널 작성
JAX에서는 컴파일러가 우선시되지만 최대 성능을 얻기 위해서는 하드웨어를 세부적으로 제어해야 하는 상황이 있을 수 있습니다. Pallas는 GPU 및 TPU용 커스텀 커널을 작성할 수 있는 JAX 확장 프로그램입니다. JAX 추적 및 jax.numpy API의 고수준 인체공학적 기능이 결합되어 있어 생성된 코드를 정밀하게 제어할 수 있습니다.
Pallas는 사용자 정의 커널 함수가 병렬 작업 그룹의 다차원 그리드에서 실행되는 그리드 기반 병렬 구조 모델을 노출합니다. 그리드 위치를 특정 데이터 블록과 연결하는 색인 지도를 사용하여 텐서가 타일링되고 느리고 큰 메모리(예: HBM)와 빠르고 작은 온칩 메모리(예: TPU의 VMEM, GPU의 공유 메모리) 간에 전송되는 방식을 정의해 메모리 계층 구조를 명시적으로 관리할 수 있습니다. Pallas는 커널을 대상 아키텍처에 적합한 중간 표현(TPU의 경우 모자이크)으로 컴파일하거나 GPU의 경우 Triton과 같은 기술을 활용함으로써 동일한 커널 정의를 낮춰 Google의 TPU와 다양한 GPU에서 효율적으로 실행될 수 있습니다. Pallas를 사용하면 공급업체별 툴킷을 사용하지 않고도 어텐션과 같은 블록을 전문화하는 고성능 커널을 작성하여 대상 하드웨어에서 최상의 모델 성능을 얻을 수 있습니다.
Tokamax: 최신 커널의 선별된 라이브러리
Pallas가 커널을 작성하는 도구라면 Tokamax는 TPU와 GPU를 모두 지원하는 최첨단 커스텀 액셀러레이터 커널 라이브러리입니다. Tokamax는 JAX와 Pallas를 기반으로 빌드되며 하드웨어의 모든 기능을 사용할 수 있게 해줍니다. 또한 커스텀 커널을 빌드하고 자동 조정할 수 있는 도구도 제공합니다.
동기
XLA에 기반을 둔 JAX는 컴파일러 우선 프레임워크이지만 최대 성능을 얻으려면 하드웨어를 직접 제어해야 하는 경우가 제한적으로 존재합니다4. 커스텀 커널은 TPU 및 GPU와 같은 고가의 ML 액셀러레이터 리소스에서 최상의 성능을 얻는 데 중요합니다. 커스텀 커널은 어텐션과 같은 주요 연산자가 우수한 성능으로 실행되기 위해 널리 사용되지만 커스텀 커널을 구현하려면 모델과 대상 하드웨어 아키텍처 모두 심도 있게 이해해야 합니다. Tokamax는 개발, 유지보수, 수명 주기 관리에 사용되는 강력한 공유 인프라와 함께 선별되고 테스트를 거친 고성능 커널의 공신력 있는 소스 하나를 제공합니다. 이러한 라이브러리는 필요에 따라 빌드하고 맞춤설정할 수 있는 참조 구현으로도 사용될 수 있습니다. 따라서 인프라를 걱정할 필요 없이 모델링에 모든 노력을 집중할 수 있습니다.
4이는 잘 정립된 패러다임이며 컴파일된 코드가 프로그램의 대부분을 형성하고 개발자가 성능에 중요한 섹션을 최적화하기 위해 내부 기능이나 인라인 어셈블리로 전환하는 CPU 세계에 선례가 있습니다.
디자인
Tokamax는 여러 구현으로 지원될 수 있는 공통 API를 지정된 커널에 제공합니다. 예를 들어 TPU 커널은 표준 XLA를 낮추거나 Pallas/모자이크-TPU를 사용하여 명시적으로 구현될 수 있습니다. GPU 커널은 모자이크-GPU 또는 Triton을 사용하여 표준 XLA을 낮춰 구현될 수 있습니다. 기본적으로 Tokamax API는 주기적인 자동 조정 및 벤치마킹 실행에서 캐시된 결과에 따라 지정된 구성에 가장 적합한 구현을 선택하지만 필요한 경우 개발자가 특정 구현을 선택할 수 있습니다. 시간이 지남에 따라 새로운 하드웨어 세대의 특정 기능을 더 잘 활용하여 성능을 더욱 향상하기 위해 새로운 구현이 추가될 수 있습니다.
커널 자체 외에도 Tokamax 라이브러리의 주요 구성요소는 커스텀 커널을 작성할 수 있는 지원 인프라입니다. 예를 들어 자동 조정 인프라를 사용하면 Tokamax에서 가능한 최적의 조정된 설정을 결정하고 캐시하기 위해 철저한 스윕을 수행할 수 있는 구성 가능한 파라미터 집합(예: 타일 크기)을 정의할 수 있습니다. 나이틀리 회귀는 기본 컴파일러 인프라나 기타 종속 항목의 변경으로 인해 발생하는 예기치 않은 성능 문제와 수치 문제를 방지합니다.
주요 강점
- 원활한 개발자 환경: 선별된 통합 라이브러리는 지원되는 하드웨어 세대와 예상 성능을 프로그래매틱 방식으로 문서에 명확하게 표현하여 주요 커널의 알려진 고성능 구현을 제공합니다. 이렇게 하면 단편화와 이탈을 최소화할 수 있습니다.
- 유연성 및 수명 주기 관리: 적절한 경우 시간 경과에 따라 구현을 변경하는 등 다양한 구현을 선택할 수 있습니다. 예를 들어 XLA 컴파일러에서 특정 작업에 대한 지원을 강화하여 더 이상 커스텀 커널이 필요하지 않은 경우 지원 중단 및 마이그레이션에 대한 경로가 있습니다.
- 확장성: 잘 지원되는 공유 인프라를 활용하면서 자체 커널을 구현할 수 있으므로 부가 가치 기능과 최적화에 집중할 수 있습니다. 명확하게 작성된 표준 구현은 사용자가 학습하고 확장할 수 있는 시작점 역할을 합니다.
Qwix: 비침입적인 포괄적 양자화
Qwix는 JAX AI 스택에 사용되는 포괄적인 양자화 라이브러리로, 학습(양자화 인식 학습(QAT), 양자화 기법(QT), 양자화된 낮은 순위 적응(QLoRA)) 및 추론(학습 후 양자화(PTQ))을 포함한 모든 단계에서 LLM과 기타 모델 유형을 지원하고 XLA 및 온디바이스 런타임을 타겟팅합니다.
동기
기존 양자화 라이브러리(특히 PyTorch 생태계)는 제한된 용도(예: PTQ 전용 또는 QLoRA 전용)를 제공하는 경우가 많습니다. 이러한 분산된 환경으로 인해 도구를 전환해야 하므로 학습과 추론 간에 일관된 코드를 사용하고 정확하게 숫자를 일치시키기가 어렵습니다. 또한 여러 솔루션을 사용하려면 모델을 상당히 수정해야 하므로 모델 로직이 양자화 로직과 긴밀하게 결합됩니다.
디자인
Qwix의 설계 철학은 포괄적인 솔루션과 중요한 비침입적 모델 통합에 중점을 둡니다. 재사용 가능한 함수 API를 기반으로 빌드된 계층적이고 확장 가능한 디자인으로 설계되었습니다.
이 비침입적 통합은 JAX 함수를 양자화된 대응 항목으로 리디렉션하는 세심하게 설계된 인터셉션 메커니즘을 통해 달성됩니다. 이렇게 하면 수정 없이 모델을 통합하여 모델 정의에서 양자화 코드를 완벽하게 분리할 수 있습니다.
다음 예시에서는 LLM의 MLP 레이어에 w4a4(4비트 가중치, 4비트 활성화) 양자화를 적용하고 임베더에 w8(8비트 가중치) 양자화를 적용하는 방법을 보여줍니다. 양자화 레시피를 변경하려면 규칙 목록만 업데이트하면 됩니다.
fp_model = ModelWithoutQuantization(...)
rules = [
qwix.QuantizationRule(
module_path=r'embedder',
weight_qtype='int8',
),
qwix.QuantizationRule(
module_path=r'layers_\d+/mlp',
weight_qtype='int4',
act_qtype='int4',
tile_size=128,
weight_calibration_method='rms,7',
),
]
quantized_model = qwix.quantize_model(fp_model, qwix.PtqProvider(rules))
주요 강점
- 포괄적인 솔루션: Qwix는 다양한 양자화 시나리오에 광범위하게 적용될 수 있으므로 학습과 추론 간에 코드를 일관되게 사용할 수 있습니다.
- 비침입적 모델 통합: 예시에서 볼 수 있듯이 코드 한 줄로 모델을 통합할 수 있습니다. 이를 통해 여러 양자화 스킴에서 하이퍼파라미터를 사용하여 최적의 품질과 성능의 절충안을 찾을 수 있습니다.
- 다른 라이브러리와 제휴: Qwix는 JAX AI 스택과 원활하게 통합됩니다. 예를 들어 Tokamax는 모델이 Qwix로 양자화될 때 추가 사용자 코드가 없어도 양자화된 커널 버전을 사용하도록 자동으로 적응합니다.
- 연구 친화적: Qwix의 기본 API와 확장 가능한 아키텍처를 통해 연구자는 새로운 알고리즘을 탐색하고 통합 벤치마크 도구와 평가 도구를 사용하여 이 알고리즘을 간단하게 비교할 수 있습니다.
애플리케이션 계층: 학습 및 조정
파운데이션 모델 학습: MaxText 및 MaxDiffusion
MaxText 및 MaxDiffusion은 각각 Google의 대표적인 LLM 및 확산 모델 학습 프레임워크입니다. 이러한 저장소에서 인기 있는 오픈 가중치 모델의 고도로 최적화된 구현을 선택할 수 있습니다. 이러한 구현은 즉시 사용할 수 있는 모델 학습 코드베이스 및 파운데이션 모델 빌더가 빌드하는 데 사용할 수 있는 참조 등 두 가지 용도로 제공합니다.
동기
업계 전반에서 생성형 AI 모델 학습에 대한 관심이 빠르게 증가하고 있습니다. 오픈 모델의 인기로 이 추세가 가속화되어 검증된 아키텍처가 제공됩니다. 이러한 모델을 학습하고 적응시키려면 고성능, 효율성, 대규모 칩으로의 확장성, 명확하고 이해하기 쉬운 코드가 필요합니다. MaxText 및 MaxDiffusion은 TPU 또는 GPU에서 사용할 수 있는 포괄적인 솔루션으로, 이러한 니즈가 충족되도록 설계되었습니다.
디자인
MaxText 및 MaxDiffusion은 가독성과 성능을 고려하여 설계된 파운데이션 모델 코드베이스입니다. 테스트를 거친 재사용 가능한 구성요소로 구성되어 있습니다. 여기에는 최대 성능을 위해 커스텀 커널(예: Tokamax)을 사용하는 모델 정의, 조정 및 모니터링을 위한 학습 하네스, 직관적인 인터페이스를 통해 샤딩 및 양자화(Qwix 사용)와 같은 세부정보를 제어할 수 있는 강력한 구성 시스템이 포함됩니다. 지속적인 굿풋이 보장되도록 멀티 계층 체크포인트와 같은 고급 신뢰성 기능이 통합되어 있습니다.
MaxText 및 MaxDiffusion은 동급 최고의 JAX 라이브러리인 Qwix, Tunix, Orbax, Optax를 사용하여 핵심 기능을 제공합니다. 이러한 라이브러리는 강력하고 확장 가능한 인프라를 제공하므로 개발 오버헤드가 줄어들고 모델링 태스크에 집중할 수 있습니다. 추론의 경우 효율적이고 확장 가능한 서빙이 지원되도록 모델 코드가 공유됩니다.
주요 강점
- 설계 단계부터 성능을 고려: MaxText 및 MaxDiffusion은 높은 '굿풋'(유용한 처리량)이 달성되도록 설정된 학습 인프라와 높은 MFU(Model Flops Utilization)가 달성되도록 최적화된 모델 구현을 통해 기본적으로 대규모로 고성능을 제공합니다.
- 확장 가능: JAX AI 스택(특히 Pathways) 기능을 활용하는 이러한 프레임워크로 칩을 수십 개에서 수만 개로 원활하게 확장할 수 있습니다.
- 파운데이션 모델 빌더를 위한 견고한 기반: 가독성이 우수한 고품질 구현은 개발자가 엔드 투 엔드 솔루션으로 사용하거나 자체 맞춤설정을 위한 참조 구현으로 사용할 수 있는 견고한 시작점 역할을 합니다.
학습 후 정렬: Tunix 프레임워크
Tunix는 최첨단 오픈소스 강화 학습(RL) 알고리즘과 함께 강력한 프레임워크와 인프라를 제공하므로 개발자는 간소화된 경로로 JAX 및 TPU를 사용한 지도 미세 조정(SFT)과 정렬을 포함한 LLM 학습 후 기법을 실험할 수 있습니다.
동기
학습 후 단계는 LLM의 진정한 힘을 발휘하는 데 있어 중요한 단계입니다. 강화 학습(RL) 단계는 특히 정렬 및 추론 기능을 개발하는 데 중요합니다. 이 분야의 오픈소스 개발은 거의 전적으로 PyTorch 및 GPU를 기반으로 진행되며 JAX 및 TPU 솔루션에는 근본적인 격차가 있습니다. Tunix(Tune-in-JAX)는 이 격차를 해소하기 위해 설계된 고성능 JAX 네이티브 라이브러리입니다.
디자인

프레임워크 관점에서 Tunix는 RL 알고리즘을 인프라와 명확하게 구분하는 최신 설정을 지원합니다. RL 인프라의 복잡성을 숨기는 간단한 클라이언트 유사 API를 제공하므로 새로운 알고리즘을 개발할 수 있습니다. Tunix는 근위 정책 최적화(PPO), 직접 선호 최적화(DPO)를 포함한 인기 알고리즘을 즉시 사용할 수 있는 솔루션을 제공합니다.
인프라 측면에서 Tunix는 Pathways와 통합되어 멀티 노드 RL 학습에 액세스할 수 있는 단일 컨트롤러 아키텍처를 지원합니다. 학습 측면에서 Tunix는 파라미터 효율적인 학습(예: LoRA)을 기본 지원하고 JAX 샤딩 및 XLA(General and Scalable Parallelization for ML Computation Graph(GSPMD))를 활용하여 성능이 우수한 컴퓨팅 그래프를 생성합니다. 널리 사용되는 오픈소스 모델(예: Gemma 및 Llama)을 기본 지원합니다.
주요 강점
- 단순성: 기본 분산 인프라의 복잡성을 추상화하는 고수준 클라이언트 유사 API를 제공합니다.
- 개발자 효율성: Tunix는 기본 제공 알고리즘과 '레시피'로 R&D 수명 주기를 가속화하므로 개발자는 작동하는 모델을 제공하고 빠르게 반복할 수 있습니다.
- 성능 및 확장성: Tunix는 백엔드에서 Pathways를 단일 컨트롤러로 활용함으로써 매우 효율적이고 수평으로 확장 가능한 학습 인프라를 지원합니다.
애플리케이션 계층: 프로덕션 및 추론
JAX 채택의 역사적 과제는 연구부터 프로덕션에까지 이르는 경로였습니다. 이제 JAX AI 스택은 생태계 호환성과 JAX 성능을 모두 제공하는 성숙한 이중 프로덕션 스토리를 제공합니다.
고성능 LLM 추론: vLLM 솔루션
vLLM-TPU는 Cloud TPU에서 PyTorch 및 JAX 대규모 언어 모델(LLM)을 효율적으로 실행하도록 설계된 Google의 고성능 추론 스택입니다. 널리 사용되는 오픈소스 vLLM 프레임워크를 Google의 JAX 및 TPU 생태계와의 기본적인 통합을 통해 이를 이룩했습니다.
동기
원활하고 고성능이며 사용하기 쉬운 추론 솔루션에 대한 수요가 증가함에 따라 업계가 빠르게 진화하고 있습니다. 개발자는 복잡하고 일관되지 않은 도구, 수준 이하의 성능, 제한된 모델 호환성으로 인해 상당한 어려움에 직면하는 경우가 많습니다. vLLM 스택은 성능이 우수하고 직관적인 통합 플랫폼을 제공함으로써 이러한 문제를 해결합니다.
디자인
이 솔루션은 vLLM 프레임워크를 재창조하지 않고 확장합니다. vLLM-TPU는 PagedAttention(단편화를 최소화하기 위해 가상 메모리처럼 KV 캐시 관리) 및 지속적 일괄 처리(활용도가 향상되도록 일괄 처리에 요청을 동적으로 추가)와 같은 주요 기능을 사용하여 높은 처리량을 달성하는 것으로 알려진 고도로 최적화된 오픈소스 LLM 서빙 엔진입니다.
vLLM-TPU는 이 기반을 토대로 빌드되었으며 요청 처리, 예약, 메모리 관리를 위한 핵심 구성요소를 개발합니다. vLLM의 연산 그래프와 메모리 작업을 TPU 실행 가능 코드로 변환하는 브리지 역할을 하는 JAX 기반 백엔드가 도입되었습니다. 이 백엔드는 기기 상호작용, JAX 모델 실행, TPU 하드웨어의 KV 캐시 관리 관련 세부사항을 처리합니다. 효율적인 어텐션 메커니즘(예: Ragged Paged Attention을 위해 JAX Pallas 커널 활용) 및 양자화와 같은 TPU별 최적화가 통합되어 있으며 모두 TPU 아키텍처에 맞게 조정됩니다.
주요 강점
- 사용자 온보딩/오프보딩 무료: 사용자는 큰 어려움 없이 이 솔루션을 채택할 수 있습니다. 사용자 경험 관점에서 TPU에서 추론 요청을 처리하는 것은 GPU에서 처리하는 것과 동일해야 합니다. 서버를 시작하고 프롬프트를 수락하고 출력을 반환하는 CLI는 모두 공유됩니다.
- 생태계 완전 수용: 이 방식은 vLLM 인터페이스와 사용자 경험을 활용하고 이에 기여하여 호환성과 사용 편의성을 보장합니다.
- TPU 및 GPU 간 대체 가능성: 이 솔루션은 TPU 및 GPU 모두에서 효율적으로 작동하므로 유연성이 높습니다.
- 경제적(최고의 $당 성능): 성능을 최적화하여 인기 모델에 대한 최고의 가성비를 제공합니다.
JAX 서빙: Orbax 직렬화 및 Neptune 서빙 엔진
LLM 이외의 모델이나 완전한 JAX 네이티브 파이프라인을 원하는 사용자를 위해 Orbax 직렬화 라이브러리 및 Neptune 서빙 엔진(NSE) 시스템에서 엔드 투 엔드 고성능 서빙 솔루션을 제공합니다.
동기
이전에는 JAX 모델이 TensorFlow 그래프로 래핑되고 TensorFlow Serving을 사용하여 배포되는 등 프로덕션까지의 우회 경로를 사용하는 경우가 많았습니다. 이 방식은 상당한 제한사항과 비효율성을 발생시켜 개발자는 별도의 생태계와 상호작용해야 했고 반복이 늦어졌습니다. 전용 JAX 네이티브 서빙 시스템은 지속 가능성, 복잡성 감소, 최적화된 성능에 매우 중요합니다.
디자인
이 솔루션은 다음 다이어그램에 표시된 것처럼 두 가지 핵심 구성요소로 구성됩니다.

- Orbax 직렬화 라이브러리: JAX 모델을 새로운 강력한 Orbax 직렬화 형식으로 직렬화할 수 있는 사용자 친화적인 API를 제공합니다. 이 형식은 프로덕션 배포에 최적화되어 있습니다. StableHLO를 사용하여 JAX 모델 계산을 직접 나타내므로 계산 그래프를 기본적으로 나타낼 수 있습니다. 또한 가중치를 저장하는 데 TensorStore를 활용하여 서빙을 위한 빠른 체크포인트 로딩을 지원합니다.
- Neptune 서빙 엔진(NSE): Orbax 형식으로 JAX 모델을 기본적으로 실행할 수 있도록 설계된 유연한 동반 고성능 서빙 엔진입니다(일반적으로 C++ 바이너리로 배포됨). NSE는 빠른 모델 로드, 기본 제공 일괄 처리 기능을 사용한 높은 처리량 동시 제공, 여러 모델 버전 지원, PJRT 및 Pathways를 활용한 단일 및 다중 호스트 서빙과 같은 프로덕션에 필수적인 기능을 제공합니다. 다음과 같은 경우 Neptune Serving Engine을 사용합니다.
- LLM이 아닌 모델: 추천자 시스템, 확산 모델, 기타 AI 모델과 같은 워크로드에 적합한 범용 솔루션입니다.
- 소형 LLM 및 '원샷' 제공: 비자기 회귀 모델이나 '단항' 방식으로 제공되는 소형 모델을 위해 설계되었습니다. 여기서 전체 출력은 KV 캐시와 같은 복잡한 상태 관리 없이 단일 패스에서 생성됩니다.
간단히 말해 Neptune 서빙 엔진은 크지 않고 자기회귀 언어 모델이 아닌 다양한 모델을 제공하는 데 있어 격차를 해소하여 더욱 다양한 ML 생태계를 위한 고성능 TPU 네이티브 솔루션을 제공합니다.
주요 강점
- JAX 네이티브 서빙: 이 솔루션은 기본적으로 JAX용으로 빌드되어 모델 직렬화와 서빙에서 프레임워크 간 오버헤드를 제거합니다. 이렇게 하면 CPU, GPU, TPU에서 모델을 빠르게 로드하고 실행을 최적화할 수 있습니다.
- 간편한 프로덕션 배포: 직렬화된 모델은 Python 종속 항목의 드리프트에 영향을 받지 않는 밀폐된 배포 경로를 제공하고 런타임 모델 무결성 검사를 지원합니다. 이를 통해 JAX 모델 프로덕션화에 사용할 수 있는 원활하고 직관적인 경로가 제공됩니다.
- 개발자 환경 개선: 이 솔루션은 번거로운 프레임워크 래핑의 필요성을 없애 종속 항목과 시스템 복잡성을 크게 줄이므로 JAX 개발자를 위한 반복이 가속화됩니다.
시스템 전체 분석 및 프로파일링
XProf: 하드웨어 통합 심층 성능 프로파일링
XProf는 ML 워크로드 실행의 다양한 관점에 대한 심층 있는 가시성을 제공하여 성능을 디버그하고 최적화할 수 있는 프로파일링 및 성능 분석 도구입니다. JAX 및 TPU 생태계에 긴밀하게 통합되어 있습니다.
동기
한편 ML 워크로드는 점점 복잡해지고 있습니다. 반면 이러한 워크로드를 타겟팅하는 전문 하드웨어 기능이 폭발적으로 증가하고 있습니다. ML 인프라의 막대한 비용을 고려할 때 최대 성능과 효율성을 보장하려면 두 가지를 효과적으로 일치시켜야 합니다. 이렇게 하려면 워크로드와 하드웨어 모두에 대한 심도 있는 가시성이 필요하며 이 가시성은 빠르게 소비할 수 있는 방식으로 제공되어야 합니다. XProf는 이 부분에서 뛰어납니다.
디자인
XProf는 수집과 분석이라는 두 가지 기본 구성요소로 구성됩니다.
- 수집: XProf는 JAX 코드의 주석, XLA 컴파일러 내 작업의 비용 모델, TPU 내의 맞춤형 하드웨어 프로파일링 기능과 같은 다양한 소스에서 정보를 캡처합니다. 이 수집은 프로그래매틱 방식 또는 주문형으로 트리거되므로 포괄적인 이벤트 아티팩트를 생성할 수 있습니다.
- 분석: XProf는 수집된 데이터를 후처리하고 브라우저로 액세스할 수 있는 강력한 시각화 모음을 만듭니다.
주요 강점
XProf의 진정한 힘은 풀 스택과의 긴밀한 통합에서 비롯되며 공동 설계된 JAX/TPU 생태계의 실질적인 이점인 폭넓은 심층적인 분석을 제공합니다.
- TPU와 공동 설계: XProf는 원활한 프로필 수집을 위해 특별히 설계된 하드웨어 기능을 활용하여 1%미만의 수집 오버헤드를 지원합니다. 이를 통해 프로파일링은 개발의 가벼운 반복 부분이 될 수 있습니다.
- 폭넓은 심층적인 분석: XProf는 여러 축에 걸쳐 심층적인 분석을 제공합니다. 이 도구에는 다음이 포함됩니다.
- Trace 뷰어: 다양한 하드웨어 유닛(예: TensorCore)에서의 실행 작업 타임라인 뷰입니다.
- HLO 작업 프로필: 총 소요 시간을 다양한 작업 카테고리로 분류합니다.
- 메모리 뷰어: 프로파일링된 창에서 다양한 작업별 메모리 할당을 자세히 보여줍니다.
- 루프라인 분석: 특정 작업이 컴퓨팅 또는 메모리 바운드인지, 하드웨어 최대 기능과 동떨어져 있는 정도를 파악하는 데 도움이 됩니다.
- 그래프 뷰어: 하드웨어에서 실행되는 전체 HLO 그래프를 볼 수 있습니다.
비교 관점: 매력적인 선택으로 JAX/TPU 스택
최신 머신러닝 환경에서는 성숙하고 우수한 도구 모음을 다양하게 제공합니다. JAX AI 스택은 모듈식 설계와 심층적인 하드웨어 공동 설계에서 직접 비롯되는 대규모 고성능 ML에 중점을 두는 개발자에게 고유하고 매력적인 이점을 제공합니다.
많은 프레임워크에서 다양한 기능을 제공하지만 JAX AI 스택은 개발 수명 주기의 주요 영역에서 구체적이고 강력한 차별화 요소를 제공합니다.
- 더 간단하고 강력한 개발자 환경: Optax의 연결 가능한 경사 변환 패러다임을 사용하면 학습 루프에서 명령형으로 관리하는 대신 한 번 선언되는 더 강력하고 유연한 최적화 전략을 사용할 수 있습니다. 시스템 수준에서 Pathways의 단순화된 단일 컨트롤러 인터페이스는 멀티슬라이스 학습 복잡성을 추상화하여 연구자에게 상당한 간소화를 제공합니다.
- 놀라운 규모의 복원력을 위해 설계: JAX 스택은 대규모 학습을 위해 설계되었습니다. Orbax는 긴급 및 멀티 계층 체크포인트와 같은 '놀라운 규모의 학습 복원력' 기능을 제공합니다. 이는 결정론적 전역 셔플과 체크포인트 가능 데이터 로더를 통해 재현성을 완벽하게 지원하는 Grain으로 보완됩니다. 모델 상태(Orbax)로 데이터 파이프라인 상태(Grain)를 원자적으로 체크포인트하는 기능은 장기 실행 작업에서 재현성을 보장하는 데 중요한 기능입니다.
- 완전한 엔드 투 엔드 생태계: 스택은 일관된 엔드 투 엔드 솔루션을 제공합니다. 개발자는 MaxText를 학습용 SOTA 참조로, Tunix를 정렬용으로 사용하고 vLLM-TPU(vLLM 호환성을 위해) 및 NSE(JAX 성능을 위해)를 사용하여 프로덕션으로 이어지는 명확한 이중 경로를 따를 수 있습니다.
많은 스택이 고수준 소프트웨어 관점에서 유사하지만 결정적인 요소는 성능/TCO인 경우가 많습니다. 이 부분이 JAX와 TPU의 공동 설계가 명확한 이점을 제공하는 부분입니다. 이 성능/TCO 이점은 소프트웨어와 TPU 하드웨어 간의 수직적 통합의 직접적인 결과입니다. XLA 컴파일러에서 TPU 아키텍처를 위해 특별히 작업을 융합하거나 XProf 프로파일러에서 1% 미만의 오버헤드 프로파일링을 위해 하드웨어 후크를 사용할 수 있는 기능은 이러한 긴밀한 통합의 실질적인 이점입니다.
이 스택을 채택하는 조직은 JAX AI 스택의 모든 기능을 활용하여 마이그레이션 비용을 최소화할 수 있습니다. 인기 있는 개방형 모델 아키텍처를 사용하는 고객의 경우 다른 프레임워크에서 MaxText로 전환은 구성 파일 설정과 관련된 문제인 경우가 많습니다. 또한 스택에서 safetensors와 같은 인기 있는 체크포인트 형식을 수집할 수 있으므로 비용이 많이 드는 재학습 없이 기존 체크포인트를 마이그레이션할 수 있습니다.
다음 표에서는 JAX AI 스택에서 제공하는 구성요소와 다른 프레임워크나 라이브러리의 해당 구성요소를 매핑합니다.
| 기능 | JAX | 기타 프레임워크의 대안/상응하는 항목5 |
| 컴파일러/런타임 | XLA | Inductor, 쉬움 |
| MultiPod 학습 | Pathways | Torch lightning 전략, Ray Train, Monarch(신규) |
| 핵심 프레임워크 | JAX | PyTorch |
| 모델 작성 | Flax, Max* 모델 | torch.nn.*, NVidia TransformerEngine, HuggingFace Transformers
|
| 옵티마이저 및 손실 | Optax | torch.optim.*, torch.nn.*Loss |
| 데이터 로더 | Grain | Ray Data, HuggingFace dataloaders |
| 체크포인트 | Orbax | PyTorch 분산 체크포인트, NeMo 체크포인트 |
| 양자화 | Qwix | TorchAO, bitsandbytes |
| 커널 작성 및 잘 알려진 구현 | Pallas/Tokamax | Triton/Helion, Liger-kernel, TransformerEngine |
| 사후 학습/조정 | Tunix | VERL, NeMoRL |
| 프로파일링 | XProf | PyTorch 프로파일러, NSight 시스템, NSight Compute |
| 파운데이션 모델 학습 | MaxText, MaxDiffusion | NeMo-Megatron, DeepSpeed, TorchTitan |
| LLM 추론 | vLLM | SGLang |
| LLM이 아닌 추론 | NSE | Triton 추론 서버, RayServe |
5여기에서 일부 동등한 항목은 다른 프레임워크가 JAX와 비교하여 API 경계를 다르게 그리기 때문에 항상 정확한 비교는 아닙니다. 동등한 항목 목록은 전체 목록이 아니며 새로운 라이브러리가 자주 표시됩니다.
결론: AI 미래를 위한 지속 가능하고 프로덕션에 즉시 사용 가능한 플랫폼
이전 표에 제공된 데이터는 명백한 결론을 보여줍니다. 이러한 스택은 소수의 영역에서 자체적인 강점과 약점을 가지고 있지만 전체적으로 소프트웨어 관점에서 매우 유사합니다. 두 스택 모두 파운데이션 모델의 사전 학습, 학습 후 적응 및 배포에 사용할 수 있는 턴키 솔루션을 제공합니다.
JAX AI 스택은 모든 규모에서 ML 모델을 학습시키고 배포할 수 있는 강력한 솔루션을 제공합니다. 소프트웨어와 TPU 하드웨어 전반에서 심층적인 수직 통합을 활용하여 동급 최고의 성능을 제공하고 총소유비용을 낮춥니다.
실전에서 검증된 내부 시스템을 기반으로 빌드된 스택은 내재된 신뢰성과 확장성을 제공하도록 진화하고 있으며 사용자는 이를 사용하여 가장 큰 모델도 자신 있게 개발하고 배포할 수 있습니다. JAX AI 스택 철학에 기반한 모듈식 및 구성 가능 설계로 사용자에게 비교할 수 없는 자유와 관리를 제공하여 모놀리식 프레임워크의 제약 조건 없이 특정 니즈에 맞게 스택을 맞춤설정할 수 있습니다.
확장 가능하고 내결함성 기반을 제공하는 XLA 및 Pathways, 우수한 성능과 표현력이 풍부한 숫자 라이브러리를 제공하는 JAX, 강력한 핵심 개발 라이브러리(예: Flax, Optax, Grain, Orbax), 고급 성능 도구(예: Pallas, Tokamax, Qwix), MaxText, vLLM, NSE의 강력한 애플리케이션 및 프로덕션 레이어를 사용하는 JAX AI 스택은 사용자가 최첨단 연구를 빌드하고 프로덕션에 신속하게 적용할 수 있는 지속 가능한 기반을 제공합니다.