• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/slicing.h"
17 
18 #include "tensorflow/compiler/xla/client/xla_builder.h"
19 #include "tensorflow/compiler/xla/literal_util.h"
20 #include "tensorflow/compiler/xla/test.h"
21 #include "tensorflow/compiler/xla/tests/client_library_test_base.h"
22 #include "tensorflow/compiler/xla/tests/test_macros.h"
23 #include "tensorflow/compiler/xla/types.h"
24 
25 namespace xla {
26 namespace {
27 
28 using SlicingTest = xla::ClientLibraryTestBase;
29 
BValsRight()30 xla::Array2D<float> BValsRight() {
31   return {{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}};
32 }
33 
BValsLeft()34 xla::Array2D<float> BValsLeft() {
35   return {{1, 2, 3}, {4, 5, 6}, {7, 8, 9}, {10, 11, 12}};
36 }
37 
AValsFull()38 xla::Array2D<float> AValsFull() {
39   return {{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 7, 9, 0}, {5, 8, 10, 11}};
40 }
41 
BatchedAValsFull()42 xla::Array3D<float> BatchedAValsFull() {
43   return {{
44               {2, 0, 1, 2},
45               {3, 6, 0, 1},
46               {4, 7, 9, 0},
47               {5, 8, 10, 11},
48           },
49           {
50               {16, 24, 8, 12},
51               {24, 61, 82, 48},
52               {8, 82, 456, 106},
53               {12, 48, 106, 62},
54           }};
55 }
56 
XLA_TEST_F(SlicingTest,Simple2dLookup)57 XLA_TEST_F(SlicingTest, Simple2dLookup) {
58   xla::XlaBuilder builder(TestName());
59 
60   xla::XlaOp a, x, y;
61   auto a_data = CreateR2Parameter<float>(BValsRight(), 0, "a", &builder, &a);
62   auto x_data = CreateR0Parameter<int>(2, 1, "x", &builder, &x);
63   auto y_data = CreateR0Parameter<int>(1, 2, "y", &builder, &y);
64   DynamicSliceInMinorDims(a, {x, y}, {1, 1});
65 
66   ComputeAndCompareR2<float>(&builder, {{10}},
67                              {a_data.get(), x_data.get(), y_data.get()},
68                              xla::ErrorSpec(1e-2, 1e-2));
69 }
70 
XLA_TEST_F(SlicingTest,Simple3dLookup)71 XLA_TEST_F(SlicingTest, Simple3dLookup) {
72   xla::XlaBuilder builder(TestName());
73 
74   xla::XlaOp a, index;
75   auto a_data =
76       CreateR3Parameter<float>(BatchedAValsFull(), 0, "a", &builder, &a);
77   auto index_data = CreateR0Parameter<int>(1, 1, "index", &builder, &index);
78 
79   DynamicSliceInMinorDims(a, {index, xla::ConstantR0<int32>(&builder, 0)},
80                           {1, 4});
81 
82   ComputeAndCompareR3<float>(&builder, {{{3, 6, 0, 1}}, {{24, 61, 82, 48}}},
83                              {a_data.get(), index_data.get()});
84 }
85 
XLA_TEST_F(SlicingTest,SimpleSliceUpdate)86 XLA_TEST_F(SlicingTest, SimpleSliceUpdate) {
87   xla::XlaBuilder builder(TestName());
88 
89   xla::XlaOp a, b, x, y;
90   auto a_data = CreateR2Parameter<float>(AValsFull(), 0, "a", &builder, &a);
91   auto b_data = CreateR2Parameter<float>({{9, 1, -10}}, 1, "b", &builder, &b);
92   auto x_data = CreateR0Parameter<int>(2, 2, "x", &builder, &x);
93   auto y_data = CreateR0Parameter<int>(1, 3, "y", &builder, &y);
94 
95   DynamicUpdateSliceInMinorDims(a, b, {x, y});
96 
97   xla::Array2D<float> expected(
98       {{{2, 0, 1, 2}, {3, 6, 0, 1}, {4, 9, 1, -10}, {5, 8, 10, 11}}});
99 
100   ComputeAndCompareR2<float>(
101       &builder, expected,
102       {a_data.get(), b_data.get(), x_data.get(), y_data.get()});
103 }
104 
XLA_TEST_F(SlicingTest,TorchGather)105 XLA_TEST_F(SlicingTest, TorchGather) {
106   xla::XlaBuilder builder(TestName());
107 
108   xla::XlaOp input, index;
109   auto input_data =
110       CreateR2Parameter<int>({{1, 2}, {3, 4}}, 0, "input", &builder, &input);
111   auto index_data =
112       CreateR2Parameter<int>({{0, 0}, {1, 0}}, 1, "index", &builder, &index);
113   TorchGather(input, index, 1);
114 
115   ComputeAndCompareR2<int>(&builder, {{1, 1}, {4, 3}},
116                            {input_data.get(), index_data.get()});
117 }
118 
XLA_TEST_F(SlicingTest,TorchIndexSelectOn0)119 XLA_TEST_F(SlicingTest, TorchIndexSelectOn0) {
120   xla::XlaBuilder builder(TestName());
121 
122   xla::XlaOp input, index;
123   auto input_data =
124       CreateR2Parameter<float>({{0.1427, 0.0231, -0.5414, -1.0009},
125                                 {-0.4664, 0.2647, -0.1228, -1.1068},
126                                 {-1.1734, -0.6571, 0.7230, -0.6004}},
127                                0, "input", &builder, &input);
128   auto index_data =
129       CreateR1Parameter<int>({0, 2}, 1, "index", &builder, &index);
130   TorchIndexSelect(input, index, 0);
131 
132   ComputeAndCompareR2<float>(
133       &builder,
134       {{0.1427, 0.0231, -0.5414, -1.0009}, {-1.1734, -0.6571, 0.7230, -0.6004}},
135       {input_data.get(), index_data.get()});
136 }
137 
XLA_TEST_F(SlicingTest,TorchIndexSelectOn1)138 XLA_TEST_F(SlicingTest, TorchIndexSelectOn1) {
139   xla::XlaBuilder builder(TestName());
140 
141   xla::XlaOp input, index;
142   auto input_data =
143       CreateR2Parameter<float>({{0.1427, 0.0231, -0.5414, -1.0009},
144                                 {-0.4664, 0.2647, -0.1228, -1.1068},
145                                 {-1.1734, -0.6571, 0.7230, -0.6004}},
146                                0, "input", &builder, &input);
147   auto index_data =
148       CreateR1Parameter<int>({0, 2}, 1, "index", &builder, &index);
149   TorchIndexSelect(input, index, 1);
150 
151   ComputeAndCompareR2<float>(
152       &builder, {{0.1427, -0.5414}, {-0.4664, -0.1228}, {-1.1734, 0.7230}},
153       {input_data.get(), index_data.get()});
154 }
155 
156 }  // namespace
157 }  // namespace xla
158