1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_MEMORY_ALLOC_ACTOR_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_MEMORY_ALLOC_ACTOR_H_ 19 20 #include <string> 21 #include <memory> 22 #include "runtime/graph_scheduler/actor/memory_aware_actor.h" 23 24 namespace mindspore { 25 namespace runtime { 26 using mindspore::session::SomasInfo; 27 28 // The memory alloc actor is used to alloc memory of the whole graph at the begin of graph running. 29 class MemoryAllocActor : public MemoryAwareActor { 30 public: MemoryAllocActor(const std::string & name,const AID & memory_manager_aid,SomasInfo * somas_info,const DeviceContext * device_context)31 MemoryAllocActor(const std::string &name, const AID &memory_manager_aid, SomasInfo *somas_info, 32 const DeviceContext *device_context) 33 : MemoryAwareActor(name, KernelTransformType::kMemoryAllocActor, nullptr, memory_manager_aid), 34 somas_info_(somas_info) { 35 (void)device_contexts_.emplace_back(device_context); 36 } 37 ~MemoryAllocActor() override = default; 38 39 // The memory related operation interface. 40 void SendMemoryAllocReq(OpContext<DeviceTensor> *const context) override; 41 // The processing after memory alloc finished. 42 void OnMemoryAllocFinish(OpContext<DeviceTensor> *const context) override; 43 44 // Get the member. somas_info()45 SomasInfo *somas_info() const { return somas_info_; } 46 47 protected: 48 void Init() override; Run(OpContext<DeviceTensor> * const context)49 void Run(OpContext<DeviceTensor> *const context) override { SendMemoryAllocReq(context); } 50 51 private: 52 friend class SchedulerHelper; 53 54 SomasInfo *somas_info_; 55 }; 56 57 using MemoryAllocActorPtr = std::shared_ptr<MemoryAllocActor>; 58 } // namespace runtime 59 } // namespace mindspore 60 61 #endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_MEMORY_ALLOC_ACTOR_H_ 62