1# Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15"""Operations that generate constants. 16 17See the [constants guide](https://tensorflow.org/api_guides/python/constant_op). 18""" 19 20# Must be separate from array_ops to avoid a cyclic dependency. 21 22from __future__ import absolute_import 23from __future__ import division 24from __future__ import print_function 25 26import numpy as np 27import six 28 29from tensorflow.core.framework import attr_value_pb2 30from tensorflow.core.framework import types_pb2 31from tensorflow.python.eager import context 32from tensorflow.python.eager import execute 33from tensorflow.python.framework import dtypes 34from tensorflow.python.framework import ops 35from tensorflow.python.framework import tensor_shape 36from tensorflow.python.framework import tensor_util 37from tensorflow.python.util.tf_export import tf_export 38 39 40def _eager_reshape(tensor, shape, ctx): 41 """Eager-only version of Reshape op; requires tensor is an eager Tensor.""" 42 attr_t = tensor._datatype_enum() # pylint: disable=protected-access 43 attr_tshape, (shape,) = execute.args_to_matching_eager( 44 [shape], ctx, dtypes.int32) 45 inputs_flat = [tensor, shape] 46 attrs = ("T", attr_t, "Tshape", attr_tshape) 47 result, = execute.execute( 48 b"Reshape", 1, inputs=inputs_flat, attrs=attrs, ctx=ctx) 49 return result 50 51 52def _eager_fill(dims, value, ctx): 53 """Eager-only version of Fill op; requires value is an eager Tensor.""" 54 attr_t = value.dtype.as_datatype_enum 55 dims = convert_to_eager_tensor(dims, ctx, dtypes.int32) 56 inputs_flat = [dims, value] 57 attrs = ("T", attr_t, "index_type", types_pb2.DT_INT32) 58 result, = execute.execute( 59 b"Fill", 1, inputs=inputs_flat, attrs=attrs, ctx=ctx) 60 return result 61 62 63def _eager_identity(tensor, ctx): 64 """Eager-only version of Identity op; requires tensor is an eager Tensor.""" 65 attrs = ("T", tensor.dtype.as_datatype_enum) 66 result, = execute.execute( 67 b"Identity", 1, inputs=[tensor], attrs=attrs, ctx=ctx) 68 return result 69 70 71def convert_to_eager_tensor(value, ctx, dtype=None): 72 """Converts the given `value` to an `EagerTensor`. 73 74 Note that this function could return cached copies of created constants for 75 performance reasons. 76 77 Args: 78 value: value to convert to EagerTensor. 79 ctx: value of context.context(). 80 dtype: optional desired dtype of the converted EagerTensor. 81 82 Returns: 83 EagerTensor created from value. 84 85 Raises: 86 TypeError: if `dtype` is not compatible with the type of t. 87 """ 88 if isinstance(value, ops.EagerTensor): 89 if dtype is not None and value.dtype != dtype: 90 raise TypeError("Expected tensor with type %r not %r" % ( 91 dtype, value.dtype)) 92 return value 93 if dtype is not None: 94 try: 95 dtype = dtype.as_datatype_enum 96 except AttributeError: 97 dtype = dtypes.as_dtype(dtype).as_datatype_enum 98 device = ctx.device_name 99 handle = ctx._handle # pylint: disable=protected-access 100 if isinstance(value, (float,) + six.integer_types): 101 # Use a scalar cache. This will put each scalar of each type only once on 102 # each device. Scalars don't use much device memory but copying scalars can 103 # trigger memcpys which are slow. 104 cache_key = device, value, dtype, type(value) 105 scalar_cache = ctx.scalar_cache() 106 tensor = scalar_cache.get(cache_key, None) 107 if tensor is not None: 108 return ops.EagerTensor( 109 value, handle, device, dtype, tensor) 110 t = ops.EagerTensor(value, handle, device, dtype) 111 scalar_cache[cache_key] = t 112 return t 113 else: 114 return ops.EagerTensor(value, handle, device, dtype) 115 116 117@tf_export(v1=["constant"]) 118def constant_v1( 119 value, dtype=None, shape=None, name="Const", verify_shape=False): 120 """Creates a constant tensor. 121 122 The resulting tensor is populated with values of type `dtype`, as 123 specified by arguments `value` and (optionally) `shape` (see examples 124 below). 125 126 The argument `value` can be a constant value, or a list of values of type 127 `dtype`. If `value` is a list, then the length of the list must be less 128 than or equal to the number of elements implied by the `shape` argument (if 129 specified). In the case where the list length is less than the number of 130 elements specified by `shape`, the last element in the list will be used 131 to fill the remaining entries. 132 133 The argument `shape` is optional. If present, it specifies the dimensions of 134 the resulting tensor. If not present, the shape of `value` is used. 135 136 If the argument `dtype` is not specified, then the type is inferred from 137 the type of `value`. 138 139 For example: 140 141 ```python 142 # Constant 1-D Tensor populated with value list. 143 tensor = tf.constant([1, 2, 3, 4, 5, 6, 7]) => [1 2 3 4 5 6 7] 144 145 # Constant 2-D tensor populated with scalar value -1. 146 tensor = tf.constant(-1.0, shape=[2, 3]) => [[-1. -1. -1.] 147 [-1. -1. -1.]] 148 ``` 149 150 `tf.constant` differs from `tf.fill` in a few ways: 151 152 * `tf.constant` supports arbitrary constants, not just uniform scalar 153 Tensors like `tf.fill`. 154 * `tf.constant` creates a `Const` node in the computation graph with the 155 exact value at graph construction time. On the other hand, `tf.fill` 156 creates an Op in the graph that is expanded at runtime. 157 * Because `tf.constant` only embeds constant values in the graph, it does 158 not support dynamic shapes based on other runtime Tensors, whereas 159 `tf.fill` does. 160 161 Args: 162 value: A constant value (or list) of output type `dtype`. 163 164 dtype: The type of the elements of the resulting tensor. 165 166 shape: Optional dimensions of resulting tensor. 167 168 name: Optional name for the tensor. 169 170 verify_shape: Boolean that enables verification of a shape of values. 171 172 Returns: 173 A Constant Tensor. 174 175 Raises: 176 TypeError: if shape is incorrectly specified or unsupported. 177 """ 178 return _constant_impl(value, dtype, shape, name, verify_shape=verify_shape, 179 allow_broadcast=False) 180 181 182@tf_export("constant", v1=[]) 183def constant(value, dtype=None, shape=None, name="Const"): 184 """Creates a constant tensor. 185 186 The resulting tensor is populated with values of type `dtype`, as 187 specified by arguments `value` and (optionally) `shape` (see examples 188 below). 189 190 The argument `value` can be a constant value, or a list of values of type 191 `dtype`. If `value` is a list, then the length of the list must be less 192 than or equal to the number of elements implied by the `shape` argument (if 193 specified). In the case where the list length is less than the number of 194 elements specified by `shape`, the last element in the list will be used 195 to fill the remaining entries. 196 197 The argument `shape` is optional. If present, it specifies the dimensions of 198 the resulting tensor. If not present, the shape of `value` is used. 199 200 If the argument `dtype` is not specified, then the type is inferred from 201 the type of `value`. 202 203 For example: 204 205 ```python 206 # Constant 1-D Tensor populated with value list. 207 tensor = tf.constant([1, 2, 3, 4, 5, 6]) => [1 2 3 4 5 6] 208 209 # Constant 1-D Tensor populated with value list. 210 tensor = tf.constant([1, 2, 3, 4, 5, 6], shape=(2,3)) 211 => [[1 2 3], [4 5 6]] 212 213 # Constant 2-D tensor populated with scalar value -1. 214 tensor = tf.constant(-1.0, shape=[2, 3]) => [[-1. -1. -1.] 215 [-1. -1. -1.]] 216 ``` 217 218 `tf.constant` differs from `tf.fill` in a few ways: 219 220 * `tf.constant` supports arbitrary constants, not just uniform scalar 221 Tensors like `tf.fill`. 222 * `tf.constant` creates a `Const` node in the computation graph with the 223 exact value at graph construction time. On the other hand, `tf.fill` 224 creates an Op in the graph that is expanded at runtime. 225 * Because `tf.constant` only embeds constant values in the graph, it does 226 not support dynamic shapes based on other runtime Tensors, whereas 227 `tf.fill` does. 228 229 Args: 230 value: A constant value (or list) of output type `dtype`. 231 232 dtype: The type of the elements of the resulting tensor. 233 234 shape: Optional dimensions of resulting tensor. 235 236 name: Optional name for the tensor. 237 238 Returns: 239 A Constant Tensor. 240 241 Raises: 242 TypeError: if shape is incorrectly specified or unsupported. 243 """ 244 return _constant_impl(value, dtype, shape, name, verify_shape=False, 245 allow_broadcast=True) 246 247 248def _constant_impl( 249 value, dtype, shape, name, verify_shape, allow_broadcast): 250 """Implementation of constant.""" 251 ctx = context.context() 252 if ctx.executing_eagerly(): 253 t = convert_to_eager_tensor(value, ctx, dtype) 254 if shape is None: 255 return t 256 shape = tensor_shape.as_shape(shape) 257 if shape == t.shape: 258 return t 259 if verify_shape: 260 raise TypeError("Expected Tensor's shape: %s, got %s." % (tuple(shape), 261 tuple(t.shape))) 262 num_t = t.shape.num_elements() 263 # TODO(josh11b): Implement shape -> eager tensor conversion. 264 if num_t == shape.num_elements(): 265 return _eager_reshape(t, shape.as_list(), ctx) 266 if num_t == 1: 267 if t.dtype == dtypes.bool: 268 # We don't have a Fill kernel for bool dtype on GPU. So we first run 269 # Fill on CPU and then copy to GPU if needed. 270 with ops.device("/device:CPU:0"): 271 x = _eager_fill(shape.as_list(), t.cpu(), ctx) 272 return _eager_identity(x, ctx) 273 else: 274 return _eager_fill(shape.as_list(), t, ctx) 275 raise TypeError("Eager execution of tf.constant with unsupported shape " 276 "(value has %d elements, shape is %s with %d elements)." % 277 (num_t, shape, shape.num_elements())) 278 g = ops.get_default_graph() 279 tensor_value = attr_value_pb2.AttrValue() 280 tensor_value.tensor.CopyFrom( 281 tensor_util.make_tensor_proto( 282 value, dtype=dtype, shape=shape, verify_shape=verify_shape, 283 allow_broadcast=allow_broadcast)) 284 dtype_value = attr_value_pb2.AttrValue(type=tensor_value.tensor.dtype) 285 const_tensor = g.create_op( 286 "Const", [], [dtype_value.type], 287 attrs={"value": tensor_value, 288 "dtype": dtype_value}, 289 name=name).outputs[0] 290 return const_tensor 291 292 293def is_constant(tensor_or_op): 294 if isinstance(tensor_or_op, ops.Tensor): 295 op = tensor_or_op.op 296 else: 297 op = tensor_or_op 298 return op.type == "Const" 299 300 301def _constant_tensor_conversion_function(v, dtype=None, name=None, 302 as_ref=False): 303 _ = as_ref 304 return constant(v, dtype=dtype, name=name) 305 306 307ops.register_tensor_conversion_function( 308 (list, tuple), _constant_tensor_conversion_function, 100) 309ops.register_tensor_conversion_function( 310 np.ndarray, _constant_tensor_conversion_function, 100) 311ops.register_tensor_conversion_function( 312 np.generic, _constant_tensor_conversion_function, 100) 313ops.register_tensor_conversion_function( 314 object, _constant_tensor_conversion_function, 200) 315 316 317def _tensor_shape_tensor_conversion_function(s, 318 dtype=None, 319 name=None, 320 as_ref=False): 321 """Function to convert TensorShape to Tensor.""" 322 _ = as_ref 323 if not s.is_fully_defined(): 324 raise ValueError( 325 "Cannot convert a partially known TensorShape to a Tensor: %s" % s) 326 s_list = s.as_list() 327 int64_value = 0 328 for dim in s_list: 329 if dim >= 2**31: 330 int64_value = dim 331 break 332 333 if dtype is not None: 334 if dtype not in (dtypes.int32, dtypes.int64): 335 raise TypeError("Cannot convert a TensorShape to dtype: %s" % dtype) 336 if dtype == dtypes.int32 and int64_value: 337 raise ValueError("Cannot convert a TensorShape to dtype int32; " 338 "a dimension is too large (%s)" % int64_value) 339 else: 340 dtype = dtypes.int64 if int64_value else dtypes.int32 341 if name is None: 342 name = "shape_as_tensor" 343 return constant(s_list, dtype=dtype, name=name) 344 345 346ops.register_tensor_conversion_function( 347 tensor_shape.TensorShape, _tensor_shape_tensor_conversion_function, 100) 348 349 350def _dimension_tensor_conversion_function(d, 351 dtype=None, 352 name=None, 353 as_ref=False): 354 """Function to convert Dimension to Tensor.""" 355 _ = as_ref 356 if d.value is None: 357 raise ValueError("Cannot convert an unknown Dimension to a Tensor: %s" % d) 358 if dtype is not None: 359 if dtype not in (dtypes.int32, dtypes.int64): 360 raise TypeError("Cannot convert a TensorShape to dtype: %s" % dtype) 361 else: 362 dtype = dtypes.int32 363 if name is None: 364 name = "shape_as_tensor" 365 return constant(d.value, dtype=dtype, name=name) 366 367 368ops.register_tensor_conversion_function( 369 tensor_shape.Dimension, _dimension_tensor_conversion_function, 100) 370