• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/gather_expander.h"
17 
18 #include "tensorflow/compiler/xla/service/hlo_query.h"
19 #include "tensorflow/compiler/xla/test.h"
20 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
21 #include "tensorflow/compiler/xla/tests/test_macros.h"
22 
23 namespace xla {
24 namespace {
25 
26 using GatherExpanderTest = HloTestBase;
27 
TEST_F(GatherExpanderTest,ErrorStatusOnTooManyIndices)28 TEST_F(GatherExpanderTest, ErrorStatusOnTooManyIndices) {
29   const string hlo_text = R"(
30 HloModule TensorFlowGatherMultipleBatchDims
31 
32 ENTRY main {
33   operand = s32[3,3] parameter(0)
34   indices = s32[2147483647,5] parameter(1)
35   ROOT gather = s32[2147483647,3,5] gather(operand, indices),
36       offset_dims={1},
37       collapsed_slice_dims={1},
38       start_index_map={1},
39       index_vector_dim=2,
40       slice_sizes={3, 1}
41 }
42 )";
43   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
44                           ParseAndReturnVerifiedModule(hlo_text));
45 
46   Status status = GatherExpander{GatherExpander::kEliminateAllGathers}
47                       .Run(module.get())
48                       .status();
49   EXPECT_EQ(status.code(), tensorflow::error::UNIMPLEMENTED);
50 
51   ASSERT_THAT(
52       status.error_message(),
53       ::testing::HasSubstr("Gather operations with more than 2147483647 gather "
54                            "indices are not supported."));
55 }
56 
TEST_F(GatherExpanderTest,AvoidDegenerateDims)57 TEST_F(GatherExpanderTest, AvoidDegenerateDims) {
58   const string hlo_text = R"(
59 HloModule TensorFlowGatherV2
60 
61 ENTRY main {
62   operand = s32[3,3] parameter(0)
63   indices = s32[2] parameter(1)
64   ROOT gather = s32[3,2] gather(operand, indices),
65       offset_dims={0},
66       collapsed_slice_dims={1},
67       start_index_map={1},
68       index_vector_dim=1,
69       slice_sizes={3, 1}
70 }
71 )";
72   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
73                           ParseAndReturnVerifiedModule(hlo_text));
74   TF_ASSERT_OK_AND_ASSIGN(
75       bool changed,
76       GatherExpander{GatherExpander::kEliminateAllGathers}.Run(module.get()));
77   ASSERT_TRUE(changed);
78 
79   HloInstruction* while_instr = nullptr;
80   for (auto* instr : module->entry_computation()->instructions()) {
81     if (instr->opcode() == HloOpcode::kWhile) {
82       ASSERT_EQ(while_instr, nullptr)
83           << "Expected exactly one while instruction in the entry computation "
84              "after gather expansion";
85       while_instr = instr;
86     }
87   }
88 
89   ASSERT_NE(while_instr, nullptr)
90       << "Expected exactly one while instruction in the entry computation "
91          "after gather expansion";
92 
93   // We want to avoid create while loop with shapes that have degenerate
94   // dimensions for TF gather.  In this case we expect the loop state to be of
95   // the shape (sNN[], s32[3,3]{1,0}, s32[2]{0}, s32[2,3]{1,0}).  The leading
96   // sNN is an implementation detail from WhileUtil::MakeCountedLoop so we don't
97   // check it here (though in theory the form of the while loop state is itself
98   // an implementation detail from WhileUtil::MakeCountedLoop).
99 
100   const Shape& while_shape = while_instr->shape();
101   ASSERT_TRUE(while_shape.IsTuple());
102   ASSERT_EQ(ShapeUtil::TupleElementCount(while_shape), 4);
103 
104   EXPECT_TRUE(ShapeUtil::SameDimensions(
105       ShapeUtil::MakeShape(S32, {3, 3}),
106       ShapeUtil::GetTupleElementShape(while_shape, 1)));
107 
108   EXPECT_TRUE(ShapeUtil::SameDimensions(
109       ShapeUtil::MakeShape(S32, {2}),
110       ShapeUtil::GetTupleElementShape(while_shape, 2)));
111 
112   EXPECT_TRUE(ShapeUtil::SameDimensions(
113       ShapeUtil::MakeShape(S32, {2, 3}),
114       ShapeUtil::GetTupleElementShape(while_shape, 3)));
115 }
116 
TEST_F(GatherExpanderTest,CheckOpMetadata)117 TEST_F(GatherExpanderTest, CheckOpMetadata) {
118   const string hlo_text = R"(
119 HloModule TensorFlowGatherV2
120 
121 ENTRY main {
122   operand = s32[3,3] parameter(0)
123   indices = s32[2] parameter(1)
124   ROOT gather = s32[3,2] gather(operand, indices),
125       offset_dims={0},
126       collapsed_slice_dims={1},
127       start_index_map={1},
128       index_vector_dim=1,
129       slice_sizes={3, 1}
130 }
131 )";
132   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
133                           ParseAndReturnVerifiedModule(hlo_text));
134   OpMetadata metadata;
135   metadata.set_op_name("Gather");
136   module->entry_computation()->root_instruction()->set_metadata(metadata);
137   TF_ASSERT_OK_AND_ASSIGN(
138       bool changed,
139       GatherExpander{GatherExpander::kEliminateAllGathers}.Run(module.get()));
140   ASSERT_TRUE(changed);
141 
142   HloInstruction* while_instr = nullptr;
143   for (auto* instr : module->entry_computation()->instructions()) {
144     if (instr->opcode() == HloOpcode::kWhile) {
145       ASSERT_EQ(while_instr, nullptr)
146           << "Expected exactly one while instruction in the entry computation "
147              "after gather expansion";
148       while_instr = instr;
149     }
150   }
151 
152   ASSERT_NE(while_instr, nullptr)
153       << "Expected exactly one while instruction in the entry computation "
154          "after gather expansion";
155   EXPECT_EQ(while_instr->metadata().op_name(), "Gather");
156 }
157 
TEST_F(GatherExpanderTest,EliminateSimpleGathersSkipsNontrivialGather)158 TEST_F(GatherExpanderTest, EliminateSimpleGathersSkipsNontrivialGather) {
159   const string hlo_text = R"(
160 HloModule TensorFlowGatherV1
161 
162 ENTRY main {
163   operand = s32[3,3] parameter(0)
164   indices = s32[2] parameter(1)
165   ROOT gather = s32[2,3] gather(operand, indices),
166       offset_dims={1},
167       collapsed_slice_dims={0},
168       start_index_map={0},
169       index_vector_dim=1,
170       slice_sizes={1, 3}
171 }
172 )";
173 
174   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
175                           ParseAndReturnVerifiedModule(hlo_text));
176   GatherExpander pass(GatherExpander::kEliminateSimpleGathers);
177   TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&pass, module.get()));
178   ASSERT_FALSE(changed);
179 }
180 
TEST_F(GatherExpanderTest,EliminateSimpleGathersRewritesTrivialGather)181 TEST_F(GatherExpanderTest, EliminateSimpleGathersRewritesTrivialGather) {
182   const string hlo_text = R"(
183 HloModule test
184 
185 ENTRY main {
186   operand = s32[100] parameter(0)
187   indices = s32[1] parameter(1)
188   ROOT gather = s32[10] gather(operand, indices),
189       offset_dims={0},
190       collapsed_slice_dims={},
191       start_index_map={0},
192       index_vector_dim=0,
193       slice_sizes={10}
194 }
195 )";
196 
197   TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
198                           ParseAndReturnVerifiedModule(hlo_text));
199   GatherExpander pass(GatherExpander::kEliminateAllGathers);
200   TF_ASSERT_OK_AND_ASSIGN(bool changed, RunHloPass(&pass, module.get()));
201   ASSERT_TRUE(changed);
202   ASSERT_FALSE(hlo_query::ContainsInstrWithOpcode(module->entry_computation(),
203                                                   {HloOpcode::kGather}));
204 }
205 
206 }  // namespace
207 }  // namespace xla
208