• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Owner(s): ["module: inductor"]
2
3import functools
4import unittest
5
6import torch
7from torch import Tensor
8from torch._inductor import utils
9from torch._inductor.test_case import run_tests, TestCase
10from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8, SM90OrLater
11from torch.testing._internal.common_utils import (
12    instantiate_parametrized_tests,
13    parametrize,
14    TEST_WITH_ROCM,
15)
16from torch.testing._internal.inductor_utils import HAS_CUDA
17
18
19torch.set_float32_matmul_precision("high")
20
21
22f8_msg = "FP8 is only supported on H100+ and sm_89 and MI300+ devices"
23
24# define the e4m3/e5m2 constants
25E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max
26E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max
27E4M3FNUZ_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max
28E5M2FNUZ_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max
29
30FP16_MAX_POS: float = torch.finfo(torch.float16).max
31EPS: float = 1e-12
32
33
34def _to_fp8_saturated(x: Tensor, float8_dtype: torch.dtype) -> Tensor:
35    # The default behavior in PyTorch for casting to `float8_e4m3fn`
36    # and `e5m2` is to not saturate. In this context, we should saturate.
37    # A common case where we want to saturate is when the history of a
38    # tensor has a maximum value of `amax1`, and the current amax value
39    # is `amax2`, where `amax1 < amax2`. This is common when using delayed
40    # scaling.
41    if float8_dtype == torch.float8_e4m3fn:
42        x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS)
43    elif float8_dtype == torch.float8_e5m2:
44        x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS)
45    elif float8_dtype == torch.float8_e4m3fnuz:
46        x = x.clamp(min=-1 * E4M3FNUZ_MAX_POS, max=E4M3FNUZ_MAX_POS)
47    elif float8_dtype == torch.float8_e5m2fnuz:
48        x = x.clamp(min=-1 * E5M2FNUZ_MAX_POS, max=E5M2FNUZ_MAX_POS)
49    else:
50        raise TypeError(f"Unsupported float8_dtype: {float8_dtype}")
51    return x.to(float8_dtype)
52
53
54@torch.no_grad()
55def _amax_to_scale(
56    amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype
57) -> torch.Tensor:
58    # To make scale dtype to be fp32 for accuracy
59    amax = amax.float()
60    if float8_dtype == torch.float8_e4m3fn:
61        res = E4M3_MAX_POS / torch.clamp(amax, min=EPS)
62    else:  # e5m2
63        res = E5M2_MAX_POS / torch.clamp(amax, min=EPS)
64
65    # Ensure that the scale is representable in float16,
66    # this helps when amax is small. We are assuming that we don't need
67    # to care about this for float32/bfloat16.
68    if orig_dtype is torch.float16:
69        res = torch.clamp(res, max=FP16_MAX_POS)
70    return res
71
72
73def _quantize_tensorwise(x: Tensor, float8_dtype: torch.dtype):
74    amax = torch.max(torch.abs(x))
75    scale = _amax_to_scale(amax, float8_dtype, x.dtype)
76    x_fp8 = _to_fp8_saturated(x * scale, float8_dtype)
77    inverse_scale = scale.reciprocal()
78    return x_fp8, inverse_scale
79
80
81def _quantize_rowwise(x: Tensor, float8_dtype: torch.dtype):
82    amax = torch.max(torch.abs(x), dim=1, keepdim=True).values
83    scale = _amax_to_scale(amax, float8_dtype, x.dtype)
84    x_fp8 = _to_fp8_saturated(x * scale, float8_dtype)
85    inverse_scale = scale.reciprocal()
86    return x_fp8, inverse_scale
87
88
89@instantiate_parametrized_tests
90class TestFP8Types(TestCase):
91    @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
92    @unittest.skipIf(TEST_WITH_ROCM, "Not supported yet")
93    @parametrize("dtype", (torch.float16, torch.bfloat16))
94    def test_eager_fallback(self, dtype: torch.dtype):
95        weight_shape = (32, 16)
96
97        e4m3_type = (
98            torch.float8_e4m3fn if torch.version.hip is None else torch.float8_e4m3fnuz
99        )
100
101        def fp8_matmul_unwrapped(x):
102            a_scale = torch.Tensor([1.0]).to(device="cuda")
103            b_scale = torch.Tensor([1.0]).to(device="cuda")
104            output_scale = None
105            input_bias = torch.rand(32, device="cuda", dtype=dtype)
106            weight = torch.rand(*weight_shape, device="cuda", dtype=dtype).T.to(
107                e4m3_type
108            )
109            a_inverse_scale = 1 / a_scale
110            b_inverse_scale = 1 / b_scale
111            output = torch._scaled_mm(
112                x,
113                weight,
114                bias=input_bias,
115                out_dtype=dtype,
116                scale_a=a_inverse_scale,
117                scale_b=b_inverse_scale,
118                scale_result=output_scale,
119            )
120            return output
121
122        compiled_fp8_matmul = torch.compile(
123            fp8_matmul_unwrapped, backend="inductor", dynamic=True
124        )
125
126        x_shape = (16, 16)
127        x = torch.rand(*x_shape, device="cuda", dtype=dtype).to(e4m3_type)
128        y_fp8 = compiled_fp8_matmul(x)
129
130        x_shape = (15, 16)
131        x = torch.rand(*x_shape, device="cuda", dtype=dtype).to(e4m3_type)
132        y_fp8 = compiled_fp8_matmul(x)
133
134    @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
135    @parametrize("dtype", (torch.float16, torch.bfloat16, torch.float))
136    @parametrize("shape", ("15,3,13", "4,2048,4096"))
137    @parametrize(
138        "dst_types",
139        [(torch.float8_e4m3fn, torch.float8_e5m2)]
140        if torch.version.hip is None
141        else [(torch.float8_e4m3fnuz, torch.float8_e5m2fnuz)],
142    )
143    def test_valid_cast(self, dtype: torch.dtype, shape: str, dst_types: tuple):
144        e4m3, e5m2 = dst_types
145
146        def fp8_cast(x):
147            y0 = x.to(dtype=e4m3).to(dtype)
148            y1 = x.to(dtype=e5m2).to(dtype)
149            return y0, y1
150
151        compiled_fp8_cast = torch.compile(fp8_cast, backend="inductor", dynamic=True)
152
153        shape = [int(dim) for dim in shape.split(",")]
154        x = torch.rand(*shape, device="cuda", dtype=dtype)
155        y0_fp8, y1_fp8 = compiled_fp8_cast(x)
156
157        torch.testing.assert_close(y0_fp8, x, rtol=5e-1, atol=5e-1)
158        torch.testing.assert_close(y1_fp8, x, rtol=5e-1, atol=5e-1)
159
160    @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
161    def test_bad_cast(self):
162        def fp8_cast(x, dtype):
163            return x.to(dtype=dtype)
164
165        compiled_fp8_cast = torch.compile(fp8_cast, backend="inductor", dynamic=True)
166
167        x_shape = (16, 16, 16)
168
169        with self.assertRaisesRegex(
170            torch._dynamo.exc.BackendCompilerFailed,
171            "Conversions between float8_e5m2 and float8_e4m3fn is not supported!",
172        ):
173            x = torch.rand(*x_shape, device="cuda").to(dtype=torch.float8_e4m3fn)
174            y = compiled_fp8_cast(x, torch.float8_e5m2)
175
176        with self.assertRaisesRegex(
177            torch._dynamo.exc.BackendCompilerFailed,
178            "Conversions between float8_e5m2 and float8_e4m3fn is not supported!",
179        ):
180            x = torch.rand(*x_shape, device="cuda").to(dtype=torch.float8_e5m2)
181            y = compiled_fp8_cast(x, torch.float8_e4m3fn)
182
183    @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
184    @parametrize("src_dtype", (torch.float16, torch.bfloat16, torch.float))
185    @parametrize(
186        "dst_dtype",
187        (torch.float8_e4m3fn, torch.float8_e5m2)
188        if torch.version.hip is None
189        else (torch.float8_e4m3fnuz, torch.float8_e5m2fnuz),
190    )
191    @parametrize("shape", ("16,16,16", "4,2048,4096"))
192    def test_to_fp8_saturated(
193        self, src_dtype: torch.dtype, dst_dtype: torch.dtype, shape: str
194    ):
195        def fp8_saturated(x, dtype):
196            return _to_fp8_saturated(x, dtype)
197
198        compiled_fp8_cast = torch.compile(
199            fp8_saturated, backend="inductor", dynamic=True
200        )
201        shape = [int(dim) for dim in shape.split(",")]
202        x = torch.rand(*shape, device="cuda", dtype=src_dtype)
203        y_compiled = compiled_fp8_cast(x, dst_dtype)
204        y = fp8_saturated(x, dst_dtype)
205
206        torch.testing.assert_close(y_compiled.half(), y.half(), rtol=5e-1, atol=5e-1)
207
208    @unittest.skipIf(TEST_WITH_ROCM, "ROCm fails with accuracy issue")
209    @unittest.skipIf(not SM90OrLater, "FP8 is only supported on H100+")
210    @parametrize(
211        "float8_dtype",
212        (torch.float8_e4m3fn, torch.float8_e5m2)
213        if torch.version.hip is None
214        else (torch.float8_e4m3fnuz, torch.float8_e5m2fnuz),
215    )
216    @parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096"))
217    def test_amax_fp8_quant(self, float8_dtype: torch.dtype, shape: str):
218        shape = [int(dim) for dim in shape.split(",")]
219        batch_size, sequence_length, hidden_size = shape
220
221        def amax_fp8(x: Tensor, scale: Tensor):
222            y = torch.amax(torch.abs(x))
223            y_scaled = y.to(dtype=torch.float) * scale
224            bits_fp8 = _to_fp8_saturated(y_scaled, float8_dtype)
225            return bits_fp8
226
227        compiled_amax_fp8_quant = torch.compile(amax_fp8, backend="inductor")
228
229        x_shape = (batch_size, sequence_length, hidden_size)
230        x = torch.rand(*x_shape, device="cuda", dtype=torch.half)
231        scale = torch.tensor(0.2, device="cuda", dtype=torch.float)
232
233        y_compiled = compiled_amax_fp8_quant(x, scale)
234        y = amax_fp8(x, scale)
235
236        torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-2, atol=1e-2)
237
238    @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
239    @parametrize(
240        "float8_dtype",
241        (torch.float8_e4m3fn, torch.float8_e5m2)
242        if torch.version.hip is None
243        else (torch.float8_e4m3fnuz, torch.float8_e5m2fnuz),
244    )
245    @parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096"))
246    def test_amax_along_with_fp8_quant(self, float8_dtype: torch.dtype, shape: str):
247        shape = [int(dim) for dim in shape.split(",")]
248        batch_size, sequence_length, hidden_size = shape
249
250        def amax_fp8(x: Tensor, scale: Tensor, amax_buffer: Tensor):
251            amax_buffer.fill_(torch.amax(torch.abs(x)))
252            x_scaled = x.to(dtype=torch.float) * scale
253            bits_fp8 = _to_fp8_saturated(x_scaled, float8_dtype)
254            return bits_fp8
255
256        compiled_amax_fp8_quant = torch.compile(amax_fp8, backend="inductor")
257
258        x_shape = (batch_size, sequence_length, hidden_size)
259        x = torch.rand(*x_shape, device="cuda", dtype=torch.half)
260        scale = torch.tensor(1.0, device="cuda", dtype=torch.float)
261
262        amax_buffer_compiled = torch.zeros((1), device="cuda", dtype=torch.half)
263        y_compiled = compiled_amax_fp8_quant(x, scale, amax_buffer_compiled)
264        amax_buffer = torch.zeros((1), device="cuda", dtype=torch.half)
265        y = amax_fp8(x, scale, amax_buffer)
266
267        torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-1, atol=1e-1)
268        torch.testing.assert_close(
269            amax_buffer_compiled, amax_buffer, rtol=1e-2, atol=1e-2
270        )
271
272    @unittest.skipIf(TEST_WITH_ROCM, "ROCm fails with accuracy issue")
273    @unittest.skipIf(not SM90OrLater, "FP8 is only supported on H100+")
274    @parametrize(
275        "float8_dtype",
276        (torch.float8_e4m3fn, torch.float8_e5m2)
277        if torch.version.hip is None
278        else (torch.float8_e4m3fnuz, torch.float8_e5m2fnuz),
279    )
280    @parametrize("amax_keep_dim", (True, False))
281    @parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096"))
282    def test_layernorm_fp8_quant(
283        self, float8_dtype: torch.dtype, amax_keep_dim: bool, shape: str
284    ):
285        shape = [int(dim) for dim in shape.split(",")]
286        batch_size, sequence_length, hidden_size = shape
287
288        def ln_fp8(x: Tensor, scale: Tensor, amax_buffer: Tensor):
289            x = torch.nn.functional.layer_norm(
290                x.to(dtype=torch.float),
291                [hidden_size],
292                weight=None,
293                bias=None,
294                eps=1e-05,
295            )
296            amax_buffer.fill_(
297                torch.amax(torch.abs(x), keepdim=amax_keep_dim).reshape(-1)[0]
298            )
299            x_scaled = x * scale
300            bits_fp8 = _to_fp8_saturated(x_scaled, float8_dtype)
301            return bits_fp8
302
303        compiled_ln_fp8_quant = torch.compile(ln_fp8, backend="inductor")
304
305        x_shape = (batch_size, sequence_length, hidden_size)
306        x = torch.rand(*x_shape, device="cuda", dtype=torch.half)
307        scale = torch.tensor(0.2, device="cuda", dtype=torch.float)
308
309        amax_buffer_compiled = torch.zeros((1), device="cuda", dtype=torch.half)
310        y_compiled = compiled_ln_fp8_quant(x, scale, amax_buffer_compiled)
311        amax_buffer = torch.zeros((1), device="cuda", dtype=torch.half)
312        y = ln_fp8(x, scale, amax_buffer)
313
314        torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-1, atol=1e-1)
315        torch.testing.assert_close(
316            amax_buffer_compiled, amax_buffer, rtol=1e-2, atol=1e-2
317        )
318
319    @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg)
320    @parametrize(
321        "float8_dtype",
322        (torch.float8_e4m3fn, torch.float8_e5m2)
323        if torch.version.hip is None
324        else (torch.float8_e4m3fnuz, torch.float8_e5m2fnuz),
325    )
326    @parametrize("shape", ("4,2048,4096",))
327    @parametrize("keepdim", (False, True))
328    def test_layernorm_fp8_quant_benchmark(
329        self,
330        float8_dtype: torch.dtype,
331        shape: str,
332        keepdim: bool,
333    ):
334        shape = [int(dim) for dim in shape.split(",")]
335        batch_size, sequence_length, hidden_size = shape
336
337        def ln(x: Tensor):
338            x = torch.nn.functional.layer_norm(
339                x.to(dtype=torch.float),
340                [hidden_size],
341                weight=None,
342                bias=None,
343                eps=1e-05,
344            )
345            return x
346
347        def ln_fp8(x: Tensor, scale: Tensor, amax_buffer: Tensor):
348            x = torch.nn.functional.layer_norm(
349                x.to(dtype=torch.float),
350                [hidden_size],
351                weight=None,
352                bias=None,
353                eps=1e-05,
354            )
355            amax = torch.amax(torch.abs(x), keepdim=keepdim)
356            amax_buffer.view_as(amax).copy_(amax)
357            x_scaled = x * scale
358            bits_fp8 = _to_fp8_saturated(x_scaled, float8_dtype)
359            return bits_fp8
360
361        compiled_ln_fp8_quant = torch.compile(ln_fp8, backend="inductor")
362
363        x_shape = (batch_size, sequence_length, hidden_size)
364        x = torch.rand(*x_shape, device="cuda", dtype=torch.half)
365        scale = torch.tensor(0.2, device="cuda", dtype=torch.float)
366
367        amax_buffer_compiled = torch.zeros((1), device="cuda", dtype=torch.half)
368        amax_buffer = torch.zeros((1), device="cuda", dtype=torch.half)
369        _ = compiled_ln_fp8_quant(x, scale, amax_buffer_compiled)
370        compiled_latency = utils.do_bench_using_profiling(
371            functools.partial(compiled_ln_fp8_quant, x, scale, amax_buffer_compiled)
372        )
373        eager_latency = utils.do_bench_using_profiling(
374            functools.partial(ln_fp8, x, scale, amax_buffer)
375        )
376
377        compiled_ln = torch.compile(ln, backend="inductor")
378        _ = compiled_ln(x)
379        ln_latency = utils.do_bench_using_profiling(functools.partial(compiled_ln, x))
380
381        print(
382            f"Config: {float8_dtype=}, {shape=}, {keepdim=}. "
383            f"Benchmark results: Inductor: {compiled_latency}ms, Eager: {eager_latency}ms, "
384            f"LN only Inductor: {ln_latency}ms."
385        )
386
387
388@instantiate_parametrized_tests
389class TestFP8Lowering(TestCase):
390    @unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
391    @unittest.skipIf(not SM90OrLater, "FP8 is only supported on H100+")
392    @parametrize("dtype", (torch.bfloat16, torch.float32))
393    @parametrize("shape", ("16,16,32", "1024,1024,512"))
394    @parametrize("has_bias", (False, True))
395    @parametrize("use_fast_accum", (False, True))
396    def test_tensorwise_scaling(
397        self, dtype: torch.dtype, shape: str, has_bias: bool, use_fast_accum: bool
398    ):
399        if dtype is torch.float32 and has_bias:
400            self.skipTest("bias is not supported when output dtype is float32")
401
402        device = "cuda"
403        dtype_float8 = torch.float8_e4m3fn
404
405        shape = [int(dim) for dim in shape.split(",")]
406        M, K, N = shape  # Matmul Y = X [M, K] x W [N, K]
407        # input and output dtypes of _scaled_mm do not need to be the same, but
408        # typically in a model they are
409        x = torch.randn(M, K, dtype=dtype, device=device)
410        w = torch.randn(N, K, dtype=dtype, device=device)
411        bias = None
412        if has_bias:
413            bias = torch.randn(N, device=device, dtype=torch.bfloat16)
414
415        # quantize weight (prior to inference)
416        w_fp8, w_inverse_scale = _quantize_tensorwise(w, dtype_float8)
417        w_t_fp8 = w_fp8.t()
418
419        # quantize input x
420        x_fp8, x_inverse_scale = _quantize_tensorwise(x, dtype_float8)
421
422        def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias):
423            y = torch._scaled_mm(
424                x_fp8,
425                w_t_fp8,
426                x_inverse_scale,
427                w_inverse_scale,
428                bias,
429                out_dtype=dtype,
430                use_fast_accum=use_fast_accum,
431            )
432            return y
433
434        y_eager = linear(
435            x_fp8,
436            x_inverse_scale,
437            w_t_fp8,
438            w_inverse_scale,
439            bias,
440        )
441        linear_compiled = torch.compile(linear, backend="inductor", mode="max-autotune")
442        y_compiled = linear_compiled(
443            x_fp8,
444            x_inverse_scale,
445            w_t_fp8,
446            w_inverse_scale,
447            bias,
448        )
449        self.assertEqual(y_eager.dtype, dtype)
450        self.assertEqual(y_compiled.dtype, dtype)
451        # depending on the kernel config (BLOCK_M size, etc) selected during Inductor
452        # autotuning for the compiled case, the results can be different because of
453        # the way blocks of results are accumulated (float addition not associative), so
454        # setting a small absolute tolerance in these tests
455        torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05)
456
457    @unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
458    @unittest.skipIf(not SM90OrLater, "FP8 is only supported on H100+")
459    @parametrize("shape", ("16,16,32", "1024,1024,512"))
460    @parametrize("has_bias", (False, True))
461    @parametrize("use_fast_accum", (False, True))
462    def test_rowwise_scaling(self, shape: str, has_bias: bool, use_fast_accum: bool):
463        # Only bf16 output type is supported for row-wise scaling, not fp32
464        dtype: torch.dtype = torch.bfloat16
465        device = "cuda"
466        dtype_float8 = torch.float8_e4m3fn
467
468        shape = [int(dim) for dim in shape.split(",")]
469        M, K, N = shape  # Matmul Y = X [M, K] x W [N, K]
470        x = torch.randn(M, K, dtype=dtype, device=device)
471        w = torch.randn(N, K, dtype=dtype, device=device)
472        bias = None
473        if has_bias:
474            bias = torch.randn(N, device=device, dtype=torch.bfloat16)
475
476        # quantize weight (prior to inference)
477        w_fp8, w_inverse_scale = _quantize_rowwise(w, dtype_float8)
478        w_t_fp8 = w_fp8.t()
479        w_inverse_scale = w_inverse_scale.t()  # scale_b should be (1, N)
480
481        # quantize input x
482        x_fp8, x_inverse_scale = _quantize_rowwise(x, dtype_float8)
483
484        def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias):
485            y = torch._scaled_mm(
486                x_fp8,
487                w_t_fp8,
488                x_inverse_scale,
489                w_inverse_scale,
490                bias,
491                out_dtype=dtype,
492                use_fast_accum=use_fast_accum,
493            )
494            return y
495
496        y_eager = linear(
497            x_fp8,
498            x_inverse_scale,
499            w_t_fp8,
500            w_inverse_scale,
501            bias,
502        )
503        linear_compiled = torch.compile(linear, backend="inductor", mode="max-autotune")
504        y_compiled = linear_compiled(
505            x_fp8,
506            x_inverse_scale,
507            w_t_fp8,
508            w_inverse_scale,
509            bias,
510        )
511        self.assertEqual(y_eager.dtype, dtype)
512        self.assertEqual(y_compiled.dtype, dtype)
513        torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05)
514
515    @unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
516    @unittest.skipIf(not SM90OrLater, "FP8 is only supported on H100+")
517    @parametrize("M", (1, 3, 33, 257, 1024))
518    @parametrize("K", (16, 1024))
519    @parametrize("N", (16, 2048))
520    def test_tensorwise_scaling_acceptable_input_dims(self, M: int, K: int, N: int):
521        # alignment requirements: K and N divisible by 16
522        dtype: torch.dtype = torch.bfloat16
523        use_fast_accum = True
524        device = "cuda"
525        dtype_float8 = torch.float8_e4m3fn
526
527        x = torch.randn(M, K, dtype=dtype, device=device)
528        w = torch.randn(N, K, dtype=dtype, device=device)
529        bias = None
530        w_fp8, w_inverse_scale = _quantize_tensorwise(w, dtype_float8)
531        w_t_fp8 = w_fp8.t()
532        x_fp8, x_inverse_scale = _quantize_tensorwise(x, dtype_float8)
533
534        def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias):
535            y = torch._scaled_mm(
536                x_fp8,
537                w_t_fp8,
538                x_inverse_scale,
539                w_inverse_scale,
540                bias,
541                out_dtype=dtype,
542                use_fast_accum=use_fast_accum,
543            )
544            return y
545
546        y_eager = linear(
547            x_fp8,
548            x_inverse_scale,
549            w_t_fp8,
550            w_inverse_scale,
551            bias,
552        )
553        linear_compiled = torch.compile(linear, backend="inductor", mode="max-autotune")
554        y_compiled = linear_compiled(
555            x_fp8,
556            x_inverse_scale,
557            w_t_fp8,
558            w_inverse_scale,
559            bias,
560        )
561        self.assertEqual(y_eager.dtype, dtype)
562        self.assertEqual(y_compiled.dtype, dtype)
563        torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.07)
564
565    @unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
566    @unittest.skipIf(not SM90OrLater, "FP8 is only supported on H100+")
567    @parametrize("M", (1, 3, 33, 257, 1024))
568    @parametrize("K", (16, 1024))
569    @parametrize("N", (16, 2048))
570    def test_rowwise_scaling_acceptable_input_dims(self, M: int, K: int, N: int):
571        dtype: torch.dtype = torch.bfloat16
572        use_fast_accum = True
573        device = "cuda"
574        dtype_float8 = torch.float8_e4m3fn
575
576        x = torch.randn(M, K, dtype=dtype, device=device)
577        w = torch.randn(N, K, dtype=dtype, device=device)
578        bias = torch.randn(N, device=device, dtype=torch.bfloat16)
579
580        w_fp8, w_inverse_scale = _quantize_rowwise(w, dtype_float8)
581        w_t_fp8 = w_fp8.t()
582        w_inverse_scale = w_inverse_scale.t()  # scale_b should be (1, N)
583        x_fp8, x_inverse_scale = _quantize_rowwise(x, dtype_float8)
584
585        def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias):
586            y = torch._scaled_mm(
587                x_fp8,
588                w_t_fp8,
589                x_inverse_scale,
590                w_inverse_scale,
591                bias,
592                out_dtype=dtype,
593                use_fast_accum=use_fast_accum,
594            )
595            return y
596
597        y_eager = linear(
598            x_fp8,
599            x_inverse_scale,
600            w_t_fp8,
601            w_inverse_scale,
602            bias,
603        )
604        linear_compiled = torch.compile(linear, backend="inductor", mode="max-autotune")
605        y_compiled = linear_compiled(
606            x_fp8,
607            x_inverse_scale,
608            w_t_fp8,
609            w_inverse_scale,
610            bias,
611        )
612        self.assertEqual(y_eager.dtype, dtype)
613        self.assertEqual(y_compiled.dtype, dtype)
614        torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.07)
615
616    @unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
617    @unittest.skipIf(not SM90OrLater, "FP8 is only supported on H100+")
618    def test_unacceptable_input_dims(self):
619        # for compiled ops, type checking is in torch/_meta_registrations.py
620        dtype: torch.dtype = torch.bfloat16
621        device = "cuda"
622        dtype_float8 = torch.float8_e4m3fn
623        M, K, N = 64, 15, 2048  # K needs to be a multiple of 16
624        x = torch.randn(M, K, dtype=dtype, device=device)
625        w = torch.randn(N, K, dtype=dtype, device=device)
626        bias = torch.randn(N, device=device, dtype=torch.bfloat16)
627        w_fp8, w_inverse_scale = _quantize_tensorwise(w, dtype_float8)
628        w_t_fp8 = w_fp8.t()
629
630        def linear(x, w_t_fp8, w_inverse_scale, bias):
631            x_fp8, x_inverse_scale = _quantize_tensorwise(x, dtype_float8)
632            y = torch._scaled_mm(
633                x_fp8,
634                w_t_fp8,
635                x_inverse_scale,
636                w_inverse_scale,
637                bias,
638                out_dtype=dtype,
639                use_fast_accum=True,
640            )
641            return y
642
643        linear_compiled = torch.compile(linear, backend="inductor", mode="max-autotune")
644        with self.assertRaises(torch._dynamo.exc.TorchRuntimeError) as cm:
645            y_compiled = linear_compiled(
646                x,
647                w_t_fp8,
648                w_inverse_scale,
649                bias,
650            )
651        self.assertTrue(
652            f"Expected self.size(1) to be divisible by 16, but got self.size(1)={K}"
653            in str(cm.exception)
654        )
655
656    @unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM")
657    @unittest.skipIf(not SM90OrLater, "FP8 is only supported on H100+")
658    def test_unacceptable_scale_dims_rowwise_scaling(self):
659        dtype: torch.dtype = torch.bfloat16
660        device = "cuda"
661        dtype_float8 = torch.float8_e4m3fn
662        M, K, N = 233, 32, 128
663        x = torch.randn(M, K, dtype=dtype, device=device)
664        w = torch.randn(N, K, dtype=dtype, device=device)
665        bias = torch.randn(N, device=device, dtype=torch.bfloat16)
666        w_fp8, w_inverse_scale = _quantize_rowwise(w, dtype_float8)
667        w_t_fp8 = w_fp8.t()
668
669        def linear(x, w_t_fp8, w_inverse_scale, bias):
670            x_fp8, x_inverse_scale = _quantize_rowwise(x, dtype_float8)
671            y = torch._scaled_mm(
672                x_fp8,
673                w_t_fp8,
674                w_inverse_scale.t(),  # testing with w and x scales switched
675                x_inverse_scale,
676                bias,
677                out_dtype=dtype,
678                use_fast_accum=True,
679            )
680            return y
681
682        linear_compiled = torch.compile(linear, backend="inductor", mode="max-autotune")
683        with self.assertRaises(torch._dynamo.exc.TorchRuntimeError) as cm:
684            y_compiled = linear_compiled(
685                x,
686                w_t_fp8,
687                w_inverse_scale,
688                bias,
689            )
690        self.assertTrue("Invalid scaling configuration." in str(cm.exception))
691
692
693if __name__ == "__main__":
694    if HAS_CUDA:
695        run_tests()
696