1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
17
18 #include "tensorflow/compiler/xla/service/hlo_computation.h"
19 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
20 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
21 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
22 #include "tensorflow/compiler/xla/test.h"
23 #include "tensorflow/compiler/xla/test_helpers.h"
24 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
25
26 namespace xla {
27 namespace {
28
29 namespace op = xla::testing::opcode_matchers;
30 class DynamicIndexSplitterTest : public HloTestBase {};
31
TEST_F(DynamicIndexSplitterTest,DynamicSlice)32 TEST_F(DynamicIndexSplitterTest, DynamicSlice) {
33 const char* const kDynamicSlice = R"(
34 HloModule DynamicSlice_module
35
36 ENTRY entry (operand: s32[4,5,6], indices: s32[3]) -> s32[1,1,1] {
37 operand = s32[4,5,6] parameter(0)
38 indices = s32[3] parameter(1)
39 ROOT dynamic-slice = s32[1,1,1] dynamic-slice(operand, indices), dynamic_slice_sizes={1,1,1}
40 }
41 )";
42
43 HloModuleConfig config;
44 DebugOptions debug_options = config.debug_options();
45 debug_options.set_xla_allow_scalar_index_dynamic_ops(true);
46 config.set_debug_options(debug_options);
47
48 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kDynamicSlice, config));
49 TF_ASSERT_OK_AND_ASSIGN(bool changed,
50 DynamicIndexSplitter().Run(module.get()));
51 EXPECT_TRUE(changed);
52 ASSERT_THAT(module->entry_computation()->root_instruction(),
53 op::DynamicSlice(op::Parameter(0),
54 op::Reshape(op::Slice(op::Parameter(1))),
55 op::Reshape(op::Slice(op::Parameter(1))),
56 op::Reshape(op::Slice(op::Parameter(1)))));
57
58 for (int i = 0; i < 3; ++i) {
59 const HloInstruction* slice = module->entry_computation()
60 ->root_instruction()
61 ->operand(i + 1)
62 ->operand(0);
63 EXPECT_EQ(slice->slice_starts(0), i);
64 EXPECT_EQ(slice->slice_limits(0), i + 1);
65 }
66 }
67
TEST_F(DynamicIndexSplitterTest,DynamicUpdateSlice)68 TEST_F(DynamicIndexSplitterTest, DynamicUpdateSlice) {
69 const char* const kDynamicUpdateSlice = R"(
70 HloModule DynamicUpdatedSlice_module
71
72 ENTRY entry (operand: s32[4,5,6], indices: s32[3], update: s32[1,1,1]) -> s32[4,5,6] {
73 operand = s32[4,5,6] parameter(0)
74 indices = s32[3] parameter(1)
75 update = s32[1,1,1] parameter(2)
76 ROOT dynamic-update-slice = s32[4,5,6] dynamic-update-slice(operand, update, indices)
77 }
78 )";
79
80 HloModuleConfig config;
81 DebugOptions debug_options = config.debug_options();
82 debug_options.set_xla_allow_scalar_index_dynamic_ops(true);
83 config.set_debug_options(debug_options);
84
85 TF_ASSERT_OK_AND_ASSIGN(auto module,
86 ParseHloString(kDynamicUpdateSlice, config));
87 TF_ASSERT_OK_AND_ASSIGN(bool changed,
88 DynamicIndexSplitter().Run(module.get()));
89 EXPECT_TRUE(changed);
90 ASSERT_THAT(module->entry_computation()->root_instruction(),
91 op::DynamicUpdateSlice(op::Parameter(0), op::Parameter(2),
92 op::Reshape(op::Slice(op::Parameter(1))),
93 op::Reshape(op::Slice(op::Parameter(1))),
94 op::Reshape(op::Slice(op::Parameter(1)))));
95
96 for (int i = 0; i < 3; ++i) {
97 const HloInstruction* slice = module->entry_computation()
98 ->root_instruction()
99 ->operand(i + 2)
100 ->operand(0);
101 EXPECT_EQ(slice->slice_starts(0), i);
102 EXPECT_EQ(slice->slice_limits(0), i + 1);
103 }
104 }
105
TEST_F(DynamicIndexSplitterTest,AlreadyScalar)106 TEST_F(DynamicIndexSplitterTest, AlreadyScalar) {
107 const char* const kDynamicSlice = R"(
108 HloModule DynamicSlice_module
109
110 ENTRY entry (operand: s32[4,5,6], index.0: s32[], index.1: s32[], index.2: s32[]) -> s32[1,1,1] {
111 operand = s32[4,5,6] parameter(0)
112 index.0 = s32[] parameter(1)
113 index.1 = s32[] parameter(2)
114 index.2 = s32[] parameter(3)
115 ROOT dynamic-slice = s32[1,1,1] dynamic-slice(operand, index.0, index.1, index.2), dynamic_slice_sizes={1,1,1}
116 }
117 )";
118
119 HloModuleConfig config;
120 DebugOptions debug_options = config.debug_options();
121 debug_options.set_xla_allow_scalar_index_dynamic_ops(true);
122 config.set_debug_options(debug_options);
123
124 TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kDynamicSlice, config));
125 TF_ASSERT_OK_AND_ASSIGN(bool changed,
126 DynamicIndexSplitter().Run(module.get()));
127 EXPECT_FALSE(changed);
128 EXPECT_THAT(module->entry_computation()->root_instruction(),
129 op::DynamicSlice(op::Parameter(0), op::Parameter(1),
130 op::Parameter(2), op::Parameter(3)));
131 }
132
133 } // namespace
134 } // namespace xla
135