• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16# pylint: disable=line-too-long
17"""Inputs and Readers.
18
19See the [Inputs and
20Readers](https://tensorflow.org/api_guides/python/io_ops) guide.
21"""
22
23from __future__ import absolute_import
24from __future__ import division
25from __future__ import print_function
26
27from tensorflow.python.eager import context
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import ops
30from tensorflow.python.lib.io import python_io
31from tensorflow.python.ops import gen_data_flow_ops
32from tensorflow.python.ops import gen_io_ops
33from tensorflow.python.ops import gen_parsing_ops
34# go/tf-wildcard-import
35# pylint: disable=wildcard-import
36from tensorflow.python.ops.gen_io_ops import *
37# pylint: enable=wildcard-import
38from tensorflow.python.util import deprecation
39from tensorflow.python.util import dispatch as _dispatch
40from tensorflow.python.util.tf_export import tf_export
41
42
43# pylint: disable=protected-access
44def _save(filename, tensor_names, tensors, tensor_slices=None, name="save"):
45  """Save a list of tensors to a file with given names.
46
47  Example usage without slice info:
48    Save("/foo/bar", ["w", "b"], [w, b])
49
50  Example usage with slices:
51    Save("/foo/bar", ["w", "w"], [slice0, slice1],
52         tensor_slices=["4 10 0,2:-", "4 10 2,2:-"])
53
54  Args:
55    filename: the file name of the sstable.
56    tensor_names: a list of strings.
57    tensors: the list of tensors to be saved.
58    tensor_slices: Optional list of strings to specify the shape and slices of
59      a larger virtual tensor that each tensor is a part of.  If not specified
60      each tensor is saved as a full slice.
61    name: string.  Optional name for the op.
62
63  Requires:
64    The length of tensors should match the size of tensor_names and of
65    tensor_slices.
66
67  Returns:
68    An Operation that saves the tensors.
69  """
70  if tensor_slices is None:
71    return gen_io_ops.save(filename, tensor_names, tensors, name=name)
72  else:
73    return gen_io_ops.save_slices(filename, tensor_names, tensor_slices,
74                                  tensors, name=name)
75
76
77def _restore_slice(file_pattern, tensor_name, shape_and_slice, tensor_type,
78                   name="restore_slice", preferred_shard=-1):
79  """Restore a tensor slice from a set of files with a given pattern.
80
81  Example usage:
82    RestoreSlice("/foo/bar-?????-of-?????", "w", "10 10 0,2:-", DT_FLOAT)
83
84  Args:
85    file_pattern: the file pattern used to match a set of checkpoint files.
86    tensor_name: the name of the tensor to restore.
87    shape_and_slice: the shape-and-slice spec of the slice.
88    tensor_type: the type of the tensor to restore.
89    name: string.  Optional name for the op.
90    preferred_shard: Int. Optional shard to open first in the checkpoint file.
91
92  Returns:
93    A tensor of type "tensor_type".
94  """
95  base_type = dtypes.as_dtype(tensor_type).base_dtype
96  return gen_io_ops.restore_slice(
97      file_pattern, tensor_name, shape_and_slice, base_type,
98      preferred_shard, name=name)
99
100
101@_dispatch.add_dispatch_list
102@tf_export("io.read_file", v1=["io.read_file", "read_file"])
103def read_file(filename, name=None):
104  """Reads the contents of file.
105
106  This operation returns a tensor with the entire contents of the input
107  filename. It does not do any parsing, it just returns the contents as
108  they are. Usually, this is the first step in the input pipeline.
109
110  Example:
111
112  >>> with open("/tmp/file.txt", "w") as f:
113  ...   f.write("asdf")
114  ...
115  4
116  >>> tf.io.read_file("/tmp/file.txt")
117  <tf.Tensor: shape=(), dtype=string, numpy=b'asdf'>
118
119  Example of using the op in a function to read an image, decode it and reshape
120  the tensor containing the pixel data:
121
122  >>> @tf.function
123  ... def load_image(filename):
124  ...   raw = tf.io.read_file(filename)
125  ...   image = tf.image.decode_png(raw, channels=3)
126  ...   # the `print` executes during tracing.
127  ...   print("Initial shape: ", image.shape)
128  ...   image.set_shape([28, 28, 3])
129  ...   print("Final shape: ", image.shape)
130  ...   return image
131
132  Args:
133    filename: string. filename to read from.
134    name: string.  Optional name for the op.
135
136  Returns:
137    A tensor of dtype "string", with the file contents.
138  """
139  return gen_io_ops.read_file(filename, name)
140
141
142@_dispatch.add_dispatch_list
143@tf_export(
144    "io.serialize_tensor", v1=["io.serialize_tensor", "serialize_tensor"])
145def serialize_tensor(tensor, name=None):
146  r"""Transforms a Tensor into a serialized TensorProto proto.
147
148  This operation transforms data in a `tf.Tensor` into a `tf.Tensor` of type
149  `tf.string` containing the data in a binary string format. This operation can
150  transform scalar data and linear arrays, but it is most useful in converting
151  multidimensional arrays into a format accepted by binary storage formats such
152  as a `TFRecord` or `tf.train.Example`.
153
154  See also:
155  - `tf.io.parse_tensor`: inverse operation of `tf.io.serialize_tensor` that
156  transforms a scalar string containing a serialized Tensor into a Tensor of a
157  specified type.
158  - `tf.ensure_shape`: `parse_tensor` cannot statically determine the shape of
159  the parsed tensor. Use `tf.ensure_shape` to set the static shape when running
160  under a `tf.function`
161  - `.SerializeToString`, serializes a proto to a binary-string
162
163  Example of serializing scalar data:
164
165  >>> t = tf.constant(1)
166  >>> tf.io.serialize_tensor(t)
167  <tf.Tensor: shape=(), dtype=string, numpy=b'\x08...\x00'>
168
169  Example of storing non-scalar data into a `tf.train.Example`:
170
171  >>> t1 = [[1, 2]]
172  >>> t2 = [[7, 8]]
173  >>> nonscalar = tf.concat([t1, t2], 0)
174  >>> nonscalar
175  <tf.Tensor: shape=(2, 2), dtype=int32, numpy=
176  array([[1, 2],
177         [7, 8]], dtype=int32)>
178
179  Serialize the data using `tf.io.serialize_tensor`.
180
181  >>> serialized_nonscalar = tf.io.serialize_tensor(nonscalar)
182  >>> serialized_nonscalar
183  <tf.Tensor: shape=(), dtype=string, numpy=b'\x08...\x00'>
184
185  Store the data in a `tf.train.Feature`.
186
187  >>> feature_of_bytes = tf.train.Feature(
188  ...   bytes_list=tf.train.BytesList(value=[serialized_nonscalar.numpy()]))
189  >>> feature_of_bytes
190  bytes_list {
191    value: "\010...\000"
192  }
193
194  Put the `tf.train.Feature` message into a `tf.train.Example`.
195
196  >>> features_for_example = {
197  ...   'feature0': feature_of_bytes
198  ... }
199  >>> example_proto = tf.train.Example(
200  ...   features=tf.train.Features(feature=features_for_example))
201  >>> example_proto
202  features {
203    feature {
204      key: "feature0"
205      value {
206        bytes_list {
207          value: "\010...\000"
208        }
209      }
210    }
211  }
212
213  Args:
214    tensor: A `tf.Tensor`.
215    name: string.  Optional name for the op.
216
217  Returns:
218    A Tensor of dtype string.
219  """
220  return gen_parsing_ops.serialize_tensor(tensor, name)
221
222
223@tf_export(v1=["ReaderBase"])
224class ReaderBase(object):
225  """Base class for different Reader types, that produce a record every step.
226
227  Conceptually, Readers convert string 'work units' into records (key,
228  value pairs).  Typically the 'work units' are filenames and the
229  records are extracted from the contents of those files.  We want a
230  single record produced per step, but a work unit can correspond to
231  many records.
232
233  Therefore we introduce some decoupling using a queue.  The queue
234  contains the work units and the Reader dequeues from the queue when
235  it is asked to produce a record (via Read()) but it has finished the
236  last work unit.
237
238  @compatibility(eager)
239  Readers are not compatible with eager execution. Instead, please
240  use `tf.data` to get data into your model.
241  @end_compatibility
242  """
243
244  def __init__(self, reader_ref, supports_serialize=False):
245    """Creates a new ReaderBase.
246
247    Args:
248      reader_ref: The operation that implements the reader.
249      supports_serialize: True if the reader implementation can
250        serialize its state.
251
252    Raises:
253      RuntimeError: If eager execution is enabled.
254    """
255    if context.executing_eagerly():
256      raise RuntimeError(
257          "Readers are not supported when eager execution is enabled. "
258          "Instead, please use tf.data to get data into your model.")
259
260    self._reader_ref = reader_ref
261    self._supports_serialize = supports_serialize
262
263  @property
264  def reader_ref(self):
265    """Op that implements the reader."""
266    return self._reader_ref
267
268  def read(self, queue, name=None):
269    """Returns the next record (key, value) pair produced by a reader.
270
271    Will dequeue a work unit from queue if necessary (e.g. when the
272    Reader needs to start reading from a new file since it has
273    finished with the previous file).
274
275    Args:
276      queue: A Queue or a mutable string Tensor representing a handle
277        to a Queue, with string work items.
278      name: A name for the operation (optional).
279
280    Returns:
281      A tuple of Tensors (key, value).
282      key: A string scalar Tensor.
283      value: A string scalar Tensor.
284    """
285    if isinstance(queue, ops.Tensor):
286      queue_ref = queue
287    else:
288      queue_ref = queue.queue_ref
289    if self._reader_ref.dtype == dtypes.resource:
290      return gen_io_ops.reader_read_v2(self._reader_ref, queue_ref, name=name)
291    else:
292      # For compatibility with pre-resource queues, create a ref(string) tensor
293      # which can be looked up as the same queue by a resource manager.
294      old_queue_op = gen_data_flow_ops.fake_queue(queue_ref)
295      return gen_io_ops.reader_read(self._reader_ref, old_queue_op, name=name)
296
297  def read_up_to(self, queue, num_records,  # pylint: disable=invalid-name
298                 name=None):
299    """Returns up to num_records (key, value) pairs produced by a reader.
300
301    Will dequeue a work unit from queue if necessary (e.g., when the
302    Reader needs to start reading from a new file since it has
303    finished with the previous file).
304    It may return less than num_records even before the last batch.
305
306    Args:
307      queue: A Queue or a mutable string Tensor representing a handle
308        to a Queue, with string work items.
309      num_records: Number of records to read.
310      name: A name for the operation (optional).
311
312    Returns:
313      A tuple of Tensors (keys, values).
314      keys: A 1-D string Tensor.
315      values: A 1-D string Tensor.
316    """
317    if isinstance(queue, ops.Tensor):
318      queue_ref = queue
319    else:
320      queue_ref = queue.queue_ref
321    if self._reader_ref.dtype == dtypes.resource:
322      return gen_io_ops.reader_read_up_to_v2(self._reader_ref,
323                                             queue_ref,
324                                             num_records,
325                                             name=name)
326    else:
327      # For compatibility with pre-resource queues, create a ref(string) tensor
328      # which can be looked up as the same queue by a resource manager.
329      old_queue_op = gen_data_flow_ops.fake_queue(queue_ref)
330      return gen_io_ops.reader_read_up_to(self._reader_ref,
331                                          old_queue_op,
332                                          num_records,
333                                          name=name)
334
335  def num_records_produced(self, name=None):
336    """Returns the number of records this reader has produced.
337
338    This is the same as the number of Read executions that have
339    succeeded.
340
341    Args:
342      name: A name for the operation (optional).
343
344    Returns:
345      An int64 Tensor.
346
347    """
348    if self._reader_ref.dtype == dtypes.resource:
349      return gen_io_ops.reader_num_records_produced_v2(self._reader_ref,
350                                                       name=name)
351    else:
352      return gen_io_ops.reader_num_records_produced(self._reader_ref,
353                                                    name=name)
354
355  def num_work_units_completed(self, name=None):
356    """Returns the number of work units this reader has finished processing.
357
358    Args:
359      name: A name for the operation (optional).
360
361    Returns:
362      An int64 Tensor.
363    """
364    if self._reader_ref.dtype == dtypes.resource:
365      return gen_io_ops.reader_num_work_units_completed_v2(self._reader_ref,
366                                                           name=name)
367    else:
368      return gen_io_ops.reader_num_work_units_completed(self._reader_ref,
369                                                        name=name)
370
371  def serialize_state(self, name=None):
372    """Produce a string tensor that encodes the state of a reader.
373
374    Not all Readers support being serialized, so this can produce an
375    Unimplemented error.
376
377    Args:
378      name: A name for the operation (optional).
379
380    Returns:
381      A string Tensor.
382    """
383    if self._reader_ref.dtype == dtypes.resource:
384      return gen_io_ops.reader_serialize_state_v2(self._reader_ref, name=name)
385    else:
386      return gen_io_ops.reader_serialize_state(self._reader_ref, name=name)
387
388  def restore_state(self, state, name=None):
389    """Restore a reader to a previously saved state.
390
391    Not all Readers support being restored, so this can produce an
392    Unimplemented error.
393
394    Args:
395      state: A string Tensor.
396        Result of a SerializeState of a Reader with matching type.
397      name: A name for the operation (optional).
398
399    Returns:
400      The created Operation.
401    """
402    if self._reader_ref.dtype == dtypes.resource:
403      return gen_io_ops.reader_restore_state_v2(
404          self._reader_ref, state, name=name)
405    else:
406      return gen_io_ops.reader_restore_state(self._reader_ref, state, name=name)
407
408  @property
409  def supports_serialize(self):
410    """Whether the Reader implementation can serialize its state."""
411    return self._supports_serialize
412
413  def reset(self, name=None):
414    """Restore a reader to its initial clean state.
415
416    Args:
417      name: A name for the operation (optional).
418
419    Returns:
420      The created Operation.
421    """
422    if self._reader_ref.dtype == dtypes.resource:
423      return gen_io_ops.reader_reset_v2(self._reader_ref, name=name)
424    else:
425      return gen_io_ops.reader_reset(self._reader_ref, name=name)
426
427
428ops.NotDifferentiable("ReaderRead")
429ops.NotDifferentiable("ReaderReadUpTo")
430ops.NotDifferentiable("ReaderNumRecordsProduced")
431ops.NotDifferentiable("ReaderNumWorkUnitsCompleted")
432ops.NotDifferentiable("ReaderSerializeState")
433ops.NotDifferentiable("ReaderRestoreState")
434ops.NotDifferentiable("ReaderReset")
435
436
437@tf_export(v1=["WholeFileReader"])
438class WholeFileReader(ReaderBase):
439  """A Reader that outputs the entire contents of a file as a value.
440
441  To use, enqueue filenames in a Queue.  The output of Read will
442  be a filename (key) and the contents of that file (value).
443
444  See ReaderBase for supported methods.
445
446  @compatibility(eager)
447  Readers are not compatible with eager execution. Instead, please
448  use `tf.data` to get data into your model.
449  @end_compatibility
450  """
451
452  @deprecation.deprecated(
453      None, "Queue-based input pipelines have been replaced by `tf.data`. Use "
454      "`tf.data.Dataset.map(tf.read_file)`.")
455  def __init__(self, name=None):
456    """Create a WholeFileReader.
457
458    Args:
459      name: A name for the operation (optional).
460    """
461    rr = gen_io_ops.whole_file_reader_v2(name=name)
462    super(WholeFileReader, self).__init__(rr, supports_serialize=True)
463
464
465ops.NotDifferentiable("WholeFileReader")
466
467
468@tf_export(v1=["TextLineReader"])
469class TextLineReader(ReaderBase):
470  """A Reader that outputs the lines of a file delimited by newlines.
471
472  Newlines are stripped from the output.
473  See ReaderBase for supported methods.
474
475  @compatibility(eager)
476  Readers are not compatible with eager execution. Instead, please
477  use `tf.data` to get data into your model.
478  @end_compatibility
479  """
480  # TODO(josh11b): Support serializing and restoring state.
481
482  @deprecation.deprecated(
483      None, "Queue-based input pipelines have been replaced by `tf.data`. Use "
484      "`tf.data.TextLineDataset`.")
485  def __init__(self, skip_header_lines=None, name=None):
486    """Create a TextLineReader.
487
488    Args:
489      skip_header_lines: An optional int. Defaults to 0.  Number of lines
490        to skip from the beginning of every file.
491      name: A name for the operation (optional).
492    """
493    rr = gen_io_ops.text_line_reader_v2(skip_header_lines=skip_header_lines,
494                                        name=name)
495    super(TextLineReader, self).__init__(rr)
496
497
498ops.NotDifferentiable("TextLineReader")
499
500
501@tf_export(v1=["FixedLengthRecordReader"])
502class FixedLengthRecordReader(ReaderBase):
503  """A Reader that outputs fixed-length records from a file.
504
505  See ReaderBase for supported methods.
506
507  @compatibility(eager)
508  Readers are not compatible with eager execution. Instead, please
509  use `tf.data` to get data into your model.
510  @end_compatibility
511  """
512  # TODO(josh11b): Support serializing and restoring state.
513
514  @deprecation.deprecated(
515      None, "Queue-based input pipelines have been replaced by `tf.data`. Use "
516      "`tf.data.FixedLengthRecordDataset`.")
517  def __init__(self,
518               record_bytes,
519               header_bytes=None,
520               footer_bytes=None,
521               hop_bytes=None,
522               name=None,
523               encoding=None):
524    """Create a FixedLengthRecordReader.
525
526    Args:
527      record_bytes: An int.
528      header_bytes: An optional int. Defaults to 0.
529      footer_bytes: An optional int. Defaults to 0.
530      hop_bytes: An optional int. Defaults to 0.
531      name: A name for the operation (optional).
532      encoding: The type of encoding for the file. Defaults to none.
533    """
534    rr = gen_io_ops.fixed_length_record_reader_v2(
535        record_bytes=record_bytes,
536        header_bytes=header_bytes,
537        footer_bytes=footer_bytes,
538        hop_bytes=hop_bytes,
539        encoding=encoding,
540        name=name)
541    super(FixedLengthRecordReader, self).__init__(rr)
542
543
544ops.NotDifferentiable("FixedLengthRecordReader")
545
546
547@tf_export(v1=["TFRecordReader"])
548class TFRecordReader(ReaderBase):
549  """A Reader that outputs the records from a TFRecords file.
550
551  See ReaderBase for supported methods.
552
553  @compatibility(eager)
554  Readers are not compatible with eager execution. Instead, please
555  use `tf.data` to get data into your model.
556  @end_compatibility
557  """
558  # TODO(josh11b): Support serializing and restoring state.
559
560  @deprecation.deprecated(
561      None, "Queue-based input pipelines have been replaced by `tf.data`. Use "
562      "`tf.data.TFRecordDataset`.")
563  def __init__(self, name=None, options=None):
564    """Create a TFRecordReader.
565
566    Args:
567      name: A name for the operation (optional).
568      options: A TFRecordOptions object (optional).
569    """
570    compression_type = python_io.TFRecordOptions.get_compression_type_string(
571        options)
572
573    rr = gen_io_ops.tf_record_reader_v2(
574        name=name, compression_type=compression_type)
575    super(TFRecordReader, self).__init__(rr)
576
577
578ops.NotDifferentiable("TFRecordReader")
579
580
581@tf_export(v1=["LMDBReader"])
582class LMDBReader(ReaderBase):
583  """A Reader that outputs the records from a LMDB file.
584
585  See ReaderBase for supported methods.
586
587  @compatibility(eager)
588  Readers are not compatible with eager execution. Instead, please
589  use `tf.data` to get data into your model.
590  @end_compatibility
591  """
592
593  @deprecation.deprecated(
594      None, "Queue-based input pipelines have been replaced by `tf.data`. Use "
595      "`tf.contrib.data.LMDBDataset`.")
596  def __init__(self, name=None, options=None):
597    """Create a LMDBReader.
598
599    Args:
600      name: A name for the operation (optional).
601      options: A LMDBRecordOptions object (optional).
602    """
603    del options
604    rr = gen_io_ops.lmdb_reader(name=name)
605    super(LMDBReader, self).__init__(rr)
606
607
608ops.NotDifferentiable("LMDBReader")
609
610
611@tf_export(v1=["IdentityReader"])
612class IdentityReader(ReaderBase):
613  """A Reader that outputs the queued work as both the key and value.
614
615  To use, enqueue strings in a Queue.  Read will take the front
616  work string and output (work, work).
617
618  See ReaderBase for supported methods.
619
620  @compatibility(eager)
621  Readers are not compatible with eager execution. Instead, please
622  use `tf.data` to get data into your model.
623  @end_compatibility
624  """
625
626  @deprecation.deprecated(
627      None, "Queue-based input pipelines have been replaced by `tf.data`. Use "
628      "`tf.data.Dataset.map(...)`.")
629  def __init__(self, name=None):
630    """Create a IdentityReader.
631
632    Args:
633      name: A name for the operation (optional).
634    """
635    rr = gen_io_ops.identity_reader_v2(name=name)
636    super(IdentityReader, self).__init__(rr, supports_serialize=True)
637
638
639ops.NotDifferentiable("IdentityReader")
640