1 /** 2 * Copyright 2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CUSTOM_ACTOR_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CUSTOM_ACTOR_H_ 19 20 #include <string> 21 #include <memory> 22 #include <vector> 23 #include "runtime/graph_scheduler/actor/memory_aware_actor.h" 24 #include "runtime/hardware/device_context.h" 25 #include "ir/anf.h" 26 27 namespace mindspore { 28 namespace runtime { 29 class CustomActor : public MemoryAwareActor { 30 public: CustomActor(const std::string & name,const AnfNodePtr & kernel,const device::DeviceContext * device_context,const AID & memory_manager_aid)31 CustomActor(const std::string &name, const AnfNodePtr &kernel, const device::DeviceContext *device_context, 32 const AID &memory_manager_aid) 33 : MemoryAwareActor(name, KernelTransformType::kCustomActor, nullptr, memory_manager_aid), kernel_(kernel) { 34 device_contexts_.push_back(device_context); 35 } 36 ~CustomActor() override = default; 37 38 // The memory related operation interface. 39 void SendMemoryFreeReq(OpContext<DeviceTensor> *const context) override; 40 kernel()41 const AnfNodeWeakPtr &kernel() const { return kernel_; } 42 43 protected: 44 void Run(OpContext<DeviceTensor> *const context) override; 45 void Init() override; 46 47 private: 48 friend class GraphScheduler; 49 friend class ControlNodeScheduler; 50 51 // The info of kernel. 52 AnfNodeWeakPtr kernel_; 53 AnfUtils::CustomActorCallback custom_func_ = {}; 54 GraphExecutionStrategy strategy_{GraphExecutionStrategy::kPipeline}; 55 // The device tensors for launch. 56 std::vector<DeviceTensor *> input_device_tensors_; 57 // The device tensors for memory free. 58 std::vector<DeviceTensor *> memory_free_list_; 59 60 std::string custom_type_; 61 }; 62 63 using CustomActorPtr = std::shared_ptr<CustomActor>; 64 } // namespace runtime 65 } // namespace mindspore 66 #endif // MINDSPORE_CCSRC_RUNTIME_FRAMEWORK_ACTOR_CUSTOM_ACTOR_H_ 67