1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_LITE_EXTENDRT_GRAPH_COMPILER_ANFNODE_TENSOR_ADAPTER_H_ 18 #define MINDSPORE_LITE_EXTENDRT_GRAPH_COMPILER_ANFNODE_TENSOR_ADAPTER_H_ 19 #include <string> 20 #include <memory> 21 #include <utility> 22 #include <unordered_map> 23 #include <vector> 24 #include "src/infer/tensor.h" 25 #include "abstract/abstract_value.h" 26 #include "ir/anf.h" 27 #include "include/api/status.h" 28 29 namespace mindspore { 30 namespace lite { 31 class TensorAdapter; 32 using TensorAdapterPtr = std::shared_ptr<TensorAdapter>; 33 class TensorAdapter { 34 public: TensorAdapter(std::string name)35 explicit TensorAdapter(std::string name) : name_(std::move(name)) {} ~TensorAdapter()36 virtual ~TensorAdapter() { 37 if (own_data_) { 38 free(data_); 39 } 40 } 41 42 InferTensor *ToTensor(); 43 44 static TensorAdapterPtr Create(const ParameterPtr ¶m_node, Format format = DEFAULT_FORMAT); 45 static TensorAdapterPtr Create(const ValueNodePtr &value_node, Format format = DEFAULT_FORMAT); 46 static TensorAdapterPtr Create(const mindspore::abstract::AbstractTensorPtr &abstract, 47 Format format = DEFAULT_FORMAT); 48 static TensorAdapterPtr Create(const mindspore::abstract::AbstractBasePtr &abstract, Format format = DEFAULT_FORMAT); 49 50 static std::vector<std::unique_ptr<InferTensor>> CreateTensorsFromAbstract(const AbstractBasePtr &abstract, 51 Format format = Format::DEFAULT_FORMAT); 52 static std::vector<InferTensor *> Convert2Tensor(const CNodePtr &cnode, Format format = DEFAULT_FORMAT); 53 static InferTensor *Convert2Tensor(const ParameterPtr ¶m_node, Format format = DEFAULT_FORMAT); 54 static InferTensor *Convert2Tensor(const ValueNodePtr &value_node, Format format = DEFAULT_FORMAT); 55 static InferTensor *Convert2Tensor(const mindspore::abstract::AbstractTensorPtr &abstract, 56 Format format = DEFAULT_FORMAT); 57 static InferTensor *Convert2Tensor(const mindspore::abstract::AbstractBasePtr &abstract, 58 Format format = DEFAULT_FORMAT); 59 60 static StatusCode GetDTAndShapeFromAbTensor(const mindspore::abstract::AbstractTensorPtr &abstract, TypeId *data_type, 61 ShapeVector *shape_vector); 62 static StatusCode SetDTAndShapeFromAbTensor(const TypeId &data_type, const ShapeVector &shape, 63 const mindspore::abstract::AbstractTensorPtr &abstract); 64 static StatusCode SetDTAndShapeFromAbTensor(const TypeId &data_type, const std::vector<int> &shape, 65 const mindspore::abstract::AbstractTensorPtr &abstract); 66 67 static bool SetDTAndShapeFromAbTensorToLiteTensor(const AbstractBasePtr &abstract, InferTensor *tensor); 68 static bool SetDTAndShapeFromLiteTensorToAbTensor(const InferTensor &tensor, const AbstractBasePtr &abstract); 69 70 private: 71 static StatusCode GetDTAndShapeFromParameter(const ParameterPtr ¶m_node, TypeId *data_type, ShapeVector *shape); 72 73 static TensorAdapterPtr CreateFromTensorValueNode(const ValueNodePtr &value_node); 74 75 static TensorAdapterPtr CreateFromInt32ImmValue(const ValueNodePtr &value_node); 76 77 static TensorAdapterPtr CreateFromInt64ImmValue(const ValueNodePtr &value_node); 78 79 static TensorAdapterPtr CreateFromBoolImmValue(const ValueNodePtr &value_node); 80 81 static TensorAdapterPtr CreateFromNumberTypeValue(const ValueNodePtr &value_node); 82 83 static TensorAdapterPtr CreateFromIntSequenceValue(const ValueNodePtr &value_node); 84 85 public: 86 Format format_{DEFAULT_FORMAT}; 87 TensorCompressionType compress_type_ = TensorCompressionType::kNoCompression; 88 TypeId data_type_{kTypeUnknown}; 89 bool is_const_{false}; 90 ShapeVector shape_{}; 91 void *data_{nullptr}; 92 size_t data_len_{0}; 93 bool own_data_{true}; 94 std::string name_; 95 }; 96 } // namespace lite 97 } // namespace mindspore 98 99 #endif 100