1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "minddata/dataset/kernels/image/lite_cv/lite_mat.h"
17
18 #include <limits>
19 #include <algorithm>
20 #include <cmath>
21 #ifdef ENABLE_NEON
22 #include <arm_neon.h>
23 #endif
24
25 namespace mindspore {
26 namespace dataset {
27
LiteMat()28 LiteMat::LiteMat() {
29 data_ptr_ = nullptr;
30 elem_size_ = 0;
31 width_ = 0;
32 height_ = 0;
33 channel_ = 0;
34 c_step_ = 0;
35 dims_ = 0;
36 size_ = 0;
37 data_type_ = LDataType::UINT8;
38 ref_count_ = nullptr;
39 setSteps(0, 0, 0);
40 release_flag = false;
41 }
42
LiteMat(int width,LDataType data_type)43 LiteMat::LiteMat(int width, LDataType data_type) {
44 data_ptr_ = nullptr;
45 elem_size_ = 0;
46 width_ = 0;
47 height_ = 0;
48 channel_ = 0;
49 c_step_ = 0;
50 dims_ = 0;
51 data_type_ = LDataType::UINT8;
52 ref_count_ = nullptr;
53 size_ = 0;
54 setSteps(0, 0, 0);
55 release_flag = false;
56 Init(width, data_type);
57 }
58
LiteMat(int width,int height,LDataType data_type)59 LiteMat::LiteMat(int width, int height, LDataType data_type) {
60 data_ptr_ = nullptr;
61 elem_size_ = 0;
62 width_ = 0;
63 height_ = 0;
64 channel_ = 0;
65 c_step_ = 0;
66 dims_ = 0;
67 data_type_ = LDataType::UINT8;
68 ref_count_ = nullptr;
69 size_ = 0;
70 setSteps(0, 0, 0);
71 release_flag = false;
72 Init(width, height, data_type);
73 }
74
LiteMat(int width,int height,void * p_data,LDataType data_type)75 LiteMat::LiteMat(int width, int height, void *p_data, LDataType data_type) {
76 data_ptr_ = nullptr;
77 elem_size_ = 0;
78 width_ = 0;
79 height_ = 0;
80 channel_ = 0;
81 c_step_ = 0;
82 dims_ = 0;
83 data_type_ = LDataType::UINT8;
84 ref_count_ = nullptr;
85 size_ = 0;
86 setSteps(0, 0, 0);
87 release_flag = false;
88 Init(width, height, p_data, data_type);
89 }
90
LiteMat(int width,int height,int channel,LDataType data_type)91 LiteMat::LiteMat(int width, int height, int channel, LDataType data_type) {
92 data_ptr_ = nullptr;
93 elem_size_ = 0;
94 width_ = 0;
95 height_ = 0;
96 channel_ = 0;
97 c_step_ = 0;
98 dims_ = 0;
99 data_type_ = LDataType::UINT8;
100 ref_count_ = nullptr;
101 size_ = 0;
102 setSteps(0, 0, 0);
103 release_flag = false;
104 Init(width, height, channel, data_type);
105 }
106
LiteMat(int width,int height,int channel,void * p_data,LDataType data_type)107 LiteMat::LiteMat(int width, int height, int channel, void *p_data, LDataType data_type) {
108 data_ptr_ = nullptr;
109 elem_size_ = 0;
110 width_ = 0;
111 height_ = 0;
112 channel_ = 0;
113 c_step_ = 0;
114 dims_ = 0;
115 data_type_ = LDataType::UINT8;
116 ref_count_ = nullptr;
117 size_ = 0;
118 setSteps(0, 0, 0);
119 release_flag = false;
120 Init(width, height, channel, p_data, data_type);
121 }
122
~LiteMat()123 LiteMat::~LiteMat() { Release(); }
124
addRef(int * p,int value)125 int LiteMat::addRef(int *p, int value) {
126 int v = *p;
127 *p += value;
128 return v;
129 }
130
LiteMat(const LiteMat & m)131 LiteMat::LiteMat(const LiteMat &m) {
132 data_ptr_ = m.data_ptr_;
133 elem_size_ = m.elem_size_;
134 width_ = m.width_;
135 height_ = m.height_;
136 channel_ = m.channel_;
137 c_step_ = m.c_step_;
138 dims_ = m.dims_;
139 data_type_ = m.data_type_;
140 ref_count_ = m.ref_count_;
141 size_ = m.size_;
142 release_flag = m.release_flag;
143 setSteps(m.steps_[0], m.steps_[1], m.steps_[2]);
144 if (ref_count_) {
145 addRef(ref_count_, 1);
146 }
147 }
148
setSteps(int c0,int c1,int c2)149 void LiteMat::setSteps(int c0, int c1, int c2) {
150 steps_[0] = c0;
151 steps_[1] = c1;
152 steps_[2] = c2;
153 }
154
operator =(const LiteMat & m)155 LiteMat &LiteMat::operator=(const LiteMat &m) {
156 if (this == &m) {
157 return *this;
158 }
159
160 if (m.ref_count_) {
161 addRef(m.ref_count_, 1);
162 }
163
164 Release();
165 data_ptr_ = m.data_ptr_;
166 elem_size_ = m.elem_size_;
167 width_ = m.width_;
168 height_ = m.height_;
169 channel_ = m.channel_;
170 c_step_ = m.c_step_;
171 dims_ = m.dims_;
172 data_type_ = m.data_type_;
173 ref_count_ = m.ref_count_;
174 setSteps(m.steps_[0], m.steps_[1], m.steps_[2]);
175 size_ = m.size_;
176 return *this;
177 }
178
Init(int width,LDataType data_type)179 void LiteMat::Init(int width, LDataType data_type) {
180 Release();
181 data_type_ = data_type;
182 InitElemSize(data_type);
183 width_ = width;
184 dims_ = 1;
185 height_ = 1;
186 channel_ = 1;
187 if (!CheckLiteMat()) {
188 Release();
189 return;
190 }
191 c_step_ = width;
192 size_ = c_step_ * elem_size_;
193 data_ptr_ = AlignMalloc(size_);
194 ref_count_ = new int[1];
195 *ref_count_ = 1;
196 steps_[0] = elem_size_;
197 }
198
Init(int width,int height,LDataType data_type)199 void LiteMat::Init(int width, int height, LDataType data_type) {
200 Release();
201 data_type_ = data_type;
202 InitElemSize(data_type);
203 width_ = width;
204 height_ = height;
205 dims_ = 2;
206 channel_ = 1;
207 if (!CheckLiteMat()) {
208 Release();
209 return;
210 }
211 c_step_ = width_ * height_;
212 size_ = c_step_ * elem_size_;
213 data_ptr_ = AlignMalloc(size_);
214 ref_count_ = new int[1];
215 *ref_count_ = 1;
216 steps_[1] = elem_size_;
217 steps_[0] = width_ * steps_[1];
218 }
219
Init(int width,int height,void * p_data,LDataType data_type)220 void LiteMat::Init(int width, int height, void *p_data, LDataType data_type) {
221 data_type_ = data_type;
222 InitElemSize(data_type);
223 width_ = width;
224 height_ = height;
225 dims_ = 2;
226 channel_ = 1;
227 if (!CheckLiteMat()) {
228 Release();
229 return;
230 }
231 c_step_ = height_ * width_;
232 size_ = c_step_ * channel_ * elem_size_;
233 data_ptr_ = p_data;
234 ref_count_ = nullptr;
235 steps_[1] = elem_size_;
236 steps_[0] = width_ * steps_[1];
237 }
238
Init(int width,int height,int channel,const LDataType & data_type,bool align_memory)239 void LiteMat::Init(int width, int height, int channel, const LDataType &data_type, bool align_memory) {
240 Release();
241 data_type_ = data_type;
242 InitElemSize(data_type);
243 width_ = width;
244 height_ = height;
245 dims_ = 3;
246 channel_ = channel;
247 if (!CheckLiteMat()) {
248 Release();
249 return;
250 }
251 if (align_memory) {
252 c_step_ = ((height_ * width_ * elem_size_ + ALIGN - 1) & (-ALIGN)) / elem_size_;
253 } else {
254 c_step_ = height_ * width_;
255 }
256 size_ = c_step_ * channel_ * elem_size_;
257 data_ptr_ = AlignMalloc(size_);
258 ref_count_ = new int[1];
259 *ref_count_ = 1;
260
261 steps_[2] = elem_size_;
262 steps_[1] = channel * steps_[2];
263 steps_[0] = width_ * steps_[1];
264 }
265
Init(int width,int height,int channel,void * p_data,LDataType data_type)266 void LiteMat::Init(int width, int height, int channel, void *p_data, LDataType data_type) {
267 data_type_ = data_type;
268 InitElemSize(data_type);
269 width_ = width;
270 height_ = height;
271 dims_ = 3;
272 channel_ = channel;
273 if (!CheckLiteMat()) {
274 Release();
275 return;
276 }
277 c_step_ = height_ * width_;
278 size_ = c_step_ * channel_ * elem_size_;
279 data_ptr_ = p_data;
280 ref_count_ = nullptr;
281 steps_[2] = elem_size_;
282 steps_[1] = channel * steps_[2];
283 steps_[0] = width_ * steps_[1];
284 }
285
IsEmpty() const286 bool LiteMat::IsEmpty() const { return data_ptr_ == nullptr || c_step_ * channel_ == 0; }
287
Release()288 void LiteMat::Release() {
289 if (ref_count_ && (addRef(ref_count_, -1) == 1)) {
290 if (data_ptr_) {
291 AlignFree(data_ptr_);
292 }
293 delete[] ref_count_;
294 }
295 data_ptr_ = nullptr;
296 elem_size_ = 0;
297 width_ = 0;
298 height_ = 0;
299 channel_ = 0;
300 c_step_ = 0;
301 ref_count_ = nullptr;
302 size_ = 0;
303 setSteps(0, 0, 0);
304 }
305
AlignMalloc(unsigned int size)306 void *LiteMat::AlignMalloc(unsigned int size) {
307 unsigned int length = sizeof(void *) + ALIGN - 1;
308 if (size > std::numeric_limits<uint32_t>::max() - length) {
309 return nullptr;
310 }
311 void *p_raw = reinterpret_cast<void *>(malloc(size + length));
312 if (p_raw) {
313 release_flag = true;
314 void **p_algin = reinterpret_cast<void **>(((size_t)(p_raw) + length) & ~(ALIGN - 1));
315 p_algin[-1] = p_raw;
316 return p_algin;
317 }
318 return nullptr;
319 }
320
AlignFree(void * ptr)321 void LiteMat::AlignFree(void *ptr) {
322 if (release_flag) {
323 (void)free(reinterpret_cast<void **>(ptr)[-1]);
324 ptr = nullptr;
325 release_flag = false;
326 }
327 }
328
InitElemSize(LDataType data_type)329 inline void LiteMat::InitElemSize(LDataType data_type) { elem_size_ = data_type.SizeInBytes(); }
330
CheckLiteMat()331 bool LiteMat::CheckLiteMat() {
332 if (width_ <= 0 || height_ <= 0 || channel_ <= 0 || elem_size_ <= 0) {
333 return false;
334 }
335 if (height_ != 1 && height_ > std::numeric_limits<int>::max() / width_) {
336 return false;
337 }
338 int area = height_ * width_;
339 if (channel_ != 1 && channel_ > std::numeric_limits<int>::max() / area) {
340 return false;
341 }
342 int size = area * channel_;
343 if (elem_size_ > std::numeric_limits<int>::max() / size) {
344 return false;
345 }
346 return true;
347 }
348
GetROI(int x,int y,int w,int h,LiteMat & m)349 bool LiteMat::GetROI(int x, int y, int w, int h, LiteMat &m) {
350 if (x < 0 || y < 0 || x > width_ - w || h > height_ - y || w <= 0 || h <= 0) {
351 return false;
352 }
353 if (!m.IsEmpty()) {
354 m.Release();
355 }
356
357 if (ref_count_) {
358 addRef(ref_count_, 1);
359 }
360
361 m.height_ = h;
362 m.width_ = w;
363 m.dims_ = dims_;
364 m.elem_size_ = elem_size_;
365 m.data_ptr_ = reinterpret_cast<uint8_t *>(data_ptr_) + y * steps_[0] + x * elem_size_ * channel_;
366 m.channel_ = channel_;
367 m.c_step_ = c_step_;
368 m.data_type_ = data_type_;
369 m.ref_count_ = ref_count_;
370 m.setSteps(steps_[0], steps_[1], steps_[2]);
371 return true;
372 }
373
374 template <typename T>
SubtractImpl(const T * src0,const T * src1,T * dst,int64_t total_size)375 inline void SubtractImpl(const T *src0, const T *src1, T *dst, int64_t total_size) {
376 for (int64_t i = 0; i < total_size; i++) {
377 dst[i] = src0[i] - src1[i];
378 }
379 }
380
381 template <>
SubtractImpl(const uint8_t * src0,const uint8_t * src1,uint8_t * dst,int64_t total_size)382 inline void SubtractImpl(const uint8_t *src0, const uint8_t *src1, uint8_t *dst, int64_t total_size) {
383 int64_t x = 0;
384 #ifdef ENABLE_NEON
385 const int64_t step = 32;
386 for (; x <= total_size - step; x += step) {
387 uint8x16_t v_src00 = vld1q_u8(src0 + x);
388 uint8x16_t v_src01 = vld1q_u8(src0 + x + 16);
389 uint8x16_t v_src10 = vld1q_u8(src1 + x);
390 uint8x16_t v_src11 = vld1q_u8(src1 + x + 16);
391 uint8x16_t v_dst;
392
393 v_dst = vqsubq_u8(v_src00, v_src10);
394 vst1q_u8(dst + x, v_dst);
395
396 v_dst = vqsubq_u8(v_src01, v_src11);
397 vst1q_u8(dst + x + 16, v_dst);
398 }
399 #endif
400 for (; x < total_size; x++) {
401 int32_t val = static_cast<int32_t>(src0[x]) - src1[x];
402 dst[x] = std::max<int32_t>(std::numeric_limits<uint8_t>::min(),
403 std::min<int32_t>(std::numeric_limits<uint8_t>::max(), val));
404 }
405 }
406
407 template <>
SubtractImpl(const uint16_t * src0,const uint16_t * src1,uint16_t * dst,int64_t total_size)408 inline void SubtractImpl(const uint16_t *src0, const uint16_t *src1, uint16_t *dst, int64_t total_size) {
409 for (int64_t i = 0; i < total_size; i++) {
410 int32_t val = static_cast<int32_t>(src0[i]) - src1[i];
411 dst[i] = std::max<int32_t>(std::numeric_limits<uint16_t>::min(),
412 std::min<int32_t>(std::numeric_limits<uint16_t>::max(), val));
413 }
414 }
415
416 template <>
SubtractImpl(const uint32_t * src0,const uint32_t * src1,uint32_t * dst,int64_t total_size)417 inline void SubtractImpl(const uint32_t *src0, const uint32_t *src1, uint32_t *dst, int64_t total_size) {
418 for (int64_t i = 0; i < total_size; i++) {
419 int64_t val = static_cast<int64_t>(src0[i]) - src1[i];
420 dst[i] = std::max<int64_t>(std::numeric_limits<uint32_t>::min(),
421 std::min<int64_t>(std::numeric_limits<uint32_t>::max(), val));
422 }
423 }
424
CheckSubstract(const LiteMat & src_a,const LiteMat & src_b,LiteMat * dst)425 inline bool CheckSubstract(const LiteMat &src_a, const LiteMat &src_b, LiteMat *dst) {
426 if (dst == nullptr) {
427 return false;
428 }
429
430 if (src_a.width_ != src_b.width_ || src_a.height_ != src_b.height_ || src_a.channel_ != src_b.channel_) {
431 return false;
432 }
433
434 return src_a.data_type_ == src_b.data_type_;
435 }
436
Subtract(const LiteMat & src_a,const LiteMat & src_b,LiteMat * dst)437 bool Subtract(const LiteMat &src_a, const LiteMat &src_b, LiteMat *dst) {
438 if (!CheckSubstract(src_a, src_b, dst)) {
439 return false;
440 }
441
442 if (dst->IsEmpty()) {
443 dst->Init(src_a.width_, src_a.height_, src_a.channel_, src_a.data_type_);
444 } else if (src_a.width_ != dst->width_ || src_a.height_ != dst->height_ || src_a.channel_ != dst->channel_) {
445 return false;
446 } else if (src_a.data_type_ != dst->data_type_) {
447 return false;
448 }
449
450 int64_t total_size = src_a.height_ * src_a.width_ * src_a.channel_;
451 if (src_a.data_type_ == LDataType::BOOL) {
452 SubtractImpl<bool>(src_a, src_b, *dst, total_size);
453 } else if (src_a.data_type_ == LDataType::INT8) {
454 SubtractImpl<int8_t>(src_a, src_b, *dst, total_size);
455 } else if (src_a.data_type_ == LDataType::UINT8) {
456 SubtractImpl<uint8_t>(src_a, src_b, *dst, total_size);
457 } else if (src_a.data_type_ == LDataType::INT16) {
458 SubtractImpl<int16_t>(src_a, src_b, *dst, total_size);
459 } else if (src_a.data_type_ == LDataType::UINT16) {
460 SubtractImpl<uint16_t>(src_a, src_b, *dst, total_size);
461 } else if (src_a.data_type_ == LDataType::INT32) {
462 SubtractImpl<int32_t>(src_a, src_b, *dst, total_size);
463 } else if (src_a.data_type_ == LDataType::UINT32) {
464 SubtractImpl<uint32_t>(src_a, src_b, *dst, total_size);
465 } else if (src_a.data_type_ == LDataType::INT64) {
466 SubtractImpl<int64_t>(src_a, src_b, *dst, total_size);
467 } else if (src_a.data_type_ == LDataType::UINT64) {
468 SubtractImpl<uint64_t>(src_a, src_b, *dst, total_size);
469 } else if (src_a.data_type_ == LDataType::FLOAT32) {
470 SubtractImpl<float>(src_a, src_b, *dst, total_size);
471 } else if (src_a.data_type_ == LDataType::FLOAT64) {
472 SubtractImpl<double>(src_a, src_b, *dst, total_size);
473 } else {
474 return false;
475 }
476
477 return true;
478 }
479
480 #ifdef ENABLE_NEON
reciprocal_simd(float32x4_t val)481 inline float32x4_t reciprocal_simd(float32x4_t val) {
482 // get an initial estimate of 1/val
483 float32x4_t reciprocal = vrecpeq_f32(val);
484
485 // use Newton-Raphson steps to refine the estimate
486 reciprocal = vmulq_f32(vrecpsq_f32(val, reciprocal), reciprocal);
487 reciprocal = vmulq_f32(vrecpsq_f32(val, reciprocal), reciprocal);
488 return reciprocal;
489 }
490
round_simd(const float32x4_t & v)491 inline float32x4_t round_simd(const float32x4_t &v) {
492 const int32x4_t signMask = vdupq_n_s32(1U << 31);
493 const int32x4_t half = vreinterpretq_s32_f32(vdupq_n_f32(0.5f));
494 float32x4_t v_addition = vreinterpretq_f32_s32(vorrq_s32(half, vandq_s32(signMask, vreinterpretq_s32_f32(v))));
495 return vaddq_f32(v, v_addition);
496 }
497 #endif
498
499 template <typename T>
DivideImpl(const T * src0,const T * src1,T * dst,int64_t total_size)500 inline void DivideImpl(const T *src0, const T *src1, T *dst, int64_t total_size) {
501 for (size_t i = 0; i < total_size; i++) {
502 dst[i] = src1[i] ? src0[i] / src1[i] : 0;
503 }
504 }
505
506 template <>
DivideImpl(const uint8_t * src0,const uint8_t * src1,uint8_t * dst,int64_t total_size)507 inline void DivideImpl(const uint8_t *src0, const uint8_t *src1, uint8_t *dst, int64_t total_size) {
508 int64_t x = 0;
509 #ifdef ENABLE_NEON
510 const int64_t step = 16;
511 for (; x <= total_size - step; x += step) {
512 __builtin_prefetch(reinterpret_cast<const char *>(src0 + x) + 32 * 10);
513 __builtin_prefetch(reinterpret_cast<const char *>(src1 + x) + 32 * 10);
514
515 uint8x16_t v_a = vld1q_u8(src0 + x);
516 uint8x16_t v_b = vld1q_u8(src1 + x);
517 uint8x16_t v_mask = vtstq_u8(v_b, v_b);
518
519 uint16x8_t va_l_16x8 = vmovl_u8(vget_low_u8(v_a));
520 uint16x8_t va_h_16x8 = vmovl_u8(vget_high_u8(v_a));
521 uint16x8_t vb_l_16x8 = vmovl_u8(vget_low_u8(v_b));
522 uint16x8_t vb_h_16x8 = vmovl_u8(vget_high_u8(v_b));
523
524 float32x4_t va_ll_f32x4 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(va_l_16x8)));
525 float32x4_t va_lh_f32x4 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(va_l_16x8)));
526 float32x4_t va_hl_f32x4 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(va_h_16x8)));
527 float32x4_t va_hh_f32x4 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(va_h_16x8)));
528 float32x4_t vb_ll_f32x4 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(vb_l_16x8)));
529 float32x4_t vb_lh_f32x4 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vb_l_16x8)));
530 float32x4_t vb_hl_f32x4 = vcvtq_f32_u32(vmovl_u16(vget_low_u16(vb_h_16x8)));
531 float32x4_t vb_hh_f32x4 = vcvtq_f32_u32(vmovl_u16(vget_high_u16(vb_h_16x8)));
532
533 float32x4_t vb_ll_re_f32x4 = reciprocal_simd(vb_ll_f32x4);
534 float32x4_t vb_lh_re_f32x4 = reciprocal_simd(vb_lh_f32x4);
535 float32x4_t vb_hl_re_f32x4 = reciprocal_simd(vb_hl_f32x4);
536 float32x4_t vb_hh_re_f32x4 = reciprocal_simd(vb_hh_f32x4);
537
538 float32x4_t dst_ll_f32x4 = round_simd(vmulq_f32(va_ll_f32x4, vb_ll_re_f32x4));
539 float32x4_t dst_lh_f32x4 = round_simd(vmulq_f32(va_lh_f32x4, vb_lh_re_f32x4));
540 float32x4_t dst_hl_f32x4 = round_simd(vmulq_f32(va_hl_f32x4, vb_hl_re_f32x4));
541 float32x4_t dst_hh_f32x4 = round_simd(vmulq_f32(va_hh_f32x4, vb_hh_re_f32x4));
542
543 uint32x4_t dst_ll_32x4 = vcvtq_u32_f32(dst_ll_f32x4);
544 uint32x4_t dst_lh_32x4 = vcvtq_u32_f32(dst_lh_f32x4);
545 uint32x4_t dst_hl_32x4 = vcvtq_u32_f32(dst_hl_f32x4);
546 uint32x4_t dst_hh_32x4 = vcvtq_u32_f32(dst_hh_f32x4);
547
548 uint16x4_t dst_ll_16x4 = vqmovn_u32(dst_ll_32x4);
549 uint16x4_t dst_lh_16x4 = vqmovn_u32(dst_lh_32x4);
550 uint16x4_t dst_hl_16x4 = vqmovn_u32(dst_hl_32x4);
551 uint16x4_t dst_hh_16x4 = vqmovn_u32(dst_hh_32x4);
552
553 uint16x8_t dst_l_16x8 = vcombine_u16(dst_ll_16x4, dst_lh_16x4);
554 uint16x8_t dst_h_16x8 = vcombine_u16(dst_hl_16x4, dst_hh_16x4);
555
556 int8x8_t dst_l_8x8 = vqmovn_u16(dst_l_16x8);
557 int8x8_t dst_h_8x8 = vqmovn_u16(dst_h_16x8);
558 int8x16_t dst_8x16 = vcombine_u8(dst_l_8x8, dst_h_8x8);
559
560 dst_8x16 = vandq_u8(dst_8x16, v_mask);
561 vst1q_u8(dst + x, dst_8x16);
562 }
563 #endif
564 for (; x < total_size; x++) {
565 int32_t val = src1[x] ? std::round(src0[x] / src1[x]) : 0;
566 dst[x] = std::max<int32_t>(std::numeric_limits<uint8_t>::min(),
567 std::min<int32_t>(std::numeric_limits<uint8_t>::max(), val));
568 }
569 }
570
571 template <>
DivideImpl(const uint16_t * src0,const uint16_t * src1,uint16_t * dst,int64_t total_size)572 inline void DivideImpl(const uint16_t *src0, const uint16_t *src1, uint16_t *dst, int64_t total_size) {
573 for (size_t i = 0; i < total_size; i++) {
574 int32_t val = src1[i] ? std::round(src0[i] / src1[i]) : 0;
575 dst[i] = std::max<int32_t>(std::numeric_limits<uint16_t>::min(),
576 std::min<int32_t>(std::numeric_limits<uint16_t>::max(), val));
577 }
578 }
579
580 template <>
DivideImpl(const uint32_t * src0,const uint32_t * src1,uint32_t * dst,int64_t total_size)581 inline void DivideImpl(const uint32_t *src0, const uint32_t *src1, uint32_t *dst, int64_t total_size) {
582 for (size_t i = 0; i < total_size; i++) {
583 int64_t val = src1[i] ? std::round(src0[i] / src1[i]) : 0;
584 dst[i] = std::max<int64_t>(std::numeric_limits<uint32_t>::min(),
585 std::min<int64_t>(std::numeric_limits<uint32_t>::max(), val));
586 }
587 }
588
CheckDivide(const LiteMat & src_a,const LiteMat & src_b,LiteMat * dst)589 inline bool CheckDivide(const LiteMat &src_a, const LiteMat &src_b, LiteMat *dst) {
590 if (dst == nullptr) {
591 return false;
592 }
593
594 if (src_a.width_ != src_b.width_ || src_a.height_ != src_b.height_ || src_a.channel_ != src_b.channel_) {
595 return false;
596 }
597
598 return src_a.data_type_ == src_b.data_type_;
599 }
600
Divide(const LiteMat & src_a,const LiteMat & src_b,LiteMat * dst)601 bool Divide(const LiteMat &src_a, const LiteMat &src_b, LiteMat *dst) {
602 if (!CheckDivide(src_a, src_b, dst)) {
603 return false;
604 }
605
606 if (dst->IsEmpty()) {
607 dst->Init(src_a.width_, src_a.height_, src_a.channel_, src_a.data_type_);
608 } else if (src_a.width_ != dst->width_ || src_a.height_ != dst->height_ || src_a.channel_ != dst->channel_) {
609 return false;
610 } else if (src_a.data_type_ != dst->data_type_) {
611 return false;
612 }
613
614 int64_t total_size = src_a.height_ * src_a.width_ * src_a.channel_;
615 if (src_a.data_type_ == LDataType::INT8) {
616 DivideImpl<int8_t>(src_a, src_b, *dst, total_size);
617 } else if (src_a.data_type_ == LDataType::UINT8) {
618 DivideImpl<uint8_t>(src_a, src_b, *dst, total_size);
619 } else if (src_a.data_type_ == LDataType::INT16) {
620 DivideImpl<int16_t>(src_a, src_b, *dst, total_size);
621 } else if (src_a.data_type_ == LDataType::UINT16) {
622 DivideImpl<uint16_t>(src_a, src_b, *dst, total_size);
623 } else if (src_a.data_type_ == LDataType::INT32) {
624 DivideImpl<int32_t>(src_a, src_b, *dst, total_size);
625 } else if (src_a.data_type_ == LDataType::UINT32) {
626 DivideImpl<uint32_t>(src_a, src_b, *dst, total_size);
627 } else if (src_a.data_type_ == LDataType::INT64) {
628 DivideImpl<int64_t>(src_a, src_b, *dst, total_size);
629 } else if (src_a.data_type_ == LDataType::UINT64) {
630 DivideImpl<uint64_t>(src_a, src_b, *dst, total_size);
631 } else if (src_a.data_type_ == LDataType::FLOAT32) {
632 DivideImpl<float>(src_a, src_b, *dst, total_size);
633 } else if (src_a.data_type_ == LDataType::FLOAT64) {
634 DivideImpl<double>(src_a, src_b, *dst, total_size);
635 } else {
636 return false;
637 }
638 return true;
639 }
640
641 template <typename T>
MultiplyImpl(const T * src0,const T * src1,T * dst,int64_t total_size)642 inline void MultiplyImpl(const T *src0, const T *src1, T *dst, int64_t total_size) {
643 for (size_t i = 0; i < total_size; i++) {
644 dst[i] = src0[i] * src1[i];
645 }
646 }
647
648 template <>
MultiplyImpl(const uint8_t * src0,const uint8_t * src1,uint8_t * dst,int64_t total_size)649 inline void MultiplyImpl(const uint8_t *src0, const uint8_t *src1, uint8_t *dst, int64_t total_size) {
650 int64_t x = 0;
651 #ifdef ENABLE_NEON
652 const int64_t step = 32;
653 for (; x <= total_size - step; x += step) {
654 uint8x16_t v_src00 = vld1q_u8(src0 + x);
655 uint8x16_t v_src01 = vld1q_u8(src0 + x + 16);
656 uint8x16_t v_src10 = vld1q_u8(src1 + x);
657 uint8x16_t v_src11 = vld1q_u8(src1 + x + 16);
658 uint8x16_t v_dst_l, v_dst_h;
659
660 v_dst_l = vmull_u8(vget_low_u8(v_src00), vget_low_u8(v_src10));
661 v_dst_h = vmull_u8(vget_high_u8(v_src00), vget_high_u8(v_src10));
662 vst1q_u8(dst + x, vcombine_u8(vqmovn_u16(v_dst_l), vqmovn_u16(v_dst_h)));
663
664 v_dst_l = vmull_u8(vget_low_u8(v_src01), vget_low_u8(v_src11));
665 v_dst_h = vmull_u8(vget_high_u8(v_src01), vget_high_u8(v_src11));
666 vst1q_u8(dst + x + 16, vcombine_u8(vqmovn_u16(v_dst_l), vqmovn_u16(v_dst_h)));
667 }
668 #endif
669 for (; x < total_size; x++) {
670 int32_t val = src0[x] * src1[x];
671 dst[x] = std::max<int32_t>(std::numeric_limits<uint8_t>::min(),
672 std::min<int32_t>(std::numeric_limits<uint8_t>::max(), val));
673 }
674 }
675
676 template <>
MultiplyImpl(const uint16_t * src0,const uint16_t * src1,uint16_t * dst,int64_t total_size)677 inline void MultiplyImpl(const uint16_t *src0, const uint16_t *src1, uint16_t *dst, int64_t total_size) {
678 for (size_t i = 0; i < total_size; i++) {
679 int32_t val = src0[i] * src1[i];
680 dst[i] = std::max<int32_t>(std::numeric_limits<uint16_t>::min(),
681 std::min<int32_t>(std::numeric_limits<uint16_t>::max(), val));
682 }
683 }
684
685 template <>
MultiplyImpl(const uint32_t * src0,const uint32_t * src1,uint32_t * dst,int64_t total_size)686 inline void MultiplyImpl(const uint32_t *src0, const uint32_t *src1, uint32_t *dst, int64_t total_size) {
687 for (size_t i = 0; i < total_size; i++) {
688 int64_t val = src0[i] * src1[i];
689 dst[i] = std::max<int64_t>(std::numeric_limits<uint32_t>::min(),
690 std::min<int64_t>(std::numeric_limits<uint32_t>::max(), val));
691 }
692 }
693
CheckMultiply(const LiteMat & src_a,const LiteMat & src_b,LiteMat * dst)694 inline bool CheckMultiply(const LiteMat &src_a, const LiteMat &src_b, LiteMat *dst) {
695 if (dst == nullptr) {
696 return false;
697 }
698
699 if (src_a.width_ != src_b.width_ || src_a.height_ != src_b.height_ || src_a.channel_ != src_b.channel_) {
700 return false;
701 }
702
703 return src_a.data_type_ == src_b.data_type_;
704 }
705
Multiply(const LiteMat & src_a,const LiteMat & src_b,LiteMat * dst)706 bool Multiply(const LiteMat &src_a, const LiteMat &src_b, LiteMat *dst) {
707 if (!CheckMultiply(src_a, src_b, dst)) {
708 return false;
709 }
710 if (dst->IsEmpty()) {
711 dst->Init(src_a.width_, src_a.height_, src_a.channel_, src_a.data_type_);
712 } else if (src_a.width_ != dst->width_ || src_a.height_ != dst->height_ || src_a.channel_ != dst->channel_) {
713 return false;
714 } else if (src_a.data_type_ != dst->data_type_) {
715 return false;
716 }
717
718 int64_t total_size = src_a.height_ * src_a.width_ * src_a.channel_;
719 if (src_a.data_type_ == LDataType::INT8) {
720 MultiplyImpl<int8_t>(src_a, src_b, *dst, total_size);
721 } else if (src_a.data_type_ == LDataType::UINT8) {
722 MultiplyImpl<uint8_t>(src_a, src_b, *dst, total_size);
723 } else if (src_a.data_type_ == LDataType::INT16) {
724 MultiplyImpl<int16_t>(src_a, src_b, *dst, total_size);
725 } else if (src_a.data_type_ == LDataType::UINT16) {
726 MultiplyImpl<uint16_t>(src_a, src_b, *dst, total_size);
727 } else if (src_a.data_type_ == LDataType::INT32) {
728 MultiplyImpl<int32_t>(src_a, src_b, *dst, total_size);
729 } else if (src_a.data_type_ == LDataType::UINT32) {
730 MultiplyImpl<uint32_t>(src_a, src_b, *dst, total_size);
731 } else if (src_a.data_type_ == LDataType::INT64) {
732 MultiplyImpl<int64_t>(src_a, src_b, *dst, total_size);
733 } else if (src_a.data_type_ == LDataType::UINT64) {
734 MultiplyImpl<uint64_t>(src_a, src_b, *dst, total_size);
735 } else if (src_a.data_type_ == LDataType::FLOAT32) {
736 MultiplyImpl<float>(src_a, src_b, *dst, total_size);
737 } else if (src_a.data_type_ == LDataType::FLOAT64) {
738 MultiplyImpl<double>(src_a, src_b, *dst, total_size);
739 } else {
740 return false;
741 }
742 return true;
743 }
744
745 } // namespace dataset
746 } // namespace mindspore
747