• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "minddata/dataset/kernels/image/lite_cv/lite_mat.h"
17 
18 #include <limits>
19 #include <algorithm>
20 #include <cmath>
21 #ifdef ENABLE_NEON
22 #include <arm_neon.h>
23 #endif
24 
25 namespace mindspore {
26 namespace dataset {
27 
LiteMat()28 LiteMat::LiteMat() {
29   data_ptr_ = nullptr;
30   elem_size_ = 0;
31   width_ = 0;
32   height_ = 0;
33   channel_ = 0;
34   c_step_ = 0;
35   dims_ = 0;
36   size_ = 0;
37   data_type_ = LDataType::UINT8;
38   ref_count_ = nullptr;
39   setSteps(0, 0, 0);
40   release_flag = false;
41 }
42 
LiteMat(int width,LDataType data_type)43 LiteMat::LiteMat(int width, LDataType data_type) {
44   data_ptr_ = nullptr;
45   elem_size_ = 0;
46   width_ = 0;
47   height_ = 0;
48   channel_ = 0;
49   c_step_ = 0;
50   dims_ = 0;
51   data_type_ = LDataType::UINT8;
52   ref_count_ = nullptr;
53   size_ = 0;
54   setSteps(0, 0, 0);
55   release_flag = false;
56   Init(width, data_type);
57 }
58 
LiteMat(int width,int height,LDataType data_type)59 LiteMat::LiteMat(int width, int height, LDataType data_type) {
60   data_ptr_ = nullptr;
61   elem_size_ = 0;
62   width_ = 0;
63   height_ = 0;
64   channel_ = 0;
65   c_step_ = 0;
66   dims_ = 0;
67   data_type_ = LDataType::UINT8;
68   ref_count_ = nullptr;
69   size_ = 0;
70   setSteps(0, 0, 0);
71   release_flag = false;
72   Init(width, height, data_type);
73 }
74 
LiteMat(int width,int height,void * p_data,LDataType data_type)75 LiteMat::LiteMat(int width, int height, void *p_data, LDataType data_type) {
76   data_ptr_ = nullptr;
77   elem_size_ = 0;
78   width_ = 0;
79   height_ = 0;
80   channel_ = 0;
81   c_step_ = 0;
82   dims_ = 0;
83   data_type_ = LDataType::UINT8;
84   ref_count_ = nullptr;
85   size_ = 0;
86   setSteps(0, 0, 0);
87   release_flag = false;
88   Init(width, height, p_data, data_type);
89 }
90 
LiteMat(int width,int height,int channel,LDataType data_type)91 LiteMat::LiteMat(int width, int height, int channel, LDataType data_type) {
92   data_ptr_ = nullptr;
93   elem_size_ = 0;
94   width_ = 0;
95   height_ = 0;
96   channel_ = 0;
97   c_step_ = 0;
98   dims_ = 0;
99   data_type_ = LDataType::UINT8;
100   ref_count_ = nullptr;
101   size_ = 0;
102   setSteps(0, 0, 0);
103   release_flag = false;
104   Init(width, height, channel, data_type);
105 }
106 
LiteMat(int width,int height,int channel,void * p_data,LDataType data_type)107 LiteMat::LiteMat(int width, int height, int channel, void *p_data, LDataType data_type) {
108   data_ptr_ = nullptr;
109   elem_size_ = 0;
110   width_ = 0;
111   height_ = 0;
112   channel_ = 0;
113   c_step_ = 0;
114   dims_ = 0;
115   data_type_ = LDataType::UINT8;
116   ref_count_ = nullptr;
117   size_ = 0;
118   setSteps(0, 0, 0);
119   release_flag = false;
120   Init(width, height, channel, p_data, data_type);
121 }
122 
~LiteMat()123 LiteMat::~LiteMat() { Release(); }
124 
addRef(int * p,int value)125 int LiteMat::addRef(int *p, int value) {
126   int v = *p;
127   *p += value;
128   return v;
129 }
130 
LiteMat(const LiteMat & m)131 LiteMat::LiteMat(const LiteMat &m) {
132   data_ptr_ = m.data_ptr_;
133   elem_size_ = m.elem_size_;
134   width_ = m.width_;
135   height_ = m.height_;
136   channel_ = m.channel_;
137   c_step_ = m.c_step_;
138   dims_ = m.dims_;
139   data_type_ = m.data_type_;
140   ref_count_ = m.ref_count_;
141   size_ = m.size_;
142   release_flag = m.release_flag;
143   setSteps(m.steps_[0], m.steps_[1], m.steps_[2]);
144   if (ref_count_) {
145     addRef(ref_count_, 1);
146   }
147 }
148 
setSteps(int c0,int c1,int c2)149 void LiteMat::setSteps(int c0, int c1, int c2) {
150   steps_[0] = c0;
151   steps_[1] = c1;
152   steps_[2] = c2;
153 }
154 
operator =(const LiteMat & m)155 LiteMat &LiteMat::operator=(const LiteMat &m) {
156   if (this == &m) {
157     return *this;
158   }
159 
160   if (m.ref_count_) {
161     addRef(m.ref_count_, 1);
162   }
163 
164   Release();
165   data_ptr_ = m.data_ptr_;
166   elem_size_ = m.elem_size_;
167   width_ = m.width_;
168   height_ = m.height_;
169   channel_ = m.channel_;
170   c_step_ = m.c_step_;
171   dims_ = m.dims_;
172   data_type_ = m.data_type_;
173   ref_count_ = m.ref_count_;
174   setSteps(m.steps_[0], m.steps_[1], m.steps_[2]);
175   size_ = m.size_;
176   return *this;
177 }
178 
Init(int width,LDataType data_type)179 void LiteMat::Init(int width, LDataType data_type) {
180   Release();
181   data_type_ = data_type;
182   InitElemSize(data_type);
183   width_ = width;
184   dims_ = 1;
185   height_ = 1;
186   channel_ = 1;
187   if (!CheckLiteMat()) {
188     Release();
189     return;
190   }
191   c_step_ = width;
192   size_ = c_step_ * elem_size_;
193   data_ptr_ = AlignMalloc(size_);
194   ref_count_ = new int[1];
195   *ref_count_ = 1;
196   steps_[0] = elem_size_;
197 }
198 
Init(int width,int height,LDataType data_type)199 void LiteMat::Init(int width, int height, LDataType data_type) {
200   Release();
201   data_type_ = data_type;
202   InitElemSize(data_type);
203   width_ = width;
204   height_ = height;
205   dims_ = 2;
206   channel_ = 1;
207   if (!CheckLiteMat()) {
208     Release();
209     return;
210   }
211   c_step_ = width_ * height_;
212   size_ = c_step_ * elem_size_;
213   data_ptr_ = AlignMalloc(size_);
214   ref_count_ = new int[1];
215   *ref_count_ = 1;
216   steps_[1] = elem_size_;
217   steps_[0] = width_ * steps_[1];
218 }
219 
Init(int width,int height,void * p_data,LDataType data_type)220 void LiteMat::Init(int width, int height, void *p_data, LDataType data_type) {
221   data_type_ = data_type;
222   InitElemSize(data_type);
223   width_ = width;
224   height_ = height;
225   dims_ = 2;
226   channel_ = 1;
227   if (!CheckLiteMat()) {
228     Release();
229     return;
230   }
231   c_step_ = height_ * width_;
232   size_ = c_step_ * channel_ * elem_size_;
233   data_ptr_ = p_data;
234   ref_count_ = nullptr;
235   steps_[1] = elem_size_;
236   steps_[0] = width_ * steps_[1];
237 }
238 
Init(int width,int height,int channel,const LDataType & data_type,bool align_memory)239 void LiteMat::Init(int width, int height, int channel, const LDataType &data_type, bool align_memory) {
240   Release();
241   data_type_ = data_type;
242   InitElemSize(data_type);
243   width_ = width;
244   height_ = height;
245   dims_ = 3;
246   channel_ = channel;
247   if (!CheckLiteMat()) {
248     Release();
249     return;
250   }
251   if (align_memory) {
252     c_step_ = ((height_ * width_ * elem_size_ + ALIGN - 1) & (-ALIGN)) / elem_size_;
253   } else {
254     c_step_ = height_ * width_;
255   }
256   size_ = c_step_ * channel_ * elem_size_;
257   data_ptr_ = AlignMalloc(size_);
258   ref_count_ = new int[1];
259   *ref_count_ = 1;
260 
261   steps_[2] = elem_size_;
262   steps_[1] = channel * steps_[2];
263   steps_[0] = width_ * steps_[1];
264 }
265 
Init(int width,int height,int channel,void * p_data,LDataType data_type)266 void LiteMat::Init(int width, int height, int channel, void *p_data, LDataType data_type) {
267   data_type_ = data_type;
268   InitElemSize(data_type);
269   width_ = width;
270   height_ = height;
271   dims_ = 3;
272   channel_ = channel;
273   if (!CheckLiteMat()) {
274     Release();
275     return;
276   }
277   c_step_ = height_ * width_;
278   size_ = c_step_ * channel_ * elem_size_;
279   data_ptr_ = p_data;
280   ref_count_ = nullptr;
281   steps_[2] = elem_size_;
282   steps_[1] = channel * steps_[2];
283   steps_[0] = width_ * steps_[1];
284 }
285 
IsEmpty() const286 bool LiteMat::IsEmpty() const { return data_ptr_ == nullptr || c_step_ * channel_ == 0; }
287 
Release()288 void LiteMat::Release() {
289   if (ref_count_ && (addRef(ref_count_, -1) == 1)) {
290     if (data_ptr_) {
291       AlignFree(data_ptr_);
292     }
293     delete[] ref_count_;
294   }
295   data_ptr_ = nullptr;
296   elem_size_ = 0;
297   width_ = 0;
298   height_ = 0;
299   channel_ = 0;
300   c_step_ = 0;
301   ref_count_ = nullptr;
302   size_ = 0;
303   setSteps(0, 0, 0);
304 }
305 
AlignMalloc(unsigned int size)306 void *LiteMat::AlignMalloc(unsigned int size) {
307   unsigned int length = sizeof(void *) + ALIGN - 1;
308   if (size > std::numeric_limits<uint32_t>::max() - length) {
309     return nullptr;
310   }
311   void *p_raw = reinterpret_cast<void *>(malloc(size + length));
312   if (p_raw) {
313     release_flag = true;
314     void **p_algin = reinterpret_cast<void **>(((size_t)(p_raw) + length) & ~(ALIGN - 1));
315     p_algin[-1] = p_raw;
316     return p_algin;
317   }
318   return nullptr;
319 }
320 
AlignFree(void * ptr)321 void LiteMat::AlignFree(void *ptr) {
322   if (release_flag) {
323     (void)free(reinterpret_cast<void **>(ptr)[-1]);
324     ptr = nullptr;
325     release_flag = false;
326   }
327 }
328 
InitElemSize(LDataType data_type)329 inline void LiteMat::InitElemSize(LDataType data_type) { elem_size_ = data_type.SizeInBytes(); }
330 
CheckLiteMat()331 bool LiteMat::CheckLiteMat() {
332   if (width_ <= 0 || height_ <= 0 || channel_ <= 0 || elem_size_ <= 0) {
333     return false;
334   }
335   if (height_ != 1 && height_ > std::numeric_limits<int>::max() / width_) {
336     return false;
337   }
338   int area = height_ * width_;
339   if (channel_ != 1 && channel_ > std::numeric_limits<int>::max() / area) {
340     return false;
341   }
342   int size = area * channel_;
343   if (elem_size_ > std::numeric_limits<int>::max() / size) {
344     return false;
345   }
346   return true;
347 }
348 
GetROI(int x,int y,int w,int h,LiteMat & m)349 bool LiteMat::GetROI(int x, int y, int w, int h, LiteMat &m) {
350   if (x < 0 || y < 0 || x > width_ - w || h > height_ - y || w <= 0 || h <= 0) {
351     return false;
352   }
353   if (!m.IsEmpty()) {
354     m.Release();
355   }
356 
357   if (ref_count_) {
358     addRef(ref_count_, 1);
359   }
360 
361   m.height_ = h;
362   m.width_ = w;
363   m.dims_ = dims_;
364   m.elem_size_ = elem_size_;
365   m.data_ptr_ = reinterpret_cast<uint8_t *>(data_ptr_) + y * steps_[0] + x * elem_size_ * channel_;
366   m.channel_ = channel_;
367   m.c_step_ = c_step_;
368   m.data_type_ = data_type_;
369   m.ref_count_ = ref_count_;
370   m.setSteps(steps_[0], steps_[1], steps_[2]);
371   return true;
372 }
373 
374 template <typename T>
SubtractImpl(const T * src0,const T * src1,T * dst,int64_t total_size)375 inline void SubtractImpl(const T *src0, const T *src1, T *dst, int64_t total_size) {
376   for (int64_t i = 0; i < total_size; i++) {
377     dst[i] = src0[i] - src1[i];
378   }
379 }
380 
381 template <>
SubtractImpl(const uint8_t * src0,const uint8_t * src1,uint8_t * dst,int64_t total_size)382 inline void SubtractImpl(const uint8_t *src0, const uint8_t *src1, uint8_t *dst, int64_t total_size) {
383   int64_t x = 0;
384 #ifdef ENABLE_NEON
385   const int64_t step = 32;
386   for (; x <= total_size - step; x += step) {
387     uint8x16_t v_src00 = vld1q_u8(src0 + x);
388     uint8x16_t v_src01 = vld1q_u8(src0 + x + 16);
389     uint8x16_t v_src10 = vld1q_u8(src1 + x);
390     uint8x16_t v_src11 = vld1q_u8(src1 + x + 16);
391     uint8x16_t v_dst;
392 
393     v_dst = vqsubq_u8(v_src00, v_src10);
394     vst1q_u8(dst + x, v_dst);
395 
396     v_dst = vqsubq_u8(v_src01, v_src11);
397     vst1q_u8(dst + x + 16, v_dst);
398   }
399 #endif
400   for (; x < total_size; x++) {
401     int32_t val = static_cast<int32_t>(src0[x]) - src1[x];
402     dst[x] = std::max<int32_t>(std::numeric_limits<uint8_t>::min(),
403                                std::min<int32_t>(std::numeric_limits<uint8_t>::max(), val));
404   }
405 }
406 
407 template <>
SubtractImpl(const uint16_t * src0,const uint16_t * src1,uint16_t * dst,int64_t total_size)408 inline void SubtractImpl(const uint16_t *src0, const uint16_t *src1, uint16_t *dst, int64_t total_size) {
409   for (int64_t i = 0; i < total_size; i++) {
410     int32_t val = static_cast<int32_t>(src0[i]) - src1[i];
411     dst[i] = std::max<int32_t>(std::numeric_limits<uint16_t>::min(),
412                                std::min<int32_t>(std::numeric_limits<uint16_t>::max(), val));
413   }
414 }
415 
416 template <>
SubtractImpl(const uint32_t * src0,const uint32_t * src1,uint32_t * dst,int64_t total_size)417 inline void SubtractImpl(const uint32_t *src0, const uint32_t *src1, uint32_t *dst, int64_t total_size) {
418   for (int64_t i = 0; i < total_size; i++) {
419     int64_t val = static_cast<int64_t>(src0[i]) - src1[i];
420     dst[i] = std::max<int64_t>(std::numeric_limits<uint32_t>::min(),
421                                std::min<int64_t>(std::numeric_limits<uint32_t>::max(), val));
422   }
423 }
424 
CheckSubstract(const LiteMat & src_a,const LiteMat & src_b,LiteMat * dst)425 inline bool CheckSubstract(const LiteMat &src_a, const LiteMat &src_b, LiteMat *dst) {
426   if (dst == nullptr) {
427     return false;
428   }
429 
430   if (src_a.width_ != src_b.width_ || src_a.height_ != src_b.height_ || src_a.channel_ != src_b.channel_) {
431     return false;
432   }
433 
434   return src_a.data_type_ == src_b.data_type_;
435 }
436 
Subtract(const LiteMat & src_a,const LiteMat & src_b,LiteMat * dst)437 bool Subtract(const LiteMat &src_a, const LiteMat &src_b, LiteMat *dst) {
438   if (!CheckSubstract(src_a, src_b, dst)) {
439     return false;
440   }
441 
442   if (dst->IsEmpty()) {
443     dst->Init(src_a.width_, src_a.height_, src_a.channel_, src_a.data_type_);
444   } else if (src_a.width_ != dst->width_ || src_a.height_ != dst->height_ || src_a.channel_ != dst->channel_) {
445     return false;
446   } else if (src_a.data_type_ != dst->data_type_) {
447     return false;
448   }
449 
450   int64_t total_size = src_a.height_ * src_a.width_ * src_a.channel_;
451   if (src_a.data_type_ == LDataType::BOOL) {
452     SubtractImpl<bool>(src_a, src_b, *dst, total_size);
453   } else if (src_a.data_type_ == LDataType::INT8) {
454     SubtractImpl<int8_t>(src_a, src_b, *dst, total_size);
455   } else if (src_a.data_type_ == LDataType::UINT8) {
456     SubtractImpl<uint8_t>(src_a, src_b, *dst, total_size);
457   } else if (src_a.data_type_ == LDataType::INT16) {
458     SubtractImpl<int16_t>(src_a, src_b, *dst, total_size);
459   } else if (src_a.data_type_ == LDataType::UINT16) {
460     SubtractImpl<uint16_t>(src_a, src_b, *dst, total_size);
461   } else if (src_a.data_type_ == LDataType::INT32) {
462     SubtractImpl<int32_t>(src_a, src_b, *dst, total_size);
463   } else if (src_a.data_type_ == LDataType::UINT32) {
464     SubtractImpl<uint32_t>(src_a, src_b, *dst, total_size);
465   } else if (src_a.data_type_ == LDataType::INT64) {
466     SubtractImpl<int64_t>(src_a, src_b, *dst, total_size);
467   } else if (src_a.data_type_ == LDataType::UINT64) {
468     SubtractImpl<uint64_t>(src_a, src_b, *dst, total_size);
469   } else if (src_a.data_type_ == LDataType::FLOAT32) {
470     SubtractImpl<float>(src_a, src_b, *dst, total_size);
471   } else if (src_a.data_type_ == LDataType::FLOAT64) {
472     SubtractImpl<double>(src_a, src_b, *dst, total_size);
473   } else {
474     return false;
475   }
476 
477   return true;
478 }
479 
480 #ifdef ENABLE_NEON
reciprocal_simd(float32x4_t val)481 inline float32x4_t reciprocal_simd(float32x4_t val) {
482   // get an initial estimate of 1/val
483   float32x4_t reciprocal = vrecpeq_f32(val);
484 
485   // use Newton-Raphson steps to refine the estimate
486   reciprocal = vmulq_f32(vrecpsq_f32(val, reciprocal), reciprocal);
487   reciprocal = vmulq_f32(vrecpsq_f32(val, reciprocal), reciprocal);
488   return reciprocal;
489 }
490 
round_simd(const float32x4_t & v)491 inline float32x4_t round_simd(const float32x4_t &v) {
492   const int32x4_t signMask = vdupq_n_s32(1U << 31);
493   const int32x4_t half = vreinterpretq_s32_f32(vdupq_n_f32(0.5f));
494   float32x4_t v_addition = vreinterpretq_f32_s32(vorrq_s32(half, vandq_s32(signMask, vreinterpretq_s32_f32(v))));
495   return vaddq_f32(v, v_addition);
496 }
497 #endif
498 
499 template <typename T>
DivideImpl(const T * src0,const T * src1,T * dst,int64_t total_size)500 inline void DivideImpl(const T *src0, const T *src1, T *dst, int64_t total_size) {
501   for (size_t i = 0; i < total_size; i++) {
502     dst[i] = src1[i] ? src0[i] / src1[i] : 0;
503   }
504 }
505 
506 template <>
DivideImpl(const uint8_t * src0,const uint8_t * src1,uint8_t * dst,int64_t total_size)507 inline void DivideImpl(const uint8_t *src0, const uint8_t *src1, uint8_t *dst, int64_t total_size) {
508   int64_t x = 0;
509 #ifdef ENABLE_NEON
510   const int64_t step = 16;
511   for (; x <= total_size - step; x += step) {
512     __builtin_prefetch(reinterpret_cast<const char *>(src0 + x) + 32 * 10);
513     __builtin_prefetch(reinterpret_cast<const char *>(src1 + x) + 32 * 10);
514 
515     uint8x16_t v_a = vld1q_u8(src0 + x);
516     uint8x16_t v_b = vld1q_u8(src1 + x);
517     uint8x16_t v_mask = vtstq_u8(v_b, v_b);
518 
519     uint16x8_t va_l_16x8 = vmovl_u8(vget_low_u8(v_a));
520     uint16x8_t va_h_16x8 = vmovl_u8(vget_high_u8(v_a));
521     uint16x8_t vb_l_16x8 = vmovl_u8(vget_low_u8(v_b));
522     uint16x8_t vb_h_16x8 = vmovl_u8(vget_high_u8(v_b));
523 
524     float32x4_t va_ll_f32x4 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(va_l_16x8)));
525     float32x4_t va_lh_f32x4 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(va_l_16x8)));
526     float32x4_t va_hl_f32x4 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(va_h_16x8)));
527     float32x4_t va_hh_f32x4 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(va_h_16x8)));
528     float32x4_t vb_ll_f32x4 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(vb_l_16x8)));
529     float32x4_t vb_lh_f32x4 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vb_l_16x8)));
530     float32x4_t vb_hl_f32x4 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(vb_h_16x8)));
531     float32x4_t vb_hh_f32x4 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vb_h_16x8)));
532 
533     float32x4_t vb_ll_re_f32x4 = reciprocal_simd(vb_ll_f32x4);
534     float32x4_t vb_lh_re_f32x4 = reciprocal_simd(vb_lh_f32x4);
535     float32x4_t vb_hl_re_f32x4 = reciprocal_simd(vb_hl_f32x4);
536     float32x4_t vb_hh_re_f32x4 = reciprocal_simd(vb_hh_f32x4);
537 
538     float32x4_t dst_ll_f32x4 = round_simd(vmulq_f32(va_ll_f32x4, vb_ll_re_f32x4));
539     float32x4_t dst_lh_f32x4 = round_simd(vmulq_f32(va_lh_f32x4, vb_lh_re_f32x4));
540     float32x4_t dst_hl_f32x4 = round_simd(vmulq_f32(va_hl_f32x4, vb_hl_re_f32x4));
541     float32x4_t dst_hh_f32x4 = round_simd(vmulq_f32(va_hh_f32x4, vb_hh_re_f32x4));
542 
543     uint32x4_t dst_ll_32x4 = vcvtq_u32_f32(dst_ll_f32x4);
544     uint32x4_t dst_lh_32x4 = vcvtq_u32_f32(dst_lh_f32x4);
545     uint32x4_t dst_hl_32x4 = vcvtq_u32_f32(dst_hl_f32x4);
546     uint32x4_t dst_hh_32x4 = vcvtq_u32_f32(dst_hh_f32x4);
547 
548     uint16x4_t dst_ll_16x4 = vqmovn_u32(dst_ll_32x4);
549     uint16x4_t dst_lh_16x4 = vqmovn_u32(dst_lh_32x4);
550     uint16x4_t dst_hl_16x4 = vqmovn_u32(dst_hl_32x4);
551     uint16x4_t dst_hh_16x4 = vqmovn_u32(dst_hh_32x4);
552 
553     uint16x8_t dst_l_16x8 = vcombine_u16(dst_ll_16x4, dst_lh_16x4);
554     uint16x8_t dst_h_16x8 = vcombine_u16(dst_hl_16x4, dst_hh_16x4);
555 
556     int8x8_t dst_l_8x8 = vqmovn_u16(dst_l_16x8);
557     int8x8_t dst_h_8x8 = vqmovn_u16(dst_h_16x8);
558     int8x16_t dst_8x16 = vcombine_u8(dst_l_8x8, dst_h_8x8);
559 
560     dst_8x16 = vandq_u8(dst_8x16, v_mask);
561     vst1q_u8(dst + x, dst_8x16);
562   }
563 #endif
564   for (; x < total_size; x++) {
565     int32_t val = src1[x] ? std::round(src0[x] / src1[x]) : 0;
566     dst[x] = std::max<int32_t>(std::numeric_limits<uint8_t>::min(),
567                                std::min<int32_t>(std::numeric_limits<uint8_t>::max(), val));
568   }
569 }
570 
571 template <>
DivideImpl(const uint16_t * src0,const uint16_t * src1,uint16_t * dst,int64_t total_size)572 inline void DivideImpl(const uint16_t *src0, const uint16_t *src1, uint16_t *dst, int64_t total_size) {
573   for (size_t i = 0; i < total_size; i++) {
574     int32_t val = src1[i] ? std::round(src0[i] / src1[i]) : 0;
575     dst[i] = std::max<int32_t>(std::numeric_limits<uint16_t>::min(),
576                                std::min<int32_t>(std::numeric_limits<uint16_t>::max(), val));
577   }
578 }
579 
580 template <>
DivideImpl(const uint32_t * src0,const uint32_t * src1,uint32_t * dst,int64_t total_size)581 inline void DivideImpl(const uint32_t *src0, const uint32_t *src1, uint32_t *dst, int64_t total_size) {
582   for (size_t i = 0; i < total_size; i++) {
583     int64_t val = src1[i] ? std::round(src0[i] / src1[i]) : 0;
584     dst[i] = std::max<int64_t>(std::numeric_limits<uint32_t>::min(),
585                                std::min<int64_t>(std::numeric_limits<uint32_t>::max(), val));
586   }
587 }
588 
CheckDivide(const LiteMat & src_a,const LiteMat & src_b,LiteMat * dst)589 inline bool CheckDivide(const LiteMat &src_a, const LiteMat &src_b, LiteMat *dst) {
590   if (dst == nullptr) {
591     return false;
592   }
593 
594   if (src_a.width_ != src_b.width_ || src_a.height_ != src_b.height_ || src_a.channel_ != src_b.channel_) {
595     return false;
596   }
597 
598   return src_a.data_type_ == src_b.data_type_;
599 }
600 
Divide(const LiteMat & src_a,const LiteMat & src_b,LiteMat * dst)601 bool Divide(const LiteMat &src_a, const LiteMat &src_b, LiteMat *dst) {
602   if (!CheckDivide(src_a, src_b, dst)) {
603     return false;
604   }
605 
606   if (dst->IsEmpty()) {
607     dst->Init(src_a.width_, src_a.height_, src_a.channel_, src_a.data_type_);
608   } else if (src_a.width_ != dst->width_ || src_a.height_ != dst->height_ || src_a.channel_ != dst->channel_) {
609     return false;
610   } else if (src_a.data_type_ != dst->data_type_) {
611     return false;
612   }
613 
614   int64_t total_size = src_a.height_ * src_a.width_ * src_a.channel_;
615   if (src_a.data_type_ == LDataType::INT8) {
616     DivideImpl<int8_t>(src_a, src_b, *dst, total_size);
617   } else if (src_a.data_type_ == LDataType::UINT8) {
618     DivideImpl<uint8_t>(src_a, src_b, *dst, total_size);
619   } else if (src_a.data_type_ == LDataType::INT16) {
620     DivideImpl<int16_t>(src_a, src_b, *dst, total_size);
621   } else if (src_a.data_type_ == LDataType::UINT16) {
622     DivideImpl<uint16_t>(src_a, src_b, *dst, total_size);
623   } else if (src_a.data_type_ == LDataType::INT32) {
624     DivideImpl<int32_t>(src_a, src_b, *dst, total_size);
625   } else if (src_a.data_type_ == LDataType::UINT32) {
626     DivideImpl<uint32_t>(src_a, src_b, *dst, total_size);
627   } else if (src_a.data_type_ == LDataType::INT64) {
628     DivideImpl<int64_t>(src_a, src_b, *dst, total_size);
629   } else if (src_a.data_type_ == LDataType::UINT64) {
630     DivideImpl<uint64_t>(src_a, src_b, *dst, total_size);
631   } else if (src_a.data_type_ == LDataType::FLOAT32) {
632     DivideImpl<float>(src_a, src_b, *dst, total_size);
633   } else if (src_a.data_type_ == LDataType::FLOAT64) {
634     DivideImpl<double>(src_a, src_b, *dst, total_size);
635   } else {
636     return false;
637   }
638   return true;
639 }
640 
641 template <typename T>
MultiplyImpl(const T * src0,const T * src1,T * dst,int64_t total_size)642 inline void MultiplyImpl(const T *src0, const T *src1, T *dst, int64_t total_size) {
643   for (size_t i = 0; i < total_size; i++) {
644     dst[i] = src0[i] * src1[i];
645   }
646 }
647 
648 template <>
MultiplyImpl(const uint8_t * src0,const uint8_t * src1,uint8_t * dst,int64_t total_size)649 inline void MultiplyImpl(const uint8_t *src0, const uint8_t *src1, uint8_t *dst, int64_t total_size) {
650   int64_t x = 0;
651 #ifdef ENABLE_NEON
652   const int64_t step = 32;
653   for (; x <= total_size - step; x += step) {
654     uint8x16_t v_src00 = vld1q_u8(src0 + x);
655     uint8x16_t v_src01 = vld1q_u8(src0 + x + 16);
656     uint8x16_t v_src10 = vld1q_u8(src1 + x);
657     uint8x16_t v_src11 = vld1q_u8(src1 + x + 16);
658     uint8x16_t v_dst_l, v_dst_h;
659 
660     v_dst_l = vmull_u8(vget_low_u8(v_src00), vget_low_u8(v_src10));
661     v_dst_h = vmull_u8(vget_high_u8(v_src00), vget_high_u8(v_src10));
662     vst1q_u8(dst + x, vcombine_u8(vqmovn_u16(v_dst_l), vqmovn_u16(v_dst_h)));
663 
664     v_dst_l = vmull_u8(vget_low_u8(v_src01), vget_low_u8(v_src11));
665     v_dst_h = vmull_u8(vget_high_u8(v_src01), vget_high_u8(v_src11));
666     vst1q_u8(dst + x + 16, vcombine_u8(vqmovn_u16(v_dst_l), vqmovn_u16(v_dst_h)));
667   }
668 #endif
669   for (; x < total_size; x++) {
670     int32_t val = src0[x] * src1[x];
671     dst[x] = std::max<int32_t>(std::numeric_limits<uint8_t>::min(),
672                                std::min<int32_t>(std::numeric_limits<uint8_t>::max(), val));
673   }
674 }
675 
676 template <>
MultiplyImpl(const uint16_t * src0,const uint16_t * src1,uint16_t * dst,int64_t total_size)677 inline void MultiplyImpl(const uint16_t *src0, const uint16_t *src1, uint16_t *dst, int64_t total_size) {
678   for (size_t i = 0; i < total_size; i++) {
679     int32_t val = src0[i] * src1[i];
680     dst[i] = std::max<int32_t>(std::numeric_limits<uint16_t>::min(),
681                                std::min<int32_t>(std::numeric_limits<uint16_t>::max(), val));
682   }
683 }
684 
685 template <>
MultiplyImpl(const uint32_t * src0,const uint32_t * src1,uint32_t * dst,int64_t total_size)686 inline void MultiplyImpl(const uint32_t *src0, const uint32_t *src1, uint32_t *dst, int64_t total_size) {
687   for (size_t i = 0; i < total_size; i++) {
688     int64_t val = src0[i] * src1[i];
689     dst[i] = std::max<int64_t>(std::numeric_limits<uint32_t>::min(),
690                                std::min<int64_t>(std::numeric_limits<uint32_t>::max(), val));
691   }
692 }
693 
CheckMultiply(const LiteMat & src_a,const LiteMat & src_b,LiteMat * dst)694 inline bool CheckMultiply(const LiteMat &src_a, const LiteMat &src_b, LiteMat *dst) {
695   if (dst == nullptr) {
696     return false;
697   }
698 
699   if (src_a.width_ != src_b.width_ || src_a.height_ != src_b.height_ || src_a.channel_ != src_b.channel_) {
700     return false;
701   }
702 
703   return src_a.data_type_ == src_b.data_type_;
704 }
705 
Multiply(const LiteMat & src_a,const LiteMat & src_b,LiteMat * dst)706 bool Multiply(const LiteMat &src_a, const LiteMat &src_b, LiteMat *dst) {
707   if (!CheckMultiply(src_a, src_b, dst)) {
708     return false;
709   }
710   if (dst->IsEmpty()) {
711     dst->Init(src_a.width_, src_a.height_, src_a.channel_, src_a.data_type_);
712   } else if (src_a.width_ != dst->width_ || src_a.height_ != dst->height_ || src_a.channel_ != dst->channel_) {
713     return false;
714   } else if (src_a.data_type_ != dst->data_type_) {
715     return false;
716   }
717 
718   int64_t total_size = src_a.height_ * src_a.width_ * src_a.channel_;
719   if (src_a.data_type_ == LDataType::INT8) {
720     MultiplyImpl<int8_t>(src_a, src_b, *dst, total_size);
721   } else if (src_a.data_type_ == LDataType::UINT8) {
722     MultiplyImpl<uint8_t>(src_a, src_b, *dst, total_size);
723   } else if (src_a.data_type_ == LDataType::INT16) {
724     MultiplyImpl<int16_t>(src_a, src_b, *dst, total_size);
725   } else if (src_a.data_type_ == LDataType::UINT16) {
726     MultiplyImpl<uint16_t>(src_a, src_b, *dst, total_size);
727   } else if (src_a.data_type_ == LDataType::INT32) {
728     MultiplyImpl<int32_t>(src_a, src_b, *dst, total_size);
729   } else if (src_a.data_type_ == LDataType::UINT32) {
730     MultiplyImpl<uint32_t>(src_a, src_b, *dst, total_size);
731   } else if (src_a.data_type_ == LDataType::INT64) {
732     MultiplyImpl<int64_t>(src_a, src_b, *dst, total_size);
733   } else if (src_a.data_type_ == LDataType::UINT64) {
734     MultiplyImpl<uint64_t>(src_a, src_b, *dst, total_size);
735   } else if (src_a.data_type_ == LDataType::FLOAT32) {
736     MultiplyImpl<float>(src_a, src_b, *dst, total_size);
737   } else if (src_a.data_type_ == LDataType::FLOAT64) {
738     MultiplyImpl<double>(src_a, src_b, *dst, total_size);
739   } else {
740     return false;
741   }
742   return true;
743 }
744 
745 }  // namespace dataset
746 }  // namespace mindspore
747