• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <string>
17 #include <vector>
18 
19 #include "absl/strings/str_split.h"
20 #include "tensorflow/core/framework/common_shape_fns.h"
21 #include "tensorflow/core/framework/op.h"
22 #include "tensorflow/core/framework/shape_inference.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/lib/strings/strcat.h"
26 #include "tensorflow/core/platform/types.h"
27 
28 namespace tensorflow {
29 
30 namespace shape_inference {
31 class InferenceContext;
32 }  // namespace shape_inference
33 
34 using shape_inference::DimensionHandle;
35 using shape_inference::InferenceContext;
36 using shape_inference::ShapeHandle;
37 
38 REGISTER_OP("RegexReplace")
39     .Input("input: string")
40     .Input("pattern: string")
41     .Input("rewrite: string")
42     .Output("output: string")
43     .Attr("replace_global: bool = true")
__anon4354d5f10102(InferenceContext* c) 44     .SetShapeFn([](InferenceContext* c) {
45       ShapeHandle unused;
46       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
47       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
48       c->set_output(0, c->input(0));
49       return Status::OK();
50     });
51 
52 REGISTER_OP("StaticRegexReplace")
53     .Input("input: string")
54     .Attr("pattern: string")
55     .Attr("rewrite: string")
56     .Output("output: string")
57     .Attr("replace_global: bool = true")
58     .SetShapeFn(shape_inference::UnchangedShape);
59 
60 REGISTER_OP("RegexFullMatch")
61     .Input("input: string")
62     .Input("pattern: string")
63     .Output("output: bool")
__anon4354d5f10202(InferenceContext* c) 64     .SetShapeFn([](InferenceContext* c) {
65       ShapeHandle unused;
66       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
67       c->set_output(0, c->input(0));
68       return Status::OK();
69     });
70 
71 REGISTER_OP("StaticRegexFullMatch")
72     .Input("input: string")
73     .Attr("pattern: string")
74     .Output("output: bool")
75     .SetShapeFn(shape_inference::UnchangedShape);
76 
77 REGISTER_OP("StringToHashBucketFast")
78     .Input("input: string")
79     .Output("output: int64")
80     .Attr("num_buckets: int >= 1")
81     .SetShapeFn(shape_inference::UnchangedShape);
82 
83 REGISTER_OP("StringToHashBucketStrong")
84     .Input("input: string")
85     .Output("output: int64")
86     .Attr("num_buckets: int >= 1")
87     .Attr("key: list(int)")
88     .SetShapeFn(shape_inference::UnchangedShape);
89 
90 REGISTER_OP("StringToHashBucket")
91     .Input("string_tensor: string")
92     .Output("output: int64")
93     .Attr("num_buckets: int >= 1")
94     .SetShapeFn(shape_inference::UnchangedShape);
95 
96 REGISTER_OP("ReduceJoin")
97     .Input("inputs: string")
98     .Input("reduction_indices: int32")
99     .Attr("keep_dims: bool = false")
100     .Attr("separator: string = ''")
101     .Output("output: string")
102     .SetShapeFn(shape_inference::ReductionShape);
103 
104 REGISTER_OP("UnsortedSegmentJoin")
105     .Input("inputs: string")
106     .Input("segment_ids: Tindices")
107     .Input("num_segments: Tnumsegments")
108     .Attr("separator: string = ''")
109     .Attr("Tindices: {int32,int64}")
110     .Attr("Tnumsegments: {int32,int64} = DT_INT32")
111     .Output("output: string")
112     .SetShapeFn(shape_inference::UnsortedSegmentReductionShapeFn);
113 
114 REGISTER_OP("AsString")
115     .Input("input: T")
116     .Output("output: string")
117     .Attr(
118         "T: {int8, int16, int32, int64, complex64, complex128, float, double, "
119         "bool, variant}")
120     .Attr("precision: int = -1")
121     .Attr("scientific: bool = false")
122     .Attr("shortest: bool = false")
123     .Attr("width: int = -1")
124     .Attr("fill: string = ''")
125     .SetShapeFn(shape_inference::UnchangedShape);
126 
127 REGISTER_OP("StringFormat")
128     .Input("inputs: T")
129     .Output("output: string")
130     .Attr("T: list(type) >= 0")
131     .Attr("template: string = '%s'")
132     .Attr("placeholder: string = '%s'")
133     .Attr("summarize: int = 3")
__anon4354d5f10302(InferenceContext* c) 134     .SetShapeFn([](InferenceContext* c) {
135       string template_;
136       string placeholder;
137       TF_RETURN_IF_ERROR(c->GetAttr("template", &template_));
138       TF_RETURN_IF_ERROR(c->GetAttr("placeholder", &placeholder));
139 
140       std::vector<std::string> split_template;
141       split_template = absl::StrSplit(template_, placeholder);
142       int64 num_placeholders = split_template.size() - 1;
143       if (c->num_inputs() != num_placeholders) {
144         return errors::InvalidArgument(strings::StrCat(
145             "num placeholders in template and num inputs must match: ",
146             num_placeholders, " vs. ", c->num_inputs()));
147       }
148 
149       c->set_output(0, c->Scalar());
150       return Status::OK();
151     });
152 
153 REGISTER_OP("StringJoin")
154     .Input("inputs: N * string")
155     .Attr("N: int")
156     .Attr("separator: string = ''")
157     .Output("output: string")
__anon4354d5f10402(InferenceContext* c) 158     .SetShapeFn([](InferenceContext* c) {
159       // If all inputs are scalars, then return a scalar.
160       bool all_scalar = true;
161       for (int i = 0; i < c->num_inputs(); ++i) {
162         if (c->Rank(c->input(i)) != 0) all_scalar = false;
163       }
164       if (all_scalar) {
165         c->set_output(0, c->Scalar());
166         return Status::OK();
167       }
168 
169       // At least one input is unknown or a scalar.
170       // Merge the non-scalars to find the output shape.
171       // Don't merge inputs with unknown rank, as they can actually be scalars
172       // or the output shape.
173       ShapeHandle out = c->UnknownShape();
174       for (int i = 0; i < c->num_inputs(); ++i) {
175         if (c->RankKnown(c->input(i)) && c->Rank(c->input(i)) != 0) {
176           TF_RETURN_IF_ERROR(c->Merge(out, c->input(i), &out));
177         }
178       }
179       c->set_output(0, out);
180       return Status::OK();
181     });
182 
183 REGISTER_OP("StringSplit")
184     .Input("input: string")
185     .Input("delimiter: string")
186     .Output("indices: int64")
187     .Output("values: string")
188     .Output("shape: int64")
189     .Attr("skip_empty: bool = true")
__anon4354d5f10502(InferenceContext* c) 190     .SetShapeFn([](InferenceContext* c) {
191       ShapeHandle unused;
192       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
193       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
194 
195       c->set_output(0, c->Matrix(InferenceContext::kUnknownDim, 2));
196       c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
197       c->set_output(2, c->Vector(2));
198       return Status::OK();
199     });
200 
201 REGISTER_OP("StringSplitV2")
202     .Input("input: string")
203     .Input("sep: string")
204     .Output("indices: int64")
205     .Output("values: string")
206     .Output("shape: int64")
207     .Attr("maxsplit: int = -1")
__anon4354d5f10602(InferenceContext* c) 208     .SetShapeFn([](InferenceContext* c) {
209       ShapeHandle unused;
210       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
211       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
212 
213       c->set_output(0, c->Matrix(InferenceContext::kUnknownDim, 2));
214       c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
215       c->set_output(2, c->Vector(2));
216       return Status::OK();
217     });
218 
219 REGISTER_OP("StringLower")
220     .Input("input: string")
221     .Output("output: string")
222     .Attr("encoding: string =''")
223     .SetShapeFn(shape_inference::UnchangedShape);
224 
225 REGISTER_OP("StringUpper")
226     .Input("input: string")
227     .Output("output: string")
228     .Attr("encoding: string =''")
229     .SetShapeFn(shape_inference::UnchangedShape);
230 
231 REGISTER_OP("StringStrip")
232     .Input("input: string")
233     .Output("output: string")
234     .SetShapeFn(shape_inference::UnchangedShape);
235 
236 REGISTER_OP("StringLength")
237     .Input("input: string")
238     .Output("output: int32")
239     .Attr("unit: {'BYTE', 'UTF8_CHAR'} = 'BYTE'")
240     .SetShapeFn(shape_inference::UnchangedShape);
241 
242 REGISTER_OP("EncodeBase64")
243     .Input("input: string")
244     .Output("output: string")
245     .Attr("pad: bool = false")
246     .SetShapeFn(shape_inference::UnchangedShape);
247 
248 REGISTER_OP("DecodeBase64")
249     .Input("input: string")
250     .Output("output: string")
251     .SetShapeFn(shape_inference::UnchangedShape);
252 
253 REGISTER_OP("Substr")
254     .Input("input: string")
255     .Input("pos: T")
256     .Input("len: T")
257     .Output("output: string")
258     .Attr("T: {int32, int64}")
259     .Attr("unit: {'BYTE', 'UTF8_CHAR'} = 'BYTE'")
__anon4354d5f10702(InferenceContext* c) 260     .SetShapeFn([](InferenceContext* c) {
261       ShapeHandle pos_shape = c->input(1);
262       ShapeHandle len_shape = c->input(2);
263       ShapeHandle unused;
264       // If len rank is known, check that pos and len have the same rank
265       if (c->RankKnown(len_shape)) {
266         TF_RETURN_IF_ERROR(c->WithRank(pos_shape, c->Rank(len_shape), &unused));
267       }
268       // Check that dimensions are equal
269       for (int32 i = 0; i < c->Rank(pos_shape); ++i) {
270         DimensionHandle pos_dim = c->Dim(pos_shape, i);
271         DimensionHandle len_dim = c->Dim(len_shape, i);
272         if (c->Value(pos_dim) != c->Value(len_dim)) {
273           return errors::InvalidArgument(
274               "pos and len shapes must match: ", c->DebugString(pos_shape),
275               " vs. ", c->DebugString(len_shape));
276         }
277       }
278       // c->input(0) is the ShapeHandle to input strings
279       // BroadcastBinaryOpShapeFn infers shape from c->input(0) and c->input(1).
280       return shape_inference::BroadcastBinaryOpShapeFn(c);
281     });
282 
283 REGISTER_OP("UnicodeScript")
284     .Input("input: int32")
285     .Output("output: int32")
286     .SetShapeFn(shape_inference::UnchangedShape);
287 
288 REGISTER_OP("UnicodeEncode")
289     .Input("input_values: int32")
290     .Input("input_splits: Tsplits")
291     .Attr("errors: {'ignore', 'replace', 'strict'} = 'replace'")
292     .Attr("output_encoding: {'UTF-8', 'UTF-16-BE', 'UTF-32-BE'}")
293     .Attr("replacement_char: int = 65533")  // 0xFFFD unicode replacement char
294     .Attr("Tsplits: {int32, int64} = DT_INT64")
295     .Output("output: string")
__anon4354d5f10802(InferenceContext* c) 296     .SetShapeFn([](InferenceContext* c) {
297       // Check rank of inner values
298       ShapeHandle input_inner_values_shape = c->input(0);
299       ShapeHandle unused;
300       TF_RETURN_IF_ERROR(c->WithRank(input_inner_values_shape, 1, &unused));
301 
302       // Check rank of input_splits
303       ShapeHandle splits_shape = c->input(1);
304       TF_RETURN_IF_ERROR(c->WithRank(splits_shape, 1, &unused));
305 
306       // Output shape is a 1-D tensor with size equal to number of splits.
307       std::vector<DimensionHandle> dims(1);
308       TF_RETURN_IF_ERROR(c->Subtract(c->Dim(splits_shape, 0), 1, &dims[0]));
309       c->set_output(0, c->MakeShape(dims));
310 
311       return Status::OK();
312     });
313 
314 REGISTER_OP("UnicodeTranscode")
315     .Input("input: string")
316     .Output("output: string")
317     .Attr("input_encoding: string")
318     .Attr("output_encoding: {'UTF-8', 'UTF-16-BE', 'UTF-32-BE'}")
319     .Attr("errors: {'strict', 'replace', 'ignore'} = 'replace'")
320     .Attr("replacement_char: int = 65533")  // 0xFFFD unicode replacement char
321     .Attr("replace_control_characters: bool = false")
322     .SetShapeFn(shape_inference::UnchangedShape);
323 
324 REGISTER_OP("UnicodeDecode")
325     .Input("input: string")
326     .Output("row_splits: Tsplits")
327     .Output("char_values: int32")
328     .Attr("input_encoding: string")
329     .Attr("errors: {'strict', 'replace', 'ignore'} = 'replace'")
330     .Attr("replacement_char: int = 65533")  // 0xFFFD unicode replacement char
331     .Attr("replace_control_characters: bool = false")
332     .Attr("Tsplits: {int32, int64} = DT_INT64")
__anon4354d5f10902(InferenceContext* c) 333     .SetShapeFn([](InferenceContext* c) {
334       // row_splits.shape == [input.size() + 1]
335       DimensionHandle num_row_splits;
336       DimensionHandle input_size = c->NumElements(c->input(0));
337       TF_RETURN_IF_ERROR(c->Add(input_size, 1, &num_row_splits));
338       c->set_output(0, c->Vector(num_row_splits));
339 
340       // char_values.shape == [num_chars]
341       DimensionHandle num_chars = c->UnknownDim();
342       c->set_output(1, c->Vector(num_chars));
343       return Status::OK();
344     });
345 
346 REGISTER_OP("UnicodeDecodeWithOffsets")
347     .Input("input: string")
348     .Output("row_splits: Tsplits")
349     .Output("char_values: int32")
350     .Output("char_to_byte_starts: int64")
351     .Attr("input_encoding: string")
352     .Attr("errors: {'strict', 'replace', 'ignore'} = 'replace'")
353     .Attr("replacement_char: int = 65533")  // 0xFFFD unicode replacement char
354     .Attr("replace_control_characters: bool = false")
355     .Attr("Tsplits: {int32, int64} = DT_INT64")
__anon4354d5f10a02(InferenceContext* c) 356     .SetShapeFn([](InferenceContext* c) {
357       // row_splits.shape == [input.size() + 1]
358       DimensionHandle num_row_splits;
359       DimensionHandle input_size = c->NumElements(c->input(0));
360       TF_RETURN_IF_ERROR(c->Add(input_size, 1, &num_row_splits));
361       c->set_output(0, c->Vector(num_row_splits));
362 
363       // char_values.shape == offset_values.shape == [num_chars]
364       DimensionHandle num_chars = c->UnknownDim();
365       c->set_output(1, c->Vector(num_chars));
366       c->set_output(2, c->Vector(num_chars));
367       return Status::OK();
368     });
369 
370 REGISTER_OP("StringNGrams")
371     .Attr("separator: string")
372     .Attr("ngram_widths: list(int) >= 0")
373     .Attr("left_pad: string")
374     .Attr("right_pad: string")
375     .Attr("pad_width: int")
376     .Attr("preserve_short_sequences: bool")
377     .Attr("Tsplits: {int32, int64} = DT_INT64")
378     .Input("data: string")
379     .Input("data_splits: Tsplits")
380     .Output("ngrams: string")
381     .Output("ngrams_splits: Tsplits")
__anon4354d5f10b02(InferenceContext* c) 382     .SetShapeFn([](InferenceContext* c) {
383       c->set_output(0, c->UnknownShapeOfRank(1));
384       ShapeHandle data = c->input(0);
385       TF_RETURN_IF_ERROR(c->WithRank(data, 1, &data));
386       ShapeHandle data_splits = c->input(1);
387       TF_RETURN_IF_ERROR(c->WithRank(data_splits, 1, &data_splits));
388       c->set_output(1, data_splits);
389       return Status::OK();
390     });
391 
392 }  // namespace tensorflow
393