1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_MAP_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_MAP_H_ 19 20 #include <memory> 21 #include <set> 22 #include <utility> 23 #include <vector> 24 25 #include "ir/dtype.h" 26 #include "ir/meta_func_graph.h" 27 #include "frontend/operator/composite/multitype_funcgraph.h" 28 29 namespace mindspore { 30 // namespace to support composite operators definition 31 namespace prim { 32 using ArgsPairList = std::vector<std::pair<AnfNodePtr, TypePtr>>; 33 34 class Map : public MetaFuncGraph { 35 public: 36 explicit Map(bool reverse = false, const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr) 37 : MetaFuncGraph("map"), 38 fn_leaf_(fn_leaf), 39 reverse_(reverse), 40 broadcast_(false), 41 nonleaf_({kObjectTypeList, kObjectTypeTuple, kObjectTypeClass}) { 42 Init(); 43 } Map(const Map & map)44 Map(const Map &map) 45 : MetaFuncGraph("map"), 46 fn_leaf_(map.fn_leaf_), 47 reverse_(map.reverse_), 48 broadcast_(map.broadcast_), 49 nonleaf_(map.nonleaf_) { 50 Init(); 51 } 52 Map &operator=(const Map &map) { 53 if (this != &map) { 54 fn_leaf_ = map.fn_leaf_; 55 reverse_ = map.reverse_; 56 broadcast_ = map.broadcast_; 57 nonleaf_ = map.nonleaf_; 58 if (fn_leaf_) { 59 name_ = "map[" + fn_leaf_->name() + "]"; 60 } 61 } 62 return *this; 63 } 64 ~Map() override = default; 65 MS_DECLARE_PARENT(Map, MetaFuncGraph) 66 abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const override; 67 FuncGraphPtr GenerateFromTypes(const TypePtrList &args_spec_list) override; GetFnLeaf()68 MetaFuncGraphPtr GetFnLeaf() { return fn_leaf_; } 69 70 private: 71 FuncGraphPtr GenerateLeafFunc(const size_t &args_size); 72 AnfNodePtr FullMakeLeaf(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const AnfNodePtrList &args); 73 AnfNodePtr FullMakeList(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, 74 const ArgsPairList &arg_pairs); 75 AnfNodePtr FullMakeTuple(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, 76 const ArgsPairList &arg_pairs); 77 AnfNodePtr FullMakeClass(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, 78 const ArgsPairList &arg_pairs); 79 AnfNodePtr Make(const FuncGraphPtr &graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs); Init()80 void Init() { 81 if (fn_leaf_ != nullptr) { 82 name_ = "map[" + fn_leaf_->name() + "]"; 83 } 84 signatures_ = 85 // def map(func:read, *args:ref): 86 std::vector<Signature>({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault}, 87 {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}}); 88 } 89 90 MultitypeFuncGraphPtr fn_leaf_; 91 bool reverse_; 92 bool broadcast_; 93 std::set<TypeId> nonleaf_; 94 }; 95 using MapPtr = std::shared_ptr<Map>; 96 class MapPy : public Map { 97 public: 98 explicit MapPy(bool reverse = false, const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr) Map(reverse,fn_leaf)99 : Map(reverse, fn_leaf) {} 100 ~MapPy() override = default; 101 MS_DECLARE_PARENT(MapPy, Map) 102 }; 103 using MapPyPtr = std::shared_ptr<MapPy>; 104 } // namespace prim 105 } // namespace mindspore 106 107 #endif // MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_MAP_H_ 108