1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Extending CheckpointReader for TensorFlow.""" 16from __future__ import absolute_import 17from __future__ import division 18from __future__ import print_function 19 20from tensorflow.python.framework import dtypes 21from tensorflow.python.framework import errors_impl 22from tensorflow.python.util import compat 23from tensorflow.python.util._pywrap_checkpoint_reader import CheckpointReader 24from tensorflow.python.util.tf_export import tf_export 25 26 27def error_translator(e): 28 """Translate the tensor_slice_reader.cc errors.""" 29 # TODO(b/143319754): Remove the RuntimeError casting logic once we resolve the 30 # issue with throwing python exceptions from C++. 31 error_message = str(e) 32 if 'not found in checkpoint' in error_message or ( 33 'Failed to find any ' 34 'matching files for') in error_message: 35 raise errors_impl.NotFoundError(None, None, error_message) 36 elif 'Sliced checkpoints are not supported' in error_message or ( 37 'Data type ' 38 'not ' 39 'supported') in error_message: 40 raise errors_impl.UnimplementedError(None, None, error_message) 41 elif 'Failed to get matching files on' in error_message: 42 raise errors_impl.InvalidArgumentError(None, None, error_message) 43 elif 'Unable to open table file' in error_message: 44 raise errors_impl.DataLossError(None, None, error_message) 45 elif 'Failed to find the saved tensor slices' in error_message or ( 46 'not convertible to numpy dtype' in error_message): 47 raise errors_impl.InternalError(None, None, error_message) 48 else: 49 raise errors_impl.OpError(None, None, error_message, errors_impl.UNKNOWN) 50 51 52def get_variable_to_dtype_map(self): 53 return { 54 name: dtypes.DType(type_enum) 55 for name, type_enum in self._GetVariableToDataTypeMap().items() # pylint: disable=protected-access 56 } 57 58CheckpointReader.get_variable_to_dtype_map = get_variable_to_dtype_map 59 60 61def has_tensor(self, tensor_str): 62 return self._HasTensor(compat.as_bytes(tensor_str)) # pylint: disable=protected-access 63 64CheckpointReader.has_tensor = has_tensor 65 66 67def get_tensor(self, tensor_str): 68 """Get the tensor from the Checkpoint object.""" 69 try: 70 return CheckpointReader.CheckpointReader_GetTensor( 71 self, compat.as_bytes(tensor_str)) 72 # TODO(b/143319754): Remove the RuntimeError casting logic once we resolve the 73 # issue with throwing python exceptions from C++. 74 except RuntimeError as e: 75 error_translator(e) 76 77 78CheckpointReader.get_tensor = get_tensor 79 80 81# Disable invalid name to keep backwards compatibility with that function. 82# It was previously exported from py_checkpoint_reader.i which did not conform 83# to pylint checks. 84# pylint: disable=invalid-name 85@tf_export(v1=['train.NewCheckpointReader']) 86def NewCheckpointReader(filepattern): 87 """A function that returns a CheckPointReader. 88 89 Args: 90 filepattern: The filename. 91 92 Returns: 93 A CheckpointReader object. 94 """ 95 try: 96 return CheckpointReader(compat.as_bytes(filepattern)) 97 # TODO(b/143319754): Remove the RuntimeError casting logic once we resolve the 98 # issue with throwing python exceptions from C++. 99 except RuntimeError as e: 100 error_translator(e) 101