• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Utilities for using the TensorFlow C API."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22from tensorflow.core.framework import api_def_pb2
23from tensorflow.core.framework import op_def_pb2
24from tensorflow.python.client import pywrap_tf_session as c_api
25from tensorflow.python.util import compat
26from tensorflow.python.util import tf_contextlib
27
28
29class ScopedTFStatus(object):
30  """Wrapper around TF_Status that handles deletion."""
31
32  __slots__ = ["status"]
33
34  def __init__(self):
35    self.status = c_api.TF_NewStatus()
36
37  def __del__(self):
38    # Note: when we're destructing the global context (i.e when the process is
39    # terminating) we can have already deleted other modules.
40    if c_api is not None and c_api.TF_DeleteStatus is not None:
41      c_api.TF_DeleteStatus(self.status)
42
43
44class ScopedTFGraph(object):
45  """Wrapper around TF_Graph that handles deletion."""
46
47  __slots__ = ["graph", "deleter"]
48
49  def __init__(self):
50    self.graph = c_api.TF_NewGraph()
51    # Note: when we're destructing the global context (i.e when the process is
52    # terminating) we may have already deleted other modules. By capturing the
53    # DeleteGraph function here, we retain the ability to cleanly destroy the
54    # graph at shutdown, which satisfies leak checkers.
55    self.deleter = c_api.TF_DeleteGraph
56
57  def __del__(self):
58    self.deleter(self.graph)
59
60
61class ScopedTFImportGraphDefOptions(object):
62  """Wrapper around TF_ImportGraphDefOptions that handles deletion."""
63
64  __slots__ = ["options"]
65
66  def __init__(self):
67    self.options = c_api.TF_NewImportGraphDefOptions()
68
69  def __del__(self):
70    # Note: when we're destructing the global context (i.e when the process is
71    # terminating) we can have already deleted other modules.
72    if c_api is not None and c_api.TF_DeleteImportGraphDefOptions is not None:
73      c_api.TF_DeleteImportGraphDefOptions(self.options)
74
75
76class ScopedTFImportGraphDefResults(object):
77  """Wrapper around TF_ImportGraphDefOptions that handles deletion."""
78
79  __slots__ = ["results"]
80
81  def __init__(self, results):
82    self.results = results
83
84  def __del__(self):
85    # Note: when we're destructing the global context (i.e when the process is
86    # terminating) we can have already deleted other modules.
87    if c_api is not None and c_api.TF_DeleteImportGraphDefResults is not None:
88      c_api.TF_DeleteImportGraphDefResults(self.results)
89
90
91class ScopedTFFunction(object):
92  """Wrapper around TF_Function that handles deletion."""
93
94  __slots__ = ["func", "deleter"]
95
96  def __init__(self, func):
97    self.func = func
98    # Note: when we're destructing the global context (i.e when the process is
99    # terminating) we may have already deleted other modules. By capturing the
100    # DeleteFunction function here, we retain the ability to cleanly destroy the
101    # Function at shutdown, which satisfies leak checkers.
102    self.deleter = c_api.TF_DeleteFunction
103
104  @property
105  def has_been_garbage_collected(self):
106    return self.func is None
107
108  def __del__(self):
109    if not self.has_been_garbage_collected:
110      self.deleter(self.func)
111      self.func = None
112
113
114class ScopedTFBuffer(object):
115  """An internal class to help manage the TF_Buffer lifetime."""
116
117  __slots__ = ["buffer"]
118
119  def __init__(self, buf_string):
120    self.buffer = c_api.TF_NewBufferFromString(compat.as_bytes(buf_string))
121
122  def __del__(self):
123    c_api.TF_DeleteBuffer(self.buffer)
124
125
126class ApiDefMap(object):
127  """Wrapper around Tf_ApiDefMap that handles querying and deletion.
128
129  The OpDef protos are also stored in this class so that they could
130  be queried by op name.
131  """
132
133  __slots__ = ["_api_def_map", "_op_per_name"]
134
135  def __init__(self):
136    op_def_proto = op_def_pb2.OpList()
137    buf = c_api.TF_GetAllOpList()
138    try:
139      op_def_proto.ParseFromString(c_api.TF_GetBuffer(buf))
140      self._api_def_map = c_api.TF_NewApiDefMap(buf)
141    finally:
142      c_api.TF_DeleteBuffer(buf)
143
144    self._op_per_name = {}
145    for op in op_def_proto.op:
146      self._op_per_name[op.name] = op
147
148  def __del__(self):
149    # Note: when we're destructing the global context (i.e when the process is
150    # terminating) we can have already deleted other modules.
151    if c_api is not None and c_api.TF_DeleteApiDefMap is not None:
152      c_api.TF_DeleteApiDefMap(self._api_def_map)
153
154  def put_api_def(self, text):
155    c_api.TF_ApiDefMapPut(self._api_def_map, text, len(text))
156
157  def get_api_def(self, op_name):
158    api_def_proto = api_def_pb2.ApiDef()
159    buf = c_api.TF_ApiDefMapGet(self._api_def_map, op_name, len(op_name))
160    try:
161      api_def_proto.ParseFromString(c_api.TF_GetBuffer(buf))
162    finally:
163      c_api.TF_DeleteBuffer(buf)
164    return api_def_proto
165
166  def get_op_def(self, op_name):
167    if op_name in self._op_per_name:
168      return self._op_per_name[op_name]
169    raise ValueError("No entry found for " + op_name + ".")
170
171  def op_names(self):
172    return self._op_per_name.keys()
173
174
175@tf_contextlib.contextmanager
176def tf_buffer(data=None):
177  """Context manager that creates and deletes TF_Buffer.
178
179  Example usage:
180    with tf_buffer() as buf:
181      # get serialized graph def into buf
182      ...
183      proto_data = c_api.TF_GetBuffer(buf)
184      graph_def.ParseFromString(compat.as_bytes(proto_data))
185    # buf has been deleted
186
187    with tf_buffer(some_string) as buf:
188      c_api.TF_SomeFunction(buf)
189    # buf has been deleted
190
191  Args:
192    data: An optional `bytes`, `str`, or `unicode` object. If not None, the
193      yielded buffer will contain this data.
194
195  Yields:
196    Created TF_Buffer
197  """
198  if data:
199    buf = c_api.TF_NewBufferFromString(compat.as_bytes(data))
200  else:
201    buf = c_api.TF_NewBuffer()
202  try:
203    yield buf
204  finally:
205    c_api.TF_DeleteBuffer(buf)
206
207
208def tf_output(c_op, index):
209  """Returns a wrapped TF_Output with specified operation and index.
210
211  Args:
212    c_op: wrapped TF_Operation
213    index: integer
214
215  Returns:
216    Wrapped TF_Output
217  """
218  ret = c_api.TF_Output()
219  ret.oper = c_op
220  ret.index = index
221  return ret
222
223
224def tf_operations(graph):
225  """Generator that yields every TF_Operation in `graph`.
226
227  Args:
228    graph: Graph
229
230  Yields:
231    wrapped TF_Operation
232  """
233  # pylint: disable=protected-access
234  pos = 0
235  c_op, pos = c_api.TF_GraphNextOperation(graph._c_graph, pos)
236  while c_op is not None:
237    yield c_op
238    c_op, pos = c_api.TF_GraphNextOperation(graph._c_graph, pos)
239  # pylint: enable=protected-access
240
241
242def new_tf_operations(graph):
243  """Generator that yields newly-added TF_Operations in `graph`.
244
245  Specifically, yields TF_Operations that don't have associated Operations in
246  `graph`. This is useful for processing nodes added by the C API.
247
248  Args:
249    graph: Graph
250
251  Yields:
252    wrapped TF_Operation
253  """
254  # TODO(b/69679162): do this more efficiently
255  for c_op in tf_operations(graph):
256    try:
257      graph._get_operation_by_tf_operation(c_op)  # pylint: disable=protected-access
258    except KeyError:
259      yield c_op
260