Quantization
Quantize a tensor to FP4.
Sample Code
With Four Over Six
a = torch.tensor(1024, 1024, dtype=torch.bfloat16, device="cuda")
a_e2m1, a_sf, a_normconst = quantize_to_fp4(a)
Without Four Over Six
a = torch.tensor(1024, 1024, dtype=torch.bfloat16, device="cuda")
a_e2m1, a_sf, a_normconst = quantize_to_fp4(
a,
scale_rule=AdaptiveBlockScalingRule.always_6,
)
With Stochastic Rounding
a = torch.tensor(1024, 1024, dtype=torch.bfloat16, device="cuda")
a_e2m1, a_sf, a_normconst = quantize_to_fp4(a, round_style=RoundStyle.stochastic)
With the Random Hadamard Transform
from scipy.linalg import hadamard
a = torch.tensor(1024, 1024, dtype=torch.bfloat16, device="cuda")
had = torch.tensor(hadamard(16), dtype=torch.bfloat16, device="cuda")
a_e2m1, a_sf, a_normconst = quantize_to_fp4(a, had=had)
Backends
We provide three different implementations of FP4 quantization:
- CUDA: A fast implementation written in CUDA which currently does not support the operations required for training (2D block scaling, stochastic rounding, random Hadamard transform). Requires a Blackwell GPU.
- Triton: A slightly slower implementation written in Triton which supports all operations needed for training. Requires a Blackwell GPU.
- PyTorch: A slow implementation written in PyTorch which supports all operations and can be run on any GPU.
If quantize_to_fp4 is called with backend=None, a backend will be selected
automatically based on the following rules:
- If there is no GPU available, or if the available GPU is not a Blackwell GPU, select PyTorch.
- If any quantization options are set other than
scale_rule, select Triton.- However, if the available GPU is SM120 (i.e. RTX 5090, RTX 6000) and
round_styleis set toRoundStyle.stochastic, select PyTorch as stochastic rounding does not have hardware support on SM120 GPUs.
- However, if the available GPU is SM120 (i.e. RTX 5090, RTX 6000) and
- Otherwise, select CUDA.
Parameters
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
The input tensor to quantize. |
required |
backend
|
QuantizeBackend
|
The backend to use for quantization, either
|
None
|
scale_rule
|
AdaptiveBlockScalingRule
|
The scaling rule to use during quantization. See (Adaptive Block Scaling)[/adaptive_block_scaling] for more details. |
mse
|
block_scale_2d
|
bool
|
If True, scale factors will be computed across 16x16 chunks of the input rather than 1x16 chunks. This is useful to apply to the weight matrix during training, so that W and W.T will be equivalent after quantization. |
False
|
had
|
Tensor
|
A high-precision Hadamard matrix to apply to the input prior to quantization. |
None
|
fp4_format
|
FP4Format
|
The FP4 format to quantize to, either |
nvfp4
|
round_style
|
RoundStyle
|
The rounding style to apply during quantization,
either |
nearest
|
transpose
|
bool
|
If True, the output will be a quantized version of the
transposed input. This may be helpful for certain operations during training
as |
False
|
Returns:
| Type | Description |
|---|---|
Tensor
|
The packed E2M1 values. |
Tensor
|
The FP8 scale factors. |
Tensor | None
|
The tensor-wide FP32 scale factor. |
Source code in src/fouroversix/frontend.py
161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 | |