• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <cstddef>
17 #include <cstdlib>
18 #include <string>
19 
20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
21 #include "tensorflow/core/framework/bounds_check.h"
22 #include "tensorflow/core/framework/kernel_def_builder.h"
23 #include "tensorflow/core/framework/op.h"
24 #include "tensorflow/core/framework/op_kernel.h"
25 #include "tensorflow/core/framework/tensor.h"
26 #include "tensorflow/core/framework/tensor_shape.h"
27 #include "tensorflow/core/framework/tensor_types.h"
28 #include "tensorflow/core/framework/types.h"
29 #include "tensorflow/core/kernels/string_util.h"
30 #include "tensorflow/core/lib/core/errors.h"
31 #include "tensorflow/core/lib/core/stringpiece.h"
32 #include "tensorflow/core/platform/types.h"
33 #include "tensorflow/core/util/bcast.h"
34 
35 namespace tensorflow {
36 
37 // Position/length can be 32 or 64-bit integers
38 template <typename T>
39 class SubstrOp : public OpKernel {
40  public:
SubstrOp(OpKernelConstruction * ctx)41   explicit SubstrOp(OpKernelConstruction* ctx) : OpKernel(ctx) {
42     string unit;
43     OP_REQUIRES_OK(ctx, ctx->GetAttr("unit", &unit));
44     OP_REQUIRES_OK(ctx, ParseCharUnit(unit, &unit_));
45   }
46 
Compute(OpKernelContext * context)47   void Compute(OpKernelContext* context) override {
48     // Get inputs
49     const Tensor& input_tensor = context->input(0);
50     const Tensor& pos_tensor = context->input(1);
51     const Tensor& len_tensor = context->input(2);
52     const TensorShape& input_shape = input_tensor.shape();
53     const TensorShape& pos_shape = pos_tensor.shape();
54     const TensorShape& len_shape = len_tensor.shape();
55     OP_REQUIRES(context, (pos_shape == len_shape),
56                 errors::InvalidArgument(
57                     "pos and len should have the same shape, got: ",
58                     pos_shape.DebugString(), " vs. ", len_shape.DebugString()));
59 
60     bool is_scalar = TensorShapeUtils::IsScalar(pos_shape);
61 
62     if (is_scalar || input_shape == pos_shape) {
63       // pos/len are either scalar or match the shape of input_tensor
64       // Do not need to do broadcasting
65 
66       // Reshape input
67       auto input = input_tensor.flat<tstring>();
68       // Allocate output
69       Tensor* output_tensor = nullptr;
70       OP_REQUIRES_OK(context,
71                      context->allocate_output("output", input_tensor.shape(),
72                                               &output_tensor));
73       auto output = output_tensor->flat<tstring>();
74       if (is_scalar) {
75         // Perform Op with scalar pos/len
76         const T pos =
77             tensorflow::internal::SubtleMustCopy(pos_tensor.scalar<T>()());
78         const T len =
79             tensorflow::internal::SubtleMustCopy(len_tensor.scalar<T>()());
80         for (size_t i = 0; i < input_tensor.NumElements(); ++i) {
81           StringPiece in(input(i));
82           T byte_pos = pos;
83           T byte_len = len;
84           switch (unit_) {
85             case CharUnit::UTF8_CHAR:
86               OP_REQUIRES(
87                   context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len),
88                   errors::InvalidArgument("pos ", pos, " out of range for ",
89                                           "string at index ", i));
90               break;
91             case CharUnit::BYTE:
92               byte_pos = AdjustedPosIndex(byte_pos, in);
93               OP_REQUIRES(
94                   context, FastBoundsCheck(byte_pos, in.size() + 1),
95                   errors::InvalidArgument("pos ", pos, " out of range for ",
96                                           "string b'", in, "' at index ", i));
97           }
98           StringPiece sub_in = in.substr(byte_pos, byte_len);
99           output(i).assign(sub_in.data(), sub_in.size());
100         }
101       } else {
102         // Perform Op element-wise with tensor pos/len
103         auto pos_flat = pos_tensor.flat<T>();
104         auto len_flat = len_tensor.flat<T>();
105         for (size_t i = 0; i < input_tensor.NumElements(); ++i) {
106           StringPiece in(input(i));
107           const T pos = tensorflow::internal::SubtleMustCopy(pos_flat(i));
108           const T len = tensorflow::internal::SubtleMustCopy(len_flat(i));
109           T byte_pos = pos;
110           T byte_len = len;
111           switch (unit_) {
112             case CharUnit::UTF8_CHAR:
113               OP_REQUIRES(
114                   context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len),
115                   errors::InvalidArgument("pos ", pos, " out of range for ",
116                                           "string at index ", i));
117               break;
118             case CharUnit::BYTE:
119               byte_pos = AdjustedPosIndex(byte_pos, in);
120               OP_REQUIRES(
121                   context, FastBoundsCheck(byte_pos, in.size() + 1),
122                   errors::InvalidArgument("pos ", pos, " out of range for ",
123                                           "string b'", in, "' at index ", i));
124           }
125           StringPiece sub_in = in.substr(byte_pos, byte_len);
126           output(i).assign(sub_in.data(), sub_in.size());
127         }
128       }
129     } else {
130       // Perform op with broadcasting
131       // TODO: Use ternary broadcasting for once available in Eigen. Current
132       //       implementation iterates through broadcasted ops element-wise;
133       //       this should be parallelized.
134 
135       // Create BCast helper with shape of input and pos/len
136       BCast bcast(BCast::FromShape(input_shape), BCast::FromShape(pos_shape));
137       OP_REQUIRES(context, bcast.IsValid(),
138                   errors::InvalidArgument(
139                       "Incompatible shapes: ", input_shape.DebugString(),
140                       " vs. ", pos_shape.DebugString()));
141       TensorShape output_shape = BCast::ToShape(bcast.result_shape());
142       int ndims = output_shape.dims();
143       Tensor* output_tensor = nullptr;
144       OP_REQUIRES_OK(context, context->allocate_output("output", output_shape,
145                                                        &output_tensor));
146       switch (ndims) {
147         case 1: {
148           // Reshape tensors according to BCast results
149           auto input = input_tensor.shaped<tstring, 1>(bcast.x_reshape());
150           auto output = output_tensor->shaped<tstring, 1>(bcast.result_shape());
151           auto pos_shaped = pos_tensor.shaped<T, 1>(bcast.y_reshape());
152           auto len_shaped = len_tensor.shaped<T, 1>(bcast.y_reshape());
153 
154           // Allocate temporary buffer for broadcasted position tensor
155           Tensor pos_buffer;
156           OP_REQUIRES_OK(context,
157                          context->allocate_temp(DataTypeToEnum<T>::v(),
158                                                 output_shape, &pos_buffer));
159           typename TTypes<T, 1>::Tensor pos_bcast(
160               pos_buffer.shaped<T, 1>(bcast.result_shape()));
161           pos_bcast =
162               pos_shaped.broadcast(BCast::ToIndexArray<1>(bcast.y_bcast()));
163 
164           // Allocate temporary buffer for broadcasted length tensor
165           Tensor len_buffer;
166           OP_REQUIRES_OK(context,
167                          context->allocate_temp(DataTypeToEnum<T>::v(),
168                                                 output_shape, &len_buffer));
169           typename TTypes<T, 1>::Tensor len_bcast(
170               len_buffer.shaped<T, 1>(bcast.result_shape()));
171           len_bcast =
172               len_shaped.broadcast(BCast::ToIndexArray<1>(bcast.y_bcast()));
173 
174           // Iterate through broadcasted tensors and perform substr
175           for (int i = 0; i < output_shape.dim_size(0); ++i) {
176             StringPiece in(input(input.dimension(0) > 1 ? i : 0));
177             const T pos = tensorflow::internal::SubtleMustCopy(pos_bcast(i));
178             const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i));
179             T byte_pos = pos;
180             T byte_len = len;
181             switch (unit_) {
182               case CharUnit::UTF8_CHAR:
183                 OP_REQUIRES(
184                     context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len),
185                     errors::InvalidArgument("pos ", pos, " out of range for ",
186                                             "string at index ", i));
187                 break;
188               case CharUnit::BYTE:
189                 byte_pos = AdjustedPosIndex(byte_pos, in);
190                 OP_REQUIRES(
191                     context, FastBoundsCheck(byte_pos, in.size() + 1),
192                     errors::InvalidArgument("pos ", pos, " out of range for ",
193                                             "string b'", in, "' at index ", i));
194             }
195             StringPiece sub_in = in.substr(byte_pos, byte_len);
196             output(i).assign(sub_in.data(), sub_in.size());
197           }
198           break;
199         }
200         case 2: {
201           // Reshape tensors according to BCast results
202           auto input = input_tensor.shaped<tstring, 2>(bcast.x_reshape());
203           auto output = output_tensor->shaped<tstring, 2>(bcast.result_shape());
204           auto pos_shaped = pos_tensor.shaped<T, 2>(bcast.y_reshape());
205           auto len_shaped = len_tensor.shaped<T, 2>(bcast.y_reshape());
206 
207           // Allocate temporary buffer for broadcasted position tensor
208           Tensor pos_buffer;
209           OP_REQUIRES_OK(context,
210                          context->allocate_temp(DataTypeToEnum<T>::v(),
211                                                 output_shape, &pos_buffer));
212           typename TTypes<T, 2>::Tensor pos_bcast(
213               pos_buffer.shaped<T, 2>(bcast.result_shape()));
214           pos_bcast =
215               pos_shaped.broadcast(BCast::ToIndexArray<2>(bcast.y_bcast()));
216 
217           // Allocate temporary buffer for broadcasted length tensor
218           Tensor len_buffer;
219           OP_REQUIRES_OK(context,
220                          context->allocate_temp(DataTypeToEnum<T>::v(),
221                                                 output_shape, &len_buffer));
222           typename TTypes<T, 2>::Tensor len_bcast(
223               len_buffer.shaped<T, 2>(bcast.result_shape()));
224           len_bcast =
225               len_shaped.broadcast(BCast::ToIndexArray<2>(bcast.y_bcast()));
226 
227           // Iterate through broadcasted tensors and perform substr
228           for (int i = 0; i < output_shape.dim_size(0); ++i) {
229             for (int j = 0; j < output_shape.dim_size(1); ++j) {
230               StringPiece in(input(input.dimension(0) > 1 ? i : 0,
231                                    input.dimension(1) > 1 ? j : 0));
232               const T pos =
233                   tensorflow::internal::SubtleMustCopy(pos_bcast(i, j));
234               const T len =
235                   tensorflow::internal::SubtleMustCopy(len_bcast(i, j));
236               T byte_pos = pos;
237               T byte_len = len;
238               switch (unit_) {
239                 case CharUnit::UTF8_CHAR:
240                   OP_REQUIRES(
241                       context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len),
242                       errors::InvalidArgument("pos ", pos, " out of range for ",
243                                               "string at index ", i));
244                   break;
245                 case CharUnit::BYTE:
246                   byte_pos = AdjustedPosIndex(byte_pos, in);
247                   OP_REQUIRES(
248                       context, FastBoundsCheck(byte_pos, in.size() + 1),
249                       errors::InvalidArgument("pos ", pos, " out of range for ",
250                                               "string b'", in, "' at index (",
251                                               i, ", ", j, ")"));
252               }
253               StringPiece sub_in = in.substr(byte_pos, byte_len);
254               output(i, j).assign(sub_in.data(), sub_in.size());
255             }
256           }
257           break;
258         }
259         default: {
260           context->SetStatus(errors::Unimplemented(
261               "Substr broadcast not implemented for ", ndims, " dimensions"));
262         }
263       }
264     }
265   }
266 
267  private:
268   // This adjusts the requested position. Note it does not perform any bound
269   // checks.
AdjustedPosIndex(const T pos_requested,const StringPiece s)270   static inline T AdjustedPosIndex(const T pos_requested, const StringPiece s) {
271     if (pos_requested < 0) {
272       return s.size() + pos_requested;
273     }
274     return pos_requested;
275   }
276 
277   // Return true if successful; otherwise, return false if the `pos` argument
278   // is out of range in the string.
UpdatePosAndLenForUtf8(const StringPiece in,T * pos,T * len)279   static inline bool UpdatePosAndLenForUtf8(const StringPiece in, T* pos,
280                                             T* len) {
281     if (*pos >= 0) {
282       return UpdatePositivePosAndLenForUtf8(in, *pos, *len, pos, len);
283     } else {
284       return UpdateNegativePosAndLenForUtf8(in, *pos, *len, pos, len);
285     }
286   }
287 
UpdatePositivePosAndLenForUtf8(const StringPiece in,const T pos,const T len,T * char_pos,T * char_len)288   static bool UpdatePositivePosAndLenForUtf8(const StringPiece in, const T pos,
289                                              const T len, T* char_pos,
290                                              T* char_len) {
291     *char_pos = 0;
292     // Determine byte position of the substring start.
293     if (!ForwardNUTF8CharPositions(in, pos, char_pos)) {
294       return false;
295     }
296     // Determine position of the end of the substring.
297     // The length will be capped at the end of the string, and we ignore whether
298     // the string had enough characters to handle it or not.
299     *char_len = *char_pos;
300     ForwardNUTF8CharPositions(in, len, char_len);
301     // The length in bytes is the position end of the substring less the start.
302     *char_len = *char_len - *char_pos;
303     return true;
304   }
305 
306   // This function expects a negative position relative to the end of the
307   // string, but will update the character position to a positive number
308   // relative to the beginning of the string.
UpdateNegativePosAndLenForUtf8(const StringPiece in,const T pos,const T len,T * char_pos,T * char_len)309   static bool UpdateNegativePosAndLenForUtf8(const StringPiece in, const T pos,
310                                              const T len, T* char_pos,
311                                              T* char_len) {
312     // Initially treat the length as position of the end of the substring.
313     *char_len = in.size();
314     // This is the number of character to skip from the end of the string to
315     // arrive at the position where the substring should end.
316     T utf8_chars_to_skip = -pos - len;
317     if (utf8_chars_to_skip < 0) {
318       utf8_chars_to_skip = 0;
319     }
320     // Find the byte position where the substring should end using the computed
321     // number of characters to skip.
322     if (!BackNUTF8CharPositions(in, utf8_chars_to_skip, char_len)) {
323       return false;
324     }
325     // Next, determine where the substring should begin. The number of chars to
326     // skip is the requested position minus the chars we've previously skipped.
327     *char_pos = *char_len;
328     if (!BackNUTF8CharPositions(in, -pos - utf8_chars_to_skip, char_pos)) {
329       return false;
330     }
331     // The length in bytes is the position end of the substring less the start.
332     *char_len = *char_len - *char_pos;
333     return true;
334   }
335 
336   CharUnit unit_ = CharUnit::BYTE;
337 };
338 
339 #define REGISTER_SUBSTR(type)                                      \
340   REGISTER_KERNEL_BUILDER(                                         \
341       Name("Substr").Device(DEVICE_CPU).TypeConstraint<type>("T"), \
342       SubstrOp<type>);
343 REGISTER_SUBSTR(int32);
344 REGISTER_SUBSTR(int64);
345 }  // namespace tensorflow
346