1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/reference_util.h"
17
18 #include <array>
19 #include <utility>
20
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/memory/memory.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/literal_util.h"
25 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/shape_inference.h"
28 #include "tensorflow/compiler/xla/window_util.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/core/lib/math/math_util.h"
31 #include "tensorflow/core/platform/logging.h"
32
33 namespace xla {
34
Array2DF32ToF64(const Array2D<float> & input)35 /* static */ std::unique_ptr<Array2D<double>> ReferenceUtil::Array2DF32ToF64(
36 const Array2D<float>& input) {
37 auto result =
38 absl::make_unique<Array2D<double>>(input.height(), input.width());
39 for (int64_t rowno = 0; rowno < input.height(); ++rowno) {
40 for (int64_t colno = 0; colno < input.height(); ++colno) {
41 (*result)(rowno, colno) = input(rowno, colno);
42 }
43 }
44 return result;
45 }
46
ConvArray3D(const Array3D<float> & lhs,const Array3D<float> & rhs,int64_t kernel_stride,Padding padding)47 /* static */ std::unique_ptr<Array3D<float>> ReferenceUtil::ConvArray3D(
48 const Array3D<float>& lhs, const Array3D<float>& rhs, int64_t kernel_stride,
49 Padding padding) {
50 return ConvArray3DGeneralDimensionsDilated(
51 lhs, rhs, kernel_stride, padding, 1, 1,
52 XlaBuilder::CreateDefaultConvDimensionNumbers(1));
53 }
54
55 /*static*/ std::unique_ptr<Array3D<float>>
ConvArray3DGeneralDimensionsDilated(const Array3D<float> & lhs,const Array3D<float> & rhs,int64_t kernel_stride,Padding padding,int64_t lhs_dilation,int64_t rhs_dilation,const ConvolutionDimensionNumbers & dnums)56 ReferenceUtil::ConvArray3DGeneralDimensionsDilated(
57 const Array3D<float>& lhs, const Array3D<float>& rhs, int64_t kernel_stride,
58 Padding padding, int64_t lhs_dilation, int64_t rhs_dilation,
59 const ConvolutionDimensionNumbers& dnums) {
60 CHECK_EQ(dnums.input_spatial_dimensions_size(), 1);
61 CHECK_EQ(dnums.kernel_spatial_dimensions_size(), 1);
62 CHECK_EQ(dnums.output_spatial_dimensions_size(), 1);
63 // Reuse the code for Array4D-convolution by extending the 3D input into a 4D
64 // array by adding a fourth dummy dimension of size 1 without stride, padding
65 // and dilation.
66 Array4D<float> a4dlhs(lhs.n1(), lhs.n2(), lhs.n3(), 1);
67 a4dlhs.Each([&](absl::Span<const int64> indices, float* value_ptr) {
68 CHECK_EQ(indices[3], 0);
69 *value_ptr = lhs.operator()(indices[0], indices[1], indices[2]);
70 });
71 Array4D<float> a4drhs(rhs.n1(), rhs.n2(), rhs.n3(), 1);
72 a4drhs.Each([&](absl::Span<const int64> indices, float* value_ptr) {
73 CHECK_EQ(indices[3], 0);
74 *value_ptr = rhs.operator()(indices[0], indices[1], indices[2]);
75 });
76 // Add a second dummy spatial dimensions.
77 ConvolutionDimensionNumbers dnums2d = dnums;
78 dnums2d.add_input_spatial_dimensions(3);
79 dnums2d.add_kernel_spatial_dimensions(3);
80 dnums2d.add_output_spatial_dimensions(3);
81 std::unique_ptr<Array4D<float>> convr4 = ConvArray4DGeneralDimensionsDilated(
82 a4dlhs, a4drhs, {kernel_stride, 1}, padding, {lhs_dilation, 1},
83 {rhs_dilation, 1}, dnums2d);
84
85 auto convr3 = absl::make_unique<Array3D<float>>(
86 convr4->planes(), convr4->depth(), convr4->height());
87 convr4->Each([&](absl::Span<const int64> indices, float* value_ptr) {
88 CHECK_EQ(indices[3], 0);
89 convr3->operator()(indices[0], indices[1], indices[2]) = *value_ptr;
90 });
91 return convr3;
92 }
93
ConvArray4D(const Array4D<float> & lhs,const Array4D<float> & rhs,std::pair<int64,int64> kernel_stride,Padding padding)94 /* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ConvArray4D(
95 const Array4D<float>& lhs, const Array4D<float>& rhs,
96 std::pair<int64, int64> kernel_stride, Padding padding) {
97 return ConvArray4DGeneralDimensions(
98 lhs, rhs, kernel_stride, padding,
99 XlaBuilder::CreateDefaultConvDimensionNumbers());
100 }
101
102 /* static */ std::unique_ptr<Array4D<float>>
SeparableConvArray4D(const Array4D<float> & input,const Array4D<float> & depthwise_weights,const Array4D<float> & pointwise_weights,std::pair<int64,int64> kernel_stride,Padding padding)103 ReferenceUtil::SeparableConvArray4D(const Array4D<float>& input,
104 const Array4D<float>& depthwise_weights,
105 const Array4D<float>& pointwise_weights,
106 std::pair<int64, int64> kernel_stride,
107 Padding padding) {
108 const int64_t depth_multiplier = depthwise_weights.planes();
109 CHECK_EQ(pointwise_weights.depth(), input.depth() * depth_multiplier);
110
111 // Combine the two weights by reducing the depth_multiplier, so that we can
112 // apply a single convolution on the combined weights.
113 Array4D<float> weights(pointwise_weights.planes(), input.depth(),
114 depthwise_weights.height(), depthwise_weights.width());
115 for (int64_t kx = 0; kx < depthwise_weights.width(); ++kx) {
116 for (int64_t ky = 0; ky < depthwise_weights.height(); ++ky) {
117 for (int64_t kz = 0; kz < input.depth(); ++kz) {
118 for (int64_t out = 0; out < pointwise_weights.planes(); ++out) {
119 float weight = 0.0;
120 for (int64_t depth = 0; depth < depth_multiplier; ++depth) {
121 weight +=
122 depthwise_weights(depth, kz, ky, kx) *
123 pointwise_weights(out, depth + kz * depth_multiplier, 0, 0);
124 }
125 weights(out, kz, ky, kx) = weight;
126 }
127 }
128 }
129 }
130
131 return ConvArray4D(input, weights, kernel_stride, padding);
132 }
133
WindowCount(int64_t unpadded_width,int64_t window_len,int64_t stride,Padding padding)134 /* static */ int64 ReferenceUtil::WindowCount(int64_t unpadded_width,
135 int64_t window_len,
136 int64_t stride, Padding padding) {
137 if (padding == Padding::kValid) {
138 return window_util::StridedBound(unpadded_width, window_len, stride);
139 }
140 return tensorflow::MathUtil::CeilOfRatio(unpadded_width, stride);
141 }
142
143 /* static */ std::unique_ptr<std::vector<float>>
ReduceWindow1DGeneric(absl::Span<const float> operand,float init,const std::function<float (float,float)> & reduce_func,absl::Span<const int64> window,absl::Span<const int64> stride,absl::Span<const std::pair<int64,int64>> padding)144 ReferenceUtil::ReduceWindow1DGeneric(
145 absl::Span<const float> operand, float init,
146 const std::function<float(float, float)>& reduce_func,
147 absl::Span<const int64> window, absl::Span<const int64> stride,
148 absl::Span<const std::pair<int64, int64>> padding) {
149 CHECK_EQ(window.size(), 1);
150 CHECK_EQ(stride.size(), 1);
151 CHECK_EQ(padding.size(), 1);
152
153 int64_t padded_width = padding[0].first + operand.size() + padding[0].second;
154 int64_t stride_amount = stride[0];
155 int64_t window_size = window[0];
156 int64_t result_size =
157 window_util::StridedBound(padded_width, window_size, stride_amount);
158 int64_t pad_low = padding[0].first;
159 auto result = absl::make_unique<std::vector<float>>(result_size);
160
161 // Do a full 1D reduce window.
162 for (int64_t i0 = 0; i0 < result_size; ++i0) {
163 int64_t i0_base = i0 * stride_amount - pad_low;
164 float val = init;
165 for (int64_t i0_win = 0; i0_win < window_size; ++i0_win) {
166 if (i0_base + i0_win >= 0 && i0_base + i0_win < operand.size()) {
167 val = reduce_func(val, operand[i0_base + i0_win]);
168 }
169 }
170 (*result)[i0] = val;
171 }
172 return result;
173 }
174
175 /* static */ std::unique_ptr<std::vector<float>>
ReduceWindow1DAdd(absl::Span<const float> operand,float init,absl::Span<const int64> window,absl::Span<const int64> stride,Padding padding)176 ReferenceUtil::ReduceWindow1DAdd(absl::Span<const float> operand, float init,
177 absl::Span<const int64> window,
178 absl::Span<const int64> stride,
179 Padding padding) {
180 const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
181 std::vector<int64> dim_lengths{static_cast<int64>(operand.size())};
182 return ReduceWindow1DGeneric(
183 operand, init, add_reduce, window, stride,
184 xla::MakePadding(dim_lengths, window, stride, padding));
185 }
186
ReduceWindow3DAdd(const Array3D<float> & operand,float init,absl::Span<const int64> window,absl::Span<const int64> stride,Padding padding)187 /* static */ std::unique_ptr<Array3D<float>> ReferenceUtil::ReduceWindow3DAdd(
188 const Array3D<float>& operand, float init, absl::Span<const int64> window,
189 absl::Span<const int64> stride, Padding padding) {
190 std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3()};
191 auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
192
193 std::vector<int64> window_counts(window.size(), 0);
194 std::vector<int64> pad_low(window.size(), 0);
195 for (int64_t i = 0; i < window.size(); ++i) {
196 window_counts[i] =
197 WindowCount(dim_lengths[i], window[i], stride[i], padding);
198 pad_low[i] = padding_both[i].first;
199 }
200 auto result = absl::make_unique<Array3D<float>>(
201 window_counts[0], window_counts[1], window_counts[2]);
202
203 for (int64_t i0 = 0; i0 < window_counts[0]; ++i0) {
204 for (int64_t i1 = 0; i1 < window_counts[1]; ++i1) {
205 for (int64_t i2 = 0; i2 < window_counts[2]; ++i2) {
206 int64_t i0_base = i0 * stride[0] - pad_low[0];
207 int64_t i1_base = i1 * stride[1] - pad_low[1];
208 int64_t i2_base = i2 * stride[2] - pad_low[2];
209
210 float val = init;
211 for (int64_t i0_win = 0; i0_win < window[0]; ++i0_win) {
212 for (int64_t i1_win = 0; i1_win < window[1]; ++i1_win) {
213 for (int64_t i2_win = 0; i2_win < window[2]; ++i2_win) {
214 if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
215 i2_base + i2_win >= 0 && i0_base + i0_win < operand.n1() &&
216 i1_base + i1_win < operand.n2() &&
217 i2_base + i2_win < operand.n3()) {
218 val += operand(i0_base + i0_win, i1_base + i1_win,
219 i2_base + i2_win);
220 }
221 }
222 }
223 }
224 (*result)(i0, i1, i2) = val;
225 }
226 }
227 }
228 return result;
229 }
230
231 /* static */ std::unique_ptr<Array4D<float>>
ReduceWindow4DGeneric(const Array4D<float> & operand,float init,const std::function<float (float,float)> & reduce_func,absl::Span<const int64> window,absl::Span<const int64> stride,Padding padding)232 ReferenceUtil::ReduceWindow4DGeneric(
233 const Array4D<float>& operand, float init,
234 const std::function<float(float, float)>& reduce_func,
235 absl::Span<const int64> window, absl::Span<const int64> stride,
236 Padding padding) {
237 std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
238 operand.n4()};
239 return ReduceWindow4DGeneric(
240 operand, init, reduce_func, window, stride,
241 xla::MakePadding(dim_lengths, window, stride, padding));
242 }
243
244 /* static */ std::unique_ptr<Array4D<float>>
ReduceWindow4DGeneric(const Array4D<float> & operand,float init,const std::function<float (float,float)> & reduce_func,absl::Span<const int64> window,absl::Span<const int64> stride,absl::Span<const std::pair<int64,int64>> padding)245 ReferenceUtil::ReduceWindow4DGeneric(
246 const Array4D<float>& operand, float init,
247 const std::function<float(float, float)>& reduce_func,
248 absl::Span<const int64> window, absl::Span<const int64> stride,
249 absl::Span<const std::pair<int64, int64>> padding) {
250 std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
251 operand.n4()};
252
253 std::vector<int64> window_counts(window.size(), 0);
254 std::vector<int64> pad_low(window.size(), 0);
255 for (int64_t i = 0; i < window.size(); ++i) {
256 int64_t padded_width =
257 padding[i].first + dim_lengths[i] + padding[i].second;
258 window_counts[i] =
259 window_util::StridedBound(padded_width, window[i], stride[i]);
260 pad_low[i] = padding[i].first;
261 }
262 auto result = absl::make_unique<Array4D<float>>(
263 window_counts[0], window_counts[1], window_counts[2], window_counts[3]);
264 // Do a full 4D reduce window.
265 for (int64_t i0 = 0; i0 < window_counts[0]; ++i0) {
266 for (int64_t i1 = 0; i1 < window_counts[1]; ++i1) {
267 for (int64_t i2 = 0; i2 < window_counts[2]; ++i2) {
268 for (int64_t i3 = 0; i3 < window_counts[3]; ++i3) {
269 int64_t i0_base = i0 * stride[0] - pad_low[0];
270 int64_t i1_base = i1 * stride[1] - pad_low[1];
271 int64_t i2_base = i2 * stride[2] - pad_low[2];
272 int64_t i3_base = i3 * stride[3] - pad_low[3];
273
274 float val = init;
275 for (int64_t i0_win = 0; i0_win < window[0]; ++i0_win) {
276 for (int64_t i1_win = 0; i1_win < window[1]; ++i1_win) {
277 for (int64_t i2_win = 0; i2_win < window[2]; ++i2_win) {
278 for (int64_t i3_win = 0; i3_win < window[3]; ++i3_win) {
279 if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
280 i2_base + i2_win >= 0 && i3_base + i3_win >= 0 &&
281 i0_base + i0_win < operand.n1() &&
282 i1_base + i1_win < operand.n2() &&
283 i2_base + i2_win < operand.n3() &&
284 i3_base + i3_win < operand.n4()) {
285 val = reduce_func(
286 val, operand(i0_base + i0_win, i1_base + i1_win,
287 i2_base + i2_win, i3_base + i3_win));
288 }
289 }
290 }
291 }
292 }
293 (*result)(i0, i1, i2, i3) = val;
294 }
295 }
296 }
297 }
298 return result;
299 }
300
ReduceWindow4DAdd(const Array4D<float> & operand,float init,absl::Span<const int64> window,absl::Span<const int64> stride,Padding padding)301 /* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ReduceWindow4DAdd(
302 const Array4D<float>& operand, float init, absl::Span<const int64> window,
303 absl::Span<const int64> stride, Padding padding) {
304 const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
305 return ReduceWindow4DGeneric(operand, init, add_reduce, window, stride,
306 padding);
307 }
308
BatchNorm4D(const Array4D<float> & input,const Array4D<float> & mean,const Array4D<float> & var,const Array4D<float> & scale,const Array4D<float> & offset,float epsilon)309 /* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::BatchNorm4D(
310 const Array4D<float>& input, const Array4D<float>& mean,
311 const Array4D<float>& var, const Array4D<float>& scale,
312 const Array4D<float>& offset, float epsilon) {
313 auto normalized =
314 *MapArray4D(input, mean, [](float a, float b) { return a - b; });
315 normalized = *MapArray4D(normalized, var, [&](float a, float b) {
316 return a / std::sqrt(b + epsilon);
317 });
318 normalized =
319 *MapArray4D(normalized, scale, [](float a, float b) { return a * b; });
320 return MapArray4D(normalized, offset, [](float a, float b) { return a + b; });
321 }
322
323 /* static */ std::unique_ptr<Array4D<float>>
SelectAndScatter4DGePlus(const Array4D<float> & operand,const Array4D<float> & source,float init,absl::Span<const int64> window,absl::Span<const int64> stride,bool same_padding)324 ReferenceUtil::SelectAndScatter4DGePlus(const Array4D<float>& operand,
325 const Array4D<float>& source,
326 float init,
327 absl::Span<const int64> window,
328 absl::Span<const int64> stride,
329 bool same_padding) {
330 Padding padding = same_padding ? Padding::kSame : Padding::kValid;
331 auto result = absl::make_unique<Array4D<float>>(operand.n1(), operand.n2(),
332 operand.n3(), operand.n4());
333 std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
334 operand.n4()};
335 auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
336 // Fill the output, with the initial value.
337 result->Fill(init);
338
339 std::vector<int64> window_counts(window.size(), 0);
340 std::vector<int64> pad_low(window.size(), 0);
341 for (int64_t i = 0; i < window.size(); ++i) {
342 window_counts[i] =
343 WindowCount(dim_lengths[i], window[i], stride[i], padding);
344 pad_low[i] = padding_both[i].first;
345 }
346 CHECK_EQ(window_counts[0], source.n1());
347 CHECK_EQ(window_counts[1], source.n2());
348 CHECK_EQ(window_counts[2], source.n3());
349 CHECK_EQ(window_counts[3], source.n4());
350
351 // Do a full 4D select and Scatter.
352 for (int64_t i0 = 0; i0 < window_counts[0]; ++i0) {
353 for (int64_t i1 = 0; i1 < window_counts[1]; ++i1) {
354 for (int64_t i2 = 0; i2 < window_counts[2]; ++i2) {
355 for (int64_t i3 = 0; i3 < window_counts[3]; ++i3) {
356 // Now we are inside a window and need to find the max and the argmax.
357 int64_t i0_base = i0 * stride[0] - pad_low[0];
358 int64_t i1_base = i1 * stride[1] - pad_low[1];
359 int64_t i2_base = i2 * stride[2] - pad_low[2];
360 int64_t i3_base = i3 * stride[3] - pad_low[3];
361 int64_t scatter_0 = (i0_base >= 0) ? i0_base : 0;
362 int64_t scatter_1 = (i1_base >= 0) ? i1_base : 0;
363 int64_t scatter_2 = (i2_base >= 0) ? i2_base : 0;
364 int64_t scatter_3 = (i3_base >= 0) ? i3_base : 0;
365 float val = operand(scatter_0, scatter_1, scatter_2, scatter_3);
366 for (int64_t i0_win = 0; i0_win < window[0]; ++i0_win) {
367 for (int64_t i1_win = 0; i1_win < window[1]; ++i1_win) {
368 for (int64_t i2_win = 0; i2_win < window[2]; ++i2_win) {
369 for (int64_t i3_win = 0; i3_win < window[3]; ++i3_win) {
370 if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
371 i2_base + i2_win >= 0 && i3_base + i3_win >= 0 &&
372 i0_base + i0_win < operand.n1() &&
373 i1_base + i1_win < operand.n2() &&
374 i2_base + i2_win < operand.n3() &&
375 i3_base + i3_win < operand.n4()) {
376 float tmp = operand(i0_base + i0_win, i1_base + i1_win,
377 i2_base + i2_win, i3_base + i3_win);
378 if (tmp > val) {
379 val = tmp;
380 scatter_0 = i0_base + i0_win;
381 scatter_1 = i1_base + i1_win;
382 scatter_2 = i2_base + i2_win;
383 scatter_3 = i3_base + i3_win;
384 }
385 }
386 }
387 }
388 }
389 }
390 (*result)(scatter_0, scatter_1, scatter_2, scatter_3) +=
391 source(i0, i1, i2, i3);
392 }
393 }
394 }
395 }
396 return result;
397 }
398
399 /* static */ std::unique_ptr<Array4D<float>>
ConvArray4DGeneralDimensions(const Array4D<float> & lhs,const Array4D<float> & rhs,std::pair<int64,int64> kernel_stride,Padding padding,ConvolutionDimensionNumbers dimension_numbers)400 ReferenceUtil::ConvArray4DGeneralDimensions(
401 const Array4D<float>& lhs, const Array4D<float>& rhs,
402 std::pair<int64, int64> kernel_stride, Padding padding,
403 ConvolutionDimensionNumbers dimension_numbers) {
404 return ConvArray4DGeneralDimensionsDilated(lhs, rhs, kernel_stride, padding,
405 {1, 1}, {1, 1},
406 std::move(dimension_numbers));
407 }
408
409 /* static */ std::unique_ptr<Array4D<float>>
ConvArray4DGeneralDimensionsDilated(const Array4D<float> & lhs,const Array4D<float> & rhs,std::pair<int64,int64> kernel_stride,Padding padding,std::pair<int64,int64> lhs_dilation,std::pair<int64,int64> rhs_dilation,ConvolutionDimensionNumbers dnums)410 ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
411 const Array4D<float>& lhs, const Array4D<float>& rhs,
412 std::pair<int64, int64> kernel_stride, Padding padding,
413 std::pair<int64, int64> lhs_dilation, std::pair<int64, int64> rhs_dilation,
414 ConvolutionDimensionNumbers dnums) {
415 HloComputation::Builder b("ConvArray4DGeneralDimensionDilated");
416 auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs);
417 auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs);
418
419 std::array<int64, 2> ordered_kernel_strides;
420 std::array<int64, 2> ordered_input_dimensions;
421 std::array<int64, 2> ordered_kernel_dimensions;
422 if (dnums.kernel_spatial_dimensions(0) > dnums.kernel_spatial_dimensions(1)) {
423 ordered_kernel_strides[0] = kernel_stride.second;
424 ordered_kernel_strides[1] = kernel_stride.first;
425 } else {
426 ordered_kernel_strides[0] = kernel_stride.first;
427 ordered_kernel_strides[1] = kernel_stride.second;
428 }
429
430 ordered_input_dimensions[0] =
431 lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(0));
432 ordered_input_dimensions[1] =
433 lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(1));
434 ordered_kernel_dimensions[0] =
435 rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0));
436 ordered_kernel_dimensions[1] =
437 rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1));
438
439 std::vector<std::pair<int64, int64>> paddings =
440 MakePadding(ordered_input_dimensions, ordered_kernel_dimensions,
441 ordered_kernel_strides, padding);
442 CHECK_EQ(paddings.size(), 2);
443
444 Window window;
445
446 WindowDimension dim;
447 dim.set_size(
448 rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0)));
449 dim.set_stride(kernel_stride.first);
450 dim.set_padding_low(paddings[0].first);
451 dim.set_padding_high(paddings[0].second);
452 dim.set_window_dilation(rhs_dilation.first);
453 dim.set_base_dilation(lhs_dilation.first);
454 *window.add_dimensions() = dim;
455
456 WindowDimension dim2;
457 dim2.set_size(
458 rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1)));
459 dim2.set_stride(kernel_stride.second);
460 dim2.set_padding_low(paddings[1].first);
461 dim2.set_padding_high(paddings[1].second);
462 dim2.set_window_dilation(rhs_dilation.second);
463 dim2.set_base_dilation(lhs_dilation.second);
464 *window.add_dimensions() = dim2;
465
466 const Shape& shape =
467 ShapeInference::InferConvolveShape(
468 lhs_literal.shape(), rhs_literal.shape(),
469 /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums,
470 /*preferred_element_type=*/absl::nullopt)
471 .ConsumeValueOrDie();
472
473 HloInstruction* lhs_instruction =
474 b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
475 HloInstruction* rhs_instruction =
476 b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
477
478 PrecisionConfig precision_config;
479 precision_config.mutable_operand_precision()->Resize(
480 /*new_size=*/2, PrecisionConfig::DEFAULT);
481 b.AddInstruction(HloInstruction::CreateConvolve(
482 shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
483 /*batch_group_count=*/1, window, dnums, precision_config));
484 HloModuleConfig config;
485 HloModule module("ReferenceUtil", config);
486 auto computation = module.AddEntryComputation(b.Build());
487
488 HloEvaluator evaluator;
489 Literal result_literal =
490 evaluator.Evaluate(*computation, {}).ConsumeValueOrDie();
491
492 CHECK_EQ(result_literal.shape().rank(), 4);
493 auto result =
494 absl::make_unique<Array4D<float>>(result_literal.shape().dimensions(0),
495 result_literal.shape().dimensions(1),
496 result_literal.shape().dimensions(2),
497 result_literal.shape().dimensions(3));
498
499 result->Each([&](absl::Span<const int64> indices, float* value) {
500 *value = result_literal.Get<float>(indices);
501 });
502
503 return result;
504 }
505
506 /* static */ std::unique_ptr<std::vector<float>>
ReduceToColArray2D(const Array2D<float> & matrix,float init,const std::function<float (float,float)> & reduce_function)507 ReferenceUtil::ReduceToColArray2D(
508 const Array2D<float>& matrix, float init,
509 const std::function<float(float, float)>& reduce_function) {
510 int64_t rows = matrix.height();
511 int64_t cols = matrix.width();
512 auto result = absl::make_unique<std::vector<float>>();
513 for (int64_t i = 0; i < rows; ++i) {
514 float acc = init;
515 for (int64_t j = 0; j < cols; ++j) {
516 acc = reduce_function(acc, matrix(i, j));
517 }
518 result->push_back(acc);
519 }
520 return result;
521 }
522
523 /* static */ std::unique_ptr<std::vector<float>>
ReduceToRowArray2D(const Array2D<float> & matrix,float init,const std::function<float (float,float)> & reduce_function)524 ReferenceUtil::ReduceToRowArray2D(
525 const Array2D<float>& matrix, float init,
526 const std::function<float(float, float)>& reduce_function) {
527 int64_t rows = matrix.height();
528 int64_t cols = matrix.width();
529 auto result = absl::make_unique<std::vector<float>>();
530 for (int64_t i = 0; i < cols; ++i) {
531 float acc = init;
532 for (int64_t j = 0; j < rows; ++j) {
533 acc = reduce_function(acc, matrix(j, i));
534 }
535 result->push_back(acc);
536 }
537 return result;
538 }
539
Reduce4DTo1D(const Array4D<float> & array,float init,absl::Span<const int64> dims,const std::function<float (float,float)> & reduce_function)540 /*static*/ std::vector<float> ReferenceUtil::Reduce4DTo1D(
541 const Array4D<float>& array, float init, absl::Span<const int64> dims,
542 const std::function<float(float, float)>& reduce_function) {
543 std::vector<float> result;
544 CHECK_EQ(dims.size(), 3);
545 const absl::flat_hash_set<int64> dim_set(dims.begin(), dims.end());
546 CHECK_EQ(dim_set.size(), 3);
547 for (int64_t a0 = 0; a0 == 0 || (!dim_set.contains(0) && a0 < array.n1());
548 ++a0) {
549 for (int64_t a1 = 0; a1 == 0 || (!dim_set.contains(1) && a1 < array.n2());
550 ++a1) {
551 for (int64_t a2 = 0; a2 == 0 || (!dim_set.contains(2) && a2 < array.n3());
552 ++a2) {
553 for (int64_t a3 = 0;
554 a3 == 0 || (!dim_set.contains(3) && a3 < array.n4()); ++a3) {
555 float accumulator = init;
556 for (int64_t i0 = 0;
557 i0 == 0 || (dim_set.contains(0) && i0 < array.n1()); ++i0) {
558 for (int64_t i1 = 0;
559 i1 == 0 || (dim_set.contains(1) && i1 < array.n2()); ++i1) {
560 for (int64_t i2 = 0;
561 i2 == 0 || (dim_set.contains(2) && i2 < array.n3()); ++i2) {
562 for (int64_t i3 = 0;
563 i3 == 0 || (dim_set.contains(3) && i3 < array.n4());
564 ++i3) {
565 // Handle zero-sized arrays.
566 if (array.n1() > 0 && array.n2() > 0 && array.n3() > 0 &&
567 array.n4() > 0) {
568 accumulator = reduce_function(
569 accumulator, array(a0 + i0, a1 + i1, a2 + i2, a3 + i3));
570 }
571 }
572 }
573 }
574 }
575 result.push_back(accumulator);
576 }
577 }
578 }
579 }
580 return result;
581 }
582
Broadcast1DTo4D(const std::vector<float> & array,const std::vector<int64> & bounds,int64_t broadcast_from_dim)583 /* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::Broadcast1DTo4D(
584 const std::vector<float>& array, const std::vector<int64>& bounds,
585 int64_t broadcast_from_dim) {
586 auto result = absl::make_unique<Array4D<float>>(bounds[0], bounds[1],
587 bounds[2], bounds[3]);
588 for (int64_t i = 0; i < result->n1(); ++i) {
589 for (int64_t j = 0; j < result->n2(); ++j) {
590 for (int64_t k = 0; k < result->n3(); ++k) {
591 for (int64_t l = 0; l < result->n4(); ++l) {
592 switch (broadcast_from_dim) {
593 case 0:
594 (*result)(i, j, k, l) = array[i];
595 break;
596 case 1:
597 (*result)(i, j, k, l) = array[j];
598 break;
599 case 2:
600 (*result)(i, j, k, l) = array[k];
601 break;
602 case 3:
603 (*result)(i, j, k, l) = array[l];
604 break;
605 default:
606 break;
607 }
608 }
609 }
610 }
611 }
612 return result;
613 }
614
Reduce3DTo2D(const Array3D<float> & array,float init,absl::Span<const int64> dims,const std::function<float (float,float)> & reduce_function)615 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::Reduce3DTo2D(
616 const Array3D<float>& array, float init, absl::Span<const int64> dims,
617 const std::function<float(float, float)>& reduce_function) {
618 CHECK_EQ(dims.size(), 1);
619 int64_t rows = dims[0] == 0 ? array.n2() : array.n1();
620 int64_t cols = dims[0] == 2 ? array.n2() : array.n3();
621 auto result = absl::make_unique<Array2D<float>>(rows, cols);
622 result->Fill(init);
623 for (int i0 = 0; i0 < array.n1(); ++i0) {
624 for (int i1 = 0; i1 < array.n2(); ++i1) {
625 for (int i2 = 0; i2 < array.n3(); ++i2) {
626 int64_t row = dims[0] == 0 ? i1 : i0;
627 int64_t col = dims[0] == 2 ? i1 : i2;
628 (*result)(row, col) =
629 reduce_function((*result)(row, col), array(i0, i1, i2));
630 }
631 }
632 }
633 return result;
634 }
635
MapArray2D(const Array2D<float> & matrix,const std::function<float (float)> & map_function)636 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapArray2D(
637 const Array2D<float>& matrix,
638 const std::function<float(float)>& map_function) {
639 int64_t rows = matrix.height();
640 int64_t cols = matrix.width();
641 auto result = absl::make_unique<Array2D<float>>(rows, cols);
642 for (int64_t i = 0; i < rows; ++i) {
643 for (int64_t j = 0; j < cols; ++j) {
644 (*result)(i, j) = map_function(matrix(i, j));
645 }
646 }
647 return result;
648 }
649
MapArray2D(const Array2D<float> & lhs,const Array2D<float> & rhs,const std::function<float (float,float)> & map_function)650 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapArray2D(
651 const Array2D<float>& lhs, const Array2D<float>& rhs,
652 const std::function<float(float, float)>& map_function) {
653 CHECK_EQ(lhs.height(), rhs.height());
654 CHECK_EQ(lhs.width(), rhs.width());
655 int64_t rows = lhs.height();
656 int64_t cols = rhs.width();
657 auto result = absl::make_unique<Array2D<float>>(rows, cols);
658 for (int64_t i = 0; i < rows; ++i) {
659 for (int64_t j = 0; j < cols; ++j) {
660 (*result)(i, j) = map_function(lhs(i, j), rhs(i, j));
661 }
662 }
663 return result;
664 }
665
MapWithIndexArray2D(const Array2D<float> & matrix,const std::function<float (float,int64_t,int64_t)> & map_function)666 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapWithIndexArray2D(
667 const Array2D<float>& matrix,
668 const std::function<float(float, int64_t, int64_t)>& map_function) {
669 int64_t rows = matrix.height();
670 int64_t cols = matrix.width();
671 auto result = absl::make_unique<Array2D<float>>(rows, cols);
672 for (int64_t i = 0; i < rows; ++i) {
673 for (int64_t j = 0; j < cols; ++j) {
674 (*result)(i, j) = map_function(matrix(i, j), i, j);
675 }
676 }
677 return result;
678 }
679
680 } // namespace xla
681