1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/shape_inference.h"
19
20 namespace tensorflow {
21
22 using shape_inference::DimensionHandle;
23 using shape_inference::InferenceContext;
24 using shape_inference::ShapeHandle;
25
26 namespace {
27
28 // Sets output[0] to shape [batch_dim,height,width,channel_dim], where
29 // height and width come from the size_tensor.
SetOutputToSizedImage(InferenceContext * c,DimensionHandle batch_dim,int size_input_idx,DimensionHandle channel_dim)30 Status SetOutputToSizedImage(InferenceContext* c, DimensionHandle batch_dim,
31 int size_input_idx, DimensionHandle channel_dim) {
32 // Verify shape of size input.
33 ShapeHandle size;
34 TF_RETURN_IF_ERROR(c->WithRank(c->input(size_input_idx), 1, &size));
35 DimensionHandle unused;
36 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 2, &unused));
37
38 // Get size values from the size tensor.
39 const Tensor* size_tensor = c->input_tensor(size_input_idx);
40 DimensionHandle width;
41 DimensionHandle height;
42 if (size_tensor == nullptr) {
43 width = c->UnknownDim();
44 height = c->UnknownDim();
45 } else {
46 // TODO(petewarden) - Remove once we have constant evaluation in C++ only.
47 if (size_tensor->dtype() != DT_INT32) {
48 return errors::InvalidArgument(
49 "Bad size input type for SetOutputToSizedImage: Expected DT_INT32 "
50 "but got ",
51 DataTypeString(size_tensor->dtype()), " for input #", size_input_idx,
52 " in ", c->DebugString());
53 }
54 auto vec = size_tensor->vec<int32>();
55 height = c->MakeDim(vec(0));
56 width = c->MakeDim(vec(1));
57 }
58 c->set_output(0, c->MakeShape({batch_dim, height, width, channel_dim}));
59 return Status::OK();
60 }
61
ResizeShapeFn(InferenceContext * c)62 Status ResizeShapeFn(InferenceContext* c) {
63 ShapeHandle input;
64 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
65 return SetOutputToSizedImage(c, c->Dim(input, 0), 1 /* size_input_idx */,
66 c->Dim(input, 3));
67 }
68
DecodeImageShapeFn(InferenceContext * c)69 Status DecodeImageShapeFn(InferenceContext* c) {
70 ShapeHandle unused;
71 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
72 DimensionHandle channels_dim;
73 int32 channels;
74 TF_RETURN_IF_ERROR(c->GetAttr("channels", &channels));
75 if (channels == 0) {
76 channels_dim = c->UnknownDim();
77 } else {
78 if (channels < 0) {
79 return errors::InvalidArgument("channels must be non-negative, got ",
80 channels);
81 }
82 channels_dim = c->MakeDim(channels);
83 }
84
85 c->set_output(0, c->MakeShape({InferenceContext::kUnknownDim,
86 InferenceContext::kUnknownDim, channels_dim}));
87 return Status::OK();
88 }
89
DecodeImageV2ShapeFn(InferenceContext * c)90 Status DecodeImageV2ShapeFn(InferenceContext* c) {
91 ShapeHandle unused;
92 int32 channels;
93 bool expand_animations;
94 DimensionHandle channels_dim;
95
96 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
97 TF_RETURN_IF_ERROR(c->GetAttr("channels", &channels));
98 TF_RETURN_IF_ERROR(c->GetAttr("expand_animations", &expand_animations));
99
100 if (channels == 0) {
101 channels_dim = c->UnknownDim();
102 } else {
103 if (channels < 0) {
104 return errors::InvalidArgument("channels must be non-negative, got ",
105 channels);
106 }
107 channels_dim = c->MakeDim(channels);
108 }
109
110 // `expand_animations` set to true will return 4-D shapes for GIF. 3-D shapes
111 // will be returned for jpg, png, and bmp. `expand_animations` set to false
112 // will always return 3-D shapes for all (jpg, png, bmp, gif).
113 if (expand_animations) {
114 c->set_output(0, c->UnknownShape());
115 return Status::OK();
116 } else {
117 c->set_output(0,
118 c->MakeShape({InferenceContext::kUnknownDim,
119 InferenceContext::kUnknownDim, channels_dim}));
120 return Status::OK();
121 }
122 }
123
EncodeImageShapeFn(InferenceContext * c)124 Status EncodeImageShapeFn(InferenceContext* c) {
125 ShapeHandle unused;
126 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &unused));
127 c->set_output(0, c->Scalar());
128 return Status::OK();
129 }
130
ColorspaceShapeFn(InferenceContext * c)131 Status ColorspaceShapeFn(InferenceContext* c) {
132 ShapeHandle input;
133 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input));
134
135 // The last dimension value is always 3.
136 DimensionHandle last_dim;
137 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(input, -1), 3, &last_dim));
138 ShapeHandle out;
139 TF_RETURN_IF_ERROR(c->ReplaceDim(input, -1, last_dim, &out));
140 c->set_output(0, out);
141
142 return Status::OK();
143 }
144
NMSShapeFn(InferenceContext * c)145 Status NMSShapeFn(InferenceContext* c) {
146 // Get inputs and validate ranks.
147 ShapeHandle boxes;
148 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes));
149 ShapeHandle scores;
150 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores));
151 ShapeHandle max_output_size;
152 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size));
153 ShapeHandle iou_threshold;
154 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &iou_threshold));
155 ShapeHandle score_threshold;
156 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &score_threshold));
157 // The boxes is a 2-D float Tensor of shape [num_boxes, 4].
158 DimensionHandle unused;
159 // The boxes[0] and scores[0] are both num_boxes.
160 TF_RETURN_IF_ERROR(c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused));
161 // The boxes[1] is 4.
162 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
163
164 c->set_output(0, c->Vector(c->UnknownDim()));
165 return Status::OK();
166 }
167
SoftNMSShapeFn(InferenceContext * c)168 Status SoftNMSShapeFn(InferenceContext* c) {
169 // Get inputs and validate ranks.
170 ShapeHandle boxes;
171 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes));
172 ShapeHandle scores;
173 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores));
174 ShapeHandle max_output_size;
175 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size));
176 ShapeHandle iou_threshold;
177 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &iou_threshold));
178 ShapeHandle score_threshold;
179 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &score_threshold));
180 ShapeHandle soft_nms_sigma;
181 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &soft_nms_sigma));
182 // The boxes is a 2-D float Tensor of shape [num_boxes, 4].
183 DimensionHandle unused;
184 // The boxes[0] and scores[0] are both num_boxes.
185 TF_RETURN_IF_ERROR(c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused));
186 // The boxes[1] is 4.
187 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
188
189 c->set_output(0, c->Vector(c->UnknownDim()));
190 c->set_output(1, c->Vector(c->UnknownDim()));
191 return Status::OK();
192 }
193
CombinedNMSShapeFn(InferenceContext * c)194 Status CombinedNMSShapeFn(InferenceContext* c) {
195 // Get inputs and validate ranks
196 ShapeHandle boxes;
197 // boxes is a tensor of Dimensions [batch_size, num_anchors, q, 4]
198 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &boxes));
199 ShapeHandle scores;
200 // scores is a tensor of Dimensions [batch_size, num_anchors, num_classes]
201 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &scores));
202 ShapeHandle max_output_size_per_class;
203 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size_per_class));
204 ShapeHandle max_total_size;
205 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &max_total_size));
206 ShapeHandle unused_shape;
207 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_shape));
208 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused_shape));
209
210 DimensionHandle unused;
211 // boxes[0] and scores[0] are both batch_size
212 TF_RETURN_IF_ERROR(c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused));
213 // boxes[1] and scores[1] are both num_anchors
214 TF_RETURN_IF_ERROR(c->Merge(c->Dim(boxes, 1), c->Dim(scores, 1), &unused));
215 // The boxes[3] is 4.
216 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 3), 4, &unused));
217
218 DimensionHandle d = c->Dim(boxes, 2);
219 DimensionHandle class_dim = c->Dim(scores, 2);
220 if (c->ValueKnown(d) && c->ValueKnown(class_dim)) {
221 if (c->Value(d) != 1 && c->Value(d) != c->Value(class_dim)) {
222 return errors::InvalidArgument(
223 "third dimension of boxes must be either "
224 "1 or equal to the third dimension of scores");
225 }
226 }
227 DimensionHandle output_dim;
228 DimensionHandle batch_dim = c->Dim(boxes, 0);
229
230 TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(3, &output_dim));
231 if (c->ValueKnown(output_dim) && c->Value(output_dim) <= 0) {
232 return errors::InvalidArgument("max_total_size should be > 0 ");
233 }
234 DimensionHandle size_per_class;
235 TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &size_per_class));
236
237 int64 output_size;
238 bool pad_per_class;
239 TF_RETURN_IF_ERROR(c->GetAttr("pad_per_class", &pad_per_class));
240 if (!pad_per_class) {
241 output_size = c->Value(output_dim);
242 } else {
243 if (c->ValueKnown(size_per_class) && c->Value(size_per_class) <= 0) {
244 return errors::InvalidArgument(
245 "max_output_size_per_class must be > 0 "
246 "if pad_per_class is set to true ");
247 }
248 output_size = std::min(c->Value(output_dim),
249 c->Value(size_per_class) * c->Value(class_dim));
250 }
251 c->set_output(0, c->MakeShape({batch_dim, output_size, 4}));
252 c->set_output(1, c->MakeShape({batch_dim, output_size}));
253 c->set_output(2, c->MakeShape({batch_dim, output_size}));
254 c->set_output(3, c->Vector(batch_dim));
255 return Status::OK();
256 }
257
258 } // namespace
259
260 // --------------------------------------------------------------------------
261 REGISTER_OP("ResizeArea")
262 .Input("images: T")
263 .Input("size: int32")
264 .Output("resized_images: float")
265 .Attr(
266 "T: {int8, uint8, int16, uint16, int32, int64, half, float, double,"
267 "bfloat16}")
268 .Attr("align_corners: bool = false")
269 .SetShapeFn(ResizeShapeFn);
270
271 // --------------------------------------------------------------------------
272 REGISTER_OP("ResizeBicubic")
273 .Input("images: T")
274 .Input("size: int32")
275 .Output("resized_images: float")
276 .Attr(
277 "T: {int8, uint8, int16, uint16, int32, int64, half, float, double,"
278 "bfloat16}")
279 .Attr("align_corners: bool = false")
280 .Attr("half_pixel_centers: bool = false")
281 .SetShapeFn(ResizeShapeFn);
282
283 // --------------------------------------------------------------------------
284 REGISTER_OP("ResizeBicubicGrad")
285 .Input("grads: float")
286 .Input("original_image: T")
287 .Output("output: T")
288 .Attr("T: {float, double}")
289 .Attr("align_corners: bool = false")
290 .Attr("half_pixel_centers: bool = false")
__anon30b5031d0202(InferenceContext* c) 291 .SetShapeFn([](InferenceContext* c) {
292 c->set_output(0, c->input(1));
293 return Status::OK();
294 });
295
296 // --------------------------------------------------------------------------
297 REGISTER_OP("ResizeBilinear")
298 .Input("images: T")
299 .Input("size: int32")
300 .Output("resized_images: float")
301 .Attr(
302 "T: {int8, uint8, int16, uint16, int32, int64, bfloat16, half, "
303 "float, double, bfloat16}")
304 .Attr("align_corners: bool = false")
305 .Attr("half_pixel_centers: bool = false")
306 .SetShapeFn(ResizeShapeFn);
307
308 // --------------------------------------------------------------------------
309 REGISTER_OP("ScaleAndTranslate")
310 .Input("images: T")
311 .Input("size: int32")
312 .Input("scale: float")
313 .Input("translation: float")
314 .Output("resized_images: float")
315 .Attr(
316 "T: {int8, uint8, int16, uint16, int32, int64, bfloat16, half, "
317 "float, double}")
318 .Attr("kernel_type: string = 'lanczos3'")
319 .Attr("antialias: bool = true")
320 .SetShapeFn(ResizeShapeFn);
321
322 // --------------------------------------------------------------------------
323 REGISTER_OP("QuantizedResizeBilinear")
324 .Input("images: T")
325 .Input("size: int32")
326 .Input("min: float")
327 .Input("max: float")
328 .Output("resized_images: T")
329 .Output("out_min: float")
330 .Output("out_max: float")
331 .Attr("T: {quint8, qint32, float}")
332 .Attr("align_corners: bool = false")
333 .Attr("half_pixel_centers: bool = false")
__anon30b5031d0302(InferenceContext* c) 334 .SetShapeFn([](InferenceContext* c) {
335 TF_RETURN_IF_ERROR(ResizeShapeFn(c));
336 ShapeHandle min_shape;
337 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &min_shape));
338 ShapeHandle max_shape;
339 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &max_shape));
340 c->set_output(1, c->MakeShape({}));
341 c->set_output(2, c->MakeShape({}));
342 return Status::OK();
343 });
344
345 // --------------------------------------------------------------------------
346 REGISTER_OP("ResizeBilinearGrad")
347 .Input("grads: float")
348 .Input("original_image: T")
349 .Output("output: T")
350 .Attr("T: {float, bfloat16, half, double}")
351 .Attr("align_corners: bool = false")
352 .Attr("half_pixel_centers: bool = false")
__anon30b5031d0402(InferenceContext* c) 353 .SetShapeFn([](InferenceContext* c) {
354 c->set_output(0, c->input(1));
355 return Status::OK();
356 });
357
358 // --------------------------------------------------------------------------
359 REGISTER_OP("ScaleAndTranslateGrad")
360 .Input("grads: T")
361 .Input("original_image: T")
362 .Input("scale: float")
363 .Input("translation: float")
364 .Output("output: T")
365 .Attr("T: {float}")
366 .Attr("kernel_type: string = 'lanczos3'")
367 .Attr("antialias: bool = true")
__anon30b5031d0502(InferenceContext* c) 368 .SetShapeFn([](InferenceContext* c) {
369 c->set_output(0, c->input(1));
370 return Status::OK();
371 });
372
373 // --------------------------------------------------------------------------
374 REGISTER_OP("ResizeNearestNeighbor")
375 .Input("images: T")
376 .Input("size: int32")
377 .Output("resized_images: T")
378 .Attr(
379 "T: {int8, uint8, int16, uint16, int32, int64, half, float,"
380 "double, bfloat16}")
381 .Attr("align_corners: bool = false")
382 .Attr("half_pixel_centers: bool = false")
383 .SetShapeFn(ResizeShapeFn);
384
385 // --------------------------------------------------------------------------
386 REGISTER_OP("ResizeNearestNeighborGrad")
387 .Input("grads: T")
388 .Input("size: int32")
389 .Output("output: T")
390 .Attr("T: {uint8, int8, int32, half, float, double, bfloat16}")
391 .Attr("align_corners: bool = false")
392 .Attr("half_pixel_centers: bool = false")
__anon30b5031d0602(InferenceContext* c) 393 .SetShapeFn([](InferenceContext* c) {
394 ShapeHandle input;
395 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
396 ShapeHandle unused;
397 DimensionHandle unused_dim;
398 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
399 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(unused, 0), 2, &unused_dim));
400 const Tensor* size = c->input_tensor(1);
401 if (size == nullptr) {
402 TF_RETURN_IF_ERROR(c->ReplaceDim(input, 1, c->UnknownDim(), &input));
403 TF_RETURN_IF_ERROR(c->ReplaceDim(input, 2, c->UnknownDim(), &input));
404 } else {
405 auto size_vec = size->vec<int32>();
406 TF_RETURN_IF_ERROR(
407 c->ReplaceDim(input, 1, c->MakeDim(size_vec(0)), &input));
408 TF_RETURN_IF_ERROR(
409 c->ReplaceDim(input, 2, c->MakeDim(size_vec(1)), &input));
410 }
411 c->set_output(0, input);
412 return Status::OK();
413 });
414
415 // --------------------------------------------------------------------------
416 REGISTER_OP("RandomCrop")
417 .Input("image: T")
418 .Input("size: int64")
419 .Output("output: T")
420 .Attr("T: {uint8, int8, int16, int32, int64, float, double}")
421 .Attr("seed: int = 0")
422 .Attr("seed2: int = 0")
423 .SetIsStateful()
424 .Deprecated(8, "Random crop is now pure Python")
__anon30b5031d0702(InferenceContext* c) 425 .SetShapeFn([](InferenceContext* c) {
426 ShapeHandle image;
427 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &image));
428 DimensionHandle channels = c->Dim(image, -1);
429
430 ShapeHandle unused;
431 TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->Vector(2), &unused));
432
433 const Tensor* size = c->input_tensor(1);
434 DimensionHandle h;
435 DimensionHandle w;
436 if (size == nullptr) {
437 h = c->UnknownDim();
438 w = c->UnknownDim();
439 } else {
440 auto size_vec = size->vec<int64>();
441 h = c->MakeDim(size_vec(0));
442 w = c->MakeDim(size_vec(1));
443 }
444 c->set_output(0, c->MakeShape({h, w, channels}));
445 return Status::OK();
446 });
447 // TODO(shlens): Support variable rank in RandomCrop.
448
449 // --------------------------------------------------------------------------
450 REGISTER_OP("DecodeImage")
451 .Input("contents: string")
452 // Setting `channels` to 0 means using the inherent number of channels in
453 // the image.
454 .Attr("channels: int = 0")
455 .Attr("dtype: {uint8, uint16, float32} = DT_UINT8")
456 .Output("image: dtype")
457 .Attr("expand_animations: bool = true")
458 .SetShapeFn(DecodeImageV2ShapeFn);
459
460 // --------------------------------------------------------------------------
461 REGISTER_OP("DecodeJpeg")
462 .Input("contents: string")
463 .Attr("channels: int = 0")
464 .Attr("ratio: int = 1")
465 .Attr("fancy_upscaling: bool = true")
466 .Attr("try_recover_truncated: bool = false")
467 .Attr("acceptable_fraction: float = 1.0")
468 .Attr("dct_method: string = ''")
469 .Output("image: uint8")
470 .SetShapeFn(DecodeImageShapeFn);
471
472 // --------------------------------------------------------------------------
473 REGISTER_OP("DecodeAndCropJpeg")
474 .Input("contents: string")
475 .Input("crop_window: int32")
476 .Attr("channels: int = 0")
477 .Attr("ratio: int = 1")
478 .Attr("fancy_upscaling: bool = true")
479 .Attr("try_recover_truncated: bool = false")
480 .Attr("acceptable_fraction: float = 1.0")
481 .Attr("dct_method: string = ''")
482 .Output("image: uint8")
__anon30b5031d0802(InferenceContext* c) 483 .SetShapeFn([](InferenceContext* c) {
484 ShapeHandle unused;
485 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
486 DimensionHandle channels_dim = c->UnknownDim();
487 DimensionHandle h = c->UnknownDim();
488 DimensionHandle w = c->UnknownDim();
489
490 int32 channels;
491 TF_RETURN_IF_ERROR(c->GetAttr("channels", &channels));
492 if (channels != 0) {
493 if (channels < 0) {
494 return errors::InvalidArgument("channels must be non-negative, got ",
495 channels);
496 }
497 channels_dim = c->MakeDim(channels);
498 }
499
500 DimensionHandle unused_dim;
501 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
502 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(unused, 0), 4, &unused_dim));
503
504 const Tensor* crop_window = c->input_tensor(1);
505 if (crop_window != nullptr) {
506 auto crop_window_vec = crop_window->vec<int32>();
507 h = c->MakeDim(crop_window_vec(2));
508 w = c->MakeDim(crop_window_vec(3));
509 }
510 c->set_output(0, c->MakeShape({h, w, channels_dim}));
511 return Status::OK();
512 });
513
514 // --------------------------------------------------------------------------
515 REGISTER_OP("EncodeJpeg")
516 .Input("image: uint8")
517 .Attr("format: {'', 'grayscale', 'rgb'} = ''")
518 .Attr("quality: int = 95")
519 .Attr("progressive: bool = false")
520 .Attr("optimize_size: bool = false")
521 .Attr("chroma_downsampling: bool = true")
522 .Attr("density_unit: {'in', 'cm'} = 'in'")
523 .Attr("x_density: int = 300")
524 .Attr("y_density: int = 300")
525 .Attr("xmp_metadata: string = ''")
526 .Output("contents: string")
527 .SetShapeFn(EncodeImageShapeFn);
528
529 // --------------------------------------------------------------------------
530 REGISTER_OP("EncodeJpegVariableQuality")
531 .Input("images: uint8")
532 .Input("quality: int32")
533 .Output("contents: string")
534 .SetShapeFn(EncodeImageShapeFn);
535
536 // --------------------------------------------------------------------------
537 REGISTER_OP("ExtractJpegShape")
538 .Input("contents: string")
539 .Output("image_shape: output_type")
540 .Attr("output_type: {int32, int64} = DT_INT32")
__anon30b5031d0902(InferenceContext* c) 541 .SetShapeFn([](InferenceContext* c) {
542 ShapeHandle unused;
543 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
544 c->set_output(0, c->Vector(3));
545 return Status::OK();
546 });
547
548 // --------------------------------------------------------------------------
549 REGISTER_OP("AdjustContrast")
550 .Input("images: T")
551 .Input("contrast_factor: float")
552 .Input("min_value: float")
553 .Input("max_value: float")
554 .Output("output: float")
555 .Attr("T: {uint8, int8, int16, int32, int64, float, double}")
556 .Deprecated(2, "Use AdjustContrastv2 instead")
__anon30b5031d0a02(InferenceContext* c) 557 .SetShapeFn([](InferenceContext* c) {
558 // The contrast_factor, min_value, max_value should be scalar only.
559 ShapeHandle unused;
560 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
561 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
562 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
563 return shape_inference::UnchangedShapeWithRankAtLeast(c, 3);
564 });
565
566 // --------------------------------------------------------------------------
567 REGISTER_OP("AdjustContrastv2")
568 .Input("images: T")
569 .Input("contrast_factor: float")
570 .Output("output: T")
571 .Attr("T: {half, float} = DT_FLOAT")
__anon30b5031d0b02(InferenceContext* c) 572 .SetShapeFn([](InferenceContext* c) {
573 // The contrast_factor should be scalar only.
574 ShapeHandle unused;
575 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
576 return shape_inference::UnchangedShapeWithRankAtLeast(c, 3);
577 });
578
579 // --------------------------------------------------------------------------
580 REGISTER_OP("AdjustHue")
581 .Input("images: T")
582 .Input("delta: float")
583 .Output("output: T")
584 .Attr("T: {half, float} = DT_FLOAT")
__anon30b5031d0c02(InferenceContext* c) 585 .SetShapeFn([](InferenceContext* c) {
586 return shape_inference::UnchangedShapeWithRankAtLeast(c, 3);
587 });
588
589 // --------------------------------------------------------------------------
590 REGISTER_OP("AdjustSaturation")
591 .Input("images: T")
592 .Input("scale: float")
593 .Output("output: T")
594 .Attr("T: {half, float} = DT_FLOAT")
__anon30b5031d0d02(InferenceContext* c) 595 .SetShapeFn([](InferenceContext* c) {
596 return shape_inference::UnchangedShapeWithRankAtLeast(c, 3);
597 });
598
599 // --------------------------------------------------------------------------
600 REGISTER_OP("DecodePng")
601 .Input("contents: string")
602 .Attr("channels: int = 0")
603 .Attr("dtype: {uint8, uint16} = DT_UINT8")
604 .Output("image: dtype")
605 .SetShapeFn(DecodeImageShapeFn);
606
607 // --------------------------------------------------------------------------
608 REGISTER_OP("EncodePng")
609 .Attr("compression: int = -1")
610 .Attr("T: {uint8, uint16} = DT_UINT8")
611 .Input("image: T")
612 .Output("contents: string")
613 .SetShapeFn(EncodeImageShapeFn);
614
615 // --------------------------------------------------------------------------
616 REGISTER_OP("DecodeBmp")
617 .Input("contents: string")
618 .Output("image: uint8")
619 .Attr("channels: int = 0")
620 .SetShapeFn(DecodeImageShapeFn);
621
622 // --------------------------------------------------------------------------
623 REGISTER_OP("DecodeGif")
624 .Input("contents: string")
625 .Output("image: uint8")
__anon30b5031d0e02(InferenceContext* c) 626 .SetShapeFn([](InferenceContext* c) {
627 ShapeHandle unused;
628 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
629 c->set_output(0, c->MakeShape({InferenceContext::kUnknownDim,
630 InferenceContext::kUnknownDim,
631 InferenceContext::kUnknownDim, 3}));
632 return Status::OK();
633 });
634
635 // --------------------------------------------------------------------------
636 REGISTER_OP("RGBToHSV")
637 .Input("images: T")
638 .Output("output: T")
639 .Attr("T: {half, bfloat16, float, double} = DT_FLOAT")
640 .SetShapeFn(ColorspaceShapeFn);
641
642 // --------------------------------------------------------------------------
643 REGISTER_OP("HSVToRGB")
644 .Input("images: T")
645 .Output("output: T")
646 .Attr("T: {half, bfloat16, float, double} = DT_FLOAT")
647 .SetShapeFn(ColorspaceShapeFn);
648
649 // --------------------------------------------------------------------------
650 REGISTER_OP("DrawBoundingBoxes")
651 .Input("images: T")
652 .Input("boxes: float")
653 .Output("output: T")
654 .Attr("T: {float, half} = DT_FLOAT")
__anon30b5031d0f02(InferenceContext* c) 655 .SetShapeFn([](InferenceContext* c) {
656 // The rank of images should be 4.
657 ShapeHandle images;
658 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &images));
659 // Channel depth should be either 1 (GRY), 3 (RGB), or 4 (RGBA).
660 if (c->ValueKnown(c->Dim(images, 3))) {
661 int64 depth = c->Value(c->Dim(images, 3));
662 if (!(depth == 1 || depth == 3 || depth == 4)) {
663 return errors::InvalidArgument(
664 "Channel depth should be either 1 (GRY), "
665 "3 (RGB), or 4 (RGBA)");
666 }
667 }
668
669 // The rank of boxes is 3: [batch, num_bounding_boxes, 4].
670 ShapeHandle boxes;
671 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &boxes));
672 // The last value of boxes shape is 4.
673 DimensionHandle unused;
674 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 2), 4, &unused));
675
676 // The rank of the input image (rank = 4) has already been restricted
677 // above, and the output is of the same shape as the input.
678 return shape_inference::UnchangedShape(c);
679 });
680
681 // --------------------------------------------------------------------------
682 REGISTER_OP("DrawBoundingBoxesV2")
683 .Input("images: T")
684 .Input("boxes: float")
685 .Input("colors: float")
686 .Output("output: T")
687 .Attr("T: {float, half} = DT_FLOAT")
__anon30b5031d1002(InferenceContext* c) 688 .SetShapeFn([](InferenceContext* c) {
689 return shape_inference::UnchangedShapeWithRankAtLeast(c, 3);
690 });
691
692 // --------------------------------------------------------------------------
693 REGISTER_OP("SampleDistortedBoundingBox")
694 .Input("image_size: T")
695 .Input("bounding_boxes: float")
696 .Output("begin: T")
697 .Output("size: T")
698 .Output("bboxes: float")
699 .Attr("T: {uint8, int8, int16, int32, int64}")
700 .Attr("seed: int = 0")
701 .Attr("seed2: int = 0")
702 .Attr("min_object_covered: float = 0.1")
703 .Attr("aspect_ratio_range: list(float) = [0.75, 1.33]")
704 .Attr("area_range: list(float) = [0.05, 1.0]")
705 .Attr("max_attempts: int = 100")
706 .Attr("use_image_if_no_bounding_boxes: bool = false")
707 .SetIsStateful()
__anon30b5031d1102(InferenceContext* c) 708 .SetShapeFn([](InferenceContext* c) {
709 // Get inputs and validate ranks.
710 ShapeHandle image_size;
711 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &image_size));
712 ShapeHandle bounding_boxes;
713 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &bounding_boxes));
714 // image_size: 1-D with [height, width, channels]
715 // bounding_boxes: 3-D with shape [batch, N, 4]
716 DimensionHandle unused;
717 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(image_size, 0), 3, &unused));
718 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(bounding_boxes, 2), 4, &unused));
719
720 c->set_output(0, c->Vector(3));
721 c->set_output(1, c->Vector(3));
722 c->set_output(2, c->MakeShape({1, 1, 4}));
723 return Status::OK();
724 });
725
726 REGISTER_OP("SampleDistortedBoundingBoxV2")
727 .Input("image_size: T")
728 .Input("bounding_boxes: float")
729 .Input("min_object_covered: float")
730 .Output("begin: T")
731 .Output("size: T")
732 .Output("bboxes: float")
733 .Attr("T: {uint8, int8, int16, int32, int64}")
734 .Attr("seed: int = 0")
735 .Attr("seed2: int = 0")
736 .Attr("aspect_ratio_range: list(float) = [0.75, 1.33]")
737 .Attr("area_range: list(float) = [0.05, 1.0]")
738 .Attr("max_attempts: int = 100")
739 .Attr("use_image_if_no_bounding_boxes: bool = false")
740 .SetIsStateful()
__anon30b5031d1202(InferenceContext* c) 741 .SetShapeFn([](InferenceContext* c) {
742 // Get inputs and validate ranks.
743 ShapeHandle image_size;
744 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &image_size));
745 ShapeHandle bounding_boxes;
746 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &bounding_boxes));
747 ShapeHandle min_object_covered;
748 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &min_object_covered));
749 // image_size: 1-D with [height, width, channels]
750 // bounding_boxes: 3-D with shape [batch, N, 4]
751 DimensionHandle unused;
752 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(image_size, 0), 3, &unused));
753 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(bounding_boxes, 2), 4, &unused));
754
755 c->set_output(0, c->Vector(3));
756 c->set_output(1, c->Vector(3));
757 c->set_output(2, c->MakeShape({1, 1, 4}));
758 return Status::OK();
759 });
760
761 REGISTER_OP("StatelessSampleDistortedBoundingBox")
762 .Input("image_size: T")
763 .Input("bounding_boxes: float")
764 .Input("min_object_covered: float")
765 .Input("seed: Tseed")
766 .Output("begin: T")
767 .Output("size: T")
768 .Output("bboxes: float")
769 .Attr("T: {uint8, int8, int16, int32, int64}")
770 .Attr("Tseed: {int32, int64}")
771 .Attr("aspect_ratio_range: list(float) = [0.75, 1.33]")
772 .Attr("area_range: list(float) = [0.05, 1.0]")
773 .Attr("max_attempts: int = 100")
774 .Attr("use_image_if_no_bounding_boxes: bool = false")
__anon30b5031d1302(InferenceContext* c) 775 .SetShapeFn([](InferenceContext* c) {
776 // Get inputs and validate ranks.
777 ShapeHandle image_size;
778 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &image_size));
779 ShapeHandle bounding_boxes;
780 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &bounding_boxes));
781 ShapeHandle min_object_covered;
782 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &min_object_covered));
783 ShapeHandle seed;
784 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &seed));
785 // image_size: 1-D with [height, width, channels]
786 // bounding_boxes: 3-D with shape [batch, N, 4]
787 DimensionHandle unused;
788 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(image_size, 0), 3, &unused));
789 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(bounding_boxes, 2), 4, &unused));
790 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(seed, 0), 2, &unused));
791
792 c->set_output(0, c->Vector(3));
793 c->set_output(1, c->Vector(3));
794 c->set_output(2, c->MakeShape({1, 1, 4}));
795
796 return Status::OK();
797 });
798
799 // --------------------------------------------------------------------------
800
801 // glimpse = extract_glimpse(input, size, offsets) extract the glimpse
802 // of size `size` centered at location `offsets` from the input tensor
803 // `input`.
804 //
805 // REQUIRES: input.dims() == 4
806 //
807 REGISTER_OP("ExtractGlimpse")
808 .Input("input: float")
809 .Input("size: int32")
810 .Input("offsets: float")
811 .Output("glimpse: float")
812 .Attr("centered: bool = true")
813 .Attr("normalized: bool = true")
814 .Attr("uniform_noise: bool = true")
815 .Attr("noise: string = 'uniform'")
__anon30b5031d1402(InferenceContext* c) 816 .SetShapeFn([](InferenceContext* c) {
817 ShapeHandle input;
818 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
819 ShapeHandle offsets;
820 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &offsets));
821
822 DimensionHandle batch_dim;
823 TF_RETURN_IF_ERROR(
824 c->Merge(c->Dim(input, 0), c->Dim(offsets, 0), &batch_dim));
825 DimensionHandle unused;
826 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(offsets, 1), 2, &unused));
827
828 bool uniform_noise = false;
829 TF_RETURN_IF_ERROR(c->GetAttr("uniform_noise", &uniform_noise));
830 string noise;
831 TF_RETURN_IF_ERROR(c->GetAttr("noise", &noise));
832 if (uniform_noise && (!noise.empty() && noise != "uniform")) {
833 return errors::InvalidArgument(
834 "The uniform_noise and noise should not be specified at the same "
835 "time");
836 }
837
838 return SetOutputToSizedImage(c, batch_dim, 1 /* size_input_idx */,
839 c->Dim(input, 3));
840 });
841
842 REGISTER_OP("ExtractGlimpseV2")
843 .Input("input: float")
844 .Input("size: int32")
845 .Input("offsets: float")
846 .Output("glimpse: float")
847 .Attr("centered: bool = true")
848 .Attr("normalized: bool = true")
849 .Attr("uniform_noise: bool = true")
850 .Attr("noise: string = 'uniform'")
__anon30b5031d1502(InferenceContext* c) 851 .SetShapeFn([](InferenceContext* c) {
852 ShapeHandle input;
853 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
854 ShapeHandle offsets;
855 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &offsets));
856
857 DimensionHandle batch_dim;
858 TF_RETURN_IF_ERROR(
859 c->Merge(c->Dim(input, 0), c->Dim(offsets, 0), &batch_dim));
860 DimensionHandle unused;
861 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(offsets, 1), 2, &unused));
862
863 bool uniform_noise = false;
864 TF_RETURN_IF_ERROR(c->GetAttr("uniform_noise", &uniform_noise));
865 string noise;
866 TF_RETURN_IF_ERROR(c->GetAttr("noise", &noise));
867 if (uniform_noise && (!noise.empty() && noise != "uniform")) {
868 return errors::InvalidArgument(
869 "The uniform_noise and noise should not be specified at the same "
870 "time");
871 }
872
873 return SetOutputToSizedImage(c, batch_dim, 1 /* size_input_idx */,
874 c->Dim(input, 3));
875 });
876
877 // --------------------------------------------------------------------------
878
879 REGISTER_OP("CropAndResize")
880 .Input("image: T")
881 .Input("boxes: float")
882 .Input("box_ind: int32")
883 .Input("crop_size: int32")
884 .Output("crops: float")
885 .Attr("T: {uint8, uint16, int8, int16, int32, int64, half, float, double}")
886 .Attr("method: {'bilinear', 'nearest'} = 'bilinear'")
887 .Attr("extrapolation_value: float = 0")
__anon30b5031d1602(InferenceContext* c) 888 .SetShapeFn([](InferenceContext* c) {
889 // Get inputs and validate ranks.
890 ShapeHandle input;
891 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
892 ShapeHandle boxes;
893 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &boxes));
894 ShapeHandle box_ind;
895 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &box_ind));
896
897 // boxes[0] and box_ind[0] are both num_boxes.
898 DimensionHandle num_boxes_dim;
899 TF_RETURN_IF_ERROR(
900 c->Merge(c->Dim(boxes, 0), c->Dim(box_ind, 0), &num_boxes_dim));
901
902 // boxes.dim(1) is 4.
903 DimensionHandle unused;
904 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
905
906 return SetOutputToSizedImage(c, num_boxes_dim, 3 /* size_input_idx */,
907 c->Dim(input, 3));
908 });
909
910 REGISTER_OP("CropAndResizeGradImage")
911 .Input("grads: float")
912 .Input("boxes: float")
913 .Input("box_ind: int32")
914 .Input("image_size: int32")
915 .Output("output: T")
916 .Attr("T: {float, half, double}")
917 .Attr("method: {'bilinear', 'nearest'} = 'bilinear'")
__anon30b5031d1702(InferenceContext* c) 918 .SetShapeFn([](InferenceContext* c) {
919 ShapeHandle out;
920 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(3, &out));
921 TF_RETURN_IF_ERROR(c->WithRank(out, 4, &out));
922 c->set_output(0, out);
923 return Status::OK();
924 });
925
926 REGISTER_OP("CropAndResizeGradBoxes")
927 .Input("grads: float")
928 .Input("image: T")
929 .Input("boxes: float")
930 .Input("box_ind: int32")
931 .Output("output: float")
932 .Attr("T: {uint8, uint16, int8, int16, int32, int64, half, float, double}")
933 .Attr("method: {'bilinear'} = 'bilinear'")
__anon30b5031d1802(InferenceContext* c) 934 .SetShapeFn([](InferenceContext* c) {
935 c->set_output(0, c->input(2));
936 return Status::OK();
937 });
938
939 // --------------------------------------------------------------------------
940
941 REGISTER_OP("NonMaxSuppression")
942 .Input("boxes: float")
943 .Input("scores: float")
944 .Input("max_output_size: int32")
945 .Output("selected_indices: int32")
946 .Attr("iou_threshold: float = 0.5")
__anon30b5031d1902(InferenceContext* c) 947 .SetShapeFn([](InferenceContext* c) {
948 // Get inputs and validate ranks.
949 ShapeHandle boxes;
950 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes));
951 ShapeHandle scores;
952 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores));
953 ShapeHandle max_output_size;
954 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size));
955 // The boxes is a 2-D float Tensor of shape [num_boxes, 4].
956 DimensionHandle unused;
957 // The boxes[0] and scores[0] are both num_boxes.
958 TF_RETURN_IF_ERROR(
959 c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused));
960 // The boxes[1] is 4.
961 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
962
963 c->set_output(0, c->Vector(c->UnknownDim()));
964 return Status::OK();
965 });
966
967 REGISTER_OP("NonMaxSuppressionV2")
968 .Input("boxes: T")
969 .Input("scores: T")
970 .Input("max_output_size: int32")
971 .Input("iou_threshold: T_threshold")
972 .Output("selected_indices: int32")
973 .Attr("T: {half, float} = DT_FLOAT")
974 .Attr("T_threshold: {half, float} = DT_FLOAT")
__anon30b5031d1a02(InferenceContext* c) 975 .SetShapeFn([](InferenceContext* c) {
976 // Get inputs and validate ranks.
977 ShapeHandle boxes;
978 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes));
979 ShapeHandle scores;
980 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores));
981 ShapeHandle max_output_size;
982 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size));
983 ShapeHandle iou_threshold;
984 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &iou_threshold));
985 // The boxes is a 2-D float Tensor of shape [num_boxes, 4].
986 DimensionHandle unused;
987 // The boxes[0] and scores[0] are both num_boxes.
988 TF_RETURN_IF_ERROR(
989 c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused));
990 // The boxes[1] is 4.
991 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
992
993 c->set_output(0, c->Vector(c->UnknownDim()));
994 return Status::OK();
995 });
996
997 REGISTER_OP("NonMaxSuppressionV3")
998 .Input("boxes: T")
999 .Input("scores: T")
1000 .Input("max_output_size: int32")
1001 .Input("iou_threshold: T_threshold")
1002 .Input("score_threshold: T_threshold")
1003 .Output("selected_indices: int32")
1004 .Attr("T: {half, float} = DT_FLOAT")
1005 .Attr("T_threshold: {half, float} = DT_FLOAT")
1006 .SetShapeFn(NMSShapeFn);
1007
1008 REGISTER_OP("NonMaxSuppressionV4")
1009 .Input("boxes: T")
1010 .Input("scores: T")
1011 .Input("max_output_size: int32")
1012 .Input("iou_threshold: T_threshold")
1013 .Input("score_threshold: T_threshold")
1014 .Output("selected_indices: int32")
1015 .Output("valid_outputs: int32")
1016 .Attr("T: {half, float} = DT_FLOAT")
1017 .Attr("T_threshold: {half, float} = DT_FLOAT")
1018 .Attr("pad_to_max_output_size: bool = false")
__anon30b5031d1b02(InferenceContext* c) 1019 .SetShapeFn([](InferenceContext* c) {
1020 TF_RETURN_IF_ERROR(NMSShapeFn(c));
1021
1022 bool pad_to_max;
1023 TF_RETURN_IF_ERROR(c->GetAttr("pad_to_max_output_size", &pad_to_max));
1024 if (pad_to_max) {
1025 // If padded, overwrite the shape of the output to be static.
1026 DimensionHandle output_dim;
1027 TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &output_dim));
1028 c->set_output(0, c->MakeShape({output_dim}));
1029 }
1030 c->set_output(1, c->MakeShape({}));
1031 return Status::OK();
1032 });
1033
1034 REGISTER_OP("NonMaxSuppressionV5")
1035 .Input("boxes: T")
1036 .Input("scores: T")
1037 .Input("max_output_size: int32")
1038 .Input("iou_threshold: T")
1039 .Input("score_threshold: T")
1040 .Input("soft_nms_sigma: T")
1041 .Output("selected_indices: int32")
1042 .Output("selected_scores: T")
1043 .Output("valid_outputs: int32")
1044 .Attr("T: {half, float} = DT_FLOAT")
1045 .Attr("pad_to_max_output_size: bool = false")
__anon30b5031d1c02(InferenceContext* c) 1046 .SetShapeFn([](InferenceContext* c) {
1047 TF_RETURN_IF_ERROR(SoftNMSShapeFn(c));
1048
1049 bool pad_to_max;
1050 TF_RETURN_IF_ERROR(c->GetAttr("pad_to_max_output_size", &pad_to_max));
1051 if (pad_to_max) {
1052 // If padded, overwrite the shape of the output to be static.
1053 DimensionHandle output_dim;
1054 TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &output_dim));
1055 c->set_output(0, c->MakeShape({output_dim}));
1056 c->set_output(1, c->MakeShape({output_dim}));
1057 }
1058
1059 c->set_output(2, c->MakeShape({}));
1060 return Status::OK();
1061 });
1062
1063 REGISTER_OP("NonMaxSuppressionWithOverlaps")
1064 .Input("overlaps: float")
1065 .Input("scores: float")
1066 .Input("max_output_size: int32")
1067 .Input("overlap_threshold: float")
1068 .Input("score_threshold: float")
1069 .Output("selected_indices: int32")
__anon30b5031d1d02(InferenceContext* c) 1070 .SetShapeFn([](InferenceContext* c) {
1071 // Get inputs and validate ranks.
1072 ShapeHandle overlaps;
1073 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &overlaps));
1074 ShapeHandle scores;
1075 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores));
1076 ShapeHandle max_output_size;
1077 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size));
1078 ShapeHandle overlap_threshold;
1079 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &overlap_threshold));
1080 ShapeHandle score_threshold;
1081 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &score_threshold));
1082 // The boxes is a 2-D float Tensor of shape [num_boxes, 4].
1083 DimensionHandle unused;
1084 // The boxes[0] and scores[0] are both num_boxes.
1085 TF_RETURN_IF_ERROR(
1086 c->Merge(c->Dim(overlaps, 0), c->Dim(scores, 0), &unused));
1087 // The boxes[1] is 4.
1088 TF_RETURN_IF_ERROR(
1089 c->Merge(c->Dim(overlaps, 0), c->Dim(overlaps, 1), &unused));
1090
1091 c->set_output(0, c->Vector(c->UnknownDim()));
1092 return Status::OK();
1093 });
1094
1095 REGISTER_OP("CombinedNonMaxSuppression")
1096 .Input("boxes: float")
1097 .Input("scores: float")
1098 .Input("max_output_size_per_class: int32")
1099 .Input("max_total_size: int32")
1100 .Input("iou_threshold: float")
1101 .Input("score_threshold: float")
1102 .Output("nmsed_boxes: float")
1103 .Output("nmsed_scores: float")
1104 .Output("nmsed_classes: float")
1105 .Output("valid_detections: int32")
1106 .Attr("pad_per_class: bool = false")
1107 .Attr("clip_boxes: bool = true")
1108 .SetShapeFn(CombinedNMSShapeFn);
1109
1110 REGISTER_OP("GenerateBoundingBoxProposals")
1111 .Input("scores: float")
1112 .Input("bbox_deltas: float")
1113 .Input("image_info: float")
1114 .Input("anchors: float")
1115 .Input("nms_threshold: float")
1116 .Input("pre_nms_topn: int32")
1117 .Input("min_size: float")
1118 .Output("rois: float")
1119 .Output("roi_probabilities: float")
1120 .Attr("post_nms_topn: int = 300")
__anon30b5031d1e02(InferenceContext* c) 1121 .SetShapeFn([](InferenceContext* c) -> Status {
1122 // make sure input tensors have are correct rank
1123 ShapeHandle scores, images, bounding_boxes, anchors, nms_threshold,
1124 n_pre_nms, min_box_size;
1125 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &scores)); //(N, H, W, A)
1126 TF_RETURN_IF_ERROR(
1127 c->WithRank(c->input(1), 4, &bounding_boxes)); //(N,H,W,A4)
1128 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &images)); // (N,5)
1129 auto im_info = c->Dim(images, 1);
1130 TF_RETURN_IF_ERROR(c->WithValue(im_info, 5, &im_info));
1131 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 3, &anchors)); // (A4)
1132 // check scalar tensors
1133 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &nms_threshold));
1134 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &n_pre_nms));
1135 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &min_box_size));
1136
1137 // TODO(skama): verify that the inputs are compatible
1138 int post_nms_top_n;
1139 TF_RETURN_IF_ERROR(c->GetAttr("post_nms_topn", &post_nms_top_n));
1140 auto roi_shape = c->MakeShape(
1141 {c->Dim(scores, 0), post_nms_top_n, 4}); //(N,post_nms_top_n,4)
1142 auto prob_shape = c->MakeShape(
1143 {c->Dim(scores, 0), post_nms_top_n}); // (N,post_nms_top_n)
1144 c->set_output(0, roi_shape);
1145 c->set_output(1, prob_shape);
1146 return Status::OK();
1147 });
1148
1149 // V3 op supports fill_value.
1150 // V2 op supports output_shape.
1151 // V1 op is in contrib.
1152 REGISTER_OP("ImageProjectiveTransformV2")
1153 .Input("images: dtype")
1154 .Input("transforms: float32")
1155 .Input("output_shape: int32")
1156 .Attr("dtype: {uint8, int32, int64, float16, float32, float64}")
1157 .Attr("interpolation: string")
1158 .Attr("fill_mode: string = 'CONSTANT'")
1159 .Output("transformed_images: dtype")
__anon30b5031d1f02(InferenceContext* c) 1160 .SetShapeFn([](InferenceContext* c) {
1161 ShapeHandle input;
1162 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
1163 return SetOutputToSizedImage(c, c->Dim(input, 0), 2 /* size_input_idx */,
1164 c->Dim(input, 3));
1165 });
1166
1167 REGISTER_OP("ImageProjectiveTransformV3")
1168 .Input("images: dtype")
1169 .Input("transforms: float32")
1170 .Input("output_shape: int32")
1171 .Input("fill_value: float32")
1172 .Attr("dtype: {uint8, int32, int64, float16, float32, float64}")
1173 .Attr("interpolation: string")
1174 .Attr("fill_mode: string = 'CONSTANT'")
1175 .Output("transformed_images: dtype")
__anon30b5031d2002(InferenceContext* c) 1176 .SetShapeFn([](InferenceContext* c) {
1177 ShapeHandle input;
1178 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
1179 return SetOutputToSizedImage(c, c->Dim(input, 0), 2 /* size_input_idx */,
1180 c->Dim(input, 3));
1181 });
1182
1183 } // namespace tensorflow
1184