1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Utilities for using the TensorFlow C API.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22from tensorflow.core.framework import api_def_pb2 23from tensorflow.core.framework import op_def_pb2 24from tensorflow.python.client import pywrap_tf_session as c_api 25from tensorflow.python.util import compat 26from tensorflow.python.util import tf_contextlib 27 28 29class ScopedTFStatus(object): 30 """Wrapper around TF_Status that handles deletion.""" 31 32 def __init__(self): 33 self.status = c_api.TF_NewStatus() 34 35 def __del__(self): 36 # Note: when we're destructing the global context (i.e when the process is 37 # terminating) we can have already deleted other modules. 38 if c_api is not None and c_api.TF_DeleteStatus is not None: 39 c_api.TF_DeleteStatus(self.status) 40 41 42class ScopedTFGraph(object): 43 """Wrapper around TF_Graph that handles deletion.""" 44 45 def __init__(self): 46 self.graph = c_api.TF_NewGraph() 47 # Note: when we're destructing the global context (i.e when the process is 48 # terminating) we may have already deleted other modules. By capturing the 49 # DeleteGraph function here, we retain the ability to cleanly destroy the 50 # graph at shutdown, which satisfies leak checkers. 51 self.deleter = c_api.TF_DeleteGraph 52 53 def __del__(self): 54 self.deleter(self.graph) 55 56 57class ScopedTFImportGraphDefOptions(object): 58 """Wrapper around TF_ImportGraphDefOptions that handles deletion.""" 59 60 def __init__(self): 61 self.options = c_api.TF_NewImportGraphDefOptions() 62 63 def __del__(self): 64 # Note: when we're destructing the global context (i.e when the process is 65 # terminating) we can have already deleted other modules. 66 if c_api is not None and c_api.TF_DeleteImportGraphDefOptions is not None: 67 c_api.TF_DeleteImportGraphDefOptions(self.options) 68 69 70class ScopedTFImportGraphDefResults(object): 71 """Wrapper around TF_ImportGraphDefOptions that handles deletion.""" 72 73 def __init__(self, results): 74 self.results = results 75 76 def __del__(self): 77 # Note: when we're destructing the global context (i.e when the process is 78 # terminating) we can have already deleted other modules. 79 if c_api is not None and c_api.TF_DeleteImportGraphDefResults is not None: 80 c_api.TF_DeleteImportGraphDefResults(self.results) 81 82 83class ScopedTFFunction(object): 84 """Wrapper around TF_Function that handles deletion.""" 85 86 def __init__(self, func): 87 self.func = func 88 # Note: when we're destructing the global context (i.e when the process is 89 # terminating) we may have already deleted other modules. By capturing the 90 # DeleteFunction function here, we retain the ability to cleanly destroy the 91 # Function at shutdown, which satisfies leak checkers. 92 self.deleter = c_api.TF_DeleteFunction 93 94 def __del__(self): 95 if self.func is not None: 96 self.deleter(self.func) 97 self.func = None 98 99 100class ApiDefMap(object): 101 """Wrapper around Tf_ApiDefMap that handles querying and deletion. 102 103 The OpDef protos are also stored in this class so that they could 104 be queried by op name. 105 """ 106 107 def __init__(self): 108 op_def_proto = op_def_pb2.OpList() 109 buf = c_api.TF_GetAllOpList() 110 try: 111 op_def_proto.ParseFromString(c_api.TF_GetBuffer(buf)) 112 self._api_def_map = c_api.TF_NewApiDefMap(buf) 113 finally: 114 c_api.TF_DeleteBuffer(buf) 115 116 self._op_per_name = {} 117 for op in op_def_proto.op: 118 self._op_per_name[op.name] = op 119 120 def __del__(self): 121 # Note: when we're destructing the global context (i.e when the process is 122 # terminating) we can have already deleted other modules. 123 if c_api is not None and c_api.TF_DeleteApiDefMap is not None: 124 c_api.TF_DeleteApiDefMap(self._api_def_map) 125 126 def put_api_def(self, text): 127 c_api.TF_ApiDefMapPut(self._api_def_map, text, len(text)) 128 129 def get_api_def(self, op_name): 130 api_def_proto = api_def_pb2.ApiDef() 131 buf = c_api.TF_ApiDefMapGet(self._api_def_map, op_name, len(op_name)) 132 try: 133 api_def_proto.ParseFromString(c_api.TF_GetBuffer(buf)) 134 finally: 135 c_api.TF_DeleteBuffer(buf) 136 return api_def_proto 137 138 def get_op_def(self, op_name): 139 if op_name in self._op_per_name: 140 return self._op_per_name[op_name] 141 raise ValueError("No entry found for " + op_name + ".") 142 143 def op_names(self): 144 return self._op_per_name.keys() 145 146 147@tf_contextlib.contextmanager 148def tf_buffer(data=None): 149 """Context manager that creates and deletes TF_Buffer. 150 151 Example usage: 152 with tf_buffer() as buf: 153 # get serialized graph def into buf 154 ... 155 proto_data = c_api.TF_GetBuffer(buf) 156 graph_def.ParseFromString(compat.as_bytes(proto_data)) 157 # buf has been deleted 158 159 with tf_buffer(some_string) as buf: 160 c_api.TF_SomeFunction(buf) 161 # buf has been deleted 162 163 Args: 164 data: An optional `bytes`, `str`, or `unicode` object. If not None, the 165 yielded buffer will contain this data. 166 167 Yields: 168 Created TF_Buffer 169 """ 170 if data: 171 buf = c_api.TF_NewBufferFromString(compat.as_bytes(data)) 172 else: 173 buf = c_api.TF_NewBuffer() 174 try: 175 yield buf 176 finally: 177 c_api.TF_DeleteBuffer(buf) 178 179 180def tf_output(c_op, index): 181 """Returns a wrapped TF_Output with specified operation and index. 182 183 Args: 184 c_op: wrapped TF_Operation 185 index: integer 186 187 Returns: 188 Wrapped TF_Output 189 """ 190 ret = c_api.TF_Output() 191 ret.oper = c_op 192 ret.index = index 193 return ret 194 195 196def tf_operations(graph): 197 """Generator that yields every TF_Operation in `graph`. 198 199 Args: 200 graph: Graph 201 202 Yields: 203 wrapped TF_Operation 204 """ 205 # pylint: disable=protected-access 206 pos = 0 207 c_op, pos = c_api.TF_GraphNextOperation(graph._c_graph, pos) 208 while c_op is not None: 209 yield c_op 210 c_op, pos = c_api.TF_GraphNextOperation(graph._c_graph, pos) 211 # pylint: enable=protected-access 212 213 214def new_tf_operations(graph): 215 """Generator that yields newly-added TF_Operations in `graph`. 216 217 Specifically, yields TF_Operations that don't have associated Operations in 218 `graph`. This is useful for processing nodes added by the C API. 219 220 Args: 221 graph: Graph 222 223 Yields: 224 wrapped TF_Operation 225 """ 226 # TODO(b/69679162): do this more efficiently 227 for c_op in tf_operations(graph): 228 try: 229 graph._get_operation_by_tf_operation(c_op) # pylint: disable=protected-access 230 except KeyError: 231 yield c_op 232