• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_MAP_H_
18 #define MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_MAP_H_
19 
20 #include <memory>
21 #include <set>
22 #include <utility>
23 #include <vector>
24 #include <string>
25 
26 #include "ir/dtype.h"
27 #include "ir/meta_func_graph.h"
28 #include "frontend/operator/composite/multitype_funcgraph.h"
29 
30 namespace mindspore {
31 // namespace to support composite operators definition
32 namespace prim {
33 using ArgsPairList = std::vector<std::pair<AnfNodePtr, TypePtr>>;
34 
35 class Map : public MetaFuncGraph {
36  public:
37   explicit Map(bool reverse = false, const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr)
38       : MetaFuncGraph("map"),
39         fn_leaf_(fn_leaf),
40         reverse_(reverse),
41         broadcast_(false),
42         nonleaf_({kObjectTypeList, kObjectTypeTuple}) {
43     Init();
44   }
Map(const Map & map)45   Map(const Map &map)
46       : MetaFuncGraph("map"),
47         fn_leaf_(map.fn_leaf_),
48         reverse_(map.reverse_),
49         broadcast_(map.broadcast_),
50         nonleaf_(map.nonleaf_) {
51     Init();
52   }
53   Map &operator=(const Map &map) noexcept {
54     if (this != &map) {
55       fn_leaf_ = map.fn_leaf_;
56       reverse_ = map.reverse_;
57       broadcast_ = map.broadcast_;
58       nonleaf_ = map.nonleaf_;
59       if (fn_leaf_) {
60         name_ = "map[" + fn_leaf_->name() + "]";
61       }
62     }
63     return *this;
64   }
65   ~Map() override = default;
66   MS_DECLARE_PARENT(Map, MetaFuncGraph)
67   abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_abs_list) const override;
68   FuncGraphPtr GenerateFromTypes(const TypePtrList &args_abs_list) override;
GetFnLeaf()69   MetaFuncGraphPtr GetFnLeaf() { return fn_leaf_; }
70 
71  private:
72   FuncGraphPtr GenerateLeafFunc(const size_t &args_size);
73   AnfNodePtr FullMakeLeaf(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const AnfNodePtrList &args);
74   AnfNodePtr FullMakeList(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
75                           const ArgsPairList &arg_pairs);
76   AnfNodePtr FullMakeTuple(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
77                            const ArgsPairList &arg_pairs);
78   AnfNodePtr Make(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs);
79   std::pair<std::string, std::string> GetMapInputIndex(size_t num) const;
Init()80   void Init() {
81     if (fn_leaf_ != nullptr) {
82       name_ = "map[" + fn_leaf_->name() + "]";
83     }
84     signatures_ =
85       // def map(func:read, *args:ref):
86       std::vector<Signature>({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault},
87                               {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}});
88   }
89 
90   MultitypeFuncGraphPtr fn_leaf_;
91   bool reverse_;
92   bool broadcast_;
93   std::set<TypeId> nonleaf_;
94 };
95 using MapPtr = std::shared_ptr<Map>;
96 class MapPy : public Map {
97  public:
98   explicit MapPy(bool reverse = false, const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr)
Map(reverse,fn_leaf)99       : Map(reverse, fn_leaf) {}
100   ~MapPy() override = default;
101   MS_DECLARE_PARENT(MapPy, Map)
102 };
103 using MapPyPtr = std::shared_ptr<MapPy>;
104 }  // namespace prim
105 }  // namespace mindspore
106 
107 #endif  // MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_MAP_H_
108