• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/dynamic_index_splitter.h"
17 
18 #include "tensorflow/compiler/xla/service/hlo_computation.h"
19 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
20 #include "tensorflow/compiler/xla/service/hlo_matchers.h"
21 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
22 #include "tensorflow/compiler/xla/test.h"
23 #include "tensorflow/compiler/xla/test_helpers.h"
24 #include "tensorflow/compiler/xla/tests/hlo_test_base.h"
25 
26 namespace xla {
27 namespace {
28 
29 namespace op = xla::testing::opcode_matchers;
30 class DynamicIndexSplitterTest : public HloTestBase {};
31 
TEST_F(DynamicIndexSplitterTest,DynamicSlice)32 TEST_F(DynamicIndexSplitterTest, DynamicSlice) {
33   const char* const kDynamicSlice = R"(
34     HloModule DynamicSlice_module
35 
36     ENTRY entry (operand: s32[4,5,6], indices: s32[3]) -> s32[1,1,1] {
37       operand = s32[4,5,6] parameter(0)
38       indices = s32[3] parameter(1)
39       ROOT dynamic-slice = s32[1,1,1] dynamic-slice(operand, indices), dynamic_slice_sizes={1,1,1}
40     }
41   )";
42 
43   HloModuleConfig config;
44   DebugOptions debug_options = config.debug_options();
45   debug_options.set_xla_allow_scalar_index_dynamic_ops(true);
46   config.set_debug_options(debug_options);
47 
48   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kDynamicSlice, config));
49   TF_ASSERT_OK_AND_ASSIGN(bool changed,
50                           DynamicIndexSplitter().Run(module.get()));
51   EXPECT_TRUE(changed);
52   ASSERT_THAT(module->entry_computation()->root_instruction(),
53               op::DynamicSlice(op::Parameter(0),
54                                op::Reshape(op::Slice(op::Parameter(1))),
55                                op::Reshape(op::Slice(op::Parameter(1))),
56                                op::Reshape(op::Slice(op::Parameter(1)))));
57 
58   for (int i = 0; i < 3; ++i) {
59     const HloInstruction* slice = module->entry_computation()
60                                       ->root_instruction()
61                                       ->operand(i + 1)
62                                       ->operand(0);
63     EXPECT_EQ(slice->slice_starts(0), i);
64     EXPECT_EQ(slice->slice_limits(0), i + 1);
65   }
66 }
67 
TEST_F(DynamicIndexSplitterTest,DynamicUpdateSlice)68 TEST_F(DynamicIndexSplitterTest, DynamicUpdateSlice) {
69   const char* const kDynamicUpdateSlice = R"(
70     HloModule DynamicUpdatedSlice_module
71 
72     ENTRY entry (operand: s32[4,5,6], indices: s32[3], update: s32[1,1,1]) -> s32[4,5,6] {
73       operand = s32[4,5,6] parameter(0)
74       indices = s32[3] parameter(1)
75       update = s32[1,1,1] parameter(2)
76       ROOT dynamic-update-slice = s32[4,5,6] dynamic-update-slice(operand, update, indices)
77     }
78   )";
79 
80   HloModuleConfig config;
81   DebugOptions debug_options = config.debug_options();
82   debug_options.set_xla_allow_scalar_index_dynamic_ops(true);
83   config.set_debug_options(debug_options);
84 
85   TF_ASSERT_OK_AND_ASSIGN(auto module,
86                           ParseHloString(kDynamicUpdateSlice, config));
87   TF_ASSERT_OK_AND_ASSIGN(bool changed,
88                           DynamicIndexSplitter().Run(module.get()));
89   EXPECT_TRUE(changed);
90   ASSERT_THAT(module->entry_computation()->root_instruction(),
91               op::DynamicUpdateSlice(op::Parameter(0), op::Parameter(2),
92                                      op::Reshape(op::Slice(op::Parameter(1))),
93                                      op::Reshape(op::Slice(op::Parameter(1))),
94                                      op::Reshape(op::Slice(op::Parameter(1)))));
95 
96   for (int i = 0; i < 3; ++i) {
97     const HloInstruction* slice = module->entry_computation()
98                                       ->root_instruction()
99                                       ->operand(i + 2)
100                                       ->operand(0);
101     EXPECT_EQ(slice->slice_starts(0), i);
102     EXPECT_EQ(slice->slice_limits(0), i + 1);
103   }
104 }
105 
TEST_F(DynamicIndexSplitterTest,AlreadyScalar)106 TEST_F(DynamicIndexSplitterTest, AlreadyScalar) {
107   const char* const kDynamicSlice = R"(
108     HloModule DynamicSlice_module
109 
110     ENTRY entry (operand: s32[4,5,6], index.0: s32[], index.1: s32[], index.2: s32[]) -> s32[1,1,1] {
111       operand = s32[4,5,6] parameter(0)
112       index.0 = s32[] parameter(1)
113       index.1 = s32[] parameter(2)
114       index.2 = s32[] parameter(3)
115       ROOT dynamic-slice = s32[1,1,1] dynamic-slice(operand, index.0, index.1, index.2), dynamic_slice_sizes={1,1,1}
116     }
117   )";
118 
119   HloModuleConfig config;
120   DebugOptions debug_options = config.debug_options();
121   debug_options.set_xla_allow_scalar_index_dynamic_ops(true);
122   config.set_debug_options(debug_options);
123 
124   TF_ASSERT_OK_AND_ASSIGN(auto module, ParseHloString(kDynamicSlice, config));
125   TF_ASSERT_OK_AND_ASSIGN(bool changed,
126                           DynamicIndexSplitter().Run(module.get()));
127   EXPECT_FALSE(changed);
128   EXPECT_THAT(module->entry_computation()->root_instruction(),
129               op::DynamicSlice(op::Parameter(0), op::Parameter(1),
130                                op::Parameter(2), op::Parameter(3)));
131 }
132 
133 }  // namespace
134 }  // namespace xla
135