Skip to content

Matrix Multiplication

Perform a matrix multiplication (a @ b.T) with two FP4-quantized tensors provided in row-major layout.

Sample Code

Each tensor may be provided in either high or low precision. If provided in high precision, tensors will be quantized to FP4 prior to the matrix multiplication, and quantization may be configured with the a_quantization_kwargs and b_quantization_kwargs parameters. For example, the following two code samples are equivalent:

With High-Precision Inputs

a = torch.tensor(1024, 1024, dtype=torch.bfloat16, device="cuda")
b = torch.tensor(1024, 1024, dtype=torch.bfloat16, device="cuda")
out = fp4_matmul(a, b)

With Low-Precision Inputs

a = torch.tensor(1024, 1024, dtype=torch.bfloat16, device="cuda")
b = torch.tensor(1024, 1024, dtype=torch.bfloat16, device="cuda")

a_e2m1, a_sf, a_normconst = quantize_to_fp4(a)
b_e2m1, b_sf, b_normconst = quantize_to_fp4(b)

out = fp4_matmul(
    a_e2m1=a_e2m1,
    a_sf=a_sf,
    a_normconst=a_normconst,
    b_e2m1=b_e2m1,
    b_sf=b_sf,
    b_normconst=b_normconst
)

Backends

We provide two different implementations of FP4 matrix multiplication:

  • CUTLASS: Uses CUTLASS kernels to perform fast FP4 matrix multiplication. Requires a Blackwell GPU.
  • PyTorch: A slow implementation which dequantizes FP4 tensors, and then performs a high-precision matrix multiplication.

Note that our CUTLASS kernels accumulate in FP32, so it should be roughly equivalent to simulations done with the PyTorch backend.

Parameters

Parameters:

Name Type Description Default
a Tensor

The high-precision input tensor A.

None
b Tensor

The high-precision input tensor B.

None
backend MatmulBackend

The backend to use for the matrix multiplication, either MatmulBackend.cutlass or MatmulBackend.pytorch. If no backend is provided, CUTLASS will be used if the machine has a Blackwell GPU, and PyTorch will be used otherwise.

None
a_e2m1 Tensor

The values of the first input tensor in packed E2M1 format (2 values per byte).

None
a_sf Tensor

The scale factors of the first input tensor.

None
a_normconst Tensor

The per-tensor normalization constant of the first input tensor.

None
b_e2m1 Tensor

The values of the second input tensor in packed E2M1 format (2 values per byte).

None
b_sf Tensor

The scale factors of the second input tensor.

None
b_normconst Tensor

The per-tensor normalization constant of the second input tensor.

None
a_quantize_kwargs dict

If a is provided in high precision, these parameters will be passed to the quantize_to_fp4 call done prior to the matrix multiplication.

None
b_quantize_kwargs dict

If b is provided in high precision, these parameters will be passed to the quantize_to_fp4 call done prior to the matrix multiplication.

None
fp4_format FP4Format

The FP4 format of the input tensors, either FP4Format.nvfp4 or FP4Format.mxfp4.

nvfp4
out_dtype DataType

The data type of the output tensor, either DataType.bfloat16 or DataType.float16.

bfloat16
out_shape tuple[int, int] | None

The shape of the output tensor. This is helpful when the input tensors have shapes that are not multiples of 64, but which were padded to multiples of 64 during quantization.

None

Returns:

Type Description
Tensor

The output tensor.

