1 /** 2 * Copyright 2021-2022 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_HARDWARE_DEVICE_CONTEXT_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_HARDWARE_DEVICE_CONTEXT_H_ 19 20 #include <map> 21 #include <memory> 22 #include <string> 23 #include <vector> 24 #include <unordered_map> 25 26 #include "include/backend/device_type.h" 27 #include "include/backend/device_address.h" 28 #include "runtime/device/gsm/swap_manager.h" 29 #include "runtime/collective/collective_communication_lib.h" 30 #include "runtime/collective/collective_comm_lib_loader.h" 31 #include "include/backend/kernel_graph.h" 32 #include "include/backend/anf_runtime_algorithm.h" 33 #include "include/common/utils/anfalgo.h" 34 #include "runtime/hardware/deprecated_interface.h" 35 #include "runtime/device/auto_mem_offload.h" 36 #include "runtime/device/memory_manager.h" 37 #include "include/backend/optimizer/graph_optimizer.h" 38 #include "runtime/pipeline/task/task.h" 39 #include "ir/device_event.h" 40 #include "utils/ms_context.h" 41 #include "ir/tensor.h" 42 #ifdef __APPLE__ 43 #include "mindrt/include/async/spinlock.h" 44 #endif 45 46 namespace mindspore { 47 namespace device { 48 using mindspore::kernel::AddressPtr; 49 using mindspore::kernel::KernelMod; 50 using mindspore::kernel::KernelTensor; 51 52 const size_t kDeviceContextsNumOne = 1; 53 const size_t kDeviceContextsNumTwo = 2; 54 55 struct DeviceContextKey { 56 // device type name, such as 'GPU' 'Ascend' 'CPU'. 57 std::string device_name_; 58 uint32_t device_id_{0}; 59 60 // Use the result of ToString() as key to look up DeviceContext 61 // in cache map which maintains created DeviceContext objects. ToStringDeviceContextKey62 std::string ToString() const { return device_name_ + "_" + std::to_string(device_id_); } 63 }; 64 65 class DeviceResManager; 66 class GraphExecutor; 67 class KernelExecutor; 68 69 // DeviceContext is unified interface of interaction with device. 70 class DeviceContext { 71 public: DeviceContext(const DeviceContextKey & device_context_key)72 explicit DeviceContext(const DeviceContextKey &device_context_key) 73 : device_context_key_(device_context_key), initialized_(false) {} 74 virtual ~DeviceContext() = default; 75 76 // Initialize the device context. 77 virtual void Initialize() = 0; 78 79 // Destroy device context and release device resource. Destroy()80 virtual void Destroy() {} 81 82 // Analysis the function graph to check whether all nodes are supported, if yes, return true, if no, return false and 83 // mark the unsupported node as "NotSupport" through SetCNodeNotSupported() 84 // For further usage, each device can add a attribute kAttrGraphSplitGroup to the node, and give different 85 // group_name (the type must be a std::string, default is 'DefaultGroup') to the attribute, which means the 86 // continuous nodes with the same group_name will be split into one subgraph. PartitionGraph(const FuncGraphPtr & func_graph)87 virtual bool PartitionGraph(const FuncGraphPtr &func_graph) const { return false; } 88 89 // Analysis the function graph and select the appropriate run mode for the graph 90 virtual RunMode GetRunMode(const FuncGraphPtr &func_graph) const = 0; 91 92 // Get device_context_key_ to obtain device name and device id. device_context_key()93 const DeviceContextKey &device_context_key() const { return device_context_key_; } 94 95 // Get device address type according different device type, such GPU, Ascend. GetDeviceType()96 DeviceType GetDeviceType() const { return GetDeviceTypeByName(device_context_key_.device_name_); } 97 98 // Get kernel executor by is dynamic shape GetKernelExecutor(bool is_dynamic_shape)99 std::shared_ptr<KernelExecutor> GetKernelExecutor(bool is_dynamic_shape) const { 100 if (is_dynamic_shape) { 101 return dyn_kernel_executor_; 102 } else { 103 return kernel_executor_; 104 } 105 } 106 SetKernelExecutor(const std::shared_ptr<KernelExecutor> & kernel_executor)107 void SetKernelExecutor(const std::shared_ptr<KernelExecutor> &kernel_executor) { kernel_executor_ = kernel_executor; } 108 SetDynKernelExecutor(const std::shared_ptr<KernelExecutor> & kernel_executor)109 void SetDynKernelExecutor(const std::shared_ptr<KernelExecutor> &kernel_executor) { 110 dyn_kernel_executor_ = kernel_executor; 111 } 112 113 // todo: delete GetDeprecatedInterface()114 virtual DeprecatedInterface *GetDeprecatedInterface() { return nullptr; } 115 116 // Return whether this device context is initialized. initialized()117 bool initialized() const { 118 #ifdef __APPLE__ 119 std::lock_guard<SpinLock> spin_lock(init_lock_); 120 #else 121 std::lock_guard<std::mutex> lock(init_mutex_); 122 #endif 123 return initialized_; 124 } 125 126 DeviceContextKey device_context_key_; 127 std::unique_ptr<DeviceResManager> device_res_manager_; 128 std::unique_ptr<GraphExecutor> graph_executor_; 129 130 protected: 131 #ifdef __APPLE__ 132 // There are some problems with using mutex on Mac, use spinlocks instead. 133 inline static SpinLock init_lock_; 134 #else 135 inline static std::mutex init_mutex_; 136 #endif 137 bool initialized_; 138 139 private: 140 std::shared_ptr<KernelExecutor> kernel_executor_; 141 std::shared_ptr<KernelExecutor> dyn_kernel_executor_; 142 }; 143 using DeviceContextPtr = std::shared_ptr<DeviceContext>; 144 145 class BACKEND_EXPORT DeviceResManager { 146 public: DeviceResManager()147 DeviceResManager() : collective_comm_lib_(nullptr), device_context_(nullptr) { 148 offloaded_mem_pool_ = std::make_shared<device::OffloadedMemPool>(); 149 } 150 virtual ~DeviceResManager() = default; 151 152 // Initialize the device resource manager. Initialize()153 virtual void Initialize() {} 154 155 // Destroy device resource manager and release device resource. Destroy()156 virtual void Destroy() {} 157 158 // Bind device to current thread to gain device control privileges 159 // If force_bind is true, bind context to current thread every time; 160 // Otherwise, only bind context to current thread for the first time. BindDeviceToCurrentThread(bool force_bind)161 virtual bool BindDeviceToCurrentThread(bool force_bind) const { return true; } ResetStreamAndCtx()162 virtual void ResetStreamAndCtx() {} 163 164 // Relevant function to allocate and free device memory of raw ptr. 165 virtual void *AllocateMemory(size_t size, uint32_t stream_id = kDefaultStreamIndex) const = 0; 166 virtual void FreeMemory(void *ptr) const = 0; 167 virtual void FreePartMemorys(const std::vector<void *> &free_addrs, const std::vector<void *> &keep_addrs, 168 const std::vector<size_t> &keep_addr_sizes) const = 0; DefragMemory()169 virtual void DefragMemory() {} 170 SwapIn(const void * host_ptr,void * device_ptr,size_t mem_size,void * stream)171 virtual void SwapIn(const void *host_ptr, void *device_ptr, size_t mem_size, void *stream) { 172 MS_LOG(EXCEPTION) << "Unimplemented interface."; 173 return; 174 } SwapOut(const void * device_ptr,void * host_ptr,size_t mem_size,void * stream)175 virtual void SwapOut(const void *device_ptr, void *host_ptr, size_t mem_size, void *stream) { 176 MS_LOG(EXCEPTION) << "Unimplemented interface."; 177 return; 178 } 179 180 // Relevant function to allocate and free device memory of DeviceAddress. 181 virtual bool AllocateMemory(DeviceAddress *const &address, uint32_t stream_id = UINT32_MAX) const; 182 virtual void FreeMemory(DeviceAddress *const &address) const; GetMaxUsedMemorySize()183 virtual size_t GetMaxUsedMemorySize() const { return 0; } 184 185 // Relevant function to manage memory statistics GetTotalMemStatistics()186 virtual size_t GetTotalMemStatistics() const { return 0; } GetTotalUsedMemStatistics()187 virtual size_t GetTotalUsedMemStatistics() const { return 0; } GetTotalIdleMemStatistics()188 virtual size_t GetTotalIdleMemStatistics() const { return 0; } GetTotalEagerFreeMemStatistics()189 virtual size_t GetTotalEagerFreeMemStatistics() const { return 0; } GetUsedMemPeakStatistics()190 virtual size_t GetUsedMemPeakStatistics() const { return 0; } GetReservedMemPeakStatistics()191 virtual size_t GetReservedMemPeakStatistics() const { return 0; } GetBlockCountsStatistics()192 virtual std::unordered_map<std::string, std::size_t> GetBlockCountsStatistics() const { return {}; } GetBlockUnitSizeStatistics()193 virtual std::unordered_map<std::string, std::size_t> GetBlockUnitSizeStatistics() const { return {}; } 194 virtual std::unordered_map<device::DeviceMemPtr, std::unordered_map<std::string, size_t>> GetCommonMemBlocksInfoStatistics()195 GetCommonMemBlocksInfoStatistics() const { 196 return {}; 197 } 198 virtual std::unordered_map<device::DeviceMemPtr, std::unordered_map<std::string, size_t>> GetPersistentMemBlocksInfoStatistics()199 GetPersistentMemBlocksInfoStatistics() const { 200 return {}; 201 } ResetMaxMemoryReserved()202 virtual void ResetMaxMemoryReserved() const {}; ResetMaxMemoryAllocated()203 virtual void ResetMaxMemoryAllocated() const {}; 204 205 // Allocate host memory with raii and ref count AllocateHostMemory(size_t size)206 virtual std::shared_ptr<void> AllocateHostMemory(size_t size) const { 207 return std::shared_ptr<void>(::malloc(size), ::free); 208 } 209 // Allocate host memory for offload device memory. 210 virtual void *AllocateOffloadMemory(size_t size) const; 211 // Release host memory which was allocated by AllocateOffloadMemory to pool. 212 // It will not be free to os. 213 virtual void FreeOffloadMemory(void *ptr) const; 214 GetAvailableMemSize()215 virtual size_t GetAvailableMemSize() const { return 0; } 216 217 // Allocate continuous device memory according to size list. 218 // Communication operators may need continuous memory for input and output 219 // to optimize the communication performance. 220 virtual std::vector<void *> AllocateContinuousMemory(const std::vector<size_t> &size_list, 221 uint32_t stream_id = kDefaultStreamIndex) const { 222 MS_LOG(EXCEPTION) << "Unimplemented interface."; 223 } 224 225 // Create concrete device address according different device type using KernelTensor. CreateDeviceAddress(const KernelTensorPtr & kernel_tensor)226 virtual DeviceAddressPtr CreateDeviceAddress(const KernelTensorPtr &kernel_tensor) const { 227 MS_LOG(EXCEPTION) << "Unimplemented interface."; 228 } MoveTo(const tensor::TensorPtr & src_tensor,const tensor::TensorPtr & dst_tensor,const std::string & to,bool blocking,bool * return_self)229 virtual void MoveTo(const tensor::TensorPtr &src_tensor, const tensor::TensorPtr &dst_tensor, const std::string &to, 230 bool blocking, bool *return_self) { 231 MS_LOG(EXCEPTION) << "Unimplemented interface."; 232 } 233 CreateDeviceAddress(void * ptr,size_t size,const ShapeVector & shape_vector,const Format & format,TypeId type_id,const std::string & device_name,uint32_t device_id,uint32_t stream_id)234 virtual DeviceAddressPtr CreateDeviceAddress(void *ptr, size_t size, const ShapeVector &shape_vector, 235 const Format &format, TypeId type_id, const std::string &device_name, 236 uint32_t device_id, uint32_t stream_id) const { 237 MS_LOG(EXCEPTION) << "Unimplemented interface."; 238 } 239 240 // Create a stream with assigning a stream id, the assigned stream id will be written to the parameter '*stream_id'. CreateStream(size_t * stream_id)241 virtual bool CreateStream(size_t *stream_id) const { 242 MS_LOG(WARNING) << "Unimplemented interface: 'CreateStream'."; 243 *stream_id = kSizeZero; 244 return false; 245 } 246 247 // Create a stream with priority. CreateStreamWithPriority(size_t * stream_id,int32_t priority)248 virtual bool CreateStreamWithPriority(size_t *stream_id, int32_t priority) const { 249 *stream_id = kSizeZero; 250 return false; 251 } 252 QueryStreamSize()253 virtual size_t QueryStreamSize() const { return 0L; } GetStreamIds()254 virtual std::vector<uint32_t> GetStreamIds() const { return {}; } 255 256 // If multi-stream used in pynative mode, other streams must be sync before the graph 257 // is executed. Otherwise, out-of-order occurs. Therefore this flag is added. 258 // This solution is a temporary solution, this flag will be removed after multi-stream is 259 // supported in graph mode. single_op_multi_stream_enable()260 virtual bool single_op_multi_stream_enable() const { return false; } set_single_op_multi_stream_enable(bool single_op_multi_stream_enable)261 virtual void set_single_op_multi_stream_enable(bool single_op_multi_stream_enable) {} 262 263 // Get the stream pointer by stream_id. GetStream(size_t stream_id)264 virtual void *GetStream(size_t stream_id) const { return nullptr; }; 265 266 // Set currently using stream id. SetCurrentStreamId(size_t stream_id)267 virtual void SetCurrentStreamId(size_t stream_id) { return; } 268 269 // Get currently using stream id. GetCurrentStreamId()270 virtual size_t GetCurrentStreamId() const { return kSizeZero; } 271 GetStream()272 virtual void *GetStream() const { return nullptr; }; 273 274 // Destroy a stream bound to the input parameter "stream_id". DestroyStream(size_t stream_id)275 virtual bool DestroyStream(size_t stream_id) const { return false; } 276 277 // Query tasks' completion status of a stream. QueryStream(size_t stream_id)278 virtual bool QueryStream(size_t stream_id) const { return true; } 279 280 // Synchronize stream, device such as GPU and Ascend need stream to launch kernel asynchronously, 281 // Using 'SyncStream' to block thread and wait for completing all tasks on specific stream. 282 // Using 'SyncAllStream' to block thread and wait for completing all tasks on all streams. 283 // Devices without stream could ignore the implementation of these function. 284 // Since the current entry for creating streams is not unified, the implementation of the 'SyncStream' and 285 // "SyncAllStreams" interfaces are implemented by subclasses. SyncStream(size_t stream_id)286 virtual bool SyncStream(size_t stream_id) const { return true; } 287 SyncAllStreams()288 virtual bool SyncAllStreams() const { return true; } 289 SyncNotDefaultStreams()290 virtual bool SyncNotDefaultStreams() const { return true; } 291 292 // Return default stream id. Normally it's 0. DefaultStream()293 virtual size_t DefaultStream() const { return 0; } 294 295 // Create device event for runtime. CreateRuntimeEvent(bool enable_blocking,bool enable_record_wait)296 virtual DeviceEventPtr CreateRuntimeEvent(bool enable_blocking, bool enable_record_wait) { return nullptr; } 297 298 // Create device event with flag. CreateEventWithFlag(bool enable_timing,bool blocking)299 virtual DeviceEventPtr CreateEventWithFlag(bool enable_timing, bool blocking) { return nullptr; }; 300 301 // Destroy specified device event. 302 virtual bool DestroyEvent(const DeviceEventPtr &event); 303 304 // Destroy all device events. 305 virtual bool DestroyAllEvents(); 306 307 // Dynamically load collective communication library. 308 // Currently, four types are supported: OpenMPI and self developed framework for CPU. NCCL for GPU. HCCL for Ascend. LoadCollectiveCommLib()309 virtual bool LoadCollectiveCommLib() { return true; } 310 311 // Return collective communication object for caller to access collective_comm_lib()312 CollectiveCommunicationLib *collective_comm_lib() const { return collective_comm_lib_; } 313 swap_manager()314 std::shared_ptr<SwapManager> swap_manager() const { return swap_manager_; } 315 mem_manager()316 std::shared_ptr<MemoryManager> mem_manager() const { return mem_manager_; } 317 318 protected: 319 // Ensure the thread safety for allocating device memory. 320 mutable std::mutex alloc_mem_mutex_; 321 322 // The collective communication library. 323 CollectiveCommunicationLib *collective_comm_lib_; 324 325 DeviceContext *device_context_{nullptr}; 326 327 std::shared_ptr<SwapManager> swap_manager_{nullptr}; 328 329 std::mutex device_events_mutex_; 330 331 DeviceEventPtrList device_events_{}; 332 333 std::shared_ptr<MemoryManager> mem_manager_{nullptr}; 334 335 private: 336 template <class... Args> 337 friend class DeviceInterface; SetDeviceContext(DeviceContext * device_context)338 void SetDeviceContext(DeviceContext *device_context) { device_context_ = device_context; } 339 std::shared_ptr<device::OffloadedMemPool> offloaded_mem_pool_; 340 }; 341 342 class GraphExecutor { 343 public: 344 virtual ~GraphExecutor() = default; CompileGraph(const FuncGraphPtr & graph,const std::map<string,string> & compile_options)345 virtual bool CompileGraph(const FuncGraphPtr &graph, const std::map<string, string> &compile_options) { return true; } RunGraph(const FuncGraphPtr & graph,const std::vector<tensor::Tensor> & inputs,std::vector<tensor::Tensor> * outputs,const std::map<string,string> & compile_options)346 virtual bool RunGraph(const FuncGraphPtr &graph, const std::vector<tensor::Tensor> &inputs, 347 std::vector<tensor::Tensor> *outputs, const std::map<string, string> &compile_options) { 348 MS_LOG(EXCEPTION) << "Unimplemented interface."; 349 } GetRandomStatus(const std::vector<FuncGraphPtr> & graphs)350 virtual std::string GetRandomStatus(const std::vector<FuncGraphPtr> &graphs) { return ""; } GetGraphFeatureMemory(const FuncGraphPtr & graph)351 virtual size_t GetGraphFeatureMemory(const FuncGraphPtr &graph) const { return 0; } InitGraphInfo(const FuncGraphPtr & graph)352 virtual void InitGraphInfo(const FuncGraphPtr &graph) { return; }; 353 354 protected: 355 DeviceContext *device_context_{nullptr}; 356 357 private: 358 template <class... Args> 359 friend class DeviceInterface; 360 SetDeviceContext(DeviceContext * device_context)361 void SetDeviceContext(DeviceContext *device_context) { device_context_ = device_context; } 362 }; 363 364 using CallbackFunc = std::function<void(void)>; 365 366 class BACKEND_EXPORT KernelExecutor { 367 public: 368 virtual ~KernelExecutor() = default; 369 Initialize()370 virtual void Initialize(){}; Destroy()371 virtual void Destroy(){}; 372 373 // Optimize the kernel graph for graph mode. OptimizeGraph(const FuncGraphPtr & graph)374 virtual void OptimizeGraph(const FuncGraphPtr &graph) const {} 375 376 // Generate 'KernelMod' for all kernels and set 'KernelMod' into kernel, 377 // 'KernelMod' is real executive object of kernel. CreateKernel(const std::vector<CNodePtr> & nodes)378 virtual void CreateKernel(const std::vector<CNodePtr> &nodes) const {} CreateKernelMod(const std::string & op_name)379 virtual kernel::KernelModPtr CreateKernelMod(const std::string &op_name) const { MS_LOG(EXCEPTION) << "Unrealized"; }; 380 381 // Adjust kernel graph before run graph. PreprocessBeforeRun(const FuncGraphPtr & graph)382 virtual void PreprocessBeforeRun(const FuncGraphPtr &graph) const {} 383 384 // Launch a kernel via 'KernelMod' of the kernel, use KernelTensor input type. LaunchKernel(const CNodePtr & kernel,const std::vector<KernelTensor * > & inputs,const std::vector<KernelTensor * > & workspace,const std::vector<KernelTensor * > & outputs,KernelMod * kernel_mod,void * stream)385 virtual bool LaunchKernel(const CNodePtr &kernel, const std::vector<KernelTensor *> &inputs, 386 const std::vector<KernelTensor *> &workspace, const std::vector<KernelTensor *> &outputs, 387 KernelMod *kernel_mod, void *stream) const { 388 MS_LOG(EXCEPTION) << "Unimplemented interface."; 389 } 390 // Launch callback. LaunchCallback(std::function<void (void)> callback_func,size_t stream_id)391 virtual bool LaunchCallback(std::function<void(void)> callback_func, size_t stream_id) const { 392 callback_func(); 393 return true; 394 }; 395 // Unify the MindIR, the default behavior uses the common unified MindIR. 396 virtual void UnifyMindIR(const KernelGraphPtr &graph) const; AddMindIRPass(const KernelGraphPtr & graph)397 virtual void AddMindIRPass(const KernelGraphPtr &graph) const {}; 398 399 // Get rank id for distributed training. GetRankID()400 virtual uint32_t GetRankID() const { return 0; } 401 SetDeviceContext(DeviceContext * device_context)402 void SetDeviceContext(DeviceContext *device_context) { device_context_ = device_context; } 403 ExecuteKernelTask(const runtime::KernelTaskType & task_type,const device::DeviceAddressPtrList & input_addr_list,const device::DeviceAddressPtrList & output_addr_list,const size_t & stream_id)404 virtual bool ExecuteKernelTask(const runtime::KernelTaskType &task_type, 405 const device::DeviceAddressPtrList &input_addr_list, 406 const device::DeviceAddressPtrList &output_addr_list, const size_t &stream_id) const { 407 return false; 408 }; 409 410 protected: 411 DeviceContext *device_context_{nullptr}; 412 }; 413 414 template <class... Args> 415 class DeviceInterface : public DeviceContext {}; 416 417 template <> 418 class DeviceInterface<> : public DeviceContext { 419 public: DeviceInterface(const DeviceContextKey & key)420 explicit DeviceInterface(const DeviceContextKey &key) : DeviceContext(key) {} 421 422 protected: CheckUnset(const void * ptr,const std::string & error_msg)423 void CheckUnset(const void *ptr, const std::string &error_msg) const { 424 if (ptr != nullptr) { 425 MS_LOG(EXCEPTION) << error_msg; 426 } 427 } 428 }; 429 430 template <class T, class... Args> 431 class DeviceInterface<T, Args...> : public DeviceInterface<Args...> { 432 public: DeviceInterface(const DeviceContextKey & key)433 explicit DeviceInterface(const DeviceContextKey &key) : DeviceInterface<Args...>(key) { 434 if constexpr (std::is_base_of_v<DeviceResManager, T>) { 435 DeviceInterface::CheckUnset(reinterpret_cast<void *>(DeviceContext::device_res_manager_.get()), 436 "DeviceResManager has been registered!"); 437 DeviceContext::device_res_manager_ = std::make_unique<T>(); 438 DeviceContext::device_res_manager_->SetDeviceContext(this); 439 } else if constexpr (std::is_base_of_v<GraphExecutor, T>) { 440 DeviceInterface::CheckUnset(reinterpret_cast<void *>(DeviceContext::graph_executor_.get()), 441 "GraphExecutor has been registered!"); 442 DeviceContext::graph_executor_ = std::make_unique<T>(); 443 DeviceContext::graph_executor_->SetDeviceContext(this); 444 } else if constexpr (std::is_base_of_v<KernelExecutor, T>) { 445 DeviceInterface::CheckUnset(reinterpret_cast<void *>(DeviceContext::GetKernelExecutor(false).get()), 446 "KernelExecutor has been registered!"); 447 DeviceInterface::CheckUnset(reinterpret_cast<void *>(DeviceContext::GetKernelExecutor(true).get()), 448 "Dyn KernelExecutor has been registered!"); 449 DeviceContext::SetKernelExecutor(std::make_shared<T>()); 450 DeviceContext::GetKernelExecutor(false)->SetDeviceContext(this); 451 // for GPU/CPU dynamic shape kernel executor 452 DeviceContext::SetDynKernelExecutor(DeviceContext::GetKernelExecutor(false)); 453 } 454 } 455 456 private: 457 template <typename = std::enable_if_t<std::is_base_of_v<DeviceResManager, T> || std::is_base_of_v<GraphExecutor, T> || 458 std::is_base_of_v<KernelExecutor, T>>> Assert()459 void Assert() const {} 460 }; 461 } // namespace device 462 } // namespace mindspore 463 #endif // MINDSPORE_CCSRC_RUNTIME_HARDWARE_DEVICE_CONTEXT_H_ 464