TPU#

SGLang supports high-performance TPU inference through the SGLang-JAX backend, which is specifically optimized for Google Cloud TPUs. The JAX-based implementation delivers exceptional throughput and low latency for Large Language Model (LLM) serving workloads on TPU hardware.

For TPU-specific issues or feature requests, please visit the sglang-jax GitHub issues page.

NOTE: SGLang TPU support is implemented via the SGLang-JAX backend, a dedicated JAX-based inference engine maintained as a separate repository at sgl-project/sglang-jax.

System Requirements#

Supported TPU Hardware#

TPU Type

HBM Memory

Availability

TPU v6e

32 GB

Google Cloud

TPU v7

96 GB per core

Google Cloud

Software Requirements#

  • Python: 3.12 or higher

  • JAX: Latest version with TPU support

  • Environment: Google Cloud TPU VM or compatible TPU runtime

  • Optional: SkyPilot for simplified cloud deployment

Feature Support Matrix#

SGLang-JAX provides comprehensive TPU-optimized features for production LLM serving:

Feature

Support Status

Description

High-Throughput Continuous Batching

Dynamic request batching for maximum TPU utilization

Radix Tree KV Cache

Memory-efficient prefix sharing between requests

FlashAttention Backend

TPU-optimized attention kernel for long sequences

Tensor Parallelism

Distribute models across multiple TPU cores

Paged Attention

Flexible KV cache management with paging

Speculative Decoding (EAGLE/EAGLE3)

20-40% throughput improvement for compatible models

Chunked Prefill

Mixed prefill-decode batching

OpenAI-Compatible API

Drop-in replacement for OpenAI API

Data Parallel Attention

🚧

In development - Attention computation with data parallelism

Quantization

🚧

In development - Model quantization for reduced memory usage

Multi-LoRA

🚧

In development - Serve multiple LoRA adapters simultaneously

Attention Backend Comparison#

Backend

Paged Attention

Spec Decoding

MLA

Sliding Window

FlashAttention (fa)

Native

NOTE: FlashAttention backend is recommended for production workloads due to superior memory efficiency and performance.

Optimized Model List#

The following models have been tested and optimized for TPU deployment:

Model Family

Performance Status

Qwen 3

⭐ Recommended for production

Qwen 3 MoE

⭐ Best performance

Qwen 2

Needs improvement

Qwen 2 MoE

Needs improvement

Qwen 1.5

Needs improvement

Llama/LLaMA

Needs improvement

Grok-2

Needs improvement

Gemma 2

Verified on TPU

Bailing MoE

Needs improvement

Installation#

Method 2: From Source#

git clone https://github.com/sgl-project/sglang-jax
cd sglang-jax
uv venv --python 3.12 && source .venv/bin/activate
uv pip install -e "python[all]"

Method 3: Using Docker#

NOTE: Docker support for TPU is currently under development. Please use PyPI or source installation methods.

Method 4: Cloud TPU with SkyPilot#

SkyPilot provides simplified deployment on Google Cloud TPU:

  1. Install SkyPilot and configure GCP access (see SkyPilot documentation)

  2. Create a SkyPilot configuration file:

SkyPilot YAML: sglang-jax.sky.yaml
# sglang-jax.sky.yaml
resources:
   accelerators: tpu-v6e-4
   accelerator_args:
      tpu_vm: True
      runtime_version: v2-alpha-tpuv6e

run: |
  git clone https://github.com/sgl-project/sglang-jax.git
  cd sglang-jax
  uv venv --python 3.12
  source .venv/bin/activate
  uv pip install -e "python[all]"
  1. Launch your TPU cluster:

# Standard deployment
sky launch -c sglang-jax sglang-jax.sky.yaml --infra=gcp

# With spot instances for cost savings
sky launch -c sglang-jax sglang-jax.sky.yaml --infra=gcp --use-spot

Launch of the Serving Engine#

Basic Example: Qwen-7B#

JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache python3 -u -m sgl_jax.launch_server \
    --model-path Qwen/Qwen-7B-Chat \
    --trust-remote-code \
    --dist-init-addr=0.0.0.0:10011 \
    --nnodes=1 \
    --tp-size=4 \
    --device=tpu \
    --random-seed=3 \
    --node-rank=0 \
    --mem-fraction-static=0.8 \
    --max-prefill-tokens=8192 \
    --download-dir=/tmp \
    --dtype=bfloat16 \
    --skip-server-warmup \
    --host 0.0.0.0 \
    --port 30000

