• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Utilities for using the TensorFlow C API."""
17
18import contextlib
19from tensorflow.core.framework import api_def_pb2
20from tensorflow.core.framework import op_def_pb2
21from tensorflow.python.client import pywrap_tf_session as c_api
22from tensorflow.python.util import compat
23from tensorflow.python.util import tf_contextlib
24
25
26class AlreadyGarbageCollectedError(Exception):
27
28  def __init__(self, name, obj_type):
29    super(AlreadyGarbageCollectedError,
30          self).__init__(f"{name} of type {obj_type} has already been garbage "
31                         f"collected and cannot be called.")
32
33
34# FIXME(b/235488206): Convert all Scoped objects to the context manager
35# to protect against deletion during use when the object is attached to
36# an attribute.
37class UniquePtr(object):
38  """Wrapper around single-ownership C-API objects that handles deletion."""
39
40  __slots__ = ["_obj", "deleter", "name", "type_name"]
41
42  def __init__(self, name, obj, deleter):
43    # '_' prefix marks _obj private, but unclear if it is required also to
44    # maintain a special CPython destruction order.
45    self._obj = obj
46    self.name = name
47    # Note: when we're destructing the global context (i.e when the process is
48    # terminating) we may have already deleted other modules. By capturing the
49    # DeleteGraph function here, we retain the ability to cleanly destroy the
50    # graph at shutdown, which satisfies leak checkers.
51    self.deleter = deleter
52    self.type_name = str(type(obj))
53
54  @contextlib.contextmanager
55  def get(self):
56    """Yields the managed C-API Object, guaranteeing aliveness.
57
58    This is a context manager. Inside the context the C-API object is
59    guaranteed to be alive.
60
61    Raises:
62      AlreadyGarbageCollectedError: if the object is already deleted.
63    """
64    # Thread-safety: self.__del__ never runs during the call of this function
65    # because there is a reference to self from the argument list.
66    if self._obj is None:
67      raise AlreadyGarbageCollectedError(self.name, self.type_name)
68    yield self._obj
69
70  def __del__(self):
71    obj = self._obj
72    if obj is not None:
73      self._obj = None
74      self.deleter(obj)
75
76
77class ScopedTFStatus(object):
78  """Wrapper around TF_Status that handles deletion."""
79
80  __slots__ = ["status"]
81
82  def __init__(self):
83    self.status = c_api.TF_NewStatus()
84
85  def __del__(self):
86    # Note: when we're destructing the global context (i.e when the process is
87    # terminating) we can have already deleted other modules.
88    if c_api is not None and c_api.TF_DeleteStatus is not None:
89      c_api.TF_DeleteStatus(self.status)
90
91
92class ScopedTFGraph(UniquePtr):
93  """Wrapper around TF_Graph that handles deletion."""
94
95  def __init__(self, name):
96    super(ScopedTFGraph, self).__init__(
97        name, obj=c_api.TF_NewGraph(), deleter=c_api.TF_DeleteGraph)
98
99
100class ScopedTFImportGraphDefOptions(object):
101  """Wrapper around TF_ImportGraphDefOptions that handles deletion."""
102
103  __slots__ = ["options"]
104
105  def __init__(self):
106    self.options = c_api.TF_NewImportGraphDefOptions()
107
108  def __del__(self):
109    # Note: when we're destructing the global context (i.e when the process is
110    # terminating) we can have already deleted other modules.
111    if c_api is not None and c_api.TF_DeleteImportGraphDefOptions is not None:
112      c_api.TF_DeleteImportGraphDefOptions(self.options)
113
114
115class ScopedTFImportGraphDefResults(object):
116  """Wrapper around TF_ImportGraphDefOptions that handles deletion."""
117
118  __slots__ = ["results"]
119
120  def __init__(self, results):
121    self.results = results
122
123  def __del__(self):
124    # Note: when we're destructing the global context (i.e when the process is
125    # terminating) we can have already deleted other modules.
126    if c_api is not None and c_api.TF_DeleteImportGraphDefResults is not None:
127      c_api.TF_DeleteImportGraphDefResults(self.results)
128
129
130class ScopedTFFunction(UniquePtr):
131  """Wrapper around TF_Function that handles deletion."""
132
133  def __init__(self, func, name):
134    super(ScopedTFFunction, self).__init__(
135        name=name, obj=func, deleter=c_api.TF_DeleteFunction)
136
137
138class ScopedTFBuffer(object):
139  """An internal class to help manage the TF_Buffer lifetime."""
140
141  __slots__ = ["buffer"]
142
143  def __init__(self, buf_string):
144    self.buffer = c_api.TF_NewBufferFromString(compat.as_bytes(buf_string))
145
146  def __del__(self):
147    c_api.TF_DeleteBuffer(self.buffer)
148
149
150class ApiDefMap(object):
151  """Wrapper around Tf_ApiDefMap that handles querying and deletion.
152
153  The OpDef protos are also stored in this class so that they could
154  be queried by op name.
155  """
156
157  __slots__ = ["_api_def_map", "_op_per_name"]
158
159  def __init__(self):
160    op_def_proto = op_def_pb2.OpList()
161    buf = c_api.TF_GetAllOpList()
162    try:
163      op_def_proto.ParseFromString(c_api.TF_GetBuffer(buf))
164      self._api_def_map = c_api.TF_NewApiDefMap(buf)
165    finally:
166      c_api.TF_DeleteBuffer(buf)
167
168    self._op_per_name = {}
169    for op in op_def_proto.op:
170      self._op_per_name[op.name] = op
171
172  def __del__(self):
173    # Note: when we're destructing the global context (i.e when the process is
174    # terminating) we can have already deleted other modules.
175    if c_api is not None and c_api.TF_DeleteApiDefMap is not None:
176      c_api.TF_DeleteApiDefMap(self._api_def_map)
177
178  def put_api_def(self, text):
179    c_api.TF_ApiDefMapPut(self._api_def_map, text, len(text))
180
181  def get_api_def(self, op_name):
182    api_def_proto = api_def_pb2.ApiDef()
183    buf = c_api.TF_ApiDefMapGet(self._api_def_map, op_name, len(op_name))
184    try:
185      api_def_proto.ParseFromString(c_api.TF_GetBuffer(buf))
186    finally:
187      c_api.TF_DeleteBuffer(buf)
188    return api_def_proto
189
190  def get_op_def(self, op_name):
191    if op_name in self._op_per_name:
192      return self._op_per_name[op_name]
193    raise ValueError(f"No op_def found for op name {op_name}.")
194
195  def op_names(self):
196    return self._op_per_name.keys()
197
198
199@tf_contextlib.contextmanager
200def tf_buffer(data=None):
201  """Context manager that creates and deletes TF_Buffer.
202
203  Example usage:
204    with tf_buffer() as buf:
205      # get serialized graph def into buf
206      ...
207      proto_data = c_api.TF_GetBuffer(buf)
208      graph_def.ParseFromString(compat.as_bytes(proto_data))
209    # buf has been deleted
210
211    with tf_buffer(some_string) as buf:
212      c_api.TF_SomeFunction(buf)
213    # buf has been deleted
214
215  Args:
216    data: An optional `bytes`, `str`, or `unicode` object. If not None, the
217      yielded buffer will contain this data.
218
219  Yields:
220    Created TF_Buffer
221  """
222  if data:
223    buf = c_api.TF_NewBufferFromString(compat.as_bytes(data))
224  else:
225    buf = c_api.TF_NewBuffer()
226  try:
227    yield buf
228  finally:
229    c_api.TF_DeleteBuffer(buf)
230
231
232def tf_output(c_op, index):
233  """Returns a wrapped TF_Output with specified operation and index.
234
235  Args:
236    c_op: wrapped TF_Operation
237    index: integer
238
239  Returns:
240    Wrapped TF_Output
241  """
242  ret = c_api.TF_Output()
243  ret.oper = c_op
244  ret.index = index
245  return ret
246
247
248def tf_operations(graph):
249  """Generator that yields every TF_Operation in `graph`.
250
251  Args:
252    graph: Graph
253
254  Yields:
255    wrapped TF_Operation
256  """
257  # pylint: disable=protected-access
258  pos = 0
259  with graph._c_graph.get() as c_graph:
260    c_op, pos = c_api.TF_GraphNextOperation(c_graph, pos)
261    while c_op is not None:
262      yield c_op
263      c_op, pos = c_api.TF_GraphNextOperation(c_graph, pos)
264  # pylint: enable=protected-access
265
266
267def new_tf_operations(graph):
268  """Generator that yields newly-added TF_Operations in `graph`.
269
270  Specifically, yields TF_Operations that don't have associated Operations in
271  `graph`. This is useful for processing nodes added by the C API.
272
273  Args:
274    graph: Graph
275
276  Yields:
277    wrapped TF_Operation
278  """
279  # TODO(b/69679162): do this more efficiently
280  for c_op in tf_operations(graph):
281    try:
282      graph._get_operation_by_tf_operation(c_op)  # pylint: disable=protected-access
283    except KeyError:
284      yield c_op
285