• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/runtime/kernel/arm/int8/relux_int8.h"
18 #include "schema/model_generated.h"
19 #include "src/kernel_registry.h"
20 #include "include/errorcode.h"
21 
22 using mindspore::kernel::KERNEL_ARCH;
23 using mindspore::lite::KernelRegistrar;
24 using mindspore::lite::RET_ERROR;
25 using mindspore::lite::RET_OK;
26 using mindspore::schema::ActivationType_RELU;
27 
28 namespace mindspore::kernel {
Init()29 int ReluXInt8CPUKernel::Init() {
30   lite::Tensor *input = in_tensors_.at(0);
31   lite::Tensor *output = out_tensors_.at(0);
32   MS_ASSERT(input);
33   MS_ASSERT(output);
34 
35   quant_arg_.input_arg.scale_ = input->quant_params().front().scale;
36   quant_arg_.input_arg.zp_ = input->quant_params().front().zeroPoint;
37   quant_arg_.output_arg.scale_ = output->quant_params().front().scale;
38   quant_arg_.output_arg.zp_ = output->quant_params().front().zeroPoint;
39 
40   const double multiplier = quant_arg_.input_arg.scale_ / quant_arg_.output_arg.scale_;
41   QuantizeRoundParameterWithDoublePrecision(multiplier, &quant_arg_.input_multiplier_, &quant_arg_.left_shift_,
42                                             &quant_arg_.right_shift_);
43 
44   return RET_OK;
45 }
46 
ReSize()47 int ReluXInt8CPUKernel::ReSize() { return RET_OK; }
48 
DoActivation(int task_id)49 int ReluXInt8CPUKernel::DoActivation(int task_id) {
50   auto input_addr = reinterpret_cast<int8_t *>(in_tensors_.at(0)->MutableData());
51   MS_ASSERT(input_addr);
52   auto output_addr = reinterpret_cast<int8_t *>(out_tensors_.at(0)->MutableData());
53   MS_ASSERT(output_addr);
54   auto length = in_tensors_.at(0)->ElementsNum();
55 
56   int stride = UP_DIV(length, op_parameter_->thread_num_);
57   int count = MSMIN(stride, length - stride * task_id);
58 
59   ReluXInt8(input_addr + stride * task_id, count, output_addr + stride * task_id, &quant_arg_);
60   return RET_OK;
61 }
62 
ReluXInt8Run(void * cdata,int task_id,float lhs_scale,float rhs_scale)63 int ReluXInt8Run(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
64   auto activation_kernel = reinterpret_cast<ReluXInt8CPUKernel *>(cdata);
65   auto error_code = activation_kernel->DoActivation(task_id);
66   if (error_code != RET_OK) {
67     MS_LOG(ERROR) << "ReluXInt8Run error task_id[" << task_id << "] error_code[" << error_code << "]";
68     return RET_ERROR;
69   }
70   return RET_OK;
71 }
72 
Run()73 int ReluXInt8CPUKernel::Run() {
74   int error_code = ParallelLaunch(this->ms_context_, ReluXInt8Run, this, op_parameter_->thread_num_);
75   if (error_code != RET_OK) {
76     MS_LOG(ERROR) << "ReluXInt8Run function error error_code[" << error_code << "]";
77     return RET_ERROR;
78   }
79   return RET_OK;
80 }
81 }  // namespace mindspore::kernel
82