1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Utilities for using the TensorFlow C API.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22from tensorflow.core.framework import api_def_pb2 23from tensorflow.core.framework import op_def_pb2 24from tensorflow.python.client import pywrap_tf_session as c_api 25from tensorflow.python.util import compat 26from tensorflow.python.util import tf_contextlib 27 28 29class ScopedTFStatus(object): 30 """Wrapper around TF_Status that handles deletion.""" 31 32 __slots__ = ["status"] 33 34 def __init__(self): 35 self.status = c_api.TF_NewStatus() 36 37 def __del__(self): 38 # Note: when we're destructing the global context (i.e when the process is 39 # terminating) we can have already deleted other modules. 40 if c_api is not None and c_api.TF_DeleteStatus is not None: 41 c_api.TF_DeleteStatus(self.status) 42 43 44class ScopedTFGraph(object): 45 """Wrapper around TF_Graph that handles deletion.""" 46 47 __slots__ = ["graph", "deleter"] 48 49 def __init__(self): 50 self.graph = c_api.TF_NewGraph() 51 # Note: when we're destructing the global context (i.e when the process is 52 # terminating) we may have already deleted other modules. By capturing the 53 # DeleteGraph function here, we retain the ability to cleanly destroy the 54 # graph at shutdown, which satisfies leak checkers. 55 self.deleter = c_api.TF_DeleteGraph 56 57 def __del__(self): 58 self.deleter(self.graph) 59 60 61class ScopedTFImportGraphDefOptions(object): 62 """Wrapper around TF_ImportGraphDefOptions that handles deletion.""" 63 64 __slots__ = ["options"] 65 66 def __init__(self): 67 self.options = c_api.TF_NewImportGraphDefOptions() 68 69 def __del__(self): 70 # Note: when we're destructing the global context (i.e when the process is 71 # terminating) we can have already deleted other modules. 72 if c_api is not None and c_api.TF_DeleteImportGraphDefOptions is not None: 73 c_api.TF_DeleteImportGraphDefOptions(self.options) 74 75 76class ScopedTFImportGraphDefResults(object): 77 """Wrapper around TF_ImportGraphDefOptions that handles deletion.""" 78 79 __slots__ = ["results"] 80 81 def __init__(self, results): 82 self.results = results 83 84 def __del__(self): 85 # Note: when we're destructing the global context (i.e when the process is 86 # terminating) we can have already deleted other modules. 87 if c_api is not None and c_api.TF_DeleteImportGraphDefResults is not None: 88 c_api.TF_DeleteImportGraphDefResults(self.results) 89 90 91class ScopedTFFunction(object): 92 """Wrapper around TF_Function that handles deletion.""" 93 94 __slots__ = ["func", "deleter"] 95 96 def __init__(self, func): 97 self.func = func 98 # Note: when we're destructing the global context (i.e when the process is 99 # terminating) we may have already deleted other modules. By capturing the 100 # DeleteFunction function here, we retain the ability to cleanly destroy the 101 # Function at shutdown, which satisfies leak checkers. 102 self.deleter = c_api.TF_DeleteFunction 103 104 def __del__(self): 105 if self.func is not None: 106 self.deleter(self.func) 107 self.func = None 108 109 110class ScopedTFBuffer(object): 111 """An internal class to help manage the TF_Buffer lifetime.""" 112 113 __slots__ = ["buffer"] 114 115 def __init__(self, buf_string): 116 self.buffer = c_api.TF_NewBufferFromString(compat.as_bytes(buf_string)) 117 118 def __del__(self): 119 c_api.TF_DeleteBuffer(self.buffer) 120 121 122class ApiDefMap(object): 123 """Wrapper around Tf_ApiDefMap that handles querying and deletion. 124 125 The OpDef protos are also stored in this class so that they could 126 be queried by op name. 127 """ 128 129 __slots__ = ["_api_def_map", "_op_per_name"] 130 131 def __init__(self): 132 op_def_proto = op_def_pb2.OpList() 133 buf = c_api.TF_GetAllOpList() 134 try: 135 op_def_proto.ParseFromString(c_api.TF_GetBuffer(buf)) 136 self._api_def_map = c_api.TF_NewApiDefMap(buf) 137 finally: 138 c_api.TF_DeleteBuffer(buf) 139 140 self._op_per_name = {} 141 for op in op_def_proto.op: 142 self._op_per_name[op.name] = op 143 144 def __del__(self): 145 # Note: when we're destructing the global context (i.e when the process is 146 # terminating) we can have already deleted other modules. 147 if c_api is not None and c_api.TF_DeleteApiDefMap is not None: 148 c_api.TF_DeleteApiDefMap(self._api_def_map) 149 150 def put_api_def(self, text): 151 c_api.TF_ApiDefMapPut(self._api_def_map, text, len(text)) 152 153 def get_api_def(self, op_name): 154 api_def_proto = api_def_pb2.ApiDef() 155 buf = c_api.TF_ApiDefMapGet(self._api_def_map, op_name, len(op_name)) 156 try: 157 api_def_proto.ParseFromString(c_api.TF_GetBuffer(buf)) 158 finally: 159 c_api.TF_DeleteBuffer(buf) 160 return api_def_proto 161 162 def get_op_def(self, op_name): 163 if op_name in self._op_per_name: 164 return self._op_per_name[op_name] 165 raise ValueError("No entry found for " + op_name + ".") 166 167 def op_names(self): 168 return self._op_per_name.keys() 169 170 171@tf_contextlib.contextmanager 172def tf_buffer(data=None): 173 """Context manager that creates and deletes TF_Buffer. 174 175 Example usage: 176 with tf_buffer() as buf: 177 # get serialized graph def into buf 178 ... 179 proto_data = c_api.TF_GetBuffer(buf) 180 graph_def.ParseFromString(compat.as_bytes(proto_data)) 181 # buf has been deleted 182 183 with tf_buffer(some_string) as buf: 184 c_api.TF_SomeFunction(buf) 185 # buf has been deleted 186 187 Args: 188 data: An optional `bytes`, `str`, or `unicode` object. If not None, the 189 yielded buffer will contain this data. 190 191 Yields: 192 Created TF_Buffer 193 """ 194 if data: 195 buf = c_api.TF_NewBufferFromString(compat.as_bytes(data)) 196 else: 197 buf = c_api.TF_NewBuffer() 198 try: 199 yield buf 200 finally: 201 c_api.TF_DeleteBuffer(buf) 202 203 204def tf_output(c_op, index): 205 """Returns a wrapped TF_Output with specified operation and index. 206 207 Args: 208 c_op: wrapped TF_Operation 209 index: integer 210 211 Returns: 212 Wrapped TF_Output 213 """ 214 ret = c_api.TF_Output() 215 ret.oper = c_op 216 ret.index = index 217 return ret 218 219 220def tf_operations(graph): 221 """Generator that yields every TF_Operation in `graph`. 222 223 Args: 224 graph: Graph 225 226 Yields: 227 wrapped TF_Operation 228 """ 229 # pylint: disable=protected-access 230 pos = 0 231 c_op, pos = c_api.TF_GraphNextOperation(graph._c_graph, pos) 232 while c_op is not None: 233 yield c_op 234 c_op, pos = c_api.TF_GraphNextOperation(graph._c_graph, pos) 235 # pylint: enable=protected-access 236 237 238def new_tf_operations(graph): 239 """Generator that yields newly-added TF_Operations in `graph`. 240 241 Specifically, yields TF_Operations that don't have associated Operations in 242 `graph`. This is useful for processing nodes added by the C API. 243 244 Args: 245 graph: Graph 246 247 Yields: 248 wrapped TF_Operation 249 """ 250 # TODO(b/69679162): do this more efficiently 251 for c_op in tf_operations(graph): 252 try: 253 graph._get_operation_by_tf_operation(c_op) # pylint: disable=protected-access 254 except KeyError: 255 yield c_op 256