1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Helpers to manipulate a tensor graph in python. 16""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21import copy 22import re 23 24import six 25 26from tensorflow.core.framework import graph_pb2 27from tensorflow.core.framework import node_def_pb2 28from tensorflow.python import _proto_comparators 29from tensorflow.python.framework import dtypes 30from tensorflow.python.framework import ops 31from tensorflow.python.util import deprecation 32from tensorflow.python.util import lazy_loader 33from tensorflow.python.util.tf_export import tf_export 34 35# A normal import here would generate circular dependencies. 36convert_to_constants = lazy_loader.LazyLoader( 37 "convert_to_constants", globals(), 38 "tensorflow.python.framework.convert_to_constants") 39 40_VARIABLE_OPS = { 41 "Assign", 42 "AssignAdd", 43 "AssignSub", 44 "Queue", 45 "ScatterAdd", 46 "ScatterSub", 47 "ScatterUpdate", 48 "TruncatedNormal", 49 "Variable", 50 "VariableV2", 51} 52 53_CONTROL_FLOW_OP_NAMES_OR_IDENTITY = [ 54 "Switch", 55 "Enter", 56 "Exit", 57 "Identity", 58 "Merge", 59 "NextIteration", 60] 61 62 63def _is_variable_op(op): 64 """Returns true if 'op' refers to a Variable node.""" 65 return op in _VARIABLE_OPS 66 67 68@deprecation.deprecated( 69 date=None, 70 instructions="Use `tf.compat.v1.graph_util.must_run_on_cpu`") 71@tf_export(v1=["graph_util.must_run_on_cpu"]) 72def must_run_on_cpu(node, pin_variables_on_cpu=False): 73 """Returns True if the given node_def must run on CPU, otherwise False. 74 75 Args: 76 node: The node to be assigned to a device. Could be either an ops.Operation 77 or NodeDef. 78 pin_variables_on_cpu: If True, this function will return False if node_def 79 represents a variable-related op. 80 81 Returns: 82 True if the given node must run on CPU, otherwise False. 83 """ 84 85 if isinstance(node, ops.Operation): 86 node_def = node.node_def 87 else: 88 assert isinstance(node, node_def_pb2.NodeDef) 89 node_def = node 90 91 # If the op is a variable-related op, should we pin it on CPU? 92 if pin_variables_on_cpu and _is_variable_op(node_def.op): 93 return True 94 95 # Constant operations producing a string or int32 must run on CPU. 96 if node_def.op == "Const": 97 # Get the value of the 'dtype' attr 98 dtype = node_def.attr["dtype"].type 99 if dtype == dtypes.string or dtype == dtypes.int32: 100 return True 101 102 if node_def.op in ["DynamicStitch", "ParallelDynamicStitch"]: 103 dtype = node_def.attr["T"].type 104 if dtype == dtypes.int32: 105 # DynamicStitch on GPU only works for int32 values. 106 return True 107 108 if node_def.op in ["Cast"]: 109 dtype = node_def.attr["SrcT"].type 110 if dtype == dtypes.int32: 111 # Cast on GPU does not works for int32 values. 112 return True 113 return False 114 115 116################################################################################ 117# 118# device functions for use in with g.device(...) 119# 120################################################################################ 121 122 123def _node_name(n): 124 if n.startswith("^"): 125 return n[1:] 126 else: 127 return n.split(":")[0] 128 129 130def _get_colocated_node_name(colocated_node_name): 131 """Decodes colocated node name and returns it without loc:@ prepended.""" 132 colocated_node_decoded = colocated_node_name.decode("utf-8") 133 if colocated_node_decoded.startswith("loc:@"): 134 return colocated_node_decoded[5:] 135 return colocated_node_decoded 136 137 138def _extract_graph_summary(graph_def): 139 """Extracts useful information from the graph and returns them.""" 140 name_to_input_name = {} # Keyed by the dest node name. 141 name_to_node = {} # Keyed by node name. 142 143 # Keeps track of node sequences. It is important to still output the 144 # operations in the original order. 145 name_to_seq_num = {} # Keyed by node name. 146 seq = 0 147 for node in graph_def.node: 148 n = _node_name(node.name) 149 name_to_node[n] = node 150 name_to_input_name[n] = [_node_name(x) for x in node.input] 151 # Prevent colocated nodes from being lost. 152 if "_class" in node.attr: 153 for colocated_node_name in node.attr["_class"].list.s: 154 name_to_input_name[n].append( 155 _get_colocated_node_name(colocated_node_name)) 156 name_to_seq_num[n] = seq 157 seq += 1 158 return name_to_input_name, name_to_node, name_to_seq_num 159 160 161def _assert_nodes_are_present(name_to_node, nodes): 162 """Assert that nodes are present in the graph.""" 163 for d in nodes: 164 assert d in name_to_node, "%s is not in graph" % d 165 166 167def _bfs_for_reachable_nodes(target_nodes, name_to_input_name): 168 """Breadth first search for reachable nodes from target nodes.""" 169 nodes_to_keep = set() 170 # Breadth first search to find all the nodes that we should keep. 171 next_to_visit = target_nodes[:] 172 while next_to_visit: 173 node = next_to_visit[0] 174 del next_to_visit[0] 175 if node in nodes_to_keep: 176 # Already visited this node. 177 continue 178 nodes_to_keep.add(node) 179 if node in name_to_input_name: 180 next_to_visit += name_to_input_name[node] 181 return nodes_to_keep 182 183 184@deprecation.deprecated( 185 date=None, 186 instructions="Use `tf.compat.v1.graph_util.extract_sub_graph`") 187@tf_export(v1=["graph_util.extract_sub_graph"]) 188def extract_sub_graph(graph_def, dest_nodes): 189 """Extract the subgraph that can reach any of the nodes in 'dest_nodes'. 190 191 Args: 192 graph_def: A graph_pb2.GraphDef proto. 193 dest_nodes: A list of strings specifying the destination node names. 194 Returns: 195 The GraphDef of the sub-graph. 196 197 Raises: 198 TypeError: If 'graph_def' is not a graph_pb2.GraphDef proto. 199 """ 200 201 if not isinstance(graph_def, graph_pb2.GraphDef): 202 raise TypeError("graph_def must be a graph_pb2.GraphDef proto.") 203 204 if isinstance(dest_nodes, six.string_types): 205 raise TypeError("dest_nodes must be a list.") 206 207 name_to_input_name, name_to_node, name_to_seq_num = _extract_graph_summary( 208 graph_def) 209 _assert_nodes_are_present(name_to_node, dest_nodes) 210 211 nodes_to_keep = _bfs_for_reachable_nodes(dest_nodes, name_to_input_name) 212 213 nodes_to_keep_list = sorted( 214 list(nodes_to_keep), key=lambda n: name_to_seq_num[n]) 215 # Now construct the output GraphDef 216 out = graph_pb2.GraphDef() 217 for n in nodes_to_keep_list: 218 out.node.extend([copy.deepcopy(name_to_node[n])]) 219 out.library.CopyFrom(graph_def.library) 220 out.versions.CopyFrom(graph_def.versions) 221 222 return out 223 224 225@deprecation.deprecated( 226 date=None, 227 instructions="Use `tf.compat.v1.graph_util.tensor_shape_from_node_def_name`" 228) 229@tf_export(v1=["graph_util.tensor_shape_from_node_def_name"]) 230def tensor_shape_from_node_def_name(graph, input_name): 231 """Convenience function to get a shape from a NodeDef's input string.""" 232 # To get a tensor, the name must be in the form <input>:<port>, for example 233 # 'Mul:0'. The GraphDef input strings don't always have the port specified 234 # though, so if there isn't a colon we need to add a default ':0' to the end. 235 if ":" not in input_name: 236 canonical_name = input_name + ":0" 237 else: 238 canonical_name = input_name 239 tensor = graph.get_tensor_by_name(canonical_name) 240 shape = tensor.get_shape() 241 return shape 242 243 244@deprecation.deprecated( 245 date=None, 246 instructions="Use `tf.compat.v1.graph_util.convert_variables_to_constants`") 247@tf_export(v1=["graph_util.convert_variables_to_constants"]) 248def convert_variables_to_constants(sess, 249 input_graph_def, 250 output_node_names, 251 variable_names_whitelist=None, 252 variable_names_blacklist=None): 253 """Replaces all the variables in a graph with constants of the same values. 254 255 If you have a trained graph containing Variable ops, it can be convenient to 256 convert them all to Const ops holding the same values. This makes it possible 257 to describe the network fully with a single GraphDef file, and allows the 258 removal of a lot of ops related to loading and saving the variables. 259 260 Args: 261 sess: Active TensorFlow session containing the variables. 262 input_graph_def: GraphDef object holding the network. 263 output_node_names: List of name strings for the result nodes of the graph. 264 variable_names_whitelist: The set of variable names to convert (by default, 265 all variables are converted). 266 variable_names_blacklist: The set of variable names to omit converting 267 to constants. 268 269 Returns: 270 GraphDef containing a simplified version of the original. 271 272 Raises: 273 RuntimeError: if a DT_RESOURCE op is found whose ancestor Variables are both 274 denylisted AND whitelisted for freezing. 275 """ 276 ret = convert_to_constants.convert_variables_to_constants_from_session_graph( 277 session=sess, 278 graph_def=input_graph_def, 279 output_node_names=output_node_names, 280 variable_names_allowlist=variable_names_whitelist, 281 variable_names_denylist=variable_names_blacklist) 282 # The previous code logic generated an empty versions field, we clear it here 283 # to maintain backwards compatibility. 284 ret.versions.Clear() 285 return ret 286 287 288@deprecation.deprecated( 289 date=None, 290 instructions="Use `tf.compat.v1.graph_util.remove_training_nodes`") 291@tf_export(v1=["graph_util.remove_training_nodes"]) 292def remove_training_nodes(input_graph, protected_nodes=None): 293 """Prunes out nodes that aren't needed for inference. 294 295 There are nodes like Identity and CheckNumerics that are only useful 296 during training, and can be removed in graphs that will be used for 297 nothing but inference. Here we identify and remove them, returning an 298 equivalent graph. To be specific, CheckNumerics nodes are always removed, and 299 Identity nodes that aren't involved in control edges are spliced out so that 300 their input and outputs are directly connected. 301 302 Args: 303 input_graph: Model to analyze and prune. 304 protected_nodes: An optional list of names of nodes to be kept 305 unconditionally. This is for example useful to preserve Identity output 306 nodes. 307 308 Returns: 309 A list of nodes with the unnecessary ones removed. 310 """ 311 if not protected_nodes: 312 protected_nodes = [] 313 314 types_to_remove = {"CheckNumerics": True} 315 316 input_nodes = input_graph.node 317 names_to_remove = {} 318 for node in input_nodes: 319 if node.op in types_to_remove and node.name not in protected_nodes: 320 names_to_remove[node.name] = True 321 322 nodes_after_removal = [] 323 for node in input_nodes: 324 if node.name in names_to_remove: 325 continue 326 new_node = node_def_pb2.NodeDef() 327 new_node.CopyFrom(node) 328 input_before_removal = node.input 329 del new_node.input[:] 330 for full_input_name in input_before_removal: 331 input_name = re.sub(r"^\^", "", full_input_name) 332 if input_name in names_to_remove: 333 continue 334 new_node.input.append(full_input_name) 335 nodes_after_removal.append(new_node) 336 337 types_to_splice = {"Identity": True} 338 control_input_names = set() 339 node_names_with_control_input = set() 340 for node in nodes_after_removal: 341 for node_input in node.input: 342 if "^" in node_input: 343 control_input_names.add(node_input.replace("^", "")) 344 node_names_with_control_input.add(node.name) 345 346 names_to_splice = {} 347 for node in nodes_after_removal: 348 if node.op in types_to_splice and node.name not in protected_nodes: 349 # We don't want to remove nodes that have control edge inputs, because 350 # they might be involved in subtle dependency issues that removing them 351 # will jeopardize. 352 if node.name not in node_names_with_control_input: 353 names_to_splice[node.name] = node.input[0] 354 355 # We also don't want to remove nodes which are used as control edge inputs. 356 names_to_splice = {name: value for name, value in names_to_splice.items() 357 if name not in control_input_names} 358 359 nodes_after_splicing = [] 360 for node in nodes_after_removal: 361 if node.name in names_to_splice: 362 continue 363 new_node = node_def_pb2.NodeDef() 364 new_node.CopyFrom(node) 365 input_before_removal = node.input 366 del new_node.input[:] 367 for full_input_name in input_before_removal: 368 input_name = re.sub(r"^\^", "", full_input_name) 369 while input_name in names_to_splice: 370 full_input_name = names_to_splice[input_name] 371 input_name = re.sub(r"^\^", "", full_input_name) 372 new_node.input.append(full_input_name) 373 nodes_after_splicing.append(new_node) 374 375 output_graph = graph_pb2.GraphDef() 376 output_graph.node.extend(nodes_after_splicing) 377 return output_graph 378 379 380def graph_defs_equal(graph_def_1: graph_pb2.GraphDef, 381 graph_def_2: graph_pb2.GraphDef, 382 treat_nan_as_equal: bool = False) -> bool: 383 """Returns True iff the graph def arguments are structurally equivalent. 384 385 The notion of equivalence encoded here checks that the set of NodeDefs in 386 the GraphDef's function library and main graph body are identical. 387 Additionally, it checks that the functions in the function library are equal 388 as sets. 389 390 Args: 391 graph_def_1: Instance of `graph_pb2.GraphDef` to compare. 392 graph_def_2: Instance of `graph_pb2.GraphDef` to compare. 393 treat_nan_as_equal: Boolean indicating whether or not to treat nan 394 floating-point values as equal. This is crucial for any equivalence 395 relation defined over GraphDefs, to ensure symmetry. 396 397 Returns: 398 Boolean indicating structural equivalence as described above. 399 400 Raises: 401 TypeError: If either of the GraphDefs are not instances of 402 `graph_pb2.GraphDef`. 403 """ 404 if not isinstance(graph_def_1, graph_pb2.GraphDef): 405 raise TypeError("graph_def_1 must be a graph_pb2.GraphDef proto.") 406 if not isinstance(graph_def_2, graph_pb2.GraphDef): 407 raise TypeError("graph_def_2 must be a graph_pb2.GraphDef proto.") 408 options = _proto_comparators.ProtoComparisonOptions(treat_nan_as_equal) 409 return _proto_comparators.EqualsGraphDef(graph_def_1.SerializeToString(), 410 graph_def_2.SerializeToString(), 411 options) 412