1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Utilities to remove pruning-related ops and variables from a GraphDef. 16""" 17 18# pylint: disable=missing-docstring 19from __future__ import absolute_import 20from __future__ import division 21from __future__ import print_function 22 23import numpy as np 24 25from tensorflow.core.framework import attr_value_pb2 26from tensorflow.core.framework import graph_pb2 27from tensorflow.core.framework import node_def_pb2 28from tensorflow.python.client import session 29from tensorflow.python.framework import graph_util 30from tensorflow.python.framework import importer 31from tensorflow.python.framework import ops 32from tensorflow.python.framework import tensor_util 33from tensorflow.python.platform import tf_logging as logging 34from tensorflow.python.training import saver as saver_lib 35 36 37def _node_name(tensor_name): 38 """Remove the trailing ':0' from the variable name.""" 39 if ':' not in tensor_name: 40 return tensor_name 41 42 return tensor_name.split(':')[0] 43 44 45def _tensor_name(node_name): 46 """Appends the :0 in the op name to get the canonical tensor name.""" 47 if ':' in node_name: 48 return node_name 49 50 return node_name + ':0' 51 52 53def _get_masked_weights(input_graph_def): 54 """Extracts masked_weights from the graph as a dict of {var_name:ndarray}.""" 55 input_graph = ops.Graph() 56 with input_graph.as_default(): 57 importer.import_graph_def(input_graph_def, name='') 58 59 with session.Session(graph=input_graph) as sess: 60 masked_weights_dict = {} 61 for node in input_graph_def.node: 62 if 'masked_weight' in node.name: 63 masked_weight_val = sess.run( 64 sess.graph.get_tensor_by_name(_tensor_name(node.name))) 65 logging.info( 66 '%s has %d values, %1.2f%% zeros \n', node.name, 67 np.size(masked_weight_val), 68 100 - float(100 * np.count_nonzero(masked_weight_val)) / 69 np.size(masked_weight_val)) 70 masked_weights_dict.update({node.name: masked_weight_val}) 71 return masked_weights_dict 72 73 74def strip_pruning_vars_fn(input_graph_def, output_node_names): 75 """Removes mask variable from the graph. 76 77 Replaces the masked_weight tensor with element-wise multiplication of mask 78 and the corresponding weight variable. 79 80 Args: 81 input_graph_def: A GraphDef in which the variables have been converted to 82 constants. This is typically the output of 83 tf.graph_util.convert_variables_to_constant() 84 output_node_names: List of name strings for the result nodes of the graph 85 86 Returns: 87 A GraphDef in which pruning-related variables have been removed 88 """ 89 masked_weights_dict = _get_masked_weights(input_graph_def) 90 pruned_graph_def = graph_pb2.GraphDef() 91 92 # Replace masked_weight with a const op containing the 93 # result of tf.multiply(mask,weight) 94 for node in input_graph_def.node: 95 output_node = node_def_pb2.NodeDef() 96 if 'masked_weight' in node.name: 97 output_node.op = 'Const' 98 output_node.name = node.name 99 dtype = node.attr['T'] 100 data = masked_weights_dict[node.name] 101 output_node.attr['dtype'].CopyFrom(dtype) 102 output_node.attr['value'].CopyFrom( 103 attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(data))) 104 105 else: 106 output_node.CopyFrom(node) 107 pruned_graph_def.node.extend([output_node]) 108 109 # Remove stranded nodes: mask and weights 110 return graph_util.extract_sub_graph(pruned_graph_def, output_node_names) 111 112 113def graph_def_from_checkpoint(checkpoint_dir, output_node_names): 114 """Converts checkpoint data to GraphDef. 115 116 Reads the latest checkpoint data and produces a GraphDef in which the 117 variables have been converted to constants. 118 119 Args: 120 checkpoint_dir: Path to the checkpoints. 121 output_node_names: List of name strings for the result nodes of the graph. 122 123 Returns: 124 A GraphDef from the latest checkpoint 125 126 Raises: 127 ValueError: if no checkpoint is found 128 """ 129 checkpoint_path = saver_lib.latest_checkpoint(checkpoint_dir) 130 if checkpoint_path is None: 131 raise ValueError('Could not find a checkpoint at: {0}.' 132 .format(checkpoint_dir)) 133 134 saver_for_restore = saver_lib.import_meta_graph( 135 checkpoint_path + '.meta', clear_devices=True) 136 with session.Session() as sess: 137 saver_for_restore.restore(sess, checkpoint_path) 138 graph_def = ops.get_default_graph().as_graph_def() 139 output_graph_def = graph_util.convert_variables_to_constants( 140 sess, graph_def, output_node_names) 141 142 return output_graph_def 143