1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Various ways of selecting operations and tensors in a graph.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import re 22 23from six import iteritems 24from six import string_types 25 26from tensorflow.contrib.graph_editor import util 27from tensorflow.python.framework import ops as tf_ops 28 29__all__ = [ 30 "can_be_regex", 31 "make_regex", 32 "filter_ts", 33 "filter_ts_from_regex", 34 "filter_ops", 35 "filter_ops_from_regex", 36 "get_name_scope_ops", 37 "check_cios", 38 "get_ops_ios", 39 "compute_boundary_ts", 40 "get_within_boundary_ops", 41 "get_forward_walk_ops", 42 "get_backward_walk_ops", 43 "get_walks_intersection_ops", 44 "get_walks_union_ops", 45 "select_ops", 46 "select_ts", 47 "select_ops_and_ts", 48] 49 50_RE_TYPE = type(re.compile("")) 51 52 53def can_be_regex(obj): 54 """Return True if obj can be turned into a regular expression.""" 55 return isinstance(obj, string_types + (_RE_TYPE,)) 56 57 58def make_regex(obj): 59 """Return a compiled regular expression. 60 61 Args: 62 obj: a string or a regular expression. 63 Returns: 64 A compiled regular expression. 65 Raises: 66 ValueError: if obj could not be converted to a regular expression. 67 """ 68 if not can_be_regex(obj): 69 raise ValueError("Expected a string or a regex, got: {}".format(type(obj))) 70 71 if isinstance(obj, string_types): 72 return re.compile(obj) 73 else: 74 return obj 75 76 77def _get_input_ts(ops): 78 """Compute the list of unique input tensors of all the op in ops. 79 80 Args: 81 ops: an object convertible to a list of `tf.Operation`. 82 Returns: 83 The list of unique input tensors of all the op in ops. 84 Raises: 85 TypeError: if ops cannot be converted to a list of `tf.Operation`. 86 """ 87 ops = util.make_list_of_op(ops) 88 ts = [] 89 ts_set = set() 90 for op in ops: 91 for t in op.inputs: 92 if t not in ts_set: 93 ts.append(t) 94 ts_set.add(t) 95 return ts 96 97 98def _get_output_ts(ops): 99 """Compute the list of unique output tensors of all the op in ops. 100 101 Args: 102 ops: an object convertible to a list of tf.Operation. 103 Returns: 104 The list of unique output tensors of all the op in ops. 105 Raises: 106 TypeError: if ops cannot be converted to a list of tf.Operation. 107 """ 108 ops = util.make_list_of_op(ops) 109 ts = [] 110 for op in ops: 111 ts += op.outputs 112 return ts 113 114 115def filter_ts(ops, positive_filter): 116 """Get all the tensors which are input or output of an op in ops. 117 118 Args: 119 ops: an object convertible to a list of `tf.Operation`. 120 positive_filter: a function deciding whether to keep a tensor or not. 121 If `True`, all the tensors are returned. 122 Returns: 123 A list of `tf.Tensor`. 124 Raises: 125 TypeError: if ops cannot be converted to a list of `tf.Operation`. 126 """ 127 ops = util.make_list_of_op(ops) 128 ts = _get_input_ts(ops) 129 util.concatenate_unique(ts, _get_output_ts(ops)) 130 if positive_filter is not True: 131 ts = [t for t in ts if positive_filter(t)] 132 return ts 133 134 135def filter_ts_from_regex(ops, regex): 136 r"""Get all the tensors linked to ops that match the given regex. 137 138 Args: 139 ops: an object convertible to a list of tf.Operation. 140 regex: a regular expression matching the tensors' name. 141 For example, "^foo(/.*)?:\d+$" will match all the tensors in the "foo" 142 scope. 143 Returns: 144 A list of tf.Tensor. 145 Raises: 146 TypeError: if ops cannot be converted to a list of tf.Operation. 147 """ 148 ops = util.make_list_of_op(ops) 149 regex_obj = make_regex(regex) 150 return filter_ts(ops, positive_filter=lambda op: regex_obj.search(op.name)) 151 152 153def filter_ops(ops, positive_filter): 154 """Get the ops passing the given filter. 155 156 Args: 157 ops: an object convertible to a list of tf.Operation. 158 positive_filter: a function deciding where to keep an operation or not. 159 If True, all the operations are returned. 160 Returns: 161 A list of selected tf.Operation. 162 Raises: 163 TypeError: if ops cannot be converted to a list of tf.Operation. 164 """ 165 ops = util.make_list_of_op(ops) 166 if positive_filter is not True: # pylint: disable=g-explicit-bool-comparison 167 ops = [op for op in ops if positive_filter(op)] 168 return ops 169 170 171def filter_ops_from_regex(ops, regex): 172 """Get all the operations that match the given regex. 173 174 Args: 175 ops: an object convertible to a list of `tf.Operation`. 176 regex: a regular expression matching the operation's name. 177 For example, `"^foo(/.*)?$"` will match all the operations in the "foo" 178 scope. 179 Returns: 180 A list of `tf.Operation`. 181 Raises: 182 TypeError: if ops cannot be converted to a list of `tf.Operation`. 183 """ 184 ops = util.make_list_of_op(ops) 185 regex_obj = make_regex(regex) 186 return filter_ops(ops, lambda op: regex_obj.search(op.name)) 187 188 189def get_name_scope_ops(ops, scope): 190 """Get all the operations under the given scope path. 191 192 Args: 193 ops: an object convertible to a list of tf.Operation. 194 scope: a scope path. 195 Returns: 196 A list of tf.Operation. 197 Raises: 198 TypeError: if ops cannot be converted to a list of tf.Operation. 199 """ 200 if scope and scope[-1] == "/": 201 scope = scope[:-1] 202 return filter_ops_from_regex(ops, "^{}(/.*)?$".format(scope)) 203 204 205def check_cios(control_inputs=False, control_outputs=None, control_ios=None): 206 """Do various check on control_inputs and control_outputs. 207 208 Args: 209 control_inputs: A boolean indicating whether control inputs are enabled. 210 control_outputs: An instance of util.ControlOutputs or None. If not None, 211 control outputs are enabled. 212 control_ios: An instance of util.ControlOutputs or None. If not None, both 213 control inputs and control outputs are enabled. This is equivalent to set 214 control_inputs to True and control_outputs to the util.ControlOutputs 215 instance. 216 Returns: 217 A tuple `(control_inputs, control_outputs)` where: 218 `control_inputs` is a boolean indicating whether to use control inputs. 219 `control_outputs` is an instance of util.ControlOutputs or None 220 Raises: 221 ValueError: if control_inputs is an instance of util.ControlOutputs but 222 control_outputs is not None 223 TypeError: if control_outputs is not None and is not a util.ControlOutputs. 224 """ 225 if control_ios is not None: 226 if not isinstance(control_ios, util.ControlOutputs): 227 raise TypeError("Expected a util.ControlOutputs, got: {}".format( 228 type(control_ios))) 229 if control_outputs is not None: 230 raise ValueError("control_outputs should be None when using control_ios.") 231 control_inputs = True 232 control_outputs = control_ios 233 elif control_outputs is not None: 234 if not isinstance(control_outputs, util.ControlOutputs): 235 raise TypeError("Expected a util.ControlOutputs, got: {}".format( 236 type(control_outputs))) 237 238 if control_outputs is not None: 239 control_outputs.update() 240 return control_inputs, control_outputs 241 242 243def get_ops_ios(ops, control_inputs=False, control_outputs=None, 244 control_ios=None): 245 """Return all the `tf.Operation` which are connected to an op in ops. 246 247 Args: 248 ops: an object convertible to a list of `tf.Operation`. 249 control_inputs: A boolean indicating whether control inputs are enabled. 250 control_outputs: An instance of `util.ControlOutputs` or `None`. If not 251 `None`, control outputs are enabled. 252 control_ios: An instance of `util.ControlOutputs` or `None`. If not `None`, 253 both control inputs and control outputs are enabled. This is equivalent to 254 set `control_inputs` to `True` and `control_outputs` to the 255 `util.ControlOutputs` instance. 256 Returns: 257 All the `tf.Operation` surrounding the given ops. 258 Raises: 259 TypeError: if `ops` cannot be converted to a list of `tf.Operation`. 260 """ 261 control_inputs, control_outputs = check_cios(control_inputs, control_outputs, 262 control_ios) 263 ops = util.make_list_of_op(ops) 264 res = [] 265 for op in ops: 266 util.concatenate_unique(res, [t.op for t in op.inputs]) 267 for t in op.outputs: 268 util.concatenate_unique(res, t.consumers()) 269 if control_outputs is not None: 270 util.concatenate_unique(res, control_outputs.get(op)) 271 if control_inputs: 272 util.concatenate_unique(res, op.control_inputs) 273 return res 274 275 276def compute_boundary_ts(ops): 277 """Compute the tensors at the boundary of a set of ops. 278 279 This function looks at all the tensors connected to the given ops (in/out) 280 and classify them into three categories: 281 1) input tensors: tensors whose generating operation is not in ops. 282 2) output tensors: tensors whose consumer operations are not in ops 283 3) inside tensors: tensors which are neither input nor output tensors. 284 285 Note that a tensor can be both an inside tensor and an output tensor if it is 286 consumed by operations both outside and inside of `ops`. 287 288 Args: 289 ops: an object convertible to a list of tf.Operation. 290 Returns: 291 A tuple `(outside_input_ts, outside_output_ts, inside_ts)` where: 292 `outside_input_ts` is a Python list of input tensors; 293 `outside_output_ts` is a python list of output tensors; 294 `inside_ts` is a python list of inside tensors. 295 Since a tensor can be both an inside tensor and an output tensor, 296 `outside_output_ts` and `inside_ts` might intersect. 297 Raises: 298 TypeError: if ops cannot be converted to a list of tf.Operation. 299 """ 300 ops = util.make_list_of_op(ops) 301 input_ts = _get_input_ts(ops) 302 output_ts = _get_output_ts(ops) 303 output_ts_set = frozenset(output_ts) 304 ops_set = frozenset(ops) 305 306 # Compute inside tensors. 307 inside_ts = [] 308 only_inside_ts = [] 309 for t in input_ts: 310 # Skip if the input tensor is not also an output tensor. 311 if t not in output_ts_set: 312 continue 313 # Mark as "inside". 314 inside_ts.append(t) 315 # Mark as "only inside" if the tensor is not both inside and output. 316 consumers = frozenset(t.consumers()) 317 if consumers - ops_set: 318 continue 319 only_inside_ts.append(t) 320 321 inside_ts_set = frozenset(inside_ts) 322 only_inside_ts_set = frozenset(only_inside_ts) 323 outside_output_ts = [t for t in output_ts if t not in only_inside_ts_set] 324 outside_input_ts = [t for t in input_ts if t not in inside_ts_set] 325 return outside_input_ts, outside_output_ts, inside_ts 326 327 328def get_within_boundary_ops(ops, 329 seed_ops, 330 boundary_ops=(), 331 inclusive=True, 332 control_inputs=False, 333 control_outputs=None, 334 control_ios=None): 335 """Return all the `tf.Operation` within the given boundary. 336 337 Args: 338 ops: an object convertible to a list of `tf.Operation`. those ops define the 339 set in which to perform the operation (if a `tf.Graph` is given, it 340 will be converted to the list of all its operations). 341 seed_ops: the operations from which to start expanding. 342 boundary_ops: the ops forming the boundary. 343 inclusive: if `True`, the result will also include the boundary ops. 344 control_inputs: A boolean indicating whether control inputs are enabled. 345 control_outputs: An instance of `util.ControlOutputs` or `None`. If not 346 `None`, control outputs are enabled. 347 control_ios: An instance of `util.ControlOutputs` or `None`. If not 348 `None`, both control inputs and control outputs are enabled. This is 349 equivalent to set control_inputs to True and control_outputs to 350 the `util.ControlOutputs` instance. 351 Returns: 352 All the `tf.Operation` surrounding the given ops. 353 Raises: 354 TypeError: if `ops` or `seed_ops` cannot be converted to a list of 355 `tf.Operation`. 356 ValueError: if the boundary is intersecting with the seeds. 357 """ 358 control_inputs, control_outputs = check_cios(control_inputs, control_outputs, 359 control_ios) 360 ops = util.make_list_of_op(ops) 361 seed_ops = util.make_list_of_op(seed_ops, allow_graph=False) 362 boundary_ops = set(util.make_list_of_op(boundary_ops)) 363 res = set(seed_ops) 364 if boundary_ops & res: 365 raise ValueError("Boundary is intersecting with the seeds.") 366 wave = set(seed_ops) 367 while wave: 368 new_wave = set() 369 ops_io = get_ops_ios(wave, control_inputs, control_outputs) 370 for op in ops_io: 371 if op in res: 372 continue 373 if op in boundary_ops: 374 if inclusive: 375 res.add(op) 376 else: 377 new_wave.add(op) 378 res.update(new_wave) 379 wave = new_wave 380 return [op for op in ops if op in res] 381 382 383def get_forward_walk_ops(seed_ops, 384 inclusive=True, 385 within_ops=None, 386 within_ops_fn=None, 387 stop_at_ts=(), 388 control_outputs=None): 389 """Do a forward graph walk and return all the visited ops. 390 391 Args: 392 seed_ops: an iterable of operations from which the forward graph 393 walk starts. If a list of tensors is given instead, the seed_ops are set 394 to be the consumers of those tensors. 395 inclusive: if True the given seed_ops are also part of the resulting set. 396 within_ops: an iterable of `tf.Operation` within which the search is 397 restricted. If `within_ops` is `None`, the search is performed within 398 the whole graph. 399 within_ops_fn: if provided, a function on ops that should return True iff 400 the op is within the graph traversal. This can be used along within_ops, 401 in which case an op is within if it is also in within_ops. 402 stop_at_ts: an iterable of tensors at which the graph walk stops. 403 control_outputs: a `util.ControlOutputs` instance or None. 404 If not `None`, it will be used while walking the graph forward. 405 Returns: 406 A Python set of all the `tf.Operation` ahead of `seed_ops`. 407 Raises: 408 TypeError: if `seed_ops` or `within_ops` cannot be converted to a list of 409 `tf.Operation`. 410 """ 411 _, control_outputs = check_cios(False, control_outputs) 412 if not util.is_iterable(seed_ops): 413 seed_ops = [seed_ops] 414 if not seed_ops: 415 return [] 416 if isinstance(seed_ops[0], tf_ops.Tensor): 417 ts = util.make_list_of_t(seed_ops, allow_graph=False) 418 seed_ops = util.get_consuming_ops(ts) 419 else: 420 seed_ops = util.make_list_of_op(seed_ops, allow_graph=False) 421 422 seed_ops = frozenset(seed_ops) 423 stop_at_ts = frozenset(util.make_list_of_t(stop_at_ts)) 424 if within_ops: 425 within_ops = util.make_list_of_op(within_ops, allow_graph=False) 426 within_ops = frozenset(within_ops) 427 seed_ops &= within_ops 428 429 def is_within(op): 430 return (within_ops is None or op in within_ops) and ( 431 within_ops_fn is None or within_ops_fn(op)) 432 433 result = list(seed_ops) 434 wave = set(seed_ops) 435 while wave: 436 new_wave = set() 437 for op in wave: 438 for new_t in op.outputs: 439 if new_t in stop_at_ts: 440 continue 441 for new_op in new_t.consumers(): 442 if new_op not in result and is_within(new_op): 443 new_wave.add(new_op) 444 if control_outputs is not None: 445 for new_op in control_outputs.get(op): 446 if new_op not in result and is_within(new_op): 447 new_wave.add(new_op) 448 util.concatenate_unique(result, new_wave) 449 wave = new_wave 450 if not inclusive: 451 result = [op for op in result if op not in seed_ops] 452 return result 453 454 455def get_backward_walk_ops(seed_ops, 456 inclusive=True, 457 within_ops=None, 458 within_ops_fn=None, 459 stop_at_ts=(), 460 control_inputs=False): 461 """Do a backward graph walk and return all the visited ops. 462 463 Args: 464 seed_ops: an iterable of operations from which the backward graph 465 walk starts. If a list of tensors is given instead, the seed_ops are set 466 to be the generators of those tensors. 467 inclusive: if True the given seed_ops are also part of the resulting set. 468 within_ops: an iterable of `tf.Operation` within which the search is 469 restricted. If `within_ops` is `None`, the search is performed within 470 the whole graph. 471 within_ops_fn: if provided, a function on ops that should return True iff 472 the op is within the graph traversal. This can be used along within_ops, 473 in which case an op is within if it is also in within_ops. 474 stop_at_ts: an iterable of tensors at which the graph walk stops. 475 control_inputs: if True, control inputs will be used while moving backward. 476 Returns: 477 A Python set of all the `tf.Operation` behind `seed_ops`. 478 Raises: 479 TypeError: if `seed_ops` or `within_ops` cannot be converted to a list of 480 `tf.Operation`. 481 """ 482 if not util.is_iterable(seed_ops): 483 seed_ops = [seed_ops] 484 if not seed_ops: 485 return [] 486 if isinstance(seed_ops[0], tf_ops.Tensor): 487 ts = util.make_list_of_t(seed_ops, allow_graph=False) 488 seed_ops = util.get_generating_ops(ts) 489 else: 490 seed_ops = util.make_list_of_op(seed_ops, allow_graph=False) 491 492 stop_at_ts = frozenset(util.make_list_of_t(stop_at_ts)) 493 seed_ops = frozenset(util.make_list_of_op(seed_ops)) 494 if within_ops: 495 within_ops = util.make_list_of_op(within_ops, allow_graph=False) 496 within_ops = frozenset(within_ops) 497 seed_ops &= within_ops 498 499 def is_within(op): 500 return (within_ops is None or op in within_ops) and ( 501 within_ops_fn is None or within_ops_fn(op)) 502 503 result = list(seed_ops) 504 wave = set(seed_ops) 505 while wave: 506 new_wave = set() 507 for op in wave: 508 for new_t in op.inputs: 509 if new_t in stop_at_ts: 510 continue 511 if new_t.op not in result and is_within(new_t.op): 512 new_wave.add(new_t.op) 513 if control_inputs: 514 for new_op in op.control_inputs: 515 if new_op not in result and is_within(new_op): 516 new_wave.add(new_op) 517 util.concatenate_unique(result, new_wave) 518 wave = new_wave 519 if not inclusive: 520 result = [op for op in result if op not in seed_ops] 521 return result 522 523 524def get_walks_intersection_ops(forward_seed_ops, 525 backward_seed_ops, 526 forward_inclusive=True, 527 backward_inclusive=True, 528 within_ops=None, 529 within_ops_fn=None, 530 control_inputs=False, 531 control_outputs=None, 532 control_ios=None): 533 """Return the intersection of a forward and a backward walk. 534 535 Args: 536 forward_seed_ops: an iterable of operations from which the forward graph 537 walk starts. If a list of tensors is given instead, the seed_ops are set 538 to be the consumers of those tensors. 539 backward_seed_ops: an iterable of operations from which the backward graph 540 walk starts. If a list of tensors is given instead, the seed_ops are set 541 to be the generators of those tensors. 542 forward_inclusive: if True the given forward_seed_ops are also part of the 543 resulting set. 544 backward_inclusive: if True the given backward_seed_ops are also part of the 545 resulting set. 546 within_ops: an iterable of tf.Operation within which the search is 547 restricted. If within_ops is None, the search is performed within 548 the whole graph. 549 within_ops_fn: if provided, a function on ops that should return True iff 550 the op is within the graph traversal. This can be used along within_ops, 551 in which case an op is within if it is also in within_ops. 552 control_inputs: A boolean indicating whether control inputs are enabled. 553 control_outputs: An instance of util.ControlOutputs or None. If not None, 554 control outputs are enabled. 555 control_ios: An instance of util.ControlOutputs or None. If not None, both 556 control inputs and control outputs are enabled. This is equivalent to set 557 control_inputs to True and control_outputs to the util.ControlOutputs 558 instance. 559 Returns: 560 A Python set of all the tf.Operation in the intersection of a forward and a 561 backward walk. 562 Raises: 563 TypeError: if `forward_seed_ops` or `backward_seed_ops` or `within_ops` 564 cannot be converted to a list of `tf.Operation`. 565 """ 566 control_inputs, control_outputs = check_cios(control_inputs, control_outputs, 567 control_ios) 568 forward_ops = get_forward_walk_ops( 569 forward_seed_ops, 570 inclusive=forward_inclusive, 571 within_ops=within_ops, 572 within_ops_fn=within_ops_fn, 573 control_outputs=control_outputs) 574 backward_ops = get_backward_walk_ops( 575 backward_seed_ops, 576 inclusive=backward_inclusive, 577 within_ops=within_ops, 578 within_ops_fn=within_ops_fn, 579 control_inputs=control_inputs) 580 return [op for op in forward_ops if op in backward_ops] 581 582 583def get_walks_union_ops(forward_seed_ops, 584 backward_seed_ops, 585 forward_inclusive=True, 586 backward_inclusive=True, 587 within_ops=None, 588 within_ops_fn=None, 589 control_inputs=False, 590 control_outputs=None, 591 control_ios=None): 592 """Return the union of a forward and a backward walk. 593 594 Args: 595 forward_seed_ops: an iterable of operations from which the forward graph 596 walk starts. If a list of tensors is given instead, the seed_ops are set 597 to be the consumers of those tensors. 598 backward_seed_ops: an iterable of operations from which the backward graph 599 walk starts. If a list of tensors is given instead, the seed_ops are set 600 to be the generators of those tensors. 601 forward_inclusive: if True the given forward_seed_ops are also part of the 602 resulting set. 603 backward_inclusive: if True the given backward_seed_ops are also part of the 604 resulting set. 605 within_ops: restrict the search within those operations. If within_ops is 606 None, the search is done within the whole graph. 607 within_ops_fn: if provided, a function on ops that should return True iff 608 the op is within the graph traversal. This can be used along within_ops, 609 in which case an op is within if it is also in within_ops. 610 control_inputs: A boolean indicating whether control inputs are enabled. 611 control_outputs: An instance of util.ControlOutputs or None. If not None, 612 control outputs are enabled. 613 control_ios: An instance of util.ControlOutputs or None. If not None, both 614 control inputs and control outputs are enabled. This is equivalent to set 615 control_inputs to True and control_outputs to the util.ControlOutputs 616 instance. 617 Returns: 618 A Python set of all the tf.Operation in the union of a forward and a 619 backward walk. 620 Raises: 621 TypeError: if forward_seed_ops or backward_seed_ops or within_ops cannot be 622 converted to a list of tf.Operation. 623 """ 624 control_inputs, control_outputs = check_cios(control_inputs, control_outputs, 625 control_ios) 626 forward_ops = get_forward_walk_ops( 627 forward_seed_ops, 628 inclusive=forward_inclusive, 629 within_ops=within_ops, 630 within_ops_fn=within_ops_fn, 631 control_outputs=control_outputs) 632 backward_ops = get_backward_walk_ops( 633 backward_seed_ops, 634 inclusive=backward_inclusive, 635 within_ops=within_ops, 636 within_ops_fn=within_ops_fn, 637 control_inputs=control_inputs) 638 return util.concatenate_unique(forward_ops, backward_ops) 639 640 641def select_ops(*args, **kwargs): 642 """Helper to select operations. 643 644 Args: 645 *args: list of 1) regular expressions (compiled or not) or 2) (array of) 646 `tf.Operation`. `tf.Tensor` instances are silently ignored. 647 **kwargs: 'graph': `tf.Graph` in which to perform the regex query.This is 648 required when using regex. 649 'positive_filter': an elem if selected only if `positive_filter(elem)` is 650 `True`. This is optional. 651 'restrict_ops_regex': a regular expression is ignored if it doesn't start 652 with the substring "(?#ops)". 653 Returns: 654 A list of `tf.Operation`. 655 Raises: 656 TypeError: if the optional keyword argument graph is not a `tf.Graph` 657 or if an argument in args is not an (array of) `tf.Operation` 658 or an (array of) `tf.Tensor` (silently ignored) or a string 659 or a regular expression. 660 ValueError: if one of the keyword arguments is unexpected or if a regular 661 expression is used without passing a graph as a keyword argument. 662 """ 663 # get keywords arguments 664 graph = None 665 positive_filter = None 666 restrict_ops_regex = False 667 for k, v in iteritems(kwargs): 668 if k == "graph": 669 graph = v 670 if graph is not None and not isinstance(graph, tf_ops.Graph): 671 raise TypeError("Expected a tf.Graph, got: {}".format(type(graph))) 672 elif k == "positive_filter": 673 positive_filter = v 674 elif k == "restrict_ops_regex": 675 restrict_ops_regex = v 676 elif k == "restrict_ts_regex": 677 pass 678 else: 679 raise ValueError("Wrong keywords argument: {}.".format(k)) 680 681 ops = [] 682 683 for arg in args: 684 if can_be_regex(arg): 685 if graph is None: 686 raise ValueError("Use the keyword argument 'graph' to use regex.") 687 regex = make_regex(arg) 688 if regex.pattern.startswith("(?#ts)"): 689 continue 690 if restrict_ops_regex and not regex.pattern.startswith("(?#ops)"): 691 continue 692 ops_ = filter_ops_from_regex(graph, regex) 693 for op_ in ops_: 694 if op_ not in ops: 695 if positive_filter is None or positive_filter(op_): 696 ops.append(op_) 697 else: 698 ops_aux = util.make_list_of_op(arg, ignore_ts=True) 699 if positive_filter is not None: 700 ops_aux = [op for op in ops_aux if positive_filter(op)] 701 ops_aux = [op for op in ops_aux if op not in ops] 702 ops += ops_aux 703 704 return ops 705 706 707def select_ts(*args, **kwargs): 708 """Helper to select tensors. 709 710 Args: 711 *args: list of 1) regular expressions (compiled or not) or 2) (array of) 712 `tf.Tensor`. `tf.Operation` instances are silently ignored. 713 **kwargs: 'graph': `tf.Graph` in which to perform the regex query.This is 714 required when using regex. 715 'positive_filter': an elem if selected only if `positive_filter(elem)` is 716 `True`. This is optional. 717 'restrict_ts_regex': a regular expression is ignored if it doesn't start 718 with the substring "(?#ts)". 719 Returns: 720 A list of `tf.Tensor`. 721 Raises: 722 TypeError: if the optional keyword argument graph is not a `tf.Graph` 723 or if an argument in args is not an (array of) `tf.Tensor` 724 or an (array of) `tf.Operation` (silently ignored) or a string 725 or a regular expression. 726 ValueError: if one of the keyword arguments is unexpected or if a regular 727 expression is used without passing a graph as a keyword argument. 728 """ 729 # get keywords arguments 730 graph = None 731 positive_filter = None 732 restrict_ts_regex = False 733 for k, v in iteritems(kwargs): 734 if k == "graph": 735 graph = v 736 if graph is not None and not isinstance(graph, tf_ops.Graph): 737 raise TypeError("Expected a tf.Graph, got {}".format(type(graph))) 738 elif k == "positive_filter": 739 positive_filter = v 740 elif k == "restrict_ts_regex": 741 restrict_ts_regex = v 742 elif k == "restrict_ops_regex": 743 pass 744 else: 745 raise ValueError("Wrong keywords argument: {}.".format(k)) 746 747 ts = [] 748 749 for arg in args: 750 if can_be_regex(arg): 751 if graph is None: 752 raise ValueError("Use the keyword argument 'graph' to use regex.") 753 regex = make_regex(arg) 754 if regex.pattern.startswith("(?#ops)"): 755 continue 756 if restrict_ts_regex and not regex.pattern.startswith("(?#ts)"): 757 continue 758 ts_ = filter_ts_from_regex(graph, regex) 759 for t_ in ts_: 760 if t_ not in ts: 761 if positive_filter is None or positive_filter(t_): 762 ts.append(t_) 763 else: 764 ts_aux = util.make_list_of_t(arg, ignore_ops=True) 765 if positive_filter is not None: 766 ts_aux = [t for t in ts_aux if positive_filter(t)] 767 ts_aux = [t for t in ts_aux if t not in ts] 768 ts += ts_aux 769 770 return ts 771 772 773def select_ops_and_ts(*args, **kwargs): 774 """Helper to select operations and tensors. 775 776 Args: 777 *args: list of 1) regular expressions (compiled or not) or 2) (array of) 778 `tf.Operation` 3) (array of) tf.Tensor. Regular expressions matching 779 tensors must start with the comment `"(?#ts)"`, for instance: 780 `"(?#ts)^foo/.*"`. 781 **kwargs: 'graph': `tf.Graph` in which to perform the regex query.This is 782 required when using regex. 783 'positive_filter': an elem if selected only if `positive_filter(elem)` is 784 `True`. This is optional. 785 Returns: 786 A tuple `(ops, ts)` where: 787 `ops` is a list of `tf.Operation`, and 788 `ts` is a list of `tf.Tensor` 789 Raises: 790 TypeError: if the optional keyword argument graph is not a `tf.Graph` 791 or if an argument in args is not an (array of) `tf.Tensor` 792 or an (array of) `tf.Operation` or a string or a regular expression. 793 ValueError: if one of the keyword arguments is unexpected or if a regular 794 expression is used without passing a graph as a keyword argument. 795 """ 796 ops = select_ops(*args, restrict_ops_regex=False, **kwargs) 797 ts = select_ts(*args, restrict_ts_regex=True, **kwargs) 798 return ops, ts 799