Skip to content

Quantization

Quantize a tensor to FP4.

Sample Code

With Four Over Six

a = torch.tensor(1024, 1024, dtype=torch.bfloat16, device="cuda")
a_e2m1, a_sf, a_normconst = quantize_to_fp4(a)

Without Four Over Six

a = torch.tensor(1024, 1024, dtype=torch.bfloat16, device="cuda")
a_e2m1, a_sf, a_normconst = quantize_to_fp4(
    a,
    scale_rule=AdaptiveBlockScalingRule.always_6,
)

With Stochastic Rounding

a = torch.tensor(1024, 1024, dtype=torch.bfloat16, device="cuda")
a_e2m1, a_sf, a_normconst = quantize_to_fp4(a, round_style=RoundStyle.stochastic)

With the Random Hadamard Transform

from scipy.linalg import hadamard

a = torch.tensor(1024, 1024, dtype=torch.bfloat16, device="cuda")
had = torch.tensor(hadamard(16), dtype=torch.bfloat16, device="cuda")
a_e2m1, a_sf, a_normconst = quantize_to_fp4(a, had=had)

Backends

We provide three different implementations of FP4 quantization:

  • CUDA: A fast implementation written in CUDA which currently does not support the operations required for training (2D block scaling, stochastic rounding, random Hadamard transform). Requires a Blackwell GPU.
  • Triton: A slightly slower implementation written in Triton which supports all operations needed for training. Requires a Blackwell GPU.
  • PyTorch: A slow implementation written in PyTorch which supports all operations and can be run on any GPU.

If quantize_to_fp4 is called with backend=None, a backend will be selected automatically based on the following rules:

  • If there is no GPU available, or if the available GPU is not a Blackwell GPU, select PyTorch.
  • If any quantization options are set other than scale_rule, select Triton.
    • However, if the available GPU is SM120 (i.e. RTX 5090, RTX 6000) and round_style is set to RoundStyle.stochastic, select PyTorch as stochastic rounding does not have hardware support on SM120 GPUs.
  • Otherwise, select CUDA.

Parameters

Parameters:

Name Type Description Default
x Tensor

The input tensor to quantize.

required
backend QuantizeBackend

The backend to use for quantization, either QuantizeBackend.cuda, QuantizeBackend.triton, or QuantizeBackend.pytorch. If no backend is provided, one will be selected automatically based on the available GPU and the options provided. See above for more details.

None
scale_rule AdaptiveBlockScalingRule

The scaling rule to use during quantization. See (Adaptive Block Scaling)[/adaptive_block_scaling] for more details.

mse
block_scale_2d bool

If True, scale factors will be computed across 16x16 chunks of the input rather than 1x16 chunks. This is useful to apply to the weight matrix during training, so that W and W.T will be equivalent after quantization.

False
had Tensor

A high-precision Hadamard matrix to apply to the input prior to quantization.

None
fp4_format FP4Format

The FP4 format to quantize to, either FP4Format.mxfp4 or FP4Format.nvfp4.

nvfp4
round_style RoundStyle

The rounding style to apply during quantization, either RoundStyle.nearest for round-to-nearest quantization, or RoundStyle.stochastic for stochastic rounding.

nearest
transpose bool

If True, the output will be a quantized version of the transposed input. This may be helpful for certain operations during training as fp4_matmul requires that both tensors are provided in row-major format.

False

Returns:

Type Description
Tensor

The packed E2M1 values.

Tensor

The FP8 scale factors.

Tensor | None

The tensor-wide FP32 scale factor.

