• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "src/runtime/kernel/arm/fp32/instance_norm_fp32.h"
17 #include "schema/model_generated.h"
18 #include "src/kernel_registry.h"
19 #include "include/errorcode.h"
20 #include "nnacl/fp32/instance_norm_fp32.h"
21 #include "nnacl/fp32/pack_fp32.h"
22 
23 using mindspore::kernel::KERNEL_ARCH;
24 using mindspore::lite::KernelRegistrar;
25 using mindspore::lite::RET_ERROR;
26 using mindspore::lite::RET_OK;
27 using mindspore::schema::PrimitiveType_InstanceNorm;
28 
29 namespace mindspore::kernel {
Init()30 int InstanceNormCPUKernel::Init() {
31   CHECK_LESS_RETURN(in_tensors_.size(), DIMENSION_3D);
32   CHECK_LESS_RETURN(out_tensors_.size(), 1);
33   if (!InferShapeDone()) {
34     return RET_OK;
35   }
36   return ReSize();
37 }
38 
ReSize()39 int InstanceNormCPUKernel::ReSize() {
40   auto in_tensor = in_tensors_.front();
41   param_->batch_ = in_tensor->Batch();
42   param_->inner_size_ = in_tensor->Height() * in_tensor->Width();
43   param_->channel_ = in_tensor->Channel();
44   param_->op_parameter_.thread_num_ = MSMIN(UP_DIV(param_->channel_, C8NUM), op_parameter_->thread_num_);
45   return RET_OK;
46 }
47 
DoInstanceNorm(int task_id)48 int InstanceNormCPUKernel::DoInstanceNorm(int task_id) {
49   int ret = 0;
50   if (in_tensors_[0]->format() == NC4HW4) {  // arm64 x86-avx x86-sse x86
51 #ifdef ENABLE_AVX
52     ret = InstanceNormNC8HW8(tmp_src_data_, dst_data_, gamma_data_, beta_data_, param_, task_id);
53 #else
54     ret = InstanceNormNC4HW4(tmp_src_data_, dst_data_, gamma_data_, beta_data_, param_, task_id);
55 #endif
56   } else {
57     ret = InstanceNorm(tmp_src_data_, dst_data_, gamma_data_, beta_data_, param_, task_id);
58   }
59   if (ret != RET_OK) {
60     MS_LOG(ERROR) << "DoInstanceNorm error error_code[" << ret << "]";
61     return ret;
62   }
63   return RET_OK;
64 }
65 
InstanceNormRun(void * cdata,int task_id,float lhs_scale,float rhs_scale)66 int InstanceNormRun(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
67   auto kernel = reinterpret_cast<InstanceNormCPUKernel *>(cdata);
68   auto ret = kernel->DoInstanceNorm(task_id);
69   if (ret != RET_OK) {
70     MS_LOG(ERROR) << "InstanceNormRun error task_id[" << task_id << "] error_code[" << ret << "]";
71     return RET_ERROR;
72   }
73   return RET_OK;
74 }
75 
Run()76 int InstanceNormCPUKernel::Run() {
77   src_data_ = reinterpret_cast<float *>(in_tensors_.at(0)->data());
78   gamma_data_ = reinterpret_cast<float *>(in_tensors_.at(1)->data());
79   beta_data_ = reinterpret_cast<float *>(in_tensors_.at(2)->data());
80   dst_data_ = reinterpret_cast<float *>(out_tensors_.at(0)->data());
81   CHECK_NULL_RETURN(src_data_);
82   CHECK_NULL_RETURN(gamma_data_);
83   CHECK_NULL_RETURN(beta_data_);
84   CHECK_NULL_RETURN(dst_data_);
85   if (in_tensors_[0]->format() == NC4HW4) {
86 #if defined(ENABLE_AVX) || defined(ENABLE_ARM64)
87     tmp_src_data_ = src_data_;
88 #else  // other platform is not support nc4hw4 and must be pack to nc4hw4
89     tmp_src_data_ = reinterpret_cast<float *>(ms_context_->allocator->Malloc(in_tensors_[0]->Size()));
90     CHECK_NULL_RETURN(tmp_src_data_);
91     PackNHWCToNC4HW4Fp32(src_data_, tmp_src_data_, param_->batch_, param_->inner_size_, param_->channel_);
92 #endif
93   } else {
94     tmp_src_data_ = src_data_;
95   }
96   auto ret = ParallelLaunch(this->ms_context_, InstanceNormRun, this, op_parameter_->thread_num_);
97   if (ret != RET_OK) {
98     MS_LOG(ERROR) << "InstanceNormRun error error_code[" << ret << "]";
99   }
100   if (in_tensors_[0]->format() == NC4HW4) {
101 #if (!defined(ENABLE_AVX) && !defined(ENABLE_ARM64))
102     FreeTmpBuffer();
103 #endif
104   }
105   return ret;
106 }
107 
108 REG_KERNEL(kCPU, kNumberTypeFloat32, PrimitiveType_InstanceNorm, LiteKernelCreator<InstanceNormCPUKernel>)
109 }  // namespace mindspore::kernel
110