• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /**
2  * Copyright 2022 Huawei Technologies Co., Ltd
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  * http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "runtime/device/tensors_queue.h"
17 
18 namespace mindspore {
19 namespace device {
CreateTensorsQueue()20 void TensorsQueue::CreateTensorsQueue() {
21   // Store one element tensors' size.
22   // The whole TensorsQueue is like: [[tensor1, tensor2], [tensor3, tensor4]].
23   // One element means [tensor1, tensor2].
24   std::vector<int64_t> element_size_list;
25   for (auto shape : shapes_) {
26     int64_t item_size =
27       std::accumulate(shape.begin(), shape.end(), SizeToLong(GetTypeByte(dtype_)), std::multiplies<int64_t>());
28     element_size_list.push_back(item_size);
29   }
30   // Create the elements in TensorsQueue when construct.
31   for (int64_t i = 0; i < size_; i++) {
32     mindspore::kernel::AddressPtrList element_addrs;
33     for (auto element_size : element_size_list) {
34       kernel::AddressPtr create_dev = std::make_shared<kernel::Address>();
35       create_dev->addr = AllocateMemory(LongToSize(element_size));
36       create_dev->size = LongToSize(element_size);
37       element_addrs.push_back(create_dev);
38       MS_LOG(DEBUG) << "Create  " << element_size << "bytes for " << name_;
39     }
40     tensors_q_.push_back(element_addrs);
41   }
42   MS_LOG(DEBUG) << "Create a TensorsQueue: " << name_ << ", Q size is " << size_ << ", elements num is "
43                 << elements_num_;
44 }
45 
CopyTensor(const mindspore::kernel::AddressPtr &,const mindspore::kernel::AddressPtr &)46 void TensorsQueue::CopyTensor(const mindspore::kernel::AddressPtr &, const mindspore::kernel::AddressPtr &) {
47   MS_LOG(EXCEPTION) << "This should be overridden by subclass !";
48 }
CopyTensor(const mindspore::kernel::AddressPtr &,const mindspore::kernel::AddressPtr &,void *)49 void TensorsQueue::CopyTensor(const mindspore::kernel::AddressPtr &, const mindspore::kernel::AddressPtr &, void *) {
50   MS_LOG(EXCEPTION) << "This should be overridden by subclass !";
51 }
52 
AvailableSize()53 size_t TensorsQueue::AvailableSize() {
54   return (rear_ > front_) ? (rear_ - front_) : (LongToSize(size_) - front_ + rear_);
55 }
IsFull()56 bool TensorsQueue::IsFull() {
57   if (size_ <= 0) {
58     return false;
59   } else {
60     return (rear_ + IntToSize(1)) % LongToSize(size_) == front_;
61   }
62 }
IsEmpty()63 bool TensorsQueue::IsEmpty() { return front_ == rear_; }
64 
Put(const mindspore::kernel::AddressPtrList & dev_addr)65 bool TensorsQueue::Put(const mindspore::kernel::AddressPtrList &dev_addr) {
66   // When the tensor_q is full, put will failed.
67   if (IsFull()) {
68     MS_LOG(WARNING) << "The " << name_ << " is full, total size is " << size_;
69     return false;
70   }
71   // Get the element in position rear_ and change the value by input, the we increase the rear_.
72   // We can get a effect like a circle queue and reuse the addrs.
73   MS_EXCEPTION_IF_CHECK_FAIL((tensors_q_.size() > rear_), "The index is out of range.");
74   mindspore::kernel::AddressPtrList element = tensors_q_[rear_];
75   for (int64_t i = 0; i < elements_num_; i++) {
76     CopyTensor(element[LongToSize(i)], dev_addr[LongToSize(i) + IntToSize(1)]);
77   }
78   if (size_ <= 0) {
79     return false;
80   }
81   rear_ = (rear_ + 1) % LongToSize(size_);
82   MS_LOG(DEBUG) << "Put an element into  " << name_ << ", now the avliable q size is [" << AvailableSize() << "/"
83                 << size_ << "]";
84   return true;
85 }
86 
Put(const mindspore::kernel::AddressPtrList & dev_addr,void * stream)87 bool TensorsQueue::Put(const mindspore::kernel::AddressPtrList &dev_addr, void *stream) {
88   if (IsFull()) {
89     MS_LOG(WARNING) << "The " << name_ << " is full, total size is " << size_;
90     return false;
91   }
92   MS_EXCEPTION_IF_CHECK_FAIL((tensors_q_.size() > rear_), "The index is out of range.");
93   mindspore::kernel::AddressPtrList element = tensors_q_[rear_];
94   for (int64_t i = 0; i < elements_num_; i++) {
95     CopyTensor(element[LongToSize(i)], dev_addr[LongToSize(i) + IntToSize(1)], stream);
96   }
97   if (size_ <= 0) {
98     return false;
99   }
100   rear_ = (rear_ + IntToSize(1)) % LongToSize(size_);
101   MS_LOG(DEBUG) << "Put an element into  " << name_ << ", now the avliable q size is [" << AvailableSize() << "/"
102                 << size_ << "]";
103   return true;
104 }
105 
Get(const mindspore::kernel::AddressPtrList & dev_addr,const bool & pop_after_get,void * stream)106 bool TensorsQueue::Get(const mindspore::kernel::AddressPtrList &dev_addr, const bool &pop_after_get, void *stream) {
107   // Get a tensor addrs list from the queue.
108   // If pop_after_get is true, we will pop the addrs from tensors_q_.
109   if (IsEmpty()) {
110     MS_LOG(WARNING) << "The TensorsQueue " << name_ << " is empty";
111     return false;
112   }
113   MS_EXCEPTION_IF_CHECK_FAIL((tensors_q_.size() > front_), "The index is out of range.");
114   mindspore::kernel::AddressPtrList element = tensors_q_[front_];
115   for (int64_t i = 0; i < elements_num_; i++) {
116     CopyTensor(dev_addr[LongToSize(i)], element[LongToSize(i)], stream);
117   }
118   if (pop_after_get) {
119     if (size_ <= 0) {
120       MS_LOG(ERROR) << "The size is zero.";
121       return false;
122     }
123     front_ = (front_ + IntToSize(1)) % LongToSize(size_);
124   }
125   MS_LOG(DEBUG) << "Get an element from  " << name_ << ", pop_after_get is " << pop_after_get
126                 << ", now the avliable q size is[" << AvailableSize() << " / " << size_ << "] ";
127   return true;
128 }
129 
Get(const mindspore::kernel::AddressPtrList & dev_addr,const bool & pop_after_get)130 bool TensorsQueue::Get(const mindspore::kernel::AddressPtrList &dev_addr, const bool &pop_after_get) {
131   if (IsEmpty()) {
132     MS_LOG(WARNING) << "The TensorsQueue " << name_ << " is empty";
133     return false;
134   }
135   MS_EXCEPTION_IF_CHECK_FAIL((tensors_q_.size() > front_), "The index is out of range.");
136   mindspore::kernel::AddressPtrList element = tensors_q_.front();
137   for (int64_t i = 0; i < elements_num_; i++) {
138     CopyTensor(dev_addr[LongToSize(i)], element[LongToSize(i)]);
139   }
140   if (pop_after_get) {
141     if (size_ <= 0) {
142       MS_LOG(ERROR) << "The size is zero.";
143       return false;
144     }
145     front_ = (front_ + IntToSize(1)) % LongToSize(size_);
146   }
147   MS_LOG(DEBUG) << "Get an element from  " << name_ << ", pop_after_get is " << pop_after_get
148                 << ", now the avliable q size is[" << AvailableSize() << " / " << size_ << "] ";
149   return true;
150 }
151 
Clear()152 void TensorsQueue::Clear() {
153   // Clear the tensors_q_ and return the element addr back to tensors_store.
154   if (IsEmpty()) {
155     MS_LOG(WARNING) << "The TensorsQueue " << name_ << " is already empty when execute Clear.";
156   }
157   rear_ = 0;
158   front_ = 0;
159   MS_LOG(DEBUG) << "Clear the elements for " << name_;
160 }
161 
Free()162 void TensorsQueue::Free() {
163   while (!IsEmpty()) {
164     MS_EXCEPTION_IF_CHECK_FAIL((tensors_q_.size() > front_), "The index is out of range.");
165     auto element = tensors_q_[front_];
166     for (const auto &addr : element) {
167       if (addr != nullptr) {
168         FreeMemory(static_cast<DeviceMemPtr>(addr->addr));
169       }
170     }
171 
172     if (size_ <= 0) {
173       MS_LOG(ERROR) << "The size is zero.";
174       return;
175     }
176     front_ = (front_ + IntToSize(1)) % LongToSize(size_);
177   }
178   MS_LOG(DEBUG) << "Free the TensorsQueue's memory for " << name_;
179 }
180 }  // namespace device
181 }  // namespace mindspore
182