1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16# pylint: disable=line-too-long 17"""Inputs and Readers. 18 19See the [Inputs and 20Readers](https://tensorflow.org/api_guides/python/io_ops) guide. 21""" 22 23from __future__ import absolute_import 24from __future__ import division 25from __future__ import print_function 26 27from tensorflow.python.eager import context 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import ops 30from tensorflow.python.lib.io import python_io 31from tensorflow.python.ops import gen_data_flow_ops 32from tensorflow.python.ops import gen_io_ops 33from tensorflow.python.ops import gen_parsing_ops 34# go/tf-wildcard-import 35# pylint: disable=wildcard-import 36from tensorflow.python.ops.gen_io_ops import * 37# pylint: enable=wildcard-import 38from tensorflow.python.util import deprecation 39from tensorflow.python.util import dispatch as _dispatch 40from tensorflow.python.util.tf_export import tf_export 41 42 43# pylint: disable=protected-access 44def _save(filename, tensor_names, tensors, tensor_slices=None, name="save"): 45 """Save a list of tensors to a file with given names. 46 47 Example usage without slice info: 48 Save("/foo/bar", ["w", "b"], [w, b]) 49 50 Example usage with slices: 51 Save("/foo/bar", ["w", "w"], [slice0, slice1], 52 tensor_slices=["4 10 0,2:-", "4 10 2,2:-"]) 53 54 Args: 55 filename: the file name of the sstable. 56 tensor_names: a list of strings. 57 tensors: the list of tensors to be saved. 58 tensor_slices: Optional list of strings to specify the shape and slices of 59 a larger virtual tensor that each tensor is a part of. If not specified 60 each tensor is saved as a full slice. 61 name: string. Optional name for the op. 62 63 Requires: 64 The length of tensors should match the size of tensor_names and of 65 tensor_slices. 66 67 Returns: 68 An Operation that saves the tensors. 69 """ 70 if tensor_slices is None: 71 return gen_io_ops.save(filename, tensor_names, tensors, name=name) 72 else: 73 return gen_io_ops.save_slices(filename, tensor_names, tensor_slices, 74 tensors, name=name) 75 76 77def _restore_slice(file_pattern, tensor_name, shape_and_slice, tensor_type, 78 name="restore_slice", preferred_shard=-1): 79 """Restore a tensor slice from a set of files with a given pattern. 80 81 Example usage: 82 RestoreSlice("/foo/bar-?????-of-?????", "w", "10 10 0,2:-", DT_FLOAT) 83 84 Args: 85 file_pattern: the file pattern used to match a set of checkpoint files. 86 tensor_name: the name of the tensor to restore. 87 shape_and_slice: the shape-and-slice spec of the slice. 88 tensor_type: the type of the tensor to restore. 89 name: string. Optional name for the op. 90 preferred_shard: Int. Optional shard to open first in the checkpoint file. 91 92 Returns: 93 A tensor of type "tensor_type". 94 """ 95 base_type = dtypes.as_dtype(tensor_type).base_dtype 96 return gen_io_ops.restore_slice( 97 file_pattern, tensor_name, shape_and_slice, base_type, 98 preferred_shard, name=name) 99 100 101@_dispatch.add_dispatch_list 102@tf_export("io.read_file", v1=["io.read_file", "read_file"]) 103def read_file(filename, name=None): 104 """Reads the contents of file. 105 106 This operation returns a tensor with the entire contents of the input 107 filename. It does not do any parsing, it just returns the contents as 108 they are. Usually, this is the first step in the input pipeline. 109 110 Example: 111 112 >>> with open("/tmp/file.txt", "w") as f: 113 ... f.write("asdf") 114 ... 115 4 116 >>> tf.io.read_file("/tmp/file.txt") 117 <tf.Tensor: shape=(), dtype=string, numpy=b'asdf'> 118 119 Example of using the op in a function to read an image, decode it and reshape 120 the tensor containing the pixel data: 121 122 >>> @tf.function 123 ... def load_image(filename): 124 ... raw = tf.io.read_file(filename) 125 ... image = tf.image.decode_png(raw, channels=3) 126 ... # the `print` executes during tracing. 127 ... print("Initial shape: ", image.shape) 128 ... image.set_shape([28, 28, 3]) 129 ... print("Final shape: ", image.shape) 130 ... return image 131 132 Args: 133 filename: string. filename to read from. 134 name: string. Optional name for the op. 135 136 Returns: 137 A tensor of dtype "string", with the file contents. 138 """ 139 return gen_io_ops.read_file(filename, name) 140 141 142@_dispatch.add_dispatch_list 143@tf_export( 144 "io.serialize_tensor", v1=["io.serialize_tensor", "serialize_tensor"]) 145def serialize_tensor(tensor, name=None): 146 r"""Transforms a Tensor into a serialized TensorProto proto. 147 148 This operation transforms data in a `tf.Tensor` into a `tf.Tensor` of type 149 `tf.string` containing the data in a binary string format. This operation can 150 transform scalar data and linear arrays, but it is most useful in converting 151 multidimensional arrays into a format accepted by binary storage formats such 152 as a `TFRecord` or `tf.train.Example`. 153 154 See also: 155 - `tf.io.parse_tensor`: inverse operation of `tf.io.serialize_tensor` that 156 transforms a scalar string containing a serialized Tensor into a Tensor of a 157 specified type. 158 - `tf.ensure_shape`: `parse_tensor` cannot statically determine the shape of 159 the parsed tensor. Use `tf.ensure_shape` to set the static shape when running 160 under a `tf.function` 161 - `.SerializeToString`, serializes a proto to a binary-string 162 163 Example of serializing scalar data: 164 165 >>> t = tf.constant(1) 166 >>> tf.io.serialize_tensor(t) 167 <tf.Tensor: shape=(), dtype=string, numpy=b'\x08...\x00'> 168 169 Example of storing non-scalar data into a `tf.train.Example`: 170 171 >>> t1 = [[1, 2]] 172 >>> t2 = [[7, 8]] 173 >>> nonscalar = tf.concat([t1, t2], 0) 174 >>> nonscalar 175 <tf.Tensor: shape=(2, 2), dtype=int32, numpy= 176 array([[1, 2], 177 [7, 8]], dtype=int32)> 178 179 Serialize the data using `tf.io.serialize_tensor`. 180 181 >>> serialized_nonscalar = tf.io.serialize_tensor(nonscalar) 182 >>> serialized_nonscalar 183 <tf.Tensor: shape=(), dtype=string, numpy=b'\x08...\x00'> 184 185 Store the data in a `tf.train.Feature`. 186 187 >>> feature_of_bytes = tf.train.Feature( 188 ... bytes_list=tf.train.BytesList(value=[serialized_nonscalar.numpy()])) 189 >>> feature_of_bytes 190 bytes_list { 191 value: "\010...\000" 192 } 193 194 Put the `tf.train.Feature` message into a `tf.train.Example`. 195 196 >>> features_for_example = { 197 ... 'feature0': feature_of_bytes 198 ... } 199 >>> example_proto = tf.train.Example( 200 ... features=tf.train.Features(feature=features_for_example)) 201 >>> example_proto 202 features { 203 feature { 204 key: "feature0" 205 value { 206 bytes_list { 207 value: "\010...\000" 208 } 209 } 210 } 211 } 212 213 Args: 214 tensor: A `tf.Tensor`. 215 name: string. Optional name for the op. 216 217 Returns: 218 A Tensor of dtype string. 219 """ 220 return gen_parsing_ops.serialize_tensor(tensor, name) 221 222 223@tf_export(v1=["ReaderBase"]) 224class ReaderBase(object): 225 """Base class for different Reader types, that produce a record every step. 226 227 Conceptually, Readers convert string 'work units' into records (key, 228 value pairs). Typically the 'work units' are filenames and the 229 records are extracted from the contents of those files. We want a 230 single record produced per step, but a work unit can correspond to 231 many records. 232 233 Therefore we introduce some decoupling using a queue. The queue 234 contains the work units and the Reader dequeues from the queue when 235 it is asked to produce a record (via Read()) but it has finished the 236 last work unit. 237 238 @compatibility(eager) 239 Readers are not compatible with eager execution. Instead, please 240 use `tf.data` to get data into your model. 241 @end_compatibility 242 """ 243 244 def __init__(self, reader_ref, supports_serialize=False): 245 """Creates a new ReaderBase. 246 247 Args: 248 reader_ref: The operation that implements the reader. 249 supports_serialize: True if the reader implementation can 250 serialize its state. 251 252 Raises: 253 RuntimeError: If eager execution is enabled. 254 """ 255 if context.executing_eagerly(): 256 raise RuntimeError( 257 "Readers are not supported when eager execution is enabled. " 258 "Instead, please use tf.data to get data into your model.") 259 260 self._reader_ref = reader_ref 261 self._supports_serialize = supports_serialize 262 263 @property 264 def reader_ref(self): 265 """Op that implements the reader.""" 266 return self._reader_ref 267 268 def read(self, queue, name=None): 269 """Returns the next record (key, value) pair produced by a reader. 270 271 Will dequeue a work unit from queue if necessary (e.g. when the 272 Reader needs to start reading from a new file since it has 273 finished with the previous file). 274 275 Args: 276 queue: A Queue or a mutable string Tensor representing a handle 277 to a Queue, with string work items. 278 name: A name for the operation (optional). 279 280 Returns: 281 A tuple of Tensors (key, value). 282 key: A string scalar Tensor. 283 value: A string scalar Tensor. 284 """ 285 if isinstance(queue, ops.Tensor): 286 queue_ref = queue 287 else: 288 queue_ref = queue.queue_ref 289 if self._reader_ref.dtype == dtypes.resource: 290 return gen_io_ops.reader_read_v2(self._reader_ref, queue_ref, name=name) 291 else: 292 # For compatibility with pre-resource queues, create a ref(string) tensor 293 # which can be looked up as the same queue by a resource manager. 294 old_queue_op = gen_data_flow_ops.fake_queue(queue_ref) 295 return gen_io_ops.reader_read(self._reader_ref, old_queue_op, name=name) 296 297 def read_up_to(self, queue, num_records, # pylint: disable=invalid-name 298 name=None): 299 """Returns up to num_records (key, value) pairs produced by a reader. 300 301 Will dequeue a work unit from queue if necessary (e.g., when the 302 Reader needs to start reading from a new file since it has 303 finished with the previous file). 304 It may return less than num_records even before the last batch. 305 306 Args: 307 queue: A Queue or a mutable string Tensor representing a handle 308 to a Queue, with string work items. 309 num_records: Number of records to read. 310 name: A name for the operation (optional). 311 312 Returns: 313 A tuple of Tensors (keys, values). 314 keys: A 1-D string Tensor. 315 values: A 1-D string Tensor. 316 """ 317 if isinstance(queue, ops.Tensor): 318 queue_ref = queue 319 else: 320 queue_ref = queue.queue_ref 321 if self._reader_ref.dtype == dtypes.resource: 322 return gen_io_ops.reader_read_up_to_v2(self._reader_ref, 323 queue_ref, 324 num_records, 325 name=name) 326 else: 327 # For compatibility with pre-resource queues, create a ref(string) tensor 328 # which can be looked up as the same queue by a resource manager. 329 old_queue_op = gen_data_flow_ops.fake_queue(queue_ref) 330 return gen_io_ops.reader_read_up_to(self._reader_ref, 331 old_queue_op, 332 num_records, 333 name=name) 334 335 def num_records_produced(self, name=None): 336 """Returns the number of records this reader has produced. 337 338 This is the same as the number of Read executions that have 339 succeeded. 340 341 Args: 342 name: A name for the operation (optional). 343 344 Returns: 345 An int64 Tensor. 346 347 """ 348 if self._reader_ref.dtype == dtypes.resource: 349 return gen_io_ops.reader_num_records_produced_v2(self._reader_ref, 350 name=name) 351 else: 352 return gen_io_ops.reader_num_records_produced(self._reader_ref, 353 name=name) 354 355 def num_work_units_completed(self, name=None): 356 """Returns the number of work units this reader has finished processing. 357 358 Args: 359 name: A name for the operation (optional). 360 361 Returns: 362 An int64 Tensor. 363 """ 364 if self._reader_ref.dtype == dtypes.resource: 365 return gen_io_ops.reader_num_work_units_completed_v2(self._reader_ref, 366 name=name) 367 else: 368 return gen_io_ops.reader_num_work_units_completed(self._reader_ref, 369 name=name) 370 371 def serialize_state(self, name=None): 372 """Produce a string tensor that encodes the state of a reader. 373 374 Not all Readers support being serialized, so this can produce an 375 Unimplemented error. 376 377 Args: 378 name: A name for the operation (optional). 379 380 Returns: 381 A string Tensor. 382 """ 383 if self._reader_ref.dtype == dtypes.resource: 384 return gen_io_ops.reader_serialize_state_v2(self._reader_ref, name=name) 385 else: 386 return gen_io_ops.reader_serialize_state(self._reader_ref, name=name) 387 388 def restore_state(self, state, name=None): 389 """Restore a reader to a previously saved state. 390 391 Not all Readers support being restored, so this can produce an 392 Unimplemented error. 393 394 Args: 395 state: A string Tensor. 396 Result of a SerializeState of a Reader with matching type. 397 name: A name for the operation (optional). 398 399 Returns: 400 The created Operation. 401 """ 402 if self._reader_ref.dtype == dtypes.resource: 403 return gen_io_ops.reader_restore_state_v2( 404 self._reader_ref, state, name=name) 405 else: 406 return gen_io_ops.reader_restore_state(self._reader_ref, state, name=name) 407 408 @property 409 def supports_serialize(self): 410 """Whether the Reader implementation can serialize its state.""" 411 return self._supports_serialize 412 413 def reset(self, name=None): 414 """Restore a reader to its initial clean state. 415 416 Args: 417 name: A name for the operation (optional). 418 419 Returns: 420 The created Operation. 421 """ 422 if self._reader_ref.dtype == dtypes.resource: 423 return gen_io_ops.reader_reset_v2(self._reader_ref, name=name) 424 else: 425 return gen_io_ops.reader_reset(self._reader_ref, name=name) 426 427 428ops.NotDifferentiable("ReaderRead") 429ops.NotDifferentiable("ReaderReadUpTo") 430ops.NotDifferentiable("ReaderNumRecordsProduced") 431ops.NotDifferentiable("ReaderNumWorkUnitsCompleted") 432ops.NotDifferentiable("ReaderSerializeState") 433ops.NotDifferentiable("ReaderRestoreState") 434ops.NotDifferentiable("ReaderReset") 435 436 437@tf_export(v1=["WholeFileReader"]) 438class WholeFileReader(ReaderBase): 439 """A Reader that outputs the entire contents of a file as a value. 440 441 To use, enqueue filenames in a Queue. The output of Read will 442 be a filename (key) and the contents of that file (value). 443 444 See ReaderBase for supported methods. 445 446 @compatibility(eager) 447 Readers are not compatible with eager execution. Instead, please 448 use `tf.data` to get data into your model. 449 @end_compatibility 450 """ 451 452 @deprecation.deprecated( 453 None, "Queue-based input pipelines have been replaced by `tf.data`. Use " 454 "`tf.data.Dataset.map(tf.read_file)`.") 455 def __init__(self, name=None): 456 """Create a WholeFileReader. 457 458 Args: 459 name: A name for the operation (optional). 460 """ 461 rr = gen_io_ops.whole_file_reader_v2(name=name) 462 super(WholeFileReader, self).__init__(rr, supports_serialize=True) 463 464 465ops.NotDifferentiable("WholeFileReader") 466 467 468@tf_export(v1=["TextLineReader"]) 469class TextLineReader(ReaderBase): 470 """A Reader that outputs the lines of a file delimited by newlines. 471 472 Newlines are stripped from the output. 473 See ReaderBase for supported methods. 474 475 @compatibility(eager) 476 Readers are not compatible with eager execution. Instead, please 477 use `tf.data` to get data into your model. 478 @end_compatibility 479 """ 480 # TODO(josh11b): Support serializing and restoring state. 481 482 @deprecation.deprecated( 483 None, "Queue-based input pipelines have been replaced by `tf.data`. Use " 484 "`tf.data.TextLineDataset`.") 485 def __init__(self, skip_header_lines=None, name=None): 486 """Create a TextLineReader. 487 488 Args: 489 skip_header_lines: An optional int. Defaults to 0. Number of lines 490 to skip from the beginning of every file. 491 name: A name for the operation (optional). 492 """ 493 rr = gen_io_ops.text_line_reader_v2(skip_header_lines=skip_header_lines, 494 name=name) 495 super(TextLineReader, self).__init__(rr) 496 497 498ops.NotDifferentiable("TextLineReader") 499 500 501@tf_export(v1=["FixedLengthRecordReader"]) 502class FixedLengthRecordReader(ReaderBase): 503 """A Reader that outputs fixed-length records from a file. 504 505 See ReaderBase for supported methods. 506 507 @compatibility(eager) 508 Readers are not compatible with eager execution. Instead, please 509 use `tf.data` to get data into your model. 510 @end_compatibility 511 """ 512 # TODO(josh11b): Support serializing and restoring state. 513 514 @deprecation.deprecated( 515 None, "Queue-based input pipelines have been replaced by `tf.data`. Use " 516 "`tf.data.FixedLengthRecordDataset`.") 517 def __init__(self, 518 record_bytes, 519 header_bytes=None, 520 footer_bytes=None, 521 hop_bytes=None, 522 name=None, 523 encoding=None): 524 """Create a FixedLengthRecordReader. 525 526 Args: 527 record_bytes: An int. 528 header_bytes: An optional int. Defaults to 0. 529 footer_bytes: An optional int. Defaults to 0. 530 hop_bytes: An optional int. Defaults to 0. 531 name: A name for the operation (optional). 532 encoding: The type of encoding for the file. Defaults to none. 533 """ 534 rr = gen_io_ops.fixed_length_record_reader_v2( 535 record_bytes=record_bytes, 536 header_bytes=header_bytes, 537 footer_bytes=footer_bytes, 538 hop_bytes=hop_bytes, 539 encoding=encoding, 540 name=name) 541 super(FixedLengthRecordReader, self).__init__(rr) 542 543 544ops.NotDifferentiable("FixedLengthRecordReader") 545 546 547@tf_export(v1=["TFRecordReader"]) 548class TFRecordReader(ReaderBase): 549 """A Reader that outputs the records from a TFRecords file. 550 551 See ReaderBase for supported methods. 552 553 @compatibility(eager) 554 Readers are not compatible with eager execution. Instead, please 555 use `tf.data` to get data into your model. 556 @end_compatibility 557 """ 558 # TODO(josh11b): Support serializing and restoring state. 559 560 @deprecation.deprecated( 561 None, "Queue-based input pipelines have been replaced by `tf.data`. Use " 562 "`tf.data.TFRecordDataset`.") 563 def __init__(self, name=None, options=None): 564 """Create a TFRecordReader. 565 566 Args: 567 name: A name for the operation (optional). 568 options: A TFRecordOptions object (optional). 569 """ 570 compression_type = python_io.TFRecordOptions.get_compression_type_string( 571 options) 572 573 rr = gen_io_ops.tf_record_reader_v2( 574 name=name, compression_type=compression_type) 575 super(TFRecordReader, self).__init__(rr) 576 577 578ops.NotDifferentiable("TFRecordReader") 579 580 581@tf_export(v1=["LMDBReader"]) 582class LMDBReader(ReaderBase): 583 """A Reader that outputs the records from a LMDB file. 584 585 See ReaderBase for supported methods. 586 587 @compatibility(eager) 588 Readers are not compatible with eager execution. Instead, please 589 use `tf.data` to get data into your model. 590 @end_compatibility 591 """ 592 593 @deprecation.deprecated( 594 None, "Queue-based input pipelines have been replaced by `tf.data`. Use " 595 "`tf.contrib.data.LMDBDataset`.") 596 def __init__(self, name=None, options=None): 597 """Create a LMDBReader. 598 599 Args: 600 name: A name for the operation (optional). 601 options: A LMDBRecordOptions object (optional). 602 """ 603 del options 604 rr = gen_io_ops.lmdb_reader(name=name) 605 super(LMDBReader, self).__init__(rr) 606 607 608ops.NotDifferentiable("LMDBReader") 609 610 611@tf_export(v1=["IdentityReader"]) 612class IdentityReader(ReaderBase): 613 """A Reader that outputs the queued work as both the key and value. 614 615 To use, enqueue strings in a Queue. Read will take the front 616 work string and output (work, work). 617 618 See ReaderBase for supported methods. 619 620 @compatibility(eager) 621 Readers are not compatible with eager execution. Instead, please 622 use `tf.data` to get data into your model. 623 @end_compatibility 624 """ 625 626 @deprecation.deprecated( 627 None, "Queue-based input pipelines have been replaced by `tf.data`. Use " 628 "`tf.data.Dataset.map(...)`.") 629 def __init__(self, name=None): 630 """Create a IdentityReader. 631 632 Args: 633 name: A name for the operation (optional). 634 """ 635 rr = gen_io_ops.identity_reader_v2(name=name) 636 super(IdentityReader, self).__init__(rr, supports_serialize=True) 637 638 639ops.NotDifferentiable("IdentityReader") 640