• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2023 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_HAL_DEVICE_MBUF_RECEIVE_MANAGER_H_
18 #define MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_HAL_DEVICE_MBUF_RECEIVE_MANAGER_H_
19 
20 #include <atomic>
21 #include <condition_variable>
22 #include <cstdint>
23 #include <functional>
24 #include <future>
25 #include <map>
26 #include <memory>
27 #include <mutex>
28 #include <sstream>
29 #include <string>
30 #include <thread>
31 #include <utility>
32 #include <vector>
33 #include <variant>
34 #include "ir/tensor.h"
35 #include "transform/symbol/acl_tdt_symbol.h"
36 #include "transform/symbol/symbol_utils.h"
37 
38 #ifndef SECUREC_MEM_MAX_LEN
39 #define SECUREC_MEM_MAX_LEN 0x7fffffffUL
40 #endif
41 
42 namespace mindspore::device::ascend {
43 
44 class ScopeAclTdtDataset;
45 
46 using MbufFuncType = std::function<void(ScopeAclTdtDataset &)>;
47 
48 const std::map<aclDataType, TypeId> kAclDataTypeMap = {
49   {ACL_INT8, TypeId::kNumberTypeInt8},       {ACL_UINT8, TypeId::kNumberTypeUInt8},
50   {ACL_INT16, TypeId::kNumberTypeInt16},     {ACL_UINT16, TypeId::kNumberTypeUInt16},
51   {ACL_INT32, TypeId::kNumberTypeInt32},     {ACL_UINT32, TypeId::kNumberTypeUInt32},
52   {ACL_INT64, TypeId::kNumberTypeInt64},     {ACL_UINT64, TypeId::kNumberTypeUInt64},
53   {ACL_FLOAT16, TypeId::kNumberTypeFloat16}, {ACL_FLOAT, TypeId::kNumberTypeFloat32},
54   {ACL_DOUBLE, TypeId::kNumberTypeFloat64},  {ACL_BOOL, TypeId::kNumberTypeBool}};
55 
56 struct SlicedTensor {
SlicedTensorSlicedTensor57   SlicedTensor(size_t slice_num, aclDataType type, const ShapeVector &shape)
58       : slice_id_(0), slice_num_(slice_num), data_type_(type), tensor_shape_(shape) {}
59   SlicedTensor(const SlicedTensor &) = delete;
60   SlicedTensor &operator=(const SlicedTensor &) = delete;
61   ~SlicedTensor() = default;
62 
63   // the id of current slice of tensor
64   size_t slice_id_{0};
65   // the number of total slices of tensor
66   size_t slice_num_{0};
67   // tensor's data type and shape
68   aclDataType data_type_;
69   ShapeVector tensor_shape_;
70   // buffer for storing contents of sliced tensor
71   std::ostringstream buffer_;
72 };
73 
74 using DataItem = std::variant<std::string, mindspore::tensor::TensorPtr>;
75 
76 class ScopeAclTdtDataset {
77  public:
ScopeAclTdtDataset()78   ScopeAclTdtDataset() {
79     acl_dataset_ = CALL_ASCEND_API(acltdtCreateDataset);
80     Reset();
81   }
Get()82   acltdtDataset *Get() const { return acl_dataset_; }
~ScopeAclTdtDataset()83   ~ScopeAclTdtDataset() {
84     if (acl_dataset_ != nullptr && CALL_ASCEND_API(acltdtDestroyDataset, acl_dataset_) != ACL_SUCCESS) {
85       MS_LOG(ERROR) << "AcltdtDestroyDataset failed.";
86     } else {
87       MS_LOG(INFO) << "AcltdtDestroyDataset succeed.";
88     }
89   }
90 
Reset()91   void Reset() {
92     sliced_tensor_ = nullptr;
93     sliced_string_ = nullptr;
94     dataset_name_ = "";
95     tensor_type_ = ACL_TENSOR_DATA_UNDEFINED;
96     data_items_.clear();
97   }
98 
GetDataItems()99   const std::vector<DataItem> &GetDataItems() const { return data_items_; }
100 
GetDatasetName()101   const std::string &GetDatasetName() const { return dataset_name_; }
102 
103   // process full tensor(i.e. the content of tensor is in only one acltdtDataItem)
104   // return true when success, otherwise false
105   bool ProcessFullTensor(acltdtDataItem *item);
106 
107   // process sliced tensor(i.e. the content of tensor spans multiple acltdtDataItems)
108   // return true when success, otherwise false
109   bool ProcessSliceTensor(acltdtDataItem *item);
110 
111   // call this function when received last piece of slice tensor, return true when success, otherwise false
112   bool FinishSliceTensor();
113 
114   // return true when encounter the end of OutfeedEnqueueOpV2's output, otherwise false
115   bool ProcessDataset(acltdtDataset *acl_dataset);
116 
117   // set and check consistency of tensor types of data items, return true when success, otherwise false
118   bool CheckAndSetTensorType(acltdtTensorType tensor_type);
119 
120  private:
121   // acl tdt dataset for receiving data, created once, used many times
122   acltdtDataset *acl_dataset_{nullptr};
123 
124   // structure for connecting tensor slices to a full tensor
125   std::shared_ptr<SlicedTensor> sliced_tensor_{nullptr};
126   // structure for connecting string slices to a full string
127   std::shared_ptr<std::ostringstream> sliced_string_{nullptr};
128 
129   // ONLY the FIRST dataset containing the dataset name when the outputs of OutfeedEnqueueOpV2 span multiple datasets
130   std::string dataset_name_;
131   // NOTE: the data items of output of one OutfeedEnqueueOpV2 must be all with type ACL_TENSOR_DATA_TENSOR, or all with
132   // type ACL_TENSOR_DATA_SLICE_TENSOR(ACL_TENSOR_DATA_END_TENSOR is also indicating type ACL_TENSOR_DATA_SLICE_TENSOR)
133   acltdtTensorType tensor_type_{ACL_TENSOR_DATA_UNDEFINED};
134   // vector for buffering outputs of OutfeedEnqueueOpV2 at a time
135   std::vector<DataItem> data_items_;
136 };
137 
138 class MbufDataHandler {
139  public:
140   MbufDataHandler(MbufFuncType func, uint32_t device_id, string channel_name, string op_name = "",
141                   size_t capacity = 128, int32_t timeout = 800);
142   ~MbufDataHandler();
GetChannelName()143   string GetChannelName() { return channel_name_; }
GetDeviceId()144   uint32_t GetDeviceId() { return device_id_; }
GetCapacity()145   size_t GetCapacity() { return capacity_; }
StopReceive()146   void StopReceive() { stop_receive_.store(true, std::memory_order_acq_rel); }
147 
148  private:
149   MbufFuncType func_;
150   uint32_t device_id_;
151   std::string channel_name_;
152   std::string prim_name_;
153   size_t capacity_;
154   int32_t timeout_;
155   std::mutex mutex_;
156   std::atomic_bool stop_receive_{false};
157   std::thread thread_;
158   acltdtChannelHandle *acl_handle_;
159 
160   void HandleData();
161   bool ReceiveAndProcessData(ScopeAclTdtDataset *dataset);
162   bool QueryChannelSize(size_t *queue_size);
163 };
164 
165 class MbufDataHandlerManager {
166  public:
GetInstance()167   static MbufDataHandlerManager &GetInstance() {
168     static MbufDataHandlerManager instance;
169     return instance;
170   }
171   ~MbufDataHandlerManager() = default;
172   MbufDataHandlerManager(const MbufDataHandlerManager &) = delete;
173   MbufDataHandlerManager &operator=(const MbufDataHandlerManager &) = delete;
174 
AddHandler(std::unique_ptr<MbufDataHandler> handler)175   void AddHandler(std::unique_ptr<MbufDataHandler> handler) { handles_.push_back(std::move(handler)); }
176 
DestoryPrintHandler()177   void DestoryPrintHandler() {
178     for (auto iter = handles_.begin(); iter != handles_.end(); iter++) {
179       if ((*iter)->GetChannelName() == kChannelNameNpuLog) {
180         (*iter)->StopReceive();
181         handles_.erase(iter);
182         break;
183       }
184     }
185   }
186 
DestoryHandler()187   void DestoryHandler() {
188     for (auto &handle : handles_) {
189       handle->StopReceive();
190     }
191     while (!handles_.empty()) {
192       MS_LOG(INFO) << "The thread of " << handles_.back()->GetChannelName() << " channel is being destroyed.";
193       handles_.pop_back();
194     }
195   }
196 
197  private:
198   MbufDataHandlerManager() = default;
199   std::vector<std::unique_ptr<MbufDataHandler>> handles_;
200 };
201 }  // namespace mindspore::device::ascend
202 #endif  // MINDSPORE_CCSRC_PLUGIN_DEVICE_ASCEND_HAL_DEVICE_TENSORDUMP_UTILS_H_
203