1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/op.h"
18 #include "tensorflow/core/framework/shape_inference.h"
19
20 namespace tensorflow {
21
22 using shape_inference::DimensionHandle;
23 using shape_inference::InferenceContext;
24 using shape_inference::ShapeHandle;
25
26 namespace {
27
28 // Sets output[0] to shape [batch_dim,height,width,channel_dim], where
29 // height and width come from the size_tensor.
SetOutputToSizedImage(InferenceContext * c,DimensionHandle batch_dim,int size_input_idx,DimensionHandle channel_dim)30 Status SetOutputToSizedImage(InferenceContext* c, DimensionHandle batch_dim,
31 int size_input_idx, DimensionHandle channel_dim) {
32 // Verify shape of size input.
33 ShapeHandle size;
34 TF_RETURN_IF_ERROR(c->WithRank(c->input(size_input_idx), 1, &size));
35 DimensionHandle unused;
36 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(size, 0), 2, &unused));
37
38 // Get size values from the size tensor.
39 const Tensor* size_tensor = c->input_tensor(size_input_idx);
40 DimensionHandle width;
41 DimensionHandle height;
42 if (size_tensor == nullptr) {
43 width = c->UnknownDim();
44 height = c->UnknownDim();
45 } else {
46 // TODO(petewarden) - Remove once we have constant evaluation in C++ only.
47 if (size_tensor->dtype() != DT_INT32) {
48 return errors::InvalidArgument(
49 "Bad size input type for SetOutputToSizedImage: Expected DT_INT32 "
50 "but got ",
51 DataTypeString(size_tensor->dtype()), " for input #", size_input_idx,
52 " in ", c->DebugString());
53 }
54 auto vec = size_tensor->vec<int32>();
55 height = c->MakeDim(vec(0));
56 width = c->MakeDim(vec(1));
57 }
58 c->set_output(0, c->MakeShape({batch_dim, height, width, channel_dim}));
59 return Status::OK();
60 }
61
ResizeShapeFn(InferenceContext * c)62 Status ResizeShapeFn(InferenceContext* c) {
63 ShapeHandle input;
64 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
65 return SetOutputToSizedImage(c, c->Dim(input, 0), 1 /* size_input_idx */,
66 c->Dim(input, 3));
67 }
68
DecodeImageShapeFn(InferenceContext * c)69 Status DecodeImageShapeFn(InferenceContext* c) {
70 ShapeHandle unused;
71 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
72 DimensionHandle channels_dim;
73 int32 channels;
74 TF_RETURN_IF_ERROR(c->GetAttr("channels", &channels));
75 if (channels == 0) {
76 channels_dim = c->UnknownDim();
77 } else {
78 if (channels < 0) {
79 return errors::InvalidArgument("channels must be non-negative, got ",
80 channels);
81 }
82 channels_dim = c->MakeDim(channels);
83 }
84
85 c->set_output(0, c->MakeShape({InferenceContext::kUnknownDim,
86 InferenceContext::kUnknownDim, channels_dim}));
87 return Status::OK();
88 }
89
EncodeImageShapeFn(InferenceContext * c)90 Status EncodeImageShapeFn(InferenceContext* c) {
91 ShapeHandle unused;
92 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &unused));
93 c->set_output(0, c->Scalar());
94 return Status::OK();
95 }
96
ColorspaceShapeFn(InferenceContext * c)97 Status ColorspaceShapeFn(InferenceContext* c) {
98 ShapeHandle input;
99 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input));
100
101 // The last dimension value is always 3.
102 DimensionHandle last_dim;
103 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(input, -1), 3, &last_dim));
104 ShapeHandle out;
105 TF_RETURN_IF_ERROR(c->ReplaceDim(input, -1, last_dim, &out));
106 c->set_output(0, out);
107
108 return Status::OK();
109 }
110
NMSShapeFn(InferenceContext * c)111 Status NMSShapeFn(InferenceContext* c) {
112 // Get inputs and validate ranks.
113 ShapeHandle boxes;
114 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes));
115 ShapeHandle scores;
116 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores));
117 ShapeHandle max_output_size;
118 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size));
119 ShapeHandle iou_threshold;
120 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &iou_threshold));
121 ShapeHandle score_threshold;
122 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &score_threshold));
123 // The boxes is a 2-D float Tensor of shape [num_boxes, 4].
124 DimensionHandle unused;
125 // The boxes[0] and scores[0] are both num_boxes.
126 TF_RETURN_IF_ERROR(c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused));
127 // The boxes[1] is 4.
128 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
129
130 c->set_output(0, c->Vector(c->UnknownDim()));
131 return Status::OK();
132 }
133
CombinedNMSShapeFn(InferenceContext * c)134 Status CombinedNMSShapeFn(InferenceContext* c) {
135 // Get inputs and validate ranks
136 ShapeHandle boxes;
137 // boxes is a tensor of Dimensions [batch_size, num_anchors, q, 4]
138 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &boxes));
139 ShapeHandle scores;
140 // scores is a tensor of Dimensions [batch_size, num_anchors, num_classes]
141 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &scores));
142 ShapeHandle max_output_size_per_class;
143 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size_per_class));
144 ShapeHandle max_total_size;
145 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &max_total_size));
146 ShapeHandle unused_shape;
147 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused_shape));
148 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused_shape));
149
150 DimensionHandle unused;
151 // boxes[0] and scores[0] are both batch_size
152 TF_RETURN_IF_ERROR(c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused));
153 // boxes[1] and scores[1] are both num_anchors
154 TF_RETURN_IF_ERROR(c->Merge(c->Dim(boxes, 1), c->Dim(scores, 1), &unused));
155 // The boxes[3] is 4.
156 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 3), 4, &unused));
157
158 DimensionHandle d = c->Dim(boxes, 2);
159 DimensionHandle class_dim = c->Dim(scores, 2);
160 if (c->ValueKnown(d) && c->ValueKnown(class_dim)) {
161 if (c->Value(d) != 1 && c->Value(d) != c->Value(class_dim)) {
162 return errors::InvalidArgument(
163 "third dimension of boxes must be either "
164 "1 or equal to the third dimension of scores");
165 }
166 }
167 DimensionHandle output_dim;
168 DimensionHandle batch_dim = c->Dim(boxes, 0);
169
170 TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(3, &output_dim));
171 if (c->ValueKnown(output_dim) && c->Value(output_dim) <= 0) {
172 return errors::InvalidArgument("max_total_size should be > 0 ");
173 }
174 DimensionHandle size_per_class;
175 TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &size_per_class));
176
177 int64 output_size;
178 bool pad_per_class;
179 TF_RETURN_IF_ERROR(c->GetAttr("pad_per_class", &pad_per_class));
180 if (!pad_per_class) {
181 output_size = c->Value(output_dim);
182 } else {
183 if (c->ValueKnown(size_per_class) && c->Value(size_per_class) <= 0) {
184 return errors::InvalidArgument(
185 "max_output_size_per_class must be > 0 "
186 "if pad_per_class is set to true ");
187 }
188 output_size = std::min(c->Value(output_dim),
189 c->Value(size_per_class) * c->Value(class_dim));
190 }
191 c->set_output(0, c->MakeShape({batch_dim, output_size, 4}));
192 c->set_output(1, c->MakeShape({batch_dim, output_size}));
193 c->set_output(2, c->MakeShape({batch_dim, output_size}));
194 c->set_output(3, c->Vector(batch_dim));
195 return Status::OK();
196 }
197
198 } // namespace
199
200 // --------------------------------------------------------------------------
201 REGISTER_OP("ResizeArea")
202 .Input("images: T")
203 .Input("size: int32")
204 .Output("resized_images: float")
205 .Attr("T: {int8, uint8, int16, uint16, int32, int64, half, float, double}")
206 .Attr("align_corners: bool = false")
207 .SetShapeFn(ResizeShapeFn);
208
209 // --------------------------------------------------------------------------
210 REGISTER_OP("ResizeBicubic")
211 .Input("images: T")
212 .Input("size: int32")
213 .Output("resized_images: float")
214 .Attr("T: {int8, uint8, int16, uint16, int32, int64, half, float, double}")
215 .Attr("align_corners: bool = false")
216 .Attr("half_pixel_centers: bool = false")
217 .SetShapeFn(ResizeShapeFn);
218
219 // --------------------------------------------------------------------------
220 REGISTER_OP("ResizeBicubicGrad")
221 .Input("grads: float")
222 .Input("original_image: T")
223 .Output("output: T")
224 .Attr("T: {float, double}")
225 .Attr("align_corners: bool = false")
226 .Attr("half_pixel_centers: bool = false")
__anon6a71d27f0202(InferenceContext* c) 227 .SetShapeFn([](InferenceContext* c) {
228 c->set_output(0, c->input(1));
229 return Status::OK();
230 });
231
232 // --------------------------------------------------------------------------
233 REGISTER_OP("ResizeBilinear")
234 .Input("images: T")
235 .Input("size: int32")
236 .Output("resized_images: float")
237 .Attr(
238 "T: {int8, uint8, int16, uint16, int32, int64, bfloat16, half, "
239 "float, double}")
240 .Attr("align_corners: bool = false")
241 .Attr("half_pixel_centers: bool = false")
242 .SetShapeFn(ResizeShapeFn);
243
244 // --------------------------------------------------------------------------
245 REGISTER_OP("ScaleAndTranslate")
246 .Input("images: T")
247 .Input("size: int32")
248 .Input("scale: float")
249 .Input("translation: float")
250 .Output("resized_images: float")
251 .Attr(
252 "T: {int8, uint8, int16, uint16, int32, int64, bfloat16, half, "
253 "float, double}")
254 .Attr("kernel_type: string = 'lanczos3'")
255 .Attr("antialias: bool = true")
256 .SetShapeFn(ResizeShapeFn);
257
258 // --------------------------------------------------------------------------
259 REGISTER_OP("QuantizedResizeBilinear")
260 .Input("images: T")
261 .Input("size: int32")
262 .Input("min: float")
263 .Input("max: float")
264 .Output("resized_images: T")
265 .Output("out_min: float")
266 .Output("out_max: float")
267 .Attr("T: {quint8, qint32, float}")
268 .Attr("align_corners: bool = false")
269 .Attr("half_pixel_centers: bool = false")
__anon6a71d27f0302(InferenceContext* c) 270 .SetShapeFn([](InferenceContext* c) {
271 TF_RETURN_IF_ERROR(ResizeShapeFn(c));
272 ShapeHandle min_shape;
273 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &min_shape));
274 ShapeHandle max_shape;
275 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &max_shape));
276 c->set_output(1, c->MakeShape({}));
277 c->set_output(2, c->MakeShape({}));
278 return Status::OK();
279 });
280
281 // --------------------------------------------------------------------------
282 REGISTER_OP("ResizeBilinearGrad")
283 .Input("grads: float")
284 .Input("original_image: T")
285 .Output("output: T")
286 .Attr("T: {float, bfloat16, half, double}")
287 .Attr("align_corners: bool = false")
288 .Attr("half_pixel_centers: bool = false")
__anon6a71d27f0402(InferenceContext* c) 289 .SetShapeFn([](InferenceContext* c) {
290 c->set_output(0, c->input(1));
291 return Status::OK();
292 });
293
294 // --------------------------------------------------------------------------
295 REGISTER_OP("ScaleAndTranslateGrad")
296 .Input("grads: T")
297 .Input("original_image: T")
298 .Input("scale: float")
299 .Input("translation: float")
300 .Output("output: T")
301 .Attr("T: {float}")
302 .Attr("kernel_type: string = 'lanczos3'")
303 .Attr("antialias: bool = true")
__anon6a71d27f0502(InferenceContext* c) 304 .SetShapeFn([](InferenceContext* c) {
305 c->set_output(0, c->input(1));
306 return Status::OK();
307 });
308
309 // --------------------------------------------------------------------------
310 REGISTER_OP("ResizeNearestNeighbor")
311 .Input("images: T")
312 .Input("size: int32")
313 .Output("resized_images: T")
314 .Attr("T: {int8, uint8, int16, uint16, int32, int64, half, float, double}")
315 .Attr("align_corners: bool = false")
316 .Attr("half_pixel_centers: bool = false")
317 .SetShapeFn(ResizeShapeFn);
318
319 // --------------------------------------------------------------------------
320 REGISTER_OP("ResizeNearestNeighborGrad")
321 .Input("grads: T")
322 .Input("size: int32")
323 .Output("output: T")
324 .Attr("T: {uint8, int8, int32, half, float, double}")
325 .Attr("align_corners: bool = false")
326 .Attr("half_pixel_centers: bool = false")
__anon6a71d27f0602(InferenceContext* c) 327 .SetShapeFn([](InferenceContext* c) {
328 ShapeHandle input;
329 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
330 ShapeHandle unused;
331 DimensionHandle unused_dim;
332 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
333 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(unused, 0), 2, &unused_dim));
334 const Tensor* size = c->input_tensor(1);
335 if (size == nullptr) {
336 TF_RETURN_IF_ERROR(c->ReplaceDim(input, 1, c->UnknownDim(), &input));
337 TF_RETURN_IF_ERROR(c->ReplaceDim(input, 2, c->UnknownDim(), &input));
338 } else {
339 auto size_vec = size->vec<int32>();
340 TF_RETURN_IF_ERROR(
341 c->ReplaceDim(input, 1, c->MakeDim(size_vec(0)), &input));
342 TF_RETURN_IF_ERROR(
343 c->ReplaceDim(input, 2, c->MakeDim(size_vec(1)), &input));
344 }
345 c->set_output(0, input);
346 return Status::OK();
347 });
348
349 // --------------------------------------------------------------------------
350 REGISTER_OP("RandomCrop")
351 .Input("image: T")
352 .Input("size: int64")
353 .Output("output: T")
354 .Attr("T: {uint8, int8, int16, int32, int64, float, double}")
355 .Attr("seed: int = 0")
356 .Attr("seed2: int = 0")
357 .SetIsStateful()
358 .Deprecated(8, "Random crop is now pure Python")
__anon6a71d27f0702(InferenceContext* c) 359 .SetShapeFn([](InferenceContext* c) {
360 ShapeHandle image;
361 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 3, &image));
362 DimensionHandle channels = c->Dim(image, -1);
363
364 ShapeHandle unused;
365 TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->Vector(2), &unused));
366
367 const Tensor* size = c->input_tensor(1);
368 DimensionHandle h;
369 DimensionHandle w;
370 if (size == nullptr) {
371 h = c->UnknownDim();
372 w = c->UnknownDim();
373 } else {
374 auto size_vec = size->vec<int64>();
375 h = c->MakeDim(size_vec(0));
376 w = c->MakeDim(size_vec(1));
377 }
378 c->set_output(0, c->MakeShape({h, w, channels}));
379 return Status::OK();
380 });
381 // TODO(shlens): Support variable rank in RandomCrop.
382
383 // --------------------------------------------------------------------------
384 REGISTER_OP("DecodeJpeg")
385 .Input("contents: string")
386 .Attr("channels: int = 0")
387 .Attr("ratio: int = 1")
388 .Attr("fancy_upscaling: bool = true")
389 .Attr("try_recover_truncated: bool = false")
390 .Attr("acceptable_fraction: float = 1.0")
391 .Attr("dct_method: string = ''")
392 .Output("image: uint8")
393 .SetShapeFn(DecodeImageShapeFn);
394
395 // --------------------------------------------------------------------------
396 REGISTER_OP("DecodeAndCropJpeg")
397 .Input("contents: string")
398 .Input("crop_window: int32")
399 .Attr("channels: int = 0")
400 .Attr("ratio: int = 1")
401 .Attr("fancy_upscaling: bool = true")
402 .Attr("try_recover_truncated: bool = false")
403 .Attr("acceptable_fraction: float = 1.0")
404 .Attr("dct_method: string = ''")
405 .Output("image: uint8")
__anon6a71d27f0802(InferenceContext* c) 406 .SetShapeFn([](InferenceContext* c) {
407 ShapeHandle unused;
408 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
409 DimensionHandle channels_dim = c->UnknownDim();
410 DimensionHandle h = c->UnknownDim();
411 DimensionHandle w = c->UnknownDim();
412
413 int32 channels;
414 TF_RETURN_IF_ERROR(c->GetAttr("channels", &channels));
415 if (channels != 0) {
416 if (channels < 0) {
417 return errors::InvalidArgument("channels must be non-negative, got ",
418 channels);
419 }
420 channels_dim = c->MakeDim(channels);
421 }
422
423 DimensionHandle unused_dim;
424 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &unused));
425 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(unused, 0), 4, &unused_dim));
426
427 const Tensor* crop_window = c->input_tensor(1);
428 if (crop_window != nullptr) {
429 auto crop_window_vec = crop_window->vec<int32>();
430 h = c->MakeDim(crop_window_vec(2));
431 w = c->MakeDim(crop_window_vec(3));
432 }
433 c->set_output(0, c->MakeShape({h, w, channels_dim}));
434 return Status::OK();
435 });
436
437 // --------------------------------------------------------------------------
438 REGISTER_OP("EncodeJpeg")
439 .Input("image: uint8")
440 .Attr("format: {'', 'grayscale', 'rgb'} = ''")
441 .Attr("quality: int = 95")
442 .Attr("progressive: bool = false")
443 .Attr("optimize_size: bool = false")
444 .Attr("chroma_downsampling: bool = true")
445 .Attr("density_unit: {'in', 'cm'} = 'in'")
446 .Attr("x_density: int = 300")
447 .Attr("y_density: int = 300")
448 .Attr("xmp_metadata: string = ''")
449 .Output("contents: string")
450 .SetShapeFn(EncodeImageShapeFn);
451
452 // --------------------------------------------------------------------------
453 REGISTER_OP("ExtractJpegShape")
454 .Input("contents: string")
455 .Output("image_shape: output_type")
456 .Attr("output_type: {int32, int64} = DT_INT32")
__anon6a71d27f0902(InferenceContext* c) 457 .SetShapeFn([](InferenceContext* c) {
458 ShapeHandle unused;
459 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
460 c->set_output(0, c->Vector(3));
461 return Status::OK();
462 });
463
464 // --------------------------------------------------------------------------
465 REGISTER_OP("AdjustContrast")
466 .Input("images: T")
467 .Input("contrast_factor: float")
468 .Input("min_value: float")
469 .Input("max_value: float")
470 .Output("output: float")
471 .Attr("T: {uint8, int8, int16, int32, int64, float, double}")
472 .Deprecated(2, "Use AdjustContrastv2 instead")
__anon6a71d27f0a02(InferenceContext* c) 473 .SetShapeFn([](InferenceContext* c) {
474 // The contrast_factor, min_value, max_value should be scalar only.
475 ShapeHandle unused;
476 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
477 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
478 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
479 return shape_inference::UnchangedShapeWithRankAtLeast(c, 3);
480 });
481
482 // --------------------------------------------------------------------------
483 REGISTER_OP("AdjustContrastv2")
484 .Input("images: T")
485 .Input("contrast_factor: float")
486 .Output("output: T")
487 .Attr("T: {half, float} = DT_FLOAT")
__anon6a71d27f0b02(InferenceContext* c) 488 .SetShapeFn([](InferenceContext* c) {
489 // The contrast_factor should be scalar only.
490 ShapeHandle unused;
491 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
492 return shape_inference::UnchangedShapeWithRankAtLeast(c, 3);
493 });
494
495 // --------------------------------------------------------------------------
496 REGISTER_OP("AdjustHue")
497 .Input("images: T")
498 .Input("delta: float")
499 .Output("output: T")
500 .Attr("T: {half, float} = DT_FLOAT")
__anon6a71d27f0c02(InferenceContext* c) 501 .SetShapeFn([](InferenceContext* c) {
502 return shape_inference::UnchangedShapeWithRankAtLeast(c, 3);
503 });
504
505 // --------------------------------------------------------------------------
506 REGISTER_OP("AdjustSaturation")
507 .Input("images: T")
508 .Input("scale: float")
509 .Output("output: T")
510 .Attr("T: {half, float} = DT_FLOAT")
__anon6a71d27f0d02(InferenceContext* c) 511 .SetShapeFn([](InferenceContext* c) {
512 return shape_inference::UnchangedShapeWithRankAtLeast(c, 3);
513 });
514
515 // --------------------------------------------------------------------------
516 REGISTER_OP("DecodePng")
517 .Input("contents: string")
518 .Attr("channels: int = 0")
519 .Attr("dtype: {uint8, uint16} = DT_UINT8")
520 .Output("image: dtype")
521 .SetShapeFn(DecodeImageShapeFn);
522
523 // --------------------------------------------------------------------------
524 REGISTER_OP("EncodePng")
525 .Attr("compression: int = -1")
526 .Attr("T: {uint8, uint16} = DT_UINT8")
527 .Input("image: T")
528 .Output("contents: string")
529 .SetShapeFn(EncodeImageShapeFn);
530
531 // --------------------------------------------------------------------------
532 REGISTER_OP("DecodeBmp")
533 .Input("contents: string")
534 .Output("image: uint8")
535 .Attr("channels: int = 0")
536 .SetShapeFn(DecodeImageShapeFn);
537
538 // --------------------------------------------------------------------------
539 REGISTER_OP("DecodeGif")
540 .Input("contents: string")
541 .Output("image: uint8")
__anon6a71d27f0e02(InferenceContext* c) 542 .SetShapeFn([](InferenceContext* c) {
543 ShapeHandle unused;
544 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 0, &unused));
545 c->set_output(0, c->MakeShape({InferenceContext::kUnknownDim,
546 InferenceContext::kUnknownDim,
547 InferenceContext::kUnknownDim, 3}));
548 return Status::OK();
549 });
550
551 // --------------------------------------------------------------------------
552 REGISTER_OP("RGBToHSV")
553 .Input("images: T")
554 .Output("output: T")
555 .Attr("T: {half, bfloat16, float, double} = DT_FLOAT")
556 .SetShapeFn(ColorspaceShapeFn);
557
558 // --------------------------------------------------------------------------
559 REGISTER_OP("HSVToRGB")
560 .Input("images: T")
561 .Output("output: T")
562 .Attr("T: {half, bfloat16, float, double} = DT_FLOAT")
563 .SetShapeFn(ColorspaceShapeFn);
564
565 // --------------------------------------------------------------------------
566 REGISTER_OP("DrawBoundingBoxes")
567 .Input("images: T")
568 .Input("boxes: float")
569 .Output("output: T")
570 .Attr("T: {float, half} = DT_FLOAT")
__anon6a71d27f0f02(InferenceContext* c) 571 .SetShapeFn([](InferenceContext* c) {
572 // The rank of images should be 4.
573 ShapeHandle images;
574 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &images));
575 // Channel depth should be either 1 (GRY), 3 (RGB), or 4 (RGBA).
576 if (c->ValueKnown(c->Dim(images, 3))) {
577 int64 depth = c->Value(c->Dim(images, 3));
578 if (!(depth == 1 || depth == 3 || depth == 4)) {
579 return errors::InvalidArgument(
580 "Channel depth should be either 1 (GRY), "
581 "3 (RGB), or 4 (RGBA)");
582 }
583 }
584
585 // The rank of boxes is 3: [batch, num_bounding_boxes, 4].
586 ShapeHandle boxes;
587 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &boxes));
588 // The last value of boxes shape is 4.
589 DimensionHandle unused;
590 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 2), 4, &unused));
591
592 // The rank of the input image (rank = 4) has already been restricted
593 // above, and the output is of the same shape as the input.
594 return shape_inference::UnchangedShape(c);
595 });
596
597 // --------------------------------------------------------------------------
598 REGISTER_OP("SampleDistortedBoundingBox")
599 .Input("image_size: T")
600 .Input("bounding_boxes: float")
601 .Output("begin: T")
602 .Output("size: T")
603 .Output("bboxes: float")
604 .Attr("T: {uint8, int8, int16, int32, int64}")
605 .Attr("seed: int = 0")
606 .Attr("seed2: int = 0")
607 .Attr("min_object_covered: float = 0.1")
608 .Attr("aspect_ratio_range: list(float) = [0.75, 1.33]")
609 .Attr("area_range: list(float) = [0.05, 1.0]")
610 .Attr("max_attempts: int = 100")
611 .Attr("use_image_if_no_bounding_boxes: bool = false")
612 .SetIsStateful()
__anon6a71d27f1002(InferenceContext* c) 613 .SetShapeFn([](InferenceContext* c) {
614 // Get inputs and validate ranks.
615 ShapeHandle image_size;
616 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &image_size));
617 ShapeHandle bounding_boxes;
618 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &bounding_boxes));
619 // image_size: 1-D with [height, width, channels]
620 // bounding_boxes: 3-D with shape [batch, N, 4]
621 DimensionHandle unused;
622 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(image_size, 0), 3, &unused));
623 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(bounding_boxes, 2), 4, &unused));
624
625 c->set_output(0, c->Vector(3));
626 c->set_output(1, c->Vector(3));
627 c->set_output(2, c->MakeShape({1, 1, 4}));
628 return Status::OK();
629 });
630
631 REGISTER_OP("SampleDistortedBoundingBoxV2")
632 .Input("image_size: T")
633 .Input("bounding_boxes: float")
634 .Input("min_object_covered: float")
635 .Output("begin: T")
636 .Output("size: T")
637 .Output("bboxes: float")
638 .Attr("T: {uint8, int8, int16, int32, int64}")
639 .Attr("seed: int = 0")
640 .Attr("seed2: int = 0")
641 .Attr("aspect_ratio_range: list(float) = [0.75, 1.33]")
642 .Attr("area_range: list(float) = [0.05, 1.0]")
643 .Attr("max_attempts: int = 100")
644 .Attr("use_image_if_no_bounding_boxes: bool = false")
645 .SetIsStateful()
__anon6a71d27f1102(InferenceContext* c) 646 .SetShapeFn([](InferenceContext* c) {
647 // Get inputs and validate ranks.
648 ShapeHandle image_size;
649 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 1, &image_size));
650 ShapeHandle bounding_boxes;
651 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &bounding_boxes));
652 ShapeHandle min_object_covered;
653 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &min_object_covered));
654 // image_size: 1-D with [height, width, channels]
655 // bounding_boxes: 3-D with shape [batch, N, 4]
656 DimensionHandle unused;
657 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(image_size, 0), 3, &unused));
658 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(bounding_boxes, 2), 4, &unused));
659
660 c->set_output(0, c->Vector(3));
661 c->set_output(1, c->Vector(3));
662 c->set_output(2, c->MakeShape({1, 1, 4}));
663 return Status::OK();
664 });
665
666 // --------------------------------------------------------------------------
667
668 // glimpse = extract_glimpse(input, size, offsets) extract the glimpse
669 // of size `size` centered at location `offsets` from the input tensor
670 // `input`.
671 //
672 // REQUIRES: input.dims() == 4
673 //
674 REGISTER_OP("ExtractGlimpse")
675 .Input("input: float")
676 .Input("size: int32")
677 .Input("offsets: float")
678 .Output("glimpse: float")
679 .Attr("centered: bool = true")
680 .Attr("normalized: bool = true")
681 .Attr("uniform_noise: bool = true")
682 .Attr("noise: string = 'uniform'")
__anon6a71d27f1202(InferenceContext* c) 683 .SetShapeFn([](InferenceContext* c) {
684 ShapeHandle input;
685 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
686 ShapeHandle offsets;
687 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 2, &offsets));
688
689 DimensionHandle batch_dim;
690 TF_RETURN_IF_ERROR(
691 c->Merge(c->Dim(input, 0), c->Dim(offsets, 0), &batch_dim));
692 DimensionHandle unused;
693 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(offsets, 1), 2, &unused));
694
695 bool uniform_noise = false;
696 TF_RETURN_IF_ERROR(c->GetAttr("uniform_noise", &uniform_noise));
697 string noise;
698 TF_RETURN_IF_ERROR(c->GetAttr("noise", &noise));
699 if (uniform_noise && (!noise.empty() && noise != "uniform")) {
700 return errors::InvalidArgument(
701 "The uniform_noise and noise should not be specified at the same "
702 "time");
703 }
704
705 return SetOutputToSizedImage(c, batch_dim, 1 /* size_input_idx */,
706 c->Dim(input, 3));
707 });
708
709 // --------------------------------------------------------------------------
710
711 REGISTER_OP("CropAndResize")
712 .Input("image: T")
713 .Input("boxes: float")
714 .Input("box_ind: int32")
715 .Input("crop_size: int32")
716 .Output("crops: float")
717 .Attr("T: {uint8, uint16, int8, int16, int32, int64, half, float, double}")
718 .Attr("method: {'bilinear', 'nearest'} = 'bilinear'")
719 .Attr("extrapolation_value: float = 0")
__anon6a71d27f1302(InferenceContext* c) 720 .SetShapeFn([](InferenceContext* c) {
721 // Get inputs and validate ranks.
722 ShapeHandle input;
723 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
724 ShapeHandle boxes;
725 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 2, &boxes));
726 ShapeHandle box_ind;
727 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &box_ind));
728
729 // boxes[0] and box_ind[0] are both num_boxes.
730 DimensionHandle num_boxes_dim;
731 TF_RETURN_IF_ERROR(
732 c->Merge(c->Dim(boxes, 0), c->Dim(box_ind, 0), &num_boxes_dim));
733
734 // boxes.dim(1) is 4.
735 DimensionHandle unused;
736 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
737
738 return SetOutputToSizedImage(c, num_boxes_dim, 3 /* size_input_idx */,
739 c->Dim(input, 3));
740 });
741
742 REGISTER_OP("CropAndResizeGradImage")
743 .Input("grads: float")
744 .Input("boxes: float")
745 .Input("box_ind: int32")
746 .Input("image_size: int32")
747 .Output("output: T")
748 .Attr("T: {float, half, double}")
749 .Attr("method: {'bilinear', 'nearest'} = 'bilinear'")
__anon6a71d27f1402(InferenceContext* c) 750 .SetShapeFn([](InferenceContext* c) {
751 ShapeHandle out;
752 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(3, &out));
753 TF_RETURN_IF_ERROR(c->WithRank(out, 4, &out));
754 c->set_output(0, out);
755 return Status::OK();
756 });
757
758 REGISTER_OP("CropAndResizeGradBoxes")
759 .Input("grads: float")
760 .Input("image: T")
761 .Input("boxes: float")
762 .Input("box_ind: int32")
763 .Output("output: float")
764 .Attr("T: {uint8, uint16, int8, int16, int32, int64, half, float, double}")
765 .Attr("method: {'bilinear'} = 'bilinear'")
__anon6a71d27f1502(InferenceContext* c) 766 .SetShapeFn([](InferenceContext* c) {
767 c->set_output(0, c->input(2));
768 return Status::OK();
769 });
770
771 // --------------------------------------------------------------------------
772
773 REGISTER_OP("NonMaxSuppression")
774 .Input("boxes: float")
775 .Input("scores: float")
776 .Input("max_output_size: int32")
777 .Output("selected_indices: int32")
778 .Attr("iou_threshold: float = 0.5")
__anon6a71d27f1602(InferenceContext* c) 779 .SetShapeFn([](InferenceContext* c) {
780 // Get inputs and validate ranks.
781 ShapeHandle boxes;
782 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes));
783 ShapeHandle scores;
784 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores));
785 ShapeHandle max_output_size;
786 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size));
787 // The boxes is a 2-D float Tensor of shape [num_boxes, 4].
788 DimensionHandle unused;
789 // The boxes[0] and scores[0] are both num_boxes.
790 TF_RETURN_IF_ERROR(
791 c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused));
792 // The boxes[1] is 4.
793 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
794
795 c->set_output(0, c->Vector(c->UnknownDim()));
796 return Status::OK();
797 });
798
799 REGISTER_OP("NonMaxSuppressionV2")
800 .Input("boxes: T")
801 .Input("scores: T")
802 .Input("max_output_size: int32")
803 .Input("iou_threshold: float")
804 .Output("selected_indices: int32")
805 .Attr("T: {half, float} = DT_FLOAT")
__anon6a71d27f1702(InferenceContext* c) 806 .SetShapeFn([](InferenceContext* c) {
807 // Get inputs and validate ranks.
808 ShapeHandle boxes;
809 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &boxes));
810 ShapeHandle scores;
811 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores));
812 ShapeHandle max_output_size;
813 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size));
814 ShapeHandle iou_threshold;
815 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &iou_threshold));
816 // The boxes is a 2-D float Tensor of shape [num_boxes, 4].
817 DimensionHandle unused;
818 // The boxes[0] and scores[0] are both num_boxes.
819 TF_RETURN_IF_ERROR(
820 c->Merge(c->Dim(boxes, 0), c->Dim(scores, 0), &unused));
821 // The boxes[1] is 4.
822 TF_RETURN_IF_ERROR(c->WithValue(c->Dim(boxes, 1), 4, &unused));
823
824 c->set_output(0, c->Vector(c->UnknownDim()));
825 return Status::OK();
826 });
827
828 REGISTER_OP("NonMaxSuppressionV3")
829 .Input("boxes: T")
830 .Input("scores: T")
831 .Input("max_output_size: int32")
832 .Input("iou_threshold: float")
833 .Input("score_threshold: float")
834 .Output("selected_indices: int32")
835 .Attr("T: {half, float} = DT_FLOAT")
836 .SetShapeFn(NMSShapeFn);
837
838 REGISTER_OP("NonMaxSuppressionV4")
839 .Input("boxes: T")
840 .Input("scores: T")
841 .Input("max_output_size: int32")
842 .Input("iou_threshold: float")
843 .Input("score_threshold: float")
844 .Output("selected_indices: int32")
845 .Output("valid_outputs: int32")
846 .Attr("T: {half, float} = DT_FLOAT")
847 .Attr("pad_to_max_output_size: bool = false")
__anon6a71d27f1802(InferenceContext* c) 848 .SetShapeFn([](InferenceContext* c) {
849 TF_RETURN_IF_ERROR(NMSShapeFn(c));
850
851 bool pad_to_max;
852 TF_RETURN_IF_ERROR(c->GetAttr("pad_to_max_output_size", &pad_to_max));
853 if (pad_to_max) {
854 // If padded, overwrite the shape of the output to be static.
855 DimensionHandle output_dim;
856 TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(2, &output_dim));
857 c->set_output(0, c->MakeShape({output_dim}));
858 }
859 c->set_output(1, c->MakeShape({}));
860 return Status::OK();
861 });
862
863 REGISTER_OP("NonMaxSuppressionWithOverlaps")
864 .Input("overlaps: float")
865 .Input("scores: float")
866 .Input("max_output_size: int32")
867 .Input("overlap_threshold: float")
868 .Input("score_threshold: float")
869 .Output("selected_indices: int32")
__anon6a71d27f1902(InferenceContext* c) 870 .SetShapeFn([](InferenceContext* c) {
871 // Get inputs and validate ranks.
872 ShapeHandle overlaps;
873 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &overlaps));
874 ShapeHandle scores;
875 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &scores));
876 ShapeHandle max_output_size;
877 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &max_output_size));
878 ShapeHandle overlap_threshold;
879 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &overlap_threshold));
880 ShapeHandle score_threshold;
881 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &score_threshold));
882 // The boxes is a 2-D float Tensor of shape [num_boxes, 4].
883 DimensionHandle unused;
884 // The boxes[0] and scores[0] are both num_boxes.
885 TF_RETURN_IF_ERROR(
886 c->Merge(c->Dim(overlaps, 0), c->Dim(scores, 0), &unused));
887 // The boxes[1] is 4.
888 TF_RETURN_IF_ERROR(
889 c->Merge(c->Dim(overlaps, 0), c->Dim(overlaps, 1), &unused));
890
891 c->set_output(0, c->Vector(c->UnknownDim()));
892 return Status::OK();
893 });
894
895 REGISTER_OP("CombinedNonMaxSuppression")
896 .Input("boxes: float")
897 .Input("scores: float")
898 .Input("max_output_size_per_class: int32")
899 .Input("max_total_size: int32")
900 .Input("iou_threshold: float")
901 .Input("score_threshold: float")
902 .Output("nmsed_boxes: float")
903 .Output("nmsed_scores: float")
904 .Output("nmsed_classes: float")
905 .Output("valid_detections: int32")
906 .Attr("pad_per_class: bool = false")
907 .SetShapeFn(CombinedNMSShapeFn);
908
909 } // namespace tensorflow
910