• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Utilities for using the TensorFlow C API."""
17
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22from tensorflow.core.framework import api_def_pb2
23from tensorflow.core.framework import op_def_pb2
24from tensorflow.python.client import pywrap_tf_session as c_api
25from tensorflow.python.util import compat
26from tensorflow.python.util import tf_contextlib
27
28
29class ScopedTFStatus(object):
30  """Wrapper around TF_Status that handles deletion."""
31
32  def __init__(self):
33    self.status = c_api.TF_NewStatus()
34
35  def __del__(self):
36    # Note: when we're destructing the global context (i.e when the process is
37    # terminating) we can have already deleted other modules.
38    if c_api is not None and c_api.TF_DeleteStatus is not None:
39      c_api.TF_DeleteStatus(self.status)
40
41
42class ScopedTFGraph(object):
43  """Wrapper around TF_Graph that handles deletion."""
44
45  def __init__(self):
46    self.graph = c_api.TF_NewGraph()
47    # Note: when we're destructing the global context (i.e when the process is
48    # terminating) we may have already deleted other modules. By capturing the
49    # DeleteGraph function here, we retain the ability to cleanly destroy the
50    # graph at shutdown, which satisfies leak checkers.
51    self.deleter = c_api.TF_DeleteGraph
52
53  def __del__(self):
54    self.deleter(self.graph)
55
56
57class ScopedTFImportGraphDefOptions(object):
58  """Wrapper around TF_ImportGraphDefOptions that handles deletion."""
59
60  def __init__(self):
61    self.options = c_api.TF_NewImportGraphDefOptions()
62
63  def __del__(self):
64    # Note: when we're destructing the global context (i.e when the process is
65    # terminating) we can have already deleted other modules.
66    if c_api is not None and c_api.TF_DeleteImportGraphDefOptions is not None:
67      c_api.TF_DeleteImportGraphDefOptions(self.options)
68
69
70class ScopedTFImportGraphDefResults(object):
71  """Wrapper around TF_ImportGraphDefOptions that handles deletion."""
72
73  def __init__(self, results):
74    self.results = results
75
76  def __del__(self):
77    # Note: when we're destructing the global context (i.e when the process is
78    # terminating) we can have already deleted other modules.
79    if c_api is not None and c_api.TF_DeleteImportGraphDefResults is not None:
80      c_api.TF_DeleteImportGraphDefResults(self.results)
81
82
83class ScopedTFFunction(object):
84  """Wrapper around TF_Function that handles deletion."""
85
86  def __init__(self, func):
87    self.func = func
88    # Note: when we're destructing the global context (i.e when the process is
89    # terminating) we may have already deleted other modules. By capturing the
90    # DeleteFunction function here, we retain the ability to cleanly destroy the
91    # Function at shutdown, which satisfies leak checkers.
92    self.deleter = c_api.TF_DeleteFunction
93
94  def __del__(self):
95    if self.func is not None:
96      self.deleter(self.func)
97      self.func = None
98
99
100class ApiDefMap(object):
101  """Wrapper around Tf_ApiDefMap that handles querying and deletion.
102
103  The OpDef protos are also stored in this class so that they could
104  be queried by op name.
105  """
106
107  def __init__(self):
108    op_def_proto = op_def_pb2.OpList()
109    buf = c_api.TF_GetAllOpList()
110    try:
111      op_def_proto.ParseFromString(c_api.TF_GetBuffer(buf))
112      self._api_def_map = c_api.TF_NewApiDefMap(buf)
113    finally:
114      c_api.TF_DeleteBuffer(buf)
115
116    self._op_per_name = {}
117    for op in op_def_proto.op:
118      self._op_per_name[op.name] = op
119
120  def __del__(self):
121    # Note: when we're destructing the global context (i.e when the process is
122    # terminating) we can have already deleted other modules.
123    if c_api is not None and c_api.TF_DeleteApiDefMap is not None:
124      c_api.TF_DeleteApiDefMap(self._api_def_map)
125
126  def put_api_def(self, text):
127    c_api.TF_ApiDefMapPut(self._api_def_map, text, len(text))
128
129  def get_api_def(self, op_name):
130    api_def_proto = api_def_pb2.ApiDef()
131    buf = c_api.TF_ApiDefMapGet(self._api_def_map, op_name, len(op_name))
132    try:
133      api_def_proto.ParseFromString(c_api.TF_GetBuffer(buf))
134    finally:
135      c_api.TF_DeleteBuffer(buf)
136    return api_def_proto
137
138  def get_op_def(self, op_name):
139    if op_name in self._op_per_name:
140      return self._op_per_name[op_name]
141    raise ValueError("No entry found for " + op_name + ".")
142
143  def op_names(self):
144    return self._op_per_name.keys()
145
146
147@tf_contextlib.contextmanager
148def tf_buffer(data=None):
149  """Context manager that creates and deletes TF_Buffer.
150
151  Example usage:
152    with tf_buffer() as buf:
153      # get serialized graph def into buf
154      ...
155      proto_data = c_api.TF_GetBuffer(buf)
156      graph_def.ParseFromString(compat.as_bytes(proto_data))
157    # buf has been deleted
158
159    with tf_buffer(some_string) as buf:
160      c_api.TF_SomeFunction(buf)
161    # buf has been deleted
162
163  Args:
164    data: An optional `bytes`, `str`, or `unicode` object. If not None, the
165      yielded buffer will contain this data.
166
167  Yields:
168    Created TF_Buffer
169  """
170  if data:
171    buf = c_api.TF_NewBufferFromString(compat.as_bytes(data))
172  else:
173    buf = c_api.TF_NewBuffer()
174  try:
175    yield buf
176  finally:
177    c_api.TF_DeleteBuffer(buf)
178
179
180def tf_output(c_op, index):
181  """Returns a wrapped TF_Output with specified operation and index.
182
183  Args:
184    c_op: wrapped TF_Operation
185    index: integer
186
187  Returns:
188    Wrapped TF_Output
189  """
190  ret = c_api.TF_Output()
191  ret.oper = c_op
192  ret.index = index
193  return ret
194
195
196def tf_operations(graph):
197  """Generator that yields every TF_Operation in `graph`.
198
199  Args:
200    graph: Graph
201
202  Yields:
203    wrapped TF_Operation
204  """
205  # pylint: disable=protected-access
206  pos = 0
207  c_op, pos = c_api.TF_GraphNextOperation(graph._c_graph, pos)
208  while c_op is not None:
209    yield c_op
210    c_op, pos = c_api.TF_GraphNextOperation(graph._c_graph, pos)
211  # pylint: enable=protected-access
212
213
214def new_tf_operations(graph):
215  """Generator that yields newly-added TF_Operations in `graph`.
216
217  Specifically, yields TF_Operations that don't have associated Operations in
218  `graph`. This is useful for processing nodes added by the C API.
219
220  Args:
221    graph: Graph
222
223  Yields:
224    wrapped TF_Operation
225  """
226  # TODO(b/69679162): do this more efficiently
227  for c_op in tf_operations(graph):
228    try:
229      graph._get_operation_by_tf_operation(c_op)  # pylint: disable=protected-access
230    except KeyError:
231      yield c_op
232