• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2019-2021 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_DEVICE_ADDRESS_H_
18 #define MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_DEVICE_ADDRESS_H_
19 
20 #include <string>
21 #include <vector>
22 #include <memory>
23 #include "runtime/device/device_address.h"
24 #include "runtime/device/ascend/ascend_memory_pool.h"
25 #include "ir/dtype.h"
26 #include "backend/kernel_compiler/kernel.h"
27 #include "utils/shape_utils.h"
28 
29 namespace mindspore {
30 #ifdef ENABLE_DEBUGGER
31 class Debugger;
32 #endif
33 namespace device {
34 class LaunchKernel;
35 namespace ascend {
36 class AscendDeviceAddress : public DeviceAddress {
37  public:
AscendDeviceAddress(void * ptr,size_t size)38   explicit AscendDeviceAddress(void *ptr, size_t size) : DeviceAddress(ptr, size) {}
AscendDeviceAddress(void * ptr,size_t size,const std::string & format,TypeId type_id)39   explicit AscendDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id)
40       : DeviceAddress(ptr, size, format, type_id) {}
AscendDeviceAddress(void * ptr,size_t size,const std::string & format,TypeId type_id,const KernelWithIndex & node_index)41   explicit AscendDeviceAddress(void *ptr, size_t size, const std::string &format, TypeId type_id,
42                                const KernelWithIndex &node_index)
43       : DeviceAddress(ptr, size, format, type_id, node_index) {}
44   ~AscendDeviceAddress() override;
45   bool SyncDeviceToHost(size_t size, void *const host_ptr) const override;
46   bool SyncHostToDevice(size_t size, const void *host_ptr) const override;
47   bool SyncDeviceToHost(const ShapeVector &shape, size_t size, TypeId type, void *host_ptr) const override;
48   bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr,
49                         const std::string &format = "DefaultFormat") const override;
50   bool SyncDeviceToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *src_ptr,
51                           const std::string &format) const override;
52   void ClearDeviceMemory() override;
DeviceType()53   DeviceAddressType DeviceType() const override { return DeviceAddressType::kAscend; }
54   bool DumpMemToFile(const std::string &filepath, const std::string &host_fmt, const ShapeVector &host_shape,
55                      TypeId host_type, bool trans_flag) const override;
56 #ifdef ENABLE_DEBUGGER
57   bool LoadMemToHost(const std::string &tensor_name, int execution_order, const std::string &host_fmt,
58                      const ShapeVector &host_shape, TypeId host_type, size_t slot, bool keep_prev) const override;
59 #endif
60 
61  private:
62   bool SyncDeviceToHostAndConvertFormat(const ShapeVector &shape, size_t size, TypeId type, void *host_ptr) const;
63   bool ConvertFormatAndSyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr) const;
64   bool SyncDeviceToHostAndConvertFormatBasedOnTransData(const std::vector<size_t> &host_shape, size_t size,
65                                                         mindspore::TypeId type, void *host_ptr) const;
66   void SyncStream() const;
67   std::vector<size_t> GetDeviceShape(std::vector<size_t> *host_shape) const;
68   std::shared_ptr<LaunchKernel> CreateLaunchTransData(const std::vector<size_t> &host_shape,
69                                                       const std::string &ori_format,
70                                                       const std::string &dst_format) const;
71   mutable std::shared_ptr<LaunchKernel> launch_transdata_{nullptr};
72 };
73 using AscendDeviceAddressPtr = std::shared_ptr<AscendDeviceAddress>;
74 }  // namespace ascend
75 }  // namespace device
76 }  // namespace mindspore
77 #endif  // MINDSPORE_CCSRC_RUNTIME_DEVICE_ASCEND_ASCEND_DEVICE_ADDRESS_H_
78