• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // XLA-specific Ops for split.
17 
18 #include "tensorflow/compiler/tf2xla/type_util.h"
19 #include "tensorflow/compiler/tf2xla/xla_helpers.h"
20 #include "tensorflow/compiler/tf2xla/xla_op_kernel.h"
21 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
22 #include "tensorflow/compiler/xla/client/xla_builder.h"
23 #include "tensorflow/compiler/xla/literal.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/register_types.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 
29 namespace tensorflow {
30 namespace {
31 
32 class SplitOp : public XlaOpKernel {
33  public:
SplitOp(OpKernelConstruction * ctx)34   explicit SplitOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
35 
Compile(XlaOpKernelContext * ctx)36   void Compile(XlaOpKernelContext* ctx) override {
37     const int32 num_split = num_outputs();
38     const TensorShape split_dim_shape = ctx->InputShape("split_dim");
39     const TensorShape input_shape = ctx->InputShape(1);
40 
41     OP_REQUIRES(
42         ctx, TensorShapeUtils::IsScalar(split_dim_shape),
43         errors::InvalidArgument("split_dim must be a scalar but has rank ",
44                                 split_dim_shape.dims()));
45     int64 split_dim_orig;
46     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(0, &split_dim_orig));
47 
48     int32 split_dim = split_dim_orig < 0 ? split_dim_orig + input_shape.dims()
49                                          : split_dim_orig;
50     OP_REQUIRES(ctx, 0 <= split_dim && split_dim < input_shape.dims(),
51                 errors::InvalidArgument("-input rank(-", input_shape.dims(),
52                                         ") <= split_dim < input rank (",
53                                         input_shape.dims(), "), but got ",
54                                         split_dim_orig));
55 
56     OP_REQUIRES(
57         ctx, num_split > 0,
58         errors::InvalidArgument(
59             "Number of ways to split should be > 0, but got ", num_split));
60 
61     OP_REQUIRES(
62         ctx, input_shape.dim_size(split_dim) % num_split == 0,
63         errors::InvalidArgument(
64             "Number of ways to split should evenly divide the split "
65             "dimension, but got split_dim ",
66             split_dim_orig, " (size = ", input_shape.dim_size(split_dim), ") ",
67             "and num_split ", num_split));
68 
69     // All the slices are the same size: this is the size along the
70     // split dimension.
71     const int32 slice_size = input_shape.dim_size(split_dim) / num_split;
72 
73     // The vectors we will use to define the slice. The entry for the
74     // split dimensions varies for each output.
75     std::vector<int64> begin(input_shape.dims(), 0);
76     std::vector<int64> limits(input_shape.dims());
77     std::vector<int64> strides(input_shape.dims(), 1);
78     for (int i = 0; i < input_shape.dims(); ++i) {
79       // Initially set up the limits to be the full size of the input:
80       // the split dimension is filled in below.
81       int64 dim = input_shape.dim_size(i);
82       limits[i] = dim;
83     }
84 
85     auto input = ctx->Input(1);
86 
87     // Create each of the outputs.
88     for (int i = 0; i < num_split; ++i) {
89       // Slice out the ith split from the split dimension.
90       begin[split_dim] = i * slice_size;
91       limits[split_dim] = (i + 1) * slice_size;
92       ctx->SetOutput(i, xla::Slice(input, begin, limits, strides));
93     }
94   }
95 };
96 
97 REGISTER_XLA_OP(Name("Split").CompileTimeConstantInput("split_dim"), SplitOp);
98 
99 class SplitVOp : public XlaOpKernel {
100  public:
SplitVOp(OpKernelConstruction * ctx)101   explicit SplitVOp(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
102 
Compile(XlaOpKernelContext * ctx)103   void Compile(XlaOpKernelContext* ctx) override {
104     const int32 num_split = num_outputs();
105     const TensorShape input_shape = ctx->InputShape(0);
106     const TensorShape index_shape = ctx->InputShape(2);
107 
108     int64 split_dim_orig;
109     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntScalar(2, &split_dim_orig));
110     int64 split_dim = split_dim_orig < 0 ? split_dim_orig + input_shape.dims()
111                                          : split_dim_orig;
112     OP_REQUIRES(ctx, 0 <= split_dim && split_dim < input_shape.dims(),
113                 errors::InvalidArgument("-input rank(-", input_shape.dims(),
114                                         ") <= split_dim < input rank (",
115                                         input_shape.dims(), "), but got ",
116                                         split_dim_orig));
117 
118     xla::XlaOp input = ctx->Input(0);
119 
120     OP_REQUIRES(ctx, input_shape.dims() > 0,
121                 errors::InvalidArgument("Can't split a 0 dimensional input"));
122 
123     OP_REQUIRES(
124         ctx, num_split > 0,
125         errors::InvalidArgument(
126             "Number of ways to split should be > 0, but got ", num_split));
127 
128     // Check that sizes are correct.
129     int total_split_size = 0;
130     int neg_one_dim = -1;
131     const TensorShape split_size_shape = ctx->InputShape(1);
132     OP_REQUIRES(ctx,
133                 split_size_shape.dims() == 1 &&
134                     split_size_shape.num_elements() == num_split,
135                 errors::InvalidArgument(
136                     "shape of tensor describing "
137                     " the output must have dimension 1 and the same "
138                     " number of elements as the output. Got ",
139                     split_size_shape.dims(), "-D and ",
140                     split_size_shape.num_elements(), " elements"));
141     // Get the dimension of this split.
142     std::vector<int64> split_sizes;
143     OP_REQUIRES_OK(ctx, ctx->ConstantInputAsIntVector(1, &split_sizes));
144 
145     for (int i = 0; i < num_split; ++i) {
146       int64 slice_size = split_sizes[i];
147       if (slice_size == -1) {
148         OP_REQUIRES(
149             ctx, neg_one_dim == -1,
150             errors::InvalidArgument("Only one dimensions can have a value of"
151                                     "-1. Second one found at dimension ",
152                                     i));
153         neg_one_dim = i;
154       } else {
155         total_split_size += slice_size;
156       }
157     }
158 
159     OP_REQUIRES(
160         ctx,
161         (neg_one_dim == -1 &&
162          total_split_size == input_shape.dim_size(split_dim)) ||
163             (neg_one_dim >= 0 &&
164              total_split_size <= input_shape.dim_size(split_dim)),
165         errors::InvalidArgument("Determined shape must either match "
166                                 "input shape along split_dim exactly if "
167                                 "fully specified, or be less than the size of "
168                                 "the input along split_dim if not fully "
169                                 "specified.  Got: ",
170                                 total_split_size));
171 
172     if (neg_one_dim >= 0) {
173       split_sizes[neg_one_dim] =
174           input_shape.dim_size(split_dim) - total_split_size;
175     }
176 
177     // The vectors we will use to define the slice. The entry for the
178     // split dimensions varies for each output.
179     std::vector<int64> begin(input_shape.dims(), 0);
180     auto dim_sizes = input_shape.dim_sizes();
181     std::vector<int64> limits(dim_sizes.begin(), dim_sizes.end());
182     std::vector<int64> strides(input_shape.dims(), 1);
183     for (int i = 0; i < num_split; ++i) {
184       TensorShape output_shape(input_shape);
185       int slice_size = split_sizes[i];
186       output_shape.set_dim(split_dim, slice_size);
187 
188       // Slice out the ith split from the split dimension.
189       limits[split_dim] = begin[split_dim] + slice_size;
190       ctx->SetOutput(i, xla::Slice(input, begin, limits, strides));
191       begin[split_dim] = limits[split_dim];
192     }
193   }
194 };
195 
196 REGISTER_XLA_OP(Name("SplitV")
197                     .CompileTimeConstantInput("split_dim")
198                     .CompileTimeConstantInput("size_splits"),
199                 SplitVOp);
200 
201 }  // namespace
202 }  // namespace tensorflow
203