• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
17 
18 #include "tensorflow/compiler/xla/service/hlo_computation.h"
19 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
20 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
21 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
22 #include "tensorflow/compiler/xla/test.h"
23 #include "tensorflow/compiler/xla/test_helpers.h"
24 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
25 
26 namespace xla {
27 namespace {
28 
29 namespace op = xla::testing::opcode_matchers;
30 class DynamicIndexSplitterTest : public HloTestBase {};
31 
TEST_F(DynamicIndexSplitterTest,DynamicSlice)32 TEST_F(DynamicIndexSplitterTest, DynamicSlice) {
33   const char* const kDynamicSlice = R"(
34     HloModule DynamicSlice_module
35 
36     ENTRY entry (operand: s32[4,5,6], indices: s32[3]) -> s32[1,1,1] {
37       operand = s32[4,5,6] parameter(0)
38       indices = s32[3] parameter(1)
39       ROOT dynamic-slice = s32[1,1,1] dynamic-slice(operand, indices), dynamic_slice_sizes={1,1,1}
40     }
41   )";
42 
43   HloModuleConfig config;
44   DebugOptions debug_options = config.debug_options();
45   debug_options.set_xla_allow_scalar_index_dynamic_ops(true);
46   config.set_debug_options(debug_options);
47 
48   TF_ASSERT_OK_AND_ASSIGN(auto module,
49                           ParseAndReturnVerifiedModule(kDynamicSlice, config));
50   TF_ASSERT_OK_AND_ASSIGN(bool changed,
51                           DynamicIndexSplitter().Run(module.get()));
52   EXPECT_TRUE(changed);
53   ASSERT_THAT(module->entry_computation()->root_instruction(),
54               op::DynamicSlice(op::Parameter(0),
55                                op::Reshape(op::Slice(op::Parameter(1))),
56                                op::Reshape(op::Slice(op::Parameter(1))),
57                                op::Reshape(op::Slice(op::Parameter(1)))));
58 
59   for (int i = 0; i < 3; ++i) {
60     const HloInstruction* slice = module->entry_computation()
61                                       ->root_instruction()
62                                       ->operand(i + 1)
63                                       ->operand(0);
64     EXPECT_EQ(slice->slice_starts(0), i);
65     EXPECT_EQ(slice->slice_limits(0), i + 1);
66   }
67 }
68 
TEST_F(DynamicIndexSplitterTest,DynamicUpdateSlice)69 TEST_F(DynamicIndexSplitterTest, DynamicUpdateSlice) {
70   const char* const kDynamicUpdateSlice = R"(
71     HloModule DynamicUpdatedSlice_module
72 
73     ENTRY entry (operand: s32[4,5,6], indices: s32[3], update: s32[1,1,1]) -> s32[4,5,6] {
74       operand = s32[4,5,6] parameter(0)
75       indices = s32[3] parameter(1)
76       update = s32[1,1,1] parameter(2)
77       ROOT dynamic-update-slice = s32[4,5,6] dynamic-update-slice(operand, update, indices)
78     }
79   )";
80 
81   HloModuleConfig config;
82   DebugOptions debug_options = config.debug_options();
83   debug_options.set_xla_allow_scalar_index_dynamic_ops(true);
84   config.set_debug_options(debug_options);
85 
86   TF_ASSERT_OK_AND_ASSIGN(
87       auto module, ParseAndReturnVerifiedModule(kDynamicUpdateSlice, config));
88   TF_ASSERT_OK_AND_ASSIGN(bool changed,
89                           DynamicIndexSplitter().Run(module.get()));
90   EXPECT_TRUE(changed);
91   ASSERT_THAT(module->entry_computation()->root_instruction(),
92               op::DynamicUpdateSlice(op::Parameter(0), op::Parameter(2),
93                                      op::Reshape(op::Slice(op::Parameter(1))),
94                                      op::Reshape(op::Slice(op::Parameter(1))),
95                                      op::Reshape(op::Slice(op::Parameter(1)))));
96 
97   for (int i = 0; i < 3; ++i) {
98     const HloInstruction* slice = module->entry_computation()
99                                       ->root_instruction()
100                                       ->operand(i + 2)
101                                       ->operand(0);
102     EXPECT_EQ(slice->slice_starts(0), i);
103     EXPECT_EQ(slice->slice_limits(0), i + 1);
104   }
105 }
106 
TEST_F(DynamicIndexSplitterTest,AlreadyScalar)107 TEST_F(DynamicIndexSplitterTest, AlreadyScalar) {
108   const char* const kDynamicSlice = R"(
109     HloModule DynamicSlice_module
110 
111     ENTRY entry (operand: s32[4,5,6], index.0: s32[], index.1: s32[], index.2: s32[]) -> s32[1,1,1] {
112       operand = s32[4,5,6] parameter(0)
113       index.0 = s32[] parameter(1)
114       index.1 = s32[] parameter(2)
115       index.2 = s32[] parameter(3)
116       ROOT dynamic-slice = s32[1,1,1] dynamic-slice(operand, index.0, index.1, index.2), dynamic_slice_sizes={1,1,1}
117     }
118   )";
119 
120   HloModuleConfig config;
121   DebugOptions debug_options = config.debug_options();
122   debug_options.set_xla_allow_scalar_index_dynamic_ops(true);
123   config.set_debug_options(debug_options);
124 
125   TF_ASSERT_OK_AND_ASSIGN(auto module,
126                           ParseAndReturnVerifiedModule(kDynamicSlice, config));
127   TF_ASSERT_OK_AND_ASSIGN(bool changed,
128                           DynamicIndexSplitter().Run(module.get()));
129   EXPECT_FALSE(changed);
130   EXPECT_THAT(module->entry_computation()->root_instruction(),
131               op::DynamicSlice(op::Parameter(0), op::Parameter(1),
132                                op::Parameter(2), op::Parameter(3)));
133 }
134 
135 }  // namespace
136 }  // namespace xla
137