1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONSTANTS_H_
17 #define TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONSTANTS_H_
18
19 #include <type_traits>
20
21 #include "tensorflow/compiler/xla/client/xla_builder.h"
22 #include "tensorflow/compiler/xla/primitive_util.h"
23 #include "tensorflow/compiler/xla/types.h"
24 #include "tensorflow/compiler/xla/xla_data.pb.h"
25
26 namespace xla {
27
28 // Returns scalar 'value' as a scalar of 'type'. Unlike ConstantR0, 'type' is
29 // determined at C++ run-time, rather than C++ compile-time.
30 // If 'value' is floating point but 'type' is not, or if 'value' is complex but
31 // 'type' is not, an error will be returned. This is to catch accidental
32 // truncation; in such cases, use an explicit cast.
33 template <typename T>
ConstantR0WithType(XlaBuilder * builder,PrimitiveType type,T value)34 XlaOp ConstantR0WithType(XlaBuilder* builder, PrimitiveType type, T value) {
35 if (std::is_floating_point<T>::value &&
36 !(primitive_util::IsFloatingPointType(type) ||
37 primitive_util::IsComplexType(type))) {
38 return builder->ReportError(InvalidArgument(
39 "Invalid cast from floating point type to %s in ConstantR0WithType.",
40 PrimitiveType_Name(type)));
41 }
42 if (std::is_same<T, complex64>::value &&
43 !primitive_util::IsComplexType(type)) {
44 return builder->ReportError(InvalidArgument(
45 "Invalid cast from complex type to %s in ConstantR0WithType.",
46 PrimitiveType_Name(type)));
47 }
48 switch (type) {
49 case PRED:
50 return ConstantR0<bool>(builder, static_cast<bool>(value));
51 case F16:
52 return ConstantR0<half>(builder, static_cast<half>(value));
53 case BF16:
54 return ConstantR0<bfloat16>(builder, static_cast<bfloat16>(value));
55 case F32:
56 return ConstantR0<float>(builder, static_cast<float>(value));
57 case F64:
58 return ConstantR0<double>(builder, static_cast<double>(value));
59 case C64:
60 return ConstantR0<complex64>(builder, static_cast<complex64>(value));
61 case C128:
62 return ConstantR0<complex128>(builder, static_cast<complex128>(value));
63 case U8:
64 return ConstantR0<uint8_t>(builder, static_cast<uint8_t>(value));
65 case U16:
66 return ConstantR0<uint16_t>(builder, static_cast<uint16_t>(value));
67 case U32:
68 return ConstantR0<uint32_t>(builder, static_cast<uint32_t>(value));
69 case U64:
70 return ConstantR0<uint64_t>(builder, static_cast<uint64_t>(value));
71 case S8:
72 return ConstantR0<int8_t>(builder, static_cast<int8_t>(value));
73 case S16:
74 return ConstantR0<int16_t>(builder, static_cast<int16_t>(value));
75 case S32:
76 return ConstantR0<int32_t>(builder, static_cast<int32_t>(value));
77 case S64:
78 return ConstantR0<int64_t>(builder, static_cast<int64_t>(value));
79 default:
80 return builder->ReportError(
81 InvalidArgument("Invalid type for ConstantR0WithType (%s).",
82 PrimitiveType_Name(type)));
83 }
84 }
85
86 // Returns a scalar containing 'value' cast to the same run-time type as
87 // 'prototype'.
88 // If 'value' is floating point but 'prototype' is not, or if 'value' is complex
89 // 'prototype' is not, an error will be returned.
90 template <typename T>
ScalarLike(XlaOp prototype,T value)91 XlaOp ScalarLike(XlaOp prototype, T value) {
92 XlaBuilder* builder = prototype.builder();
93 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
94 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(prototype));
95 return ConstantR0WithType(builder, shape.element_type(), value);
96 });
97 }
98
99 // Returns an array or scalar containing copies of `value` cast to the same
100 // run-type type as `prototype` and broadcast to the same dimensions as
101 // `prototype`.
102 //
103 // If `prototype` is not a scalar or array, returns an error.
104 template <typename T>
FullLike(XlaOp prototype,T value)105 XlaOp FullLike(XlaOp prototype, T value) {
106 XlaBuilder* builder = prototype.builder();
107 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
108 TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(prototype));
109 if (ShapeUtil::IsScalar(shape) || shape.IsArray()) {
110 return Broadcast(ScalarLike(prototype, value), shape.dimensions());
111 } else {
112 return InvalidArgument(
113 "Prototype shape for BroadcastConstantLike must be a scalar or "
114 "array, but was %s",
115 shape.ToString());
116 }
117 });
118 }
119
120 // Returns a scalar with value '0' of 'type'.
121 XlaOp Zero(XlaBuilder* builder, PrimitiveType type);
122
123 // Returns a zero-filled tensor with shape `shape`.
124 XlaOp Zeros(XlaBuilder* builder, const Shape& shape);
125
126 // Returns a zero-filled tensor with the same shape as `prototype`.
127 XlaOp ZerosLike(XlaOp prototype);
128
129 // Returns a scalar with value '1' of 'type'.
130 XlaOp One(XlaBuilder* builder, PrimitiveType type);
131
132 // Returns the machine epsilon for floating-point type `type`, i.e.,
133 // the difference between 1.0 and the next representable value.
134 XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type);
135
136 // Returns the minimum representable finite or infinite value for 'type'.
137 // Returns '-inf' for floating-point types.
138 XlaOp MinValue(XlaBuilder* builder, PrimitiveType type);
139
140 // Returns the minimum representable finite value for 'type'. For a floating
141 // point type, this is equal to -MaxFiniteValue().
142 XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type);
143
144 // Returns the minimum positive normal value for floating-point type `type`.
145 XlaOp MinPositiveNormalValue(XlaBuilder* builder, PrimitiveType type);
146
147 // Returns the maximum representable finite or infinite value for 'type'.
148 // Returns 'inf' for floating-point types.
149 XlaOp MaxValue(XlaBuilder* builder, PrimitiveType type);
150
151 // Returns the maximum representable finite value for 'type'.
152 XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type);
153
154 // Returns a nan for the given type. Only valid for real-valued fp types.
155 XlaOp NanValue(XlaBuilder* builder, PrimitiveType type);
156
157 } // namespace xla
158
159 #endif // TENSORFLOW_COMPILER_XLA_CLIENT_LIB_CONSTANTS_H_
160