• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_ABS_CACHE_H
17 #define MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_ABS_CACHE_H
18 #include <string>
19 #include <utility>
20 #include <vector>
21 #include <memory>
22 #include <unordered_map>
23 #include "pybind11/pytypes.h"
24 #include "utils/hash_map.h"
25 #include "utils/ms_utils.h"
26 #include "ir/anf.h"
27 #include "ir/signature.h"
28 
29 namespace mindspore {
30 namespace pynative {
31 // The following structures used to get output abstract of op from cache
32 struct AbsCacheKey {
33   std::string prim_name_;
34   size_t prim_hash_value_;
35   mindspore::HashMap<std::string, ValuePtr> prim_attrs_;
36 };
37 
38 struct AbsCacheKeyHasher {
operatorAbsCacheKeyHasher39   size_t operator()(const AbsCacheKey &key) const { return key.prim_hash_value_; }
40 };
41 
42 struct AbsCacheKeyEqual {
operatorAbsCacheKeyEqual43   bool operator()(const AbsCacheKey &lk, const AbsCacheKey &rk) const {
44     if (lk.prim_name_ != rk.prim_name_) {
45       return false;
46     }
47     return common::IsAttrsEqual(lk.prim_attrs_, rk.prim_attrs_);
48   }
49 };
50 
51 struct PrimAbsInfo {
52   abstract::AbstractBasePtr abs;
53   bool is_dynamic_shape = false;
54   mindspore::HashMap<std::string, ValuePtr> attrs;
55 };
56 using AbstractListMap = std::unordered_map<abstract::AbstractBasePtrList, PrimAbsInfo,
57                                            abstract::AbstractBasePtrListHasher, abstract::AbstractBasePtrListEqual>;
58 using PrimAbsCache = std::unordered_map<AbsCacheKey, AbstractListMap, AbsCacheKeyHasher, AbsCacheKeyEqual>;
59 
60 // Used to get input abstract of op from cache
61 // Key is id of input obj, value is the abstract of input obj
62 using NodeAbsCache = mindspore::HashMap<std::string, abstract::AbstractBasePtr>;
63 
64 // Used to cache implicit cast info according to primitive
65 // Key is primitive name, value is the implicit cast info
66 struct PrimSignature {
67   bool has_dtype_sig;
68   std::vector<SignatureEnumDType> dtypes;
69 };
70 using ImplicitCastCache = mindspore::HashMap<std::string, PrimSignature>;
71 }  // namespace pynative
72 }  // namespace mindspore
73 #endif  // MINDSPORE_CCSRC_PIPELINE_PYNATIVE_PYNATIVE_ABS_CACHE_H
74