• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_REFERENCE_UTIL_H_
17 #define TENSORFLOW_COMPILER_XLA_REFERENCE_UTIL_H_
18 
19 #include <array>
20 #include <functional>
21 #include <memory>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/memory/memory.h"
26 #include "absl/types/span.h"
27 #include "tensorflow/compiler/xla/array2d.h"
28 #include "tensorflow/compiler/xla/array3d.h"
29 #include "tensorflow/compiler/xla/array4d.h"
30 #include "tensorflow/compiler/xla/client/padding.h"
31 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/platform/macros.h"
35 #include "tensorflow/core/platform/types.h"
36 
37 namespace xla {
38 
39 // Utility class for reference implementations of linear algebra routines.
40 class ReferenceUtil {
41  public:
42   // Returns the result of a transpose operation on the input matrix.
43   template <typename T>
TransposeArray2D(const Array2D<T> & operand)44   static std::unique_ptr<Array2D<T>> TransposeArray2D(
45       const Array2D<T>& operand) {
46     auto result =
47         absl::make_unique<Array2D<T>>(operand.width(), operand.height());
48     for (int64_t w = 0; w < operand.width(); ++w) {
49       for (int64_t h = 0; h < operand.height(); ++h) {
50         (*result)(w, h) = operand(h, w);
51       }
52     }
53 
54     return result;
55   }
56 
57   // Returns the result of a matrix multiply `lhs x rhs`.
58   template <typename T>
MatmulArray2D(const Array2D<T> & lhs,const Array2D<T> & rhs)59   static std::unique_ptr<Array2D<T>> MatmulArray2D(const Array2D<T>& lhs,
60                                                    const Array2D<T>& rhs) {
61     return HloEvaluator::MatmulArray2D(lhs, rhs);
62   }
63 
64   // Converts the input operand to use f64 values instead of f32 values.
65   static std::unique_ptr<Array2D<double>> Array2DF32ToF64(
66       const Array2D<float>& input);
67 
68   // Returns the result of a convolution `lhs <conv> rhs`, with the default
69   // convolution dimension numbers returned from
70   // ComputationBuilder::CreateDefaultConvDimensionNumbers().
71   static std::unique_ptr<Array4D<float>> ConvArray4D(
72       const Array4D<float>& lhs, const Array4D<float>& rhs,
73       std::pair<int64, int64> kernel_stride, Padding padding);
74 
75   // Returns the result of a convolution `lhs <conv> rhs`, with the given
76   // convolution dimension numbers.
77   static std::unique_ptr<Array4D<float>> ConvArray4DGeneralDimensions(
78       const Array4D<float>& lhs, const Array4D<float>& rhs,
79       std::pair<int64, int64> kernel_stride, Padding padding,
80       ConvolutionDimensionNumbers dimension_numbers);
81 
82   // Returns the result of a convolution `lhs <conv> rhs`, with the given
83   // dilation factors.
84   static std::unique_ptr<Array4D<float>> ConvArray4DGeneralDimensionsDilated(
85       const Array4D<float>& lhs, const Array4D<float>& rhs,
86       std::pair<int64, int64> kernel_stride, Padding padding,
87       std::pair<int64, int64> lhs_dilation,
88       std::pair<int64, int64> rhs_dilation, ConvolutionDimensionNumbers dnums);
89 
90   // Returns the result of a convolution `lhs <conv> rhs`, with the default
91   // convolution dimension numbers returned from
92   // ComputationBuilder::CreateDefaultConvDimensionNumbers().
93   static std::unique_ptr<Array3D<float>> ConvArray3D(const Array3D<float>& lhs,
94                                                      const Array3D<float>& rhs,
95                                                      int64_t kernel_stride,
96                                                      Padding padding);
97 
98   // Returns the result of a convolution `lhs <conv> rhs`.
99   static std::unique_ptr<Array3D<float>> ConvArray3DGeneralDimensionsDilated(
100       const Array3D<float>& lhs, const Array3D<float>& rhs,
101       int64_t kernel_stride, Padding padding, int64_t lhs_dilation,
102       int64_t rhs_dilation, const ConvolutionDimensionNumbers& dnums);
103 
104   // Returns the result of a separable  convolution with the given parameters.
105   // kernel_stride and padding applies to the depthwise convolution during
106   // the separable convolution. pointwise_weights.depth() must be equal to
107   // input.depth() * depthwise_weights.planes().
108   static std::unique_ptr<Array4D<float>> SeparableConvArray4D(
109       const Array4D<float>& input, const Array4D<float>& depthwise_weights,
110       const Array4D<float>& pointwise_weights,
111       std::pair<int64, int64> kernel_stride, Padding padding);
112 
113   // Returns the result of reducing a matrix to a column vector. init is the
114   // initial value for the reduce operation, and reduce_function is the function
115   // to apply for each reduction step.
116   static std::unique_ptr<std::vector<float>> ReduceToColArray2D(
117       const Array2D<float>& matrix, float init,
118       const std::function<float(float, float)>& reduce_function);
119 
120   // Returns the result of reducing a matrix to a row vector. init is the
121   // initial value for the reduce operation, and reduce_function is the function
122   // to apply for each reduction step.
123   static std::unique_ptr<std::vector<float>> ReduceToRowArray2D(
124       const Array2D<float>& matrix, float init,
125       const std::function<float(float, float)>& reduce_function);
126 
127   // Performs a R2=>R1 reduction by reducing away the dimension specified in
128   // 'dimension_to_reduce'.
129   template <typename T>
ReduceR2ToR1(const Array2D<T> & input,int dimension_to_reduce,T init,const std::function<T (T,T)> & freduce)130   static std::vector<T> ReduceR2ToR1(const Array2D<T>& input,
131                                      int dimension_to_reduce, T init,
132                                      const std::function<T(T, T)>& freduce) {
133     std::vector<T> result(dimension_to_reduce == 0 ? input.n2() : input.n1(),
134                           init);
135     for (int i0 = 0; i0 < input.n1(); ++i0) {
136       for (int i1 = 0; i1 < input.n2(); ++i1) {
137         int output = dimension_to_reduce == 0 ? i1 : i0;
138         result[output] = freduce(result[output], input(i0, i1));
139       }
140     }
141     return result;
142   }
143 
144   // Returns the result of reducing the 4D array to a vector, reducing away
145   // the dimensions specified in dims.
146   static std::vector<float> Reduce4DTo1D(
147       const Array4D<float>& array, float init, absl::Span<const int64> dims,
148       const std::function<float(float, float)>& reduce_function);
149 
150   // Broadcast 1D dimension to 4D, from the dimension `broadcast_from_dim`.
151   static std::unique_ptr<Array4D<float>> Broadcast1DTo4D(
152       const std::vector<float>& array, const std::vector<int64>& bounds,
153       int64_t broadcast_from_dim);
154 
155   // Returns the result of reducing the 3D array to a 2D array, reducing away
156   // the dimensions specified in dims.
157   static std::unique_ptr<Array2D<float>> Reduce3DTo2D(
158       const Array3D<float>& array, float init, absl::Span<const int64> dims,
159       const std::function<float(float, float)>& reduce_function);
160 
161   // Applies map_function to each element in the input (2D array) and returns
162   // the result.
163   static std::unique_ptr<Array2D<float>> MapArray2D(
164       const Array2D<float>& matrix,
165       const std::function<float(float)>& map_function);
166 
167   // Applies map_function to each pair of corresponding elements in the two
168   // inputs arrays and returns the result.
169   static std::unique_ptr<Array2D<float>> MapArray2D(
170       const Array2D<float>& lhs, const Array2D<float>& rhs,
171       const std::function<float(float, float)>& map_function);
172 
173   // Number of windows in a given dimension. Calculation taken from
174   // xla::MakePadding().
175   static int64 WindowCount(int64_t unpadded_width, int64_t window_len,
176                            int64_t stride, Padding padding);
177 
178   // Windowed reductions with Add as the function to apply.
179   static std::unique_ptr<std::vector<float>> ReduceWindow1DAdd(
180       absl::Span<const float> operand, float init,
181       absl::Span<const int64> window, absl::Span<const int64> stride,
182       Padding padding);
183   static std::unique_ptr<Array3D<float>> ReduceWindow3DAdd(
184       const Array3D<float>& operand, float init, absl::Span<const int64> window,
185       absl::Span<const int64> stride, Padding padding);
186   static std::unique_ptr<Array4D<float>> ReduceWindow4DAdd(
187       const Array4D<float>& operand, float init, absl::Span<const int64> window,
188       absl::Span<const int64> stride, Padding padding);
189 
190   // Windowed reductions with a generic reduce function.
191   static std::unique_ptr<std::vector<float>> ReduceWindow1DGeneric(
192       absl::Span<const float> operand, float init,
193       const std::function<float(float, float)>& reduce_func,
194       absl::Span<const int64> window, absl::Span<const int64> stride,
195       absl::Span<const std::pair<int64, int64>> padding);
196   static std::unique_ptr<Array4D<float>> ReduceWindow4DGeneric(
197       const Array4D<float>& operand, float init,
198       const std::function<float(float, float)>& reduce_func,
199       absl::Span<const int64> window, absl::Span<const int64> stride,
200       Padding padding);
201   // With arbitrary padding.
202   static std::unique_ptr<Array4D<float>> ReduceWindow4DGeneric(
203       const Array4D<float>& operand, float init,
204       const std::function<float(float, float)>& reduce_func,
205       absl::Span<const int64> window, absl::Span<const int64> stride,
206       absl::Span<const std::pair<int64, int64>> padding);
207 
208   // Batch normalize data.
209   static std::unique_ptr<Array4D<float>> BatchNorm4D(
210       const Array4D<float>& input, const Array4D<float>& mean,
211       const Array4D<float>& var, const Array4D<float>& scale,
212       const Array4D<float>& offset, float epsilon);
213 
214   // Performs select and scatter with Greater Than or equal as the select, plus
215   // as the scatter, and Same Padding.
216   // TODO(b/74533103) Switch tests to evaluator and remove this implementation.
217   static std::unique_ptr<Array4D<float>> SelectAndScatter4DGePlus(
218       const Array4D<float>& operand, const Array4D<float>& source, float init,
219       absl::Span<const int64> window, absl::Span<const int64> stride,
220       bool same_padding);
221 
222   // Concatenates the lhs and rhs arrays along the concatenate_dimension.
223   // E.g. if concatenate_dimension is 0, the "n1"/height dimension is
224   // concatenated, so the arrays are stacked on top of each other.
225   template <typename T>
Concat2D(const Array2D<T> & lhs,const Array2D<T> & rhs,int concatenate_dimension)226   static std::unique_ptr<Array2D<T>> Concat2D(const Array2D<T>& lhs,
227                                               const Array2D<T>& rhs,
228                                               int concatenate_dimension) {
229     CHECK(0 <= concatenate_dimension && concatenate_dimension < 2);
230     auto result = absl::make_unique<Array2D<T>>(
231         concatenate_dimension == 0 ? lhs.n1() + rhs.n1() : lhs.n1(),
232         concatenate_dimension == 1 ? lhs.n2() + rhs.n2() : lhs.n2());
233     for (int64_t i0 = 0; i0 < result->n1(); ++i0) {
234       for (int64_t i1 = 0; i1 < result->n2(); ++i1) {
235         // If we exceed the bounds of the LHS, draw from the RHS, where the
236         // result index is adjusted by the number of values present in the LHS.
237         (*result)(i0, i1) = i0 < lhs.n1() && i1 < lhs.n2()
238                                 ? lhs(i0, i1)
239                                 : rhs(i0 >= lhs.n1() ? i0 - lhs.n1() : i0,
240                                       i1 >= lhs.n2() ? i1 - lhs.n2() : i1);
241       }
242     }
243     return result;
244   }
245 
246   // Concatenates the lhs and rhs 3D arrays along the concatenate_dimension. lhs
247   // and rhs must have the same dimensions except for the concatenate dimension.
248   template <typename T>
Concat3D(const Array3D<T> & lhs,const Array3D<T> & rhs,int concatenate_dimension)249   static std::unique_ptr<Array3D<T>> Concat3D(const Array3D<T>& lhs,
250                                               const Array3D<T>& rhs,
251                                               int concatenate_dimension) {
252     CHECK(0 <= concatenate_dimension && concatenate_dimension < 3);
253     const int64 lhs_dims[] = {lhs.n1(), lhs.n2(), lhs.n3()};
254     const int64 rhs_dims[] = {rhs.n1(), rhs.n2(), rhs.n3()};
255     int64 out_dims[] = {rhs.n1(), rhs.n2(), rhs.n3()};
256     for (int i = 0; i < 3; ++i) {
257       if (i != concatenate_dimension) {
258         out_dims[i] = lhs_dims[i];
259         CHECK_EQ(lhs_dims[i], rhs_dims[i]);
260       } else {
261         out_dims[i] = lhs_dims[i] + rhs_dims[i];
262       }
263     }
264     auto result =
265         absl::make_unique<Array3D<T>>(out_dims[0], out_dims[1], out_dims[2]);
266     for (int64_t i0 = 0; i0 < result->n1(); ++i0) {
267       for (int64_t i1 = 0; i1 < result->n2(); ++i1) {
268         for (int64_t i2 = 0; i2 < result->n3(); ++i2) {
269           (*result)(i0, i1, i2) =
270               i0 < lhs.n1() && i1 < lhs.n2() && i2 < lhs.n3()
271                   ? lhs(i0, i1, i2)
272                   : rhs(i0 >= lhs.n1() ? i0 - lhs.n1() : i0,
273                         i1 >= lhs.n2() ? i1 - lhs.n2() : i1,
274                         i2 >= lhs.n3() ? i2 - lhs.n3() : i2);
275         }
276       }
277     }
278     return result;
279   }
280 
281   // Concatenates the lhs and rhs 4D arrays along the concatenate_dimension. lhs
282   // and rhs must have the same dimensions except for the concatenate dimension.
283   template <typename T>
Concat4D(const Array4D<T> & lhs,const Array4D<T> & rhs,int concatenate_dimension)284   static std::unique_ptr<Array4D<T>> Concat4D(const Array4D<T>& lhs,
285                                               const Array4D<T>& rhs,
286                                               int concatenate_dimension) {
287     CHECK(0 <= concatenate_dimension && concatenate_dimension < 4);
288     const int64 lhs_dims[] = {lhs.n1(), lhs.n2(), lhs.n3(), lhs.n4()};
289     const int64 rhs_dims[] = {rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()};
290     int64 out_dims[] = {rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()};
291     for (int i = 0; i < 4; ++i) {
292       if (i != concatenate_dimension) {
293         out_dims[i] = lhs_dims[i];
294         CHECK_EQ(lhs_dims[i], rhs_dims[i]);
295       } else {
296         out_dims[i] = lhs_dims[i] + rhs_dims[i];
297       }
298     }
299     auto result = absl::make_unique<Array4D<T>>(out_dims[0], out_dims[1],
300                                                 out_dims[2], out_dims[3]);
301     for (int64_t i0 = 0; i0 < result->n1(); ++i0) {
302       for (int64_t i1 = 0; i1 < result->n2(); ++i1) {
303         for (int64_t i2 = 0; i2 < result->n3(); ++i2) {
304           for (int64_t i3 = 0; i3 < result->n4(); ++i3) {
305             (*result)(i0, i1, i2, i3) =
306                 i0 < lhs.n1() && i1 < lhs.n2() && i2 < lhs.n3() && i3 < lhs.n4()
307                     ? lhs(i0, i1, i2, i3)
308                     : rhs(i0 >= lhs.n1() ? i0 - lhs.n1() : i0,
309                           i1 >= lhs.n2() ? i1 - lhs.n2() : i1,
310                           i2 >= lhs.n3() ? i2 - lhs.n3() : i2,
311                           i3 >= lhs.n4() ? i3 - lhs.n4() : i3);
312           }
313         }
314       }
315     }
316     return result;
317   }
318 
319   // Slices with index clamping
320   template <typename T>
ClampSlice1D(absl::Span<const T> input,int64_t start,int64_t size)321   static std::vector<T> ClampSlice1D(absl::Span<const T> input, int64_t start,
322                                      int64_t size) {
323     start = std::min<int64>(std::max<int64>(0, start), input.size() - size);
324     std::vector<T> result;
325     for (int64_t i = 0; i < size; ++i) {
326       result.push_back(input[(start + i)]);
327     }
328     return result;
329   }
330 
331   // Slices the input array given starting indices, limit indices, and strides
332   // in each dimension.
333   template <typename T>
Slice2D(const Array2D<T> & input,std::array<int64,2> starts,std::array<int64,2> limits,std::array<int64,2> strides)334   static std::unique_ptr<Array2D<T>> Slice2D(const Array2D<T>& input,
335                                              std::array<int64, 2> starts,
336                                              std::array<int64, 2> limits,
337                                              std::array<int64, 2> strides) {
338     CHECK_LE(starts[0], input.n1());
339     CHECK_LE(starts[1], input.n2());
340     CHECK_LE(limits[0], input.n1());
341     CHECK_LE(limits[1], input.n2());
342     CHECK_GE(strides[0], 1);
343     CHECK_GE(strides[1], 1);
344     auto result = absl::make_unique<Array2D<T>>(
345         CeilOfRatio(limits[0] - starts[0], strides[0]),
346         CeilOfRatio(limits[1] - starts[1], strides[1]));
347     for (int64_t i0 = 0; i0 < result->n1(); ++i0) {
348       for (int64_t i1 = 0; i1 < result->n2(); ++i1) {
349         (*result)(i0, i1) =
350             input(starts[0] + i0 * strides[0], starts[1] + i1 * strides[1]);
351       }
352     }
353     return result;
354   }
355 
356   template <typename T>
Slice3D(const Array3D<T> & input,std::array<int64,3> starts,std::array<int64,3> limits,std::array<int64,3> strides)357   static std::unique_ptr<Array3D<T>> Slice3D(const Array3D<T>& input,
358                                              std::array<int64, 3> starts,
359                                              std::array<int64, 3> limits,
360                                              std::array<int64, 3> strides) {
361     CHECK_LE(starts[0], input.n1());
362     CHECK_LE(starts[1], input.n2());
363     CHECK_LE(starts[2], input.n3());
364     CHECK_LE(limits[0], input.n1());
365     CHECK_LE(limits[1], input.n2());
366     CHECK_LE(limits[2], input.n3());
367     CHECK_GE(strides[0], 1);
368     CHECK_GE(strides[1], 1);
369     CHECK_GE(strides[2], 1);
370     auto result = absl::make_unique<Array3D<T>>(
371         CeilOfRatio(limits[0] - starts[0], strides[0]),
372         CeilOfRatio(limits[1] - starts[1], strides[1]),
373         CeilOfRatio(limits[2] - starts[2], strides[2]));
374 
375     for (int64_t i0 = 0; i0 < result->n1(); ++i0) {
376       for (int64_t i1 = 0; i1 < result->n2(); ++i1) {
377         for (int64_t i2 = 0; i2 < result->n3(); ++i2) {
378           (*result)(i0, i1, i2) =
379               input(starts[0] + i0 * strides[0], starts[1] + i1 * strides[1],
380                     starts[2] + i2 * strides[2]);
381         }
382       }
383     }
384     return result;
385   }
386 
387   template <typename T>
Slice4D(const Array4D<T> & input,std::array<int64,4> starts,std::array<int64,4> limits,std::array<int64,4> strides)388   static std::unique_ptr<Array4D<T>> Slice4D(const Array4D<T>& input,
389                                              std::array<int64, 4> starts,
390                                              std::array<int64, 4> limits,
391                                              std::array<int64, 4> strides) {
392     CHECK_LE(starts[0], input.n1());
393     CHECK_LE(starts[1], input.n2());
394     CHECK_LE(starts[2], input.n3());
395     CHECK_LE(starts[3], input.n4());
396     CHECK_LE(limits[0], input.n1());
397     CHECK_LE(limits[1], input.n2());
398     CHECK_LE(limits[2], input.n3());
399     CHECK_LE(limits[3], input.n4());
400     CHECK_GE(strides[0], 1);
401     CHECK_GE(strides[1], 1);
402     CHECK_GE(strides[2], 1);
403     CHECK_GE(strides[3], 1);
404     auto result = absl::make_unique<Array4D<T>>(
405         CeilOfRatio(limits[0] - starts[0], strides[0]),
406         CeilOfRatio(limits[1] - starts[1], strides[1]),
407         CeilOfRatio(limits[2] - starts[2], strides[2]),
408         CeilOfRatio(limits[3] - starts[3], strides[3]));
409     for (int64_t i0 = 0; i0 < result->n1(); ++i0) {
410       for (int64_t i1 = 0; i1 < result->n2(); ++i1) {
411         for (int64_t i2 = 0; i2 < result->n3(); ++i2) {
412           for (int64_t i3 = 0; i3 < result->n4(); ++i3) {
413             (*result)(i0, i1, i2, i3) =
414                 input(starts[0] + i0 * strides[0], starts[1] + i1 * strides[1],
415                       starts[2] + i2 * strides[2], starts[3] + i3 * strides[3]);
416           }
417         }
418       }
419     }
420     return result;
421   }
422 
423   // Applies map_function to each element in the input (2D array) and returns
424   // the result.
425   // (row, column) index of each element is also provided as arguments to
426   // map_function.
427   static std::unique_ptr<Array2D<float>> MapWithIndexArray2D(
428       const Array2D<float>& matrix,
429       const std::function<float(float, int64_t, int64_t)>& map_function);
430 
431   // Applies map_function to each element in the input (4D array) and returns
432   // the result.
433   template <typename F>
MapArray4D(const Array4D<float> & input,F && map_function)434   static std::unique_ptr<Array4D<float>> MapArray4D(const Array4D<float>& input,
435                                                     F&& map_function) {
436     return MapWithIndexArray4D(
437         input, [&](float value, int64_t, int64_t, int64_t, int64_t) {
438           return map_function(value);
439         });
440   }
441 
442   // Applies map_function to each element in the input (4D array) and returns
443   // the result.
444   // (plane, depth, height, width) index of each element is also provided as
445   // arguments to map_function.
446   template <typename F>
MapWithIndexArray4D(const Array4D<float> & input,F && map_function)447   static std::unique_ptr<Array4D<float>> MapWithIndexArray4D(
448       const Array4D<float>& input, F&& map_function) {
449     auto result = absl::make_unique<Array4D<float>>(
450         input.planes(), input.depth(), input.height(), input.width());
451     for (int64_t plane = 0; plane < input.planes(); ++plane) {
452       for (int64_t depth = 0; depth < input.depth(); ++depth) {
453         for (int64_t height = 0; height < input.height(); ++height) {
454           for (int64_t width = 0; width < input.width(); ++width) {
455             (*result)(plane, depth, height, width) =
456                 map_function(input(plane, depth, height, width), plane, depth,
457                              height, width);
458           }
459         }
460       }
461     }
462     return result;
463   }
464 
465   // Applies map_function to each pair of elements in the input lhs and rhs
466   // (4D array) and returns the result.
467   template <typename F>
MapArray4D(const Array4D<float> & lhs,const Array4D<float> & rhs,F && map_function)468   static std::unique_ptr<Array4D<float>> MapArray4D(const Array4D<float>& lhs,
469                                                     const Array4D<float>& rhs,
470                                                     F&& map_function) {
471     return MapWithIndexArray4D(
472         lhs, rhs,
473         [&](float lhs, float rhs, int64_t, int64_t, int64_t, int64_t) {
474           return map_function(lhs, rhs);
475         });
476   }
477 
478   // Applies map_function to each pair of element in lhs and rhs (4D array) and
479   // returns the result.
480   // (plane, depth, height, width) index of each element is also provided as
481   // arguments to map_function.
482   template <typename F>
MapWithIndexArray4D(const Array4D<float> & lhs,const Array4D<float> & rhs,F && map_function)483   static std::unique_ptr<Array4D<float>> MapWithIndexArray4D(
484       const Array4D<float>& lhs, const Array4D<float>& rhs, F&& map_function) {
485     auto result = absl::make_unique<Array4D<float>>(lhs.planes(), lhs.depth(),
486                                                     lhs.height(), lhs.width());
487     for (int64_t plane = 0; plane < lhs.planes(); ++plane) {
488       for (int64_t depth = 0; depth < lhs.depth(); ++depth) {
489         for (int64_t height = 0; height < lhs.height(); ++height) {
490           for (int64_t width = 0; width < lhs.width(); ++width) {
491             (*result)(plane, depth, height, width) = map_function(
492                 lhs(plane, depth, height, width),
493                 rhs(plane, depth, height, width), plane, depth, height, width);
494           }
495         }
496       }
497     }
498     return result;
499   }
500 
501   // Returns the result of a 2D pad on an input matrix.
502   template <typename NativeT>
PadArray2D(const Array2D<NativeT> & operand,const PaddingConfig & padding,const NativeT pad)503   static std::unique_ptr<Array2D<NativeT>> PadArray2D(
504       const Array2D<NativeT>& operand, const PaddingConfig& padding,
505       const NativeT pad) {
506     int64_t in0 = operand.n1();
507     int64_t high_padding0 = padding.dimensions(0).edge_padding_high();
508     int64_t low_padding0 = padding.dimensions(0).edge_padding_low();
509     int64_t interior_padding0 = padding.dimensions(0).interior_padding();
510     int64_t out0 =
511         in0 + low_padding0 + high_padding0 + (in0 - 1) * interior_padding0;
512 
513     int64_t in1 = operand.n2();
514     int64_t high_padding1 = padding.dimensions(1).edge_padding_high();
515     int64_t low_padding1 = padding.dimensions(1).edge_padding_low();
516     int64_t interior_padding1 = padding.dimensions(1).interior_padding();
517     int64_t out1 =
518         in1 + low_padding1 + high_padding1 + (in1 - 1) * interior_padding1;
519 
520     auto result = absl::make_unique<Array2D<NativeT>>(out0, out1);
521     result->Fill(pad);
522     int64_t o0 = low_padding0;
523     for (int64_t i0 = 0; i0 < in0; ++i0) {
524       int64_t o1 = low_padding1;
525       for (int64_t i1 = 0; i1 < in1; ++i1) {
526         if (o0 >= 0 && o1 >= 0 && o0 < out0 && o1 < out1) {
527           (*result)(o0, o1) = operand(i0, i1);
528         }
529         o1 += interior_padding1 + 1;
530       }
531       o0 += interior_padding0 + 1;
532     }
533     return result;
534   }
535 
536   // Returns the result of a 3D pad on an input matrix.
537   template <typename NativeT>
PadArray3D(const Array3D<NativeT> & operand,const PaddingConfig & padding,const NativeT pad)538   static Array3D<NativeT> PadArray3D(const Array3D<NativeT>& operand,
539                                      const PaddingConfig& padding,
540                                      const NativeT pad) {
541     CHECK_EQ(padding.dimensions_size(), 3);
542 
543     const int64 input_bounds[] = {operand.n1(), operand.n2(), operand.n3()};
544     int64 pad_low[3];
545     int64 pad_high[3];
546     int64 pad_interior[3];
547     int64 output_bounds[3];
548     for (int64_t i = 0; i < 3; ++i) {
549       pad_low[i] = padding.dimensions(i).edge_padding_low();
550       pad_high[i] = padding.dimensions(i).edge_padding_high();
551       CHECK_LE(0, pad_low[i]);
552       CHECK_LE(0, pad_high[i]);
553       CHECK_LE(0, padding.dimensions(i).interior_padding())
554           << "not implemented";
555       pad_interior[i] = padding.dimensions(i).interior_padding();
556 
557       output_bounds[i] = pad_low[i] + input_bounds[i] + pad_high[i] +
558                          (input_bounds[i] - 1) * pad_interior[i];
559     }
560 
561     Array3D<NativeT> result(output_bounds[0], output_bounds[1],
562                             output_bounds[2]);
563     int indices[] = {0, 0, 0};
564     for (indices[0] = 0; indices[0] < output_bounds[0]; ++indices[0]) {
565       for (indices[1] = 0; indices[1] < output_bounds[1]; ++indices[1]) {
566         for (indices[2] = 0; indices[2] < output_bounds[2]; ++indices[2]) {
567           NativeT* value = &result(indices[0], indices[1], indices[2]);
568           bool value_padded = false;
569           for (int i = 0; i < 3; ++i) {
570             bool in_low_padding = indices[i] < pad_low[i];
571             bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i];
572             if (in_low_padding || in_high_padding) {
573               *value = pad;
574               value_padded = true;
575             }
576             if (pad_interior[i] &&
577                 (indices[i] - pad_low[i]) % (pad_interior[i] + 1)) {
578               *value = pad;
579               value_padded = true;
580             }
581           }
582           if (value_padded) {
583             continue;
584           }
585           *value = operand((indices[0] - pad_low[0]) / (pad_interior[0] + 1),
586                            (indices[1] - pad_low[1]) / (pad_interior[1] + 1),
587                            (indices[2] - pad_low[2]) / (pad_interior[2] + 1));
588         }
589       }
590     }
591     return result;
592   }
593 
594   // Returns the result of a 4D pad on an input array.
595   template <typename NativeT>
PadArray4D(const Array4D<NativeT> & operand,const PaddingConfig & padding,const NativeT pad)596   static Array4D<NativeT> PadArray4D(const Array4D<NativeT>& operand,
597                                      const PaddingConfig& padding,
598                                      const NativeT pad) {
599     CHECK_EQ(padding.dimensions_size(), 4);
600 
601     const int64 input_bounds[] = {operand.n1(), operand.n2(), operand.n3(),
602                                   operand.n4()};
603     int64 pad_low[4];
604     int64 pad_high[4];
605     int64 pad_interior[4];
606     int64 output_bounds[4];
607     for (int64_t i = 0; i < 4; ++i) {
608       pad_low[i] = padding.dimensions(i).edge_padding_low();
609       pad_high[i] = padding.dimensions(i).edge_padding_high();
610       CHECK_LE(0, padding.dimensions(i).interior_padding())
611           << "not implemented";
612       pad_interior[i] = padding.dimensions(i).interior_padding();
613 
614       output_bounds[i] = pad_low[i] + input_bounds[i] + pad_high[i] +
615                          (input_bounds[i] - 1) * pad_interior[i];
616     }
617 
618     Array4D<NativeT> result(output_bounds[0], output_bounds[1],
619                             output_bounds[2], output_bounds[3]);
620     result.Each(
621         [&](absl::Span<const int64> indices, NativeT* value) {
622           for (int i = 0; i < 4; ++i) {
623             bool in_low_padding = indices[i] < pad_low[i];
624             bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i];
625             if (in_low_padding || in_high_padding) {
626               *value = pad;
627               return;
628             }
629             if (pad_interior[i] &&
630                 (indices[i] - pad_low[i]) % (pad_interior[i] + 1)) {
631               *value = pad;
632               return;
633             }
634           }
635           *value = operand((indices[0] - pad_low[0]) / (pad_interior[0] + 1),
636                            (indices[1] - pad_low[1]) / (pad_interior[1] + 1),
637                            (indices[2] - pad_low[2]) / (pad_interior[2] + 1),
638                            (indices[3] - pad_low[3]) / (pad_interior[3] + 1));
639         });
640     return result;
641   }
642 
643   // ApplyElementwise2D(f, x, y, ...) returns the Array2D formed by running
644   // f(x[i], y[i], ...) for each array element in the Array2Ds x, y, ....
645   //
646   // The given arrays must have the same size and element type, and the return
647   // type of f must be implicitly convertible to the arrays' element type.
648   //
649   // Example usage:
650   //
651   //   Array2D<float> x, y, z = ...;
652   //   std::unique_ptr<Array2D> result = ReferenceUtil::ApplyElementwise2D(
653   //     [](float a, float b, float c) { return a * b + c; }, x, y, z);
654   //
655   template <typename F, typename T1, typename... Ts>
ApplyElementwise2D(F && f,const Array2D<T1> & array1,const Array2D<Ts> &...arrays)656   static std::unique_ptr<Array2D<T1>> ApplyElementwise2D(
657       F&& f, const Array2D<T1>& array1, const Array2D<Ts>&... arrays) {
658     AssertSameSize2D(array1, arrays...);
659     auto result = absl::make_unique<Array2D<T1>>(array1.n1(), array1.n2());
660     for (int64_t i = 0; i < array1.n1(); ++i) {
661       for (int64_t j = 0; j < array1.n2(); ++j) {
662         (*result)(i, j) = f(array1(i, j), arrays(i, j)...);
663       }
664     }
665     return result;
666   }
667 
668  private:
669   template <typename T1, typename T2, typename... Ts>
AssertSameSize2D(const Array2D<T1> & array1,const Array2D<T2> & array2,const Array2D<Ts> &...arrays)670   static void AssertSameSize2D(const Array2D<T1>& array1,
671                                const Array2D<T2>& array2,
672                                const Array2D<Ts>&... arrays) {
673     static_assert(std::is_same<T1, T2>::value, "Args must be same type.");
674     CHECK_EQ(array1.n1(), array2.n1());
675     CHECK_EQ(array1.n2(), array2.n2());
676     AssertSameSize2D(array2, arrays...);
677   }
678 
679   // Recursive base case for AssertSameSize2D.
680   template <typename Array1>
AssertSameSize2D(const Array1 & array1)681   static void AssertSameSize2D(const Array1& array1) {}
682 
683   TF_DISALLOW_COPY_AND_ASSIGN(ReferenceUtil);
684 };
685 
686 }  // namespace xla
687 
688 #endif  // TENSORFLOW_COMPILER_XLA_REFERENCE_UTIL_H_
689