Source code in src/fouroversix/frontend.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def fp4_matmul(
    a: torch.Tensor | None = None,
    b: torch.Tensor | None = None,
    *,
    backend: MatmulBackend | None = None,
    a_e2m1: torch.Tensor | None = None,
    a_sf: torch.Tensor | None = None,
    a_normconst: torch.Tensor | None = None,
    b_e2m1: torch.Tensor | None = None,
    b_sf: torch.Tensor | None = None,
    b_normconst: torch.Tensor | None = None,
    a_quantize_kwargs: dict[str, Any] | None = None,
    b_quantize_kwargs: dict[str, Any] | None = None,
    fp4_format: FP4Format = FP4Format.nvfp4,
    out_dtype: DataType = DataType.bfloat16,
    out_shape: tuple[int, int] | None = None,
) -> torch.Tensor:
    """
    Perform a matrix multiplication (`a @ b.T`) with two FP4-quantized tensors provided
    in row-major layout.

    ## Sample Code

    Each tensor may be provided in either high or low precision. If provided in high
    precision, tensors will be quantized to FP4 prior to the matrix multiplication, and
    quantization may be configured with the `a_quantization_kwargs` and
    `b_quantization_kwargs` parameters. For example, the following two code samples are
    equivalent:

    ### With High-Precision Inputs

    ```
    a = torch.tensor(1024, 1024, dtype=torch.bfloat16, device="cuda")
    b = torch.tensor(1024, 1024, dtype=torch.bfloat16, device="cuda")
    out = fp4_matmul(a, b)
    ```

    ### With Low-Precision Inputs

    ```
    a = torch.tensor(1024, 1024, dtype=torch.bfloat16, device="cuda")
    b = torch.tensor(1024, 1024, dtype=torch.bfloat16, device="cuda")

    a_e2m1, a_sf, a_normconst = quantize_to_fp4(a)
    b_e2m1, b_sf, b_normconst = quantize_to_fp4(b)

    out = fp4_matmul(
        a_e2m1=a_e2m1,
        a_sf=a_sf,
        a_normconst=a_normconst,
        b_e2m1=b_e2m1,
        b_sf=b_sf,
        b_normconst=b_normconst
    )
    ```

    ## Backends

    We provide two different implementations of FP4 matrix multiplication:

    - **CUTLASS**: Uses CUTLASS kernels to perform fast FP4 matrix multiplication.
        Requires a Blackwell GPU.
    - **PyTorch**: A slow implementation which dequantizes FP4 tensors, and then
        performs a high-precision matrix multiplication.

    Note that our CUTLASS kernels accumulate in FP32, so it should be roughly
    equivalent to simulations done with the PyTorch backend.

    ## Parameters

    Args:
        a (torch.Tensor): The high-precision input tensor A.
        b (torch.Tensor): The high-precision input tensor B.
        backend (MatmulBackend): The backend to use for the matrix multiplication,
            either `MatmulBackend.cutlass` or `MatmulBackend.pytorch`. If no backend is
            provided, CUTLASS will be used if the machine has a Blackwell GPU, and
            PyTorch will be used otherwise.
        a_e2m1 (torch.Tensor): The values of the first input tensor in packed E2M1
            format (2 values per byte).
        a_sf (torch.Tensor): The scale factors of the first input tensor.
        a_normconst (torch.Tensor): The per-tensor normalization constant of the
            first input tensor.
        b_e2m1 (torch.Tensor): The values of the second input tensor in packed E2M1
            format (2 values per byte).
        b_sf (torch.Tensor): The scale factors of the second input tensor.
        b_normconst (torch.Tensor): The per-tensor normalization constant of the
            second input tensor.
        a_quantize_kwargs (dict): If `a` is provided in high precision, these parameters
            will be passed to the `quantize_to_fp4` call done prior to the matrix
            multiplication.
        b_quantize_kwargs (dict): If `b` is provided in high precision, these parameters
            will be passed to the `quantize_to_fp4` call done prior to the matrix
            multiplication.
        fp4_format (FP4Format): The FP4 format of the input tensors, either
            `FP4Format.nvfp4` or `FP4Format.mxfp4`.
        out_dtype (DataType): The data type of the output tensor, either
            `DataType.bfloat16` or `DataType.float16`.
        out_shape (tuple[int, int] | None): The shape of the output tensor. This is
            helpful when the input tensors have shapes that are not multiples of 64,
            but which were padded to multiples of 64 during quantization.

    Returns:
        The output tensor.

    """

    if a is None and (a_e2m1 is None or a_sf is None):
        msg = "If a is None, a_e2m1 and a_sf must be provided"
        raise ValueError(msg)

    if b is None and (b_e2m1 is None or b_sf is None):
        msg = "If b is None, b_e2m1 and b_sf must be provided"
        raise ValueError(msg)

    if a_quantize_kwargs is None:
        a_quantize_kwargs = {}

    if b_quantize_kwargs is None:
        b_quantize_kwargs = {}

    if a_e2m1 is None or a_sf is None:
        a_e2m1, a_sf, a_normconst = quantize_to_fp4(a, **a_quantize_kwargs)

    if b_e2m1 is None or b_sf is None:
        b_e2m1, b_sf, b_normconst = quantize_to_fp4(b, **b_quantize_kwargs)

    kwargs = {
        "fp4_format": fp4_format,
        "out_dtype": out_dtype,
        "out_shape": out_shape,
    }

    if backend is None:
        backend = MatmulBackend.auto_select()
    elif not backend.is_available():
        msg = f"Backend {backend} is not available"
        raise ValueError(msg)

    return backend.fp4_matmul(
        a_e2m1,
        a_sf,
        a_normconst,
        b_e2m1,
        b_sf,
        b_normconst,
        **kwargs,
    )