1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <cstddef> 17 #include <cstdlib> 18 #include <string> 19 20 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 21 #include "tensorflow/core/framework/bounds_check.h" 22 #include "tensorflow/core/framework/kernel_def_builder.h" 23 #include "tensorflow/core/framework/op.h" 24 #include "tensorflow/core/framework/op_kernel.h" 25 #include "tensorflow/core/framework/tensor.h" 26 #include "tensorflow/core/framework/tensor_shape.h" 27 #include "tensorflow/core/framework/tensor_types.h" 28 #include "tensorflow/core/framework/types.h" 29 #include "tensorflow/core/kernels/string_util.h" 30 #include "tensorflow/core/lib/core/errors.h" 31 #include "tensorflow/core/lib/core/stringpiece.h" 32 #include "tensorflow/core/platform/types.h" 33 #include "tensorflow/core/util/bcast.h" 34 35 namespace tensorflow { 36 37 // Position/length can be 32 or 64-bit integers 38 template <typename T> 39 class SubstrOp : public OpKernel { 40 public: SubstrOp(OpKernelConstruction * ctx)41 explicit SubstrOp(OpKernelConstruction* ctx) : OpKernel(ctx) { 42 string unit; 43 OP_REQUIRES_OK(ctx, ctx->GetAttr("unit", &unit)); 44 OP_REQUIRES_OK(ctx, ParseCharUnit(unit, &unit_)); 45 } 46 Compute(OpKernelContext * context)47 void Compute(OpKernelContext* context) override { 48 // Get inputs 49 const Tensor& input_tensor = context->input(0); 50 const Tensor& pos_tensor = context->input(1); 51 const Tensor& len_tensor = context->input(2); 52 const TensorShape& input_shape = input_tensor.shape(); 53 const TensorShape& pos_shape = pos_tensor.shape(); 54 const TensorShape& len_shape = len_tensor.shape(); 55 OP_REQUIRES(context, (pos_shape == len_shape), 56 errors::InvalidArgument( 57 "pos and len should have the same shape, got: ", 58 pos_shape.DebugString(), " vs. ", len_shape.DebugString())); 59 60 bool is_scalar = TensorShapeUtils::IsScalar(pos_shape); 61 62 if (is_scalar || input_shape == pos_shape) { 63 // pos/len are either scalar or match the shape of input_tensor 64 // Do not need to do broadcasting 65 66 // Reshape input 67 auto input = input_tensor.flat<tstring>(); 68 // Allocate output 69 Tensor* output_tensor = nullptr; 70 OP_REQUIRES_OK(context, 71 context->allocate_output("output", input_tensor.shape(), 72 &output_tensor)); 73 auto output = output_tensor->flat<tstring>(); 74 if (is_scalar) { 75 // Perform Op with scalar pos/len 76 const T pos = 77 tensorflow::internal::SubtleMustCopy(pos_tensor.scalar<T>()()); 78 const T len = 79 tensorflow::internal::SubtleMustCopy(len_tensor.scalar<T>()()); 80 for (size_t i = 0; i < input_tensor.NumElements(); ++i) { 81 StringPiece in(input(i)); 82 T byte_pos = pos; 83 T byte_len = len; 84 switch (unit_) { 85 case CharUnit::UTF8_CHAR: 86 OP_REQUIRES( 87 context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len), 88 errors::InvalidArgument("pos ", pos, " out of range for ", 89 "string at index ", i)); 90 break; 91 case CharUnit::BYTE: 92 byte_pos = AdjustedPosIndex(byte_pos, in); 93 OP_REQUIRES( 94 context, FastBoundsCheck(byte_pos, in.size() + 1), 95 errors::InvalidArgument("pos ", pos, " out of range for ", 96 "string b'", in, "' at index ", i)); 97 } 98 StringPiece sub_in = in.substr(byte_pos, byte_len); 99 output(i).assign(sub_in.data(), sub_in.size()); 100 } 101 } else { 102 // Perform Op element-wise with tensor pos/len 103 auto pos_flat = pos_tensor.flat<T>(); 104 auto len_flat = len_tensor.flat<T>(); 105 for (size_t i = 0; i < input_tensor.NumElements(); ++i) { 106 StringPiece in(input(i)); 107 const T pos = tensorflow::internal::SubtleMustCopy(pos_flat(i)); 108 const T len = tensorflow::internal::SubtleMustCopy(len_flat(i)); 109 T byte_pos = pos; 110 T byte_len = len; 111 switch (unit_) { 112 case CharUnit::UTF8_CHAR: 113 OP_REQUIRES( 114 context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len), 115 errors::InvalidArgument("pos ", pos, " out of range for ", 116 "string at index ", i)); 117 break; 118 case CharUnit::BYTE: 119 byte_pos = AdjustedPosIndex(byte_pos, in); 120 OP_REQUIRES( 121 context, FastBoundsCheck(byte_pos, in.size() + 1), 122 errors::InvalidArgument("pos ", pos, " out of range for ", 123 "string b'", in, "' at index ", i)); 124 } 125 StringPiece sub_in = in.substr(byte_pos, byte_len); 126 output(i).assign(sub_in.data(), sub_in.size()); 127 } 128 } 129 } else { 130 // Perform op with broadcasting 131 // TODO: Use ternary broadcasting for once available in Eigen. Current 132 // implementation iterates through broadcasted ops element-wise; 133 // this should be parallelized. 134 135 // Create BCast helper with shape of input and pos/len 136 BCast bcast(BCast::FromShape(input_shape), BCast::FromShape(pos_shape)); 137 OP_REQUIRES(context, bcast.IsValid(), 138 errors::InvalidArgument( 139 "Incompatible shapes: ", input_shape.DebugString(), 140 " vs. ", pos_shape.DebugString())); 141 TensorShape output_shape = BCast::ToShape(bcast.result_shape()); 142 int ndims = output_shape.dims(); 143 Tensor* output_tensor = nullptr; 144 OP_REQUIRES_OK(context, context->allocate_output("output", output_shape, 145 &output_tensor)); 146 switch (ndims) { 147 case 1: { 148 // Reshape tensors according to BCast results 149 auto input = input_tensor.shaped<tstring, 1>(bcast.x_reshape()); 150 auto output = output_tensor->shaped<tstring, 1>(bcast.result_shape()); 151 auto pos_shaped = pos_tensor.shaped<T, 1>(bcast.y_reshape()); 152 auto len_shaped = len_tensor.shaped<T, 1>(bcast.y_reshape()); 153 154 // Allocate temporary buffer for broadcasted position tensor 155 Tensor pos_buffer; 156 OP_REQUIRES_OK(context, 157 context->allocate_temp(DataTypeToEnum<T>::v(), 158 output_shape, &pos_buffer)); 159 typename TTypes<T, 1>::Tensor pos_bcast( 160 pos_buffer.shaped<T, 1>(bcast.result_shape())); 161 pos_bcast = 162 pos_shaped.broadcast(BCast::ToIndexArray<1>(bcast.y_bcast())); 163 164 // Allocate temporary buffer for broadcasted length tensor 165 Tensor len_buffer; 166 OP_REQUIRES_OK(context, 167 context->allocate_temp(DataTypeToEnum<T>::v(), 168 output_shape, &len_buffer)); 169 typename TTypes<T, 1>::Tensor len_bcast( 170 len_buffer.shaped<T, 1>(bcast.result_shape())); 171 len_bcast = 172 len_shaped.broadcast(BCast::ToIndexArray<1>(bcast.y_bcast())); 173 174 // Iterate through broadcasted tensors and perform substr 175 for (int i = 0; i < output_shape.dim_size(0); ++i) { 176 StringPiece in(input(input.dimension(0) > 1 ? i : 0)); 177 const T pos = tensorflow::internal::SubtleMustCopy(pos_bcast(i)); 178 const T len = tensorflow::internal::SubtleMustCopy(len_bcast(i)); 179 T byte_pos = pos; 180 T byte_len = len; 181 switch (unit_) { 182 case CharUnit::UTF8_CHAR: 183 OP_REQUIRES( 184 context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len), 185 errors::InvalidArgument("pos ", pos, " out of range for ", 186 "string at index ", i)); 187 break; 188 case CharUnit::BYTE: 189 byte_pos = AdjustedPosIndex(byte_pos, in); 190 OP_REQUIRES( 191 context, FastBoundsCheck(byte_pos, in.size() + 1), 192 errors::InvalidArgument("pos ", pos, " out of range for ", 193 "string b'", in, "' at index ", i)); 194 } 195 StringPiece sub_in = in.substr(byte_pos, byte_len); 196 output(i).assign(sub_in.data(), sub_in.size()); 197 } 198 break; 199 } 200 case 2: { 201 // Reshape tensors according to BCast results 202 auto input = input_tensor.shaped<tstring, 2>(bcast.x_reshape()); 203 auto output = output_tensor->shaped<tstring, 2>(bcast.result_shape()); 204 auto pos_shaped = pos_tensor.shaped<T, 2>(bcast.y_reshape()); 205 auto len_shaped = len_tensor.shaped<T, 2>(bcast.y_reshape()); 206 207 // Allocate temporary buffer for broadcasted position tensor 208 Tensor pos_buffer; 209 OP_REQUIRES_OK(context, 210 context->allocate_temp(DataTypeToEnum<T>::v(), 211 output_shape, &pos_buffer)); 212 typename TTypes<T, 2>::Tensor pos_bcast( 213 pos_buffer.shaped<T, 2>(bcast.result_shape())); 214 pos_bcast = 215 pos_shaped.broadcast(BCast::ToIndexArray<2>(bcast.y_bcast())); 216 217 // Allocate temporary buffer for broadcasted length tensor 218 Tensor len_buffer; 219 OP_REQUIRES_OK(context, 220 context->allocate_temp(DataTypeToEnum<T>::v(), 221 output_shape, &len_buffer)); 222 typename TTypes<T, 2>::Tensor len_bcast( 223 len_buffer.shaped<T, 2>(bcast.result_shape())); 224 len_bcast = 225 len_shaped.broadcast(BCast::ToIndexArray<2>(bcast.y_bcast())); 226 227 // Iterate through broadcasted tensors and perform substr 228 for (int i = 0; i < output_shape.dim_size(0); ++i) { 229 for (int j = 0; j < output_shape.dim_size(1); ++j) { 230 StringPiece in(input(input.dimension(0) > 1 ? i : 0, 231 input.dimension(1) > 1 ? j : 0)); 232 const T pos = 233 tensorflow::internal::SubtleMustCopy(pos_bcast(i, j)); 234 const T len = 235 tensorflow::internal::SubtleMustCopy(len_bcast(i, j)); 236 T byte_pos = pos; 237 T byte_len = len; 238 switch (unit_) { 239 case CharUnit::UTF8_CHAR: 240 OP_REQUIRES( 241 context, UpdatePosAndLenForUtf8(in, &byte_pos, &byte_len), 242 errors::InvalidArgument("pos ", pos, " out of range for ", 243 "string at index ", i)); 244 break; 245 case CharUnit::BYTE: 246 byte_pos = AdjustedPosIndex(byte_pos, in); 247 OP_REQUIRES( 248 context, FastBoundsCheck(byte_pos, in.size() + 1), 249 errors::InvalidArgument("pos ", pos, " out of range for ", 250 "string b'", in, "' at index (", 251 i, ", ", j, ")")); 252 } 253 StringPiece sub_in = in.substr(byte_pos, byte_len); 254 output(i, j).assign(sub_in.data(), sub_in.size()); 255 } 256 } 257 break; 258 } 259 default: { 260 context->SetStatus(errors::Unimplemented( 261 "Substr broadcast not implemented for ", ndims, " dimensions")); 262 } 263 } 264 } 265 } 266 267 private: 268 // This adjusts the requested position. Note it does not perform any bound 269 // checks. AdjustedPosIndex(const T pos_requested,const StringPiece s)270 static inline T AdjustedPosIndex(const T pos_requested, const StringPiece s) { 271 if (pos_requested < 0) { 272 return s.size() + pos_requested; 273 } 274 return pos_requested; 275 } 276 277 // Return true if successful; otherwise, return false if the `pos` argument 278 // is out of range in the string. UpdatePosAndLenForUtf8(const StringPiece in,T * pos,T * len)279 static inline bool UpdatePosAndLenForUtf8(const StringPiece in, T* pos, 280 T* len) { 281 if (*pos >= 0) { 282 return UpdatePositivePosAndLenForUtf8(in, *pos, *len, pos, len); 283 } else { 284 return UpdateNegativePosAndLenForUtf8(in, *pos, *len, pos, len); 285 } 286 } 287 UpdatePositivePosAndLenForUtf8(const StringPiece in,const T pos,const T len,T * char_pos,T * char_len)288 static bool UpdatePositivePosAndLenForUtf8(const StringPiece in, const T pos, 289 const T len, T* char_pos, 290 T* char_len) { 291 *char_pos = 0; 292 // Determine byte position of the substring start. 293 if (!ForwardNUTF8CharPositions(in, pos, char_pos)) { 294 return false; 295 } 296 // Determine position of the end of the substring. 297 // The length will be capped at the end of the string, and we ignore whether 298 // the string had enough characters to handle it or not. 299 *char_len = *char_pos; 300 ForwardNUTF8CharPositions(in, len, char_len); 301 // The length in bytes is the position end of the substring less the start. 302 *char_len = *char_len - *char_pos; 303 return true; 304 } 305 306 // This function expects a negative position relative to the end of the 307 // string, but will update the character position to a positive number 308 // relative to the beginning of the string. UpdateNegativePosAndLenForUtf8(const StringPiece in,const T pos,const T len,T * char_pos,T * char_len)309 static bool UpdateNegativePosAndLenForUtf8(const StringPiece in, const T pos, 310 const T len, T* char_pos, 311 T* char_len) { 312 // Initially treat the length as position of the end of the substring. 313 *char_len = in.size(); 314 // This is the number of character to skip from the end of the string to 315 // arrive at the position where the substring should end. 316 T utf8_chars_to_skip = -pos - len; 317 if (utf8_chars_to_skip < 0) { 318 utf8_chars_to_skip = 0; 319 } 320 // Find the byte position where the substring should end using the computed 321 // number of characters to skip. 322 if (!BackNUTF8CharPositions(in, utf8_chars_to_skip, char_len)) { 323 return false; 324 } 325 // Next, determine where the substring should begin. The number of chars to 326 // skip is the requested position minus the chars we've previously skipped. 327 *char_pos = *char_len; 328 if (!BackNUTF8CharPositions(in, -pos - utf8_chars_to_skip, char_pos)) { 329 return false; 330 } 331 // The length in bytes is the position end of the substring less the start. 332 *char_len = *char_len - *char_pos; 333 return true; 334 } 335 336 CharUnit unit_ = CharUnit::BYTE; 337 }; 338 339 #define REGISTER_SUBSTR(type) \ 340 REGISTER_KERNEL_BUILDER( \ 341 Name("Substr").Device(DEVICE_CPU).TypeConstraint<type>("T"), \ 342 SubstrOp<type>); 343 REGISTER_SUBSTR(int32); 344 REGISTER_SUBSTR(int64); 345 } // namespace tensorflow 346