• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2020-2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CORE_IR_DEVICE_SYNC_H_
18 #define MINDSPORE_CORE_IR_DEVICE_SYNC_H_
19 
20 #include <vector>
21 #include <memory>
22 #include <string>
23 
24 #include "ir/dtype/type.h"
25 #include "utils/shape_utils.h"
26 #include "ir/tensor_storage_info.h"
27 #include "ir/tensor_data.h"
28 
29 using std::string;
30 
31 namespace mindspore {
32 // Interface for data synchornize between device and host.
33 class DeviceSync {
34  public:
35   // Used to sync data between different device addresses, only need the data size and data ptr. The CPU device doesn't
36   // need use the interfaces, so need the default implementation.
SyncDeviceToHost(size_t,void *)37   virtual bool SyncDeviceToHost(size_t, void *) const { return true; }
SyncHostToDevice(size_t,const void *)38   virtual bool SyncHostToDevice(size_t, const void *) const { return true; }
39 
40   // Used to sync data between host tensor and device address, additional need the data shape and data type.
41   virtual bool SyncDeviceToHost(const ShapeVector &shape, size_t size, TypeId type, void *host_ptr) const = 0;
42   virtual bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr,
43                                 const std::string &format) const = 0;
SyncHostToDevice(const ShapeVector & shape,size_t size,TypeId type,const void * host_ptr)44   virtual bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const void *host_ptr) const {
45     return SyncHostToDevice(shape, size, type, host_ptr, "DefaultFormat");
46   }
47 
SyncHostToDevice(const ShapeVector & shape,size_t size,TypeId type,const std::string & format,const tensor::TensorDataPtr & tensor_data)48   virtual bool SyncHostToDevice(const ShapeVector &shape, size_t size, TypeId type, const std::string &format,
49                                 const tensor::TensorDataPtr &tensor_data) const {
50     MS_EXCEPTION_IF_NULL(tensor_data);
51     return SyncHostToDevice(shape, size, type, tensor_data->data(), format);
52   }
53 
54   virtual void *GetMutablePtr() const = 0;
55   virtual void ClearDeviceMemory() = 0;
56   virtual const TensorStorageInfoPtr GetTensorStorageInfo() const = 0;
57 
58   // The related interface of reference count operation.
59   virtual void set_original_ref_count(size_t original_ref_count) const = 0;
60   virtual size_t original_ref_count() const = 0;
61   virtual void set_ref_count(size_t ref_count) const = 0;
62   virtual size_t ref_count() const = 0;
63   virtual void ResetRefCount() = 0;
64 
~DeviceSync()65   virtual ~DeviceSync() {}
66 
user_data()67   virtual const UserDataPtr &user_data() const { MS_LOG(EXCEPTION) << "Not implement exception"; }
set_user_data(const UserDataPtr & user_data)68   virtual void set_user_data(const UserDataPtr &user_data) { MS_LOG(EXCEPTION) << "Not implement exception"; }
69 };
70 using DeviceSyncPtr = std::shared_ptr<DeviceSync>;
71 }  // namespace mindspore
72 #endif  // MINDSPORE_CORE_IR_DEVICE_SYNC_H_
73