• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_REFERENCE_UTIL_H_
17 #define TENSORFLOW_COMPILER_XLA_REFERENCE_UTIL_H_
18 
19 #include <array>
20 #include <functional>
21 #include <memory>
22 #include <utility>
23 #include <vector>
24 
25 #include "absl/memory/memory.h"
26 #include "absl/types/span.h"
27 #include "tensorflow/compiler/xla/array2d.h"
28 #include "tensorflow/compiler/xla/array3d.h"
29 #include "tensorflow/compiler/xla/array4d.h"
30 #include "tensorflow/compiler/xla/client/padding.h"
31 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
32 #include "tensorflow/compiler/xla/util.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/platform/macros.h"
35 #include "tensorflow/core/platform/types.h"
36 
37 namespace xla {
38 
39 // Utility class for reference implementations of linear algebra routines.
40 class ReferenceUtil {
41  public:
42   // Returns the result of a transpose operation on the input matrix.
43   template <typename T>
TransposeArray2D(const Array2D<T> & operand)44   static std::unique_ptr<Array2D<T>> TransposeArray2D(
45       const Array2D<T>& operand) {
46     auto result =
47         absl::make_unique<Array2D<T>>(operand.width(), operand.height());
48     for (int64 w = 0; w < operand.width(); ++w) {
49       for (int64 h = 0; h < operand.height(); ++h) {
50         (*result)(w, h) = operand(h, w);
51       }
52     }
53 
54     return result;
55   }
56 
57   // Returns the result of a matrix multiply `lhs x rhs`.
58   template <typename T>
MatmulArray2D(const Array2D<T> & lhs,const Array2D<T> & rhs)59   static std::unique_ptr<Array2D<T>> MatmulArray2D(const Array2D<T>& lhs,
60                                                    const Array2D<T>& rhs) {
61     return HloEvaluator::MatmulArray2D(lhs, rhs);
62   }
63 
64   // Converts the input operand to use f64 values instead of f32 values.
65   static std::unique_ptr<Array2D<double>> Array2DF32ToF64(
66       const Array2D<float>& input);
67 
68   // Returns the result of a convolution `lhs <conv> rhs`, with the default
69   // convolution dimension numbers returned from
70   // ComputationBuilder::CreateDefaultConvDimensionNumbers().
71   static std::unique_ptr<Array4D<float>> ConvArray4D(
72       const Array4D<float>& lhs, const Array4D<float>& rhs,
73       std::pair<int64, int64> kernel_stride, Padding padding);
74 
75   // Returns the result of a convolution `lhs <conv> rhs`, with the given
76   // convolution dimension numbers.
77   static std::unique_ptr<Array4D<float>> ConvArray4DGeneralDimensions(
78       const Array4D<float>& lhs, const Array4D<float>& rhs,
79       std::pair<int64, int64> kernel_stride, Padding padding,
80       ConvolutionDimensionNumbers dimension_numbers);
81 
82   // Returns the result of a convolution `lhs <conv> rhs`, with the given
83   // dilation factors.
84   static std::unique_ptr<Array4D<float>> ConvArray4DGeneralDimensionsDilated(
85       const Array4D<float>& lhs, const Array4D<float>& rhs,
86       std::pair<int64, int64> kernel_stride, Padding padding,
87       std::pair<int64, int64> lhs_dilation,
88       std::pair<int64, int64> rhs_dilation, ConvolutionDimensionNumbers dnums);
89 
90   // Returns the result of a convolution `lhs <conv> rhs`, with the default
91   // convolution dimension numbers returned from
92   // ComputationBuilder::CreateDefaultConvDimensionNumbers().
93   static std::unique_ptr<Array3D<float>> ConvArray3D(const Array3D<float>& lhs,
94                                                      const Array3D<float>& rhs,
95                                                      int64 kernel_stride,
96                                                      Padding padding);
97 
98   // Returns the result of a convolution `lhs <conv> rhs`.
99   static std::unique_ptr<Array3D<float>> ConvArray3DGeneralDimensionsDilated(
100       const Array3D<float>& lhs, const Array3D<float>& rhs, int64 kernel_stride,
101       Padding padding, int64 lhs_dilation, int64 rhs_dilation,
102       const ConvolutionDimensionNumbers& dnums);
103 
104   // Returns the result of a separable  convolution with the given parameters.
105   // kernel_stride and padding applies to the depthwise convolution during
106   // the separable convolution. pointwise_weights.depth() must be equal to
107   // input.depth() * depthwise_weights.planes().
108   static std::unique_ptr<Array4D<float>> SeparableConvArray4D(
109       const Array4D<float>& input, const Array4D<float>& depthwise_weights,
110       const Array4D<float>& pointwise_weights,
111       std::pair<int64, int64> kernel_stride, Padding padding);
112 
113   // Returns the result of reducing a matrix to a column vector. init is the
114   // initial value for the reduce operation, and reduce_function is the function
115   // to apply for each reduction step.
116   static std::unique_ptr<std::vector<float>> ReduceToColArray2D(
117       const Array2D<float>& matrix, float init,
118       const std::function<float(float, float)>& reduce_function);
119 
120   // Returns the result of reducing a matrix to a row vector. init is the
121   // initial value for the reduce operation, and reduce_function is the function
122   // to apply for each reduction step.
123   static std::unique_ptr<std::vector<float>> ReduceToRowArray2D(
124       const Array2D<float>& matrix, float init,
125       const std::function<float(float, float)>& reduce_function);
126 
127   // Performs a R2=>R1 reduction by reducing away the dimension specified in
128   // 'dimension_to_reduce'.
129   template <typename T>
ReduceR2ToR1(const Array2D<T> & input,int dimension_to_reduce,T init,const std::function<T (T,T)> & freduce)130   static std::vector<T> ReduceR2ToR1(const Array2D<T>& input,
131                                      int dimension_to_reduce, T init,
132                                      const std::function<T(T, T)>& freduce) {
133     std::vector<T> result(dimension_to_reduce == 0 ? input.n2() : input.n1(),
134                           init);
135     for (int i0 = 0; i0 < input.n1(); ++i0) {
136       for (int i1 = 0; i1 < input.n2(); ++i1) {
137         int output = dimension_to_reduce == 0 ? i1 : i0;
138         result[output] = freduce(result[output], input(i0, i1));
139       }
140     }
141     return result;
142   }
143 
144   // Returns the result of reducing the 4D array to a vector, reducing away
145   // the dimensions specified in dims.
146   static std::vector<float> Reduce4DTo1D(
147       const Array4D<float>& array, float init, absl::Span<const int64> dims,
148       const std::function<float(float, float)>& reduce_function);
149 
150   // Broadcast 1D dimension to 4D, from the dimension `broadcast_from_dim`.
151   static std::unique_ptr<Array4D<float>> Broadcast1DTo4D(
152       const std::vector<float>& array, const std::vector<int64>& bounds,
153       int64 broadcast_from_dim);
154 
155   // Returns the result of reducing the 3D array to a 2D array, reducing away
156   // the dimensions specified in dims.
157   static std::unique_ptr<Array2D<float>> Reduce3DTo2D(
158       const Array3D<float>& array, float init, absl::Span<const int64> dims,
159       const std::function<float(float, float)>& reduce_function);
160 
161   // Applies map_function to each element in the input (2D array) and returns
162   // the result.
163   static std::unique_ptr<Array2D<float>> MapArray2D(
164       const Array2D<float>& matrix,
165       const std::function<float(float)>& map_function);
166 
167   // Applies map_function to each pair of corresponding elements in the two
168   // inputs arrays and returns the result.
169   static std::unique_ptr<Array2D<float>> MapArray2D(
170       const Array2D<float>& lhs, const Array2D<float>& rhs,
171       const std::function<float(float, float)>& map_function);
172 
173   // Number of windows in a given dimension. Calculation taken from
174   // xla::MakePadding().
175   static int64 WindowCount(int64 unpadded_width, int64 window_len, int64 stride,
176                            Padding padding);
177 
178   // Windowed reductions with Add as the function to apply.
179   static std::unique_ptr<std::vector<float>> ReduceWindow1DAdd(
180       absl::Span<const float> operand, float init,
181       absl::Span<const int64> window, absl::Span<const int64> stride,
182       Padding padding);
183   static std::unique_ptr<Array3D<float>> ReduceWindow3DAdd(
184       const Array3D<float>& operand, float init, absl::Span<const int64> window,
185       absl::Span<const int64> stride, Padding padding);
186   static std::unique_ptr<Array4D<float>> ReduceWindow4DAdd(
187       const Array4D<float>& operand, float init, absl::Span<const int64> window,
188       absl::Span<const int64> stride, Padding padding);
189 
190   // Windowed reductions with a generic reduce function.
191   static std::unique_ptr<std::vector<float>> ReduceWindow1DGeneric(
192       absl::Span<const float> operand, float init,
193       const std::function<float(float, float)>& reduce_func,
194       absl::Span<const int64> window, absl::Span<const int64> stride,
195       absl::Span<const std::pair<int64, int64>> padding);
196   static std::unique_ptr<Array4D<float>> ReduceWindow4DGeneric(
197       const Array4D<float>& operand, float init,
198       const std::function<float(float, float)>& reduce_func,
199       absl::Span<const int64> window, absl::Span<const int64> stride,
200       Padding padding);
201   // With arbitrary padding.
202   static std::unique_ptr<Array4D<float>> ReduceWindow4DGeneric(
203       const Array4D<float>& operand, float init,
204       const std::function<float(float, float)>& reduce_func,
205       absl::Span<const int64> window, absl::Span<const int64> stride,
206       absl::Span<const std::pair<int64, int64>> padding);
207 
208   // Batch normalize data.
209   static std::unique_ptr<Array4D<float>> BatchNorm4D(
210       const Array4D<float>& input, const Array4D<float>& mean,
211       const Array4D<float>& var, const Array4D<float>& scale,
212       const Array4D<float>& offset, float epsilon);
213 
214   // Performs select and scatter with Greater Than or equal as the select, plus
215   // as the scatter, and Same Padding.
216   // TODO(b/74533103) Switch tests to evaluator and remove this implementation.
217   static std::unique_ptr<Array4D<float>> SelectAndScatter4DGePlus(
218       const Array4D<float>& operand, const Array4D<float>& source, float init,
219       absl::Span<const int64> window, absl::Span<const int64> stride,
220       bool same_padding);
221 
222   // Concatenates the lhs and rhs arrays along the concatenate_dimension.
223   // E.g. if concatenate_dimension is 0, the "n1"/height dimension is
224   // concatenated, so the arrays are stacked on top of each other.
225   template <typename T>
Concat2D(const Array2D<T> & lhs,const Array2D<T> & rhs,int concatenate_dimension)226   static std::unique_ptr<Array2D<T>> Concat2D(const Array2D<T>& lhs,
227                                               const Array2D<T>& rhs,
228                                               int concatenate_dimension) {
229     CHECK(0 <= concatenate_dimension && concatenate_dimension < 2);
230     auto result = absl::make_unique<Array2D<T>>(
231         concatenate_dimension == 0 ? lhs.n1() + rhs.n1() : lhs.n1(),
232         concatenate_dimension == 1 ? lhs.n2() + rhs.n2() : lhs.n2());
233     for (int64 i0 = 0; i0 < result->n1(); ++i0) {
234       for (int64 i1 = 0; i1 < result->n2(); ++i1) {
235         // If we exceed the bounds of the LHS, draw from the RHS, where the
236         // result index is adjusted by the number of values present in the LHS.
237         (*result)(i0, i1) = i0 < lhs.n1() && i1 < lhs.n2()
238                                 ? lhs(i0, i1)
239                                 : rhs(i0 >= lhs.n1() ? i0 - lhs.n1() : i0,
240                                       i1 >= lhs.n2() ? i1 - lhs.n2() : i1);
241       }
242     }
243     return result;
244   }
245 
246   // Concatenates the lhs and rhs 3D arrays along the concatenate_dimension. lhs
247   // and rhs must have the same dimensions except for the concatenate dimension.
248   template <typename T>
Concat3D(const Array3D<T> & lhs,const Array3D<T> & rhs,int concatenate_dimension)249   static std::unique_ptr<Array3D<T>> Concat3D(const Array3D<T>& lhs,
250                                               const Array3D<T>& rhs,
251                                               int concatenate_dimension) {
252     CHECK(0 <= concatenate_dimension && concatenate_dimension < 3);
253     const int64 lhs_dims[] = {lhs.n1(), lhs.n2(), lhs.n3()};
254     const int64 rhs_dims[] = {rhs.n1(), rhs.n2(), rhs.n3()};
255     int64 out_dims[] = {rhs.n1(), rhs.n2(), rhs.n3()};
256     for (int i = 0; i < 3; ++i) {
257       if (i != concatenate_dimension) {
258         out_dims[i] = lhs_dims[i];
259         CHECK_EQ(lhs_dims[i], rhs_dims[i]);
260       } else {
261         out_dims[i] = lhs_dims[i] + rhs_dims[i];
262       }
263     }
264     auto result =
265         absl::make_unique<Array3D<T>>(out_dims[0], out_dims[1], out_dims[2]);
266     for (int64 i0 = 0; i0 < result->n1(); ++i0) {
267       for (int64 i1 = 0; i1 < result->n2(); ++i1) {
268         for (int64 i2 = 0; i2 < result->n3(); ++i2) {
269           (*result)(i0, i1, i2) =
270               i0 < lhs.n1() && i1 < lhs.n2() && i2 < lhs.n3()
271                   ? lhs(i0, i1, i2)
272                   : rhs(i0 >= lhs.n1() ? i0 - lhs.n1() : i0,
273                         i1 >= lhs.n2() ? i1 - lhs.n2() : i1,
274                         i2 >= lhs.n3() ? i2 - lhs.n3() : i2);
275         }
276       }
277     }
278     return result;
279   }
280 
281   // Concatenates the lhs and rhs 4D arrays along the concatenate_dimension. lhs
282   // and rhs must have the same dimensions except for the concatenate dimension.
283   template <typename T>
Concat4D(const Array4D<T> & lhs,const Array4D<T> & rhs,int concatenate_dimension)284   static std::unique_ptr<Array4D<T>> Concat4D(const Array4D<T>& lhs,
285                                               const Array4D<T>& rhs,
286                                               int concatenate_dimension) {
287     CHECK(0 <= concatenate_dimension && concatenate_dimension < 4);
288     const int64 lhs_dims[] = {lhs.n1(), lhs.n2(), lhs.n3(), lhs.n4()};
289     const int64 rhs_dims[] = {rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()};
290     int64 out_dims[] = {rhs.n1(), rhs.n2(), rhs.n3(), rhs.n4()};
291     for (int i = 0; i < 4; ++i) {
292       if (i != concatenate_dimension) {
293         out_dims[i] = lhs_dims[i];
294         CHECK_EQ(lhs_dims[i], rhs_dims[i]);
295       } else {
296         out_dims[i] = lhs_dims[i] + rhs_dims[i];
297       }
298     }
299     auto result = absl::make_unique<Array4D<T>>(out_dims[0], out_dims[1],
300                                                 out_dims[2], out_dims[3]);
301     for (int64 i0 = 0; i0 < result->n1(); ++i0) {
302       for (int64 i1 = 0; i1 < result->n2(); ++i1) {
303         for (int64 i2 = 0; i2 < result->n3(); ++i2) {
304           for (int64 i3 = 0; i3 < result->n4(); ++i3) {
305             (*result)(i0, i1, i2, i3) =
306                 i0 < lhs.n1() && i1 < lhs.n2() && i2 < lhs.n3() && i3 < lhs.n4()
307                     ? lhs(i0, i1, i2, i3)
308                     : rhs(i0 >= lhs.n1() ? i0 - lhs.n1() : i0,
309                           i1 >= lhs.n2() ? i1 - lhs.n2() : i1,
310                           i2 >= lhs.n3() ? i2 - lhs.n3() : i2,
311                           i3 >= lhs.n4() ? i3 - lhs.n4() : i3);
312           }
313         }
314       }
315     }
316     return result;
317   }
318 
319   // Slices with index clamping
320   template <typename T>
ClampSlice1D(absl::Span<const T> input,int64 start,int64 size)321   static std::vector<T> ClampSlice1D(absl::Span<const T> input, int64 start,
322                                      int64 size) {
323     start = std::min<int64>(std::max<int64>(0, start), input.size() - size);
324     std::vector<T> result;
325     for (int64 i = 0; i < size; ++i) {
326       result.push_back(input[(start + i)]);
327     }
328     return result;
329   }
330 
331   // Slices the input array given starting indices, limit indices, and strides
332   // in each dimension.
333   template <typename T>
Slice2D(const Array2D<T> & input,std::array<int64,2> starts,std::array<int64,2> limits,std::array<int64,2> strides)334   static std::unique_ptr<Array2D<T>> Slice2D(const Array2D<T>& input,
335                                              std::array<int64, 2> starts,
336                                              std::array<int64, 2> limits,
337                                              std::array<int64, 2> strides) {
338     CHECK_LE(starts[0], input.n1());
339     CHECK_LE(starts[1], input.n2());
340     CHECK_LE(limits[0], input.n1());
341     CHECK_LE(limits[1], input.n2());
342     CHECK_GE(strides[0], 1);
343     CHECK_GE(strides[1], 1);
344     auto result = absl::make_unique<Array2D<T>>(
345         CeilOfRatio(limits[0] - starts[0], strides[0]),
346         CeilOfRatio(limits[1] - starts[1], strides[1]));
347     for (int64 i0 = 0; i0 < result->n1(); ++i0) {
348       for (int64 i1 = 0; i1 < result->n2(); ++i1) {
349         (*result)(i0, i1) =
350             input(starts[0] + i0 * strides[0], starts[1] + i1 * strides[1]);
351       }
352     }
353     return result;
354   }
355 
356   template <typename T>
Slice3D(const Array3D<T> & input,std::array<int64,3> starts,std::array<int64,3> limits,std::array<int64,3> strides)357   static std::unique_ptr<Array3D<T>> Slice3D(const Array3D<T>& input,
358                                              std::array<int64, 3> starts,
359                                              std::array<int64, 3> limits,
360                                              std::array<int64, 3> strides) {
361     CHECK_LE(starts[0], input.n1());
362     CHECK_LE(starts[1], input.n2());
363     CHECK_LE(starts[2], input.n3());
364     CHECK_LE(limits[0], input.n1());
365     CHECK_LE(limits[1], input.n2());
366     CHECK_LE(limits[2], input.n3());
367     CHECK_GE(strides[0], 1);
368     CHECK_GE(strides[1], 1);
369     CHECK_GE(strides[2], 1);
370     auto result = absl::make_unique<Array3D<T>>(
371         CeilOfRatio(limits[0] - starts[0], strides[0]),
372         CeilOfRatio(limits[1] - starts[1], strides[1]),
373         CeilOfRatio(limits[2] - starts[2], strides[2]));
374 
375     for (int64 i0 = 0; i0 < result->n1(); ++i0) {
376       for (int64 i1 = 0; i1 < result->n2(); ++i1) {
377         for (int64 i2 = 0; i2 < result->n3(); ++i2) {
378           (*result)(i0, i1, i2) =
379               input(starts[0] + i0 * strides[0], starts[1] + i1 * strides[1],
380                     starts[2] + i2 * strides[2]);
381         }
382       }
383     }
384     return result;
385   }
386 
387   template <typename T>
Slice4D(const Array4D<T> & input,std::array<int64,4> starts,std::array<int64,4> limits,std::array<int64,4> strides)388   static std::unique_ptr<Array4D<T>> Slice4D(const Array4D<T>& input,
389                                              std::array<int64, 4> starts,
390                                              std::array<int64, 4> limits,
391                                              std::array<int64, 4> strides) {
392     CHECK_LE(starts[0], input.n1());
393     CHECK_LE(starts[1], input.n2());
394     CHECK_LE(starts[2], input.n3());
395     CHECK_LE(starts[3], input.n4());
396     CHECK_LE(limits[0], input.n1());
397     CHECK_LE(limits[1], input.n2());
398     CHECK_LE(limits[2], input.n3());
399     CHECK_LE(limits[3], input.n4());
400     CHECK_GE(strides[0], 1);
401     CHECK_GE(strides[1], 1);
402     CHECK_GE(strides[2], 1);
403     CHECK_GE(strides[3], 1);
404     auto result = absl::make_unique<Array4D<T>>(
405         CeilOfRatio(limits[0] - starts[0], strides[0]),
406         CeilOfRatio(limits[1] - starts[1], strides[1]),
407         CeilOfRatio(limits[2] - starts[2], strides[2]),
408         CeilOfRatio(limits[3] - starts[3], strides[3]));
409     for (int64 i0 = 0; i0 < result->n1(); ++i0) {
410       for (int64 i1 = 0; i1 < result->n2(); ++i1) {
411         for (int64 i2 = 0; i2 < result->n3(); ++i2) {
412           for (int64 i3 = 0; i3 < result->n4(); ++i3) {
413             (*result)(i0, i1, i2, i3) =
414                 input(starts[0] + i0 * strides[0], starts[1] + i1 * strides[1],
415                       starts[2] + i2 * strides[2], starts[3] + i3 * strides[3]);
416           }
417         }
418       }
419     }
420     return result;
421   }
422 
423   // Applies map_function to each element in the input (2D array) and returns
424   // the result.
425   // (row, column) index of each element is also provided as arguments to
426   // map_function.
427   static std::unique_ptr<Array2D<float>> MapWithIndexArray2D(
428       const Array2D<float>& matrix,
429       const std::function<float(float, int64, int64)>& map_function);
430 
431   // Applies map_function to each element in the input (4D array) and returns
432   // the result.
433   template <typename F>
MapArray4D(const Array4D<float> & input,F && map_function)434   static std::unique_ptr<Array4D<float>> MapArray4D(const Array4D<float>& input,
435                                                     F&& map_function) {
436     return MapWithIndexArray4D(input,
437                                [&](float value, int64, int64, int64, int64) {
438                                  return map_function(value);
439                                });
440   }
441 
442   // Applies map_function to each element in the input (4D array) and returns
443   // the result.
444   // (plane, depth, height, width) index of each element is also provided as
445   // arguments to map_function.
446   template <typename F>
MapWithIndexArray4D(const Array4D<float> & input,F && map_function)447   static std::unique_ptr<Array4D<float>> MapWithIndexArray4D(
448       const Array4D<float>& input, F&& map_function) {
449     auto result = absl::make_unique<Array4D<float>>(
450         input.planes(), input.depth(), input.height(), input.width());
451     for (int64 plane = 0; plane < input.planes(); ++plane) {
452       for (int64 depth = 0; depth < input.depth(); ++depth) {
453         for (int64 height = 0; height < input.height(); ++height) {
454           for (int64 width = 0; width < input.width(); ++width) {
455             (*result)(plane, depth, height, width) =
456                 map_function(input(plane, depth, height, width), plane, depth,
457                              height, width);
458           }
459         }
460       }
461     }
462     return result;
463   }
464 
465   // Applies map_function to each pair of elements in the input lhs and rhs
466   // (4D array) and returns the result.
467   template <typename F>
MapArray4D(const Array4D<float> & lhs,const Array4D<float> & rhs,F && map_function)468   static std::unique_ptr<Array4D<float>> MapArray4D(const Array4D<float>& lhs,
469                                                     const Array4D<float>& rhs,
470                                                     F&& map_function) {
471     return MapWithIndexArray4D(
472         lhs, rhs, [&](float lhs, float rhs, int64, int64, int64, int64) {
473           return map_function(lhs, rhs);
474         });
475   }
476 
477   // Applies map_function to each pair of element in lhs and rhs (4D array) and
478   // returns the result.
479   // (plane, depth, height, width) index of each element is also provided as
480   // arguments to map_function.
481   template <typename F>
MapWithIndexArray4D(const Array4D<float> & lhs,const Array4D<float> & rhs,F && map_function)482   static std::unique_ptr<Array4D<float>> MapWithIndexArray4D(
483       const Array4D<float>& lhs, const Array4D<float>& rhs, F&& map_function) {
484     auto result = absl::make_unique<Array4D<float>>(lhs.planes(), lhs.depth(),
485                                                     lhs.height(), lhs.width());
486     for (int64 plane = 0; plane < lhs.planes(); ++plane) {
487       for (int64 depth = 0; depth < lhs.depth(); ++depth) {
488         for (int64 height = 0; height < lhs.height(); ++height) {
489           for (int64 width = 0; width < lhs.width(); ++width) {
490             (*result)(plane, depth, height, width) = map_function(
491                 lhs(plane, depth, height, width),
492                 rhs(plane, depth, height, width), plane, depth, height, width);
493           }
494         }
495       }
496     }
497     return result;
498   }
499 
500   // Returns the result of a 2D pad on an input matrix.
501   template <typename NativeT>
PadArray2D(const Array2D<NativeT> & operand,const PaddingConfig & padding,const NativeT pad)502   static std::unique_ptr<Array2D<NativeT>> PadArray2D(
503       const Array2D<NativeT>& operand, const PaddingConfig& padding,
504       const NativeT pad) {
505     int64 in0 = operand.n1();
506     int64 high_padding0 = padding.dimensions(0).edge_padding_high();
507     int64 low_padding0 = padding.dimensions(0).edge_padding_low();
508     int64 interior_padding0 = padding.dimensions(0).interior_padding();
509     int64 out0 =
510         in0 + low_padding0 + high_padding0 + (in0 - 1) * interior_padding0;
511 
512     int64 in1 = operand.n2();
513     int64 high_padding1 = padding.dimensions(1).edge_padding_high();
514     int64 low_padding1 = padding.dimensions(1).edge_padding_low();
515     int64 interior_padding1 = padding.dimensions(1).interior_padding();
516     int64 out1 =
517         in1 + low_padding1 + high_padding1 + (in1 - 1) * interior_padding1;
518 
519     auto result = absl::make_unique<Array2D<NativeT>>(out0, out1);
520     result->Fill(pad);
521     int64 o0 = low_padding0;
522     for (int64 i0 = 0; i0 < in0; ++i0) {
523       int64 o1 = low_padding1;
524       for (int64 i1 = 0; i1 < in1; ++i1) {
525         if (o0 >= 0 && o1 >= 0 && o0 < out0 && o1 < out1) {
526           (*result)(o0, o1) = operand(i0, i1);
527         }
528         o1 += interior_padding1 + 1;
529       }
530       o0 += interior_padding0 + 1;
531     }
532     return result;
533   }
534 
535   // Returns the result of a 3D pad on an input matrix.
536   template <typename NativeT>
PadArray3D(const Array3D<NativeT> & operand,const PaddingConfig & padding,const NativeT pad)537   static Array3D<NativeT> PadArray3D(const Array3D<NativeT>& operand,
538                                      const PaddingConfig& padding,
539                                      const NativeT pad) {
540     CHECK_EQ(padding.dimensions_size(), 3);
541 
542     const int64 input_bounds[] = {operand.n1(), operand.n2(), operand.n3()};
543     int64 pad_low[3];
544     int64 pad_high[3];
545     int64 pad_interior[3];
546     int64 output_bounds[3];
547     for (int64 i = 0; i < 3; ++i) {
548       pad_low[i] = padding.dimensions(i).edge_padding_low();
549       pad_high[i] = padding.dimensions(i).edge_padding_high();
550       CHECK_LE(0, pad_low[i]);
551       CHECK_LE(0, pad_high[i]);
552       CHECK_LE(0, padding.dimensions(i).interior_padding())
553           << "not implemented";
554       pad_interior[i] = padding.dimensions(i).interior_padding();
555 
556       output_bounds[i] = pad_low[i] + input_bounds[i] + pad_high[i] +
557                          (input_bounds[i] - 1) * pad_interior[i];
558     }
559 
560     Array3D<NativeT> result(output_bounds[0], output_bounds[1],
561                             output_bounds[2]);
562     int indices[] = {0, 0, 0};
563     for (indices[0] = 0; indices[0] < output_bounds[0]; ++indices[0]) {
564       for (indices[1] = 0; indices[1] < output_bounds[1]; ++indices[1]) {
565         for (indices[2] = 0; indices[2] < output_bounds[2]; ++indices[2]) {
566           NativeT* value = &result(indices[0], indices[1], indices[2]);
567           bool value_padded = false;
568           for (int i = 0; i < 3; ++i) {
569             bool in_low_padding = indices[i] < pad_low[i];
570             bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i];
571             if (in_low_padding || in_high_padding) {
572               *value = pad;
573               value_padded = true;
574             }
575             if (pad_interior[i] &&
576                 (indices[i] - pad_low[i]) % (pad_interior[i] + 1)) {
577               *value = pad;
578               value_padded = true;
579             }
580           }
581           if (value_padded) {
582             continue;
583           }
584           *value = operand((indices[0] - pad_low[0]) / (pad_interior[0] + 1),
585                            (indices[1] - pad_low[1]) / (pad_interior[1] + 1),
586                            (indices[2] - pad_low[2]) / (pad_interior[2] + 1));
587         }
588       }
589     }
590     return result;
591   }
592 
593   // Returns the result of a 4D pad on an input array.
594   template <typename NativeT>
PadArray4D(const Array4D<NativeT> & operand,const PaddingConfig & padding,const NativeT pad)595   static Array4D<NativeT> PadArray4D(const Array4D<NativeT>& operand,
596                                      const PaddingConfig& padding,
597                                      const NativeT pad) {
598     CHECK_EQ(padding.dimensions_size(), 4);
599 
600     const int64 input_bounds[] = {operand.n1(), operand.n2(), operand.n3(),
601                                   operand.n4()};
602     int64 pad_low[4];
603     int64 pad_high[4];
604     int64 pad_interior[4];
605     int64 output_bounds[4];
606     for (int64 i = 0; i < 4; ++i) {
607       pad_low[i] = padding.dimensions(i).edge_padding_low();
608       pad_high[i] = padding.dimensions(i).edge_padding_high();
609       CHECK_LE(0, padding.dimensions(i).interior_padding())
610           << "not implemented";
611       pad_interior[i] = padding.dimensions(i).interior_padding();
612 
613       output_bounds[i] = pad_low[i] + input_bounds[i] + pad_high[i] +
614                          (input_bounds[i] - 1) * pad_interior[i];
615     }
616 
617     Array4D<NativeT> result(output_bounds[0], output_bounds[1],
618                             output_bounds[2], output_bounds[3]);
619     result.Each(
620         [&](absl::Span<const int64> indices, NativeT* value) {
621           for (int i = 0; i < 4; ++i) {
622             bool in_low_padding = indices[i] < pad_low[i];
623             bool in_high_padding = indices[i] >= output_bounds[i] - pad_high[i];
624             if (in_low_padding || in_high_padding) {
625               *value = pad;
626               return;
627             }
628             if (pad_interior[i] &&
629                 (indices[i] - pad_low[i]) % (pad_interior[i] + 1)) {
630               *value = pad;
631               return;
632             }
633           }
634           *value = operand((indices[0] - pad_low[0]) / (pad_interior[0] + 1),
635                            (indices[1] - pad_low[1]) / (pad_interior[1] + 1),
636                            (indices[2] - pad_low[2]) / (pad_interior[2] + 1),
637                            (indices[3] - pad_low[3]) / (pad_interior[3] + 1));
638         });
639     return result;
640   }
641 
642   // ApplyElementwise2D(f, x, y, ...) returns the Array2D formed by running
643   // f(x[i], y[i], ...) for each array element in the Array2Ds x, y, ....
644   //
645   // The given arrays must have the same size and element type, and the return
646   // type of f must be implicitly convertible to the arrays' element type.
647   //
648   // Example usage:
649   //
650   //   Array2D<float> x, y, z = ...;
651   //   std::unique_ptr<Array2D> result = ReferenceUtil::ApplyElementwise2D(
652   //     [](float a, float b, float c) { return a * b + c; }, x, y, z);
653   //
654   template <typename F, typename T1, typename... Ts>
ApplyElementwise2D(F && f,const Array2D<T1> & array1,const Array2D<Ts> &...arrays)655   static std::unique_ptr<Array2D<T1>> ApplyElementwise2D(
656       F&& f, const Array2D<T1>& array1, const Array2D<Ts>&... arrays) {
657     AssertSameSize2D(array1, arrays...);
658     auto result = absl::make_unique<Array2D<T1>>(array1.n1(), array1.n2());
659     for (int64 i = 0; i < array1.n1(); ++i) {
660       for (int64 j = 0; j < array1.n2(); ++j) {
661         (*result)(i, j) = f(array1(i, j), arrays(i, j)...);
662       }
663     }
664     return result;
665   }
666 
667  private:
668   template <typename T1, typename T2, typename... Ts>
AssertSameSize2D(const Array2D<T1> & array1,const Array2D<T2> & array2,const Array2D<Ts> &...arrays)669   static void AssertSameSize2D(const Array2D<T1>& array1,
670                                const Array2D<T2>& array2,
671                                const Array2D<Ts>&... arrays) {
672     static_assert(std::is_same<T1, T2>::value, "Args must be same type.");
673     CHECK_EQ(array1.n1(), array2.n1());
674     CHECK_EQ(array1.n2(), array2.n2());
675     AssertSameSize2D(array2, arrays...);
676   }
677 
678   // Recursive base case for AssertSameSize2D.
679   template <typename Array1>
AssertSameSize2D(const Array1 & array1)680   static void AssertSameSize2D(const Array1& array1) {}
681 
682   TF_DISALLOW_COPY_AND_ASSIGN(ReferenceUtil);
683 };
684 
685 }  // namespace xla
686 
687 #endif  // TENSORFLOW_COMPILER_XLA_REFERENCE_UTIL_H_
688