1 /** 2 * Copyright 2020-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_GROUP_CONVOLUTION_CREATOR_H_ 18 #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_GROUP_CONVOLUTION_CREATOR_H_ 19 20 #include <utility> 21 #include <vector> 22 #include "src/inner_kernel.h" 23 #include "nnacl/conv_parameter.h" 24 25 namespace mindspore::kernel { 26 struct TensorInfo { 27 std::vector<int> shape_; 28 mindspore::Format format_; 29 TypeId data_type_; 30 lite::Tensor::Category tensor_type_; 31 bool is_in_; 32 }; 33 34 class GroupConvCreator { 35 public: GroupConvCreator(std::vector<lite::Tensor * > inputs,std::vector<lite::Tensor * > outputs,OpParameter * op_parameter,const lite::InnerContext * ctx,bool is_quant,TypeId data_type)36 GroupConvCreator(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs, OpParameter *op_parameter, 37 const lite::InnerContext *ctx, bool is_quant, TypeId data_type) 38 : origin_inputs_(std::move(inputs)), 39 origin_outputs_(std::move(outputs)), 40 is_quant_(is_quant), 41 data_type_(data_type), 42 ctx_(ctx) { 43 auto shape = origin_outputs_.front()->shape(); 44 infered_ = std::find(shape.begin(), shape.end(), -1) == shape.end(); 45 conv_param_ = reinterpret_cast<ConvParameter *>(op_parameter); 46 } 47 48 ~GroupConvCreator() = default; 49 50 public: 51 void SetShapeOfTensors(); 52 int CreateConvs(std::vector<kernel::InnerKernel *> *group_convs); get_group_conv()53 std::vector<kernel::InnerKernel *> *get_group_conv() { return &group_convs_; } 54 void CopyQuantParam(std::vector<lite::Tensor *> *tensors); 55 int GetSingleConvParam(ConvParameter *conv_param, std::vector<lite::Tensor *> *new_inputs, 56 std::vector<lite::Tensor *> *new_outputs, int group_id); 57 58 protected: set_input_shape(const std::vector<int> & shape)59 void set_input_shape(const std::vector<int> &shape) { input_shape_ = shape; } set_output_shape(const std::vector<int> & shape)60 void set_output_shape(const std::vector<int> &shape) { output_shape_ = shape; } set_filter_shape(const std::vector<int> & shape)61 void set_filter_shape(const std::vector<int> &shape) { filter_shape_ = shape; } set_bias_shape(const std::vector<int> & shape)62 void set_bias_shape(const std::vector<int> &shape) { bias_shape_ = shape; } 63 void FreeGroupConvs(); 64 int NewInputTensor(std::vector<lite::Tensor *> *tensors); 65 int NewConstTensor(std::vector<lite::Tensor *> *tensors, int group_id); 66 int NewOutputTensor(std::vector<lite::Tensor *> *tensors, lite::Tensor *output); 67 68 private: 69 std::vector<lite::Tensor *> origin_inputs_; 70 std::vector<lite::Tensor *> origin_outputs_; 71 std::vector<kernel::InnerKernel *> group_convs_; 72 std::vector<int> input_shape_; 73 std::vector<int> output_shape_; 74 std::vector<int> filter_shape_; 75 std::vector<int> bias_shape_; 76 ConvParameter *conv_param_; 77 bool infered_ = false; 78 bool is_quant_ = false; 79 TypeId data_type_; 80 const lite::InnerContext *ctx_ = nullptr; 81 }; 82 83 ConvParameter *CreateNewConvParameter(ConvParameter *parameter); 84 } // namespace mindspore::kernel 85 86 #endif // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_GROUP_CONVOLUTION_CREATOR_H_ 87