1 /** 2 * Copyright 2023 Huawei Technologies Co., Ltd 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_HAL_DEVICE_MBUF_RECEIVE_MANAGER_H_ 18 #define MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_HAL_DEVICE_MBUF_RECEIVE_MANAGER_H_ 19 20 #include <atomic> 21 #include <condition_variable> 22 #include <cstdint> 23 #include <functional> 24 #include <future> 25 #include <map> 26 #include <memory> 27 #include <mutex> 28 #include <sstream> 29 #include <string> 30 #include <thread> 31 #include <utility> 32 #include <vector> 33 #include <variant> 34 #include "ir/tensor.h" 35 #include "transform/symbol/acl_tdt_symbol.h" 36 #include "transform/symbol/symbol_utils.h" 37 38 #ifndef SECUREC_MEM_MAX_LEN 39 #define SECUREC_MEM_MAX_LEN 0x7fffffffUL 40 #endif 41 42 namespace mindspore::device::ascend { 43 44 class ScopeAclTdtDataset; 45 46 using MbufFuncType = std::function<void(ScopeAclTdtDataset &)>; 47 48 const std::map<aclDataType, TypeId> kAclDataTypeMap = { 49 {ACL_INT8, TypeId::kNumberTypeInt8}, {ACL_UINT8, TypeId::kNumberTypeUInt8}, 50 {ACL_INT16, TypeId::kNumberTypeInt16}, {ACL_UINT16, TypeId::kNumberTypeUInt16}, 51 {ACL_INT32, TypeId::kNumberTypeInt32}, {ACL_UINT32, TypeId::kNumberTypeUInt32}, 52 {ACL_INT64, TypeId::kNumberTypeInt64}, {ACL_UINT64, TypeId::kNumberTypeUInt64}, 53 {ACL_FLOAT16, TypeId::kNumberTypeFloat16}, {ACL_FLOAT, TypeId::kNumberTypeFloat32}, 54 {ACL_DOUBLE, TypeId::kNumberTypeFloat64}, {ACL_BOOL, TypeId::kNumberTypeBool}}; 55 56 struct SlicedTensor { SlicedTensorSlicedTensor57 SlicedTensor(size_t slice_num, aclDataType type, const ShapeVector &shape) 58 : slice_id_(0), slice_num_(slice_num), data_type_(type), tensor_shape_(shape) {} 59 SlicedTensor(const SlicedTensor &) = delete; 60 SlicedTensor &operator=(const SlicedTensor &) = delete; 61 ~SlicedTensor() = default; 62 63 // the id of current slice of tensor 64 size_t slice_id_{0}; 65 // the number of total slices of tensor 66 size_t slice_num_{0}; 67 // tensor's data type and shape 68 aclDataType data_type_; 69 ShapeVector tensor_shape_; 70 // buffer for storing contents of sliced tensor 71 std::ostringstream buffer_; 72 }; 73 74 using DataItem = std::variant<std::string, mindspore::tensor::TensorPtr>; 75 76 class ScopeAclTdtDataset { 77 public: ScopeAclTdtDataset()78 ScopeAclTdtDataset() { 79 acl_dataset_ = CALL_ASCEND_API(acltdtCreateDataset); 80 Reset(); 81 } Get()82 acltdtDataset *Get() const { return acl_dataset_; } ~ScopeAclTdtDataset()83 ~ScopeAclTdtDataset() { 84 if (acl_dataset_ != nullptr && CALL_ASCEND_API(acltdtDestroyDataset, acl_dataset_) != ACL_SUCCESS) { 85 MS_LOG(ERROR) << "AcltdtDestroyDataset failed."; 86 } else { 87 MS_LOG(INFO) << "AcltdtDestroyDataset succeed."; 88 } 89 } 90 Reset()91 void Reset() { 92 sliced_tensor_ = nullptr; 93 sliced_string_ = nullptr; 94 dataset_name_ = ""; 95 tensor_type_ = ACL_TENSOR_DATA_UNDEFINED; 96 data_items_.clear(); 97 } 98 GetDataItems()99 const std::vector<DataItem> &GetDataItems() const { return data_items_; } 100 GetDatasetName()101 const std::string &GetDatasetName() const { return dataset_name_; } 102 103 // process full tensor(i.e. the content of tensor is in only one acltdtDataItem) 104 // return true when success, otherwise false 105 bool ProcessFullTensor(acltdtDataItem *item); 106 107 // process sliced tensor(i.e. the content of tensor spans multiple acltdtDataItems) 108 // return true when success, otherwise false 109 bool ProcessSliceTensor(acltdtDataItem *item); 110 111 // call this function when received last piece of slice tensor, return true when success, otherwise false 112 bool FinishSliceTensor(); 113 114 // return true when encounter the end of OutfeedEnqueueOpV2's output, otherwise false 115 bool ProcessDataset(acltdtDataset *acl_dataset); 116 117 // set and check consistency of tensor types of data items, return true when success, otherwise false 118 bool CheckAndSetTensorType(acltdtTensorType tensor_type); 119 120 private: 121 // acl tdt dataset for receiving data, created once, used many times 122 acltdtDataset *acl_dataset_{nullptr}; 123 124 // structure for connecting tensor slices to a full tensor 125 std::shared_ptr<SlicedTensor> sliced_tensor_{nullptr}; 126 // structure for connecting string slices to a full string 127 std::shared_ptr<std::ostringstream> sliced_string_{nullptr}; 128 129 // ONLY the FIRST dataset containing the dataset name when the outputs of OutfeedEnqueueOpV2 span multiple datasets 130 std::string dataset_name_; 131 // NOTE: the data items of output of one OutfeedEnqueueOpV2 must be all with type ACL_TENSOR_DATA_TENSOR, or all with 132 // type ACL_TENSOR_DATA_SLICE_TENSOR(ACL_TENSOR_DATA_END_TENSOR is also indicating type ACL_TENSOR_DATA_SLICE_TENSOR) 133 acltdtTensorType tensor_type_{ACL_TENSOR_DATA_UNDEFINED}; 134 // vector for buffering outputs of OutfeedEnqueueOpV2 at a time 135 std::vector<DataItem> data_items_; 136 }; 137 138 class MbufDataHandler { 139 public: 140 MbufDataHandler(MbufFuncType func, uint32_t device_id, string channel_name, string op_name = "", 141 size_t capacity = 128, int32_t timeout = 800); 142 ~MbufDataHandler(); GetChannelName()143 string GetChannelName() { return channel_name_; } GetDeviceId()144 uint32_t GetDeviceId() { return device_id_; } GetCapacity()145 size_t GetCapacity() { return capacity_; } StopReceive()146 void StopReceive() { stop_receive_.store(true, std::memory_order_acq_rel); } 147 148 private: 149 MbufFuncType func_; 150 uint32_t device_id_; 151 std::string channel_name_; 152 std::string prim_name_; 153 size_t capacity_; 154 int32_t timeout_; 155 std::mutex mutex_; 156 std::atomic_bool stop_receive_{false}; 157 std::thread thread_; 158 acltdtChannelHandle *acl_handle_; 159 160 void HandleData(); 161 bool ReceiveAndProcessData(ScopeAclTdtDataset *dataset); 162 bool QueryChannelSize(size_t *queue_size); 163 }; 164 165 class MbufDataHandlerManager { 166 public: GetInstance()167 static MbufDataHandlerManager &GetInstance() { 168 static MbufDataHandlerManager instance; 169 return instance; 170 } 171 ~MbufDataHandlerManager() = default; 172 MbufDataHandlerManager(const MbufDataHandlerManager &) = delete; 173 MbufDataHandlerManager &operator=(const MbufDataHandlerManager &) = delete; 174 AddHandler(std::unique_ptr<MbufDataHandler> handler)175 void AddHandler(std::unique_ptr<MbufDataHandler> handler) { handles_.push_back(std::move(handler)); } 176 DestoryPrintHandler()177 void DestoryPrintHandler() { 178 for (auto iter = handles_.begin(); iter != handles_.end(); iter++) { 179 if ((*iter)->GetChannelName() == kChannelNameNpuLog) { 180 (*iter)->StopReceive(); 181 handles_.erase(iter); 182 break; 183 } 184 } 185 } 186 DestoryHandler()187 void DestoryHandler() { 188 for (auto &handle : handles_) { 189 handle->StopReceive(); 190 } 191 while (!handles_.empty()) { 192 MS_LOG(INFO) << "The thread of " << handles_.back()->GetChannelName() << " channel is being destroyed."; 193 handles_.pop_back(); 194 } 195 } 196 197 private: 198 MbufDataHandlerManager() = default; 199 std::vector<std::unique_ptr<MbufDataHandler>> handles_; 200 }; 201 } // namespace mindspore::device::ascend 202 #endif // MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_HAL_DEVICE_TENSORDUMP_UTILS_H_ 203