1# Owner(s): ["module: inductor"] 2 3import functools 4import unittest 5 6import torch 7from torch import Tensor 8from torch._inductor import utils 9from torch._inductor.test_case import run_tests, TestCase 10from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FP8, SM90OrLater 11from torch.testing._internal.common_utils import ( 12 instantiate_parametrized_tests, 13 parametrize, 14 TEST_WITH_ROCM, 15) 16from torch.testing._internal.inductor_utils import HAS_CUDA 17 18 19torch.set_float32_matmul_precision("high") 20 21 22f8_msg = "FP8 is only supported on H100+ and sm_89 and MI300+ devices" 23 24# define the e4m3/e5m2 constants 25E4M3_MAX_POS = torch.finfo(torch.float8_e4m3fn).max 26E5M2_MAX_POS = torch.finfo(torch.float8_e5m2).max 27E4M3FNUZ_MAX_POS = torch.finfo(torch.float8_e4m3fnuz).max 28E5M2FNUZ_MAX_POS = torch.finfo(torch.float8_e5m2fnuz).max 29 30FP16_MAX_POS: float = torch.finfo(torch.float16).max 31EPS: float = 1e-12 32 33 34def _to_fp8_saturated(x: Tensor, float8_dtype: torch.dtype) -> Tensor: 35 # The default behavior in PyTorch for casting to `float8_e4m3fn` 36 # and `e5m2` is to not saturate. In this context, we should saturate. 37 # A common case where we want to saturate is when the history of a 38 # tensor has a maximum value of `amax1`, and the current amax value 39 # is `amax2`, where `amax1 < amax2`. This is common when using delayed 40 # scaling. 41 if float8_dtype == torch.float8_e4m3fn: 42 x = x.clamp(min=-1 * E4M3_MAX_POS, max=E4M3_MAX_POS) 43 elif float8_dtype == torch.float8_e5m2: 44 x = x.clamp(min=-1 * E5M2_MAX_POS, max=E5M2_MAX_POS) 45 elif float8_dtype == torch.float8_e4m3fnuz: 46 x = x.clamp(min=-1 * E4M3FNUZ_MAX_POS, max=E4M3FNUZ_MAX_POS) 47 elif float8_dtype == torch.float8_e5m2fnuz: 48 x = x.clamp(min=-1 * E5M2FNUZ_MAX_POS, max=E5M2FNUZ_MAX_POS) 49 else: 50 raise TypeError(f"Unsupported float8_dtype: {float8_dtype}") 51 return x.to(float8_dtype) 52 53 54@torch.no_grad() 55def _amax_to_scale( 56 amax: torch.Tensor, float8_dtype: torch.dtype, orig_dtype: torch.dtype 57) -> torch.Tensor: 58 # To make scale dtype to be fp32 for accuracy 59 amax = amax.float() 60 if float8_dtype == torch.float8_e4m3fn: 61 res = E4M3_MAX_POS / torch.clamp(amax, min=EPS) 62 else: # e5m2 63 res = E5M2_MAX_POS / torch.clamp(amax, min=EPS) 64 65 # Ensure that the scale is representable in float16, 66 # this helps when amax is small. We are assuming that we don't need 67 # to care about this for float32/bfloat16. 68 if orig_dtype is torch.float16: 69 res = torch.clamp(res, max=FP16_MAX_POS) 70 return res 71 72 73def _quantize_tensorwise(x: Tensor, float8_dtype: torch.dtype): 74 amax = torch.max(torch.abs(x)) 75 scale = _amax_to_scale(amax, float8_dtype, x.dtype) 76 x_fp8 = _to_fp8_saturated(x * scale, float8_dtype) 77 inverse_scale = scale.reciprocal() 78 return x_fp8, inverse_scale 79 80 81def _quantize_rowwise(x: Tensor, float8_dtype: torch.dtype): 82 amax = torch.max(torch.abs(x), dim=1, keepdim=True).values 83 scale = _amax_to_scale(amax, float8_dtype, x.dtype) 84 x_fp8 = _to_fp8_saturated(x * scale, float8_dtype) 85 inverse_scale = scale.reciprocal() 86 return x_fp8, inverse_scale 87 88 89@instantiate_parametrized_tests 90class TestFP8Types(TestCase): 91 @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) 92 @unittest.skipIf(TEST_WITH_ROCM, "Not supported yet") 93 @parametrize("dtype", (torch.float16, torch.bfloat16)) 94 def test_eager_fallback(self, dtype: torch.dtype): 95 weight_shape = (32, 16) 96 97 e4m3_type = ( 98 torch.float8_e4m3fn if torch.version.hip is None else torch.float8_e4m3fnuz 99 ) 100 101 def fp8_matmul_unwrapped(x): 102 a_scale = torch.Tensor([1.0]).to(device="cuda") 103 b_scale = torch.Tensor([1.0]).to(device="cuda") 104 output_scale = None 105 input_bias = torch.rand(32, device="cuda", dtype=dtype) 106 weight = torch.rand(*weight_shape, device="cuda", dtype=dtype).T.to( 107 e4m3_type 108 ) 109 a_inverse_scale = 1 / a_scale 110 b_inverse_scale = 1 / b_scale 111 output = torch._scaled_mm( 112 x, 113 weight, 114 bias=input_bias, 115 out_dtype=dtype, 116 scale_a=a_inverse_scale, 117 scale_b=b_inverse_scale, 118 scale_result=output_scale, 119 ) 120 return output 121 122 compiled_fp8_matmul = torch.compile( 123 fp8_matmul_unwrapped, backend="inductor", dynamic=True 124 ) 125 126 x_shape = (16, 16) 127 x = torch.rand(*x_shape, device="cuda", dtype=dtype).to(e4m3_type) 128 y_fp8 = compiled_fp8_matmul(x) 129 130 x_shape = (15, 16) 131 x = torch.rand(*x_shape, device="cuda", dtype=dtype).to(e4m3_type) 132 y_fp8 = compiled_fp8_matmul(x) 133 134 @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) 135 @parametrize("dtype", (torch.float16, torch.bfloat16, torch.float)) 136 @parametrize("shape", ("15,3,13", "4,2048,4096")) 137 @parametrize( 138 "dst_types", 139 [(torch.float8_e4m3fn, torch.float8_e5m2)] 140 if torch.version.hip is None 141 else [(torch.float8_e4m3fnuz, torch.float8_e5m2fnuz)], 142 ) 143 def test_valid_cast(self, dtype: torch.dtype, shape: str, dst_types: tuple): 144 e4m3, e5m2 = dst_types 145 146 def fp8_cast(x): 147 y0 = x.to(dtype=e4m3).to(dtype) 148 y1 = x.to(dtype=e5m2).to(dtype) 149 return y0, y1 150 151 compiled_fp8_cast = torch.compile(fp8_cast, backend="inductor", dynamic=True) 152 153 shape = [int(dim) for dim in shape.split(",")] 154 x = torch.rand(*shape, device="cuda", dtype=dtype) 155 y0_fp8, y1_fp8 = compiled_fp8_cast(x) 156 157 torch.testing.assert_close(y0_fp8, x, rtol=5e-1, atol=5e-1) 158 torch.testing.assert_close(y1_fp8, x, rtol=5e-1, atol=5e-1) 159 160 @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) 161 def test_bad_cast(self): 162 def fp8_cast(x, dtype): 163 return x.to(dtype=dtype) 164 165 compiled_fp8_cast = torch.compile(fp8_cast, backend="inductor", dynamic=True) 166 167 x_shape = (16, 16, 16) 168 169 with self.assertRaisesRegex( 170 torch._dynamo.exc.BackendCompilerFailed, 171 "Conversions between float8_e5m2 and float8_e4m3fn is not supported!", 172 ): 173 x = torch.rand(*x_shape, device="cuda").to(dtype=torch.float8_e4m3fn) 174 y = compiled_fp8_cast(x, torch.float8_e5m2) 175 176 with self.assertRaisesRegex( 177 torch._dynamo.exc.BackendCompilerFailed, 178 "Conversions between float8_e5m2 and float8_e4m3fn is not supported!", 179 ): 180 x = torch.rand(*x_shape, device="cuda").to(dtype=torch.float8_e5m2) 181 y = compiled_fp8_cast(x, torch.float8_e4m3fn) 182 183 @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) 184 @parametrize("src_dtype", (torch.float16, torch.bfloat16, torch.float)) 185 @parametrize( 186 "dst_dtype", 187 (torch.float8_e4m3fn, torch.float8_e5m2) 188 if torch.version.hip is None 189 else (torch.float8_e4m3fnuz, torch.float8_e5m2fnuz), 190 ) 191 @parametrize("shape", ("16,16,16", "4,2048,4096")) 192 def test_to_fp8_saturated( 193 self, src_dtype: torch.dtype, dst_dtype: torch.dtype, shape: str 194 ): 195 def fp8_saturated(x, dtype): 196 return _to_fp8_saturated(x, dtype) 197 198 compiled_fp8_cast = torch.compile( 199 fp8_saturated, backend="inductor", dynamic=True 200 ) 201 shape = [int(dim) for dim in shape.split(",")] 202 x = torch.rand(*shape, device="cuda", dtype=src_dtype) 203 y_compiled = compiled_fp8_cast(x, dst_dtype) 204 y = fp8_saturated(x, dst_dtype) 205 206 torch.testing.assert_close(y_compiled.half(), y.half(), rtol=5e-1, atol=5e-1) 207 208 @unittest.skipIf(TEST_WITH_ROCM, "ROCm fails with accuracy issue") 209 @unittest.skipIf(not SM90OrLater, "FP8 is only supported on H100+") 210 @parametrize( 211 "float8_dtype", 212 (torch.float8_e4m3fn, torch.float8_e5m2) 213 if torch.version.hip is None 214 else (torch.float8_e4m3fnuz, torch.float8_e5m2fnuz), 215 ) 216 @parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096")) 217 def test_amax_fp8_quant(self, float8_dtype: torch.dtype, shape: str): 218 shape = [int(dim) for dim in shape.split(",")] 219 batch_size, sequence_length, hidden_size = shape 220 221 def amax_fp8(x: Tensor, scale: Tensor): 222 y = torch.amax(torch.abs(x)) 223 y_scaled = y.to(dtype=torch.float) * scale 224 bits_fp8 = _to_fp8_saturated(y_scaled, float8_dtype) 225 return bits_fp8 226 227 compiled_amax_fp8_quant = torch.compile(amax_fp8, backend="inductor") 228 229 x_shape = (batch_size, sequence_length, hidden_size) 230 x = torch.rand(*x_shape, device="cuda", dtype=torch.half) 231 scale = torch.tensor(0.2, device="cuda", dtype=torch.float) 232 233 y_compiled = compiled_amax_fp8_quant(x, scale) 234 y = amax_fp8(x, scale) 235 236 torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-2, atol=1e-2) 237 238 @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) 239 @parametrize( 240 "float8_dtype", 241 (torch.float8_e4m3fn, torch.float8_e5m2) 242 if torch.version.hip is None 243 else (torch.float8_e4m3fnuz, torch.float8_e5m2fnuz), 244 ) 245 @parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096")) 246 def test_amax_along_with_fp8_quant(self, float8_dtype: torch.dtype, shape: str): 247 shape = [int(dim) for dim in shape.split(",")] 248 batch_size, sequence_length, hidden_size = shape 249 250 def amax_fp8(x: Tensor, scale: Tensor, amax_buffer: Tensor): 251 amax_buffer.fill_(torch.amax(torch.abs(x))) 252 x_scaled = x.to(dtype=torch.float) * scale 253 bits_fp8 = _to_fp8_saturated(x_scaled, float8_dtype) 254 return bits_fp8 255 256 compiled_amax_fp8_quant = torch.compile(amax_fp8, backend="inductor") 257 258 x_shape = (batch_size, sequence_length, hidden_size) 259 x = torch.rand(*x_shape, device="cuda", dtype=torch.half) 260 scale = torch.tensor(1.0, device="cuda", dtype=torch.float) 261 262 amax_buffer_compiled = torch.zeros((1), device="cuda", dtype=torch.half) 263 y_compiled = compiled_amax_fp8_quant(x, scale, amax_buffer_compiled) 264 amax_buffer = torch.zeros((1), device="cuda", dtype=torch.half) 265 y = amax_fp8(x, scale, amax_buffer) 266 267 torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-1, atol=1e-1) 268 torch.testing.assert_close( 269 amax_buffer_compiled, amax_buffer, rtol=1e-2, atol=1e-2 270 ) 271 272 @unittest.skipIf(TEST_WITH_ROCM, "ROCm fails with accuracy issue") 273 @unittest.skipIf(not SM90OrLater, "FP8 is only supported on H100+") 274 @parametrize( 275 "float8_dtype", 276 (torch.float8_e4m3fn, torch.float8_e5m2) 277 if torch.version.hip is None 278 else (torch.float8_e4m3fnuz, torch.float8_e5m2fnuz), 279 ) 280 @parametrize("amax_keep_dim", (True, False)) 281 @parametrize("shape", ("1,1,15", "1,10,15", "1,10,512", "1,10,4096", "4,2048,4096")) 282 def test_layernorm_fp8_quant( 283 self, float8_dtype: torch.dtype, amax_keep_dim: bool, shape: str 284 ): 285 shape = [int(dim) for dim in shape.split(",")] 286 batch_size, sequence_length, hidden_size = shape 287 288 def ln_fp8(x: Tensor, scale: Tensor, amax_buffer: Tensor): 289 x = torch.nn.functional.layer_norm( 290 x.to(dtype=torch.float), 291 [hidden_size], 292 weight=None, 293 bias=None, 294 eps=1e-05, 295 ) 296 amax_buffer.fill_( 297 torch.amax(torch.abs(x), keepdim=amax_keep_dim).reshape(-1)[0] 298 ) 299 x_scaled = x * scale 300 bits_fp8 = _to_fp8_saturated(x_scaled, float8_dtype) 301 return bits_fp8 302 303 compiled_ln_fp8_quant = torch.compile(ln_fp8, backend="inductor") 304 305 x_shape = (batch_size, sequence_length, hidden_size) 306 x = torch.rand(*x_shape, device="cuda", dtype=torch.half) 307 scale = torch.tensor(0.2, device="cuda", dtype=torch.float) 308 309 amax_buffer_compiled = torch.zeros((1), device="cuda", dtype=torch.half) 310 y_compiled = compiled_ln_fp8_quant(x, scale, amax_buffer_compiled) 311 amax_buffer = torch.zeros((1), device="cuda", dtype=torch.half) 312 y = ln_fp8(x, scale, amax_buffer) 313 314 torch.testing.assert_close(y_compiled.half(), y.half(), rtol=1e-1, atol=1e-1) 315 torch.testing.assert_close( 316 amax_buffer_compiled, amax_buffer, rtol=1e-2, atol=1e-2 317 ) 318 319 @unittest.skipIf(not PLATFORM_SUPPORTS_FP8, f8_msg) 320 @parametrize( 321 "float8_dtype", 322 (torch.float8_e4m3fn, torch.float8_e5m2) 323 if torch.version.hip is None 324 else (torch.float8_e4m3fnuz, torch.float8_e5m2fnuz), 325 ) 326 @parametrize("shape", ("4,2048,4096",)) 327 @parametrize("keepdim", (False, True)) 328 def test_layernorm_fp8_quant_benchmark( 329 self, 330 float8_dtype: torch.dtype, 331 shape: str, 332 keepdim: bool, 333 ): 334 shape = [int(dim) for dim in shape.split(",")] 335 batch_size, sequence_length, hidden_size = shape 336 337 def ln(x: Tensor): 338 x = torch.nn.functional.layer_norm( 339 x.to(dtype=torch.float), 340 [hidden_size], 341 weight=None, 342 bias=None, 343 eps=1e-05, 344 ) 345 return x 346 347 def ln_fp8(x: Tensor, scale: Tensor, amax_buffer: Tensor): 348 x = torch.nn.functional.layer_norm( 349 x.to(dtype=torch.float), 350 [hidden_size], 351 weight=None, 352 bias=None, 353 eps=1e-05, 354 ) 355 amax = torch.amax(torch.abs(x), keepdim=keepdim) 356 amax_buffer.view_as(amax).copy_(amax) 357 x_scaled = x * scale 358 bits_fp8 = _to_fp8_saturated(x_scaled, float8_dtype) 359 return bits_fp8 360 361 compiled_ln_fp8_quant = torch.compile(ln_fp8, backend="inductor") 362 363 x_shape = (batch_size, sequence_length, hidden_size) 364 x = torch.rand(*x_shape, device="cuda", dtype=torch.half) 365 scale = torch.tensor(0.2, device="cuda", dtype=torch.float) 366 367 amax_buffer_compiled = torch.zeros((1), device="cuda", dtype=torch.half) 368 amax_buffer = torch.zeros((1), device="cuda", dtype=torch.half) 369 _ = compiled_ln_fp8_quant(x, scale, amax_buffer_compiled) 370 compiled_latency = utils.do_bench_using_profiling( 371 functools.partial(compiled_ln_fp8_quant, x, scale, amax_buffer_compiled) 372 ) 373 eager_latency = utils.do_bench_using_profiling( 374 functools.partial(ln_fp8, x, scale, amax_buffer) 375 ) 376 377 compiled_ln = torch.compile(ln, backend="inductor") 378 _ = compiled_ln(x) 379 ln_latency = utils.do_bench_using_profiling(functools.partial(compiled_ln, x)) 380 381 print( 382 f"Config: {float8_dtype=}, {shape=}, {keepdim=}. " 383 f"Benchmark results: Inductor: {compiled_latency}ms, Eager: {eager_latency}ms, " 384 f"LN only Inductor: {ln_latency}ms." 385 ) 386 387 388@instantiate_parametrized_tests 389class TestFP8Lowering(TestCase): 390 @unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM") 391 @unittest.skipIf(not SM90OrLater, "FP8 is only supported on H100+") 392 @parametrize("dtype", (torch.bfloat16, torch.float32)) 393 @parametrize("shape", ("16,16,32", "1024,1024,512")) 394 @parametrize("has_bias", (False, True)) 395 @parametrize("use_fast_accum", (False, True)) 396 def test_tensorwise_scaling( 397 self, dtype: torch.dtype, shape: str, has_bias: bool, use_fast_accum: bool 398 ): 399 if dtype is torch.float32 and has_bias: 400 self.skipTest("bias is not supported when output dtype is float32") 401 402 device = "cuda" 403 dtype_float8 = torch.float8_e4m3fn 404 405 shape = [int(dim) for dim in shape.split(",")] 406 M, K, N = shape # Matmul Y = X [M, K] x W [N, K] 407 # input and output dtypes of _scaled_mm do not need to be the same, but 408 # typically in a model they are 409 x = torch.randn(M, K, dtype=dtype, device=device) 410 w = torch.randn(N, K, dtype=dtype, device=device) 411 bias = None 412 if has_bias: 413 bias = torch.randn(N, device=device, dtype=torch.bfloat16) 414 415 # quantize weight (prior to inference) 416 w_fp8, w_inverse_scale = _quantize_tensorwise(w, dtype_float8) 417 w_t_fp8 = w_fp8.t() 418 419 # quantize input x 420 x_fp8, x_inverse_scale = _quantize_tensorwise(x, dtype_float8) 421 422 def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): 423 y = torch._scaled_mm( 424 x_fp8, 425 w_t_fp8, 426 x_inverse_scale, 427 w_inverse_scale, 428 bias, 429 out_dtype=dtype, 430 use_fast_accum=use_fast_accum, 431 ) 432 return y 433 434 y_eager = linear( 435 x_fp8, 436 x_inverse_scale, 437 w_t_fp8, 438 w_inverse_scale, 439 bias, 440 ) 441 linear_compiled = torch.compile(linear, backend="inductor", mode="max-autotune") 442 y_compiled = linear_compiled( 443 x_fp8, 444 x_inverse_scale, 445 w_t_fp8, 446 w_inverse_scale, 447 bias, 448 ) 449 self.assertEqual(y_eager.dtype, dtype) 450 self.assertEqual(y_compiled.dtype, dtype) 451 # depending on the kernel config (BLOCK_M size, etc) selected during Inductor 452 # autotuning for the compiled case, the results can be different because of 453 # the way blocks of results are accumulated (float addition not associative), so 454 # setting a small absolute tolerance in these tests 455 torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05) 456 457 @unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM") 458 @unittest.skipIf(not SM90OrLater, "FP8 is only supported on H100+") 459 @parametrize("shape", ("16,16,32", "1024,1024,512")) 460 @parametrize("has_bias", (False, True)) 461 @parametrize("use_fast_accum", (False, True)) 462 def test_rowwise_scaling(self, shape: str, has_bias: bool, use_fast_accum: bool): 463 # Only bf16 output type is supported for row-wise scaling, not fp32 464 dtype: torch.dtype = torch.bfloat16 465 device = "cuda" 466 dtype_float8 = torch.float8_e4m3fn 467 468 shape = [int(dim) for dim in shape.split(",")] 469 M, K, N = shape # Matmul Y = X [M, K] x W [N, K] 470 x = torch.randn(M, K, dtype=dtype, device=device) 471 w = torch.randn(N, K, dtype=dtype, device=device) 472 bias = None 473 if has_bias: 474 bias = torch.randn(N, device=device, dtype=torch.bfloat16) 475 476 # quantize weight (prior to inference) 477 w_fp8, w_inverse_scale = _quantize_rowwise(w, dtype_float8) 478 w_t_fp8 = w_fp8.t() 479 w_inverse_scale = w_inverse_scale.t() # scale_b should be (1, N) 480 481 # quantize input x 482 x_fp8, x_inverse_scale = _quantize_rowwise(x, dtype_float8) 483 484 def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): 485 y = torch._scaled_mm( 486 x_fp8, 487 w_t_fp8, 488 x_inverse_scale, 489 w_inverse_scale, 490 bias, 491 out_dtype=dtype, 492 use_fast_accum=use_fast_accum, 493 ) 494 return y 495 496 y_eager = linear( 497 x_fp8, 498 x_inverse_scale, 499 w_t_fp8, 500 w_inverse_scale, 501 bias, 502 ) 503 linear_compiled = torch.compile(linear, backend="inductor", mode="max-autotune") 504 y_compiled = linear_compiled( 505 x_fp8, 506 x_inverse_scale, 507 w_t_fp8, 508 w_inverse_scale, 509 bias, 510 ) 511 self.assertEqual(y_eager.dtype, dtype) 512 self.assertEqual(y_compiled.dtype, dtype) 513 torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.05) 514 515 @unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM") 516 @unittest.skipIf(not SM90OrLater, "FP8 is only supported on H100+") 517 @parametrize("M", (1, 3, 33, 257, 1024)) 518 @parametrize("K", (16, 1024)) 519 @parametrize("N", (16, 2048)) 520 def test_tensorwise_scaling_acceptable_input_dims(self, M: int, K: int, N: int): 521 # alignment requirements: K and N divisible by 16 522 dtype: torch.dtype = torch.bfloat16 523 use_fast_accum = True 524 device = "cuda" 525 dtype_float8 = torch.float8_e4m3fn 526 527 x = torch.randn(M, K, dtype=dtype, device=device) 528 w = torch.randn(N, K, dtype=dtype, device=device) 529 bias = None 530 w_fp8, w_inverse_scale = _quantize_tensorwise(w, dtype_float8) 531 w_t_fp8 = w_fp8.t() 532 x_fp8, x_inverse_scale = _quantize_tensorwise(x, dtype_float8) 533 534 def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): 535 y = torch._scaled_mm( 536 x_fp8, 537 w_t_fp8, 538 x_inverse_scale, 539 w_inverse_scale, 540 bias, 541 out_dtype=dtype, 542 use_fast_accum=use_fast_accum, 543 ) 544 return y 545 546 y_eager = linear( 547 x_fp8, 548 x_inverse_scale, 549 w_t_fp8, 550 w_inverse_scale, 551 bias, 552 ) 553 linear_compiled = torch.compile(linear, backend="inductor", mode="max-autotune") 554 y_compiled = linear_compiled( 555 x_fp8, 556 x_inverse_scale, 557 w_t_fp8, 558 w_inverse_scale, 559 bias, 560 ) 561 self.assertEqual(y_eager.dtype, dtype) 562 self.assertEqual(y_compiled.dtype, dtype) 563 torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.07) 564 565 @unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM") 566 @unittest.skipIf(not SM90OrLater, "FP8 is only supported on H100+") 567 @parametrize("M", (1, 3, 33, 257, 1024)) 568 @parametrize("K", (16, 1024)) 569 @parametrize("N", (16, 2048)) 570 def test_rowwise_scaling_acceptable_input_dims(self, M: int, K: int, N: int): 571 dtype: torch.dtype = torch.bfloat16 572 use_fast_accum = True 573 device = "cuda" 574 dtype_float8 = torch.float8_e4m3fn 575 576 x = torch.randn(M, K, dtype=dtype, device=device) 577 w = torch.randn(N, K, dtype=dtype, device=device) 578 bias = torch.randn(N, device=device, dtype=torch.bfloat16) 579 580 w_fp8, w_inverse_scale = _quantize_rowwise(w, dtype_float8) 581 w_t_fp8 = w_fp8.t() 582 w_inverse_scale = w_inverse_scale.t() # scale_b should be (1, N) 583 x_fp8, x_inverse_scale = _quantize_rowwise(x, dtype_float8) 584 585 def linear(x_fp8, x_inverse_scale, w_t_fp8, w_inverse_scale, bias): 586 y = torch._scaled_mm( 587 x_fp8, 588 w_t_fp8, 589 x_inverse_scale, 590 w_inverse_scale, 591 bias, 592 out_dtype=dtype, 593 use_fast_accum=use_fast_accum, 594 ) 595 return y 596 597 y_eager = linear( 598 x_fp8, 599 x_inverse_scale, 600 w_t_fp8, 601 w_inverse_scale, 602 bias, 603 ) 604 linear_compiled = torch.compile(linear, backend="inductor", mode="max-autotune") 605 y_compiled = linear_compiled( 606 x_fp8, 607 x_inverse_scale, 608 w_t_fp8, 609 w_inverse_scale, 610 bias, 611 ) 612 self.assertEqual(y_eager.dtype, dtype) 613 self.assertEqual(y_compiled.dtype, dtype) 614 torch.testing.assert_close(y_eager, y_compiled, rtol=1e-2, atol=0.07) 615 616 @unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM") 617 @unittest.skipIf(not SM90OrLater, "FP8 is only supported on H100+") 618 def test_unacceptable_input_dims(self): 619 # for compiled ops, type checking is in torch/_meta_registrations.py 620 dtype: torch.dtype = torch.bfloat16 621 device = "cuda" 622 dtype_float8 = torch.float8_e4m3fn 623 M, K, N = 64, 15, 2048 # K needs to be a multiple of 16 624 x = torch.randn(M, K, dtype=dtype, device=device) 625 w = torch.randn(N, K, dtype=dtype, device=device) 626 bias = torch.randn(N, device=device, dtype=torch.bfloat16) 627 w_fp8, w_inverse_scale = _quantize_tensorwise(w, dtype_float8) 628 w_t_fp8 = w_fp8.t() 629 630 def linear(x, w_t_fp8, w_inverse_scale, bias): 631 x_fp8, x_inverse_scale = _quantize_tensorwise(x, dtype_float8) 632 y = torch._scaled_mm( 633 x_fp8, 634 w_t_fp8, 635 x_inverse_scale, 636 w_inverse_scale, 637 bias, 638 out_dtype=dtype, 639 use_fast_accum=True, 640 ) 641 return y 642 643 linear_compiled = torch.compile(linear, backend="inductor", mode="max-autotune") 644 with self.assertRaises(torch._dynamo.exc.TorchRuntimeError) as cm: 645 y_compiled = linear_compiled( 646 x, 647 w_t_fp8, 648 w_inverse_scale, 649 bias, 650 ) 651 self.assertTrue( 652 f"Expected self.size(1) to be divisible by 16, but got self.size(1)={K}" 653 in str(cm.exception) 654 ) 655 656 @unittest.skipIf(TEST_WITH_ROCM, "FP8 is not supported on ROCM") 657 @unittest.skipIf(not SM90OrLater, "FP8 is only supported on H100+") 658 def test_unacceptable_scale_dims_rowwise_scaling(self): 659 dtype: torch.dtype = torch.bfloat16 660 device = "cuda" 661 dtype_float8 = torch.float8_e4m3fn 662 M, K, N = 233, 32, 128 663 x = torch.randn(M, K, dtype=dtype, device=device) 664 w = torch.randn(N, K, dtype=dtype, device=device) 665 bias = torch.randn(N, device=device, dtype=torch.bfloat16) 666 w_fp8, w_inverse_scale = _quantize_rowwise(w, dtype_float8) 667 w_t_fp8 = w_fp8.t() 668 669 def linear(x, w_t_fp8, w_inverse_scale, bias): 670 x_fp8, x_inverse_scale = _quantize_rowwise(x, dtype_float8) 671 y = torch._scaled_mm( 672 x_fp8, 673 w_t_fp8, 674 w_inverse_scale.t(), # testing with w and x scales switched 675 x_inverse_scale, 676 bias, 677 out_dtype=dtype, 678 use_fast_accum=True, 679 ) 680 return y 681 682 linear_compiled = torch.compile(linear, backend="inductor", mode="max-autotune") 683 with self.assertRaises(torch._dynamo.exc.TorchRuntimeError) as cm: 684 y_compiled = linear_compiled( 685 x, 686 w_t_fp8, 687 w_inverse_scale, 688 bias, 689 ) 690 self.assertTrue("Invalid scaling configuration." in str(cm.exception)) 691 692 693if __name__ == "__main__": 694 if HAS_CUDA: 695 run_tests() 696