• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef MINDSPORE_CCSRC_INCLUDE_BACKEND_DATA_QUEUE_DATA_QUEUE_H
18 #define MINDSPORE_CCSRC_INCLUDE_BACKEND_DATA_QUEUE_DATA_QUEUE_H
19 
20 #include <string>
21 #include <memory>
22 #include <vector>
23 #include <functional>
24 #include "include/backend/visible.h"
25 
26 namespace mindspore {
27 namespace device {
28 class DeviceContext;
29 
30 enum class DataQueueStatus : int { SUCCESS = 0, QUEUE_EXIST, QUEUE_NOT_EXIST, ERROR_INPUT, INTERNAL_ERROR, TIMEOUT };
31 
32 struct DataQueueItem {
33   int32_t worker_id{0};
34   std::string data_type;
35   size_t data_len{0};
36   void *data_ptr{nullptr};
37   std::vector<int64_t> shapes;
38   void *device_addr{nullptr};
39   // add tensor type when tdt need more types than data and end-of-sequence
40 };
41 
42 class BACKEND_EXPORT DataQueue {
43  public:
44   explicit DataQueue(const std::string &channel_name, const size_t capacity);
45   virtual ~DataQueue() = default;
46 
RegisterRelease(const std::function<void (void *,int32_t)> & func)47   virtual void RegisterRelease(const std::function<void(void *, int32_t)> &func) { host_release_ = func; }
IsOpen()48   virtual bool IsOpen() const { return !closed_; }
Close()49   virtual void Close() { closed_ = true; }
IsEmpty()50   virtual bool IsEmpty() const { return size_ == 0; }
IsFull()51   virtual bool IsFull() const { return size_ == capacity_; }
FrontAsync(std::vector<DataQueueItem> * data)52   virtual DataQueueStatus FrontAsync(std::vector<DataQueueItem> *data) const { return DataQueueStatus::SUCCESS; }
53   virtual DataQueueStatus Push(std::vector<DataQueueItem> data) = 0;
54   virtual DataQueueStatus Front(std::vector<DataQueueItem> *data) const = 0;
55   virtual DataQueueStatus Pop() = 0;
SetThreadDevice()56   virtual void SetThreadDevice() {}
Size()57   virtual size_t Size() const { return size_; }
Capacity()58   virtual size_t Capacity() const { return capacity_; }
QueryQueueSize()59   virtual size_t QueryQueueSize() const { return 0; }
QueueType()60   virtual std::string QueueType() const { return "Unknown"; }
61 
62  protected:
63   const std::string channel_name_;
64   size_t head_;
65   size_t tail_;
66   size_t size_;
67   size_t capacity_;
68   bool closed_{false};
69   std::function<void(void *, int32_t)> host_release_;
70   DeviceContext *device_context_;
71 
72  private:
73   DataQueue(const DataQueue &) = delete;
74   DataQueue &operator=(const DataQueue &) = delete;
75 };
76 }  // namespace device
77 }  // namespace mindspore
78 #endif  // MINDSPORE_CCSRC_INCLUDE_BACKEND_DATA_QUEUE_DATA_QUEUE_H
79