• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Utilities for using the TensorFlow C API."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22from tensorflow.core.framework import api_def_pb2
23from tensorflow.core.framework import op_def_pb2
24from tensorflow.python.client import pywrap_tf_session as c_api
25from tensorflow.python.util import compat
26from tensorflow.python.util import tf_contextlib
27
28
29class ScopedTFStatus(object):
30  """Wrapper around TF_Status that handles deletion."""
31
32  __slots__ = ["status"]
33
34  def __init__(self):
35    self.status = c_api.TF_NewStatus()
36
37  def __del__(self):
38    # Note: when we're destructing the global context (i.e when the process is
39    # terminating) we can have already deleted other modules.
40    if c_api is not None and c_api.TF_DeleteStatus is not None:
41      c_api.TF_DeleteStatus(self.status)
42
43
44class ScopedTFGraph(object):
45  """Wrapper around TF_Graph that handles deletion."""
46
47  __slots__ = ["graph", "deleter"]
48
49  def __init__(self):
50    self.graph = c_api.TF_NewGraph()
51    # Note: when we're destructing the global context (i.e when the process is
52    # terminating) we may have already deleted other modules. By capturing the
53    # DeleteGraph function here, we retain the ability to cleanly destroy the
54    # graph at shutdown, which satisfies leak checkers.
55    self.deleter = c_api.TF_DeleteGraph
56
57  def __del__(self):
58    self.deleter(self.graph)
59
60
61class ScopedTFImportGraphDefOptions(object):
62  """Wrapper around TF_ImportGraphDefOptions that handles deletion."""
63
64  __slots__ = ["options"]
65
66  def __init__(self):
67    self.options = c_api.TF_NewImportGraphDefOptions()
68
69  def __del__(self):
70    # Note: when we're destructing the global context (i.e when the process is
71    # terminating) we can have already deleted other modules.
72    if c_api is not None and c_api.TF_DeleteImportGraphDefOptions is not None:
73      c_api.TF_DeleteImportGraphDefOptions(self.options)
74
75
76class ScopedTFImportGraphDefResults(object):
77  """Wrapper around TF_ImportGraphDefOptions that handles deletion."""
78
79  __slots__ = ["results"]
80
81  def __init__(self, results):
82    self.results = results
83
84  def __del__(self):
85    # Note: when we're destructing the global context (i.e when the process is
86    # terminating) we can have already deleted other modules.
87    if c_api is not None and c_api.TF_DeleteImportGraphDefResults is not None:
88      c_api.TF_DeleteImportGraphDefResults(self.results)
89
90
91class ScopedTFFunction(object):
92  """Wrapper around TF_Function that handles deletion."""
93
94  __slots__ = ["func", "deleter"]
95
96  def __init__(self, func):
97    self.func = func
98    # Note: when we're destructing the global context (i.e when the process is
99    # terminating) we may have already deleted other modules. By capturing the
100    # DeleteFunction function here, we retain the ability to cleanly destroy the
101    # Function at shutdown, which satisfies leak checkers.
102    self.deleter = c_api.TF_DeleteFunction
103
104  def __del__(self):
105    if self.func is not None:
106      self.deleter(self.func)
107      self.func = None
108
109
110class ScopedTFBuffer(object):
111  """An internal class to help manage the TF_Buffer lifetime."""
112
113  __slots__ = ["buffer"]
114
115  def __init__(self, buf_string):
116    self.buffer = c_api.TF_NewBufferFromString(compat.as_bytes(buf_string))
117
118  def __del__(self):
119    c_api.TF_DeleteBuffer(self.buffer)
120
121
122class ApiDefMap(object):
123  """Wrapper around Tf_ApiDefMap that handles querying and deletion.
124
125  The OpDef protos are also stored in this class so that they could
126  be queried by op name.
127  """
128
129  __slots__ = ["_api_def_map", "_op_per_name"]
130
131  def __init__(self):
132    op_def_proto = op_def_pb2.OpList()
133    buf = c_api.TF_GetAllOpList()
134    try:
135      op_def_proto.ParseFromString(c_api.TF_GetBuffer(buf))
136      self._api_def_map = c_api.TF_NewApiDefMap(buf)
137    finally:
138      c_api.TF_DeleteBuffer(buf)
139
140    self._op_per_name = {}
141    for op in op_def_proto.op:
142      self._op_per_name[op.name] = op
143
144  def __del__(self):
145    # Note: when we're destructing the global context (i.e when the process is
146    # terminating) we can have already deleted other modules.
147    if c_api is not None and c_api.TF_DeleteApiDefMap is not None:
148      c_api.TF_DeleteApiDefMap(self._api_def_map)
149
150  def put_api_def(self, text):
151    c_api.TF_ApiDefMapPut(self._api_def_map, text, len(text))
152
153  def get_api_def(self, op_name):
154    api_def_proto = api_def_pb2.ApiDef()
155    buf = c_api.TF_ApiDefMapGet(self._api_def_map, op_name, len(op_name))
156    try:
157      api_def_proto.ParseFromString(c_api.TF_GetBuffer(buf))
158    finally:
159      c_api.TF_DeleteBuffer(buf)
160    return api_def_proto
161
162  def get_op_def(self, op_name):
163    if op_name in self._op_per_name:
164      return self._op_per_name[op_name]
165    raise ValueError("No entry found for " + op_name + ".")
166
167  def op_names(self):
168    return self._op_per_name.keys()
169
170
171@tf_contextlib.contextmanager
172def tf_buffer(data=None):
173  """Context manager that creates and deletes TF_Buffer.
174
175  Example usage:
176    with tf_buffer() as buf:
177      # get serialized graph def into buf
178      ...
179      proto_data = c_api.TF_GetBuffer(buf)
180      graph_def.ParseFromString(compat.as_bytes(proto_data))
181    # buf has been deleted
182
183    with tf_buffer(some_string) as buf:
184      c_api.TF_SomeFunction(buf)
185    # buf has been deleted
186
187  Args:
188    data: An optional `bytes`, `str`, or `unicode` object. If not None, the
189      yielded buffer will contain this data.
190
191  Yields:
192    Created TF_Buffer
193  """
194  if data:
195    buf = c_api.TF_NewBufferFromString(compat.as_bytes(data))
196  else:
197    buf = c_api.TF_NewBuffer()
198  try:
199    yield buf
200  finally:
201    c_api.TF_DeleteBuffer(buf)
202
203
204def tf_output(c_op, index):
205  """Returns a wrapped TF_Output with specified operation and index.
206
207  Args:
208    c_op: wrapped TF_Operation
209    index: integer
210
211  Returns:
212    Wrapped TF_Output
213  """
214  ret = c_api.TF_Output()
215  ret.oper = c_op
216  ret.index = index
217  return ret
218
219
220def tf_operations(graph):
221  """Generator that yields every TF_Operation in `graph`.
222
223  Args:
224    graph: Graph
225
226  Yields:
227    wrapped TF_Operation
228  """
229  # pylint: disable=protected-access
230  pos = 0
231  c_op, pos = c_api.TF_GraphNextOperation(graph._c_graph, pos)
232  while c_op is not None:
233    yield c_op
234    c_op, pos = c_api.TF_GraphNextOperation(graph._c_graph, pos)
235  # pylint: enable=protected-access
236
237
238def new_tf_operations(graph):
239  """Generator that yields newly-added TF_Operations in `graph`.
240
241  Specifically, yields TF_Operations that don't have associated Operations in
242  `graph`. This is useful for processing nodes added by the C API.
243
244  Args:
245    graph: Graph
246
247  Yields:
248    wrapped TF_Operation
249  """
250  # TODO(b/69679162): do this more efficiently
251  for c_op in tf_operations(graph):
252    try:
253      graph._get_operation_by_tf_operation(c_op)  # pylint: disable=protected-access
254    except KeyError:
255      yield c_op
256