1 /** 2 * Copyright 2019-2021 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_DEVICE_ADDRESS_H_ 18 #define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_DEVICE_ADDRESS_H_ 19 20 #include <string> 21 #include <vector> 22 #include <memory> 23 #include "runtime/device/device_address.h" 24 #include "runtime/device/ascend/ascend_memory_pool.h" 25 #include "ir/dtype.h" 26 #include "backend/kernel_compiler/kernel.h" 27 #include "utils/shape_utils.h" 28 29 namespace mindspore { 30 #ifdef ENABLE_DEBUGGER 31 class Debugger; 32 #endif 33 namespace device { 34 class LaunchKernel; 35 namespace ascend { 36 class AscendDeviceAddress : public DeviceAddress { 37 public: AscendDeviceAddress(void * ptr,size_t size)38 explicit AscendDeviceAddress(void *ptr, size_t size) : DeviceAddress(ptr, size) {} AscendDeviceAddress(void * ptr,size_t size,const std::string & format,TypeId type_id)39 explicit AscendDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id) 40 : DeviceAddress(ptr, size, format, type_id) {} AscendDeviceAddress(void * ptr,size_t size,const std::string & format,TypeId type_id,const KernelWithIndex & node_index)41 explicit AscendDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id, 42 const KernelWithIndex &node_index) 43 : DeviceAddress(ptr, size, format, type_id, node_index) {} 44 ~AscendDeviceAddress() override; 45 bool SyncDeviceToHost(size_t size, void *const host_ptr) const override; 46 bool SyncHostToDevice(size_t size, const void *host_ptr) const override; 47 bool SyncDeviceToHost(const ShapeVector &shape, size_t size, TypeId type, void *host_ptr) const override; 48 bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr, 49 const std::string &format = "DefaultFormat") const override; 50 bool SyncDeviceToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *src_ptr, 51 const std::string &format) const override; 52 void ClearDeviceMemory() override; DeviceType()53 DeviceAddressType DeviceType() const override { return DeviceAddressType::kAscend; } 54 bool DumpMemToFile(const std::string &filepath, const std::string &host_fmt, const ShapeVector &host_shape, 55 TypeId host_type, bool trans_flag) const override; 56 #ifdef ENABLE_DEBUGGER 57 bool LoadMemToHost(const std::string &tensor_name, int execution_order, const std::string &host_fmt, 58 const ShapeVector &host_shape, TypeId host_type, size_t slot, bool keep_prev) const override; 59 #endif 60 61 private: 62 bool SyncDeviceToHostAndConvertFormat(const ShapeVector &shape, size_t size, TypeId type, void *host_ptr) const; 63 bool ConvertFormatAndSyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr) const; 64 bool SyncDeviceToHostAndConvertFormatBasedOnTransData(const std::vector<size_t> &host_shape, size_t size, 65 mindspore::TypeId type, void *host_ptr) const; 66 void SyncStream() const; 67 std::vector<size_t> GetDeviceShape(std::vector<size_t> *host_shape) const; 68 std::shared_ptr<LaunchKernel> CreateLaunchTransData(const std::vector<size_t> &host_shape, 69 const std::string &ori_format, 70 const std::string &dst_format) const; 71 mutable std::shared_ptr<LaunchKernel> launch_transdata_{nullptr}; 72 }; 73 using AscendDeviceAddressPtr = std::shared_ptr<AscendDeviceAddress>; 74 } // namespace ascend 75 } // namespace device 76 } // namespace mindspore 77 #endif // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_DEVICE_ADDRESS_H_ 78