Matrix Multiplication
Perform a matrix multiplication (a @ b.T) with two FP4-quantized tensors provided
in row-major layout.
Sample Code
Each tensor may be provided in either high or low precision. If provided in high
precision, tensors will be quantized to FP4 prior to the matrix multiplication, and
quantization may be configured with the a_quantization_kwargs and
b_quantization_kwargs parameters. For example, the following two code samples are
equivalent:
With High-Precision Inputs
a = torch.tensor(1024, 1024, dtype=torch.bfloat16, device="cuda")
b = torch.tensor(1024, 1024, dtype=torch.bfloat16, device="cuda")
out = fp4_matmul(a, b)
With Low-Precision Inputs
a = torch.tensor(1024, 1024, dtype=torch.bfloat16, device="cuda")
b = torch.tensor(1024, 1024, dtype=torch.bfloat16, device="cuda")
a_e2m1, a_sf, a_normconst = quantize_to_fp4(a)
b_e2m1, b_sf, b_normconst = quantize_to_fp4(b)
out = fp4_matmul(
a_e2m1=a_e2m1,
a_sf=a_sf,
a_normconst=a_normconst,
b_e2m1=b_e2m1,
b_sf=b_sf,
b_normconst=b_normconst
)
Backends
We provide two different implementations of FP4 matrix multiplication:
- CUTLASS: Uses CUTLASS kernels to perform fast FP4 matrix multiplication. Requires a Blackwell GPU.
- PyTorch: A slow implementation which dequantizes FP4 tensors, and then performs a high-precision matrix multiplication.
Note that our CUTLASS kernels accumulate in FP32, so it should be roughly equivalent to simulations done with the PyTorch backend.
Parameters
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a
|
Tensor
|
The high-precision input tensor A. |
None
|
b
|
Tensor
|
The high-precision input tensor B. |
None
|
backend
|
MatmulBackend
|
The backend to use for the matrix multiplication,
either |
None
|
a_e2m1
|
Tensor
|
The values of the first input tensor in packed E2M1 format (2 values per byte). |
None
|
a_sf
|
Tensor
|
The scale factors of the first input tensor. |
None
|
a_normconst
|
Tensor
|
The per-tensor normalization constant of the first input tensor. |
None
|
b_e2m1
|
Tensor
|
The values of the second input tensor in packed E2M1 format (2 values per byte). |
None
|
b_sf
|
Tensor
|
The scale factors of the second input tensor. |
None
|
b_normconst
|
Tensor
|
The per-tensor normalization constant of the second input tensor. |
None
|
a_quantize_kwargs
|
dict
|
If |
None
|
b_quantize_kwargs
|
dict
|
If |
None
|
fp4_format
|
FP4Format
|
The FP4 format of the input tensors, either
|
nvfp4
|
out_dtype
|
DataType
|
The data type of the output tensor, either
|
bfloat16
|
out_shape
|
tuple[int, int] | None
|
The shape of the output tensor. This is helpful when the input tensors have shapes that are not multiples of 64, but which were padded to multiples of 64 during quantization. |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
The output tensor. |
Source code in src/fouroversix/frontend.py
| |