1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Helpers to manipulate a tensor graph in python. 16""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21import copy 22import re 23import six 24 25from tensorflow.core.framework import attr_value_pb2 26from tensorflow.core.framework import graph_pb2 27from tensorflow.core.framework import node_def_pb2 28from tensorflow.python.framework import dtypes 29from tensorflow.python.framework import ops 30from tensorflow.python.framework import tensor_util 31from tensorflow.python.platform import tf_logging as logging 32from tensorflow.python.util import deprecation 33from tensorflow.python.util.tf_export import tf_export 34 35_VARIABLE_OPS = { 36 "Assign", 37 "AssignAdd", 38 "AssignSub", 39 "Queue", 40 "ScatterAdd", 41 "ScatterSub", 42 "ScatterUpdate", 43 "TruncatedNormal", 44 "Variable", 45 "VariableV2", 46} 47 48 49def _is_variable_op(op): 50 """Returns true if 'op' refers to a Variable node.""" 51 return op in _VARIABLE_OPS 52 53 54@deprecation.deprecated( 55 date=None, 56 instructions="Use `tf.compat.v1.graph_util.must_run_on_cpu`") 57@tf_export(v1=["graph_util.must_run_on_cpu"]) 58def must_run_on_cpu(node, pin_variables_on_cpu=False): 59 """Returns True if the given node_def must run on CPU, otherwise False. 60 61 Args: 62 node: The node to be assigned to a device. Could be either an ops.Operation 63 or NodeDef. 64 pin_variables_on_cpu: If True, this function will return False if node_def 65 represents a variable-related op. 66 67 Returns: 68 True if the given node must run on CPU, otherwise False. 69 """ 70 71 if isinstance(node, ops.Operation): 72 node_def = node.node_def 73 else: 74 assert isinstance(node, node_def_pb2.NodeDef) 75 node_def = node 76 77 # If the op is a variable-related op, should we pin it on CPU? 78 if pin_variables_on_cpu and _is_variable_op(node_def.op): 79 return True 80 81 # Constant operations producing a string or int32 must run on CPU. 82 if node_def.op == "Const": 83 # Get the value of the 'dtype' attr 84 dtype = node_def.attr["dtype"].type 85 if dtype == dtypes.string or dtype == dtypes.int32: 86 return True 87 88 if node_def.op in ["DynamicStitch", "ParallelDynamicStitch"]: 89 dtype = node_def.attr["T"].type 90 if dtype == dtypes.int32: 91 # DynamicStitch on GPU only works for int32 values. 92 return True 93 94 if node_def.op in ["Cast"]: 95 dtype = node_def.attr["SrcT"].type 96 if dtype == dtypes.int32: 97 # Cast on GPU does not works for int32 values. 98 return True 99 return False 100 101 102################################################################################ 103# 104# device functions for use in with g.device(...) 105# 106################################################################################ 107 108 109def _node_name(n): 110 if n.startswith("^"): 111 return n[1:] 112 else: 113 return n.split(":")[0] 114 115 116def _extract_graph_summary(graph_def): 117 """Extracts useful information from the graph and returns them.""" 118 name_to_input_name = {} # Keyed by the dest node name. 119 name_to_node = {} # Keyed by node name. 120 121 # Keeps track of node sequences. It is important to still output the 122 # operations in the original order. 123 name_to_seq_num = {} # Keyed by node name. 124 seq = 0 125 for node in graph_def.node: 126 n = _node_name(node.name) 127 name_to_node[n] = node 128 name_to_input_name[n] = [_node_name(x) for x in node.input] 129 name_to_seq_num[n] = seq 130 seq += 1 131 return name_to_input_name, name_to_node, name_to_seq_num 132 133 134def _assert_nodes_are_present(name_to_node, nodes): 135 """Assert that nodes are present in the graph.""" 136 for d in nodes: 137 assert d in name_to_node, "%s is not in graph" % d 138 139 140def _bfs_for_reachable_nodes(target_nodes, name_to_input_name): 141 """Breadth first search for reachable nodes from target nodes.""" 142 nodes_to_keep = set() 143 # Breadth first search to find all the nodes that we should keep. 144 next_to_visit = target_nodes[:] 145 while next_to_visit: 146 node = next_to_visit[0] 147 del next_to_visit[0] 148 if node in nodes_to_keep: 149 # Already visited this node. 150 continue 151 nodes_to_keep.add(node) 152 if node in name_to_input_name: 153 next_to_visit += name_to_input_name[node] 154 return nodes_to_keep 155 156 157@deprecation.deprecated( 158 date=None, 159 instructions="Use `tf.compat.v1.graph_util.extract_sub_graph`") 160@tf_export(v1=["graph_util.extract_sub_graph"]) 161def extract_sub_graph(graph_def, dest_nodes): 162 """Extract the subgraph that can reach any of the nodes in 'dest_nodes'. 163 164 Args: 165 graph_def: A graph_pb2.GraphDef proto. 166 dest_nodes: A list of strings specifying the destination node names. 167 Returns: 168 The GraphDef of the sub-graph. 169 170 Raises: 171 TypeError: If 'graph_def' is not a graph_pb2.GraphDef proto. 172 """ 173 174 if not isinstance(graph_def, graph_pb2.GraphDef): 175 raise TypeError("graph_def must be a graph_pb2.GraphDef proto.") 176 177 if isinstance(dest_nodes, six.string_types): 178 raise TypeError("dest_nodes must be a list.") 179 180 name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary( 181 graph_def) 182 _assert_nodes_are_present(name_to_node, dest_nodes) 183 184 nodes_to_keep = _bfs_for_reachable_nodes(dest_nodes, name_to_input_name) 185 186 nodes_to_keep_list = sorted( 187 list(nodes_to_keep), key=lambda n: name_to_seq_num[n]) 188 # Now construct the output GraphDef 189 out = graph_pb2.GraphDef() 190 for n in nodes_to_keep_list: 191 out.node.extend([copy.deepcopy(name_to_node[n])]) 192 out.library.CopyFrom(graph_def.library) 193 out.versions.CopyFrom(graph_def.versions) 194 195 return out 196 197 198@deprecation.deprecated( 199 date=None, 200 instructions="Use `tf.compat.v1.graph_util.tensor_shape_from_node_def_name`" 201) 202@tf_export(v1=["graph_util.tensor_shape_from_node_def_name"]) 203def tensor_shape_from_node_def_name(graph, input_name): 204 """Convenience function to get a shape from a NodeDef's input string.""" 205 # To get a tensor, the name must be in the form <input>:<port>, for example 206 # 'Mul:0'. The GraphDef input strings don't always have the port specified 207 # though, so if there isn't a colon we need to add a default ':0' to the end. 208 if ":" not in input_name: 209 canonical_name = input_name + ":0" 210 else: 211 canonical_name = input_name 212 tensor = graph.get_tensor_by_name(canonical_name) 213 shape = tensor.get_shape() 214 return shape 215 216 217@deprecation.deprecated( 218 date=None, 219 instructions="Use `tf.compat.v1.graph_util.convert_variables_to_constants`") 220@tf_export(v1=["graph_util.convert_variables_to_constants"]) 221def convert_variables_to_constants(sess, 222 input_graph_def, 223 output_node_names, 224 variable_names_whitelist=None, 225 variable_names_blacklist=None): 226 """Replaces all the variables in a graph with constants of the same values. 227 228 If you have a trained graph containing Variable ops, it can be convenient to 229 convert them all to Const ops holding the same values. This makes it possible 230 to describe the network fully with a single GraphDef file, and allows the 231 removal of a lot of ops related to loading and saving the variables. 232 233 Args: 234 sess: Active TensorFlow session containing the variables. 235 input_graph_def: GraphDef object holding the network. 236 output_node_names: List of name strings for the result nodes of the graph. 237 variable_names_whitelist: The set of variable names to convert (by default, 238 all variables are converted). 239 variable_names_blacklist: The set of variable names to omit converting 240 to constants. 241 242 Returns: 243 GraphDef containing a simplified version of the original. 244 """ 245 # This graph only includes the nodes needed to evaluate the output nodes, and 246 # removes unneeded nodes like those involved in saving and assignment. 247 inference_graph = extract_sub_graph(input_graph_def, output_node_names) 248 249 found_variables = {} 250 variable_names = [] 251 variable_dict_names = [] 252 for node in inference_graph.node: 253 if node.op in ["Variable", "VariableV2", "VarHandleOp"]: 254 variable_name = node.name 255 if ((variable_names_whitelist is not None and 256 variable_name not in variable_names_whitelist) or 257 (variable_names_blacklist is not None and 258 variable_name in variable_names_blacklist)): 259 continue 260 variable_dict_names.append(variable_name) 261 if node.op == "VarHandleOp": 262 variable_names.append(variable_name + "/Read/ReadVariableOp:0") 263 else: 264 variable_names.append(variable_name + ":0") 265 if variable_names: 266 returned_variables = sess.run(variable_names) 267 else: 268 returned_variables = [] 269 found_variables = dict(zip(variable_dict_names, returned_variables)) 270 logging.info("Froze %d variables.", len(returned_variables)) 271 272 output_graph_def = graph_pb2.GraphDef() 273 how_many_converted = 0 274 for input_node in inference_graph.node: 275 output_node = node_def_pb2.NodeDef() 276 if input_node.name in found_variables: 277 output_node.op = "Const" 278 output_node.name = input_node.name 279 dtype = input_node.attr["dtype"] 280 data = found_variables[input_node.name] 281 output_node.attr["dtype"].CopyFrom(dtype) 282 output_node.attr["value"].CopyFrom( 283 attr_value_pb2.AttrValue( 284 tensor=tensor_util.make_tensor_proto( 285 data, dtype=dtype.type, shape=data.shape))) 286 how_many_converted += 1 287 elif input_node.op == "ReadVariableOp" and ( 288 input_node.input[0] in found_variables): 289 # The preceding branch converts all VarHandleOps of ResourceVariables to 290 # constants, so we need to convert the associated ReadVariableOps to 291 # Identity ops. 292 output_node.op = "Identity" 293 output_node.name = input_node.name 294 output_node.input.extend([input_node.input[0]]) 295 output_node.attr["T"].CopyFrom(input_node.attr["dtype"]) 296 if "_class" in input_node.attr: 297 output_node.attr["_class"].CopyFrom(input_node.attr["_class"]) 298 else: 299 output_node.CopyFrom(input_node) 300 output_graph_def.node.extend([output_node]) 301 302 output_graph_def.library.CopyFrom(inference_graph.library) 303 logging.info("Converted %d variables to const ops.", how_many_converted) 304 return output_graph_def 305 306 307@deprecation.deprecated( 308 date=None, 309 instructions="Use `tf.compat.v1.graph_util.remove_training_nodes`") 310@tf_export(v1=["graph_util.remove_training_nodes"]) 311def remove_training_nodes(input_graph, protected_nodes=None): 312 """Prunes out nodes that aren't needed for inference. 313 314 There are nodes like Identity and CheckNumerics that are only useful 315 during training, and can be removed in graphs that will be used for 316 nothing but inference. Here we identify and remove them, returning an 317 equivalent graph. To be specific, CheckNumerics nodes are always removed, and 318 Identity nodes that aren't involved in control edges are spliced out so that 319 their input and outputs are directly connected. 320 321 Args: 322 input_graph: Model to analyze and prune. 323 protected_nodes: An optional list of names of nodes to be kept 324 unconditionally. This is for example useful to preserve Identity output 325 nodes. 326 327 Returns: 328 A list of nodes with the unnecessary ones removed. 329 """ 330 if not protected_nodes: 331 protected_nodes = [] 332 333 types_to_remove = {"CheckNumerics": True} 334 335 input_nodes = input_graph.node 336 names_to_remove = {} 337 for node in input_nodes: 338 if node.op in types_to_remove and node.name not in protected_nodes: 339 names_to_remove[node.name] = True 340 341 nodes_after_removal = [] 342 for node in input_nodes: 343 if node.name in names_to_remove: 344 continue 345 new_node = node_def_pb2.NodeDef() 346 new_node.CopyFrom(node) 347 input_before_removal = node.input 348 del new_node.input[:] 349 for full_input_name in input_before_removal: 350 input_name = re.sub(r"^\^", "", full_input_name) 351 if input_name in names_to_remove: 352 continue 353 new_node.input.append(full_input_name) 354 nodes_after_removal.append(new_node) 355 356 types_to_splice = {"Identity": True} 357 control_input_names = set() 358 node_names_with_control_input = set() 359 for node in nodes_after_removal: 360 for node_input in node.input: 361 if "^" in node_input: 362 control_input_names.add(node_input.replace("^", "")) 363 node_names_with_control_input.add(node.name) 364 365 names_to_splice = {} 366 for node in nodes_after_removal: 367 if node.op in types_to_splice and node.name not in protected_nodes: 368 # We don't want to remove nodes that have control edge inputs, because 369 # they might be involved in subtle dependency issues that removing them 370 # will jeopardize. 371 if node.name not in node_names_with_control_input: 372 names_to_splice[node.name] = node.input[0] 373 374 # We also don't want to remove nodes which are used as control edge inputs. 375 names_to_splice = {name: value for name, value in names_to_splice.items() 376 if name not in control_input_names} 377 378 nodes_after_splicing = [] 379 for node in nodes_after_removal: 380 if node.name in names_to_splice: 381 continue 382 new_node = node_def_pb2.NodeDef() 383 new_node.CopyFrom(node) 384 input_before_removal = node.input 385 del new_node.input[:] 386 for full_input_name in input_before_removal: 387 input_name = re.sub(r"^\^", "", full_input_name) 388 while input_name in names_to_splice: 389 full_input_name = names_to_splice[input_name] 390 input_name = re.sub(r"^\^", "", full_input_name) 391 new_node.input.append(full_input_name) 392 nodes_after_splicing.append(new_node) 393 394 output_graph = graph_pb2.GraphDef() 395 output_graph.node.extend(nodes_after_splicing) 396 return output_graph 397