• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONSTANTS_H_
17 #define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONSTANTS_H_
18 
19 #include <type_traits>
20 
21 #include "tensorflow/compiler/xla/client/xla_builder.h"
22 #include "tensorflow/compiler/xla/primitive_util.h"
23 #include "tensorflow/compiler/xla/types.h"
24 #include "tensorflow/compiler/xla/xla_data.pb.h"
25 
26 namespace xla {
27 
28 // Returns scalar 'value' as a scalar of 'type'. Unlike ConstantR0, 'type' is
29 // determined at C++ run-time, rather than C++ compile-time.
30 // If 'value' is floating point but 'type' is not, or if 'value' is complex but
31 // 'type' is not, an error will be returned. This is to catch accidental
32 // truncation; in such cases, use an explicit cast.
33 template <typename T>
ConstantR0WithType(XlaBuilder * builder,PrimitiveType type,T value)34 XlaOp ConstantR0WithType(XlaBuilder* builder, PrimitiveType type, T value) {
35   if (std::is_floating_point<T>::value &&
36       !(primitive_util::IsFloatingPointType(type) ||
37         primitive_util::IsComplexType(type))) {
38     return builder->ReportError(InvalidArgument(
39         "Invalid cast from floating point type to %s in ConstantR0WithType.",
40         PrimitiveType_Name(type)));
41   }
42   if (std::is_same<T, complex64>::value &&
43       !primitive_util::IsComplexType(type)) {
44     return builder->ReportError(InvalidArgument(
45         "Invalid cast from complex type to %s in ConstantR0WithType.",
46         PrimitiveType_Name(type)));
47   }
48   switch (type) {
49     case F16:
50       return ConstantR0<half>(builder, static_cast<half>(value));
51     case BF16:
52       return ConstantR0<bfloat16>(builder, static_cast<bfloat16>(value));
53     case F32:
54       return ConstantR0<float>(builder, static_cast<float>(value));
55     case F64:
56       return ConstantR0<double>(builder, static_cast<double>(value));
57     case C64:
58       return ConstantR0<complex64>(builder, static_cast<complex64>(value));
59     case C128:
60       return ConstantR0<complex128>(builder, static_cast<complex128>(value));
61     case U8:
62       return ConstantR0<uint8>(builder, static_cast<uint8>(value));
63     case U32:
64       return ConstantR0<uint32>(builder, static_cast<uint32>(value));
65     case U64:
66       return ConstantR0<uint64>(builder, static_cast<uint64>(value));
67     case S8:
68       return ConstantR0<int8>(builder, static_cast<int8>(value));
69     case S32:
70       return ConstantR0<int32>(builder, static_cast<int32>(value));
71     case S64:
72       return ConstantR0<int64>(builder, static_cast<int64>(value));
73     default:
74       return builder->ReportError(
75           InvalidArgument("Invalid type for ConstantR0WithType (%s).",
76                           PrimitiveType_Name(type)));
77   }
78 }
79 
80 // Returns a scalar containing 'value' cast to the same run-time type as
81 // 'prototype'.
82 // If 'value' is floating point but 'prototype' is not, or if 'value' is complex
83 // 'prototype' is not, an error will be returned.
84 template <typename T>
ScalarLike(XlaOp prototype,T value)85 XlaOp ScalarLike(XlaOp prototype, T value) {
86   XlaBuilder* builder = prototype.builder();
87   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
88     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(prototype));
89     return ConstantR0WithType(builder, shape.element_type(), value);
90   });
91 }
92 
93 // Returns an array or scalar containing copies of `value` cast to the same
94 // run-type type as `prototype` and broadcast to the same dimensions as
95 // `prototype`.
96 //
97 // If `prototype` is not a scalar or array, returns an error.
98 template <typename T>
FullLike(XlaOp prototype,T value)99 XlaOp FullLike(XlaOp prototype, T value) {
100   XlaBuilder* builder = prototype.builder();
101   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
102     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(prototype));
103     if (ShapeUtil::IsScalar(shape) || shape.IsArray()) {
104       return Broadcast(ScalarLike(prototype, value), shape.dimensions());
105     } else {
106       return InvalidArgument(
107           "Prototype shape for BroadcastConstantLike must be a scalar or "
108           "array, but was %s",
109           shape.ToString());
110     }
111   });
112 }
113 
114 // Returns a scalar with value '0' of 'type'.
115 XlaOp Zero(XlaBuilder* builder, PrimitiveType type);
116 
117 // Returns a zero-filled tensor with shape `shape`.
118 XlaOp Zeros(XlaBuilder* builder, const Shape& shape);
119 
120 // Returns a zero-filled tensor with the same shape as `prototype`.
121 XlaOp ZerosLike(XlaOp prototype);
122 
123 // Returns a scalar with value '1' of 'type'.
124 XlaOp One(XlaBuilder* builder, PrimitiveType type);
125 
126 // Returns the machine epsilon for floating-point type `type`, i.e.,
127 // the difference between 1.0 and the next representable value.
128 XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type);
129 
130 // Returns the minimum representable finite or infinite value for 'type'.
131 // Returns '-inf' for floating-point types.
132 XlaOp MinValue(XlaBuilder* builder, PrimitiveType type);
133 
134 // Returns the minimum representable finite value for 'type'. For a floating
135 // point type, this is equal to -MaxFiniteValue().
136 XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type);
137 
138 // Returns the minimum positive normal value for floating-point type `type`.
139 XlaOp MinPositiveNormalValue(XlaBuilder* builder, PrimitiveType type);
140 
141 // Returns the maximum representable finite or infinite value for 'type'.
142 // Returns 'inf' for floating-point types.
143 XlaOp MaxValue(XlaBuilder* builder, PrimitiveType type);
144 
145 // Returns the maximum representable finite value for 'type'.
146 XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type);
147 
148 // Returns a nan for the given type.  Only valid for real-valued fp types.
149 XlaOp NanValue(XlaBuilder* builder, PrimitiveType type);
150 
151 }  // namespace xla
152 
153 #endif  // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONSTANTS_H_
154