• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "src/runtime/kernel/arm/int8/group_convolution_int8.h"
18 #include "src/runtime/kernel/arm/int8/convolution_int8_creator.h"
19 
20 using mindspore::lite::RET_OK;
21 
22 namespace mindspore::kernel {
SeparateInput(int group_id)23 int GroupConvolutionInt8CPUKernel::SeparateInput(int group_id) {
24   int in_plane = conv_param_->input_h_ * conv_param_->input_w_ * conv_param_->input_batch_;
25   int sub_in_channel = conv_param_->input_channel_;
26   int ori_in_channel = sub_in_channel * group_num_;
27   auto sub_in_data =
28     reinterpret_cast<int8_t *>(static_cast<lite::Tensor *>(group_convs_.at(group_id)->in_tensors().front())->data());
29   int8_t *src_ptr = reinterpret_cast<int8_t *>(ori_in_data_) + group_id * sub_in_channel;
30   int8_t *dst_ptr = sub_in_data;
31   for (int i = 0; i < in_plane; ++i) {
32     memcpy(dst_ptr, src_ptr, static_cast<size_t>(sub_in_channel) * sizeof(int8_t));
33     src_ptr += ori_in_channel;
34     dst_ptr += sub_in_channel;
35   }
36   return RET_OK;
37 }
38 
PostConcat(int group_id)39 int GroupConvolutionInt8CPUKernel::PostConcat(int group_id) {
40   int out_plane = conv_param_->output_h_ * conv_param_->output_w_ * conv_param_->output_batch_;
41   int sub_out_channel = conv_param_->output_channel_;
42   int ori_out_channel = sub_out_channel * group_num_;
43   auto sub_out_data =
44     reinterpret_cast<int8_t *>(static_cast<lite::Tensor *>(group_convs_.at(group_id)->out_tensors().front())->data());
45   int8_t *src_ptr = sub_out_data;
46   int8_t *dst_ptr = reinterpret_cast<int8_t *>(ori_out_data_) + group_id * sub_out_channel;
47   for (int i = 0; i < out_plane; ++i) {
48     memcpy(dst_ptr, src_ptr, static_cast<size_t>(sub_out_channel) * sizeof(int8_t));
49     src_ptr += sub_out_channel;
50     dst_ptr += ori_out_channel;
51   }
52   return RET_OK;
53 }
54 
Init()55 int GroupConvolutionInt8CPUKernel::Init() {
56   if (group_conv_creator_ == nullptr) {
57     return lite::RET_ERROR;
58   }
59   group_conv_creator_->SetShapeOfTensors();
60   for (int i = 0; i < conv_param_->group_; ++i) {
61     auto *new_conv_param = CreateNewConvParameter(conv_param_);
62     std::vector<lite::Tensor *> new_inputs;
63     std::vector<lite::Tensor *> new_outputs;
64     auto ret = group_conv_creator_->GetSingleConvParam(new_conv_param, &new_inputs, &new_outputs, i);
65     if (ret != RET_OK) {
66       MS_LOG(ERROR) << "GetSingleConv for fp32 group conv failed.";
67       return lite::RET_ERROR;
68     }
69     group_conv_creator_->CopyQuantParam(&new_inputs);
70     group_convs_.emplace_back(
71       CpuConvInt8KernelSelect(new_inputs, new_outputs, reinterpret_cast<OpParameter *>(new_conv_param), ctx_));
72   }
73   return GroupConvolutionBaseCPUKernel::Init();
74 }
75 }  // namespace mindspore::kernel
76