1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <memory>
17
18 #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
19 #include "tensorflow/compiler/xla/tests/verified_hlo_module.h"
20
21 // Check that the ftz (flush denormals to zero) flag is reflected in PTX as
22 // expected.
23
24 namespace xla {
25 namespace gpu {
26 namespace {
27
28 class GpuFtzTest : public GpuCodegenTest {
29 public:
GpuFtzTest(bool ftz)30 explicit GpuFtzTest(bool ftz) : ftz_(ftz) {}
31
32 // Creates an HLO module that performs the given binary operation on some
33 // data.
CreateBinaryOpModule(HloOpcode op)34 std::unique_ptr<VerifiedHloModule> CreateBinaryOpModule(HloOpcode op) {
35 HloComputation::Builder builder(TestName());
36
37 Shape param_shape = ShapeUtil::MakeShapeWithLayout(
38 F32, /*dimensions=*/{100, 100}, /*minor_to_major=*/{1, 0});
39 HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
40 /* parameter_number=*/0, param_shape, "x"));
41 HloInstruction* y = builder.AddInstruction(HloInstruction::CreateParameter(
42 /* parameter_number=*/1, param_shape, "y"));
43 builder.AddInstruction(HloInstruction::CreateBinary(param_shape, op, x, y));
44
45 auto hlo_module = CreateNewVerifiedModuleWithFTZ(ftz_);
46 hlo_module->AddEntryComputation(builder.Build());
47 return hlo_module;
48 }
49
50 // Creates an HLO module that performs the given unary operation on some data.
CreateUnaryOpModule(HloOpcode op)51 std::unique_ptr<VerifiedHloModule> CreateUnaryOpModule(HloOpcode op) {
52 HloComputation::Builder builder(TestName());
53
54 Shape param_shape = ShapeUtil::MakeShapeWithLayout(
55 F32, /*dimensions=*/{100, 100}, /*minor_to_major=*/{1, 0});
56 HloInstruction* x = builder.AddInstruction(HloInstruction::CreateParameter(
57 /* parameter_number=*/0, param_shape, "x"));
58 builder.AddInstruction(HloInstruction::CreateUnary(param_shape, op, x));
59
60 auto hlo_module = CreateNewVerifiedModuleWithFTZ(ftz_);
61 hlo_module->AddEntryComputation(builder.Build());
62 return hlo_module;
63 }
64
65 bool ftz_;
66 };
67
68 class GpuFtzEnabledTest : public GpuFtzTest {
69 public:
GpuFtzEnabledTest()70 GpuFtzEnabledTest() : GpuFtzTest(/*ftz=*/true) {}
71 };
72
73 class GpuFtzDisabledTest : public GpuFtzTest {
74 public:
GpuFtzDisabledTest()75 GpuFtzDisabledTest() : GpuFtzTest(/*ftz=*/false) {}
76 };
77
78 // Check that we emit mul.ftz.f32 when in ftz mode, and plain mul.f32 otherwise.
TEST_F(GpuFtzEnabledTest,MultiplyFtz)79 TEST_F(GpuFtzEnabledTest, MultiplyFtz) {
80 CompileAndOptionallyVerifyPtx(CreateBinaryOpModule(HloOpcode::kMultiply), R"(
81 CHECK-NOT: mul.rn.f32
82 CHECK: mul.rn.ftz.f32
83 CHECK-NOT: mul.rn.f32
84 )");
85 }
TEST_F(GpuFtzDisabledTest,MultiplyFtz)86 TEST_F(GpuFtzDisabledTest, MultiplyFtz) {
87 CompileAndOptionallyVerifyPtx(CreateBinaryOpModule(HloOpcode::kMultiply), R"(
88 CHECK-NOT: mul.rn.ftz.f32
89 CHECK: mul.rn.f32
90 CHECK-NOT: mul.rn.ftz.f32
91 )");
92 }
93
94 // In NVPTX, exp(float) is implemented in libdevice, and consults __nvvm_reflect
95 // to determine whether or not ftz is enabled.
96 // The implementation in CUDA 11 uses one ex2.approx.ftz, irrespective of ftz
97 // being enabled or not. The ftz flag is reflected in the Newton iteration.
TEST_F(GpuFtzEnabledTest,ExpFtz)98 TEST_F(GpuFtzEnabledTest, ExpFtz) {
99 CompileAndOptionallyVerifyPtx(CreateUnaryOpModule(HloOpcode::kExp), R"(
100 CHECK: ex2.approx.ftz.f32
101 CHECK-NEXT: mul.rn.ftz.f32
102 CHECK-NOT: ex2.approx.f32
103 )");
104 }
105
TEST_F(GpuFtzDisabledTest,ExpFtz)106 TEST_F(GpuFtzDisabledTest, ExpFtz) {
107 CompileAndOptionallyVerifyPtx(CreateUnaryOpModule(HloOpcode::kExp), R"(
108 CHECK: ex2.approx.ftz.f32
109 CHECK-NEXT: mul.rn.f32
110 CHECK-NOT: ex2.approx.f32
111 )");
112 }
113
114 } // namespace
115 } // namespace gpu
116 } // namespace xla
117