• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "parser/caffe/caffe_innerproduct_parser.h"
18 #include <memory>
19 #include "common/op_enum.h"
20 #include "common/op_attr.h"
21 #include "common/data_transpose_utils.h"
22 #include "ops/fusion/full_connection.h"
23 
24 namespace mindspore {
25 namespace lite {
26 namespace {
27 constexpr int kInnerProductAxis = 2;
TransformShape(caffe::BlobShape * shape)28 void TransformShape(caffe::BlobShape *shape) {
29   auto origin_row = shape->dim(0);
30   auto origin_col = shape->dim(1);
31   shape->clear_dim();
32   shape->add_dim(origin_col);
33   shape->add_dim(origin_row);
34 }
35 }  // namespace
Parse(const caffe::LayerParameter & proto,const caffe::LayerParameter & weight)36 BaseOperatorPtr CaffeInnerProductParser::Parse(const caffe::LayerParameter &proto,
37                                                const caffe::LayerParameter &weight) {
38   auto prim = std::make_shared<ops::FullConnection>();
39   if (prim == nullptr) {
40     MS_LOG(ERROR) << "prim is nullptr.";
41     return nullptr;
42   }
43   prim->set_activation_type(mindspore::ActivationType::NO_ACTIVATION);
44   const caffe::InnerProductParameter &innerProductParam = proto.inner_product_param();
45 
46   if (innerProductParam.has_transpose() && innerProductParam.transpose()) {
47     auto mutable_weight = const_cast<caffe::LayerParameter *>(&weight);
48     if (mutable_weight == nullptr) {
49       MS_LOG(ERROR) << "weight is nullptr.";
50       return nullptr;
51     }
52     auto blob = mutable_weight->mutable_blobs(0);
53     if (blob == nullptr) {
54       MS_LOG(ERROR) << "blob is nullptr.";
55       return nullptr;
56     }
57     auto shape = blob->mutable_shape();
58     if (shape == nullptr) {
59       MS_LOG(ERROR) << "shape is nullptr.";
60       return nullptr;
61     }
62     if (blob->mutable_data() == nullptr) {
63       MS_LOG(ERROR) << "blob mutable data is nullptr.";
64       return nullptr;
65     }
66     if (shape->dim_size() < kNums2) {
67       MS_LOG(ERROR) << "weight shape size " << shape->dim_size() << " should greater than 1";
68       return nullptr;
69     }
70     if (shape->dim(0) == 0 || shape->dim(1) == 0) {
71       MS_LOG(ERROR) << "dim val can't be 0.";
72       return nullptr;
73     }
74     dpico::TransposeMatrix(blob->mutable_data()->mutable_data(), static_cast<int>(shape->dim(0)),
75                            static_cast<int>(shape->dim(1)));
76     TransformShape(shape);
77   }
78 
79   if (!innerProductParam.has_num_output()) {
80     MS_LOG(ERROR) << "InnerProduct Parse num_output for " << proto.name().c_str() << " failed.";
81     return nullptr;
82   } else {
83     (void)prim->AddAttr(dpico::kNumOutput, api::MakeValue(static_cast<int64_t>(innerProductParam.num_output())));
84   }
85 
86   if (innerProductParam.axis() == 1 || innerProductParam.axis() == kInnerProductAxis) {
87     prim->set_axis(innerProductParam.axis());
88     prim->set_use_axis(true);
89   } else {
90     MS_LOG(ERROR) << "InnerProduct Parse axis only support default 1 OR 2, but actually " << innerProductParam.axis();
91     return nullptr;
92   }
93   if (innerProductParam.bias_term()) {
94     prim->set_has_bias(true);
95   }
96 
97   return prim;
98 }
99 
100 CaffeNodeRegistrar g_caffeInnerProductParser("InnerProduct", new CaffeInnerProductParser());
101 }  // namespace lite
102 }  // namespace mindspore
103