• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_MAP_H_
18 #define MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_MAP_H_
19 
20 #include <memory>
21 #include <set>
22 #include <utility>
23 #include <vector>
24 
25 #include "ir/dtype.h"
26 #include "ir/meta_func_graph.h"
27 #include "frontend/operator/composite/multitype_funcgraph.h"
28 
29 namespace mindspore {
30 // namespace to support composite operators definition
31 namespace prim {
32 using ArgsPairList = std::vector<std::pair<AnfNodePtr, TypePtr>>;
33 
34 class Map : public MetaFuncGraph {
35  public:
36   explicit Map(bool reverse = false, const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr)
37       : MetaFuncGraph("map"),
38         fn_leaf_(fn_leaf),
39         reverse_(reverse),
40         broadcast_(false),
41         nonleaf_({kObjectTypeList, kObjectTypeTuple, kObjectTypeClass}) {
42     Init();
43   }
Map(const Map & map)44   Map(const Map &map)
45       : MetaFuncGraph("map"),
46         fn_leaf_(map.fn_leaf_),
47         reverse_(map.reverse_),
48         broadcast_(map.broadcast_),
49         nonleaf_(map.nonleaf_) {
50     Init();
51   }
52   Map &operator=(const Map &map) {
53     if (this != &map) {
54       fn_leaf_ = map.fn_leaf_;
55       reverse_ = map.reverse_;
56       broadcast_ = map.broadcast_;
57       nonleaf_ = map.nonleaf_;
58       if (fn_leaf_) {
59         name_ = "map[" + fn_leaf_->name() + "]";
60       }
61     }
62     return *this;
63   }
64   ~Map() override = default;
65   MS_DECLARE_PARENT(Map, MetaFuncGraph)
66   abstract::AbstractBasePtrList NormalizeArgs(const abstract::AbstractBasePtrList &args_spec_list) const override;
67   FuncGraphPtr GenerateFromTypes(const TypePtrList &args_spec_list) override;
GetFnLeaf()68   MetaFuncGraphPtr GetFnLeaf() { return fn_leaf_; }
69 
70  private:
71   FuncGraphPtr GenerateLeafFunc(const size_t &args_size);
72   AnfNodePtr FullMakeLeaf(const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg, const AnfNodePtrList &args);
73   AnfNodePtr FullMakeList(const std::shared_ptr<List> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
74                           const ArgsPairList &arg_pairs);
75   AnfNodePtr FullMakeTuple(const std::shared_ptr<Tuple> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
76                            const ArgsPairList &arg_pairs);
77   AnfNodePtr FullMakeClass(const std::shared_ptr<Class> &type, const FuncGraphPtr &func_graph, const AnfNodePtr &fn_arg,
78                            const ArgsPairList &arg_pairs);
79   AnfNodePtr Make(const FuncGraphPtr &graph, const AnfNodePtr &fn_arg, const ArgsPairList &arg_pairs);
Init()80   void Init() {
81     if (fn_leaf_ != nullptr) {
82       name_ = "map[" + fn_leaf_->name() + "]";
83     }
84     signatures_ =
85       // def map(func:read, *args:ref):
86       std::vector<Signature>({{"func", SignatureEnumRW::kRWRead, SignatureEnumKind::kKindDefault},
87                               {"args", SignatureEnumRW::kRWRef, SignatureEnumKind::kKindVarPositional}});
88   }
89 
90   MultitypeFuncGraphPtr fn_leaf_;
91   bool reverse_;
92   bool broadcast_;
93   std::set<TypeId> nonleaf_;
94 };
95 using MapPtr = std::shared_ptr<Map>;
96 class MapPy : public Map {
97  public:
98   explicit MapPy(bool reverse = false, const std::shared_ptr<MultitypeFuncGraph> &fn_leaf = nullptr)
Map(reverse,fn_leaf)99       : Map(reverse, fn_leaf) {}
100   ~MapPy() override = default;
101   MS_DECLARE_PARENT(MapPy, Map)
102 };
103 using MapPyPtr = std::shared_ptr<MapPy>;
104 }  // namespace prim
105 }  // namespace mindspore
106 
107 #endif  // MINDSPORE_CCSRC_FRONTEND_OPERATOR_COMPOSITE_MAP_H_
108