Skip to content

Matrix Multiplication

Perform a matrix multiplication (a @ b.T) with two FP4-quantized tensors provided in row-major layout.

Sample Code

Each tensor may be provided in either high or low precision. If provided in high precision, tensors will be quantized to FP4 prior to the matrix multiplication, and quantization may be configured with the input_quantize_kwargs and other_quantize_kwargs parameters. For example, the following two code samples are equivalent:

With High-Precision Inputs

a = torch.tensor(1024, 1024, dtype=torch.bfloat16, device="cuda")
b = torch.tensor(1024, 1024, dtype=torch.bfloat16, device="cuda")
out = fp4_matmul(a, b)

With Low-Precision Inputs

a = torch.tensor(1024, 1024, dtype=torch.bfloat16, device="cuda")
b = torch.tensor(1024, 1024, dtype=torch.bfloat16, device="cuda")

a_quantized = quantize_to_fp4(a)
b_quantized = quantize_to_fp4(b)

out = fp4_matmul(a_quantized, b_quantized)

Backends

We provide two different implementations of FP4 matrix multiplication:

  • CUTLASS: Uses CUTLASS kernels to perform fast FP4 matrix multiplication. Requires a Blackwell GPU.
  • PyTorch: A slow implementation which dequantizes FP4 tensors, and then performs a high-precision matrix multiplication.

Note that our CUTLASS kernels accumulate in FP32, so it should be roughly equivalent to simulations done with the PyTorch backend.

Parameters

Parameters:

Name Type Description Default
input Tensor | FP4Tensor

The first tensor to be multiplied.

required
other Tensor | FP4Tensor

The second tensor to be multiplied.

required
backend MatmulBackend

The backend to use for the matrix multiplication, either MatmulBackend.cutlass or MatmulBackend.pytorch. If no backend is provided, CUTLASS will be used if the machine has a Blackwell GPU, and PyTorch will be used otherwise.

None
input_quantize_kwargs dict

If a is provided in high precision, these parameters will be passed to the quantize_to_fp4 call done prior to the matrix multiplication.

None
other_quantize_kwargs dict

If other is provided in high precision, these parameters will be passed to the quantize_to_fp4 call done prior to the matrix multiplication.

None
out_dtype DataType

The data type of the output tensor, either DataType.bfloat16 or DataType.float16.

bfloat16

Returns:

Type Description
Tensor

The output tensor.

Source code in src/fouroversix/frontend.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
def fp4_matmul(
    input: torch.Tensor | FP4Tensor,
    other: torch.Tensor | FP4Tensor,
    *,
    backend: MatmulBackend | None = None,
    input_quantize_kwargs: dict[str, Any] | None = None,
    other_quantize_kwargs: dict[str, Any] | None = None,
    out_dtype: torch.dtype = torch.bfloat16,
) -> torch.Tensor:
    """
    Perform a matrix multiplication (`a @ b.T`) with two FP4-quantized tensors provided
    in row-major layout.

    ## Sample Code

    Each tensor may be provided in either high or low precision. If provided in high
    precision, tensors will be quantized to FP4 prior to the matrix multiplication, and
    quantization may be configured with the `input_quantize_kwargs` and
    `other_quantize_kwargs` parameters. For example, the following two code samples are
    equivalent:

    ### With High-Precision Inputs

    ```
    a = torch.tensor(1024, 1024, dtype=torch.bfloat16, device="cuda")
    b = torch.tensor(1024, 1024, dtype=torch.bfloat16, device="cuda")
    out = fp4_matmul(a, b)
    ```

    ### With Low-Precision Inputs

    ```
    a = torch.tensor(1024, 1024, dtype=torch.bfloat16, device="cuda")
    b = torch.tensor(1024, 1024, dtype=torch.bfloat16, device="cuda")

    a_quantized = quantize_to_fp4(a)
    b_quantized = quantize_to_fp4(b)

    out = fp4_matmul(a_quantized, b_quantized)
    ```

    ## Backends

    We provide two different implementations of FP4 matrix multiplication:

    - **CUTLASS**: Uses CUTLASS kernels to perform fast FP4 matrix multiplication.
        Requires a Blackwell GPU.
    - **PyTorch**: A slow implementation which dequantizes FP4 tensors, and then
        performs a high-precision matrix multiplication.

    Note that our CUTLASS kernels accumulate in FP32, so it should be roughly
    equivalent to simulations done with the PyTorch backend.

    ## Parameters

    Args:
        input (torch.Tensor | FP4Tensor): The first tensor to be multiplied.
        other (torch.Tensor | FP4Tensor): The second tensor to be multiplied.
        backend (MatmulBackend): The backend to use for the matrix multiplication,
            either `MatmulBackend.cutlass` or `MatmulBackend.pytorch`. If no backend is
            provided, CUTLASS will be used if the machine has a Blackwell GPU, and
            PyTorch will be used otherwise.
        input_quantize_kwargs (dict): If `a` is provided in high precision, these
            parameters will be passed to the `quantize_to_fp4` call done prior to the
            matrix multiplication.
        other_quantize_kwargs (dict): If `other` is provided in high precision, these
            parameters will be passed to the `quantize_to_fp4` call done prior to the
            matrix multiplication.
        out_dtype (DataType): The data type of the output tensor, either
            `DataType.bfloat16` or `DataType.float16`.

    Returns:
        The output tensor.

    """

    if input_quantize_kwargs is None:
        input_quantize_kwargs = {}

    if other_quantize_kwargs is None:
        other_quantize_kwargs = {}

    if isinstance(input, torch.Tensor):
        input = quantize_to_fp4(input, **(input_quantize_kwargs or {}))

    if isinstance(other, torch.Tensor):
        other = quantize_to_fp4(other, **(other_quantize_kwargs or {}))

    if backend is None:
        backend = MatmulBackend.auto_select()
    elif not backend.is_available():
        msg = f"Backend {backend} is not available"
        raise ValueError(msg)

    return backend.fp4_matmul(input, other, out_dtype=out_dtype)