1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #include <string> 17 #include <vector> 18 19 #include "absl/strings/str_split.h" 20 #include "tensorflow/core/framework/common_shape_fns.h" 21 #include "tensorflow/core/framework/op.h" 22 #include "tensorflow/core/framework/shape_inference.h" 23 #include "tensorflow/core/lib/core/errors.h" 24 #include "tensorflow/core/lib/core/status.h" 25 #include "tensorflow/core/lib/strings/strcat.h" 26 #include "tensorflow/core/platform/types.h" 27 28 namespace tensorflow { 29 30 namespace shape_inference { 31 class InferenceContext; 32 } // namespace shape_inference 33 34 using shape_inference::DimensionHandle; 35 using shape_inference::InferenceContext; 36 using shape_inference::ShapeHandle; 37 38 REGISTER_OP("RegexReplace") 39 .Input("input: string") 40 .Input("pattern: string") 41 .Input("rewrite: string") 42 .Output("output: string") 43 .Attr("replace_global: bool = true") __anonb4ab91930102(InferenceContext* c) 44 .SetShapeFn([](InferenceContext* c) { 45 ShapeHandle unused; 46 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 47 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused)); 48 c->set_output(0, c->input(0)); 49 return Status::OK(); 50 }); 51 52 REGISTER_OP("StaticRegexReplace") 53 .Input("input: string") 54 .Attr("pattern: string") 55 .Attr("rewrite: string") 56 .Output("output: string") 57 .Attr("replace_global: bool = true") 58 .SetShapeFn(shape_inference::UnchangedShape); 59 60 REGISTER_OP("RegexFullMatch") 61 .Input("input: string") 62 .Input("pattern: string") 63 .Output("output: bool") __anonb4ab91930202(InferenceContext* c) 64 .SetShapeFn([](InferenceContext* c) { 65 ShapeHandle unused; 66 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 67 c->set_output(0, c->input(0)); 68 return Status::OK(); 69 }); 70 71 REGISTER_OP("StaticRegexFullMatch") 72 .Input("input: string") 73 .Attr("pattern: string") 74 .Output("output: bool") 75 .SetShapeFn(shape_inference::UnchangedShape); 76 77 REGISTER_OP("StringToHashBucketFast") 78 .Input("input: string") 79 .Output("output: int64") 80 .Attr("num_buckets: int >= 1") 81 .SetShapeFn(shape_inference::UnchangedShape); 82 83 REGISTER_OP("StringToHashBucketStrong") 84 .Input("input: string") 85 .Output("output: int64") 86 .Attr("num_buckets: int >= 1") 87 .Attr("key: list(int)") 88 .SetShapeFn(shape_inference::UnchangedShape); 89 90 REGISTER_OP("StringToHashBucket") 91 .Input("string_tensor: string") 92 .Output("output: int64") 93 .Attr("num_buckets: int >= 1") 94 .SetShapeFn(shape_inference::UnchangedShape); 95 96 REGISTER_OP("ReduceJoin") 97 .Input("inputs: string") 98 .Input("reduction_indices: int32") 99 .Attr("keep_dims: bool = false") 100 .Attr("separator: string = ''") 101 .Output("output: string") 102 .SetShapeFn(shape_inference::ReductionShape); 103 104 REGISTER_OP("AsString") 105 .Input("input: T") 106 .Output("output: string") 107 .Attr( 108 "T: {int8, int16, int32, int64, complex64, complex128, float, double, " 109 "bool}") 110 .Attr("precision: int = -1") 111 .Attr("scientific: bool = false") 112 .Attr("shortest: bool = false") 113 .Attr("width: int = -1") 114 .Attr("fill: string = ''") 115 .SetShapeFn(shape_inference::UnchangedShape); 116 117 REGISTER_OP("StringFormat") 118 .Input("inputs: T") 119 .Output("output: string") 120 .Attr("T: list(type) >= 0") 121 .Attr("template: string = '%s'") 122 .Attr("placeholder: string = '%s'") 123 .Attr("summarize: int = 3") __anonb4ab91930302(InferenceContext* c) 124 .SetShapeFn([](InferenceContext* c) { 125 string template_; 126 string placeholder; 127 TF_RETURN_IF_ERROR(c->GetAttr("template", &template_)); 128 TF_RETURN_IF_ERROR(c->GetAttr("placeholder", &placeholder)); 129 130 std::vector<std::string> split_template; 131 split_template = absl::StrSplit(template_, placeholder); 132 int64 num_placeholders = split_template.size() - 1; 133 if (c->num_inputs() != num_placeholders) { 134 return errors::InvalidArgument(strings::StrCat( 135 "num placeholders in template and num inputs must match: ", 136 num_placeholders, " vs. ", c->num_inputs())); 137 } 138 139 c->set_output(0, c->Scalar()); 140 return Status::OK(); 141 }); 142 143 REGISTER_OP("StringJoin") 144 .Input("inputs: N * string") 145 .Attr("N: int") 146 .Attr("separator: string = ''") 147 .Output("output: string") __anonb4ab91930402(InferenceContext* c) 148 .SetShapeFn([](InferenceContext* c) { 149 // If all inputs are scalars, then return a scalar. 150 bool all_scalar = true; 151 for (int i = 0; i < c->num_inputs(); ++i) { 152 if (c->Rank(c->input(i)) != 0) all_scalar = false; 153 } 154 if (all_scalar) { 155 c->set_output(0, c->Scalar()); 156 return Status::OK(); 157 } 158 159 // At least one input is unknown or a scalar. 160 // Merge the non-scalars to find the output shape. 161 // Don't merge inputs with unknown rank, as they can actually be scalars 162 // or the output shape. 163 ShapeHandle out = c->UnknownShape(); 164 for (int i = 0; i < c->num_inputs(); ++i) { 165 if (c->RankKnown(c->input(i)) && c->Rank(c->input(i)) != 0) { 166 TF_RETURN_IF_ERROR(c->Merge(out, c->input(i), &out)); 167 } 168 } 169 c->set_output(0, out); 170 return Status::OK(); 171 }); 172 173 REGISTER_OP("StringSplit") 174 .Input("input: string") 175 .Input("delimiter: string") 176 .Output("indices: int64") 177 .Output("values: string") 178 .Output("shape: int64") 179 .Attr("skip_empty: bool = true") __anonb4ab91930502(InferenceContext* c) 180 .SetShapeFn([](InferenceContext* c) { 181 ShapeHandle unused; 182 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused)); 183 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 184 185 c->set_output(0, c->Matrix(InferenceContext::kUnknownDim, 2)); 186 c->set_output(1, c->Vector(InferenceContext::kUnknownDim)); 187 c->set_output(2, c->Vector(2)); 188 return Status::OK(); 189 }); 190 191 REGISTER_OP("StringSplitV2") 192 .Input("input: string") 193 .Input("sep: string") 194 .Output("indices: int64") 195 .Output("values: string") 196 .Output("shape: int64") 197 .Attr("maxsplit: int = -1") __anonb4ab91930602(InferenceContext* c) 198 .SetShapeFn([](InferenceContext* c) { 199 ShapeHandle unused; 200 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused)); 201 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); 202 203 c->set_output(0, c->Matrix(InferenceContext::kUnknownDim, 2)); 204 c->set_output(1, c->Vector(InferenceContext::kUnknownDim)); 205 c->set_output(2, c->Vector(2)); 206 return Status::OK(); 207 }); 208 209 REGISTER_OP("StringStrip") 210 .Input("input: string") 211 .Output("output: string") 212 .SetShapeFn(shape_inference::UnchangedShape); 213 214 REGISTER_OP("StringLength") 215 .Input("input: string") 216 .Output("output: int32") 217 .Attr("unit: {'BYTE', 'UTF8_CHAR'} = 'BYTE'") 218 .SetShapeFn(shape_inference::UnchangedShape); 219 220 REGISTER_OP("EncodeBase64") 221 .Input("input: string") 222 .Output("output: string") 223 .Attr("pad: bool = false") 224 .SetShapeFn(shape_inference::UnchangedShape); 225 226 REGISTER_OP("DecodeBase64") 227 .Input("input: string") 228 .Output("output: string") 229 .SetShapeFn(shape_inference::UnchangedShape); 230 231 REGISTER_OP("Substr") 232 .Input("input: string") 233 .Input("pos: T") 234 .Input("len: T") 235 .Output("output: string") 236 .Attr("T: {int32, int64}") 237 .Attr("unit: {'BYTE', 'UTF8_CHAR'} = 'BYTE'") __anonb4ab91930702(InferenceContext* c) 238 .SetShapeFn([](InferenceContext* c) { 239 ShapeHandle pos_shape = c->input(1); 240 ShapeHandle len_shape = c->input(2); 241 ShapeHandle unused; 242 // Check that pos/len have same rank 243 TF_RETURN_IF_ERROR(c->WithRank(pos_shape, c->Rank(len_shape), &unused)); 244 // Check that dimensions are equal 245 for (int32 i = 0; i < c->Rank(pos_shape); ++i) { 246 DimensionHandle pos_dim = c->Dim(pos_shape, i); 247 DimensionHandle len_dim = c->Dim(len_shape, i); 248 if (c->Value(pos_dim) != c->Value(len_dim)) { 249 return errors::InvalidArgument( 250 "pos and len shapes must match: ", c->DebugString(pos_shape), 251 " vs. ", c->DebugString(len_shape)); 252 } 253 } 254 // c->input(0) is the ShapeHandle to input strings 255 // BroadcastBinaryOpShapeFn infers shape from c->input(0) and c->input(1). 256 return shape_inference::BroadcastBinaryOpShapeFn(c); 257 }); 258 259 REGISTER_OP("UnicodeScript") 260 .Input("input: int32") 261 .Output("output: int32") 262 .SetShapeFn(shape_inference::UnchangedShape); 263 264 REGISTER_OP("UnicodeEncode") 265 .Input("input_values: int32") 266 .Input("input_splits: int64") 267 .Attr("errors: {'ignore', 'replace', 'strict'} = 'replace'") 268 .Attr("output_encoding: {'UTF-8', 'UTF-16-BE', 'UTF-32-BE'}") 269 .Attr("replacement_char: int = 65533") // 0xFFFD unicode replacement char 270 .Output("output: string") __anonb4ab91930802(InferenceContext* c) 271 .SetShapeFn([](InferenceContext* c) { 272 // Check rank of inner values 273 ShapeHandle input_inner_values_shape = c->input(0); 274 ShapeHandle unused; 275 TF_RETURN_IF_ERROR(c->WithRank(input_inner_values_shape, 1, &unused)); 276 277 // Check rank of input_splits 278 ShapeHandle splits_shape = c->input(1); 279 TF_RETURN_IF_ERROR(c->WithRank(splits_shape, 1, &unused)); 280 281 // Output shape is a 1-D tensor with size equal to number of splits. 282 std::vector<DimensionHandle> dims(1); 283 TF_RETURN_IF_ERROR(c->Subtract(c->Dim(splits_shape, 0), 1, &dims[0])); 284 c->set_output(0, c->MakeShape(dims)); 285 286 return Status::OK(); 287 }); 288 289 REGISTER_OP("UnicodeTranscode") 290 .Input("input: string") 291 .Output("output: string") 292 .Attr("input_encoding: string") 293 .Attr("output_encoding: {'UTF-8', 'UTF-16-BE', 'UTF-32-BE'}") 294 .Attr("errors: {'strict', 'replace', 'ignore'} = 'replace'") 295 .Attr("replacement_char: int = 65533") // 0xFFFD unicode replacement char 296 .Attr("replace_control_characters: bool = false") 297 .SetShapeFn(shape_inference::UnchangedShape); 298 299 REGISTER_OP("UnicodeDecode") 300 .Input("input: string") 301 .Output("row_splits: int64") 302 .Output("char_values: int32") 303 .Attr("input_encoding: string") 304 .Attr("errors: {'strict', 'replace', 'ignore'} = 'replace'") 305 .Attr("replacement_char: int = 65533") // 0xFFFD unicode replacement char 306 .Attr("replace_control_characters: bool = false") __anonb4ab91930902(InferenceContext* c) 307 .SetShapeFn([](InferenceContext* c) { 308 // row_splits.shape == [input.size() + 1] 309 DimensionHandle num_row_splits; 310 DimensionHandle input_size = c->NumElements(c->input(0)); 311 TF_RETURN_IF_ERROR(c->Add(input_size, 1, &num_row_splits)); 312 c->set_output(0, c->Vector(num_row_splits)); 313 314 // char_values.shape == [num_chars] 315 DimensionHandle num_chars = c->UnknownDim(); 316 c->set_output(1, c->Vector(num_chars)); 317 return Status::OK(); 318 }); 319 320 REGISTER_OP("UnicodeDecodeWithOffsets") 321 .Input("input: string") 322 .Output("row_splits: int64") 323 .Output("char_values: int32") 324 .Output("char_to_byte_starts: int64") 325 .Attr("input_encoding: string") 326 .Attr("errors: {'strict', 'replace', 'ignore'} = 'replace'") 327 .Attr("replacement_char: int = 65533") // 0xFFFD unicode replacement char 328 .Attr("replace_control_characters: bool = false") __anonb4ab91930a02(InferenceContext* c) 329 .SetShapeFn([](InferenceContext* c) { 330 // row_splits.shape == [input.size() + 1] 331 DimensionHandle num_row_splits; 332 DimensionHandle input_size = c->NumElements(c->input(0)); 333 TF_RETURN_IF_ERROR(c->Add(input_size, 1, &num_row_splits)); 334 c->set_output(0, c->Vector(num_row_splits)); 335 336 // char_values.shape == offset_values.shape == [num_chars] 337 DimensionHandle num_chars = c->UnknownDim(); 338 c->set_output(1, c->Vector(num_chars)); 339 c->set_output(2, c->Vector(num_chars)); 340 return Status::OK(); 341 }); 342 343 } // namespace tensorflow 344