• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <utility>
17 
18 #include "tensorflow/compiler/xla/service/gpu/tests/gpu_codegen_test.h"
19 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
20 #include "tensorflow/compiler/xla/service/hlo_parser.h"
21 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
22 #include "tensorflow/core/platform/test.h"
23 
24 namespace xla {
25 namespace gpu {
26 namespace {
27 
28 class GpuSliceInputFusionTest : public GpuCodegenTest {
29  protected:
GpuSliceInputFusionTest()30   GpuSliceInputFusionTest() {}
31 
ConfigWithoutLayoutAssignment()32   HloModuleConfig ConfigWithoutLayoutAssignment() {
33     HloModuleConfig config;
34     auto debug_options = HloTestBase::GetDebugOptionsForTest();
35     // Disable the layout_assignment pass to use the preassigned layouts;
36     // otherwise, the pass throws away the layouts in the fusion computation.
37     debug_options.add_xla_disable_hlo_passes("layout-assignment");
38     config.set_debug_options(debug_options);
39     return config;
40   }
41 };
42 
TEST_F(GpuSliceInputFusionTest,InputFusionWithATupleOfSlices)43 TEST_F(GpuSliceInputFusionTest, InputFusionWithATupleOfSlices) {
44   const char *const kHloString = R"(
45   HloModule input_fusion_with_a_tuple_of_slices
46 
47   fused_computation {
48     arg.1 = f16[1024,512]{1,0} parameter(0)
49     arg.2 = f16[1024,512]{1,0} parameter(1)
50     mul.1 = f16[1024,512]{1,0} multiply(arg.1, arg.2)
51     add.1 = f16[1024,512]{1,0} add(mul.1, arg.2)
52     slice.1 = f16[512,511]{1,0} slice(arg.1), slice={[512:1024], [1:512]}
53     slice.2 = f16[0,512]{1,0} slice(add.1), slice={[512:512], [0:512]}
54     slice.3 = f16[1,1]{1,0} slice(add.1), slice={[512:513], [511:512]}
55     ROOT tuple.1 = (f16[512,511]{1,0}, f16[0,512]{1,0}, f16[1,1]{1,0})
56         tuple(slice.1, slice.2, slice.3)
57   }
58 
59   ENTRY kernel_entry {
60     arg.1 = f16[1024,512]{1,0} parameter(0)
61     arg.2 = f16[1024,512]{1,0} parameter(1)
62     ROOT fusion = (f16[512,511]{1,0}, f16[0,512]{1,0}, f16[1,1]{1,0})
63         fusion(arg.1, arg.2), kind=kInput, calls=fused_computation
64   })";
65 
66   auto hlo_module =
67       ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment())
68           .ValueOrDie();
69   auto expected_ir = is_built_with_rocm_ ? R"(
70 ; CHECK-LABEL: define amdgpu_kernel void @fusion
71 ; CHECK: slice2
72 ; CHECK: }
73 )"
74                                          : R"(
75 ; CHECK-LABEL: define void @fusion
76 ; CHECK: slice2
77 ; CHECK: }
78 )";
79   CompileAndVerifyIr(std::move(hlo_module), expected_ir,
80                      /*match_optimized_ir=*/false);
81   // Check that the kernel runs correctly.
82   EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0, 0}));
83 }
84 
TEST_F(GpuSliceInputFusionTest,ConcatThenSplit)85 TEST_F(GpuSliceInputFusionTest, ConcatThenSplit) {
86   const char *const kHloString = R"(
87   HloModule input_fusion_with_a_tuple_of_slices
88 
89   fused_computation {
90     arg.1 = f16[1024]{0} parameter(0)
91     arg.2 = f16[1024]{0} parameter(1)
92     arg.3 = f16[1023]{0} parameter(2)
93     arg.4 = f16[1023]{0} parameter(3)
94     mul.1 = f16[1024]{0} multiply(arg.1, arg.2)
95     add.1 = f16[1023]{0} add(arg.3, arg.4)
96     concat.1 = f16[2047]{0} concatenate(mul.1, add.1), dimensions={0}
97     slice.1 = f16[1024]{0} slice(concat.1), slice={[0:1024]}
98     slice.2 = f16[1023]{0} slice(concat.1), slice={[1024:2047]}
99     slice.3 = f16[0]{0} slice(concat.1), slice={[2047:2047]}
100     ROOT tuple.1 = (f16[1024]{0}, f16[1023]{0}, f16[0]{0})
101         tuple(slice.1, slice.2, slice.3)
102   }
103 
104   ENTRY kernel_entry {
105     arg.1 = f16[1024]{0} parameter(0)
106     arg.2 = f16[1024]{0} parameter(1)
107     arg.3 = f16[1023]{0} parameter(2)
108     arg.4 = f16[1023]{0} parameter(3)
109     ROOT fusion = (f16[1024]{0}, f16[1023]{0}, f16[0]{0})
110         fusion(arg.1, arg.2, arg.3, arg.4), kind=kInput, calls=fused_computation
111   })";
112 
113   auto hlo_module =
114       ParseAndReturnVerifiedModule(kHloString, ConfigWithoutLayoutAssignment())
115           .ValueOrDie();
116   auto expected_ir = is_built_with_rocm_ ? R"(
117 ; CHECK-LABEL: define amdgpu_kernel void @fusion
118 ; CHECK: slice2
119 ; CHECK: }
120 )"
121                                          : R"(
122 ; CHECK-LABEL: define void @fusion
123 ; CHECK: slice2
124 ; CHECK: }
125 )";
126   CompileAndVerifyIr(std::move(hlo_module), expected_ir,
127                      /*match_optimized_ir=*/false);
128   // Check that the kernel runs correctly.
129   EXPECT_TRUE(RunAndCompareNoHloPasses(kHloString, ErrorSpec{0, 0}));
130 }
131 
132 }  // namespace
133 }  // namespace gpu
134 }  // namespace xla
135