1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Utilities for using the TensorFlow C API.""" 17 18import contextlib 19from tensorflow.core.framework import api_def_pb2 20from tensorflow.core.framework import op_def_pb2 21from tensorflow.python.client import pywrap_tf_session as c_api 22from tensorflow.python.util import compat 23from tensorflow.python.util import tf_contextlib 24 25 26class AlreadyGarbageCollectedError(Exception): 27 28 def __init__(self, name, obj_type): 29 super(AlreadyGarbageCollectedError, 30 self).__init__(f"{name} of type {obj_type} has already been garbage " 31 f"collected and cannot be called.") 32 33 34# FIXME(b/235488206): Convert all Scoped objects to the context manager 35# to protect against deletion during use when the object is attached to 36# an attribute. 37class UniquePtr(object): 38 """Wrapper around single-ownership C-API objects that handles deletion.""" 39 40 __slots__ = ["_obj", "deleter", "name", "type_name"] 41 42 def __init__(self, name, obj, deleter): 43 # '_' prefix marks _obj private, but unclear if it is required also to 44 # maintain a special CPython destruction order. 45 self._obj = obj 46 self.name = name 47 # Note: when we're destructing the global context (i.e when the process is 48 # terminating) we may have already deleted other modules. By capturing the 49 # DeleteGraph function here, we retain the ability to cleanly destroy the 50 # graph at shutdown, which satisfies leak checkers. 51 self.deleter = deleter 52 self.type_name = str(type(obj)) 53 54 @contextlib.contextmanager 55 def get(self): 56 """Yields the managed C-API Object, guaranteeing aliveness. 57 58 This is a context manager. Inside the context the C-API object is 59 guaranteed to be alive. 60 61 Raises: 62 AlreadyGarbageCollectedError: if the object is already deleted. 63 """ 64 # Thread-safety: self.__del__ never runs during the call of this function 65 # because there is a reference to self from the argument list. 66 if self._obj is None: 67 raise AlreadyGarbageCollectedError(self.name, self.type_name) 68 yield self._obj 69 70 def __del__(self): 71 obj = self._obj 72 if obj is not None: 73 self._obj = None 74 self.deleter(obj) 75 76 77class ScopedTFStatus(object): 78 """Wrapper around TF_Status that handles deletion.""" 79 80 __slots__ = ["status"] 81 82 def __init__(self): 83 self.status = c_api.TF_NewStatus() 84 85 def __del__(self): 86 # Note: when we're destructing the global context (i.e when the process is 87 # terminating) we can have already deleted other modules. 88 if c_api is not None and c_api.TF_DeleteStatus is not None: 89 c_api.TF_DeleteStatus(self.status) 90 91 92class ScopedTFGraph(UniquePtr): 93 """Wrapper around TF_Graph that handles deletion.""" 94 95 def __init__(self, name): 96 super(ScopedTFGraph, self).__init__( 97 name, obj=c_api.TF_NewGraph(), deleter=c_api.TF_DeleteGraph) 98 99 100class ScopedTFImportGraphDefOptions(object): 101 """Wrapper around TF_ImportGraphDefOptions that handles deletion.""" 102 103 __slots__ = ["options"] 104 105 def __init__(self): 106 self.options = c_api.TF_NewImportGraphDefOptions() 107 108 def __del__(self): 109 # Note: when we're destructing the global context (i.e when the process is 110 # terminating) we can have already deleted other modules. 111 if c_api is not None and c_api.TF_DeleteImportGraphDefOptions is not None: 112 c_api.TF_DeleteImportGraphDefOptions(self.options) 113 114 115class ScopedTFImportGraphDefResults(object): 116 """Wrapper around TF_ImportGraphDefOptions that handles deletion.""" 117 118 __slots__ = ["results"] 119 120 def __init__(self, results): 121 self.results = results 122 123 def __del__(self): 124 # Note: when we're destructing the global context (i.e when the process is 125 # terminating) we can have already deleted other modules. 126 if c_api is not None and c_api.TF_DeleteImportGraphDefResults is not None: 127 c_api.TF_DeleteImportGraphDefResults(self.results) 128 129 130class ScopedTFFunction(UniquePtr): 131 """Wrapper around TF_Function that handles deletion.""" 132 133 def __init__(self, func, name): 134 super(ScopedTFFunction, self).__init__( 135 name=name, obj=func, deleter=c_api.TF_DeleteFunction) 136 137 138class ScopedTFBuffer(object): 139 """An internal class to help manage the TF_Buffer lifetime.""" 140 141 __slots__ = ["buffer"] 142 143 def __init__(self, buf_string): 144 self.buffer = c_api.TF_NewBufferFromString(compat.as_bytes(buf_string)) 145 146 def __del__(self): 147 c_api.TF_DeleteBuffer(self.buffer) 148 149 150class ApiDefMap(object): 151 """Wrapper around Tf_ApiDefMap that handles querying and deletion. 152 153 The OpDef protos are also stored in this class so that they could 154 be queried by op name. 155 """ 156 157 __slots__ = ["_api_def_map", "_op_per_name"] 158 159 def __init__(self): 160 op_def_proto = op_def_pb2.OpList() 161 buf = c_api.TF_GetAllOpList() 162 try: 163 op_def_proto.ParseFromString(c_api.TF_GetBuffer(buf)) 164 self._api_def_map = c_api.TF_NewApiDefMap(buf) 165 finally: 166 c_api.TF_DeleteBuffer(buf) 167 168 self._op_per_name = {} 169 for op in op_def_proto.op: 170 self._op_per_name[op.name] = op 171 172 def __del__(self): 173 # Note: when we're destructing the global context (i.e when the process is 174 # terminating) we can have already deleted other modules. 175 if c_api is not None and c_api.TF_DeleteApiDefMap is not None: 176 c_api.TF_DeleteApiDefMap(self._api_def_map) 177 178 def put_api_def(self, text): 179 c_api.TF_ApiDefMapPut(self._api_def_map, text, len(text)) 180 181 def get_api_def(self, op_name): 182 api_def_proto = api_def_pb2.ApiDef() 183 buf = c_api.TF_ApiDefMapGet(self._api_def_map, op_name, len(op_name)) 184 try: 185 api_def_proto.ParseFromString(c_api.TF_GetBuffer(buf)) 186 finally: 187 c_api.TF_DeleteBuffer(buf) 188 return api_def_proto 189 190 def get_op_def(self, op_name): 191 if op_name in self._op_per_name: 192 return self._op_per_name[op_name] 193 raise ValueError(f"No op_def found for op name {op_name}.") 194 195 def op_names(self): 196 return self._op_per_name.keys() 197 198 199@tf_contextlib.contextmanager 200def tf_buffer(data=None): 201 """Context manager that creates and deletes TF_Buffer. 202 203 Example usage: 204 with tf_buffer() as buf: 205 # get serialized graph def into buf 206 ... 207 proto_data = c_api.TF_GetBuffer(buf) 208 graph_def.ParseFromString(compat.as_bytes(proto_data)) 209 # buf has been deleted 210 211 with tf_buffer(some_string) as buf: 212 c_api.TF_SomeFunction(buf) 213 # buf has been deleted 214 215 Args: 216 data: An optional `bytes`, `str`, or `unicode` object. If not None, the 217 yielded buffer will contain this data. 218 219 Yields: 220 Created TF_Buffer 221 """ 222 if data: 223 buf = c_api.TF_NewBufferFromString(compat.as_bytes(data)) 224 else: 225 buf = c_api.TF_NewBuffer() 226 try: 227 yield buf 228 finally: 229 c_api.TF_DeleteBuffer(buf) 230 231 232def tf_output(c_op, index): 233 """Returns a wrapped TF_Output with specified operation and index. 234 235 Args: 236 c_op: wrapped TF_Operation 237 index: integer 238 239 Returns: 240 Wrapped TF_Output 241 """ 242 ret = c_api.TF_Output() 243 ret.oper = c_op 244 ret.index = index 245 return ret 246 247 248def tf_operations(graph): 249 """Generator that yields every TF_Operation in `graph`. 250 251 Args: 252 graph: Graph 253 254 Yields: 255 wrapped TF_Operation 256 """ 257 # pylint: disable=protected-access 258 pos = 0 259 with graph._c_graph.get() as c_graph: 260 c_op, pos = c_api.TF_GraphNextOperation(c_graph, pos) 261 while c_op is not None: 262 yield c_op 263 c_op, pos = c_api.TF_GraphNextOperation(c_graph, pos) 264 # pylint: enable=protected-access 265 266 267def new_tf_operations(graph): 268 """Generator that yields newly-added TF_Operations in `graph`. 269 270 Specifically, yields TF_Operations that don't have associated Operations in 271 `graph`. This is useful for processing nodes added by the C API. 272 273 Args: 274 graph: Graph 275 276 Yields: 277 wrapped TF_Operation 278 """ 279 # TODO(b/69679162): do this more efficiently 280 for c_op in tf_operations(graph): 281 try: 282 graph._get_operation_by_tf_operation(c_op) # pylint: disable=protected-access 283 except KeyError: 284 yield c_op 285