1 /** 2 * Copyright 2020 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_MAP_H_ 18 #define MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_MAP_H_ 19 20 #include <memory> 21 #include <set> 22 #include <utility> 23 #include <vector> 24 #include <string> 25 26 #include "ir/dtype.h" 27 #include "ir/meta_func_graph.h" 28 #include "frontend/operator/composite/multitype_funcgraph.h" 29 30 namespace mindspore { 31 // namespace to support composite operators definition 32 namespace prim { 33 using ArgsPairList = std::vector<std::pair<AnfNodePtr, TypePtr>>; 34 35 class Map : public MetaFuncGraph { 36 public: 37 explicit Map(bool reverse = false, const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr) 38 : MetaFuncGraph("map"), 39 fn_leaf_(fn_leaf), 40 reverse_(reverse), 41 broadcast_(false), 42 nonleaf_({kObjectTypeList, kObjectTypeTuple}) { 43 Init(); 44 } Map(const Map & map)45 Map(const Map &map) 46 : MetaFuncGraph("map"), 47 fn_leaf_(map.fn_leaf_), 48 reverse_(map.reverse_), 49 broadcast_(map.broadcast_), 50 nonleaf_(map.nonleaf_) { 51 Init(); 52 } 53 Map &operator=(const Map &map) noexcept { 54 if (this != &map) { 55 fn_leaf_ = map.fn_leaf_; 56 reverse_ = map.reverse_; 57 broadcast_ = map.broadcast_; 58 nonleaf_ = map.nonleaf_; 59 if (fn_leaf_) { 60 name_ = "map[" + fn_leaf_->name() + "]"; 61 } 62 } 63 return *this; 64 } 65 ~Map() override = default; 66 MS_DECLARE_PARENT(Map, MetaFuncGraph) 67 abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_abs_list) const override; 68 FuncGraphPtr GenerateFromTypes(const TypePtrList &args_abs_list) override; GetFnLeaf()69 MetaFuncGraphPtr GetFnLeaf() { return fn_leaf_; } 70 71 private: 72 FuncGraphPtr GenerateLeafFunc(const size_t &args_size); 73 AnfNodePtr FullMakeLeaf(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const AnfNodePtrList &args); 74 AnfNodePtr FullMakeList(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, 75 const ArgsPairList &arg_pairs); 76 AnfNodePtr FullMakeTuple(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, 77 const ArgsPairList &arg_pairs); 78 AnfNodePtr Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs); 79 std::pair<std::string, std::string> GetMapInputIndex(size_t num) const; Init()80 void Init() { 81 if (fn_leaf_ != nullptr) { 82 name_ = "map[" + fn_leaf_->name() + "]"; 83 } 84 signatures_ = 85 // def map(func:read, *args:ref): 86 std::vector<Signature>({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault}, 87 {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}}); 88 } 89 90 MultitypeFuncGraphPtr fn_leaf_; 91 bool reverse_; 92 bool broadcast_; 93 std::set<TypeId> nonleaf_; 94 }; 95 using MapPtr = std::shared_ptr<Map>; 96 class MapPy : public Map { 97 public: 98 explicit MapPy(bool reverse = false, const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr) Map(reverse,fn_leaf)99 : Map(reverse, fn_leaf) {} 100 ~MapPy() override = default; 101 MS_DECLARE_PARENT(MapPy, Map) 102 }; 103 using MapPyPtr = std::shared_ptr<MapPy>; 104 } // namespace prim 105 } // namespace mindspore 106 107 #endif // MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_MAP_H_ 108