1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=g-import-not-at-top 16"""Utilities related to disk I/O.""" 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from collections import defaultdict 22 23import numpy as np 24import six 25from tensorflow.python.util.tf_export import keras_export 26 27 28try: 29 import h5py 30except ImportError: 31 h5py = None 32 33 34@keras_export('keras.utils.HDF5Matrix') 35class HDF5Matrix(object): 36 """Representation of HDF5 dataset to be used instead of a Numpy array. 37 38 Example: 39 40 ```python 41 x_data = HDF5Matrix('input/file.hdf5', 'data') 42 model.predict(x_data) 43 ``` 44 45 Providing `start` and `end` allows use of a slice of the dataset. 46 47 Optionally, a normalizer function (or lambda) can be given. This will 48 be called on every slice of data retrieved. 49 50 Arguments: 51 datapath: string, path to a HDF5 file 52 dataset: string, name of the HDF5 dataset in the file specified 53 in datapath 54 start: int, start of desired slice of the specified dataset 55 end: int, end of desired slice of the specified dataset 56 normalizer: function to be called on data when retrieved 57 58 Returns: 59 An array-like HDF5 dataset. 60 """ 61 refs = defaultdict(int) 62 63 def __init__(self, datapath, dataset, start=0, end=None, normalizer=None): 64 if h5py is None: 65 raise ImportError('The use of HDF5Matrix requires ' 66 'HDF5 and h5py installed.') 67 68 if datapath not in list(self.refs.keys()): 69 f = h5py.File(datapath) 70 self.refs[datapath] = f 71 else: 72 f = self.refs[datapath] 73 self.data = f[dataset] 74 self.start = start 75 if end is None: 76 self.end = self.data.shape[0] 77 else: 78 self.end = end 79 self.normalizer = normalizer 80 81 def __len__(self): 82 return self.end - self.start 83 84 def __getitem__(self, key): 85 if isinstance(key, slice): 86 start, stop = key.start, key.stop 87 if start is None: 88 start = 0 89 if stop is None: 90 stop = self.shape[0] 91 if stop + self.start <= self.end: 92 idx = slice(start + self.start, stop + self.start) 93 else: 94 raise IndexError 95 elif isinstance(key, (int, np.integer)): 96 if key + self.start < self.end: 97 idx = key + self.start 98 else: 99 raise IndexError 100 elif isinstance(key, np.ndarray): 101 if np.max(key) + self.start < self.end: 102 idx = (self.start + key).tolist() 103 else: 104 raise IndexError 105 else: 106 # Assume list/iterable 107 if max(key) + self.start < self.end: 108 idx = [x + self.start for x in key] 109 else: 110 raise IndexError 111 if self.normalizer is not None: 112 return self.normalizer(self.data[idx]) 113 else: 114 return self.data[idx] 115 116 @property 117 def shape(self): 118 """Gets a numpy-style shape tuple giving the dataset dimensions. 119 120 Returns: 121 A numpy-style shape tuple. 122 """ 123 return (self.end - self.start,) + self.data.shape[1:] 124 125 @property 126 def dtype(self): 127 """Gets the datatype of the dataset. 128 129 Returns: 130 A numpy dtype string. 131 """ 132 return self.data.dtype 133 134 @property 135 def ndim(self): 136 """Gets the number of dimensions (rank) of the dataset. 137 138 Returns: 139 An integer denoting the number of dimensions (rank) of the dataset. 140 """ 141 return self.data.ndim 142 143 @property 144 def size(self): 145 """Gets the total dataset size (number of elements). 146 147 Returns: 148 An integer denoting the number of elements in the dataset. 149 """ 150 return np.prod(self.shape) 151 152 153def ask_to_proceed_with_overwrite(filepath): 154 """Produces a prompt asking about overwriting a file. 155 156 Arguments: 157 filepath: the path to the file to be overwritten. 158 159 Returns: 160 True if we can proceed with overwrite, False otherwise. 161 """ 162 overwrite = six.moves.input('[WARNING] %s already exists - overwrite? ' 163 '[y/n]' % (filepath)).strip().lower() 164 while overwrite not in ('y', 'n'): 165 overwrite = six.moves.input('Enter "y" (overwrite) or "n" ' 166 '(cancel).').strip().lower() 167 if overwrite == 'n': 168 return False 169 print('[TIP] Next time specify overwrite=True!') 170 return True 171