• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/reference_util.h"
17 
18 #include <array>
19 #include <utility>
20 
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/memory/memory.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/literal_util.h"
25 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/shape_inference.h"
28 #include "tensorflow/compiler/xla/window_util.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/core/lib/math/math_util.h"
31 #include "tensorflow/core/platform/logging.h"
32 
33 namespace xla {
34 
Array2DF32ToF64(const Array2D<float> & input)35 /* static */ std::unique_ptr<Array2D<double>> ReferenceUtil::Array2DF32ToF64(
36     const Array2D<float>& input) {
37   auto result =
38       absl::make_unique<Array2D<double>>(input.height(), input.width());
39   for (int64_t rowno = 0; rowno < input.height(); ++rowno) {
40     for (int64_t colno = 0; colno < input.height(); ++colno) {
41       (*result)(rowno, colno) = input(rowno, colno);
42     }
43   }
44   return result;
45 }
46 
ConvArray3D(const Array3D<float> & lhs,const Array3D<float> & rhs,int64_t kernel_stride,Padding padding)47 /*  static */ std::unique_ptr<Array3D<float>> ReferenceUtil::ConvArray3D(
48     const Array3D<float>& lhs, const Array3D<float>& rhs, int64_t kernel_stride,
49     Padding padding) {
50   return ConvArray3DGeneralDimensionsDilated(
51       lhs, rhs, kernel_stride, padding, 1, 1,
52       XlaBuilder::CreateDefaultConvDimensionNumbers(1));
53 }
54 
55 /*static*/ std::unique_ptr<Array3D<float>>
ConvArray3DGeneralDimensionsDilated(const Array3D<float> & lhs,const Array3D<float> & rhs,int64_t kernel_stride,Padding padding,int64_t lhs_dilation,int64_t rhs_dilation,const ConvolutionDimensionNumbers & dnums)56 ReferenceUtil::ConvArray3DGeneralDimensionsDilated(
57     const Array3D<float>& lhs, const Array3D<float>& rhs, int64_t kernel_stride,
58     Padding padding, int64_t lhs_dilation, int64_t rhs_dilation,
59     const ConvolutionDimensionNumbers& dnums) {
60   CHECK_EQ(dnums.input_spatial_dimensions_size(), 1);
61   CHECK_EQ(dnums.kernel_spatial_dimensions_size(), 1);
62   CHECK_EQ(dnums.output_spatial_dimensions_size(), 1);
63   // Reuse the code for Array4D-convolution by extending the 3D input into a 4D
64   // array by adding a fourth dummy dimension of size 1 without stride, padding
65   // and dilation.
66   Array4D<float> a4dlhs(lhs.n1(), lhs.n2(), lhs.n3(), 1);
67   a4dlhs.Each([&](absl::Span<const int64> indices, float* value_ptr) {
68     CHECK_EQ(indices[3], 0);
69     *value_ptr = lhs.operator()(indices[0], indices[1], indices[2]);
70   });
71   Array4D<float> a4drhs(rhs.n1(), rhs.n2(), rhs.n3(), 1);
72   a4drhs.Each([&](absl::Span<const int64> indices, float* value_ptr) {
73     CHECK_EQ(indices[3], 0);
74     *value_ptr = rhs.operator()(indices[0], indices[1], indices[2]);
75   });
76   // Add a second dummy spatial dimensions.
77   ConvolutionDimensionNumbers dnums2d = dnums;
78   dnums2d.add_input_spatial_dimensions(3);
79   dnums2d.add_kernel_spatial_dimensions(3);
80   dnums2d.add_output_spatial_dimensions(3);
81   std::unique_ptr<Array4D<float>> convr4 = ConvArray4DGeneralDimensionsDilated(
82       a4dlhs, a4drhs, {kernel_stride, 1}, padding, {lhs_dilation, 1},
83       {rhs_dilation, 1}, dnums2d);
84 
85   auto convr3 = absl::make_unique<Array3D<float>>(
86       convr4->planes(), convr4->depth(), convr4->height());
87   convr4->Each([&](absl::Span<const int64> indices, float* value_ptr) {
88     CHECK_EQ(indices[3], 0);
89     convr3->operator()(indices[0], indices[1], indices[2]) = *value_ptr;
90   });
91   return convr3;
92 }
93 
ConvArray4D(const Array4D<float> & lhs,const Array4D<float> & rhs,std::pair<int64,int64> kernel_stride,Padding padding)94 /* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ConvArray4D(
95     const Array4D<float>& lhs, const Array4D<float>& rhs,
96     std::pair<int64, int64> kernel_stride, Padding padding) {
97   return ConvArray4DGeneralDimensions(
98       lhs, rhs, kernel_stride, padding,
99       XlaBuilder::CreateDefaultConvDimensionNumbers());
100 }
101 
102 /* static */ std::unique_ptr<Array4D<float>>
SeparableConvArray4D(const Array4D<float> & input,const Array4D<float> & depthwise_weights,const Array4D<float> & pointwise_weights,std::pair<int64,int64> kernel_stride,Padding padding)103 ReferenceUtil::SeparableConvArray4D(const Array4D<float>& input,
104                                     const Array4D<float>& depthwise_weights,
105                                     const Array4D<float>& pointwise_weights,
106                                     std::pair<int64, int64> kernel_stride,
107                                     Padding padding) {
108   const int64_t depth_multiplier = depthwise_weights.planes();
109   CHECK_EQ(pointwise_weights.depth(), input.depth() * depth_multiplier);
110 
111   // Combine the two weights by reducing the depth_multiplier, so that we can
112   // apply a single convolution on the combined weights.
113   Array4D<float> weights(pointwise_weights.planes(), input.depth(),
114                          depthwise_weights.height(), depthwise_weights.width());
115   for (int64_t kx = 0; kx < depthwise_weights.width(); ++kx) {
116     for (int64_t ky = 0; ky < depthwise_weights.height(); ++ky) {
117       for (int64_t kz = 0; kz < input.depth(); ++kz) {
118         for (int64_t out = 0; out < pointwise_weights.planes(); ++out) {
119           float weight = 0.0;
120           for (int64_t depth = 0; depth < depth_multiplier; ++depth) {
121             weight +=
122                 depthwise_weights(depth, kz, ky, kx) *
123                 pointwise_weights(out, depth + kz * depth_multiplier, 0, 0);
124           }
125           weights(out, kz, ky, kx) = weight;
126         }
127       }
128     }
129   }
130 
131   return ConvArray4D(input, weights, kernel_stride, padding);
132 }
133 
WindowCount(int64_t unpadded_width,int64_t window_len,int64_t stride,Padding padding)134 /* static */ int64 ReferenceUtil::WindowCount(int64_t unpadded_width,
135                                               int64_t window_len,
136                                               int64_t stride, Padding padding) {
137   if (padding == Padding::kValid) {
138     return window_util::StridedBound(unpadded_width, window_len, stride);
139   }
140   return tensorflow::MathUtil::CeilOfRatio(unpadded_width, stride);
141 }
142 
143 /* static  */ std::unique_ptr<std::vector<float>>
ReduceWindow1DGeneric(absl::Span<const float> operand,float init,const std::function<float (float,float)> & reduce_func,absl::Span<const int64> window,absl::Span<const int64> stride,absl::Span<const std::pair<int64,int64>> padding)144 ReferenceUtil::ReduceWindow1DGeneric(
145     absl::Span<const float> operand, float init,
146     const std::function<float(float, float)>& reduce_func,
147     absl::Span<const int64> window, absl::Span<const int64> stride,
148     absl::Span<const std::pair<int64, int64>> padding) {
149   CHECK_EQ(window.size(), 1);
150   CHECK_EQ(stride.size(), 1);
151   CHECK_EQ(padding.size(), 1);
152 
153   int64_t padded_width = padding[0].first + operand.size() + padding[0].second;
154   int64_t stride_amount = stride[0];
155   int64_t window_size = window[0];
156   int64_t result_size =
157       window_util::StridedBound(padded_width, window_size, stride_amount);
158   int64_t pad_low = padding[0].first;
159   auto result = absl::make_unique<std::vector<float>>(result_size);
160 
161   // Do a full 1D reduce window.
162   for (int64_t i0 = 0; i0 < result_size; ++i0) {
163     int64_t i0_base = i0 * stride_amount - pad_low;
164     float val = init;
165     for (int64_t i0_win = 0; i0_win < window_size; ++i0_win) {
166       if (i0_base + i0_win >= 0 && i0_base + i0_win < operand.size()) {
167         val = reduce_func(val, operand[i0_base + i0_win]);
168       }
169     }
170     (*result)[i0] = val;
171   }
172   return result;
173 }
174 
175 /* static  */ std::unique_ptr<std::vector<float>>
ReduceWindow1DAdd(absl::Span<const float> operand,float init,absl::Span<const int64> window,absl::Span<const int64> stride,Padding padding)176 ReferenceUtil::ReduceWindow1DAdd(absl::Span<const float> operand, float init,
177                                  absl::Span<const int64> window,
178                                  absl::Span<const int64> stride,
179                                  Padding padding) {
180   const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
181   std::vector<int64> dim_lengths{static_cast<int64>(operand.size())};
182   return ReduceWindow1DGeneric(
183       operand, init, add_reduce, window, stride,
184       xla::MakePadding(dim_lengths, window, stride, padding));
185 }
186 
ReduceWindow3DAdd(const Array3D<float> & operand,float init,absl::Span<const int64> window,absl::Span<const int64> stride,Padding padding)187 /* static  */ std::unique_ptr<Array3D<float>> ReferenceUtil::ReduceWindow3DAdd(
188     const Array3D<float>& operand, float init, absl::Span<const int64> window,
189     absl::Span<const int64> stride, Padding padding) {
190   std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3()};
191   auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
192 
193   std::vector<int64> window_counts(window.size(), 0);
194   std::vector<int64> pad_low(window.size(), 0);
195   for (int64_t i = 0; i < window.size(); ++i) {
196     window_counts[i] =
197         WindowCount(dim_lengths[i], window[i], stride[i], padding);
198     pad_low[i] = padding_both[i].first;
199   }
200   auto result = absl::make_unique<Array3D<float>>(
201       window_counts[0], window_counts[1], window_counts[2]);
202 
203   for (int64_t i0 = 0; i0 < window_counts[0]; ++i0) {
204     for (int64_t i1 = 0; i1 < window_counts[1]; ++i1) {
205       for (int64_t i2 = 0; i2 < window_counts[2]; ++i2) {
206         int64_t i0_base = i0 * stride[0] - pad_low[0];
207         int64_t i1_base = i1 * stride[1] - pad_low[1];
208         int64_t i2_base = i2 * stride[2] - pad_low[2];
209 
210         float val = init;
211         for (int64_t i0_win = 0; i0_win < window[0]; ++i0_win) {
212           for (int64_t i1_win = 0; i1_win < window[1]; ++i1_win) {
213             for (int64_t i2_win = 0; i2_win < window[2]; ++i2_win) {
214               if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
215                   i2_base + i2_win >= 0 && i0_base + i0_win < operand.n1() &&
216                   i1_base + i1_win < operand.n2() &&
217                   i2_base + i2_win < operand.n3()) {
218                 val += operand(i0_base + i0_win, i1_base + i1_win,
219                                i2_base + i2_win);
220               }
221             }
222           }
223         }
224         (*result)(i0, i1, i2) = val;
225       }
226     }
227   }
228   return result;
229 }
230 
231 /* static */ std::unique_ptr<Array4D<float>>
ReduceWindow4DGeneric(const Array4D<float> & operand,float init,const std::function<float (float,float)> & reduce_func,absl::Span<const int64> window,absl::Span<const int64> stride,Padding padding)232 ReferenceUtil::ReduceWindow4DGeneric(
233     const Array4D<float>& operand, float init,
234     const std::function<float(float, float)>& reduce_func,
235     absl::Span<const int64> window, absl::Span<const int64> stride,
236     Padding padding) {
237   std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
238                                  operand.n4()};
239   return ReduceWindow4DGeneric(
240       operand, init, reduce_func, window, stride,
241       xla::MakePadding(dim_lengths, window, stride, padding));
242 }
243 
244 /* static */ std::unique_ptr<Array4D<float>>
ReduceWindow4DGeneric(const Array4D<float> & operand,float init,const std::function<float (float,float)> & reduce_func,absl::Span<const int64> window,absl::Span<const int64> stride,absl::Span<const std::pair<int64,int64>> padding)245 ReferenceUtil::ReduceWindow4DGeneric(
246     const Array4D<float>& operand, float init,
247     const std::function<float(float, float)>& reduce_func,
248     absl::Span<const int64> window, absl::Span<const int64> stride,
249     absl::Span<const std::pair<int64, int64>> padding) {
250   std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
251                                  operand.n4()};
252 
253   std::vector<int64> window_counts(window.size(), 0);
254   std::vector<int64> pad_low(window.size(), 0);
255   for (int64_t i = 0; i < window.size(); ++i) {
256     int64_t padded_width =
257         padding[i].first + dim_lengths[i] + padding[i].second;
258     window_counts[i] =
259         window_util::StridedBound(padded_width, window[i], stride[i]);
260     pad_low[i] = padding[i].first;
261   }
262   auto result = absl::make_unique<Array4D<float>>(
263       window_counts[0], window_counts[1], window_counts[2], window_counts[3]);
264   // Do a full 4D reduce window.
265   for (int64_t i0 = 0; i0 < window_counts[0]; ++i0) {
266     for (int64_t i1 = 0; i1 < window_counts[1]; ++i1) {
267       for (int64_t i2 = 0; i2 < window_counts[2]; ++i2) {
268         for (int64_t i3 = 0; i3 < window_counts[3]; ++i3) {
269           int64_t i0_base = i0 * stride[0] - pad_low[0];
270           int64_t i1_base = i1 * stride[1] - pad_low[1];
271           int64_t i2_base = i2 * stride[2] - pad_low[2];
272           int64_t i3_base = i3 * stride[3] - pad_low[3];
273 
274           float val = init;
275           for (int64_t i0_win = 0; i0_win < window[0]; ++i0_win) {
276             for (int64_t i1_win = 0; i1_win < window[1]; ++i1_win) {
277               for (int64_t i2_win = 0; i2_win < window[2]; ++i2_win) {
278                 for (int64_t i3_win = 0; i3_win < window[3]; ++i3_win) {
279                   if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
280                       i2_base + i2_win >= 0 && i3_base + i3_win >= 0 &&
281                       i0_base + i0_win < operand.n1() &&
282                       i1_base + i1_win < operand.n2() &&
283                       i2_base + i2_win < operand.n3() &&
284                       i3_base + i3_win < operand.n4()) {
285                     val = reduce_func(
286                         val, operand(i0_base + i0_win, i1_base + i1_win,
287                                      i2_base + i2_win, i3_base + i3_win));
288                   }
289                 }
290               }
291             }
292           }
293           (*result)(i0, i1, i2, i3) = val;
294         }
295       }
296     }
297   }
298   return result;
299 }
300 
ReduceWindow4DAdd(const Array4D<float> & operand,float init,absl::Span<const int64> window,absl::Span<const int64> stride,Padding padding)301 /* static  */ std::unique_ptr<Array4D<float>> ReferenceUtil::ReduceWindow4DAdd(
302     const Array4D<float>& operand, float init, absl::Span<const int64> window,
303     absl::Span<const int64> stride, Padding padding) {
304   const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
305   return ReduceWindow4DGeneric(operand, init, add_reduce, window, stride,
306                                padding);
307 }
308 
BatchNorm4D(const Array4D<float> & input,const Array4D<float> & mean,const Array4D<float> & var,const Array4D<float> & scale,const Array4D<float> & offset,float epsilon)309 /* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::BatchNorm4D(
310     const Array4D<float>& input, const Array4D<float>& mean,
311     const Array4D<float>& var, const Array4D<float>& scale,
312     const Array4D<float>& offset, float epsilon) {
313   auto normalized =
314       *MapArray4D(input, mean, [](float a, float b) { return a - b; });
315   normalized = *MapArray4D(normalized, var, [&](float a, float b) {
316     return a / std::sqrt(b + epsilon);
317   });
318   normalized =
319       *MapArray4D(normalized, scale, [](float a, float b) { return a * b; });
320   return MapArray4D(normalized, offset, [](float a, float b) { return a + b; });
321 }
322 
323 /* static  */ std::unique_ptr<Array4D<float>>
SelectAndScatter4DGePlus(const Array4D<float> & operand,const Array4D<float> & source,float init,absl::Span<const int64> window,absl::Span<const int64> stride,bool same_padding)324 ReferenceUtil::SelectAndScatter4DGePlus(const Array4D<float>& operand,
325                                         const Array4D<float>& source,
326                                         float init,
327                                         absl::Span<const int64> window,
328                                         absl::Span<const int64> stride,
329                                         bool same_padding) {
330   Padding padding = same_padding ? Padding::kSame : Padding::kValid;
331   auto result = absl::make_unique<Array4D<float>>(operand.n1(), operand.n2(),
332                                                   operand.n3(), operand.n4());
333   std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
334                                  operand.n4()};
335   auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
336   // Fill the output, with the initial value.
337   result->Fill(init);
338 
339   std::vector<int64> window_counts(window.size(), 0);
340   std::vector<int64> pad_low(window.size(), 0);
341   for (int64_t i = 0; i < window.size(); ++i) {
342     window_counts[i] =
343         WindowCount(dim_lengths[i], window[i], stride[i], padding);
344     pad_low[i] = padding_both[i].first;
345   }
346   CHECK_EQ(window_counts[0], source.n1());
347   CHECK_EQ(window_counts[1], source.n2());
348   CHECK_EQ(window_counts[2], source.n3());
349   CHECK_EQ(window_counts[3], source.n4());
350 
351   // Do a full 4D select and Scatter.
352   for (int64_t i0 = 0; i0 < window_counts[0]; ++i0) {
353     for (int64_t i1 = 0; i1 < window_counts[1]; ++i1) {
354       for (int64_t i2 = 0; i2 < window_counts[2]; ++i2) {
355         for (int64_t i3 = 0; i3 < window_counts[3]; ++i3) {
356           // Now we are inside a window and need to find the max and the argmax.
357           int64_t i0_base = i0 * stride[0] - pad_low[0];
358           int64_t i1_base = i1 * stride[1] - pad_low[1];
359           int64_t i2_base = i2 * stride[2] - pad_low[2];
360           int64_t i3_base = i3 * stride[3] - pad_low[3];
361           int64_t scatter_0 = (i0_base >= 0) ? i0_base : 0;
362           int64_t scatter_1 = (i1_base >= 0) ? i1_base : 0;
363           int64_t scatter_2 = (i2_base >= 0) ? i2_base : 0;
364           int64_t scatter_3 = (i3_base >= 0) ? i3_base : 0;
365           float val = operand(scatter_0, scatter_1, scatter_2, scatter_3);
366           for (int64_t i0_win = 0; i0_win < window[0]; ++i0_win) {
367             for (int64_t i1_win = 0; i1_win < window[1]; ++i1_win) {
368               for (int64_t i2_win = 0; i2_win < window[2]; ++i2_win) {
369                 for (int64_t i3_win = 0; i3_win < window[3]; ++i3_win) {
370                   if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
371                       i2_base + i2_win >= 0 && i3_base + i3_win >= 0 &&
372                       i0_base + i0_win < operand.n1() &&
373                       i1_base + i1_win < operand.n2() &&
374                       i2_base + i2_win < operand.n3() &&
375                       i3_base + i3_win < operand.n4()) {
376                     float tmp = operand(i0_base + i0_win, i1_base + i1_win,
377                                         i2_base + i2_win, i3_base + i3_win);
378                     if (tmp > val) {
379                       val = tmp;
380                       scatter_0 = i0_base + i0_win;
381                       scatter_1 = i1_base + i1_win;
382                       scatter_2 = i2_base + i2_win;
383                       scatter_3 = i3_base + i3_win;
384                     }
385                   }
386                 }
387               }
388             }
389           }
390           (*result)(scatter_0, scatter_1, scatter_2, scatter_3) +=
391               source(i0, i1, i2, i3);
392         }
393       }
394     }
395   }
396   return result;
397 }
398 
399 /* static */ std::unique_ptr<Array4D<float>>
ConvArray4DGeneralDimensions(const Array4D<float> & lhs,const Array4D<float> & rhs,std::pair<int64,int64> kernel_stride,Padding padding,ConvolutionDimensionNumbers dimension_numbers)400 ReferenceUtil::ConvArray4DGeneralDimensions(
401     const Array4D<float>& lhs, const Array4D<float>& rhs,
402     std::pair<int64, int64> kernel_stride, Padding padding,
403     ConvolutionDimensionNumbers dimension_numbers) {
404   return ConvArray4DGeneralDimensionsDilated(lhs, rhs, kernel_stride, padding,
405                                              {1, 1}, {1, 1},
406                                              std::move(dimension_numbers));
407 }
408 
409 /* static */ std::unique_ptr<Array4D<float>>
ConvArray4DGeneralDimensionsDilated(const Array4D<float> & lhs,const Array4D<float> & rhs,std::pair<int64,int64> kernel_stride,Padding padding,std::pair<int64,int64> lhs_dilation,std::pair<int64,int64> rhs_dilation,ConvolutionDimensionNumbers dnums)410 ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
411     const Array4D<float>& lhs, const Array4D<float>& rhs,
412     std::pair<int64, int64> kernel_stride, Padding padding,
413     std::pair<int64, int64> lhs_dilation, std::pair<int64, int64> rhs_dilation,
414     ConvolutionDimensionNumbers dnums) {
415   HloComputation::Builder b("ConvArray4DGeneralDimensionDilated");
416   auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs);
417   auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs);
418 
419   std::array<int64, 2> ordered_kernel_strides;
420   std::array<int64, 2> ordered_input_dimensions;
421   std::array<int64, 2> ordered_kernel_dimensions;
422   if (dnums.kernel_spatial_dimensions(0) > dnums.kernel_spatial_dimensions(1)) {
423     ordered_kernel_strides[0] = kernel_stride.second;
424     ordered_kernel_strides[1] = kernel_stride.first;
425   } else {
426     ordered_kernel_strides[0] = kernel_stride.first;
427     ordered_kernel_strides[1] = kernel_stride.second;
428   }
429 
430   ordered_input_dimensions[0] =
431       lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(0));
432   ordered_input_dimensions[1] =
433       lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(1));
434   ordered_kernel_dimensions[0] =
435       rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0));
436   ordered_kernel_dimensions[1] =
437       rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1));
438 
439   std::vector<std::pair<int64, int64>> paddings =
440       MakePadding(ordered_input_dimensions, ordered_kernel_dimensions,
441                   ordered_kernel_strides, padding);
442   CHECK_EQ(paddings.size(), 2);
443 
444   Window window;
445 
446   WindowDimension dim;
447   dim.set_size(
448       rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0)));
449   dim.set_stride(kernel_stride.first);
450   dim.set_padding_low(paddings[0].first);
451   dim.set_padding_high(paddings[0].second);
452   dim.set_window_dilation(rhs_dilation.first);
453   dim.set_base_dilation(lhs_dilation.first);
454   *window.add_dimensions() = dim;
455 
456   WindowDimension dim2;
457   dim2.set_size(
458       rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1)));
459   dim2.set_stride(kernel_stride.second);
460   dim2.set_padding_low(paddings[1].first);
461   dim2.set_padding_high(paddings[1].second);
462   dim2.set_window_dilation(rhs_dilation.second);
463   dim2.set_base_dilation(lhs_dilation.second);
464   *window.add_dimensions() = dim2;
465 
466   const Shape& shape =
467       ShapeInference::InferConvolveShape(
468           lhs_literal.shape(), rhs_literal.shape(),
469           /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums,
470           /*preferred_element_type=*/absl::nullopt)
471           .ConsumeValueOrDie();
472 
473   HloInstruction* lhs_instruction =
474       b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
475   HloInstruction* rhs_instruction =
476       b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
477 
478   PrecisionConfig precision_config;
479   precision_config.mutable_operand_precision()->Resize(
480       /*new_size=*/2, PrecisionConfig::DEFAULT);
481   b.AddInstruction(HloInstruction::CreateConvolve(
482       shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
483       /*batch_group_count=*/1, window, dnums, precision_config));
484   HloModuleConfig config;
485   HloModule module("ReferenceUtil", config);
486   auto computation = module.AddEntryComputation(b.Build());
487 
488   HloEvaluator evaluator;
489   Literal result_literal =
490       evaluator.Evaluate(*computation, {}).ConsumeValueOrDie();
491 
492   CHECK_EQ(result_literal.shape().rank(), 4);
493   auto result =
494       absl::make_unique<Array4D<float>>(result_literal.shape().dimensions(0),
495                                         result_literal.shape().dimensions(1),
496                                         result_literal.shape().dimensions(2),
497                                         result_literal.shape().dimensions(3));
498 
499   result->Each([&](absl::Span<const int64> indices, float* value) {
500     *value = result_literal.Get<float>(indices);
501   });
502 
503   return result;
504 }
505 
506 /* static */ std::unique_ptr<std::vector<float>>
ReduceToColArray2D(const Array2D<float> & matrix,float init,const std::function<float (float,float)> & reduce_function)507 ReferenceUtil::ReduceToColArray2D(
508     const Array2D<float>& matrix, float init,
509     const std::function<float(float, float)>& reduce_function) {
510   int64_t rows = matrix.height();
511   int64_t cols = matrix.width();
512   auto result = absl::make_unique<std::vector<float>>();
513   for (int64_t i = 0; i < rows; ++i) {
514     float acc = init;
515     for (int64_t j = 0; j < cols; ++j) {
516       acc = reduce_function(acc, matrix(i, j));
517     }
518     result->push_back(acc);
519   }
520   return result;
521 }
522 
523 /* static */ std::unique_ptr<std::vector<float>>
ReduceToRowArray2D(const Array2D<float> & matrix,float init,const std::function<float (float,float)> & reduce_function)524 ReferenceUtil::ReduceToRowArray2D(
525     const Array2D<float>& matrix, float init,
526     const std::function<float(float, float)>& reduce_function) {
527   int64_t rows = matrix.height();
528   int64_t cols = matrix.width();
529   auto result = absl::make_unique<std::vector<float>>();
530   for (int64_t i = 0; i < cols; ++i) {
531     float acc = init;
532     for (int64_t j = 0; j < rows; ++j) {
533       acc = reduce_function(acc, matrix(j, i));
534     }
535     result->push_back(acc);
536   }
537   return result;
538 }
539 
Reduce4DTo1D(const Array4D<float> & array,float init,absl::Span<const int64> dims,const std::function<float (float,float)> & reduce_function)540 /*static*/ std::vector<float> ReferenceUtil::Reduce4DTo1D(
541     const Array4D<float>& array, float init, absl::Span<const int64> dims,
542     const std::function<float(float, float)>& reduce_function) {
543   std::vector<float> result;
544   CHECK_EQ(dims.size(), 3);
545   const absl::flat_hash_set<int64> dim_set(dims.begin(), dims.end());
546   CHECK_EQ(dim_set.size(), 3);
547   for (int64_t a0 = 0; a0 == 0 || (!dim_set.contains(0) && a0 < array.n1());
548        ++a0) {
549     for (int64_t a1 = 0; a1 == 0 || (!dim_set.contains(1) && a1 < array.n2());
550          ++a1) {
551       for (int64_t a2 = 0; a2 == 0 || (!dim_set.contains(2) && a2 < array.n3());
552            ++a2) {
553         for (int64_t a3 = 0;
554              a3 == 0 || (!dim_set.contains(3) && a3 < array.n4()); ++a3) {
555           float accumulator = init;
556           for (int64_t i0 = 0;
557                i0 == 0 || (dim_set.contains(0) && i0 < array.n1()); ++i0) {
558             for (int64_t i1 = 0;
559                  i1 == 0 || (dim_set.contains(1) && i1 < array.n2()); ++i1) {
560               for (int64_t i2 = 0;
561                    i2 == 0 || (dim_set.contains(2) && i2 < array.n3()); ++i2) {
562                 for (int64_t i3 = 0;
563                      i3 == 0 || (dim_set.contains(3) && i3 < array.n4());
564                      ++i3) {
565                   // Handle zero-sized arrays.
566                   if (array.n1() > 0 && array.n2() > 0 && array.n3() > 0 &&
567                       array.n4() > 0) {
568                     accumulator = reduce_function(
569                         accumulator, array(a0 + i0, a1 + i1, a2 + i2, a3 + i3));
570                   }
571                 }
572               }
573             }
574           }
575           result.push_back(accumulator);
576         }
577       }
578     }
579   }
580   return result;
581 }
582 
Broadcast1DTo4D(const std::vector<float> & array,const std::vector<int64> & bounds,int64_t broadcast_from_dim)583 /* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::Broadcast1DTo4D(
584     const std::vector<float>& array, const std::vector<int64>& bounds,
585     int64_t broadcast_from_dim) {
586   auto result = absl::make_unique<Array4D<float>>(bounds[0], bounds[1],
587                                                   bounds[2], bounds[3]);
588   for (int64_t i = 0; i < result->n1(); ++i) {
589     for (int64_t j = 0; j < result->n2(); ++j) {
590       for (int64_t k = 0; k < result->n3(); ++k) {
591         for (int64_t l = 0; l < result->n4(); ++l) {
592           switch (broadcast_from_dim) {
593             case 0:
594               (*result)(i, j, k, l) = array[i];
595               break;
596             case 1:
597               (*result)(i, j, k, l) = array[j];
598               break;
599             case 2:
600               (*result)(i, j, k, l) = array[k];
601               break;
602             case 3:
603               (*result)(i, j, k, l) = array[l];
604               break;
605             default:
606               break;
607           }
608         }
609       }
610     }
611   }
612   return result;
613 }
614 
Reduce3DTo2D(const Array3D<float> & array,float init,absl::Span<const int64> dims,const std::function<float (float,float)> & reduce_function)615 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::Reduce3DTo2D(
616     const Array3D<float>& array, float init, absl::Span<const int64> dims,
617     const std::function<float(float, float)>& reduce_function) {
618   CHECK_EQ(dims.size(), 1);
619   int64_t rows = dims[0] == 0 ? array.n2() : array.n1();
620   int64_t cols = dims[0] == 2 ? array.n2() : array.n3();
621   auto result = absl::make_unique<Array2D<float>>(rows, cols);
622   result->Fill(init);
623   for (int i0 = 0; i0 < array.n1(); ++i0) {
624     for (int i1 = 0; i1 < array.n2(); ++i1) {
625       for (int i2 = 0; i2 < array.n3(); ++i2) {
626         int64_t row = dims[0] == 0 ? i1 : i0;
627         int64_t col = dims[0] == 2 ? i1 : i2;
628         (*result)(row, col) =
629             reduce_function((*result)(row, col), array(i0, i1, i2));
630       }
631     }
632   }
633   return result;
634 }
635 
MapArray2D(const Array2D<float> & matrix,const std::function<float (float)> & map_function)636 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapArray2D(
637     const Array2D<float>& matrix,
638     const std::function<float(float)>& map_function) {
639   int64_t rows = matrix.height();
640   int64_t cols = matrix.width();
641   auto result = absl::make_unique<Array2D<float>>(rows, cols);
642   for (int64_t i = 0; i < rows; ++i) {
643     for (int64_t j = 0; j < cols; ++j) {
644       (*result)(i, j) = map_function(matrix(i, j));
645     }
646   }
647   return result;
648 }
649 
MapArray2D(const Array2D<float> & lhs,const Array2D<float> & rhs,const std::function<float (float,float)> & map_function)650 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapArray2D(
651     const Array2D<float>& lhs, const Array2D<float>& rhs,
652     const std::function<float(float, float)>& map_function) {
653   CHECK_EQ(lhs.height(), rhs.height());
654   CHECK_EQ(lhs.width(), rhs.width());
655   int64_t rows = lhs.height();
656   int64_t cols = rhs.width();
657   auto result = absl::make_unique<Array2D<float>>(rows, cols);
658   for (int64_t i = 0; i < rows; ++i) {
659     for (int64_t j = 0; j < cols; ++j) {
660       (*result)(i, j) = map_function(lhs(i, j), rhs(i, j));
661     }
662   }
663   return result;
664 }
665 
MapWithIndexArray2D(const Array2D<float> & matrix,const std::function<float (float,int64_t,int64_t)> & map_function)666 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapWithIndexArray2D(
667     const Array2D<float>& matrix,
668     const std::function<float(float, int64_t, int64_t)>& map_function) {
669   int64_t rows = matrix.height();
670   int64_t cols = matrix.width();
671   auto result = absl::make_unique<Array2D<float>>(rows, cols);
672   for (int64_t i = 0; i < rows; ++i) {
673     for (int64_t j = 0; j < cols; ++j) {
674       (*result)(i, j) = map_function(matrix(i, j), i, j);
675     }
676   }
677   return result;
678 }
679 
680 }  // namespace xla
681