1 /** 2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). 3 * 4 * Copyright 2019 Huawei Technologies Co., Ltd 5 * 6 * Licensed under the Apache License, Version 2.0 (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, software 13 * distributed under the License is distributed on an "AS IS" BASIS, 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 * See the License for the specific language governing permissions and 16 * limitations under the License. 17 */ 18 19 #ifndef MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_H_ 20 #define MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_H_ 21 22 #include <vector> 23 #include <string> 24 #include <unordered_map> 25 #include <utility> 26 #include <map> 27 #include <set> 28 #include <memory> 29 #include "frontend/operator/composite/zip_operation.h" 30 #include "frontend/operator/composite/list_append_operation.h" 31 #include "frontend/operator/composite/do_signature.h" 32 #include "frontend/operator/composite/unpack_call.h" 33 #include "frontend/operator/composite/multitype_funcgraph.h" 34 #include "pipeline/jit/static_analysis/static_analysis.h" 35 #include "utils/misc.h" 36 #include "utils/any.h" 37 #include "ir/dtype.h" 38 #include "ir/meta_func_graph.h" 39 40 namespace mindspore { 41 // namespace to support composite operators definition 42 namespace prim { 43 using AbstractSlicePtr = abstract::AbstractSlicePtr; 44 using AbstractScalarPtr = abstract::AbstractScalarPtr; 45 using AbstractTensorPtr = abstract::AbstractTensorPtr; 46 using ElemwiseMap = std::unordered_map<std::string, PrimitivePtr>; 47 using ArgsPairList = std::vector<std::pair<AnfNodePtr, TypePtr>>; 48 49 class HyperMap : public MetaFuncGraph { 50 public: 51 explicit HyperMap(bool reverse = false, const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr); 52 HyperMap(const HyperMap &h); 53 void Init(); 54 HyperMap &operator=(const HyperMap &h) { 55 if (this != &h) { 56 fn_leaf_ = h.fn_leaf_; 57 reverse_ = h.reverse_; 58 broadcast_ = h.broadcast_; 59 nonleaf_ = h.nonleaf_; 60 if (fn_leaf_) { 61 name_ = "hyper_map[" + fn_leaf_->name() + "]"; 62 } 63 } 64 return *this; 65 } 66 ~HyperMap() override = default; 67 MS_DECLARE_PARENT(HyperMap, MetaFuncGraph) 68 69 abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const override; 70 FuncGraphPtr GenerateFromTypes(const TypePtrList &args_spec_list) override; GetFnLeaf()71 MetaFuncGraphPtr GetFnLeaf() { return fn_leaf_; } 72 73 private: 74 AnfNodePtr FullMake(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map); 75 AnfNodePtr FullMake(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, 76 const ArgsPairList &arg_map); 77 AnfNodePtr FullMake(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, 78 const ArgsPairList &arg_map); 79 AnfNodePtr FullMake(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, 80 const ArgsPairList &arg_map); 81 AnfNodePtr Make(const FuncGraphPtr &graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map); 82 ArgsPairList Harmonize(const FuncGraphPtr &graph, const ArgsPairList &args_spec_list); 83 84 MultitypeFuncGraphPtr fn_leaf_; 85 bool reverse_; 86 bool broadcast_; 87 std::set<TypeId> nonleaf_; 88 }; 89 using HyperMapPtr = std::shared_ptr<HyperMap>; 90 91 class HyperMapPy : public HyperMap { 92 public: 93 explicit HyperMapPy(bool reverse = false, const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr) HyperMap(reverse,fn_leaf)94 : HyperMap(reverse, fn_leaf) {} 95 ~HyperMapPy() override = default; 96 MS_DECLARE_PARENT(HyperMapPy, HyperMap) 97 }; 98 using HyperMapPyPtr = std::shared_ptr<HyperMapPy>; 99 100 extern ValuePtr kCompositeHyperMap; 101 102 enum TailType { kGradAll, kGradFirst, kNotGrad }; 103 104 class Tail : public MetaFuncGraph { 105 public: 106 explicit Tail(const std::string &name, TailType tail_type = kNotGrad) MetaFuncGraph(name)107 : MetaFuncGraph(name), tail_type_(tail_type), enable_tuple_grad_(false) {} 108 ~Tail() override = default; 109 MS_DECLARE_PARENT(Tail, MetaFuncGraph) 110 111 FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; 112 FuncGraphPtr GenerateSequeueFuncGraph(const abstract::AbstractSequeuePtr &sequeue) const; 113 114 friend bool operator==(const Tail &lhs, const Tail &rhs) { return lhs.name_ == rhs.name_; } set_enable_tuple_grad(bool enable_tuple_grad)115 void set_enable_tuple_grad(bool enable_tuple_grad) { enable_tuple_grad_ = enable_tuple_grad; } 116 117 private: 118 TailType tail_type_; 119 bool enable_tuple_grad_; 120 }; 121 using TailPtr = std::shared_ptr<Tail>; 122 123 class MakeTupleGradient : public MetaFuncGraph { 124 public: MakeTupleGradient(const std::string & name)125 explicit MakeTupleGradient(const std::string &name) : MetaFuncGraph(name) {} 126 ~MakeTupleGradient() override = default; 127 MS_DECLARE_PARENT(MakeTupleGradient, MetaFuncGraph) 128 FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; 129 friend bool operator==(const MakeTupleGradient &lhs, const MakeTupleGradient &rhs) { return lhs.name_ == rhs.name_; } 130 }; 131 using MakeTupleGradientPtr = std::shared_ptr<MakeTupleGradient>; 132 133 class MakeListGradient : public MetaFuncGraph { 134 public: MakeListGradient(const std::string & name)135 explicit MakeListGradient(const std::string &name) : MetaFuncGraph(name) {} 136 ~MakeListGradient() override = default; 137 MS_DECLARE_PARENT(MakeListGradient, MetaFuncGraph) 138 FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; 139 friend bool operator==(const MakeListGradient &lhs, const MakeListGradient &rhs) { return lhs.name_ == rhs.name_; } 140 }; 141 using MakeListGradientPtr = std::shared_ptr<MakeListGradient>; 142 143 class GradOperation : public MetaFuncGraph { 144 public: 145 explicit GradOperation(const std::string &name, bool get_all = false, bool get_by_list = false, 146 bool sens_param = false); 147 ~GradOperation() override = default; 148 MS_DECLARE_PARENT(GradOperation, MetaFuncGraph) 149 150 FuncGraphPtr GetGrad(const AnfNodePtr &k, const AnfNodePtr &weights, 151 const std::vector<AnfNodePtr> &forward_graph_params, bool enable_tuple_grad, 152 const std::vector<AnfNodePtr> &weight_args = {}); 153 154 FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; sens_param()155 bool sens_param() const { return sens_param_; } 156 bool get_all_; 157 bool get_by_list_; 158 bool sens_param_; 159 160 private: 161 void GradByParameter(const FuncGraphPtr &k_child, const AnfNodePtr &f_app, const AnfNodePtr &bprop, 162 const AnfNodePtr &weights, bool enable_tuple_grad); 163 }; 164 using GradOperationPtr = std::shared_ptr<GradOperation>; 165 166 class ListMap { 167 public: ListMap(const std::string & name)168 explicit ListMap(const std::string &name) : name_(name) { cache_.clear(); } 169 ~ListMap() = default; 170 void MakeCond(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &gnext_ptr, const FuncGraphPtr &graph_ptr); 171 void MakeNext(const std::vector<AnfNodePtr> &lists, const FuncGraphPtr &gcond_ptr, const FuncGraphPtr &graph_ptr); 172 FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list); 173 174 private: 175 std::string name_; 176 std::map<std::vector<AnyPtr>, FuncGraphPtr> cache_; 177 }; 178 179 class TupleAdd : public MetaFuncGraph { 180 public: TupleAdd(const std::string & name)181 explicit TupleAdd(const std::string &name) : MetaFuncGraph(name) {} 182 ~TupleAdd() override = default; 183 MS_DECLARE_PARENT(TupleAdd, MetaFuncGraph) 184 FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; 185 friend bool operator==(const TupleAdd &lhs, const TupleAdd &rhs) { return lhs.name_ == rhs.name_; } 186 }; 187 using TupleAddPtr = std::shared_ptr<TupleAdd>; 188 189 class TupleSlice : public MetaFuncGraph { 190 public: TupleSlice(const std::string & name)191 explicit TupleSlice(const std::string &name) : MetaFuncGraph(name) {} 192 ~TupleSlice() override = default; 193 MS_DECLARE_PARENT(TupleSlice, MetaFuncGraph) 194 FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; 195 friend bool operator==(const TupleSlice &lhs, const TupleSlice &rhs) { return lhs.name_ == rhs.name_; } 196 }; 197 using TupleSlicePtr = std::shared_ptr<TupleSlice>; 198 199 class TupleGetItemTensor : public MetaFuncGraph { 200 public: TupleGetItemTensor(const std::string & name)201 explicit TupleGetItemTensor(const std::string &name) : MetaFuncGraph(name) {} 202 ~TupleGetItemTensor() override = default; 203 MS_DECLARE_PARENT(TupleGetItemTensor, MetaFuncGraph) 204 FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_spec_list) override; 205 friend bool operator==(const TupleGetItemTensor &lhs, const TupleGetItemTensor &rhs) { 206 return lhs.name_ == rhs.name_; 207 } 208 }; 209 using TupleGetItemTensorPtr = std::shared_ptr<TupleGetItemTensor>; 210 } // namespace prim 211 } // namespace mindspore 212 213 #endif // MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_H_ 214