Matrix Multiplication
Perform a matrix multiplication (a @ b.T) with two FP4-quantized tensors provided
in row-major layout.
Sample Code
Each tensor may be provided in either high or low precision. If provided in high
precision, tensors will be quantized to FP4 prior to the matrix multiplication, and
quantization may be configured with the input_quantize_kwargs and
other_quantize_kwargs parameters. For example, the following two code samples are
equivalent:
With High-Precision Inputs
a = torch.tensor(1024, 1024, dtype=torch.bfloat16, device="cuda")
b = torch.tensor(1024, 1024, dtype=torch.bfloat16, device="cuda")
out = fp4_matmul(a, b)
With Low-Precision Inputs
a = torch.tensor(1024, 1024, dtype=torch.bfloat16, device="cuda")
b = torch.tensor(1024, 1024, dtype=torch.bfloat16, device="cuda")
a_quantized = quantize_to_fp4(a)
b_quantized = quantize_to_fp4(b)
out = fp4_matmul(a_quantized, b_quantized)
Backends
We provide two different implementations of FP4 matrix multiplication:
- CUTLASS: Uses CUTLASS kernels to perform fast FP4 matrix multiplication. Requires a Blackwell GPU.
- PyTorch: A slow implementation which dequantizes FP4 tensors, and then performs a high-precision matrix multiplication.
Note that our CUTLASS kernels accumulate in FP32, so it should be roughly equivalent to simulations done with the PyTorch backend.
Parameters
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input
|
Tensor | FP4Tensor
|
The first tensor to be multiplied. |
required |
other
|
Tensor | FP4Tensor
|
The second tensor to be multiplied. |
required |
backend
|
MatmulBackend
|
The backend to use for the matrix multiplication,
either |
None
|
input_quantize_kwargs
|
dict
|
If |
None
|
other_quantize_kwargs
|
dict
|
If |
None
|
out_dtype
|
DataType
|
The data type of the output tensor, either
|
bfloat16
|
Returns:
| Type | Description |
|---|---|
Tensor
|
The output tensor. |
Source code in src/fouroversix/frontend.py
14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 | |