• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/index_util.h"
17 
18 #include <algorithm>
19 #include <string>
20 
21 #include "tensorflow/compiler/xla/shape_util.h"
22 #include "tensorflow/compiler/xla/types.h"
23 #include "tensorflow/compiler/xla/xla_data.pb.h"
24 #include "tensorflow/core/lib/strings/str_util.h"
25 #include "tensorflow/core/platform/logging.h"
26 
27 namespace xla {
28 
MultidimensionalIndexToLinearIndex(const Shape & shape,tensorflow::gtl::ArraySlice<int64> multi_index)29 /* static */ int64 IndexUtil::MultidimensionalIndexToLinearIndex(
30     const Shape& shape, tensorflow::gtl::ArraySlice<int64> multi_index) {
31   DCHECK_EQ(shape.dimensions_size(), multi_index.size());
32   // Padding and nested layouts not supported yet.
33   DCHECK_EQ(0, shape.layout().padded_dimensions_size());
34 
35   for (size_t i = 0; i < multi_index.size(); ++i) {
36     DCHECK_GE(multi_index[i], 0);
37     DCHECK_LT(multi_index[i], shape.dimensions(i))
38         << "indexing beyond extent in dimension " << i << ":"
39         << "\n\tindex: " << tensorflow::str_util::Join(multi_index, ",")
40         << "\n\tshape: " << ShapeUtil::HumanString(shape);
41   }
42 
43   // Let the array be sized like so for dimensions i from 0 to n-1:
44   //
45   //   [D{n-1} x D{n-2} x .. x D{0}]
46   //
47   // Let the order of the dimensions in the minor_to_major field in
48   // Layout be:
49   //
50   //   L(0), L(1), ... , L(n-1)
51   //
52   // where L(0) is the most-minor dimension and L(n-1) the most-major. The
53   // multidimensional index:
54   //
55   //   [I{0}, I{1}, ... , I{n-1}]
56   //
57   // then corresponds to the following linear index:
58   //
59   // linear_index =
60   //   (((  ... + I{L(2)}) * D{L(1)} + I{L(1)}) * D{L(0)} + I{L(0)}
61   //
62   // or equivalently:
63   //
64   // linear_index =
65   //   I{L(n-1)} * (D{L(n-2)} * D{L(n-3)} * D{L(n-4)} *     ....    D{L(0)}) +
66   //   I{L(n-2)} *             (D{L(n-3)} * D{L(n-4)} *     ....    D{L(0)}) +
67   //   I{L(n-3)} *                         (D{L(n-4)} *     ....    D{L(0)}) +
68   //                                   ...                                   +
69   //   I{L(2)} *                                         (D{L(1)} * D{L(0)}) +
70   //   I{L(1)} *                                                    D{L(0)}  +
71   //   I{L(0)}
72   //
73   // We compute the linear index value by accumulating the terms above from
74   // I{L(0)} up to I{L(n-1)}. Scale accumulates the product term D{L(0}} *
75   // D{L(1)} * ...
76 
77   // Scale factor holding the growing product of D{L(i)} terms.
78   int64 scale = 1;
79   int64 linear_index = 0;
80   bool first = true;
81   for (auto dimension : LayoutUtil::MinorToMajor(shape)) {
82     if (first) {
83       // Avoid two multiplies on the first loop iteration
84       linear_index = multi_index[dimension];
85       scale = shape.dimensions(dimension);
86       first = false;
87     } else {
88       linear_index += scale * multi_index[dimension];
89       scale *= shape.dimensions(dimension);
90     }
91   }
92   return linear_index;
93 }
94 
LinearIndexToMultidimensionalIndex(const Shape & shape,int64 linear_index)95 /* static */ std::vector<int64> IndexUtil::LinearIndexToMultidimensionalIndex(
96     const Shape& shape, int64 linear_index) {
97   // Padding and nested layouts not supported yet.
98   DCHECK_EQ(0, shape.layout().padded_dimensions_size());
99   DCHECK_GE(linear_index, 0);
100   DCHECK_LT(linear_index, ShapeUtil::ElementsIn(shape));
101 
102   // The following formula computes each element of the multidimensional index
103   // (See comments in MultidimensionalIndexToLinearIndex for notation):
104   //
105   // I{L(0)} = linear_index % D{L(0)}
106   // I{L(1)} = (linear_index / D{L(0)}) % D{L(1)}
107   // I{L(2)} = (linear_index / (D{L(0)} * D{L(1)})) % D{L(2)}
108   // ...
109   std::vector<int64> multi_index(shape.dimensions_size());
110 
111   // Accumulated product D{L(0)} * D{L(1)} * ...
112   int64 divisor = 1;
113   for (auto dimension : LayoutUtil::MinorToMajor(shape)) {
114     multi_index[dimension] =
115         (linear_index / divisor) % shape.dimensions(dimension);
116     divisor *= shape.dimensions(dimension);
117   }
118   return multi_index;
119 }
120 
BumpIndices(const Shape & shape,tensorflow::gtl::MutableArraySlice<int64> indices)121 /* static */ bool IndexUtil::BumpIndices(
122     const Shape& shape, tensorflow::gtl::MutableArraySlice<int64> indices) {
123   for (int64 dimno = indices.size() - 1; dimno >= 0; --dimno) {
124     int64 limit = shape.dimensions(dimno);
125     if (indices[dimno] + 1 < limit) {
126       indices[dimno]++;
127       std::fill(indices.begin() + dimno + 1, indices.end(), 0);
128       return true;
129     }
130   }
131   return false;
132 }
133 
GetDimensionStride(const Shape & shape,int64 dimension)134 /* static */ int64 IndexUtil::GetDimensionStride(const Shape& shape,
135                                                  int64 dimension) {
136   int64 pdim_size = LayoutUtil::PaddedDimensions(shape).size();
137   int64 stride = 1;
138   DCHECK(pdim_size == 0 || pdim_size == shape.dimensions_size());
139   for (auto dim : LayoutUtil::MinorToMajor(shape)) {
140     if (dim == dimension) {
141       break;
142     }
143     if (pdim_size == 0) {
144       stride *= shape.dimensions(dim);
145     } else {
146       stride *= LayoutUtil::PaddedDimension(shape, dim);
147     }
148   }
149   return stride;
150 }
151 
IndexInBounds(const Shape & shape,tensorflow::gtl::ArraySlice<int64> index)152 /* static */ bool IndexUtil::IndexInBounds(
153     const Shape& shape, tensorflow::gtl::ArraySlice<int64> index) {
154   int64 rank = ShapeUtil::Rank(shape);
155   if (rank != index.size()) {
156     return false;
157   }
158   for (int64 d = 0; d < rank; ++d) {
159     if (index[d] >= shape.dimensions(d)) {
160       return false;
161     }
162   }
163   return true;
164 }
165 
CompareIndices(tensorflow::gtl::ArraySlice<int64> lhs,tensorflow::gtl::ArraySlice<int64> rhs)166 /* static */ int IndexUtil::CompareIndices(
167     tensorflow::gtl::ArraySlice<int64> lhs,
168     tensorflow::gtl::ArraySlice<int64> rhs) {
169   int64 rank = lhs.size();
170   CHECK_EQ(rhs.size(), rank);
171   for (int64 dim = 0; dim < rank; ++dim) {
172     if (lhs[dim] < rhs[dim]) {
173       return -1;
174     } else if (lhs[dim] > rhs[dim]) {
175       return 1;
176     }
177   }
178   return 0;
179 }
180 
181 }  // namespace xla
182