• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Functions called by the generated code to execute an eager-mode op."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import six
22
23from google.protobuf import text_format
24from tensorflow.core.framework import tensor_pb2
25from tensorflow.python import pywrap_tfe
26from tensorflow.python.eager import core
27from tensorflow.python.framework import dtypes
28from tensorflow.python.framework import ops
29from tensorflow.python.framework import tensor_shape
30from tensorflow.python.util import compat
31
32
33def quick_execute(op_name, num_outputs, inputs, attrs, ctx, name=None):
34  """Execute a TensorFlow operation.
35
36  Args:
37    op_name: Name of the TensorFlow operation (see REGISTER_OP in C++ code) to
38      execute.
39    num_outputs: The number of outputs of the operation to fetch.
40                 (Explicitly provided instead of being inferred for performance
41                 reasons).
42    inputs: A list of inputs to the operation. Each entry should be a Tensor, or
43      a value which can be passed to the Tensor constructor to create one.
44    attrs: A tuple with alternating string attr names and attr values for this
45      operation.
46    ctx: The value of context.context().
47    name: Customized name for the operation.
48
49  Returns:
50    List of output Tensor objects. The list is empty if there are no outputs
51
52  Raises:
53    An exception on error.
54  """
55  device_name = ctx.device_name
56  # pylint: disable=protected-access
57  try:
58    ctx.ensure_initialized()
59    tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
60                                        inputs, attrs, num_outputs)
61  except core._NotOkStatusException as e:
62    if name is not None:
63      message = e.message + " name: " + name
64    else:
65      message = e.message
66    six.raise_from(core._status_to_exception(e.code, message), None)
67  except TypeError as e:
68    keras_symbolic_tensors = [
69        x for x in inputs if ops._is_keras_symbolic_tensor(x)
70    ]
71    if keras_symbolic_tensors:
72      raise core._SymbolicException(
73          "Inputs to eager execution function cannot be Keras symbolic "
74          "tensors, but found {}".format(keras_symbolic_tensors))
75    raise e
76  # pylint: enable=protected-access
77  return tensors
78
79
80def execute_with_cancellation(op_name,
81                              num_outputs,
82                              inputs,
83                              attrs,
84                              ctx,
85                              cancellation_manager,
86                              name=None):
87  """Execute a TensorFlow operation.
88
89  Args:
90    op_name: Name of the TensorFlow operation (see REGISTER_OP in C++ code) to
91      execute.
92    num_outputs: The number of outputs of the operation to fetch. (Explicitly
93      provided instead of being inferred for performance reasons).
94    inputs: A list of inputs to the operation. Each entry should be a Tensor, or
95      a value which can be passed to the Tensor constructor to create one.
96    attrs: A tuple with alternating string attr names and attr values for this
97      operation.
98    ctx: The value of context.context().
99    cancellation_manager: a `CancellationManager` object that can be used to
100      cancel the operation.
101    name: Customized name for the operation.
102
103  Returns:
104    List of output Tensor objects. The list is empty if there are no outputs
105
106  Raises:
107    An exception on error.
108  """
109  device_name = ctx.device_name
110  # pylint: disable=protected-access
111  try:
112    ctx.ensure_initialized()
113    tensors = pywrap_tfe.TFE_Py_ExecuteCancelable(ctx._handle, device_name,
114                                                  op_name, inputs, attrs,
115                                                  cancellation_manager._impl,
116                                                  num_outputs)
117  except core._NotOkStatusException as e:
118    if name is not None:
119      message = e.message + " name: " + name
120    else:
121      message = e.message
122    six.raise_from(core._status_to_exception(e.code, message), None)
123  except TypeError as e:
124    keras_symbolic_tensors = [
125        x for x in inputs if ops._is_keras_symbolic_tensor(x)
126    ]
127    if keras_symbolic_tensors:
128      raise core._SymbolicException(
129          "Inputs to eager execution function cannot be Keras symbolic "
130          "tensors, but found {}".format(keras_symbolic_tensors))
131    raise e
132  # pylint: enable=protected-access
133  return tensors
134
135
136def execute_with_callbacks(op_name, num_outputs, inputs, attrs, ctx, name=None):
137  """Monkey-patch to execute to enable execution callbacks."""
138  tensors = quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
139  for callback in ctx.op_callbacks:
140    callback(op_name, tuple(inputs), attrs, tensors, name)
141
142  return tensors
143
144
145execute = quick_execute
146
147
148def must_record_gradient():
149  """Import backprop if you want gradients recorded."""
150  return False
151
152
153def record_gradient(unused_op_name, unused_inputs, unused_attrs,
154                    unused_results):
155  """Import backprop if you want gradients recorded."""
156  pass
157
158
159def make_float(v, arg_name):
160  if not isinstance(v, compat.real_types):
161    raise TypeError("Expected float for argument '%s' not %s." %
162                    (arg_name, repr(v)))
163  return float(v)
164
165
166def make_int(v, arg_name):
167  if isinstance(v, six.string_types):
168    raise TypeError("Expected int for argument '%s' not %s." %
169                    (arg_name, repr(v)))
170  try:
171    return int(v)
172  except (ValueError, TypeError):
173    raise TypeError("Expected int for argument '%s' not %s." %
174                    (arg_name, repr(v)))
175
176
177def make_str(v, arg_name):
178  if not isinstance(v, compat.bytes_or_text_types):
179    raise TypeError("Expected string for argument '%s' not %s." %
180                    (arg_name, repr(v)))
181  return compat.as_bytes(v)  # Convert unicode strings to bytes.
182
183
184def make_bool(v, arg_name):
185  if not isinstance(v, bool):
186    raise TypeError("Expected bool for argument '%s' not %s." %
187                    (arg_name, repr(v)))
188  return v
189
190
191def make_type(v, arg_name):
192  try:
193    v = dtypes.as_dtype(v).base_dtype
194  except TypeError:
195    raise TypeError("Expected DataType for argument '%s' not %s." %
196                    (arg_name, repr(v)))
197  i = v.as_datatype_enum
198  return i
199
200
201def make_shape(v, arg_name):
202  """Convert v into a list."""
203  # Args:
204  #   v: A TensorShapeProto, a list of ints, or a tensor_shape.TensorShape.
205  #   arg_name: String, for error messages.
206
207  # Returns:
208  #   None if the rank is unknown, otherwise a list of ints (or Nones in the
209  #   position where the dimension is unknown).
210  try:
211    shape = tensor_shape.as_shape(v)
212  except TypeError as e:
213    raise TypeError("Error converting %s to a TensorShape: %s." % (arg_name, e))
214  except ValueError as e:
215    raise ValueError("Error converting %s to a TensorShape: %s." % (arg_name,
216                                                                    e))
217  if shape.ndims is None:
218    return None
219  else:
220    return shape.as_list()
221
222
223def make_tensor(v, arg_name):
224  """Ensure v is a TensorProto."""
225  if isinstance(v, tensor_pb2.TensorProto):
226    return v
227  elif isinstance(v, six.string_types):
228    pb = tensor_pb2.TensorProto()
229    text_format.Merge(v, pb)
230    return pb
231  raise TypeError(
232      "Don't know how to convert %s to a TensorProto for argument '%s'." %
233      (repr(v), arg_name))
234
235
236def args_to_matching_eager(l, ctx, allowed_dtypes, default_dtype=None):
237  """Convert sequence `l` to eager same-type Tensors."""
238  if (not l) and (default_dtype is not None):
239    return default_dtype, []  # List is empty; assume default dtype.
240  EagerTensor = ops.EagerTensor  # pylint: disable=invalid-name
241  for x in l:
242    if not isinstance(x, EagerTensor):
243      break
244  else:  # note: intentional for-else
245    return l[0]._datatype_enum(), l  # pylint: disable=protected-access
246
247  # Is some input already a Tensor with a dtype?
248  dtype = None
249  for t in l:
250    if isinstance(t, EagerTensor):
251      dtype = t.dtype
252      break
253
254  if dtype is None:
255    # Infer a dtype based on the first value, and use that dtype for the
256    # remaining values.
257
258    ret = []
259    for t in l:
260      tensor = None
261      # First see if we can get a valid dtype with the default conversion
262      # and see if it matches an allowed dtypes. Some ops like ConcatV2 may
263      # not list allowed dtypes, in which case we should skip this.
264      if dtype is None and allowed_dtypes:
265        tensor = ops.convert_to_tensor(t, ctx=ctx)
266        # If we did not match an allowed dtype, try again with the default
267        # dtype. This could be because we have an empty tensor and thus we
268        # picked the wrong type.
269        if tensor.dtype not in allowed_dtypes:
270          tensor = None
271
272      if tensor is None:
273        tensor = ops.convert_to_tensor(
274            t, dtype, preferred_dtype=default_dtype, ctx=ctx)
275
276      ret.append(tensor)
277      if dtype is None:
278        dtype = tensor.dtype
279  else:
280    ret = [ops.convert_to_tensor(t, dtype, ctx=ctx) for t in l]
281
282  # TODO(slebedev): consider removing this as it leaks a Keras concept.
283  # pylint: disable=protected-access
284  keras_symbolic_tensors = [x for x in ret if
285                            ops._is_keras_symbolic_tensor(x)]
286  if keras_symbolic_tensors:
287    raise core._SymbolicException(
288        "Using symbolic output of a Keras layer during eager execution "
289        "{}".format(keras_symbolic_tensors))
290  # pylint: enable=protected-access
291  return dtype.as_datatype_enum, ret
292
293
294def convert_to_mixed_eager_tensors(values, ctx):
295  v = [ops.convert_to_tensor(t, ctx=ctx) for t in values]
296  types = [t._datatype_enum() for t in v]  # pylint: disable=protected-access
297  return types, v
298
299
300def args_to_mixed_eager_tensors(lists, ctx):
301  """Converts a list of same-length lists of values to eager tensors."""
302  assert len(lists) > 1
303
304  # Generate an error if len(lists[i]) is not the same for all i.
305  lists_ret = []
306  for l in lists[1:]:
307    if len(l) != len(lists[0]):
308      raise ValueError(
309          "Expected list arguments to be the same length: %d != %d (%r vs. %r)."
310          % (len(lists[0]), len(l), lists[0], l))
311    lists_ret.append([])
312
313  # Convert the first element of each list first, then the second element, etc.
314  types = []
315  for i in range(len(lists[0])):
316    dtype = None
317    # If any list has a Tensor, use that dtype
318    for l in lists:
319      if isinstance(l[i], ops.EagerTensor):
320        dtype = l[i].dtype
321        break
322    if dtype is None:
323      # Convert the first one and use its dtype.
324      lists_ret[0].append(ops.convert_to_tensor(lists[0][i], ctx=ctx))
325      dtype = lists_ret[0][i].dtype
326      for j in range(1, len(lists)):
327        lists_ret[j].append(
328            ops.convert_to_tensor(lists[j][i], dtype=dtype, ctx=ctx))
329    else:
330      # Convert everything to the found dtype.
331      for j in range(len(lists)):
332        lists_ret[j].append(
333            ops.convert_to_tensor(lists[j][i], dtype=dtype, ctx=ctx))
334    types.append(dtype.as_datatype_enum)
335  return types, lists_ret
336