1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/index_util.h"
17
18 #include <algorithm>
19 #include <string>
20
21 #include "tensorflow/compiler/xla/shape_util.h"
22 #include "tensorflow/compiler/xla/types.h"
23 #include "tensorflow/compiler/xla/xla_data.pb.h"
24 #include "tensorflow/core/lib/strings/str_util.h"
25 #include "tensorflow/core/platform/logging.h"
26
27 namespace xla {
28
MultidimensionalIndexToLinearIndex(const Shape & shape,tensorflow::gtl::ArraySlice<int64> multi_index)29 /* static */ int64 IndexUtil::MultidimensionalIndexToLinearIndex(
30 const Shape& shape, tensorflow::gtl::ArraySlice<int64> multi_index) {
31 DCHECK_EQ(shape.dimensions_size(), multi_index.size());
32 // Padding and nested layouts not supported yet.
33 DCHECK_EQ(0, shape.layout().padded_dimensions_size());
34
35 for (size_t i = 0; i < multi_index.size(); ++i) {
36 DCHECK_GE(multi_index[i], 0);
37 DCHECK_LT(multi_index[i], shape.dimensions(i))
38 << "indexing beyond extent in dimension " << i << ":"
39 << "\n\tindex: " << tensorflow::str_util::Join(multi_index, ",")
40 << "\n\tshape: " << ShapeUtil::HumanString(shape);
41 }
42
43 // Let the array be sized like so for dimensions i from 0 to n-1:
44 //
45 // [D{n-1} x D{n-2} x .. x D{0}]
46 //
47 // Let the order of the dimensions in the minor_to_major field in
48 // Layout be:
49 //
50 // L(0), L(1), ... , L(n-1)
51 //
52 // where L(0) is the most-minor dimension and L(n-1) the most-major. The
53 // multidimensional index:
54 //
55 // [I{0}, I{1}, ... , I{n-1}]
56 //
57 // then corresponds to the following linear index:
58 //
59 // linear_index =
60 // ((( ... + I{L(2)}) * D{L(1)} + I{L(1)}) * D{L(0)} + I{L(0)}
61 //
62 // or equivalently:
63 //
64 // linear_index =
65 // I{L(n-1)} * (D{L(n-2)} * D{L(n-3)} * D{L(n-4)} * .... D{L(0)}) +
66 // I{L(n-2)} * (D{L(n-3)} * D{L(n-4)} * .... D{L(0)}) +
67 // I{L(n-3)} * (D{L(n-4)} * .... D{L(0)}) +
68 // ... +
69 // I{L(2)} * (D{L(1)} * D{L(0)}) +
70 // I{L(1)} * D{L(0)} +
71 // I{L(0)}
72 //
73 // We compute the linear index value by accumulating the terms above from
74 // I{L(0)} up to I{L(n-1)}. Scale accumulates the product term D{L(0}} *
75 // D{L(1)} * ...
76
77 // Scale factor holding the growing product of D{L(i)} terms.
78 int64 scale = 1;
79 int64 linear_index = 0;
80 bool first = true;
81 for (auto dimension : LayoutUtil::MinorToMajor(shape)) {
82 if (first) {
83 // Avoid two multiplies on the first loop iteration
84 linear_index = multi_index[dimension];
85 scale = shape.dimensions(dimension);
86 first = false;
87 } else {
88 linear_index += scale * multi_index[dimension];
89 scale *= shape.dimensions(dimension);
90 }
91 }
92 return linear_index;
93 }
94
LinearIndexToMultidimensionalIndex(const Shape & shape,int64 linear_index)95 /* static */ std::vector<int64> IndexUtil::LinearIndexToMultidimensionalIndex(
96 const Shape& shape, int64 linear_index) {
97 // Padding and nested layouts not supported yet.
98 DCHECK_EQ(0, shape.layout().padded_dimensions_size());
99 DCHECK_GE(linear_index, 0);
100 DCHECK_LT(linear_index, ShapeUtil::ElementsIn(shape));
101
102 // The following formula computes each element of the multidimensional index
103 // (See comments in MultidimensionalIndexToLinearIndex for notation):
104 //
105 // I{L(0)} = linear_index % D{L(0)}
106 // I{L(1)} = (linear_index / D{L(0)}) % D{L(1)}
107 // I{L(2)} = (linear_index / (D{L(0)} * D{L(1)})) % D{L(2)}
108 // ...
109 std::vector<int64> multi_index(shape.dimensions_size());
110
111 // Accumulated product D{L(0)} * D{L(1)} * ...
112 int64 divisor = 1;
113 for (auto dimension : LayoutUtil::MinorToMajor(shape)) {
114 multi_index[dimension] =
115 (linear_index / divisor) % shape.dimensions(dimension);
116 divisor *= shape.dimensions(dimension);
117 }
118 return multi_index;
119 }
120
BumpIndices(const Shape & shape,tensorflow::gtl::MutableArraySlice<int64> indices)121 /* static */ bool IndexUtil::BumpIndices(
122 const Shape& shape, tensorflow::gtl::MutableArraySlice<int64> indices) {
123 for (int64 dimno = indices.size() - 1; dimno >= 0; --dimno) {
124 int64 limit = shape.dimensions(dimno);
125 if (indices[dimno] + 1 < limit) {
126 indices[dimno]++;
127 std::fill(indices.begin() + dimno + 1, indices.end(), 0);
128 return true;
129 }
130 }
131 return false;
132 }
133
GetDimensionStride(const Shape & shape,int64 dimension)134 /* static */ int64 IndexUtil::GetDimensionStride(const Shape& shape,
135 int64 dimension) {
136 int64 pdim_size = LayoutUtil::PaddedDimensions(shape).size();
137 int64 stride = 1;
138 DCHECK(pdim_size == 0 || pdim_size == shape.dimensions_size());
139 for (auto dim : LayoutUtil::MinorToMajor(shape)) {
140 if (dim == dimension) {
141 break;
142 }
143 if (pdim_size == 0) {
144 stride *= shape.dimensions(dim);
145 } else {
146 stride *= LayoutUtil::PaddedDimension(shape, dim);
147 }
148 }
149 return stride;
150 }
151
IndexInBounds(const Shape & shape,tensorflow::gtl::ArraySlice<int64> index)152 /* static */ bool IndexUtil::IndexInBounds(
153 const Shape& shape, tensorflow::gtl::ArraySlice<int64> index) {
154 int64 rank = ShapeUtil::Rank(shape);
155 if (rank != index.size()) {
156 return false;
157 }
158 for (int64 d = 0; d < rank; ++d) {
159 if (index[d] >= shape.dimensions(d)) {
160 return false;
161 }
162 }
163 return true;
164 }
165
CompareIndices(tensorflow::gtl::ArraySlice<int64> lhs,tensorflow::gtl::ArraySlice<int64> rhs)166 /* static */ int IndexUtil::CompareIndices(
167 tensorflow::gtl::ArraySlice<int64> lhs,
168 tensorflow::gtl::ArraySlice<int64> rhs) {
169 int64 rank = lhs.size();
170 CHECK_EQ(rhs.size(), rank);
171 for (int64 dim = 0; dim < rank; ++dim) {
172 if (lhs[dim] < rhs[dim]) {
173 return -1;
174 } else if (lhs[dim] > rhs[dim]) {
175 return 1;
176 }
177 }
178 return 0;
179 }
180
181 } // namespace xla
182