• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Utilities to remove pruning-related ops and variables from a GraphDef.
16"""
17
18# pylint: disable=missing-docstring
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23import numpy as np
24
25from tensorflow.core.framework import attr_value_pb2
26from tensorflow.core.framework import graph_pb2
27from tensorflow.core.framework import node_def_pb2
28from tensorflow.python.client import session
29from tensorflow.python.framework import graph_util
30from tensorflow.python.framework import importer
31from tensorflow.python.framework import ops
32from tensorflow.python.framework import tensor_util
33from tensorflow.python.platform import tf_logging as logging
34from tensorflow.python.training import saver as saver_lib
35
36
37def _node_name(tensor_name):
38  """Remove the trailing ':0' from the variable name."""
39  if ':' not in tensor_name:
40    return tensor_name
41
42  return tensor_name.split(':')[0]
43
44
45def _tensor_name(node_name):
46  """Appends the :0 in the op name to get the canonical tensor name."""
47  if ':' in node_name:
48    return node_name
49
50  return node_name + ':0'
51
52
53def _get_masked_weights(input_graph_def):
54  """Extracts masked_weights from the graph as a dict of {var_name:ndarray}."""
55  input_graph = ops.Graph()
56  with input_graph.as_default():
57    importer.import_graph_def(input_graph_def, name='')
58
59    with session.Session(graph=input_graph) as sess:
60      masked_weights_dict = {}
61      for node in input_graph_def.node:
62        if 'masked_weight' in node.name:
63          masked_weight_val = sess.run(
64              sess.graph.get_tensor_by_name(_tensor_name(node.name)))
65          logging.info(
66              '%s has %d values, %1.2f%% zeros \n', node.name,
67              np.size(masked_weight_val),
68              100 - float(100 * np.count_nonzero(masked_weight_val)) /
69              np.size(masked_weight_val))
70          masked_weights_dict.update({node.name: masked_weight_val})
71  return masked_weights_dict
72
73
74def strip_pruning_vars_fn(input_graph_def, output_node_names):
75  """Removes mask variable from the graph.
76
77  Replaces the masked_weight tensor with element-wise multiplication of mask
78  and the corresponding weight variable.
79
80  Args:
81    input_graph_def: A GraphDef in which the variables have been converted to
82      constants. This is typically the output of
83      tf.graph_util.convert_variables_to_constant()
84    output_node_names: List of name strings for the result nodes of the graph
85
86  Returns:
87    A GraphDef in which pruning-related variables have been removed
88  """
89  masked_weights_dict = _get_masked_weights(input_graph_def)
90  pruned_graph_def = graph_pb2.GraphDef()
91
92  # Replace masked_weight with a const op containing the
93  # result of tf.multiply(mask,weight)
94  for node in input_graph_def.node:
95    output_node = node_def_pb2.NodeDef()
96    if 'masked_weight' in node.name:
97      output_node.op = 'Const'
98      output_node.name = node.name
99      dtype = node.attr['T']
100      data = masked_weights_dict[node.name]
101      output_node.attr['dtype'].CopyFrom(dtype)
102      output_node.attr['value'].CopyFrom(
103          attr_value_pb2.AttrValue(tensor=tensor_util.make_tensor_proto(data)))
104
105    else:
106      output_node.CopyFrom(node)
107    pruned_graph_def.node.extend([output_node])
108
109  # Remove stranded nodes: mask and weights
110  return graph_util.extract_sub_graph(pruned_graph_def, output_node_names)
111
112
113def graph_def_from_checkpoint(checkpoint_dir, output_node_names):
114  """Converts checkpoint data to GraphDef.
115
116  Reads the latest checkpoint data and produces a GraphDef in which the
117  variables have been converted to constants.
118
119  Args:
120    checkpoint_dir: Path to the checkpoints.
121    output_node_names: List of name strings for the result nodes of the graph.
122
123  Returns:
124    A GraphDef from the latest checkpoint
125
126  Raises:
127    ValueError: if no checkpoint is found
128  """
129  checkpoint_path = saver_lib.latest_checkpoint(checkpoint_dir)
130  if checkpoint_path is None:
131    raise ValueError('Could not find a checkpoint at: {0}.'
132                     .format(checkpoint_dir))
133
134  saver_for_restore = saver_lib.import_meta_graph(
135      checkpoint_path + '.meta', clear_devices=True)
136  with session.Session() as sess:
137    saver_for_restore.restore(sess, checkpoint_path)
138    graph_def = ops.get_default_graph().as_graph_def()
139    output_graph_def = graph_util.convert_variables_to_constants(
140        sess, graph_def, output_node_names)
141
142  return output_graph_def
143