1 /**
2 * Copyright 2020 Huawei Technologies Co., Ltd
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "src/runtime/kernel/arm/int8/relux_int8.h"
18 #include "schema/model_generated.h"
19 #include "src/kernel_registry.h"
20 #include "include/errorcode.h"
21
22 using mindspore::kernel::KERNEL_ARCH;
23 using mindspore::lite::KernelRegistrar;
24 using mindspore::lite::RET_ERROR;
25 using mindspore::lite::RET_OK;
26 using mindspore::schema::ActivationType_RELU;
27
28 namespace mindspore::kernel {
Init()29 int ReluXInt8CPUKernel::Init() {
30 lite::Tensor *input = in_tensors_.at(0);
31 lite::Tensor *output = out_tensors_.at(0);
32 MS_ASSERT(input);
33 MS_ASSERT(output);
34
35 quant_arg_.input_arg.scale_ = input->quant_params().front().scale;
36 quant_arg_.input_arg.zp_ = input->quant_params().front().zeroPoint;
37 quant_arg_.output_arg.scale_ = output->quant_params().front().scale;
38 quant_arg_.output_arg.zp_ = output->quant_params().front().zeroPoint;
39
40 const double multiplier = quant_arg_.input_arg.scale_ / quant_arg_.output_arg.scale_;
41 QuantizeRoundParameterWithDoublePrecision(multiplier, &quant_arg_.input_multiplier_, &quant_arg_.left_shift_,
42 &quant_arg_.right_shift_);
43
44 return RET_OK;
45 }
46
ReSize()47 int ReluXInt8CPUKernel::ReSize() { return RET_OK; }
48
DoActivation(int task_id)49 int ReluXInt8CPUKernel::DoActivation(int task_id) {
50 auto input_addr = reinterpret_cast<int8_t *>(in_tensors_.at(0)->MutableData());
51 MS_ASSERT(input_addr);
52 auto output_addr = reinterpret_cast<int8_t *>(out_tensors_.at(0)->MutableData());
53 MS_ASSERT(output_addr);
54 auto length = in_tensors_.at(0)->ElementsNum();
55
56 int stride = UP_DIV(length, op_parameter_->thread_num_);
57 int count = MSMIN(stride, length - stride * task_id);
58
59 ReluXInt8(input_addr + stride * task_id, count, output_addr + stride * task_id, &quant_arg_);
60 return RET_OK;
61 }
62
ReluXInt8Run(void * cdata,int task_id,float lhs_scale,float rhs_scale)63 int ReluXInt8Run(void *cdata, int task_id, float lhs_scale, float rhs_scale) {
64 auto activation_kernel = reinterpret_cast<ReluXInt8CPUKernel *>(cdata);
65 auto error_code = activation_kernel->DoActivation(task_id);
66 if (error_code != RET_OK) {
67 MS_LOG(ERROR) << "ReluXInt8Run error task_id[" << task_id << "] error_code[" << error_code << "]";
68 return RET_ERROR;
69 }
70 return RET_OK;
71 }
72
Run()73 int ReluXInt8CPUKernel::Run() {
74 int error_code = ParallelLaunch(this->ms_context_, ReluXInt8Run, this, op_parameter_->thread_num_);
75 if (error_code != RET_OK) {
76 MS_LOG(ERROR) << "ReluXInt8Run function error error_code[" << error_code << "]";
77 return RET_ERROR;
78 }
79 return RET_OK;
80 }
81 } // namespace mindspore::kernel
82