• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_LITE_GRAPH_H_
17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_LITE_GRAPH_H_
18 
19 #include <memory>
20 #include <string>
21 #include "backend/common/graph_kernel/model/node.h"
22 
23 namespace mindspore::graphkernel::inner {
24 class LiteGraph {
25  public:
26   class GraphBuilderBase;
name_(name)27   explicit LiteGraph(const std::string &name = "") : name_(name), output_(new OutputNode()) {}
28 
29   const NodePtrList &GetOrderedNodes();
30   std::string ToString(bool reset_node_name = false) const;
name()31   const std::string &name() const { return name_; }
ops()32   const NodePtrList &ops() const { return ops_; }
inputs()33   const NodePtrList &inputs() const { return inputs_; }
output(size_t i)34   const NodePtr &output(size_t i) const { return output_->input(i); }
GetOutputs()35   const NodePtrList &GetOutputs() const { return output_->inputs(); }
36 
SetOutput(size_t i,const NodePtr & node)37   void SetOutput(size_t i, const NodePtr &node) { output_->SetInput(i, node); }
SetOutputs(const NodePtrList & nodes)38   void SetOutputs(const NodePtrList &nodes) { output_->SetInputs(nodes); }
39 
40  protected:
41   std::string name_;
42   NodePtrList ops_;  // save all operators in topo order
43   NodePtrList inputs_;
44   NodePtr output_;
45 
46  private:
ParamName()47   std::string ParamName() const { return "input_" + std::to_string(param_id_++); }
NodeName()48   std::string NodeName() const { return "output_" + std::to_string(node_id_++); }
49   mutable int param_id_{0};
50   mutable int node_id_{0};
51 };
52 using LiteGraphPtr = std::shared_ptr<LiteGraph>;
53 class LiteGraph::GraphBuilderBase {
54  public:
55   explicit GraphBuilderBase(const std::string &name = "") { graph_ = std::make_shared<LiteGraph>(name); }
56   ~GraphBuilderBase() = default;
57 
58   // Create a parameter of graph
Parameter(const NodeBase & baseinfo)59   NodePtr Parameter(const NodeBase &baseinfo) const {
60     auto para = std::make_shared<ParamNode>(baseinfo);
61     para->SetDebugName(graph_->ParamName());
62     graph_->inputs_.push_back(para);
63     return para;
64   }
65 
66   // Create a const value node
Value(const tensor::TensorPtr & data)67   NodePtr Value(const tensor::TensorPtr &data) const { return std::make_shared<ConstTensorNode>(data); }
68 
SetOutputs(const NodePtrList & nodes)69   void SetOutputs(const NodePtrList &nodes) const { graph_->output_->SetInputs(nodes); }
70 
71   // Emit op, auto inferring the baseinfo of Node.
72   NodePtr Emit(const std::string &op, const NodePtrList &inputs, const DAttrs &attrs = {}) const;
73 
74   // Create op node with given baseinfo.
75   NodePtr Op(const std::string &op, const NodeBaseList &baseinfolist, const NodePtrList &inputs,
76              const DAttrs &attrs = {}) const;
77   NodePtr Op(const std::string &op, const NodeBase &baseinfo, const NodePtrList &inputs,
78              const DAttrs &attrs = {}) const;
Get()79   LiteGraphPtr Get() const { return graph_; }
80 
81  private:
82   LiteGraphPtr graph_;
83 };
84 }  // namespace mindspore::graphkernel::inner
85 #endif
86