1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Functions called by the generated code to execute an eager-mode op.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import six 22 23from google.protobuf import text_format 24from tensorflow.core.framework import tensor_pb2 25from tensorflow.python import pywrap_tfe 26from tensorflow.python.eager import core 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import ops 29from tensorflow.python.framework import tensor_shape 30from tensorflow.python.util import compat 31 32 33def quick_execute(op_name, num_outputs, inputs, attrs, ctx, name=None): 34 """Execute a TensorFlow operation. 35 36 Args: 37 op_name: Name of the TensorFlow operation (see REGISTER_OP in C++ code) to 38 execute. 39 num_outputs: The number of outputs of the operation to fetch. 40 (Explicitly provided instead of being inferred for performance 41 reasons). 42 inputs: A list of inputs to the operation. Each entry should be a Tensor, or 43 a value which can be passed to the Tensor constructor to create one. 44 attrs: A tuple with alternating string attr names and attr values for this 45 operation. 46 ctx: The value of context.context(). 47 name: Customized name for the operation. 48 49 Returns: 50 List of output Tensor objects. The list is empty if there are no outputs 51 52 Raises: 53 An exception on error. 54 """ 55 device_name = ctx.device_name 56 # pylint: disable=protected-access 57 try: 58 ctx.ensure_initialized() 59 tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name, 60 inputs, attrs, num_outputs) 61 except core._NotOkStatusException as e: 62 if name is not None: 63 message = e.message + " name: " + name 64 else: 65 message = e.message 66 six.raise_from(core._status_to_exception(e.code, message), None) 67 except TypeError as e: 68 keras_symbolic_tensors = [ 69 x for x in inputs if ops._is_keras_symbolic_tensor(x) 70 ] 71 if keras_symbolic_tensors: 72 raise core._SymbolicException( 73 "Inputs to eager execution function cannot be Keras symbolic " 74 "tensors, but found {}".format(keras_symbolic_tensors)) 75 raise e 76 # pylint: enable=protected-access 77 return tensors 78 79 80def execute_with_cancellation(op_name, 81 num_outputs, 82 inputs, 83 attrs, 84 ctx, 85 cancellation_manager, 86 name=None): 87 """Execute a TensorFlow operation. 88 89 Args: 90 op_name: Name of the TensorFlow operation (see REGISTER_OP in C++ code) to 91 execute. 92 num_outputs: The number of outputs of the operation to fetch. (Explicitly 93 provided instead of being inferred for performance reasons). 94 inputs: A list of inputs to the operation. Each entry should be a Tensor, or 95 a value which can be passed to the Tensor constructor to create one. 96 attrs: A tuple with alternating string attr names and attr values for this 97 operation. 98 ctx: The value of context.context(). 99 cancellation_manager: a `CancellationManager` object that can be used to 100 cancel the operation. 101 name: Customized name for the operation. 102 103 Returns: 104 List of output Tensor objects. The list is empty if there are no outputs 105 106 Raises: 107 An exception on error. 108 """ 109 device_name = ctx.device_name 110 # pylint: disable=protected-access 111 try: 112 ctx.ensure_initialized() 113 tensors = pywrap_tfe.TFE_Py_ExecuteCancelable(ctx._handle, device_name, 114 op_name, inputs, attrs, 115 cancellation_manager._impl, 116 num_outputs) 117 except core._NotOkStatusException as e: 118 if name is not None: 119 message = e.message + " name: " + name 120 else: 121 message = e.message 122 six.raise_from(core._status_to_exception(e.code, message), None) 123 except TypeError as e: 124 keras_symbolic_tensors = [ 125 x for x in inputs if ops._is_keras_symbolic_tensor(x) 126 ] 127 if keras_symbolic_tensors: 128 raise core._SymbolicException( 129 "Inputs to eager execution function cannot be Keras symbolic " 130 "tensors, but found {}".format(keras_symbolic_tensors)) 131 raise e 132 # pylint: enable=protected-access 133 return tensors 134 135 136def execute_with_callbacks(op_name, num_outputs, inputs, attrs, ctx, name=None): 137 """Monkey-patch to execute to enable execution callbacks.""" 138 tensors = quick_execute(op_name, num_outputs, inputs, attrs, ctx, name) 139 for callback in ctx.op_callbacks: 140 callback(op_name, tuple(inputs), attrs, tensors, name) 141 142 return tensors 143 144 145execute = quick_execute 146 147 148def must_record_gradient(): 149 """Import backprop if you want gradients recorded.""" 150 return False 151 152 153def record_gradient(unused_op_name, unused_inputs, unused_attrs, 154 unused_results): 155 """Import backprop if you want gradients recorded.""" 156 pass 157 158 159def make_float(v, arg_name): 160 if not isinstance(v, compat.real_types): 161 raise TypeError("Expected float for argument '%s' not %s." % 162 (arg_name, repr(v))) 163 return float(v) 164 165 166def make_int(v, arg_name): 167 if isinstance(v, six.string_types): 168 raise TypeError("Expected int for argument '%s' not %s." % 169 (arg_name, repr(v))) 170 try: 171 return int(v) 172 except (ValueError, TypeError): 173 raise TypeError("Expected int for argument '%s' not %s." % 174 (arg_name, repr(v))) 175 176 177def make_str(v, arg_name): 178 if not isinstance(v, compat.bytes_or_text_types): 179 raise TypeError("Expected string for argument '%s' not %s." % 180 (arg_name, repr(v))) 181 return compat.as_bytes(v) # Convert unicode strings to bytes. 182 183 184def make_bool(v, arg_name): 185 if not isinstance(v, bool): 186 raise TypeError("Expected bool for argument '%s' not %s." % 187 (arg_name, repr(v))) 188 return v 189 190 191def make_type(v, arg_name): 192 try: 193 v = dtypes.as_dtype(v).base_dtype 194 except TypeError: 195 raise TypeError("Expected DataType for argument '%s' not %s." % 196 (arg_name, repr(v))) 197 i = v.as_datatype_enum 198 return i 199 200 201def make_shape(v, arg_name): 202 """Convert v into a list.""" 203 # Args: 204 # v: A TensorShapeProto, a list of ints, or a tensor_shape.TensorShape. 205 # arg_name: String, for error messages. 206 207 # Returns: 208 # None if the rank is unknown, otherwise a list of ints (or Nones in the 209 # position where the dimension is unknown). 210 try: 211 shape = tensor_shape.as_shape(v) 212 except TypeError as e: 213 raise TypeError("Error converting %s to a TensorShape: %s." % (arg_name, e)) 214 except ValueError as e: 215 raise ValueError("Error converting %s to a TensorShape: %s." % (arg_name, 216 e)) 217 if shape.ndims is None: 218 return None 219 else: 220 return shape.as_list() 221 222 223def make_tensor(v, arg_name): 224 """Ensure v is a TensorProto.""" 225 if isinstance(v, tensor_pb2.TensorProto): 226 return v 227 elif isinstance(v, six.string_types): 228 pb = tensor_pb2.TensorProto() 229 text_format.Merge(v, pb) 230 return pb 231 raise TypeError( 232 "Don't know how to convert %s to a TensorProto for argument '%s'." % 233 (repr(v), arg_name)) 234 235 236def args_to_matching_eager(l, ctx, allowed_dtypes, default_dtype=None): 237 """Convert sequence `l` to eager same-type Tensors.""" 238 if (not l) and (default_dtype is not None): 239 return default_dtype, [] # List is empty; assume default dtype. 240 EagerTensor = ops.EagerTensor # pylint: disable=invalid-name 241 for x in l: 242 if not isinstance(x, EagerTensor): 243 break 244 else: # note: intentional for-else 245 return l[0]._datatype_enum(), l # pylint: disable=protected-access 246 247 # Is some input already a Tensor with a dtype? 248 dtype = None 249 for t in l: 250 if isinstance(t, EagerTensor): 251 dtype = t.dtype 252 break 253 254 if dtype is None: 255 # Infer a dtype based on the first value, and use that dtype for the 256 # remaining values. 257 258 ret = [] 259 for t in l: 260 tensor = None 261 # First see if we can get a valid dtype with the default conversion 262 # and see if it matches an allowed dtypes. Some ops like ConcatV2 may 263 # not list allowed dtypes, in which case we should skip this. 264 if dtype is None and allowed_dtypes: 265 tensor = ops.convert_to_tensor(t, ctx=ctx) 266 # If we did not match an allowed dtype, try again with the default 267 # dtype. This could be because we have an empty tensor and thus we 268 # picked the wrong type. 269 if tensor.dtype not in allowed_dtypes: 270 tensor = None 271 272 if tensor is None: 273 tensor = ops.convert_to_tensor( 274 t, dtype, preferred_dtype=default_dtype, ctx=ctx) 275 276 ret.append(tensor) 277 if dtype is None: 278 dtype = tensor.dtype 279 else: 280 ret = [ops.convert_to_tensor(t, dtype, ctx=ctx) for t in l] 281 282 # TODO(slebedev): consider removing this as it leaks a Keras concept. 283 # pylint: disable=protected-access 284 keras_symbolic_tensors = [x for x in ret if 285 ops._is_keras_symbolic_tensor(x)] 286 if keras_symbolic_tensors: 287 raise core._SymbolicException( 288 "Using symbolic output of a Keras layer during eager execution " 289 "{}".format(keras_symbolic_tensors)) 290 # pylint: enable=protected-access 291 return dtype.as_datatype_enum, ret 292 293 294def convert_to_mixed_eager_tensors(values, ctx): 295 v = [ops.convert_to_tensor(t, ctx=ctx) for t in values] 296 types = [t._datatype_enum() for t in v] # pylint: disable=protected-access 297 return types, v 298 299 300def args_to_mixed_eager_tensors(lists, ctx): 301 """Converts a list of same-length lists of values to eager tensors.""" 302 assert len(lists) > 1 303 304 # Generate an error if len(lists[i]) is not the same for all i. 305 lists_ret = [] 306 for l in lists[1:]: 307 if len(l) != len(lists[0]): 308 raise ValueError( 309 "Expected list arguments to be the same length: %d != %d (%r vs. %r)." 310 % (len(lists[0]), len(l), lists[0], l)) 311 lists_ret.append([]) 312 313 # Convert the first element of each list first, then the second element, etc. 314 types = [] 315 for i in range(len(lists[0])): 316 dtype = None 317 # If any list has a Tensor, use that dtype 318 for l in lists: 319 if isinstance(l[i], ops.EagerTensor): 320 dtype = l[i].dtype 321 break 322 if dtype is None: 323 # Convert the first one and use its dtype. 324 lists_ret[0].append(ops.convert_to_tensor(lists[0][i], ctx=ctx)) 325 dtype = lists_ret[0][i].dtype 326 for j in range(1, len(lists)): 327 lists_ret[j].append( 328 ops.convert_to_tensor(lists[j][i], dtype=dtype, ctx=ctx)) 329 else: 330 # Convert everything to the found dtype. 331 for j in range(len(lists)): 332 lists_ret[j].append( 333 ops.convert_to_tensor(lists[j][i], dtype=dtype, ctx=ctx)) 334 types.append(dtype.as_datatype_enum) 335 return types, lists_ret 336