1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // XLA-specific sequence and range Ops.
17
18 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
19 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
20 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
21 #include "tensorflow/compiler/xla/client/xla_builder.h"
22 #include "tensorflow/compiler/xla/literal.h"
23 #include "tensorflow/compiler/xla/primitive_util.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/register_types.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/framework/types.h"
29
30 namespace tensorflow {
31 namespace {
32
33 // The type-specific part of the implementation of Range.
34 template <typename T>
CreateRangeTensor(const xla::LiteralSlice & start_literal,const xla::LiteralSlice & limit_literal,const xla::LiteralSlice & delta_literal,xla::XlaBuilder * builder)35 xla::StatusOr<xla::XlaOp> CreateRangeTensor(
36 const xla::LiteralSlice& start_literal,
37 const xla::LiteralSlice& limit_literal,
38 const xla::LiteralSlice& delta_literal, xla::XlaBuilder* builder) {
39 T start = start_literal.Get<T>({});
40 T limit = limit_literal.Get<T>({});
41 T delta = delta_literal.Get<T>({});
42
43 if (delta == 0) {
44 return errors::InvalidArgument("Requires delta != 0: ", delta);
45 }
46 if (delta > 0) {
47 if (start > limit) {
48 return errors::InvalidArgument(
49 "Requires start <= limit when delta > 0: ", start, "/", limit);
50 }
51 } else {
52 if (start < limit) {
53 return errors::InvalidArgument(
54 "Requires start >= limit when delta < 0: ", start, "/", limit);
55 }
56 }
57 int64 size =
58 (std::is_integral<T>::value
59 ? ((std::abs(limit - start) + std::abs(delta) - 1) / std::abs(delta))
60 : std::ceil(std::abs((limit - start) / delta)));
61
62 return xla::ConstantR0(builder, start) +
63 xla::ConstantR0(builder, delta) *
64 xla::Iota(builder, xla::primitive_util::NativeToPrimitiveType<T>(),
65 size);
66 }
67
68 class RangeOp : public XlaOpKernel {
69 public:
RangeOp(OpKernelConstruction * ctx)70 explicit RangeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
71
Compile(XlaOpKernelContext * ctx)72 void Compile(XlaOpKernelContext* ctx) override {
73 const TensorShape start_in_shape = ctx->InputShape(0);
74 const TensorShape limit_in_shape = ctx->InputShape(1);
75 const TensorShape delta_in_shape = ctx->InputShape(2);
76 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(start_in_shape),
77 errors::InvalidArgument("start must be a scalar, not shape ",
78 start_in_shape.DebugString()));
79 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(limit_in_shape),
80 errors::InvalidArgument("limit must be a scalar, not shape ",
81 limit_in_shape.DebugString()));
82 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(delta_in_shape),
83 errors::InvalidArgument("delta must be a scalar, not shape ",
84 delta_in_shape.DebugString()));
85 xla::Literal start, limit, delta;
86 OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &start));
87 OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &limit));
88 OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &delta));
89
90 DataType type = input_type(0);
91 xla::StatusOr<xla::XlaOp> output;
92 switch (type) {
93 case DT_INT32:
94 output = CreateRangeTensor<int32>(start, limit, delta, ctx->builder());
95 break;
96 case DT_INT64:
97 output = CreateRangeTensor<int64>(start, limit, delta, ctx->builder());
98 break;
99 case DT_FLOAT:
100 output = CreateRangeTensor<float>(start, limit, delta, ctx->builder());
101 break;
102 case DT_DOUBLE:
103 output = CreateRangeTensor<double>(start, limit, delta, ctx->builder());
104 break;
105 default:
106 output = errors::InvalidArgument("Invalid type for Range ",
107 DataTypeString(type));
108 }
109 OP_REQUIRES_OK(ctx, output.status());
110 ctx->SetOutput(0, output.ValueOrDie());
111 }
112 };
113
114 REGISTER_XLA_OP(Name("Range")
115 .CompileTimeConstantInput("start")
116 .CompileTimeConstantInput("limit")
117 .CompileTimeConstantInput("delta"),
118 RangeOp);
119
120 class LinSpaceOp : public XlaOpKernel {
121 public:
LinSpaceOp(OpKernelConstruction * ctx)122 explicit LinSpaceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
123
Compile(XlaOpKernelContext * ctx)124 void Compile(XlaOpKernelContext* ctx) override {
125 const TensorShape start_in_shape = ctx->InputShape("start");
126 const TensorShape stop_in_shape = ctx->InputShape("stop");
127 const TensorShape num_in_shape = ctx->InputShape("num");
128 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(start_in_shape),
129 errors::InvalidArgument("start must be a scalar, not shape ",
130 start_in_shape.DebugString()));
131 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(stop_in_shape),
132 errors::InvalidArgument("stop must be a scalar, not shape ",
133 stop_in_shape.DebugString()));
134 OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(num_in_shape),
135 errors::InvalidArgument("num must be a scalar, not shape ",
136 num_in_shape.DebugString()));
137
138 DataType type = ctx->input_type(0);
139
140 int64 num;
141 OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar("num", &num));
142 OP_REQUIRES(ctx, num > 0,
143 errors::InvalidArgument("Requires num > 0: ", num));
144 Tensor out_constant(type, TensorShape({num}));
145
146 xla::Literal start_literal;
147 OP_REQUIRES_OK(ctx, ctx->ConstantInput("start", &start_literal));
148 xla::Literal stop_literal;
149 OP_REQUIRES_OK(ctx, ctx->ConstantInput("stop", &stop_literal));
150
151 switch (type) {
152 case DT_FLOAT: {
153 float start = start_literal.GetFirstElement<float>();
154 float stop = stop_literal.GetFirstElement<float>();
155 auto flat = out_constant.flat<float>();
156 if (num == 1) {
157 flat(0) = start;
158 } else {
159 const float step = (stop - start) / (num - 1);
160 for (int64 i = 0; i < num - 1; ++i) {
161 flat(i) = start + step * i;
162 }
163 // The last value in the sequence must be equal to stop.
164 flat(num - 1) = stop;
165 }
166 break;
167 }
168 case DT_DOUBLE: {
169 double start = start_literal.GetFirstElement<double>();
170 double stop = stop_literal.GetFirstElement<double>();
171 auto flat = out_constant.flat<double>();
172 if (num == 1) {
173 flat(0) = start;
174 } else {
175 const double step = (stop - start) / (num - 1);
176 for (int64 i = 0; i < num - 1; ++i) {
177 flat(i) = start + step * i;
178 }
179 // The last value in the sequence must be equal to stop.
180 flat(num - 1) = stop;
181 }
182 break;
183 }
184
185 default:
186 ctx->SetStatus(errors::InvalidArgument("Invalid argument type ",
187 DataTypeString(type)));
188 return;
189 }
190 ctx->SetConstantOutput(0, out_constant);
191 }
192 };
193
194 REGISTER_XLA_OP(Name("LinSpace")
195 .CompileTimeConstantInput("start")
196 .CompileTimeConstantInput("stop")
197 .CompileTimeConstantInput("num"),
198 LinSpaceOp);
199
200 } // namespace
201 } // namespace tensorflow
202