• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/reader_base.h"
17 
18 #include "tensorflow/core/framework/reader_base.pb.h"
19 #include "tensorflow/core/framework/types.h"
20 #include "tensorflow/core/lib/core/coding.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/core/notification.h"
23 #include "tensorflow/core/lib/core/stringpiece.h"
24 #include "tensorflow/core/lib/strings/str_util.h"
25 #include "tensorflow/core/lib/strings/strcat.h"
26 
27 namespace tensorflow {
28 
29 // ReaderBase ------------------------------------------------------
30 
ReaderBase(const string & name)31 ReaderBase::ReaderBase(const string& name) : name_(name) {}
32 
NumRecordsProduced()33 int64 ReaderBase::NumRecordsProduced() {
34   mutex_lock lock(mu_);
35   return num_records_produced_;
36 }
37 
NumWorkUnitsCompleted()38 int64 ReaderBase::NumWorkUnitsCompleted() {
39   mutex_lock lock(mu_);
40   return work_finished_;
41 }
42 
Reset()43 Status ReaderBase::Reset() {
44   mutex_lock lock(mu_);
45   return ResetLocked();
46 }
47 
ResetLocked()48 Status ReaderBase::ResetLocked() {
49   work_started_ = 0;
50   work_finished_ = 0;
51   num_records_produced_ = 0;
52   work_.clear();
53   return Status::OK();
54 }
55 
SerializeState(tstring * state)56 Status ReaderBase::SerializeState(tstring* state) {
57   mutex_lock lock(mu_);
58   return SerializeStateLocked(state);
59 }
60 
SerializeStateLocked(tstring * state)61 Status ReaderBase::SerializeStateLocked(tstring* state) {
62   return errors::Unimplemented("Reader SerializeState");
63 }
64 
RestoreState(const tstring & state)65 Status ReaderBase::RestoreState(const tstring& state) {
66   mutex_lock lock(mu_);
67   Status status = RestoreStateLocked(state);
68   if (!status.ok()) {
69     ResetLocked().IgnoreError();
70   }
71   return status;
72 }
73 
RestoreStateLocked(const tstring & state)74 Status ReaderBase::RestoreStateLocked(const tstring& state) {
75   return errors::Unimplemented("Reader RestoreState");
76 }
77 
ReadUpTo(const int64 num_records,QueueInterface * queue,std::vector<tstring> * keys,std::vector<tstring> * values,OpKernelContext * context)78 int64 ReaderBase::ReadUpTo(const int64 num_records, QueueInterface* queue,
79                            std::vector<tstring>* keys,
80                            std::vector<tstring>* values,
81                            OpKernelContext* context) {
82   mutex_lock lock(mu_);
83   int64 records_produced_this_call = 0;
84   while (true) {
85     // Records produced by this iteration of the ReadUpToLocked call.
86     int64 num_records_produced = 0;
87     int64 remaining = num_records - records_produced_this_call;
88     if (remaining == 0) {
89       return records_produced_this_call;
90     }
91     if (!work_in_progress()) {
92       work_ = GetNextWorkLocked(queue, context);
93       if (!context->status().ok()) {
94         return records_produced_this_call;
95       }
96       Status status = OnWorkStartedLocked();
97       if (status.ok()) {
98         work_started_++;
99       } else {
100         context->SetStatus(status);
101         return records_produced_this_call;
102       }
103     }
104     bool at_end = false;
105 
106     Status status =
107         ReadUpToLocked(remaining, keys, values, &num_records_produced, &at_end);
108     // This call so far.
109     records_produced_this_call += num_records_produced;
110 
111     // In total, over the lifetime of the ReaderBase.
112     num_records_produced_ += num_records_produced;
113 
114     if (!at_end && status.ok() && num_records_produced == 0) {
115       status = errors::Internal(
116           "ReadManyLocked() for ", name(),
117           " must set *at_end=true, *num_produced > 0 or return an error.");
118       context->SetStatus(status);
119       return records_produced_this_call;
120     }
121     if (status.ok() && at_end) {
122       status = OnWorkFinishedLocked();
123       work_finished_ = work_started_;
124       if (records_produced_this_call > 0) {
125         return records_produced_this_call;
126       }
127     }
128     if (!status.ok()) {
129       context->SetStatus(status);
130       return records_produced_this_call;
131     }
132   }
133 }
134 
135 // Default implementation just reads one record at a time.
ReadUpToLocked(int64 num_records,std::vector<tstring> * keys,std::vector<tstring> * values,int64 * num_read,bool * at_end)136 Status ReaderBase::ReadUpToLocked(int64 num_records, std::vector<tstring>* keys,
137                                   std::vector<tstring>* values, int64* num_read,
138                                   bool* at_end) {
139   bool produced = false;
140   tstring key;
141   tstring value;
142   Status status = ReadLocked(&key, &value, &produced, at_end);
143   if (produced) {
144     keys->push_back(std::move(key));
145     values->push_back(std::move(value));
146     *num_read = 1;
147   } else {
148     *num_read = 0;
149   }
150   return status;
151 }
152 
Read(QueueInterface * queue,tstring * key,tstring * value,OpKernelContext * context)153 void ReaderBase::Read(QueueInterface* queue, tstring* key, tstring* value,
154                       OpKernelContext* context) {
155   mutex_lock lock(mu_);
156   while (true) {
157     if (!work_in_progress()) {
158       work_ = GetNextWorkLocked(queue, context);
159       if (!context->status().ok()) {
160         return;
161       }
162       Status status = OnWorkStartedLocked();
163       if (status.ok()) {
164         work_started_++;
165       } else {
166         context->SetStatus(status);
167         return;
168       }
169     }
170 
171     bool produced = false;
172     bool at_end = false;
173     Status status = ReadLocked(key, value, &produced, &at_end);
174 
175     if (!at_end && status.ok() && !produced) {
176       status = errors::Internal(
177           "ReadLocked() for ", name(),
178           " must set *at_end=true, *produced=true, or return an error.");
179     }
180     if (!status.ok() && produced) {
181       status = errors::Internal(
182           "ReadLocked() for ", name(),
183           " set *produced=true *and* returned an error: ", status.ToString());
184     }
185     if (status.ok() && at_end) {
186       status = OnWorkFinishedLocked();
187       work_finished_ = work_started_;
188     }
189     if (!status.ok()) {
190       context->SetStatus(status);
191       return;
192     }
193     if (produced) {
194       ++num_records_produced_;
195       return;
196     }
197   }
198 }
199 
GetNextWorkLocked(QueueInterface * queue,OpKernelContext * context) const200 string ReaderBase::GetNextWorkLocked(QueueInterface* queue,
201                                      OpKernelContext* context) const {
202   string work;
203   Notification n;
204   queue->TryDequeue(
205       context, [context, &n, &work](const QueueInterface::Tuple& tuple) {
206         if (context->status().ok()) {
207           if (tuple.size() != 1) {
208             context->SetStatus(
209                 errors::InvalidArgument("Expected single component queue"));
210           } else if (tuple[0].dtype() != DT_STRING) {
211             context->SetStatus(errors::InvalidArgument(
212                 "Expected queue with single string component"));
213           } else if (tuple[0].NumElements() != 1) {
214             context->SetStatus(errors::InvalidArgument(
215                 "Expected to dequeue a one-element string tensor"));
216           } else {
217             work = tuple[0].flat<tstring>()(0);
218           }
219         }
220         n.Notify();
221       });
222   n.WaitForNotification();
223   return work;
224 }
225 
SaveBaseState(ReaderBaseState * state) const226 void ReaderBase::SaveBaseState(ReaderBaseState* state) const {
227   state->Clear();
228   state->set_work_started(work_started_);
229   state->set_work_finished(work_finished_);
230   state->set_num_records_produced(num_records_produced_);
231   state->set_current_work(work_.data(), work_.size());
232 }
233 
KeyName(const tstring & key) const234 tstring ReaderBase::KeyName(const tstring& key) const {
235   return strings::StrCat(current_work(), ":", key);
236 }
237 
RestoreBaseState(const ReaderBaseState & state)238 Status ReaderBase::RestoreBaseState(const ReaderBaseState& state) {
239   work_started_ = state.work_started();
240   work_finished_ = state.work_finished();
241   num_records_produced_ = state.num_records_produced();
242   work_ = state.current_work();
243   if (work_started_ < 0 || work_finished_ < 0 || num_records_produced_ < 0) {
244 #if defined(__ANDROID__) || defined(__EMSCRIPTEN__)
245     const string debug_string = "<debug state not available>";
246 #else
247     const string debug_string = state.DebugString();
248 #endif
249     return errors::InvalidArgument(
250         "Unexpected negative value when restoring in ", name(), ": ",
251         debug_string);
252   }
253   if (work_started_ > work_finished_) {
254 #if defined(__ANDROID__) || (__EMSCRIPTEN__)
255     const string debug_string = "<debug state not available>";
256 #else
257     const string debug_string = state.DebugString();
258 #endif
259     return errors::InvalidArgument(
260         "Inconsistent work started vs. finished when restoring in ", name(),
261         ": ", debug_string);
262   }
263   return Status::OK();
264 }
265 
266 }  // namespace tensorflow
267