Key Parameters Explained:

  1. JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache - Enables JIT compilation caching to accelerate server startup on subsequent runs

  2. --tp-size=4 - Tensor parallelism size; match this to your TPU core count (typically 1, 4, or 8)

  3. --device=tpu - Specifies TPU device (this is the default for sglang-jax)

  4. --dtype=bfloat16 - Uses bfloat16 precision, which TPUs are optimized for

  5. --mem-fraction-static=0.8 - Allocates 80% of TPU HBM for static memory (adjustable from 0.2 to 0.9)

  6. --max-prefill-tokens=8192 - Maximum number of tokens processed in the prefill phase

High-Performance Configuration: Qwen3-8B#

For production workloads with optimal throughput:

python3 -u -m sgl_jax.launch_server \
    --model-path Qwen/Qwen3-8B \
    --trust-remote-code \
    --tp-size=4 \
    --device=tpu \
    --mem-fraction-static=0.8 \
    --chunked-prefill-size=2048 \
    --dtype=bfloat16 \
    --max-running-requests=256 \
    --page-size=128 \
    --attention-backend=fa

Advanced: Speculative Decoding (EAGLE3)#

Speculative decoding can improve throughput by 20-40% for compatible models:

python3 -u -m sgl_jax.launch_server \
    --model-path Qwen/Qwen3-32B \
    --trust-remote-code \
    --device=tpu \
    --tp-size=4 \
    --mem-fraction-static=0.8 \
    --max-prefill-tokens=4096 \
    --attention-backend=fa \
    --dtype=bfloat16 \
    --port=30000 \
    --host=0.0.0.0 \
    --disable-overlap-schedule \
    --speculative-algorithm=EAGLE3 \
    --speculative-draft-model-path=AngelSlim/Qwen3-32B_eagle3 \
    --page-size=64 \
    --speculative-eagle-topk=1 \
    --speculative-num-steps=3 \
    --speculative-num-draft-tokens=4

NOTE: Speculative decoding is currently supported for Qwen3 and LLaMA model families. See the Speculative Decoding documentation for detailed configuration guidance.

Multi-Node Distributed Serving#

For large models requiring multiple TPU VMs:

# Node 0 (coordinator)
python3 -m sgl_jax.launch_server \
    --model-path MODEL_PATH \
    --dist-init-addr=NODE0_IP:10011 \
    --nnodes=2 \
    --node-rank=0 \
    --tp-size=8 \
    [other parameters...]

# Node 1 (worker)
python3 -m sgl_jax.launch_server \
    --model-path MODEL_PATH \
    --dist-init-addr=NODE0_IP:10011 \
    --nnodes=2 \
    --node-rank=1 \
    --tp-size=8 \
    [other parameters...]

Benchmarking with Requests#

Throughput Testing#

Basic throughput benchmark:

python3 -m sgl_jax.bench_serving \
    --backend sgl-jax \
    --dataset-name random \
    --num-prompts=100 \
    --random-input=512 \
    --random-output=128 \
    --max-concurrency=8 \
    --random-range-ratio=1 \
    --warmup-requests=0

Latency Testing#

Measure single-batch latency:

python3 -m sgl_jax.bench_one_batch_server \
    --base-url http://127.0.0.1:30000 \
    --model-path Qwen/Qwen-7B-Chat \
    --batch-size=32 \
    --input-len=256 \
    --output-len=32

Comprehensive Benchmark Script#

For systematic performance evaluation across different configurations:

#!/bin/bash
set -e

backend=${1:-sgl-jax}
num_prompts_per_concurrency=3
input_seq_lens=(1024 4096 8192)
output_seq_lens=(1 1024)
max_concurrencies=(8 16 32 64 128 256)

for input_seq_len in "${input_seq_lens[@]}"; do
    for output_seq_len in "${output_seq_lens[@]}"; do
        echo "======================================="
        echo "Testing ISL/OSL: $input_seq_len/$output_seq_len"
        echo "======================================="
        for max_concurrency in "${max_concurrencies[@]}"; do
            num_prompts=$((num_prompts_per_concurrency * max_concurrency))
            python3 -m sgl_jax.bench_serving \
                --backend ${backend} \
                --dataset-name random \
                --num-prompts ${num_prompts} \
                --random-input ${input_seq_len} \
                --random-output ${output_seq_len} \
                --max-concurrency ${max_concurrency} \
                --random-range-ratio 1 \
                --disable-ignore-eos \
                --warmup-requests 0
        done
    done
done

For detailed help on all benchmark parameters:

python3 -m sgl_jax.bench_serving --help

See the Benchmark and Profiling Guide for advanced benchmarking techniques and profiling with JAX Profiler.

Performance Optimization#

Memory Optimization#

