• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tools for selecting ops in a graph."""
16
17from tensorflow.python.framework import ops
18from tensorflow.python.util import object_identity
19
20
21def is_differentiable(op):
22  try:
23    return ops._gradient_registry.lookup(op.op_def.name) is not None  # pylint: disable=protected-access
24  except LookupError:
25    return False
26
27
28def is_iterable(obj):
29  """Return true if the object is iterable."""
30  if isinstance(obj, ops.Tensor):
31    return False
32  try:
33    _ = iter(obj)
34  except Exception:  # pylint: disable=broad-except
35    return False
36  return True
37
38
39def concatenate_unique(la, lb):
40  """Add all the elements of `lb` to `la` if they are not there already.
41
42  The elements added to `la` maintain ordering with respect to `lb`.
43
44  Args:
45    la: List of Python objects.
46    lb: List of Python objects.
47  Returns:
48    `la`: The list `la` with missing elements from `lb`.
49  """
50  la_set = set(la)
51  for l in lb:
52    if l not in la_set:
53      la.append(l)
54      la_set.add(l)
55  return la
56
57
58def get_tensors(graph):
59  """get all the tensors which are input or output of an op in the graph.
60
61  Args:
62    graph: a `tf.Graph`.
63  Returns:
64    A list of `tf.Tensor`.
65  Raises:
66    TypeError: if graph is not a `tf.Graph`.
67  """
68  if not isinstance(graph, ops.Graph):
69    raise TypeError("Expected a graph, got: {}".format(type(graph)))
70  ts = []
71  for op in graph.get_operations():
72    ts += op.outputs
73  return ts
74
75
76def get_unique_graph(tops, check_types=None, none_if_empty=False):
77  """Return the unique graph used by the all the elements in tops.
78
79  Args:
80    tops: iterable of elements to check (usually a list of tf.Operation and/or
81      tf.Tensor). Or a tf.Graph.
82    check_types: check that the element in tops are of given type(s). If None,
83      the types (tf.Operation, tf.Tensor) are used.
84    none_if_empty: don't raise an error if tops is an empty list, just return
85      None.
86  Returns:
87    The unique graph used by all the tops.
88  Raises:
89    TypeError: if tops is not a iterable of tf.Operation.
90    ValueError: if the graph is not unique.
91  """
92  if isinstance(tops, ops.Graph):
93    return tops
94  if not is_iterable(tops):
95    raise TypeError("{} is not iterable".format(type(tops)))
96  if check_types is None:
97    check_types = (ops.Operation, ops.Tensor)
98  elif not is_iterable(check_types):
99    check_types = (check_types,)
100  g = None
101  for op in tops:
102    if not isinstance(op, check_types):
103      raise TypeError("Expected a type in ({}), got: {}".format(", ".join([str(
104          t) for t in check_types]), type(op)))
105    if g is None:
106      g = op.graph
107    elif g._graph_key != op.graph._graph_key:  # pylint: disable=protected-access
108      raise ValueError("Operation {} does not belong to given graph".format(op))
109  if g is None and not none_if_empty:
110    raise ValueError("Can't find the unique graph of an empty list")
111  return g
112
113
114def check_graphs(*args):
115  """Check that all the element in args belong to the same graph.
116
117  Args:
118    *args: a list of object with a obj.graph property.
119  Raises:
120    ValueError: if all the elements do not belong to the same graph.
121  """
122  graph = None
123  for i, sgv in enumerate(args):
124    if graph is None and sgv.graph is not None:
125      graph = sgv.graph
126    elif sgv.graph is not None and sgv.graph is not graph:
127      raise ValueError(f"args[{i}] does not belong to the same graph as "
128                       "other arguments.")
129
130
131def make_list_of_t(ts, check_graph=True, allow_graph=True, ignore_ops=False):
132  """Convert ts to a list of `tf.Tensor`.
133
134  Args:
135    ts: can be an iterable of `tf.Tensor`, a `tf.Graph` or a single tensor.
136    check_graph: if `True` check if all the tensors belong to the same graph.
137    allow_graph: if `False` a `tf.Graph` cannot be converted.
138    ignore_ops: if `True`, silently ignore `tf.Operation`.
139  Returns:
140    A newly created list of `tf.Tensor`.
141  Raises:
142    TypeError: if `ts` cannot be converted to a list of `tf.Tensor` or,
143     if `check_graph` is `True`, if all the ops do not belong to the same graph.
144  """
145  if isinstance(ts, ops.Graph):
146    if allow_graph:
147      return get_tensors(ts)
148    else:
149      raise TypeError("allow_graph is False: cannot convert a tf.Graph.")
150  else:
151    if not is_iterable(ts):
152      ts = [ts]
153    if not ts:
154      return []
155    if check_graph:
156      check_types = None if ignore_ops else ops.Tensor
157      get_unique_graph(ts, check_types=check_types)
158    return [t for t in ts if isinstance(t, ops.Tensor)]
159
160
161def get_generating_ops(ts):
162  """Return all the generating ops of the tensors in `ts`.
163
164  Args:
165    ts: a list of `tf.Tensor`
166  Returns:
167    A list of all the generating `tf.Operation` of the tensors in `ts`.
168  Raises:
169    TypeError: if `ts` cannot be converted to a list of `tf.Tensor`.
170  """
171  ts = make_list_of_t(ts, allow_graph=False)
172  return [t.op for t in ts]
173
174
175def get_consuming_ops(ts):
176  """Return all the consuming ops of the tensors in ts.
177
178  Args:
179    ts: a list of `tf.Tensor`
180  Returns:
181    A list of all the consuming `tf.Operation` of the tensors in `ts`.
182  Raises:
183    TypeError: if ts cannot be converted to a list of `tf.Tensor`.
184  """
185  ts = make_list_of_t(ts, allow_graph=False)
186  tops = []
187  for t in ts:
188    for op in t.consumers():
189      if op not in tops:
190        tops.append(op)
191  return tops
192
193
194def make_list_of_op(tops, check_graph=True, allow_graph=True, ignore_ts=False):
195  """Convert ops to a list of `tf.Operation`.
196
197  Args:
198    tops: can be an iterable of `tf.Operation`, a `tf.Graph` or a single
199      operation.
200    check_graph: if `True` check if all the operations belong to the same graph.
201    allow_graph: if `False` a `tf.Graph` cannot be converted.
202    ignore_ts: if True, silently ignore `tf.Tensor`.
203  Returns:
204    A newly created list of `tf.Operation`.
205  Raises:
206    TypeError: if tops cannot be converted to a list of `tf.Operation` or,
207     if `check_graph` is `True`, if all the ops do not belong to the
208     same graph.
209  """
210  if isinstance(tops, ops.Graph):
211    if allow_graph:
212      return tops.get_operations()
213    else:
214      raise TypeError("allow_graph is False: cannot convert a tf.Graph.")
215  else:
216    if not is_iterable(tops):
217      tops = [tops]
218    if not tops:
219      return []
220    if check_graph:
221      check_types = None if ignore_ts else ops.Operation
222      get_unique_graph(tops, check_types=check_types)
223    return [op for op in tops if isinstance(op, ops.Operation)]
224
225
226def _get_inputs(op, only_differentiable):
227  op_inputs = op.inputs
228  if only_differentiable:
229    return op_inputs if is_differentiable(op) else []
230  else:
231    return op_inputs
232
233
234def get_backward_walk_ops(seed_ops,
235                          inclusive=True,
236                          within_ops=None,
237                          within_ops_fn=None,
238                          stop_at_ts=(),
239                          control_inputs=False,
240                          only_differentiable=False):
241  """Do a backward graph walk and return all the visited ops.
242
243  Args:
244    seed_ops: an iterable of operations from which the backward graph
245      walk starts. If a list of tensors is given instead, the seed_ops are set
246      to be the generators of those tensors.
247    inclusive: if True the given seed_ops are also part of the resulting set.
248    within_ops: an iterable of `tf.Operation` within which the search is
249      restricted. If `within_ops` is `None`, the search is performed within
250      the whole graph.
251    within_ops_fn: if provided, a function on ops that should return True iff
252      the op is within the graph traversal. This can be used along within_ops,
253      in which case an op is within if it is also in within_ops.
254    stop_at_ts: an iterable of tensors at which the graph walk stops.
255    control_inputs: if True, control inputs will be used while moving backward.
256    only_differentiable: if True, only traverse ops which are differentiable.
257      This includes natively differentiable ops, or ops with custom gradients.
258  Returns:
259    A Python set of all the `tf.Operation` behind `seed_ops`.
260  Raises:
261    TypeError: if `seed_ops` or `within_ops` cannot be converted to a list of
262      `tf.Operation`.
263  """
264  control_inputs = control_inputs and (not only_differentiable)
265
266  if not is_iterable(seed_ops):
267    seed_ops = [seed_ops]
268
269  try:
270    first_seed_op = next(iter(seed_ops))
271  except StopIteration:
272    # Empty iterable.
273    return []
274
275  if isinstance(first_seed_op, ops.Tensor):
276    ts = make_list_of_t(seed_ops, allow_graph=False)
277    seed_ops = get_generating_ops(ts)
278  else:
279    seed_ops = make_list_of_op(seed_ops, allow_graph=False)
280
281  stop_at_ts = object_identity.ObjectIdentitySet(make_list_of_t(stop_at_ts))
282  seed_ops = object_identity.ObjectIdentitySet(make_list_of_op(seed_ops))
283  if within_ops:
284    within_ops = make_list_of_op(within_ops, allow_graph=False)
285    within_ops = object_identity.ObjectIdentitySet(within_ops)
286    seed_ops &= within_ops
287
288  def is_within(op):
289    return (within_ops is None or op in within_ops) and (
290        within_ops_fn is None or within_ops_fn(op))
291
292  result = list(seed_ops)
293  wave = set(seed_ops)
294  while wave:
295    new_wave = set()
296    for op in wave:
297      for new_t in _get_inputs(op, only_differentiable=only_differentiable):
298        if new_t in stop_at_ts:
299          continue
300        if new_t.op not in result and is_within(new_t.op):
301          new_wave.add(new_t.op)
302      if control_inputs:
303        for new_op in op.control_inputs:
304          if new_op not in result and is_within(new_op):
305            new_wave.add(new_op)
306    concatenate_unique(result, new_wave)
307    wave = new_wave
308  if not inclusive:
309    result = [op for op in result if op not in seed_ops]
310  return result
311
312
313class UnliftableError(Exception):
314  """Raised if a Tensor cannot be lifted from the graph."""
315
316  # Prevent autograph from rewriting this error.
317  ag_pass_through = True
318
319
320def _as_operation(op_or_tensor):
321  if isinstance(op_or_tensor, ops.Tensor):
322    return op_or_tensor.op
323  return op_or_tensor
324
325
326def graph_inputs(op):
327  return [x.op for x in op.inputs] + list(op.control_inputs)
328
329
330def show_path(from_op, tensors, sources):
331  """Find one path from `from_op` to any of `tensors`, ignoring `sources`.
332
333  Args:
334    from_op: A `tf.Operation`.
335    tensors: A `tf.Operation`, a `tf.Tensor`, or a list thereof.
336    sources: A list of `tf.Tensor`.
337
338  Returns:
339    A python string containing the path, or "??" if none is found.
340  """
341  if isinstance(from_op, ops.Tensor):
342    from_op = from_op.op
343
344  if not isinstance(tensors, list):
345    tensors = [tensors]
346
347  final_ops = [_as_operation(tensor) for tensor in tensors]
348
349  visited_ops = set(x.op for x in sources)
350  ops_to_visit = list(final_ops)
351  some_op_output = {}
352  while ops_to_visit:
353    op = ops_to_visit.pop()
354    if op in visited_ops:
355      continue
356    visited_ops.add(op)
357    if op == from_op:
358      path_op = op
359      path = [path_op]
360      while path_op not in final_ops:
361        path_op = some_op_output[path_op]
362        path.append(path_op)
363      return " <- ".join("%s (%s)" % (x.name, x.type) for x in reversed(path))
364    else:
365      for inp in graph_inputs(op):
366        if inp not in visited_ops and inp not in sources:
367          some_op_output[inp] = op
368          ops_to_visit.append(inp)
369  return "??"
370
371
372# TODO(jmenick) - there is considerable duplication of functionality between
373# this function and get_backward_walk_ops(). Need to deduplicate.
374def map_subgraph(init_tensor, sources, disallowed_placeholders, visited_ops,
375                 op_outputs, add_sources):
376  """Walk a Graph and capture the subgraph between init_tensor and sources.
377
378  Note: This function mutates visited_ops and op_outputs.
379
380  Args:
381    init_tensor:  A Tensor or Operation where the subgraph terminates.
382    sources:  A set of Tensors where subgraph extraction should stop.
383    disallowed_placeholders: An optional set of ops which may not appear in the
384      lifted graph. Defaults to all placeholders.
385    visited_ops: A set of operations which were visited in a prior pass.
386    op_outputs: A defaultdict containing the outputs of an op which are to be
387      copied into the new subgraph.
388    add_sources: A boolean indicating whether placeholders which are not in
389      sources should be allowed.
390
391  Returns:
392    The set of placeholders upon which init_tensor depends and are not in
393    sources.
394
395  Raises:
396    UnliftableError: if init_tensor depends on a placeholder which is not in
397      sources and add_sources is False.
398  """
399  ops_to_visit = [_as_operation(init_tensor)]
400  extra_sources = object_identity.ObjectIdentitySet()
401  while ops_to_visit:
402    op = ops_to_visit.pop()
403    if op in visited_ops:
404      continue
405    visited_ops.add(op)
406
407    should_raise = False
408    if disallowed_placeholders is not None and op in disallowed_placeholders:
409      should_raise = True
410    elif op.type == "Placeholder":
411      if disallowed_placeholders is None and not add_sources:
412        should_raise = True
413      extra_sources.update(op.outputs)
414
415    if should_raise:
416      raise UnliftableError(
417          "Unable to lift tensor %s because it depends transitively on "
418          "placeholder %s via at least one path, e.g.: %s" %
419          (repr(init_tensor), repr(op), show_path(op, init_tensor, sources)))
420    for inp in graph_inputs(op):
421      op_outputs[inp].add(op)
422      if inp not in visited_ops and inp not in (sources or extra_sources):
423        ops_to_visit.append(inp)
424
425  return extra_sources
426