1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <string> 17 #include <vector> 18 19 #include "absl/strings/str_split.h" 20 #include "tensorflow/core/framework/common_shape_fns.h" 21 #include "tensorflow/core/framework/op.h" 22 #include "tensorflow/core/framework/shape_inference.h" 23 #include "tensorflow/core/lib/core/errors.h" 24 #include "tensorflow/core/lib/core/status.h" 25 #include "tensorflow/core/lib/strings/strcat.h" 26 #include "tensorflow/core/platform/types.h" 27 28 namespace tensorflow { 29 30 namespace shape_inference { 31 class InferenceContext; 32 } // namespace shape_inference 33 34 using shape_inference::DimensionHandle; 35 using shape_inference::InferenceContext; 36 using shape_inference::ShapeHandle; 37 38 REGISTER_OP("RegexReplace") 39 .Input("input: string") 40 .Input("pattern: string") 41 .Input("rewrite: string") 42 .Output("output: string") 43 .Attr("replace_global: bool = true") __anon4354d5f10102(InferenceContext* c) 44 .SetShapeFn([](InferenceContext* c) { 45 ShapeHandle unused; 46 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 47 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 48 c->set_output(0, c->input(0)); 49 return Status::OK(); 50 }); 51 52 REGISTER_OP("StaticRegexReplace") 53 .Input("input: string") 54 .Attr("pattern: string") 55 .Attr("rewrite: string") 56 .Output("output: string") 57 .Attr("replace_global: bool = true") 58 .SetShapeFn(shape_inference::UnchangedShape); 59 60 REGISTER_OP("RegexFullMatch") 61 .Input("input: string") 62 .Input("pattern: string") 63 .Output("output: bool") __anon4354d5f10202(InferenceContext* c) 64 .SetShapeFn([](InferenceContext* c) { 65 ShapeHandle unused; 66 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 67 c->set_output(0, c->input(0)); 68 return Status::OK(); 69 }); 70 71 REGISTER_OP("StaticRegexFullMatch") 72 .Input("input: string") 73 .Attr("pattern: string") 74 .Output("output: bool") 75 .SetShapeFn(shape_inference::UnchangedShape); 76 77 REGISTER_OP("StringToHashBucketFast") 78 .Input("input: string") 79 .Output("output: int64") 80 .Attr("num_buckets: int >= 1") 81 .SetShapeFn(shape_inference::UnchangedShape); 82 83 REGISTER_OP("StringToHashBucketStrong") 84 .Input("input: string") 85 .Output("output: int64") 86 .Attr("num_buckets: int >= 1") 87 .Attr("key: list(int)") 88 .SetShapeFn(shape_inference::UnchangedShape); 89 90 REGISTER_OP("StringToHashBucket") 91 .Input("string_tensor: string") 92 .Output("output: int64") 93 .Attr("num_buckets: int >= 1") 94 .SetShapeFn(shape_inference::UnchangedShape); 95 96 REGISTER_OP("ReduceJoin") 97 .Input("inputs: string") 98 .Input("reduction_indices: int32") 99 .Attr("keep_dims: bool = false") 100 .Attr("separator: string = ''") 101 .Output("output: string") 102 .SetShapeFn(shape_inference::ReductionShape); 103 104 REGISTER_OP("UnsortedSegmentJoin") 105 .Input("inputs: string") 106 .Input("segment_ids: Tindices") 107 .Input("num_segments: Tnumsegments") 108 .Attr("separator: string = ''") 109 .Attr("Tindices: {int32,int64}") 110 .Attr("Tnumsegments: {int32,int64} = DT_INT32") 111 .Output("output: string") 112 .SetShapeFn(shape_inference::UnsortedSegmentReductionShapeFn); 113 114 REGISTER_OP("AsString") 115 .Input("input: T") 116 .Output("output: string") 117 .Attr( 118 "T: {int8, int16, int32, int64, complex64, complex128, float, double, " 119 "bool, variant}") 120 .Attr("precision: int = -1") 121 .Attr("scientific: bool = false") 122 .Attr("shortest: bool = false") 123 .Attr("width: int = -1") 124 .Attr("fill: string = ''") 125 .SetShapeFn(shape_inference::UnchangedShape); 126 127 REGISTER_OP("StringFormat") 128 .Input("inputs: T") 129 .Output("output: string") 130 .Attr("T: list(type) >= 0") 131 .Attr("template: string = '%s'") 132 .Attr("placeholder: string = '%s'") 133 .Attr("summarize: int = 3") __anon4354d5f10302(InferenceContext* c) 134 .SetShapeFn([](InferenceContext* c) { 135 string template_; 136 string placeholder; 137 TF_RETURN_IF_ERROR(c->GetAttr("template", &template_)); 138 TF_RETURN_IF_ERROR(c->GetAttr("placeholder", &placeholder)); 139 140 std::vector<std::string> split_template; 141 split_template = absl::StrSplit(template_, placeholder); 142 int64 num_placeholders = split_template.size() - 1; 143 if (c->num_inputs() != num_placeholders) { 144 return errors::InvalidArgument(strings::StrCat( 145 "num placeholders in template and num inputs must match: ", 146 num_placeholders, " vs. ", c->num_inputs())); 147 } 148 149 c->set_output(0, c->Scalar()); 150 return Status::OK(); 151 }); 152 153 REGISTER_OP("StringJoin") 154 .Input("inputs: N * string") 155 .Attr("N: int") 156 .Attr("separator: string = ''") 157 .Output("output: string") __anon4354d5f10402(InferenceContext* c) 158 .SetShapeFn([](InferenceContext* c) { 159 // If all inputs are scalars, then return a scalar. 160 bool all_scalar = true; 161 for (int i = 0; i < c->num_inputs(); ++i) { 162 if (c->Rank(c->input(i)) != 0) all_scalar = false; 163 } 164 if (all_scalar) { 165 c->set_output(0, c->Scalar()); 166 return Status::OK(); 167 } 168 169 // At least one input is unknown or a scalar. 170 // Merge the non-scalars to find the output shape. 171 // Don't merge inputs with unknown rank, as they can actually be scalars 172 // or the output shape. 173 ShapeHandle out = c->UnknownShape(); 174 for (int i = 0; i < c->num_inputs(); ++i) { 175 if (c->RankKnown(c->input(i)) && c->Rank(c->input(i)) != 0) { 176 TF_RETURN_IF_ERROR(c->Merge(out, c->input(i), &out)); 177 } 178 } 179 c->set_output(0, out); 180 return Status::OK(); 181 }); 182 183 REGISTER_OP("StringSplit") 184 .Input("input: string") 185 .Input("delimiter: string") 186 .Output("indices: int64") 187 .Output("values: string") 188 .Output("shape: int64") 189 .Attr("skip_empty: bool = true") __anon4354d5f10502(InferenceContext* c) 190 .SetShapeFn([](InferenceContext* c) { 191 ShapeHandle unused; 192 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused)); 193 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 194 195 c->set_output(0, c->Matrix(InferenceContext::kUnknownDim, 2)); 196 c->set_output(1, c->Vector(InferenceContext::kUnknownDim)); 197 c->set_output(2, c->Vector(2)); 198 return Status::OK(); 199 }); 200 201 REGISTER_OP("StringSplitV2") 202 .Input("input: string") 203 .Input("sep: string") 204 .Output("indices: int64") 205 .Output("values: string") 206 .Output("shape: int64") 207 .Attr("maxsplit: int = -1") __anon4354d5f10602(InferenceContext* c) 208 .SetShapeFn([](InferenceContext* c) { 209 ShapeHandle unused; 210 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused)); 211 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 212 213 c->set_output(0, c->Matrix(InferenceContext::kUnknownDim, 2)); 214 c->set_output(1, c->Vector(InferenceContext::kUnknownDim)); 215 c->set_output(2, c->Vector(2)); 216 return Status::OK(); 217 }); 218 219 REGISTER_OP("StringLower") 220 .Input("input: string") 221 .Output("output: string") 222 .Attr("encoding: string =''") 223 .SetShapeFn(shape_inference::UnchangedShape); 224 225 REGISTER_OP("StringUpper") 226 .Input("input: string") 227 .Output("output: string") 228 .Attr("encoding: string =''") 229 .SetShapeFn(shape_inference::UnchangedShape); 230 231 REGISTER_OP("StringStrip") 232 .Input("input: string") 233 .Output("output: string") 234 .SetShapeFn(shape_inference::UnchangedShape); 235 236 REGISTER_OP("StringLength") 237 .Input("input: string") 238 .Output("output: int32") 239 .Attr("unit: {'BYTE', 'UTF8_CHAR'} = 'BYTE'") 240 .SetShapeFn(shape_inference::UnchangedShape); 241 242 REGISTER_OP("EncodeBase64") 243 .Input("input: string") 244 .Output("output: string") 245 .Attr("pad: bool = false") 246 .SetShapeFn(shape_inference::UnchangedShape); 247 248 REGISTER_OP("DecodeBase64") 249 .Input("input: string") 250 .Output("output: string") 251 .SetShapeFn(shape_inference::UnchangedShape); 252 253 REGISTER_OP("Substr") 254 .Input("input: string") 255 .Input("pos: T") 256 .Input("len: T") 257 .Output("output: string") 258 .Attr("T: {int32, int64}") 259 .Attr("unit: {'BYTE', 'UTF8_CHAR'} = 'BYTE'") __anon4354d5f10702(InferenceContext* c) 260 .SetShapeFn([](InferenceContext* c) { 261 ShapeHandle pos_shape = c->input(1); 262 ShapeHandle len_shape = c->input(2); 263 ShapeHandle unused; 264 // If len rank is known, check that pos and len have the same rank 265 if (c->RankKnown(len_shape)) { 266 TF_RETURN_IF_ERROR(c->WithRank(pos_shape, c->Rank(len_shape), &unused)); 267 } 268 // Check that dimensions are equal 269 for (int32 i = 0; i < c->Rank(pos_shape); ++i) { 270 DimensionHandle pos_dim = c->Dim(pos_shape, i); 271 DimensionHandle len_dim = c->Dim(len_shape, i); 272 if (c->Value(pos_dim) != c->Value(len_dim)) { 273 return errors::InvalidArgument( 274 "pos and len shapes must match: ", c->DebugString(pos_shape), 275 " vs. ", c->DebugString(len_shape)); 276 } 277 } 278 // c->input(0) is the ShapeHandle to input strings 279 // BroadcastBinaryOpShapeFn infers shape from c->input(0) and c->input(1). 280 return shape_inference::BroadcastBinaryOpShapeFn(c); 281 }); 282 283 REGISTER_OP("UnicodeScript") 284 .Input("input: int32") 285 .Output("output: int32") 286 .SetShapeFn(shape_inference::UnchangedShape); 287 288 REGISTER_OP("UnicodeEncode") 289 .Input("input_values: int32") 290 .Input("input_splits: Tsplits") 291 .Attr("errors: {'ignore', 'replace', 'strict'} = 'replace'") 292 .Attr("output_encoding: {'UTF-8', 'UTF-16-BE', 'UTF-32-BE'}") 293 .Attr("replacement_char: int = 65533") // 0xFFFD unicode replacement char 294 .Attr("Tsplits: {int32, int64} = DT_INT64") 295 .Output("output: string") __anon4354d5f10802(InferenceContext* c) 296 .SetShapeFn([](InferenceContext* c) { 297 // Check rank of inner values 298 ShapeHandle input_inner_values_shape = c->input(0); 299 ShapeHandle unused; 300 TF_RETURN_IF_ERROR(c->WithRank(input_inner_values_shape, 1, &unused)); 301 302 // Check rank of input_splits 303 ShapeHandle splits_shape = c->input(1); 304 TF_RETURN_IF_ERROR(c->WithRank(splits_shape, 1, &unused)); 305 306 // Output shape is a 1-D tensor with size equal to number of splits. 307 std::vector<DimensionHandle> dims(1); 308 TF_RETURN_IF_ERROR(c->Subtract(c->Dim(splits_shape, 0), 1, &dims[0])); 309 c->set_output(0, c->MakeShape(dims)); 310 311 return Status::OK(); 312 }); 313 314 REGISTER_OP("UnicodeTranscode") 315 .Input("input: string") 316 .Output("output: string") 317 .Attr("input_encoding: string") 318 .Attr("output_encoding: {'UTF-8', 'UTF-16-BE', 'UTF-32-BE'}") 319 .Attr("errors: {'strict', 'replace', 'ignore'} = 'replace'") 320 .Attr("replacement_char: int = 65533") // 0xFFFD unicode replacement char 321 .Attr("replace_control_characters: bool = false") 322 .SetShapeFn(shape_inference::UnchangedShape); 323 324 REGISTER_OP("UnicodeDecode") 325 .Input("input: string") 326 .Output("row_splits: Tsplits") 327 .Output("char_values: int32") 328 .Attr("input_encoding: string") 329 .Attr("errors: {'strict', 'replace', 'ignore'} = 'replace'") 330 .Attr("replacement_char: int = 65533") // 0xFFFD unicode replacement char 331 .Attr("replace_control_characters: bool = false") 332 .Attr("Tsplits: {int32, int64} = DT_INT64") __anon4354d5f10902(InferenceContext* c) 333 .SetShapeFn([](InferenceContext* c) { 334 // row_splits.shape == [input.size() + 1] 335 DimensionHandle num_row_splits; 336 DimensionHandle input_size = c->NumElements(c->input(0)); 337 TF_RETURN_IF_ERROR(c->Add(input_size, 1, &num_row_splits)); 338 c->set_output(0, c->Vector(num_row_splits)); 339 340 // char_values.shape == [num_chars] 341 DimensionHandle num_chars = c->UnknownDim(); 342 c->set_output(1, c->Vector(num_chars)); 343 return Status::OK(); 344 }); 345 346 REGISTER_OP("UnicodeDecodeWithOffsets") 347 .Input("input: string") 348 .Output("row_splits: Tsplits") 349 .Output("char_values: int32") 350 .Output("char_to_byte_starts: int64") 351 .Attr("input_encoding: string") 352 .Attr("errors: {'strict', 'replace', 'ignore'} = 'replace'") 353 .Attr("replacement_char: int = 65533") // 0xFFFD unicode replacement char 354 .Attr("replace_control_characters: bool = false") 355 .Attr("Tsplits: {int32, int64} = DT_INT64") __anon4354d5f10a02(InferenceContext* c) 356 .SetShapeFn([](InferenceContext* c) { 357 // row_splits.shape == [input.size() + 1] 358 DimensionHandle num_row_splits; 359 DimensionHandle input_size = c->NumElements(c->input(0)); 360 TF_RETURN_IF_ERROR(c->Add(input_size, 1, &num_row_splits)); 361 c->set_output(0, c->Vector(num_row_splits)); 362 363 // char_values.shape == offset_values.shape == [num_chars] 364 DimensionHandle num_chars = c->UnknownDim(); 365 c->set_output(1, c->Vector(num_chars)); 366 c->set_output(2, c->Vector(num_chars)); 367 return Status::OK(); 368 }); 369 370 REGISTER_OP("StringNGrams") 371 .Attr("separator: string") 372 .Attr("ngram_widths: list(int) >= 0") 373 .Attr("left_pad: string") 374 .Attr("right_pad: string") 375 .Attr("pad_width: int") 376 .Attr("preserve_short_sequences: bool") 377 .Attr("Tsplits: {int32, int64} = DT_INT64") 378 .Input("data: string") 379 .Input("data_splits: Tsplits") 380 .Output("ngrams: string") 381 .Output("ngrams_splits: Tsplits") __anon4354d5f10b02(InferenceContext* c) 382 .SetShapeFn([](InferenceContext* c) { 383 c->set_output(0, c->UnknownShapeOfRank(1)); 384 ShapeHandle data = c->input(0); 385 TF_RETURN_IF_ERROR(c->WithRank(data, 1, &data)); 386 ShapeHandle data_splits = c->input(1); 387 TF_RETURN_IF_ERROR(c->WithRank(data_splits, 1, &data_splits)); 388 c->set_output(1, data_splits); 389 return Status::OK(); 390 }); 391 392 } // namespace tensorflow 393