1 /**
2 * Copyright 2021 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #include "src/runtime/kernel/arm/fp32/instance_norm_fp32.h"
17 #include "schema/model_generated.h"
18 #include "src/kernel_registry.h"
19 #include "include/errorcode.h"
20 #include "nnacl/fp32/instance_norm_fp32.h"
21 #include "nnacl/fp32/pack_fp32.h"
22
23 using mindspore::kernel::KERNEL_ARCH;
24 using mindspore::lite::KernelRegistrar;
25 using mindspore::lite::RET_ERROR;
26 using mindspore::lite::RET_OK;
27 using mindspore::schema::PrimitiveType_InstanceNorm;
28
29 namespace mindspore::kernel {
Init()30 int InstanceNormCPUKernel::Init() {
31 CHECK_LESS_RETURN(in_tensors_.size(), DIMENSION_3D);
32 CHECK_LESS_RETURN(out_tensors_.size(), 1);
33 if (!InferShapeDone()) {
34 return RET_OK;
35 }
36 return ReSize();
37 }
38
ReSize()39 int InstanceNormCPUKernel::ReSize() {
40 auto in_tensor = in_tensors_.front();
41 param_->batch_ = in_tensor->Batch();
42 param_->inner_size_ = in_tensor->Height() * in_tensor->Width();
43 param_->channel_ = in_tensor->Channel();
44 param_->op_parameter_.thread_num_ = MSMIN(UP_DIV(param_->channel_, C8NUM), op_parameter_->thread_num_);
45 return RET_OK;
46 }
47
DoInstanceNorm(int task_id)48 int InstanceNormCPUKernel::DoInstanceNorm(int task_id) {
49 int ret = 0;
50 if (in_tensors_[0]->format() == NC4HW4) { // arm64 x86-avx x86-sse x86
51 #ifdef ENABLE_AVX
52 ret = InstanceNormNC8HW8(tmp_src_data_, dst_data_, gamma_data_, beta_data_, param_, task_id);
53 #else
54 ret = InstanceNormNC4HW4(tmp_src_data_, dst_data_, gamma_data_, beta_data_, param_, task_id);
55 #endif
56 } else {
57 ret = InstanceNorm(tmp_src_data_, dst_data_, gamma_data_, beta_data_, param_, task_id);
58 }
59 if (ret != RET_OK) {
60 MS_LOG(ERROR) << "DoInstanceNorm error error_code[" << ret << "]";
61 return ret;
62 }
63 return RET_OK;
64 }
65
InstanceNormRun(void * cdata,int task_id,float lhs_scale,float rhs_scale)66 int InstanceNormRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
67 auto kernel = reinterpret_cast<InstanceNormCPUKernel *>(cdata);
68 auto ret = kernel->DoInstanceNorm(task_id);
69 if (ret != RET_OK) {
70 MS_LOG(ERROR) << "InstanceNormRun error task_id[" << task_id << "] error_code[" << ret << "]";
71 return RET_ERROR;
72 }
73 return RET_OK;
74 }
75
Run()76 int InstanceNormCPUKernel::Run() {
77 src_data_ = reinterpret_cast<float *>(in_tensors_.at(0)->data());
78 gamma_data_ = reinterpret_cast<float *>(in_tensors_.at(1)->data());
79 beta_data_ = reinterpret_cast<float *>(in_tensors_.at(2)->data());
80 dst_data_ = reinterpret_cast<float *>(out_tensors_.at(0)->data());
81 CHECK_NULL_RETURN(src_data_);
82 CHECK_NULL_RETURN(gamma_data_);
83 CHECK_NULL_RETURN(beta_data_);
84 CHECK_NULL_RETURN(dst_data_);
85 if (in_tensors_[0]->format() == NC4HW4) {
86 #if defined(ENABLE_AVX) || defined(ENABLE_ARM64)
87 tmp_src_data_ = src_data_;
88 #else // other platform is not support nc4hw4 and must be pack to nc4hw4
89 tmp_src_data_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(in_tensors_[0]->Size()));
90 CHECK_NULL_RETURN(tmp_src_data_);
91 PackNHWCToNC4HW4Fp32(src_data_, tmp_src_data_, param_->batch_, param_->inner_size_, param_->channel_);
92 #endif
93 } else {
94 tmp_src_data_ = src_data_;
95 }
96 auto ret = ParallelLaunch(this->ms_context_, InstanceNormRun, this, op_parameter_->thread_num_);
97 if (ret != RET_OK) {
98 MS_LOG(ERROR) << "InstanceNormRun error error_code[" << ret << "]";
99 }
100 if (in_tensors_[0]->format() == NC4HW4) {
101 #if (!defined(ENABLE_AVX) && !defined(ENABLE_ARM64))
102 FreeTmpBuffer();
103 #endif
104 }
105 return ret;
106 }
107
108 REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_InstanceNorm, LiteKernelCreator<InstanceNormCPUKernel>)
109 } // namespace mindspore::kernel
110