1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Various function for graph rerouting.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21from tensorflow.contrib.graph_editor import subgraph as _subgraph 22from tensorflow.contrib.graph_editor import util as _util 23from tensorflow.python.framework import ops as _tf_ops 24 25from tensorflow.python.util.all_util import remove_undocumented 26 27_allowed_symbols = [ 28 "swap_ts", 29 "reroute_ts", 30 "swap_inputs", 31 "reroute_inputs", 32 "swap_outputs", 33 "reroute_outputs", 34 "swap_ios", 35 "reroute_ios", 36 "remove_control_inputs", 37 "add_control_inputs", 38] 39 40 41def _check_ts_compatibility(ts0, ts1): 42 """Make sure the shape and dtype of the two tensor's lists are compatible. 43 44 Args: 45 ts0: an object convertible to a list of `tf.Tensor`. 46 ts1: an object convertible to a list of `tf.Tensor`. 47 Raises: 48 ValueError: if any pair of tensors (same index in ts0 and ts1) have 49 a dtype or a shape which is not compatible. 50 """ 51 ts0 = _util.make_list_of_t(ts0) 52 ts1 = _util.make_list_of_t(ts1) 53 if len(ts0) != len(ts1): 54 raise ValueError("ts0 and ts1 have different sizes: {} != {}".format( 55 len(ts0), len(ts1))) 56 for t0, t1 in zip(ts0, ts1): 57 # check dtype 58 dtype0, dtype1 = t0.dtype, t1.dtype 59 if not dtype0.is_compatible_with(dtype1): 60 raise ValueError("Dtypes {} and {} are not compatible.".format(dtype0, 61 dtype1)) 62 # check shape 63 shape0, shape1 = t0.get_shape(), t1.get_shape() 64 if not shape0.is_compatible_with(shape1): 65 raise ValueError("Shapes {} and {} are not compatible.".format(shape0, 66 shape1)) 67 68 69class _RerouteMode(object): 70 """Enums for reroute's mode. 71 72 swap: the end of tensors a and b are swapped. 73 a2b: the end of the tensor a are also rerouted to the end of the tensor b 74 (the end of b is left dangling). 75 b2a: the end of the tensor b are also rerouted to the end of the tensor a 76 (the end of a is left dangling). 77 """ 78 swap, a2b, b2a = range(3) 79 80 @classmethod 81 def check(cls, mode): 82 """Check swap mode. 83 84 Args: 85 mode: an integer representing one of the modes. 86 Returns: 87 A tuple `(a2b, b2a)` boolean indicating what rerouting needs doing. 88 Raises: 89 ValueError: if mode is outside the enum range. 90 """ 91 if mode == cls.swap: 92 return True, True 93 elif mode == cls.b2a: 94 return False, True 95 elif mode == cls.a2b: 96 return True, False 97 else: 98 raise ValueError("Unknown _RerouteMode: {}".format(mode)) 99 100 101def _reroute_t(t0, t1, consumers1, can_modify=None, cannot_modify=None): 102 """Reroute the end of the tensors (t0,t1). 103 104 Warning: this function is directly manipulating the internals of the 105 `tf.Graph`. 106 107 Args: 108 t0: a tf.Tensor. 109 t1: a tf.Tensor. 110 consumers1: The consumers of t1 which needs to be rerouted. 111 can_modify: iterable of operations which can be modified. Any operation 112 outside within_ops will be left untouched by this function. 113 cannot_modify: iterable of operations which cannot be modified. 114 Any operation within cannot_modify will be left untouched by this 115 function. 116 Returns: 117 The number of individual modifications made by the function. 118 """ 119 nb_update_inputs = 0 120 if can_modify is not None: 121 consumers1 &= can_modify 122 if cannot_modify is not None: 123 consumers1 -= cannot_modify 124 consumers1_indices = {} 125 for consumer1 in consumers1: 126 consumers1_indices[consumer1] = [i for i, t in enumerate(consumer1.inputs) 127 if t is t1] 128 for consumer1 in consumers1: 129 for i in consumers1_indices[consumer1]: 130 consumer1._update_input(i, t0) # pylint: disable=protected-access 131 nb_update_inputs += 1 132 return nb_update_inputs 133 134 135def _reroute_ts(ts0, ts1, mode, can_modify=None, cannot_modify=None): 136 """Reroute the end of the tensors in each pair (t0,t1) in ts0 x ts1. 137 138 This function is the back-bone of the Graph-Editor. It is essentially a thin 139 wrapper on top of the tf.Operation._update_input. 140 141 Given a pair of tensor t0, t1 in ts0 x ts1, this function re-route the end 142 of t0 and t1 in three possible ways: 143 1) The reroute mode is "a<->b" or "b<->a": the tensors' end are swapped. After 144 this operation, the previous consumers of t0 are now consumers of t1 and 145 vice-versa. 146 2) The reroute mode is "a->b": the tensors' end of t0 are re-routed to the 147 tensors's end of t1 (which are left dangling). After this operation, the 148 previous consumers of t0 are still consuming t0 but the previous consumers of 149 t1 are not also consuming t0. The tensor t1 has no consumer. 150 3) The reroute mode is "b->a": this mode is the symmetric of the "a->b" mode. 151 152 Note that this function is re-routing the end of two tensors, not the start. 153 Re-routing the start of two tensors is not supported by this library. The 154 reason for that is the following: TensorFlow, by design, creates a strong bond 155 between an op and its output tensor. This Graph editor follows this design and 156 treats an operation A and its generating tensors {t_i} as an entity which 157 cannot be broken. In other words, an op cannot be detached from any of its 158 output tensors, ever. But it is possible to detach an op from its input 159 tensors, which is what this function concerns itself with. 160 161 Warning: this function is directly manipulating the internals of the tf.Graph. 162 163 Args: 164 ts0: an object convertible to a list of `tf.Tensor`. 165 ts1: an object convertible to a list of `tf.Tensor`. 166 mode: what to do with those tensors: "a->b" or "b<->a" for swaping and 167 "a->b" or "b->a" for one direction re-routing. 168 can_modify: iterable of operations which can be modified. Any operation 169 outside within_ops will be left untouched by this function. 170 cannot_modify: iterable of operations which cannot be modified. 171 Any operation within cannot_modify will be left untouched by this 172 function. 173 Returns: 174 The number of individual modifications made by the function. 175 Raises: 176 TypeError: if `ts0` or `ts1` cannot be converted to a list of `tf.Tensor`. 177 TypeError: if `can_modify` or `cannot_modify` is not `None` and cannot be 178 converted to a list of `tf.Operation`. 179 """ 180 a2b, b2a = _RerouteMode.check(mode) 181 ts0 = _util.make_list_of_t(ts0) 182 ts1 = _util.make_list_of_t(ts1) 183 _check_ts_compatibility(ts0, ts1) 184 if cannot_modify is not None: 185 cannot_modify = frozenset(_util.make_list_of_op(cannot_modify)) 186 if can_modify is not None: 187 can_modify = frozenset(_util.make_list_of_op(can_modify)) 188 nb_update_inputs = 0 189 precomputed_consumers = [] 190 # precompute consumers to avoid issue with repeated tensors: 191 for t0, t1 in zip(ts0, ts1): 192 consumers0 = set(t0.consumers()) 193 consumers1 = set(t1.consumers()) 194 precomputed_consumers.append((consumers0, consumers1)) 195 for t0, t1, consumers in zip(ts0, ts1, precomputed_consumers): 196 if t0 is t1: 197 continue # Silently ignore identical tensors. 198 consumers0, consumers1 = consumers 199 if a2b: 200 nb_update_inputs += _reroute_t(t0, t1, consumers1, can_modify, 201 cannot_modify) 202 if b2a: 203 nb_update_inputs += _reroute_t(t1, t0, consumers0, can_modify, 204 cannot_modify) 205 return nb_update_inputs 206 207 208def swap_ts(ts0, ts1, can_modify=None, cannot_modify=None): 209 """For each tensor's pair, swap the end of (t0,t1). 210 211 B0 B1 B0 B1 212 | | => X 213 A0 A1 A0 A1 214 215 Args: 216 ts0: an object convertible to a list of `tf.Tensor`. 217 ts1: an object convertible to a list of `tf.Tensor`. 218 can_modify: iterable of operations which can be modified. Any operation 219 outside within_ops will be left untouched by this function. 220 cannot_modify: iterable of operations which cannot be modified. 221 Any operation within cannot_modify will be left untouched by this 222 function. 223 Returns: 224 The number of individual modifications made by the function. 225 Raises: 226 TypeError: if ts0 or ts1 cannot be converted to a list of tf.Tensor. 227 TypeError: if can_modify or cannot_modify is not None and cannot be 228 converted to a list of tf.Operation. 229 """ 230 return _reroute_ts(ts0, ts1, _RerouteMode.swap, can_modify, cannot_modify) 231 232 233def reroute_ts(ts0, ts1, can_modify=None, cannot_modify=None): 234 """For each tensor's pair, replace the end of t1 by the end of t0. 235 236 B0 B1 B0 B1 237 | | => |/ 238 A0 A1 A0 A1 239 240 The end of the tensors in ts1 are left dangling. 241 242 Args: 243 ts0: an object convertible to a list of `tf.Tensor`. 244 ts1: an object convertible to a list of `tf.Tensor`. 245 can_modify: iterable of operations which can be modified. Any operation 246 outside within_ops will be left untouched by this function. 247 cannot_modify: iterable of operations which cannot be modified. Any 248 operation within cannot_modify will be left untouched by this function. 249 Returns: 250 The number of individual modifications made by the function. 251 Raises: 252 TypeError: if ts0 or ts1 cannot be converted to a list of tf.Tensor. 253 TypeError: if can_modify or cannot_modify is not None and cannot be 254 converted to a list of tf.Operation. 255 """ 256 return _reroute_ts(ts0, ts1, _RerouteMode.a2b, can_modify, cannot_modify) 257 258 259def _reroute_sgv_remap(sgv0, sgv1, mode): 260 """Remap in place the inputs of two subgraph views to mimic the reroute. 261 262 This function is meant to used by reroute_inputs only. 263 264 Args: 265 sgv0: the first subgraph to have its inputs remapped. 266 sgv1: the second subgraph to have its inputs remapped. 267 mode: reroute mode, see _reroute_ts(...). 268 Raises: 269 TypeError: if svg0 or svg1 are not SubGraphView. 270 ValueError: if sgv0 and sgv1 do not belong to the same graph. 271 """ 272 a2b, b2a = _RerouteMode.check(mode) 273 if not isinstance(sgv0, _subgraph.SubGraphView): 274 raise TypeError("Expected a SubGraphView, got {}".format(type(sgv0))) 275 if not isinstance(sgv1, _subgraph.SubGraphView): 276 raise TypeError("Expected a SubGraphView, got {}".format(type(sgv1))) 277 _util.check_graphs(sgv0, sgv1) 278 sgv0_ = sgv0.copy() 279 sgv1_ = sgv1.copy() 280 # pylint: disable=protected-access 281 if a2b and b2a: 282 (sgv0_._input_ts, sgv1_._input_ts) = (sgv1_._input_ts, sgv0_._input_ts) 283 (sgv0_._passthrough_ts, sgv1_._passthrough_ts) = (sgv1_._passthrough_ts, 284 sgv0_._passthrough_ts) 285 elif a2b: 286 sgv1_._input_ts = sgv0_._input_ts[:] 287 sgv1_._passthrough_ts = sgv0_._passthrough_ts[:] 288 elif b2a: 289 sgv0_._input_ts = sgv1_._input_ts[:] 290 sgv0_._passthrough_ts = sgv1_._passthrough_ts[:] 291 # pylint: enable=protected-access 292 293 # Update the passthrough outputs as well. 294 def update_passthrough_outputs(a, b): 295 # pylint: disable=protected-access 296 for i, t in enumerate(b._output_ts): 297 if t in a._passthrough_ts: 298 ii = a._input_ts.index(t) 299 b._output_ts[i] = b._input_ts[ii] 300 # pylint: enable=protected-access 301 302 if a2b: 303 update_passthrough_outputs(sgv0_, sgv1_) 304 if b2a: 305 update_passthrough_outputs(sgv1_, sgv0_) 306 307 # in-place 308 # pylint: disable=protected-access 309 sgv0._assign_from(sgv0_) 310 sgv1._assign_from(sgv1_) 311 # pylint: enable=protected-access 312 313 314def _reroute_sgv_inputs(sgv0, sgv1, mode): 315 """Re-route all the inputs of two subgraphs. 316 317 Args: 318 sgv0: the first subgraph to have its inputs swapped. This argument is 319 converted to a subgraph using the same rules than the function 320 subgraph.make_view. 321 sgv1: the second subgraph to have its inputs swapped. This argument is 322 converted to a subgraph using the same rules than the function 323 subgraph.make_view. 324 mode: reroute mode, see _reroute_ts(...). 325 Returns: 326 A tuple `(sgv0, sgv1)` of subgraph views with their inputs swapped. 327 Note that the function argument sgv0 and sgv1 are also modified in place. 328 Raises: 329 StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using 330 the same rules than the function subgraph.make_view. 331 """ 332 sgv0 = _subgraph.make_view(sgv0) 333 sgv1 = _subgraph.make_view(sgv1) 334 _util.check_graphs(sgv0, sgv1) 335 can_modify = sgv0.ops + sgv1.ops 336 # also allow consumers of passthrough to be modified: 337 can_modify += _util.get_consuming_ops(sgv0.passthroughs) 338 can_modify += _util.get_consuming_ops(sgv1.passthroughs) 339 _reroute_ts(sgv0.inputs, sgv1.inputs, mode, can_modify=can_modify) 340 _reroute_sgv_remap(sgv0, sgv1, mode) 341 return sgv0, sgv1 342 343 344def _reroute_sgv_outputs(sgv0, sgv1, mode): 345 """Re-route all the outputs of two operations. 346 347 Args: 348 sgv0: the first subgraph to have its outputs swapped. This argument is 349 converted to a subgraph using the same rules than the function 350 subgraph.make_view. 351 sgv1: the second subgraph to have its outputs swapped. This argument is 352 converted to a subgraph using the same rules than the function 353 subgraph.make_view. 354 mode: reroute mode, see _reroute_ts(...). 355 Returns: 356 A tuple `(sgv0, sgv1)` of subgraph views with their outputs swapped. 357 Note that the function argument sgv0 and sgv1 are also modified in place. 358 Raises: 359 StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using 360 the same rules than the function subgraph.make_view. 361 """ 362 sgv0 = _subgraph.make_view(sgv0) 363 sgv1 = _subgraph.make_view(sgv1) 364 _util.check_graphs(sgv0, sgv1) 365 cannot_modify = sgv0.ops + sgv1.ops 366 _reroute_ts(sgv0.outputs, sgv1.outputs, mode, cannot_modify=cannot_modify) 367 return sgv0, sgv1 368 369 370def _reroute_sgv(sgv0, sgv1, mode): 371 """Re-route both the inputs and the outputs of the two subgraph views. 372 373 This involves swapping all the inputs/outputs of the two subgraph views. 374 375 Args: 376 sgv0: the first subgraph to be swapped. This argument is converted to a 377 subgraph using the same rules than the function subgraph.make_view. 378 sgv1: the second subgraph to be swapped. This argument is converted to a 379 subgraph using the same rules than the function subgraph.make_view. 380 mode: reroute mode, see _reroute_ts(...). 381 Returns: 382 A tuple `(sgv0, sgv1)` of subgraph views with their outputs and inputs 383 swapped. 384 Note that the function argument sgv0 and sgv1 are also modified in place. 385 Raises: 386 StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using 387 the same rules than the function subgraph.make_view. 388 """ 389 _reroute_sgv_outputs(sgv0, sgv1, mode) 390 _reroute_sgv_inputs(sgv0, sgv1, mode) 391 return sgv0, sgv1 392 393 394def swap_inputs(sgv0, sgv1): 395 """Swap all the inputs of sgv0 and sgv1 (see reroute_inputs).""" 396 return _reroute_sgv_inputs(sgv0, sgv1, _RerouteMode.swap) 397 398 399def reroute_inputs(sgv0, sgv1): 400 """Re-route all the inputs of two subgraphs. 401 402 Args: 403 sgv0: the first subgraph to have its inputs swapped. This argument is 404 converted to a subgraph using the same rules than the function 405 subgraph.make_view. 406 sgv1: the second subgraph to have its inputs swapped. This argument is 407 converted to a subgraph using the same rules than the function 408 subgraph.make_view. 409 Returns: 410 A tuple `(sgv0, sgv1)` of subgraph views with their inputs swapped. 411 Note that the function argument sgv0 and sgv1 are also modified in place. 412 Raises: 413 StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using 414 the same rules than the function subgraph.make_view. 415 """ 416 return _reroute_sgv_inputs(sgv0, sgv1, _RerouteMode.a2b) 417 418 419def swap_outputs(sgv0, sgv1): 420 """Swap all the outputs of sgv0 and sgv1 (see reroute_outputs).""" 421 return _reroute_sgv_outputs(sgv0, sgv1, _RerouteMode.swap) 422 423 424def reroute_outputs(sgv0, sgv1): 425 """Re-route all the outputs of two operations. 426 427 Args: 428 sgv0: the first subgraph to have its outputs swapped. This argument is 429 converted to a subgraph using the same rules than the function 430 subgraph.make_view. 431 sgv1: the second subgraph to have its outputs swapped. This argument is 432 converted to a subgraph using the same rules than the function 433 subgraph.make_view. 434 Returns: 435 A tuple `(sgv0, sgv1)` of subgraph views with their outputs swapped. 436 Note that the function argument sgv0 and sgv1 are also modified in place. 437 Raises: 438 StandardError: if sgv0 or sgv1 cannot be converted to a SubGraphView using 439 the same rules than the function subgraph.make_view. 440 """ 441 return _reroute_sgv_outputs(sgv0, sgv1, _RerouteMode.a2b) 442 443 444def swap_ios(sgv0, sgv1): 445 """Swap the inputs and outputs of sgv1 to sgv0 (see _reroute_sgv).""" 446 return _reroute_sgv(sgv0, sgv1, _RerouteMode.swap) 447 448 449def reroute_ios(sgv0, sgv1): 450 """Re-route the inputs and outputs of sgv0 to sgv1 (see _reroute_sgv).""" 451 return _reroute_sgv(sgv0, sgv1, _RerouteMode.a2b) 452 453 454def remove_control_inputs(op, cops): 455 """Remove the control inputs cops from co. 456 457 Warning: this function is directly manipulating the internals of the 458 `tf.Graph`. 459 460 Args: 461 op: a `tf.Operation` from which to remove the control inputs. 462 cops: an object convertible to a list of `tf.Operation`. 463 Raises: 464 TypeError: if op is not a `tf.Operation`. 465 ValueError: if any cop in cops is not a control input of op. 466 """ 467 if not isinstance(op, _tf_ops.Operation): 468 raise TypeError("Expected a tf.Operation, got: {}", type(op)) 469 cops = _util.make_list_of_op(cops, allow_graph=False) 470 for cop in cops: 471 if cop not in op.control_inputs: 472 raise ValueError("{} is not a control_input of {}".format(op.name, 473 cop.name)) 474 control_inputs = [cop for cop in op.control_inputs if cop not in cops] 475 # pylint: disable=protected-access 476 op._remove_all_control_inputs() 477 op._add_control_inputs(control_inputs) 478 # pylint: enable=protected-access 479 480 481def add_control_inputs(op, cops): 482 """Add the control inputs cops to op. 483 484 Warning: this function is directly manipulating the internals of the tf.Graph. 485 486 Args: 487 op: a tf.Operation to which the control inputs are added. 488 cops: an object convertible to a list of `tf.Operation`. 489 Raises: 490 TypeError: if op is not a tf.Operation 491 ValueError: if any cop in cops is already a control input of op. 492 """ 493 if not isinstance(op, _tf_ops.Operation): 494 raise TypeError("Expected a tf.Operation, got: {}", type(op)) 495 cops = _util.make_list_of_op(cops, allow_graph=False) 496 for cop in cops: 497 if cop in op.control_inputs: 498 raise ValueError("{} is already a control_input of {}".format(cop.name, 499 op.name)) 500 op._add_control_inputs(cops) # pylint: disable=protected-access 501 502remove_undocumented(__name__, _allowed_symbols) 503