1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_KERNELS_SHAPE_OPS_H_
17 #define TENSORFLOW_CORE_KERNELS_SHAPE_OPS_H_
18
19 #include <limits>
20 #include <unordered_set>
21 #include <vector>
22
23 #include "tensorflow/core/framework/bounds_check.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/tensor_shape.h"
27 #include "tensorflow/core/framework/variant_op_registry.h"
28
29 namespace tensorflow {
30
31 namespace shape_op_helpers {
GetShape(OpKernelContext * ctx,int input_index,TensorShape * shape)32 inline Status GetShape(OpKernelContext* ctx, int input_index,
33 TensorShape* shape) {
34 *shape = ctx->input(input_index).shape();
35 return Status::OK();
36 }
37 } // namespace shape_op_helpers
38
39 template <typename OutType>
40 class ShapeOp : public OpKernel {
41 public:
ShapeOp(OpKernelConstruction * ctx)42 explicit ShapeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
43
Compute(OpKernelContext * ctx)44 void Compute(OpKernelContext* ctx) override {
45 TensorShape shape;
46 OP_REQUIRES_OK(ctx, shape_op_helpers::GetShape(ctx, 0, &shape));
47 const int rank = shape.dims();
48 Tensor* out = nullptr;
49 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({rank}), &out));
50 auto vec = out->vec<OutType>();
51 for (int i = 0; i < rank; ++i) {
52 int64 dim_size = shape.dim_size(i);
53 if (out->dtype() == DT_INT32) {
54 OP_REQUIRES(
55 ctx, FastBoundsCheck(dim_size, std::numeric_limits<int32>::max()),
56 errors::InvalidArgument("Shape output type is 32-bit ", " but dim ",
57 i, " is ", dim_size));
58 }
59 vec(i) = static_cast<OutType>(dim_size);
60 }
61 }
62
IsExpensive()63 bool IsExpensive() override { return false; }
64 };
65
66 template <typename OutType>
67 class ShapeNOp : public OpKernel {
68 public:
ShapeNOp(OpKernelConstruction * ctx)69 explicit ShapeNOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
70
Compute(OpKernelContext * ctx)71 void Compute(OpKernelContext* ctx) override {
72 for (int i = 0; i < ctx->num_inputs(); ++i) {
73 TensorShape shape;
74 OP_REQUIRES_OK(ctx, shape_op_helpers::GetShape(ctx, i, &shape));
75 const int dims = shape.dims();
76 Tensor* out = nullptr;
77 OP_REQUIRES_OK(ctx, ctx->allocate_output(i, {dims}, &out));
78 auto vec = out->vec<OutType>();
79
80 for (int j = 0; j < dims; ++j) {
81 int64 dim_size = shape.dim_size(j);
82 if (out->dtype() == DT_INT32) {
83 OP_REQUIRES(
84 ctx, FastBoundsCheck(dim_size, std::numeric_limits<int32>::max()),
85 errors::InvalidArgument("ShapeN output type is 32-bit but shape ",
86 i, " dim ", j, " is ", dim_size));
87 }
88 vec(j) = static_cast<OutType>(dim_size);
89 }
90 }
91 }
92
IsExpensive()93 bool IsExpensive() override { return false; }
94 };
95
96 class RankOp : public OpKernel {
97 public:
RankOp(OpKernelConstruction * ctx)98 explicit RankOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
99
Compute(OpKernelContext * ctx)100 void Compute(OpKernelContext* ctx) override {
101 TensorShape shape;
102 OP_REQUIRES_OK(ctx, shape_op_helpers::GetShape(ctx, 0, &shape));
103 const int rank = shape.dims();
104 Tensor* out = nullptr;
105 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out));
106 out->scalar<int32>()() = rank;
107 }
108
IsExpensive()109 bool IsExpensive() override { return false; }
110 };
111
112 template <typename OutType>
113 class SizeOp : public OpKernel {
114 public:
SizeOp(OpKernelConstruction * ctx)115 explicit SizeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
116
Compute(OpKernelContext * ctx)117 void Compute(OpKernelContext* ctx) override {
118 TensorShape shape;
119 OP_REQUIRES_OK(ctx, shape_op_helpers::GetShape(ctx, 0, &shape));
120 const int64 size = shape.num_elements();
121 Tensor* out = nullptr;
122 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, TensorShape({}), &out));
123 if (out->dtype() == DT_INT32) {
124 OP_REQUIRES(
125 ctx, FastBoundsCheck(size, std::numeric_limits<int32>::max()),
126 errors::InvalidArgument("Number of elements was larger than "
127 "representable by 32-bit output type"));
128 }
129 out->scalar<OutType>()() = static_cast<OutType>(size);
130 }
131
IsExpensive()132 bool IsExpensive() override { return false; }
133 };
134
135 template <typename Tdim>
136 class ExpandDimsOp : public OpKernel {
137 public:
ExpandDimsOp(OpKernelConstruction * ctx)138 explicit ExpandDimsOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
139
Compute(OpKernelContext * ctx)140 void Compute(OpKernelContext* ctx) override {
141 OP_REQUIRES(ctx, ctx->input(0).dtype() != DT_VARIANT,
142 errors::InvalidArgument("ExpandDims on Variant not supported"));
143
144 OP_REQUIRES(
145 ctx, (ctx->input(1).NumElements() == 1),
146 errors::InvalidArgument("'dim' must be a tensor with a single value"));
147 Tdim dim = ctx->input(1).flat<Tdim>()(0);
148 OP_REQUIRES(
149 ctx, (dim >= -1 - ctx->input(0).dims() && dim <= ctx->input(0).dims()),
150 errors::InvalidArgument("Tried to expand dim index ", dim,
151 " for tensor with ", ctx->input(0).dims(),
152 " dimensions."));
153
154 auto existing_dims = ctx->input(0).shape().dim_sizes();
155 // Safe - # elements in tensor dims bounded.
156 const int existing_dims_size = static_cast<int>(existing_dims.size());
157 std::vector<int64> new_shape(existing_dims_size);
158 for (size_t i = 0; i < new_shape.size(); ++i) {
159 new_shape[i] = existing_dims[i];
160 }
161
162 // We emulate numpy's interpretation of the dim axis when
163 // -input.dims() >= dim <= input.dims().
164 if (dim < 0) {
165 dim += existing_dims.size() + 1;
166 }
167
168 // Clamp to the end if needed.
169 dim = std::min<Tdim>(dim, existing_dims_size);
170 new_shape.emplace(new_shape.begin() + dim, 1);
171 const TensorShape output_shape(new_shape);
172
173 Tensor* output = nullptr;
174 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {0}, &output));
175 if (!output->CopyFrom(ctx->input(0), output_shape)) {
176 // This should never happen, since the sizes of the input and output
177 // should always be the same (we only expand the dimension with 1).
178 ctx->SetStatus(
179 errors::Internal("Could not expand dimension with input shape ",
180 ctx->input(0).shape().DebugString(),
181 " and output shape ", output_shape.DebugString()));
182 }
183 }
184
IsExpensive()185 bool IsExpensive() override { return false; }
186 };
187
188 class SqueezeOp : public OpKernel {
189 public:
SqueezeOp(OpKernelConstruction * ctx)190 explicit SqueezeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
191 std::vector<int32> squeeze_dims;
192 OP_REQUIRES_OK(ctx, ctx->GetAttr("squeeze_dims", &squeeze_dims));
193 squeeze_dims_.insert(squeeze_dims.begin(), squeeze_dims.end());
194 }
195
Compute(OpKernelContext * ctx)196 void Compute(OpKernelContext* ctx) override {
197 OP_REQUIRES(ctx, ctx->input(0).dtype() != DT_VARIANT,
198 errors::InvalidArgument("Squeeze on Variant not supported"));
199
200 auto existing_dims = ctx->input(0).shape().dim_sizes();
201 const int existing_dims_size = static_cast<int>(existing_dims.size());
202 std::vector<int64> new_shape;
203
204 std::unordered_set<int32> wrapped_squeeze_dims;
205 wrapped_squeeze_dims.reserve(squeeze_dims_.size());
206 // Validate squeeze dims against the input.
207 for (int32 dim : squeeze_dims_) {
208 OP_REQUIRES(
209 ctx, (dim >= -ctx->input(0).dims() && dim < ctx->input(0).dims()),
210 errors::InvalidArgument("Tried to squeeze dim index ", dim,
211 " for tensor with ", ctx->input(0).dims(),
212 " dimensions."));
213 // If dim is < 0, we wrap around (-1 means the last element).
214 if (dim < 0) {
215 dim = existing_dims_size + dim;
216 }
217
218 wrapped_squeeze_dims.insert(dim);
219 }
220
221 for (int i = 0; i < existing_dims_size; ++i) {
222 auto existing_dim = existing_dims[i];
223
224 // If squeeze_set is non-empty, only squeeze those dimensions.
225 if (!wrapped_squeeze_dims.empty()) {
226 if (wrapped_squeeze_dims.count(i) > 0) {
227 OP_REQUIRES(ctx, existing_dim == 1,
228 errors::InvalidArgument(
229 "Can not squeeze dim[", i,
230 "], expected a dimension of 1, got ", existing_dim));
231 } else {
232 // This dimension is not being squeezed.
233 new_shape.push_back(existing_dim);
234 }
235 } else {
236 // Copy over all non-1-length dimensions.
237 if (existing_dim != 1) {
238 new_shape.push_back(existing_dim);
239 }
240 }
241 }
242
243 const TensorShape output_shape(new_shape);
244 Tensor* output = nullptr;
245 OP_REQUIRES_OK(ctx, ctx->allocate_output(0, {0}, &output));
246 if (!output->CopyFrom(ctx->input(0), output_shape)) {
247 // This should never happen, since the sizes of the input and
248 // output should always be the same.
249 ctx->SetStatus(errors::Internal("Could not squeeze input with shape ",
250 ctx->input(0).shape().DebugString(),
251 " and output shape ",
252 output_shape.DebugString()));
253 }
254 }
255
IsExpensive()256 bool IsExpensive() override { return false; }
257
258 private:
259 std::unordered_set<int32> squeeze_dims_;
260 };
261
262 } // namespace tensorflow
263
264 #endif // TENSORFLOW_CORE_KERNELS_SHAPE_OPS_H_
265