• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Extending CheckpointReader for TensorFlow."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20from tensorflow.python.framework import dtypes
21from tensorflow.python.framework import errors_impl
22from tensorflow.python.util import compat
23from tensorflow.python.util._pywrap_checkpoint_reader import CheckpointReader
24from tensorflow.python.util.tf_export import tf_export
25
26
27def error_translator(e):
28  """Translate the tensor_slice_reader.cc errors."""
29  # TODO(b/143319754): Remove the RuntimeError casting logic once we resolve the
30  # issue with throwing python exceptions from C++.
31  error_message = str(e)
32  if 'not found in checkpoint' in error_message or (
33      'Failed to find any '
34      'matching files for') in error_message:
35    raise errors_impl.NotFoundError(None, None, error_message)
36  elif 'Sliced checkpoints are not supported' in error_message or (
37      'Data type '
38      'not '
39      'supported') in error_message:
40    raise errors_impl.UnimplementedError(None, None, error_message)
41  elif 'Failed to get matching files on' in error_message:
42    raise errors_impl.InvalidArgumentError(None, None, error_message)
43  elif 'Unable to open table file' in error_message:
44    raise errors_impl.DataLossError(None, None, error_message)
45  elif 'Failed to find the saved tensor slices' in error_message or (
46      'not convertible to numpy dtype' in error_message):
47    raise errors_impl.InternalError(None, None, error_message)
48  else:
49    raise errors_impl.OpError(None, None, error_message, errors_impl.UNKNOWN)
50
51
52def get_variable_to_dtype_map(self):
53  return {
54      name: dtypes.DType(type_enum)
55      for name, type_enum in self._GetVariableToDataTypeMap().items()  # pylint: disable=protected-access
56  }
57
58CheckpointReader.get_variable_to_dtype_map = get_variable_to_dtype_map
59
60
61def has_tensor(self, tensor_str):
62  return self._HasTensor(compat.as_bytes(tensor_str))  # pylint: disable=protected-access
63
64CheckpointReader.has_tensor = has_tensor
65
66
67def get_tensor(self, tensor_str):
68  """Get the tensor from the Checkpoint object."""
69  try:
70    return CheckpointReader.CheckpointReader_GetTensor(
71        self, compat.as_bytes(tensor_str))
72  # TODO(b/143319754): Remove the RuntimeError casting logic once we resolve the
73  # issue with throwing python exceptions from C++.
74  except RuntimeError as e:
75    error_translator(e)
76
77
78CheckpointReader.get_tensor = get_tensor
79
80
81# Disable invalid name to keep backwards compatibility with that function.
82# It was previously exported from py_checkpoint_reader.i which did not conform
83# to pylint checks.
84# pylint: disable=invalid-name
85@tf_export(v1=['train.NewCheckpointReader'])
86def NewCheckpointReader(filepattern):
87  """A function that returns a CheckPointReader.
88
89  Args:
90    filepattern: The filename.
91
92  Returns:
93    A CheckpointReader object.
94  """
95  try:
96    return CheckpointReader(compat.as_bytes(filepattern))
97  # TODO(b/143319754): Remove the RuntimeError casting logic once we resolve the
98  # issue with throwing python exceptions from C++.
99  except RuntimeError as e:
100    error_translator(e)
101