• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_GRID_SAMPLER_2D_GRAD_CPU_KERNEL_H_
17 #define MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_GRID_SAMPLER_2D_GRAD_CPU_KERNEL_H_
18 #include <map>
19 #include <vector>
20 #include <memory>
21 #include <algorithm>
22 #include <bitset>
23 #include <cmath>
24 #include <iostream>
25 #include <type_traits>
26 #include <array>
27 #include <functional>
28 #include <utility>
29 #include <tuple>
30 #include "plugin/device/cpu/kernel/cpu_kernel.h"
31 #include "plugin/factory/ms_factory.h"
32 #include "mindspore/core/ops/ops_func_impl/grid_sampler_2d_grad.h"
33 
34 namespace mindspore {
35 const int64_t hZero = 0;
36 const int64_t hOne = 1;
37 const int64_t hTwo = 2;
38 const int64_t hThree = 3;
39 const int64_t hFour = 4;
40 const int64_t hFive = 5;
41 const int64_t hSix = 6;
42 const int64_t hSeven = 7;
43 const int64_t hEight = 8;
44 namespace kernel {
45 enum class GridSamplerInterpolation { Bilinear, Nearest };
46 enum class GridSamplerPadding { Zeros, Border, Reflection };
47 
48 class GridSampler2DGradCpuKernelMod : public NativeCpuKernelMod {
49  public:
50   GridSampler2DGradCpuKernelMod() = default;
51   ~GridSampler2DGradCpuKernelMod() override = default;
52 
53   bool Init(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &outputs) override;
54   bool Launch(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &workspace,
55               const std::vector<KernelTensor *> &outputs) override;
56   int Resize(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &outputs) override;
57 
58   template <typename T>
59   void LaunchKernel(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &outputs);
60 
GetOpSupport()61   std::vector<KernelAttr> GetOpSupport() override {
62     static std::vector<KernelAttr> support_list = {KernelAttr()
63                                                      .AddInputAttr(kNumberTypeFloat32)
64                                                      .AddInputAttr(kNumberTypeFloat32)
65                                                      .AddInputAttr(kNumberTypeFloat32)
66                                                      .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
67                                                      .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
68                                                      .AddInputAttr(kObjectTypeNumber, kNumberTypeBool)
69                                                      .AddOutputAttr(kNumberTypeFloat32)
70                                                      .AddOutputAttr(kNumberTypeFloat32),
71                                                    KernelAttr()
72                                                      .AddInputAttr(kNumberTypeFloat64)
73                                                      .AddInputAttr(kNumberTypeFloat64)
74                                                      .AddInputAttr(kNumberTypeFloat64)
75                                                      .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
76                                                      .AddInputAttr(kObjectTypeNumber, kNumberTypeInt64)
77                                                      .AddInputAttr(kObjectTypeNumber, kNumberTypeBool)
78                                                      .AddOutputAttr(kNumberTypeFloat64)
79                                                      .AddOutputAttr(kNumberTypeFloat64)};
80     return support_list;
81   }
82 
83  private:
84   ShapeVector grad_shape_;
85   ShapeVector x_shape_;
86   ShapeVector grid_shape_;
87   ShapeVector dx_shape_;
88   ShapeVector dgrid_shape_;
89   int64_t interpolation_mode_;
90   int64_t padding_mode_;
91   bool align_corners_;
92   size_t dx_size_;
93   size_t grid_size_;
94   TypeId dtype_{kTypeUnknown};
95 
96   template <typename T>
97   void ComputeTask(const std::vector<KernelTensor *> &inputs, const std::vector<KernelTensor *> &outputs);
98 };
99 
100 // *******************VEC256***********************
101 
102 namespace vec256 {
103 template <size_t n>
104 struct int_of_size;
105 
106 #define DEFINE_INT_OF_SIZE(int_t)     \
107   template <>                         \
108   struct int_of_size<sizeof(int_t)> { \
109     using type = int_t;               \
110   }
111 
112 DEFINE_INT_OF_SIZE(int64_t);
113 DEFINE_INT_OF_SIZE(int32_t);
114 DEFINE_INT_OF_SIZE(int16_t);
115 DEFINE_INT_OF_SIZE(int8_t);
116 
117 #undef DEFINE_INT_OF_SIZE
118 
119 template <typename T>
120 using int_same_size_t = typename int_of_size<sizeof(T)>::type;
121 template <class T>
122 struct Vec256 {
123  private:
124   T values[32 / sizeof(T)];  // 32
125 
126  public:
127   using value_type = T;
sizeVec256128   static constexpr int size() { return 32 / sizeof(T); }
Vec256Vec256129   Vec256() : values{0} {}
Vec256Vec256130   explicit Vec256(T val) {
131     for (int i = 0; i != size(); i++) {
132       values[i] = val;
133     }
134   }
135 
136   template <typename... Args, typename = typename std::enable_if<(sizeof...(Args) == size()), void>::type>
Vec256Vec256137   explicit Vec256(Args... vals) {
138     values = {vals...};
139   }
140 
truncVec256141   Vec256<T> trunc() const { return map(std::trunc); }
LoadUVec256142   static Vec256<T> LoadU(const void *ptr) {
143     Vec256 vec;
144     size_t count = 32;
145     auto cp_ret = memcpy_s(static_cast<void *>(vec.values), count, ptr, count);
146     if (cp_ret != EOK) {
147       MS_LOG(ERROR) << "memcpy_s failed. errorno is: " << cp_ret;
148     }
149     return vec;
150   }
LoadUVec256151   static Vec256<T> LoadU(const void *ptr, int64_t count) {
152     Vec256 vec;
153     auto cp_ret = memcpy_s(static_cast<void *>(vec.values), count * sizeof(T), ptr, count * sizeof(T));
154     if (cp_ret != EOK) {
155       MS_LOG(ERROR) << "memcpy_s failed. errorno is: " << cp_ret;
156     }
157     return vec;
158   }
159   void store(void *ptr, int count = size()) const {
160     auto cp_ret = memcpy_s(ptr, count * sizeof(T), values, count * sizeof(T));
161     if (cp_ret != EOK) {
162       MS_LOG(ERROR) << "memcpy_s failed. errorno is: " << cp_ret;
163     }
164   }
165   const T &operator[](int idx) const { return values[idx]; }
166   T &operator[](int idx) { return values[idx]; }
zero_maskVec256167   int zero_mask() const {
168     int mask = 0;
169     for (int i = 0; i < size(); ++i) {
170       if (values[i] == static_cast<T>(0)) {
171         mask |= (1 << i);
172       }
173     }
174     return mask;
175   }
mapVec256176   Vec256<T> map(T (*f)(T)) const {
177     Vec256<T> ret;
178     for (size_t i = 0; i != IntToSize(size()); i++) {
179       ret[i] = f(values[i]);
180     }
181     return ret;
182   }
183   template <typename other_t_abs = T,
184             typename std::enable_if<!std::is_floating_point<other_t_abs>::value, int>::type = 0>
absVec256185   Vec256<T> abs() const {
186     static_assert(std::is_same<other_t_abs, T>::value, "other_t_abs must be T");
187     return map([](T x) -> T { return x < static_cast<T>(0) ? -x : x; });
188   }
189   template <typename float_t_abs = T,
190             typename std::enable_if<std::is_floating_point<float_t_abs>::value, int>::type = 0>
absVec256191   Vec256<T> abs() const {
192     static_assert(std::is_same<float_t_abs, T>::value, "float_t_abs must be T");
193     return map(std::abs);
194   }
blendvVec256195   static Vec256<T> blendv(const Vec256<T> &a, const Vec256<T> &b, const Vec256<T> &mask) {
196     Vec256 vec;
197     int_same_size_t<T> buffer[size()];
198     mask.store(buffer);
199     for (size_t i = 0; i < IntToSize(size()); i++) {
200       if (buffer[i] & 0x01) {
201         vec[i] = b[i];
202       } else {
203         vec[i] = a[i];
204       }
205     }
206     return vec;
207   }
208   static Vec256<T> arange(T base = static_cast<T>(0), T step = static_cast<T>(1)) {
209     Vec256 vec;
210     for (size_t i = 0; i < IntToSize(size()); i++) {
211       vec.values[i] = base + i * step;
212     }
213     return vec;
214   }
215   static Vec256<T> set(const Vec256<T> &a, const Vec256<T> &b, int64_t count = size()) {
216     Vec256 vec;
217     for (size_t i = 0; i < IntToSize(size()); i++) {
218       if (i < LongToSize(count)) {
219         vec[i] = b[i];
220       } else {
221         vec[i] = a[i];
222       }
223     }
224     return vec;
225   }
floorVec256226   Vec256<T> floor() const {
227     Vec256<T> ret;
228     for (size_t i = 0; i != IntToSize(size()); i++) {
229       ret[i] = std::floor(values[i]);
230     }
231     return ret;
232   }
233 
round_implVec256234   inline T round_impl(const T z) { return std::nearbyint(z); }
roundVec256235   Vec256<T> round() const {
236     Vec256<T> ret;
237     for (size_t i = 0; i != IntToSize(size()); i++) {
238       ret[i] = std::nearbyint(values[i]);
239     }
240     return ret;
241   }
242   Vec256<T> operator==(const Vec256<T> &other) const { return binary_pred(other, std::equal_to<T>()); }
243   Vec256<T> operator!=(const Vec256<T> &other) const { return binary_pred(other, std::not_equal_to<T>()); }
244   Vec256<T> operator>=(const Vec256<T> &other) const { return binary_pred(other, std::greater_equal<T>()); }
245   Vec256<T> operator<=(const Vec256<T> &other) const { return binary_pred(other, std::less_equal<T>()); }
246   Vec256<T> operator>(const Vec256<T> &other) const { return binary_pred(other, std::greater<T>()); }
247   Vec256<T> operator<(const Vec256<T> &other) const { return binary_pred(other, std::less<T>()); }
248 
249  private:
250   template <typename Op>
binary_predVec256251   inline Vec256<T> binary_pred(const Vec256<T> &other, Op op) const {
252     Vec256<T> vec;
253     for (int64_t i = 0; i != size(); i++) {
254       if (op(values[i], other.values[i])) {
255         auto ret = memset_s(static_cast<void *>(vec.values + i), sizeof(T), 0xFF, sizeof(T));
256         if (ret != 0) {
257           MS_LOG(ERROR) << "memset_s error, errorno(" << ret << ")";
258         }
259       } else {
260         auto ret = memset_s(static_cast<void *>(vec.values + i), sizeof(T), 0, sizeof(T));
261         if (ret != 0) {
262           MS_LOG(ERROR) << "memset_s error, errorno(" << ret << ")";
263         }
264       }
265     }
266     return vec;
267   }
268 };
269 
270 template <class T>
271 Vec256<T> inline operator+(const Vec256<T> &a, const Vec256<T> &b) {
272   Vec256<T> c = Vec256<T>();
273   for (int i = 0; i != Vec256<T>::size(); i++) {
274     c[i] = a[i] + b[i];
275   }
276   return c;
277 }
278 template <class T>
279 Vec256<T> operator-(const Vec256<T> &a, const Vec256<T> &b) {
280   Vec256<T> c = Vec256<T>();
281   for (int i = 0; i != Vec256<T>::size(); i++) {
282     c[i] = a[i] - b[i];
283   }
284   return c;
285 }
286 template <class T>
287 Vec256<T> inline operator*(const Vec256<T> &a, const Vec256<T> &b) {
288   Vec256<T> c = Vec256<T>();
289   for (int i = 0; i != Vec256<T>::size(); i++) {
290     c[i] = a[i] * b[i];
291   }
292   return c;
293 }
294 template <class T>
295 Vec256<T> inline operator/(const Vec256<T> &a, const Vec256<T> &b) {
296   Vec256<T> c = Vec256<T>();
297   for (int i = 0; i != Vec256<T>::size(); i++) {
298     c[i] = a[i] / b[i];
299   }
300   return c;
301 }
302 template <class T>
303 inline Vec256<T> operator&(const Vec256<T> &a, const Vec256<T> &b) {
304   return bitwise_binary_op(a, b, std::bit_and<int_same_size_t<T>>());
305 }
306 template <class T>
307 inline Vec256<T> operator^(const Vec256<T> &a, const Vec256<T> &b) {
308   return bitwise_binary_op(a, b, std::bit_xor<int_same_size_t<T>>());
309 }
310 #define RETURN_TYPE std::enable_if<scale == hOne || scale == hTwo || scale == hFour || scale == hEight, Vec256<T>>::type
311 template <int64_t scale = hOne, typename T = void>
GatherVec(T const * base_addr,const Vec256<int_same_size_t<T>> & vindex)312 typename RETURN_TYPE inline GatherVec(T const *base_addr, const Vec256<int_same_size_t<T>> &vindex) {
313   static constexpr int kSize = Vec256<T>::size();
314   int_same_size_t<T> index_arr[kSize];
315   vindex.store(static_cast<void *>(index_arr));
316   T buffer[kSize];
317   for (int64_t i = 0; i < kSize; i++) {
318     buffer[i] = base_addr[index_arr[i] * scale / sizeof(T)];
319   }
320   return Vec256<T>::LoadU(static_cast<void *>(buffer));
321 }
322 template <int64_t scale = hOne, typename T = void>
MaskGather(const Vec256<T> & src,T const * base_addr,const Vec256<int_same_size_t<T>> & vindex,Vec256<T> * mask)323 typename RETURN_TYPE inline MaskGather(const Vec256<T> &src, T const *base_addr,
324                                        const Vec256<int_same_size_t<T>> &vindex, Vec256<T> *mask) {
325   static constexpr int kSize = Vec256<T>::size();
326   T src_arr[kSize];
327   int_same_size_t<T> mask_arr[kSize];  // use int type so we can logical and
328   int_same_size_t<T> index_arr[kSize];
329   src.store(static_cast<void *>(src_arr));
330   mask->store(static_cast<void *>(mask_arr));
331   vindex.store(static_cast<void *>(index_arr));
332   T buffer[kSize];
333   for (int64_t i = 0; i < kSize; i++) {
334     if (mask_arr[i] & 0x01) {
335       buffer[i] = base_addr[static_cast<size_t>(index_arr[i] * static_cast<size_t>(scale) / sizeof(T))];
336     } else {
337       buffer[i] = src_arr[i];
338     }
339   }
340   *mask = Vec256<T>();
341   return Vec256<T>::LoadU(static_cast<void *>(buffer));
342 }
343 template <typename T>
ConvertToIntOfSameSize(const Vec256<T> & src)344 inline Vec256<int_same_size_t<T>> ConvertToIntOfSameSize(const Vec256<T> &src) {
345   static constexpr int kSize = Vec256<T>::size();
346   T src_arr[kSize];
347   src.store(static_cast<void *>(src_arr));
348   int_same_size_t<T> buffer[kSize];
349   for (int64_t i = 0; i < kSize; i++) {
350     buffer[i] = static_cast<int_same_size_t<T>>(src_arr[i]);
351   }
352   return Vec256<int_same_size_t<T>>::LoadU(static_cast<void *>(buffer));
353 }
354 template <typename dst_t, typename src_t>
355 struct CastImpl {
applyCastImpl356   static inline Vec256<dst_t> apply(const Vec256<src_t> &src) {
357     src_t src_arr[Vec256<src_t>::size()];
358     src.store(static_cast<void *>(src_arr));
359     return Vec256<dst_t>::LoadU(static_cast<const void *>(src_arr));
360   }
361 };
362 template <typename T>
363 struct CastImpl<T, T> {
364   static inline Vec256<T> apply(const Vec256<T> &src) { return src; }
365 };
366 template <typename T>
367 inline typename std::enable_if<Vec256<T>::size() % hTwo == hZero, std::pair<Vec256<T>, Vec256<T>>>::type deinterleave2(
368   const Vec256<T> &a, const Vec256<T> &b) {
369   static constexpr int kSize = Vec256<T>::size();
370   static constexpr int half_size = kSize / 2;
371   T a_arr[kSize];
372   T b_arr[kSize];
373   T buffer1[kSize];
374   T buffer2[kSize];
375   a.store(static_cast<void *>(a_arr));
376   b.store(static_cast<void *>(b_arr));
377   for (int64_t i = 0; i < half_size; i++) {
378     buffer1[i] = a_arr[i * hTwo];
379     buffer1[half_size + i] = b_arr[i * hTwo];
380     buffer2[i] = a_arr[i * hTwo + hOne];
381     buffer2[half_size + i] = b_arr[i * hTwo + hOne];
382   }
383   return std::make_pair(Vec256<T>::LoadU(static_cast<void *>(buffer1)), Vec256<T>::LoadU(static_cast<void *>(buffer2)));
384 }
385 template <typename T>
386 inline typename std::enable_if<Vec256<T>::size() % hTwo == hZero, std::pair<Vec256<T>, Vec256<T>>>::type interleave2(
387   const Vec256<T> &a, const Vec256<T> &b) {
388   static constexpr int kSize = Vec256<T>::size();
389   static constexpr int half_size = kSize / 2;
390   T a_arr[kSize];
391   T b_arr[kSize];
392   T buffer1[kSize];
393   T buffer2[kSize];
394   a.store(static_cast<void *>(a_arr));
395   b.store(static_cast<void *>(b_arr));
396   for (int64_t i = 0; i < half_size; i++) {
397     buffer1[i * hTwo] = a_arr[i];
398     buffer1[i * hTwo + hOne] = b_arr[i];
399     buffer2[i * hTwo] = a_arr[half_size + i];
400     buffer2[i * hTwo + hOne] = b_arr[half_size + i];
401   }
402   return std::make_pair(Vec256<T>::LoadU(static_cast<void *>(buffer1)), Vec256<T>::LoadU(static_cast<void *>(buffer2)));
403 }
404 
405 template <typename dst_t, typename src_t>
406 inline Vec256<dst_t> cast(const Vec256<src_t> &src) {
407   return CastImpl<dst_t, src_t>::apply(src);
408 }
409 
410 template <typename T>
411 inline bool _isnan(T val) {
412   return std::isnan(T(val));
413 }
414 template <class T>
415 Vec256<T> inline maximum(const Vec256<T> &a, const Vec256<T> &b) {
416   Vec256<T> c = Vec256<T>();
417   for (int i = 0; i != Vec256<T>::size(); i++) {
418     c[i] = (a[i] > b[i]) ? a[i] : b[i];
419     if (_isnan(a[i])) {
420       c[i] = a[i];
421     }
422   }
423   return c;
424 }
425 template <typename T>
426 Vec256<T> inline minimum(const Vec256<T> &a, const Vec256<T> &b) {
427   Vec256<T> c = Vec256<T>();
428   for (int i = 0; i != Vec256<T>::size(); i++) {
429     c[i] = (a[i] < b[i]) ? a[i] : b[i];
430     if (_isnan(a[i])) {
431       c[i] = a[i];
432     }
433   }
434   return c;
435 }
436 template <class T, typename Op>
437 static inline Vec256<T> bitwise_binary_op(const Vec256<T> &a, const Vec256<T> &b, Op op) {
438   using iT = int_same_size_t<T>;
439   iT buffer[Vec256<T>::size()];
440   for (int i = 0; i != Vec256<T>::size(); i++) {
441     auto a_val = a[i];
442     auto b_val = b[i];
443     iT *i_a_ptr = reinterpret_cast<iT *>(&a_val);
444     iT *i_b_ptr = reinterpret_cast<iT *>(&b_val);
445     buffer[i] = op(*i_a_ptr, *i_b_ptr);
446   }
447   return Vec256<T>::LoadU(buffer);
448 }
449 }  // namespace vec256
450 
451 template <typename T, size_t N>
452 class TensorAcc {
453  public:
454   TensorAcc(T *data_, int64_t *sizes_, int64_t *strides_) : dataptr(data_), sizes(sizes_), strides(strides_) {}
455   TensorAcc(const TensorAcc<T, 4> &tacc) { TensorAcc(tacc.dataptr, tacc.sizes, tacc.strides); }
456   int64_t stride(int64_t i) const { return strides[i]; }
457   int64_t size(int64_t i) const { return sizes[i]; }
458   T *data() { return dataptr; }
459   const T *data() const { return dataptr; }
460   TensorAcc<T, N - 1> operator[](const int64_t i) {
461     return TensorAcc<T, N - 1>(this->dataptr + this->strides[0] * i, this->sizes + 1, this->strides + 1);
462   }
463 
464   const TensorAcc<T, N - 1> operator[](int64_t i) const {
465     return TensorAcc<T, N - 1>(this->dataptr + this->strides[0] * i, this->sizes + 1, this->strides + 1);
466   }
467   ~TensorAcc() {}
468 
469  private:
470   T *dataptr;
471   int64_t *sizes;
472   int64_t *strides;
473 };
474 
475 template <typename T, size_t N>
476 TensorAcc<T, N> accessor(T *data_ptr, std::vector<int64_t> sizess) {
477   static_assert(N > 0, "accessor is used for indexing tensor, for scalars use *data_ptr<T>()");
478   int64_t stride_tmp = 1;
479   int64_t *strid = new int64_t[N];
480   for (int64_t i = N - 1; i > -1; --i) {
481     strid[i] = stride_tmp;
482     stride_tmp *= static_cast<int64_t>(sizess[i]);
483   }
484   int64_t *sizes = new int64_t[N];
485   for (size_t k = 0; k < N; ++k) sizes[k] = sizess[k];
486   return TensorAcc<T, N>(data_ptr, sizes, strid);
487 }
488 
489 bool GeometryIsContiguous(std::array<int64_t, hFour> sizes, std::array<int64_t, hFour> strides) {
490   int64_t dim = sizes.size();
491   int64_t expected_stride = 1;
492   bool contig_if_nonempty = true;
493   for (int64_t i = dim - 1; i >= 0; i--) {
494     if (sizes[LongToSize(i)] == 0) {
495       return true;
496     }
497     if (contig_if_nonempty) {
498       if (sizes[LongToSize(i)] != 1 && strides[LongToSize(i)] != expected_stride) {
499         contig_if_nonempty = false;
500       }
501       expected_stride *= sizes[i];
502     }
503   }
504   return contig_if_nonempty;
505 }
506 
507 template <typename T, bool align_corners>
508 struct ComputeLocationBase;
509 
510 template <typename T>
511 struct ComputeLocationBase<T, true> {
512   using Vec = vec256::Vec256<T>;
513   const T max_val;
514   const T scaling_factor;
515   const T low;
516   const T twice_span;
517   const bool empty;
518 
519   explicit ComputeLocationBase(int64_t size)
520       : max_val(static_cast<T>(size - 1)),
521         scaling_factor(static_cast<T>(size - hOne) / hTwo),
522         low(static_cast<T>(0)),
523         twice_span(static_cast<T>(size - hOne) * hTwo),
524         empty(size <= 0) {}
525 
526   inline Vec unnormalize(const Vec &in) const { return (in + Vec(1)) * Vec(scaling_factor); }
527 
528   inline Vec clip_coordinates(const Vec &in) const { return minimum(Vec(max_val), maximum(in, Vec(0))); }
529   inline std::pair<Vec, Vec> clip_coordinates_get_grad(const Vec &in) const {
530     using int_t = vec256::int_same_size_t<T>;
531     auto bounded_lo = maximum(in, Vec(0));
532     auto in_bound_lo = vec256::cast<T>(cast<int_t>(bounded_lo) != vec256::cast<int_t>(Vec(0)));
533     auto res = minimum(bounded_lo, Vec(max_val));
534     auto in_bound_hi = vec256::cast<T>(cast<int_t>(res) != vec256::cast<int_t>(Vec(max_val)));
535     return std::make_pair(res, in_bound_lo & in_bound_hi);
536   }
537 
538   inline vec256::Vec256<T> reflect_coordinates(const vec256::Vec256<T> &in) const {
539     if (empty) {
540       return Vec(0);
541     }
542     Vec twice_span_vec(twice_span);
543     auto abs_in = in.abs();
544     auto fdouble_flips = abs_in / twice_span_vec;
545     auto double_flips = fdouble_flips.trunc();
546     auto extra = abs_in - double_flips * twice_span_vec;
547     return minimum(extra, twice_span_vec - extra);
548   }
549 
550   inline std::pair<vec256::Vec256<T>, vec256::Vec256<T>> reflect_coordinates_get_grad(
551     const vec256::Vec256<T> &in) const {
552     if (empty) {
553       return std::make_pair(Vec(0), Vec(0));
554     }
555     Vec twice_span_vec(twice_span);
556     auto neg_in = in < Vec(0);
557     auto abs_in = in.abs();
558     auto fdouble_flips = abs_in / twice_span_vec;
559     auto double_flips = fdouble_flips.trunc();
560 
561     auto extra = abs_in - double_flips * twice_span_vec;
562     auto reflected_extra = twice_span_vec - extra;
563     auto one_more_flip = extra > reflected_extra;
564 
565     return std::make_pair(Vec::blendv(extra, reflected_extra, one_more_flip),
566                           Vec::blendv(Vec(1), Vec(-1), one_more_flip ^ neg_in));
567   }
568 };
569 
570 template <typename T>
571 struct ComputeLocationBase<T, false> {
572   using Vec = vec256::Vec256<T>;
573   const T max_val;
574   const T scaling_factor;
575   const T low;
576   const T twice_span;
577   const bool empty;  // only used when align_corners=True
578 
579   explicit ComputeLocationBase(int64_t size)
580       : max_val(static_cast<T>(size - 1)),
581         scaling_factor(static_cast<T>(size) / 2),
582         low(static_cast<T>(-0.5)),
583         twice_span(static_cast<T>(size) * 2),
584         empty(size <= 0) {}
585 
586   inline Vec unnormalize(const Vec &in) const { return (in + Vec(1)) * Vec(scaling_factor) - Vec(0.5); }
587 
588   inline Vec clip_coordinates(const Vec &in) const { return minimum(Vec(max_val), maximum(in, Vec(0))); }
589   inline std::pair<Vec, Vec> clip_coordinates_get_grad(const Vec &in) const {
590     using int_t = vec256::int_same_size_t<T>;
591     auto bounded_lo = maximum(in, Vec(0));
592     auto in_bound_lo = vec256::cast<T>(vec256::cast<int_t>(bounded_lo) != vec256::cast<int_t>(Vec(0)));
593     auto res = minimum(bounded_lo, Vec(max_val));
594     auto in_bound_hi = vec256::cast<T>(vec256::cast<int_t>(res) != vec256::cast<int_t>(Vec(max_val)));
595     return std::make_pair(res, in_bound_lo & in_bound_hi);
596   }
597 
598   inline Vec reflect_coordinates(const Vec &in) const {
599     Vec twice_span_vec(twice_span), low_vec(low);
600     auto abs_in = (in - low_vec).abs();
601     auto fdouble_flips = abs_in / twice_span_vec;
602     auto double_flips = fdouble_flips.trunc();
603     auto extra = abs_in - double_flips * twice_span_vec;
604     return minimum(extra, twice_span_vec - extra) + low_vec;
605   }
606 
607   inline std::pair<vec256::Vec256<T>, vec256::Vec256<T>> reflect_coordinates_get_grad(
608     const vec256::Vec256<T> &in) const {
609     Vec twice_span_vec(twice_span), low_vec(low);
610     Vec in_minus_low = in - low_vec;
611     auto neg_in = in_minus_low < Vec(0);
612     auto abs_in = in_minus_low.abs();
613     auto fdouble_flips = abs_in / twice_span_vec;
614     auto double_flips = fdouble_flips.trunc();
615 
616     auto extra = abs_in - double_flips * twice_span_vec;
617     auto reflected_extra = twice_span_vec - extra;
618     auto one_more_flip = extra > reflected_extra;
619     auto boolex = one_more_flip ^ neg_in;
620     return std::make_pair(Vec::blendv(extra, reflected_extra, one_more_flip) + low_vec,
621                           Vec::blendv(Vec(1), Vec(-1), boolex));
622   }
623 };
624 
625 template <typename T, GridSamplerPadding padding, bool align_corners>
626 struct ComputeLocation;
627 
628 template <typename T, bool align_corners>
629 struct ComputeLocation<T, GridSamplerPadding::Zeros, align_corners> : ComputeLocationBase<T, align_corners> {
630   using Vec = vec256::Vec256<T>;
631   using ComputeLocationBase<T, align_corners>::unnormalize;
632   using ComputeLocationBase<T, align_corners>::scaling_factor;
633 
634   using ComputeLocationBase<T, align_corners>::ComputeLocationBase;
635 
636   inline Vec apply(const Vec &in) const { return unnormalize(in); }
637 
638   inline std::pair<Vec, Vec> ApplyGetGrad(const Vec &in) const {
639     return std::make_pair(unnormalize(in), Vec(scaling_factor));
640   }
641 };
642 
643 template <typename T, bool align_corners>
644 struct ComputeLocation<T, GridSamplerPadding::Border, align_corners> : ComputeLocationBase<T, align_corners> {
645   using Vec = vec256::Vec256<T>;
646   using ComputeLocationBase<T, align_corners>::unnormalize;
647   using ComputeLocationBase<T, align_corners>::clip_coordinates;
648   using ComputeLocationBase<T, align_corners>::clip_coordinates_get_grad;
649   using ComputeLocationBase<T, align_corners>::scaling_factor;
650 
651   using ComputeLocationBase<T, align_corners>::ComputeLocationBase;
652 
653   inline Vec apply(const Vec &in) const { return clip_coordinates(unnormalize(in)); }
654 
655   inline std::pair<Vec, Vec> ApplyGetGrad(const Vec &in) const {
656     Vec res, grad_clip;
657     std::tie(res, grad_clip) = clip_coordinates_get_grad(unnormalize(in));
658     return std::make_pair(res, grad_clip & Vec(scaling_factor));
659   }
660 };
661 
662 template <typename T, bool align_corners>
663 struct ComputeLocation<T, GridSamplerPadding::Reflection, align_corners> : ComputeLocationBase<T, align_corners> {
664   using Vec = vec256::Vec256<T>;
665   using ComputeLocationBase<T, align_corners>::unnormalize;
666   using ComputeLocationBase<T, align_corners>::clip_coordinates;
667   using ComputeLocationBase<T, align_corners>::clip_coordinates_get_grad;
668   using ComputeLocationBase<T, align_corners>::reflect_coordinates;
669   using ComputeLocationBase<T, align_corners>::reflect_coordinates_get_grad;
670   using ComputeLocationBase<T, align_corners>::scaling_factor;
671 
672   using ComputeLocationBase<T, align_corners>::ComputeLocationBase;
673 
674   inline Vec apply(const Vec &in) const {
675     auto res = reflect_coordinates(unnormalize(in));
676     res = clip_coordinates(res);
677     return res;
678   }
679 
680   inline std::pair<Vec, Vec> ApplyGetGrad(const Vec &in) const {
681     Vec res, grad_refl, grad_clip, grad(scaling_factor);
682     std::tie(res, grad_refl) = reflect_coordinates_get_grad(unnormalize(in));
683     grad = grad_refl * grad;
684     std::tie(res, grad_clip) = clip_coordinates_get_grad(res);
685     grad = grad_clip & grad;
686     return std::make_pair(res, grad);
687   }
688 };
689 
690 template <typename T>
691 static inline void MaskScatterAdd(const T *src, T *base_addr, const vec256::int_same_size_t<T> *offsets,
692                                   const vec256::int_same_size_t<T> *mask, int64_t len) {
693   for (int64_t i = 0; i < len; i++) {
694     if (mask[i] & 0x01) {
695       base_addr[offsets[i]] += src[i];
696     }
697   }
698 }
699 
700 template <typename T, int spatial_dim, GridSamplerInterpolation interp, GridSamplerPadding padding, bool align_corners>
701 struct ApplyGridSample2D;
702 
703 template <typename T, GridSamplerPadding padding, bool align_corners>
704 struct ApplyGridSample2D<T, hTwo, GridSamplerInterpolation::Bilinear, padding, align_corners> {
705   using Vec = vec256::Vec256<T>;
706   using integer_t = vec256::int_same_size_t<T>;
707   using iVec = vec256::Vec256<integer_t>;
708 
709   const int64_t InpH;
710   const int64_t InpW;
711   const int64_t InpSH;
712   const int64_t InpSW;
713   const int64_t C;
714   const int64_t InpSC;
715   const ComputeLocation<T, padding, align_corners> ComputeH;
716   const ComputeLocation<T, padding, align_corners> ComputeW;
717   const bool MustInBound = padding != GridSamplerPadding::Zeros;
718 
719   explicit ApplyGridSample2D(const TensorAcc<T, 4> &input)
720       : InpH(input.size(2)),
721         InpW(input.size(3)),
722         InpSH(input.stride(2)),
723         InpSW(input.stride(3)),
724         C(input.size(1)),
725         InpSC(input.stride(1)),
726         ComputeH(input.size(2)),
727         ComputeW(input.size(3)) {}
728   inline std::tuple<Vec, Vec, Vec, Vec, Vec, Vec, Vec, Vec, Vec, Vec, Vec, Vec, iVec, iVec> ComputeInterpParams(
729     const Vec &X, const Vec &Y) const {
730     auto XW = X.floor();
731     auto YN = Y.floor();
732     auto W = X - XW;
733     auto E = Vec(1) - W;
734     auto N = Y - YN;
735     auto S = Vec(1) - N;
736     auto NW = S * E;
737     auto NE = S * W;
738     auto SW = N * E;
739     auto SE = N * W;
740     auto IXW = vec256::ConvertToIntOfSameSize(XW);
741     auto IYN = vec256::ConvertToIntOfSameSize(YN);
742     auto IXE = IXW + iVec(1);
743     auto IYS = IYN + iVec(1);
744     auto WMask = MustInBound ? iVec(-1) : (IXW > iVec(-1)) & (IXW < iVec(InpW));
745     auto NMask = MustInBound ? iVec(-1) : (IYN > iVec(-1)) & (IYN < iVec(InpH));
746     auto EMask = MustInBound ? (IXE < iVec(InpW)) : (IXE > iVec(-1)) & (IXE < iVec(InpW));
747     auto SMask = MustInBound ? (IYS < iVec(InpH)) : (IYS > iVec(-1)) & (IYS < iVec(InpH));
748     auto NWMask = vec256::cast<T>(MustInBound ? iVec(-1) : (WMask & NMask));
749     auto NEMask = vec256::cast<T>(EMask & NMask);
750     auto SWMask = vec256::cast<T>(WMask & SMask);
751     auto SEMask = vec256::cast<T>(EMask & SMask);
752 
753     return std::make_tuple(N, S, W, E, NW, NE, SW, SE, NWMask, NEMask, SWMask, SEMask, IYN, IXW);
754   }
755 
756   inline void Backward(TensorAcc<T, 3> *GInpSlice, TensorAcc<T, 3> *GGridSlice, const TensorAcc<T, 3> &GOutSlice,
757                        const TensorAcc<T, 3> &InpSlice, int64_t offset, const Vec &grid_x, const Vec &grid_y,
758                        int64_t len) const {
759     Vec X, Y, GxMult, GyMult;
760     std::tie(X, GxMult) = ComputeW.ApplyGetGrad(grid_x);
761     std::tie(Y, GyMult) = ComputeH.ApplyGetGrad(grid_y);
762 
763     iVec IYN, IXW;
764     Vec N, S, W, E, NW, NE, SW, SE, NWMask, NEMask, SWMask, SEMask;
765 
766     std::tie(N, S, W, E, NW, NE, SW, SE, NWMask, NEMask, SWMask, SEMask, IYN, IXW) = ComputeInterpParams(X, Y);
767 
768     auto INWOffset = IYN * iVec(InpSH) + IXW * iVec(InpSW);
769     auto INEOffset = iVec(InpSW) + INWOffset;
770     auto ISWOffset = iVec(InpSH) + INWOffset;
771     auto ISEOffset = iVec(InpSW) + ISWOffset;
772 
773     auto IGInpNWOffset = IYN * iVec(InpW) + IXW;
774     auto IGInpNEOffset = IGInpNWOffset + iVec(1);
775     auto IGInpSWOffset = IGInpNWOffset + iVec(InpW);
776     auto IGInpSEOffset = IGInpSWOffset + iVec(1);
777     static constexpr int kSize = iVec::size();
778     integer_t IGInpNWOffsetArr[kSize];
779     integer_t IGInpNEOffsetArr[kSize];
780     integer_t IGInpSWOffsetArr[kSize];
781     integer_t IGInpSEOffsetArr[kSize];
782     IGInpNWOffset.store(IGInpNWOffsetArr);
783     IGInpNEOffset.store(IGInpNEOffsetArr);
784     IGInpSWOffset.store(IGInpSWOffsetArr);
785     IGInpSEOffset.store(IGInpSEOffsetArr);
786 
787     integer_t INWMaskArr[kSize], INEMaskArr[kSize], ISWMaskArr[kSize], ISEMaskArr[kSize];
788     NWMask.store(INWMaskArr);
789     NEMask.store(INEMaskArr);
790     SWMask.store(ISWMaskArr);
791     SEMask.store(ISEMaskArr);
792 
793     T GInpCornerArr[Vec::size()];
794 
795     auto GX = Vec(hZero), GY = Vec(hZero);
796     int64_t i = 0;
797     while (i < C) {
798       auto InpSliceCPtr = InpSlice[i].data();
799       auto GInpSliceCPtr = (*GInpSlice)[i].data();
800       auto GOut = Vec::LoadU(offset + GOutSlice[i].data(), len);
801 
802       (NW * GOut).store(GInpCornerArr);
803       MaskScatterAdd(GInpCornerArr, GInpSliceCPtr, IGInpNWOffsetArr, INWMaskArr, len);
804       (NE * GOut).store(GInpCornerArr);
805       MaskScatterAdd(GInpCornerArr, GInpSliceCPtr, IGInpNEOffsetArr, INEMaskArr, len);
806       (SW * GOut).store(GInpCornerArr);
807       MaskScatterAdd(GInpCornerArr, GInpSliceCPtr, IGInpSWOffsetArr, ISWMaskArr, len);
808       (SE * GOut).store(GInpCornerArr);
809       MaskScatterAdd(GInpCornerArr, GInpSliceCPtr, IGInpSEOffsetArr, ISEMaskArr, len);
810       Vec NWMaskCopy = NWMask;
811       Vec NEMaskCopy = NEMask;
812       Vec SWMaskCopy = SWMask;
813       Vec SEMaskCopy = SEMask;
814       auto NWVal = vec256::MaskGather<sizeof(T)>(Vec(0), InpSliceCPtr, INWOffset, &NWMaskCopy);
815       auto NEVal = vec256::MaskGather<sizeof(T)>(Vec(0), InpSliceCPtr, INEOffset, &NEMaskCopy);
816       auto SWVal = vec256::MaskGather<sizeof(T)>(Vec(0), InpSliceCPtr, ISWOffset, &SWMaskCopy);
817       auto SEVal = vec256::MaskGather<sizeof(T)>(Vec(0), InpSliceCPtr, ISEOffset, &SEMaskCopy);
818 
819       GX = GX + (S * (NEVal - NWVal) + N * (SEVal - SWVal)) * GOut;
820       GY = GY + (E * (SWVal - NWVal) + W * (SEVal - NEVal)) * GOut;
821       ++i;
822     }
823 
824     GX = GX * GxMult;
825     GY = GY * GyMult;
826 
827     constexpr int64_t step = Vec::size();
828     auto InterleavedGGrid = interleave2(GX, GY);
829     auto GGridPtr = (*GGridSlice)[0].data() + offset * 2;
830     std::get<0>(InterleavedGGrid).store(GGridPtr, std::min(len * hTwo, step));
831     std::get<1>(InterleavedGGrid).store(GGridPtr + step, std::max(static_cast<int64_t>(0), len * hTwo - step));
832   }
833 };
834 
835 template <typename T, GridSamplerPadding padding, bool align_corners>
836 struct ApplyGridSample2D<T, hTwo, GridSamplerInterpolation::Nearest, padding, align_corners> {
837   using Vec = vec256::Vec256<T>;
838   using integer_t = vec256::int_same_size_t<T>;
839   using iVec = vec256::Vec256<integer_t>;
840 
841   const int64_t InpH;
842   const int64_t InpW;
843   const int64_t InpSH;
844   const int64_t InpSW;
845   const int64_t C;
846   const int64_t InpSC;
847   const ComputeLocation<T, padding, align_corners> ComputeH;
848   const ComputeLocation<T, padding, align_corners> ComputeW;
849   const bool MustInBound = padding != GridSamplerPadding::Zeros;
850 
851   explicit ApplyGridSample2D(const TensorAcc<T, 4> &input)
852       : InpH(input.size(2)),
853         InpW(input.size(3)),
854         InpSH(input.stride(2)),
855         InpSW(input.stride(3)),
856         C(input.size(1)),
857         InpSC(input.stride(1)),
858         ComputeH(input.size(2)),
859         ComputeW(input.size(3)) {}
860 
861   inline void Backward(TensorAcc<T, 3> *GInpSlice, TensorAcc<T, 3> *GGridSlice, const TensorAcc<T, 3> &GOutSlice,
862                        const TensorAcc<T, 3> &, int64_t offset, const Vec &grid_x, const Vec &grid_y,
863                        int64_t len) const {
864     auto X = ComputeW.apply(grid_x);
865     auto XNearest = X.round();
866     auto IXNearest = vec256::ConvertToIntOfSameSize<T>(XNearest);
867     auto Y = ComputeH.apply(grid_y);
868     auto YNearest = Y.round();
869     auto IYNearest = vec256::ConvertToIntOfSameSize<T>(YNearest);
870 
871     auto IMask = MustInBound ? iVec(-1)
872                              : (IXNearest > iVec(-1)) & (IXNearest < iVec(InpW)) & (IYNearest > iVec(-1)) &
873                                  (IYNearest < iVec(InpH));
874 
875     auto IGInpOffset = IXNearest + iVec(InpW) * IYNearest;  // gInp is contiguous
876     static constexpr int kSize = iVec::size();
877     integer_t MaskArr[kSize], GInpOffsetArr[kSize];
878     IMask.store(MaskArr);
879     IGInpOffset.store(GInpOffsetArr);
880 
881     int64_t i = 0;
882     while (i < C) {
883       MaskScatterAdd(GOutSlice[i].data() + offset, (*GInpSlice)[i].data(), GInpOffsetArr, MaskArr, len);
884       ++i;
885     }
886     auto GGridPtr = (*GGridSlice)[0].data() + offset * 2;
887     auto ret = memset_s(static_cast<void *>(GGridPtr), sizeof(T) * len * hTwo, 0, sizeof(T) * len * hTwo);
888     if (ret != 0) {
889       MS_LOG(ERROR) << "memset_s error, errorno(" << ret << ")";
890     }
891   }
892 };
893 
894 template <typename T, typename ApplyFn>
895 static inline void GridSampler2DGridSliceIterator(const TensorAcc<T, 3> &GridSlice, const ApplyFn &ApplyFN) {
896   int64_t OutH = GridSlice.size(0);
897   int64_t OutW = GridSlice.size(1);
898   int64_t GridSH = GridSlice.stride(0);
899   int64_t GridSW = GridSlice.stride(1);
900   int64_t GridSCoor = GridSlice.stride(2);
901   auto GridPtr = GridSlice.data();
902 
903   using Vec = vec256::Vec256<T>;
904   using iVec = vec256::Vec256<vec256::int_same_size_t<T>>;
905   constexpr int64_t step = Vec::size();
906 
907   if (GeometryIsContiguous({OutH, OutW, 2}, {GridSH, GridSW, GridSCoor})) {
908     int64_t tSize, spatial_offset;
909     tSize = OutH * OutW;
910     spatial_offset = 0;
911     while (spatial_offset < tSize) {
912       int64_t grid_offset, len;
913       grid_offset = spatial_offset * hTwo;
914       len = std::min(step, tSize - spatial_offset);
915       auto vec1 = Vec::LoadU(GridPtr + grid_offset, std::min(step, len * 2));
916       auto vec2 = Vec::LoadU(GridPtr + grid_offset + step, std::max(static_cast<int64_t>(0), len * 2 - step));
917       auto PairVecXY = deinterleave2(vec1, vec2);
918 
919       auto Y = std::get<1>(PairVecXY);
920       auto X = std::get<0>(PairVecXY);
921       if (len < step) {
922         X = Vec::set(Vec(0), X, len);
923         Y = Vec::set(Vec(0), Y, len);
924       }
925 
926       ApplyFN(X, Y, spatial_offset, len);
927       spatial_offset += step;
928     }
929   } else if (GridSW == hOne || OutW == hOne) {
930     auto LineFn = [&ApplyFN, &step](const T *grid_ptr_x, const T *grid_ptr_y, int64_t out_base_offset, int64_t tSize) {
931       int64_t i = 0;
932       while (i < tSize) {
933         int64_t len;
934         len = std::min(step, tSize - i);
935         auto X = Vec::LoadU(grid_ptr_x + i, len);
936         auto Y = Vec::LoadU(grid_ptr_y + i, len);
937         // make sure that X and Y are valid grid sample locations
938         if (len < step) {
939           X = Vec::set(Vec(0), X, len);
940           Y = Vec::set(Vec(0), Y, len);
941         }
942         ApplyFN(X, Y, out_base_offset + i, len);
943         i += step;
944       }
945     };
946 
947     if (GeometryIsContiguous({OutH, OutW}, {GridSH, GridSW})) {
948       LineFn(GridPtr, GridPtr + GridSCoor, 0, OutH * OutW);
949     } else {
950       auto grid_ptr_NH = GridPtr;
951       int64_t h = 0;
952       while (h < OutH) {
953         LineFn(grid_ptr_NH, grid_ptr_NH + GridSCoor, h * OutW, OutW);
954         grid_ptr_NH += GridSH;
955         h++;
956       }
957     }
958   } else {
959     auto spatial_offset = 0;
960     auto i_offsets_delta = iVec(GridSW * step);
961     int64_t h = 0;
962     while (h < OutH) {
963       auto grid_ptr_x = h * GridSH + GridPtr;
964       auto grid_ptr_y = GridSCoor + grid_ptr_x;
965       auto i_offsets = iVec::arange(0, LongToInt(GridSW));
966       int64_t w = 0;
967       while (w < OutW) {
968         auto len = std::min(step, OutW - w);
969         if (len < step) {
970           i_offsets = iVec::set(iVec(0), i_offsets, len);
971         }
972         ApplyFN(vec256::GatherVec<sizeof(T)>(grid_ptr_x, i_offsets),
973                 vec256::GatherVec<sizeof(T)>(grid_ptr_y, i_offsets), spatial_offset, len);
974 
975         i_offsets = i_offsets + i_offsets_delta;
976         spatial_offset += len;
977         w += step;
978       }
979       h++;
980     }
981   }
982 }
983 }  // namespace kernel
984 }  // namespace mindspore
985 
986 #endif  // MINDSPORE_CCSRC_PLUGIN_DEVICE_CPU_GRID_SAMPLER_2D_GRAD_CPU_KERNEL_H_
987