1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
17
18 #include "tensorflow/compiler/xla/service/hlo_computation.h"
19 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
20 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
21 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
22 #include "tensorflow/compiler/xla/test.h"
23 #include "tensorflow/compiler/xla/test_helpers.h"
24 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
25
26 namespace xla {
27 namespace {
28
29 namespace op = xla::testing::opcode_matchers;
30 class DynamicIndexSplitterTest : public HloTestBase {};
31
TEST_F(DynamicIndexSplitterTest,DynamicSlice)32 TEST_F(DynamicIndexSplitterTest, DynamicSlice) {
33 const char* const kDynamicSlice = R"(
34 HloModule DynamicSlice_module
35
36 ENTRY entry (operand: s32[4,5,6], indices: s32[3]) -> s32[1,1,1] {
37 operand = s32[4,5,6] parameter(0)
38 indices = s32[3] parameter(1)
39 ROOT dynamic-slice = s32[1,1,1] dynamic-slice(operand, indices), dynamic_slice_sizes={1,1,1}
40 }
41 )";
42
43 HloModuleConfig config;
44 DebugOptions debug_options = config.debug_options();
45 debug_options.set_xla_allow_scalar_index_dynamic_ops(true);
46 config.set_debug_options(debug_options);
47
48 TF_ASSERT_OK_AND_ASSIGN(auto module,
49 ParseAndReturnVerifiedModule(kDynamicSlice, config));
50 TF_ASSERT_OK_AND_ASSIGN(bool changed,
51 DynamicIndexSplitter().Run(module.get()));
52 EXPECT_TRUE(changed);
53 ASSERT_THAT(module->entry_computation()->root_instruction(),
54 op::DynamicSlice(op::Parameter(0),
55 op::Reshape(op::Slice(op::Parameter(1))),
56 op::Reshape(op::Slice(op::Parameter(1))),
57 op::Reshape(op::Slice(op::Parameter(1)))));
58
59 for (int i = 0; i < 3; ++i) {
60 const HloInstruction* slice = module->entry_computation()
61 ->root_instruction()
62 ->operand(i + 1)
63 ->operand(0);
64 EXPECT_EQ(slice->slice_starts(0), i);
65 EXPECT_EQ(slice->slice_limits(0), i + 1);
66 }
67 }
68
TEST_F(DynamicIndexSplitterTest,DynamicUpdateSlice)69 TEST_F(DynamicIndexSplitterTest, DynamicUpdateSlice) {
70 const char* const kDynamicUpdateSlice = R"(
71 HloModule DynamicUpdatedSlice_module
72
73 ENTRY entry (operand: s32[4,5,6], indices: s32[3], update: s32[1,1,1]) -> s32[4,5,6] {
74 operand = s32[4,5,6] parameter(0)
75 indices = s32[3] parameter(1)
76 update = s32[1,1,1] parameter(2)
77 ROOT dynamic-update-slice = s32[4,5,6] dynamic-update-slice(operand, update, indices)
78 }
79 )";
80
81 HloModuleConfig config;
82 DebugOptions debug_options = config.debug_options();
83 debug_options.set_xla_allow_scalar_index_dynamic_ops(true);
84 config.set_debug_options(debug_options);
85
86 TF_ASSERT_OK_AND_ASSIGN(
87 auto module, ParseAndReturnVerifiedModule(kDynamicUpdateSlice, config));
88 TF_ASSERT_OK_AND_ASSIGN(bool changed,
89 DynamicIndexSplitter().Run(module.get()));
90 EXPECT_TRUE(changed);
91 ASSERT_THAT(module->entry_computation()->root_instruction(),
92 op::DynamicUpdateSlice(op::Parameter(0), op::Parameter(2),
93 op::Reshape(op::Slice(op::Parameter(1))),
94 op::Reshape(op::Slice(op::Parameter(1))),
95 op::Reshape(op::Slice(op::Parameter(1)))));
96
97 for (int i = 0; i < 3; ++i) {
98 const HloInstruction* slice = module->entry_computation()
99 ->root_instruction()
100 ->operand(i + 2)
101 ->operand(0);
102 EXPECT_EQ(slice->slice_starts(0), i);
103 EXPECT_EQ(slice->slice_limits(0), i + 1);
104 }
105 }
106
TEST_F(DynamicIndexSplitterTest,AlreadyScalar)107 TEST_F(DynamicIndexSplitterTest, AlreadyScalar) {
108 const char* const kDynamicSlice = R"(
109 HloModule DynamicSlice_module
110
111 ENTRY entry (operand: s32[4,5,6], index.0: s32[], index.1: s32[], index.2: s32[]) -> s32[1,1,1] {
112 operand = s32[4,5,6] parameter(0)
113 index.0 = s32[] parameter(1)
114 index.1 = s32[] parameter(2)
115 index.2 = s32[] parameter(3)
116 ROOT dynamic-slice = s32[1,1,1] dynamic-slice(operand, index.0, index.1, index.2), dynamic_slice_sizes={1,1,1}
117 }
118 )";
119
120 HloModuleConfig config;
121 DebugOptions debug_options = config.debug_options();
122 debug_options.set_xla_allow_scalar_index_dynamic_ops(true);
123 config.set_debug_options(debug_options);
124
125 TF_ASSERT_OK_AND_ASSIGN(auto module,
126 ParseAndReturnVerifiedModule(kDynamicSlice, config));
127 TF_ASSERT_OK_AND_ASSIGN(bool changed,
128 DynamicIndexSplitter().Run(module.get()));
129 EXPECT_FALSE(changed);
130 EXPECT_THAT(module->entry_computation()->root_instruction(),
131 op::DynamicSlice(op::Parameter(0), op::Parameter(1),
132 op::Parameter(2), op::Parameter(3)));
133 }
134
135 } // namespace
136 } // namespace xla
137