• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Helpers to manipulate a tensor graph in python.
16"""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21import copy
22import re
23import six
24
25from tensorflow.core.framework import attr_value_pb2
26from tensorflow.core.framework import graph_pb2
27from tensorflow.core.framework import node_def_pb2
28from tensorflow.python.framework import dtypes
29from tensorflow.python.framework import ops
30from tensorflow.python.framework import tensor_util
31from tensorflow.python.platform import tf_logging as logging
32from tensorflow.python.util import deprecation
33from tensorflow.python.util.tf_export import tf_export
34
35_VARIABLE_OPS = {
36    "Assign",
37    "AssignAdd",
38    "AssignSub",
39    "Queue",
40    "ScatterAdd",
41    "ScatterSub",
42    "ScatterUpdate",
43    "TruncatedNormal",
44    "Variable",
45    "VariableV2",
46}
47
48
49def _is_variable_op(op):
50  """Returns true if 'op' refers to a Variable node."""
51  return op in _VARIABLE_OPS
52
53
54@deprecation.deprecated(
55    date=None,
56    instructions="Use `tf.compat.v1.graph_util.must_run_on_cpu`")
57@tf_export(v1=["graph_util.must_run_on_cpu"])
58def must_run_on_cpu(node, pin_variables_on_cpu=False):
59  """Returns True if the given node_def must run on CPU, otherwise False.
60
61  Args:
62    node: The node to be assigned to a device. Could be either an ops.Operation
63      or NodeDef.
64    pin_variables_on_cpu: If True, this function will return False if node_def
65      represents a variable-related op.
66
67  Returns:
68    True if the given node must run on CPU, otherwise False.
69  """
70
71  if isinstance(node, ops.Operation):
72    node_def = node.node_def
73  else:
74    assert isinstance(node, node_def_pb2.NodeDef)
75    node_def = node
76
77  # If the op is a variable-related op, should we pin it on CPU?
78  if pin_variables_on_cpu and _is_variable_op(node_def.op):
79    return True
80
81  # Constant operations producing a string or int32 must run on CPU.
82  if node_def.op == "Const":
83    # Get the value of the 'dtype' attr
84    dtype = node_def.attr["dtype"].type
85    if dtype == dtypes.string or dtype == dtypes.int32:
86      return True
87
88  if node_def.op in ["DynamicStitch", "ParallelDynamicStitch"]:
89    dtype = node_def.attr["T"].type
90    if dtype == dtypes.int32:
91      # DynamicStitch on GPU only works for int32 values.
92      return True
93
94  if node_def.op in ["Cast"]:
95    dtype = node_def.attr["SrcT"].type
96    if dtype == dtypes.int32:
97      # Cast on GPU does not works for int32 values.
98      return True
99  return False
100
101
102################################################################################
103#
104# device functions for use in with g.device(...)
105#
106################################################################################
107
108
109def _node_name(n):
110  if n.startswith("^"):
111    return n[1:]
112  else:
113    return n.split(":")[0]
114
115
116def _extract_graph_summary(graph_def):
117  """Extracts useful information from the graph and returns them."""
118  name_to_input_name = {}  # Keyed by the dest node name.
119  name_to_node = {}  # Keyed by node name.
120
121  # Keeps track of node sequences. It is important to still output the
122  # operations in the original order.
123  name_to_seq_num = {}  # Keyed by node name.
124  seq = 0
125  for node in graph_def.node:
126    n = _node_name(node.name)
127    name_to_node[n] = node
128    name_to_input_name[n] = [_node_name(x) for x in node.input]
129    name_to_seq_num[n] = seq
130    seq += 1
131  return name_to_input_name, name_to_node, name_to_seq_num
132
133
134def _assert_nodes_are_present(name_to_node, nodes):
135  """Assert that nodes are present in the graph."""
136  for d in nodes:
137    assert d in name_to_node, "%s is not in graph" % d
138
139
140def _bfs_for_reachable_nodes(target_nodes, name_to_input_name):
141  """Breadth first search for reachable nodes from target nodes."""
142  nodes_to_keep = set()
143  # Breadth first search to find all the nodes that we should keep.
144  next_to_visit = target_nodes[:]
145  while next_to_visit:
146    node = next_to_visit[0]
147    del next_to_visit[0]
148    if node in nodes_to_keep:
149      # Already visited this node.
150      continue
151    nodes_to_keep.add(node)
152    if node in name_to_input_name:
153      next_to_visit += name_to_input_name[node]
154  return nodes_to_keep
155
156
157@deprecation.deprecated(
158    date=None,
159    instructions="Use `tf.compat.v1.graph_util.extract_sub_graph`")
160@tf_export(v1=["graph_util.extract_sub_graph"])
161def extract_sub_graph(graph_def, dest_nodes):
162  """Extract the subgraph that can reach any of the nodes in 'dest_nodes'.
163
164  Args:
165    graph_def: A graph_pb2.GraphDef proto.
166    dest_nodes: A list of strings specifying the destination node names.
167  Returns:
168    The GraphDef of the sub-graph.
169
170  Raises:
171    TypeError: If 'graph_def' is not a graph_pb2.GraphDef proto.
172  """
173
174  if not isinstance(graph_def, graph_pb2.GraphDef):
175    raise TypeError("graph_def must be a graph_pb2.GraphDef proto.")
176
177  if isinstance(dest_nodes, six.string_types):
178    raise TypeError("dest_nodes must be a list.")
179
180  name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary(
181      graph_def)
182  _assert_nodes_are_present(name_to_node, dest_nodes)
183
184  nodes_to_keep = _bfs_for_reachable_nodes(dest_nodes, name_to_input_name)
185
186  nodes_to_keep_list = sorted(
187      list(nodes_to_keep), key=lambda n: name_to_seq_num[n])
188  # Now construct the output GraphDef
189  out = graph_pb2.GraphDef()
190  for n in nodes_to_keep_list:
191    out.node.extend([copy.deepcopy(name_to_node[n])])
192  out.library.CopyFrom(graph_def.library)
193  out.versions.CopyFrom(graph_def.versions)
194
195  return out
196
197
198@deprecation.deprecated(
199    date=None,
200    instructions="Use `tf.compat.v1.graph_util.tensor_shape_from_node_def_name`"
201)
202@tf_export(v1=["graph_util.tensor_shape_from_node_def_name"])
203def tensor_shape_from_node_def_name(graph, input_name):
204  """Convenience function to get a shape from a NodeDef's input string."""
205  # To get a tensor, the name must be in the form <input>:<port>, for example
206  # 'Mul:0'. The GraphDef input strings don't always have the port specified
207  # though, so if there isn't a colon we need to add a default ':0' to the end.
208  if ":" not in input_name:
209    canonical_name = input_name + ":0"
210  else:
211    canonical_name = input_name
212  tensor = graph.get_tensor_by_name(canonical_name)
213  shape = tensor.get_shape()
214  return shape
215
216
217@deprecation.deprecated(
218    date=None,
219    instructions="Use `tf.compat.v1.graph_util.convert_variables_to_constants`")
220@tf_export(v1=["graph_util.convert_variables_to_constants"])
221def convert_variables_to_constants(sess,
222                                   input_graph_def,
223                                   output_node_names,
224                                   variable_names_whitelist=None,
225                                   variable_names_blacklist=None):
226  """Replaces all the variables in a graph with constants of the same values.
227
228  If you have a trained graph containing Variable ops, it can be convenient to
229  convert them all to Const ops holding the same values. This makes it possible
230  to describe the network fully with a single GraphDef file, and allows the
231  removal of a lot of ops related to loading and saving the variables.
232
233  Args:
234    sess: Active TensorFlow session containing the variables.
235    input_graph_def: GraphDef object holding the network.
236    output_node_names: List of name strings for the result nodes of the graph.
237    variable_names_whitelist: The set of variable names to convert (by default,
238                              all variables are converted).
239    variable_names_blacklist: The set of variable names to omit converting
240                              to constants.
241
242  Returns:
243    GraphDef containing a simplified version of the original.
244  """
245  # This graph only includes the nodes needed to evaluate the output nodes, and
246  # removes unneeded nodes like those involved in saving and assignment.
247  inference_graph = extract_sub_graph(input_graph_def, output_node_names)
248
249  found_variables = {}
250  variable_names = []
251  variable_dict_names = []
252  for node in inference_graph.node:
253    if node.op in ["Variable", "VariableV2", "VarHandleOp"]:
254      variable_name = node.name
255      if ((variable_names_whitelist is not None and
256           variable_name not in variable_names_whitelist) or
257          (variable_names_blacklist is not None and
258           variable_name in variable_names_blacklist)):
259        continue
260      variable_dict_names.append(variable_name)
261      if node.op == "VarHandleOp":
262        variable_names.append(variable_name + "/Read/ReadVariableOp:0")
263      else:
264        variable_names.append(variable_name + ":0")
265  if variable_names:
266    returned_variables = sess.run(variable_names)
267  else:
268    returned_variables = []
269  found_variables = dict(zip(variable_dict_names, returned_variables))
270  logging.info("Froze %d variables.", len(returned_variables))
271
272  output_graph_def = graph_pb2.GraphDef()
273  how_many_converted = 0
274  for input_node in inference_graph.node:
275    output_node = node_def_pb2.NodeDef()
276    if input_node.name in found_variables:
277      output_node.op = "Const"
278      output_node.name = input_node.name
279      dtype = input_node.attr["dtype"]
280      data = found_variables[input_node.name]
281      output_node.attr["dtype"].CopyFrom(dtype)
282      output_node.attr["value"].CopyFrom(
283          attr_value_pb2.AttrValue(
284              tensor=tensor_util.make_tensor_proto(
285                  data, dtype=dtype.type, shape=data.shape)))
286      how_many_converted += 1
287    elif input_node.op == "ReadVariableOp" and (
288        input_node.input[0] in found_variables):
289      # The preceding branch converts all VarHandleOps of ResourceVariables to
290      # constants, so we need to convert the associated ReadVariableOps to
291      # Identity ops.
292      output_node.op = "Identity"
293      output_node.name = input_node.name
294      output_node.input.extend([input_node.input[0]])
295      output_node.attr["T"].CopyFrom(input_node.attr["dtype"])
296      if "_class" in input_node.attr:
297        output_node.attr["_class"].CopyFrom(input_node.attr["_class"])
298    else:
299      output_node.CopyFrom(input_node)
300    output_graph_def.node.extend([output_node])
301
302  output_graph_def.library.CopyFrom(inference_graph.library)
303  logging.info("Converted %d variables to const ops.", how_many_converted)
304  return output_graph_def
305
306
307@deprecation.deprecated(
308    date=None,
309    instructions="Use `tf.compat.v1.graph_util.remove_training_nodes`")
310@tf_export(v1=["graph_util.remove_training_nodes"])
311def remove_training_nodes(input_graph, protected_nodes=None):
312  """Prunes out nodes that aren't needed for inference.
313
314  There are nodes like Identity and CheckNumerics that are only useful
315  during training, and can be removed in graphs that will be used for
316  nothing but inference. Here we identify and remove them, returning an
317  equivalent graph. To be specific, CheckNumerics nodes are always removed, and
318  Identity nodes that aren't involved in control edges are spliced out so that
319  their input and outputs are directly connected.
320
321  Args:
322    input_graph: Model to analyze and prune.
323    protected_nodes: An optional list of names of nodes to be kept
324      unconditionally. This is for example useful to preserve Identity output
325      nodes.
326
327  Returns:
328    A list of nodes with the unnecessary ones removed.
329  """
330  if not protected_nodes:
331    protected_nodes = []
332
333  types_to_remove = {"CheckNumerics": True}
334
335  input_nodes = input_graph.node
336  names_to_remove = {}
337  for node in input_nodes:
338    if node.op in types_to_remove and node.name not in protected_nodes:
339      names_to_remove[node.name] = True
340
341  nodes_after_removal = []
342  for node in input_nodes:
343    if node.name in names_to_remove:
344      continue
345    new_node = node_def_pb2.NodeDef()
346    new_node.CopyFrom(node)
347    input_before_removal = node.input
348    del new_node.input[:]
349    for full_input_name in input_before_removal:
350      input_name = re.sub(r"^\^", "", full_input_name)
351      if input_name in names_to_remove:
352        continue
353      new_node.input.append(full_input_name)
354    nodes_after_removal.append(new_node)
355
356  types_to_splice = {"Identity": True}
357  control_input_names = set()
358  node_names_with_control_input = set()
359  for node in nodes_after_removal:
360    for node_input in node.input:
361      if "^" in node_input:
362        control_input_names.add(node_input.replace("^", ""))
363        node_names_with_control_input.add(node.name)
364
365  names_to_splice = {}
366  for node in nodes_after_removal:
367    if node.op in types_to_splice and node.name not in protected_nodes:
368      # We don't want to remove nodes that have control edge inputs, because
369      # they might be involved in subtle dependency issues that removing them
370      # will jeopardize.
371      if node.name not in node_names_with_control_input:
372        names_to_splice[node.name] = node.input[0]
373
374  # We also don't want to remove nodes which are used as control edge inputs.
375  names_to_splice = {name: value for name, value in names_to_splice.items()
376                     if name not in control_input_names}
377
378  nodes_after_splicing = []
379  for node in nodes_after_removal:
380    if node.name in names_to_splice:
381      continue
382    new_node = node_def_pb2.NodeDef()
383    new_node.CopyFrom(node)
384    input_before_removal = node.input
385    del new_node.input[:]
386    for full_input_name in input_before_removal:
387      input_name = re.sub(r"^\^", "", full_input_name)
388      while input_name in names_to_splice:
389        full_input_name = names_to_splice[input_name]
390        input_name = re.sub(r"^\^", "", full_input_name)
391      new_node.input.append(full_input_name)
392    nodes_after_splicing.append(new_node)
393
394  output_graph = graph_pb2.GraphDef()
395  output_graph.node.extend(nodes_after_splicing)
396  return output_graph
397