TFLOPS Gap: Why FP4 MoE Kernel Engineering Matters on Blackwell
How to achieve 3.54x speedup over BF16 and 1.32x over vLLM on interactive inference through kernel fusion, Blackwell optimization, and expert-aware computation
by https://x.com/advpropx (on X)
Introduction
When NVIDIA announced Blackwell's native FP4 support, the promise was clear: 2x memory bandwidth savings and massive throughput gains for LLM inference. But hardware capabilities are only half the story. The other half? Kernel engineering.
I benchmarked three leading MoE backends (vLLM, SGLang, and FlashInfer CuteDSL) running Mixture of Experts models with FP4 quantization on a Blackwell B200 GPU. The model: GPT-OSS-20B with 32 experts and top-4 routing. Same hardware, same model architecture, different kernels.
The results at peak throughput (batch size 4096):
- SGLang: 1262 TFLOPS
- FlashInfer CuteDSL: 1225 TFLOPS
- vLLM: 1117 TFLOPS
That's a 145 TFLOPS gap between SGLang and vLLM. At batch size 1 (where interactive inference lives), SGLang is 1.32x faster than vLLM FP4. More importantly, SGLang FP4 is 2.23x faster than SGLang BF16 at batch size 128, saving 171 seconds per 1000 token generation.
This isn't about distributed training or multi-node setups. This is single-GPU inference with grouped GEMM kernels. The difference comes down to three critical optimizations:
- Kernel fusion (7 memory passes to 5, eliminating shuffle overhead)
- Blackwell-specific CUTLASS schedules (native FP4 warp specialization)
- Adaptive grid sizing (maximizing SM occupancy at small batches)
Let's dive into the data, then unpack exactly what makes these kernels different.
The Benchmark: GPT-OSS-20B on Blackwell B200
Model Configuration
- Architecture: GPT-OSS-20B
- Experts: 32 total, top-4 routing
- Hidden size: 2880
- Intermediate size: 7680 (per expert)
- Quantization: NVFP4 (4-bit floating point, E2M1 format)
- Hardware: NVIDIA Blackwell B200 (sm_100a)
Peak Throughput (Batch Size 4096)
Figure 1: Effective TFLOPS across batch sizes. SGLang maintains consistent lead, widening at small batches.
At batch size 4096, we see:
- SGLang FP4: 1262 TFLOPS (3.54x faster than BF16)
- FlashInfer FP4: 1225 TFLOPS (3.43x faster than BF16)
- vLLM FP4: 1117 TFLOPS (3.24x faster than BF16)
SGLang's 145 TFLOPS advantage over vLLM compounds across layers and tokens.
Latency Breakdown (Batch Size 128)
Figure 2: Per-layer latency at batch size 128 (decode sweet spot).
At the decode-optimized batch size of 128:
- vLLM FP4: 0.604ms per layer (112.6 TFLOPS)
- SGLang FP4: 0.433ms per layer (157.1 TFLOPS)
SGLang is 28.3% faster
The Small-Batch Advantage
Here's where it gets interesting. The performance gap widens at smaller batch sizes.
At batch size 1:
- vLLM FP4: 369.5μs per layer
- SGLang FP4: 206.9μs per layer (1.78x faster)
- FlashInfer FP4: 481.9μs per layer
This matters. Interactive inference (chatbots, code completion, agents) operates at batch sizes 1-16. That's where SGLang's 1.32x-2.23x advantage translates directly to user experience.
Kernel Finding #1: Fusion Eliminates Memory Bottlenecks
vLLM's Sequential Approach (7 Kernel Launches)
vLLM's MoE forward pass launches seven separate CUDA kernels:
Source: vllm/model_executor/layers/fused_moe/cutlass_moe.py:671-712
# 1. Reorder tokens by expert assignment
rep_a_fp4 = ops.shuffle_rows(a_fp4, a_map, (m * topk, k))
rep_a_blockscale = ops.shuffle_rows(a_blockscale, a_map, (m * topk, k // 16))
# 2. Quantize activations to FP4
a_fp4, a_blockscale = ops.scaled_fp4_experts_quant(a, a1_gscale, ...)
# 3. First GEMM: gate_up projection
c1 = ops.cutlass_fp4_moe_mm(rep_a_fp4, w1_fp4, rep_a_blockscale, w1_blockscale, ...)
# 4. SiLU activation
intermediate = torch.empty((m * topk, n), device=device, dtype=out_dtype)
torch.ops._C.silu_and_mul(c1, intermediate)
# 5. Quantize intermediate activations
int_fp4, int_blockscale = ops.scaled_fp4_experts_quant(intermediate, a2_gscale, ...)
# 6. Second GEMM: down projection
c2 = ops.cutlass_fp4_moe_mm(int_fp4, w2_fp4, int_blockscale, w2_blockscale, ...)
# 7. Reorder output back to original token order
output = ops.shuffle_rows(c2, c_map, (m, k))
The cost:
- 7 kernel launches (each with ~5-10μs overhead)
- 7 global memory roundtrips
- 6 synchronization points between kernels
- Intermediate buffers allocated for rep_a_fp4, c1, intermediate, int_fp4, c2
At batch size 4, kernel launch overhead alone consumes 10-20% of total latency.
SGLang's Fused Reduction Kernel
SGLang fuses the first and last shuffle operations plus the final reduction into a single kernel:
Source: sglang/sgl-kernel/csrc/moe/prepare_moe_input.cu:258-321
template <typename T, typename scalar_t>
__global__ void apply_shuffle_mul_sum_kernel(
const T* __restrict__ input, // [m*topk, k]
const int* __restrict__ permutation, // [m*topk] mapping
const scalar_t* __restrict__ weights, // [m, topk] routing weights
T* __restrict__ output, // [m, k]
int m, int k, int topk
) {
// 128-bit vectorized loads
constexpr uint32_t vec_size = 16 / sizeof(scalar_t);
using vec_t = flashinfer::vec_t<T, vec_size>;
const int token_idx = blockIdx.x;
const int feature_idx = threadIdx.x * vec_size;
vec_t sum;
sum.fill(scalar_t(0));
// Iterate over top-k experts for this token
for (int k_idx = 0; k_idx < topk; k_idx++) {
const int src_idx = permutation[token_idx * topk + k_idx];
const scalar_t weight = weights[token_idx * topk + k_idx];
// Vectorized load from expert output
vec_t expert_output;
expert_output.load(input + src_idx * k + feature_idx);
// Multiply by routing weight and accumulate
#pragma unroll
for (int i = 0; i < vec_size; i++) {
sum[i] += expert_output[i] * weight;
}
}
// Vectorized store to output
sum.store(output + token_idx * k + feature_idx);
}
Three operations in one kernel:
- Token reordering (permutation lookup)
- Routing weight multiplication
- TopK reduction (sum across selected experts)
The payoff:
- 3x memory bandwidth reduction (single global memory pass)
- Better cache locality (reuse permutation and weights in L1)
- Eliminates 2 intermediate buffer allocations
- Reduces kernel launch count: 7 to 5
128-bit vectorization means each thread processes 8 bfloat16 elements per load, saturating memory bandwidth even at small batch sizes.
FlashInfer CuteDSL: A Different Trade-Off
FlashInfer takes yet another approach: expert-first layout instead of token-first.
Source: benchmark script
def prepare_flashinfer_input_vectorized(hidden_states, topk_ids, topk_weights, num_experts, topk, device, dtype):
"""Prepare input in expert-first format for FlashInfer CuteDSL.
Reshapes from [batch, hidden] to [num_experts, max_tokens_per_expert, hidden]
"""
batch_size, hidden_dim = hidden_states.shape
# Count tokens per expert using vectorized bincount
flat_ids = topk_ids.flatten()
expert_counts = torch.bincount(flat_ids.to(torch.int64), minlength=num_experts).to(torch.int32)
max_tokens_per_expert = expert_counts.max().item()
# Sort tokens by expert ID
sorted_indices = torch.argsort(flat_ids)
sorted_hidden = weighted_hidden[sorted_indices]
sorted_expert_ids = flat_ids[sorted_indices]
# Create expert-first tensor [num_experts, max_tokens, hidden_dim]
expert_hidden = torch.zeros((num_experts, max_tokens_per_expert, hidden_dim), device=device, dtype=dtype)
# Fill using advanced indexing
expert_offsets = torch.zeros(num_experts + 1, dtype=torch.int64, device=device)
expert_offsets[1:] = expert_counts.cumsum(0)
token_positions = torch.arange(len(sorted_expert_ids), device=device)
position_in_expert = token_positions - expert_offsets[sorted_expert_ids]
expert_hidden[sorted_expert_ids, position_in_expert] = sorted_hidden
return expert_hidden, expert_counts
Trade-off: FlashInfer's preprocessing is heavier (sorting, scatter), but the expert-first layout enables better expert-level batching. At small batch sizes (BS=1-16), this overhead hurts performance. At large batches (BS=4096), it achieves 1225 TFLOPS, competitive with SGLang's 1262 TFLOPS.
Kernel Finding #2: Blackwell-Specific CUTLASS Schedules
SGLang's Native FP4 Schedule
SGLang uses a Blackwell-optimized CUTLASS schedule designed specifically for grouped FP4 GEMM on sm_100a:
Source: sglang/sgl-kernel/csrc/moe/nvfp4_blockwise_moe.cu:196-201
// SM100/Blackwell B200 configuration
using ThreadBlockShape = Shape<_128, _128, _128>;
using ClusterShape = Shape<_1, _1, _1>;
using AlignmentA = 32;
using AlignmentB = 32;
using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100;
What KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100 gives you:
- Warp specialization for FP4: Dedicated warp roles for loading FP4 data, dequantizing to FP16/BF16, and accumulating in FP32. This avoids the generic load-convert-compute path.
- TMA (Tensor Memory Accelerator) integration: Asynchronous bulk tensor loads that bypass L1 cache and feed directly into shared memory. Requires strict 128-byte alignment.
- 1 SM grouping: Processes multiple experts per SM rather than one expert per SM. Better for MoE workloads where expert sizes vary.
- Native NvFP4 support: Uses Blackwell's hardware FP4 instructions instead of software emulation.
TMA Alignment Enforcement
Source: sglang/sgl-kernel/csrc/moe/nvfp4_blockwise_moe.cu:89-103
// Strict TMA alignment enforcement
assert((reinterpret_cast<uintptr_t>(a_scales_offsets[expert_id]) % 128) == 0
&& "TMA requires 128-byte alignment");
*layout_sfa_ptr = ScaleConfig::tile_atom_to_shape_SFA(
cute::make_shape(static_cast<int>(m), static_cast<int>(n),
static_cast<int>(k), 1)
);
SGLang pads blockscale offsets to 128-token boundaries to guarantee TMA alignment:
Source: sglang/sgl-kernel/csrc/moe/prepare_moe_input.cu:55-73
// Round to 128-token boundaries for TMA
blockscale_offsets[expert_id + 1] =
(expert_offsets[expert_id + 1] + 127) / 128 * 128;
This wastes a small amount of memory (max 127 extra floats per expert), but ensures zero TMA stalls from misalignment.
vLLM's Generic CUTLASS Configuration
vLLM uses standard CUTLASS 3.x schedules that work across Ampere, Hopper, and Blackwell but lack Blackwell-specific optimizations.
Source: vllm/csrc/quantization/fp4/nvfp4_blockwise_moe_kernel.cu:93-101
// Generic TMA alignment checks (no padding)
assert(reinterpret_cast<uintptr_t>(a_scales) % 128 == 0);
assert(reinterpret_cast<uintptr_t>(b_scales) % 128 == 0);
vLLM checks alignment but doesn't enforce it through padding. If token counts don't align naturally, TMA falls back to slower paths.
Impact:
- Higher effective memory bandwidth (TMA vs L1 cache)
- Better warp utilization (specialized roles vs generic)
- Fewer register spills (tuned for sm_100a)
Kernel Finding #3: Adaptive Grid Sizing for Small Batches
The Occupancy Problem at Small Batch Sizes
GPU kernels achieve peak performance when they saturate the GPU with enough parallelism. On a B200 with 142 SMs, you need at least 142 thread blocks to keep all SMs busy.
The problem with MoE at batch size 1:
- 32 experts x 4 topk = 128 tokens to process
- If each thread block handles 128 tokens: only 1 block
- 99.3% of SMs sit idle
Standard CUTLASS launch heuristics don't adapt well to this regime.
SGLang's Dynamic Block Sizing
SGLang uses adaptive grid sizing that trades block size for grid size when parallelism is low:
Source: sglang/sgl-kernel/csrc/gemm/nvfp4_expert_quant.cu:456-477
// Adaptive kernel launch configuration
int const workSizePerRow = k / ELTS_PER_THREAD; // 8 FP4 elements per thread
int const totalWorkSize = m_topk * workSizePerRow;
dim3 block(std::min(workSizePerRow, 512));
int const numBlocksPerSM = 2048 / block.x;
dim3 grid(std::min(
static_cast<int>((totalWorkSize + block.x - 1) / block.x),
multiProcessorCount * numBlocksPerSM
));
// Dynamic adjustment: halve block size, double grid size
while (grid.x <= multiProcessorCount && block.x > 64) {
grid.x *= 2;
block.x = (block.x + 1) / 2;
}
Example for batch size 1:
Initial configuration:
m_topk = 128 tokens (1 batch x 32 experts x 4 topk)
k = 2880 (hidden size)
workSizePerRow = 2880 / 8 = 360
totalWorkSize = 128 x 360 = 46,080
block.x = min(360, 512) = 360
grid.x = min(46080 / 360, 142 x 5) = min(128, 710) = 128
After adaptive adjustment:
- Iteration 1: grid.x=128 < 142, adjust: grid=256, block=180
- Iteration 2: grid.x=256 > 142, stop
- Final: grid=256, block=180
Result: Maximizes SM occupancy by finding the right balance between block size and grid size.
vLLM's Fixed Heuristics
vLLM relies on CUTLASS's default launch heuristics, which optimize for large matrix sizes. At small batch sizes, this results in:
- Larger block sizes (256-512 threads)
- Fewer blocks (underutilization)
- Lower occupancy
Measured impact: SGLang's adaptive sizing accounts for a significant portion of the 1.32x speedup at BS=1.
The DeepSeek Connection: Expert Parallelism at Scale
DeepSeek-V3 Architecture
DeepSeek-V3 pushed MoE to new extremes:
- 256 experts per layer
- Top-8 routing (2x more active experts)
- 7168 hidden dim, 18432 intermediate (massive per-expert compute)
This architecture is designed for expert parallelism (EP): splitting experts across multiple GPUs/nodes and using all-to-all communication to route tokens.
I benchmarked a scaled-down version that fits on a single B200:
- Experts: 256
- TopK: 8
- Hidden: 2560 (scaled from 7168)
- Intermediate: 8960 (scaled from 18432)
Results at batch size 4096:
- SGLang FP4: 993 TFLOPS (4.12x faster than BF16)
- FlashInfer FP4: 1132 TFLOPS (4.69x faster than BF16)
- vLLM FP4: 968 TFLOPS (4.08x faster than BF16)
Why the Gap Shrinks with More Experts
With 256 experts, there's more inherent parallelism. Even with suboptimal launch heuristics, vLLM saturates the GPU at large batch sizes. The kernel fusion and Blackwell optimization advantages become less pronounced when compute-bound.
But at small batches, SGLang's adaptive grid sizing still wins. At BS=1, SGLang achieves 1.47x speedup over BF16 baseline while vLLM shows 0.86x (slower than BF16 due to quantization overhead dominating at low occupancy).
DeepEP: Multi-Node Expert Parallelism
For true DeepSeek-V3 scale (256 experts at full size), you need expert parallelism across nodes. SGLang implements DeepEP (DeepSeek Expert Parallelism), which uses:
- All-to-All dispatch: Route tokens to the rank that owns each expert
- Local GEMM: Each rank computes its assigned experts
- All-to-All combine: Gather results back to original token order
Source: sglang/srt/layers/moe/token_dispatcher/deepep.py:398-457
def _dispatch_core(
self,
x: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
previous_event,
):
buffer = self._get_buffer()
# Compute dispatch layout (which tokens go to which rank)
(
num_tokens_per_rank,
num_tokens_per_rdma_rank,
num_tokens_per_expert,
is_token_in_rank,
previous_event,
) = buffer.get_dispatch_layout(
topk_ids,
self.num_experts,
previous_event=previous_event,
async_finish=self.async_finish,
allocate_on_comm_stream=previous_event is not None,
)
# All-to-all dispatch
(
recv_x,
recv_topk_ids,
recv_topk_weights,
num_recv_tokens_per_expert,
self.handle,
event,
) = buffer.dispatch(
x,
topk_idx=topk_ids,
topk_weights=topk_weights,
num_tokens_per_rank=num_tokens_per_rank,
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
is_token_in_rank=is_token_in_rank,
num_tokens_per_expert=num_tokens_per_expert,
previous_event=previous_event,
async_finish=self.async_finish,
allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
expert_alignment=128 if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM else 1,
config=DeepEPConfig.get_instance().normal_dispatch_config,
)
return (
recv_x,
recv_topk_ids,
recv_topk_weights,
num_recv_tokens_per_expert,
event,
)
Key details:
- Uses NCCL/RDMA for low-latency all-to-all
- Supports FP8 communication to reduce bandwidth (quantize activations before dispatch)
- Has two modes: Normal (prefill) and Low-Latency (decode)
The same kernel optimizations (fusion, Blackwell scheduling, adaptive sizing) apply to the local GEMM step after dispatch. So the 1.32x-2.23x speedup compounds with multi-node parallelism.
FlashInfer CuteDSL: The Third Contender
CuteDSL is a relatively new entrant, focusing on template-based kernel generation for MoE workloads.
Performance Comparison (GPT-OSS-20B, 32 experts, top-4)
At peak throughput (BS=4096), FlashInfer achieves 1225 TFLOPS, just 3% slower than SGLang (1262 TFLOPS). But at small batches, it's 2.3x slower than SGLang.
At BS=1:
- SGLang FP4: 206.9μs
- vLLM FP4: 369.5μs
- FlashInfer FP4: 481.9μs
Why FlashInfer Struggles at Small Batches
FlashInfer's expert-first layout requires expensive preprocessing:
- Bincount to count tokens per expert
- Argsort to group tokens by expert
- Scatter to fill expert-first tensor with potential padding
At batch size 1 with 32 experts, this overhead dominates. At batch size 4096, the overhead amortizes and FlashInfer's expert-level batching shines.
FlashInfer's Strength: Masked GEMM
FlashInfer supports masked GEMM, where each expert's output is padded to a fixed size. This enables:
- Better memory coalescing (no irregular strides)
- Simpler kernel logic (no variable-length batching)
At large expert counts (256), FlashInfer shows its strength: 1132 TFLOPS on DeepSeek-scaled config, outperforming SGLang's 993 TFLOPS at BS=4096.
Memory Bandwidth Analysis
Let's quantify the memory bandwidth savings from kernel fusion.
vLLM's Memory Traffic
For batch size 128, hidden size 2880, 32 experts, topk 4:
- Tokens processed: 128 x 4 = 512 tokens
- Hidden size: 2880
- Data type: bfloat16 (2 bytes)
Memory operations:
- shuffle_rows (input): Read 512 x 2880 x 2 = 2.95 MB, Write 2.95 MB
- scaled_fp4_quant: Read 2.95 MB, Write 1.47 MB (FP4) + 0.09 MB (scales)
- cutlass_fp4_moe_mm (GEMM1): Read 1.47 + 0.09 MB (activations) + 56 MB (weights), Write 7.87 MB (intermediate)
- silu_and_mul: Read 7.87 MB, Write 3.94 MB
- scaled_fp4_quant: Read 3.94 MB, Write 1.97 MB + 0.12 MB (scales)
- cutlass_fp4_moe_mm (GEMM2): Read 1.97 + 0.12 MB + 28 MB (weights), Write 2.95 MB
- shuffle_rows (output): Read 2.95 MB, Write 2.95 MB
Total activation memory traffic: 2x(2.95) + 2x(2.95) + 7.87 + 3.94 + 2.95 = 26.5 MB (Excluding weight reads, which dominate but are common to both approaches)
SGLang's Fused Memory Traffic
SGLang fuses steps 1 and 7 into apply_shuffle_mul_sum:
- scaled_fp4_quant: Read 2.95 MB, Write 1.47 MB + 0.09 MB
- cutlass_fp4_moe_mm (GEMM1): Read 1.47 + 0.09 MB + 56 MB, Write 7.87 MB
- silu_and_mul: Read 7.87 MB, Write 3.94 MB
- scaled_fp4_quant: Read 3.94 MB, Write 1.97 MB + 0.12 MB
- cutlass_fp4_moe_mm (GEMM2): Read 1.97 + 0.12 MB + 28 MB, Write 2.95 MB
- apply_shuffle_mul_sum: Read 2.95 MB (c2), Read 0.01 MB (c_map, topk_weights), Write 2.95 MB
Total activation memory traffic: 2.95 + 7.87 + 3.94 + 2.95 + 2.95 + 0.01 = 20.7 MB
Savings: (26.5 - 20.7) / 26.5 = 21.9% reduction in activation memory traffic.
At B200's memory bandwidth (8 TB/s), this translates to:
- vLLM: 26.5 MB / 8000 GB/s = 0.0033 ms
- SGLang: 20.7 MB / 8000 GB/s = 0.0026 ms
- Savings: 0.0007 ms per layer
Over 24 layers and 1000 forward passes (1000 token generation):
- Total savings: 0.0007 x 24 x 1000 = 16.8 ms
This is a conservative estimate. Actual savings are higher due to cache effects and reduced launch overhead.
Why Kernel Engineering Compounds
These optimizations might seem incremental. 145 TFLOPS gap, 171 seconds saved per 1000 tokens at BS=128, 21.9% memory bandwidth reduction. But they compound across:
- Layer Count
Modern LLMs have 24-80 layers. Every layer runs the MoE forward pass. Multiply the per-layer savings by 24-80x.
- Token Count
A single request generates hundreds to thousands of tokens. Chat applications, code generation, and agentic workflows routinely exceed 10K tokens.
At 24 layers with batch size 128:
- vLLM FP4: 0.604ms/layer x 24 layers x 1000 tokens = 14.5 seconds per 1000 tokens
- SGLang FP4: 0.433ms/layer x 24 layers x 1000 tokens = 10.4 seconds per 1000 tokens
- Savings: 4.1 seconds per 1000 tokens (28% faster)
For a mid-sized inference workload (1M tokens/day), kernel optimization saves substantial GPU hours and cost.
- Multi-Request Batching
The savings scale linearly with batch size up to the decode sweet spot (BS=128). For serving workloads with continuous batching, the 2.23x speedup at BS=128 directly translates to 2.23x higher throughput per GPU.
Recommendations for Framework Developers
If you're building an inference framework and targeting Blackwell GPUs, here's what matters:
- Fuse Aggressively
vLLM's 7-kernel approach is defensible for debugging and modularity, but production kernels should fuse:
- Shuffle + reduce (SGLang's apply_shuffle_mul_sum)
- Quantization + GEMM (avoid separate quant kernel)
- Activation + quantization (fuse SiLU with subsequent quant)
Target: Get token-to-token latency down to 3-4 kernel launches (prepare, GEMM1, GEMM2, reduce).
- Use Hardware-Specific Schedules
Don't rely on generic CUTLASS configs. NVIDIA provides optimized schedules for each architecture:
- Ampere: KernelTmaWarpSpecialized
- Hopper: KernelTmaWarpSpecializedPingpong
- Blackwell: KernelPtrArrayTmaWarpSpecialized1SmNvf4Sm100
Test on target hardware. Hopper-optimized kernels may underperform on Blackwell due to different tensor core configurations.
- Adaptive Launch Heuristics
Fixed grid/block sizes fail at extreme batch sizes. Implement:
- Small batches (1-16): Maximize grid size, minimize block size
- Large batches (512+): Standard CUTLASS heuristics
- Dynamic tuning: Profile on first run, cache configurations
SGLang's while loop is a simple but effective heuristic. More sophisticated approaches (e.g., auto-tuning similar to TVM) can further optimize.
- Enforce Alignment for TMA
If you're using TMA (and you should on Blackwell), pad your tensors to 128-byte boundaries. The memory overhead is negligible compared to the performance gain.
- Benchmark at Small Batches
Most public benchmarks focus on batch size 128-512. But interactive inference lives at BS=1-16. If your framework targets chatbots, code completion, or agents, optimize for small batches first.
Conclusion
The 145 TFLOPS gap between SGLang and vLLM on FP4 MoE inference isn't about CUDA magic or secret sauce. It's about systematic kernel engineering:
- Kernel fusion eliminates 21.9% of activation memory traffic
- Blackwell-specific CUTLASS schedules unlock native FP4 and TMA acceleration
- Adaptive grid sizing maximizes SM occupancy at small batches
These optimizations compound across layers, tokens, and requests to deliver:
- 1.32x speedup at batch size 1 (interactive inference)
- 2.23x speedup at batch size 128 (decode sweet spot)
- 3.54x speedup over BF16 baseline at large batches
FlashInfer shows that expert-first layouts can match SGLang at large batches (especially with 256 experts) but struggle at small batches due to preprocessing overhead.
The takeaway: Hardware support for FP4 is necessary but not sufficient. You need kernels that exploit Blackwell's unique features (TMA, warp specialization, native FP4 instructions) to unlock the full potential.
As models scale to 256+ experts and multi-node expert parallelism becomes standard, these optimizations will matter even more. The frameworks that invest in kernel engineering today will define the performance envelope tomorrow.
Appendix: Full Benchmark Data
Hardware
- GPU: NVIDIA Blackwell B200 on Nebius
- Compute Capability: sm_100a
Software
- vLLM: v0.11.0
- SGLang: v0.5.5rc2
- FlashInfer CuteDSL: from sglang's implementation
- CUDA: 13.0
Benchmark Methodology
- Warmup: 20 iterations
- Measurement: 200 iterations
- Metric: Mean latency (μs/ms), standard deviation, TFLOPS
- TFLOPS calculation:
flops = batch_size * topk * (
2 * hidden_dim * inter_dim * 2 + # up-projection (gate + up)
2 * inter_dim * hidden_dim # down-projection
)
tflops = flops * 1e-12 / (mean_ms * 1e-3)
- Synchronization: torch.cuda.synchronize() after each iteration
- Memory: GPU cache cleared between configs
Speedups over BF16
Deepseek like moe
References
My twitter https://x.com/advpropx
SGLang
- GitHub: sgl-project/sglang
- Fused reduction kernel: prepare_moe_input.cu
- Blackwell FP4 GEMM: nvfp4_blockwise_moe.cu
- Expert quantization: nvfp4_expert_quant.cu
- DeepEP integration: deepep.py
- CuteDSL (cutedsl moe is specifically sglang's project)
vLLM
- GitHub: vllm-project/vllm
- FP4 MoE layer: cutlass_moe.py
- CUTLASS kernel: nvfp4_blockwise_moe_kernel.cu
FlashInfer
- GitHub: flashinfer-ai/flashinfer
NVIDIA
- CUTLASS: NVIDIA/cutlass
- Blackwell Architecture: NVIDIA Blackwell White Paper
- TMA Documentation: CUDA Programming Guide - Tensor Memory Accelerator



