• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15r"""Convert checkpoints using RNNCells to new name convention.
16
17Usage:
18
19  python checkpoint_convert.py [--write_v1_checkpoint] \
20      '/path/to/checkpoint' '/path/to/new_checkpoint'
21
22For example, if there is a V2 checkpoint to be converted and the files include:
23  /tmp/my_checkpoint/model.ckpt.data-00000-of-00001
24  /tmp/my_checkpoint/model.ckpt.index
25  /tmp/my_checkpoint/model.ckpt.meta
26
27use the following command:
28  mkdir /tmp/my_converted_checkpoint &&
29  python checkpoint_convert.py \
30      /tmp/my_checkpoint/model.ckpt /tmp/my_converted_checkpoint/model.ckpt
31
32This will generate three converted checkpoint files corresponding to the three
33old ones in the new directory:
34  /tmp/my_converted_checkpoint/model.ckpt.data-00000-of-00001
35  /tmp/my_converted_checkpoint/model.ckpt.index
36  /tmp/my_converted_checkpoint/model.ckpt.meta
37"""
38from __future__ import absolute_import
39from __future__ import division
40from __future__ import print_function
41
42import argparse
43import collections
44import re
45import sys
46
47from tensorflow.core.protobuf import saver_pb2
48from tensorflow.python import pywrap_tensorflow
49from tensorflow.python.client import session
50from tensorflow.python.framework import ops
51from tensorflow.python.ops import variables
52from tensorflow.python.platform import app
53from tensorflow.python.platform import tf_logging as logging
54from tensorflow.python.training import saver as saver_lib
55
56# Mapping between old <=> new names. Externalized so that user scripts that
57# may need to consume multiple checkpoint formats can use this metadata.
58RNN_NAME_REPLACEMENTS = collections.OrderedDict([
59    ############################################################################
60    # contrib/rnn/python/ops/core_rnn_cell_impl.py
61    # BasicRNNCell
62    ('basic_rnn_cell/weights', 'basic_rnn_cell/kernel'),
63    ('basic_rnn_cell/biases', 'basic_rnn_cell/bias'),
64    # GRUCell
65    ('gru_cell/weights', 'gru_cell/kernel'),
66    ('gru_cell/biases', 'gru_cell/bias'),
67    ('gru_cell/gates/weights', 'gru_cell/gates/kernel'),
68    ('gru_cell/gates/biases', 'gru_cell/gates/bias'),
69    ('gru_cell/candidate/weights', 'gru_cell/candidate/kernel'),
70    ('gru_cell/candidate/biases', 'gru_cell/candidate/bias'),
71    # BasicLSTMCell
72    ('basic_lstm_cell/weights', 'basic_lstm_cell/kernel'),
73    ('basic_lstm_cell/biases', 'basic_lstm_cell/bias'),
74    # LSTMCell
75    ('lstm_cell/weights', 'lstm_cell/kernel'),
76    ('lstm_cell/biases', 'lstm_cell/bias'),
77    ('lstm_cell/projection/weights', 'lstm_cell/projection/kernel'),
78    ('lstm_cell/projection/biases', 'lstm_cell/projection/bias'),
79    # OutputProjectionWrapper
80    ('output_projection_wrapper/weights', 'output_projection_wrapper/kernel'),
81    ('output_projection_wrapper/biases', 'output_projection_wrapper/bias'),
82    # InputProjectionWrapper
83    ('input_projection_wrapper/weights', 'input_projection_wrapper/kernel'),
84    ('input_projection_wrapper/biases', 'input_projection_wrapper/bias'),
85    ############################################################################
86    # contrib/rnn/python/ops/lstm_ops.py
87    # LSTMBlockFusedCell ??
88    ('lstm_block_wrapper/weights', 'lstm_block_wrapper/kernel'),
89    ('lstm_block_wrapper/biases', 'lstm_block_wrapper/bias'),
90    ############################################################################
91    # contrib/rnn/python/ops/rnn_cell.py
92    # LayerNormBasicLSTMCell
93    ('layer_norm_basic_lstm_cell/weights', 'layer_norm_basic_lstm_cell/kernel'),
94    ('layer_norm_basic_lstm_cell/biases', 'layer_norm_basic_lstm_cell/bias'),
95    # UGRNNCell, not found in g3, but still need it?
96    ('ugrnn_cell/weights', 'ugrnn_cell/kernel'),
97    ('ugrnn_cell/biases', 'ugrnn_cell/bias'),
98    # NASCell
99    ('nas_rnn/weights', 'nas_rnn/kernel'),
100    ('nas_rnn/recurrent_weights', 'nas_rnn/recurrent_kernel'),
101    # IntersectionRNNCell
102    ('intersection_rnn_cell/weights', 'intersection_rnn_cell/kernel'),
103    ('intersection_rnn_cell/biases', 'intersection_rnn_cell/bias'),
104    ('intersection_rnn_cell/in_projection/weights',
105     'intersection_rnn_cell/in_projection/kernel'),
106    ('intersection_rnn_cell/in_projection/biases',
107     'intersection_rnn_cell/in_projection/bias'),
108    # PhasedLSTMCell
109    ('phased_lstm_cell/mask_gates/weights',
110     'phased_lstm_cell/mask_gates/kernel'),
111    ('phased_lstm_cell/mask_gates/biases', 'phased_lstm_cell/mask_gates/bias'),
112    ('phased_lstm_cell/new_input/weights', 'phased_lstm_cell/new_input/kernel'),
113    ('phased_lstm_cell/new_input/biases', 'phased_lstm_cell/new_input/bias'),
114    ('phased_lstm_cell/output_gate/weights',
115     'phased_lstm_cell/output_gate/kernel'),
116    ('phased_lstm_cell/output_gate/biases',
117     'phased_lstm_cell/output_gate/bias'),
118    # AttentionCellWrapper
119    ('attention_cell_wrapper/weights', 'attention_cell_wrapper/kernel'),
120    ('attention_cell_wrapper/biases', 'attention_cell_wrapper/bias'),
121    ('attention_cell_wrapper/attn_output_projection/weights',
122     'attention_cell_wrapper/attn_output_projection/kernel'),
123    ('attention_cell_wrapper/attn_output_projection/biases',
124     'attention_cell_wrapper/attn_output_projection/bias'),
125    ('attention_cell_wrapper/attention/weights',
126     'attention_cell_wrapper/attention/kernel'),
127    ('attention_cell_wrapper/attention/biases',
128     'attention_cell_wrapper/attention/bias'),
129    ############################################################################
130    # contrib/legacy_seq2seq/python/ops/seq2seq.py
131    ('attention_decoder/weights', 'attention_decoder/kernel'),
132    ('attention_decoder/biases', 'attention_decoder/bias'),
133    ('attention_decoder/Attention_0/weights',
134     'attention_decoder/Attention_0/kernel'),
135    ('attention_decoder/Attention_0/biases',
136     'attention_decoder/Attention_0/bias'),
137    ('attention_decoder/AttnOutputProjection/weights',
138     'attention_decoder/AttnOutputProjection/kernel'),
139    ('attention_decoder/AttnOutputProjection/biases',
140     'attention_decoder/AttnOutputProjection/bias'),
141    # contrib/legacy_seq2seq/python/ops/seq2seq.py before cl/140060366
142    ('attention_decoder/Attention_0/Linear/Bias',
143     'attention_decoder/Attention_0/bias'),
144    ('attention_decoder/Attention_0/Linear/Matrix',
145     'attention_decoder/Attention_0/kernel'),
146    ('attention_decoder/AttnOutputProjection/Linear/Bias',
147     'attention_decoder/AttnOutputProjection/bias'),
148    ('attention_decoder/AttnOutputProjection/Linear/Matrix',
149     'attention_decoder/AttnOutputProjection/kernel'),
150    ('attention_decoder/LSTMCell/B', 'attention_decoder/lstm_cell/bias'),
151    ('attention_decoder/LSTMCell/W_0', 'attention_decoder/lstm_cell/kernel'),
152    ('attention_decoder/Linear/Bias', 'attention_decoder/bias'),
153    ('attention_decoder/Linear/Matrix', 'attention_decoder/kernel')
154])
155
156_RNN_SHARDED_NAME_REPLACEMENTS = collections.OrderedDict([
157    ('LSTMCell/W_', 'lstm_cell/weights/part_'),
158    ('BasicLSTMCell/Linear/Matrix_', 'basic_lstm_cell/weights/part_'),
159    ('GRUCell/W_', 'gru_cell/weights/part_'),
160    ('MultiRNNCell/Cell', 'multi_rnn_cell/cell_'),
161])
162
163
164def _rnn_name_replacement(var_name):
165  for pattern in RNN_NAME_REPLACEMENTS:
166    if pattern in var_name:
167      old_var_name = var_name
168      var_name = var_name.replace(pattern, RNN_NAME_REPLACEMENTS[pattern])
169      logging.info('Converted: %s --> %s' % (old_var_name, var_name))
170      break
171  return var_name
172
173
174def _rnn_name_replacement_sharded(var_name):
175  for pattern in _RNN_SHARDED_NAME_REPLACEMENTS:
176    if pattern in var_name:
177      old_var_name = var_name
178      var_name = var_name.replace(pattern,
179                                  _RNN_SHARDED_NAME_REPLACEMENTS[pattern])
180      logging.info('Converted: %s --> %s' % (old_var_name, var_name))
181  return var_name
182
183
184def _split_sharded_vars(name_shape_map):
185  """Split shareded variables.
186
187  Args:
188    name_shape_map: A dict from variable name to variable shape.
189
190  Returns:
191    not_sharded: Names of the non-sharded variables.
192    sharded: Names of the sharded variables.
193  """
194  sharded = []
195  not_sharded = []
196  for name in name_shape_map:
197    if re.match(name, '_[0-9]+$'):
198      if re.sub('_[0-9]+$', '_1', name) in name_shape_map:
199        sharded.append(name)
200      else:
201        not_sharded.append(name)
202    else:
203      not_sharded.append(name)
204  return not_sharded, sharded
205
206
207def convert_names(checkpoint_from_path,
208                  checkpoint_to_path,
209                  write_v1_checkpoint=False):
210  """Migrates the names of variables within a checkpoint.
211
212  Args:
213    checkpoint_from_path: Path to source checkpoint to be read in.
214    checkpoint_to_path: Path to checkpoint to be written out.
215    write_v1_checkpoint: Whether the output checkpoint will be in V1 format.
216
217  Returns:
218    A dictionary that maps the new variable names to the Variable objects.
219    A dictionary that maps the old variable names to the new variable names.
220  """
221  with ops.Graph().as_default():
222    logging.info('Reading checkpoint_from_path %s' % checkpoint_from_path)
223    reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_from_path)
224    name_shape_map = reader.get_variable_to_shape_map()
225    not_sharded, sharded = _split_sharded_vars(name_shape_map)
226    new_variable_map = {}
227    conversion_map = {}
228    for var_name in not_sharded:
229      new_var_name = _rnn_name_replacement(var_name)
230      tensor = reader.get_tensor(var_name)
231      var = variables.Variable(tensor, name=var_name)
232      new_variable_map[new_var_name] = var
233      if new_var_name != var_name:
234        conversion_map[var_name] = new_var_name
235    for var_name in sharded:
236      new_var_name = _rnn_name_replacement_sharded(var_name)
237      var = variables.Variable(tensor, name=var_name)
238      new_variable_map[new_var_name] = var
239      if new_var_name != var_name:
240        conversion_map[var_name] = new_var_name
241
242    write_version = (saver_pb2.SaverDef.V1
243                     if write_v1_checkpoint else saver_pb2.SaverDef.V2)
244    saver = saver_lib.Saver(new_variable_map, write_version=write_version)
245
246    with session.Session() as sess:
247      sess.run(variables.global_variables_initializer())
248      logging.info('Writing checkpoint_to_path %s' % checkpoint_to_path)
249      saver.save(sess, checkpoint_to_path)
250
251  logging.info('Summary:')
252  logging.info('  Converted %d variable name(s).' % len(new_variable_map))
253  return new_variable_map, conversion_map
254
255
256def main(_):
257  convert_names(
258      FLAGS.checkpoint_from_path,
259      FLAGS.checkpoint_to_path,
260      write_v1_checkpoint=FLAGS.write_v1_checkpoint)
261
262
263if __name__ == '__main__':
264  parser = argparse.ArgumentParser()
265  parser.register('type', 'bool', lambda v: v.lower() == 'true')
266  parser.add_argument('checkpoint_from_path', type=str,
267                      help='Path to source checkpoint to be read in.')
268  parser.add_argument('checkpoint_to_path', type=str,
269                      help='Path to checkpoint to be written out.')
270  parser.add_argument('--write_v1_checkpoint', action='store_true',
271                      help='Write v1 checkpoint')
272  FLAGS, unparsed = parser.parse_known_args()
273
274  app.run(main=main, argv=[sys.argv[0]] + unparsed)
275