• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/client/lib/constants.h"
17 
18 #include "tensorflow/compiler/xla/literal_util.h"
19 #include "tensorflow/compiler/xla/util.h"
20 
21 namespace xla {
22 
Zero(XlaBuilder * builder,PrimitiveType type)23 XlaOp Zero(XlaBuilder* builder, PrimitiveType type) {
24   return ConstantLiteral(builder, LiteralUtil::Zero(type));
25 }
26 
Zeros(XlaBuilder * builder,const Shape & shape)27 XlaOp Zeros(XlaBuilder* builder, const Shape& shape) {
28   return Broadcast(Zero(builder, shape.element_type()), shape.dimensions());
29 }
30 
ZerosLike(XlaOp prototype)31 XlaOp ZerosLike(XlaOp prototype) {
32   XlaBuilder* builder = prototype.builder();
33   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
34     TF_ASSIGN_OR_RETURN(Shape shape, builder->GetShape(prototype));
35     return Zeros(builder, shape);
36   });
37 }
38 
One(XlaBuilder * builder,PrimitiveType type)39 XlaOp One(XlaBuilder* builder, PrimitiveType type) {
40   return ConstantLiteral(builder, LiteralUtil::One(type));
41 }
42 
Epsilon(XlaBuilder * builder,PrimitiveType type)43 XlaOp Epsilon(XlaBuilder* builder, PrimitiveType type) {
44   switch (type) {
45     case F16:
46       return ConstantR0<Eigen::half>(
47           builder,
48           static_cast<Eigen::half>(Eigen::NumTraits<Eigen::half>::epsilon()));
49     case BF16:
50       return ConstantR0<Eigen::bfloat16>(
51           builder, static_cast<Eigen::bfloat16>(
52                        Eigen::NumTraits<Eigen::bfloat16>::epsilon()));
53     case F32:
54       return ConstantR0<float>(builder, std::numeric_limits<float>::epsilon());
55     case F64:
56       return ConstantR0<double>(builder,
57                                 std::numeric_limits<double>::epsilon());
58     default:
59       return builder->ReportError(InvalidArgument(
60           "Invalid type for Epsilon (%s).", PrimitiveType_Name(type)));
61   }
62 }
63 
MinValue(XlaBuilder * builder,PrimitiveType type)64 XlaOp MinValue(XlaBuilder* builder, PrimitiveType type) {
65   return ConstantLiteral(builder, LiteralUtil::MinValue(type));
66 }
67 
MinFiniteValue(XlaBuilder * builder,PrimitiveType type)68 XlaOp MinFiniteValue(XlaBuilder* builder, PrimitiveType type) {
69   switch (type) {
70     case F16:
71       return ConstantR0<Eigen::half>(builder,
72                                      Eigen::NumTraits<Eigen::half>::lowest());
73     case BF16:
74       return ConstantR0<Eigen::bfloat16>(
75           builder, Eigen::NumTraits<Eigen::bfloat16>::lowest());
76     case F32:
77       return ConstantR0<float>(builder, -std::numeric_limits<float>::max());
78     case F64:
79       return ConstantR0<double>(builder, -std::numeric_limits<double>::max());
80     default:
81       return MinValue(builder, type);
82   }
83 }
84 
MinPositiveNormalValue(XlaBuilder * builder,PrimitiveType type)85 XlaOp MinPositiveNormalValue(XlaBuilder* builder, PrimitiveType type) {
86   switch (type) {
87     case F16:
88       return ConstantR0<Eigen::half>(builder,
89                                      std::numeric_limits<Eigen::half>::min());
90     case BF16:
91       return ConstantR0<Eigen::bfloat16>(
92           builder, std::numeric_limits<Eigen::bfloat16>::min());
93     case F32:
94       return ConstantR0<float>(builder, std::numeric_limits<float>::min());
95     case F64:
96       return ConstantR0<double>(builder, std::numeric_limits<double>::min());
97     default:
98       return builder->ReportError(
99           InvalidArgument("Invalid type for MinPositiveNormalValue (%s).",
100                           PrimitiveType_Name(type)));
101   }
102 }
103 
MaxValue(XlaBuilder * builder,PrimitiveType type)104 XlaOp MaxValue(XlaBuilder* builder, PrimitiveType type) {
105   return ConstantLiteral(builder, LiteralUtil::MaxValue(type));
106 }
107 
MaxFiniteValue(XlaBuilder * builder,PrimitiveType type)108 XlaOp MaxFiniteValue(XlaBuilder* builder, PrimitiveType type) {
109   switch (type) {
110     case F16:
111       return ConstantR0<Eigen::half>(builder,
112                                      Eigen::NumTraits<Eigen::half>::highest());
113     case BF16:
114       return ConstantR0<Eigen::bfloat16>(
115           builder, Eigen::NumTraits<Eigen::bfloat16>::highest());
116     case F32:
117       return ConstantR0<float>(builder, std::numeric_limits<float>::max());
118     case F64:
119       return ConstantR0<double>(builder, std::numeric_limits<double>::max());
120     default:
121       return MaxValue(builder, type);
122   }
123 }
124 
NanValue(XlaBuilder * builder,PrimitiveType type)125 XlaOp NanValue(XlaBuilder* builder, PrimitiveType type) {
126   return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> {
127     switch (type) {
128       case F16:
129         return ConstantR0<Eigen::half>(
130             builder, Eigen::NumTraits<Eigen::half>::quiet_NaN());
131       case BF16:
132         return ConstantR0<Eigen::bfloat16>(
133             builder, Eigen::NumTraits<Eigen::bfloat16>::quiet_NaN());
134       case F32:
135         return ConstantR0<float>(builder,
136                                  std::numeric_limits<float>::quiet_NaN());
137       case F64:
138         return ConstantR0<double>(builder,
139                                   std::numeric_limits<double>::quiet_NaN());
140       default:
141         return InvalidArgument(
142             "Operand to NanValue was %s, but must be a real-valued "
143             "floating-point type.",
144             PrimitiveType_Name(type));
145     }
146   });
147 }
148 
149 }  // namespace xla
150