Source code in src/fouroversix/frontend.py
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
def quantize_to_fp4(
    x: torch.Tensor,
    *,
    backend: QuantizeBackend | None = None,
    scale_rule: AdaptiveBlockScalingRule = AdaptiveBlockScalingRule.mse,
    block_scale_2d: bool = False,
    had: torch.Tensor | None = None,
    fp4_format: FP4Format = FP4Format.nvfp4,
    round_style: RoundStyle = RoundStyle.nearest,
    transpose: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
    """
    Quantize a tensor to FP4.

    ## Sample Code

    ### With Four Over Six

    ```
    a = torch.tensor(1024, 1024, dtype=torch.bfloat16, device="cuda")
    a_e2m1, a_sf, a_normconst = quantize_to_fp4(a)
    ```

    ### Without Four Over Six

    ```
    a = torch.tensor(1024, 1024, dtype=torch.bfloat16, device="cuda")
    a_e2m1, a_sf, a_normconst = quantize_to_fp4(
        a,
        scale_rule=AdaptiveBlockScalingRule.always_6,
    )
    ```

    ### With Stochastic Rounding

    ```
    a = torch.tensor(1024, 1024, dtype=torch.bfloat16, device="cuda")
    a_e2m1, a_sf, a_normconst = quantize_to_fp4(a, round_style=RoundStyle.stochastic)
    ```

    ### With the Random Hadamard Transform

    ```
    from scipy.linalg import hadamard

    a = torch.tensor(1024, 1024, dtype=torch.bfloat16, device="cuda")
    had = torch.tensor(hadamard(16), dtype=torch.bfloat16, device="cuda")
    a_e2m1, a_sf, a_normconst = quantize_to_fp4(a, had=had)
    ```

    ## Backends

    We provide three different implementations of FP4 quantization:

    - **CUDA**: A fast implementation written in CUDA which currently does not support
        the operations required for training (2D block scaling, stochastic rounding,
        random Hadamard transform). Requires a Blackwell GPU.
    - **Triton**: A slightly slower implementation written in Triton which supports all
        operations needed for training. Requires a Blackwell GPU.
    - **PyTorch**: A slow implementation written in PyTorch which supports all
        operations and can be run on any GPU.

    If `quantize_to_fp4` is called with `backend=None`, a backend will be selected
    automatically based on the following rules:

    - If there is no GPU available, or if the available GPU is not a Blackwell GPU,
        select PyTorch.
    - If any quantization options are set other than `scale_rule`, select Triton.
        - However, if the available GPU is SM120 (i.e. RTX 5090, RTX 6000) and
            `round_style` is set to `RoundStyle.stochastic`, select PyTorch as
            stochastic rounding does not have hardware support on SM120 GPUs.
    - Otherwise, select CUDA.

    ## Parameters

    Args:
        x (torch.Tensor): The input tensor to quantize.
        backend (QuantizeBackend): The backend to use for quantization, either
            `QuantizeBackend.cuda`, `QuantizeBackend.triton`, or
            `QuantizeBackend.pytorch`. If no backend is provided, one will be selected
            automatically based on the available GPU and the options provided. See above
            for more details.
        scale_rule (AdaptiveBlockScalingRule): The scaling rule to use during
            quantization. See (Adaptive Block Scaling)[/adaptive_block_scaling] for more
            details.
        block_scale_2d (bool): If True, scale factors will be computed across 16x16
            chunks of the input rather than 1x16 chunks. This is useful to apply to the
            weight matrix during training, so that W and W.T will be equivalent after
            quantization.
        had (torch.Tensor): A high-precision Hadamard matrix to apply to the input prior
            to quantization.
        fp4_format (FP4Format): The FP4 format to quantize to, either `FP4Format.mxfp4`
            or `FP4Format.nvfp4`.
        round_style (RoundStyle): The rounding style to apply during quantization,
            either `RoundStyle.nearest` for round-to-nearest quantization, or
            `RoundStyle.stochastic` for stochastic rounding.
        transpose (bool): If True, the output will be a quantized version of the
            transposed input. This may be helpful for certain operations during training
            as `fp4_matmul` requires that both tensors are provided in row-major format.

    Returns:
        The packed E2M1 values.
        The FP8 scale factors.
        The tensor-wide FP32 scale factor.

    """

    kwargs = {
        "scale_rule": scale_rule,
        "block_scale_2d": block_scale_2d,
        "had": had,
        "fp4_format": fp4_format,
        "round_style": round_style,
        "transpose": transpose,
    }

    if backend is None:
        backend = QuantizeBackend.auto_select(x, **kwargs)
    elif not backend.is_supported(x, **kwargs):
        msg = f"Backend {backend} does not support the given parameters"
        raise ValueError(msg)

    return backend.quantize_to_fp4(x, **kwargs)