• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Tools for selecting ops in a graph."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21from tensorflow.python.framework import ops
22from tensorflow.python.util import object_identity
23
24
25def is_differentiable(op):
26  try:
27    return ops._gradient_registry.lookup(op.op_def.name) is not None  # pylint: disable=protected-access
28  except LookupError:
29    return False
30
31
32def is_iterable(obj):
33  """Return true if the object is iterable."""
34  if isinstance(obj, ops.Tensor):
35    return False
36  try:
37    _ = iter(obj)
38  except Exception:  # pylint: disable=broad-except
39    return False
40  return True
41
42
43def concatenate_unique(la, lb):
44  """Add all the elements of `lb` to `la` if they are not there already.
45
46  The elements added to `la` maintain ordering with respect to `lb`.
47
48  Args:
49    la: List of Python objects.
50    lb: List of Python objects.
51  Returns:
52    `la`: The list `la` with missing elements from `lb`.
53  """
54  la_set = set(la)
55  for l in lb:
56    if l not in la_set:
57      la.append(l)
58      la_set.add(l)
59  return la
60
61
62def get_tensors(graph):
63  """get all the tensors which are input or output of an op in the graph.
64
65  Args:
66    graph: a `tf.Graph`.
67  Returns:
68    A list of `tf.Tensor`.
69  Raises:
70    TypeError: if graph is not a `tf.Graph`.
71  """
72  if not isinstance(graph, ops.Graph):
73    raise TypeError("Expected a graph, got: {}".format(type(graph)))
74  ts = []
75  for op in graph.get_operations():
76    ts += op.outputs
77  return ts
78
79
80def get_unique_graph(tops, check_types=None, none_if_empty=False):
81  """Return the unique graph used by the all the elements in tops.
82
83  Args:
84    tops: list of elements to check (usually a list of tf.Operation and/or
85      tf.Tensor). Or a tf.Graph.
86    check_types: check that the element in tops are of given type(s). If None,
87      the types (tf.Operation, tf.Tensor) are used.
88    none_if_empty: don't raise an error if tops is an empty list, just return
89      None.
90  Returns:
91    The unique graph used by all the tops.
92  Raises:
93    TypeError: if tops is not a iterable of tf.Operation.
94    ValueError: if the graph is not unique.
95  """
96  if isinstance(tops, ops.Graph):
97    return tops
98  if not is_iterable(tops):
99    raise TypeError("{} is not iterable".format(type(tops)))
100  if check_types is None:
101    check_types = (ops.Operation, ops.Tensor)
102  elif not is_iterable(check_types):
103    check_types = (check_types,)
104  g = None
105  for op in tops:
106    if not isinstance(op, check_types):
107      raise TypeError("Expected a type in ({}), got: {}".format(", ".join([str(
108          t) for t in check_types]), type(op)))
109    if g is None:
110      g = op.graph
111    elif g._graph_key != op.graph._graph_key:  # pylint: disable=protected-access
112      raise ValueError("Operation {} does not belong to given graph".format(op))
113  if g is None and not none_if_empty:
114    raise ValueError("Can't find the unique graph of an empty list")
115  return g
116
117
118def check_graphs(*args):
119  """Check that all the element in args belong to the same graph.
120
121  Args:
122    *args: a list of object with a obj.graph property.
123  Raises:
124    ValueError: if all the elements do not belong to the same graph.
125  """
126  graph = None
127  for i, sgv in enumerate(args):
128    if graph is None and sgv.graph is not None:
129      graph = sgv.graph
130    elif sgv.graph is not None and sgv.graph is not graph:
131      raise ValueError(f"args[{i}] does not belong to the same graph as "
132                       "other arguments.")
133
134
135def make_list_of_t(ts, check_graph=True, allow_graph=True, ignore_ops=False):
136  """Convert ts to a list of `tf.Tensor`.
137
138  Args:
139    ts: can be an iterable of `tf.Tensor`, a `tf.Graph` or a single tensor.
140    check_graph: if `True` check if all the tensors belong to the same graph.
141    allow_graph: if `False` a `tf.Graph` cannot be converted.
142    ignore_ops: if `True`, silently ignore `tf.Operation`.
143  Returns:
144    A newly created list of `tf.Tensor`.
145  Raises:
146    TypeError: if `ts` cannot be converted to a list of `tf.Tensor` or,
147     if `check_graph` is `True`, if all the ops do not belong to the same graph.
148  """
149  if isinstance(ts, ops.Graph):
150    if allow_graph:
151      return get_tensors(ts)
152    else:
153      raise TypeError("allow_graph is False: cannot convert a tf.Graph.")
154  else:
155    if not is_iterable(ts):
156      ts = [ts]
157    if not ts:
158      return []
159    if check_graph:
160      check_types = None if ignore_ops else ops.Tensor
161      get_unique_graph(ts, check_types=check_types)
162    return [t for t in ts if isinstance(t, ops.Tensor)]
163
164
165def get_generating_ops(ts):
166  """Return all the generating ops of the tensors in `ts`.
167
168  Args:
169    ts: a list of `tf.Tensor`
170  Returns:
171    A list of all the generating `tf.Operation` of the tensors in `ts`.
172  Raises:
173    TypeError: if `ts` cannot be converted to a list of `tf.Tensor`.
174  """
175  ts = make_list_of_t(ts, allow_graph=False)
176  return [t.op for t in ts]
177
178
179def get_consuming_ops(ts):
180  """Return all the consuming ops of the tensors in ts.
181
182  Args:
183    ts: a list of `tf.Tensor`
184  Returns:
185    A list of all the consuming `tf.Operation` of the tensors in `ts`.
186  Raises:
187    TypeError: if ts cannot be converted to a list of `tf.Tensor`.
188  """
189  ts = make_list_of_t(ts, allow_graph=False)
190  tops = []
191  for t in ts:
192    for op in t.consumers():
193      if op not in tops:
194        tops.append(op)
195  return tops
196
197
198def make_list_of_op(tops, check_graph=True, allow_graph=True, ignore_ts=False):
199  """Convert ops to a list of `tf.Operation`.
200
201  Args:
202    tops: can be an iterable of `tf.Operation`, a `tf.Graph` or a single
203      operation.
204    check_graph: if `True` check if all the operations belong to the same graph.
205    allow_graph: if `False` a `tf.Graph` cannot be converted.
206    ignore_ts: if True, silently ignore `tf.Tensor`.
207  Returns:
208    A newly created list of `tf.Operation`.
209  Raises:
210    TypeError: if tops cannot be converted to a list of `tf.Operation` or,
211     if `check_graph` is `True`, if all the ops do not belong to the
212     same graph.
213  """
214  if isinstance(tops, ops.Graph):
215    if allow_graph:
216      return tops.get_operations()
217    else:
218      raise TypeError("allow_graph is False: cannot convert a tf.Graph.")
219  else:
220    if not is_iterable(tops):
221      tops = [tops]
222    if not tops:
223      return []
224    if check_graph:
225      check_types = None if ignore_ts else ops.Operation
226      get_unique_graph(tops, check_types=check_types)
227    return [op for op in tops if isinstance(op, ops.Operation)]
228
229
230def _get_inputs(op, only_differentiable):
231  op_inputs = op.inputs
232  if only_differentiable:
233    return op_inputs if is_differentiable(op) else []
234  else:
235    return op_inputs
236
237
238def get_backward_walk_ops(seed_ops,
239                          inclusive=True,
240                          within_ops=None,
241                          within_ops_fn=None,
242                          stop_at_ts=(),
243                          control_inputs=False,
244                          only_differentiable=False):
245  """Do a backward graph walk and return all the visited ops.
246
247  Args:
248    seed_ops: an iterable of operations from which the backward graph
249      walk starts. If a list of tensors is given instead, the seed_ops are set
250      to be the generators of those tensors.
251    inclusive: if True the given seed_ops are also part of the resulting set.
252    within_ops: an iterable of `tf.Operation` within which the search is
253      restricted. If `within_ops` is `None`, the search is performed within
254      the whole graph.
255    within_ops_fn: if provided, a function on ops that should return True iff
256      the op is within the graph traversal. This can be used along within_ops,
257      in which case an op is within if it is also in within_ops.
258    stop_at_ts: an iterable of tensors at which the graph walk stops.
259    control_inputs: if True, control inputs will be used while moving backward.
260    only_differentiable: if True, only traverse ops which are differentiable.
261      This includes natively differentiable ops, or ops with custom gradients.
262  Returns:
263    A Python set of all the `tf.Operation` behind `seed_ops`.
264  Raises:
265    TypeError: if `seed_ops` or `within_ops` cannot be converted to a list of
266      `tf.Operation`.
267  """
268  control_inputs = control_inputs and (not only_differentiable)
269
270  if not is_iterable(seed_ops):
271    seed_ops = [seed_ops]
272  if not seed_ops:
273    return []
274  if isinstance(seed_ops[0], ops.Tensor):
275    ts = make_list_of_t(seed_ops, allow_graph=False)
276    seed_ops = get_generating_ops(ts)
277  else:
278    seed_ops = make_list_of_op(seed_ops, allow_graph=False)
279
280  stop_at_ts = object_identity.ObjectIdentitySet(make_list_of_t(stop_at_ts))
281  seed_ops = object_identity.ObjectIdentitySet(make_list_of_op(seed_ops))
282  if within_ops:
283    within_ops = make_list_of_op(within_ops, allow_graph=False)
284    within_ops = object_identity.ObjectIdentitySet(within_ops)
285    seed_ops &= within_ops
286
287  def is_within(op):
288    return (within_ops is None or op in within_ops) and (
289        within_ops_fn is None or within_ops_fn(op))
290
291  result = list(seed_ops)
292  wave = set(seed_ops)
293  while wave:
294    new_wave = set()
295    for op in wave:
296      for new_t in _get_inputs(op, only_differentiable=only_differentiable):
297        if new_t in stop_at_ts:
298          continue
299        if new_t.op not in result and is_within(new_t.op):
300          new_wave.add(new_t.op)
301      if control_inputs:
302        for new_op in op.control_inputs:
303          if new_op not in result and is_within(new_op):
304            new_wave.add(new_op)
305    concatenate_unique(result, new_wave)
306    wave = new_wave
307  if not inclusive:
308    result = [op for op in result if op not in seed_ops]
309  return result
310
311
312class UnliftableError(Exception):
313  """Raised if a Tensor cannot be lifted from the graph."""
314
315  # Prevent autograph from rewriting this error.
316  ag_pass_through = True
317
318
319def _as_operation(op_or_tensor):
320  if isinstance(op_or_tensor, ops.Tensor):
321    return op_or_tensor.op
322  return op_or_tensor
323
324
325def graph_inputs(op):
326  return [x.op for x in op.inputs] + list(op.control_inputs)
327
328
329def _path_from(from_op, tensor, sources):
330  """Find one path from `from_op` to `tensor`, ignoring `sources`.
331
332  Args:
333    from_op: A `tf.Operation`.
334    tensor: A `tf.Operation` or `tf.Tensor`.
335    sources: A list of `tf.Tensor`.
336
337  Returns:
338    A python string containing the path, or "??" if none is found.
339  """
340  if isinstance(from_op, ops.Tensor):
341    from_op = from_op.op
342
343  visited_ops = set(x.op for x in sources)
344  ops_to_visit = [_as_operation(tensor)]
345  some_op_output = {}
346  while ops_to_visit:
347    op = ops_to_visit.pop()
348    if op in visited_ops:
349      continue
350    visited_ops.add(op)
351    if op == from_op:
352      path_op = op
353      path = [path_op]
354      final_op = _as_operation(tensor)
355      while path_op != final_op:
356        path_op = some_op_output[path_op]
357        path.append(path_op)
358      return " <- ".join("%s (%s)" % (x.name, x.type) for x in reversed(path))
359    else:
360      for inp in graph_inputs(op):
361        if inp not in visited_ops and inp not in sources:
362          some_op_output[inp] = op
363          ops_to_visit.append(inp)
364  return "??"
365
366
367# TODO(jmenick) - there is considerable duplication of functionality between
368# this function and get_backward_walk_ops(). Need to deduplicate.
369def map_subgraph(init_tensor, sources, disallowed_placeholders, visited_ops,
370                 op_outputs, add_sources):
371  """Walk a Graph and capture the subgraph between init_tensor and sources.
372
373  Note: This function mutates visited_ops and op_outputs.
374
375  Args:
376    init_tensor:  A Tensor or Operation where the subgraph terminates.
377    sources:  A set of Tensors where subgraph extraction should stop.
378    disallowed_placeholders: An optional set of ops which may not appear in the
379      lifted graph. Defaults to all placeholders.
380    visited_ops: A set of operations which were visited in a prior pass.
381    op_outputs: A defaultdict containing the outputs of an op which are to be
382      copied into the new subgraph.
383    add_sources: A boolean indicating whether placeholders which are not in
384      sources should be allowed.
385
386  Returns:
387    The set of placeholders upon which init_tensor depends and are not in
388    sources.
389
390  Raises:
391    UnliftableError: if init_tensor depends on a placeholder which is not in
392      sources and add_sources is False.
393  """
394  ops_to_visit = [_as_operation(init_tensor)]
395  extra_sources = object_identity.ObjectIdentitySet()
396  while ops_to_visit:
397    op = ops_to_visit.pop()
398    if op in visited_ops:
399      continue
400    visited_ops.add(op)
401
402    should_raise = False
403    if disallowed_placeholders is not None and op in disallowed_placeholders:
404      should_raise = True
405    elif op.type == "Placeholder":
406      if disallowed_placeholders is None and not add_sources:
407        should_raise = True
408      extra_sources.update(op.outputs)
409
410    if should_raise:
411      raise UnliftableError(
412          "Unable to lift tensor %s because it depends transitively on "
413          "placeholder %s via at least one path, e.g.: %s"
414          % (repr(init_tensor), repr(op), _path_from(op, init_tensor, sources)))
415    for inp in graph_inputs(op):
416      op_outputs[inp].add(op)
417      if inp not in visited_ops and inp not in (sources or extra_sources):
418        ops_to_visit.append(inp)
419
420  return extra_sources
421