1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 #ifndef MINDSPORE_CORE_ABSTRACT_OPS_OP_INFER_H 17 #define MINDSPORE_CORE_ABSTRACT_OPS_OP_INFER_H 18 19 #include <vector> 20 #include <set> 21 #include <memory> 22 #include "ir/primitive.h" 23 #include "abstract/abstract_value.h" 24 #include "ir/anf.h" 25 26 namespace mindspore { 27 namespace abstract { 28 class OpInferBase { 29 public: 30 OpInferBase() = default; 31 virtual ~OpInferBase() = default; 32 33 /// \brief Infer the output shape for target operator. 34 /// 35 /// \param[in] primitive Operator's primitive. 36 /// \param[in] input_args Operator's inputs. 37 /// 38 /// \return The inferred shape. 39 virtual BaseShapePtr InferShape(const PrimitivePtr &primitive, 40 const std::vector<AbstractBasePtr> &input_args) const = 0; 41 42 /// \brief Infer the output type for target operator. 43 /// 44 /// \param[in] primitive Operator's primitive. 45 /// \param[in] input_args Operator's inputs. 46 /// 47 /// \return The inferred type. 48 virtual TypePtr InferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const = 0; 49 50 /// \brief Infer the output value for target operator. Only override when needed. 51 /// 52 /// \param[in] primitive Operator's primitive. 53 /// \param[in] input_args Operator's inputs. 54 /// 55 /// \return Inferred Value based on given inputs. InferValue(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)56 virtual ValuePtr InferValue(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) const { 57 return kValueAny; 58 } 59 60 /// \brief Get the indices of infer-depend value. 61 /// 62 /// \return Set with indices of infer-depend value. GetValueDependArgIndices()63 virtual std::set<int64_t> GetValueDependArgIndices() const { return {}; } 64 65 /// \brief Infer the related shape and type for target operator. 66 /// 67 /// \param[in] engine 68 /// \param[in] primitive Operator's primitive. 69 /// \param[in] input_args Operator's inputs. 70 /// 71 /// \return AbstractBasePtr with inferred shape and inferred type. InferShapeAndType(const abstract::AnalysisEnginePtr &,const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)72 virtual AbstractBasePtr InferShapeAndType(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, 73 const std::vector<AbstractBasePtr> &input_args) const { 74 auto type = InferType(primitive, input_args); 75 auto shape = InferShape(primitive, input_args); 76 return MakeAbstract(shape, type); 77 } 78 79 /// \brief Infer the related Abstract for target operator. 80 /// 81 /// \param[in] primitive Operator's primitive. 82 /// \param[in] input_args Operator's inputs. 83 /// 84 /// \return AbstractBasePtr with inferred shape and inferred type. InferAbstract(const PrimitivePtr & primitive,const std::vector<AbstractBasePtr> & input_args)85 virtual AbstractBasePtr InferAbstract(const PrimitivePtr &primitive, 86 const std::vector<AbstractBasePtr> &input_args) const { 87 return nullptr; 88 } 89 }; 90 91 using OpInferBasePtr = std::shared_ptr<OpInferBase>; 92 } // namespace abstract 93 } // namespace mindspore 94 #endif // MINDSPORE_CORE_ABSTRACT_OPS_OP_INFER_H 95