1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Utilities for using the TensorFlow C API.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22from tensorflow.core.framework import api_def_pb2 23from tensorflow.core.framework import op_def_pb2 24from tensorflow.python.client import pywrap_tf_session as c_api 25from tensorflow.python.util import compat 26from tensorflow.python.util import tf_contextlib 27 28 29class ScopedTFStatus(object): 30 """Wrapper around TF_Status that handles deletion.""" 31 32 __slots__ = ["status"] 33 34 def __init__(self): 35 self.status = c_api.TF_NewStatus() 36 37 def __del__(self): 38 # Note: when we're destructing the global context (i.e when the process is 39 # terminating) we can have already deleted other modules. 40 if c_api is not None and c_api.TF_DeleteStatus is not None: 41 c_api.TF_DeleteStatus(self.status) 42 43 44class ScopedTFGraph(object): 45 """Wrapper around TF_Graph that handles deletion.""" 46 47 __slots__ = ["graph", "deleter"] 48 49 def __init__(self): 50 self.graph = c_api.TF_NewGraph() 51 # Note: when we're destructing the global context (i.e when the process is 52 # terminating) we may have already deleted other modules. By capturing the 53 # DeleteGraph function here, we retain the ability to cleanly destroy the 54 # graph at shutdown, which satisfies leak checkers. 55 self.deleter = c_api.TF_DeleteGraph 56 57 def __del__(self): 58 self.deleter(self.graph) 59 60 61class ScopedTFImportGraphDefOptions(object): 62 """Wrapper around TF_ImportGraphDefOptions that handles deletion.""" 63 64 __slots__ = ["options"] 65 66 def __init__(self): 67 self.options = c_api.TF_NewImportGraphDefOptions() 68 69 def __del__(self): 70 # Note: when we're destructing the global context (i.e when the process is 71 # terminating) we can have already deleted other modules. 72 if c_api is not None and c_api.TF_DeleteImportGraphDefOptions is not None: 73 c_api.TF_DeleteImportGraphDefOptions(self.options) 74 75 76class ScopedTFImportGraphDefResults(object): 77 """Wrapper around TF_ImportGraphDefOptions that handles deletion.""" 78 79 __slots__ = ["results"] 80 81 def __init__(self, results): 82 self.results = results 83 84 def __del__(self): 85 # Note: when we're destructing the global context (i.e when the process is 86 # terminating) we can have already deleted other modules. 87 if c_api is not None and c_api.TF_DeleteImportGraphDefResults is not None: 88 c_api.TF_DeleteImportGraphDefResults(self.results) 89 90 91class ScopedTFFunction(object): 92 """Wrapper around TF_Function that handles deletion.""" 93 94 __slots__ = ["func", "deleter"] 95 96 def __init__(self, func): 97 self.func = func 98 # Note: when we're destructing the global context (i.e when the process is 99 # terminating) we may have already deleted other modules. By capturing the 100 # DeleteFunction function here, we retain the ability to cleanly destroy the 101 # Function at shutdown, which satisfies leak checkers. 102 self.deleter = c_api.TF_DeleteFunction 103 104 @property 105 def has_been_garbage_collected(self): 106 return self.func is None 107 108 def __del__(self): 109 if not self.has_been_garbage_collected: 110 self.deleter(self.func) 111 self.func = None 112 113 114class ScopedTFBuffer(object): 115 """An internal class to help manage the TF_Buffer lifetime.""" 116 117 __slots__ = ["buffer"] 118 119 def __init__(self, buf_string): 120 self.buffer = c_api.TF_NewBufferFromString(compat.as_bytes(buf_string)) 121 122 def __del__(self): 123 c_api.TF_DeleteBuffer(self.buffer) 124 125 126class ApiDefMap(object): 127 """Wrapper around Tf_ApiDefMap that handles querying and deletion. 128 129 The OpDef protos are also stored in this class so that they could 130 be queried by op name. 131 """ 132 133 __slots__ = ["_api_def_map", "_op_per_name"] 134 135 def __init__(self): 136 op_def_proto = op_def_pb2.OpList() 137 buf = c_api.TF_GetAllOpList() 138 try: 139 op_def_proto.ParseFromString(c_api.TF_GetBuffer(buf)) 140 self._api_def_map = c_api.TF_NewApiDefMap(buf) 141 finally: 142 c_api.TF_DeleteBuffer(buf) 143 144 self._op_per_name = {} 145 for op in op_def_proto.op: 146 self._op_per_name[op.name] = op 147 148 def __del__(self): 149 # Note: when we're destructing the global context (i.e when the process is 150 # terminating) we can have already deleted other modules. 151 if c_api is not None and c_api.TF_DeleteApiDefMap is not None: 152 c_api.TF_DeleteApiDefMap(self._api_def_map) 153 154 def put_api_def(self, text): 155 c_api.TF_ApiDefMapPut(self._api_def_map, text, len(text)) 156 157 def get_api_def(self, op_name): 158 api_def_proto = api_def_pb2.ApiDef() 159 buf = c_api.TF_ApiDefMapGet(self._api_def_map, op_name, len(op_name)) 160 try: 161 api_def_proto.ParseFromString(c_api.TF_GetBuffer(buf)) 162 finally: 163 c_api.TF_DeleteBuffer(buf) 164 return api_def_proto 165 166 def get_op_def(self, op_name): 167 if op_name in self._op_per_name: 168 return self._op_per_name[op_name] 169 raise ValueError("No entry found for " + op_name + ".") 170 171 def op_names(self): 172 return self._op_per_name.keys() 173 174 175@tf_contextlib.contextmanager 176def tf_buffer(data=None): 177 """Context manager that creates and deletes TF_Buffer. 178 179 Example usage: 180 with tf_buffer() as buf: 181 # get serialized graph def into buf 182 ... 183 proto_data = c_api.TF_GetBuffer(buf) 184 graph_def.ParseFromString(compat.as_bytes(proto_data)) 185 # buf has been deleted 186 187 with tf_buffer(some_string) as buf: 188 c_api.TF_SomeFunction(buf) 189 # buf has been deleted 190 191 Args: 192 data: An optional `bytes`, `str`, or `unicode` object. If not None, the 193 yielded buffer will contain this data. 194 195 Yields: 196 Created TF_Buffer 197 """ 198 if data: 199 buf = c_api.TF_NewBufferFromString(compat.as_bytes(data)) 200 else: 201 buf = c_api.TF_NewBuffer() 202 try: 203 yield buf 204 finally: 205 c_api.TF_DeleteBuffer(buf) 206 207 208def tf_output(c_op, index): 209 """Returns a wrapped TF_Output with specified operation and index. 210 211 Args: 212 c_op: wrapped TF_Operation 213 index: integer 214 215 Returns: 216 Wrapped TF_Output 217 """ 218 ret = c_api.TF_Output() 219 ret.oper = c_op 220 ret.index = index 221 return ret 222 223 224def tf_operations(graph): 225 """Generator that yields every TF_Operation in `graph`. 226 227 Args: 228 graph: Graph 229 230 Yields: 231 wrapped TF_Operation 232 """ 233 # pylint: disable=protected-access 234 pos = 0 235 c_op, pos = c_api.TF_GraphNextOperation(graph._c_graph, pos) 236 while c_op is not None: 237 yield c_op 238 c_op, pos = c_api.TF_GraphNextOperation(graph._c_graph, pos) 239 # pylint: enable=protected-access 240 241 242def new_tf_operations(graph): 243 """Generator that yields newly-added TF_Operations in `graph`. 244 245 Specifically, yields TF_Operations that don't have associated Operations in 246 `graph`. This is useful for processing nodes added by the C API. 247 248 Args: 249 graph: Graph 250 251 Yields: 252 wrapped TF_Operation 253 """ 254 # TODO(b/69679162): do this more efficiently 255 for c_op in tf_operations(graph): 256 try: 257 graph._get_operation_by_tf_operation(c_op) # pylint: disable=protected-access 258 except KeyError: 259 yield c_op 260