• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // XLA-specific sequence and range Ops.
17 
18 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
19 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
20 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
21 #include "tensorflow/compiler/xla/client/xla_builder.h"
22 #include "tensorflow/compiler/xla/literal.h"
23 #include "tensorflow/compiler/xla/primitive_util.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/register_types.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/framework/types.h"
29 
30 namespace tensorflow {
31 namespace {
32 
33 // The type-specific part of the implementation of Range.
34 template <typename T>
CreateRangeTensor(const xla::LiteralSlice & start_literal,const xla::LiteralSlice & limit_literal,const xla::LiteralSlice & delta_literal,xla::XlaBuilder * builder)35 xla::StatusOr<xla::XlaOp> CreateRangeTensor(
36     const xla::LiteralSlice& start_literal,
37     const xla::LiteralSlice& limit_literal,
38     const xla::LiteralSlice& delta_literal, xla::XlaBuilder* builder) {
39   T start = start_literal.Get<T>({});
40   T limit = limit_literal.Get<T>({});
41   T delta = delta_literal.Get<T>({});
42 
43   if (delta == 0) {
44     return errors::InvalidArgument("Requires delta != 0: ", delta);
45   }
46   if (delta > 0) {
47     if (start > limit) {
48       return errors::InvalidArgument(
49           "Requires start <= limit when delta > 0: ", start, "/", limit);
50     }
51   } else {
52     if (start < limit) {
53       return errors::InvalidArgument(
54           "Requires start >= limit when delta < 0: ", start, "/", limit);
55     }
56   }
57   int64 size =
58       (std::is_integral<T>::value
59            ? ((std::abs(limit - start) + std::abs(delta) - 1) / std::abs(delta))
60            : std::ceil(std::abs((limit - start) / delta)));
61 
62   return xla::ConstantR0(builder, start) +
63          xla::ConstantR0(builder, delta) *
64              xla::Iota(builder, xla::primitive_util::NativeToPrimitiveType<T>(),
65                        size);
66 }
67 
68 class RangeOp : public XlaOpKernel {
69  public:
RangeOp(OpKernelConstruction * ctx)70   explicit RangeOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
71 
Compile(XlaOpKernelContext * ctx)72   void Compile(XlaOpKernelContext* ctx) override {
73     const TensorShape start_in_shape = ctx->InputShape(0);
74     const TensorShape limit_in_shape = ctx->InputShape(1);
75     const TensorShape delta_in_shape = ctx->InputShape(2);
76     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(start_in_shape),
77                 errors::InvalidArgument("start must be a scalar, not shape ",
78                                         start_in_shape.DebugString()));
79     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(limit_in_shape),
80                 errors::InvalidArgument("limit must be a scalar, not shape ",
81                                         limit_in_shape.DebugString()));
82     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(delta_in_shape),
83                 errors::InvalidArgument("delta must be a scalar, not shape ",
84                                         delta_in_shape.DebugString()));
85     xla::Literal start, limit, delta;
86     OP_REQUIRES_OK(ctx, ctx->ConstantInput(0, &start));
87     OP_REQUIRES_OK(ctx, ctx->ConstantInput(1, &limit));
88     OP_REQUIRES_OK(ctx, ctx->ConstantInput(2, &delta));
89 
90     DataType type = input_type(0);
91     xla::StatusOr<xla::XlaOp> output;
92     switch (type) {
93       case DT_INT32:
94         output = CreateRangeTensor<int32>(start, limit, delta, ctx->builder());
95         break;
96       case DT_INT64:
97         output = CreateRangeTensor<int64>(start, limit, delta, ctx->builder());
98         break;
99       case DT_FLOAT:
100         output = CreateRangeTensor<float>(start, limit, delta, ctx->builder());
101         break;
102       case DT_DOUBLE:
103         output = CreateRangeTensor<double>(start, limit, delta, ctx->builder());
104         break;
105       default:
106         output = errors::InvalidArgument("Invalid type for Range ",
107                                          DataTypeString(type));
108     }
109     OP_REQUIRES_OK(ctx, output.status());
110     ctx->SetOutput(0, output.ValueOrDie());
111   }
112 };
113 
114 REGISTER_XLA_OP(Name("Range")
115                     .CompileTimeConstantInput("start")
116                     .CompileTimeConstantInput("limit")
117                     .CompileTimeConstantInput("delta"),
118                 RangeOp);
119 
120 class LinSpaceOp : public XlaOpKernel {
121  public:
LinSpaceOp(OpKernelConstruction * ctx)122   explicit LinSpaceOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
123 
Compile(XlaOpKernelContext * ctx)124   void Compile(XlaOpKernelContext* ctx) override {
125     const TensorShape start_in_shape = ctx->InputShape("start");
126     const TensorShape stop_in_shape = ctx->InputShape("stop");
127     const TensorShape num_in_shape = ctx->InputShape("num");
128     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(start_in_shape),
129                 errors::InvalidArgument("start must be a scalar, not shape ",
130                                         start_in_shape.DebugString()));
131     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(stop_in_shape),
132                 errors::InvalidArgument("stop must be a scalar, not shape ",
133                                         stop_in_shape.DebugString()));
134     OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(num_in_shape),
135                 errors::InvalidArgument("num must be a scalar, not shape ",
136                                         num_in_shape.DebugString()));
137 
138     DataType type = ctx->input_type(0);
139 
140     int64 num;
141     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar("num", &num));
142     OP_REQUIRES(ctx, num > 0,
143                 errors::InvalidArgument("Requires num > 0: ", num));
144     Tensor out_constant(type, TensorShape({num}));
145 
146     xla::Literal start_literal;
147     OP_REQUIRES_OK(ctx, ctx->ConstantInput("start", &start_literal));
148     xla::Literal stop_literal;
149     OP_REQUIRES_OK(ctx, ctx->ConstantInput("stop", &stop_literal));
150 
151     switch (type) {
152       case DT_FLOAT: {
153         float start = start_literal.GetFirstElement<float>();
154         float stop = stop_literal.GetFirstElement<float>();
155         auto flat = out_constant.flat<float>();
156         if (num == 1) {
157           flat(0) = start;
158         } else {
159           const float step = (stop - start) / (num - 1);
160           for (int64 i = 0; i < num - 1; ++i) {
161             flat(i) = start + step * i;
162           }
163           // The last value in the sequence must be equal to stop.
164           flat(num - 1) = stop;
165         }
166         break;
167       }
168       case DT_DOUBLE: {
169         double start = start_literal.GetFirstElement<double>();
170         double stop = stop_literal.GetFirstElement<double>();
171         auto flat = out_constant.flat<double>();
172         if (num == 1) {
173           flat(0) = start;
174         } else {
175           const double step = (stop - start) / (num - 1);
176           for (int64 i = 0; i < num - 1; ++i) {
177             flat(i) = start + step * i;
178           }
179           // The last value in the sequence must be equal to stop.
180           flat(num - 1) = stop;
181         }
182         break;
183       }
184 
185       default:
186         ctx->SetStatus(errors::InvalidArgument("Invalid argument type ",
187                                                DataTypeString(type)));
188         return;
189     }
190     ctx->SetConstantOutput(0, out_constant);
191   }
192 };
193 
194 REGISTER_XLA_OP(Name("LinSpace")
195                     .CompileTimeConstantInput("start")
196                     .CompileTimeConstantInput("stop")
197                     .CompileTimeConstantInput("num"),
198                 LinSpaceOp);
199 
200 }  // namespace
201 }  // namespace tensorflow
202