• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/shape_inference.h"
19 
20 namespace tensorflow {
21 
22 using shape_inference::DimensionHandle;
23 using shape_inference::InferenceContext;
24 using shape_inference::ShapeHandle;
25 
26 namespace {
27 
28 // Sets output[0] to shape [batch_dim,height,width,channel_dim], where
29 // height and width come from the size_tensor.
SetOutputToSizedImage(InferenceContext * c,DimensionHandle batch_dim,int size_input_idx,DimensionHandle channel_dim)30 Status SetOutputToSizedImage(InferenceContext* c, DimensionHandle batch_dim,
31                              int size_input_idx, DimensionHandle channel_dim) {
32   // Verify shape of size input.
33   ShapeHandle size;
34   TF_RETURN_IF_ERROR(c->WithRank(c->input(size_input_idx), 1, &size));
35   DimensionHandle unused;
36   TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 2, &unused));
37 
38   // Get size values from the size tensor.
39   const Tensor* size_tensor = c->input_tensor(size_input_idx);
40   DimensionHandle width;
41   DimensionHandle height;
42   if (size_tensor == nullptr) {
43     width = c->UnknownDim();
44     height = c->UnknownDim();
45   } else {
46     // TODO(petewarden) - Remove once we have constant evaluation in C++ only.
47     if (size_tensor->dtype() != DT_INT32) {
48       return errors::InvalidArgument(
49           "Bad size input type for SetOutputToSizedImage: Expected DT_INT32 "
50           "but got ",
51           DataTypeString(size_tensor->dtype()), " for input #", size_input_idx,
52           " in ", c->DebugString());
53     }
54     auto vec = size_tensor->vec<int32>();
55     height = c->MakeDim(vec(0));
56     width = c->MakeDim(vec(1));
57   }
58   c->set_output(0, c->MakeShape({batch_dim, height, width, channel_dim}));
59   return Status::OK();
60 }
61 
ResizeShapeFn(InferenceContext * c)62 Status ResizeShapeFn(InferenceContext* c) {
63   ShapeHandle input;
64   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
65   return SetOutputToSizedImage(c, c->Dim(input, 0), 1 /* size_input_idx */,
66                                c->Dim(input, 3));
67 }
68 
DecodeImageShapeFn(InferenceContext * c)69 Status DecodeImageShapeFn(InferenceContext* c) {
70   ShapeHandle unused;
71   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
72   DimensionHandle channels_dim;
73   int32 channels;
74   TF_RETURN_IF_ERROR(c->GetAttr("channels", &channels));
75   if (channels == 0) {
76     channels_dim = c->UnknownDim();
77   } else {
78     if (channels < 0) {
79       return errors::InvalidArgument("channels must be non-negative, got ",
80                                      channels);
81     }
82     channels_dim = c->MakeDim(channels);
83   }
84 
85   c->set_output(0, c->MakeShape({InferenceContext::kUnknownDim,
86                                  InferenceContext::kUnknownDim, channels_dim}));
87   return Status::OK();
88 }
89 
DecodeImageV2ShapeFn(InferenceContext * c)90 Status DecodeImageV2ShapeFn(InferenceContext* c) {
91   ShapeHandle unused;
92   int32 channels;
93   bool expand_animations;
94   DimensionHandle channels_dim;
95 
96   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
97   TF_RETURN_IF_ERROR(c->GetAttr("channels", &channels));
98   TF_RETURN_IF_ERROR(c->GetAttr("expand_animations", &expand_animations));
99 
100   if (channels == 0) {
101     channels_dim = c->UnknownDim();
102   } else {
103     if (channels < 0) {
104       return errors::InvalidArgument("channels must be non-negative, got ",
105                                      channels);
106     }
107     channels_dim = c->MakeDim(channels);
108   }
109 
110   // `expand_animations` set to true will return 4-D shapes for GIF. 3-D shapes
111   // will be returned for jpg, png, and bmp. `expand_animations` set to false
112   // will always return 3-D shapes for all (jpg, png, bmp, gif).
113   if (expand_animations) {
114     c->set_output(0, c->UnknownShape());
115     return Status::OK();
116   } else {
117     c->set_output(0,
118                   c->MakeShape({InferenceContext::kUnknownDim,
119                                 InferenceContext::kUnknownDim, channels_dim}));
120     return Status::OK();
121   }
122 }
123 
EncodeImageShapeFn(InferenceContext * c)124 Status EncodeImageShapeFn(InferenceContext* c) {
125   ShapeHandle unused;
126   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &unused));
127   c->set_output(0, c->Scalar());
128   return Status::OK();
129 }
130 
ColorspaceShapeFn(InferenceContext * c)131 Status ColorspaceShapeFn(InferenceContext* c) {
132   ShapeHandle input;
133   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input));
134 
135   // The last dimension value is always 3.
136   DimensionHandle last_dim;
137   TF_RETURN_IF_ERROR(c->WithValue(c->Dim(input, -1), 3, &last_dim));
138   ShapeHandle out;
139   TF_RETURN_IF_ERROR(c->ReplaceDim(input, -1, last_dim, &out));
140   c->set_output(0, out);
141 
142   return Status::OK();
143 }
144 
NMSShapeFn(InferenceContext * c)145 Status NMSShapeFn(InferenceContext* c) {
146   // Get inputs and validate ranks.
147   ShapeHandle boxes;
148   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes));
149   ShapeHandle scores;
150   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores));
151   ShapeHandle max_output_size;
152   TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size));
153   ShapeHandle iou_threshold;
154   TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &iou_threshold));
155   ShapeHandle score_threshold;
156   TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &score_threshold));
157   // The boxes is a 2-D float Tensor of shape [num_boxes, 4].
158   DimensionHandle unused;
159   // The boxes[0] and scores[0] are both num_boxes.
160   TF_RETURN_IF_ERROR(c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused));
161   // The boxes[1] is 4.
162   TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
163 
164   c->set_output(0, c->Vector(c->UnknownDim()));
165   return Status::OK();
166 }
167 
SoftNMSShapeFn(InferenceContext * c)168 Status SoftNMSShapeFn(InferenceContext* c) {
169   // Get inputs and validate ranks.
170   ShapeHandle boxes;
171   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes));
172   ShapeHandle scores;
173   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores));
174   ShapeHandle max_output_size;
175   TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size));
176   ShapeHandle iou_threshold;
177   TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &iou_threshold));
178   ShapeHandle score_threshold;
179   TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &score_threshold));
180   ShapeHandle soft_nms_sigma;
181   TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &soft_nms_sigma));
182   // The boxes is a 2-D float Tensor of shape [num_boxes, 4].
183   DimensionHandle unused;
184   // The boxes[0] and scores[0] are both num_boxes.
185   TF_RETURN_IF_ERROR(c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused));
186   // The boxes[1] is 4.
187   TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
188 
189   c->set_output(0, c->Vector(c->UnknownDim()));
190   c->set_output(1, c->Vector(c->UnknownDim()));
191   return Status::OK();
192 }
193 
CombinedNMSShapeFn(InferenceContext * c)194 Status CombinedNMSShapeFn(InferenceContext* c) {
195   // Get inputs and validate ranks
196   ShapeHandle boxes;
197   // boxes is a tensor of Dimensions [batch_size, num_anchors, q, 4]
198   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &boxes));
199   ShapeHandle scores;
200   // scores is a tensor of Dimensions [batch_size, num_anchors, num_classes]
201   TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &scores));
202   ShapeHandle max_output_size_per_class;
203   TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size_per_class));
204   ShapeHandle max_total_size;
205   TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &max_total_size));
206   ShapeHandle unused_shape;
207   TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_shape));
208   TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused_shape));
209 
210   DimensionHandle unused;
211   // boxes[0] and scores[0] are both batch_size
212   TF_RETURN_IF_ERROR(c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused));
213   // boxes[1] and scores[1] are both num_anchors
214   TF_RETURN_IF_ERROR(c->Merge(c->Dim(boxes, 1), c->Dim(scores, 1), &unused));
215   // The boxes[3] is 4.
216   TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 3), 4, &unused));
217 
218   DimensionHandle d = c->Dim(boxes, 2);
219   DimensionHandle class_dim = c->Dim(scores, 2);
220   if (c->ValueKnown(d) && c->ValueKnown(class_dim)) {
221     if (c->Value(d) != 1 && c->Value(d) != c->Value(class_dim)) {
222       return errors::InvalidArgument(
223           "third dimension of boxes must be either "
224           "1 or equal to the third dimension of scores");
225     }
226   }
227   DimensionHandle output_dim;
228   DimensionHandle batch_dim = c->Dim(boxes, 0);
229 
230   TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(3, &output_dim));
231   if (c->ValueKnown(output_dim) && c->Value(output_dim) <= 0) {
232     return errors::InvalidArgument("max_total_size should be > 0 ");
233   }
234   DimensionHandle size_per_class;
235   TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &size_per_class));
236 
237   int64 output_size;
238   bool pad_per_class;
239   TF_RETURN_IF_ERROR(c->GetAttr("pad_per_class", &pad_per_class));
240   if (!pad_per_class) {
241     output_size = c->Value(output_dim);
242   } else {
243     if (c->ValueKnown(size_per_class) && c->Value(size_per_class) <= 0) {
244       return errors::InvalidArgument(
245           "max_output_size_per_class must be > 0 "
246           "if pad_per_class is set to true ");
247     }
248     output_size = std::min(c->Value(output_dim),
249                            c->Value(size_per_class) * c->Value(class_dim));
250   }
251   c->set_output(0, c->MakeShape({batch_dim, output_size, 4}));
252   c->set_output(1, c->MakeShape({batch_dim, output_size}));
253   c->set_output(2, c->MakeShape({batch_dim, output_size}));
254   c->set_output(3, c->Vector(batch_dim));
255   return Status::OK();
256 }
257 
258 }  // namespace
259 
260 // --------------------------------------------------------------------------
261 REGISTER_OP("ResizeArea")
262     .Input("images: T")
263     .Input("size: int32")
264     .Output("resized_images: float")
265     .Attr(
266         "T: {int8, uint8, int16, uint16, int32, int64, half, float, double,"
267         "bfloat16}")
268     .Attr("align_corners: bool = false")
269     .SetShapeFn(ResizeShapeFn);
270 
271 // --------------------------------------------------------------------------
272 REGISTER_OP("ResizeBicubic")
273     .Input("images: T")
274     .Input("size: int32")
275     .Output("resized_images: float")
276     .Attr(
277         "T: {int8, uint8, int16, uint16, int32, int64, half, float, double,"
278         "bfloat16}")
279     .Attr("align_corners: bool = false")
280     .Attr("half_pixel_centers: bool = false")
281     .SetShapeFn(ResizeShapeFn);
282 
283 // --------------------------------------------------------------------------
284 REGISTER_OP("ResizeBicubicGrad")
285     .Input("grads: float")
286     .Input("original_image: T")
287     .Output("output: T")
288     .Attr("T: {float, double}")
289     .Attr("align_corners: bool = false")
290     .Attr("half_pixel_centers: bool = false")
__anon30b5031d0202(InferenceContext* c) 291     .SetShapeFn([](InferenceContext* c) {
292       c->set_output(0, c->input(1));
293       return Status::OK();
294     });
295 
296 // --------------------------------------------------------------------------
297 REGISTER_OP("ResizeBilinear")
298     .Input("images: T")
299     .Input("size: int32")
300     .Output("resized_images: float")
301     .Attr(
302         "T: {int8, uint8, int16, uint16, int32, int64, bfloat16, half, "
303         "float, double, bfloat16}")
304     .Attr("align_corners: bool = false")
305     .Attr("half_pixel_centers: bool = false")
306     .SetShapeFn(ResizeShapeFn);
307 
308 // --------------------------------------------------------------------------
309 REGISTER_OP("ScaleAndTranslate")
310     .Input("images: T")
311     .Input("size: int32")
312     .Input("scale: float")
313     .Input("translation: float")
314     .Output("resized_images: float")
315     .Attr(
316         "T: {int8, uint8, int16, uint16, int32, int64, bfloat16, half, "
317         "float, double}")
318     .Attr("kernel_type: string = 'lanczos3'")
319     .Attr("antialias: bool = true")
320     .SetShapeFn(ResizeShapeFn);
321 
322 // --------------------------------------------------------------------------
323 REGISTER_OP("QuantizedResizeBilinear")
324     .Input("images: T")
325     .Input("size: int32")
326     .Input("min: float")
327     .Input("max: float")
328     .Output("resized_images: T")
329     .Output("out_min: float")
330     .Output("out_max: float")
331     .Attr("T: {quint8, qint32, float}")
332     .Attr("align_corners: bool = false")
333     .Attr("half_pixel_centers: bool = false")
__anon30b5031d0302(InferenceContext* c) 334     .SetShapeFn([](InferenceContext* c) {
335       TF_RETURN_IF_ERROR(ResizeShapeFn(c));
336       ShapeHandle min_shape;
337       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &min_shape));
338       ShapeHandle max_shape;
339       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &max_shape));
340       c->set_output(1, c->MakeShape({}));
341       c->set_output(2, c->MakeShape({}));
342       return Status::OK();
343     });
344 
345 // --------------------------------------------------------------------------
346 REGISTER_OP("ResizeBilinearGrad")
347     .Input("grads: float")
348     .Input("original_image: T")
349     .Output("output: T")
350     .Attr("T: {float, bfloat16, half, double}")
351     .Attr("align_corners: bool = false")
352     .Attr("half_pixel_centers: bool = false")
__anon30b5031d0402(InferenceContext* c) 353     .SetShapeFn([](InferenceContext* c) {
354       c->set_output(0, c->input(1));
355       return Status::OK();
356     });
357 
358 // --------------------------------------------------------------------------
359 REGISTER_OP("ScaleAndTranslateGrad")
360     .Input("grads: T")
361     .Input("original_image: T")
362     .Input("scale: float")
363     .Input("translation: float")
364     .Output("output: T")
365     .Attr("T: {float}")
366     .Attr("kernel_type: string = 'lanczos3'")
367     .Attr("antialias: bool = true")
__anon30b5031d0502(InferenceContext* c) 368     .SetShapeFn([](InferenceContext* c) {
369       c->set_output(0, c->input(1));
370       return Status::OK();
371     });
372 
373 // --------------------------------------------------------------------------
374 REGISTER_OP("ResizeNearestNeighbor")
375     .Input("images: T")
376     .Input("size: int32")
377     .Output("resized_images: T")
378     .Attr(
379         "T: {int8, uint8, int16, uint16, int32, int64, half, float,"
380         "double, bfloat16}")
381     .Attr("align_corners: bool = false")
382     .Attr("half_pixel_centers: bool = false")
383     .SetShapeFn(ResizeShapeFn);
384 
385 // --------------------------------------------------------------------------
386 REGISTER_OP("ResizeNearestNeighborGrad")
387     .Input("grads: T")
388     .Input("size: int32")
389     .Output("output: T")
390     .Attr("T: {uint8, int8, int32, half, float, double, bfloat16}")
391     .Attr("align_corners: bool = false")
392     .Attr("half_pixel_centers: bool = false")
__anon30b5031d0602(InferenceContext* c) 393     .SetShapeFn([](InferenceContext* c) {
394       ShapeHandle input;
395       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
396       ShapeHandle unused;
397       DimensionHandle unused_dim;
398       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
399       TF_RETURN_IF_ERROR(c->WithValue(c->Dim(unused, 0), 2, &unused_dim));
400       const Tensor* size = c->input_tensor(1);
401       if (size == nullptr) {
402         TF_RETURN_IF_ERROR(c->ReplaceDim(input, 1, c->UnknownDim(), &input));
403         TF_RETURN_IF_ERROR(c->ReplaceDim(input, 2, c->UnknownDim(), &input));
404       } else {
405         auto size_vec = size->vec<int32>();
406         TF_RETURN_IF_ERROR(
407             c->ReplaceDim(input, 1, c->MakeDim(size_vec(0)), &input));
408         TF_RETURN_IF_ERROR(
409             c->ReplaceDim(input, 2, c->MakeDim(size_vec(1)), &input));
410       }
411       c->set_output(0, input);
412       return Status::OK();
413     });
414 
415 // --------------------------------------------------------------------------
416 REGISTER_OP("RandomCrop")
417     .Input("image: T")
418     .Input("size: int64")
419     .Output("output: T")
420     .Attr("T: {uint8, int8, int16, int32, int64, float, double}")
421     .Attr("seed: int = 0")
422     .Attr("seed2: int = 0")
423     .SetIsStateful()
424     .Deprecated(8, "Random crop is now pure Python")
__anon30b5031d0702(InferenceContext* c) 425     .SetShapeFn([](InferenceContext* c) {
426       ShapeHandle image;
427       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &image));
428       DimensionHandle channels = c->Dim(image, -1);
429 
430       ShapeHandle unused;
431       TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->Vector(2), &unused));
432 
433       const Tensor* size = c->input_tensor(1);
434       DimensionHandle h;
435       DimensionHandle w;
436       if (size == nullptr) {
437         h = c->UnknownDim();
438         w = c->UnknownDim();
439       } else {
440         auto size_vec = size->vec<int64>();
441         h = c->MakeDim(size_vec(0));
442         w = c->MakeDim(size_vec(1));
443       }
444       c->set_output(0, c->MakeShape({h, w, channels}));
445       return Status::OK();
446     });
447 // TODO(shlens): Support variable rank in RandomCrop.
448 
449 // --------------------------------------------------------------------------
450 REGISTER_OP("DecodeImage")
451     .Input("contents: string")
452     // Setting `channels` to 0 means using the inherent number of channels in
453     // the image.
454     .Attr("channels: int = 0")
455     .Attr("dtype: {uint8, uint16, float32} = DT_UINT8")
456     .Output("image: dtype")
457     .Attr("expand_animations: bool = true")
458     .SetShapeFn(DecodeImageV2ShapeFn);
459 
460 // --------------------------------------------------------------------------
461 REGISTER_OP("DecodeJpeg")
462     .Input("contents: string")
463     .Attr("channels: int = 0")
464     .Attr("ratio: int = 1")
465     .Attr("fancy_upscaling: bool = true")
466     .Attr("try_recover_truncated: bool = false")
467     .Attr("acceptable_fraction: float = 1.0")
468     .Attr("dct_method: string = ''")
469     .Output("image: uint8")
470     .SetShapeFn(DecodeImageShapeFn);
471 
472 // --------------------------------------------------------------------------
473 REGISTER_OP("DecodeAndCropJpeg")
474     .Input("contents: string")
475     .Input("crop_window: int32")
476     .Attr("channels: int = 0")
477     .Attr("ratio: int = 1")
478     .Attr("fancy_upscaling: bool = true")
479     .Attr("try_recover_truncated: bool = false")
480     .Attr("acceptable_fraction: float = 1.0")
481     .Attr("dct_method: string = ''")
482     .Output("image: uint8")
__anon30b5031d0802(InferenceContext* c) 483     .SetShapeFn([](InferenceContext* c) {
484       ShapeHandle unused;
485       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
486       DimensionHandle channels_dim = c->UnknownDim();
487       DimensionHandle h = c->UnknownDim();
488       DimensionHandle w = c->UnknownDim();
489 
490       int32 channels;
491       TF_RETURN_IF_ERROR(c->GetAttr("channels", &channels));
492       if (channels != 0) {
493         if (channels < 0) {
494           return errors::InvalidArgument("channels must be non-negative, got ",
495                                          channels);
496         }
497         channels_dim = c->MakeDim(channels);
498       }
499 
500       DimensionHandle unused_dim;
501       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
502       TF_RETURN_IF_ERROR(c->WithValue(c->Dim(unused, 0), 4, &unused_dim));
503 
504       const Tensor* crop_window = c->input_tensor(1);
505       if (crop_window != nullptr) {
506         auto crop_window_vec = crop_window->vec<int32>();
507         h = c->MakeDim(crop_window_vec(2));
508         w = c->MakeDim(crop_window_vec(3));
509       }
510       c->set_output(0, c->MakeShape({h, w, channels_dim}));
511       return Status::OK();
512     });
513 
514 // --------------------------------------------------------------------------
515 REGISTER_OP("EncodeJpeg")
516     .Input("image: uint8")
517     .Attr("format: {'', 'grayscale', 'rgb'} = ''")
518     .Attr("quality: int = 95")
519     .Attr("progressive: bool = false")
520     .Attr("optimize_size: bool = false")
521     .Attr("chroma_downsampling: bool = true")
522     .Attr("density_unit: {'in', 'cm'} = 'in'")
523     .Attr("x_density: int = 300")
524     .Attr("y_density: int = 300")
525     .Attr("xmp_metadata: string = ''")
526     .Output("contents: string")
527     .SetShapeFn(EncodeImageShapeFn);
528 
529 // --------------------------------------------------------------------------
530 REGISTER_OP("EncodeJpegVariableQuality")
531     .Input("images: uint8")
532     .Input("quality: int32")
533     .Output("contents: string")
534     .SetShapeFn(EncodeImageShapeFn);
535 
536 // --------------------------------------------------------------------------
537 REGISTER_OP("ExtractJpegShape")
538     .Input("contents: string")
539     .Output("image_shape: output_type")
540     .Attr("output_type: {int32, int64} = DT_INT32")
__anon30b5031d0902(InferenceContext* c) 541     .SetShapeFn([](InferenceContext* c) {
542       ShapeHandle unused;
543       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
544       c->set_output(0, c->Vector(3));
545       return Status::OK();
546     });
547 
548 // --------------------------------------------------------------------------
549 REGISTER_OP("AdjustContrast")
550     .Input("images: T")
551     .Input("contrast_factor: float")
552     .Input("min_value: float")
553     .Input("max_value: float")
554     .Output("output: float")
555     .Attr("T: {uint8, int8, int16, int32, int64, float, double}")
556     .Deprecated(2, "Use AdjustContrastv2 instead")
__anon30b5031d0a02(InferenceContext* c) 557     .SetShapeFn([](InferenceContext* c) {
558       // The contrast_factor, min_value, max_value should be scalar only.
559       ShapeHandle unused;
560       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
561       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
562       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
563       return shape_inference::UnchangedShapeWithRankAtLeast(c, 3);
564     });
565 
566 // --------------------------------------------------------------------------
567 REGISTER_OP("AdjustContrastv2")
568     .Input("images: T")
569     .Input("contrast_factor: float")
570     .Output("output: T")
571     .Attr("T: {half, float} = DT_FLOAT")
__anon30b5031d0b02(InferenceContext* c) 572     .SetShapeFn([](InferenceContext* c) {
573       // The contrast_factor should be scalar only.
574       ShapeHandle unused;
575       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
576       return shape_inference::UnchangedShapeWithRankAtLeast(c, 3);
577     });
578 
579 // --------------------------------------------------------------------------
580 REGISTER_OP("AdjustHue")
581     .Input("images: T")
582     .Input("delta: float")
583     .Output("output: T")
584     .Attr("T: {half, float} = DT_FLOAT")
__anon30b5031d0c02(InferenceContext* c) 585     .SetShapeFn([](InferenceContext* c) {
586       return shape_inference::UnchangedShapeWithRankAtLeast(c, 3);
587     });
588 
589 // --------------------------------------------------------------------------
590 REGISTER_OP("AdjustSaturation")
591     .Input("images: T")
592     .Input("scale: float")
593     .Output("output: T")
594     .Attr("T: {half, float} = DT_FLOAT")
__anon30b5031d0d02(InferenceContext* c) 595     .SetShapeFn([](InferenceContext* c) {
596       return shape_inference::UnchangedShapeWithRankAtLeast(c, 3);
597     });
598 
599 // --------------------------------------------------------------------------
600 REGISTER_OP("DecodePng")
601     .Input("contents: string")
602     .Attr("channels: int = 0")
603     .Attr("dtype: {uint8, uint16} = DT_UINT8")
604     .Output("image: dtype")
605     .SetShapeFn(DecodeImageShapeFn);
606 
607 // --------------------------------------------------------------------------
608 REGISTER_OP("EncodePng")
609     .Attr("compression: int = -1")
610     .Attr("T: {uint8, uint16} = DT_UINT8")
611     .Input("image: T")
612     .Output("contents: string")
613     .SetShapeFn(EncodeImageShapeFn);
614 
615 // --------------------------------------------------------------------------
616 REGISTER_OP("DecodeBmp")
617     .Input("contents: string")
618     .Output("image: uint8")
619     .Attr("channels: int = 0")
620     .SetShapeFn(DecodeImageShapeFn);
621 
622 // --------------------------------------------------------------------------
623 REGISTER_OP("DecodeGif")
624     .Input("contents: string")
625     .Output("image: uint8")
__anon30b5031d0e02(InferenceContext* c) 626     .SetShapeFn([](InferenceContext* c) {
627       ShapeHandle unused;
628       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
629       c->set_output(0, c->MakeShape({InferenceContext::kUnknownDim,
630                                      InferenceContext::kUnknownDim,
631                                      InferenceContext::kUnknownDim, 3}));
632       return Status::OK();
633     });
634 
635 // --------------------------------------------------------------------------
636 REGISTER_OP("RGBToHSV")
637     .Input("images: T")
638     .Output("output: T")
639     .Attr("T: {half, bfloat16, float, double} = DT_FLOAT")
640     .SetShapeFn(ColorspaceShapeFn);
641 
642 // --------------------------------------------------------------------------
643 REGISTER_OP("HSVToRGB")
644     .Input("images: T")
645     .Output("output: T")
646     .Attr("T: {half, bfloat16, float, double} = DT_FLOAT")
647     .SetShapeFn(ColorspaceShapeFn);
648 
649 // --------------------------------------------------------------------------
650 REGISTER_OP("DrawBoundingBoxes")
651     .Input("images: T")
652     .Input("boxes: float")
653     .Output("output: T")
654     .Attr("T: {float, half} = DT_FLOAT")
__anon30b5031d0f02(InferenceContext* c) 655     .SetShapeFn([](InferenceContext* c) {
656       // The rank of images should be 4.
657       ShapeHandle images;
658       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &images));
659       // Channel depth should be either 1 (GRY), 3 (RGB), or 4 (RGBA).
660       if (c->ValueKnown(c->Dim(images, 3))) {
661         int64 depth = c->Value(c->Dim(images, 3));
662         if (!(depth == 1 || depth == 3 || depth == 4)) {
663           return errors::InvalidArgument(
664               "Channel depth should be either 1 (GRY), "
665               "3 (RGB), or 4 (RGBA)");
666         }
667       }
668 
669       // The rank of boxes is 3: [batch, num_bounding_boxes, 4].
670       ShapeHandle boxes;
671       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &boxes));
672       // The last value of boxes shape is 4.
673       DimensionHandle unused;
674       TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 2), 4, &unused));
675 
676       // The rank of the input image (rank = 4) has already been restricted
677       // above, and the output is of the same shape as the input.
678       return shape_inference::UnchangedShape(c);
679     });
680 
681 // --------------------------------------------------------------------------
682 REGISTER_OP("DrawBoundingBoxesV2")
683     .Input("images: T")
684     .Input("boxes: float")
685     .Input("colors: float")
686     .Output("output: T")
687     .Attr("T: {float, half} = DT_FLOAT")
__anon30b5031d1002(InferenceContext* c) 688     .SetShapeFn([](InferenceContext* c) {
689       return shape_inference::UnchangedShapeWithRankAtLeast(c, 3);
690     });
691 
692 // --------------------------------------------------------------------------
693 REGISTER_OP("SampleDistortedBoundingBox")
694     .Input("image_size: T")
695     .Input("bounding_boxes: float")
696     .Output("begin: T")
697     .Output("size: T")
698     .Output("bboxes: float")
699     .Attr("T: {uint8, int8, int16, int32, int64}")
700     .Attr("seed: int = 0")
701     .Attr("seed2: int = 0")
702     .Attr("min_object_covered: float = 0.1")
703     .Attr("aspect_ratio_range: list(float) = [0.75, 1.33]")
704     .Attr("area_range: list(float) = [0.05, 1.0]")
705     .Attr("max_attempts: int = 100")
706     .Attr("use_image_if_no_bounding_boxes: bool = false")
707     .SetIsStateful()
__anon30b5031d1102(InferenceContext* c) 708     .SetShapeFn([](InferenceContext* c) {
709       // Get inputs and validate ranks.
710       ShapeHandle image_size;
711       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &image_size));
712       ShapeHandle bounding_boxes;
713       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &bounding_boxes));
714       // image_size: 1-D with [height, width, channels]
715       // bounding_boxes: 3-D with shape [batch, N, 4]
716       DimensionHandle unused;
717       TF_RETURN_IF_ERROR(c->WithValue(c->Dim(image_size, 0), 3, &unused));
718       TF_RETURN_IF_ERROR(c->WithValue(c->Dim(bounding_boxes, 2), 4, &unused));
719 
720       c->set_output(0, c->Vector(3));
721       c->set_output(1, c->Vector(3));
722       c->set_output(2, c->MakeShape({1, 1, 4}));
723       return Status::OK();
724     });
725 
726 REGISTER_OP("SampleDistortedBoundingBoxV2")
727     .Input("image_size: T")
728     .Input("bounding_boxes: float")
729     .Input("min_object_covered: float")
730     .Output("begin: T")
731     .Output("size: T")
732     .Output("bboxes: float")
733     .Attr("T: {uint8, int8, int16, int32, int64}")
734     .Attr("seed: int = 0")
735     .Attr("seed2: int = 0")
736     .Attr("aspect_ratio_range: list(float) = [0.75, 1.33]")
737     .Attr("area_range: list(float) = [0.05, 1.0]")
738     .Attr("max_attempts: int = 100")
739     .Attr("use_image_if_no_bounding_boxes: bool = false")
740     .SetIsStateful()
__anon30b5031d1202(InferenceContext* c) 741     .SetShapeFn([](InferenceContext* c) {
742       // Get inputs and validate ranks.
743       ShapeHandle image_size;
744       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &image_size));
745       ShapeHandle bounding_boxes;
746       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &bounding_boxes));
747       ShapeHandle min_object_covered;
748       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &min_object_covered));
749       // image_size: 1-D with [height, width, channels]
750       // bounding_boxes: 3-D with shape [batch, N, 4]
751       DimensionHandle unused;
752       TF_RETURN_IF_ERROR(c->WithValue(c->Dim(image_size, 0), 3, &unused));
753       TF_RETURN_IF_ERROR(c->WithValue(c->Dim(bounding_boxes, 2), 4, &unused));
754 
755       c->set_output(0, c->Vector(3));
756       c->set_output(1, c->Vector(3));
757       c->set_output(2, c->MakeShape({1, 1, 4}));
758       return Status::OK();
759     });
760 
761 REGISTER_OP("StatelessSampleDistortedBoundingBox")
762     .Input("image_size: T")
763     .Input("bounding_boxes: float")
764     .Input("min_object_covered: float")
765     .Input("seed: Tseed")
766     .Output("begin: T")
767     .Output("size: T")
768     .Output("bboxes: float")
769     .Attr("T: {uint8, int8, int16, int32, int64}")
770     .Attr("Tseed: {int32, int64}")
771     .Attr("aspect_ratio_range: list(float) = [0.75, 1.33]")
772     .Attr("area_range: list(float) = [0.05, 1.0]")
773     .Attr("max_attempts: int = 100")
774     .Attr("use_image_if_no_bounding_boxes: bool = false")
__anon30b5031d1302(InferenceContext* c) 775     .SetShapeFn([](InferenceContext* c) {
776       // Get inputs and validate ranks.
777       ShapeHandle image_size;
778       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &image_size));
779       ShapeHandle bounding_boxes;
780       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &bounding_boxes));
781       ShapeHandle min_object_covered;
782       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &min_object_covered));
783       ShapeHandle seed;
784       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 1, &seed));
785       // image_size: 1-D with [height, width, channels]
786       // bounding_boxes: 3-D with shape [batch, N, 4]
787       DimensionHandle unused;
788       TF_RETURN_IF_ERROR(c->WithValue(c->Dim(image_size, 0), 3, &unused));
789       TF_RETURN_IF_ERROR(c->WithValue(c->Dim(bounding_boxes, 2), 4, &unused));
790       TF_RETURN_IF_ERROR(c->WithValue(c->Dim(seed, 0), 2, &unused));
791 
792       c->set_output(0, c->Vector(3));
793       c->set_output(1, c->Vector(3));
794       c->set_output(2, c->MakeShape({1, 1, 4}));
795 
796       return Status::OK();
797     });
798 
799 // --------------------------------------------------------------------------
800 
801 // glimpse = extract_glimpse(input, size, offsets) extract the glimpse
802 // of size `size` centered at location `offsets` from the input tensor
803 // `input`.
804 //
805 // REQUIRES: input.dims() == 4
806 //
807 REGISTER_OP("ExtractGlimpse")
808     .Input("input: float")
809     .Input("size: int32")
810     .Input("offsets: float")
811     .Output("glimpse: float")
812     .Attr("centered: bool = true")
813     .Attr("normalized: bool = true")
814     .Attr("uniform_noise: bool = true")
815     .Attr("noise: string = 'uniform'")
__anon30b5031d1402(InferenceContext* c) 816     .SetShapeFn([](InferenceContext* c) {
817       ShapeHandle input;
818       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
819       ShapeHandle offsets;
820       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &offsets));
821 
822       DimensionHandle batch_dim;
823       TF_RETURN_IF_ERROR(
824           c->Merge(c->Dim(input, 0), c->Dim(offsets, 0), &batch_dim));
825       DimensionHandle unused;
826       TF_RETURN_IF_ERROR(c->WithValue(c->Dim(offsets, 1), 2, &unused));
827 
828       bool uniform_noise = false;
829       TF_RETURN_IF_ERROR(c->GetAttr("uniform_noise", &uniform_noise));
830       string noise;
831       TF_RETURN_IF_ERROR(c->GetAttr("noise", &noise));
832       if (uniform_noise && (!noise.empty() && noise != "uniform")) {
833         return errors::InvalidArgument(
834             "The uniform_noise and noise should not be specified at the same "
835             "time");
836       }
837 
838       return SetOutputToSizedImage(c, batch_dim, 1 /* size_input_idx */,
839                                    c->Dim(input, 3));
840     });
841 
842 REGISTER_OP("ExtractGlimpseV2")
843     .Input("input: float")
844     .Input("size: int32")
845     .Input("offsets: float")
846     .Output("glimpse: float")
847     .Attr("centered: bool = true")
848     .Attr("normalized: bool = true")
849     .Attr("uniform_noise: bool = true")
850     .Attr("noise: string = 'uniform'")
__anon30b5031d1502(InferenceContext* c) 851     .SetShapeFn([](InferenceContext* c) {
852       ShapeHandle input;
853       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
854       ShapeHandle offsets;
855       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &offsets));
856 
857       DimensionHandle batch_dim;
858       TF_RETURN_IF_ERROR(
859           c->Merge(c->Dim(input, 0), c->Dim(offsets, 0), &batch_dim));
860       DimensionHandle unused;
861       TF_RETURN_IF_ERROR(c->WithValue(c->Dim(offsets, 1), 2, &unused));
862 
863       bool uniform_noise = false;
864       TF_RETURN_IF_ERROR(c->GetAttr("uniform_noise", &uniform_noise));
865       string noise;
866       TF_RETURN_IF_ERROR(c->GetAttr("noise", &noise));
867       if (uniform_noise && (!noise.empty() && noise != "uniform")) {
868         return errors::InvalidArgument(
869             "The uniform_noise and noise should not be specified at the same "
870             "time");
871       }
872 
873       return SetOutputToSizedImage(c, batch_dim, 1 /* size_input_idx */,
874                                    c->Dim(input, 3));
875     });
876 
877 // --------------------------------------------------------------------------
878 
879 REGISTER_OP("CropAndResize")
880     .Input("image: T")
881     .Input("boxes: float")
882     .Input("box_ind: int32")
883     .Input("crop_size: int32")
884     .Output("crops: float")
885     .Attr("T: {uint8, uint16, int8, int16, int32, int64, half, float, double}")
886     .Attr("method: {'bilinear', 'nearest'} = 'bilinear'")
887     .Attr("extrapolation_value: float = 0")
__anon30b5031d1602(InferenceContext* c) 888     .SetShapeFn([](InferenceContext* c) {
889       // Get inputs and validate ranks.
890       ShapeHandle input;
891       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
892       ShapeHandle boxes;
893       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &boxes));
894       ShapeHandle box_ind;
895       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &box_ind));
896 
897       // boxes[0] and box_ind[0] are both num_boxes.
898       DimensionHandle num_boxes_dim;
899       TF_RETURN_IF_ERROR(
900           c->Merge(c->Dim(boxes, 0), c->Dim(box_ind, 0), &num_boxes_dim));
901 
902       // boxes.dim(1) is 4.
903       DimensionHandle unused;
904       TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
905 
906       return SetOutputToSizedImage(c, num_boxes_dim, 3 /* size_input_idx */,
907                                    c->Dim(input, 3));
908     });
909 
910 REGISTER_OP("CropAndResizeGradImage")
911     .Input("grads: float")
912     .Input("boxes: float")
913     .Input("box_ind: int32")
914     .Input("image_size: int32")
915     .Output("output: T")
916     .Attr("T: {float, half, double}")
917     .Attr("method: {'bilinear', 'nearest'} = 'bilinear'")
__anon30b5031d1702(InferenceContext* c) 918     .SetShapeFn([](InferenceContext* c) {
919       ShapeHandle out;
920       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(3, &out));
921       TF_RETURN_IF_ERROR(c->WithRank(out, 4, &out));
922       c->set_output(0, out);
923       return Status::OK();
924     });
925 
926 REGISTER_OP("CropAndResizeGradBoxes")
927     .Input("grads: float")
928     .Input("image: T")
929     .Input("boxes: float")
930     .Input("box_ind: int32")
931     .Output("output: float")
932     .Attr("T: {uint8, uint16, int8, int16, int32, int64, half, float, double}")
933     .Attr("method: {'bilinear'} = 'bilinear'")
__anon30b5031d1802(InferenceContext* c) 934     .SetShapeFn([](InferenceContext* c) {
935       c->set_output(0, c->input(2));
936       return Status::OK();
937     });
938 
939 // --------------------------------------------------------------------------
940 
941 REGISTER_OP("NonMaxSuppression")
942     .Input("boxes: float")
943     .Input("scores: float")
944     .Input("max_output_size: int32")
945     .Output("selected_indices: int32")
946     .Attr("iou_threshold: float = 0.5")
__anon30b5031d1902(InferenceContext* c) 947     .SetShapeFn([](InferenceContext* c) {
948       // Get inputs and validate ranks.
949       ShapeHandle boxes;
950       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes));
951       ShapeHandle scores;
952       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores));
953       ShapeHandle max_output_size;
954       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size));
955       // The boxes is a 2-D float Tensor of shape [num_boxes, 4].
956       DimensionHandle unused;
957       // The boxes[0] and scores[0] are both num_boxes.
958       TF_RETURN_IF_ERROR(
959           c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused));
960       // The boxes[1] is 4.
961       TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
962 
963       c->set_output(0, c->Vector(c->UnknownDim()));
964       return Status::OK();
965     });
966 
967 REGISTER_OP("NonMaxSuppressionV2")
968     .Input("boxes: T")
969     .Input("scores: T")
970     .Input("max_output_size: int32")
971     .Input("iou_threshold: T_threshold")
972     .Output("selected_indices: int32")
973     .Attr("T: {half, float} = DT_FLOAT")
974     .Attr("T_threshold: {half, float} = DT_FLOAT")
__anon30b5031d1a02(InferenceContext* c) 975     .SetShapeFn([](InferenceContext* c) {
976       // Get inputs and validate ranks.
977       ShapeHandle boxes;
978       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes));
979       ShapeHandle scores;
980       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores));
981       ShapeHandle max_output_size;
982       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size));
983       ShapeHandle iou_threshold;
984       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &iou_threshold));
985       // The boxes is a 2-D float Tensor of shape [num_boxes, 4].
986       DimensionHandle unused;
987       // The boxes[0] and scores[0] are both num_boxes.
988       TF_RETURN_IF_ERROR(
989           c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused));
990       // The boxes[1] is 4.
991       TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
992 
993       c->set_output(0, c->Vector(c->UnknownDim()));
994       return Status::OK();
995     });
996 
997 REGISTER_OP("NonMaxSuppressionV3")
998     .Input("boxes: T")
999     .Input("scores: T")
1000     .Input("max_output_size: int32")
1001     .Input("iou_threshold: T_threshold")
1002     .Input("score_threshold: T_threshold")
1003     .Output("selected_indices: int32")
1004     .Attr("T: {half, float} = DT_FLOAT")
1005     .Attr("T_threshold: {half, float} = DT_FLOAT")
1006     .SetShapeFn(NMSShapeFn);
1007 
1008 REGISTER_OP("NonMaxSuppressionV4")
1009     .Input("boxes: T")
1010     .Input("scores: T")
1011     .Input("max_output_size: int32")
1012     .Input("iou_threshold: T_threshold")
1013     .Input("score_threshold: T_threshold")
1014     .Output("selected_indices: int32")
1015     .Output("valid_outputs: int32")
1016     .Attr("T: {half, float} = DT_FLOAT")
1017     .Attr("T_threshold: {half, float} = DT_FLOAT")
1018     .Attr("pad_to_max_output_size: bool = false")
__anon30b5031d1b02(InferenceContext* c) 1019     .SetShapeFn([](InferenceContext* c) {
1020       TF_RETURN_IF_ERROR(NMSShapeFn(c));
1021 
1022       bool pad_to_max;
1023       TF_RETURN_IF_ERROR(c->GetAttr("pad_to_max_output_size", &pad_to_max));
1024       if (pad_to_max) {
1025         // If padded, overwrite the shape of the output to be static.
1026         DimensionHandle output_dim;
1027         TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &output_dim));
1028         c->set_output(0, c->MakeShape({output_dim}));
1029       }
1030       c->set_output(1, c->MakeShape({}));
1031       return Status::OK();
1032     });
1033 
1034 REGISTER_OP("NonMaxSuppressionV5")
1035     .Input("boxes: T")
1036     .Input("scores: T")
1037     .Input("max_output_size: int32")
1038     .Input("iou_threshold: T")
1039     .Input("score_threshold: T")
1040     .Input("soft_nms_sigma: T")
1041     .Output("selected_indices: int32")
1042     .Output("selected_scores: T")
1043     .Output("valid_outputs: int32")
1044     .Attr("T: {half, float} = DT_FLOAT")
1045     .Attr("pad_to_max_output_size: bool = false")
__anon30b5031d1c02(InferenceContext* c) 1046     .SetShapeFn([](InferenceContext* c) {
1047       TF_RETURN_IF_ERROR(SoftNMSShapeFn(c));
1048 
1049       bool pad_to_max;
1050       TF_RETURN_IF_ERROR(c->GetAttr("pad_to_max_output_size", &pad_to_max));
1051       if (pad_to_max) {
1052         // If padded, overwrite the shape of the output to be static.
1053         DimensionHandle output_dim;
1054         TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &output_dim));
1055         c->set_output(0, c->MakeShape({output_dim}));
1056         c->set_output(1, c->MakeShape({output_dim}));
1057       }
1058 
1059       c->set_output(2, c->MakeShape({}));
1060       return Status::OK();
1061     });
1062 
1063 REGISTER_OP("NonMaxSuppressionWithOverlaps")
1064     .Input("overlaps: float")
1065     .Input("scores: float")
1066     .Input("max_output_size: int32")
1067     .Input("overlap_threshold: float")
1068     .Input("score_threshold: float")
1069     .Output("selected_indices: int32")
__anon30b5031d1d02(InferenceContext* c) 1070     .SetShapeFn([](InferenceContext* c) {
1071       // Get inputs and validate ranks.
1072       ShapeHandle overlaps;
1073       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &overlaps));
1074       ShapeHandle scores;
1075       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores));
1076       ShapeHandle max_output_size;
1077       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size));
1078       ShapeHandle overlap_threshold;
1079       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &overlap_threshold));
1080       ShapeHandle score_threshold;
1081       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &score_threshold));
1082       // The boxes is a 2-D float Tensor of shape [num_boxes, 4].
1083       DimensionHandle unused;
1084       // The boxes[0] and scores[0] are both num_boxes.
1085       TF_RETURN_IF_ERROR(
1086           c->Merge(c->Dim(overlaps, 0), c->Dim(scores, 0), &unused));
1087       // The boxes[1] is 4.
1088       TF_RETURN_IF_ERROR(
1089           c->Merge(c->Dim(overlaps, 0), c->Dim(overlaps, 1), &unused));
1090 
1091       c->set_output(0, c->Vector(c->UnknownDim()));
1092       return Status::OK();
1093     });
1094 
1095 REGISTER_OP("CombinedNonMaxSuppression")
1096     .Input("boxes: float")
1097     .Input("scores: float")
1098     .Input("max_output_size_per_class: int32")
1099     .Input("max_total_size: int32")
1100     .Input("iou_threshold: float")
1101     .Input("score_threshold: float")
1102     .Output("nmsed_boxes: float")
1103     .Output("nmsed_scores: float")
1104     .Output("nmsed_classes: float")
1105     .Output("valid_detections: int32")
1106     .Attr("pad_per_class: bool = false")
1107     .Attr("clip_boxes: bool = true")
1108     .SetShapeFn(CombinedNMSShapeFn);
1109 
1110 REGISTER_OP("GenerateBoundingBoxProposals")
1111     .Input("scores: float")
1112     .Input("bbox_deltas: float")
1113     .Input("image_info: float")
1114     .Input("anchors: float")
1115     .Input("nms_threshold: float")
1116     .Input("pre_nms_topn: int32")
1117     .Input("min_size: float")
1118     .Output("rois: float")
1119     .Output("roi_probabilities: float")
1120     .Attr("post_nms_topn: int = 300")
__anon30b5031d1e02(InferenceContext* c) 1121     .SetShapeFn([](InferenceContext* c) -> Status {
1122       // make sure input tensors have are correct rank
1123       ShapeHandle scores, images, bounding_boxes, anchors, nms_threshold,
1124           n_pre_nms, min_box_size;
1125       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &scores));  //(N, H, W, A)
1126       TF_RETURN_IF_ERROR(
1127           c->WithRank(c->input(1), 4, &bounding_boxes));         //(N,H,W,A4)
1128       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &images));  // (N,5)
1129       auto im_info = c->Dim(images, 1);
1130       TF_RETURN_IF_ERROR(c->WithValue(im_info, 5, &im_info));
1131       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 3, &anchors));  // (A4)
1132       // check scalar tensors
1133       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &nms_threshold));
1134       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &n_pre_nms));
1135       TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &min_box_size));
1136 
1137       // TODO(skama): verify that the inputs are compatible
1138       int post_nms_top_n;
1139       TF_RETURN_IF_ERROR(c->GetAttr("post_nms_topn", &post_nms_top_n));
1140       auto roi_shape = c->MakeShape(
1141           {c->Dim(scores, 0), post_nms_top_n, 4});  //(N,post_nms_top_n,4)
1142       auto prob_shape = c->MakeShape(
1143           {c->Dim(scores, 0), post_nms_top_n});  // (N,post_nms_top_n)
1144       c->set_output(0, roi_shape);
1145       c->set_output(1, prob_shape);
1146       return Status::OK();
1147     });
1148 
1149 // V3 op supports fill_value.
1150 // V2 op supports output_shape.
1151 // V1 op is in contrib.
1152 REGISTER_OP("ImageProjectiveTransformV2")
1153     .Input("images: dtype")
1154     .Input("transforms: float32")
1155     .Input("output_shape: int32")
1156     .Attr("dtype: {uint8, int32, int64, float16, float32, float64}")
1157     .Attr("interpolation: string")
1158     .Attr("fill_mode: string = 'CONSTANT'")
1159     .Output("transformed_images: dtype")
__anon30b5031d1f02(InferenceContext* c) 1160     .SetShapeFn([](InferenceContext* c) {
1161       ShapeHandle input;
1162       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
1163       return SetOutputToSizedImage(c, c->Dim(input, 0), 2 /* size_input_idx */,
1164                                    c->Dim(input, 3));
1165     });
1166 
1167 REGISTER_OP("ImageProjectiveTransformV3")
1168     .Input("images: dtype")
1169     .Input("transforms: float32")
1170     .Input("output_shape: int32")
1171     .Input("fill_value: float32")
1172     .Attr("dtype: {uint8, int32, int64, float16, float32, float64}")
1173     .Attr("interpolation: string")
1174     .Attr("fill_mode: string = 'CONSTANT'")
1175     .Output("transformed_images: dtype")
__anon30b5031d2002(InferenceContext* c) 1176     .SetShapeFn([](InferenceContext* c) {
1177       ShapeHandle input;
1178       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
1179       return SetOutputToSizedImage(c, c->Dim(input, 0), 2 /* size_input_idx */,
1180                                    c->Dim(input, 3));
1181     });
1182 
1183 }  // namespace tensorflow
1184