• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <string>
18 #include <memory>
19 #include <algorithm>
20 #include "common/anf_util.h"
21 #include "common/op_enum.h"
22 #include "checker/op_checker.h"
23 
24 namespace mindspore {
25 namespace dpico {
~OpCheckerRegistry()26 OpCheckerRegistry::~OpCheckerRegistry() {
27   for (auto ite : checkers) {
28     if (ite.second != nullptr) {
29       delete ite.second;
30       ite.second = nullptr;
31     }
32   }
33 }
34 
GetInstance()35 OpCheckerRegistry *OpCheckerRegistry::GetInstance() {
36   static OpCheckerRegistry instance;
37   return &instance;
38 }
39 
GetOpChecker(const std::string & type)40 OpChecker *OpCheckerRegistry::GetOpChecker(const std::string &type) {
41   auto it = checkers.find(type);
42   if (it != checkers.end()) {
43     return it->second;
44   }
45   return nullptr;
46 }
47 
GetWidth(const std::vector<int64_t> & shape,mindspore::Format format,int64_t * width)48 STATUS GetWidth(const std::vector<int64_t> &shape, mindspore::Format format, int64_t *width) {
49   if (width == nullptr) {
50     MS_LOG(ERROR) << "width is nullptr.";
51     return RET_ERROR;
52   }
53   if (shape.size() == kDims4) {
54     if (format == mindspore::Format::NCHW) {
55       *width = shape.at(kInputIndex3);
56     } else if (format == mindspore::Format::NHWC) {
57       *width = shape.at(kInputIndex2);
58     } else {
59       MS_LOG(ERROR) << "format should be NCHW or NHWC";
60       return RET_ERROR;
61     }
62   } else {
63     *width = shape.back();
64   }
65   return RET_OK;
66 }
67 
GetTensorChannel(const std::vector<int64_t> & shape,mindspore::Format format,int64_t * channel)68 STATUS GetTensorChannel(const std::vector<int64_t> &shape, mindspore::Format format, int64_t *channel) {
69   if (channel == nullptr) {
70     MS_LOG(ERROR) << "channel is nullptr.";
71     return RET_ERROR;
72   }
73   if (shape.size() != kDims4) {
74     MS_LOG(ERROR) << "shape size should be 4, but is " << shape.size();
75     return RET_ERROR;
76   } else {
77     if (format == mindspore::Format::NCHW) {
78       *channel = shape.at(kInputIndex1);
79     } else if (format == mindspore::Format::NHWC) {
80       *channel = shape.at(kInputIndex3);
81     } else {
82       MS_LOG(ERROR) << "format should be NCHW or NHWC";
83       return RET_ERROR;
84     }
85   }
86   return RET_OK;
87 }
88 
GetVectorChannel(const std::vector<int64_t> & shape,int64_t * channel)89 STATUS GetVectorChannel(const std::vector<int64_t> &shape, int64_t *channel) {
90   if (channel == nullptr) {
91     MS_LOG(ERROR) << "channel is nullptr.";
92     return RET_ERROR;
93   }
94   if (shape.size() != kDims2) {
95     MS_LOG(ERROR) << "shape size should be 2, but is " << shape.size();
96     return RET_ERROR;
97   }
98   *channel = shape.back();
99   return RET_OK;
100 }
101 
HasOfflineData(const api::AnfNodePtr & node)102 bool HasOfflineData(const api::AnfNodePtr &node) {
103   if (node == nullptr) {
104     MS_LOG(ERROR) << "node is nullptr.";
105     return false;
106   }
107   auto param = node->cast<api::ParameterPtr>();
108 
109   return param != nullptr && param->has_default();
110 }
111 
CheckInputW(const api::CNodePtr & op,size_t index,mindspore::Format format,int limit_w)112 bool CheckInputW(const api::CNodePtr &op, size_t index, mindspore::Format format, int limit_w) {
113   if (index >= op->size()) {
114     MS_LOG(ERROR) << "index:" << index << " is greater than " << op->fullname_with_scope()
115                   << " inputs size:" << op->size();
116     return false;
117   }
118   std::vector<int64_t> input_shape;
119   if (GetInputShapeFromCNode(op, index, &input_shape) == RET_OK && !input_shape.empty()) {
120     int64_t input_w;
121     if (GetWidth(input_shape, format, &input_w) != RET_OK) {
122       MS_LOG(ERROR) << "get input_w failed " << op->fullname_with_scope();
123       return false;
124     }
125     if (input_shape.size() == kDims4 && input_w > limit_w) {
126       MS_LOG(INFO) << op->fullname_with_scope() << "'s input_w:" << input_w << " exceed the maximum limit " << limit_w;
127       return false;
128     }
129   }
130   return true;
131 }
132 }  // namespace dpico
133 }  // namespace mindspore
134