1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15 16"""Utilities for using the TensorFlow C API.""" 17 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22from tensorflow.core.framework import api_def_pb2 23from tensorflow.core.framework import op_def_pb2 24from tensorflow.python import pywrap_tensorflow as c_api 25from tensorflow.python.util import compat 26from tensorflow.python.util import tf_contextlib 27 28 29class ScopedTFStatus(object): 30 """Wrapper around TF_Status that handles deletion.""" 31 32 def __init__(self): 33 self.status = c_api.TF_NewStatus() 34 35 def __del__(self): 36 # Note: when we're destructing the global context (i.e when the process is 37 # terminating) we can have already deleted other modules. 38 if c_api is not None and c_api.TF_DeleteStatus is not None: 39 c_api.TF_DeleteStatus(self.status) 40 41 42class ScopedTFGraph(object): 43 """Wrapper around TF_Graph that handles deletion.""" 44 45 def __init__(self): 46 self.graph = c_api.TF_NewGraph() 47 48 def __del__(self): 49 # Note: when we're destructing the global context (i.e when the process is 50 # terminating) we can have already deleted other modules. 51 if c_api is not None and c_api.TF_DeleteGraph is not None: 52 c_api.TF_DeleteGraph(self.graph) 53 54 55class ScopedTFImportGraphDefOptions(object): 56 """Wrapper around TF_ImportGraphDefOptions that handles deletion.""" 57 58 def __init__(self): 59 self.options = c_api.TF_NewImportGraphDefOptions() 60 61 def __del__(self): 62 # Note: when we're destructing the global context (i.e when the process is 63 # terminating) we can have already deleted other modules. 64 if c_api is not None and c_api.TF_DeleteImportGraphDefOptions is not None: 65 c_api.TF_DeleteImportGraphDefOptions(self.options) 66 67 68class ScopedTFImportGraphDefResults(object): 69 """Wrapper around TF_ImportGraphDefOptions that handles deletion.""" 70 71 def __init__(self, results): 72 self.results = results 73 74 def __del__(self): 75 # Note: when we're destructing the global context (i.e when the process is 76 # terminating) we can have already deleted other modules. 77 if c_api is not None and c_api.TF_DeleteImportGraphDefResults is not None: 78 c_api.TF_DeleteImportGraphDefResults(self.results) 79 80 81class ScopedTFFunction(object): 82 """Wrapper around TF_Function that handles deletion.""" 83 84 def __init__(self, func): 85 self.func = func 86 87 def __del__(self): 88 # Note: when we're destructing the global context (i.e when the process is 89 # terminating) we can have already deleted other modules. 90 if c_api is not None and c_api.TF_DeleteFunction is not None: 91 c_api.TF_DeleteFunction(self.func) 92 93 94class ApiDefMap(object): 95 """Wrapper around Tf_ApiDefMap that handles querying and deletion. 96 97 The OpDef protos are also stored in this class so that they could 98 be queried by op name. 99 """ 100 101 def __init__(self): 102 op_def_proto = op_def_pb2.OpList() 103 buf = c_api.TF_GetAllOpList() 104 try: 105 op_def_proto.ParseFromString(c_api.TF_GetBuffer(buf)) 106 self._api_def_map = c_api.TF_NewApiDefMap(buf) 107 finally: 108 c_api.TF_DeleteBuffer(buf) 109 110 self._op_per_name = {} 111 for op in op_def_proto.op: 112 self._op_per_name[op.name] = op 113 114 def __del__(self): 115 # Note: when we're destructing the global context (i.e when the process is 116 # terminating) we can have already deleted other modules. 117 if c_api is not None and c_api.TF_DeleteApiDefMap is not None: 118 c_api.TF_DeleteApiDefMap(self._api_def_map) 119 120 def put_api_def(self, text): 121 c_api.TF_ApiDefMapPut(self._api_def_map, text, len(text)) 122 123 def get_api_def(self, op_name): 124 api_def_proto = api_def_pb2.ApiDef() 125 buf = c_api.TF_ApiDefMapGet(self._api_def_map, op_name, len(op_name)) 126 try: 127 api_def_proto.ParseFromString(c_api.TF_GetBuffer(buf)) 128 finally: 129 c_api.TF_DeleteBuffer(buf) 130 return api_def_proto 131 132 def get_op_def(self, op_name): 133 if op_name in self._op_per_name: 134 return self._op_per_name[op_name] 135 raise ValueError("No entry found for " + op_name + ".") 136 137 def op_names(self): 138 return self._op_per_name.keys() 139 140 141@tf_contextlib.contextmanager 142def tf_buffer(data=None): 143 """Context manager that creates and deletes TF_Buffer. 144 145 Example usage: 146 with tf_buffer() as buf: 147 # get serialized graph def into buf 148 ... 149 proto_data = c_api.TF_GetBuffer(buf) 150 graph_def.ParseFromString(compat.as_bytes(proto_data)) 151 # buf has been deleted 152 153 with tf_buffer(some_string) as buf: 154 c_api.TF_SomeFunction(buf) 155 # buf has been deleted 156 157 Args: 158 data: An optional `bytes`, `str`, or `unicode` object. If not None, the 159 yielded buffer will contain this data. 160 161 Yields: 162 Created TF_Buffer 163 """ 164 if data: 165 buf = c_api.TF_NewBufferFromString(compat.as_bytes(data)) 166 else: 167 buf = c_api.TF_NewBuffer() 168 try: 169 yield buf 170 finally: 171 c_api.TF_DeleteBuffer(buf) 172 173 174def tf_output(c_op, index): 175 """Returns a wrapped TF_Output with specified operation and index. 176 177 Args: 178 c_op: wrapped TF_Operation 179 index: integer 180 181 Returns: 182 Wrapped TF_Output 183 """ 184 ret = c_api.TF_Output() 185 ret.oper = c_op 186 ret.index = index 187 return ret 188 189 190def tf_operations(graph): 191 """Generator that yields every TF_Operation in `graph`. 192 193 Args: 194 graph: Graph 195 196 Yields: 197 wrapped TF_Operation 198 """ 199 # pylint: disable=protected-access 200 pos = 0 201 c_op, pos = c_api.TF_GraphNextOperation(graph._c_graph, pos) 202 while c_op is not None: 203 yield c_op 204 c_op, pos = c_api.TF_GraphNextOperation(graph._c_graph, pos) 205 # pylint: enable=protected-access 206 207 208def new_tf_operations(graph): 209 """Generator that yields newly-added TF_Operations in `graph`. 210 211 Specifically, yields TF_Operations that don't have associated Operations in 212 `graph`. This is useful for processing nodes added by the C API. 213 214 Args: 215 graph: Graph 216 217 Yields: 218 wrapped TF_Operation 219 """ 220 # TODO(b/69679162): do this more efficiently 221 for c_op in tf_operations(graph): 222 try: 223 graph._get_operation_by_tf_operation(c_op) # pylint: disable=protected-access 224 except KeyError: 225 yield c_op 226