• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cstddef>
17 #include <deque>
18 #include <mutex>
19 #include <numeric>
20 #include <vector>
21 
22 #include "tensorflow/core/framework/op_kernel.h"
23 #include "tensorflow/core/framework/resource_mgr.h"
24 #include "tensorflow/core/framework/tensor.h"
25 #include "tensorflow/core/framework/tensor_shape.h"
26 #include "tensorflow/core/lib/strings/strcat.h"
27 #include "tensorflow/core/platform/env.h"
28 #include "tensorflow/core/platform/mutex.h"
29 
30 namespace tensorflow {
31 namespace {
32 
33 class Buffer : public ResourceBase {
34  public:
35   using Tuple = std::vector<Tensor>;
36 
Buffer(std::size_t capacity,std::size_t memory_limit)37   explicit Buffer(std::size_t capacity, std::size_t memory_limit)
38       : capacity_(capacity), memory_limit_(memory_limit), current_bytes_(0) {}
39 
40   // the Buffer takes ownership of the Tuple
Put(Tuple * tuple)41   Status Put(Tuple* tuple) {
42     std::unique_lock<std::mutex> lock(mu_);
43 
44     std::size_t tuple_bytes = GetTupleBytes(*tuple);
45 
46     // Sanity check so that we don't block for ever below
47     if (memory_limit_ > 0 && tuple_bytes > memory_limit_) {
48       return Status(
49           errors::ResourceExhausted("Attempted to insert "
50                                     "tensors with combined size of '",
51                                     tuple_bytes,
52                                     "' bytes into "
53                                     "Staging Area with a memory limit of '",
54                                     memory_limit_, "'."));
55     }
56 
57     // If buffer capacity is bounded wait until elements have been removed
58     if (IsBounded()) {
59       full_cond_var_.wait(lock, [tuple_bytes, this]() {
60         // If there's a memory limit, check if there's space for insertion
61         bool memory_limit_valid =
62             memory_limit_ > 0 ? !WouldExceedMemoryLimit(tuple_bytes) : true;
63         // If we're configured for capacity check if there's space for insertion
64         bool capacity_valid = capacity_ > 0 ? !IsCapacityFull() : true;
65 
66         // Stop waiting upon success for both conditions
67         return capacity_valid && memory_limit_valid;
68       });
69     }
70 
71     // Update bytes in the Staging Area
72     current_bytes_ += tuple_bytes;
73 
74     // Store tuple
75     buf_.push_back(std::move(*tuple));
76 
77     lock.unlock();
78     // Notify all removers. Removers
79     // may be peeking at a specific element or waiting
80     // for the element at the front of the deque.
81     // As we don't know the appropriate one to wake up
82     // we should wake them all.
83     non_empty_cond_var_.notify_all();
84 
85     return Status::OK();
86   }
87 
88   // Get tuple at front of the buffer
Get(Tuple * tuple)89   void Get(Tuple* tuple) {  // TODO(zhifengc): Support cancellation.
90     std::unique_lock<std::mutex> lock(mu_);
91 
92     // Wait for data if the buffer is empty
93     non_empty_cond_var_.wait(lock, [this]() { return !buf_.empty(); });
94 
95     // Move data into the output tuple
96     *tuple = std::move(buf_.front());
97     buf_.pop_front();
98 
99     // Update bytes in the Staging Area
100     current_bytes_ -= GetTupleBytes(*tuple);
101 
102     notify_inserters_if_bounded(&lock);
103   }
104 
105   // Return tuple at index
Peek(std::size_t index,Tuple * tuple)106   Status Peek(std::size_t index, Tuple* tuple) {
107     std::unique_lock<std::mutex> lock(mu_);
108 
109     // Wait if the requested index is not available
110     non_empty_cond_var_.wait(
111         lock, [index, this]() { return index < this->buf_.size(); });
112 
113     // Place tensors in the output tuple
114     for (const auto& tensor : buf_[index]) {
115       tuple->push_back(tensor);
116     }
117 
118     return Status::OK();
119   }
120 
121   // Buffer size
Size()122   size_t Size() {
123     std::unique_lock<std::mutex> lock(mu_);
124     return buf_.size();
125   }
126 
Clear()127   void Clear() {
128     std::unique_lock<std::mutex> lock(mu_);
129     buf_.clear();
130     current_bytes_ = 0;
131 
132     notify_inserters_if_bounded(&lock);
133   }
134 
DebugString() const135   string DebugString() const override {
136     std::unique_lock<std::mutex> lock(mu_);
137     return strings::StrCat("Staging size: ", buf_.size());
138   }
139 
140  private:
141   // If the buffer is configured for bounded capacity, notify
142   // waiting inserters that space is now available
notify_inserters_if_bounded(std::unique_lock<std::mutex> * lock)143   void notify_inserters_if_bounded(std::unique_lock<std::mutex>* lock) {
144     if (IsBounded()) {
145       lock->unlock();
146       // Notify all inserters. The removal of an element
147       // may make memory available for many inserters
148       // to insert new elements
149       full_cond_var_.notify_all();
150     }
151   }
152 
153   // Are there a limit number of elements or a memory limit
154   // configured on this buffer?
IsBounded() const155   bool IsBounded() const { return capacity_ > 0 || memory_limit_ > 0; }
156 
IsCapacityFull() const157   bool IsCapacityFull() const { return buf_.size() >= capacity_; }
158 
WouldExceedMemoryLimit(std::size_t bytes) const159   bool WouldExceedMemoryLimit(std::size_t bytes) const {
160     return bytes + current_bytes_ > memory_limit_;
161   }
162 
GetTupleBytes(const Tuple & tuple)163   std::size_t GetTupleBytes(const Tuple& tuple) {
164     return std::accumulate(tuple.begin(), tuple.end(), 0,
165                            [](const std::size_t& lhs, const Tensor& rhs) {
166                              return lhs + rhs.TotalBytes();
167                            });
168   }
169 
170   std::size_t capacity_;
171   std::size_t memory_limit_;
172   std::size_t current_bytes_;
173   mutable std::mutex mu_;
174   std::condition_variable non_empty_cond_var_;
175   std::condition_variable full_cond_var_;
176   std::deque<Tuple> buf_;
177 };
178 
GetBuffer(OpKernelContext * ctx,const NodeDef & ndef,Buffer ** buf)179 Status GetBuffer(OpKernelContext* ctx, const NodeDef& ndef, Buffer** buf) {
180   auto rm = ctx->resource_manager();
181   ContainerInfo cinfo;
182 
183   // Lambda for creating the Staging Area
184   auto create_fn = [&ndef](Buffer** ret) -> Status {
185     int64 capacity;
186     int64 memory_limit;
187     TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "capacity", &capacity));
188     TF_RETURN_IF_ERROR(GetNodeAttr(ndef, "memory_limit", &memory_limit));
189     *ret = new Buffer(capacity, memory_limit);
190     return Status::OK();
191   };
192 
193   TF_RETURN_IF_ERROR(cinfo.Init(rm, ndef, true /* use name() */));
194   TF_RETURN_IF_ERROR(rm->LookupOrCreate<Buffer>(cinfo.container(), cinfo.name(),
195                                                 buf, create_fn));
196   return Status::OK();
197 }
198 
199 }  // namespace
200 
201 class StageOp : public OpKernel {
202  public:
StageOp(OpKernelConstruction * ctx)203   explicit StageOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
204 
Compute(OpKernelContext * ctx)205   void Compute(OpKernelContext* ctx) override {
206     Buffer* buf = nullptr;
207     OP_REQUIRES_OK(ctx, GetBuffer(ctx, def(), &buf));
208     core::ScopedUnref scope(buf);
209     Buffer::Tuple tuple;
210     tuple.reserve(ctx->num_inputs());
211     for (int i = 0; i < ctx->num_inputs(); ++i) {
212       tuple.push_back(ctx->input(i));
213     }
214     OP_REQUIRES_OK(ctx, buf->Put(&tuple));
215   }
216 };
217 
218 REGISTER_KERNEL_BUILDER(Name("Stage").Device(DEVICE_CPU), StageOp);
219 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
220     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
221 REGISTER_KERNEL_BUILDER(Name("Stage").Device(DEVICE_GPU), StageOp);
222 #endif
223 
224 class UnstageOp : public OpKernel {
225  public:
UnstageOp(OpKernelConstruction * ctx)226   explicit UnstageOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
227 
228   // Using this op in such a way that it blocks forever
229   // is an error.  As such cancellation is not handled.
Compute(OpKernelContext * ctx)230   void Compute(OpKernelContext* ctx) override {
231     Buffer* buf = nullptr;
232     OP_REQUIRES_OK(ctx, GetBuffer(ctx, def(), &buf));
233     core::ScopedUnref scope(buf);
234     Buffer::Tuple tuple;
235 
236     buf->Get(&tuple);
237 
238     OP_REQUIRES(
239         ctx, tuple.size() == (size_t)ctx->num_outputs(),
240         errors::InvalidArgument("Mismatch stage/unstage: ", tuple.size(),
241                                 " vs. ", ctx->num_outputs()));
242 
243     for (size_t i = 0; i < tuple.size(); ++i) {
244       ctx->set_output(i, tuple[i]);
245     }
246   }
247 };
248 
249 REGISTER_KERNEL_BUILDER(Name("Unstage").Device(DEVICE_CPU), UnstageOp);
250 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
251     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
252 REGISTER_KERNEL_BUILDER(Name("Unstage").Device(DEVICE_GPU), UnstageOp);
253 #endif
254 
255 class StagePeekOp : public OpKernel {
256  public:
StagePeekOp(OpKernelConstruction * ctx)257   explicit StagePeekOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
258 
259   // Using this op in such a way that it blocks forever
260   // is an error.  As such cancellation is not handled.
Compute(OpKernelContext * ctx)261   void Compute(OpKernelContext* ctx) override {
262     Buffer* buf = nullptr;
263     OP_REQUIRES_OK(ctx, GetBuffer(ctx, def(), &buf));
264     core::ScopedUnref scope(buf);
265     Buffer::Tuple tuple;
266 
267     std::size_t index = ctx->input(0).scalar<int>()();
268 
269     OP_REQUIRES_OK(ctx, buf->Peek(index, &tuple));
270 
271     OP_REQUIRES(
272         ctx, tuple.size() == (size_t)ctx->num_outputs(),
273         errors::InvalidArgument("Mismatch stage/unstage: ", tuple.size(),
274                                 " vs. ", ctx->num_outputs()));
275 
276     for (size_t i = 0; i < tuple.size(); ++i) {
277       ctx->set_output(i, tuple[i]);
278     }
279   }
280 };
281 
282 REGISTER_KERNEL_BUILDER(Name("StagePeek").Device(DEVICE_CPU), StagePeekOp);
283 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
284     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
285 REGISTER_KERNEL_BUILDER(
286     Name("StagePeek").HostMemory("index").Device(DEVICE_GPU), StagePeekOp);
287 #endif
288 
289 class StageSizeOp : public OpKernel {
290  public:
StageSizeOp(OpKernelConstruction * ctx)291   explicit StageSizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
292 
293   // Using this op in such a way that it blocks forever
294   // is an error.  As such cancellation is not handled.
Compute(OpKernelContext * ctx)295   void Compute(OpKernelContext* ctx) override {
296     Buffer* buf = nullptr;
297     OP_REQUIRES_OK(ctx, GetBuffer(ctx, def(), &buf));
298     core::ScopedUnref scope(buf);
299 
300     // Allocate size output tensor
301     Tensor* size = nullptr;
302     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &size));
303 
304     // Set it to the actual size
305     size->scalar<int32>().setConstant(buf->Size());
306   }
307 };
308 
309 REGISTER_KERNEL_BUILDER(Name("StageSize").Device(DEVICE_CPU), StageSizeOp);
310 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
311     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
312 REGISTER_KERNEL_BUILDER(Name("StageSize").HostMemory("size").Device(DEVICE_GPU),
313                         StageSizeOp);
314 #endif
315 
316 class StageClearOp : public OpKernel {
317  public:
StageClearOp(OpKernelConstruction * ctx)318   explicit StageClearOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
319 
320   // Using this op in such a way that it blocks forever
321   // is an error.  As such cancellation is not handled.
Compute(OpKernelContext * ctx)322   void Compute(OpKernelContext* ctx) override {
323     Buffer* buf = nullptr;
324     OP_REQUIRES_OK(ctx, GetBuffer(ctx, def(), &buf));
325     core::ScopedUnref scope(buf);
326 
327     buf->Clear();
328   }
329 };
330 
331 REGISTER_KERNEL_BUILDER(Name("StageClear").Device(DEVICE_CPU), StageClearOp);
332 #if (defined(GOOGLE_CUDA) && GOOGLE_CUDA) || \
333     (defined(TENSORFLOW_USE_ROCM) && TENSORFLOW_USE_ROCM)
334 REGISTER_KERNEL_BUILDER(Name("StageClear").Device(DEVICE_GPU), StageClearOp);
335 #endif
336 
337 }  // namespace tensorflow
338