• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_GROUP_CONVOLUTION_CREATOR_H_
18 #define MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_GROUP_CONVOLUTION_CREATOR_H_
19 
20 #include <utility>
21 #include <vector>
22 #include "src/inner_kernel.h"
23 #include "nnacl/conv_parameter.h"
24 
25 namespace mindspore::kernel {
26 struct TensorInfo {
27   std::vector<int> shape_;
28   mindspore::Format format_;
29   TypeId data_type_;
30   lite::Tensor::Category tensor_type_;
31   bool is_in_;
32 };
33 
34 class GroupConvCreator {
35  public:
GroupConvCreator(std::vector<lite::Tensor * > inputs,std::vector<lite::Tensor * > outputs,OpParameter * op_parameter,const lite::InnerContext * ctx,bool is_quant,TypeId data_type)36   GroupConvCreator(std::vector<lite::Tensor *> inputs, std::vector<lite::Tensor *> outputs, OpParameter *op_parameter,
37                    const lite::InnerContext *ctx, bool is_quant, TypeId data_type)
38       : origin_inputs_(std::move(inputs)),
39         origin_outputs_(std::move(outputs)),
40         is_quant_(is_quant),
41         data_type_(data_type),
42         ctx_(ctx) {
43     auto shape = origin_outputs_.front()->shape();
44     infered_ = std::find(shape.begin(), shape.end(), -1) == shape.end();
45     conv_param_ = reinterpret_cast<ConvParameter *>(op_parameter);
46   }
47 
48   ~GroupConvCreator() = default;
49 
50  public:
51   void SetShapeOfTensors();
52   int CreateConvs(std::vector<kernel::InnerKernel *> *group_convs);
get_group_conv()53   std::vector<kernel::InnerKernel *> *get_group_conv() { return &group_convs_; }
54   void CopyQuantParam(std::vector<lite::Tensor *> *tensors);
55   int GetSingleConvParam(ConvParameter *conv_param, std::vector<lite::Tensor *> *new_inputs,
56                          std::vector<lite::Tensor *> *new_outputs, int group_id);
57 
58  protected:
set_input_shape(const std::vector<int> & shape)59   void set_input_shape(const std::vector<int> &shape) { input_shape_ = shape; }
set_output_shape(const std::vector<int> & shape)60   void set_output_shape(const std::vector<int> &shape) { output_shape_ = shape; }
set_filter_shape(const std::vector<int> & shape)61   void set_filter_shape(const std::vector<int> &shape) { filter_shape_ = shape; }
set_bias_shape(const std::vector<int> & shape)62   void set_bias_shape(const std::vector<int> &shape) { bias_shape_ = shape; }
63   void FreeGroupConvs();
64   int NewInputTensor(std::vector<lite::Tensor *> *tensors);
65   int NewConstTensor(std::vector<lite::Tensor *> *tensors, int group_id);
66   int NewOutputTensor(std::vector<lite::Tensor *> *tensors, lite::Tensor *output);
67 
68  private:
69   std::vector<lite::Tensor *> origin_inputs_;
70   std::vector<lite::Tensor *> origin_outputs_;
71   std::vector<kernel::InnerKernel *> group_convs_;
72   std::vector<int> input_shape_;
73   std::vector<int> output_shape_;
74   std::vector<int> filter_shape_;
75   std::vector<int> bias_shape_;
76   ConvParameter *conv_param_;
77   bool infered_ = false;
78   bool is_quant_ = false;
79   TypeId data_type_;
80   const lite::InnerContext *ctx_ = nullptr;
81 };
82 
83 ConvParameter *CreateNewConvParameter(ConvParameter *parameter);
84 }  // namespace mindspore::kernel
85 
86 #endif  // MINDSPORE_LITE_SRC_RUNTIME_KERNEL_ARM_GROUP_CONVOLUTION_CREATOR_H_
87