• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef MINDSPORE_CCSRC_FRONTEND_OPERATE_OPS_FRONT_INFER_FUNCTION_H_
17 #define MINDSPORE_CCSRC_FRONTEND_OPERATE_OPS_FRONT_INFER_FUNCTION_H_
18 #include <string>
19 #include <vector>
20 #include <optional>
21 #include "abstract/abstract_value.h"
22 #include "abstract/ops/primitive_infer_map.h"
23 namespace mindspore {
24 namespace abstract {
25 const std::vector<std::string> kSparsePrimStr = {"PrimitiveAbstractClosure: S_Prim_MakeCSRTensor",
26                                                  "PrimitiveAbstractClosure: S_Prim_MakeCOOTensor",
27                                                  "PrimitiveAbstractClosure: S_Prim_MakeRowTensor"};
28 // String
29 AbstractBasePtr InferImplStringMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
30                                    const AbstractBasePtrList &args_abs_list);
31 AbstractBasePtr InferImplStringGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
32                                        const AbstractBasePtrList &args_abs_list);
33 // Tuple
34 AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
35                                        const AbstractBasePtrList &args_abs_list);
36 AbstractBasePtr InferImplTupleDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
37                                   const AbstractBasePtrList &args_abs_list);
38 AbstractBasePtr InferImplTuple2Array(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
39                                      const AbstractBasePtrList &args_abs_list);
40 // List
41 AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &engine, const PrimitivePtr &primitive,
42                                     const AbstractBasePtrList &args_abs_list);
43 // Dict
44 AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
45                                  const AbstractBasePtrList &args_abs_list);
46 // Slice
47 AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
48                                    const AbstractBasePtrList &args_abs_list);
49 AbstractBasePtr InferImplSliceGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
50                                       const AbstractBasePtrList &args_abs_list);
51 // Type checking
52 AbstractBasePtr InferImplTypeof(const AnalysisEnginePtr &, const PrimitivePtr &,
53                                 const AbstractBasePtrList &args_abs_list);
54 AbstractBasePtr InferImplHasType(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
55                                  const AbstractBasePtrList &args_abs_list);
56 AbstractBasePtr InferImplIsInstance(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
57                                     const AbstractBasePtrList &args_abs_list);
58 // Shape processing
59 AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
60                                      const AbstractBasePtrList &args_abs_list);
61 // Auto-Grad
62 AbstractBasePtr InferImplStopGradient(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
63                                       const AbstractBasePtrList &args_abs_list);
64 AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
65                                    const AbstractBasePtrList &args_abs_list);
66 AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
67                            const AbstractBasePtrList &args_abs_list);
68 AbstractBasePtr InferImplBroadcastGradientArgs(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
69                                                const AbstractBasePtrList &args_abs_list);
70 // Other
71 AbstractBasePtr InferImplReusing(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
72                                  const AbstractBasePtrList &args_abs_list);
73 AbstractBasePtr InferImplTaylor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
74                                 const AbstractBasePtrList &args_abs_list);
75 AbstractBasePtr InferImplShard(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
76                                const AbstractBasePtrList &args_abs_list);
77 AbstractBasePtr InferImplVmap(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
78                               const AbstractBasePtrList &args_abs_list);
79 AbstractBasePtr InferImplConvertToAdapterTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
80                                                 const AbstractBasePtrList &args_abs_list);
81 AbstractBasePtr InferImplConvertToMsTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
82                                            const AbstractBasePtrList &args_abs_list);
83 AbstractBasePtr InferImplDtypeToEnum(const AnalysisEnginePtr &, const PrimitivePtr &primitive,
84                                      const AbstractBasePtrList &args_abs_list);
85 
86 // Delete this when the infer value can be mapped to the CPU backend operator.
87 bool PrimNeedFrontendInferValue(const PrimitivePtr &primitive);
88 
89 const PrimitiveEvalImplMap &GetFrontendPrimitiveInferMap();
90 PrimitiveEvalImplMap *GetFrontendPrimitiveInferMapPtr();
91 // get prim infer from core/ops infer map or frontend infer map
92 std::optional<StandardPrimitiveImplReg> GetFrontendPrimitiveInferImpl(const PrimitivePtr &primitive);
93 #define REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(name, primitive, infer_impl, infer_value_impl)                         \
94   auto helper_##name = abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), \
95                                                                      primitive, infer_impl, infer_value_impl, false);
96 }  // namespace abstract
97 }  // namespace mindspore
98 #endif  // MINDSPORE_CCSRC_FRONTEND_OPERATE_OPS_FRONT_INFER_FUNCTION_H_
99