• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_ABS_CACHE_H
17 #define MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_ABS_CACHE_H
18 #include <string>
19 #include <utility>
20 #include <vector>
21 #include <memory>
22 #include <unordered_map>
23 #include "ir/anf.h"
24 
25 namespace mindspore::pynative {
26 struct AbsCacheKey {
27   std::string prim_name_;
28   size_t prim_hash_value_;
29   std::unordered_map<std::string, ValuePtr> prim_attrs_;
30 };
31 
32 struct AbsCacheKeyHasher {
operatorAbsCacheKeyHasher33   size_t operator()(const AbsCacheKey &key) const { return key.prim_hash_value_; }
34 };
35 
36 struct AbsCacheKeyEqual {
operatorAbsCacheKeyEqual37   bool operator()(const AbsCacheKey &lk, const AbsCacheKey &rk) const {
38     if (lk.prim_attrs_.size() != rk.prim_attrs_.size()) {
39       return false;
40     }
41     if (lk.prim_name_ != rk.prim_name_) {
42       return false;
43     }
44 
45     auto all = std::all_of(lk.prim_attrs_.begin(), lk.prim_attrs_.end(),
46                            [&rk](const std::pair<std::string, ValuePtr> &item) -> bool {
47                              auto iter = rk.prim_attrs_.find(item.first);
48                              if (iter == rk.prim_attrs_.end()) {
49                                return false;
50                              }
51                              if (item.second == iter->second) {
52                                return true;
53                              }
54                              MS_EXCEPTION_IF_NULL(item.second);
55                              MS_EXCEPTION_IF_NULL(iter->second);
56                              return *item.second == *iter->second;
57                            });
58     return all;
59   }
60 };
61 
62 struct PrimAbsInfo {
63   abstract::AbstractBasePtr abs;
64   bool is_dynamic_shape = false;
65   std::unordered_map<std::string, ValuePtr> attrs;
66 };
67 using AbstractListMap = std::unordered_map<abstract::AbstractBasePtrList, PrimAbsInfo,
68                                            abstract::AbstractBasePtrListHasher, abstract::AbstractBasePtrListEqual>;
69 using PrimAbsCache = std::unordered_map<AbsCacheKey, AbstractListMap, AbsCacheKeyHasher, AbsCacheKeyEqual>;
70 
71 // Used for id
72 struct PyObjectHasher {
operatorPyObjectHasher73   size_t operator()(const py::handle &key) const { return py::hash(key); }
74 };
75 
76 struct PyObjectEqual {
operatorPyObjectEqual77   bool operator()(const py::handle &p1, const py::handle &p2) const { return p1 == p2; }
78 };
79 using PyObjectIdCache = std::unordered_map<py::handle, std::string, PyObjectHasher, PyObjectEqual>;
80 
81 struct PrimSignature {
82   bool has_dtype_sig;
83   std::vector<SignatureEnumDType> dtypes;
84   std::unordered_map<SignatureEnumDType, std::vector<size_t>> type_indexes;
85 };
86 using ImplicitCastCache = std::unordered_map<std::string, PrimSignature>;
87 }  // namespace mindspore::pynative
88 #endif  // MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_ABS_CACHE_H
89