Reduce memory usage:

  • Lower --mem-fraction-static (from 0.8 → 0.5 → 0.3)

  • Decrease --max-prefill-tokens (from 16384 → 8192 → 4096)

  • Reduce --max-running-requests

Handle OOM errors:

  • Start with conservative memory settings (--mem-fraction-static=0.5)

  • Gradually increase until you find the optimal balance

  • Increase --page-size for better memory locality (1 → 16 → 64 → 128)

Throughput Optimization#

To maximize tokens per second:

  • Use FlashAttention backend: --attention-backend=fa

  • Enable speculative decoding (EAGLE3) for Qwen3 models (20-40% improvement)

  • Increase --max-running-requests to 256+

  • Set --mem-fraction-static to 0.8+ (if memory allows)

  • Use larger page sizes (64-128)

  • Enable chunked prefill: --chunked-prefill-size=2048

Latency Optimization#

To minimize time-to-first-token (TTFT) and inter-token latency:

  • Reduce --page-size to 1-4

  • Lower --max-running-requests (16-32) for smaller batches

  • Reduce --chunked-prefill-size

  • Use conservative memory settings to avoid GC pauses

TPU-Specific Optimizations#

  1. JIT Compilation Cache:

    export JAX_COMPILATION_CACHE_DIR=/tmp/jit_cache
    

    Always set this environment variable to cache compiled kernels and accelerate server startup.

  2. Data Type Optimization: Use --dtype=bfloat16 for TPU native optimization. TPUs are specifically designed for bfloat16 computations.

  3. Tensor Parallelism: Match --tp-size to your TPU core configuration (1, 4, or 8) for optimal model distribution.

  4. Attention Backend: Always use --attention-backend=fa (FlashAttention) for production workloads.

Troubleshooting#

OOM (Out of Memory) Errors#

If you encounter out-of-memory errors:

  1. Reduce --mem-fraction-static from 0.8 to 0.5 or lower

  2. Decrease --max-prefill-tokens from 8192 to 4096 or 2048

  3. Lower --max-running-requests to reduce concurrent batch size

  4. Increase --page-size for better memory layout efficiency

Compilation Long-Time#

If the server takes too long to start:

  1. Ensure JAX_COMPILATION_CACHE_DIR is properly set

  2. Understand that the first run requires JIT compilation (this is normal)

  3. Subsequent runs will be significantly faster with cached compilations

  4. Consider using --skip-server-warmup to defer compilation until first request

Low Throughput#

If you’re not achieving expected throughput:

  1. Verify --tp-size matches your TPU core configuration

  2. Check that --attention-backend=fa is enabled

  3. Increase --max-running-requests to enable larger batch formation

  4. Consider enabling speculative decoding for compatible models

  5. Ensure memory settings allow for sufficient batch sizes

Connection Issues#

If clients cannot connect to the server:

  1. Ensure --host=0.0.0.0 for external access (not just 127.0.0.1)

  2. Verify firewall rules allow traffic on the specified port (default: 30000)

  3. Check that the server process is running: curl http://localhost:30000/health

Advanced Features#

Speculative Decoding#

SGLang-JAX supports EAGLE and EAGLE3 speculative decoding algorithms for Qwen3 and LLaMA model families. Speculative decoding can improve throughput by 20-40% without affecting output quality.

See the Speculative Decoding documentation for detailed configuration and supported model combinations.

Chunked Prefill#

Enable mixed prefill-decode batching for better TPU utilization:

--chunked-prefill-size=2048 --enable-mixed-chunk

This allows the scheduler to mix prefill operations with decode operations in the same batch, improving overall throughput.

Custom Attention Backends#

SGLang-JAX supports a plugin-based attention backend system. You can implement custom attention kernels optimized for specific use cases.

See the Attention Backend documentation for implementation details.

Environment Verification#

Verify your TPU setup before deploying:

python -c "from sgl_jax import check_env; check_env.check_env()"

This command checks:

  • Installed package versions

  • TPU device availability and specifications

  • System resources and configuration

  • Compatibility of settings

Contributing#

We welcome contributions to improve TPU support in SGLang-JAX!

Areas for Contribution#

Check the Development Roadmap to see planned features and find opportunities to contribute new functionality.

Current contribution areas include:

  • Performance optimizations for specific TPU generations

  • Support for additional model architectures

  • Documentation improvements and examples

  • Bug reports and fixes

  • Benchmark results and performance analysis

How to Contribute#

  1. Visit the sglang-jax repository

  2. Read the Contribution Guide

  3. Join the SGL-JAX Slack community for discussions

  4. Report issues at sglang-jax/issues

Testing on TPU#

For contributors who need TPU access for testing:

References#

Documentation#

External Resources#