1 /**
2 * Copyright 2020-2023 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "minddata/dataset/kernels/image/lite_cv/lite_mat.h"
17
18 #include <algorithm>
19 #include <cmath>
20 #include <limits>
21
22 #ifdef ENABLE_NEON
23 #include <arm_neon.h>
24 #endif
25
26 namespace mindspore {
27 namespace dataset {
LiteMat()28 LiteMat::LiteMat() {
29 data_ptr_ = nullptr;
30 elem_size_ = 0;
31 width_ = 0;
32 height_ = 0;
33 channel_ = 0;
34 c_step_ = 0;
35 dims_ = 0;
36 size_ = 0;
37 data_type_ = LDataType::UINT8;
38 ref_count_ = nullptr;
39 setSteps(0, 0, 0);
40 release_flag_ = false;
41 }
42
LiteMat(int width,LDataType data_type)43 LiteMat::LiteMat(int width, LDataType data_type) {
44 data_ptr_ = nullptr;
45 elem_size_ = 0;
46 width_ = 0;
47 height_ = 0;
48 channel_ = 0;
49 c_step_ = 0;
50 dims_ = 0;
51 data_type_ = LDataType::UINT8;
52 ref_count_ = nullptr;
53 size_ = 0;
54 setSteps(0, 0, 0);
55 release_flag_ = false;
56 Init(width, data_type);
57 }
58
LiteMat(int width,int height,LDataType data_type)59 LiteMat::LiteMat(int width, int height, LDataType data_type) {
60 data_ptr_ = nullptr;
61 elem_size_ = 0;
62 width_ = 0;
63 height_ = 0;
64 channel_ = 0;
65 c_step_ = 0;
66 dims_ = 0;
67 data_type_ = LDataType::UINT8;
68 ref_count_ = nullptr;
69 size_ = 0;
70 setSteps(0, 0, 0);
71 release_flag_ = false;
72 Init(width, height, data_type);
73 }
74
LiteMat(int width,int height,void * p_data,LDataType data_type)75 LiteMat::LiteMat(int width, int height, void *p_data, LDataType data_type) {
76 data_ptr_ = nullptr;
77 elem_size_ = 0;
78 width_ = 0;
79 height_ = 0;
80 channel_ = 0;
81 c_step_ = 0;
82 dims_ = 0;
83 data_type_ = LDataType::UINT8;
84 ref_count_ = nullptr;
85 size_ = 0;
86 setSteps(0, 0, 0);
87 release_flag_ = false;
88 Init(width, height, p_data, data_type);
89 }
90
LiteMat(int width,int height,int channel,LDataType data_type)91 LiteMat::LiteMat(int width, int height, int channel, LDataType data_type) {
92 data_ptr_ = nullptr;
93 elem_size_ = 0;
94 width_ = 0;
95 height_ = 0;
96 channel_ = 0;
97 c_step_ = 0;
98 dims_ = 0;
99 data_type_ = LDataType::UINT8;
100 ref_count_ = nullptr;
101 size_ = 0;
102 setSteps(0, 0, 0);
103 release_flag_ = false;
104 Init(width, height, channel, data_type);
105 }
106
LiteMat(int width,int height,int channel,void * p_data,LDataType data_type)107 LiteMat::LiteMat(int width, int height, int channel, void *p_data, LDataType data_type) {
108 data_ptr_ = nullptr;
109 elem_size_ = 0;
110 width_ = 0;
111 height_ = 0;
112 channel_ = 0;
113 c_step_ = 0;
114 dims_ = 0;
115 data_type_ = LDataType::UINT8;
116 ref_count_ = nullptr;
117 size_ = 0;
118 setSteps(0, 0, 0);
119 release_flag_ = false;
120 Init(width, height, channel, p_data, data_type);
121 }
122
~LiteMat()123 LiteMat::~LiteMat() { Release(); }
124
addRef(int * p,int value)125 int LiteMat::addRef(int *p, int value) {
126 int v = *p;
127 *p += value;
128 return v;
129 }
130
LiteMat(const LiteMat & m)131 LiteMat::LiteMat(const LiteMat &m) {
132 data_ptr_ = m.data_ptr_;
133 elem_size_ = m.elem_size_;
134 width_ = m.width_;
135 height_ = m.height_;
136 channel_ = m.channel_;
137 c_step_ = m.c_step_;
138 dims_ = m.dims_;
139 data_type_ = m.data_type_;
140 ref_count_ = m.ref_count_;
141 size_ = m.size_;
142 release_flag_ = m.release_flag_;
143 setSteps(m.steps_[0], m.steps_[1], m.steps_[2]);
144 if (ref_count_) {
145 addRef(ref_count_, 1);
146 }
147 }
148
setSteps(size_t c0,size_t c1,size_t c2)149 void LiteMat::setSteps(size_t c0, size_t c1, size_t c2) {
150 steps_[0] = c0;
151 steps_[1] = c1;
152 steps_[2] = c2;
153 }
154
operator =(const LiteMat & m)155 LiteMat &LiteMat::operator=(const LiteMat &m) {
156 if (this == &m) {
157 return *this;
158 }
159
160 if (m.ref_count_) {
161 addRef(m.ref_count_, 1);
162 }
163
164 Release();
165 data_ptr_ = m.data_ptr_;
166 elem_size_ = m.elem_size_;
167 width_ = m.width_;
168 height_ = m.height_;
169 channel_ = m.channel_;
170 c_step_ = m.c_step_;
171 dims_ = m.dims_;
172 data_type_ = m.data_type_;
173 ref_count_ = m.ref_count_;
174 setSteps(m.steps_[0], m.steps_[1], m.steps_[2]);
175 size_ = m.size_;
176 release_flag_ = m.release_flag_;
177 return *this;
178 }
179
Init(int width,LDataType data_type)180 void LiteMat::Init(int width, LDataType data_type) {
181 Release();
182 data_type_ = data_type;
183 InitElemSize(data_type);
184 width_ = width;
185 dims_ = 1;
186 height_ = 1;
187 channel_ = 1;
188 if (!CheckLiteMat()) {
189 Release();
190 return;
191 }
192 c_step_ = width;
193 size_ = c_step_ * elem_size_;
194 data_ptr_ = AlignMalloc(size_);
195 ref_count_ = new int[1];
196 *ref_count_ = 1;
197 steps_[0] = elem_size_;
198 }
199
Init(int width,int height,LDataType data_type)200 void LiteMat::Init(int width, int height, LDataType data_type) {
201 Release();
202 data_type_ = data_type;
203 InitElemSize(data_type);
204 width_ = width;
205 height_ = height;
206 dims_ = 2;
207 channel_ = 1;
208 if (!CheckLiteMat()) {
209 Release();
210 return;
211 }
212 c_step_ = width_ * height_;
213 size_ = c_step_ * elem_size_;
214 data_ptr_ = AlignMalloc(size_);
215 ref_count_ = new int[1];
216 *ref_count_ = 1;
217 steps_[1] = elem_size_;
218 steps_[0] = width_ * steps_[1];
219 }
220
Init(int width,int height,void * p_data,LDataType data_type)221 void LiteMat::Init(int width, int height, void *p_data, LDataType data_type) {
222 data_type_ = data_type;
223 InitElemSize(data_type);
224 width_ = width;
225 height_ = height;
226 dims_ = 2;
227 channel_ = 1;
228 if (!CheckLiteMat()) {
229 Release();
230 return;
231 }
232 c_step_ = height_ * width_;
233 size_ = c_step_ * channel_ * elem_size_;
234 data_ptr_ = p_data;
235 ref_count_ = nullptr;
236 steps_[1] = elem_size_;
237 steps_[0] = width_ * steps_[1];
238 }
239
Init(int width,int height,int channel,const LDataType & data_type,bool align_memory)240 void LiteMat::Init(int width, int height, int channel, const LDataType &data_type, bool align_memory) {
241 Release();
242 data_type_ = data_type;
243 InitElemSize(data_type);
244 width_ = width;
245 height_ = height;
246 dims_ = 3;
247 channel_ = channel;
248 if (!CheckLiteMat()) {
249 Release();
250 return;
251 }
252 if (align_memory) {
253 c_step_ = ((height_ * width_ * elem_size_ + kAlign - 1) & (-kAlign)) / elem_size_;
254 } else {
255 c_step_ = height_ * width_;
256 }
257 size_ = c_step_ * channel_ * elem_size_;
258 data_ptr_ = AlignMalloc(size_);
259 ref_count_ = new int[1];
260 *ref_count_ = 1;
261
262 steps_[2] = elem_size_;
263 steps_[1] = channel * steps_[2];
264 steps_[0] = width_ * steps_[1];
265 }
266
Init(int width,int height,int channel,void * p_data,LDataType data_type)267 void LiteMat::Init(int width, int height, int channel, void *p_data, LDataType data_type) {
268 data_type_ = data_type;
269 InitElemSize(data_type);
270 width_ = width;
271 height_ = height;
272 dims_ = 3;
273 channel_ = channel;
274 if (!CheckLiteMat()) {
275 Release();
276 return;
277 }
278 c_step_ = height_ * width_;
279 size_ = c_step_ * channel_ * elem_size_;
280 data_ptr_ = p_data;
281 ref_count_ = nullptr;
282 steps_[2] = elem_size_;
283 steps_[1] = channel * steps_[2];
284 steps_[0] = width_ * steps_[1];
285 }
286
IsEmpty() const287 bool LiteMat::IsEmpty() const { return data_ptr_ == nullptr || c_step_ * channel_ == 0; }
288
Release()289 void LiteMat::Release() {
290 if (ref_count_ && (addRef(ref_count_, -1) == 1)) {
291 if (data_ptr_) {
292 AlignFree(data_ptr_);
293 }
294 delete[] ref_count_;
295 }
296 data_ptr_ = nullptr;
297 elem_size_ = 0;
298 width_ = 0;
299 height_ = 0;
300 channel_ = 0;
301 c_step_ = 0;
302 ref_count_ = nullptr;
303 size_ = 0;
304 setSteps(0, 0, 0);
305 }
306
AlignMalloc(unsigned int size)307 void *LiteMat::AlignMalloc(unsigned int size) {
308 unsigned int length = sizeof(void *) + kAlign - 1;
309 if (size > std::numeric_limits<uint32_t>::max() - length) {
310 return nullptr;
311 }
312 void *p_raw = reinterpret_cast<void *>(malloc(size + length));
313 if (p_raw) {
314 release_flag_ = true;
315 void **p_algin = reinterpret_cast<void **>((reinterpret_cast<size_t>(p_raw) + length) & ~(kAlign - 1));
316 p_algin[-1] = p_raw;
317 return p_algin;
318 }
319 return nullptr;
320 }
321
AlignFree(void * ptr)322 void LiteMat::AlignFree(void *ptr) {
323 if (release_flag_) {
324 (void)free(reinterpret_cast<void **>(ptr)[-1]);
325 ptr = nullptr;
326 release_flag_ = false;
327 }
328 }
329
InitElemSize(LDataType data_type)330 inline void LiteMat::InitElemSize(LDataType data_type) { elem_size_ = data_type.SizeInBytes(); }
331
CheckLiteMat() const332 bool LiteMat::CheckLiteMat() const {
333 if (width_ <= 0 || height_ <= 0 || channel_ <= 0 || elem_size_ <= 0) {
334 return false;
335 }
336 if (height_ != 1 && height_ > std::numeric_limits<int>::max() / width_) {
337 return false;
338 }
339 int area = height_ * width_;
340 if (channel_ != 1 && channel_ > std::numeric_limits<int>::max() / area) {
341 return false;
342 }
343 int size = area * channel_;
344 if (elem_size_ > std::numeric_limits<int>::max() / size) {
345 return false;
346 }
347 return true;
348 }
349
GetROI(int x,int y,int w,int h,LiteMat & m)350 bool LiteMat::GetROI(int x, int y, int w, int h, LiteMat &m) {
351 if (x < 0 || y < 0 || x > width_ - w || h > height_ - y || w <= 0 || h <= 0) {
352 return false;
353 }
354 if (!m.IsEmpty()) {
355 m.Release();
356 }
357
358 if (ref_count_) {
359 addRef(ref_count_, 1);
360 }
361
362 m.height_ = h;
363 m.width_ = w;
364 m.dims_ = dims_;
365 m.elem_size_ = elem_size_;
366 m.data_ptr_ = reinterpret_cast<uint8_t *>(data_ptr_) + y * steps_[0] + x * elem_size_ * channel_;
367 m.channel_ = channel_;
368 m.c_step_ = c_step_;
369 m.data_type_ = data_type_;
370 m.ref_count_ = ref_count_;
371 m.setSteps(steps_[0], steps_[1], steps_[2]);
372 return true;
373 }
374
375 template <typename T>
SubtractImpl(const T * src0,const T * src1,T * dst,int64_t total_size)376 inline void SubtractImpl(const T *src0, const T *src1, T *dst, int64_t total_size) {
377 for (int64_t i = 0; i < total_size; i++) {
378 dst[i] = src0[i] - src1[i];
379 }
380 }
381
382 template <>
SubtractImpl(const uint8_t * src0,const uint8_t * src1,uint8_t * dst,int64_t total_size)383 inline void SubtractImpl(const uint8_t *src0, const uint8_t *src1, uint8_t *dst, int64_t total_size) {
384 int64_t x = 0;
385 #ifdef ENABLE_NEON
386 const int64_t step = 32;
387 for (; x <= total_size - step; x += step) {
388 uint8x16_t v_src00 = vld1q_u8(src0 + x);
389 uint8x16_t v_src01 = vld1q_u8(src0 + x + 16);
390 uint8x16_t v_src10 = vld1q_u8(src1 + x);
391 uint8x16_t v_src11 = vld1q_u8(src1 + x + 16);
392 uint8x16_t v_dst;
393
394 v_dst = vqsubq_u8(v_src00, v_src10);
395 vst1q_u8(dst + x, v_dst);
396
397 v_dst = vqsubq_u8(v_src01, v_src11);
398 vst1q_u8(dst + x + 16, v_dst);
399 }
400 #endif
401 for (; x < total_size; x++) {
402 int32_t val = static_cast<int32_t>(src0[x]) - src1[x];
403 dst[x] = std::max<int32_t>(std::numeric_limits<uint8_t>::min(),
404 std::min<int32_t>(std::numeric_limits<uint8_t>::max(), val));
405 }
406 }
407
408 template <>
SubtractImpl(const uint16_t * src0,const uint16_t * src1,uint16_t * dst,int64_t total_size)409 inline void SubtractImpl(const uint16_t *src0, const uint16_t *src1, uint16_t *dst, int64_t total_size) {
410 for (int64_t i = 0; i < total_size; i++) {
411 int32_t val = static_cast<int32_t>(src0[i]) - src1[i];
412 dst[i] = std::max<int32_t>(std::numeric_limits<uint16_t>::min(),
413 std::min<int32_t>(std::numeric_limits<uint16_t>::max(), val));
414 }
415 }
416
417 template <>
SubtractImpl(const uint32_t * src0,const uint32_t * src1,uint32_t * dst,int64_t total_size)418 inline void SubtractImpl(const uint32_t *src0, const uint32_t *src1, uint32_t *dst, int64_t total_size) {
419 for (int64_t i = 0; i < total_size; i++) {
420 int64_t val = static_cast<int64_t>(src0[i]) - src1[i];
421 dst[i] = std::max<int64_t>(std::numeric_limits<uint32_t>::min(),
422 std::min<int64_t>(std::numeric_limits<uint32_t>::max(), val));
423 }
424 }
425
CheckSubstract(const LiteMat & src_a,const LiteMat & src_b,LiteMat * dst)426 inline bool CheckSubstract(const LiteMat &src_a, const LiteMat &src_b, LiteMat *dst) {
427 if (dst == nullptr) {
428 return false;
429 }
430
431 if (src_a.width_ != src_b.width_ || src_a.height_ != src_b.height_ || src_a.channel_ != src_b.channel_) {
432 return false;
433 }
434
435 return src_a.data_type_ == src_b.data_type_;
436 }
437
Subtract(const LiteMat & src_a,const LiteMat & src_b,LiteMat * dst)438 bool Subtract(const LiteMat &src_a, const LiteMat &src_b, LiteMat *dst) {
439 if (!CheckSubstract(src_a, src_b, dst)) {
440 return false;
441 }
442
443 if (dst->IsEmpty()) {
444 dst->Init(src_a.width_, src_a.height_, src_a.channel_, src_a.data_type_);
445 }
446 if (src_a.width_ != dst->width_ || src_a.height_ != dst->height_ || src_a.channel_ != dst->channel_) {
447 return false;
448 }
449 if (src_a.data_type_ != dst->data_type_) {
450 return false;
451 }
452
453 int64_t total_size = src_a.height_ * src_a.width_ * src_a.channel_;
454 if (src_a.data_type_ == LDataType::BOOL) {
455 SubtractImpl<bool>(src_a, src_b, *dst, total_size);
456 } else if (src_a.data_type_ == LDataType::INT8) {
457 SubtractImpl<int8_t>(src_a, src_b, *dst, total_size);
458 } else if (src_a.data_type_ == LDataType::UINT8) {
459 SubtractImpl<uint8_t>(src_a, src_b, *dst, total_size);
460 } else if (src_a.data_type_ == LDataType::INT16) {
461 SubtractImpl<int16_t>(src_a, src_b, *dst, total_size);
462 } else if (src_a.data_type_ == LDataType::UINT16) {
463 SubtractImpl<uint16_t>(src_a, src_b, *dst, total_size);
464 } else if (src_a.data_type_ == LDataType::INT32) {
465 SubtractImpl<int32_t>(src_a, src_b, *dst, total_size);
466 } else if (src_a.data_type_ == LDataType::UINT32) {
467 SubtractImpl<uint32_t>(src_a, src_b, *dst, total_size);
468 } else if (src_a.data_type_ == LDataType::INT64) {
469 SubtractImpl<int64_t>(src_a, src_b, *dst, total_size);
470 } else if (src_a.data_type_ == LDataType::UINT64) {
471 SubtractImpl<uint64_t>(src_a, src_b, *dst, total_size);
472 } else if (src_a.data_type_ == LDataType::FLOAT32) {
473 SubtractImpl<float>(src_a, src_b, *dst, total_size);
474 } else if (src_a.data_type_ == LDataType::FLOAT64) {
475 SubtractImpl<double>(src_a, src_b, *dst, total_size);
476 } else {
477 return false;
478 }
479
480 return true;
481 }
482
483 #ifdef ENABLE_NEON
reciprocal_simd(float32x4_t val)484 inline float32x4_t reciprocal_simd(float32x4_t val) {
485 // get an initial estimate of 1/val
486 float32x4_t reciprocal = vrecpeq_f32(val);
487
488 // use Newton-Raphson steps to refine the estimate
489 reciprocal = vmulq_f32(vrecpsq_f32(val, reciprocal), reciprocal);
490 reciprocal = vmulq_f32(vrecpsq_f32(val, reciprocal), reciprocal);
491 return reciprocal;
492 }
493
round_simd(const float32x4_t & v)494 inline float32x4_t round_simd(const float32x4_t &v) {
495 const int32x4_t signMask = vdupq_n_s32(1U << 31);
496 const int32x4_t half = vreinterpretq_s32_f32(vdupq_n_f32(0.5f));
497 float32x4_t v_addition = vreinterpretq_f32_s32(vorrq_s32(half, vandq_s32(signMask, vreinterpretq_s32_f32(v))));
498 return vaddq_f32(v, v_addition);
499 }
500 #endif
501
502 template <typename T>
DivideImpl(const T * src0,const T * src1,T * dst,int64_t total_size)503 inline void DivideImpl(const T *src0, const T *src1, T *dst, int64_t total_size) {
504 for (int64_t i = 0; i < total_size; i++) {
505 dst[i] = src1[i] ? src0[i] / src1[i] : 0;
506 }
507 }
508
509 template <>
DivideImpl(const uint8_t * src0,const uint8_t * src1,uint8_t * dst,int64_t total_size)510 inline void DivideImpl(const uint8_t *src0, const uint8_t *src1, uint8_t *dst, int64_t total_size) {
511 int64_t x = 0;
512 #ifdef ENABLE_NEON
513 const int64_t step = 16;
514 for (; x <= total_size - step; x += step) {
515 __builtin_prefetch(reinterpret_cast<const char *>(src0 + x) + 32 * 10);
516 __builtin_prefetch(reinterpret_cast<const char *>(src1 + x) + 32 * 10);
517
518 uint8x16_t v_a = vld1q_u8(src0 + x);
519 uint8x16_t v_b = vld1q_u8(src1 + x);
520 uint8x16_t v_mask = vtstq_u8(v_b, v_b);
521
522 uint16x8_t va_l_16x8 = vmovl_u8(vget_low_u8(v_a));
523 uint16x8_t va_h_16x8 = vmovl_u8(vget_high_u8(v_a));
524 uint16x8_t vb_l_16x8 = vmovl_u8(vget_low_u8(v_b));
525 uint16x8_t vb_h_16x8 = vmovl_u8(vget_high_u8(v_b));
526
527 float32x4_t va_ll_f32x4 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(va_l_16x8)));
528 float32x4_t va_lh_f32x4 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(va_l_16x8)));
529 float32x4_t va_hl_f32x4 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(va_h_16x8)));
530 float32x4_t va_hh_f32x4 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(va_h_16x8)));
531 float32x4_t vb_ll_f32x4 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(vb_l_16x8)));
532 float32x4_t vb_lh_f32x4 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vb_l_16x8)));
533 float32x4_t vb_hl_f32x4 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(vb_h_16x8)));
534 float32x4_t vb_hh_f32x4 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vb_h_16x8)));
535
536 float32x4_t vb_ll_re_f32x4 = reciprocal_simd(vb_ll_f32x4);
537 float32x4_t vb_lh_re_f32x4 = reciprocal_simd(vb_lh_f32x4);
538 float32x4_t vb_hl_re_f32x4 = reciprocal_simd(vb_hl_f32x4);
539 float32x4_t vb_hh_re_f32x4 = reciprocal_simd(vb_hh_f32x4);
540
541 float32x4_t dst_ll_f32x4 = round_simd(vmulq_f32(va_ll_f32x4, vb_ll_re_f32x4));
542 float32x4_t dst_lh_f32x4 = round_simd(vmulq_f32(va_lh_f32x4, vb_lh_re_f32x4));
543 float32x4_t dst_hl_f32x4 = round_simd(vmulq_f32(va_hl_f32x4, vb_hl_re_f32x4));
544 float32x4_t dst_hh_f32x4 = round_simd(vmulq_f32(va_hh_f32x4, vb_hh_re_f32x4));
545
546 uint32x4_t dst_ll_32x4 = vcvtq_u32_f32(dst_ll_f32x4);
547 uint32x4_t dst_lh_32x4 = vcvtq_u32_f32(dst_lh_f32x4);
548 uint32x4_t dst_hl_32x4 = vcvtq_u32_f32(dst_hl_f32x4);
549 uint32x4_t dst_hh_32x4 = vcvtq_u32_f32(dst_hh_f32x4);
550
551 uint16x4_t dst_ll_16x4 = vqmovn_u32(dst_ll_32x4);
552 uint16x4_t dst_lh_16x4 = vqmovn_u32(dst_lh_32x4);
553 uint16x4_t dst_hl_16x4 = vqmovn_u32(dst_hl_32x4);
554 uint16x4_t dst_hh_16x4 = vqmovn_u32(dst_hh_32x4);
555
556 uint16x8_t dst_l_16x8 = vcombine_u16(dst_ll_16x4, dst_lh_16x4);
557 uint16x8_t dst_h_16x8 = vcombine_u16(dst_hl_16x4, dst_hh_16x4);
558
559 uint8x8_t dst_l_8x8 = vqmovn_u16(dst_l_16x8);
560 uint8x8_t dst_h_8x8 = vqmovn_u16(dst_h_16x8);
561 uint8x16_t dst_8x16 = vcombine_u8(dst_l_8x8, dst_h_8x8);
562
563 dst_8x16 = vandq_u8(dst_8x16, v_mask);
564 vst1q_u8(dst + x, dst_8x16);
565 }
566 #endif
567 for (; x < total_size; x++) {
568 int32_t val = src1[x] ? static_cast<int32_t>(std::round(src0[x] / src1[x])) : 0;
569 dst[x] = std::max<int32_t>(std::numeric_limits<uint8_t>::min(),
570 std::min<int32_t>(std::numeric_limits<uint8_t>::max(), val));
571 }
572 }
573
574 template <>
DivideImpl(const uint16_t * src0,const uint16_t * src1,uint16_t * dst,int64_t total_size)575 inline void DivideImpl(const uint16_t *src0, const uint16_t *src1, uint16_t *dst, int64_t total_size) {
576 for (size_t i = 0; i < total_size; i++) {
577 int32_t val = src1[i] ? static_cast<int32_t>(std::round(src0[i] / src1[i])) : 0;
578 dst[i] = std::max<int32_t>(std::numeric_limits<uint16_t>::min(),
579 std::min<int32_t>(std::numeric_limits<uint16_t>::max(), val));
580 }
581 }
582
583 template <>
DivideImpl(const uint32_t * src0,const uint32_t * src1,uint32_t * dst,int64_t total_size)584 inline void DivideImpl(const uint32_t *src0, const uint32_t *src1, uint32_t *dst, int64_t total_size) {
585 for (size_t i = 0; i < total_size; i++) {
586 int64_t val = src1[i] ? static_cast<int64_t>(std::round(src0[i] / src1[i])) : 0;
587 dst[i] = std::max<int64_t>(std::numeric_limits<uint32_t>::min(),
588 std::min<int64_t>(std::numeric_limits<uint32_t>::max(), val));
589 }
590 }
591
CheckDivide(const LiteMat & src_a,const LiteMat & src_b,LiteMat * dst)592 inline bool CheckDivide(const LiteMat &src_a, const LiteMat &src_b, LiteMat *dst) {
593 if (dst == nullptr) {
594 return false;
595 }
596
597 if (src_a.width_ != src_b.width_ || src_a.height_ != src_b.height_ || src_a.channel_ != src_b.channel_) {
598 return false;
599 }
600
601 return src_a.data_type_ == src_b.data_type_;
602 }
603
Divide(const LiteMat & src_a,const LiteMat & src_b,LiteMat * dst)604 bool Divide(const LiteMat &src_a, const LiteMat &src_b, LiteMat *dst) {
605 if (!CheckDivide(src_a, src_b, dst)) {
606 return false;
607 }
608
609 if (dst->IsEmpty()) {
610 dst->Init(src_a.width_, src_a.height_, src_a.channel_, src_a.data_type_);
611 } else if (src_a.width_ != dst->width_ || src_a.height_ != dst->height_ || src_a.channel_ != dst->channel_) {
612 return false;
613 } else if (src_a.data_type_ != dst->data_type_) {
614 return false;
615 }
616
617 int64_t total_size = src_a.height_ * src_a.width_ * src_a.channel_;
618 if (src_a.data_type_ == LDataType::INT8) {
619 DivideImpl<int8_t>(src_a, src_b, *dst, total_size);
620 } else if (src_a.data_type_ == LDataType::UINT8) {
621 DivideImpl<uint8_t>(src_a, src_b, *dst, total_size);
622 } else if (src_a.data_type_ == LDataType::INT16) {
623 DivideImpl<int16_t>(src_a, src_b, *dst, total_size);
624 } else if (src_a.data_type_ == LDataType::UINT16) {
625 DivideImpl<uint16_t>(src_a, src_b, *dst, total_size);
626 } else if (src_a.data_type_ == LDataType::INT32) {
627 DivideImpl<int32_t>(src_a, src_b, *dst, total_size);
628 } else if (src_a.data_type_ == LDataType::UINT32) {
629 DivideImpl<uint32_t>(src_a, src_b, *dst, total_size);
630 } else if (src_a.data_type_ == LDataType::INT64) {
631 DivideImpl<int64_t>(src_a, src_b, *dst, total_size);
632 } else if (src_a.data_type_ == LDataType::UINT64) {
633 DivideImpl<uint64_t>(src_a, src_b, *dst, total_size);
634 } else if (src_a.data_type_ == LDataType::FLOAT32) {
635 DivideImpl<float>(src_a, src_b, *dst, total_size);
636 } else if (src_a.data_type_ == LDataType::FLOAT64) {
637 DivideImpl<double>(src_a, src_b, *dst, total_size);
638 } else {
639 return false;
640 }
641 return true;
642 }
643
644 template <typename T>
MultiplyImpl(const T * src0,const T * src1,T * dst,int64_t total_size)645 inline void MultiplyImpl(const T *src0, const T *src1, T *dst, int64_t total_size) {
646 for (int64_t i = 0; i < total_size; i++) {
647 dst[i] = src0[i] * src1[i];
648 }
649 }
650
651 template <>
MultiplyImpl(const uint8_t * src0,const uint8_t * src1,uint8_t * dst,int64_t total_size)652 inline void MultiplyImpl(const uint8_t *src0, const uint8_t *src1, uint8_t *dst, int64_t total_size) {
653 int64_t x = 0;
654 #ifdef ENABLE_NEON
655 const int64_t step = 32;
656 for (; x <= total_size - step; x += step) {
657 uint8x16_t v_src00 = vld1q_u8(src0 + x);
658 uint8x16_t v_src01 = vld1q_u8(src0 + x + 16);
659 uint8x16_t v_src10 = vld1q_u8(src1 + x);
660 uint8x16_t v_src11 = vld1q_u8(src1 + x + 16);
661 uint16x8_t v_dst_l, v_dst_h;
662
663 v_dst_l = vmull_u8(vget_low_u8(v_src00), vget_low_u8(v_src10));
664 v_dst_h = vmull_u8(vget_high_u8(v_src00), vget_high_u8(v_src10));
665 vst1q_u8(dst + x, vcombine_u8(vqmovn_u16(v_dst_l), vqmovn_u16(v_dst_h)));
666
667 v_dst_l = vmull_u8(vget_low_u8(v_src01), vget_low_u8(v_src11));
668 v_dst_h = vmull_u8(vget_high_u8(v_src01), vget_high_u8(v_src11));
669 vst1q_u8(dst + x + 16, vcombine_u8(vqmovn_u16(v_dst_l), vqmovn_u16(v_dst_h)));
670 }
671 #endif
672 for (; x < total_size; x++) {
673 int32_t val = src0[x] * src1[x];
674 dst[x] = std::max<int32_t>(std::numeric_limits<uint8_t>::min(),
675 std::min<int32_t>(std::numeric_limits<uint8_t>::max(), val));
676 }
677 }
678
679 template <>
MultiplyImpl(const uint16_t * src0,const uint16_t * src1,uint16_t * dst,int64_t total_size)680 inline void MultiplyImpl(const uint16_t *src0, const uint16_t *src1, uint16_t *dst, int64_t total_size) {
681 for (size_t i = 0; i < total_size; i++) {
682 int32_t val = src0[i] * src1[i];
683 dst[i] = std::max<int32_t>(std::numeric_limits<uint16_t>::min(),
684 std::min<int32_t>(std::numeric_limits<uint16_t>::max(), val));
685 }
686 }
687
688 template <>
MultiplyImpl(const uint32_t * src0,const uint32_t * src1,uint32_t * dst,int64_t total_size)689 inline void MultiplyImpl(const uint32_t *src0, const uint32_t *src1, uint32_t *dst, int64_t total_size) {
690 for (size_t i = 0; i < total_size; i++) {
691 int64_t val = src0[i] * src1[i];
692 dst[i] = std::max<int64_t>(std::numeric_limits<uint32_t>::min(),
693 std::min<int64_t>(std::numeric_limits<uint32_t>::max(), val));
694 }
695 }
696
CheckMultiply(const LiteMat & src_a,const LiteMat & src_b,LiteMat * dst)697 inline bool CheckMultiply(const LiteMat &src_a, const LiteMat &src_b, LiteMat *dst) {
698 if (dst == nullptr) {
699 return false;
700 }
701
702 if (src_a.width_ != src_b.width_ || src_a.height_ != src_b.height_ || src_a.channel_ != src_b.channel_) {
703 return false;
704 }
705
706 return src_a.data_type_ == src_b.data_type_;
707 }
708
Multiply(const LiteMat & src_a,const LiteMat & src_b,LiteMat * dst)709 bool Multiply(const LiteMat &src_a, const LiteMat &src_b, LiteMat *dst) {
710 if (!CheckMultiply(src_a, src_b, dst)) {
711 return false;
712 }
713 if (dst->IsEmpty()) {
714 dst->Init(src_a.width_, src_a.height_, src_a.channel_, src_a.data_type_);
715 } else if (src_a.width_ != dst->width_ || src_a.height_ != dst->height_ || src_a.channel_ != dst->channel_) {
716 return false;
717 } else if (src_a.data_type_ != dst->data_type_) {
718 return false;
719 }
720
721 int64_t total_size = src_a.height_ * src_a.width_ * src_a.channel_;
722 if (src_a.data_type_ == LDataType::INT8) {
723 MultiplyImpl<int8_t>(src_a, src_b, *dst, total_size);
724 } else if (src_a.data_type_ == LDataType::UINT8) {
725 MultiplyImpl<uint8_t>(src_a, src_b, *dst, total_size);
726 } else if (src_a.data_type_ == LDataType::INT16) {
727 MultiplyImpl<int16_t>(src_a, src_b, *dst, total_size);
728 } else if (src_a.data_type_ == LDataType::UINT16) {
729 MultiplyImpl<uint16_t>(src_a, src_b, *dst, total_size);
730 } else if (src_a.data_type_ == LDataType::INT32) {
731 MultiplyImpl<int32_t>(src_a, src_b, *dst, total_size);
732 } else if (src_a.data_type_ == LDataType::UINT32) {
733 MultiplyImpl<uint32_t>(src_a, src_b, *dst, total_size);
734 } else if (src_a.data_type_ == LDataType::INT64) {
735 MultiplyImpl<int64_t>(src_a, src_b, *dst, total_size);
736 } else if (src_a.data_type_ == LDataType::UINT64) {
737 MultiplyImpl<uint64_t>(src_a, src_b, *dst, total_size);
738 } else if (src_a.data_type_ == LDataType::FLOAT32) {
739 MultiplyImpl<float>(src_a, src_b, *dst, total_size);
740 } else if (src_a.data_type_ == LDataType::FLOAT64) {
741 MultiplyImpl<double>(src_a, src_b, *dst, total_size);
742 } else {
743 return false;
744 }
745 return true;
746 }
747
748 } // namespace dataset
749 } // namespace mindspore
750