MoE + FP8¶
FP8 training (via torchao.float8) converts dense nn.Linear modules
to Float8Linear. Three classes of Linear in an MoE model get
excluded from that conversion:
Routed experts (
experts.*)Shared expert (
shared_expert.*)Router gate (
router.gate)
Everything else — attention projections, output head, dense-layer MLPs — converts normally.
Filter rule¶
From
apply_float8:
# kempnerforge/distributed/parallel.py — apply_float8._filter_fn
def _filter_fn(module, fqn):
if "experts" in fqn or "shared_expert" in fqn:
return False
return "router" not in fqn
convert_to_float8_training walks the model, asks the filter for
each Linear, and only replaces the module when the filter returns
True.
Why experts are excluded¶
The grouped GEMM path uses torch._grouped_mm, which is a distinct
kernel from torch._scaled_mm (what Float8Linear.forward calls
into). Even if an expert’s nn.Linear were wrapped in Float8Linear,
the wrap would be bypassed:
The forward path stacks per-expert weights into a single
(E, dim, hidden)tensor (or reads them from pre-packed parameters) and callstorch._grouped_mmdirectly.Float8Linear.forwardis never called on the expert weights.
The result would be FP8 parameter storage but bf16 compute — the worst of both worlds (lost precision at storage, no matmul speedup).
See Capacity and dispatch § Path A for the grouped GEMM call site.
Why the router is excluded¶
Two reasons:
Tiny output dim. Router gate is
Linear(dim, num_experts).num_expertsis typically 8-64 — not divisible by 16, whichtorch._scaled_mmrequires for its fast path. Fallback paths give no speedup.Not compute-bound. Routing is essentially a decision — the gate matmul is a tiny fraction of total FLOPs. FP8 quantization error on the routing decision could measurably perturb which experts get picked, which is a stability risk.
What stays FP8¶
Everything not caught by the three filter rules:
Module |
Converted |
FQN pattern |
|---|---|---|
Attention Q/K/V/O |
yes |
|
Dense MLP gate/up/down |
yes |
|
Output head |
yes |
|
Router gate |
no |
|
Routed experts |
no |
|
Shared expert |
no |
|
Embeddings |
no |
not |
RMSNorm |
no |
not |
With moe_frequency = 2 (alternating dense / MoE layers), the dense
layers’ MLPs get FP8 and the MoE layers’ MLPs don’t — which is fine,
because the MoE layers’ compute is dominated by grouped GEMM on the
expert weights anyway.
Memory and throughput¶
Excluding expert Linears means FP8 gives smaller gains for MoE than for dense training. A 4B-active MoE where half the FLOPs are in grouped-GEMM experts sees FP8 speedup only on the remaining half (attention + shared expert + output head + dense-layer MLPs).
For the 7B dense Llama reference (see Benchmarks § MFU scaling), FP8 provides a measurable throughput lift at 16 GPUs. For MoE runs, the same config would see a smaller lift because the expert portion of compute runs at bf16 regardless.
Config¶
One switch:
[train]
mixed_precision = "fp8" # flips the conversion on
No separate MoE knob. The expert/router exclusions are hardcoded in
_filter_fn. If you want to experiment with FP8 on a specific
expert path — for example, trying FP8 on the shared expert — you’d
edit apply_float8 directly.
FP8 + EP + TP: what actually composes¶
Three separate constraints:
FP8 + TP: Not supported.
JobConfig.__post_init__raises whentrain.mixed_precision = "fp8"anddistributed.tp > 1. Reason:Float8Linear’s weight wrapper callsaten.is_pinnedon DTensor, which has no sharding strategy yet. See FP8 § TP incompatibility.FP8 + EP: Fine. Experts are excluded from FP8, so EP — which operates on experts — is orthogonal to the conversion.
FP8 + MoE without EP: Fine. Non-expert Linears convert, expert Linears stay bf16, grouped GEMM continues to work.
See also¶
FP8 — the canonical FP8 reference (full conversion details,
enable_fsdp_float8_all_gather, master weights, hardware requirements).Expert parallelism § EP + FP8 — EP compatibility notes with FP8.
Capacity and dispatch § Path A — the grouped-GEMM path that bypasses
Float8Linear.Parallelism order § Float8 before AC and FSDP — where FP8 conversion sits in the apply sequence.
Validation rules — the FP8 + TP config check.