1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/numeric_op.h"
18 #include "tensorflow/core/framework/op.h"
19 #include "tensorflow/core/framework/shape_inference.h"
20 #include "tensorflow/core/util/mirror_pad_mode.h"
21 #include "tensorflow/core/util/padding.h"
22 #include "tensorflow/core/util/tensor_format.h"
23
24 namespace tensorflow {
25
26 using shape_inference::DimensionHandle;
27 using shape_inference::InferenceContext;
28 using shape_inference::ShapeHandle;
29
30 namespace {
31
FractionalPoolShapeFn(InferenceContext * c)32 Status FractionalPoolShapeFn(InferenceContext* c) {
33 ShapeHandle input;
34 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
35
36 std::vector<float> pooling_ratio;
37 TF_RETURN_IF_ERROR(c->GetAttr("pooling_ratio", &pooling_ratio));
38 if (pooling_ratio.size() != 4) {
39 return errors::InvalidArgument(
40 "pooling_ratio field must specify 4 dimensions");
41 }
42 std::vector<DimensionHandle> output_dims;
43 for (int i = 0; i < 4; ++i) {
44 DimensionHandle d = c->Dim(input, i);
45 if (c->ValueKnown(d)) {
46 // This must match the same logic in the kernel function in
47 // core/kernels/fractional_max_pool_op.cc.
48 auto val = static_cast<int64>(floor(c->Value(d) / pooling_ratio[i]));
49 if (val < 0) {
50 return errors::InvalidArgument("Size computed for dim ", i,
51 " is negative: ", val);
52 }
53 output_dims.push_back(c->MakeDim(val));
54 } else {
55 output_dims.push_back(c->UnknownDim());
56 }
57 }
58
59 c->set_output(0, c->MakeShape(output_dims));
60 c->set_output(1, c->Vector(output_dims[1]));
61 c->set_output(2, c->Vector(output_dims[2]));
62 return Status::OK();
63 }
64
65 } // namespace
66
67 // --------------------------------------------------------------------------
68
69 REGISTER_OP("AvgPool")
70 .Input("value: T")
71 .Output("output: T")
72 .Attr("ksize: list(int) >= 4")
73 .Attr("strides: list(int) >= 4")
74 .Attr(GetPaddingAttrString())
75 .Attr(GetConvnetDataFormatAttrString())
76 .Attr("T: {half, bfloat16, float, double}")
77 .SetShapeFn(shape_inference::AvgPoolShape);
78
79 REGISTER_OP("AvgPoolGrad")
80 .Input("orig_input_shape: int32")
81 .Input("grad: T")
82 .Output("output: T")
83 .Attr("ksize: list(int) >= 4")
84 .Attr("strides: list(int) >= 4")
85 .Attr(GetPaddingAttrString())
86 .Attr(GetConvnetDataFormatAttrString())
87 .Attr("T: {half, bfloat16, float, double}")
__anon3e672dd80202(InferenceContext* c) 88 .SetShapeFn([](InferenceContext* c) {
89 ShapeHandle s;
90 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
91 TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
92 c->set_output(0, s);
93 return Status::OK();
94 });
95
96 // --------------------------------------------------------------------------
97
98 REGISTER_OP("BatchNormWithGlobalNormalization")
99 .Input("t: T")
100 .Input("m: T")
101 .Input("v: T")
102 .Input("beta: T")
103 .Input("gamma: T")
104 .Output("result: T")
105 .Attr("T: numbertype")
106 .Attr("variance_epsilon: float")
107 .Attr("scale_after_normalization: bool")
108 .Deprecated(9, "Use tf.nn.batch_normalization()")
__anon3e672dd80302(InferenceContext* c) 109 .SetShapeFn([](InferenceContext* c) {
110 ShapeHandle input;
111 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
112
113 DimensionHandle last_dim = c->Dim(input, 3);
114 for (int i = 1; i < 5; ++i) { // covers m, v, beta, gamma
115 ShapeHandle vec;
116 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
117 TF_RETURN_IF_ERROR(c->Merge(last_dim, c->Dim(vec, 0), &last_dim));
118 }
119
120 ShapeHandle out;
121 TF_RETURN_IF_ERROR(c->ReplaceDim(input, 3, last_dim, &out));
122 c->set_output(0, out);
123 return Status::OK();
124 });
125
126 REGISTER_OP("BatchNormWithGlobalNormalizationGrad")
127 .Input("t: T")
128 .Input("m: T")
129 .Input("v: T")
130 .Input("gamma: T")
131 .Input("backprop: T")
132 .Output("dx: T")
133 .Output("dm: T")
134 .Output("dv: T")
135 .Output("db: T")
136 .Output("dg: T")
137 .Attr("T: numbertype")
138 .Attr("variance_epsilon: float")
139 .Attr("scale_after_normalization: bool")
140 .Deprecated(9, "Use tf.nn.batch_normalization()")
__anon3e672dd80402(InferenceContext* c) 141 .SetShapeFn([](InferenceContext* c) {
142 ShapeHandle input;
143 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
144 TF_RETURN_IF_ERROR(
145 c->Merge(input, c->input(4), &input)); // with backprop
146
147 DimensionHandle last_dim = c->Dim(input, 3);
148 for (int i = 1; i < 4; ++i) { // covers m, v, gamma
149 ShapeHandle vec;
150 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
151 TF_RETURN_IF_ERROR(c->Merge(last_dim, c->Dim(vec, 0), &last_dim));
152 }
153
154 ShapeHandle dx;
155 TF_RETURN_IF_ERROR(c->ReplaceDim(input, 3, last_dim, &dx));
156 c->set_output(0, dx);
157
158 ShapeHandle vector_shape = c->Vector(last_dim);
159 c->set_output(1, vector_shape);
160 c->set_output(2, vector_shape);
161 c->set_output(3, vector_shape);
162 c->set_output(4, vector_shape);
163 return Status::OK();
164 });
165
166 // --------------------------------------------------------------------------
167
168 REGISTER_OP("FusedBatchNorm")
169 .Input("x: T")
170 .Input("scale: T")
171 .Input("offset: T")
172 .Input("mean: T")
173 .Input("variance: T")
174 .Output("y: T")
175 .Output("batch_mean: T")
176 .Output("batch_variance: T")
177 .Output("reserve_space_1: T")
178 .Output("reserve_space_2: T")
179 .Attr("T: {float}")
180 .Attr("epsilon: float = 0.0001")
181 .Attr(GetConvnetDataFormatAttrString())
182 .Attr("is_training: bool = true")
183 .SetShapeFn(shape_inference::FusedBatchNormShape);
184
185 REGISTER_OP("FusedBatchNormV2")
186 .Input("x: T")
187 .Input("scale: U")
188 .Input("offset: U")
189 .Input("mean: U")
190 .Input("variance: U")
191 .Output("y: T")
192 .Output("batch_mean: U")
193 .Output("batch_variance: U")
194 .Output("reserve_space_1: U")
195 .Output("reserve_space_2: U")
196 .Attr("T: {half, bfloat16, float}")
197 .Attr("U: {float}")
198 .Attr("epsilon: float = 0.0001")
199 .Attr(GetConvnetDataFormatAttrString())
200 .Attr("is_training: bool = true")
201 .SetShapeFn(shape_inference::FusedBatchNormShape);
202
203 REGISTER_OP("FusedBatchNormGrad")
204 .Input("y_backprop: T")
205 .Input("x: T")
206 .Input("scale: T")
207 .Input("reserve_space_1: T")
208 .Input("reserve_space_2: T")
209 .Output("x_backprop: T")
210 .Output("scale_backprop: T")
211 .Output("offset_backprop: T")
212 .Output("reserve_space_3: T")
213 .Output("reserve_space_4: T")
214 .Attr("T: {float}")
215 .Attr("epsilon: float = 0.0001")
216 .Attr(GetConvnetDataFormatAttrString())
217 .Attr("is_training: bool = true")
218 .SetShapeFn(shape_inference::FusedBatchNormGradShape);
219
220 REGISTER_OP("FusedBatchNormGradV2")
221 .Input("y_backprop: T")
222 .Input("x: T")
223 .Input("scale: float")
224 .Input("reserve_space_1: U")
225 .Input("reserve_space_2: U")
226 .Output("x_backprop: T")
227 .Output("scale_backprop: U")
228 .Output("offset_backprop: U")
229 .Output("reserve_space_3: U")
230 .Output("reserve_space_4: U")
231 .Attr("T: {half, bfloat16, float}")
232 .Attr("U: {float}")
233 .Attr("epsilon: float = 0.0001")
234 .Attr(GetConvnetDataFormatAttrString())
235 .Attr("is_training: bool = true")
236 .SetShapeFn(shape_inference::FusedBatchNormGradShape);
237
238 // --------------------------------------------------------------------------
239
240 REGISTER_OP("BiasAdd")
241 .Attr("T: numbertype")
242 .Input("value: T")
243 .Input("bias: T")
244 .Attr(GetConvnetDataFormatAttrString())
245 .Output("output: T")
246 .SetShapeFn(shape_inference::BiasAddShape);
247 // --------------------------------------------------------------------------
248
249 REGISTER_OP("BiasAddGrad")
250 .Attr("T: numbertype")
251 .Input("out_backprop: T")
252 .Attr(GetConvnetDataFormatAttrString())
253 .Output("output: T")
254 .SetShapeFn(shape_inference::BiasAddGradShape);
255 // --------------------------------------------------------------------------
256
257 REGISTER_OP("BiasAddV1")
258 .Attr("T: numbertype")
259 .Input("value: T")
260 .Input("bias: T")
261 .Output("output: T")
262 .SetShapeFn(shape_inference::BiasAddShape);
263 // --------------------------------------------------------------------------
264
265 REGISTER_OP("Conv2D")
266 .Input("input: T")
267 .Input("filter: T")
268 .Output("output: T")
269 .Attr("T: {half, bfloat16, float, double}")
270 .Attr("strides: list(int)")
271 .Attr("use_cudnn_on_gpu: bool = true")
272 .Attr(GetPaddingAttrStringWithExplicit())
273 .Attr(GetExplicitPaddingsAttrString())
274 .Attr(GetConvnetDataFormatAttrString())
275 .Attr("dilations: list(int) = [1, 1, 1, 1]")
276 .SetShapeFn(shape_inference::Conv2DShapeWithExplicitPadding);
277
278 REGISTER_OP("Conv2DBackpropInput")
279 .Input("input_sizes: int32")
280 .Input("filter: T")
281 .Input("out_backprop: T")
282 .Output("output: T")
283 .Attr("T: {half, bfloat16, float, double}")
284 .Attr("strides: list(int)")
285 .Attr("use_cudnn_on_gpu: bool = true")
286 .Attr(GetPaddingAttrStringWithExplicit())
287 .Attr(GetExplicitPaddingsAttrString())
288 .Attr(GetConvnetDataFormatAttrString())
289 .Attr("dilations: list(int) = [1, 1, 1, 1]")
__anon3e672dd80502(InferenceContext* c) 290 .SetShapeFn([](InferenceContext* c) {
291 ShapeHandle s;
292 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
293 TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
294 c->set_output(0, s);
295 return Status::OK();
296 });
297
298 // TODO(jeff): Instead of 'use_cudnn_for_gpu', maybe we should have a
299 // more general string attribute ('kernel_impl'?) that can be used to
300 // select among several possible implementations.
301 REGISTER_OP("Conv2DBackpropFilter")
302 .Input("input: T")
303 .Input("filter_sizes: int32")
304 .Input("out_backprop: T")
305 .Output("output: T")
306 .Attr("T: {half, bfloat16, float, double}")
307 .Attr("strides: list(int)")
308 .Attr("use_cudnn_on_gpu: bool = true")
309 .Attr(GetPaddingAttrStringWithExplicit())
310 .Attr(GetExplicitPaddingsAttrString())
311 .Attr(GetConvnetDataFormatAttrString())
312 .Attr("dilations: list(int) = [1, 1, 1, 1]")
__anon3e672dd80602(InferenceContext* c) 313 .SetShapeFn([](InferenceContext* c) {
314 ShapeHandle s;
315 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
316 TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
317 c->set_output(0, s);
318 return Status::OK();
319 });
320
321 REGISTER_OP("_FusedConv2D")
322 .Input("input: T")
323 .Input("filter: T")
324 .Input("args: num_args * T")
325 .Output("output: T")
326 .Attr("T: {float, double}")
327 .Attr("num_args: int >= 0")
328 .Attr("strides: list(int)")
329 .Attr(GetPaddingAttrString())
330 .Attr(GetConvnetDataFormatAttrString())
331 .Attr("dilations: list(int) = [1, 1, 1, 1]")
332 .Attr("use_cudnn_on_gpu: bool = true")
333 .Attr("fused_ops: list(string) = []")
334 // Attributes for the FusedBatchNorm ------------------------------------ //
335 .Attr("epsilon: float = 0.0001")
336 // ---------------------------------------------------------------------- //
337 .SetShapeFn(shape_inference::Conv2DShape)
338 .Doc(R"doc(
339 *NOTE*: Do not invoke this operator directly in Python. Grappler is
340 expected to create these operators.
341 )doc");
342
343 namespace {
344
CommonFusedConvCalculations(InferenceContext * c,bool has_resize)345 Status CommonFusedConvCalculations(InferenceContext* c, bool has_resize) {
346 ShapeHandle input;
347 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
348
349 ShapeHandle resized = input;
350 int paddings_index = 1;
351 int filter_index = 2;
352 if (has_resize) {
353 paddings_index = 2;
354 filter_index = 3;
355
356 ShapeHandle unused_size;
357 TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->Vector(2), &unused_size));
358
359 const Tensor* size = c->input_tensor(1);
360 DimensionHandle new_height = c->UnknownDim();
361 DimensionHandle new_width = c->UnknownDim();
362 if (size != nullptr) {
363 new_height = c->MakeDim(size->flat<int32>()(0));
364 new_width = c->MakeDim(size->flat<int32>()(1));
365 }
366 TF_RETURN_IF_ERROR(c->ReplaceDim(resized, 1, new_height, &resized));
367 TF_RETURN_IF_ERROR(c->ReplaceDim(resized, 2, new_width, &resized));
368 }
369
370 ShapeHandle paddings;
371 TF_RETURN_IF_ERROR(c->WithRank(c->input(paddings_index), 2, &paddings));
372 TF_RETURN_IF_ERROR(
373 c->WithRank(resized, c->Value(c->Dim(paddings, 0)), &resized));
374 TF_RETURN_IF_ERROR(
375 c->Merge(paddings, c->Matrix(c->Rank(resized), 2), &paddings));
376
377 const Tensor* paddings_t = c->input_tensor(paddings_index);
378 ShapeHandle padded;
379 if (paddings_t != nullptr) {
380 std::vector<DimensionHandle> output_dims;
381 for (int i = 0; i < 4; ++i) {
382 DimensionHandle dim = c->Dim(resized, i);
383 int64 p0 = static_cast<int64>(paddings_t->matrix<int32>()(i, 0));
384 int64 p1 = static_cast<int64>(paddings_t->matrix<int32>()(i, 1));
385 if (p0 < 0 || p1 < 0) {
386 return errors::InvalidArgument("Paddings must be non-negative");
387 }
388
389 TF_RETURN_IF_ERROR(c->Add(dim, p0 + p1, &dim));
390 output_dims.push_back(dim);
391 }
392 padded = c->MakeShape(output_dims);
393 } else {
394 padded = c->UnknownShapeOfRank(4);
395 }
396
397 // Work out the convolution's effect with 'padded' as the input.
398 ShapeHandle filter;
399 TF_RETURN_IF_ERROR(c->WithRank(c->input(filter_index), 4, &filter));
400 std::vector<int32> strides;
401 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
402 if (strides.size() != 4) {
403 return errors::InvalidArgument(
404 "Operation requires the stride attribute to contain 4 values, but ",
405 "got: ", strides.size());
406 }
407
408 int32 stride_rows = strides[1];
409 int32 stride_cols = strides[2];
410
411 DimensionHandle batch_size_dim = c->Dim(padded, 0);
412 DimensionHandle in_rows_dim = c->Dim(padded, 1);
413 DimensionHandle in_cols_dim = c->Dim(padded, 2);
414 DimensionHandle filter_rows_dim = c->Dim(filter, 0);
415 DimensionHandle filter_cols_dim = c->Dim(filter, 1);
416 DimensionHandle output_depth_dim = c->Dim(filter, 3);
417
418 DimensionHandle unused;
419 TF_RETURN_IF_ERROR(c->Merge(c->Dim(padded, 3), c->Dim(filter, 2), &unused));
420
421 Padding padding;
422 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
423
424 DimensionHandle output_rows, output_cols;
425 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
426 c, in_rows_dim, filter_rows_dim, stride_rows, padding, &output_rows));
427 TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
428 c, in_cols_dim, filter_cols_dim, stride_cols, padding, &output_cols));
429
430 ShapeHandle output_shape = c->MakeShape(
431 {batch_size_dim, output_rows, output_cols, output_depth_dim});
432 c->set_output(0, output_shape);
433 return Status::OK();
434 }
435
436 } // namespace
437
438 REGISTER_OP("DataFormatDimMap")
439 .Input("x: T")
440 .Output("y: T")
441 .Attr("T: {int32, int64} = DT_INT32")
442 .Attr("src_format: string = 'NHWC'")
443 .Attr("dst_format: string = 'NCHW'")
444 .SetShapeFn(shape_inference::UnchangedShape);
445
446 REGISTER_OP("DataFormatVecPermute")
447 .Input("x: T")
448 .Output("y: T")
449 .Attr("T: {int32, int64} = DT_INT32")
450 .Attr("src_format: string = 'NHWC'")
451 .Attr("dst_format: string = 'NCHW'")
452 .SetShapeFn(shape_inference::UnchangedShape);
453
454 REGISTER_OP("FusedResizeAndPadConv2D")
455 .Input("input: T")
456 .Input("size: int32")
457 .Input("paddings: int32")
458 .Input("filter: T")
459 .Output("output: T")
460 .Attr("T: {half, float, double}")
461 .Attr("resize_align_corners: bool = false")
462 .Attr(GetMirrorPadModeAttrString())
463 .Attr("strides: list(int)")
464 .Attr(GetPaddingAttrString())
__anon3e672dd80802(InferenceContext* c) 465 .SetShapeFn([](InferenceContext* c) {
466 return CommonFusedConvCalculations(c, true /* has_resize */);
467 });
468
469 REGISTER_OP("FusedPadConv2D")
470 .Input("input: T")
471 .Input("paddings: int32")
472 .Input("filter: T")
473 .Output("output: T")
474 .Attr("T: {half, float, double}")
475 .Attr(GetMirrorPadModeAttrString())
476 .Attr("strides: list(int)")
477 .Attr(GetPaddingAttrString())
__anon3e672dd80902(InferenceContext* c) 478 .SetShapeFn([](InferenceContext* c) {
479 return CommonFusedConvCalculations(c, false /* has_resize */);
480 });
481
482 // --------------------------------------------------------------------------
483
484 REGISTER_OP("DepthwiseConv2dNative")
485 .Input("input: T")
486 .Input("filter: T")
487 .Output("output: T")
488 .Attr("T: {half, bfloat16, float, double}")
489 .Attr("strides: list(int)")
490 .Attr(GetPaddingAttrString())
491 .Attr(GetConvnetDataFormatAttrString())
492 .Attr("dilations: list(int) = [1, 1, 1, 1]")
493 .SetShapeFn(shape_inference::DepthwiseConv2DNativeShape);
494
495 REGISTER_OP("DepthwiseConv2dNativeBackpropInput")
496 .Input("input_sizes: int32")
497 .Input("filter: T")
498 .Input("out_backprop: T")
499 .Output("output: T")
500 .Attr("T: {half, bfloat16, float, double}")
501 .Attr("strides: list(int)")
502 .Attr(GetPaddingAttrString())
503 .Attr(GetConvnetDataFormatAttrString())
504 .Attr("dilations: list(int) = [1, 1, 1, 1]")
__anon3e672dd80a02(InferenceContext* c) 505 .SetShapeFn([](InferenceContext* c) {
506 ShapeHandle s;
507 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
508 TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
509 c->set_output(0, s);
510 return Status::OK();
511 });
512
513 REGISTER_OP("DepthwiseConv2dNativeBackpropFilter")
514 .Input("input: T")
515 .Input("filter_sizes: int32")
516 .Input("out_backprop: T")
517 .Output("output: T")
518 .Attr("T: {half, bfloat16, float, double}")
519 .Attr("strides: list(int)")
520 .Attr(GetPaddingAttrString())
521 .Attr(GetConvnetDataFormatAttrString())
522 .Attr("dilations: list(int) = [1, 1, 1, 1]")
__anon3e672dd80b02(InferenceContext* c) 523 .SetShapeFn([](InferenceContext* c) {
524 ShapeHandle s;
525 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
526 TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
527 c->set_output(0, s);
528 return Status::OK();
529 });
530
531 // --------------------------------------------------------------------------
532 REGISTER_OP("Conv3D")
533 .Input("input: T")
534 .Input("filter: T")
535 .Output("output: T")
536 .Attr("T: {half, bfloat16, float, double}")
537 .Attr("strides: list(int) >= 5")
538 .Attr(GetPaddingAttrString())
539 .Attr(GetConvnet3dDataFormatAttrString())
540 .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
541 .SetShapeFn(shape_inference::Conv3DShape);
542
543 REGISTER_OP("Conv3DBackpropInput")
544 .Input("input: T")
545 .Input("filter: T")
546 .Input("out_backprop: T")
547 .Output("output: T")
548 .Attr("T: {half, float, double}")
549 .Attr("strides: list(int) >= 5")
550 .Attr(GetPaddingAttrString())
551 .Deprecated(10, "Use Conv3DBackpropInputV2")
552 .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
__anon3e672dd80c02(InferenceContext* c) 553 .SetShapeFn([](InferenceContext* c) {
554 return UnchangedShapeWithRank(c, 5);
555 });
556
557 REGISTER_OP("Conv3DBackpropFilter")
558 .Input("input: T")
559 .Input("filter: T")
560 .Input("out_backprop: T")
561 .Output("output: T")
562 .Attr("T: {half, float, double}")
563 .Attr("strides: list(int) >= 5")
564 .Attr(GetPaddingAttrString())
565 .Deprecated(10, "Use Conv3DBackpropFilterV2")
566 .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
__anon3e672dd80d02(InferenceContext* c) 567 .SetShapeFn([](InferenceContext* c) {
568 ShapeHandle out;
569 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 5, &out));
570 c->set_output(0, out);
571 return Status::OK();
572 });
573
574 REGISTER_OP("Conv3DBackpropInputV2")
575 .Input("input_sizes: Tshape")
576 .Input("filter: T")
577 .Input("out_backprop: T")
578 .Output("output: T")
579 .Attr("T: {half, bfloat16, float, double}")
580 .Attr("strides: list(int) >= 5")
581 .Attr(GetPaddingAttrString())
582 .Attr(GetConvnet3dDataFormatAttrString())
583 .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
584 .Attr("Tshape: {int32, int64} = DT_INT32")
__anon3e672dd80e02(InferenceContext* c) 585 .SetShapeFn([](InferenceContext* c) {
586 ShapeHandle s;
587 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
588 TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
589 c->set_output(0, s);
590 return Status::OK();
591 });
592
593 REGISTER_OP("Conv3DBackpropFilterV2")
594 .Input("input: T")
595 .Input("filter_sizes: int32")
596 .Input("out_backprop: T")
597 .Output("output: T")
598 .Attr("T: {half, bfloat16, float, double}")
599 .Attr("strides: list(int) >= 5")
600 .Attr(GetPaddingAttrString())
601 .Attr(GetConvnet3dDataFormatAttrString())
602 .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
__anon3e672dd80f02(InferenceContext* c) 603 .SetShapeFn([](InferenceContext* c) {
604 ShapeHandle s;
605 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
606 TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
607 c->set_output(0, s);
608 return Status::OK();
609 });
610
611 // --------------------------------------------------------------------------
612
613 REGISTER_OP("AvgPool3D")
614 .Input("input: T")
615 .Output("output: T")
616 .Attr("ksize: list(int) >= 5")
617 .Attr("strides: list(int) >= 5")
618 .Attr(GetPaddingAttrString())
619 .Attr(GetConvnet3dDataFormatAttrString())
620 .Attr("T: {half, bfloat16, float, double}")
621 .SetShapeFn(shape_inference::Pool3DShape);
622
623 REGISTER_OP("AvgPool3DGrad")
624 .Input("orig_input_shape: int32")
625 .Input("grad: T")
626 .Output("output: T")
627 .Attr("ksize: list(int) >= 5")
628 .Attr("strides: list(int) >= 5")
629 .Attr(GetPaddingAttrString())
630 .Attr(GetConvnet3dDataFormatAttrString())
631 .Attr("T: {half, bfloat16, float, double}")
__anon3e672dd81002(InferenceContext* c) 632 .SetShapeFn([](InferenceContext* c) {
633 ShapeHandle s;
634 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
635 TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
636 c->set_output(0, s);
637 return Status::OK();
638 });
639
640 // --------------------------------------------------------------------------
641
642 REGISTER_OP("MaxPool3D")
643 .Input("input: T")
644 .Output("output: T")
645 .Attr("ksize: list(int) >= 5")
646 .Attr("strides: list(int) >= 5")
647 .Attr(GetPaddingAttrString())
648 .Attr(GetConvnet3dDataFormatAttrString())
649 .Attr("T: {half, bfloat16, float}")
650 .SetShapeFn(shape_inference::Pool3DShape);
651
652 REGISTER_OP("MaxPool3DGrad")
653 .Input("orig_input: TInput")
654 .Input("orig_output: TInput")
655 .Input("grad: T")
656 .Output("output: T")
657 .Attr("ksize: list(int) >= 5")
658 .Attr("strides: list(int) >= 5")
659 .Attr(GetPaddingAttrString())
660 .Attr(GetConvnet3dDataFormatAttrString())
661 .Attr("T: {half, bfloat16, float} = DT_FLOAT")
662 .Attr("TInput: {half, bfloat16, float} = DT_FLOAT")
__anon3e672dd81102(InferenceContext* c) 663 .SetShapeFn([](InferenceContext* c) {
664 return UnchangedShapeWithRank(c, 5);
665 });
666
667 REGISTER_OP("MaxPool3DGradGrad")
668 .Input("orig_input: T")
669 .Input("orig_output: T")
670 .Input("grad: T")
671 .Output("output: T")
672 .Attr("ksize: list(int) >= 5 ")
673 .Attr("strides: list(int) >= 5")
674 .Attr(GetPaddingAttrString())
675 .Attr(GetConvnet3dDataFormatAttrString())
676 .Attr("T: realnumbertype")
__anon3e672dd81202(InferenceContext* c) 677 .SetShapeFn([](InferenceContext* c) {
678 TF_RETURN_IF_ERROR(shape_inference::Pool3DShape(c));
679 ShapeHandle unused;
680 // Validate 'orig_input' is the same shape as 'grad'
681 TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(2), &unused));
682 // Validate 'orig_output' is same shape as 'output'
683 TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->output(0), &unused));
684 return Status::OK();
685 });
686
687 // --------------------------------------------------------------------------
688
689 REGISTER_OP("L2Loss")
690 .Input("t: T")
691 .Output("output: T")
692 .Attr("T: {half, bfloat16, float, double}")
693 .SetShapeFn(shape_inference::ScalarShape);
694
695 // --------------------------------------------------------------------------
696
697 REGISTER_OP("LRN")
698 .Input("input: T")
699 .Output("output: T")
700 .Attr("depth_radius: int = 5")
701 .Attr("bias: float = 1.0")
702 .Attr("alpha: float = 1.0")
703 .Attr("beta: float = 0.5")
704 .Attr("T: {half, bfloat16, float} = DT_FLOAT")
__anon3e672dd81302(InferenceContext* c) 705 .SetShapeFn([](InferenceContext* c) {
706 return UnchangedShapeWithRank(c, 4);
707 });
708
709 REGISTER_OP("LRNGrad")
710 .Input("input_grads: T")
711 .Input("input_image: T")
712 .Input("output_image: T")
713 .Output("output: T")
714 .Attr("depth_radius: int = 5")
715 .Attr("bias: float = 1.0")
716 .Attr("alpha: float = 1.0")
717 .Attr("beta: float = 0.5")
718 .Attr("T: {half, bfloat16, float} = DT_FLOAT")
__anon3e672dd81402(InferenceContext* c) 719 .SetShapeFn([](InferenceContext* c) {
720 ShapeHandle s;
721 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &s)); // input_grads
722 TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s)); // input_image
723 TF_RETURN_IF_ERROR(c->Merge(s, c->input(2), &s)); // output_image
724 c->set_output(0, s);
725 return Status::OK();
726 });
727
728 // --------------------------------------------------------------------------
729
730 REGISTER_OP("MaxPool")
731 .Attr(
732 "T: {half, bfloat16, float, double, int32, int64, uint8, int16, int8, "
733 "uint16, qint8} = DT_FLOAT")
734 .Attr("ksize: list(int) >= 4")
735 .Attr("strides: list(int) >= 4")
736 .Attr(GetPaddingAttrString())
737 .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
738 .Input("input: T")
739 .Output("output: T")
740 .SetShapeFn(shape_inference::MaxPoolShape);
741
742 REGISTER_OP("MaxPoolV2")
743 .Attr(
744 "T: {half, bfloat16, float, double, int32, int64, uint8, int16, int8, "
745 "uint16, qint8} = DT_FLOAT")
746 .Attr(GetPaddingAttrString())
747 .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
748 .Input("input: T")
749 .Input("ksize: int32")
750 .Input("strides: int32")
751 .Output("output: T")
__anon3e672dd81502(InferenceContext* c) 752 .SetShapeFn([](InferenceContext* c) {
753 TF_RETURN_IF_ERROR(shape_inference::MaxPoolV2Shape(c, 3));
754 return Status::OK();
755 });
756
757 REGISTER_OP("MaxPoolGrad")
758 .Attr("ksize: list(int) >= 4")
759 .Attr("strides: list(int) >= 4")
760 .Attr(GetPaddingAttrString())
761 .Attr(GetConvnetDataFormatAttrString())
762 .Input("orig_input: T")
763 .Input("orig_output: T")
764 .Input("grad: T")
765 .Output("output: T")
766 .Attr("T: realnumbertype = DT_FLOAT")
__anon3e672dd81602(InferenceContext* c) 767 .SetShapeFn([](InferenceContext* c) {
768 return UnchangedShapeWithRank(c, 4);
769 });
770
771 REGISTER_OP("MaxPoolGradV2")
772 .Attr(GetPaddingAttrString())
773 .Attr(GetConvnetDataFormatAttrString())
774 .Input("orig_input: T")
775 .Input("orig_output: T")
776 .Input("grad: T")
777 .Input("ksize: int32")
778 .Input("strides: int32")
779 .Output("output: T")
780 .Attr("T: realnumbertype = DT_FLOAT")
__anon3e672dd81702(InferenceContext* c) 781 .SetShapeFn([](InferenceContext* c) {
782 return UnchangedShapeWithRank(c, 4);
783 });
784
785 REGISTER_OP("MaxPoolGradGrad")
786 .Attr("ksize: list(int) >= 4")
787 .Attr("strides: list(int) >= 4")
788 .Attr(GetPaddingAttrString())
789 .Attr(GetConvnetDataFormatAttrString())
790 .Input("orig_input: T")
791 .Input("orig_output: T")
792 .Input("grad: T")
793 .Output("output: T")
794 .Attr("T: realnumbertype")
__anon3e672dd81802(InferenceContext* c) 795 .SetShapeFn([](InferenceContext* c) {
796 TF_RETURN_IF_ERROR(shape_inference::MaxPoolShape(c));
797 ShapeHandle unused;
798 // Validate 'orig_input' is the same shape as 'grad'
799 TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(2), &unused));
800 // Validate 'orig_output' is same shape as 'output'
801 TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->output(0), &unused));
802 return Status::OK();
803 });
804
805 REGISTER_OP("MaxPoolGradGradV2")
806 .Attr(GetPaddingAttrString())
807 .Attr(GetConvnetDataFormatAttrString())
808 .Input("orig_input: T")
809 .Input("orig_output: T")
810 .Input("grad: T")
811 .Input("ksize: int32")
812 .Input("strides: int32")
813 .Output("output: T")
814 .Attr("T: realnumbertype")
__anon3e672dd81902(InferenceContext* c) 815 .SetShapeFn([](InferenceContext* c) {
816 TF_RETURN_IF_ERROR(shape_inference::MaxPoolV2Shape(c, 5));
817 ShapeHandle unused;
818 // Validate 'orig_input' is the same shape as 'grad'
819 TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(2), &unused));
820 // Validate 'orig_output' is same shape as 'output'
821 TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->output(0), &unused));
822 return Status::OK();
823 });
824
825 REGISTER_OP("MaxPoolWithArgmax")
826 .Attr("ksize: list(int) >= 4")
827 .Attr("strides: list(int) >= 4")
828 .Attr("Targmax: {int32, int64} = DT_INT64")
829 .Attr(GetPaddingAttrString())
830 .Attr("include_batch_in_index: bool = false")
831 .Input("input: T")
832 .Output("output: T")
833 .Output("argmax: Targmax")
834 .Attr("T: realnumbertype")
__anon3e672dd81a02(InferenceContext* c) 835 .SetShapeFn([](InferenceContext* c) {
836 TF_RETURN_IF_ERROR(shape_inference::MaxPoolShape(c));
837 c->set_output(1, c->output(0));
838 return Status::OK();
839 });
840
841 REGISTER_OP("MaxPoolGradWithArgmax")
842 .Attr("ksize: list(int) >= 4")
843 .Attr("strides: list(int) >= 4")
844 .Attr(GetPaddingAttrString())
845 .Attr("include_batch_in_index: bool = false")
846 .Attr("Targmax: {int32, int64}")
847 .Input("input: T")
848 .Input("grad: T")
849 .Input("argmax: Targmax")
850 .Output("output: T")
851 .Attr("T: realnumbertype")
__anon3e672dd81b02(InferenceContext* c) 852 .SetShapeFn([](InferenceContext* c) {
853 return UnchangedShapeWithRank(c, 4);
854 });
855
856 REGISTER_OP("MaxPoolGradGradWithArgmax")
857 .Attr("ksize: list(int) >= 4")
858 .Attr("strides: list(int) >= 4")
859 .Attr(GetPaddingAttrString())
860 .Attr("include_batch_in_index: bool = false")
861 .Attr("Targmax: {int32, int64}")
862 .Input("input: T")
863 .Input("grad: T")
864 .Input("argmax: Targmax")
865 .Output("output: T")
866 .Attr("T: realnumbertype")
__anon3e672dd81c02(InferenceContext* c) 867 .SetShapeFn([](InferenceContext* c) {
868 TF_RETURN_IF_ERROR(shape_inference::MaxPoolShape(c));
869 ShapeHandle unused;
870 // Validate 'orig_input' is the same shape as 'grad'
871 TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &unused));
872 // Validate 'argmax' is same shape as 'output'
873 TF_RETURN_IF_ERROR(c->Merge(c->input(2), c->output(0), &unused));
874 return Status::OK();
875 });
876
877 // --------------------------------------------------------------------------
878
879 REGISTER_OP("Dilation2D")
880 .Input("input: T")
881 .Input("filter: T")
882 .Output("output: T")
883 .Attr("T: realnumbertype")
884 .Attr("strides: list(int) >= 4")
885 .Attr("rates: list(int) >= 4")
886 .Attr(GetPaddingAttrString())
__anon3e672dd81d02(InferenceContext* c) 887 .SetShapeFn([](InferenceContext* c) {
888 ShapeHandle input_shape;
889 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
890 ShapeHandle filter_shape;
891 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &filter_shape));
892
893 std::vector<int32> strides;
894 TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
895 if (strides.size() != 4) {
896 return errors::InvalidArgument(
897 "Dilation2D requires the stride attribute to contain 4 values, but "
898 "got: ",
899 strides.size());
900 }
901
902 std::vector<int32> rates;
903 TF_RETURN_IF_ERROR(c->GetAttr("rates", &rates));
904 if (rates.size() != 4) {
905 return errors::InvalidArgument(
906 "Dilation2D requires the rates attribute to contain 4 values, but "
907 "got: ",
908 rates.size());
909 }
910
911 int32 stride_rows = strides[1];
912 int32 stride_cols = strides[2];
913
914 int32 rate_rows = rates[1];
915 int32 rate_cols = rates[2];
916
917 DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
918 DimensionHandle in_rows_dim = c->Dim(input_shape, 1);
919 DimensionHandle in_cols_dim = c->Dim(input_shape, 2);
920 DimensionHandle filter_rows_dim = c->Dim(filter_shape, 0);
921 DimensionHandle filter_cols_dim = c->Dim(filter_shape, 1);
922 DimensionHandle output_depth_dim = c->Dim(filter_shape, 2);
923
924 if (!c->ValueKnown(in_rows_dim) || !c->ValueKnown(in_cols_dim) ||
925 !c->ValueKnown(filter_rows_dim) || !c->ValueKnown(filter_cols_dim)) {
926 ShapeHandle output_shape =
927 c->MakeShape({batch_size_dim, InferenceContext::kUnknownDim,
928 InferenceContext::kUnknownDim, output_depth_dim});
929 c->set_output(0, output_shape);
930 return Status::OK();
931 }
932 DimensionHandle unused;
933 TF_RETURN_IF_ERROR(
934 c->Merge(c->Dim(input_shape, 3), output_depth_dim, &unused));
935
936 auto in_rows = c->Value(in_rows_dim);
937 auto in_cols = c->Value(in_cols_dim);
938 auto filter_rows = c->Value(filter_rows_dim);
939 auto filter_cols = c->Value(filter_cols_dim);
940 auto filter_rows_eff = filter_rows + (filter_rows - 1) * (rate_rows - 1);
941 auto filter_cols_eff = filter_cols + (filter_cols - 1) * (rate_cols - 1);
942
943 Padding padding;
944 TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
945
946 int64 output_rows, output_cols;
947 int64 padding_before, padding_after;
948 TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
949 in_rows, filter_rows_eff, stride_rows, padding, &output_rows,
950 &padding_before, &padding_after));
951 TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
952 in_cols, filter_cols_eff, stride_cols, padding, &output_cols,
953 &padding_before, &padding_after));
954
955 ShapeHandle output_shape = c->MakeShape(
956 {batch_size_dim, output_rows, output_cols, output_depth_dim});
957 c->set_output(0, output_shape);
958 return Status::OK();
959 });
960
961 REGISTER_OP("Dilation2DBackpropInput")
962 .Input("input: T")
963 .Input("filter: T")
964 .Input("out_backprop: T")
965 .Output("in_backprop: T")
966 .Attr("T: realnumbertype")
967 .Attr("strides: list(int) >= 4")
968 .Attr("rates: list(int) >= 4")
969 .Attr(GetPaddingAttrString())
970 .SetShapeFn(shape_inference::UnchangedShape);
971
972 REGISTER_OP("Dilation2DBackpropFilter")
973 .Input("input: T")
974 .Input("filter: T")
975 .Input("out_backprop: T")
976 .Output("filter_backprop: T")
977 .Attr("T: realnumbertype")
978 .Attr("strides: list(int) >= 4")
979 .Attr("rates: list(int) >= 4")
980 .Attr(GetPaddingAttrString())
__anon3e672dd81e02(InferenceContext* c) 981 .SetShapeFn([](InferenceContext* c) {
982 c->set_output(0, c->input(1));
983 return Status::OK();
984 });
985
986 // --------------------------------------------------------------------------
987
988 REGISTER_OP("Relu")
989 .Input("features: T")
990 .Output("activations: T")
991 .Attr("T: {realnumbertype, qint8}")
992 .SetShapeFn(shape_inference::UnchangedShape);
993
994 REGISTER_OP("ReluGrad")
995 .Input("gradients: T")
996 .Input("features: T")
997 .Output("backprops: T")
998 .Attr("T: realnumbertype")
999 .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
1000
1001 REGISTER_OP("Relu6")
1002 .Input("features: T")
1003 .Output("activations: T")
1004 .Attr("T: realnumbertype")
1005 .SetShapeFn(shape_inference::UnchangedShape);
1006
1007 REGISTER_OP("Relu6Grad")
1008 .Input("gradients: T")
1009 .Input("features: T")
1010 .Output("backprops: T")
1011 .Attr("T: realnumbertype")
1012 .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
1013
1014 REGISTER_OP("LeakyRelu")
1015 .Input("features: T")
1016 .Output("activations: T")
1017 .Attr("alpha: float = 0.2")
1018 .Attr("T: {half, bfloat16, float, double} = DT_FLOAT")
1019 .SetShapeFn(shape_inference::UnchangedShape);
1020
1021 REGISTER_OP("LeakyReluGrad")
1022 .Input("gradients: T")
1023 .Input("features: T")
1024 .Output("backprops: T")
1025 .Attr("alpha: float = 0.2")
1026 .Attr("T: {half, bfloat16, float, double} = DT_FLOAT")
1027 .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
1028
1029 REGISTER_OP("Elu")
1030 .Input("features: T")
1031 .Output("activations: T")
1032 .Attr("T: {half, bfloat16, float, double}")
1033 .SetShapeFn(shape_inference::UnchangedShape);
1034
1035 REGISTER_OP("EluGrad")
1036 .Input("gradients: T")
1037 .Input("outputs: T")
1038 .Output("backprops: T")
1039 .Attr("T: {half, bfloat16, float, double}")
1040 .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
1041
1042 REGISTER_OP("Selu")
1043 .Input("features: T")
1044 .Output("activations: T")
1045 .Attr("T: {half, bfloat16, float, double}")
1046 .SetShapeFn(shape_inference::UnchangedShape);
1047
1048 REGISTER_OP("SeluGrad")
1049 .Input("gradients: T")
1050 .Input("outputs: T")
1051 .Output("backprops: T")
1052 .Attr("T: {half, bfloat16, float, double}")
1053 .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
1054
1055 REGISTER_OP("Softplus")
1056 .Input("features: T")
1057 .Output("activations: T")
1058 .Attr("T: {half, bfloat16, float, double}")
1059 .SetShapeFn(shape_inference::UnchangedShape);
1060
1061 REGISTER_OP("SoftplusGrad")
1062 .Input("gradients: T")
1063 .Input("features: T")
1064 .Output("backprops: T")
1065 .Attr("T: {half, bfloat16, float, double}")
1066 .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
1067
1068 REGISTER_OP("Softsign")
1069 .Input("features: T")
1070 .Output("activations: T")
1071 .Attr("T: {half, bfloat16, float, double}")
1072 .SetShapeFn(shape_inference::UnchangedShape);
1073
1074 REGISTER_OP("SoftsignGrad")
1075 .Input("gradients: T")
1076 .Input("features: T")
1077 .Output("backprops: T")
1078 .Attr("T: {half, bfloat16, float, double}")
1079 .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
1080
1081 // --------------------------------------------------------------------------
1082
1083 REGISTER_OP("Softmax")
1084 .Input("logits: T")
1085 .Output("softmax: T")
1086 .Attr("T: {half, bfloat16, float, double}")
__anon3e672dd81f02(InferenceContext* c) 1087 .SetShapeFn([](InferenceContext* c) {
1088 return shape_inference::UnchangedShapeWithRankAtLeast(c, 1);
1089 });
1090
1091 // --------------------------------------------------------------------------
1092
1093 REGISTER_OP("LogSoftmax")
1094 .Input("logits: T")
1095 .Output("logsoftmax: T")
1096 .Attr("T: {half, bfloat16, float, double}")
__anon3e672dd82002(InferenceContext* c) 1097 .SetShapeFn([](InferenceContext* c) {
1098 return shape_inference::UnchangedShapeWithRankAtLeast(c, 1);
1099 });
1100
1101 // --------------------------------------------------------------------------
1102
1103 REGISTER_OP("SoftmaxCrossEntropyWithLogits")
1104 .Input("features: T")
1105 .Input("labels: T")
1106 .Output("loss: T")
1107 .Output("backprop: T")
1108 .Attr("T: {half, bfloat16, float, double}")
__anon3e672dd82102(InferenceContext* c) 1109 .SetShapeFn([](InferenceContext* c) {
1110 ShapeHandle input;
1111 if (c->WithRank(c->input(0), 2, &input) == Status::OK() &&
1112 c->Merge(input, c->input(1), &input) == Status::OK()) {
1113 DimensionHandle batch_size = c->Dim(input, 0);
1114 c->set_output(0, c->Vector(batch_size));
1115 c->set_output(1, input);
1116 return Status::OK();
1117 }
1118 TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFn(c, 1));
1119
1120 if (!c->RankKnown(c->output(1))) {
1121 return errors::InvalidArgument(
1122 "Shape must be broadcasted with rank 2, but is rank is unknown.");
1123 }
1124
1125 if (c->Rank(c->output(1)) != 2) {
1126 return errors::InvalidArgument(
1127 "Shape must be broadcasted with rank 2, but is rank ",
1128 c->Rank(c->output(1)));
1129 }
1130 DimensionHandle batch_size = c->Dim(c->output(1), 0);
1131 c->set_output(0, c->Vector(batch_size));
1132 return Status::OK();
1133 });
1134
1135 REGISTER_OP("SparseSoftmaxCrossEntropyWithLogits")
1136 .Input("features: T")
1137 .Input("labels: Tlabels")
1138 .Output("loss: T")
1139 .Output("backprop: T")
1140 .Attr("T: {half, bfloat16, float, double}")
1141 .Attr("Tlabels: {int32, int64} = DT_INT64")
__anon3e672dd82202(InferenceContext* c) 1142 .SetShapeFn([](InferenceContext* c) {
1143 ShapeHandle features;
1144 ShapeHandle labels;
1145 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &features));
1146 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &labels));
1147
1148 DimensionHandle batch_size;
1149 TF_RETURN_IF_ERROR(
1150 c->Merge(c->Dim(features, 0), c->Dim(labels, 0), &batch_size));
1151 TF_RETURN_IF_ERROR(c->ReplaceDim(features, 0, batch_size, &features));
1152
1153 c->set_output(0, c->Vector(batch_size));
1154 c->set_output(1, features);
1155 return Status::OK();
1156 });
1157
1158 // --------------------------------------------------------------------------
1159
1160 REGISTER_OP("InTopK")
1161 .Input("predictions: float")
1162 .Input("targets: T")
1163 .Output("precision: bool")
1164 .Attr("k: int")
1165 .Attr("T: {int32, int64} = DT_INT32")
__anon3e672dd82302(InferenceContext* c) 1166 .SetShapeFn([](InferenceContext* c) {
1167 ShapeHandle predictions;
1168 ShapeHandle targets;
1169 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &predictions));
1170 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &targets));
1171 DimensionHandle batch_size;
1172 TF_RETURN_IF_ERROR(
1173 c->Merge(c->Dim(predictions, 0), c->Dim(targets, 0), &batch_size));
1174 c->set_output(0, c->Vector(batch_size));
1175 return Status::OK();
1176 });
1177
1178 // This is the same as `InTopK`, but takes `k` as in input rather than an attr.
1179 REGISTER_OP("InTopKV2")
1180 .Input("predictions: float")
1181 .Input("targets: T")
1182 .Input("k: T")
1183 .Output("precision: bool")
1184 .Attr("T: {int32, int64} = DT_INT32")
__anon3e672dd82402(InferenceContext* c) 1185 .SetShapeFn([](InferenceContext* c) {
1186 ShapeHandle predictions;
1187 ShapeHandle targets;
1188 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &predictions));
1189 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &targets));
1190 DimensionHandle batch_size;
1191 TF_RETURN_IF_ERROR(
1192 c->Merge(c->Dim(predictions, 0), c->Dim(targets, 0), &batch_size));
1193 c->set_output(0, c->Vector(batch_size));
1194 return Status::OK();
1195 });
1196
1197 namespace {
1198
TopKShapeFn(InferenceContext * c)1199 Status TopKShapeFn(InferenceContext* c) {
1200 ShapeHandle input;
1201 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input));
1202
1203 // Get the k value, either from input tensor or attribute.
1204 DimensionHandle k_dim;
1205 if (c->num_inputs() >= 2) {
1206 TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(1, &k_dim));
1207 } else {
1208 int32 k;
1209 TF_RETURN_IF_ERROR(c->GetAttr("k", &k));
1210 if (k < 0) {
1211 return errors::InvalidArgument("Need k >= 0, got ", k);
1212 }
1213 k_dim = c->MakeDim(k);
1214 }
1215
1216 DimensionHandle last_dim = c->Dim(input, -1);
1217 if (c->ValueKnown(last_dim) && c->ValueKnown(k_dim) &&
1218 c->Value(last_dim) < c->Value(k_dim)) {
1219 return errors::InvalidArgument(
1220 "input must have last dimension >= k = ", c->Value(k_dim), " but is ",
1221 c->Value(last_dim));
1222 }
1223
1224 // Replace last_dim with k_dim.
1225 ShapeHandle s;
1226 TF_RETURN_IF_ERROR(c->Subshape(input, 0, -1, &s));
1227 TF_RETURN_IF_ERROR(c->Concatenate(s, c->Vector(k_dim), &s));
1228 c->set_output(0, s);
1229 c->set_output(1, s);
1230 return Status::OK();
1231 }
1232
1233 } // namespace
1234
1235 REGISTER_OP("TopK")
1236 .Input("input: T")
1237 .Output("values: T")
1238 .Output("indices: int32")
1239 .Attr("k: int >= 0")
1240 .Attr("sorted: bool = true")
1241 .Attr("T: realnumbertype")
1242 .Deprecated(7, "Use TopKV2 instead")
1243 .SetShapeFn(TopKShapeFn);
1244
1245 // This is the same as `TopK`, but takes `k` as in input rather than an attr.
1246 REGISTER_OP("TopKV2")
1247 .Input("input: T")
1248 .Input("k: int32")
1249 .Output("values: T")
1250 .Output("indices: int32")
1251 .Attr("sorted: bool = true")
1252 .Attr("T: realnumbertype")
1253 .SetShapeFn(TopKShapeFn);
1254
1255 // --------------------------------------------------------------------------
1256
1257 REGISTER_OP("NthElement")
1258 .Input("input: T")
1259 .Input("n: int32")
1260 .Output("values: T")
1261 .Attr("reverse: bool = false")
1262 .Attr("T: realnumbertype")
__anon3e672dd82602(InferenceContext* c) 1263 .SetShapeFn([](InferenceContext* c) {
1264 ShapeHandle input;
1265 TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input));
1266
1267 // Get the n value from input tensor, and make sure which is a scalar.
1268 DimensionHandle n_dim;
1269 TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(1, &n_dim));
1270
1271 // The last dimension of input tensor must be greater than N.
1272 DimensionHandle last_dim = c->Dim(input, -1);
1273 if (c->ValueKnown(last_dim) && c->ValueKnown(n_dim) &&
1274 c->Value(last_dim) <= c->Value(n_dim)) {
1275 return errors::InvalidArgument(
1276 "Input must have last dimension > n = ", c->Value(n_dim),
1277 " but is ", c->Value(last_dim));
1278 }
1279
1280 // Reduce last_dim for output tensor
1281 ShapeHandle s;
1282 TF_RETURN_IF_ERROR(c->Subshape(input, 0, -1, &s));
1283 c->set_output(0, s);
1284 return Status::OK();
1285 });
1286
1287 // --------------------------------------------------------------------------
1288
1289 REGISTER_OP("FractionalMaxPool")
1290 .Input("value: T")
1291 .Output("output: T")
1292 .Output("row_pooling_sequence: int64")
1293 .Output("col_pooling_sequence: int64")
1294 .Attr("pooling_ratio: list(float) >=4")
1295 .Attr("pseudo_random: bool = false")
1296 .Attr("overlapping: bool = false")
1297 .Attr("deterministic: bool = false")
1298 .Attr("seed: int = 0")
1299 .Attr("seed2: int = 0")
1300 .Attr("T: {float, double, int32, int64}")
1301 .SetShapeFn(FractionalPoolShapeFn);
1302
1303 REGISTER_OP("FractionalMaxPoolGrad")
1304 .Input("orig_input: T")
1305 .Input("orig_output: T")
1306 .Input("out_backprop: T")
1307 .Input("row_pooling_sequence: int64")
1308 .Input("col_pooling_sequence: int64")
1309 .Output("output: T")
1310 .Attr("overlapping: bool = false")
1311 .Attr("T: {float, double, int32, int64}")
__anon3e672dd82702(InferenceContext* c) 1312 .SetShapeFn([](InferenceContext* c) {
1313 return shape_inference::UnchangedShapeWithRank(c, 4);
1314 });
1315
1316 // --------------------------------------------------------------------------
1317
1318 REGISTER_OP("FractionalAvgPool")
1319 .Input("value: T")
1320 .Output("output: T")
1321 .Output("row_pooling_sequence: int64")
1322 .Output("col_pooling_sequence: int64")
1323 .Attr("pooling_ratio: list(float) >=4")
1324 .Attr("pseudo_random: bool = false")
1325 .Attr("overlapping: bool = false")
1326 .Attr("deterministic: bool = false")
1327 .Attr("seed: int = 0")
1328 .Attr("seed2: int = 0")
1329 .Attr("T: {float, double, int32, int64}")
1330 .SetShapeFn(FractionalPoolShapeFn);
1331
1332 REGISTER_OP("FractionalAvgPoolGrad")
1333 .Input("orig_input_tensor_shape: int64")
1334 .Input("out_backprop: T")
1335 .Input("row_pooling_sequence: int64")
1336 .Input("col_pooling_sequence: int64")
1337 .Output("output: T")
1338 .Attr("overlapping: bool = false")
1339 .Attr("T: {float, double, int32, int64}")
__anon3e672dd82802(InferenceContext* c) 1340 .SetShapeFn([](InferenceContext* c) {
1341 if (c->input_tensor(0) != nullptr) {
1342 ShapeHandle out;
1343 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
1344 c->set_output(0, out);
1345 } else {
1346 c->set_output(0, c->UnknownShapeOfRank(4));
1347 }
1348 return Status::OK();
1349 });
1350
1351 REGISTER_OP("QuantizedAvgPool")
1352 .Input("input: T")
1353 .Input("min_input: float")
1354 .Input("max_input: float")
1355 .Output("output: T")
1356 .Output("min_output: float")
1357 .Output("max_output: float")
1358 .Attr("T: quantizedtype")
1359 .Attr("ksize: list(int)")
1360 .Attr("strides: list(int)")
1361 .Attr(GetPaddingAttrString())
__anon3e672dd82902(InferenceContext* c) 1362 .SetShapeFn([](InferenceContext* c) {
1363 TF_RETURN_IF_ERROR(shape_inference::AvgPoolShape(c));
1364 ShapeHandle unused;
1365 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1366 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1367 c->set_output(1, c->Scalar());
1368 c->set_output(2, c->Scalar());
1369 return Status::OK();
1370 });
1371
1372 REGISTER_OP("QuantizedBiasAdd")
1373 .Input("input: T1")
1374 .Input("bias: T2")
1375 .Input("min_input: float")
1376 .Input("max_input: float")
1377 .Input("min_bias: float")
1378 .Input("max_bias: float")
1379 .Output("output: out_type")
1380 .Output("min_out: float")
1381 .Output("max_out: float")
1382 .Attr("T1: quantizedtype")
1383 .Attr("T2: quantizedtype")
1384 .Attr("out_type: quantizedtype")
__anon3e672dd82a02(InferenceContext* c) 1385 .SetShapeFn([](InferenceContext* c) {
1386 TF_RETURN_IF_ERROR(shape_inference::BiasAddShape(c));
1387 ShapeHandle unused;
1388 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1389 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1390 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
1391 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
1392 c->set_output(1, c->Scalar());
1393 c->set_output(2, c->Scalar());
1394 return Status::OK();
1395 });
1396
1397 REGISTER_OP("QuantizedConv2D")
1398 .Input("input: Tinput")
1399 .Input("filter: Tfilter")
1400 .Input("min_input: float")
1401 .Input("max_input: float")
1402 .Input("min_filter: float")
1403 .Input("max_filter: float")
1404 .Output("output: out_type")
1405 .Output("min_output: float")
1406 .Output("max_output: float")
1407 .Attr("Tinput: quantizedtype")
1408 .Attr("Tfilter: quantizedtype")
1409 .Attr("out_type: quantizedtype = DT_QINT32")
1410 .Attr("strides: list(int)")
1411 .Attr(GetPaddingAttrString())
1412 .Attr("dilations: list(int) = [1, 1, 1, 1]")
__anon3e672dd82b02(InferenceContext* c) 1413 .SetShapeFn([](InferenceContext* c) {
1414 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
1415 ShapeHandle unused;
1416 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1417 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1418 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
1419 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
1420 c->set_output(1, c->Scalar());
1421 c->set_output(2, c->Scalar());
1422 return Status::OK();
1423 });
1424
1425 REGISTER_OP("QuantizedMaxPool")
1426 .Input("input: T")
1427 .Input("min_input: float")
1428 .Input("max_input: float")
1429 .Output("output: T")
1430 .Output("min_output: float")
1431 .Output("max_output: float")
1432 .Attr("T: quantizedtype")
1433 .Attr("ksize: list(int)")
1434 .Attr("strides: list(int)")
1435 .Attr(GetPaddingAttrString())
__anon3e672dd82c02(InferenceContext* c) 1436 .SetShapeFn([](InferenceContext* c) {
1437 TF_RETURN_IF_ERROR(shape_inference::MaxPoolShape(c));
1438 ShapeHandle unused;
1439 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1440 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1441 c->set_output(1, c->Scalar());
1442 c->set_output(2, c->Scalar());
1443 return Status::OK();
1444 });
1445
1446 REGISTER_OP("QuantizedRelu")
1447 .Input("features: Tinput")
1448 .Input("min_features: float")
1449 .Input("max_features: float")
1450 .Output("activations: out_type")
1451 .Output("min_activations: float")
1452 .Output("max_activations: float")
1453 .Attr("Tinput: quantizedtype")
1454 .Attr("out_type: quantizedtype = DT_QUINT8")
__anon3e672dd82d02(InferenceContext* c) 1455 .SetShapeFn([](InferenceContext* c) {
1456 TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
1457 ShapeHandle unused;
1458 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1459 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1460 c->set_output(1, c->Scalar());
1461 c->set_output(2, c->Scalar());
1462 return Status::OK();
1463 });
1464
1465 REGISTER_OP("QuantizedRelu6")
1466 .Input("features: Tinput")
1467 .Input("min_features: float")
1468 .Input("max_features: float")
1469 .Output("activations: out_type")
1470 .Output("min_activations: float")
1471 .Output("max_activations: float")
1472 .Attr("Tinput: quantizedtype")
1473 .Attr("out_type: quantizedtype = DT_QUINT8")
__anon3e672dd82e02(InferenceContext* c) 1474 .SetShapeFn([](InferenceContext* c) {
1475 TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
1476 ShapeHandle unused;
1477 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1478 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1479 c->set_output(1, c->Scalar());
1480 c->set_output(2, c->Scalar());
1481 return Status::OK();
1482 });
1483
1484 REGISTER_OP("QuantizedReluX")
1485 .Input("features: Tinput")
1486 .Input("max_value: float")
1487 .Input("min_features: float")
1488 .Input("max_features: float")
1489 .Output("activations: out_type")
1490 .Output("min_activations: float")
1491 .Output("max_activations: float")
1492 .Attr("Tinput: quantizedtype")
1493 .Attr("out_type: quantizedtype = DT_QUINT8")
__anon3e672dd82f02(InferenceContext* c) 1494 .SetShapeFn([](InferenceContext* c) {
1495 TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
1496 ShapeHandle unused;
1497 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1498 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1499 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1500 c->set_output(1, c->Scalar());
1501 c->set_output(2, c->Scalar());
1502 return Status::OK();
1503 });
1504
1505 REGISTER_OP("QuantizedBatchNormWithGlobalNormalization")
1506 .Input("t: Tinput")
1507 .Input("t_min: float")
1508 .Input("t_max: float")
1509 .Input("m: Tinput")
1510 .Input("m_min: float")
1511 .Input("m_max: float")
1512 .Input("v: Tinput")
1513 .Input("v_min: float")
1514 .Input("v_max: float")
1515 .Input("beta: Tinput")
1516 .Input("beta_min: float")
1517 .Input("beta_max: float")
1518 .Input("gamma: Tinput")
1519 .Input("gamma_min: float")
1520 .Input("gamma_max: float")
1521 .Output("result: out_type")
1522 .Output("result_min: float")
1523 .Output("result_max: float")
1524 .Attr("Tinput: quantizedtype")
1525 .Attr("out_type: quantizedtype")
1526 .Attr("variance_epsilon: float")
1527 .Attr("scale_after_normalization: bool")
__anon3e672dd83002(InferenceContext* c) 1528 .SetShapeFn([](InferenceContext* c) {
1529 ShapeHandle input;
1530 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
1531
1532 DimensionHandle last_dim = c->Dim(input, 3);
1533 for (int i = 1; i < 5; ++i) { // covers m, v, beta, gamma
1534 ShapeHandle vec;
1535 TF_RETURN_IF_ERROR(c->WithRank(c->input(i * 3), 1, &vec));
1536 TF_RETURN_IF_ERROR(c->Merge(last_dim, c->Dim(vec, 0), &last_dim));
1537 }
1538
1539 ShapeHandle out;
1540 TF_RETURN_IF_ERROR(c->ReplaceDim(input, 3, last_dim, &out));
1541 c->set_output(0, out);
1542 c->set_output(1, c->Scalar());
1543 c->set_output(2, c->Scalar());
1544
1545 return Status::OK();
1546 });
1547
1548 #ifdef INTEL_MKL
1549 REGISTER_OP("_MklDepthwiseConv2dNative")
1550 .Input("input: T")
1551 .Input("filter: T")
1552 .Input("mkl_input: uint8")
1553 .Input("mkl_filter: uint8")
1554 .Output("output: T")
1555 .Output("filter_output: T")
1556 .Output("mkl_output: uint8")
1557 .Output("mkl_filter_output: uint8")
1558 .Attr("T: {half, bfloat16, float, double}")
1559 .Attr("strides: list(int)")
1560 .Attr("is_filter_const: bool = false")
1561 .Attr(GetPaddingAttrString())
1562 .Attr(GetConvnetDataFormatAttrString())
1563 .Attr("dilations: list(int) = [1, 1, 1, 1]")
1564 .SetShapeFn(shape_inference::DepthwiseConv2DNativeShape);
1565
1566 REGISTER_OP("_MklConv2D")
1567 .Input("input: T")
1568 .Input("filter: T")
1569 .Input("mkl_input: uint8")
1570 .Input("mkl_filter: uint8")
1571 .Output("output: T")
1572 .Output("filter_output: T")
1573 .Output("mkl_output: uint8")
1574 .Output("mkl_filter_output: uint8")
1575 .Attr("T: {half, float, double}")
1576 .Attr("strides: list(int)")
1577 .Attr("use_cudnn_on_gpu: bool = true")
1578 .Attr("is_filter_const: bool = false")
1579 .Attr(GetPaddingAttrString())
1580 .Attr(GetConvnetDataFormatAttrString())
1581 .Attr("dilations: list(int) = [1, 1, 1, 1]")
1582 .SetShapeFn(shape_inference::Conv2DShape)
1583 .Doc(R"doc(
1584 MKL version of Conv2D operator. Uses MKL DNN APIs to perform 2D convolution.
1585
1586 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1587 expected to invoke these operators.
1588 )doc");
1589
1590 REGISTER_OP("__MklDummyConv2DWithBias")
1591 .Input("input: T")
1592 .Input("filter: T")
1593 .Input("bias: T")
1594 .Output("output: T")
1595 .Attr("T: {half, float, double}")
1596 .Attr("strides: list(int)")
1597 .Attr("use_cudnn_on_gpu: bool = true")
1598 .Attr("is_filter_const: bool = false")
1599 .Attr(GetPaddingAttrString())
1600 .Attr(GetConvnetDataFormatAttrString())
1601 .Attr("dilations: list(int) = [1, 1, 1, 1]")
1602 .SetShapeFn(shape_inference::Conv2DShape)
1603 .Doc(R"doc(
1604 Dummy node that enables fusing Conv2D and BiasAdd operator for MKL. This node
1605 does not perform anything. It is just created as an intermediate output of
1606 merging Conv2D and BiasAdd.
1607
1608 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1609 expected to invoke these operators.
1610 )doc");
1611
1612 REGISTER_OP("_MklConv2DWithBias")
1613 .Input("input: T")
1614 .Input("filter: T")
1615 .Input("bias: T")
1616 .Input("mkl_input: uint8")
1617 .Input("mkl_filter: uint8")
1618 .Input("mkl_bias: uint8")
1619 .Output("output: T")
1620 .Output("filter_output: T")
1621 .Output("mkl_output: uint8")
1622 .Output("mkl_filter_output: uint8")
1623 .Attr("T: {half, float, double}")
1624 .Attr("strides: list(int)")
1625 .Attr("use_cudnn_on_gpu: bool = true")
1626 .Attr("is_filter_const: bool = false")
1627 .Attr(GetPaddingAttrString())
1628 .Attr(GetConvnetDataFormatAttrString())
1629 .Attr("dilations: list(int) = [1, 1, 1, 1]")
1630 .SetShapeFn(shape_inference::Conv2DShape)
1631 .Doc(R"doc(
1632 MKL version of Conv2D and BiasAdd operator. Uses MKL DNN APIs to perform
1633 2D convolution and add Bias to the output of convolution.
1634
1635 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1636 expected to invoke these operators.
1637 )doc");
1638
1639 REGISTER_OP("__MklDummyPadWithConv2D")
1640 .Input("input: T")
1641 .Input("filter: T")
1642 .Input("paddings: Tpaddings")
1643 .Output("output: T")
1644 .Attr("T: {half, float, double}")
1645 .Attr("strides: list(int)")
1646 .Attr("use_cudnn_on_gpu: bool = true")
1647 .Attr("is_filter_const: bool = false")
1648 .Attr(GetPaddingAttrString())
1649 .Attr(GetConvnetDataFormatAttrString())
1650 .Attr("dilations: list(int) = [1, 1, 1, 1]")
1651 .Attr("Tpaddings: {int32, int64} = DT_INT32")
1652 .SetShapeFn(shape_inference::Conv2DShape)
1653 .Doc(R"doc(
1654 Dummy node that enables fusing Pad and Conv2D operator for MKL. This node
1655 does not perform anything. It is just created as an intermediate output of
1656 merging Pad and Conv2D.
1657
1658 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1659 expected to invoke these operators.
1660 )doc");
1661
1662 REGISTER_OP("_MklPadWithConv2D")
1663 .Input("input: T")
1664 .Input("filter: T")
1665 .Input("paddings: Tpaddings")
1666 .Input("mkl_input: uint8")
1667 .Input("mkl_filter: uint8")
1668 .Input("mkl_paddings: uint8")
1669 .Output("output: T")
1670 .Output("filter_output: T")
1671 .Output("mkl_output: uint8")
1672 .Output("mkl_filter_output: uint8")
1673 .Attr("T: {half, float, double}")
1674 .Attr("strides: list(int)")
1675 .Attr("use_cudnn_on_gpu: bool = true")
1676 .Attr(GetPaddingAttrString())
1677 .Attr(GetConvnetDataFormatAttrString())
1678 .Attr("is_filter_const: bool = false")
1679 .Attr("dilations: list(int) = [1, 1, 1, 1]")
1680 .Attr("Tpaddings: {int32, int64} = DT_INT32")
1681 .SetShapeFn(shape_inference::Conv2DShape)
1682 .Doc(R"doc(
1683 MKL version of Pad and Conv2D operator. Uses MKL DNN APIs to perform
1684 Pad and 2D convolution to the output of convolution.
1685
1686 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1687 expected to invoke these operators.
1688 )doc");
1689
1690 REGISTER_OP("_MklConv2DBackpropFilter")
1691 .Input("input: T")
1692 .Input("filter_sizes: int32")
1693 .Input("out_backprop: T")
1694 .Input("mkl_input: uint8")
1695 .Input("mkl_filter_size: uint8")
1696 .Input("mkl_out_backprop: uint8")
1697 .Output("output: T")
1698 .Output("mkl_output: uint8")
1699 .Attr("T: {half, float, double}")
1700 .Attr("strides: list(int)")
1701 .Attr("use_cudnn_on_gpu: bool = true")
1702 .Attr(GetPaddingAttrString())
1703 .Attr(GetConvnetDataFormatAttrString())
1704 .Attr("dilations: list(int) = [1, 1, 1, 1]")
__anon3e672dd83102(InferenceContext* c) 1705 .SetShapeFn([](InferenceContext* c) {
1706 ShapeHandle s;
1707 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
1708 TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
1709 c->set_output(0, s);
1710 return Status::OK();
1711 })
1712 .Doc(R"doc(
1713 MKL version of Conv2DBackpropFilter. Uses MKL DNN APIs to compute the
1714 gradients of convolution with respect to the filter.
1715
1716 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1717 expected to invoke these operators.
1718 )doc");
1719
1720 REGISTER_OP("__MklDummyConv2DBackpropFilterWithBias")
1721 .Input("input: T")
1722 .Input("filter_sizes: int32")
1723 .Input("out_backprop: T")
1724 .Output("output: T")
1725 .Output("bias_grad: T")
1726 .Attr("T: {half, float, double}")
1727 .Attr("strides: list(int)")
1728 .Attr("use_cudnn_on_gpu: bool = true")
1729 .Attr(GetPaddingAttrString())
1730 .Attr(GetConvnetDataFormatAttrString())
1731 .Attr("dilations: list(int) = [1, 1, 1, 1]")
__anon3e672dd83202(InferenceContext* c) 1732 .SetShapeFn([](InferenceContext* c) {
1733 ShapeHandle input_shape;
1734 // Fetch the data_format attribute, which may not exist.
1735 string data_format;
1736 Status s = c->GetAttr("data_format", &data_format);
1737
1738 if (s.ok() && data_format == "NCHW") {
1739 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
1740 c->set_output(1, c->Vector(c->Dim(input_shape, -3)));
1741 } else {
1742 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
1743 c->set_output(1, c->Vector(c->Dim(input_shape, -1)));
1744 }
1745 ShapeHandle sh;
1746 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &sh));
1747 TF_RETURN_IF_ERROR(c->WithRank(sh, 4, &sh));
1748 c->set_output(0, sh);
1749 return Status::OK();
1750 })
1751 .Doc(R"doc(
1752 Dummy node that enables fusing Conv2DBackpropFilter and BiasAddGrad operator
1753 for MKL. This node does not perform anything. It is just created as an
1754 intermediate output of merging Conv2DBackpropFilter and BiasAddGrad.
1755
1756 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1757 expected to invoke these operators.
1758 )doc");
1759
1760 REGISTER_OP("_MklConv2DBackpropFilterWithBias")
1761 .Input("input: T")
1762 .Input("filter_sizes: int32")
1763 .Input("out_backprop: T")
1764 .Input("mkl_input: uint8")
1765 .Input("mkl_filter_size: uint8")
1766 .Input("mkl_out_backprop: uint8")
1767 .Output("output: T")
1768 .Output("bias_grad: T")
1769 .Output("mkl_output: uint8")
1770 .Output("mkl_bias_grad: uint8")
1771 .Attr("T: {half, float, double}")
1772 .Attr("strides: list(int)")
1773 .Attr("use_cudnn_on_gpu: bool = true")
1774 .Attr(GetPaddingAttrString())
1775 .Attr(GetConvnetDataFormatAttrString())
1776 .Attr("dilations: list(int) = [1, 1, 1, 1]")
__anon3e672dd83302(InferenceContext* c) 1777 .SetShapeFn([](InferenceContext* c) {
1778 ShapeHandle input_shape;
1779 // Fetch the data_format attribute, which may not exist.
1780 string data_format;
1781 Status s = c->GetAttr("data_format", &data_format);
1782
1783 if (s.ok() && data_format == "NCHW") {
1784 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
1785 c->set_output(1, c->Vector(c->Dim(input_shape, -3)));
1786 } else {
1787 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
1788 c->set_output(1, c->Vector(c->Dim(input_shape, -1)));
1789 }
1790 ShapeHandle sh;
1791 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &sh));
1792 TF_RETURN_IF_ERROR(c->WithRank(sh, 4, &sh));
1793 c->set_output(0, sh);
1794 return Status::OK();
1795 })
1796 .Doc(R"doc(
1797 MKL version of Conv2DBackpropFilterWithBias. Uses MKL DNN APIs to compute the
1798 gradients of convolution with respect to the filter.
1799
1800 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1801 expected to invoke these operators.
1802 )doc");
1803
1804 #ifdef INTEL_MKL_ML_ONLY
1805 REGISTER_OP("_MklConv2DWithBiasBackpropBias")
1806 .Input("out_backprop: T")
1807 .Input("mkl_out_backprop: uint8")
1808 .Output("output: T")
1809 .Output("mkl_output: uint8")
1810 .Attr("T: {half, float, double}")
1811 .Attr("strides: list(int)")
1812 .Attr(GetConvnetDataFormatAttrString())
1813 .Attr("dilations: list(int) = [1, 1, 1, 1]")
1814 .Doc(R"doc(
1815 MKL version of Conv2DBackpropBias. Uses MKL DNN APIs to compute the
1816 gradients of convolution with respect to the bias.
1817
1818 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1819 expected to invoke these operators.
1820 )doc");
1821 #endif
1822
1823 REGISTER_OP("_MklConv2DBackpropInput")
1824 .Input("input_sizes: int32")
1825 .Input("filter: T")
1826 .Input("out_backprop: T")
1827 .Input("mkl_input_sizes: uint8")
1828 .Input("mkl_filter: uint8")
1829 .Input("mkl_out_backprop: uint8")
1830 .Output("output: T")
1831 .Output("mkl_output: uint8")
1832 .Attr("T: {half, float, double}")
1833 .Attr("strides: list(int)")
1834 .Attr("use_cudnn_on_gpu: bool = true")
1835 .Attr(GetPaddingAttrString())
1836 .Attr(GetConvnetDataFormatAttrString())
1837 .Attr("dilations: list(int) = [1, 1, 1, 1]")
__anon3e672dd83402(InferenceContext* c) 1838 .SetShapeFn([](InferenceContext* c) {
1839 ShapeHandle s;
1840 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
1841 TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
1842 c->set_output(0, s);
1843 return Status::OK();
1844 })
1845 .Doc(R"doc(
1846 MKL version of Convolution2D backward input. Uses MKL DNN APIs to compute the
1847 gradients of convolution with respect to the input.
1848
1849 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1850 expected to invoke these operators.
1851 )doc");
1852
1853 REGISTER_OP("_MklConv3D")
1854 .Input("input: T")
1855 .Input("filter: T")
1856 .Input("mkl_input: uint8")
1857 .Input("mkl_filter: uint8")
1858 .Output("output: T")
1859 .Output("filter_output: T")
1860 .Output("mkl_output: uint8")
1861 .Output("mkl_filter_output: uint8")
1862 .Attr("T: {half, float, double}")
1863 .Attr("strides: list(int) >= 5")
1864 .Attr("is_filter_const: bool = false")
1865 .Attr(GetPaddingAttrString())
1866 .Attr(GetConvnet3dDataFormatAttrString())
1867 .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
1868 .SetShapeFn(shape_inference::Conv3DShape)
1869 .Doc(R"doc(
1870 MKL version of Conv3D operator. Uses MKL DNN APIs to perform 3D convolution.
1871
1872 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1873 expected to invoke these operators.
1874 )doc");
1875
1876 REGISTER_OP("_MklConv3DBackpropInputV2")
1877 .Input("input_sizes: Tshape")
1878 .Input("filter: T")
1879 .Input("out_backprop: T")
1880 .Input("mkl_input_sizes: uint8")
1881 .Input("mkl_filter: uint8")
1882 .Input("mkl_out_backprop: uint8")
1883 .Output("output: T")
1884 .Output("mkl_output: uint8")
1885 .Attr("T: {half, float, double}")
1886 .Attr("strides: list(int) >= 5")
1887 .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
1888 .Attr("Tshape: {int32, int64} = DT_INT32")
1889 .Attr(GetPaddingAttrString())
1890 .Attr(GetConvnet3dDataFormatAttrString())
__anon3e672dd83502(InferenceContext* c) 1891 .SetShapeFn([](InferenceContext* c) {
1892 ShapeHandle s;
1893 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
1894 TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
1895 c->set_output(0, s);
1896 return Status::OK();
1897 })
1898 .Doc(R"doc(
1899 MKL version of Convolution3D backward input. Uses MKL DNN APIs to compute the
1900 gradients of convolution with respect to the input.
1901
1902 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1903 expected to invoke these operators.
1904 )doc");
1905
1906 REGISTER_OP("_MklConv3DBackpropFilterV2")
1907 .Input("input: T")
1908 .Input("filter_sizes: int32")
1909 .Input("out_backprop: T")
1910 .Input("mkl_input: uint8")
1911 .Input("mkl_filter_size: uint8")
1912 .Input("mkl_out_backprop: uint8")
1913 .Output("output: T")
1914 .Output("mkl_output: uint8")
1915 .Attr("T: {half, float, double}")
1916 .Attr("strides: list(int)")
1917 .Attr(GetPaddingAttrString())
1918 .Attr(GetConvnet3dDataFormatAttrString())
1919 .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
__anon3e672dd83602(InferenceContext* c) 1920 .SetShapeFn([](InferenceContext* c) {
1921 ShapeHandle s;
1922 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
1923 TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
1924 c->set_output(0, s);
1925 return Status::OK();
1926 })
1927 .Doc(R"doc(
1928 MKL version of Conv3DBackpropFilter. Uses MKL DNN APIs to compute the
1929 gradients of convolution with respect to the filter.
1930
1931 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1932 expected to invoke these operators.
1933 )doc");
1934
1935 REGISTER_OP("_MklRelu")
1936 .Input("features: T")
1937 .Input("mkl_features: uint8")
1938 .Output("activations: T")
1939 .Output("mkl_activations: uint8")
1940 .Attr("T: realnumbertype")
1941 .SetShapeFn(shape_inference::UnchangedShape)
1942 .Doc(R"doc(
1943 MKL version of Relu operator. Uses MKL DNN APIs to implement Relu operator.
1944
1945 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1946 expected to invoke these operators.
1947 )doc");
1948
1949 REGISTER_OP("_MklReluGrad")
1950 .Input("gradients: T")
1951 .Input("features: T")
1952 .Input("mkl_gradients: uint8")
1953 .Input("mkl_features: uint8")
1954 .Output("backprops: T")
1955 .Output("mkl_backprops: uint8")
1956 .Attr("T: realnumbertype")
1957 .SetShapeFn(shape_inference::MergeBothInputsShapeFn)
1958 .Doc(R"doc(
1959 MKL version of ReluGrad operator. Uses MKL DNN APIs to compute rectified
1960 linear gradients for Relu operation.
1961
1962 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1963 expected to invoke these operators.
1964 )doc");
1965
1966 REGISTER_OP("_MklRelu6")
1967 .Input("features: T")
1968 .Input("mkl_features: uint8")
1969 .Output("activations: T")
1970 .Output("mkl_activations: uint8")
1971 .Attr("T: realnumbertype")
1972 .SetShapeFn(shape_inference::UnchangedShape)
1973 .Doc(R"doc(
1974 MKL version of Relu6 operator. Uses MKL DNN APIs to implement Relu6 operator.
1975
1976 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1977 expected to invoke these operators.
1978 )doc");
1979
1980 REGISTER_OP("_MklRelu6Grad")
1981 .Input("gradients: T")
1982 .Input("features: T")
1983 .Input("mkl_gradients: uint8")
1984 .Input("mkl_features: uint8")
1985 .Output("backprops: T")
1986 .Output("mkl_backprops: uint8")
1987 .Attr("T: realnumbertype")
1988 .SetShapeFn(shape_inference::MergeBothInputsShapeFn)
1989 .Doc(R"doc(
1990 MKL version of Relu6Grad operator. Uses MKL DNN APIs to compute rectified
1991 linear gradients for Relu6 operation.
1992
1993 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1994 expected to invoke these operators.
1995 )doc");
1996
1997 REGISTER_OP("_MklLeakyRelu")
1998 .Input("features: T")
1999 .Input("mkl_features: uint8")
2000 .Output("activations: T")
2001 .Output("mkl_activations: uint8")
2002 .Attr("T: {half, float, double} = DT_FLOAT")
2003 .Attr("alpha: float = 0.2")
2004 .SetShapeFn(shape_inference::UnchangedShape)
2005 .Doc(R"doc(
2006 MKL version of LeakyRelu operator. Uses MKL DNN APIs to implement
2007 LeakyRelu operator.
2008
2009 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2010 expected to invoke these operators.
2011 )doc");
2012
2013 REGISTER_OP("_MklLeakyReluGrad")
2014 .Input("gradients: T")
2015 .Input("features: T")
2016 .Input("mkl_gradients: uint8")
2017 .Input("mkl_features: uint8")
2018 .Output("backprops: T")
2019 .Output("mkl_backprops: uint8")
2020 .Attr("T: {half, float, double} = DT_FLOAT")
2021 .Attr("alpha: float = 0.2")
2022 .SetShapeFn(shape_inference::MergeBothInputsShapeFn)
2023 .Doc(R"doc(
2024 MKL version of LeakyReluGrad operator. Uses MKL DNN APIs to compute rectified
2025 linear gradients for LeakyReluGrad operation.
2026
2027 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2028 expected to invoke these operators.
2029 )doc");
2030
2031 REGISTER_OP("_MklElu")
2032 .Input("features: T")
2033 .Input("mkl_features: uint8")
2034 .Output("activations: T")
2035 .Output("mkl_activations: uint8")
2036 .Attr("T: realnumbertype")
2037 .SetShapeFn(shape_inference::UnchangedShape)
2038 .Doc(R"doc(
2039 MKL version of Elu operator. Uses MKL DNN APIs to implement Elu operator.
2040 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2041 expected to invoke these operators.
2042 )doc");
2043
2044 REGISTER_OP("_MklEluGrad")
2045 .Input("gradients: T")
2046 .Input("features: T")
2047 .Input("mkl_gradients: uint8")
2048 .Input("mkl_features: uint8")
2049 .Output("backprops: T")
2050 .Output("mkl_backprops: uint8")
2051 .Attr("T: realnumbertype")
2052 .SetShapeFn(shape_inference::MergeBothInputsShapeFn)
2053 .Doc(R"doc(
2054 MKL version of EluGrad operator. Uses MKL DNN APIs to compute Elu
2055 gradients for Elu operation.
2056 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2057 expected to invoke these operators.
2058 )doc");
2059
2060 REGISTER_OP("_MklSoftmax")
2061 .Input("logits: T")
2062 .Input("mkl_logits: uint8")
2063 .Output("softmax: T")
2064 .Output("mkl_softmax: uint8")
2065 .Attr("T: {half, float, double}")
__anon3e672dd83702(InferenceContext* c) 2066 .SetShapeFn([](InferenceContext* c) {
2067 return shape_inference::UnchangedShapeWithRankAtLeast(c, 1);
2068 })
2069 .Doc(R"doc(
2070 MKL version of ReluGrad operator. Uses MKL DNN APIs to compute rectified
2071 linear gradients for Relu operation.
2072 )doc");
2073
2074 REGISTER_OP("_MklTanh")
2075 .Input("features: T")
2076 .Input("mkl_features: uint8")
2077 .Output("activations: T")
2078 .Output("mkl_activations: uint8")
2079 .Attr("T: realnumbertype")
2080 .SetShapeFn(shape_inference::UnchangedShape)
2081 .Doc(R"doc(
2082 MKL version of Tanh operator. Uses MKL DNN APIs to implement Tanh operator.
2083 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2084 expected to invoke these operators.
2085 )doc");
2086
2087 REGISTER_OP("_MklTanhGrad")
2088 .Input("gradients: T")
2089 .Input("features: T")
2090 .Input("mkl_gradients: uint8")
2091 .Input("mkl_features: uint8")
2092 .Output("backprops: T")
2093 .Output("mkl_backprops: uint8")
2094 .Attr("T: realnumbertype")
2095 .SetShapeFn(shape_inference::MergeBothInputsShapeFn)
2096 .Doc(R"doc(
2097 MKL version of TanhGrad operator. Uses MKL DNN APIs to compute tanh
2098 gradients for Tanh operation.
2099 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2100 expected to invoke these operators.
2101 )doc");
2102
2103 REGISTER_OP("_MklMaxPool")
2104 .Attr("T: {float, half} = DT_FLOAT")
2105 .Attr("ksize: list(int) >= 4")
2106 .Attr("strides: list(int) >= 4")
2107 .Attr(GetPaddingAttrString())
2108 .Attr(GetConvnetDataFormatAttrString())
2109 .Attr("workspace_enabled: bool = false")
2110 .Input("input: T")
2111 .Input("mkl_input: uint8")
2112 .Output("output: T")
2113 #ifdef INTEL_MKL_ML_ONLY
2114 .Output("workspace: T")
2115 #else
2116 .Output("workspace: uint8")
2117 #endif
2118 .Output("mkl_output: uint8")
2119 .Output("mkl_workspace: uint8")
2120 .SetShapeFn(shape_inference::MaxPoolShape)
2121 .Doc(R"doc(
2122 MKL version of MaxPool operator. Uses MKL DNN APIs to perform max pooling
2123 on the input.
2124
2125 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2126 expected to invoke these operators.
2127 )doc");
2128
2129 REGISTER_OP("_MklMaxPoolGrad")
2130 .Attr("T: {float, half} = DT_FLOAT")
2131 .Attr("ksize: list(int) >= 4")
2132 .Attr("strides: list(int) >= 4")
2133 .Attr("workspace_enabled: bool = false")
2134 .Attr(GetPaddingAttrString())
2135 .Attr(GetConvnetDataFormatAttrString())
2136 .Input("orig_input: T")
2137 .Input("orig_output: T")
2138 .Input("grad: T")
2139 #ifdef INTEL_MKL_ML_ONLY
2140 .Input("workspace: T")
2141 #else
2142 .Input("workspace: uint8")
2143 #endif
2144 .Input("mkl_orig_input: uint8")
2145 .Input("mkl_orig_output: uint8")
2146 .Input("mkl_grad: uint8")
2147 .Input("mkl_workspace: uint8")
2148 .Output("output: T")
2149 .Output("mkl_output: uint8")
__anon3e672dd83802(InferenceContext* c) 2150 .SetShapeFn([](InferenceContext* c) {
2151 return UnchangedShapeWithRank(c, 4);
2152 })
2153 .Doc(R"doc(
2154 MKL version of MaxPoolGrad. Uses MKL DNN APIs to compute gradients of
2155 MaxPool operator.
2156
2157 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2158 expected to invoke these operators.
2159 )doc");
2160
2161 REGISTER_OP("_MklAvgPool")
2162 .Input("value: T")
2163 .Input("mkl_input: uint8")
2164 .Output("output: T")
2165 .Output("mkl_output: uint8")
2166 .Attr("ksize: list(int) >= 4")
2167 .Attr("strides: list(int) >= 4")
2168 .Attr(GetPaddingAttrString())
2169 .Attr(GetConvnetDataFormatAttrString())
2170 .Attr("T: {float, half, double}")
2171 .SetShapeFn(shape_inference::AvgPoolShape)
2172 .Doc(R"doc(
2173 MKL version of AvgPool operator. Uses MKL DNN APIs to perform average pooling
2174 on the input.
2175
2176 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2177 expected to invoke these operators.
2178 )doc");
2179
2180 REGISTER_OP("_MklAvgPoolGrad")
2181 .Input("orig_input_shape: int32")
2182 .Input("grad: T")
2183 .Input("mkl_orig_input: uint8")
2184 .Input("mkl_grad: uint8")
2185 .Output("output: T")
2186 .Output("mkl_output: uint8")
2187 .Attr("ksize: list(int) >= 4")
2188 .Attr("strides: list(int) >= 4")
2189 .Attr(GetPaddingAttrString())
2190 .Attr(GetConvnetDataFormatAttrString())
2191 .Attr("T: {float, half, double}")
__anon3e672dd83902(InferenceContext* c) 2192 .SetShapeFn([](InferenceContext* c) {
2193 ShapeHandle s;
2194 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
2195 TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
2196 c->set_output(0, s);
2197 return Status::OK();
2198 })
2199 .Doc(R"doc(
2200 MKL version of AvgPoolGrad operator. Uses MKL DNN APIs to compute gradients
2201 of AvgPool function.
2202
2203 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2204 expected to invoke these operators.
2205 )doc");
2206
2207 REGISTER_OP("_MklAvgPool3D")
2208 .Input("value: T")
2209 .Input("mkl_input: uint8")
2210 .Output("output: T")
2211 .Output("mkl_output: uint8")
2212 .Attr("ksize: list(int) >= 5")
2213 .Attr("strides: list(int) >= 5")
2214 .Attr(GetPaddingAttrString())
2215 .Attr(GetConvnet3dDataFormatAttrString())
2216 .Attr("T: {float, half, double}")
2217 .SetShapeFn(shape_inference::Pool3DShape)
2218 .Doc(R"doc(
2219 MKL version of AvgPool3D operator. Uses MKL DNN APIs to perform average pooling
2220 on the input.
2221
2222 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2223 expected to invoke these operators.
2224 )doc");
2225
2226 REGISTER_OP("_MklAvgPool3DGrad")
2227 .Input("orig_input_shape: int32")
2228 .Input("grad: T")
2229 .Input("mkl_orig_input: uint8")
2230 .Input("mkl_grad: uint8")
2231 .Output("output: T")
2232 .Output("mkl_output: uint8")
2233 .Attr("ksize: list(int) >= 5")
2234 .Attr("strides: list(int) >= 5")
2235 .Attr(GetPaddingAttrString())
2236 .Attr(GetConvnet3dDataFormatAttrString())
2237 .Attr("T: {float, half, double}")
__anon3e672dd83a02(InferenceContext* c) 2238 .SetShapeFn([](InferenceContext* c) {
2239 ShapeHandle s;
2240 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
2241 TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
2242 c->set_output(0, s);
2243 return Status::OK();
2244 })
2245 .Doc(R"doc(
2246 MKL version of AvgPool3DGrad operator. Uses MKL DNN APIs to compute gradients
2247 of AvgPool function.
2248
2249 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2250 expected to invoke these operators.
2251 )doc");
2252
2253 REGISTER_OP("_MklMaxPool3D")
2254 .Input("input: T")
2255 .Input("mkl_input: uint8")
2256 .Output("output: T")
2257 .Output("workspace: uint8")
2258 .Output("mkl_output: uint8")
2259 .Output("mkl_workspace: uint8")
2260 .Attr("ksize: list(int) >= 5")
2261 .Attr("strides: list(int) >= 5")
2262 .Attr(GetPaddingAttrString())
2263 .Attr(GetConvnet3dDataFormatAttrString())
2264 .Attr("T: {half, bfloat16, float}")
2265 .Attr("workspace_enabled: bool = false")
2266 .SetShapeFn(shape_inference::Pool3DShape)
2267 .Doc(R"doc(
2268 MKL version of MaxPool3D operator. Uses MKL DNN APIs to perform average pooling
2269 on the input.
2270
2271 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2272 expected to invoke these operators.
2273 )doc");
2274
2275 REGISTER_OP("_MklMaxPool3DGrad")
2276 .Input("orig_input: TInput")
2277 .Input("orig_output: TInput")
2278 .Input("grad: T")
2279 .Input("workspace: uint8")
2280 .Input("mkl_orig_input: uint8")
2281 .Input("mkl_orig_output: uint8")
2282 .Input("mkl_grad: uint8")
2283 .Input("mkl_workspace: uint8")
2284 .Output("output: T")
2285 .Output("mkl_output: uint8")
2286 .Attr("ksize: list(int) >= 5")
2287 .Attr("strides: list(int) >= 5")
2288 .Attr(GetPaddingAttrString())
2289 .Attr(GetConvnet3dDataFormatAttrString())
2290 .Attr("T: {half, bfloat16, float} = DT_FLOAT")
2291 .Attr("TInput: {half, bfloat16, float} = DT_FLOAT")
2292 .Attr("workspace_enabled: bool = false")
__anon3e672dd83b02(InferenceContext* c) 2293 .SetShapeFn([](InferenceContext* c) {
2294 return UnchangedShapeWithRank(c, 5);
2295 })
2296 .Doc(R"doc(
2297 MKL version of MklPool3DGrad operator. Uses MKL DNN APIs to compute gradients
2298 of MklPool function.
2299
2300 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2301 expected to invoke these operators.
2302 )doc");
2303
2304 REGISTER_OP("_MklLRN")
2305 .Input("input: T")
2306 .Input("mkl_input: uint8")
2307 .Output("output: T")
2308 .Output("workspace: uint8")
2309 .Output("mkl_output: uint8")
2310 .Output("mkl_workspace: uint8")
2311 .Attr("depth_radius: int = 5")
2312 .Attr("bias: float = 1.0")
2313 .Attr("alpha: float = 1.0")
2314 .Attr("beta: float = 0.5")
2315 .Attr("workspace_enabled: bool = false")
2316 .Attr("T: {float, half} = DT_FLOAT")
__anon3e672dd83c02(InferenceContext* c) 2317 .SetShapeFn([](InferenceContext* c) {
2318 return UnchangedShapeWithRank(c, 4);
2319 })
2320 .Doc(R"doc(
2321 MKL version of LRN operator. Uses MKL DNN APIs to perform local response
2322 normalization.
2323
2324 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2325 expected to invoke these operators.
2326 )doc");
2327
2328 REGISTER_OP("_MklLRNGrad")
2329 .Input("input_grads: T")
2330 .Input("input_image: T")
2331 .Input("output_image: T")
2332 .Input("workspace: uint8")
2333 .Input("mkl_input_grads: uint8")
2334 .Input("mkl_input_image: uint8")
2335 .Input("mkl_output_image: uint8")
2336 .Input("mkl_workspace: uint8")
2337 .Output("output: T")
2338 .Output("mkl_output: uint8")
2339 .Attr("depth_radius: int = 5")
2340 .Attr("bias: float = 1.0")
2341 .Attr("alpha: float = 1.0")
2342 .Attr("beta: float = 0.5")
2343 .Attr("workspace_enabled: bool = false")
2344 .Attr("T: {float, half} = DT_FLOAT")
__anon3e672dd83d02(InferenceContext* c) 2345 .SetShapeFn([](InferenceContext* c) {
2346 ShapeHandle s;
2347 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &s)); // input_grads
2348 TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s)); // input_image
2349 TF_RETURN_IF_ERROR(c->Merge(s, c->input(2), &s)); // output_image
2350 c->set_output(0, s);
2351 return Status::OK();
2352 })
2353 .Doc(R"doc(
2354 MKL version of LRNGrad operator. Uses MKL DNN APIs to compute gradient for
2355 local response normalization.
2356
2357 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2358 expected to invoke these operators.
2359 )doc");
2360
2361 REGISTER_OP("_MklFusedBatchNorm")
2362 .Input("x: T")
2363 .Input("scale: T")
2364 .Input("offset: T")
2365 .Input("mean: T")
2366 .Input("variance: T")
2367 .Input("mkl_x: uint8")
2368 .Input("mkl_scale: uint8")
2369 .Input("mkl_offset: uint8")
2370 .Input("mkl_mean: uint8")
2371 .Input("mkl_variance: uint8")
2372 .Output("y: T")
2373 .Output("batch_mean: T")
2374 .Output("batch_variance: T")
2375 .Output("reserve_space_1: T")
2376 .Output("reserve_space_2: T")
2377 .Output("mkl_y: uint8")
2378 .Output("mkl_batch_mean: uint8")
2379 .Output("mkl_batch_variance: uint8")
2380 .Output("mkl_reserve_space_1: uint8")
2381 .Output("mkl_reserve_space_2: uint8")
2382 .Attr("T: numbertype")
2383 .Attr("epsilon: float = 0.0001")
2384 .Attr("data_format: string = 'NHWC'")
2385 .Attr("is_training: bool = true")
__anon3e672dd83e02(InferenceContext* c) 2386 .SetShapeFn([](InferenceContext* c) {
2387 ShapeHandle x;
2388 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &x));
2389
2390 bool is_training;
2391 TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
2392 int number_inputs = (is_training) ? 3 : 5;
2393 string data_format;
2394 TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format));
2395 DimensionHandle channel_dim =
2396 (data_format == "NHWC") ? c->Dim(x, 3) : c->Dim(x, 1);
2397
2398 // covers scale, offset, and if is_training is false, mean, variance
2399 for (int i = 1; i < number_inputs; ++i) {
2400 ShapeHandle vec;
2401 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
2402 TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim));
2403 }
2404
2405 ShapeHandle y;
2406 if (data_format == "NHWC") {
2407 TF_RETURN_IF_ERROR(c->ReplaceDim(x, 3, channel_dim, &y));
2408 } else {
2409 TF_RETURN_IF_ERROR(c->ReplaceDim(x, 1, channel_dim, &y));
2410 }
2411 c->set_output(0, y);
2412 ShapeHandle vector_shape = c->Vector(channel_dim);
2413 c->set_output(1, vector_shape);
2414 c->set_output(2, vector_shape);
2415 c->set_output(3, vector_shape);
2416 c->set_output(4, vector_shape);
2417 return Status::OK();
2418 })
2419 .Doc(R"doc(
2420 MKL version of FusedBatchNorm operator. Uses MKL DNN APIs to perform fused
2421 batch normalization.
2422
2423 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2424 expected to invoke these operators.
2425 )doc");
2426
2427 REGISTER_OP("_MklFusedBatchNormGrad")
2428 .Input("y_backprop: T")
2429 .Input("x: T")
2430 .Input("scale: T")
2431 .Input("reserve_space_1: T")
2432 .Input("reserve_space_2: T")
2433 .Input("mkl_y_backprop: uint8")
2434 .Input("mkl_x: uint8")
2435 .Input("mkl_scale: uint8")
2436 .Input("mkl_reserve_space_1: uint8")
2437 .Input("mkl_reserve_space_2: uint8")
2438 .Output("x_backprop: T")
2439 .Output("scale_backprop: T")
2440 .Output("offset_backprop: T")
2441 .Output("reserve_space_3: T")
2442 .Output("reserve_space_4: T")
2443 .Output("mkl_x_backprop: uint8")
2444 .Output("mkl_scale_backprop: uint8")
2445 .Output("mkl_offset_backprop: uint8")
2446 .Output("mkl_reserve_space_3: uint8")
2447 .Output("mkl_reserve_space_4: uint8")
2448 .Attr("T: numbertype")
2449 .Attr("epsilon: float = 0.0001")
2450 .Attr("data_format: string = 'NHWC'")
2451 .Attr("is_training: bool = true")
__anon3e672dd83f02(InferenceContext* c) 2452 .SetShapeFn([](InferenceContext* c) {
2453 ShapeHandle y_backprop;
2454 TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &y_backprop));
2455 ShapeHandle x;
2456 TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &x));
2457
2458 bool is_training;
2459 string data_format;
2460 TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
2461 TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format));
2462 DimensionHandle channel_dim = (data_format == "NHWC")
2463 ? c->Dim(y_backprop, 3)
2464 : c->Dim(y_backprop, 1);
2465 if (data_format == "NHWC") {
2466 TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(x, 3), &channel_dim));
2467 } else {
2468 TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(x, 1), &channel_dim));
2469 }
2470
2471 // covers scale, mean (reserve_space_1), variance (reserve_space_2)
2472 for (int i = 2; i < 5; ++i) {
2473 ShapeHandle vec;
2474 TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
2475 TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim));
2476 }
2477
2478 ShapeHandle x_backprop;
2479 if (data_format == "NHWC") {
2480 TF_RETURN_IF_ERROR(
2481 c->ReplaceDim(y_backprop, 3, channel_dim, &x_backprop));
2482 } else {
2483 TF_RETURN_IF_ERROR(
2484 c->ReplaceDim(y_backprop, 1, channel_dim, &x_backprop));
2485 }
2486 c->set_output(0, x_backprop);
2487 c->set_output(1, c->Vector(channel_dim));
2488 c->set_output(2, c->Vector(channel_dim));
2489 // Set the correct shapes for reserve_spaces
2490 // so that gradients can be performed when
2491 // the op is in a symbolic condition.
2492 if (is_training) {
2493 c->set_output(3, c->Vector(0));
2494 c->set_output(4, c->Vector(0));
2495 } else {
2496 c->set_output(3, c->Vector(channel_dim));
2497 c->set_output(4, c->Vector(channel_dim));
2498 }
2499 return Status::OK();
2500 })
2501 .Doc(R"doc(
2502 MKL version of FusedBatchNormGrad operator. Uses MKL DNN APIs to compute
2503 gradients for fused batch normalization.
2504
2505 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2506 expected to invoke these operators.
2507 )doc");
2508
2509 REGISTER_OP("_MklToTf")
2510 .Input("input: T")
2511 .Input("mkl_input: uint8")
2512 .Output("output: T")
2513 .Attr("T: {half, float, double, qint8, quint8, qint32}")
2514 .Attr(GetConvnetDataFormat2D3DAttrString())
2515 .SetShapeFn(shape_inference::UnknownShape)
2516 .Doc(R"doc(
2517 MKL operator to convert a tensor from MKL layout to TensorFlow layout.
2518
2519 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2520 expected to invoke these operators.
2521 )doc");
2522
2523 REGISTER_OP("_MklInputConversion")
2524 .Input("input_0: T")
2525 .Input("input_1: T")
2526 .Input("mkl_input_0: uint8")
2527 .Input("mkl_input_1: uint8")
2528 .Output("output_0: T")
2529 .Output("output_1: T")
2530 .Output("mkl_output_0: uint8")
2531 .Output("mkl_output_1: uint8")
2532 // All datatypes supported by element-wise ops
2533 .Attr(
2534 "T: {half, float, double, uint8, int8, uint16, int16, int32, int64, "
2535 "complex64, complex128}")
2536 .Attr(GetConvnetDataFormat2D3DAttrString())
2537 .SetShapeFn(shape_inference::UnknownShape)
2538 .Doc(R"doc(
2539 MKL operator to process the inputs to an elementwise MKL op. Both inputs
2540 need to be either in TF or in MKL format. This op is added before every
2541 element-wise MKL op.
2542
2543 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2544 expected to invoke these operators.
2545 )doc");
2546
2547 #endif // INTEL_MKL
2548 REGISTER_OP("QuantizedConv2DAndRequantize")
2549 .Input("input: Tinput")
2550 .Input("filter: Tfilter")
2551 .Input("min_input: float")
2552 .Input("max_input: float")
2553 .Input("min_filter: float")
2554 .Input("max_filter: float")
2555 .Input("min_freezed_output: float")
2556 .Input("max_freezed_output: float")
2557 .Output("output: out_type")
2558 .Output("min_output: float")
2559 .Output("max_output: float")
2560 .Attr("Tinput: quantizedtype")
2561 .Attr("Tfilter: quantizedtype")
2562 .Attr("out_type: quantizedtype = DT_QINT8")
2563 .Attr("strides: list(int)")
2564 .Attr(GetPaddingAttrString())
2565 .Attr("dilations: list(int) = [1, 1, 1, 1]")
2566 .Attr("padding_list: list(int) = []")
__anon3e672dd84002(InferenceContext* c) 2567 .SetShapeFn([](InferenceContext* c) {
2568 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
2569 ShapeHandle unused;
2570 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
2571 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2572 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
2573 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
2574 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
2575 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
2576 c->set_output(1, c->Scalar());
2577 c->set_output(2, c->Scalar());
2578 return Status::OK();
2579 });
2580
2581 // Fusion of Quantized Conv2D and BiasAdd.
2582 REGISTER_OP("QuantizedConv2DWithBias")
2583 .Input("input: Tinput")
2584 .Input("filter: Tfilter")
2585 .Input("bias: float")
2586 .Input("min_input: float")
2587 .Input("max_input: float")
2588 .Input("min_filter: float")
2589 .Input("max_filter: float")
2590 .Output("output: out_type")
2591 .Output("min_output: float")
2592 .Output("max_output: float")
2593 .Attr("Tinput: quantizedtype")
2594 .Attr("Tfilter: quantizedtype")
2595 .Attr("out_type: quantizedtype = DT_QINT32")
2596 .Attr("strides: list(int)")
2597 .Attr(GetPaddingAttrString())
2598 .Attr("dilations: list(int) = [1, 1, 1, 1]")
2599 .Attr("padding_list: list(int) = []")
__anon3e672dd84102(InferenceContext* c) 2600 .SetShapeFn([](InferenceContext* c) {
2601 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
2602 ShapeHandle unused;
2603 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
2604 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2605 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
2606 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
2607 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
2608 c->set_output(1, c->Scalar());
2609 c->set_output(2, c->Scalar());
2610 return Status::OK();
2611 });
2612
2613 REGISTER_OP("QuantizedConv2DWithBiasAndRequantize")
2614 .Input("input: Tinput")
2615 .Input("filter: Tfilter")
2616 .Input("bias: Tbias")
2617 .Input("min_input: float")
2618 .Input("max_input: float")
2619 .Input("min_filter: float")
2620 .Input("max_filter: float")
2621 .Input("min_freezed_output: float")
2622 .Input("max_freezed_output: float")
2623 .Output("output: out_type")
2624 .Output("min_output: float")
2625 .Output("max_output: float")
2626 .Attr("Tinput: quantizedtype")
2627 .Attr("Tfilter: quantizedtype")
2628 .Attr("Tbias: {float, qint32}")
2629 .Attr("out_type: quantizedtype = DT_QINT8")
2630 .Attr("strides: list(int)")
2631 .Attr(GetPaddingAttrString())
2632 .Attr("dilations: list(int) = [1, 1, 1, 1]")
2633 .Attr("padding_list: list(int) = []")
__anon3e672dd84202(InferenceContext* c) 2634 .SetShapeFn([](InferenceContext* c) {
2635 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
2636 ShapeHandle unused;
2637 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
2638 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2639 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
2640 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
2641 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
2642 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
2643 TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused));
2644 c->set_output(1, c->Scalar());
2645 c->set_output(2, c->Scalar());
2646 return Status::OK();
2647 });
2648
2649 // Fusion of Quantized Conv2D and Relu.
2650 REGISTER_OP("QuantizedConv2DAndRelu")
2651 .Input("input: Tinput")
2652 .Input("filter: Tfilter")
2653 .Input("min_input: float")
2654 .Input("max_input: float")
2655 .Input("min_filter: float")
2656 .Input("max_filter: float")
2657 .Output("output: out_type")
2658 .Output("min_output: float")
2659 .Output("max_output: float")
2660 .Attr("Tinput: quantizedtype")
2661 .Attr("Tfilter: quantizedtype")
2662 .Attr("out_type: quantizedtype = DT_QINT32")
2663 .Attr("strides: list(int)")
2664 .Attr(GetPaddingAttrString())
2665 .Attr("dilations: list(int) = [1, 1, 1, 1]")
2666 .Attr("padding_list: list(int) = []")
__anon3e672dd84302(InferenceContext* c) 2667 .SetShapeFn([](InferenceContext* c) {
2668 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
2669 ShapeHandle unused;
2670 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
2671 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2672 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
2673 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
2674 c->set_output(1, c->Scalar());
2675 c->set_output(2, c->Scalar());
2676 return Status::OK();
2677 });
2678
2679 REGISTER_OP("QuantizedConv2DAndReluAndRequantize")
2680 .Input("input: Tinput")
2681 .Input("filter: Tfilter")
2682 .Input("min_input: float")
2683 .Input("max_input: float")
2684 .Input("min_filter: float")
2685 .Input("max_filter: float")
2686 .Input("min_freezed_output: float")
2687 .Input("max_freezed_output: float")
2688 .Output("output: out_type")
2689 .Output("min_output: float")
2690 .Output("max_output: float")
2691 .Attr("Tinput: quantizedtype")
2692 .Attr("Tfilter: quantizedtype")
2693 .Attr("out_type: quantizedtype = DT_QUINT8")
2694 .Attr("strides: list(int)")
2695 .Attr(GetPaddingAttrString())
2696 .Attr("dilations: list(int) = [1, 1, 1, 1]")
2697 .Attr("padding_list: list(int) = []")
__anon3e672dd84402(InferenceContext* c) 2698 .SetShapeFn([](InferenceContext* c) {
2699 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
2700 ShapeHandle unused;
2701 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
2702 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2703 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
2704 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
2705 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
2706 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
2707 c->set_output(1, c->Scalar());
2708 c->set_output(2, c->Scalar());
2709 return Status::OK();
2710 });
2711
2712 // Fusion of Quantized Conv2D, BiasAdd and Relu.
2713 REGISTER_OP("QuantizedConv2DWithBiasAndRelu")
2714 .Input("input: Tinput")
2715 .Input("filter: Tfilter")
2716 .Input("bias: float")
2717 .Input("min_input: float")
2718 .Input("max_input: float")
2719 .Input("min_filter: float")
2720 .Input("max_filter: float")
2721 .Output("output: out_type")
2722 .Output("min_output: float")
2723 .Output("max_output: float")
2724 .Attr("Tinput: quantizedtype")
2725 .Attr("Tfilter: quantizedtype")
2726 .Attr("out_type: quantizedtype = DT_QINT32")
2727 .Attr("strides: list(int)")
2728 .Attr(GetPaddingAttrString())
2729 .Attr("dilations: list(int) = [1, 1, 1, 1]")
2730 .Attr("padding_list: list(int) = []")
__anon3e672dd84502(InferenceContext* c) 2731 .SetShapeFn([](InferenceContext* c) {
2732 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
2733 ShapeHandle unused;
2734 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
2735 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2736 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
2737 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
2738 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
2739 c->set_output(1, c->Scalar());
2740 c->set_output(2, c->Scalar());
2741 return Status::OK();
2742 });
2743
2744 // Fusion of Quantized Conv2D, BiasAdd, Relu, and Requantize.
2745 REGISTER_OP("QuantizedConv2DWithBiasAndReluAndRequantize")
2746 .Input("input: Tinput")
2747 .Input("filter: Tfilter")
2748 .Input("bias: Tbias")
2749 .Input("min_input: float")
2750 .Input("max_input: float")
2751 .Input("min_filter: float")
2752 .Input("max_filter: float")
2753 .Input("min_freezed_output: float")
2754 .Input("max_freezed_output: float")
2755 .Output("output: out_type")
2756 .Output("min_output: float")
2757 .Output("max_output: float")
2758 .Attr("Tinput: quantizedtype")
2759 .Attr("Tfilter: quantizedtype")
2760 .Attr("Tbias: {float, qint32}")
2761 .Attr("out_type: quantizedtype = DT_QUINT8")
2762 .Attr("strides: list(int)")
2763 .Attr(GetPaddingAttrString())
2764 .Attr("dilations: list(int) = [1, 1, 1, 1]")
2765 .Attr("padding_list: list(int) = []")
__anon3e672dd84602(InferenceContext* c) 2766 .SetShapeFn([](InferenceContext* c) {
2767 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
2768 ShapeHandle unused;
2769 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
2770 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2771 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
2772 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
2773 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
2774 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
2775 TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused));
2776 c->set_output(1, c->Scalar());
2777 c->set_output(2, c->Scalar());
2778 return Status::OK();
2779 });
2780
2781 // Fusion of Quantized Conv2D, BiasAdd, Sum, and Relu.
2782 REGISTER_OP("QuantizedConv2DWithBiasSumAndRelu")
2783 .Input("input: Tinput")
2784 .Input("filter: Tfilter")
2785 .Input("bias: float")
2786 .Input("min_input: float")
2787 .Input("max_input: float")
2788 .Input("min_filter: float")
2789 .Input("max_filter: float")
2790 .Input("summand: float")
2791 .Output("output: out_type")
2792 .Output("min_output: float")
2793 .Output("max_output: float")
2794 .Attr("Tinput: quantizedtype")
2795 .Attr("Tfilter: quantizedtype")
2796 .Attr("out_type: quantizedtype = DT_QINT32")
2797 .Attr("strides: list(int)")
2798 .Attr(GetPaddingAttrString())
2799 .Attr("dilations: list(int) = [1, 1, 1, 1]")
2800 .Attr("padding_list: list(int) = []")
__anon3e672dd84702(InferenceContext* c) 2801 .SetShapeFn([](InferenceContext* c) {
2802 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
2803 ShapeHandle unused;
2804 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
2805 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2806 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
2807 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
2808 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
2809 c->set_output(1, c->Scalar());
2810 c->set_output(2, c->Scalar());
2811 return Status::OK();
2812 });
2813
2814 REGISTER_OP("QuantizedConv2DWithBiasSumAndReluAndRequantize")
2815 .Input("input: Tinput")
2816 .Input("filter: Tfilter")
2817 .Input("bias: Tbias")
2818 .Input("min_input: float")
2819 .Input("max_input: float")
2820 .Input("min_filter: float")
2821 .Input("max_filter: float")
2822 .Input("min_freezed_output: float")
2823 .Input("max_freezed_output: float")
2824 .Input("summand: Tsummand")
2825 .Input("min_summand: float")
2826 .Input("max_summand: float")
2827 .Output("output: out_type")
2828 .Output("min_output: float")
2829 .Output("max_output: float")
2830 .Attr("Tinput: quantizedtype")
2831 .Attr("Tfilter: quantizedtype")
2832 .Attr("Tbias: {float, qint32}")
2833 .Attr("Tsummand: quantizedtype")
2834 .Attr("out_type: quantizedtype = DT_QUINT8")
2835 .Attr("strides: list(int)")
2836 .Attr(GetPaddingAttrString())
2837 .Attr("dilations: list(int) = [1, 1, 1, 1]")
2838 .Attr("padding_list: list(int) = []")
__anon3e672dd84802(InferenceContext* c) 2839 .SetShapeFn([](InferenceContext* c) {
2840 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
2841 ShapeHandle unused;
2842 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
2843 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2844 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
2845 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
2846 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
2847 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
2848 TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused));
2849 c->set_output(1, c->Scalar());
2850 c->set_output(2, c->Scalar());
2851 return Status::OK();
2852 });
2853
2854 REGISTER_OP("QuantizedConv2DWithBiasSignedSumAndReluAndRequantize")
2855 .Input("input: Tinput")
2856 .Input("filter: Tfilter")
2857 .Input("bias: Tbias")
2858 .Input("min_input: float")
2859 .Input("max_input: float")
2860 .Input("min_filter: float")
2861 .Input("max_filter: float")
2862 .Input("min_freezed_output: float")
2863 .Input("max_freezed_output: float")
2864 .Input("summand: Tsummand")
2865 .Input("min_summand: float")
2866 .Input("max_summand: float")
2867 .Output("output: out_type")
2868 .Output("min_output: float")
2869 .Output("max_output: float")
2870 .Attr("Tinput: quantizedtype")
2871 .Attr("Tfilter: quantizedtype")
2872 .Attr("Tbias: {float, qint32}")
2873 .Attr("Tsummand: quantizedtype")
2874 .Attr("out_type: quantizedtype = DT_QUINT8")
2875 .Attr("strides: list(int)")
2876 .Attr(GetPaddingAttrString())
2877 .Attr("dilations: list(int) = [1, 1, 1, 1]")
2878 .Attr("padding_list: list(int) = []")
__anon3e672dd84902(InferenceContext* c) 2879 .SetShapeFn([](InferenceContext* c) {
2880 TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
2881 ShapeHandle unused;
2882 TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
2883 TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2884 TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
2885 TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
2886 TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
2887 TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
2888 TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused));
2889 c->set_output(1, c->Scalar());
2890 c->set_output(2, c->Scalar());
2891 return Status::OK();
2892 });
2893
2894 } // namespace tensorflow
2895