• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_TOOLS_COMMON_TENSOR_UTIL_H
18 #define MINDSPORE_LITE_TOOLS_COMMON_TENSOR_UTIL_H
19 
20 #include <cmath>
21 #include <unordered_map>
22 #include <memory>
23 #include <algorithm>
24 #include <utility>
25 #include <string>
26 #include <vector>
27 #include "schema/inner/model_generated.h"
28 #include "src/common/log_adapter.h"
29 #include "ir/dtype/type_id.h"
30 #include "ir/tensor.h"
31 #include "src/common/utils.h"
32 
33 namespace mindspore {
34 namespace lite {
35 using schema::CNodeT;
36 using schema::Format;
37 using schema::FusedBatchNormT;
38 using schema::MetaGraphT;
39 using schema::QuantParamT;
40 using schema::TensorT;
41 
42 std::unique_ptr<QuantParamT> GetTensorQuantParam(const std::unique_ptr<TensorT> &tensor);
43 
44 tensor::TensorPtr CreateTensorInfo(const void *data, size_t data_size, const std::vector<int64_t> &shape,
45                                    TypeId data_type);
46 
47 AbstractBasePtr CreateTensorAbstract(const std::vector<int64_t> &shape, TypeId data_type);
48 
49 int SetParameterAbstractAndParam(const ParameterPtr &parameter, const void *data, size_t data_size,
50                                  const std::vector<int64_t> &shape, TypeId data_type);
51 
52 int SetTensorData(const tensor::TensorPtr &tensor_info, const void *data, size_t data_size);
53 
54 std::unique_ptr<schema::TensorT> CreateTensorTFromTensorInfo(const tensor::TensorPtr &tensor_info,
55                                                              const std::string &tensor_name = "");
56 
57 int UpdateTensorTFromTensorInfo(const tensor::TensorPtr &src_tensor, std::unique_ptr<schema::TensorT> *dst_tensor);
58 
59 int InitParameterFromTensorInfo(const ParameterPtr &param_node, const tensor::TensorPtr &tensor_info);
60 
61 size_t GetElementSize(const TensorT &tensor);
62 
63 size_t GetElementSize(const TypeId &dataType);
64 
65 size_t GetShapeSize(const TensorT &tensor);
66 
67 size_t GetShapeSize(const std::vector<int32_t> &shape);
68 
69 std::unique_ptr<TensorT> CopyTensorDefT(const std::unique_ptr<TensorT> &);
70 
71 size_t GetRefCount(schema::MetaGraphT *graphT, uint32_t tensorIdx);
72 
73 std::unique_ptr<schema::QuantParamT> CopyQuantParamT(const std::unique_ptr<schema::QuantParamT> &srcQuantParam);
74 
75 std::unique_ptr<schema::QuantParamT> CopyQuantParamArrayT(
76   const std::unique_ptr<schema::QuantParamT> &srcQuantParamArray);
77 
78 enum Category { CONSTANT = 0, GRAPH_INPUT = 1, OP_OUTPUT = 2, TF_CONST = 3 };
79 
80 class TensorCache {
81  public:
82   TensorCache() = default;
83 
~TensorCache()84   ~TensorCache() { tensors.clear(); }
85 
AddTensor(const std::string & name,TensorT * tensor,int Category)86   int AddTensor(const std::string &name, TensorT *tensor, int Category) {
87     index++;
88     if (Category == CONSTANT || Category == TF_CONST || Category == GRAPH_INPUT) {
89       tensor->refCount = 1;
90       tensor->nodeType = static_cast<int32_t>(MSNodeType::NodeType_ValueNode);
91     } else {
92       tensor->nodeType = static_cast<int32_t>(MSNodeType::NodeType_Parameter);
93     }
94     tensor->name = name;
95     tensors.push_back(tensor);
96 
97     if (Category == GRAPH_INPUT) {
98       graphInputs.push_back(index);
99     }
100 
101     if (Category == GRAPH_INPUT || Category == OP_OUTPUT || Category == TF_CONST) {
102       UpdateTensorIndex(name, index);
103     }
104     return index;
105   }
106 
107   // find the name index
FindTensor(const std::string & name)108   int FindTensor(const std::string &name) {
109     auto iter = tensorIndex.find(name);
110     if (iter != tensorIndex.end()) {
111       return iter->second;
112     }
113     return -1;
114   }
115 
UpdateTensorIndex(const std::string & name,int idx)116   void UpdateTensorIndex(const std::string &name, int idx) {
117     auto iter = tensorIndex.find(name);
118     if (iter != tensorIndex.end()) {
119       tensorIndex[name] = idx;
120     } else {
121       tensorIndex.insert(make_pair(name, idx));
122     }
123   }
124 
125   // return allTensors
GetCachedTensor()126   const std::vector<TensorT *> &GetCachedTensor() const { return tensors; }
127 
GetGraphInputs()128   const std::vector<int> &GetGraphInputs() const { return graphInputs; }
129 
130  private:
131   std::vector<TensorT *> tensors;
132   std::unordered_map<std::string, int> tensorIndex;
133   std::vector<int> graphInputs;
134   int index = -1;
135 };
136 }  // namespace lite
137 }  // namespace mindspore
138 
139 #endif  // MINDSPORE_LITE_TOOLS_COMMON_TENSOR_UTIL_H
140