1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <utility>
17
18 #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
19 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
20 #include "tensorflow/compiler/xla/service/hlo_parser.h"
21 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
22 #include "tensorflow/core/platform/test.h"
23
24 namespace xla {
25 namespace gpu {
26 namespace {
27
28 class GpuSliceInputFusionTest : public GpuCodegenTest {
29 protected:
GpuSliceInputFusionTest()30 GpuSliceInputFusionTest() {}
31
ConfigWithoutLayoutAssignment()32 HloModuleConfig ConfigWithoutLayoutAssignment() {
33 HloModuleConfig config;
34 auto debug_options = HloTestBase::GetDebugOptionsForTest();
35 // Disable the layout_assignment pass to use the preassigned layouts;
36 // otherwise, the pass throws away the layouts in the fusion computation.
37 debug_options.add_xla_disable_hlo_passes("layout-assignment");
38 config.set_debug_options(debug_options);
39 return config;
40 }
41 };
42
TEST_F(GpuSliceInputFusionTest,InputFusionWithATupleOfSlices)43 TEST_F(GpuSliceInputFusionTest, InputFusionWithATupleOfSlices) {
44 const char *const kHloString = R"(
45 HloModule input_fusion_with_a_tuple_of_slices
46
47 fused_computation {
48 arg.1 = f16[1024,512]{1,0} parameter(0)
49 arg.2 = f16[1024,512]{1,0} parameter(1)
50 mul.1 = f16[1024,512]{1,0} multiply(arg.1, arg.2)
51 add.1 = f16[1024,512]{1,0} add(mul.1, arg.2)
52 slice.1 = f16[512,511]{1,0} slice(arg.1), slice={[512:1024], [1:512]}
53 slice.2 = f16[0,512]{1,0} slice(add.1), slice={[512:512], [0:512]}
54 slice.3 = f16[1,1]{1,0} slice(add.1), slice={[512:513], [511:512]}
55 ROOT tuple.1 = (f16[512,511]{1,0}, f16[0,512]{1,0}, f16[1,1]{1,0})
56 tuple(slice.1, slice.2, slice.3)
57 }
58
59 ENTRY kernel_entry {
60 arg.1 = f16[1024,512]{1,0} parameter(0)
61 arg.2 = f16[1024,512]{1,0} parameter(1)
62 ROOT fusion = (f16[512,511]{1,0}, f16[0,512]{1,0}, f16[1,1]{1,0})
63 fusion(arg.1, arg.2), kind=kInput, calls=fused_computation
64 })";
65
66 auto hlo_module =
67 ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment())
68 .ValueOrDie();
69 auto expected_ir = is_built_with_rocm_ ? R"(
70 ; CHECK-LABEL: define amdgpu_kernel void @fusion
71 ; CHECK: slice2
72 ; CHECK: }
73 )"
74 : R"(
75 ; CHECK-LABEL: define void @fusion
76 ; CHECK: slice2
77 ; CHECK: }
78 )";
79 CompileAndVerifyIr(std::move(hlo_module), expected_ir,
80 /*match_optimized_ir=*/false);
81 // Check that the kernel runs correctly.
82 EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0, 0}));
83 }
84
TEST_F(GpuSliceInputFusionTest,ConcatThenSplit)85 TEST_F(GpuSliceInputFusionTest, ConcatThenSplit) {
86 const char *const kHloString = R"(
87 HloModule input_fusion_with_a_tuple_of_slices
88
89 fused_computation {
90 arg.1 = f16[1024]{0} parameter(0)
91 arg.2 = f16[1024]{0} parameter(1)
92 arg.3 = f16[1023]{0} parameter(2)
93 arg.4 = f16[1023]{0} parameter(3)
94 mul.1 = f16[1024]{0} multiply(arg.1, arg.2)
95 add.1 = f16[1023]{0} add(arg.3, arg.4)
96 concat.1 = f16[2047]{0} concatenate(mul.1, add.1), dimensions={0}
97 slice.1 = f16[1024]{0} slice(concat.1), slice={[0:1024]}
98 slice.2 = f16[1023]{0} slice(concat.1), slice={[1024:2047]}
99 slice.3 = f16[0]{0} slice(concat.1), slice={[2047:2047]}
100 ROOT tuple.1 = (f16[1024]{0}, f16[1023]{0}, f16[0]{0})
101 tuple(slice.1, slice.2, slice.3)
102 }
103
104 ENTRY kernel_entry {
105 arg.1 = f16[1024]{0} parameter(0)
106 arg.2 = f16[1024]{0} parameter(1)
107 arg.3 = f16[1023]{0} parameter(2)
108 arg.4 = f16[1023]{0} parameter(3)
109 ROOT fusion = (f16[1024]{0}, f16[1023]{0}, f16[0]{0})
110 fusion(arg.1, arg.2, arg.3, arg.4), kind=kInput, calls=fused_computation
111 })";
112
113 auto hlo_module =
114 ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment())
115 .ValueOrDie();
116 auto expected_ir = is_built_with_rocm_ ? R"(
117 ; CHECK-LABEL: define amdgpu_kernel void @fusion
118 ; CHECK: slice2
119 ; CHECK: }
120 )"
121 : R"(
122 ; CHECK-LABEL: define void @fusion
123 ; CHECK: slice2
124 ; CHECK: }
125 )";
126 CompileAndVerifyIr(std::move(hlo_module), expected_ir,
127 /*match_optimized_ir=*/false);
128 // Check that the kernel runs correctly.
129 EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0, 0}));
130 }
131
132 } // namespace
133 } // namespace gpu
134 } // namespace xla
135