1 /** 2 * This is the C++ adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/). 3 * 4 * Copyright 2019-2022 Huawei Technologies Co., Ltd 5 * 6 * Licensed under the Apache License, Version 2.0 (the "License"); 7 * you may not use this file except in compliance with the License. 8 * You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, software 13 * distributed under the License is distributed on an "AS IS" BASIS, 14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 * See the License for the specific language governing permissions and 16 * limitations under the License. 17 */ 18 19 #ifndef MINDSPORE_CORE_ABSTRACT_ABSTRACT_FUNCTION_H_ 20 #define MINDSPORE_CORE_ABSTRACT_ABSTRACT_FUNCTION_H_ 21 22 #include <cstdint> 23 #include <memory> 24 #include <string> 25 26 #include "abstract/abstract_value.h" 27 #include "abstract/analysis_context.h" 28 #include "ir/meta_func_graph.h" 29 30 namespace mindspore { 31 namespace abstract { 32 /// \brief AbstractFuncAtom defines interface for abstract of atom function. 33 class MS_CORE_API AbstractFuncAtom : public AbstractFunction { 34 public: 35 /// \brief Constructor of AbstractFuncAtom. 36 AbstractFuncAtom() = default; 37 38 /// \brief Destructor of AbstractFuncAtom. 39 ~AbstractFuncAtom() override = default; MS_DECLARE_PARENT(AbstractFuncAtom,AbstractFunction)40 MS_DECLARE_PARENT(AbstractFuncAtom, AbstractFunction) 41 42 AbstractFunctionPtr GetUnique() override { return shared_from_base<AbstractFuncAtom>(); } 43 44 AbstractFunctionPtr Join(const AbstractFunctionPtr &other) final; 45 46 void Visit(std::function<void(const AbstractFuncAtomPtr &)> visit_func) const final; 47 48 bool operator==(const AbstractFunction &other) const override; 49 hash()50 std::size_t hash() const override { return tid(); } 51 }; 52 53 /// \brief AbstractFuncUnion defines interface for abstract of union function. 54 class MS_CORE_API AbstractFuncUnion final : public AbstractFunction { 55 public: 56 /// \brief Constructor AbstractFuncUnion from AbstractFuncAtom list. 57 /// 58 /// \param[in] func_list The AbstractFuncAtom list for AbstractFuncUnion. 59 explicit AbstractFuncUnion(const AbstractFuncAtomPtrList &func_list); 60 61 /// \brief Constructor AbstractFuncUnion from two AbstractFunction. 62 /// 63 /// \param[in] first The first AbstractFunction for AbstractFuncUnion. 64 /// \param[in] second The second AbstractFunction for AbstractFuncUnion. 65 AbstractFuncUnion(const AbstractFunctionPtr &first, const AbstractFunctionPtr &second); 66 67 /// \brief Destructor for AbstractFunction. 68 ~AbstractFuncUnion() override = default; 69 MS_DECLARE_PARENT(AbstractFuncUnion, AbstractFunction) 70 71 std::string ToString() const override; 72 73 std::string ToString(bool verbose) const override; 74 GetUnique()75 AbstractFunctionPtr GetUnique() override { MS_LOG(INTERNAL_EXCEPTION) << "Cannot get unique from AbstractFuncUnion"; } 76 77 /// \brief Check whether the input AbstractFunction is in AbstractFuncUnion. 78 /// 79 /// \param[in] other The input AbstractFunction for check. 80 /// 81 /// \return Return true if other is in AbstractFuncUnion, otherwise return False. 82 bool IsSuperSet(const AbstractFunctionPtr &other); 83 84 AbstractFunctionPtr Join(const AbstractFunctionPtr &other) final; 85 86 void Visit(std::function<void(const AbstractFuncAtomPtr &)> visit_func) const final; 87 88 bool operator==(const AbstractFunction &other) const override; 89 90 std::size_t hash() const override; 91 Copy()92 AbstractFunctionPtr Copy() const override { MS_LOG(INTERNAL_EXCEPTION) << "Cannot Copy from AbstractFuncUnion"; } 93 94 private: 95 AbstractFuncAtomPtrList func_list_; 96 }; 97 98 /// \brief PrimitiveAbstractClosure defines interface for abstract of Primitive. 99 class MS_CORE_API PrimitiveAbstractClosure final : public AbstractFuncAtom { 100 public: 101 /// \brief Constructor of PrimitiveAbstractClosure 102 /// 103 /// \param[in] prim The primitive that this PrimitiveAbstractClosure corresponding to. 104 /// \param[in] tracking_node A Node identifies different uses of the prim. 105 explicit PrimitiveAbstractClosure(const PrimitivePtr &prim, const AnfNodePtr &tracking_node = nullptr) PrimitiveAbstractClosure(prim,ToTrackingId (tracking_node))106 : PrimitiveAbstractClosure(prim, ToTrackingId(tracking_node)) {} 107 108 // For internal usage only, make it public so that make_shared can work on it. PrimitiveAbstractClosure(const PrimitivePtr & prim,std::uintptr_t tracking_id)109 PrimitiveAbstractClosure(const PrimitivePtr &prim, std::uintptr_t tracking_id) 110 : prim_(prim), tracking_id_(tracking_id) {} 111 112 /// \brief Destructor of PrimitiveAbstractClosure 113 ~PrimitiveAbstractClosure() override = default; MS_DECLARE_PARENT(PrimitiveAbstractClosure,AbstractFuncAtom)114 MS_DECLARE_PARENT(PrimitiveAbstractClosure, AbstractFuncAtom) 115 116 /// \brief Get the Primitive that this PrimitiveAbstractClosure corresponding to. 117 /// 118 /// \return The Primitive that this PrimitiveAbstractClosure corresponding to. 119 const PrimitivePtr &prim() const { return prim_; } 120 tracking_id()121 std::uintptr_t tracking_id() const override { return tracking_id_; } 122 Copy()123 AbstractFunctionPtr Copy() const override { return std::make_shared<PrimitiveAbstractClosure>(prim_, tracking_id_); } 124 CopyWithoutTrackingId()125 AbstractFunctionPtr CopyWithoutTrackingId() const override { 126 return std::make_shared<PrimitiveAbstractClosure>(prim_, 0); 127 } 128 129 bool operator==(const AbstractFunction &other) const override; 130 131 std::size_t hash() const override; 132 ToString()133 std::string ToString() const override { return "PrimitiveAbstractClosure: " + prim_->name(); } 134 135 std::string ToString(bool verbose) const override; 136 RealBuildValue()137 ValuePtr RealBuildValue() const override { return prim_; } 138 139 private: 140 PrimitivePtr prim_; 141 // To discriminate different usage of same primitive calls, 142 // store it as the memory address of the user node. 143 std::uintptr_t tracking_id_; 144 }; 145 using PrimitiveAbstractClosurePtr = std::shared_ptr<PrimitiveAbstractClosure>; 146 147 /// \brief FuncGraphAbstractClosure defines interface for abstract of FuncGraph. 148 class MS_CORE_API FuncGraphAbstractClosure final : public AbstractFuncAtom { 149 public: 150 /// \brief Constructor of FuncGraphAbstractClosure. 151 /// 152 /// \param[in] func_graph The function graph that this PrimitiveAbstractClosure corresponding to. 153 /// \param[in] context The context that func_graph corresponding to. 154 /// \param[in] tracking_node A Node identifies different uses of the func_graph. 155 FuncGraphAbstractClosure(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context, 156 const AnfNodePtr &tracking_node = nullptr, bool specialized = false) FuncGraphAbstractClosure(func_graph,context,ToTrackingId (tracking_node),specialized)157 : FuncGraphAbstractClosure(func_graph, context, ToTrackingId(tracking_node), specialized) {} 158 159 // For internal usage only, make it public so that make_shared can work on it. FuncGraphAbstractClosure(const FuncGraphPtr & func_graph,const AnalysisContextPtr & context,std::uintptr_t tracking_id,bool specialized)160 FuncGraphAbstractClosure(const FuncGraphPtr &func_graph, const AnalysisContextPtr &context, 161 std::uintptr_t tracking_id, bool specialized) 162 : func_graph_(FuncGraphWeakPtr(func_graph)), 163 context_(context), 164 tracking_id_(tracking_id), 165 specialized_(specialized) { 166 MS_EXCEPTION_IF_NULL(func_graph); 167 MS_EXCEPTION_IF_NULL(context); 168 } 169 170 /// \brief Destructor of FuncGraphAbstractClosure. 171 ~FuncGraphAbstractClosure() override = default; MS_DECLARE_PARENT(FuncGraphAbstractClosure,AbstractFuncAtom)172 MS_DECLARE_PARENT(FuncGraphAbstractClosure, AbstractFuncAtom) 173 174 /// \brief Get the FuncGraph that this FuncGraphAbstractClosure corresponding to. 175 /// 176 /// \return The FuncGraph that this FuncGraphAbstractClosure corresponding to. 177 FuncGraphPtr func_graph() const { return func_graph_.lock(); } 178 context()179 AnalysisContextPtr context() const override { return context_; } 180 tracking_id()181 std::uintptr_t tracking_id() const override { return tracking_id_; } 182 specialized()183 bool specialized() const { return specialized_; } 184 Copy()185 AbstractFunctionPtr Copy() const override { 186 return std::make_shared<FuncGraphAbstractClosure>(func_graph(), context_, tracking_id_, specialized_); 187 } 188 CopyWithoutTrackingId()189 AbstractFunctionPtr CopyWithoutTrackingId() const override { 190 return std::make_shared<FuncGraphAbstractClosure>(func_graph(), context_, 0, specialized_); 191 } 192 193 bool operator==(const AbstractFunction &other) const override; 194 195 std::size_t hash() const override; 196 197 std::string ToString() const override; 198 199 std::string ToString(bool verbose) const override; 200 201 bool IsEqualExceptTrackingId(const FuncGraphAbstractClosure &other) const; 202 203 std::size_t HashWithoutTrackingId() const; 204 205 private: 206 FuncGraphWeakPtr func_graph_; 207 AnalysisContextPtr context_; 208 // To discriminate different usage of same graph by using this tracking_id, 209 // so different tracking_id will produce different FuncGraphAbstractClosure, 210 // different FuncGraphEvaluator. 211 // Especially useful for recursive func graph call, so it will not mess up 212 // the `context_` in FuncGraphEvaluator. 213 // store it as the memory address of the user node. 214 std::uintptr_t tracking_id_; 215 // If the func_graph_ member is the specialized func_graph_ in current IR or 216 // it's a old func_graph of IR before renormalized. 217 bool specialized_{false}; 218 }; 219 using FuncGraphAbstractClosurePtr = std::shared_ptr<FuncGraphAbstractClosure>; 220 221 /// \brief MetaFuncGraphAbstractClosure defines interface for abstract of MetaFuncGraph. 222 class MS_CORE_API MetaFuncGraphAbstractClosure final : public AbstractFuncAtom { 223 public: 224 /// \brief Constructor of FuncGraphAbstractClosure. 225 /// 226 /// \param[in] meta_func_graph The function graph that this MetaFuncGraphAbstractClosure corresponding to. 227 /// \param[in] tracking_node A Node identifies different uses of the meta_func_graph. 228 /// \param[in] scope The scope to which the tracking_id belong to. 229 explicit MetaFuncGraphAbstractClosure(const MetaFuncGraphPtr &meta_func_graph, 230 const AnfNodePtr &tracking_node = nullptr, 231 const ScopePtr &scope = kDefaultScope) MetaFuncGraphAbstractClosure(meta_func_graph,ToTrackingId (tracking_node),scope)232 : MetaFuncGraphAbstractClosure(meta_func_graph, ToTrackingId(tracking_node), scope) {} 233 234 // For internal usage only, make it public so that make_shared can work on it. MetaFuncGraphAbstractClosure(const MetaFuncGraphPtr & meta_func_graph,std::uintptr_t tracking_id,const ScopePtr & scope)235 MetaFuncGraphAbstractClosure(const MetaFuncGraphPtr &meta_func_graph, std::uintptr_t tracking_id, 236 const ScopePtr &scope) 237 : meta_func_graph_(meta_func_graph), tracking_id_(tracking_id), scope_(scope) {} 238 239 /// \brief Destructor of MetaFuncGraphAbstractClosure. 240 ~MetaFuncGraphAbstractClosure() override = default; MS_DECLARE_PARENT(MetaFuncGraphAbstractClosure,AbstractFuncAtom)241 MS_DECLARE_PARENT(MetaFuncGraphAbstractClosure, AbstractFuncAtom) 242 243 /// \brief Get the MetaFuncGraph that this MetaFuncGraphAbstractClosure corresponding to. 244 /// 245 /// \return The MetaFuncGraph that this MetaFuncGraphAbstractClosure corresponding to. 246 const MetaFuncGraphPtr &meta_func_graph() const { return meta_func_graph_; } 247 context()248 AnalysisContextPtr context() const override { return AnalysisContext::DummyContext(); } 249 250 /// \brief Get the Scope that this MetaFuncGraphAbstractClosure corresponding to. 251 /// 252 /// \return The Scope that this MetaFuncGraphAbstractClosure corresponding to. GetScope()253 const ScopePtr &GetScope() const { return scope_; } 254 tracking_id()255 std::uintptr_t tracking_id() const override { return tracking_id_; } 256 Copy()257 AbstractFunctionPtr Copy() const override { 258 return std::make_shared<MetaFuncGraphAbstractClosure>(meta_func_graph_, tracking_id_, kDefaultScope); 259 } 260 CopyWithoutTrackingId()261 AbstractFunctionPtr CopyWithoutTrackingId() const override { 262 return std::make_shared<MetaFuncGraphAbstractClosure>(meta_func_graph_, 0, kDefaultScope); 263 } 264 265 bool operator==(const AbstractFunction &other) const override; 266 267 std::size_t hash() const override; 268 269 std::string ToString() const override; 270 271 private: 272 MetaFuncGraphPtr meta_func_graph_; 273 // Refer the comment in FuncGraphAbstractClosure; 274 // Store it as memory address of the user node. 275 std::uintptr_t tracking_id_; 276 ScopePtr scope_; 277 }; 278 using MetaFuncGraphAbstractClosurePtr = std::shared_ptr<MetaFuncGraphAbstractClosure>; 279 280 /// \brief PartialAbstractClosure defines the abstract AbstractFuncAtom interface provided by some args in advance. 281 class MS_CORE_API PartialAbstractClosure final : public AbstractFuncAtom { 282 public: 283 /// \brief Constructor of PartialAbstractClosure. 284 /// 285 /// \param[in] fn The AbstractFuncAtom this PartialAbstractClosure corresponding to. 286 /// \param[in] args_abs_list The first few parameters provided for fn in advance. 287 /// \param[in] node The CNode which this PartialAbstractClosure evaluated from. 288 PartialAbstractClosure(const AbstractFuncAtomPtr &fn, const AbstractBasePtrList &args_abs_list, 289 const AnfNodePtr &node = nullptr) fn_(fn)290 : fn_(fn), args_abs_list_(args_abs_list), node_(AnfNodePtr(node)) {} 291 292 /// \brief Destructor of PartialAbstractClosure. 293 ~PartialAbstractClosure() override = default; MS_DECLARE_PARENT(PartialAbstractClosure,AbstractFuncAtom)294 MS_DECLARE_PARENT(PartialAbstractClosure, AbstractFuncAtom) 295 296 /// \brief Get the AbstractFuncAtom that this PartialAbstractClosure corresponding to. 297 /// 298 /// \return The AbstractFuncAtom that this PartialAbstractClosure corresponding to. 299 AbstractFuncAtomPtr fn() { return fn_; } 300 301 /// \brief Set the AbstractFuncAtom that this PartialAbstractClosure corresponding to. 302 /// 303 /// \param[in] fn The AbstractFuncAtom that this PartialAbstractClosure corresponding to. set_fn(const AbstractFuncAtomPtr & fn)304 void set_fn(const AbstractFuncAtomPtr &fn) { fn_ = fn; } 305 306 /// \brief Get the pre-provided arguments. 307 /// 308 /// \return The pre-provided arguments. args()309 const AbstractBasePtrList &args() const { return args_abs_list_; } 310 311 /// \brief Get the CNode this PartialAbstractClosure evaluated from. 312 /// 313 /// \return The CNode this PartialAbstractClosure evaluated from. node()314 AnfNodePtr node() const { return node_.lock(); } 315 316 /// \brief Set the CNode this PartialAbstractClosure evaluated from. 317 /// 318 /// \param[in] node The CNode this PartialAbstractClosure evaluated from. set_node(const AnfNodePtr & node)319 void set_node(const AnfNodePtr &node) { node_ = AnfNodeWeakPtr(node); } 320 321 /// \brief Get whether the args need to be appended to the end. 322 /// 323 /// \return Whether the args need to be appended to the end. need_append_to_end()324 bool need_append_to_end() const { return need_append_to_end_; } 325 326 /// \brief Set whether the args need to be appended to the end. 327 /// 328 /// \param[in] flag Whether the args need to be appended to the end. set_need_append_to_end(bool flag)329 void set_need_append_to_end(bool flag) { need_append_to_end_ = flag; } 330 Copy()331 AbstractFunctionPtr Copy() const override { 332 auto abs = std::make_shared<PartialAbstractClosure>(fn_, args_abs_list_, node_.lock()); 333 abs->set_need_append_to_end(need_append_to_end_); 334 return abs; 335 } 336 337 bool operator==(const AbstractFunction &other) const override; 338 339 std::size_t hash() const override; 340 341 std::string ToString() const override; 342 343 std::string ToString(bool verbose) const override; 344 345 protected: RealBuildValue()346 ValuePtr RealBuildValue() const override { return fn_->BuildValue(); } 347 348 private: 349 AbstractFuncAtomPtr fn_; 350 AbstractBasePtrList args_abs_list_; 351 // The ANFNode which this PartialAbstractClosure evaluated from. 352 AnfNodeWeakPtr node_; 353 bool need_append_to_end_{false}; 354 }; 355 using PartialAbstractClosurePtr = std::shared_ptr<PartialAbstractClosure>; 356 357 /// \brief JTransformedAbstractClosure defines interface for abstract of Function 358 /// transformed through the application of J. 359 class MS_CORE_API JTransformedAbstractClosure final : public AbstractFuncAtom { 360 public: 361 /// \brief Constructor of JTransformedAbstractClosure 362 /// 363 /// \param[in] fn The AbstractFuncAtom transformed through the application of J. JTransformedAbstractClosure(const AbstractFuncAtomPtr & fn)364 explicit JTransformedAbstractClosure(const AbstractFuncAtomPtr &fn) : fn_(fn) {} 365 366 /// \brief Destructor of JTransformedAbstractClosure 367 ~JTransformedAbstractClosure() override = default; MS_DECLARE_PARENT(JTransformedAbstractClosure,AbstractFuncAtom)368 MS_DECLARE_PARENT(JTransformedAbstractClosure, AbstractFuncAtom) 369 370 /// \brief Get the AbstractFuncAtom JTransformedAbstractClosure corresponding to. 371 /// 372 /// \return The AbstractFuncAtom JTransformedAbstractClosure corresponding to. 373 const AbstractFuncAtomPtr &fn() const { return fn_; } 374 Copy()375 AbstractFunctionPtr Copy() const override { return std::make_shared<JTransformedAbstractClosure>(fn_); } 376 377 bool operator==(const AbstractFunction &other) const override; 378 379 std::size_t hash() const override; 380 ToString()381 std::string ToString() const override { return "J(" + fn_->ToString() + ")"; } 382 383 private: 384 AbstractFuncAtomPtr fn_; 385 }; 386 387 /// \brief TaylorTransformedAbstractClosure defines interface for abstract of Function 388 /// transformed through the application of Taylor. 389 class MS_CORE_API TaylorTransformedAbstractClosure final : public AbstractFuncAtom { 390 public: 391 /// \brief Constructor of TaylorTransformedAbstractClosure 392 /// 393 /// \param[in] fn The AbstractFuncAtom transformed through the application of Taylor. TaylorTransformedAbstractClosure(const AbstractFuncAtomPtr & fn)394 explicit TaylorTransformedAbstractClosure(const AbstractFuncAtomPtr &fn) : fn_(fn) {} 395 396 /// \brief Destructor of TaylorTransformedAbstractClosure 397 ~TaylorTransformedAbstractClosure() override = default; MS_DECLARE_PARENT(TaylorTransformedAbstractClosure,AbstractFuncAtom)398 MS_DECLARE_PARENT(TaylorTransformedAbstractClosure, AbstractFuncAtom) 399 400 /// \brief Get the AbstractFuncAtom TaylorTransformedAbstractClosure corresponding to. 401 /// 402 /// \return The AbstractFuncAtom TaylorTransformedAbstractClosure corresponding to. 403 const AbstractFuncAtomPtr &fn() const { return fn_; } 404 Copy()405 AbstractFunctionPtr Copy() const override { return std::make_shared<TaylorTransformedAbstractClosure>(fn_); } 406 407 bool operator==(const AbstractFunction &other) const override; 408 409 std::size_t hash() const override; 410 ToString()411 std::string ToString() const override { return "Taylor(" + fn_->ToString() + ")"; } 412 413 private: 414 AbstractFuncAtomPtr fn_; 415 }; 416 417 /// \brief ShardTransformedAbstractClosure defines interface for abstract of Function 418 /// transformed through the application of Shard. 419 class MS_CORE_API ShardTransformedAbstractClosure final : public AbstractFuncAtom { 420 public: 421 /// \brief Constructor of ShardTransformedAbstractClosure 422 /// 423 /// \param[in] fn The AbstractFuncAtom transformed through the application of Shard. ShardTransformedAbstractClosure(const AbstractFuncAtomPtr & fn)424 explicit ShardTransformedAbstractClosure(const AbstractFuncAtomPtr &fn) : fn_(fn) {} 425 426 /// \brief Destructor of ShardTransformedAbstractClosure 427 ~ShardTransformedAbstractClosure() override = default; MS_DECLARE_PARENT(ShardTransformedAbstractClosure,AbstractFuncAtom)428 MS_DECLARE_PARENT(ShardTransformedAbstractClosure, AbstractFuncAtom) 429 430 /// \brief Get the AbstractFuncAtom ShardTransformedAbstractClosure corresponding to. 431 /// 432 /// \return The AbstractFuncAtom ShardTransformedAbstractClosure corresponding to. 433 const AbstractFuncAtomPtr &fn() const { return fn_; } 434 Copy()435 AbstractFunctionPtr Copy() const override { return std::make_shared<ShardTransformedAbstractClosure>(fn_); } 436 437 bool operator==(const AbstractFunction &other) const override; 438 439 std::size_t hash() const override; 440 ToString()441 std::string ToString() const override { return "Shard(" + fn_->ToString() + ")"; } 442 443 private: 444 AbstractFuncAtomPtr fn_; 445 }; 446 447 /// \brief VmapTransformedAbstractClosure defines interface for abstract of Function 448 /// transformed through the application of Vmap. 449 class MS_CORE_API VmapTransformedAbstractClosure final : public AbstractFuncAtom { 450 public: 451 /// \brief Constructor of VmapTransformedAbstractClosure 452 /// 453 /// \param[in] fn The AbstractFuncAtom transformed through the application of Vmap. VmapTransformedAbstractClosure(const AbstractFuncAtomPtr & fn,const ValuePtr & in_axes,const ValuePtr & out_axes,size_t cell_size)454 explicit VmapTransformedAbstractClosure(const AbstractFuncAtomPtr &fn, const ValuePtr &in_axes, 455 const ValuePtr &out_axes, size_t cell_size) 456 : fn_(fn), in_axes_(in_axes), out_axes_(out_axes), cell_size_(cell_size) {} 457 458 /// \brief Destructor of VmapTransformedAbstractClosure 459 ~VmapTransformedAbstractClosure() override = default; MS_DECLARE_PARENT(VmapTransformedAbstractClosure,AbstractFuncAtom)460 MS_DECLARE_PARENT(VmapTransformedAbstractClosure, AbstractFuncAtom) 461 462 /// \brief Get the AbstractFuncAtom VmapTransformedAbstractClosure corresponding to. 463 /// 464 /// \return The AbstractFuncAtom VmapTransformedAbstractClosure corresponding to. 465 const AbstractFuncAtomPtr &fn() const { return fn_; } 466 in_axes()467 const ValuePtr &in_axes() const { return in_axes_; } 468 out_axes()469 const ValuePtr &out_axes() const { return out_axes_; } 470 cell_size()471 size_t cell_size() const { return cell_size_; } 472 Copy()473 AbstractFunctionPtr Copy() const override { 474 return std::make_shared<VmapTransformedAbstractClosure>(fn_, in_axes_, out_axes_, cell_size_); 475 } 476 477 bool operator==(const AbstractFunction &other) const override; 478 479 std::size_t hash() const override; 480 ToString()481 std::string ToString() const override { return "Vmap(" + fn_->ToString() + ")"; } 482 483 private: 484 AbstractFuncAtomPtr fn_; 485 ValuePtr in_axes_; 486 ValuePtr out_axes_; 487 size_t cell_size_; 488 }; 489 490 /// \brief VirtualAbstractClosure defines interface for function with an explicitly 491 /// fixed type signature. 492 class MS_CORE_API VirtualAbstractClosure final : public AbstractFuncAtom { 493 public: 494 /// \brief Constructor of VirtualAbstractClosure. 495 /// 496 /// \param[in] args_abs_list The abstract values of the arguments to the function. 497 /// \param[in] output_spec The abstract value of output. VirtualAbstractClosure(const AbstractBasePtrList & args_abs_list,const AbstractBasePtr & output_spec)498 VirtualAbstractClosure(const AbstractBasePtrList &args_abs_list, const AbstractBasePtr &output_spec) 499 : args_abs_list_(args_abs_list), output_(output_spec) {} 500 501 /// \brief Constructor of VirtualAbstractClosure. 502 /// 503 /// \param[in] args_abs The abstract value of argument to the function. 504 /// \param[in] output_spec The abstract value of output. VirtualAbstractClosure(const AbstractBasePtr & args_abs,const AbstractBasePtr & output_spec)505 VirtualAbstractClosure(const AbstractBasePtr &args_abs, const AbstractBasePtr &output_spec) 506 : args_abs_list_({args_abs}), output_(output_spec) {} 507 508 /// \brief Destructor of VirtualAbstractClosure. 509 ~VirtualAbstractClosure() override = default; MS_DECLARE_PARENT(VirtualAbstractClosure,AbstractFuncAtom)510 MS_DECLARE_PARENT(VirtualAbstractClosure, AbstractFuncAtom) 511 512 /// \brief Get the abstract values of arguments. 513 /// 514 /// \return The abstract values of arguments. 515 const AbstractBasePtrList &args_abs_list() const { return args_abs_list_; } 516 517 /// \brief Get the abstract value of output. 518 /// 519 /// \return The abstract value of output. output()520 const AbstractBasePtr &output() const { return output_; } 521 Copy()522 AbstractFunctionPtr Copy() const override { 523 return std::make_shared<VirtualAbstractClosure>(args_abs_list_, output_); 524 } 525 526 bool operator==(const AbstractFunction &other) const override; 527 528 std::size_t hash() const override; 529 530 std::string ToString() const override; 531 532 private: 533 AbstractBasePtrList args_abs_list_; 534 AbstractBasePtr output_; 535 }; 536 using VirtualAbstractClosurePtr = std::shared_ptr<VirtualAbstractClosure>; 537 538 /// \brief TypedPrimitiveAbstractClosure defines interface for Primitive with an explicitly 539 /// fixed type signature. 540 class MS_CORE_API TypedPrimitiveAbstractClosure final : public AbstractFuncAtom { 541 public: 542 /// \brief Constructor of TypedPrimitiveAbstractClosure. 543 /// 544 /// \param[in] prim The Primitive with an explicitly fixed type signature. 545 /// \param[in] args_abs_list The abstract values of arguments to the Primitive. 546 /// \param[in] output_spec The abstract value of output. TypedPrimitiveAbstractClosure(const PrimitivePtr prim,const AbstractBasePtrList & args_abs_list,const AbstractBasePtr & output_spec)547 TypedPrimitiveAbstractClosure(const PrimitivePtr prim, const AbstractBasePtrList &args_abs_list, 548 const AbstractBasePtr &output_spec) 549 : prim_(prim), args_abs_list_(args_abs_list), output_(output_spec) {} 550 551 /// \brief Destructor of TypedPrimitiveAbstractClosure. 552 ~TypedPrimitiveAbstractClosure() override = default; MS_DECLARE_PARENT(TypedPrimitiveAbstractClosure,AbstractFuncAtom)553 MS_DECLARE_PARENT(TypedPrimitiveAbstractClosure, AbstractFuncAtom) 554 555 /// \brief Get the Primitive that this TypedPrimitiveAbstractClosure corresponding to. 556 /// 557 /// \return The Primitive that this TypedPrimitiveAbstractClosure corresponding to. 558 const PrimitivePtr &prim() const { return prim_; } 559 560 /// \brief Get the abstract values of arguments this TypedPrimitiveAbstractClosure corresponding to. 561 /// 562 /// \return The abstract values of arguments this TypedPrimitiveAbstractClosure corresponding to. args_abs_list()563 const AbstractBasePtrList &args_abs_list() const { return args_abs_list_; } 564 565 /// \brief Get the abstract value of output this TypedPrimitiveAbstractClosure corresponding to. 566 /// 567 /// \return The abstract value of output this TypedPrimitiveAbstractClosure corresponding to. output()568 const AbstractBasePtr &output() const { return output_; } 569 Copy()570 AbstractFunctionPtr Copy() const override { 571 return std::make_shared<TypedPrimitiveAbstractClosure>(prim_, args_abs_list_, output_); 572 } 573 574 bool operator==(const AbstractFunction &other) const override; 575 576 std::size_t hash() const override; 577 578 std::string ToString() const override; 579 580 private: 581 PrimitivePtr prim_; 582 AbstractBasePtrList args_abs_list_; 583 AbstractBasePtr output_; 584 }; 585 586 /// \brief Hash operator for AbstractFunction. 587 struct MS_CORE_API AbstractFunctionHasher { 588 /// \brief Implementation of hash operation. 589 /// 590 /// \param[in] t The AbstractFunction needs to hash. 591 /// 592 /// \return The hash result. operatorAbstractFunctionHasher593 std::size_t operator()(const AbstractFunctionPtr &t) const { 594 std::size_t hash = t->hash(); 595 return hash; 596 } 597 }; 598 599 /// \brief Equal operator for AbstractFunction. 600 struct MS_CORE_API AbstractFunctionEqual { 601 /// \brief Implementation of Equal operation. 602 /// 603 /// \param[in] lhs The left AbstractFunction for compare. 604 /// \param[in] rhs The right AbstractFunction for compare. 605 /// 606 /// \return Return True if the comparison result is equal, otherwise return False. operatorAbstractFunctionEqual607 bool operator()(const AbstractFunctionPtr &lhs, const AbstractFunctionPtr &rhs) const { return *lhs == *rhs; } 608 }; 609 } // namespace abstract 610 } // namespace mindspore 611 #endif // MINDSPORE_CORE_ABSTRACT_ABSTRACT_FUNCTION_H_ 612