1 /** 2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). 3 * 4 * Copyright 2019 Huawei Technologies Co., Ltd 5 * 6 * Licensed under the Apache License, Version 2.0 (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, software 13 * distributed under the License is distributed on an "AS IS" BASIS, 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 * See the License for the specific language governing permissions and 16 * limitations under the License. 17 */ 18 19 #ifndef MINDSPORE_CCSRC_VM_TRANSFORM_H_ 20 #define MINDSPORE_CCSRC_VM_TRANSFORM_H_ 21 22 #include <string> 23 #include <memory> 24 #include <functional> 25 #include <utility> 26 #include <unordered_map> 27 #include <vector> 28 29 #include "vm/vm.h" 30 #include "ir/anf.h" 31 #include "frontend/operator/ops.h" 32 #include "vm/segment_runner.h" 33 #include "vm/backend.h" 34 #include "vm/graph_partition.h" 35 36 // mindspore namespace is the top level namespace of MindSpore project. 37 // Other namespace should be a sub namespace of mindspore namespace in the ME project. 38 namespace mindspore { 39 extern const char kMsVm[]; 40 extern const char kGeVm[]; 41 42 // compile namespace 43 // A sub namespace in ME to support compile related definition. 44 namespace compile { 45 extern std::vector<PrimitivePtr> nonlinear_ops; 46 extern std::vector<PrimitivePtr> control_ops; 47 const std::vector<PrimitivePtr> &GetMsNonlinearOps(); 48 FuncGraphPtr WrapPrimitives(const FuncGraphPtr &graph); 49 using VmEvalFunc = std::function<BaseRef(const VectorRef &)>; 50 using VmEvalFuncPtr = std::shared_ptr<std::function<BaseRef(const VectorRef &)>>; 51 52 class CompileGraph { 53 public: 54 explicit CompileGraph(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list = nonlinear_ops); 55 56 virtual ~CompileGraph() = default; 57 58 InstSet Run(const FuncGraphPtr &func_graph); 59 bool IsCut(const AnfNodePtr &node); 60 void Push(const AnfNodePtr &node); Tie(const AnfNodePtr & n1,const AnfNodePtr & n2)61 void Tie(const AnfNodePtr &n1, const AnfNodePtr &n2) { slots_[n2] = slots_[n1]; } 62 void Ret(int64_t nargs); 63 virtual int64_t Ref(const AnfNodePtr &node); 64 set_height(int64_t h)65 void set_height(int64_t h) { 66 height_ = h; 67 if (height_ > max_height_) { 68 max_height_ = height_; 69 } 70 } 71 Reset()72 void Reset() { 73 height_ = 0; 74 max_height_ = 0; 75 slots_.clear(); 76 inst_.clear(); 77 } 78 79 protected: 80 virtual void PushParameters(const FuncGraphPtr &func_graph); 81 bool Compile(const FuncGraphPtr &func_graph); 82 int64_t LinConvert(const FuncGraphPtr &func_graph, const GraphSegmentPtr &segment, const std::string &target = ""); 83 int64_t InterpretNode(const FuncGraphPtr &func_graph, const CNodePtr &node); 84 virtual int64_t AddCall(const FuncGraphPtr &graph, const CNodePtr &node); 85 void AddPadStack(int64_t param_height); 86 void AddTailCall(const AnfNodePtr &fn, size_t size); 87 virtual void AddPartial(const CNodePtr &node); 88 void AddMakeTuple(const CNodePtr &node); 89 void AddSwitch(const CNodePtr &node); 90 void AddSwitchLayer(const CNodePtr &node); 91 void AddReturn(const CNodePtr &node); 92 void AddPrimitive(const CNodePtr &node, const PrimitivePtr &prim); 93 virtual void AddInput(const AnfNodePtr &node); 94 virtual void AddExternal(const LinConvertResult &result); 95 void AddInst(const Instruction &inst, const int64_t &arg); 96 void AddInst(const Instruction &inst, const ValuePtr &arg); 97 void AddInst(const Instruction &inst, const VectorRef &args); 98 99 BackendPtr backend_; 100 GraphPartitionPtr graph_partition_; 101 LinkFuncType lin_convert_; 102 103 int64_t height_{0}; 104 int64_t max_height_{0}; 105 106 std::unordered_map<AnfNodePtr, int64_t> slots_; 107 InstSet inst_; 108 }; 109 110 using CompileGraphPtr = std::shared_ptr<CompileGraph>; 111 112 // CompileGraphs is used to Convert a graph cluster into instruction lists. 113 class CompileGraphs { 114 public: 115 explicit CompileGraphs(const BackendPtr &backend, const std::vector<PrimitivePtr> &cut_list = nonlinear_ops); 116 117 virtual ~CompileGraphs() = default; 118 Reset()119 void Reset() { 120 insts_.clear(); 121 mapping_.clear(); 122 } 123 124 void Compile(const FuncGraphPtr &func_graph); 125 FinalVMPtr Link(); 126 FinalVMPtr CompileAndLink(const FuncGraphPtr &func_graph); 127 128 protected: 129 InstSet insts_; 130 std::unordered_map<FuncGraphPtr, int64_t> mapping_; 131 CompileGraphPtr transform_; 132 BackendPtr backend_; 133 }; 134 135 BackendPtr CreateBackend(); 136 137 // Set mindRT whether enable. GPU and CPU use mindRT currently, and other hardwares will use it in the future. 138 void SetMindRTEnable(); 139 140 } // namespace compile 141 } // namespace mindspore 142 143 #endif // MINDSPORE_CCSRC_VM_TRANSFORM_H_ 144