TPU#
SGLang supports high-performance TPU inference through the SGLang-JAX backend, which is specifically optimized for Google Cloud TPUs. The JAX-based implementation delivers exceptional throughput and low latency for Large Language Model (LLM) serving workloads on TPU hardware.
For TPU-specific issues or feature requests, please visit the sglang-jax GitHub issues page.
NOTE: SGLang TPU support is implemented via the SGLang-JAX backend, a dedicated JAX-based inference engine maintained as a separate repository at sgl-project/sglang-jax.
System Requirements#
Supported TPU Hardware#
TPU Type |
HBM Memory |
Availability |
|---|---|---|
TPU v6e |
32 GB |
Google Cloud |
TPU v7 |
96 GB per core |
Google Cloud |
Software Requirements#
Python: 3.12 or higher
JAX: Latest version with TPU support
Environment: Google Cloud TPU VM or compatible TPU runtime
Optional: SkyPilot for simplified cloud deployment
Feature Support Matrix#
SGLang-JAX provides comprehensive TPU-optimized features for production LLM serving:
Feature |
Support Status |
Description |
|---|---|---|
High-Throughput Continuous Batching |
✅ |
Dynamic request batching for maximum TPU utilization |
Radix Tree KV Cache |
✅ |
Memory-efficient prefix sharing between requests |
FlashAttention Backend |
✅ |
TPU-optimized attention kernel for long sequences |
Tensor Parallelism |
✅ |
Distribute models across multiple TPU cores |
Paged Attention |
✅ |
Flexible KV cache management with paging |
Speculative Decoding (EAGLE/EAGLE3) |
✅ |
20-40% throughput improvement for compatible models |
Chunked Prefill |
✅ |
Mixed prefill-decode batching |
OpenAI-Compatible API |
✅ |
Drop-in replacement for OpenAI API |
Data Parallel Attention |
🚧 |
In development - Attention computation with data parallelism |
Quantization |
🚧 |
In development - Model quantization for reduced memory usage |
Multi-LoRA |
🚧 |
In development - Serve multiple LoRA adapters simultaneously |
Attention Backend Comparison#
Backend |
Paged Attention |
Spec Decoding |
MLA |
Sliding Window |
|---|---|---|---|---|
FlashAttention (fa) |
✅ |
✅ |
❌ |
✅ |
Native |
❌ |
❌ |
❌ |
❌ |
NOTE: FlashAttention backend is recommended for production workloads due to superior memory efficiency and performance.
Optimized Model List#
The following models have been tested and optimized for TPU deployment:
Model Family |
Performance Status |
|---|---|
⭐ Recommended for production |
|
⭐ Best performance |
|
Needs improvement |
|
Needs improvement |
|
Needs improvement |
|
Needs improvement |
|
Needs improvement |
|
Verified on TPU |
|
Bailing MoE |
Needs improvement |
Installation#
Method 1: Using PyPI (Recommended)#
pip install sglang-jax
Method 2: From Source#
git clone https://github.com/sgl-project/sglang-jax
cd sglang-jax
uv venv --python 3.12 && source .venv/bin/activate
uv pip install -e "python[all]"
Method 3: Using Docker#
NOTE: Docker support for TPU is currently under development. Please use PyPI or source installation methods.
Method 4: Cloud TPU with SkyPilot#
SkyPilot provides simplified deployment on Google Cloud TPU:
Install SkyPilot and configure GCP access (see SkyPilot documentation)
Create a SkyPilot configuration file:
SkyPilot YAML: sglang-jax.sky.yaml
# sglang-jax.sky.yaml
resources:
accelerators: tpu-v6e-4
accelerator_args:
tpu_vm: True
runtime_version: v2-alpha-tpuv6e
run: |
git clone https://github.com/sgl-project/sglang-jax.git
cd sglang-jax
uv venv --python 3.12
source .venv/bin/activate
uv pip install -e "python[all]"
Launch your TPU cluster:
# Standard deployment
sky launch -c sglang-jax sglang-jax.sky.yaml --infra=gcp
# With spot instances for cost savings
sky launch -c sglang-jax sglang-jax.sky.yaml --infra=gcp --use-spot
Launch of the Serving Engine#
Basic Example: Qwen-7B#
JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache python3 -u -m sgl_jax.launch_server \
--model-path Qwen/Qwen-7B-Chat \
--trust-remote-code \
--dist-init-addr=0.0.0.0:10011 \
--nnodes=1 \
--tp-size=4 \
--device=tpu \
--random-seed=3 \
--node-rank=0 \
--mem-fraction-static=0.8 \
--max-prefill-tokens=8192 \
--download-dir=/tmp \
--dtype=bfloat16 \
--skip-server-warmup \
--host 0.0.0.0 \
--port 30000
Key Parameters Explained:
JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache- Enables JIT compilation caching to accelerate server startup on subsequent runs--tp-size=4- Tensor parallelism size; match this to your TPU core count (typically 1, 4, or 8)--device=tpu- Specifies TPU device (this is the default for sglang-jax)--dtype=bfloat16- Uses bfloat16 precision, which TPUs are optimized for--mem-fraction-static=0.8- Allocates 80% of TPU HBM for static memory (adjustable from 0.2 to 0.9)--max-prefill-tokens=8192- Maximum number of tokens processed in the prefill phase
High-Performance Configuration: Qwen3-8B#
For production workloads with optimal throughput:
python3 -u -m sgl_jax.launch_server \
--model-path Qwen/Qwen3-8B \
--trust-remote-code \
--tp-size=4 \
--device=tpu \
--mem-fraction-static=0.8 \
--chunked-prefill-size=2048 \
--dtype=bfloat16 \
--max-running-requests=256 \
--page-size=128 \
--attention-backend=fa
Advanced: Speculative Decoding (EAGLE3)#
Speculative decoding can improve throughput by 20-40% for compatible models:
python3 -u -m sgl_jax.launch_server \
--model-path Qwen/Qwen3-32B \
--trust-remote-code \
--device=tpu \
--tp-size=4 \
--mem-fraction-static=0.8 \
--max-prefill-tokens=4096 \
--attention-backend=fa \
--dtype=bfloat16 \
--port=30000 \
--host=0.0.0.0 \
--disable-overlap-schedule \
--speculative-algorithm=EAGLE3 \
--speculative-draft-model-path=AngelSlim/Qwen3-32B_eagle3 \
--page-size=64 \
--speculative-eagle-topk=1 \
--speculative-num-steps=3 \
--speculative-num-draft-tokens=4
NOTE: Speculative decoding is currently supported for Qwen3 and LLaMA model families. See the Speculative Decoding documentation for detailed configuration guidance.
Multi-Node Distributed Serving#
For large models requiring multiple TPU VMs:
# Node 0 (coordinator)
python3 -m sgl_jax.launch_server \
--model-path MODEL_PATH \
--dist-init-addr=NODE0_IP:10011 \
--nnodes=2 \
--node-rank=0 \
--tp-size=8 \
[other parameters...]
# Node 1 (worker)
python3 -m sgl_jax.launch_server \
--model-path MODEL_PATH \
--dist-init-addr=NODE0_IP:10011 \
--nnodes=2 \
--node-rank=1 \
--tp-size=8 \
[other parameters...]
Benchmarking with Requests#
Throughput Testing#
Basic throughput benchmark:
python3 -m sgl_jax.bench_serving \
--backend sgl-jax \
--dataset-name random \
--num-prompts=100 \
--random-input=512 \
--random-output=128 \
--max-concurrency=8 \
--random-range-ratio=1 \
--warmup-requests=0
Latency Testing#
Measure single-batch latency:
python3 -m sgl_jax.bench_one_batch_server \
--base-url http://127.0.0.1:30000 \
--model-path Qwen/Qwen-7B-Chat \
--batch-size=32 \
--input-len=256 \
--output-len=32
Comprehensive Benchmark Script#
For systematic performance evaluation across different configurations:
#!/bin/bash
set -e
backend=${1:-sgl-jax}
num_prompts_per_concurrency=3
input_seq_lens=(1024 4096 8192)
output_seq_lens=(1 1024)
max_concurrencies=(8 16 32 64 128 256)
for input_seq_len in "${input_seq_lens[@]}"; do
for output_seq_len in "${output_seq_lens[@]}"; do
echo "======================================="
echo "Testing ISL/OSL: $input_seq_len/$output_seq_len"
echo "======================================="
for max_concurrency in "${max_concurrencies[@]}"; do
num_prompts=$((num_prompts_per_concurrency * max_concurrency))
python3 -m sgl_jax.bench_serving \
--backend ${backend} \
--dataset-name random \
--num-prompts ${num_prompts} \
--random-input ${input_seq_len} \
--random-output ${output_seq_len} \
--max-concurrency ${max_concurrency} \
--random-range-ratio 1 \
--disable-ignore-eos \
--warmup-requests 0
done
done
done
For detailed help on all benchmark parameters:
python3 -m sgl_jax.bench_serving --help
See the Benchmark and Profiling Guide for advanced benchmarking techniques and profiling with JAX Profiler.
Performance Optimization#
Memory Optimization#
Reduce memory usage:
Lower
--mem-fraction-static(from 0.8 → 0.5 → 0.3)Decrease
--max-prefill-tokens(from 16384 → 8192 → 4096)Reduce
--max-running-requests
Handle OOM errors:
Start with conservative memory settings (
--mem-fraction-static=0.5)Gradually increase until you find the optimal balance
Increase
--page-sizefor better memory locality (1 → 16 → 64 → 128)
Throughput Optimization#
To maximize tokens per second:
Use FlashAttention backend:
--attention-backend=faEnable speculative decoding (EAGLE3) for Qwen3 models (20-40% improvement)
Increase
--max-running-requeststo 256+Set
--mem-fraction-staticto 0.8+ (if memory allows)Use larger page sizes (64-128)
Enable chunked prefill:
--chunked-prefill-size=2048
Latency Optimization#
To minimize time-to-first-token (TTFT) and inter-token latency:
Reduce
--page-sizeto 1-4Lower
--max-running-requests(16-32) for smaller batchesReduce
--chunked-prefill-sizeUse conservative memory settings to avoid GC pauses
TPU-Specific Optimizations#
JIT Compilation Cache:
export JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache
Always set this environment variable to cache compiled kernels and accelerate server startup.
Data Type Optimization: Use
--dtype=bfloat16for TPU native optimization. TPUs are specifically designed for bfloat16 computations.Tensor Parallelism: Match
--tp-sizeto your TPU core configuration (1, 4, or 8) for optimal model distribution.Attention Backend: Always use
--attention-backend=fa(FlashAttention) for production workloads.
Troubleshooting#
OOM (Out of Memory) Errors#
If you encounter out-of-memory errors:
Reduce
--mem-fraction-staticfrom 0.8 to 0.5 or lowerDecrease
--max-prefill-tokensfrom 8192 to 4096 or 2048Lower
--max-running-requeststo reduce concurrent batch sizeIncrease
--page-sizefor better memory layout efficiency
Compilation Long-Time#
If the server takes too long to start:
Ensure
JAX_COMPILATION_CACHE_DIRis properly setUnderstand that the first run requires JIT compilation (this is normal)
Subsequent runs will be significantly faster with cached compilations
Consider using
--skip-server-warmupto defer compilation until first request
Low Throughput#
If you’re not achieving expected throughput:
Verify
--tp-sizematches your TPU core configurationCheck that
--attention-backend=fais enabledIncrease
--max-running-requeststo enable larger batch formationConsider enabling speculative decoding for compatible models
Ensure memory settings allow for sufficient batch sizes
Connection Issues#
If clients cannot connect to the server:
Ensure
--host=0.0.0.0for external access (not just127.0.0.1)Verify firewall rules allow traffic on the specified port (default: 30000)
Check that the server process is running:
curl http://localhost:30000/health
Advanced Features#
Speculative Decoding#
SGLang-JAX supports EAGLE and EAGLE3 speculative decoding algorithms for Qwen3 and LLaMA model families. Speculative decoding can improve throughput by 20-40% without affecting output quality.
See the Speculative Decoding documentation for detailed configuration and supported model combinations.
Chunked Prefill#
Enable mixed prefill-decode batching for better TPU utilization:
--chunked-prefill-size=2048 --enable-mixed-chunk
This allows the scheduler to mix prefill operations with decode operations in the same batch, improving overall throughput.
Custom Attention Backends#
SGLang-JAX supports a plugin-based attention backend system. You can implement custom attention kernels optimized for specific use cases.
See the Attention Backend documentation for implementation details.
Environment Verification#
Verify your TPU setup before deploying:
python -c "from sgl_jax import check_env; check_env.check_env()"
This command checks:
Installed package versions
TPU device availability and specifications
System resources and configuration
Compatibility of settings
Contributing#
We welcome contributions to improve TPU support in SGLang-JAX!
Areas for Contribution#
Check the Development Roadmap to see planned features and find opportunities to contribute new functionality.
Current contribution areas include:
Performance optimizations for specific TPU generations
Support for additional model architectures
Documentation improvements and examples
Bug reports and fixes
Benchmark results and performance analysis
How to Contribute#
Visit the sglang-jax repository
Read the Contribution Guide
Join the SGL-JAX Slack community for discussions
Report issues at sglang-jax/issues
Testing on TPU#
For contributors who need TPU access for testing:
Refer to the TPU Resources Guide for information on accessing TPU hardware
Use SkyPilot with spot instances for cost-effective testing
Follow the Benchmark and Profiling Guide for performance validation