1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_DEVICE_TENSOR_ASCEND910B_H_ 18 #define MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_DEVICE_TENSOR_ASCEND910B_H_ 19 20 #include <memory> 21 #include <utility> 22 #include <vector> 23 24 #include "include/api/status.h" 25 #include "minddata/dataset/include/dataset/constants.h" 26 #include "minddata/dataset/core/data_type.h" 27 #include "minddata/dataset/core/tensor.h" 28 #include "minddata/dataset/util/status.h" 29 #include "runtime/hardware/device_context.h" 30 #include "runtime/hardware/device_context_manager.h" 31 32 namespace mindspore { 33 namespace dataset { 34 class Tensor; 35 class DATASET_API DeviceTensorAscend910B { 36 public: 37 DeviceTensorAscend910B(const TensorShape &shape, const DataType &type, device::DeviceContext *device_context, 38 const size_t &stream_id, bool is_hwc = true); 39 40 // create device_tensor by empty 41 static Status CreateDeviceTensor(const TensorShape &shape, const DataType &type, 42 device::DeviceContext *device_context, const size_t &stream_id, 43 std::shared_ptr<DeviceTensorAscend910B> *out, bool is_hwc = true, 44 std::vector<int> channels = {1, 3}); 45 46 // create device_tensor by host tensor 47 static Status CreateDeviceTensor(std::shared_ptr<Tensor> tensor, device::DeviceContext *device_context, 48 const size_t &stream_id, std::shared_ptr<DeviceTensorAscend910B> *out, 49 bool is_hwc = true, std::vector<int> channels = {1, 3}); 50 51 ~DeviceTensorAscend910B(); 52 GetDeviceContext()53 device::DeviceContext *GetDeviceContext() { return device_context_; } 54 GetStreamID()55 size_t GetStreamID() { return stream_id_; } 56 SetDeviceAddress(void * device_address)57 void SetDeviceAddress(void *device_address) { device_address_ = device_address; } 58 GetDeviceAddress()59 void *GetDeviceAddress() { return device_address_; } 60 SetDeviceTensor(void * tensor)61 void SetDeviceTensor(void *tensor) { tensor_ = tensor; } 62 GetShape()63 TensorShape &GetShape() { return tensor_shape_; } 64 GetType()65 DataType GetType() { return data_type_; } 66 GetDeviceTensor()67 void *GetDeviceTensor() { return tensor_; } 68 69 Status ToHostTensor(std::shared_ptr<Tensor> *host_tensor); 70 71 bool AddWorkSpace(void *workspace); 72 73 bool AddMaintenFloatArrayMemory(void *float_array); 74 75 bool AddMaintenIntArrayMemory(void *int_array); 76 77 bool ReleaseDeviceMemory(); 78 79 private: 80 // Ascend910B resource 81 device::DeviceContext *device_context_; 82 size_t stream_id_; 83 void *device_address_; 84 void *tensor_; // aclTensor which point to device_address_ 85 void *workspace_; // used by step1 with dvpp HostAPI 86 std::vector<void *> float_arrays_; // used by dvpp in execution 87 std::vector<void *> int_arrays_; // used by dvpp in execution 88 TensorShape tensor_shape_; 89 DataType data_type_; 90 bool is_hwc_; 91 }; 92 93 } // namespace dataset 94 } // namespace mindspore 95 #endif // MINDSPORE_CCSRC_MINDDATA_DATASET_CORE_DEVICE_TENSOR_ASCEND910B_H_ 96