1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/reference_util.h"
17
18 #include <array>
19 #include <utility>
20
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/memory/memory.h"
23 #include "tensorflow/compiler/xla/client/xla_builder.h"
24 #include "tensorflow/compiler/xla/literal_util.h"
25 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
26 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
27 #include "tensorflow/compiler/xla/service/shape_inference.h"
28 #include "tensorflow/compiler/xla/window_util.h"
29 #include "tensorflow/compiler/xla/xla_data.pb.h"
30 #include "tensorflow/core/lib/math/math_util.h"
31 #include "tensorflow/core/platform/logging.h"
32
33 namespace xla {
34
MatmulArray2D(const Array2D<Eigen::half> & lhs,const Array2D<Eigen::half> & rhs)35 /* static */ std::unique_ptr<Array2D<Eigen::half>> ReferenceUtil::MatmulArray2D(
36 const Array2D<Eigen::half>& lhs, const Array2D<Eigen::half>& rhs) {
37 return HloEvaluator::MatmulArray2D(lhs, rhs);
38 }
39
MatmulArray2D(const Array2D<float> & lhs,const Array2D<float> & rhs)40 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MatmulArray2D(
41 const Array2D<float>& lhs, const Array2D<float>& rhs) {
42 return HloEvaluator::MatmulArray2D(lhs, rhs);
43 }
44
MatmulArray2D(const Array2D<double> & lhs,const Array2D<double> & rhs)45 /* static */ std::unique_ptr<Array2D<double>> ReferenceUtil::MatmulArray2D(
46 const Array2D<double>& lhs, const Array2D<double>& rhs) {
47 return HloEvaluator::MatmulArray2D(lhs, rhs);
48 }
49
Array2DF32ToF64(const Array2D<float> & input)50 /* static */ std::unique_ptr<Array2D<double>> ReferenceUtil::Array2DF32ToF64(
51 const Array2D<float>& input) {
52 auto result =
53 absl::make_unique<Array2D<double>>(input.height(), input.width());
54 for (int64 rowno = 0; rowno < input.height(); ++rowno) {
55 for (int64 colno = 0; colno < input.height(); ++colno) {
56 (*result)(rowno, colno) = input(rowno, colno);
57 }
58 }
59 return result;
60 }
61
ConvArray3D(const Array3D<float> & lhs,const Array3D<float> & rhs,int64 kernel_stride,Padding padding)62 /* static */ std::unique_ptr<Array3D<float>> ReferenceUtil::ConvArray3D(
63 const Array3D<float>& lhs, const Array3D<float>& rhs, int64 kernel_stride,
64 Padding padding) {
65 return ConvArray3DGeneralDimensionsDilated(
66 lhs, rhs, kernel_stride, padding, 1, 1,
67 XlaBuilder::CreateDefaultConvDimensionNumbers(1));
68 }
69
70 /*static*/ std::unique_ptr<Array3D<float>>
ConvArray3DGeneralDimensionsDilated(const Array3D<float> & lhs,const Array3D<float> & rhs,int64 kernel_stride,Padding padding,int64 lhs_dilation,int64 rhs_dilation,const ConvolutionDimensionNumbers & dnums)71 ReferenceUtil::ConvArray3DGeneralDimensionsDilated(
72 const Array3D<float>& lhs, const Array3D<float>& rhs, int64 kernel_stride,
73 Padding padding, int64 lhs_dilation, int64 rhs_dilation,
74 const ConvolutionDimensionNumbers& dnums) {
75 CHECK_EQ(dnums.input_spatial_dimensions_size(), 1);
76 CHECK_EQ(dnums.kernel_spatial_dimensions_size(), 1);
77 CHECK_EQ(dnums.output_spatial_dimensions_size(), 1);
78 // Reuse the code for Array4D-convolution by extending the 3D input into a 4D
79 // array by adding a fourth dummy dimension of size 1 without stride, padding
80 // and dilation.
81 Array4D<float> a4dlhs(lhs.n1(), lhs.n2(), lhs.n3(), 1);
82 a4dlhs.Each([&](absl::Span<const int64> indices, float* value_ptr) {
83 CHECK_EQ(indices[3], 0);
84 *value_ptr = lhs.operator()(indices[0], indices[1], indices[2]);
85 });
86 Array4D<float> a4drhs(rhs.n1(), rhs.n2(), rhs.n3(), 1);
87 a4drhs.Each([&](absl::Span<const int64> indices, float* value_ptr) {
88 CHECK_EQ(indices[3], 0);
89 *value_ptr = rhs.operator()(indices[0], indices[1], indices[2]);
90 });
91 // Add a second dummy spatial dimensions.
92 ConvolutionDimensionNumbers dnums2d = dnums;
93 dnums2d.add_input_spatial_dimensions(3);
94 dnums2d.add_kernel_spatial_dimensions(3);
95 dnums2d.add_output_spatial_dimensions(3);
96 std::unique_ptr<Array4D<float>> convr4 = ConvArray4DGeneralDimensionsDilated(
97 a4dlhs, a4drhs, {kernel_stride, 1}, padding, {lhs_dilation, 1},
98 {rhs_dilation, 1}, dnums2d);
99
100 auto convr3 = absl::make_unique<Array3D<float>>(
101 convr4->planes(), convr4->depth(), convr4->height());
102 convr4->Each([&](absl::Span<const int64> indices, float* value_ptr) {
103 CHECK_EQ(indices[3], 0);
104 convr3->operator()(indices[0], indices[1], indices[2]) = *value_ptr;
105 });
106 return convr3;
107 }
108
ConvArray4D(const Array4D<float> & lhs,const Array4D<float> & rhs,std::pair<int64,int64> kernel_stride,Padding padding)109 /* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ConvArray4D(
110 const Array4D<float>& lhs, const Array4D<float>& rhs,
111 std::pair<int64, int64> kernel_stride, Padding padding) {
112 return ConvArray4DGeneralDimensions(
113 lhs, rhs, kernel_stride, padding,
114 XlaBuilder::CreateDefaultConvDimensionNumbers());
115 }
116
117 /* static */ std::unique_ptr<Array4D<float>>
SeparableConvArray4D(const Array4D<float> & input,const Array4D<float> & depthwise_weights,const Array4D<float> & pointwise_weights,std::pair<int64,int64> kernel_stride,Padding padding)118 ReferenceUtil::SeparableConvArray4D(const Array4D<float>& input,
119 const Array4D<float>& depthwise_weights,
120 const Array4D<float>& pointwise_weights,
121 std::pair<int64, int64> kernel_stride,
122 Padding padding) {
123 const int64 depth_multiplier = depthwise_weights.planes();
124 CHECK_EQ(pointwise_weights.depth(), input.depth() * depth_multiplier);
125
126 // Combine the two weights by reducing the depth_multiplier, so that we can
127 // apply a single convolution on the combined weights.
128 Array4D<float> weights(pointwise_weights.planes(), input.depth(),
129 depthwise_weights.height(), depthwise_weights.width());
130 for (int64 kx = 0; kx < depthwise_weights.width(); ++kx) {
131 for (int64 ky = 0; ky < depthwise_weights.height(); ++ky) {
132 for (int64 kz = 0; kz < input.depth(); ++kz) {
133 for (int64 out = 0; out < pointwise_weights.planes(); ++out) {
134 float weight = 0.0;
135 for (int64 depth = 0; depth < depth_multiplier; ++depth) {
136 weight +=
137 depthwise_weights(depth, kz, ky, kx) *
138 pointwise_weights(out, depth + kz * depth_multiplier, 0, 0);
139 }
140 weights(out, kz, ky, kx) = weight;
141 }
142 }
143 }
144 }
145
146 return ConvArray4D(input, weights, kernel_stride, padding);
147 }
148
WindowCount(int64 unpadded_width,int64 window_len,int64 stride,Padding padding)149 /* static */ int64 ReferenceUtil::WindowCount(int64 unpadded_width,
150 int64 window_len, int64 stride,
151 Padding padding) {
152 if (padding == Padding::kValid) {
153 return window_util::StridedBound(unpadded_width, window_len, stride);
154 }
155 return tensorflow::MathUtil::CeilOfRatio(unpadded_width, stride);
156 }
157
158 /* static */ std::unique_ptr<std::vector<float>>
ReduceWindow1DGeneric(absl::Span<const float> operand,float init,const std::function<float (float,float)> & reduce_func,absl::Span<const int64> window,absl::Span<const int64> stride,absl::Span<const std::pair<int64,int64>> padding)159 ReferenceUtil::ReduceWindow1DGeneric(
160 absl::Span<const float> operand, float init,
161 const std::function<float(float, float)>& reduce_func,
162 absl::Span<const int64> window, absl::Span<const int64> stride,
163 absl::Span<const std::pair<int64, int64>> padding) {
164 std::vector<int64> dim_lengths{static_cast<int64>(operand.size())};
165 std::vector<int64> window_counts(window.size(), 0);
166 std::vector<int64> pad_low(window.size(), 0);
167 for (int64 i = 0; i < window.size(); ++i) {
168 int64 padded_width = padding[i].first + dim_lengths[i] + padding[i].second;
169 window_counts[i] =
170 window_util::StridedBound(padded_width, window[i], stride[i]);
171 pad_low[i] = padding[i].first;
172 }
173 auto result = absl::make_unique<std::vector<float>>(window_counts[0]);
174
175 // Do a full 1D reduce window.
176 for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
177 int64 i0_base = i0 * stride[0] - pad_low[0];
178
179 float val = init;
180 for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
181 if (i0_base + i0_win >= 0 && i0_base + i0_win < dim_lengths[0]) {
182 val = reduce_func(val, operand[i0_base + i0_win]);
183 }
184 }
185 (*result)[i0] = val;
186 }
187 return result;
188 }
189
190 /* static */ std::unique_ptr<std::vector<float>>
ReduceWindow1DAdd(absl::Span<const float> operand,float init,absl::Span<const int64> window,absl::Span<const int64> stride,Padding padding)191 ReferenceUtil::ReduceWindow1DAdd(absl::Span<const float> operand, float init,
192 absl::Span<const int64> window,
193 absl::Span<const int64> stride,
194 Padding padding) {
195 const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
196 std::vector<int64> dim_lengths{static_cast<int64>(operand.size())};
197 return ReduceWindow1DGeneric(
198 operand, init, add_reduce, window, stride,
199 xla::MakePadding(dim_lengths, window, stride, padding));
200 }
201
202 /* static */ std::unique_ptr<Array2D<float>>
ReduceWindow2DGeneric(const Array2D<float> & operand,float init,const std::function<float (float,float)> & reduce_func,absl::Span<const int64> window,absl::Span<const int64> stride,absl::Span<const std::pair<int64,int64>> padding)203 ReferenceUtil::ReduceWindow2DGeneric(
204 const Array2D<float>& operand, float init,
205 const std::function<float(float, float)>& reduce_func,
206 absl::Span<const int64> window, absl::Span<const int64> stride,
207 absl::Span<const std::pair<int64, int64>> padding) {
208 std::vector<int64> dim_lengths{operand.height(), operand.width()};
209
210 std::vector<int64> window_counts(window.size(), 0);
211 std::vector<int64> pad_low(window.size(), 0);
212 for (int64 i = 0; i < window.size(); ++i) {
213 int64 padded_width = padding[i].first + dim_lengths[i] + padding[i].second;
214 window_counts[i] =
215 window_util::StridedBound(padded_width, window[i], stride[i]);
216 pad_low[i] = padding[i].first;
217 }
218 auto result =
219 absl::make_unique<Array2D<float>>(window_counts[0], window_counts[1]);
220
221 // Do a full 2D reduce window.
222 for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
223 for (int64 i1 = 0; i1 < window_counts[1]; ++i1) {
224 int64 i0_base = i0 * stride[0] - pad_low[0];
225 int64 i1_base = i1 * stride[1] - pad_low[1];
226
227 float val = init;
228 for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
229 for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) {
230 if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
231 i0_base + i0_win < operand.n1() &&
232 i1_base + i1_win < operand.n2()) {
233 val = reduce_func(val, operand(i0_base + i0_win, i1_base + i1_win));
234 }
235 }
236 }
237 (*result)(i0, i1) = val;
238 }
239 }
240 return result;
241 }
242
ReduceWindow2DAdd(const Array2D<float> & operand,float init,absl::Span<const int64> window,absl::Span<const int64> stride,Padding padding)243 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::ReduceWindow2DAdd(
244 const Array2D<float>& operand, float init, absl::Span<const int64> window,
245 absl::Span<const int64> stride, Padding padding) {
246 const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
247 std::vector<int64> dim_lengths{operand.height(), operand.width()};
248 return ReduceWindow2DGeneric(
249 operand, init, add_reduce, window, stride,
250 xla::MakePadding(dim_lengths, window, stride, padding));
251 }
252
ReduceWindow3DAdd(const Array3D<float> & operand,float init,absl::Span<const int64> window,absl::Span<const int64> stride,Padding padding)253 /* static */ std::unique_ptr<Array3D<float>> ReferenceUtil::ReduceWindow3DAdd(
254 const Array3D<float>& operand, float init, absl::Span<const int64> window,
255 absl::Span<const int64> stride, Padding padding) {
256 std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3()};
257 auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
258
259 std::vector<int64> window_counts(window.size(), 0);
260 std::vector<int64> pad_low(window.size(), 0);
261 for (int64 i = 0; i < window.size(); ++i) {
262 window_counts[i] =
263 WindowCount(dim_lengths[i], window[i], stride[i], padding);
264 pad_low[i] = padding_both[i].first;
265 }
266 auto result = absl::make_unique<Array3D<float>>(
267 window_counts[0], window_counts[1], window_counts[2]);
268
269 for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
270 for (int64 i1 = 0; i1 < window_counts[1]; ++i1) {
271 for (int64 i2 = 0; i2 < window_counts[2]; ++i2) {
272 int64 i0_base = i0 * stride[0] - pad_low[0];
273 int64 i1_base = i1 * stride[1] - pad_low[1];
274 int64 i2_base = i2 * stride[2] - pad_low[2];
275
276 float val = init;
277 for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
278 for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) {
279 for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) {
280 if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
281 i2_base + i2_win >= 0 && i0_base + i0_win < operand.n1() &&
282 i1_base + i1_win < operand.n2() &&
283 i2_base + i2_win < operand.n3()) {
284 val += operand(i0_base + i0_win, i1_base + i1_win,
285 i2_base + i2_win);
286 }
287 }
288 }
289 }
290 (*result)(i0, i1, i2) = val;
291 }
292 }
293 }
294 return result;
295 }
296
297 /* static */ std::unique_ptr<Array4D<float>>
ReduceWindow4DGeneric(const Array4D<float> & operand,float init,const std::function<float (float,float)> & reduce_func,absl::Span<const int64> window,absl::Span<const int64> stride,Padding padding)298 ReferenceUtil::ReduceWindow4DGeneric(
299 const Array4D<float>& operand, float init,
300 const std::function<float(float, float)>& reduce_func,
301 absl::Span<const int64> window, absl::Span<const int64> stride,
302 Padding padding) {
303 std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
304 operand.n4()};
305 return ReduceWindow4DGeneric(
306 operand, init, reduce_func, window, stride,
307 xla::MakePadding(dim_lengths, window, stride, padding));
308 }
309
310 /* static */ std::unique_ptr<Array4D<float>>
ReduceWindow4DGeneric(const Array4D<float> & operand,float init,const std::function<float (float,float)> & reduce_func,absl::Span<const int64> window,absl::Span<const int64> stride,absl::Span<const std::pair<int64,int64>> padding)311 ReferenceUtil::ReduceWindow4DGeneric(
312 const Array4D<float>& operand, float init,
313 const std::function<float(float, float)>& reduce_func,
314 absl::Span<const int64> window, absl::Span<const int64> stride,
315 absl::Span<const std::pair<int64, int64>> padding) {
316 std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
317 operand.n4()};
318
319 std::vector<int64> window_counts(window.size(), 0);
320 std::vector<int64> pad_low(window.size(), 0);
321 for (int64 i = 0; i < window.size(); ++i) {
322 int64 padded_width = padding[i].first + dim_lengths[i] + padding[i].second;
323 window_counts[i] =
324 window_util::StridedBound(padded_width, window[i], stride[i]);
325 pad_low[i] = padding[i].first;
326 }
327 auto result = absl::make_unique<Array4D<float>>(
328 window_counts[0], window_counts[1], window_counts[2], window_counts[3]);
329 // Do a full 4D reduce window.
330 for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
331 for (int64 i1 = 0; i1 < window_counts[1]; ++i1) {
332 for (int64 i2 = 0; i2 < window_counts[2]; ++i2) {
333 for (int64 i3 = 0; i3 < window_counts[3]; ++i3) {
334 int64 i0_base = i0 * stride[0] - pad_low[0];
335 int64 i1_base = i1 * stride[1] - pad_low[1];
336 int64 i2_base = i2 * stride[2] - pad_low[2];
337 int64 i3_base = i3 * stride[3] - pad_low[3];
338
339 float val = init;
340 for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
341 for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) {
342 for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) {
343 for (int64 i3_win = 0; i3_win < window[3]; ++i3_win) {
344 if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
345 i2_base + i2_win >= 0 && i3_base + i3_win >= 0 &&
346 i0_base + i0_win < operand.n1() &&
347 i1_base + i1_win < operand.n2() &&
348 i2_base + i2_win < operand.n3() &&
349 i3_base + i3_win < operand.n4()) {
350 val = reduce_func(
351 val, operand(i0_base + i0_win, i1_base + i1_win,
352 i2_base + i2_win, i3_base + i3_win));
353 }
354 }
355 }
356 }
357 }
358 (*result)(i0, i1, i2, i3) = val;
359 }
360 }
361 }
362 }
363 return result;
364 }
365
ReduceWindow4DAdd(const Array4D<float> & operand,float init,absl::Span<const int64> window,absl::Span<const int64> stride,Padding padding)366 /* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::ReduceWindow4DAdd(
367 const Array4D<float>& operand, float init, absl::Span<const int64> window,
368 absl::Span<const int64> stride, Padding padding) {
369 const auto add_reduce = [](float arg1, float arg2) { return arg1 + arg2; };
370 return ReduceWindow4DGeneric(operand, init, add_reduce, window, stride,
371 padding);
372 }
373
BatchNorm4D(const Array4D<float> & input,const Array4D<float> & mean,const Array4D<float> & var,const Array4D<float> & scale,const Array4D<float> & offset,float epsilon)374 /* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::BatchNorm4D(
375 const Array4D<float>& input, const Array4D<float>& mean,
376 const Array4D<float>& var, const Array4D<float>& scale,
377 const Array4D<float>& offset, float epsilon) {
378 auto normalized =
379 *MapArray4D(input, mean, [](float a, float b) { return a - b; });
380 normalized = *MapArray4D(normalized, var, [&](float a, float b) {
381 return a / std::sqrt(b + epsilon);
382 });
383 normalized =
384 *MapArray4D(normalized, scale, [](float a, float b) { return a * b; });
385 return MapArray4D(normalized, offset, [](float a, float b) { return a + b; });
386 }
387
388 /* static */ std::unique_ptr<Array4D<float>>
SelectAndScatter4DGePlus(const Array4D<float> & operand,const Array4D<float> & source,float init,absl::Span<const int64> window,absl::Span<const int64> stride,bool same_padding)389 ReferenceUtil::SelectAndScatter4DGePlus(const Array4D<float>& operand,
390 const Array4D<float>& source,
391 float init,
392 absl::Span<const int64> window,
393 absl::Span<const int64> stride,
394 bool same_padding) {
395 Padding padding = same_padding ? Padding::kSame : Padding::kValid;
396 auto result = absl::make_unique<Array4D<float>>(operand.n1(), operand.n2(),
397 operand.n3(), operand.n4());
398 std::vector<int64> dim_lengths{operand.n1(), operand.n2(), operand.n3(),
399 operand.n4()};
400 auto padding_both = xla::MakePadding(dim_lengths, window, stride, padding);
401 // Fill the output, with the initial value.
402 result->Fill(init);
403
404 std::vector<int64> window_counts(window.size(), 0);
405 std::vector<int64> pad_low(window.size(), 0);
406 for (int64 i = 0; i < window.size(); ++i) {
407 window_counts[i] =
408 WindowCount(dim_lengths[i], window[i], stride[i], padding);
409 pad_low[i] = padding_both[i].first;
410 }
411 CHECK_EQ(window_counts[0], source.n1());
412 CHECK_EQ(window_counts[1], source.n2());
413 CHECK_EQ(window_counts[2], source.n3());
414 CHECK_EQ(window_counts[3], source.n4());
415
416 // Do a full 4D select and Scatter.
417 for (int64 i0 = 0; i0 < window_counts[0]; ++i0) {
418 for (int64 i1 = 0; i1 < window_counts[1]; ++i1) {
419 for (int64 i2 = 0; i2 < window_counts[2]; ++i2) {
420 for (int64 i3 = 0; i3 < window_counts[3]; ++i3) {
421 // Now we are inside a window and need to find the max and the argmax.
422 int64 i0_base = i0 * stride[0] - pad_low[0];
423 int64 i1_base = i1 * stride[1] - pad_low[1];
424 int64 i2_base = i2 * stride[2] - pad_low[2];
425 int64 i3_base = i3 * stride[3] - pad_low[3];
426 int64 scatter_0 = (i0_base >= 0) ? i0_base : 0;
427 int64 scatter_1 = (i1_base >= 0) ? i1_base : 0;
428 int64 scatter_2 = (i2_base >= 0) ? i2_base : 0;
429 int64 scatter_3 = (i3_base >= 0) ? i3_base : 0;
430 float val = operand(scatter_0, scatter_1, scatter_2, scatter_3);
431 for (int64 i0_win = 0; i0_win < window[0]; ++i0_win) {
432 for (int64 i1_win = 0; i1_win < window[1]; ++i1_win) {
433 for (int64 i2_win = 0; i2_win < window[2]; ++i2_win) {
434 for (int64 i3_win = 0; i3_win < window[3]; ++i3_win) {
435 if (i0_base + i0_win >= 0 && i1_base + i1_win >= 0 &&
436 i2_base + i2_win >= 0 && i3_base + i3_win >= 0 &&
437 i0_base + i0_win < operand.n1() &&
438 i1_base + i1_win < operand.n2() &&
439 i2_base + i2_win < operand.n3() &&
440 i3_base + i3_win < operand.n4()) {
441 float tmp = operand(i0_base + i0_win, i1_base + i1_win,
442 i2_base + i2_win, i3_base + i3_win);
443 if (tmp > val) {
444 val = tmp;
445 scatter_0 = i0_base + i0_win;
446 scatter_1 = i1_base + i1_win;
447 scatter_2 = i2_base + i2_win;
448 scatter_3 = i3_base + i3_win;
449 }
450 }
451 }
452 }
453 }
454 }
455 (*result)(scatter_0, scatter_1, scatter_2, scatter_3) +=
456 source(i0, i1, i2, i3);
457 }
458 }
459 }
460 }
461 return result;
462 }
463
464 /* static */ std::unique_ptr<Array4D<float>>
ConvArray4DGeneralDimensions(const Array4D<float> & lhs,const Array4D<float> & rhs,std::pair<int64,int64> kernel_stride,Padding padding,ConvolutionDimensionNumbers dimension_numbers)465 ReferenceUtil::ConvArray4DGeneralDimensions(
466 const Array4D<float>& lhs, const Array4D<float>& rhs,
467 std::pair<int64, int64> kernel_stride, Padding padding,
468 ConvolutionDimensionNumbers dimension_numbers) {
469 return ConvArray4DGeneralDimensionsDilated(lhs, rhs, kernel_stride, padding,
470 {1, 1}, {1, 1},
471 std::move(dimension_numbers));
472 }
473
474 /* static */ std::unique_ptr<Array4D<float>>
ConvArray4DGeneralDimensionsDilated(const Array4D<float> & lhs,const Array4D<float> & rhs,std::pair<int64,int64> kernel_stride,Padding padding,std::pair<int64,int64> lhs_dilation,std::pair<int64,int64> rhs_dilation,ConvolutionDimensionNumbers dnums)475 ReferenceUtil::ConvArray4DGeneralDimensionsDilated(
476 const Array4D<float>& lhs, const Array4D<float>& rhs,
477 std::pair<int64, int64> kernel_stride, Padding padding,
478 std::pair<int64, int64> lhs_dilation, std::pair<int64, int64> rhs_dilation,
479 ConvolutionDimensionNumbers dnums) {
480 HloComputation::Builder b("ConvArray4DGeneralDimensionDilated");
481 auto lhs_literal = LiteralUtil::CreateR4FromArray4D<float>(lhs);
482 auto rhs_literal = LiteralUtil::CreateR4FromArray4D<float>(rhs);
483
484 std::array<int64, 2> ordered_kernel_strides;
485 std::array<int64, 2> ordered_input_dimensions;
486 std::array<int64, 2> ordered_kernel_dimensions;
487 if (dnums.kernel_spatial_dimensions(0) > dnums.kernel_spatial_dimensions(1)) {
488 ordered_kernel_strides[0] = kernel_stride.second;
489 ordered_kernel_strides[1] = kernel_stride.first;
490 } else {
491 ordered_kernel_strides[0] = kernel_stride.first;
492 ordered_kernel_strides[1] = kernel_stride.second;
493 }
494
495 ordered_input_dimensions[0] =
496 lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(0));
497 ordered_input_dimensions[1] =
498 lhs_literal.shape().dimensions(dnums.input_spatial_dimensions(1));
499 ordered_kernel_dimensions[0] =
500 rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0));
501 ordered_kernel_dimensions[1] =
502 rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1));
503
504 std::vector<std::pair<int64, int64>> paddings =
505 MakePadding(ordered_input_dimensions, ordered_kernel_dimensions,
506 ordered_kernel_strides, padding);
507 CHECK_EQ(paddings.size(), 2);
508
509 Window window;
510
511 WindowDimension dim;
512 dim.set_size(
513 rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(0)));
514 dim.set_stride(kernel_stride.first);
515 dim.set_padding_low(paddings[0].first);
516 dim.set_padding_high(paddings[0].second);
517 dim.set_window_dilation(rhs_dilation.first);
518 dim.set_base_dilation(lhs_dilation.first);
519 *window.add_dimensions() = dim;
520
521 WindowDimension dim2;
522 dim2.set_size(
523 rhs_literal.shape().dimensions(dnums.kernel_spatial_dimensions(1)));
524 dim2.set_stride(kernel_stride.second);
525 dim2.set_padding_low(paddings[1].first);
526 dim2.set_padding_high(paddings[1].second);
527 dim2.set_window_dilation(rhs_dilation.second);
528 dim2.set_base_dilation(lhs_dilation.second);
529 *window.add_dimensions() = dim2;
530
531 const Shape& shape =
532 ShapeInference::InferConvolveShape(
533 lhs_literal.shape(), rhs_literal.shape(),
534 /*feature_group_count=*/1, /*batch_group_count=*/1, window, dnums)
535 .ConsumeValueOrDie();
536
537 HloInstruction* lhs_instruction =
538 b.AddInstruction(HloInstruction::CreateConstant(std::move(lhs_literal)));
539 HloInstruction* rhs_instruction =
540 b.AddInstruction(HloInstruction::CreateConstant(std::move(rhs_literal)));
541
542 PrecisionConfig precision_config;
543 precision_config.mutable_operand_precision()->Resize(
544 /*new_size=*/2, PrecisionConfig::DEFAULT);
545 b.AddInstruction(HloInstruction::CreateConvolve(
546 shape, lhs_instruction, rhs_instruction, /*feature_group_count=*/1,
547 /*batch_group_count=*/1, window, dnums, precision_config));
548 HloModuleConfig config;
549 HloModule module("ReferenceUtil", config);
550 auto computation = module.AddEntryComputation(b.Build());
551
552 HloEvaluator evaluator;
553 Literal result_literal =
554 evaluator.Evaluate(*computation, {}).ConsumeValueOrDie();
555
556 CHECK_EQ(result_literal.shape().rank(), 4);
557 auto result =
558 absl::make_unique<Array4D<float>>(result_literal.shape().dimensions(0),
559 result_literal.shape().dimensions(1),
560 result_literal.shape().dimensions(2),
561 result_literal.shape().dimensions(3));
562
563 result->Each([&](absl::Span<const int64> indices, float* value) {
564 *value = result_literal.Get<float>(indices);
565 });
566
567 return result;
568 }
569
570 /* static */ std::unique_ptr<std::vector<float>>
ReduceToColArray2D(const Array2D<float> & matrix,float init,const std::function<float (float,float)> & reduce_function)571 ReferenceUtil::ReduceToColArray2D(
572 const Array2D<float>& matrix, float init,
573 const std::function<float(float, float)>& reduce_function) {
574 int64 rows = matrix.height();
575 int64 cols = matrix.width();
576 auto result = absl::make_unique<std::vector<float>>();
577 for (int64 i = 0; i < rows; ++i) {
578 float acc = init;
579 for (int64 j = 0; j < cols; ++j) {
580 acc = reduce_function(acc, matrix(i, j));
581 }
582 result->push_back(acc);
583 }
584 return result;
585 }
586
587 /* static */ std::unique_ptr<std::vector<float>>
ReduceToRowArray2D(const Array2D<float> & matrix,float init,const std::function<float (float,float)> & reduce_function)588 ReferenceUtil::ReduceToRowArray2D(
589 const Array2D<float>& matrix, float init,
590 const std::function<float(float, float)>& reduce_function) {
591 int64 rows = matrix.height();
592 int64 cols = matrix.width();
593 auto result = absl::make_unique<std::vector<float>>();
594 for (int64 i = 0; i < cols; ++i) {
595 float acc = init;
596 for (int64 j = 0; j < rows; ++j) {
597 acc = reduce_function(acc, matrix(j, i));
598 }
599 result->push_back(acc);
600 }
601 return result;
602 }
603
Reduce4DTo1D(const Array4D<float> & array,float init,absl::Span<const int64> dims,const std::function<float (float,float)> & reduce_function)604 /*static*/ std::vector<float> ReferenceUtil::Reduce4DTo1D(
605 const Array4D<float>& array, float init, absl::Span<const int64> dims,
606 const std::function<float(float, float)>& reduce_function) {
607 std::vector<float> result;
608 CHECK_EQ(dims.size(), 3);
609 const absl::flat_hash_set<int64> dim_set(dims.begin(), dims.end());
610 CHECK_EQ(dim_set.size(), 3);
611 for (int64 a0 = 0; a0 == 0 || (!dim_set.contains(0) && a0 < array.n1());
612 ++a0) {
613 for (int64 a1 = 0; a1 == 0 || (!dim_set.contains(1) && a1 < array.n2());
614 ++a1) {
615 for (int64 a2 = 0; a2 == 0 || (!dim_set.contains(2) && a2 < array.n3());
616 ++a2) {
617 for (int64 a3 = 0; a3 == 0 || (!dim_set.contains(3) && a3 < array.n4());
618 ++a3) {
619 float accumulator = init;
620 for (int64 i0 = 0;
621 i0 == 0 || (dim_set.contains(0) && i0 < array.n1()); ++i0) {
622 for (int64 i1 = 0;
623 i1 == 0 || (dim_set.contains(1) && i1 < array.n2()); ++i1) {
624 for (int64 i2 = 0;
625 i2 == 0 || (dim_set.contains(2) && i2 < array.n3()); ++i2) {
626 for (int64 i3 = 0;
627 i3 == 0 || (dim_set.contains(3) && i3 < array.n4());
628 ++i3) {
629 // Handle zero-sized arrays.
630 if (array.n1() > 0 && array.n2() > 0 && array.n3() > 0 &&
631 array.n4() > 0) {
632 accumulator = reduce_function(
633 accumulator, array(a0 + i0, a1 + i1, a2 + i2, a3 + i3));
634 }
635 }
636 }
637 }
638 }
639 result.push_back(accumulator);
640 }
641 }
642 }
643 }
644 return result;
645 }
646
Broadcast1DTo4D(const std::vector<float> & array,const std::vector<int64> & bounds,int64 broadcast_from_dim)647 /* static */ std::unique_ptr<Array4D<float>> ReferenceUtil::Broadcast1DTo4D(
648 const std::vector<float>& array, const std::vector<int64>& bounds,
649 int64 broadcast_from_dim) {
650 auto result = absl::make_unique<Array4D<float>>(bounds[0], bounds[1],
651 bounds[2], bounds[3]);
652 for (int64 i = 0; i < result->n1(); ++i) {
653 for (int64 j = 0; j < result->n2(); ++j) {
654 for (int64 k = 0; k < result->n3(); ++k) {
655 for (int64 l = 0; l < result->n4(); ++l) {
656 switch (broadcast_from_dim) {
657 case 0:
658 (*result)(i, j, k, l) = array[i];
659 break;
660 case 1:
661 (*result)(i, j, k, l) = array[j];
662 break;
663 case 2:
664 (*result)(i, j, k, l) = array[k];
665 break;
666 case 3:
667 (*result)(i, j, k, l) = array[l];
668 break;
669 default:
670 break;
671 }
672 }
673 }
674 }
675 }
676 return result;
677 }
678
Reduce3DTo2D(const Array3D<float> & array,float init,absl::Span<const int64> dims,const std::function<float (float,float)> & reduce_function)679 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::Reduce3DTo2D(
680 const Array3D<float>& array, float init, absl::Span<const int64> dims,
681 const std::function<float(float, float)>& reduce_function) {
682 CHECK_EQ(dims.size(), 1);
683 int64 rows = dims[0] == 0 ? array.n2() : array.n1();
684 int64 cols = dims[0] == 2 ? array.n2() : array.n3();
685 auto result = absl::make_unique<Array2D<float>>(rows, cols);
686 result->Fill(init);
687 for (int i0 = 0; i0 < array.n1(); ++i0) {
688 for (int i1 = 0; i1 < array.n2(); ++i1) {
689 for (int i2 = 0; i2 < array.n3(); ++i2) {
690 int64 row = dims[0] == 0 ? i1 : i0;
691 int64 col = dims[0] == 2 ? i1 : i2;
692 (*result)(row, col) =
693 reduce_function((*result)(row, col), array(i0, i1, i2));
694 }
695 }
696 }
697 return result;
698 }
699
MapArray2D(const Array2D<float> & matrix,const std::function<float (float)> & map_function)700 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapArray2D(
701 const Array2D<float>& matrix,
702 const std::function<float(float)>& map_function) {
703 int64 rows = matrix.height();
704 int64 cols = matrix.width();
705 auto result = absl::make_unique<Array2D<float>>(rows, cols);
706 for (int64 i = 0; i < rows; ++i) {
707 for (int64 j = 0; j < cols; ++j) {
708 (*result)(i, j) = map_function(matrix(i, j));
709 }
710 }
711 return result;
712 }
713
MapArray2D(const Array2D<float> & lhs,const Array2D<float> & rhs,const std::function<float (float,float)> & map_function)714 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapArray2D(
715 const Array2D<float>& lhs, const Array2D<float>& rhs,
716 const std::function<float(float, float)>& map_function) {
717 CHECK_EQ(lhs.height(), rhs.height());
718 CHECK_EQ(lhs.width(), rhs.width());
719 int64 rows = lhs.height();
720 int64 cols = rhs.width();
721 auto result = absl::make_unique<Array2D<float>>(rows, cols);
722 for (int64 i = 0; i < rows; ++i) {
723 for (int64 j = 0; j < cols; ++j) {
724 (*result)(i, j) = map_function(lhs(i, j), rhs(i, j));
725 }
726 }
727 return result;
728 }
729
MapWithIndexArray2D(const Array2D<float> & matrix,const std::function<float (float,int64,int64)> & map_function)730 /* static */ std::unique_ptr<Array2D<float>> ReferenceUtil::MapWithIndexArray2D(
731 const Array2D<float>& matrix,
732 const std::function<float(float, int64, int64)>& map_function) {
733 int64 rows = matrix.height();
734 int64 cols = matrix.width();
735 auto result = absl::make_unique<Array2D<float>>(rows, cols);
736 for (int64 i = 0; i < rows; ++i) {
737 for (int64 j = 0; j < cols; ++j) {
738 (*result)(i, j) = map_function(matrix(i, j), i, j);
739 }
740 }
741 return result;
742 }
743
744 } // namespace xla
745