1#!/usr/bin/env python 2"""Extracts trainable parameters from Tensorflow models and stores them in numpy arrays. 3Usage 4 python tensorflow_data_extractor -m path_to_binary_checkpoint_file -n path_to_metagraph_file 5 6Saves each variable to a {variable_name}.npy binary file. 7 8Note that since Tensorflow version 0.11 the binary checkpoint file which contains the values for each parameter has the format of: 9 {model_name}.data-{step}-of-{max_step} 10instead of: 11 {model_name}.ckpt 12When dealing with binary files with version >= 0.11, only pass {model_name} to -m option; 13when dealing with binary files with version < 0.11, pass the whole file name {model_name}.ckpt to -m option. 14 15Also note that this script relies on the parameters to be extracted being in the 16'trainable_variables' tensor collection. By default all variables are automatically added to this collection unless 17specified otherwise by the user. Thus should a user alter this default behavior and/or want to extract parameters from other 18collections, tf.GraphKeys.TRAINABLE_VARIABLES should be replaced accordingly. 19 20Tested with Tensorflow 1.2, 1.3 on Python 2.7.6 and Python 3.4.3. 21""" 22import argparse 23import numpy as np 24import os 25import tensorflow as tf 26 27 28if __name__ == "__main__": 29 # Parse arguments 30 parser = argparse.ArgumentParser('Extract Tensorflow net parameters') 31 parser.add_argument('-m', dest='modelFile', type=str, required=True, help='Path to Tensorflow checkpoint binary\ 32 file. For Tensorflow version >= 0.11, only include model name; for Tensorflow version < 0.11, include\ 33 model name with ".ckpt" extension') 34 parser.add_argument('-n', dest='netFile', type=str, required=True, help='Path to Tensorflow MetaGraph file') 35 args = parser.parse_args() 36 37 # Load Tensorflow Net 38 saver = tf.train.import_meta_graph(args.netFile) 39 with tf.Session() as sess: 40 # Restore session 41 saver.restore(sess, args.modelFile) 42 print('Model restored.') 43 # Save trainable variables to numpy arrays 44 for t in tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES): 45 varname = t.name 46 if os.path.sep in t.name: 47 varname = varname.replace(os.path.sep, '_') 48 print("Renaming variable {0} to {1}".format(t.name, varname)) 49 print("Saving variable {0} with shape {1} ...".format(varname, t.shape)) 50 # Dump as binary 51 np.save(varname, sess.run(t)) 52