1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/reader_base.h"
17
18 #include "tensorflow/core/framework/reader_base.pb.h"
19 #include "tensorflow/core/framework/types.h"
20 #include "tensorflow/core/lib/core/coding.h"
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/core/notification.h"
23 #include "tensorflow/core/lib/core/stringpiece.h"
24 #include "tensorflow/core/lib/strings/str_util.h"
25 #include "tensorflow/core/lib/strings/strcat.h"
26
27 namespace tensorflow {
28
29 // ReaderBase ------------------------------------------------------
30
ReaderBase(const string & name)31 ReaderBase::ReaderBase(const string& name) : name_(name) {}
32
NumRecordsProduced()33 int64 ReaderBase::NumRecordsProduced() {
34 mutex_lock lock(mu_);
35 return num_records_produced_;
36 }
37
NumWorkUnitsCompleted()38 int64 ReaderBase::NumWorkUnitsCompleted() {
39 mutex_lock lock(mu_);
40 return work_finished_;
41 }
42
Reset()43 Status ReaderBase::Reset() {
44 mutex_lock lock(mu_);
45 return ResetLocked();
46 }
47
ResetLocked()48 Status ReaderBase::ResetLocked() {
49 work_started_ = 0;
50 work_finished_ = 0;
51 num_records_produced_ = 0;
52 work_.clear();
53 return Status::OK();
54 }
55
SerializeState(string * state)56 Status ReaderBase::SerializeState(string* state) {
57 mutex_lock lock(mu_);
58 return SerializeStateLocked(state);
59 }
60
SerializeStateLocked(string * state)61 Status ReaderBase::SerializeStateLocked(string* state) {
62 return errors::Unimplemented("Reader SerializeState");
63 }
64
RestoreState(const string & state)65 Status ReaderBase::RestoreState(const string& state) {
66 mutex_lock lock(mu_);
67 Status status = RestoreStateLocked(state);
68 if (!status.ok()) {
69 ResetLocked().IgnoreError();
70 }
71 return status;
72 }
73
RestoreStateLocked(const string & state)74 Status ReaderBase::RestoreStateLocked(const string& state) {
75 return errors::Unimplemented("Reader RestoreState");
76 }
77
ReadUpTo(const int64 num_records,QueueInterface * queue,std::vector<string> * keys,std::vector<string> * values,OpKernelContext * context)78 int64 ReaderBase::ReadUpTo(const int64 num_records, QueueInterface* queue,
79 std::vector<string>* keys,
80 std::vector<string>* values,
81 OpKernelContext* context) {
82 mutex_lock lock(mu_);
83 int64 records_produced_this_call = 0;
84 while (true) {
85 // Records produced by this iteration of the ReadUpToLocked call.
86 int64 num_records_produced = 0;
87 int64 remaining = num_records - records_produced_this_call;
88 if (remaining == 0) {
89 return records_produced_this_call;
90 }
91 if (!work_in_progress()) {
92 work_ = GetNextWorkLocked(queue, context);
93 if (!context->status().ok()) {
94 return records_produced_this_call;
95 }
96 Status status = OnWorkStartedLocked();
97 if (status.ok()) {
98 work_started_++;
99 } else {
100 context->SetStatus(status);
101 return records_produced_this_call;
102 }
103 }
104 bool at_end = false;
105
106 Status status =
107 ReadUpToLocked(remaining, keys, values, &num_records_produced, &at_end);
108 // This call so far.
109 records_produced_this_call += num_records_produced;
110
111 // In total, over the lifetime of the ReaderBase.
112 num_records_produced_ += num_records_produced;
113
114 if (!at_end && status.ok() && num_records_produced == 0) {
115 status = errors::Internal(
116 "ReadManyLocked() for ", name(),
117 " must set *at_end=true, *num_produced > 0 or return an error.");
118 context->SetStatus(status);
119 return records_produced_this_call;
120 }
121 if (status.ok() && at_end) {
122 status = OnWorkFinishedLocked();
123 work_finished_ = work_started_;
124 if (records_produced_this_call > 0) {
125 return records_produced_this_call;
126 }
127 }
128 if (!status.ok()) {
129 context->SetStatus(status);
130 return records_produced_this_call;
131 }
132 }
133 }
134
135 // Default implementation just reads one record at a time.
ReadUpToLocked(int64 num_records,std::vector<string> * keys,std::vector<string> * values,int64 * num_read,bool * at_end)136 Status ReaderBase::ReadUpToLocked(int64 num_records, std::vector<string>* keys,
137 std::vector<string>* values, int64* num_read,
138 bool* at_end) {
139 bool produced = false;
140 string key;
141 string value;
142 Status status = ReadLocked(&key, &value, &produced, at_end);
143 if (produced) {
144 keys->emplace_back(key);
145 values->emplace_back(value);
146 *num_read = 1;
147 } else {
148 *num_read = 0;
149 }
150 return status;
151 }
152
Read(QueueInterface * queue,string * key,string * value,OpKernelContext * context)153 void ReaderBase::Read(QueueInterface* queue, string* key, string* value,
154 OpKernelContext* context) {
155 mutex_lock lock(mu_);
156 while (true) {
157 if (!work_in_progress()) {
158 work_ = GetNextWorkLocked(queue, context);
159 if (!context->status().ok()) {
160 return;
161 }
162 Status status = OnWorkStartedLocked();
163 if (status.ok()) {
164 work_started_++;
165 } else {
166 context->SetStatus(status);
167 return;
168 }
169 }
170
171 bool produced = false;
172 bool at_end = false;
173 Status status = ReadLocked(key, value, &produced, &at_end);
174
175 if (!at_end && status.ok() && !produced) {
176 status = errors::Internal(
177 "ReadLocked() for ", name(),
178 " must set *at_end=true, *produced=true, or return an error.");
179 }
180 if (!status.ok() && produced) {
181 status = errors::Internal(
182 "ReadLocked() for ", name(),
183 " set *produced=true *and* returned an error: ", status.ToString());
184 }
185 if (status.ok() && at_end) {
186 status = OnWorkFinishedLocked();
187 work_finished_ = work_started_;
188 }
189 if (!status.ok()) {
190 context->SetStatus(status);
191 return;
192 }
193 if (produced) {
194 ++num_records_produced_;
195 return;
196 }
197 }
198 }
199
GetNextWorkLocked(QueueInterface * queue,OpKernelContext * context) const200 string ReaderBase::GetNextWorkLocked(QueueInterface* queue,
201 OpKernelContext* context) const {
202 string work;
203 Notification n;
204 queue->TryDequeue(
205 context, [this, context, &n, &work](const QueueInterface::Tuple& tuple) {
206 if (context->status().ok()) {
207 if (tuple.size() != 1) {
208 context->SetStatus(
209 errors::InvalidArgument("Expected single component queue"));
210 } else if (tuple[0].dtype() != DT_STRING) {
211 context->SetStatus(errors::InvalidArgument(
212 "Expected queue with single string component"));
213 } else if (tuple[0].NumElements() != 1) {
214 context->SetStatus(errors::InvalidArgument(
215 "Expected to dequeue a one-element string tensor"));
216 } else {
217 work = tuple[0].flat<string>()(0);
218 }
219 }
220 n.Notify();
221 });
222 n.WaitForNotification();
223 return work;
224 }
225
SaveBaseState(ReaderBaseState * state) const226 void ReaderBase::SaveBaseState(ReaderBaseState* state) const {
227 state->Clear();
228 state->set_work_started(work_started_);
229 state->set_work_finished(work_finished_);
230 state->set_num_records_produced(num_records_produced_);
231 state->set_current_work(work_);
232 }
233
KeyName(const string & key) const234 string ReaderBase::KeyName(const string& key) const {
235 return strings::StrCat(current_work(), ":", key);
236 }
237
RestoreBaseState(const ReaderBaseState & state)238 Status ReaderBase::RestoreBaseState(const ReaderBaseState& state) {
239 work_started_ = state.work_started();
240 work_finished_ = state.work_finished();
241 num_records_produced_ = state.num_records_produced();
242 work_ = state.current_work();
243 if (work_started_ < 0 || work_finished_ < 0 || num_records_produced_ < 0) {
244 #if defined(__ANDROID__) || defined(__EMSCRIPTEN__)
245 const string debug_string = "<debug state not available>";
246 #else
247 const string debug_string = state.DebugString();
248 #endif
249 return errors::InvalidArgument(
250 "Unexpected negative value when restoring in ", name(), ": ",
251 debug_string);
252 }
253 if (work_started_ > work_finished_) {
254 #if defined(__ANDROID__) || (__EMSCRIPTEN__)
255 const string debug_string = "<debug state not available>";
256 #else
257 const string debug_string = state.DebugString();
258 #endif
259 return errors::InvalidArgument(
260 "Inconsistent work started vs. finished when restoring in ", name(),
261 ": ", debug_string);
262 }
263 return Status::OK();
264 }
265
266 } // namespace tensorflow
267