1# Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Functions called by the generated code to execute an eager-mode op.""" 16 17from __future__ import absolute_import 18from __future__ import division 19from __future__ import print_function 20 21import six 22 23from google.protobuf import text_format 24from tensorflow.core.framework import tensor_pb2 25from tensorflow.python import pywrap_tensorflow 26from tensorflow.python.eager import core 27from tensorflow.python.framework import dtypes 28from tensorflow.python.framework import ops 29from tensorflow.python.framework import tensor_shape 30from tensorflow.python.util import compat 31 32 33def quick_execute(op_name, num_outputs, inputs, attrs, ctx, name=None): 34 """Execute a TensorFlow operation. 35 36 Args: 37 op_name: Name of the TensorFlow operation (see REGISTER_OP in C++ code) to 38 execute. 39 num_outputs: The number of outputs of the operation to fetch. 40 (Explicitly provided instead of being inferred for performance 41 reasons). 42 inputs: A list of inputs to the operation. Each entry should be a Tensor, or 43 a value which can be passed to the Tensor constructor to create one. 44 attrs: A tuple with alternating string attr names and attr values for this 45 operation. 46 ctx: The value of context.context(). 47 name: Customized name for the operation. 48 49 Returns: 50 List of output Tensor objects. The list is empty if there are no outputs 51 52 Raises: 53 An exception on error. 54 """ 55 device_name = ctx.device_name 56 # pylint: disable=protected-access 57 try: 58 tensors = pywrap_tensorflow.TFE_Py_Execute(ctx._handle, device_name, 59 op_name, inputs, attrs, 60 num_outputs) 61 except core._NotOkStatusException as e: 62 if name is not None: 63 message = e.message + " name: " + name 64 else: 65 message = e.message 66 six.raise_from(core._status_to_exception(e.code, message), None) 67 except TypeError as e: 68 if any(ops._is_keras_symbolic_tensor(x) for x in inputs): 69 raise core._SymbolicException 70 raise e 71 # pylint: enable=protected-access 72 return tensors 73 74 75def execute_with_callbacks(op_name, num_outputs, inputs, attrs, ctx, name=None): 76 """Monkey-patch to execute to enable execution callbacks.""" 77 tensors = quick_execute(op_name, num_outputs, inputs, attrs, ctx, name) 78 for callback in ctx.post_execution_callbacks: 79 callback(op_name, inputs, attrs, tensors, name) 80 81 return tensors 82 83 84execute = quick_execute 85 86 87def record_gradient(unused_op_name, unused_inputs, unused_attrs, unused_results, 88 unused_name): 89 """Import backprop if you want gradients recorded.""" 90 pass 91 92 93def make_float(v, arg_name): 94 if not isinstance(v, compat.real_types): 95 raise TypeError("Expected float for argument '%s' not %s." % 96 (arg_name, repr(v))) 97 return float(v) 98 99 100def make_int(v, arg_name): 101 if isinstance(v, six.string_types): 102 raise TypeError("Expected int for argument '%s' not %s." % 103 (arg_name, repr(v))) 104 try: 105 return int(v) 106 except (ValueError, TypeError): 107 raise TypeError("Expected int for argument '%s' not %s." % 108 (arg_name, repr(v))) 109 110 111def make_str(v, arg_name): 112 if not isinstance(v, compat.bytes_or_text_types): 113 raise TypeError("Expected string for argument '%s' not %s." % 114 (arg_name, repr(v))) 115 return compat.as_bytes(v) # Convert unicode strings to bytes. 116 117 118def make_bool(v, arg_name): 119 if not isinstance(v, bool): 120 raise TypeError("Expected bool for argument '%s' not %s." % 121 (arg_name, repr(v))) 122 return v 123 124 125def make_type(v, arg_name): 126 try: 127 v = dtypes.as_dtype(v).base_dtype 128 except TypeError: 129 raise TypeError("Expected DataType for argument '%s' not %s." % 130 (arg_name, repr(v))) 131 i = v.as_datatype_enum 132 return i 133 134 135def make_shape(v, arg_name): 136 """Convert v into a list.""" 137 # Args: 138 # v: A TensorShapeProto, a list of ints, or a tensor_shape.TensorShape. 139 # arg_name: String, for error messages. 140 141 # Returns: 142 # None if the rank is unknown, otherwise a list of ints (or Nones in the 143 # position where the dimension is unknown). 144 try: 145 shape = tensor_shape.as_shape(v) 146 except TypeError as e: 147 raise TypeError("Error converting %s to a TensorShape: %s." % (arg_name, e)) 148 except ValueError as e: 149 raise ValueError("Error converting %s to a TensorShape: %s." % (arg_name, 150 e)) 151 if shape.ndims is None: 152 return None 153 else: 154 return shape.as_list() 155 156 157def make_tensor(v, arg_name): 158 """Ensure v is a TensorProto.""" 159 if isinstance(v, tensor_pb2.TensorProto): 160 return v 161 elif isinstance(v, six.string_types): 162 pb = tensor_pb2.TensorProto() 163 text_format.Merge(v, pb) 164 return pb 165 raise TypeError( 166 "Don't know how to convert %s to a TensorProto for argument '%s'." % 167 (repr(v), arg_name)) 168 169 170def args_to_matching_eager(l, ctx, default_dtype=None): 171 """Convert sequence `l` to eager same-type Tensors.""" 172 EagerTensor = ops.EagerTensor # pylint: disable=invalid-name 173 for x in l: 174 if not isinstance(x, EagerTensor): 175 break 176 else: # note: intentional for-else 177 return l[0]._datatype_enum(), l # pylint: disable=protected-access 178 # TODO(josh11b): Could we do a better job if we also passed in the 179 # allowed dtypes when that was known? 180 181 # Is some input already a Tensor with a dtype? 182 dtype = None 183 for t in l: 184 if isinstance(t, EagerTensor): 185 dtype = t.dtype 186 break 187 188 internal_convert_to_tensor = ops.internal_convert_to_tensor 189 if dtype is None: 190 # Infer a dtype based on the first value, and use that dtype for the 191 # remaining values. 192 ret = [] 193 for t in l: 194 ret.append(internal_convert_to_tensor( 195 t, dtype, 196 preferred_dtype=default_dtype, 197 ctx=ctx, 198 accept_symbolic_tensors=False)) 199 if dtype is None: 200 dtype = ret[-1].dtype 201 else: 202 ret = [internal_convert_to_tensor(t, dtype, ctx=ctx) for t in l] 203 204 return dtype.as_datatype_enum, ret 205 206 207def convert_to_mixed_eager_tensors(values, ctx): 208 v = [ops.internal_convert_to_tensor(t, ctx=ctx) for t in values] 209 types = [t._datatype_enum() for t in v] # pylint: disable=protected-access 210 return types, v 211 212 213def args_to_mixed_eager_tensors(lists, ctx): 214 """Converts a list of same-length lists of values to eager tensors.""" 215 assert len(lists) > 1 216 217 # Generate an error if len(lists[i]) is not the same for all i. 218 lists_ret = [] 219 for l in lists[1:]: 220 if len(l) != len(lists[0]): 221 raise ValueError( 222 "Expected list arguments to be the same length: %d != %d (%r vs. %r)." 223 % (len(lists[0]), len(l), lists[0], l)) 224 lists_ret.append([]) 225 226 # Convert the first element of each list first, then the second element, etc. 227 types = [] 228 for i in range(len(lists[0])): 229 dtype = None 230 # If any list has a Tensor, use that dtype 231 for l in lists: 232 if isinstance(l[i], ops.EagerTensor): 233 dtype = l[i].dtype 234 break 235 if dtype is None: 236 # Convert the first one and use its dtype. 237 lists_ret[0].append(ops.internal_convert_to_tensor(lists[0][i], ctx=ctx)) 238 dtype = lists_ret[0][i].dtype 239 for j in range(1, len(lists)): 240 lists_ret[j].append( 241 ops.internal_convert_to_tensor(lists[j][i], dtype=dtype, ctx=ctx)) 242 else: 243 # Convert everything to the found dtype. 244 for j in range(len(lists)): 245 lists_ret[j].append( 246 ops.internal_convert_to_tensor(lists[j][i], dtype=dtype, ctx=ctx)) 247 types.append(dtype.as_datatype_enum) 248 return types, lists_ret 249