1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/reference_util.h"
17
18 #include <array>
19 #include <utility>
20
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/memory/memory.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/literal_util.h"
25 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/shape_inference.h"
28 #include "tensorflow/compiler/xla/window_util.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/core/lib/math/math_util.h"
31 #include "tensorflow/core/platform/logging.h"
32
33 namespace xla {
34
Array2DF32ToF64(const Array2D<float> & input)35 /* static */ std::unique_ptr<Array2D<double>> ReferenceUtil::Array2DF32ToF64(
36 const Array2D<float>& input) {
37 auto result =
38 absl::make_unique<Array2D<double>>(input.height(), input.width());
39 for (int64 rowno = 0; rowno < input.height(); ++rowno) {
40 for (int64 colno = 0; colno < input.height(); ++colno) {
41 (*result)(rowno, colno) = input(rowno, colno);
42 }
43 }
44 return result;
45 }
46
ConvArray3D(const Array3D<float> & lhs,const Array3D<float> & rhs,int64 kernel_stride,Padding padding)47 /* static */ std::unique_ptr<Array3D<float>> ReferenceUtil::ConvArray3D(
48 const Array3D<float>& lhs, const Array3D<float>& rhs, int64 kernel_stride,
49 Padding padding) {
50 return ConvArray3DGeneralDimensionsDilated(
51 lhs, rhs, kernel_stride, padding, 1, 1,
52 XlaBuilder::CreateDefaultConvDimensionNumbers(1));
53 }
54
55 /*static*/ std::unique_ptr<Array3D<float>>
ConvArray3DGeneralDimensionsDilated(const Array3D<float> & lhs,const Array3D<float> & rhs,int64 kernel_stride,Padding padding,int64 lhs_dilation,int64 rhs_dilation,const ConvolutionDimensionNumbers & dnums)56 ReferenceUtil::ConvArray3DGeneralDimensionsDilated(
57 const Array3D<float>& lhs, const Array3D<float>& rhs, int64 kernel_stride,
58 Padding padding, int64 lhs_dilation, int64 rhs_dilation,
59 const ConvolutionDimensionNumbers& dnums) {
60 CHECK_EQ(dnums.input_spatial_dimensions_size(), 1);
61 CHECK_EQ(dnums.kernel_spatial_dimensions_size(), 1);
62 CHECK_EQ(dnums.output_spatial_dimensions_size(), 1);
63 // Reuse the code for Array4D-convolution by extending the 3D input into a 4D
64 // array by adding a fourth dummy dimension of size 1 without stride, padding
65 // and dilation.
66 Array4D<float> a4dlhs(lhs.n1(), lhs.n2(), lhs.n3(), 1);
67 a4dlhs.Each([&](absl::Span<const int64> indices, float* value_ptr) {
68 CHECK_EQ(indices[3], 0);
69 *value_ptr = lhs.operator()(indices[0], indices[1], indices[2]);
70 });
71 Array4D<float> a4drhs(rhs.n1(), rhs.n2(), rhs.n3(), 1);
72 a4drhs.Each([&](absl::Span<const int64> indices, float* value_ptr) {
73 CHECK_EQ(indices[3], 0);
74 *value_ptr = rhs.operator()(indices[0], indices[1], indices[2]);
75 });
76 // Add a second dummy spatial dimensions.
77 ConvolutionDimensionNumbers dnums2d = dnums;
78 dnums2d.add_input_spatial_dimensions(3);
79 dnums2d.add_kernel_spatial_dimensions(3);
80 dnums2d.add_output_spatial_dimensions(3);
81 std::unique_ptr<Array4D<float>> convr4 = ConvArray4DGeneralDimensionsDilated(
82 a4dlhs, a4drhs, {kernel_stride, 1}, padding, {lhs_dilation, 1},
83 {rhs_dilation, 1}, dnums2d);
84
85 auto convr3 = absl::make_unique<Array3D<float>>(
86 convr4->planes(), convr4->depth(), convr4->height());
87 convr4->Each([&](absl::Span<const int64> indices, float* value_ptr) {
88 CHECK_EQ(indices[3], 0);
89 convr3->operator()(indices[0], indices[1], indices[2]) = *value_ptr;
90 });
91 return convr3;
92 }
93
ConvArray4D(const Array4D<float> & lhs,const Array4D<float> & rhs,std::pair<int64,int64> kernel_stride,Padding padding)94 /* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ConvArray4D(
95 const Array4D<float>& lhs, const Array4D<float>& rhs,
96 std::pair<int64, int64> kernel_stride, Padding padding) {
97 return ConvArray4DGeneralDimensions(
98 lhs, rhs, kernel_stride, padding,
99 XlaBuilder::CreateDefaultConvDimensionNumbers());
100 }
101
102 /* static */ std::unique_ptr<Array4D<float>>
SeparableConvArray4D(const Array4D<float> & input,const Array4D<float> & depthwise_weights,const Array4D<float> & pointwise_weights,std::pair<int64,int64> kernel_stride,Padding padding)103 ReferenceUtil::SeparableConvArray4D(const Array4D<float>& input,
104 const Array4D<float>& depthwise_weights,
105 const Array4D<float>& pointwise_weights,
106 std::pair<int64, int64> kernel_stride,
107 Padding padding) {
108 const int64 depth_multiplier = depthwise_weights.planes();
109 CHECK_EQ(pointwise_weights.depth(), input.depth() * depth_multiplier);
110
111 // Combine the two weights by reducing the depth_multiplier, so that we can
112 // apply a single convolution on the combined weights.
113 Array4D<float> weights(pointwise_weights.planes(), input.depth(),
114 depthwise_weights.height(), depthwise_weights.width());
115 for (int64 kx = 0; kx < depthwise_weights.width(); ++kx) {
116 for (int64 ky = 0; ky < depthwise_weights.height(); ++ky) {
117 for (int64 kz = 0; kz < input.depth(); ++kz) {
118 for (int64 out = 0; out < pointwise_weights.planes(); ++out) {
119 float weight = 0.0;
120 for (int64 depth = 0; depth < depth_multiplier; ++depth) {
121 weight +=
122 depthwise_weights(depth, kz, ky, kx) *
123 pointwise_weights(out, depth + kz * depth_multiplier, 0, 0);
124 }
125 weights(out, kz, ky, kx) = weight;
126 }
127 }
128 }
129 }
130
131 return ConvArray4D(input, weights, kernel_stride, padding);
132 }
133
WindowCount(int64 unpadded_width,int64 window_len,int64 stride,Padding padding)134 /* static */ int64 ReferenceUtil::WindowCount(int64 unpadded_width,
135 int64 window_len, int64 stride,
136 Padding padding) {
137 if (padding == Padding::kValid) {
138 return window_util::StridedBound(unpadded_width, window_len, stride);
139 }
140 return tensorflow::MathUtil::CeilOfRatio(unpadded_width, stride);
141 }
142
143 /* static */ std::unique_ptr<std::vector<float>>
ReduceWindow1DGeneric(absl::Span<const float> operand,float init,const std::function<float (float,float)> & reduce_func,absl::Span<const int64> window,absl::Span<const int64> stride,absl::Span<const std::pair<int64,int64>> padding)144 ReferenceUtil::ReduceWindow1DGeneric(
145 absl::Span<const float> operand, float init,
146 const std::function<float(float, float)>& reduce_func,
147 absl::Span<const int64> window, absl::Span<const int64> stride,
148 absl::Span<const std::pair<int64, int64>> padding) {
149 CHECK_EQ(window.size(), 1);
150 CHECK_EQ(stride.size(), 1);
151 CHECK_EQ(padding.size(), 1);
152
153 int64 padded_width = padding[0].first + operand.size() + padding[0].second;
154 int64 stride_amount = stride[0];
155 int64 window_size = window[0];
156 int64 result_size =
157 window_util::StridedBound(padded_width, window_size, stride_amount);
158 int64 pad_low = padding[0].first;
159 auto result = absl::make_unique<std::vector<float>>(result_size);
160
161 // Do a full 1D reduce window.
162 for (int64 i0 = 0; i0 < result_size; ++i0) {
163 int64 i0_base = i0 * stride_amount - pad_low;
164 float val = init;
165 for (int64 i0_win = 0; i0_win < window_size; ++i0_win) {
166 if (i0_base + i0_win >= 0 && i0_base + i0_win < operand.size()) {
167 val = reduce_func(val, operand[i0_base + i0_win]);
168 }
169 }
170 (*result)[i0] = val;
171 }
172 return result;
173 }
174
175 /* static */ std::unique_ptr<std::vector<float>>
ReduceWindow1DAdd(absl::Span<const float> operand,float init,absl::Span<const int64> window,absl::Span<const int64> stride,Padding padding)176 ReferenceUtil::ReduceWindow1DAdd(absl::Span<const float> operand, float init,
177 absl::Span<const int64> window,
178 absl::Span<const int64> stride,
179 Padding padding) {
180 const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
181 std::vector<int64> dim_lengths{static_cast<int64>(operand.size())};
182 return ReduceWindow1DGeneric(
183 operand, init, add_reduce, window, stride,
184 xla::MakePadding(dim_lengths, window, stride, padding));
185 }
186
ReduceWindow3DAdd(const Array3D<float> & operand,float init,absl::Span<const int64> window,absl::Span<const int64> stride,Padding padding)187 /* static */ std::unique_ptr<Array3D<float>> ReferenceUtil::ReduceWindow3DAdd(
188 const Array3D<float>& operand, float init, absl::Span<const int64> window,
189 absl::Span<const int64> stride, Padding padding) {
190 std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3()};
191 auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
192
193 std::vector<int64> window_counts(window.size(), 0);
194 std::vector<int64> pad_low(window.size(), 0);
195 for (int64 i = 0; i < window.size(); ++i) {
196 window_counts[i] =
197 WindowCount(dim_lengths[i], window[i], stride[i], padding);
198 pad_low[i] = padding_both[i].first;
199 }
200 auto result = absl::make_unique<Array3D<float>>(
201 window_counts[0], window_counts[1], window_counts[2]);
202
203 for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
204 for (int64 i1 = 0; i1 < window_counts[1]; ++i1) {
205 for (int64 i2 = 0; i2 < window_counts[2]; ++i2) {
206 int64 i0_base = i0 * stride[0] - pad_low[0];
207 int64 i1_base = i1 * stride[1] - pad_low[1];
208 int64 i2_base = i2 * stride[2] - pad_low[2];
209
210 float val = init;
211 for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
212 for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) {
213 for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) {
214 if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
215 i2_base + i2_win >= 0 && i0_base + i0_win < operand.n1() &&
216 i1_base + i1_win < operand.n2() &&
217 i2_base + i2_win < operand.n3()) {
218 val += operand(i0_base + i0_win, i1_base + i1_win,
219 i2_base + i2_win);
220 }
221 }
222 }
223 }
224 (*result)(i0, i1, i2) = val;
225 }
226 }
227 }
228 return result;
229 }
230
231 /* static */ std::unique_ptr<Array4D<float>>
ReduceWindow4DGeneric(const Array4D<float> & operand,float init,const std::function<float (float,float)> & reduce_func,absl::Span<const int64> window,absl::Span<const int64> stride,Padding padding)232 ReferenceUtil::ReduceWindow4DGeneric(
233 const Array4D<float>& operand, float init,
234 const std::function<float(float, float)>& reduce_func,
235 absl::Span<const int64> window, absl::Span<const int64> stride,
236 Padding padding) {
237 std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
238 operand.n4()};
239 return ReduceWindow4DGeneric(
240 operand, init, reduce_func, window, stride,
241 xla::MakePadding(dim_lengths, window, stride, padding));
242 }
243
244 /* static */ std::unique_ptr<Array4D<float>>
ReduceWindow4DGeneric(const Array4D<float> & operand,float init,const std::function<float (float,float)> & reduce_func,absl::Span<const int64> window,absl::Span<const int64> stride,absl::Span<const std::pair<int64,int64>> padding)245 ReferenceUtil::ReduceWindow4DGeneric(
246 const Array4D<float>& operand, float init,
247 const std::function<float(float, float)>& reduce_func,
248 absl::Span<const int64> window, absl::Span<const int64> stride,
249 absl::Span<const std::pair<int64, int64>> padding) {
250 std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
251 operand.n4()};
252
253 std::vector<int64> window_counts(window.size(), 0);
254 std::vector<int64> pad_low(window.size(), 0);
255 for (int64 i = 0; i < window.size(); ++i) {
256 int64 padded_width = padding[i].first + dim_lengths[i] + padding[i].second;
257 window_counts[i] =
258 window_util::StridedBound(padded_width, window[i], stride[i]);
259 pad_low[i] = padding[i].first;
260 }
261 auto result = absl::make_unique<Array4D<float>>(
262 window_counts[0], window_counts[1], window_counts[2], window_counts[3]);
263 // Do a full 4D reduce window.
264 for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
265 for (int64 i1 = 0; i1 < window_counts[1]; ++i1) {
266 for (int64 i2 = 0; i2 < window_counts[2]; ++i2) {
267 for (int64 i3 = 0; i3 < window_counts[3]; ++i3) {
268 int64 i0_base = i0 * stride[0] - pad_low[0];
269 int64 i1_base = i1 * stride[1] - pad_low[1];
270 int64 i2_base = i2 * stride[2] - pad_low[2];
271 int64 i3_base = i3 * stride[3] - pad_low[3];
272
273 float val = init;
274 for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
275 for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) {
276 for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) {
277 for (int64 i3_win = 0; i3_win < window[3]; ++i3_win) {
278 if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
279 i2_base + i2_win >= 0 && i3_base + i3_win >= 0 &&
280 i0_base + i0_win < operand.n1() &&
281 i1_base + i1_win < operand.n2() &&
282 i2_base + i2_win < operand.n3() &&
283 i3_base + i3_win < operand.n4()) {
284 val = reduce_func(
285 val, operand(i0_base + i0_win, i1_base + i1_win,
286 i2_base + i2_win, i3_base + i3_win));
287 }
288 }
289 }
290 }
291 }
292 (*result)(i0, i1, i2, i3) = val;
293 }
294 }
295 }
296 }
297 return result;
298 }
299
ReduceWindow4DAdd(const Array4D<float> & operand,float init,absl::Span<const int64> window,absl::Span<const int64> stride,Padding padding)300 /* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ReduceWindow4DAdd(
301 const Array4D<float>& operand, float init, absl::Span<const int64> window,
302 absl::Span<const int64> stride, Padding padding) {
303 const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
304 return ReduceWindow4DGeneric(operand, init, add_reduce, window, stride,
305 padding);
306 }
307
BatchNorm4D(const Array4D<float> & input,const Array4D<float> & mean,const Array4D<float> & var,const Array4D<float> & scale,const Array4D<float> & offset,float epsilon)308 /* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::BatchNorm4D(
309 const Array4D<float>& input, const Array4D<float>& mean,
310 const Array4D<float>& var, const Array4D<float>& scale,
311 const Array4D<float>& offset, float epsilon) {
312 auto normalized =
313 *MapArray4D(input, mean, [](float a, float b) { return a - b; });
314 normalized = *MapArray4D(normalized, var, [&](float a, float b) {
315 return a / std::sqrt(b + epsilon);
316 });
317 normalized =
318 *MapArray4D(normalized, scale, [](float a, float b) { return a * b; });
319 return MapArray4D(normalized, offset, [](float a, float b) { return a + b; });
320 }
321
322 /* static */ std::unique_ptr<Array4D<float>>
SelectAndScatter4DGePlus(const Array4D<float> & operand,const Array4D<float> & source,float init,absl::Span<const int64> window,absl::Span<const int64> stride,bool same_padding)323 ReferenceUtil::SelectAndScatter4DGePlus(const Array4D<float>& operand,
324 const Array4D<float>& source,
325 float init,
326 absl::Span<const int64> window,
327 absl::Span<const int64> stride,
328 bool same_padding) {
329 Padding padding = same_padding ? Padding::kSame : Padding::kValid;
330 auto result = absl::make_unique<Array4D<float>>(operand.n1(), operand.n2(),
331 operand.n3(), operand.n4());
332 std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
333 operand.n4()};
334 auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
335 // Fill the output, with the initial value.
336 result->Fill(init);
337
338 std::vector<int64> window_counts(window.size(), 0);
339 std::vector<int64> pad_low(window.size(), 0);
340 for (int64 i = 0; i < window.size(); ++i) {
341 window_counts[i] =
342 WindowCount(dim_lengths[i], window[i], stride[i], padding);
343 pad_low[i] = padding_both[i].first;
344 }
345 CHECK_EQ(window_counts[0], source.n1());
346 CHECK_EQ(window_counts[1], source.n2());
347 CHECK_EQ(window_counts[2], source.n3());
348 CHECK_EQ(window_counts[3], source.n4());
349
350 // Do a full 4D select and Scatter.
351 for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
352 for (int64 i1 = 0; i1 < window_counts[1]; ++i1) {
353 for (int64 i2 = 0; i2 < window_counts[2]; ++i2) {
354 for (int64 i3 = 0; i3 < window_counts[3]; ++i3) {
355 // Now we are inside a window and need to find the max and the argmax.
356 int64 i0_base = i0 * stride[0] - pad_low[0];
357 int64 i1_base = i1 * stride[1] - pad_low[1];
358 int64 i2_base = i2 * stride[2] - pad_low[2];
359 int64 i3_base = i3 * stride[3] - pad_low[3];
360 int64 scatter_0 = (i0_base >= 0) ? i0_base : 0;
361 int64 scatter_1 = (i1_base >= 0) ? i1_base : 0;
362 int64 scatter_2 = (i2_base >= 0) ? i2_base : 0;
363 int64 scatter_3 = (i3_base >= 0) ? i3_base : 0;
364 float val = operand(scatter_0, scatter_1, scatter_2, scatter_3);
365 for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
366 for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) {
367 for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) {
368 for (int64 i3_win = 0; i3_win < window[3]; ++i3_win) {
369 if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
370 i2_base + i2_win >= 0 && i3_base + i3_win >= 0 &&
371 i0_base + i0_win < operand.n1() &&
372 i1_base + i1_win < operand.n2() &&
373 i2_base + i2_win < operand.n3() &&
374 i3_base + i3_win < operand.n4()) {
375 float tmp = operand(i0_base + i0_win, i1_base + i1_win,
376 i2_base + i2_win, i3_base + i3_win);
377 if (tmp > val) {
378 val = tmp;
379 scatter_0 = i0_base + i0_win;
380 scatter_1 = i1_base + i1_win;
381 scatter_2 = i2_base + i2_win;
382 scatter_3 = i3_base + i3_win;
383 }
384 }
385 }
386 }
387 }
388 }
389 (*result)(scatter_0, scatter_1, scatter_2, scatter_3) +=
390 source(i0, i1, i2, i3);
391 }
392 }
393 }
394 }
395 return result;
396 }
397
398 /* static */ std::unique_ptr<Array4D<float>>
ConvArray4DGeneralDimensions(const Array4D<float> & lhs,const Array4D<float> & rhs,std::pair<int64,int64> kernel_stride,Padding padding,ConvolutionDimensionNumbers dimension_numbers)399 ReferenceUtil::ConvArray4DGeneralDimensions(
400 const Array4D<float>& lhs, const Array4D<float>& rhs,
401 std::pair<int64, int64> kernel_stride, Padding padding,
402 ConvolutionDimensionNumbers dimension_numbers) {
403 return ConvArray4DGeneralDimensionsDilated(lhs, rhs, kernel_stride, padding,
404 {1, 1}, {1, 1},
405 std::move(dimension_numbers));
406 }
407
408 /* static */ std::unique_ptr<Array4D<float>>
ConvArray4DGeneralDimensionsDilated(const Array4D<float> & lhs,const Array4D<float> & rhs,std::pair<int64,int64> kernel_stride,Padding padding,std::pair<int64,int64> lhs_dilation,std::pair<int64,int64> rhs_dilation,ConvolutionDimensionNumbers dnums)409 ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
410 const Array4D<float>& lhs, const Array4D<float>& rhs,
411 std::pair<int64, int64> kernel_stride, Padding padding,
412 std::pair<int64, int64> lhs_dilation, std::pair<int64, int64> rhs_dilation,
413 ConvolutionDimensionNumbers dnums) {
414 HloComputation::Builder b("ConvArray4DGeneralDimensionDilated");
415 auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs);
416 auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs);
417
418 std::array<int64, 2> ordered_kernel_strides;
419 std::array<int64, 2> ordered_input_dimensions;
420 std::array<int64, 2> ordered_kernel_dimensions;
421 if (dnums.kernel_spatial_dimensions(0) > dnums.kernel_spatial_dimensions(1)) {
422 ordered_kernel_strides[0] = kernel_stride.second;
423 ordered_kernel_strides[1] = kernel_stride.first;
424 } else {
425 ordered_kernel_strides[0] = kernel_stride.first;
426 ordered_kernel_strides[1] = kernel_stride.second;
427 }
428
429 ordered_input_dimensions[0] =
430 lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(0));
431 ordered_input_dimensions[1] =
432 lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(1));
433 ordered_kernel_dimensions[0] =
434 rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0));
435 ordered_kernel_dimensions[1] =
436 rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1));
437
438 std::vector<std::pair<int64, int64>> paddings =
439 MakePadding(ordered_input_dimensions, ordered_kernel_dimensions,
440 ordered_kernel_strides, padding);
441 CHECK_EQ(paddings.size(), 2);
442
443 Window window;
444
445 WindowDimension dim;
446 dim.set_size(
447 rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0)));
448 dim.set_stride(kernel_stride.first);
449 dim.set_padding_low(paddings[0].first);
450 dim.set_padding_high(paddings[0].second);
451 dim.set_window_dilation(rhs_dilation.first);
452 dim.set_base_dilation(lhs_dilation.first);
453 *window.add_dimensions() = dim;
454
455 WindowDimension dim2;
456 dim2.set_size(
457 rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1)));
458 dim2.set_stride(kernel_stride.second);
459 dim2.set_padding_low(paddings[1].first);
460 dim2.set_padding_high(paddings[1].second);
461 dim2.set_window_dilation(rhs_dilation.second);
462 dim2.set_base_dilation(lhs_dilation.second);
463 *window.add_dimensions() = dim2;
464
465 const Shape& shape =
466 ShapeInference::InferConvolveShape(
467 lhs_literal.shape(), rhs_literal.shape(),
468 /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums,
469 /*preferred_element_type=*/absl::nullopt)
470 .ConsumeValueOrDie();
471
472 HloInstruction* lhs_instruction =
473 b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
474 HloInstruction* rhs_instruction =
475 b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
476
477 PrecisionConfig precision_config;
478 precision_config.mutable_operand_precision()->Resize(
479 /*new_size=*/2, PrecisionConfig::DEFAULT);
480 b.AddInstruction(HloInstruction::CreateConvolve(
481 shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
482 /*batch_group_count=*/1, window, dnums, precision_config));
483 HloModuleConfig config;
484 HloModule module("ReferenceUtil", config);
485 auto computation = module.AddEntryComputation(b.Build());
486
487 HloEvaluator evaluator;
488 Literal result_literal =
489 evaluator.Evaluate(*computation, {}).ConsumeValueOrDie();
490
491 CHECK_EQ(result_literal.shape().rank(), 4);
492 auto result =
493 absl::make_unique<Array4D<float>>(result_literal.shape().dimensions(0),
494 result_literal.shape().dimensions(1),
495 result_literal.shape().dimensions(2),
496 result_literal.shape().dimensions(3));
497
498 result->Each([&](absl::Span<const int64> indices, float* value) {
499 *value = result_literal.Get<float>(indices);
500 });
501
502 return result;
503 }
504
505 /* static */ std::unique_ptr<std::vector<float>>
ReduceToColArray2D(const Array2D<float> & matrix,float init,const std::function<float (float,float)> & reduce_function)506 ReferenceUtil::ReduceToColArray2D(
507 const Array2D<float>& matrix, float init,
508 const std::function<float(float, float)>& reduce_function) {
509 int64 rows = matrix.height();
510 int64 cols = matrix.width();
511 auto result = absl::make_unique<std::vector<float>>();
512 for (int64 i = 0; i < rows; ++i) {
513 float acc = init;
514 for (int64 j = 0; j < cols; ++j) {
515 acc = reduce_function(acc, matrix(i, j));
516 }
517 result->push_back(acc);
518 }
519 return result;
520 }
521
522 /* static */ std::unique_ptr<std::vector<float>>
ReduceToRowArray2D(const Array2D<float> & matrix,float init,const std::function<float (float,float)> & reduce_function)523 ReferenceUtil::ReduceToRowArray2D(
524 const Array2D<float>& matrix, float init,
525 const std::function<float(float, float)>& reduce_function) {
526 int64 rows = matrix.height();
527 int64 cols = matrix.width();
528 auto result = absl::make_unique<std::vector<float>>();
529 for (int64 i = 0; i < cols; ++i) {
530 float acc = init;
531 for (int64 j = 0; j < rows; ++j) {
532 acc = reduce_function(acc, matrix(j, i));
533 }
534 result->push_back(acc);
535 }
536 return result;
537 }
538
Reduce4DTo1D(const Array4D<float> & array,float init,absl::Span<const int64> dims,const std::function<float (float,float)> & reduce_function)539 /*static*/ std::vector<float> ReferenceUtil::Reduce4DTo1D(
540 const Array4D<float>& array, float init, absl::Span<const int64> dims,
541 const std::function<float(float, float)>& reduce_function) {
542 std::vector<float> result;
543 CHECK_EQ(dims.size(), 3);
544 const absl::flat_hash_set<int64> dim_set(dims.begin(), dims.end());
545 CHECK_EQ(dim_set.size(), 3);
546 for (int64 a0 = 0; a0 == 0 || (!dim_set.contains(0) && a0 < array.n1());
547 ++a0) {
548 for (int64 a1 = 0; a1 == 0 || (!dim_set.contains(1) && a1 < array.n2());
549 ++a1) {
550 for (int64 a2 = 0; a2 == 0 || (!dim_set.contains(2) && a2 < array.n3());
551 ++a2) {
552 for (int64 a3 = 0; a3 == 0 || (!dim_set.contains(3) && a3 < array.n4());
553 ++a3) {
554 float accumulator = init;
555 for (int64 i0 = 0;
556 i0 == 0 || (dim_set.contains(0) && i0 < array.n1()); ++i0) {
557 for (int64 i1 = 0;
558 i1 == 0 || (dim_set.contains(1) && i1 < array.n2()); ++i1) {
559 for (int64 i2 = 0;
560 i2 == 0 || (dim_set.contains(2) && i2 < array.n3()); ++i2) {
561 for (int64 i3 = 0;
562 i3 == 0 || (dim_set.contains(3) && i3 < array.n4());
563 ++i3) {
564 // Handle zero-sized arrays.
565 if (array.n1() > 0 && array.n2() > 0 && array.n3() > 0 &&
566 array.n4() > 0) {
567 accumulator = reduce_function(
568 accumulator, array(a0 + i0, a1 + i1, a2 + i2, a3 + i3));
569 }
570 }
571 }
572 }
573 }
574 result.push_back(accumulator);
575 }
576 }
577 }
578 }
579 return result;
580 }
581
Broadcast1DTo4D(const std::vector<float> & array,const std::vector<int64> & bounds,int64 broadcast_from_dim)582 /* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::Broadcast1DTo4D(
583 const std::vector<float>& array, const std::vector<int64>& bounds,
584 int64 broadcast_from_dim) {
585 auto result = absl::make_unique<Array4D<float>>(bounds[0], bounds[1],
586 bounds[2], bounds[3]);
587 for (int64 i = 0; i < result->n1(); ++i) {
588 for (int64 j = 0; j < result->n2(); ++j) {
589 for (int64 k = 0; k < result->n3(); ++k) {
590 for (int64 l = 0; l < result->n4(); ++l) {
591 switch (broadcast_from_dim) {
592 case 0:
593 (*result)(i, j, k, l) = array[i];
594 break;
595 case 1:
596 (*result)(i, j, k, l) = array[j];
597 break;
598 case 2:
599 (*result)(i, j, k, l) = array[k];
600 break;
601 case 3:
602 (*result)(i, j, k, l) = array[l];
603 break;
604 default:
605 break;
606 }
607 }
608 }
609 }
610 }
611 return result;
612 }
613
Reduce3DTo2D(const Array3D<float> & array,float init,absl::Span<const int64> dims,const std::function<float (float,float)> & reduce_function)614 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::Reduce3DTo2D(
615 const Array3D<float>& array, float init, absl::Span<const int64> dims,
616 const std::function<float(float, float)>& reduce_function) {
617 CHECK_EQ(dims.size(), 1);
618 int64 rows = dims[0] == 0 ? array.n2() : array.n1();
619 int64 cols = dims[0] == 2 ? array.n2() : array.n3();
620 auto result = absl::make_unique<Array2D<float>>(rows, cols);
621 result->Fill(init);
622 for (int i0 = 0; i0 < array.n1(); ++i0) {
623 for (int i1 = 0; i1 < array.n2(); ++i1) {
624 for (int i2 = 0; i2 < array.n3(); ++i2) {
625 int64 row = dims[0] == 0 ? i1 : i0;
626 int64 col = dims[0] == 2 ? i1 : i2;
627 (*result)(row, col) =
628 reduce_function((*result)(row, col), array(i0, i1, i2));
629 }
630 }
631 }
632 return result;
633 }
634
MapArray2D(const Array2D<float> & matrix,const std::function<float (float)> & map_function)635 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapArray2D(
636 const Array2D<float>& matrix,
637 const std::function<float(float)>& map_function) {
638 int64 rows = matrix.height();
639 int64 cols = matrix.width();
640 auto result = absl::make_unique<Array2D<float>>(rows, cols);
641 for (int64 i = 0; i < rows; ++i) {
642 for (int64 j = 0; j < cols; ++j) {
643 (*result)(i, j) = map_function(matrix(i, j));
644 }
645 }
646 return result;
647 }
648
MapArray2D(const Array2D<float> & lhs,const Array2D<float> & rhs,const std::function<float (float,float)> & map_function)649 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapArray2D(
650 const Array2D<float>& lhs, const Array2D<float>& rhs,
651 const std::function<float(float, float)>& map_function) {
652 CHECK_EQ(lhs.height(), rhs.height());
653 CHECK_EQ(lhs.width(), rhs.width());
654 int64 rows = lhs.height();
655 int64 cols = rhs.width();
656 auto result = absl::make_unique<Array2D<float>>(rows, cols);
657 for (int64 i = 0; i < rows; ++i) {
658 for (int64 j = 0; j < cols; ++j) {
659 (*result)(i, j) = map_function(lhs(i, j), rhs(i, j));
660 }
661 }
662 return result;
663 }
664
MapWithIndexArray2D(const Array2D<float> & matrix,const std::function<float (float,int64,int64)> & map_function)665 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapWithIndexArray2D(
666 const Array2D<float>& matrix,
667 const std::function<float(float, int64, int64)>& map_function) {
668 int64 rows = matrix.height();
669 int64 cols = matrix.width();
670 auto result = absl::make_unique<Array2D<float>>(rows, cols);
671 for (int64 i = 0; i < rows; ++i) {
672 for (int64 j = 0; j < cols; ++j) {
673 (*result)(i, j) = map_function(matrix(i, j), i, j);
674 }
675 }
676 return result;
677 }
678
679 } // namespace xla
680