1 /** 2 * Copyright 2021-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_LITE_GRAPH_H_ 17 #define MINDSPORE_CCSRC_BACKEND_OPTIMIZER_GRAPH_KERNEL_MODEL_LITE_GRAPH_H_ 18 19 #include <memory> 20 #include <string> 21 #include "backend/common/graph_kernel/model/node.h" 22 23 namespace mindspore::graphkernel::inner { 24 class LiteGraph { 25 public: 26 class GraphBuilderBase; name_(name)27 explicit LiteGraph(const std::string &name = "") : name_(name), output_(new OutputNode()) {} 28 29 const NodePtrList &GetOrderedNodes(); 30 std::string ToString(bool reset_node_name = false) const; name()31 const std::string &name() const { return name_; } ops()32 const NodePtrList &ops() const { return ops_; } inputs()33 const NodePtrList &inputs() const { return inputs_; } output(size_t i)34 const NodePtr &output(size_t i) const { return output_->input(i); } GetOutputs()35 const NodePtrList &GetOutputs() const { return output_->inputs(); } 36 SetOutput(size_t i,const NodePtr & node)37 void SetOutput(size_t i, const NodePtr &node) { output_->SetInput(i, node); } SetOutputs(const NodePtrList & nodes)38 void SetOutputs(const NodePtrList &nodes) { output_->SetInputs(nodes); } 39 40 protected: 41 std::string name_; 42 NodePtrList ops_; // save all operators in topo order 43 NodePtrList inputs_; 44 NodePtr output_; 45 46 private: ParamName()47 std::string ParamName() const { return "input_" + std::to_string(param_id_++); } NodeName()48 std::string NodeName() const { return "output_" + std::to_string(node_id_++); } 49 mutable int param_id_{0}; 50 mutable int node_id_{0}; 51 }; 52 using LiteGraphPtr = std::shared_ptr<LiteGraph>; 53 class LiteGraph::GraphBuilderBase { 54 public: 55 explicit GraphBuilderBase(const std::string &name = "") { graph_ = std::make_shared<LiteGraph>(name); } 56 ~GraphBuilderBase() = default; 57 58 // Create a parameter of graph Parameter(const NodeBase & baseinfo)59 NodePtr Parameter(const NodeBase &baseinfo) const { 60 auto para = std::make_shared<ParamNode>(baseinfo); 61 para->SetDebugName(graph_->ParamName()); 62 graph_->inputs_.push_back(para); 63 return para; 64 } 65 66 // Create a const value node Value(const tensor::TensorPtr & data)67 NodePtr Value(const tensor::TensorPtr &data) const { return std::make_shared<ConstTensorNode>(data); } 68 SetOutputs(const NodePtrList & nodes)69 void SetOutputs(const NodePtrList &nodes) const { graph_->output_->SetInputs(nodes); } 70 71 // Emit op, auto inferring the baseinfo of Node. 72 NodePtr Emit(const std::string &op, const NodePtrList &inputs, const DAttrs &attrs = {}) const; 73 74 // Create op node with given baseinfo. 75 NodePtr Op(const std::string &op, const NodeBaseList &baseinfolist, const NodePtrList &inputs, 76 const DAttrs &attrs = {}) const; 77 NodePtr Op(const std::string &op, const NodeBase &baseinfo, const NodePtrList &inputs, 78 const DAttrs &attrs = {}) const; Get()79 LiteGraphPtr Get() const { return graph_; } 80 81 private: 82 LiteGraphPtr graph_; 83 }; 84 } // namespace mindspore::graphkernel::inner 85 #endif 86