1 /** 2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). 3 * 4 * Copyright 2019-2024 Huawei Technologies Co., Ltd 5 * 6 * Licensed under the Apache License, Version 2.0 (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, software 13 * distributed under the License is distributed on an "AS IS" BASIS, 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 * See the License for the specific language governing permissions and 16 * limitations under the License. 17 */ 18 19 #ifndef MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_H_ 20 #define MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_H_ 21 22 #include <vector> 23 #include <string> 24 #include <utility> 25 #include <map> 26 #include <set> 27 #include <memory> 28 #include "utils/hash_map.h" 29 #include "frontend/operator/composite/zip_operation.h" 30 #include "frontend/operator/composite/list_operation.h" 31 #include "frontend/operator/composite/do_signature.h" 32 #include "frontend/operator/composite/unpack_call.h" 33 #include "frontend/operator/composite/multitype_funcgraph.h" 34 #include "frontend/operator/composite/starred_operation.h" 35 #include "pipeline/jit/ps/static_analysis/static_analysis.h" 36 #include "utils/misc.h" 37 #include "utils/any.h" 38 #include "ir/dtype.h" 39 #include "ir/meta_func_graph.h" 40 41 namespace mindspore { 42 // namespace to support composite operators definition 43 namespace prim { 44 using AbstractSlicePtr = abstract::AbstractSlicePtr; 45 using AbstractScalarPtr = abstract::AbstractScalarPtr; 46 using AbstractTensorPtr = abstract::AbstractTensorPtr; 47 using ElemwiseMap = mindspore::HashMap<std::string, PrimitivePtr>; 48 using ArgsPairList = std::vector<std::pair<AnfNodePtr, TypePtr>>; 49 using AbstractListPtr = abstract::AbstractListPtr; 50 51 class HyperMap : public MetaFuncGraph { 52 public: 53 explicit HyperMap(bool reverse = false, const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr); 54 HyperMap(const HyperMap &h); 55 void Init(); 56 HyperMap &operator=(const HyperMap &h) noexcept { 57 if (this != &h) { 58 fn_leaf_ = h.fn_leaf_; 59 reverse_ = h.reverse_; 60 nonleaf_ = h.nonleaf_; 61 if (fn_leaf_) { 62 name_ = "hyper_map[" + fn_leaf_->name() + "]"; 63 } 64 } 65 return *this; 66 } 67 ~HyperMap() override = default; 68 MS_DECLARE_PARENT(HyperMap, MetaFuncGraph) 69 70 abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_abs_list) const override; 71 FuncGraphPtr GenerateFromTypes(const TypePtrList &args_abs_list) override; GetFnLeaf()72 MetaFuncGraphPtr GetFnLeaf() { return fn_leaf_; } 73 void SetObjectForFnLeaf(const py::object &leaf_object); 74 75 private: 76 AnfNodePtr FullMake(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) const; 77 AnfNodePtr FullMake(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, 78 const ArgsPairList &arg_map) const; 79 AnfNodePtr FullMake(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, 80 const ArgsPairList &arg_map) const; 81 AnfNodePtr FullMake(const std::shared_ptr<Dictionary> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, 82 const ArgsPairList &arg_map) const; 83 AnfNodePtr Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_map) const; 84 std::pair<std::string, std::string> GetHyperMapInputIndex(size_t num) const; 85 86 MultitypeFuncGraphPtr fn_leaf_; 87 bool reverse_; 88 std::set<TypeId> nonleaf_; 89 }; 90 using HyperMapPtr = std::shared_ptr<HyperMap>; 91 92 class HyperMapPy : public HyperMap { 93 public: 94 explicit HyperMapPy(bool reverse = false, const py::object &fn_leaf = py::none()) 95 : HyperMap(reverse, fn_leaf.cast<prim::MultitypeFuncGraphPtr>()) { 96 SetObjectForFnLeaf(fn_leaf); 97 } 98 ~HyperMapPy() override = default; 99 MS_DECLARE_PARENT(HyperMapPy, HyperMap) 100 }; 101 using HyperMapPyPtr = std::shared_ptr<HyperMapPy>; 102 103 extern ValuePtr kCompositeHyperMap; 104 105 enum TailType { kGradAll, kGradFirst, kGradByPosition, kNotGrad }; 106 107 class Tail : public MetaFuncGraph { 108 public: 109 explicit Tail(const std::string &name, TailType tail_type = kNotGrad, bool return_ids = false) MetaFuncGraph(name)110 : MetaFuncGraph(name), tail_type_(tail_type), enable_tuple_grad_first_(false), return_ids_(return_ids) {} 111 ~Tail() override = default; 112 MS_DECLARE_PARENT(Tail, MetaFuncGraph) 113 114 FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override; 115 116 friend bool operator==(const Tail &lhs, const Tail &rhs) { return lhs.name_ == rhs.name_; } set_enable_tuple_grad_first(bool enable_tuple_grad_first)117 void set_enable_tuple_grad_first(bool enable_tuple_grad_first) { enable_tuple_grad_first_ = enable_tuple_grad_first; } 118 119 private: 120 FuncGraphPtr GenerateTailFuncGraph(const abstract::AbstractSequencePtr &sequence_arg) const; 121 FuncGraphPtr GenerateGradFuncGraph(const abstract::AbstractTuplePtr &tuple_arg, 122 const abstract::AbstractTuplePtr &position = nullptr) const; 123 124 TailType tail_type_; 125 bool enable_tuple_grad_first_; 126 bool return_ids_; 127 }; 128 using TailPtr = std::shared_ptr<Tail>; 129 130 class MakeTupleGradient : public MetaFuncGraph { 131 public: MakeTupleGradient(const std::string & name)132 explicit MakeTupleGradient(const std::string &name) : MetaFuncGraph(name) {} 133 ~MakeTupleGradient() override = default; 134 MS_DECLARE_PARENT(MakeTupleGradient, MetaFuncGraph) 135 FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override; 136 friend bool operator==(const MakeTupleGradient &lhs, const MakeTupleGradient &rhs) { return lhs.name_ == rhs.name_; } 137 }; 138 using MakeTupleGradientPtr = std::shared_ptr<MakeTupleGradient>; 139 140 class MakeListGradient : public MetaFuncGraph { 141 public: MakeListGradient(const std::string & name)142 explicit MakeListGradient(const std::string &name) : MetaFuncGraph(name) {} 143 ~MakeListGradient() override = default; 144 MS_DECLARE_PARENT(MakeListGradient, MetaFuncGraph) 145 FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override; 146 friend bool operator==(const MakeListGradient &lhs, const MakeListGradient &rhs) { return lhs.name_ == rhs.name_; } 147 }; 148 using MakeListGradientPtr = std::shared_ptr<MakeListGradient>; 149 150 class MakeDictGradient : public MetaFuncGraph { 151 public: MakeDictGradient(const std::string & name)152 explicit MakeDictGradient(const std::string &name) : MetaFuncGraph(name) {} 153 ~MakeDictGradient() override = default; 154 MS_DECLARE_PARENT(MakeDictGradient, MetaFuncGraph) 155 FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override; 156 friend bool operator==(const MakeDictGradient &lhs, const MakeDictGradient &rhs) { return lhs.name_ == rhs.name_; } 157 }; 158 using MakeDictGradientPtr = std::shared_ptr<MakeDictGradient>; 159 160 class PyExecuteGradient : public MetaFuncGraph { 161 public: PyExecuteGradient(const std::string & name)162 explicit PyExecuteGradient(const std::string &name) : MetaFuncGraph(name) {} 163 ~PyExecuteGradient() override = default; 164 MS_DECLARE_PARENT(PyExecuteGradient, MetaFuncGraph) 165 FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override; 166 friend bool operator==(const PyExecuteGradient &lhs, const PyExecuteGradient &rhs) { return lhs.name_ == rhs.name_; } 167 }; 168 using PyExecuteGradientPtr = std::shared_ptr<PyExecuteGradient>; 169 170 class MutableGradient : public MetaFuncGraph { 171 public: MutableGradient(const std::string & name)172 explicit MutableGradient(const std::string &name) : MetaFuncGraph(name) {} 173 ~MutableGradient() override = default; 174 MS_DECLARE_PARENT(MutableGradient, MetaFuncGraph) 175 FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override; 176 friend bool operator==(const MutableGradient &lhs, const MutableGradient &rhs) { return lhs.name_ == rhs.name_; } 177 }; 178 using MutableGradientPtr = std::shared_ptr<MutableGradient>; 179 180 class GradOperation : public MetaFuncGraph { 181 public: 182 explicit GradOperation(const std::string &name, bool get_all = false, bool get_by_list = false, 183 bool sens_param = false, bool get_by_position = false, bool has_aux = false, 184 bool get_value = false, bool return_ids = false, bool merge_forward = false); 185 ~GradOperation() override = default; 186 MS_DECLARE_PARENT(GradOperation, MetaFuncGraph) 187 188 FuncGraphPtr GetGrad(const AnfNodePtr &j, const AnfNodePtr &weights, const AnfNodePtr &position, 189 const std::vector<AnfNodePtr> &forward_graph_params, bool enable_tuple_grad, 190 bool is_weights_none) const; 191 192 FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override; 193 sens_param()194 bool sens_param() const { return sens_param_; } 195 196 bool get_all_; 197 bool get_by_list_; 198 bool sens_param_; 199 bool get_by_position_; 200 bool has_aux_; 201 bool get_value_; 202 bool return_ids_; 203 bool merge_forward_; 204 205 private: 206 void GradByParameter(const FuncGraphPtr &k_child, const AnfNodePtr &f_app, const AnfNodePtr &bprop, 207 const AnfNodePtr &weights, const AnfNodePtr &position, bool enable_tuple_grad, 208 bool is_weights_none) const; 209 CNodePtr SetNodeByParameter(const CNodePtr &grad, const FuncGraphPtr &fg) const; 210 AbstractBasePtr weight_value_; 211 }; 212 using GradOperationPtr = std::shared_ptr<GradOperation>; 213 214 class GradAux : public MetaFuncGraph { 215 public: GradAux(const std::string & name)216 explicit GradAux(const std::string &name) : MetaFuncGraph(name) {} 217 ~GradAux() override = default; 218 MS_DECLARE_PARENT(GradAux, MetaFuncGraph); 219 FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override; 220 }; 221 using GradAuxPtr = std::shared_ptr<GradAux>; 222 223 class TaylorOperation : public MetaFuncGraph { 224 public: 225 explicit TaylorOperation(const std::string &name); 226 ~TaylorOperation() override = default; 227 MS_DECLARE_PARENT(TaylorOperation, MetaFuncGraph); 228 FuncGraphPtr GetTaylorGrad(const AnfNodePtr &k, const std::vector<AnfNodePtr> &forward_graph_params) const; 229 230 FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override; 231 }; 232 using TaylorOperationPtr = std::shared_ptr<TaylorOperation>; 233 234 class TupleAdd : public MetaFuncGraph { 235 public: TupleAdd(const std::string & name)236 explicit TupleAdd(const std::string &name) : MetaFuncGraph(name) {} 237 ~TupleAdd() override = default; 238 MS_DECLARE_PARENT(TupleAdd, MetaFuncGraph) 239 FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override; 240 friend bool operator==(const TupleAdd &lhs, const TupleAdd &rhs) { return lhs.name_ == rhs.name_; } 241 }; 242 using TupleAddPtr = std::shared_ptr<TupleAdd>; 243 244 class ListAdd : public MetaFuncGraph { 245 public: ListAdd(const std::string & name)246 explicit ListAdd(const std::string &name) : MetaFuncGraph(name) {} 247 ~ListAdd() override = default; 248 MS_DECLARE_PARENT(ListAdd, MetaFuncGraph) 249 FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override; 250 friend bool operator==(const ListAdd &lhs, const ListAdd &rhs) { return lhs.name_ == rhs.name_; } 251 }; 252 using ListAddPtr = std::shared_ptr<ListAdd>; 253 254 class SequenceSlice : public MetaFuncGraph { 255 public: SequenceSlice(const std::string & name)256 explicit SequenceSlice(const std::string &name) : MetaFuncGraph(name) {} 257 ~SequenceSlice() override = default; 258 MS_DECLARE_PARENT(SequenceSlice, MetaFuncGraph) 259 FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) final; 260 friend bool operator==(const SequenceSlice &lhs, const SequenceSlice &rhs) { return lhs.name_ == rhs.name_; } 261 262 protected: 263 virtual void CheckArgs(const AbstractBasePtrList &args_abs_list) = 0; 264 virtual FuncGraphPtr BuildFuncGraph(int64_t start_index, int64_t stop_index, int64_t step_value) = 0; 265 abstract::AbstractSequencePtr sequence_ = nullptr; 266 AbstractSlicePtr slice_ = nullptr; 267 }; 268 269 class SequenceSliceGetItem : public SequenceSlice { 270 public: SequenceSliceGetItem(const std::string & name,const std::string & prim_name,const std::string & get_item_name)271 explicit SequenceSliceGetItem(const std::string &name, const std::string &prim_name, const std::string &get_item_name) 272 : SequenceSlice(name), 273 prim_(std::make_shared<Primitive>(prim_name)), 274 get_item_(std::make_shared<Primitive>(get_item_name)) {} 275 ~SequenceSliceGetItem() override = default; 276 MS_DECLARE_PARENT(SequenceSliceGetItem, MetaFuncGraph) 277 friend bool operator==(const SequenceSliceGetItem &lhs, const SequenceSliceGetItem &rhs) { 278 return lhs.name_ == rhs.name_; 279 } 280 281 protected: 282 void CheckArgs(const AbstractBasePtrList &args_abs_list) override; 283 FuncGraphPtr BuildFuncGraph(int64_t start_index, int64_t stop_index, int64_t step_value) override; 284 285 private: 286 PrimitivePtr prim_; 287 PrimitivePtr get_item_; 288 }; 289 290 class ListSliceSetItem : public SequenceSlice { 291 public: ListSliceSetItem(const std::string & name)292 explicit ListSliceSetItem(const std::string &name) : SequenceSlice(name) {} 293 ~ListSliceSetItem() override = default; 294 MS_DECLARE_PARENT(ListSliceSetItem, MetaFuncGraph) 295 friend bool operator==(const ListSliceSetItem &lhs, const ListSliceSetItem &rhs) { return lhs.name_ == rhs.name_; } 296 297 protected: 298 void CheckArgs(const AbstractBasePtrList &args_abs_list) override; 299 FuncGraphPtr BuildFuncGraph(int64_t start_index, int64_t stop_index, int64_t step_value) override; 300 301 private: 302 void CheckAssignRange(int64_t start_index, int64_t stop_index, int64_t step_value); 303 AnfNodePtr GetAssignNode(const FuncGraphPtr &func_graph, const AnfNodePtr &assign_node, int64_t step_value); 304 AbstractListPtr value_list_ = nullptr; 305 }; 306 307 class TupleGetItemTensor : public MetaFuncGraph { 308 public: TupleGetItemTensor(const std::string & name)309 explicit TupleGetItemTensor(const std::string &name) : MetaFuncGraph(name) {} 310 ~TupleGetItemTensor() override = default; 311 MS_DECLARE_PARENT(TupleGetItemTensor, MetaFuncGraph) 312 FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override; 313 friend bool operator==(const TupleGetItemTensor &lhs, const TupleGetItemTensor &rhs) { 314 return lhs.name_ == rhs.name_; 315 } 316 }; 317 using TupleGetItemTensorPtr = std::shared_ptr<TupleGetItemTensor>; 318 319 class Shard : public MetaFuncGraph { 320 public: Shard(const string & name)321 explicit Shard(const string &name) : MetaFuncGraph(name) { 322 signatures_ = 323 // def shard(func:read, weight_list:read, in_axes:read, out_axes:read, parameter_plan:read, device:read, 324 // level:read): 325 std::vector<Signature>({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault}, 326 {"in_axes", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault}, 327 {"out_axes", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault}, 328 {"device", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault}, 329 {"level", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault}}); 330 kShardInputSize = signatures_.size(); 331 } 332 ~Shard() override = default; 333 MS_DECLARE_PARENT(Shard, MetaFuncGraph) 334 335 FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override; 336 337 private: 338 size_t kShardInputSize = 0; 339 }; 340 341 class VmapOperation : public MetaFuncGraph { 342 public: 343 explicit VmapOperation(const std::string &name); 344 ~VmapOperation() override = default; 345 MS_DECLARE_PARENT(VmapOperation, MetaFuncGraph) 346 347 FuncGraphPtr GetVmap(const AnfNodePtr &vmap, int param_number) const; 348 349 FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override; 350 }; 351 using VmapOperationPtr = std::shared_ptr<VmapOperation>; 352 353 class ZerosLike : public MetaFuncGraph { 354 public: 355 explicit ZerosLike(const std::string &name, const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr) MetaFuncGraph(name)356 : MetaFuncGraph(name), fn_leaf_(fn_leaf) {} 357 ~ZerosLike() override = default; 358 MS_DECLARE_PARENT(ZerosLike, MetaFuncGraph) 359 FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override; 360 friend bool operator==(const ZerosLike &lhs, const ZerosLike &rhs) { return lhs.name_ == rhs.name_; } 361 362 private: 363 MultitypeFuncGraphPtr fn_leaf_; 364 }; 365 using ZerosLikePtr = std::shared_ptr<ZerosLike>; 366 367 class IterConverter : public MetaFuncGraph { 368 public: IterConverter(const std::string & name)369 explicit IterConverter(const std::string &name) : MetaFuncGraph(name) {} 370 ~IterConverter() override = default; 371 MS_DECLARE_PARENT(IterConverter, MetaFuncGraph) 372 FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override; 373 friend bool operator==(const IterConverter &lhs, const IterConverter &rhs) { return lhs.name_ == rhs.name_; } 374 }; 375 using IterConverterPtr = std::shared_ptr<IterConverter>; 376 377 class HasNext : public MetaFuncGraph { 378 public: HasNext(const std::string & name)379 explicit HasNext(const std::string &name) : MetaFuncGraph(name) {} 380 ~HasNext() override = default; 381 MS_DECLARE_PARENT(HasNext, MetaFuncGraph) 382 FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override; 383 friend bool operator==(const HasNext &lhs, const HasNext &rhs) { return lhs.name_ == rhs.name_; } 384 }; 385 using HasNextPtr = std::shared_ptr<HasNext>; 386 387 class Next : public MetaFuncGraph { 388 public: Next(const std::string & name)389 explicit Next(const std::string &name) : MetaFuncGraph(name) {} 390 ~Next() override = default; 391 MS_DECLARE_PARENT(Next, MetaFuncGraph) 392 FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override; 393 friend bool operator==(const Next &lhs, const Next &rhs) { return lhs.name_ == rhs.name_; } 394 }; 395 using NextPtr = std::shared_ptr<Next>; 396 397 class TupleFunc : public MetaFuncGraph { 398 public: TupleFunc(const std::string & name)399 explicit TupleFunc(const std::string &name) : MetaFuncGraph(name) {} 400 ~TupleFunc() override = default; 401 MS_DECLARE_PARENT(TupleFunc, MetaFuncGraph) 402 FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override; 403 friend bool operator==(const TupleFunc &lhs, const TupleFunc &rhs) { return lhs.name_ == rhs.name_; } 404 }; 405 using TupleFuncPtr = std::shared_ptr<TupleFunc>; 406 407 class ListFunc : public MetaFuncGraph { 408 public: ListFunc(const std::string & name)409 explicit ListFunc(const std::string &name) : MetaFuncGraph(name) {} 410 ~ListFunc() override = default; 411 MS_DECLARE_PARENT(ListFunc, MetaFuncGraph) 412 FuncGraphPtr GenerateFuncGraph(const AbstractBasePtrList &args_abs_list) override; 413 friend bool operator==(const ListFunc &lhs, const ListFunc &rhs) { return lhs.name_ == rhs.name_; } 414 }; 415 using ListFuncPtr = std::shared_ptr<ListFunc>; 416 } // namespace prim 417 } // namespace mindspore 418 419 #endif // MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_H_ 420