1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================= 15"""Utility to convert FunctionDef to GraphDef and Graph.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import itertools 22 23 24from tensorflow.core.framework import function_pb2 25from tensorflow.core.framework import graph_pb2 26from tensorflow.core.framework import tensor_shape_pb2 27from tensorflow.core.framework import types_pb2 28from tensorflow.core.framework import versions_pb2 29from tensorflow.python.eager import context 30from tensorflow.python.framework import cpp_shape_inference_pb2 31from tensorflow.python.framework import importer 32from tensorflow.python.framework import ops 33from tensorflow.python.framework import versions 34from tensorflow.python.framework.func_graph import FuncGraph 35from tensorflow.python.ops import resource_variable_ops 36 37 38def function_def_to_graph(fdef, input_shapes=None): 39 """Converts a FunctionDef to a FuncGraph (sub-class Graph). 40 41 The returned FuncGraph's `name`, `inputs` and `outputs` fields will be set. 42 The input tensors are represented as placeholders. 43 44 Note: `FuncGraph.inputs` and `FuncGraph.captures` are not set and may be set 45 by the caller. 46 47 Args: 48 fdef: FunctionDef. 49 input_shapes: Optional. A list of TensorShape objects of the shapes of 50 function inputs. Defaults to the function's "_input_shapes" attribute. If 51 specified, its length must match length of `fdef.signature.input_arg`. If 52 a shape is None, the corresponding input placeholder will have unknown 53 shape. 54 55 Returns: 56 A FuncGraph. 57 """ 58 func_graph = FuncGraph(fdef.signature.name) 59 if input_shapes is None: 60 input_shapes_attr = fdef.attr.get("_input_shapes", None) 61 if input_shapes_attr is not None: 62 input_shapes = input_shapes_attr.list.shape 63 graph_def, nested_to_flat_tensor_name = function_def_to_graph_def( 64 fdef, input_shapes) 65 66 with func_graph.as_default(): 67 # Add all function nodes to the graph. 68 importer.import_graph_def_for_function(graph_def, name="") 69 70 # Initialize fields specific to FuncGraph. 71 72 # inputs 73 input_tensor_names = [ 74 nested_to_flat_tensor_name[arg.name] for arg in fdef.signature.input_arg 75 ] 76 func_graph.inputs = [ 77 func_graph.get_tensor_by_name(name) for name in input_tensor_names 78 ] 79 80 # outputs 81 output_tensor_names = [ 82 nested_to_flat_tensor_name[fdef.ret[arg.name]] 83 for arg in fdef.signature.output_arg 84 ] 85 func_graph.outputs = [ 86 func_graph.get_tensor_by_name(name) for name in output_tensor_names 87 ] 88 func_graph.control_outputs = [ 89 func_graph.get_operation_by_name(fdef.control_ret[ret_name]) 90 for ret_name in fdef.signature.control_output 91 ] 92 93 _set_handle_data(func_graph, fdef) 94 95 for node in graph_def.node: 96 output_shapes = node.attr.get("_output_shapes", None) 97 if output_shapes is not None: 98 op = func_graph.get_operation_by_name(node.name) 99 # _output_shapes for functions can sometimes be too long because the 100 # output-intermediates-for-gradients version of the function was 101 # substituted before saving. We'll accept that here. (See b/133666530). 102 for output_index, shape in enumerate( 103 output_shapes.list.shape[:len(op.outputs)]): 104 op.outputs[output_index].set_shape(shape) 105 output_names = {} 106 for ret_arg_def, tensor_name in zip( 107 fdef.signature.output_arg, output_tensor_names): 108 output_names[ops.tensor_id( 109 func_graph.get_tensor_by_name(tensor_name))] = ( 110 ret_arg_def.name) 111 func_graph._output_names = output_names # pylint: disable=protected-access 112 return func_graph 113 114 115def is_function(fname): 116 """Checks for a function definition with `fname` in the current context.""" 117 if context.executing_eagerly(): 118 return context.context().has_function(fname) 119 else: 120 graph = ops.get_default_graph() 121 while graph is not None: 122 if graph._is_function(fname): # pylint: disable=protected-access 123 return True 124 if hasattr(graph, "outer_graph"): 125 graph = graph.outer_graph 126 else: 127 return False 128 129 130def function_def_to_graph_def(fdef, input_shapes=None): 131 """Convert a FunctionDef to a GraphDef. 132 133 Steps: 134 1. Creates placeholder nodes corresponding to inputs in 135 `FunctionDef.signature.input_arg`. 136 2. Adds NodeDefs in `FunctionDef.node_def` to `GraphDef.node`. 137 3. Renames inputs of all nodes to use the convention of GraphDef instead of 138 FunctionDef. See comment on `FunctionDef.node_def` on how the tensor naming 139 in FunctionDefs is different from GraphDefs. 140 141 Args: 142 fdef: FunctionDef. 143 input_shapes: Optional. A list of TensorShape objects of the shapes of 144 function inputs. If specified, its length must match length of 145 `fdef.signature.input_arg`. If a shape is None, the corresponding input 146 placeholder will have unknown shape. 147 148 Returns: 149 A tuple of (GraphDef, dict<string, string>). The dict contains a mapping 150 from nested tensor names (in FunctionDef) to flattened names (in GraphDef). 151 152 Raises: 153 ValueError: If the length of input_shapes does not match the number of 154 input_args or if the FunctionDef is invalid. 155 """ 156 graph_def = graph_pb2.GraphDef() 157 graph_def.versions.CopyFrom( 158 versions_pb2.VersionDef( 159 producer=versions.GRAPH_DEF_VERSION, 160 min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER)) 161 162 default_graph = ops.get_default_graph() 163 164 copied_functions = set() 165 166 if input_shapes and len(input_shapes) != len(fdef.signature.input_arg): 167 raise ValueError("Length of `input_shapes` must match the number " 168 f"of `input_arg`s in `fdef`. Got " 169 f"{len(input_shapes)} `input_shapes` and " 170 f"{len(fdef.signature.input_arg)} `input_arg`s.") 171 172 # 1. Create placeholders for input nodes. 173 for i, arg_def in enumerate(fdef.signature.input_arg): 174 node_def = graph_def.node.add() 175 node_def.name = arg_def.name 176 node_def.op = "Placeholder" 177 node_def.attr["dtype"].type = arg_def.type 178 if input_shapes and input_shapes[i] is not None: 179 input_shape = input_shapes[i] 180 if not isinstance(input_shape, tensor_shape_pb2.TensorShapeProto): 181 input_shape = input_shape.as_proto() 182 node_def.attr["shape"].shape.CopyFrom(input_shape) 183 arg_attrs = fdef.arg_attr[i].attr 184 for k in arg_attrs: 185 # Only copy internal attributes. Normal attributes for nodes cannot be 186 # applied to these Placeholder nodes. 187 if k == "_output_shapes": 188 node_def.attr["shape"].shape.CopyFrom(arg_attrs[k].list.shape[0]) 189 elif k.startswith("_"): 190 node_def.attr[k].CopyFrom(arg_attrs[k]) 191 192 # 2. Copy all body NodeDefs to the GraphDef. 193 graph_def.node.extend(fdef.node_def) 194 195 # 3. Perform the renaming. 196 197 # Build the tensor name mapping then flatten the tensor names. 198 # See comment on `FunctionDef.node_def` on how the tensor naming in 199 # FunctionDefs is different from GraphDefs. 200 nested_to_flat_tensor_name = {} 201 202 for arg_def in fdef.signature.input_arg: 203 nested_to_flat_tensor_name[arg_def.name] = "{}:0".format(arg_def.name) 204 control_name = "^" + arg_def.name 205 nested_to_flat_tensor_name[control_name] = control_name 206 207 for node_def in fdef.node_def: 208 graph = default_graph 209 while True: 210 f = graph._functions.get(node_def.op, None) # pylint: disable=protected-access 211 if f is not None or not hasattr(graph, "outer_graph"): 212 break 213 graph = graph.outer_graph 214 215 if f is not None: 216 op_def = f.definition.signature 217 if node_def.op not in copied_functions: 218 # Since this function is referenced as an op type, we have no choice but 219 # to copy it into the GraphDef if we want downstream tools to process 220 # it. 221 graph_def.library.function.add().CopyFrom(f.definition) 222 copied_functions.add(node_def.op) 223 if f.grad_func_name: 224 grad_def = function_pb2.GradientDef() 225 grad_def.function_name = f.name 226 grad_def.gradient_func = f.grad_func_name 227 graph_def.library.gradient.extend([grad_def]) 228 else: 229 op_def = default_graph._get_op_def(node_def.op) # pylint: disable=protected-access 230 231 for attr in op_def.attr: 232 if attr.type == "func": 233 fname = node_def.attr[attr.name].func.name 234 if not is_function(fname): 235 raise ValueError(f"Function {fname} was not found. Please make sure " 236 "the FunctionDef `fdef` is correct.") 237 elif attr.type == "list(func)": 238 for fn in node_def.attr[attr.name].list.func: 239 fname = fn.name 240 if not is_function(fname): 241 raise ValueError(f"Function {fname} was not found. Please make " 242 "sure the FunctionDef `fdef` is correct.") 243 244 # Iterate over output_args in op_def to build the map. 245 # Index of the output tensor in the flattened list of *all* output 246 # tensors of the op. 247 flattened_index = 0 248 for arg_def in op_def.output_arg: 249 num_args = _get_num_args(arg_def, node_def) 250 for i in range(num_args): 251 # Map tensor names from "node_name:output_arg_name:index" to 252 # "node_name:flattened_index". 253 nested_name = "{}:{}:{}".format(node_def.name, arg_def.name, i) 254 flat_name = "{}:{}".format(node_def.name, flattened_index) 255 nested_to_flat_tensor_name[nested_name] = flat_name 256 flattened_index += 1 257 control_name = "^" + node_def.name 258 nested_to_flat_tensor_name[control_name] = control_name 259 260 # Update inputs of all nodes in graph. 261 for node_def in graph_def.node: 262 for i in range(len(node_def.input)): 263 node_def.input[i] = nested_to_flat_tensor_name[node_def.input[i]] 264 265 return graph_def, nested_to_flat_tensor_name 266 267 268# Based on implementation in core/framework/node_def_util.cc::ComputeArgRange. 269def _get_num_args(arg_def, node_def): 270 if arg_def.number_attr: 271 return node_def.attr[arg_def.number_attr].i 272 elif arg_def.type_list_attr: 273 return len(node_def.attr[arg_def.type_list_attr].list.type) 274 elif arg_def.type_attr or arg_def.type != types_pb2.DT_INVALID: 275 return 1 276 else: 277 raise ValueError(f"Invalid arg_def:\n\n{arg_def}. Please make sure the " 278 "FunctionDef `fdef` is correct.") 279 280 281def _set_handle_data(func_graph, fdef): 282 """Adds handle data for resource type inputs and outputs.""" 283 for tensor, arg_def in itertools.chain( 284 zip(func_graph.inputs, fdef.signature.input_arg), 285 zip(func_graph.outputs, fdef.signature.output_arg)): 286 if arg_def.handle_data: 287 shape_and_dtype = arg_def.handle_data[0] 288 handle_data = cpp_shape_inference_pb2.CppShapeInferenceResult.HandleData() 289 handle_data.is_set = True 290 handle_data.shape_and_type.append( 291 cpp_shape_inference_pb2.CppShapeInferenceResult.HandleShapeAndType( 292 shape=shape_and_dtype.shape, dtype=shape_and_dtype.dtype)) 293 resource_variable_ops._set_handle_shapes_and_types( # pylint: disable=protected-access 294 tensor, handle_data, True) 295