• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This file defines helper routines for XLA compilation.
17 
18 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
19 #include "tensorflow/compiler/tf2xla/lib/util.h"
20 
21 #include "absl/types/span.h"
22 #include "tensorflow/compiler/tf2xla/literal_util.h"
23 #include "tensorflow/compiler/tf2xla/shape_util.h"
24 #include "tensorflow/compiler/tf2xla/type_util.h"
25 #include "tensorflow/compiler/tf2xla/xla_context.h"
26 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
27 #include "tensorflow/compiler/xla/client/lib/arithmetic.h"
28 #include "tensorflow/compiler/xla/client/lib/constants.h"
29 #include "tensorflow/compiler/xla/client/xla_builder.h"
30 #include "tensorflow/compiler/xla/client/xla_computation.h"
31 #include "tensorflow/compiler/xla/types.h"
32 #include "tensorflow/core/framework/tensor.h"
33 #include "tensorflow/core/lib/core/status.h"
34 
35 namespace tensorflow {
36 
Zero(xla::XlaBuilder * b,DataType data_type)37 xla::XlaOp XlaHelpers::Zero(xla::XlaBuilder* b, DataType data_type) {
38   xla::PrimitiveType type;
39   TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
40   return xla::ConstantLiteral(b, xla::LiteralUtil::Zero(type));
41 }
42 
One(xla::XlaBuilder * b,DataType data_type)43 xla::XlaOp XlaHelpers::One(xla::XlaBuilder* b, DataType data_type) {
44   xla::PrimitiveType type;
45   TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
46   return xla::ConstantLiteral(b, xla::LiteralUtil::One(type));
47 }
48 
IntegerLiteral(xla::XlaBuilder * b,DataType data_type,int64 value)49 xla::XlaOp XlaHelpers::IntegerLiteral(xla::XlaBuilder* b, DataType data_type,
50                                       int64 value) {
51   xla::PrimitiveType type;
52   TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
53   return ::tensorflow::IntegerLiteral(b, type, value);
54 }
55 
FloatLiteral(xla::XlaBuilder * b,DataType data_type,double value)56 xla::XlaOp XlaHelpers::FloatLiteral(xla::XlaBuilder* b, DataType data_type,
57                                     double value) {
58   xla::PrimitiveType type;
59   TF_CHECK_OK(DataTypeToPrimitiveType(data_type, &type));
60   return ::tensorflow::FloatLiteral(b, type, value);
61 }
62 
ReshapeLiteral(const xla::Literal & input,absl::Span<const int64> dimensions,xla::Literal * output)63 /* static */ Status XlaHelpers::ReshapeLiteral(
64     const xla::Literal& input, absl::Span<const int64> dimensions,
65     xla::Literal* output) {
66   if (input.shape().IsTuple()) {
67     return errors::InvalidArgument("ReshapeLiteral does not support tuples.");
68   }
69   xla::Shape shape =
70       xla::ShapeUtil::MakeShape(input.shape().element_type(), dimensions);
71   int64 elements_before = xla::ShapeUtil::ElementsIn(input.shape());
72   int64 elements_after = xla::ShapeUtil::ElementsIn(shape);
73   if (elements_before != elements_after) {
74     return errors::InvalidArgument(
75         "Shapes before and after ReshapeLiteral have different numbers of "
76         "elements.");
77   }
78 
79   *output = input.Clone();
80   output->mutable_shape_do_not_use()->Swap(&shape);
81   return Status::OK();
82 }
83 
OneHot(xla::XlaBuilder * builder,int64 depth,int axis,DataType index_type,const TensorShape & indices_shape,const xla::XlaOp & indices,const xla::XlaOp & on_value,const xla::XlaOp & off_value,xla::XlaOp * one_hot)84 Status XlaHelpers::OneHot(xla::XlaBuilder* builder, int64 depth, int axis,
85                           DataType index_type, const TensorShape& indices_shape,
86                           const xla::XlaOp& indices, const xla::XlaOp& on_value,
87                           const xla::XlaOp& off_value, xla::XlaOp* one_hot) {
88   // Broadcast the linspace constant across the indices along the new axis,
89   // and test equality at each position.
90   std::vector<int64> broadcast_dims(indices_shape.dims());
91   std::iota(broadcast_dims.begin(), broadcast_dims.begin() + axis, 0);
92   std::iota(broadcast_dims.begin() + axis, broadcast_dims.end(), axis + 1);
93 
94   TensorShape output_shape = indices_shape;
95   output_shape.InsertDim(axis, depth);
96   xla::Shape iota_shape;
97   TF_RETURN_IF_ERROR(
98       TensorShapeToXLAShape(index_type, output_shape, &iota_shape));
99 
100   // Selects the user-provided off_value and on_value values.
101   *one_hot = xla::Select(
102       xla::Eq(indices, xla::Iota(builder, iota_shape, axis), broadcast_dims),
103       xla::Broadcast(on_value, output_shape.dim_sizes()),
104       xla::Broadcast(off_value, output_shape.dim_sizes()));
105   return Status::OK();
106 }
107 
SumAccumulationType(const DataType & dtype)108 DataType XlaHelpers::SumAccumulationType(const DataType& dtype) {
109   // Upcast 16 bit sum reductions to 32 bit to reduce the precision loss from
110   // repeated floating point additions.
111   if (dtype == DT_BFLOAT16 || dtype == DT_HALF) {
112     return DT_FLOAT;
113   }
114   return dtype;
115 }
116 
ConvertElementType(const xla::XlaOp & operand,const DataType new_element_type)117 xla::XlaOp XlaHelpers::ConvertElementType(const xla::XlaOp& operand,
118                                           const DataType new_element_type) {
119   xla::PrimitiveType convert_to;
120   TF_CHECK_OK(DataTypeToPrimitiveType(new_element_type, &convert_to));
121   return xla::ConvertElementType(operand, convert_to);
122 }
123 
124 }  // end namespace tensorflow
125