• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "minddata/dataset/kernels/image/lite_cv/lite_mat.h"
17 
18 #include <algorithm>
19 #include <cmath>
20 #include <limits>
21 
22 #ifdef ENABLE_NEON
23 #include <arm_neon.h>
24 #endif
25 
26 namespace mindspore {
27 namespace dataset {
LiteMat()28 LiteMat::LiteMat() {
29   data_ptr_ = nullptr;
30   elem_size_ = 0;
31   width_ = 0;
32   height_ = 0;
33   channel_ = 0;
34   c_step_ = 0;
35   dims_ = 0;
36   size_ = 0;
37   data_type_ = LDataType::UINT8;
38   ref_count_ = nullptr;
39   setSteps(0, 0, 0);
40   release_flag_ = false;
41 }
42 
LiteMat(int width,LDataType data_type)43 LiteMat::LiteMat(int width, LDataType data_type) {
44   data_ptr_ = nullptr;
45   elem_size_ = 0;
46   width_ = 0;
47   height_ = 0;
48   channel_ = 0;
49   c_step_ = 0;
50   dims_ = 0;
51   data_type_ = LDataType::UINT8;
52   ref_count_ = nullptr;
53   size_ = 0;
54   setSteps(0, 0, 0);
55   release_flag_ = false;
56   Init(width, data_type);
57 }
58 
LiteMat(int width,int height,LDataType data_type)59 LiteMat::LiteMat(int width, int height, LDataType data_type) {
60   data_ptr_ = nullptr;
61   elem_size_ = 0;
62   width_ = 0;
63   height_ = 0;
64   channel_ = 0;
65   c_step_ = 0;
66   dims_ = 0;
67   data_type_ = LDataType::UINT8;
68   ref_count_ = nullptr;
69   size_ = 0;
70   setSteps(0, 0, 0);
71   release_flag_ = false;
72   Init(width, height, data_type);
73 }
74 
LiteMat(int width,int height,void * p_data,LDataType data_type)75 LiteMat::LiteMat(int width, int height, void *p_data, LDataType data_type) {
76   data_ptr_ = nullptr;
77   elem_size_ = 0;
78   width_ = 0;
79   height_ = 0;
80   channel_ = 0;
81   c_step_ = 0;
82   dims_ = 0;
83   data_type_ = LDataType::UINT8;
84   ref_count_ = nullptr;
85   size_ = 0;
86   setSteps(0, 0, 0);
87   release_flag_ = false;
88   Init(width, height, p_data, data_type);
89 }
90 
LiteMat(int width,int height,int channel,LDataType data_type)91 LiteMat::LiteMat(int width, int height, int channel, LDataType data_type) {
92   data_ptr_ = nullptr;
93   elem_size_ = 0;
94   width_ = 0;
95   height_ = 0;
96   channel_ = 0;
97   c_step_ = 0;
98   dims_ = 0;
99   data_type_ = LDataType::UINT8;
100   ref_count_ = nullptr;
101   size_ = 0;
102   setSteps(0, 0, 0);
103   release_flag_ = false;
104   Init(width, height, channel, data_type);
105 }
106 
LiteMat(int width,int height,int channel,void * p_data,LDataType data_type)107 LiteMat::LiteMat(int width, int height, int channel, void *p_data, LDataType data_type) {
108   data_ptr_ = nullptr;
109   elem_size_ = 0;
110   width_ = 0;
111   height_ = 0;
112   channel_ = 0;
113   c_step_ = 0;
114   dims_ = 0;
115   data_type_ = LDataType::UINT8;
116   ref_count_ = nullptr;
117   size_ = 0;
118   setSteps(0, 0, 0);
119   release_flag_ = false;
120   Init(width, height, channel, p_data, data_type);
121 }
122 
~LiteMat()123 LiteMat::~LiteMat() { Release(); }
124 
addRef(int * p,int value)125 int LiteMat::addRef(int *p, int value) {
126   int v = *p;
127   *p += value;
128   return v;
129 }
130 
LiteMat(const LiteMat & m)131 LiteMat::LiteMat(const LiteMat &m) {
132   data_ptr_ = m.data_ptr_;
133   elem_size_ = m.elem_size_;
134   width_ = m.width_;
135   height_ = m.height_;
136   channel_ = m.channel_;
137   c_step_ = m.c_step_;
138   dims_ = m.dims_;
139   data_type_ = m.data_type_;
140   ref_count_ = m.ref_count_;
141   size_ = m.size_;
142   release_flag_ = m.release_flag_;
143   setSteps(m.steps_[0], m.steps_[1], m.steps_[2]);
144   if (ref_count_) {
145     addRef(ref_count_, 1);
146   }
147 }
148 
setSteps(size_t c0,size_t c1,size_t c2)149 void LiteMat::setSteps(size_t c0, size_t c1, size_t c2) {
150   steps_[0] = c0;
151   steps_[1] = c1;
152   steps_[2] = c2;
153 }
154 
operator =(const LiteMat & m)155 LiteMat &LiteMat::operator=(const LiteMat &m) {
156   if (this == &m) {
157     return *this;
158   }
159 
160   if (m.ref_count_) {
161     addRef(m.ref_count_, 1);
162   }
163 
164   Release();
165   data_ptr_ = m.data_ptr_;
166   elem_size_ = m.elem_size_;
167   width_ = m.width_;
168   height_ = m.height_;
169   channel_ = m.channel_;
170   c_step_ = m.c_step_;
171   dims_ = m.dims_;
172   data_type_ = m.data_type_;
173   ref_count_ = m.ref_count_;
174   setSteps(m.steps_[0], m.steps_[1], m.steps_[2]);
175   size_ = m.size_;
176   release_flag_ = m.release_flag_;
177   return *this;
178 }
179 
Init(int width,LDataType data_type)180 void LiteMat::Init(int width, LDataType data_type) {
181   Release();
182   data_type_ = data_type;
183   InitElemSize(data_type);
184   width_ = width;
185   dims_ = 1;
186   height_ = 1;
187   channel_ = 1;
188   if (!CheckLiteMat()) {
189     Release();
190     return;
191   }
192   c_step_ = width;
193   size_ = c_step_ * elem_size_;
194   data_ptr_ = AlignMalloc(size_);
195   ref_count_ = new int[1];
196   *ref_count_ = 1;
197   steps_[0] = elem_size_;
198 }
199 
Init(int width,int height,LDataType data_type)200 void LiteMat::Init(int width, int height, LDataType data_type) {
201   Release();
202   data_type_ = data_type;
203   InitElemSize(data_type);
204   width_ = width;
205   height_ = height;
206   dims_ = 2;
207   channel_ = 1;
208   if (!CheckLiteMat()) {
209     Release();
210     return;
211   }
212   c_step_ = width_ * height_;
213   size_ = c_step_ * elem_size_;
214   data_ptr_ = AlignMalloc(size_);
215   ref_count_ = new int[1];
216   *ref_count_ = 1;
217   steps_[1] = elem_size_;
218   steps_[0] = width_ * steps_[1];
219 }
220 
Init(int width,int height,void * p_data,LDataType data_type)221 void LiteMat::Init(int width, int height, void *p_data, LDataType data_type) {
222   data_type_ = data_type;
223   InitElemSize(data_type);
224   width_ = width;
225   height_ = height;
226   dims_ = 2;
227   channel_ = 1;
228   if (!CheckLiteMat()) {
229     Release();
230     return;
231   }
232   c_step_ = height_ * width_;
233   size_ = c_step_ * channel_ * elem_size_;
234   data_ptr_ = p_data;
235   ref_count_ = nullptr;
236   steps_[1] = elem_size_;
237   steps_[0] = width_ * steps_[1];
238 }
239 
Init(int width,int height,int channel,const LDataType & data_type,bool align_memory)240 void LiteMat::Init(int width, int height, int channel, const LDataType &data_type, bool align_memory) {
241   Release();
242   data_type_ = data_type;
243   InitElemSize(data_type);
244   width_ = width;
245   height_ = height;
246   dims_ = 3;
247   channel_ = channel;
248   if (!CheckLiteMat()) {
249     Release();
250     return;
251   }
252   if (align_memory) {
253     c_step_ = ((height_ * width_ * elem_size_ + kAlign - 1) & (-kAlign)) / elem_size_;
254   } else {
255     c_step_ = height_ * width_;
256   }
257   size_ = c_step_ * channel_ * elem_size_;
258   data_ptr_ = AlignMalloc(size_);
259   ref_count_ = new int[1];
260   *ref_count_ = 1;
261 
262   steps_[2] = elem_size_;
263   steps_[1] = channel * steps_[2];
264   steps_[0] = width_ * steps_[1];
265 }
266 
Init(int width,int height,int channel,void * p_data,LDataType data_type)267 void LiteMat::Init(int width, int height, int channel, void *p_data, LDataType data_type) {
268   data_type_ = data_type;
269   InitElemSize(data_type);
270   width_ = width;
271   height_ = height;
272   dims_ = 3;
273   channel_ = channel;
274   if (!CheckLiteMat()) {
275     Release();
276     return;
277   }
278   c_step_ = height_ * width_;
279   size_ = c_step_ * channel_ * elem_size_;
280   data_ptr_ = p_data;
281   ref_count_ = nullptr;
282   steps_[2] = elem_size_;
283   steps_[1] = channel * steps_[2];
284   steps_[0] = width_ * steps_[1];
285 }
286 
IsEmpty() const287 bool LiteMat::IsEmpty() const { return data_ptr_ == nullptr || c_step_ * channel_ == 0; }
288 
Release()289 void LiteMat::Release() {
290   if (ref_count_ && (addRef(ref_count_, -1) == 1)) {
291     if (data_ptr_) {
292       AlignFree(data_ptr_);
293     }
294     delete[] ref_count_;
295   }
296   data_ptr_ = nullptr;
297   elem_size_ = 0;
298   width_ = 0;
299   height_ = 0;
300   channel_ = 0;
301   c_step_ = 0;
302   ref_count_ = nullptr;
303   size_ = 0;
304   setSteps(0, 0, 0);
305 }
306 
AlignMalloc(unsigned int size)307 void *LiteMat::AlignMalloc(unsigned int size) {
308   unsigned int length = sizeof(void *) + kAlign - 1;
309   if (size > std::numeric_limits<uint32_t>::max() - length) {
310     return nullptr;
311   }
312   void *p_raw = reinterpret_cast<void *>(malloc(size + length));
313   if (p_raw) {
314     release_flag_ = true;
315     void **p_algin = reinterpret_cast<void **>((reinterpret_cast<size_t>(p_raw) + length) & ~(kAlign - 1));
316     p_algin[-1] = p_raw;
317     return p_algin;
318   }
319   return nullptr;
320 }
321 
AlignFree(void * ptr)322 void LiteMat::AlignFree(void *ptr) {
323   if (release_flag_) {
324     (void)free(reinterpret_cast<void **>(ptr)[-1]);
325     ptr = nullptr;
326     release_flag_ = false;
327   }
328 }
329 
InitElemSize(LDataType data_type)330 inline void LiteMat::InitElemSize(LDataType data_type) { elem_size_ = data_type.SizeInBytes(); }
331 
CheckLiteMat() const332 bool LiteMat::CheckLiteMat() const {
333   if (width_ <= 0 || height_ <= 0 || channel_ <= 0 || elem_size_ <= 0) {
334     return false;
335   }
336   if (height_ != 1 && height_ > std::numeric_limits<int>::max() / width_) {
337     return false;
338   }
339   int area = height_ * width_;
340   if (channel_ != 1 && channel_ > std::numeric_limits<int>::max() / area) {
341     return false;
342   }
343   int size = area * channel_;
344   if (elem_size_ > std::numeric_limits<int>::max() / size) {
345     return false;
346   }
347   return true;
348 }
349 
GetROI(int x,int y,int w,int h,LiteMat & m)350 bool LiteMat::GetROI(int x, int y, int w, int h, LiteMat &m) {
351   if (x < 0 || y < 0 || x > width_ - w || h > height_ - y || w <= 0 || h <= 0) {
352     return false;
353   }
354   if (!m.IsEmpty()) {
355     m.Release();
356   }
357 
358   if (ref_count_) {
359     addRef(ref_count_, 1);
360   }
361 
362   m.height_ = h;
363   m.width_ = w;
364   m.dims_ = dims_;
365   m.elem_size_ = elem_size_;
366   m.data_ptr_ = reinterpret_cast<uint8_t *>(data_ptr_) + y * steps_[0] + x * elem_size_ * channel_;
367   m.channel_ = channel_;
368   m.c_step_ = c_step_;
369   m.data_type_ = data_type_;
370   m.ref_count_ = ref_count_;
371   m.setSteps(steps_[0], steps_[1], steps_[2]);
372   return true;
373 }
374 
375 template <typename T>
SubtractImpl(const T * src0,const T * src1,T * dst,int64_t total_size)376 inline void SubtractImpl(const T *src0, const T *src1, T *dst, int64_t total_size) {
377   for (int64_t i = 0; i < total_size; i++) {
378     dst[i] = src0[i] - src1[i];
379   }
380 }
381 
382 template <>
SubtractImpl(const uint8_t * src0,const uint8_t * src1,uint8_t * dst,int64_t total_size)383 inline void SubtractImpl(const uint8_t *src0, const uint8_t *src1, uint8_t *dst, int64_t total_size) {
384   int64_t x = 0;
385 #ifdef ENABLE_NEON
386   const int64_t step = 32;
387   for (; x <= total_size - step; x += step) {
388     uint8x16_t v_src00 = vld1q_u8(src0 + x);
389     uint8x16_t v_src01 = vld1q_u8(src0 + x + 16);
390     uint8x16_t v_src10 = vld1q_u8(src1 + x);
391     uint8x16_t v_src11 = vld1q_u8(src1 + x + 16);
392     uint8x16_t v_dst;
393 
394     v_dst = vqsubq_u8(v_src00, v_src10);
395     vst1q_u8(dst + x, v_dst);
396 
397     v_dst = vqsubq_u8(v_src01, v_src11);
398     vst1q_u8(dst + x + 16, v_dst);
399   }
400 #endif
401   for (; x < total_size; x++) {
402     int32_t val = static_cast<int32_t>(src0[x]) - src1[x];
403     dst[x] = std::max<int32_t>(std::numeric_limits<uint8_t>::min(),
404                                std::min<int32_t>(std::numeric_limits<uint8_t>::max(), val));
405   }
406 }
407 
408 template <>
SubtractImpl(const uint16_t * src0,const uint16_t * src1,uint16_t * dst,int64_t total_size)409 inline void SubtractImpl(const uint16_t *src0, const uint16_t *src1, uint16_t *dst, int64_t total_size) {
410   for (int64_t i = 0; i < total_size; i++) {
411     int32_t val = static_cast<int32_t>(src0[i]) - src1[i];
412     dst[i] = std::max<int32_t>(std::numeric_limits<uint16_t>::min(),
413                                std::min<int32_t>(std::numeric_limits<uint16_t>::max(), val));
414   }
415 }
416 
417 template <>
SubtractImpl(const uint32_t * src0,const uint32_t * src1,uint32_t * dst,int64_t total_size)418 inline void SubtractImpl(const uint32_t *src0, const uint32_t *src1, uint32_t *dst, int64_t total_size) {
419   for (int64_t i = 0; i < total_size; i++) {
420     int64_t val = static_cast<int64_t>(src0[i]) - src1[i];
421     dst[i] = std::max<int64_t>(std::numeric_limits<uint32_t>::min(),
422                                std::min<int64_t>(std::numeric_limits<uint32_t>::max(), val));
423   }
424 }
425 
CheckSubstract(const LiteMat & src_a,const LiteMat & src_b,LiteMat * dst)426 inline bool CheckSubstract(const LiteMat &src_a, const LiteMat &src_b, LiteMat *dst) {
427   if (dst == nullptr) {
428     return false;
429   }
430 
431   if (src_a.width_ != src_b.width_ || src_a.height_ != src_b.height_ || src_a.channel_ != src_b.channel_) {
432     return false;
433   }
434 
435   return src_a.data_type_ == src_b.data_type_;
436 }
437 
Subtract(const LiteMat & src_a,const LiteMat & src_b,LiteMat * dst)438 bool Subtract(const LiteMat &src_a, const LiteMat &src_b, LiteMat *dst) {
439   if (!CheckSubstract(src_a, src_b, dst)) {
440     return false;
441   }
442 
443   if (dst->IsEmpty()) {
444     dst->Init(src_a.width_, src_a.height_, src_a.channel_, src_a.data_type_);
445   }
446   if (src_a.width_ != dst->width_ || src_a.height_ != dst->height_ || src_a.channel_ != dst->channel_) {
447     return false;
448   }
449   if (src_a.data_type_ != dst->data_type_) {
450     return false;
451   }
452 
453   int64_t total_size = src_a.height_ * src_a.width_ * src_a.channel_;
454   if (src_a.data_type_ == LDataType::BOOL) {
455     SubtractImpl<bool>(src_a, src_b, *dst, total_size);
456   } else if (src_a.data_type_ == LDataType::INT8) {
457     SubtractImpl<int8_t>(src_a, src_b, *dst, total_size);
458   } else if (src_a.data_type_ == LDataType::UINT8) {
459     SubtractImpl<uint8_t>(src_a, src_b, *dst, total_size);
460   } else if (src_a.data_type_ == LDataType::INT16) {
461     SubtractImpl<int16_t>(src_a, src_b, *dst, total_size);
462   } else if (src_a.data_type_ == LDataType::UINT16) {
463     SubtractImpl<uint16_t>(src_a, src_b, *dst, total_size);
464   } else if (src_a.data_type_ == LDataType::INT32) {
465     SubtractImpl<int32_t>(src_a, src_b, *dst, total_size);
466   } else if (src_a.data_type_ == LDataType::UINT32) {
467     SubtractImpl<uint32_t>(src_a, src_b, *dst, total_size);
468   } else if (src_a.data_type_ == LDataType::INT64) {
469     SubtractImpl<int64_t>(src_a, src_b, *dst, total_size);
470   } else if (src_a.data_type_ == LDataType::UINT64) {
471     SubtractImpl<uint64_t>(src_a, src_b, *dst, total_size);
472   } else if (src_a.data_type_ == LDataType::FLOAT32) {
473     SubtractImpl<float>(src_a, src_b, *dst, total_size);
474   } else if (src_a.data_type_ == LDataType::FLOAT64) {
475     SubtractImpl<double>(src_a, src_b, *dst, total_size);
476   } else {
477     return false;
478   }
479 
480   return true;
481 }
482 
483 #ifdef ENABLE_NEON
reciprocal_simd(float32x4_t val)484 inline float32x4_t reciprocal_simd(float32x4_t val) {
485   // get an initial estimate of 1/val
486   float32x4_t reciprocal = vrecpeq_f32(val);
487 
488   // use Newton-Raphson steps to refine the estimate
489   reciprocal = vmulq_f32(vrecpsq_f32(val, reciprocal), reciprocal);
490   reciprocal = vmulq_f32(vrecpsq_f32(val, reciprocal), reciprocal);
491   return reciprocal;
492 }
493 
round_simd(const float32x4_t & v)494 inline float32x4_t round_simd(const float32x4_t &v) {
495   const int32x4_t signMask = vdupq_n_s32(1U << 31);
496   const int32x4_t half = vreinterpretq_s32_f32(vdupq_n_f32(0.5f));
497   float32x4_t v_addition = vreinterpretq_f32_s32(vorrq_s32(half, vandq_s32(signMask, vreinterpretq_s32_f32(v))));
498   return vaddq_f32(v, v_addition);
499 }
500 #endif
501 
502 template <typename T>
DivideImpl(const T * src0,const T * src1,T * dst,int64_t total_size)503 inline void DivideImpl(const T *src0, const T *src1, T *dst, int64_t total_size) {
504   for (int64_t i = 0; i < total_size; i++) {
505     dst[i] = src1[i] ? src0[i] / src1[i] : 0;
506   }
507 }
508 
509 template <>
DivideImpl(const uint8_t * src0,const uint8_t * src1,uint8_t * dst,int64_t total_size)510 inline void DivideImpl(const uint8_t *src0, const uint8_t *src1, uint8_t *dst, int64_t total_size) {
511   int64_t x = 0;
512 #ifdef ENABLE_NEON
513   const int64_t step = 16;
514   for (; x <= total_size - step; x += step) {
515     __builtin_prefetch(reinterpret_cast<const char *>(src0 + x) + 32 * 10);
516     __builtin_prefetch(reinterpret_cast<const char *>(src1 + x) + 32 * 10);
517 
518     uint8x16_t v_a = vld1q_u8(src0 + x);
519     uint8x16_t v_b = vld1q_u8(src1 + x);
520     uint8x16_t v_mask = vtstq_u8(v_b, v_b);
521 
522     uint16x8_t va_l_16x8 = vmovl_u8(vget_low_u8(v_a));
523     uint16x8_t va_h_16x8 = vmovl_u8(vget_high_u8(v_a));
524     uint16x8_t vb_l_16x8 = vmovl_u8(vget_low_u8(v_b));
525     uint16x8_t vb_h_16x8 = vmovl_u8(vget_high_u8(v_b));
526 
527     float32x4_t va_ll_f32x4 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(va_l_16x8)));
528     float32x4_t va_lh_f32x4 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(va_l_16x8)));
529     float32x4_t va_hl_f32x4 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(va_h_16x8)));
530     float32x4_t va_hh_f32x4 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(va_h_16x8)));
531     float32x4_t vb_ll_f32x4 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(vb_l_16x8)));
532     float32x4_t vb_lh_f32x4 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vb_l_16x8)));
533     float32x4_t vb_hl_f32x4 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(vb_h_16x8)));
534     float32x4_t vb_hh_f32x4 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vb_h_16x8)));
535 
536     float32x4_t vb_ll_re_f32x4 = reciprocal_simd(vb_ll_f32x4);
537     float32x4_t vb_lh_re_f32x4 = reciprocal_simd(vb_lh_f32x4);
538     float32x4_t vb_hl_re_f32x4 = reciprocal_simd(vb_hl_f32x4);
539     float32x4_t vb_hh_re_f32x4 = reciprocal_simd(vb_hh_f32x4);
540 
541     float32x4_t dst_ll_f32x4 = round_simd(vmulq_f32(va_ll_f32x4, vb_ll_re_f32x4));
542     float32x4_t dst_lh_f32x4 = round_simd(vmulq_f32(va_lh_f32x4, vb_lh_re_f32x4));
543     float32x4_t dst_hl_f32x4 = round_simd(vmulq_f32(va_hl_f32x4, vb_hl_re_f32x4));
544     float32x4_t dst_hh_f32x4 = round_simd(vmulq_f32(va_hh_f32x4, vb_hh_re_f32x4));
545 
546     uint32x4_t dst_ll_32x4 = vcvtq_u32_f32(dst_ll_f32x4);
547     uint32x4_t dst_lh_32x4 = vcvtq_u32_f32(dst_lh_f32x4);
548     uint32x4_t dst_hl_32x4 = vcvtq_u32_f32(dst_hl_f32x4);
549     uint32x4_t dst_hh_32x4 = vcvtq_u32_f32(dst_hh_f32x4);
550 
551     uint16x4_t dst_ll_16x4 = vqmovn_u32(dst_ll_32x4);
552     uint16x4_t dst_lh_16x4 = vqmovn_u32(dst_lh_32x4);
553     uint16x4_t dst_hl_16x4 = vqmovn_u32(dst_hl_32x4);
554     uint16x4_t dst_hh_16x4 = vqmovn_u32(dst_hh_32x4);
555 
556     uint16x8_t dst_l_16x8 = vcombine_u16(dst_ll_16x4, dst_lh_16x4);
557     uint16x8_t dst_h_16x8 = vcombine_u16(dst_hl_16x4, dst_hh_16x4);
558 
559     uint8x8_t dst_l_8x8 = vqmovn_u16(dst_l_16x8);
560     uint8x8_t dst_h_8x8 = vqmovn_u16(dst_h_16x8);
561     uint8x16_t dst_8x16 = vcombine_u8(dst_l_8x8, dst_h_8x8);
562 
563     dst_8x16 = vandq_u8(dst_8x16, v_mask);
564     vst1q_u8(dst + x, dst_8x16);
565   }
566 #endif
567   for (; x < total_size; x++) {
568     int32_t val = src1[x] ? static_cast<int32_t>(std::round(src0[x] / src1[x])) : 0;
569     dst[x] = std::max<int32_t>(std::numeric_limits<uint8_t>::min(),
570                                std::min<int32_t>(std::numeric_limits<uint8_t>::max(), val));
571   }
572 }
573 
574 template <>
DivideImpl(const uint16_t * src0,const uint16_t * src1,uint16_t * dst,int64_t total_size)575 inline void DivideImpl(const uint16_t *src0, const uint16_t *src1, uint16_t *dst, int64_t total_size) {
576   for (size_t i = 0; i < total_size; i++) {
577     int32_t val = src1[i] ? static_cast<int32_t>(std::round(src0[i] / src1[i])) : 0;
578     dst[i] = std::max<int32_t>(std::numeric_limits<uint16_t>::min(),
579                                std::min<int32_t>(std::numeric_limits<uint16_t>::max(), val));
580   }
581 }
582 
583 template <>
DivideImpl(const uint32_t * src0,const uint32_t * src1,uint32_t * dst,int64_t total_size)584 inline void DivideImpl(const uint32_t *src0, const uint32_t *src1, uint32_t *dst, int64_t total_size) {
585   for (size_t i = 0; i < total_size; i++) {
586     int64_t val = src1[i] ? static_cast<int64_t>(std::round(src0[i] / src1[i])) : 0;
587     dst[i] = std::max<int64_t>(std::numeric_limits<uint32_t>::min(),
588                                std::min<int64_t>(std::numeric_limits<uint32_t>::max(), val));
589   }
590 }
591 
CheckDivide(const LiteMat & src_a,const LiteMat & src_b,LiteMat * dst)592 inline bool CheckDivide(const LiteMat &src_a, const LiteMat &src_b, LiteMat *dst) {
593   if (dst == nullptr) {
594     return false;
595   }
596 
597   if (src_a.width_ != src_b.width_ || src_a.height_ != src_b.height_ || src_a.channel_ != src_b.channel_) {
598     return false;
599   }
600 
601   return src_a.data_type_ == src_b.data_type_;
602 }
603 
Divide(const LiteMat & src_a,const LiteMat & src_b,LiteMat * dst)604 bool Divide(const LiteMat &src_a, const LiteMat &src_b, LiteMat *dst) {
605   if (!CheckDivide(src_a, src_b, dst)) {
606     return false;
607   }
608 
609   if (dst->IsEmpty()) {
610     dst->Init(src_a.width_, src_a.height_, src_a.channel_, src_a.data_type_);
611   } else if (src_a.width_ != dst->width_ || src_a.height_ != dst->height_ || src_a.channel_ != dst->channel_) {
612     return false;
613   } else if (src_a.data_type_ != dst->data_type_) {
614     return false;
615   }
616 
617   int64_t total_size = src_a.height_ * src_a.width_ * src_a.channel_;
618   if (src_a.data_type_ == LDataType::INT8) {
619     DivideImpl<int8_t>(src_a, src_b, *dst, total_size);
620   } else if (src_a.data_type_ == LDataType::UINT8) {
621     DivideImpl<uint8_t>(src_a, src_b, *dst, total_size);
622   } else if (src_a.data_type_ == LDataType::INT16) {
623     DivideImpl<int16_t>(src_a, src_b, *dst, total_size);
624   } else if (src_a.data_type_ == LDataType::UINT16) {
625     DivideImpl<uint16_t>(src_a, src_b, *dst, total_size);
626   } else if (src_a.data_type_ == LDataType::INT32) {
627     DivideImpl<int32_t>(src_a, src_b, *dst, total_size);
628   } else if (src_a.data_type_ == LDataType::UINT32) {
629     DivideImpl<uint32_t>(src_a, src_b, *dst, total_size);
630   } else if (src_a.data_type_ == LDataType::INT64) {
631     DivideImpl<int64_t>(src_a, src_b, *dst, total_size);
632   } else if (src_a.data_type_ == LDataType::UINT64) {
633     DivideImpl<uint64_t>(src_a, src_b, *dst, total_size);
634   } else if (src_a.data_type_ == LDataType::FLOAT32) {
635     DivideImpl<float>(src_a, src_b, *dst, total_size);
636   } else if (src_a.data_type_ == LDataType::FLOAT64) {
637     DivideImpl<double>(src_a, src_b, *dst, total_size);
638   } else {
639     return false;
640   }
641   return true;
642 }
643 
644 template <typename T>
MultiplyImpl(const T * src0,const T * src1,T * dst,int64_t total_size)645 inline void MultiplyImpl(const T *src0, const T *src1, T *dst, int64_t total_size) {
646   for (int64_t i = 0; i < total_size; i++) {
647     dst[i] = src0[i] * src1[i];
648   }
649 }
650 
651 template <>
MultiplyImpl(const uint8_t * src0,const uint8_t * src1,uint8_t * dst,int64_t total_size)652 inline void MultiplyImpl(const uint8_t *src0, const uint8_t *src1, uint8_t *dst, int64_t total_size) {
653   int64_t x = 0;
654 #ifdef ENABLE_NEON
655   const int64_t step = 32;
656   for (; x <= total_size - step; x += step) {
657     uint8x16_t v_src00 = vld1q_u8(src0 + x);
658     uint8x16_t v_src01 = vld1q_u8(src0 + x + 16);
659     uint8x16_t v_src10 = vld1q_u8(src1 + x);
660     uint8x16_t v_src11 = vld1q_u8(src1 + x + 16);
661     uint16x8_t v_dst_l, v_dst_h;
662 
663     v_dst_l = vmull_u8(vget_low_u8(v_src00), vget_low_u8(v_src10));
664     v_dst_h = vmull_u8(vget_high_u8(v_src00), vget_high_u8(v_src10));
665     vst1q_u8(dst + x, vcombine_u8(vqmovn_u16(v_dst_l), vqmovn_u16(v_dst_h)));
666 
667     v_dst_l = vmull_u8(vget_low_u8(v_src01), vget_low_u8(v_src11));
668     v_dst_h = vmull_u8(vget_high_u8(v_src01), vget_high_u8(v_src11));
669     vst1q_u8(dst + x + 16, vcombine_u8(vqmovn_u16(v_dst_l), vqmovn_u16(v_dst_h)));
670   }
671 #endif
672   for (; x < total_size; x++) {
673     int32_t val = src0[x] * src1[x];
674     dst[x] = std::max<int32_t>(std::numeric_limits<uint8_t>::min(),
675                                std::min<int32_t>(std::numeric_limits<uint8_t>::max(), val));
676   }
677 }
678 
679 template <>
MultiplyImpl(const uint16_t * src0,const uint16_t * src1,uint16_t * dst,int64_t total_size)680 inline void MultiplyImpl(const uint16_t *src0, const uint16_t *src1, uint16_t *dst, int64_t total_size) {
681   for (size_t i = 0; i < total_size; i++) {
682     int32_t val = src0[i] * src1[i];
683     dst[i] = std::max<int32_t>(std::numeric_limits<uint16_t>::min(),
684                                std::min<int32_t>(std::numeric_limits<uint16_t>::max(), val));
685   }
686 }
687 
688 template <>
MultiplyImpl(const uint32_t * src0,const uint32_t * src1,uint32_t * dst,int64_t total_size)689 inline void MultiplyImpl(const uint32_t *src0, const uint32_t *src1, uint32_t *dst, int64_t total_size) {
690   for (size_t i = 0; i < total_size; i++) {
691     int64_t val = src0[i] * src1[i];
692     dst[i] = std::max<int64_t>(std::numeric_limits<uint32_t>::min(),
693                                std::min<int64_t>(std::numeric_limits<uint32_t>::max(), val));
694   }
695 }
696 
CheckMultiply(const LiteMat & src_a,const LiteMat & src_b,LiteMat * dst)697 inline bool CheckMultiply(const LiteMat &src_a, const LiteMat &src_b, LiteMat *dst) {
698   if (dst == nullptr) {
699     return false;
700   }
701 
702   if (src_a.width_ != src_b.width_ || src_a.height_ != src_b.height_ || src_a.channel_ != src_b.channel_) {
703     return false;
704   }
705 
706   return src_a.data_type_ == src_b.data_type_;
707 }
708 
Multiply(const LiteMat & src_a,const LiteMat & src_b,LiteMat * dst)709 bool Multiply(const LiteMat &src_a, const LiteMat &src_b, LiteMat *dst) {
710   if (!CheckMultiply(src_a, src_b, dst)) {
711     return false;
712   }
713   if (dst->IsEmpty()) {
714     dst->Init(src_a.width_, src_a.height_, src_a.channel_, src_a.data_type_);
715   } else if (src_a.width_ != dst->width_ || src_a.height_ != dst->height_ || src_a.channel_ != dst->channel_) {
716     return false;
717   } else if (src_a.data_type_ != dst->data_type_) {
718     return false;
719   }
720 
721   int64_t total_size = src_a.height_ * src_a.width_ * src_a.channel_;
722   if (src_a.data_type_ == LDataType::INT8) {
723     MultiplyImpl<int8_t>(src_a, src_b, *dst, total_size);
724   } else if (src_a.data_type_ == LDataType::UINT8) {
725     MultiplyImpl<uint8_t>(src_a, src_b, *dst, total_size);
726   } else if (src_a.data_type_ == LDataType::INT16) {
727     MultiplyImpl<int16_t>(src_a, src_b, *dst, total_size);
728   } else if (src_a.data_type_ == LDataType::UINT16) {
729     MultiplyImpl<uint16_t>(src_a, src_b, *dst, total_size);
730   } else if (src_a.data_type_ == LDataType::INT32) {
731     MultiplyImpl<int32_t>(src_a, src_b, *dst, total_size);
732   } else if (src_a.data_type_ == LDataType::UINT32) {
733     MultiplyImpl<uint32_t>(src_a, src_b, *dst, total_size);
734   } else if (src_a.data_type_ == LDataType::INT64) {
735     MultiplyImpl<int64_t>(src_a, src_b, *dst, total_size);
736   } else if (src_a.data_type_ == LDataType::UINT64) {
737     MultiplyImpl<uint64_t>(src_a, src_b, *dst, total_size);
738   } else if (src_a.data_type_ == LDataType::FLOAT32) {
739     MultiplyImpl<float>(src_a, src_b, *dst, total_size);
740   } else if (src_a.data_type_ == LDataType::FLOAT64) {
741     MultiplyImpl<double>(src_a, src_b, *dst, total_size);
742   } else {
743     return false;
744   }
745   return true;
746 }
747 
748 }  // namespace dataset
749 }  // namespace mindspore
750