1 /** 2 * Copyright 2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_LITE_GRAPH_H_ 17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_LITE_GRAPH_H_ 18 19 #include <memory> 20 #include <vector> 21 #include <list> 22 #include <unordered_map> 23 #include <unordered_set> 24 #include <stack> 25 #include <string> 26 #include "backend/optimizer/graph_kernel/model/node.h" 27 #include "backend/optimizer/graph_kernel/model/op_node.h" 28 29 namespace mindspore { 30 namespace opt { 31 namespace graphkernel { 32 class LiteGraph { 33 public: 34 class GraphBuilder; name_(name)35 explicit LiteGraph(const std::string &name = "") : name_(name), output_(new OutputNode()) {} 36 ~LiteGraph() = default; Add(PrimOpPtr op)37 NodePtr &Add(PrimOpPtr op) { 38 ops_.emplace_back(op); 39 return ops_.back(); 40 } 41 42 const NodePtrList &GetOrderedNodes(); 43 44 std::string Dump() const; name()45 const std::string &name() const { return name_; } ops()46 const NodePtrList &ops() const { return ops_; } inputs()47 const NodePtrList &inputs() const { return inputs_; } output()48 const NodePtr &output() const { return output_; } GetOutputs()49 const NodePtrList &GetOutputs() const { return output_->inputs(); } 50 51 protected: 52 std::string name_; 53 NodePtrList ops_; // save all operators in topo order 54 NodePtrList inputs_; 55 NodePtr output_; 56 57 private: 58 int name_id_{0}; 59 }; 60 using LiteGraphPtr = std::shared_ptr<LiteGraph>; 61 62 class LiteGraph::GraphBuilder { 63 public: 64 explicit GraphBuilder(const std::string &name = "") { graph_ = std::make_shared<LiteGraph>(name); } 65 ~GraphBuilder() = default; 66 NodePtr Parameter(const NodeBase &baseinfo, std::string name = "") { 67 if (name.empty()) name = NewName(); 68 auto para = std::make_shared<ParamNode>(name, baseinfo); 69 graph_->inputs_.push_back(para); 70 return para; 71 } 72 NodePtr Value(const tensor::TensorPtr &data, const std::string &name = "") { 73 return std::make_shared<ConstTensorNode>(data, name); 74 } 75 SetOutputs(const NodePtrList & nodes)76 void SetOutputs(const NodePtrList &nodes) { graph_->output_->SetInputs(nodes); } 77 78 NodePtr Emit(const std::string &op, const NodePtrList &inputs, const DAttrs &attrs = {}, std::string node_name = ""); 79 NodePtr Op(const std::string &op, const NodeBase &baseinfo, const NodePtrList &inputs, const DAttrs &attrs = {}, 80 std::string node_name = ""); Get()81 LiteGraphPtr Get() { return graph_; } 82 83 private: 84 PrimOpPtr CreateOp(const std::string &id, const std::string &name); 85 std::string NewName(std::string prefix = "output_") { return prefix + std::to_string(graph_->name_id_++); } 86 87 LiteGraphPtr graph_; 88 }; 89 } // namespace graphkernel 90 } // namespace opt 91 } // namespace mindspore 92 #endif 93