1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15r"""Convert checkpoints using RNNCells to new name convention. 16 17Usage: 18 19 python checkpoint_convert.py [--write_v1_checkpoint] \ 20 '/path/to/checkpoint' '/path/to/new_checkpoint' 21 22For example, if there is a V2 checkpoint to be converted and the files include: 23 /tmp/my_checkpoint/model.ckpt.data-00000-of-00001 24 /tmp/my_checkpoint/model.ckpt.index 25 /tmp/my_checkpoint/model.ckpt.meta 26 27use the following command: 28 mkdir /tmp/my_converted_checkpoint && 29 python checkpoint_convert.py \ 30 /tmp/my_checkpoint/model.ckpt /tmp/my_converted_checkpoint/model.ckpt 31 32This will generate three converted checkpoint files corresponding to the three 33old ones in the new directory: 34 /tmp/my_converted_checkpoint/model.ckpt.data-00000-of-00001 35 /tmp/my_converted_checkpoint/model.ckpt.index 36 /tmp/my_converted_checkpoint/model.ckpt.meta 37""" 38from __future__ import absolute_import 39from __future__ import division 40from __future__ import print_function 41 42import argparse 43import collections 44import re 45import sys 46 47from tensorflow.core.protobuf import saver_pb2 48from tensorflow.python import pywrap_tensorflow 49from tensorflow.python.client import session 50from tensorflow.python.framework import ops 51from tensorflow.python.ops import variables 52from tensorflow.python.platform import app 53from tensorflow.python.platform import tf_logging as logging 54from tensorflow.python.training import saver as saver_lib 55 56# Mapping between old <=> new names. Externalized so that user scripts that 57# may need to consume multiple checkpoint formats can use this metadata. 58RNN_NAME_REPLACEMENTS = collections.OrderedDict([ 59 ############################################################################ 60 # contrib/rnn/python/ops/core_rnn_cell_impl.py 61 # BasicRNNCell 62 ('basic_rnn_cell/weights', 'basic_rnn_cell/kernel'), 63 ('basic_rnn_cell/biases', 'basic_rnn_cell/bias'), 64 # GRUCell 65 ('gru_cell/weights', 'gru_cell/kernel'), 66 ('gru_cell/biases', 'gru_cell/bias'), 67 ('gru_cell/gates/weights', 'gru_cell/gates/kernel'), 68 ('gru_cell/gates/biases', 'gru_cell/gates/bias'), 69 ('gru_cell/candidate/weights', 'gru_cell/candidate/kernel'), 70 ('gru_cell/candidate/biases', 'gru_cell/candidate/bias'), 71 # BasicLSTMCell 72 ('basic_lstm_cell/weights', 'basic_lstm_cell/kernel'), 73 ('basic_lstm_cell/biases', 'basic_lstm_cell/bias'), 74 # LSTMCell 75 ('lstm_cell/weights', 'lstm_cell/kernel'), 76 ('lstm_cell/biases', 'lstm_cell/bias'), 77 ('lstm_cell/projection/weights', 'lstm_cell/projection/kernel'), 78 ('lstm_cell/projection/biases', 'lstm_cell/projection/bias'), 79 # OutputProjectionWrapper 80 ('output_projection_wrapper/weights', 'output_projection_wrapper/kernel'), 81 ('output_projection_wrapper/biases', 'output_projection_wrapper/bias'), 82 # InputProjectionWrapper 83 ('input_projection_wrapper/weights', 'input_projection_wrapper/kernel'), 84 ('input_projection_wrapper/biases', 'input_projection_wrapper/bias'), 85 ############################################################################ 86 # contrib/rnn/python/ops/lstm_ops.py 87 # LSTMBlockFusedCell ?? 88 ('lstm_block_wrapper/weights', 'lstm_block_wrapper/kernel'), 89 ('lstm_block_wrapper/biases', 'lstm_block_wrapper/bias'), 90 ############################################################################ 91 # contrib/rnn/python/ops/rnn_cell.py 92 # LayerNormBasicLSTMCell 93 ('layer_norm_basic_lstm_cell/weights', 'layer_norm_basic_lstm_cell/kernel'), 94 ('layer_norm_basic_lstm_cell/biases', 'layer_norm_basic_lstm_cell/bias'), 95 # UGRNNCell, not found in g3, but still need it? 96 ('ugrnn_cell/weights', 'ugrnn_cell/kernel'), 97 ('ugrnn_cell/biases', 'ugrnn_cell/bias'), 98 # NASCell 99 ('nas_rnn/weights', 'nas_rnn/kernel'), 100 ('nas_rnn/recurrent_weights', 'nas_rnn/recurrent_kernel'), 101 # IntersectionRNNCell 102 ('intersection_rnn_cell/weights', 'intersection_rnn_cell/kernel'), 103 ('intersection_rnn_cell/biases', 'intersection_rnn_cell/bias'), 104 ('intersection_rnn_cell/in_projection/weights', 105 'intersection_rnn_cell/in_projection/kernel'), 106 ('intersection_rnn_cell/in_projection/biases', 107 'intersection_rnn_cell/in_projection/bias'), 108 # PhasedLSTMCell 109 ('phased_lstm_cell/mask_gates/weights', 110 'phased_lstm_cell/mask_gates/kernel'), 111 ('phased_lstm_cell/mask_gates/biases', 'phased_lstm_cell/mask_gates/bias'), 112 ('phased_lstm_cell/new_input/weights', 'phased_lstm_cell/new_input/kernel'), 113 ('phased_lstm_cell/new_input/biases', 'phased_lstm_cell/new_input/bias'), 114 ('phased_lstm_cell/output_gate/weights', 115 'phased_lstm_cell/output_gate/kernel'), 116 ('phased_lstm_cell/output_gate/biases', 117 'phased_lstm_cell/output_gate/bias'), 118 # AttentionCellWrapper 119 ('attention_cell_wrapper/weights', 'attention_cell_wrapper/kernel'), 120 ('attention_cell_wrapper/biases', 'attention_cell_wrapper/bias'), 121 ('attention_cell_wrapper/attn_output_projection/weights', 122 'attention_cell_wrapper/attn_output_projection/kernel'), 123 ('attention_cell_wrapper/attn_output_projection/biases', 124 'attention_cell_wrapper/attn_output_projection/bias'), 125 ('attention_cell_wrapper/attention/weights', 126 'attention_cell_wrapper/attention/kernel'), 127 ('attention_cell_wrapper/attention/biases', 128 'attention_cell_wrapper/attention/bias'), 129 ############################################################################ 130 # contrib/legacy_seq2seq/python/ops/seq2seq.py 131 ('attention_decoder/weights', 'attention_decoder/kernel'), 132 ('attention_decoder/biases', 'attention_decoder/bias'), 133 ('attention_decoder/Attention_0/weights', 134 'attention_decoder/Attention_0/kernel'), 135 ('attention_decoder/Attention_0/biases', 136 'attention_decoder/Attention_0/bias'), 137 ('attention_decoder/AttnOutputProjection/weights', 138 'attention_decoder/AttnOutputProjection/kernel'), 139 ('attention_decoder/AttnOutputProjection/biases', 140 'attention_decoder/AttnOutputProjection/bias'), 141 # contrib/legacy_seq2seq/python/ops/seq2seq.py before cl/140060366 142 ('attention_decoder/Attention_0/Linear/Bias', 143 'attention_decoder/Attention_0/bias'), 144 ('attention_decoder/Attention_0/Linear/Matrix', 145 'attention_decoder/Attention_0/kernel'), 146 ('attention_decoder/AttnOutputProjection/Linear/Bias', 147 'attention_decoder/AttnOutputProjection/bias'), 148 ('attention_decoder/AttnOutputProjection/Linear/Matrix', 149 'attention_decoder/AttnOutputProjection/kernel'), 150 ('attention_decoder/LSTMCell/B', 'attention_decoder/lstm_cell/bias'), 151 ('attention_decoder/LSTMCell/W_0', 'attention_decoder/lstm_cell/kernel'), 152 ('attention_decoder/Linear/Bias', 'attention_decoder/bias'), 153 ('attention_decoder/Linear/Matrix', 'attention_decoder/kernel') 154]) 155 156_RNN_SHARDED_NAME_REPLACEMENTS = collections.OrderedDict([ 157 ('LSTMCell/W_', 'lstm_cell/weights/part_'), 158 ('BasicLSTMCell/Linear/Matrix_', 'basic_lstm_cell/weights/part_'), 159 ('GRUCell/W_', 'gru_cell/weights/part_'), 160 ('MultiRNNCell/Cell', 'multi_rnn_cell/cell_'), 161]) 162 163 164def _rnn_name_replacement(var_name): 165 for pattern in RNN_NAME_REPLACEMENTS: 166 if pattern in var_name: 167 old_var_name = var_name 168 var_name = var_name.replace(pattern, RNN_NAME_REPLACEMENTS[pattern]) 169 logging.info('Converted: %s --> %s' % (old_var_name, var_name)) 170 break 171 return var_name 172 173 174def _rnn_name_replacement_sharded(var_name): 175 for pattern in _RNN_SHARDED_NAME_REPLACEMENTS: 176 if pattern in var_name: 177 old_var_name = var_name 178 var_name = var_name.replace(pattern, 179 _RNN_SHARDED_NAME_REPLACEMENTS[pattern]) 180 logging.info('Converted: %s --> %s' % (old_var_name, var_name)) 181 return var_name 182 183 184def _split_sharded_vars(name_shape_map): 185 """Split shareded variables. 186 187 Args: 188 name_shape_map: A dict from variable name to variable shape. 189 190 Returns: 191 not_sharded: Names of the non-sharded variables. 192 sharded: Names of the sharded variables. 193 """ 194 sharded = [] 195 not_sharded = [] 196 for name in name_shape_map: 197 if re.match(name, '_[0-9]+$'): 198 if re.sub('_[0-9]+$', '_1', name) in name_shape_map: 199 sharded.append(name) 200 else: 201 not_sharded.append(name) 202 else: 203 not_sharded.append(name) 204 return not_sharded, sharded 205 206 207def convert_names(checkpoint_from_path, 208 checkpoint_to_path, 209 write_v1_checkpoint=False): 210 """Migrates the names of variables within a checkpoint. 211 212 Args: 213 checkpoint_from_path: Path to source checkpoint to be read in. 214 checkpoint_to_path: Path to checkpoint to be written out. 215 write_v1_checkpoint: Whether the output checkpoint will be in V1 format. 216 217 Returns: 218 A dictionary that maps the new variable names to the Variable objects. 219 A dictionary that maps the old variable names to the new variable names. 220 """ 221 with ops.Graph().as_default(): 222 logging.info('Reading checkpoint_from_path %s' % checkpoint_from_path) 223 reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_from_path) 224 name_shape_map = reader.get_variable_to_shape_map() 225 not_sharded, sharded = _split_sharded_vars(name_shape_map) 226 new_variable_map = {} 227 conversion_map = {} 228 for var_name in not_sharded: 229 new_var_name = _rnn_name_replacement(var_name) 230 tensor = reader.get_tensor(var_name) 231 var = variables.Variable(tensor, name=var_name) 232 new_variable_map[new_var_name] = var 233 if new_var_name != var_name: 234 conversion_map[var_name] = new_var_name 235 for var_name in sharded: 236 new_var_name = _rnn_name_replacement_sharded(var_name) 237 var = variables.Variable(tensor, name=var_name) 238 new_variable_map[new_var_name] = var 239 if new_var_name != var_name: 240 conversion_map[var_name] = new_var_name 241 242 write_version = (saver_pb2.SaverDef.V1 243 if write_v1_checkpoint else saver_pb2.SaverDef.V2) 244 saver = saver_lib.Saver(new_variable_map, write_version=write_version) 245 246 with session.Session() as sess: 247 sess.run(variables.global_variables_initializer()) 248 logging.info('Writing checkpoint_to_path %s' % checkpoint_to_path) 249 saver.save(sess, checkpoint_to_path) 250 251 logging.info('Summary:') 252 logging.info(' Converted %d variable name(s).' % len(new_variable_map)) 253 return new_variable_map, conversion_map 254 255 256def main(_): 257 convert_names( 258 FLAGS.checkpoint_from_path, 259 FLAGS.checkpoint_to_path, 260 write_v1_checkpoint=FLAGS.write_v1_checkpoint) 261 262 263if __name__ == '__main__': 264 parser = argparse.ArgumentParser() 265 parser.register('type', 'bool', lambda v: v.lower() == 'true') 266 parser.add_argument('checkpoint_from_path', type=str, 267 help='Path to source checkpoint to be read in.') 268 parser.add_argument('checkpoint_to_path', type=str, 269 help='Path to checkpoint to be written out.') 270 parser.add_argument('--write_v1_checkpoint', action='store_true', 271 help='Write v1 checkpoint') 272 FLAGS, unparsed = parser.parse_known_args() 273 274 app.run(main=main, argv=[sys.argv[0]] + unparsed) 275