• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/math_ops.cc.
17 
18 #include <cmath>
19 
20 #include "tensorflow/core/framework/op_kernel.h"
21 #include "tensorflow/core/framework/register_types.h"
22 #include "tensorflow/core/framework/tensor.h"
23 #include "tensorflow/core/framework/tensor_shape.h"
24 #include "tensorflow/core/framework/types.h"
25 
26 namespace tensorflow {
27 
GetValue(int32 v)28 int32 GetValue(int32 v) { return v; }
29 
30 template <typename T>
31 class RangeOp : public OpKernel {
32  public:
RangeOp(OpKernelConstruction * context)33   explicit RangeOp(OpKernelConstruction* context) : OpKernel(context) {}
34 
Compute(OpKernelContext * context)35   void Compute(OpKernelContext* context) override {
36     const Tensor& start_in = context->input(0);
37     const Tensor& limit_in = context->input(1);
38     const Tensor& delta_in = context->input(2);
39     OP_REQUIRES(context, IsLegacyScalar(start_in.shape()),
40                 errors::InvalidArgument("start must be a scalar, not shape ",
41                                         start_in.shape().DebugString()));
42     OP_REQUIRES(context, IsLegacyScalar(limit_in.shape()),
43                 errors::InvalidArgument("limit must be a scalar, not shape ",
44                                         limit_in.shape().DebugString()));
45     OP_REQUIRES(context, IsLegacyScalar(delta_in.shape()),
46                 errors::InvalidArgument("delta must be a scalar, not shape ",
47                                         delta_in.shape().DebugString()));
48     const T start = start_in.scalar<T>()();
49     const T limit = limit_in.scalar<T>()();
50     const T delta = delta_in.scalar<T>()();
51     OP_REQUIRES(context, delta != 0,
52                 errors::InvalidArgument("Requires delta != 0: ", delta));
53     if (delta > 0) {
54       OP_REQUIRES(
55           context, start <= limit,
56           errors::InvalidArgument(
57               "Requires start <= limit when delta > 0: ", start, "/", limit));
58     } else {
59       OP_REQUIRES(
60           context, start >= limit,
61           errors::InvalidArgument(
62               "Requires start >= limit when delta < 0: ", start, "/", limit));
63     }
64     int64 size = (std::is_integral<T>::value
65                       ? ((std::abs(limit - start) + std::abs(delta) - 1) /
66                          std::abs(delta))
67                       : std::ceil(std::abs((limit - start) / delta)));
68     Tensor* out = nullptr;
69     OP_REQUIRES_OK(context,
70                    context->allocate_output(0, TensorShape({size}), &out));
71     auto flat = out->flat<T>();
72     T val = start;
73     for (int64 i = 0; i < size; ++i) {
74       flat(i) = T(val);
75       val += delta;
76     }
77   }
78 };
79 
80 #define REGISTER_KERNEL(DEV, TYPE)                           \
81   REGISTER_KERNEL_BUILDER(Name("Range")                      \
82                               .Device(DEV)                   \
83                               .HostMemory("start")           \
84                               .HostMemory("limit")           \
85                               .HostMemory("delta")           \
86                               .HostMemory("output")          \
87                               .TypeConstraint<TYPE>("Tidx"), \
88                           RangeOp<TYPE>);
89 
90 #define REGISTER_CPU_KERNEL(T) REGISTER_KERNEL(DEVICE_CPU, T)
91 #define REGISTER_GPU_KERNEL(T) REGISTER_KERNEL(DEVICE_GPU, T)
92 #ifdef TENSORFLOW_USE_SYCL
93 #define REGISTER_SYCL_KERNEL(T) REGISTER_KERNEL(DEVICE_SYCL, T)
94 TF_CALL_float(REGISTER_SYCL_KERNEL);
95 TF_CALL_double(REGISTER_SYCL_KERNEL);
96 TF_CALL_int32(REGISTER_SYCL_KERNEL);
97 TF_CALL_int64(REGISTER_SYCL_KERNEL);
98 #undef REGISTER_SYCL_KERNEL
99 #endif  // TENSORFLOW_USE_SYCL
100 
101 TF_CALL_float(REGISTER_CPU_KERNEL);
102 TF_CALL_double(REGISTER_CPU_KERNEL);
103 TF_CALL_int32(REGISTER_CPU_KERNEL);
104 TF_CALL_int64(REGISTER_CPU_KERNEL);
105 
106 #if GOOGLE_CUDA
107 
108 TF_CALL_float(REGISTER_GPU_KERNEL);
109 TF_CALL_double(REGISTER_GPU_KERNEL);
110 TF_CALL_int32(REGISTER_GPU_KERNEL);
111 TF_CALL_int64(REGISTER_GPU_KERNEL);
112 
113 #endif  // GOOGLE_CUDA
114 
115 #undef REGISTER_KERNEL
116 #undef REGISTER_CPU_KERNEL
117 #undef REGISTER_GPU_KERNEL
118 
119 template <typename T, typename Tnum>
120 class LinSpaceOp : public OpKernel {
121  public:
LinSpaceOp(OpKernelConstruction * context)122   explicit LinSpaceOp(OpKernelConstruction* context) : OpKernel(context) {}
123 
Compute(OpKernelContext * context)124   void Compute(OpKernelContext* context) override {
125     const Tensor& start_in = context->input(0);
126     const Tensor& stop_in = context->input(1);
127     const Tensor& num_in = context->input(2);
128     OP_REQUIRES(context, TensorShapeUtils::IsScalar(start_in.shape()),
129                 errors::InvalidArgument("start must be a scalar, not shape ",
130                                         start_in.shape().DebugString()));
131     OP_REQUIRES(context, TensorShapeUtils::IsScalar(stop_in.shape()),
132                 errors::InvalidArgument("stop must be a scalar, not shape ",
133                                         stop_in.shape().DebugString()));
134     OP_REQUIRES(context, TensorShapeUtils::IsScalar(num_in.shape()),
135                 errors::InvalidArgument("num must be a scalar, not shape ",
136                                         num_in.shape().DebugString()));
137     const T start = start_in.scalar<T>()();
138     const T stop = stop_in.scalar<T>()();
139     const Tnum num = num_in.scalar<Tnum>()();
140     OP_REQUIRES(context, num > 0,
141                 errors::InvalidArgument("Requires num > 0: ", num));
142     Tensor* out = nullptr;
143     OP_REQUIRES_OK(context,
144                    context->allocate_output(0, TensorShape({num}), &out));
145     auto flat = out->flat<T>();
146     flat(0) = start;
147     if (num > 1) {
148       const T step = (stop - start) / (num - 1);
149       for (Tnum i = 1; i < num - 1; ++i) flat(i) = start + step * i;
150       // Ensure final value == stop; float arithmetic won't guarantee this.
151       flat(num - 1) = stop;
152     }
153   }
154 };
155 
156 #define REGISTER_KERNEL(DEV, T, Tidx)                       \
157   REGISTER_KERNEL_BUILDER(Name("LinSpace")                  \
158                               .Device(DEV)                  \
159                               .TypeConstraint<T>("T")       \
160                               .TypeConstraint<Tidx>("Tidx") \
161                               .HostMemory("start")          \
162                               .HostMemory("stop")           \
163                               .HostMemory("num")            \
164                               .HostMemory("output"),        \
165                           LinSpaceOp<T, Tidx>);
166 
167 #define REGISTER_KERNEL_ALL_NUMS(dev, T) \
168   REGISTER_KERNEL(dev, T, int32);        \
169   REGISTER_KERNEL(dev, T, int64)
170 
171 #define REGISTER_CPU_KERNEL(T) REGISTER_KERNEL_ALL_NUMS(DEVICE_CPU, T)
172 TF_CALL_float(REGISTER_CPU_KERNEL);
173 TF_CALL_double(REGISTER_CPU_KERNEL);
174 
175 // NOTE(touts): We register the op on GPU but it still runs on CPU
176 // because its inputs and outputs are tagged as HostMemory.
177 #define REGISTER_GPU_KERNEL(T) REGISTER_KERNEL_ALL_NUMS(DEVICE_GPU, T)
178 TF_CALL_float(REGISTER_GPU_KERNEL);
179 TF_CALL_double(REGISTER_GPU_KERNEL);
180 #undef REGISTER_GPU_KERNEL
181 
182 #ifdef TENSORFLOW_USE_SYCL
183 #define REGISTER_SYCL_KERNEL(T) REGISTER_KERNEL_ALL_NUMS(DEVICE_SYCL, T)
184 TF_CALL_float(REGISTER_SYCL_KERNEL);
185 TF_CALL_double(REGISTER_SYCL_KERNEL);
186 #undef REGISTER_SYCL_KERNEL
187 #endif  // TENSORFLOW_USE_SYCL
188 
189 #undef REGISTER_CPU_KERNEL
190 #undef REGISTER_KERNEL_ALL_NUMS
191 #undef REGISTER_KERNEL
192 
193 }  // namespace tensorflow
194