• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <string>
17 #include <vector>
18 
19 #include "absl/strings/str_split.h"
20 #include "tensorflow/core/framework/common_shape_fns.h"
21 #include "tensorflow/core/framework/op.h"
22 #include "tensorflow/core/framework/shape_inference.h"
23 #include "tensorflow/core/lib/core/errors.h"
24 #include "tensorflow/core/lib/core/status.h"
25 #include "tensorflow/core/lib/strings/strcat.h"
26 #include "tensorflow/core/platform/types.h"
27 
28 namespace tensorflow {
29 
30 namespace shape_inference {
31 class InferenceContext;
32 }  // namespace shape_inference
33 
34 using shape_inference::DimensionHandle;
35 using shape_inference::InferenceContext;
36 using shape_inference::ShapeHandle;
37 
38 REGISTER_OP("RegexReplace")
39     .Input("input: string")
40     .Input("pattern: string")
41     .Input("rewrite: string")
42     .Output("output: string")
43     .Attr("replace_global: bool = true")
__anonb4ab91930102(InferenceContext* c) 44     .SetShapeFn([](InferenceContext* c) {
45       ShapeHandle unused;
46       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
47       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
48       c->set_output(0, c->input(0));
49       return Status::OK();
50     });
51 
52 REGISTER_OP("StaticRegexReplace")
53     .Input("input: string")
54     .Attr("pattern: string")
55     .Attr("rewrite: string")
56     .Output("output: string")
57     .Attr("replace_global: bool = true")
58     .SetShapeFn(shape_inference::UnchangedShape);
59 
60 REGISTER_OP("RegexFullMatch")
61     .Input("input: string")
62     .Input("pattern: string")
63     .Output("output: bool")
__anonb4ab91930202(InferenceContext* c) 64     .SetShapeFn([](InferenceContext* c) {
65       ShapeHandle unused;
66       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
67       c->set_output(0, c->input(0));
68       return Status::OK();
69     });
70 
71 REGISTER_OP("StaticRegexFullMatch")
72     .Input("input: string")
73     .Attr("pattern: string")
74     .Output("output: bool")
75     .SetShapeFn(shape_inference::UnchangedShape);
76 
77 REGISTER_OP("StringToHashBucketFast")
78     .Input("input: string")
79     .Output("output: int64")
80     .Attr("num_buckets: int >= 1")
81     .SetShapeFn(shape_inference::UnchangedShape);
82 
83 REGISTER_OP("StringToHashBucketStrong")
84     .Input("input: string")
85     .Output("output: int64")
86     .Attr("num_buckets: int >= 1")
87     .Attr("key: list(int)")
88     .SetShapeFn(shape_inference::UnchangedShape);
89 
90 REGISTER_OP("StringToHashBucket")
91     .Input("string_tensor: string")
92     .Output("output: int64")
93     .Attr("num_buckets: int >= 1")
94     .SetShapeFn(shape_inference::UnchangedShape);
95 
96 REGISTER_OP("ReduceJoin")
97     .Input("inputs: string")
98     .Input("reduction_indices: int32")
99     .Attr("keep_dims: bool = false")
100     .Attr("separator: string = ''")
101     .Output("output: string")
102     .SetShapeFn(shape_inference::ReductionShape);
103 
104 REGISTER_OP("AsString")
105     .Input("input: T")
106     .Output("output: string")
107     .Attr(
108         "T: {int8, int16, int32, int64, complex64, complex128, float, double, "
109         "bool}")
110     .Attr("precision: int = -1")
111     .Attr("scientific: bool = false")
112     .Attr("shortest: bool = false")
113     .Attr("width: int = -1")
114     .Attr("fill: string = ''")
115     .SetShapeFn(shape_inference::UnchangedShape);
116 
117 REGISTER_OP("StringFormat")
118     .Input("inputs: T")
119     .Output("output: string")
120     .Attr("T: list(type) >= 0")
121     .Attr("template: string = '%s'")
122     .Attr("placeholder: string = '%s'")
123     .Attr("summarize: int = 3")
__anonb4ab91930302(InferenceContext* c) 124     .SetShapeFn([](InferenceContext* c) {
125       string template_;
126       string placeholder;
127       TF_RETURN_IF_ERROR(c->GetAttr("template", &template_));
128       TF_RETURN_IF_ERROR(c->GetAttr("placeholder", &placeholder));
129 
130       std::vector<std::string> split_template;
131       split_template = absl::StrSplit(template_, placeholder);
132       int64 num_placeholders = split_template.size() - 1;
133       if (c->num_inputs() != num_placeholders) {
134         return errors::InvalidArgument(strings::StrCat(
135             "num placeholders in template and num inputs must match: ",
136             num_placeholders, " vs. ", c->num_inputs()));
137       }
138 
139       c->set_output(0, c->Scalar());
140       return Status::OK();
141     });
142 
143 REGISTER_OP("StringJoin")
144     .Input("inputs: N * string")
145     .Attr("N: int")
146     .Attr("separator: string = ''")
147     .Output("output: string")
__anonb4ab91930402(InferenceContext* c) 148     .SetShapeFn([](InferenceContext* c) {
149       // If all inputs are scalars, then return a scalar.
150       bool all_scalar = true;
151       for (int i = 0; i < c->num_inputs(); ++i) {
152         if (c->Rank(c->input(i)) != 0) all_scalar = false;
153       }
154       if (all_scalar) {
155         c->set_output(0, c->Scalar());
156         return Status::OK();
157       }
158 
159       // At least one input is unknown or a scalar.
160       // Merge the non-scalars to find the output shape.
161       // Don't merge inputs with unknown rank, as they can actually be scalars
162       // or the output shape.
163       ShapeHandle out = c->UnknownShape();
164       for (int i = 0; i < c->num_inputs(); ++i) {
165         if (c->RankKnown(c->input(i)) && c->Rank(c->input(i)) != 0) {
166           TF_RETURN_IF_ERROR(c->Merge(out, c->input(i), &out));
167         }
168       }
169       c->set_output(0, out);
170       return Status::OK();
171     });
172 
173 REGISTER_OP("StringSplit")
174     .Input("input: string")
175     .Input("delimiter: string")
176     .Output("indices: int64")
177     .Output("values: string")
178     .Output("shape: int64")
179     .Attr("skip_empty: bool = true")
__anonb4ab91930502(InferenceContext* c) 180     .SetShapeFn([](InferenceContext* c) {
181       ShapeHandle unused;
182       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
183       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
184 
185       c->set_output(0, c->Matrix(InferenceContext::kUnknownDim, 2));
186       c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
187       c->set_output(2, c->Vector(2));
188       return Status::OK();
189     });
190 
191 REGISTER_OP("StringSplitV2")
192     .Input("input: string")
193     .Input("sep: string")
194     .Output("indices: int64")
195     .Output("values: string")
196     .Output("shape: int64")
197     .Attr("maxsplit: int = -1")
__anonb4ab91930602(InferenceContext* c) 198     .SetShapeFn([](InferenceContext* c) {
199       ShapeHandle unused;
200       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &unused));
201       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
202 
203       c->set_output(0, c->Matrix(InferenceContext::kUnknownDim, 2));
204       c->set_output(1, c->Vector(InferenceContext::kUnknownDim));
205       c->set_output(2, c->Vector(2));
206       return Status::OK();
207     });
208 
209 REGISTER_OP("StringStrip")
210     .Input("input: string")
211     .Output("output: string")
212     .SetShapeFn(shape_inference::UnchangedShape);
213 
214 REGISTER_OP("StringLength")
215     .Input("input: string")
216     .Output("output: int32")
217     .Attr("unit: {'BYTE', 'UTF8_CHAR'} = 'BYTE'")
218     .SetShapeFn(shape_inference::UnchangedShape);
219 
220 REGISTER_OP("EncodeBase64")
221     .Input("input: string")
222     .Output("output: string")
223     .Attr("pad: bool = false")
224     .SetShapeFn(shape_inference::UnchangedShape);
225 
226 REGISTER_OP("DecodeBase64")
227     .Input("input: string")
228     .Output("output: string")
229     .SetShapeFn(shape_inference::UnchangedShape);
230 
231 REGISTER_OP("Substr")
232     .Input("input: string")
233     .Input("pos: T")
234     .Input("len: T")
235     .Output("output: string")
236     .Attr("T: {int32, int64}")
237     .Attr("unit: {'BYTE', 'UTF8_CHAR'} = 'BYTE'")
__anonb4ab91930702(InferenceContext* c) 238     .SetShapeFn([](InferenceContext* c) {
239       ShapeHandle pos_shape = c->input(1);
240       ShapeHandle len_shape = c->input(2);
241       ShapeHandle unused;
242       // Check that pos/len have same rank
243       TF_RETURN_IF_ERROR(c->WithRank(pos_shape, c->Rank(len_shape), &unused));
244       // Check that dimensions are equal
245       for (int32 i = 0; i < c->Rank(pos_shape); ++i) {
246         DimensionHandle pos_dim = c->Dim(pos_shape, i);
247         DimensionHandle len_dim = c->Dim(len_shape, i);
248         if (c->Value(pos_dim) != c->Value(len_dim)) {
249           return errors::InvalidArgument(
250               "pos and len shapes must match: ", c->DebugString(pos_shape),
251               " vs. ", c->DebugString(len_shape));
252         }
253       }
254       // c->input(0) is the ShapeHandle to input strings
255       // BroadcastBinaryOpShapeFn infers shape from c->input(0) and c->input(1).
256       return shape_inference::BroadcastBinaryOpShapeFn(c);
257     });
258 
259 REGISTER_OP("UnicodeScript")
260     .Input("input: int32")
261     .Output("output: int32")
262     .SetShapeFn(shape_inference::UnchangedShape);
263 
264 REGISTER_OP("UnicodeEncode")
265     .Input("input_values: int32")
266     .Input("input_splits: int64")
267     .Attr("errors: {'ignore', 'replace', 'strict'} = 'replace'")
268     .Attr("output_encoding: {'UTF-8', 'UTF-16-BE', 'UTF-32-BE'}")
269     .Attr("replacement_char: int = 65533")  // 0xFFFD unicode replacement char
270     .Output("output: string")
__anonb4ab91930802(InferenceContext* c) 271     .SetShapeFn([](InferenceContext* c) {
272       // Check rank of inner values
273       ShapeHandle input_inner_values_shape = c->input(0);
274       ShapeHandle unused;
275       TF_RETURN_IF_ERROR(c->WithRank(input_inner_values_shape, 1, &unused));
276 
277       // Check rank of input_splits
278       ShapeHandle splits_shape = c->input(1);
279       TF_RETURN_IF_ERROR(c->WithRank(splits_shape, 1, &unused));
280 
281       // Output shape is a 1-D tensor with size equal to number of splits.
282       std::vector<DimensionHandle> dims(1);
283       TF_RETURN_IF_ERROR(c->Subtract(c->Dim(splits_shape, 0), 1, &dims[0]));
284       c->set_output(0, c->MakeShape(dims));
285 
286       return Status::OK();
287     });
288 
289 REGISTER_OP("UnicodeTranscode")
290     .Input("input: string")
291     .Output("output: string")
292     .Attr("input_encoding: string")
293     .Attr("output_encoding: {'UTF-8', 'UTF-16-BE', 'UTF-32-BE'}")
294     .Attr("errors: {'strict', 'replace', 'ignore'} = 'replace'")
295     .Attr("replacement_char: int = 65533")  // 0xFFFD unicode replacement char
296     .Attr("replace_control_characters: bool = false")
297     .SetShapeFn(shape_inference::UnchangedShape);
298 
299 REGISTER_OP("UnicodeDecode")
300     .Input("input: string")
301     .Output("row_splits: int64")
302     .Output("char_values: int32")
303     .Attr("input_encoding: string")
304     .Attr("errors: {'strict', 'replace', 'ignore'} = 'replace'")
305     .Attr("replacement_char: int = 65533")  // 0xFFFD unicode replacement char
306     .Attr("replace_control_characters: bool = false")
__anonb4ab91930902(InferenceContext* c) 307     .SetShapeFn([](InferenceContext* c) {
308       // row_splits.shape == [input.size() + 1]
309       DimensionHandle num_row_splits;
310       DimensionHandle input_size = c->NumElements(c->input(0));
311       TF_RETURN_IF_ERROR(c->Add(input_size, 1, &num_row_splits));
312       c->set_output(0, c->Vector(num_row_splits));
313 
314       // char_values.shape == [num_chars]
315       DimensionHandle num_chars = c->UnknownDim();
316       c->set_output(1, c->Vector(num_chars));
317       return Status::OK();
318     });
319 
320 REGISTER_OP("UnicodeDecodeWithOffsets")
321     .Input("input: string")
322     .Output("row_splits: int64")
323     .Output("char_values: int32")
324     .Output("char_to_byte_starts: int64")
325     .Attr("input_encoding: string")
326     .Attr("errors: {'strict', 'replace', 'ignore'} = 'replace'")
327     .Attr("replacement_char: int = 65533")  // 0xFFFD unicode replacement char
328     .Attr("replace_control_characters: bool = false")
__anonb4ab91930a02(InferenceContext* c) 329     .SetShapeFn([](InferenceContext* c) {
330       // row_splits.shape == [input.size() + 1]
331       DimensionHandle num_row_splits;
332       DimensionHandle input_size = c->NumElements(c->input(0));
333       TF_RETURN_IF_ERROR(c->Add(input_size, 1, &num_row_splits));
334       c->set_output(0, c->Vector(num_row_splits));
335 
336       // char_values.shape == offset_values.shape == [num_chars]
337       DimensionHandle num_chars = c->UnknownDim();
338       c->set_output(1, c->Vector(num_chars));
339       c->set_output(2, c->Vector(num_chars));
340       return Status::OK();
341     });
342 
343 }  // namespace tensorflow
344