1 /** 2 * Copyright 2019-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CCSRC_FRONTEND_OPERATE_OPS_FRONT_INFER_FUNCTION_H_ 17 #define MINDSPORE_CCSRC_FRONTEND_OPERATE_OPS_FRONT_INFER_FUNCTION_H_ 18 #include <string> 19 #include <vector> 20 #include <optional> 21 #include "abstract/abstract_value.h" 22 #include "abstract/ops/primitive_infer_map.h" 23 namespace mindspore { 24 namespace abstract { 25 const std::vector<std::string> kSparsePrimStr = {"PrimitiveAbstractClosure: S_Prim_MakeCSRTensor", 26 "PrimitiveAbstractClosure: S_Prim_MakeCOOTensor", 27 "PrimitiveAbstractClosure: S_Prim_MakeRowTensor"}; 28 // String 29 AbstractBasePtr InferImplStringMul(const AnalysisEnginePtr &, const PrimitivePtr &primitive, 30 const AbstractBasePtrList &args_abs_list); 31 AbstractBasePtr InferImplStringGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, 32 const AbstractBasePtrList &args_abs_list); 33 // Tuple 34 AbstractBasePtr InferImplTupleReversed(const AnalysisEnginePtr &, const PrimitivePtr &primitive, 35 const AbstractBasePtrList &args_abs_list); 36 AbstractBasePtr InferImplTupleDiv(const AnalysisEnginePtr &, const PrimitivePtr &primitive, 37 const AbstractBasePtrList &args_abs_list); 38 AbstractBasePtr InferImplTuple2Array(const AnalysisEnginePtr &, const PrimitivePtr &primitive, 39 const AbstractBasePtrList &args_abs_list); 40 // List 41 AbstractBasePtr InferImplListReduce(const AnalysisEnginePtr &engine, const PrimitivePtr &primitive, 42 const AbstractBasePtrList &args_abs_list); 43 // Dict 44 AbstractBasePtr InferImplDictLen(const AnalysisEnginePtr &, const PrimitivePtr &primitive, 45 const AbstractBasePtrList &args_abs_list); 46 // Slice 47 AbstractBasePtr InferImplMakeSlice(const AnalysisEnginePtr &, const PrimitivePtr &primitive, 48 const AbstractBasePtrList &args_abs_list); 49 AbstractBasePtr InferImplSliceGetItem(const AnalysisEnginePtr &, const PrimitivePtr &primitive, 50 const AbstractBasePtrList &args_abs_list); 51 // Type checking 52 AbstractBasePtr InferImplTypeof(const AnalysisEnginePtr &, const PrimitivePtr &, 53 const AbstractBasePtrList &args_abs_list); 54 AbstractBasePtr InferImplHasType(const AnalysisEnginePtr &, const PrimitivePtr &primitive, 55 const AbstractBasePtrList &args_abs_list); 56 AbstractBasePtr InferImplIsInstance(const AnalysisEnginePtr &, const PrimitivePtr &primitive, 57 const AbstractBasePtrList &args_abs_list); 58 // Shape processing 59 AbstractBasePtr InferImplReduceShape(const AnalysisEnginePtr &, const PrimitivePtr &primitive, 60 const AbstractBasePtrList &args_abs_list); 61 // Auto-Grad 62 AbstractBasePtr InferImplStopGradient(const AnalysisEnginePtr &, const PrimitivePtr &primitive, 63 const AbstractBasePtrList &args_abs_list); 64 AbstractBasePtr InferImplFakeBprop(const AnalysisEnginePtr &, const PrimitivePtr &primitive, 65 const AbstractBasePtrList &args_abs_list); 66 AbstractBasePtr InferImplJ(const AnalysisEnginePtr &, const PrimitivePtr &primitive, 67 const AbstractBasePtrList &args_abs_list); 68 AbstractBasePtr InferImplBroadcastGradientArgs(const AnalysisEnginePtr &, const PrimitivePtr &primitive, 69 const AbstractBasePtrList &args_abs_list); 70 // Other 71 AbstractBasePtr InferImplReusing(const AnalysisEnginePtr &, const PrimitivePtr &primitive, 72 const AbstractBasePtrList &args_abs_list); 73 AbstractBasePtr InferImplTaylor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, 74 const AbstractBasePtrList &args_abs_list); 75 AbstractBasePtr InferImplShard(const AnalysisEnginePtr &, const PrimitivePtr &primitive, 76 const AbstractBasePtrList &args_abs_list); 77 AbstractBasePtr InferImplVmap(const AnalysisEnginePtr &, const PrimitivePtr &primitive, 78 const AbstractBasePtrList &args_abs_list); 79 AbstractBasePtr InferImplConvertToAdapterTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, 80 const AbstractBasePtrList &args_abs_list); 81 AbstractBasePtr InferImplConvertToMsTensor(const AnalysisEnginePtr &, const PrimitivePtr &primitive, 82 const AbstractBasePtrList &args_abs_list); 83 AbstractBasePtr InferImplDtypeToEnum(const AnalysisEnginePtr &, const PrimitivePtr &primitive, 84 const AbstractBasePtrList &args_abs_list); 85 86 // Delete this when the infer value can be mapped to the CPU backend operator. 87 bool PrimNeedFrontendInferValue(const PrimitivePtr &primitive); 88 89 const PrimitiveEvalImplMap &GetFrontendPrimitiveInferMap(); 90 PrimitiveEvalImplMap *GetFrontendPrimitiveInferMapPtr(); 91 // get prim infer from core/ops infer map or frontend infer map 92 std::optional<StandardPrimitiveImplReg> GetFrontendPrimitiveInferImpl(const PrimitivePtr &primitive); 93 #define REGISTER_PRIMITIVE_FRONT_EVAL_IMPL(name, primitive, infer_impl, infer_value_impl) \ 94 auto helper_##name = abstract::RegisterStandardPrimitiveEvalHelper(abstract::GetFrontendPrimitiveInferMapPtr(), \ 95 primitive, infer_impl, infer_value_impl, false); 96 } // namespace abstract 97 } // namespace mindspore 98 #endif // MINDSPORE_CCSRC_FRONTEND_OPERATE_OPS_FRONT_INFER_FUNCTION_H_ 99