• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2021-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_RUNTIME_HARDWARE_DEVICE_CONTEXT_H_
18 #define MINDSPORE_CCSRC_RUNTIME_HARDWARE_DEVICE_CONTEXT_H_
19 
20 #include <map>
21 #include <memory>
22 #include <string>
23 #include <vector>
24 #include <unordered_map>
25 
26 #include "include/backend/device_type.h"
27 #include "include/backend/device_address.h"
28 #include "runtime/device/gsm/swap_manager.h"
29 #include "runtime/collective/collective_communication_lib.h"
30 #include "runtime/collective/collective_comm_lib_loader.h"
31 #include "include/backend/kernel_graph.h"
32 #include "include/backend/anf_runtime_algorithm.h"
33 #include "include/common/utils/anfalgo.h"
34 #include "runtime/hardware/deprecated_interface.h"
35 #include "runtime/device/auto_mem_offload.h"
36 #include "runtime/device/memory_manager.h"
37 #include "include/backend/optimizer/graph_optimizer.h"
38 #include "runtime/pipeline/task/task.h"
39 #include "ir/device_event.h"
40 #include "utils/ms_context.h"
41 #include "ir/tensor.h"
42 #ifdef __APPLE__
43 #include "mindrt/include/async/spinlock.h"
44 #endif
45 
46 namespace mindspore {
47 namespace device {
48 using mindspore::kernel::AddressPtr;
49 using mindspore::kernel::KernelMod;
50 using mindspore::kernel::KernelTensor;
51 
52 const size_t kDeviceContextsNumOne = 1;
53 const size_t kDeviceContextsNumTwo = 2;
54 
55 struct DeviceContextKey {
56   // device type name, such as 'GPU' 'Ascend' 'CPU'.
57   std::string device_name_;
58   uint32_t device_id_{0};
59 
60   // Use the result of ToString() as key to look up DeviceContext
61   // in cache map which maintains created DeviceContext objects.
ToStringDeviceContextKey62   std::string ToString() const { return device_name_ + "_" + std::to_string(device_id_); }
63 };
64 
65 class DeviceResManager;
66 class GraphExecutor;
67 class KernelExecutor;
68 
69 // DeviceContext is unified interface of interaction with device.
70 class DeviceContext {
71  public:
DeviceContext(const DeviceContextKey & device_context_key)72   explicit DeviceContext(const DeviceContextKey &device_context_key)
73       : device_context_key_(device_context_key), initialized_(false) {}
74   virtual ~DeviceContext() = default;
75 
76   // Initialize the device context.
77   virtual void Initialize() = 0;
78 
79   // Destroy device context and release device resource.
Destroy()80   virtual void Destroy() {}
81 
82   // Analysis the function graph to check whether all nodes are supported, if yes, return true, if no, return false and
83   // mark the unsupported node as "NotSupport" through SetCNodeNotSupported()
84   // For further usage, each device can add a attribute kAttrGraphSplitGroup to the node, and give different
85   // group_name (the type must be a std::string, default is 'DefaultGroup') to the attribute, which means the
86   // continuous nodes with the same group_name will be split into one subgraph.
PartitionGraph(const FuncGraphPtr & func_graph)87   virtual bool PartitionGraph(const FuncGraphPtr &func_graph) const { return false; }
88 
89   // Analysis the function graph and select the appropriate run mode for the graph
90   virtual RunMode GetRunMode(const FuncGraphPtr &func_graph) const = 0;
91 
92   // Get device_context_key_ to obtain device name and device id.
device_context_key()93   const DeviceContextKey &device_context_key() const { return device_context_key_; }
94 
95   // Get device address type according different device type, such GPU, Ascend.
GetDeviceType()96   DeviceType GetDeviceType() const { return GetDeviceTypeByName(device_context_key_.device_name_); }
97 
98   // Get kernel executor by is dynamic shape
GetKernelExecutor(bool is_dynamic_shape)99   std::shared_ptr<KernelExecutor> GetKernelExecutor(bool is_dynamic_shape) const {
100     if (is_dynamic_shape) {
101       return dyn_kernel_executor_;
102     } else {
103       return kernel_executor_;
104     }
105   }
106 
SetKernelExecutor(const std::shared_ptr<KernelExecutor> & kernel_executor)107   void SetKernelExecutor(const std::shared_ptr<KernelExecutor> &kernel_executor) { kernel_executor_ = kernel_executor; }
108 
SetDynKernelExecutor(const std::shared_ptr<KernelExecutor> & kernel_executor)109   void SetDynKernelExecutor(const std::shared_ptr<KernelExecutor> &kernel_executor) {
110     dyn_kernel_executor_ = kernel_executor;
111   }
112 
113   // todo: delete
GetDeprecatedInterface()114   virtual DeprecatedInterface *GetDeprecatedInterface() { return nullptr; }
115 
116   // Return whether this device context is initialized.
initialized()117   bool initialized() const {
118 #ifdef __APPLE__
119     std::lock_guard<SpinLock> spin_lock(init_lock_);
120 #else
121     std::lock_guard<std::mutex> lock(init_mutex_);
122 #endif
123     return initialized_;
124   }
125 
126   DeviceContextKey device_context_key_;
127   std::unique_ptr<DeviceResManager> device_res_manager_;
128   std::unique_ptr<GraphExecutor> graph_executor_;
129 
130  protected:
131 #ifdef __APPLE__
132   // There are some problems with using mutex on Mac, use spinlocks instead.
133   inline static SpinLock init_lock_;
134 #else
135   inline static std::mutex init_mutex_;
136 #endif
137   bool initialized_;
138 
139  private:
140   std::shared_ptr<KernelExecutor> kernel_executor_;
141   std::shared_ptr<KernelExecutor> dyn_kernel_executor_;
142 };
143 using DeviceContextPtr = std::shared_ptr<DeviceContext>;
144 
145 class BACKEND_EXPORT DeviceResManager {
146  public:
DeviceResManager()147   DeviceResManager() : collective_comm_lib_(nullptr), device_context_(nullptr) {
148     offloaded_mem_pool_ = std::make_shared<device::OffloadedMemPool>();
149   }
150   virtual ~DeviceResManager() = default;
151 
152   // Initialize the device resource manager.
Initialize()153   virtual void Initialize() {}
154 
155   // Destroy device resource manager and release device resource.
Destroy()156   virtual void Destroy() {}
157 
158   // Bind device to current thread to gain device control privileges
159   // If force_bind is true, bind context to current thread every time;
160   // Otherwise, only bind context to current thread for the first time.
BindDeviceToCurrentThread(bool force_bind)161   virtual bool BindDeviceToCurrentThread(bool force_bind) const { return true; }
ResetStreamAndCtx()162   virtual void ResetStreamAndCtx() {}
163 
164   // Relevant function to allocate and free device memory of raw ptr.
165   virtual void *AllocateMemory(size_t size, uint32_t stream_id = kDefaultStreamIndex) const = 0;
166   virtual void FreeMemory(void *ptr) const = 0;
167   virtual void FreePartMemorys(const std::vector<void *> &free_addrs, const std::vector<void *> &keep_addrs,
168                                const std::vector<size_t> &keep_addr_sizes) const = 0;
DefragMemory()169   virtual void DefragMemory() {}
170 
SwapIn(const void * host_ptr,void * device_ptr,size_t mem_size,void * stream)171   virtual void SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) {
172     MS_LOG(EXCEPTION) << "Unimplemented interface.";
173     return;
174   }
SwapOut(const void * device_ptr,void * host_ptr,size_t mem_size,void * stream)175   virtual void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) {
176     MS_LOG(EXCEPTION) << "Unimplemented interface.";
177     return;
178   }
179 
180   // Relevant function to allocate and free device memory of DeviceAddress.
181   virtual bool AllocateMemory(DeviceAddress *const &address, uint32_t stream_id = UINT32_MAX) const;
182   virtual void FreeMemory(DeviceAddress *const &address) const;
GetMaxUsedMemorySize()183   virtual size_t GetMaxUsedMemorySize() const { return 0; }
184 
185   // Relevant function to manage memory statistics
GetTotalMemStatistics()186   virtual size_t GetTotalMemStatistics() const { return 0; }
GetTotalUsedMemStatistics()187   virtual size_t GetTotalUsedMemStatistics() const { return 0; }
GetTotalIdleMemStatistics()188   virtual size_t GetTotalIdleMemStatistics() const { return 0; }
GetTotalEagerFreeMemStatistics()189   virtual size_t GetTotalEagerFreeMemStatistics() const { return 0; }
GetUsedMemPeakStatistics()190   virtual size_t GetUsedMemPeakStatistics() const { return 0; }
GetReservedMemPeakStatistics()191   virtual size_t GetReservedMemPeakStatistics() const { return 0; }
GetBlockCountsStatistics()192   virtual std::unordered_map<std::string, std::size_t> GetBlockCountsStatistics() const { return {}; }
GetBlockUnitSizeStatistics()193   virtual std::unordered_map<std::string, std::size_t> GetBlockUnitSizeStatistics() const { return {}; }
194   virtual std::unordered_map<device::DeviceMemPtr, std::unordered_map<std::string, size_t>>
GetCommonMemBlocksInfoStatistics()195   GetCommonMemBlocksInfoStatistics() const {
196     return {};
197   }
198   virtual std::unordered_map<device::DeviceMemPtr, std::unordered_map<std::string, size_t>>
GetPersistentMemBlocksInfoStatistics()199   GetPersistentMemBlocksInfoStatistics() const {
200     return {};
201   }
ResetMaxMemoryReserved()202   virtual void ResetMaxMemoryReserved() const {};
ResetMaxMemoryAllocated()203   virtual void ResetMaxMemoryAllocated() const {};
204 
205   // Allocate host memory with raii and ref count
AllocateHostMemory(size_t size)206   virtual std::shared_ptr<void> AllocateHostMemory(size_t size) const {
207     return std::shared_ptr<void>(::malloc(size), ::free);
208   }
209   // Allocate host memory for offload device memory.
210   virtual void *AllocateOffloadMemory(size_t size) const;
211   // Release host memory which was allocated by AllocateOffloadMemory to pool.
212   // It will not be free to os.
213   virtual void FreeOffloadMemory(void *ptr) const;
214 
GetAvailableMemSize()215   virtual size_t GetAvailableMemSize() const { return 0; }
216 
217   // Allocate continuous device memory according to size list.
218   // Communication operators may need continuous memory for input and output
219   // to optimize the communication performance.
220   virtual std::vector<void *> AllocateContinuousMemory(const std::vector<size_t> &size_list,
221                                                        uint32_t stream_id = kDefaultStreamIndex) const {
222     MS_LOG(EXCEPTION) << "Unimplemented interface.";
223   }
224 
225   // Create concrete device address according different device type using KernelTensor.
CreateDeviceAddress(const KernelTensorPtr & kernel_tensor)226   virtual DeviceAddressPtr CreateDeviceAddress(const KernelTensorPtr &kernel_tensor) const {
227     MS_LOG(EXCEPTION) << "Unimplemented interface.";
228   }
MoveTo(const tensor::TensorPtr & src_tensor,const tensor::TensorPtr & dst_tensor,const std::string & to,bool blocking,bool * return_self)229   virtual void MoveTo(const tensor::TensorPtr &src_tensor, const tensor::TensorPtr &dst_tensor, const std::string &to,
230                       bool blocking, bool *return_self) {
231     MS_LOG(EXCEPTION) << "Unimplemented interface.";
232   }
233 
CreateDeviceAddress(void * ptr,size_t size,const ShapeVector & shape_vector,const Format & format,TypeId type_id,const std::string & device_name,uint32_t device_id,uint32_t stream_id)234   virtual DeviceAddressPtr CreateDeviceAddress(void *ptr, size_t size, const ShapeVector &shape_vector,
235                                                const Format &format, TypeId type_id, const std::string &device_name,
236                                                uint32_t device_id, uint32_t stream_id) const {
237     MS_LOG(EXCEPTION) << "Unimplemented interface.";
238   }
239 
240   // Create a stream with assigning a stream id, the assigned stream id will be written to the parameter '*stream_id'.
CreateStream(size_t * stream_id)241   virtual bool CreateStream(size_t *stream_id) const {
242     MS_LOG(WARNING) << "Unimplemented interface: 'CreateStream'.";
243     *stream_id = kSizeZero;
244     return false;
245   }
246 
247   // Create a stream with priority.
CreateStreamWithPriority(size_t * stream_id,int32_t priority)248   virtual bool CreateStreamWithPriority(size_t *stream_id, int32_t priority) const {
249     *stream_id = kSizeZero;
250     return false;
251   }
252 
QueryStreamSize()253   virtual size_t QueryStreamSize() const { return 0L; }
GetStreamIds()254   virtual std::vector<uint32_t> GetStreamIds() const { return {}; }
255 
256   // If multi-stream used in pynative mode, other streams must be sync before the graph
257   // is executed. Otherwise, out-of-order occurs. Therefore this flag is added.
258   // This solution is a temporary solution, this flag will be removed after multi-stream is
259   // supported in graph mode.
single_op_multi_stream_enable()260   virtual bool single_op_multi_stream_enable() const { return false; }
set_single_op_multi_stream_enable(bool single_op_multi_stream_enable)261   virtual void set_single_op_multi_stream_enable(bool single_op_multi_stream_enable) {}
262 
263   // Get the stream pointer by stream_id.
GetStream(size_t stream_id)264   virtual void *GetStream(size_t stream_id) const { return nullptr; };
265 
266   // Set currently using stream id.
SetCurrentStreamId(size_t stream_id)267   virtual void SetCurrentStreamId(size_t stream_id) { return; }
268 
269   // Get currently using stream id.
GetCurrentStreamId()270   virtual size_t GetCurrentStreamId() const { return kSizeZero; }
271 
GetStream()272   virtual void *GetStream() const { return nullptr; };
273 
274   // Destroy a stream bound to the input parameter "stream_id".
DestroyStream(size_t stream_id)275   virtual bool DestroyStream(size_t stream_id) const { return false; }
276 
277   // Query tasks' completion status of a stream.
QueryStream(size_t stream_id)278   virtual bool QueryStream(size_t stream_id) const { return true; }
279 
280   // Synchronize stream, device such as GPU and Ascend need stream to launch kernel asynchronously,
281   // Using 'SyncStream' to block thread and wait for completing all tasks on specific stream.
282   // Using 'SyncAllStream' to block thread and wait for completing all tasks on all streams.
283   // Devices without stream could ignore the implementation of these function.
284   // Since the current entry for creating streams is not unified, the implementation of the 'SyncStream' and
285   // "SyncAllStreams" interfaces are implemented by subclasses.
SyncStream(size_t stream_id)286   virtual bool SyncStream(size_t stream_id) const { return true; }
287 
SyncAllStreams()288   virtual bool SyncAllStreams() const { return true; }
289 
SyncNotDefaultStreams()290   virtual bool SyncNotDefaultStreams() const { return true; }
291 
292   // Return default stream id. Normally it's 0.
DefaultStream()293   virtual size_t DefaultStream() const { return 0; }
294 
295   // Create device event for runtime.
CreateRuntimeEvent(bool enable_blocking,bool enable_record_wait)296   virtual DeviceEventPtr CreateRuntimeEvent(bool enable_blocking, bool enable_record_wait) { return nullptr; }
297 
298   // Create device event with flag.
CreateEventWithFlag(bool enable_timing,bool blocking)299   virtual DeviceEventPtr CreateEventWithFlag(bool enable_timing, bool blocking) { return nullptr; };
300 
301   // Destroy specified device event.
302   virtual bool DestroyEvent(const DeviceEventPtr &event);
303 
304   // Destroy all device events.
305   virtual bool DestroyAllEvents();
306 
307   // Dynamically load collective communication library.
308   // Currently, four types are supported: OpenMPI and self developed framework for CPU. NCCL for GPU. HCCL for Ascend.
LoadCollectiveCommLib()309   virtual bool LoadCollectiveCommLib() { return true; }
310 
311   // Return collective communication object for caller to access
collective_comm_lib()312   CollectiveCommunicationLib *collective_comm_lib() const { return collective_comm_lib_; }
313 
swap_manager()314   std::shared_ptr<SwapManager> swap_manager() const { return swap_manager_; }
315 
mem_manager()316   std::shared_ptr<MemoryManager> mem_manager() const { return mem_manager_; }
317 
318  protected:
319   // Ensure the thread safety for allocating device memory.
320   mutable std::mutex alloc_mem_mutex_;
321 
322   // The collective communication library.
323   CollectiveCommunicationLib *collective_comm_lib_;
324 
325   DeviceContext *device_context_{nullptr};
326 
327   std::shared_ptr<SwapManager> swap_manager_{nullptr};
328 
329   std::mutex device_events_mutex_;
330 
331   DeviceEventPtrList device_events_{};
332 
333   std::shared_ptr<MemoryManager> mem_manager_{nullptr};
334 
335  private:
336   template <class... Args>
337   friend class DeviceInterface;
SetDeviceContext(DeviceContext * device_context)338   void SetDeviceContext(DeviceContext *device_context) { device_context_ = device_context; }
339   std::shared_ptr<device::OffloadedMemPool> offloaded_mem_pool_;
340 };
341 
342 class GraphExecutor {
343  public:
344   virtual ~GraphExecutor() = default;
CompileGraph(const FuncGraphPtr & graph,const std::map<string,string> & compile_options)345   virtual bool CompileGraph(const FuncGraphPtr &graph, const std::map<string, string> &compile_options) { return true; }
RunGraph(const FuncGraphPtr & graph,const std::vector<tensor::Tensor> & inputs,std::vector<tensor::Tensor> * outputs,const std::map<string,string> & compile_options)346   virtual bool RunGraph(const FuncGraphPtr &graph, const std::vector<tensor::Tensor> &inputs,
347                         std::vector<tensor::Tensor> *outputs, const std::map<string, string> &compile_options) {
348     MS_LOG(EXCEPTION) << "Unimplemented interface.";
349   }
GetRandomStatus(const std::vector<FuncGraphPtr> & graphs)350   virtual std::string GetRandomStatus(const std::vector<FuncGraphPtr> &graphs) { return ""; }
GetGraphFeatureMemory(const FuncGraphPtr & graph)351   virtual size_t GetGraphFeatureMemory(const FuncGraphPtr &graph) const { return 0; }
InitGraphInfo(const FuncGraphPtr & graph)352   virtual void InitGraphInfo(const FuncGraphPtr &graph) { return; };
353 
354  protected:
355   DeviceContext *device_context_{nullptr};
356 
357  private:
358   template <class... Args>
359   friend class DeviceInterface;
360 
SetDeviceContext(DeviceContext * device_context)361   void SetDeviceContext(DeviceContext *device_context) { device_context_ = device_context; }
362 };
363 
364 using CallbackFunc = std::function<void(void)>;
365 
366 class BACKEND_EXPORT KernelExecutor {
367  public:
368   virtual ~KernelExecutor() = default;
369 
Initialize()370   virtual void Initialize(){};
Destroy()371   virtual void Destroy(){};
372 
373   // Optimize the kernel graph for graph mode.
OptimizeGraph(const FuncGraphPtr & graph)374   virtual void OptimizeGraph(const FuncGraphPtr &graph) const {}
375 
376   // Generate 'KernelMod' for all kernels and set 'KernelMod' into kernel,
377   // 'KernelMod' is real executive object of kernel.
CreateKernel(const std::vector<CNodePtr> & nodes)378   virtual void CreateKernel(const std::vector<CNodePtr> &nodes) const {}
CreateKernelMod(const std::string & op_name)379   virtual kernel::KernelModPtr CreateKernelMod(const std::string &op_name) const { MS_LOG(EXCEPTION) << "Unrealized"; };
380 
381   // Adjust kernel graph before run graph.
PreprocessBeforeRun(const FuncGraphPtr & graph)382   virtual void PreprocessBeforeRun(const FuncGraphPtr &graph) const {}
383 
384   // Launch a kernel via 'KernelMod' of the kernel, use KernelTensor input type.
LaunchKernel(const CNodePtr & kernel,const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & workspace,const std::vector<KernelTensor * > & outputs,KernelMod * kernel_mod,void * stream)385   virtual bool LaunchKernel(const CNodePtr &kernel, const std::vector<KernelTensor *> &inputs,
386                             const std::vector<KernelTensor *> &workspace, const std::vector<KernelTensor *> &outputs,
387                             KernelMod *kernel_mod, void *stream) const {
388     MS_LOG(EXCEPTION) << "Unimplemented interface.";
389   }
390   // Launch callback.
LaunchCallback(std::function<void (void)> callback_func,size_t stream_id)391   virtual bool LaunchCallback(std::function<void(void)> callback_func, size_t stream_id) const {
392     callback_func();
393     return true;
394   };
395   // Unify the MindIR, the default behavior uses the common unified MindIR.
396   virtual void UnifyMindIR(const KernelGraphPtr &graph) const;
AddMindIRPass(const KernelGraphPtr & graph)397   virtual void AddMindIRPass(const KernelGraphPtr &graph) const {};
398 
399   // Get rank id for distributed training.
GetRankID()400   virtual uint32_t GetRankID() const { return 0; }
401 
SetDeviceContext(DeviceContext * device_context)402   void SetDeviceContext(DeviceContext *device_context) { device_context_ = device_context; }
403 
ExecuteKernelTask(const runtime::KernelTaskType & task_type,const device::DeviceAddressPtrList & input_addr_list,const device::DeviceAddressPtrList & output_addr_list,const size_t & stream_id)404   virtual bool ExecuteKernelTask(const runtime::KernelTaskType &task_type,
405                                  const device::DeviceAddressPtrList &input_addr_list,
406                                  const device::DeviceAddressPtrList &output_addr_list, const size_t &stream_id) const {
407     return false;
408   };
409 
410  protected:
411   DeviceContext *device_context_{nullptr};
412 };
413 
414 template <class... Args>
415 class DeviceInterface : public DeviceContext {};
416 
417 template <>
418 class DeviceInterface<> : public DeviceContext {
419  public:
DeviceInterface(const DeviceContextKey & key)420   explicit DeviceInterface(const DeviceContextKey &key) : DeviceContext(key) {}
421 
422  protected:
CheckUnset(const void * ptr,const std::string & error_msg)423   void CheckUnset(const void *ptr, const std::string &error_msg) const {
424     if (ptr != nullptr) {
425       MS_LOG(EXCEPTION) << error_msg;
426     }
427   }
428 };
429 
430 template <class T, class... Args>
431 class DeviceInterface<T, Args...> : public DeviceInterface<Args...> {
432  public:
DeviceInterface(const DeviceContextKey & key)433   explicit DeviceInterface(const DeviceContextKey &key) : DeviceInterface<Args...>(key) {
434     if constexpr (std::is_base_of_v<DeviceResManager, T>) {
435       DeviceInterface::CheckUnset(reinterpret_cast<void *>(DeviceContext::device_res_manager_.get()),
436                                   "DeviceResManager has been registered!");
437       DeviceContext::device_res_manager_ = std::make_unique<T>();
438       DeviceContext::device_res_manager_->SetDeviceContext(this);
439     } else if constexpr (std::is_base_of_v<GraphExecutor, T>) {
440       DeviceInterface::CheckUnset(reinterpret_cast<void *>(DeviceContext::graph_executor_.get()),
441                                   "GraphExecutor has been registered!");
442       DeviceContext::graph_executor_ = std::make_unique<T>();
443       DeviceContext::graph_executor_->SetDeviceContext(this);
444     } else if constexpr (std::is_base_of_v<KernelExecutor, T>) {
445       DeviceInterface::CheckUnset(reinterpret_cast<void *>(DeviceContext::GetKernelExecutor(false).get()),
446                                   "KernelExecutor has been registered!");
447       DeviceInterface::CheckUnset(reinterpret_cast<void *>(DeviceContext::GetKernelExecutor(true).get()),
448                                   "Dyn KernelExecutor has been registered!");
449       DeviceContext::SetKernelExecutor(std::make_shared<T>());
450       DeviceContext::GetKernelExecutor(false)->SetDeviceContext(this);
451       // for GPU/CPU dynamic shape kernel executor
452       DeviceContext::SetDynKernelExecutor(DeviceContext::GetKernelExecutor(false));
453     }
454   }
455 
456  private:
457   template <typename = std::enable_if_t<std::is_base_of_v<DeviceResManager, T> || std::is_base_of_v<GraphExecutor, T> ||
458                                         std::is_base_of_v<KernelExecutor, T>>>
Assert()459   void Assert() const {}
460 };
461 }  // namespace device
462 }  // namespace mindspore
463 #endif  // MINDSPORE_CCSRC_RUNTIME_HARDWARE_DEVICE_CONTEXT_H_
464