• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_LITE_EXTENDRT_GRAPH_COMPILER_COMPILE_RESULT_H_
18 #define MINDSPORE_LITE_EXTENDRT_GRAPH_COMPILER_COMPILE_RESULT_H_
19 #include <string>
20 #include <memory>
21 #include <sstream>
22 #include <unordered_map>
23 #include <utility>
24 #include <vector>
25 #include "ir/anf.h"
26 #include "src/infer/tensor.h"
27 #include "include/model.h"
28 #include "ops/base_operator.h"
29 #include "utils/hash_map.h"
30 #include "include/api/status.h"
31 #include "kernel/common_utils.h"
32 #include "src/infer/primitive_type.h"
33 
34 namespace mindspore {
35 namespace lite {
36 class CompileNode {
37  public:
CompileNode(std::string name,const kernel::PrimitiveType & type)38   explicit CompileNode(std::string name, const kernel::PrimitiveType &type) : name_(std::move(name)), type_(type) {}
39   static std::shared_ptr<CompileNode> Create(CNodePtr cnode);
40 
41   virtual ~CompileNode() = default;
42 
GetName()43   std::string GetName() const { return name_; }
GetType()44   kernel::PrimitiveType GetType() const { return type_; }
GetBaseOperator()45   std::shared_ptr<ops::BaseOperator> GetBaseOperator() const { return base_operator_; }
GetCNode()46   CNodePtr GetCNode() const { return cnode_; }
GetInputs()47   const std::vector<InferTensor *> &GetInputs() const { return inputs_; }
GetInput(size_t i)48   InferTensor *GetInput(size_t i) const { return inputs_.at(i); }
InputSize()49   size_t InputSize() const { return inputs_.size(); }
GetOutputs()50   const std::vector<InferTensor *> &GetOutputs() const { return outputs_; }
GetOutput(size_t i)51   InferTensor *GetOutput(size_t i) const { return outputs_.at(i); }
OutputSize()52   size_t OutputSize() const { return outputs_.size(); }
53 
SetName(const std::string & name)54   void SetName(const std::string &name) { name_ = name; }
55   void AppendInputTensor(InferTensor *tensor);
56   void AppendOutputTensor(InferTensor *tensor);
57   void ReplaceInputTensor(InferTensor *dst, const InferTensor *src);
58   kernel::KernelAttr GetKernelAttr() const;
59   std::string Dump(int indent = 0) const;
60 
61  private:
62   std::string name_{};
63   kernel::PrimitiveType type_{};
64   std::shared_ptr<ops::BaseOperator> base_operator_{nullptr};
65   CNodePtr cnode_{nullptr};
66   std::vector<InferTensor *> inputs_{};
67   std::vector<InferTensor *> outputs_{};
68 };
69 using CompileNodePtr = std::shared_ptr<CompileNode>;
70 
71 class CompileResult {
72  public:
73   CompileResult() = default;
74   virtual ~CompileResult() = default;
75 
76   CompileNodePtr GetNode(const std::string &name);
77   CompileNodePtr GetArgNode(const std::string &name);
GetNodes()78   const std::vector<CompileNodePtr> &GetNodes() const { return nodes_; }
NodeSize()79   size_t NodeSize() const { return nodes_.size(); }
GetTensors()80   const std::vector<InferTensor *> &GetTensors() const { return tensors_; }
TensorSize()81   size_t TensorSize() const { return tensors_.size(); }
GetInputs()82   const std::vector<InferTensor *> &GetInputs() const { return inputs_; }
GetInput(size_t i)83   InferTensor *GetInput(size_t i) const { return inputs_.at(i); }
InputSize()84   size_t InputSize() const { return inputs_.size(); }
GetOutputs()85   const std::vector<InferTensor *> &GetOutputs() const { return outputs_; }
GetOutput(size_t i)86   InferTensor *GetOutput(size_t i) const { return outputs_.at(i); }
OutputSize()87   size_t OutputSize() const { return outputs_.size(); }
GetParamNodes()88   const std::vector<CompileNodePtr> &GetParamNodes() const { return param_nodes_; }
GetReturnNodes()89   const std::vector<CompileNodePtr> &GetReturnNodes() const { return return_nodes_; }
90 
91   std::vector<CompileNodePtr> &GetMutableNodes();
92   std::vector<InferTensor *> &GetMutableInputs();
93   std::vector<InferTensor *> &GetMutableOutputs();
94   StatusCode AppendNode(CompileNodePtr node);
95   StatusCode AppendArgNode(CompileNodePtr node);
96   StatusCode AppendTensor(InferTensor *tensor);
97   StatusCode AppendInputTensor(InferTensor *tensor, bool is_borrow = false);
98   StatusCode AppendOutputTensor(InferTensor *tensor, bool is_borrow = false);
99 
100   StatusCode AppendNodeInputTensor(const CompileNodePtr &compile_node, InferTensor *tensor, bool is_borrow = false);
101   StatusCode AppendNodeInputTensor(const std::string &node_name, InferTensor *tensor, bool is_borrow = false);
102   StatusCode AppendNodeOutputTensor(const CompileNodePtr &compile_node, InferTensor *tensor, bool is_borrow = false);
103   StatusCode AppendNodeOutputTensor(const std::string &node_name, InferTensor *tensor, bool is_borrow = false);
104 
Assemble()105   void Assemble() { this->assembled_ = true; }
106 
107   std::string Dump(int indent = 0) const;
108 
109  private:
110   bool assembled_ = false;
111   std::vector<CompileNodePtr> nodes_{};
112   std::vector<InferTensor *> tensors_{};
113   std::vector<InferTensor *> inputs_{};
114   std::vector<InferTensor *> outputs_{};
115   HashMap<std::string, CompileNodePtr> node_map_{};
116   HashMap<std::string, InferTensor *> tensor_map_{};
117   std::vector<CompileNodePtr> param_nodes_{};
118   std::vector<CompileNodePtr> return_nodes_{};
119   std::vector<CompileNodePtr> arg_nodes_{};
120   HashMap<std::string, CompileNodePtr> arg_node_map_{};
121 };
122 using CompileResultPtr = std::shared_ptr<CompileResult>;
123 }  // namespace lite
124 }  // namespace mindspore
125 
126 #endif
127