Matrix Multiplication
Perform a matrix multiplication (a @ b.T) with two FP4-quantized tensors provided
in row-major layout.
Sample Code
Each tensor may be provided in either high or low precision. If provided in high
precision, tensors will be quantized to FP4 prior to the matrix multiplication, and
quantization may be configured with the a_quantization_kwargs and
b_quantization_kwargs parameters. For example, the following two code samples are
equivalent:
With High-Precision Inputs
a = torch.tensor(1024, 1024, dtype=torch.bfloat16, device="cuda")
b = torch.tensor(1024, 1024, dtype=torch.bfloat16, device="cuda")
out = fp4_matmul(a, b)
With Low-Precision Inputs
a = torch.tensor(1024, 1024, dtype=torch.bfloat16, device="cuda")
b = torch.tensor(1024, 1024, dtype=torch.bfloat16, device="cuda")
a_e2m1, a_sf, a_normconst = quantize_to_fp4(a)
b_e2m1, b_sf, b_normconst = quantize_to_fp4(b)
out = fp4_matmul(
a_e2m1=a_e2m1,
a_sf=a_sf,
a_normconst=a_normconst,
b_e2m1=b_e2m1,
b_sf=b_sf,
b_normconst=b_normconst
)
Backends
We provide two different implementations of FP4 matrix multiplication:
- CUTLASS: Uses CUTLASS kernels to perform fast FP4 matrix multiplication. Requires a Blackwell GPU.
- PyTorch: A slow implementation which dequantizes FP4 tensors, and then performs a high-precision matrix multiplication.
Note that our CUTLASS kernels accumulate in FP32, so it should be roughly equivalent to simulations done with the PyTorch backend.
Parameters
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a
|
Tensor
|
The high-precision input tensor A. |
None
|
b
|
Tensor
|
The high-precision input tensor B. |
None
|
backend
|
MatmulBackend
|
The backend to use for the matrix multiplication,
either |
None
|
a_e2m1
|
Tensor
|
The values of the first input tensor in packed E2M1 format (2 values per byte). |
None
|
a_sf
|
Tensor
|
The scale factors of the first input tensor. |
None
|
a_normconst
|
Tensor
|
The per-tensor normalization constant of the first input tensor. |
None
|
b_e2m1
|
Tensor
|
The values of the second input tensor in packed E2M1 format (2 values per byte). |
None
|
b_sf
|
Tensor
|
The scale factors of the second input tensor. |
None
|
b_normconst
|
Tensor
|
The per-tensor normalization constant of the second input tensor. |
None
|
a_quantize_kwargs
|
dict
|
If |
None
|
b_quantize_kwargs
|
dict
|
If |
None
|
fp4_format
|
FP4Format
|
The FP4 format of the input tensors, either
|
nvfp4
|
out_dtype
|
DataType
|
The data type of the output tensor, either
|
bfloat16
|
out_shape
|
tuple[int, int] | None
|
The shape of the output tensor. This is helpful when the input tensors have shapes that are not multiples of 64, but which were padded to multiples of 64 during quantization. |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
The output tensor. |
Source code in src/fouroversix/frontend.py
12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 | |