• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_EXTENDRT_GRAPH_COMPILER_ANFNODE_TENSOR_ADAPTER_H_
18 #define MINDSPORE_LITE_EXTENDRT_GRAPH_COMPILER_ANFNODE_TENSOR_ADAPTER_H_
19 #include <string>
20 #include <memory>
21 #include <utility>
22 #include <unordered_map>
23 #include <vector>
24 #include "src/infer/tensor.h"
25 #include "abstract/abstract_value.h"
26 #include "ir/anf.h"
27 #include "include/api/status.h"
28 
29 namespace mindspore {
30 namespace lite {
31 class TensorAdapter;
32 using TensorAdapterPtr = std::shared_ptr<TensorAdapter>;
33 class TensorAdapter {
34  public:
TensorAdapter(std::string name)35   explicit TensorAdapter(std::string name) : name_(std::move(name)) {}
~TensorAdapter()36   virtual ~TensorAdapter() {
37     if (own_data_) {
38       free(data_);
39     }
40   }
41 
42   InferTensor *ToTensor();
43 
44   static TensorAdapterPtr Create(const ParameterPtr &param_node, Format format = DEFAULT_FORMAT);
45   static TensorAdapterPtr Create(const ValueNodePtr &value_node, Format format = DEFAULT_FORMAT);
46   static TensorAdapterPtr Create(const mindspore::abstract::AbstractTensorPtr &abstract,
47                                  Format format = DEFAULT_FORMAT);
48   static TensorAdapterPtr Create(const mindspore::abstract::AbstractBasePtr &abstract, Format format = DEFAULT_FORMAT);
49 
50   static std::vector<std::unique_ptr<InferTensor>> CreateTensorsFromAbstract(const AbstractBasePtr &abstract,
51                                                                              Format format = Format::DEFAULT_FORMAT);
52   static std::vector<InferTensor *> Convert2Tensor(const CNodePtr &cnode, Format format = DEFAULT_FORMAT);
53   static InferTensor *Convert2Tensor(const ParameterPtr &param_node, Format format = DEFAULT_FORMAT);
54   static InferTensor *Convert2Tensor(const ValueNodePtr &value_node, Format format = DEFAULT_FORMAT);
55   static InferTensor *Convert2Tensor(const mindspore::abstract::AbstractTensorPtr &abstract,
56                                      Format format = DEFAULT_FORMAT);
57   static InferTensor *Convert2Tensor(const mindspore::abstract::AbstractBasePtr &abstract,
58                                      Format format = DEFAULT_FORMAT);
59 
60   static StatusCode GetDTAndShapeFromAbTensor(const mindspore::abstract::AbstractTensorPtr &abstract, TypeId *data_type,
61                                               ShapeVector *shape_vector);
62   static StatusCode SetDTAndShapeFromAbTensor(const TypeId &data_type, const ShapeVector &shape,
63                                               const mindspore::abstract::AbstractTensorPtr &abstract);
64   static StatusCode SetDTAndShapeFromAbTensor(const TypeId &data_type, const std::vector<int> &shape,
65                                               const mindspore::abstract::AbstractTensorPtr &abstract);
66 
67   static bool SetDTAndShapeFromAbTensorToLiteTensor(const AbstractBasePtr &abstract, InferTensor *tensor);
68   static bool SetDTAndShapeFromLiteTensorToAbTensor(const InferTensor &tensor, const AbstractBasePtr &abstract);
69 
70  private:
71   static StatusCode GetDTAndShapeFromParameter(const ParameterPtr &param_node, TypeId *data_type, ShapeVector *shape);
72 
73   static TensorAdapterPtr CreateFromTensorValueNode(const ValueNodePtr &value_node);
74 
75   static TensorAdapterPtr CreateFromInt32ImmValue(const ValueNodePtr &value_node);
76 
77   static TensorAdapterPtr CreateFromInt64ImmValue(const ValueNodePtr &value_node);
78 
79   static TensorAdapterPtr CreateFromBoolImmValue(const ValueNodePtr &value_node);
80 
81   static TensorAdapterPtr CreateFromNumberTypeValue(const ValueNodePtr &value_node);
82 
83   static TensorAdapterPtr CreateFromIntSequenceValue(const ValueNodePtr &value_node);
84 
85  public:
86   Format format_{DEFAULT_FORMAT};
87   TensorCompressionType compress_type_ = TensorCompressionType::kNoCompression;
88   TypeId data_type_{kTypeUnknown};
89   bool is_const_{false};
90   ShapeVector shape_{};
91   void *data_{nullptr};
92   size_t data_len_{0};
93   bool own_data_{true};
94   std::string name_;
95 };
96 }  // namespace lite
97 }  // namespace mindspore
98 
99 #endif
100