• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/common_shape_fns.h"
17 #include "tensorflow/core/framework/numeric_op.h"
18 #include "tensorflow/core/framework/op.h"
19 #include "tensorflow/core/framework/shape_inference.h"
20 #include "tensorflow/core/util/mirror_pad_mode.h"
21 #include "tensorflow/core/util/padding.h"
22 #include "tensorflow/core/util/tensor_format.h"
23 
24 namespace tensorflow {
25 
26 using shape_inference::DimensionHandle;
27 using shape_inference::InferenceContext;
28 using shape_inference::ShapeHandle;
29 
30 namespace {
31 
FractionalPoolShapeFn(InferenceContext * c)32 Status FractionalPoolShapeFn(InferenceContext* c) {
33   ShapeHandle input;
34   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
35 
36   std::vector<float> pooling_ratio;
37   TF_RETURN_IF_ERROR(c->GetAttr("pooling_ratio", &pooling_ratio));
38   if (pooling_ratio.size() != 4) {
39     return errors::InvalidArgument(
40         "pooling_ratio field must specify 4 dimensions");
41   }
42   std::vector<DimensionHandle> output_dims;
43   for (int i = 0; i < 4; ++i) {
44     DimensionHandle d = c->Dim(input, i);
45     if (c->ValueKnown(d)) {
46       // This must match the same logic in the kernel function in
47       // core/kernels/fractional_max_pool_op.cc.
48       auto val = static_cast<int64>(floor(c->Value(d) / pooling_ratio[i]));
49       if (val < 0) {
50         return errors::InvalidArgument("Size computed for dim ", i,
51                                        " is negative: ", val);
52       }
53       output_dims.push_back(c->MakeDim(val));
54     } else {
55       output_dims.push_back(c->UnknownDim());
56     }
57   }
58 
59   c->set_output(0, c->MakeShape(output_dims));
60   c->set_output(1, c->Vector(output_dims[1]));
61   c->set_output(2, c->Vector(output_dims[2]));
62   return Status::OK();
63 }
64 
65 }  // namespace
66 
67 // --------------------------------------------------------------------------
68 
69 REGISTER_OP("AvgPool")
70     .Input("value: T")
71     .Output("output: T")
72     .Attr("ksize: list(int) >= 4")
73     .Attr("strides: list(int) >= 4")
74     .Attr(GetPaddingAttrString())
75     .Attr(GetConvnetDataFormatAttrString())
76     .Attr("T: {half, bfloat16, float, double}")
77     .SetShapeFn(shape_inference::AvgPoolShape);
78 
79 REGISTER_OP("AvgPoolGrad")
80     .Input("orig_input_shape: int32")
81     .Input("grad: T")
82     .Output("output: T")
83     .Attr("ksize: list(int) >= 4")
84     .Attr("strides: list(int) >= 4")
85     .Attr(GetPaddingAttrString())
86     .Attr(GetConvnetDataFormatAttrString())
87     .Attr("T: {half, bfloat16, float, double}")
__anon3e672dd80202(InferenceContext* c) 88     .SetShapeFn([](InferenceContext* c) {
89       ShapeHandle s;
90       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
91       TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
92       c->set_output(0, s);
93       return Status::OK();
94     });
95 
96 // --------------------------------------------------------------------------
97 
98 REGISTER_OP("BatchNormWithGlobalNormalization")
99     .Input("t: T")
100     .Input("m: T")
101     .Input("v: T")
102     .Input("beta: T")
103     .Input("gamma: T")
104     .Output("result: T")
105     .Attr("T: numbertype")
106     .Attr("variance_epsilon: float")
107     .Attr("scale_after_normalization: bool")
108     .Deprecated(9, "Use tf.nn.batch_normalization()")
__anon3e672dd80302(InferenceContext* c) 109     .SetShapeFn([](InferenceContext* c) {
110       ShapeHandle input;
111       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
112 
113       DimensionHandle last_dim = c->Dim(input, 3);
114       for (int i = 1; i < 5; ++i) {  // covers m, v, beta, gamma
115         ShapeHandle vec;
116         TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
117         TF_RETURN_IF_ERROR(c->Merge(last_dim, c->Dim(vec, 0), &last_dim));
118       }
119 
120       ShapeHandle out;
121       TF_RETURN_IF_ERROR(c->ReplaceDim(input, 3, last_dim, &out));
122       c->set_output(0, out);
123       return Status::OK();
124     });
125 
126 REGISTER_OP("BatchNormWithGlobalNormalizationGrad")
127     .Input("t: T")
128     .Input("m: T")
129     .Input("v: T")
130     .Input("gamma: T")
131     .Input("backprop: T")
132     .Output("dx: T")
133     .Output("dm: T")
134     .Output("dv: T")
135     .Output("db: T")
136     .Output("dg: T")
137     .Attr("T: numbertype")
138     .Attr("variance_epsilon: float")
139     .Attr("scale_after_normalization: bool")
140     .Deprecated(9, "Use tf.nn.batch_normalization()")
__anon3e672dd80402(InferenceContext* c) 141     .SetShapeFn([](InferenceContext* c) {
142       ShapeHandle input;
143       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
144       TF_RETURN_IF_ERROR(
145           c->Merge(input, c->input(4), &input));  // with backprop
146 
147       DimensionHandle last_dim = c->Dim(input, 3);
148       for (int i = 1; i < 4; ++i) {  // covers m, v, gamma
149         ShapeHandle vec;
150         TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
151         TF_RETURN_IF_ERROR(c->Merge(last_dim, c->Dim(vec, 0), &last_dim));
152       }
153 
154       ShapeHandle dx;
155       TF_RETURN_IF_ERROR(c->ReplaceDim(input, 3, last_dim, &dx));
156       c->set_output(0, dx);
157 
158       ShapeHandle vector_shape = c->Vector(last_dim);
159       c->set_output(1, vector_shape);
160       c->set_output(2, vector_shape);
161       c->set_output(3, vector_shape);
162       c->set_output(4, vector_shape);
163       return Status::OK();
164     });
165 
166 // --------------------------------------------------------------------------
167 
168 REGISTER_OP("FusedBatchNorm")
169     .Input("x: T")
170     .Input("scale: T")
171     .Input("offset: T")
172     .Input("mean: T")
173     .Input("variance: T")
174     .Output("y: T")
175     .Output("batch_mean: T")
176     .Output("batch_variance: T")
177     .Output("reserve_space_1: T")
178     .Output("reserve_space_2: T")
179     .Attr("T: {float}")
180     .Attr("epsilon: float = 0.0001")
181     .Attr(GetConvnetDataFormatAttrString())
182     .Attr("is_training: bool = true")
183     .SetShapeFn(shape_inference::FusedBatchNormShape);
184 
185 REGISTER_OP("FusedBatchNormV2")
186     .Input("x: T")
187     .Input("scale: U")
188     .Input("offset: U")
189     .Input("mean: U")
190     .Input("variance: U")
191     .Output("y: T")
192     .Output("batch_mean: U")
193     .Output("batch_variance: U")
194     .Output("reserve_space_1: U")
195     .Output("reserve_space_2: U")
196     .Attr("T: {half, bfloat16, float}")
197     .Attr("U: {float}")
198     .Attr("epsilon: float = 0.0001")
199     .Attr(GetConvnetDataFormatAttrString())
200     .Attr("is_training: bool = true")
201     .SetShapeFn(shape_inference::FusedBatchNormShape);
202 
203 REGISTER_OP("FusedBatchNormGrad")
204     .Input("y_backprop: T")
205     .Input("x: T")
206     .Input("scale: T")
207     .Input("reserve_space_1: T")
208     .Input("reserve_space_2: T")
209     .Output("x_backprop: T")
210     .Output("scale_backprop: T")
211     .Output("offset_backprop: T")
212     .Output("reserve_space_3: T")
213     .Output("reserve_space_4: T")
214     .Attr("T: {float}")
215     .Attr("epsilon: float = 0.0001")
216     .Attr(GetConvnetDataFormatAttrString())
217     .Attr("is_training: bool = true")
218     .SetShapeFn(shape_inference::FusedBatchNormGradShape);
219 
220 REGISTER_OP("FusedBatchNormGradV2")
221     .Input("y_backprop: T")
222     .Input("x: T")
223     .Input("scale: float")
224     .Input("reserve_space_1: U")
225     .Input("reserve_space_2: U")
226     .Output("x_backprop: T")
227     .Output("scale_backprop: U")
228     .Output("offset_backprop: U")
229     .Output("reserve_space_3: U")
230     .Output("reserve_space_4: U")
231     .Attr("T: {half, bfloat16, float}")
232     .Attr("U: {float}")
233     .Attr("epsilon: float = 0.0001")
234     .Attr(GetConvnetDataFormatAttrString())
235     .Attr("is_training: bool = true")
236     .SetShapeFn(shape_inference::FusedBatchNormGradShape);
237 
238 // --------------------------------------------------------------------------
239 
240 REGISTER_OP("BiasAdd")
241     .Attr("T: numbertype")
242     .Input("value: T")
243     .Input("bias: T")
244     .Attr(GetConvnetDataFormatAttrString())
245     .Output("output: T")
246     .SetShapeFn(shape_inference::BiasAddShape);
247 // --------------------------------------------------------------------------
248 
249 REGISTER_OP("BiasAddGrad")
250     .Attr("T: numbertype")
251     .Input("out_backprop: T")
252     .Attr(GetConvnetDataFormatAttrString())
253     .Output("output: T")
254     .SetShapeFn(shape_inference::BiasAddGradShape);
255 // --------------------------------------------------------------------------
256 
257 REGISTER_OP("BiasAddV1")
258     .Attr("T: numbertype")
259     .Input("value: T")
260     .Input("bias: T")
261     .Output("output: T")
262     .SetShapeFn(shape_inference::BiasAddShape);
263 // --------------------------------------------------------------------------
264 
265 REGISTER_OP("Conv2D")
266     .Input("input: T")
267     .Input("filter: T")
268     .Output("output: T")
269     .Attr("T: {half, bfloat16, float, double}")
270     .Attr("strides: list(int)")
271     .Attr("use_cudnn_on_gpu: bool = true")
272     .Attr(GetPaddingAttrStringWithExplicit())
273     .Attr(GetExplicitPaddingsAttrString())
274     .Attr(GetConvnetDataFormatAttrString())
275     .Attr("dilations: list(int) = [1, 1, 1, 1]")
276     .SetShapeFn(shape_inference::Conv2DShapeWithExplicitPadding);
277 
278 REGISTER_OP("Conv2DBackpropInput")
279     .Input("input_sizes: int32")
280     .Input("filter: T")
281     .Input("out_backprop: T")
282     .Output("output: T")
283     .Attr("T: {half, bfloat16, float, double}")
284     .Attr("strides: list(int)")
285     .Attr("use_cudnn_on_gpu: bool = true")
286     .Attr(GetPaddingAttrStringWithExplicit())
287     .Attr(GetExplicitPaddingsAttrString())
288     .Attr(GetConvnetDataFormatAttrString())
289     .Attr("dilations: list(int) = [1, 1, 1, 1]")
__anon3e672dd80502(InferenceContext* c) 290     .SetShapeFn([](InferenceContext* c) {
291       ShapeHandle s;
292       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
293       TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
294       c->set_output(0, s);
295       return Status::OK();
296     });
297 
298 // TODO(jeff): Instead of 'use_cudnn_for_gpu', maybe we should have a
299 // more general string attribute ('kernel_impl'?) that can be used to
300 // select among several possible implementations.
301 REGISTER_OP("Conv2DBackpropFilter")
302     .Input("input: T")
303     .Input("filter_sizes: int32")
304     .Input("out_backprop: T")
305     .Output("output: T")
306     .Attr("T: {half, bfloat16, float, double}")
307     .Attr("strides: list(int)")
308     .Attr("use_cudnn_on_gpu: bool = true")
309     .Attr(GetPaddingAttrStringWithExplicit())
310     .Attr(GetExplicitPaddingsAttrString())
311     .Attr(GetConvnetDataFormatAttrString())
312     .Attr("dilations: list(int) = [1, 1, 1, 1]")
__anon3e672dd80602(InferenceContext* c) 313     .SetShapeFn([](InferenceContext* c) {
314       ShapeHandle s;
315       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
316       TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
317       c->set_output(0, s);
318       return Status::OK();
319     });
320 
321 REGISTER_OP("_FusedConv2D")
322     .Input("input: T")
323     .Input("filter: T")
324     .Input("args: num_args * T")
325     .Output("output: T")
326     .Attr("T: {float, double}")
327     .Attr("num_args: int >= 0")
328     .Attr("strides: list(int)")
329     .Attr(GetPaddingAttrString())
330     .Attr(GetConvnetDataFormatAttrString())
331     .Attr("dilations: list(int) = [1, 1, 1, 1]")
332     .Attr("use_cudnn_on_gpu: bool = true")
333     .Attr("fused_ops: list(string) = []")
334     // Attributes for the FusedBatchNorm ------------------------------------ //
335     .Attr("epsilon: float = 0.0001")
336     // ---------------------------------------------------------------------- //
337     .SetShapeFn(shape_inference::Conv2DShape)
338     .Doc(R"doc(
339 *NOTE*: Do not invoke this operator directly in Python. Grappler is
340 expected to create these operators.
341 )doc");
342 
343 namespace {
344 
CommonFusedConvCalculations(InferenceContext * c,bool has_resize)345 Status CommonFusedConvCalculations(InferenceContext* c, bool has_resize) {
346   ShapeHandle input;
347   TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
348 
349   ShapeHandle resized = input;
350   int paddings_index = 1;
351   int filter_index = 2;
352   if (has_resize) {
353     paddings_index = 2;
354     filter_index = 3;
355 
356     ShapeHandle unused_size;
357     TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->Vector(2), &unused_size));
358 
359     const Tensor* size = c->input_tensor(1);
360     DimensionHandle new_height = c->UnknownDim();
361     DimensionHandle new_width = c->UnknownDim();
362     if (size != nullptr) {
363       new_height = c->MakeDim(size->flat<int32>()(0));
364       new_width = c->MakeDim(size->flat<int32>()(1));
365     }
366     TF_RETURN_IF_ERROR(c->ReplaceDim(resized, 1, new_height, &resized));
367     TF_RETURN_IF_ERROR(c->ReplaceDim(resized, 2, new_width, &resized));
368   }
369 
370   ShapeHandle paddings;
371   TF_RETURN_IF_ERROR(c->WithRank(c->input(paddings_index), 2, &paddings));
372   TF_RETURN_IF_ERROR(
373       c->WithRank(resized, c->Value(c->Dim(paddings, 0)), &resized));
374   TF_RETURN_IF_ERROR(
375       c->Merge(paddings, c->Matrix(c->Rank(resized), 2), &paddings));
376 
377   const Tensor* paddings_t = c->input_tensor(paddings_index);
378   ShapeHandle padded;
379   if (paddings_t != nullptr) {
380     std::vector<DimensionHandle> output_dims;
381     for (int i = 0; i < 4; ++i) {
382       DimensionHandle dim = c->Dim(resized, i);
383       int64 p0 = static_cast<int64>(paddings_t->matrix<int32>()(i, 0));
384       int64 p1 = static_cast<int64>(paddings_t->matrix<int32>()(i, 1));
385       if (p0 < 0 || p1 < 0) {
386         return errors::InvalidArgument("Paddings must be non-negative");
387       }
388 
389       TF_RETURN_IF_ERROR(c->Add(dim, p0 + p1, &dim));
390       output_dims.push_back(dim);
391     }
392     padded = c->MakeShape(output_dims);
393   } else {
394     padded = c->UnknownShapeOfRank(4);
395   }
396 
397   // Work out the convolution's effect with 'padded' as the input.
398   ShapeHandle filter;
399   TF_RETURN_IF_ERROR(c->WithRank(c->input(filter_index), 4, &filter));
400   std::vector<int32> strides;
401   TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
402   if (strides.size() != 4) {
403     return errors::InvalidArgument(
404         "Operation requires the stride attribute to contain 4 values, but ",
405         "got: ", strides.size());
406   }
407 
408   int32 stride_rows = strides[1];
409   int32 stride_cols = strides[2];
410 
411   DimensionHandle batch_size_dim = c->Dim(padded, 0);
412   DimensionHandle in_rows_dim = c->Dim(padded, 1);
413   DimensionHandle in_cols_dim = c->Dim(padded, 2);
414   DimensionHandle filter_rows_dim = c->Dim(filter, 0);
415   DimensionHandle filter_cols_dim = c->Dim(filter, 1);
416   DimensionHandle output_depth_dim = c->Dim(filter, 3);
417 
418   DimensionHandle unused;
419   TF_RETURN_IF_ERROR(c->Merge(c->Dim(padded, 3), c->Dim(filter, 2), &unused));
420 
421   Padding padding;
422   TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
423 
424   DimensionHandle output_rows, output_cols;
425   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
426       c, in_rows_dim, filter_rows_dim, stride_rows, padding, &output_rows));
427   TF_RETURN_IF_ERROR(GetWindowedOutputSizeFromDims(
428       c, in_cols_dim, filter_cols_dim, stride_cols, padding, &output_cols));
429 
430   ShapeHandle output_shape = c->MakeShape(
431       {batch_size_dim, output_rows, output_cols, output_depth_dim});
432   c->set_output(0, output_shape);
433   return Status::OK();
434 }
435 
436 }  // namespace
437 
438 REGISTER_OP("DataFormatDimMap")
439     .Input("x: T")
440     .Output("y: T")
441     .Attr("T: {int32, int64} = DT_INT32")
442     .Attr("src_format: string = 'NHWC'")
443     .Attr("dst_format: string = 'NCHW'")
444     .SetShapeFn(shape_inference::UnchangedShape);
445 
446 REGISTER_OP("DataFormatVecPermute")
447     .Input("x: T")
448     .Output("y: T")
449     .Attr("T: {int32, int64} = DT_INT32")
450     .Attr("src_format: string = 'NHWC'")
451     .Attr("dst_format: string = 'NCHW'")
452     .SetShapeFn(shape_inference::UnchangedShape);
453 
454 REGISTER_OP("FusedResizeAndPadConv2D")
455     .Input("input: T")
456     .Input("size: int32")
457     .Input("paddings: int32")
458     .Input("filter: T")
459     .Output("output: T")
460     .Attr("T: {half, float, double}")
461     .Attr("resize_align_corners: bool = false")
462     .Attr(GetMirrorPadModeAttrString())
463     .Attr("strides: list(int)")
464     .Attr(GetPaddingAttrString())
__anon3e672dd80802(InferenceContext* c) 465     .SetShapeFn([](InferenceContext* c) {
466       return CommonFusedConvCalculations(c, true /* has_resize */);
467     });
468 
469 REGISTER_OP("FusedPadConv2D")
470     .Input("input: T")
471     .Input("paddings: int32")
472     .Input("filter: T")
473     .Output("output: T")
474     .Attr("T: {half, float, double}")
475     .Attr(GetMirrorPadModeAttrString())
476     .Attr("strides: list(int)")
477     .Attr(GetPaddingAttrString())
__anon3e672dd80902(InferenceContext* c) 478     .SetShapeFn([](InferenceContext* c) {
479       return CommonFusedConvCalculations(c, false /* has_resize */);
480     });
481 
482 // --------------------------------------------------------------------------
483 
484 REGISTER_OP("DepthwiseConv2dNative")
485     .Input("input: T")
486     .Input("filter: T")
487     .Output("output: T")
488     .Attr("T: {half, bfloat16, float, double}")
489     .Attr("strides: list(int)")
490     .Attr(GetPaddingAttrString())
491     .Attr(GetConvnetDataFormatAttrString())
492     .Attr("dilations: list(int) = [1, 1, 1, 1]")
493     .SetShapeFn(shape_inference::DepthwiseConv2DNativeShape);
494 
495 REGISTER_OP("DepthwiseConv2dNativeBackpropInput")
496     .Input("input_sizes: int32")
497     .Input("filter: T")
498     .Input("out_backprop: T")
499     .Output("output: T")
500     .Attr("T: {half, bfloat16, float, double}")
501     .Attr("strides: list(int)")
502     .Attr(GetPaddingAttrString())
503     .Attr(GetConvnetDataFormatAttrString())
504     .Attr("dilations: list(int) = [1, 1, 1, 1]")
__anon3e672dd80a02(InferenceContext* c) 505     .SetShapeFn([](InferenceContext* c) {
506       ShapeHandle s;
507       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
508       TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
509       c->set_output(0, s);
510       return Status::OK();
511     });
512 
513 REGISTER_OP("DepthwiseConv2dNativeBackpropFilter")
514     .Input("input: T")
515     .Input("filter_sizes: int32")
516     .Input("out_backprop: T")
517     .Output("output: T")
518     .Attr("T: {half, bfloat16, float, double}")
519     .Attr("strides: list(int)")
520     .Attr(GetPaddingAttrString())
521     .Attr(GetConvnetDataFormatAttrString())
522     .Attr("dilations: list(int) = [1, 1, 1, 1]")
__anon3e672dd80b02(InferenceContext* c) 523     .SetShapeFn([](InferenceContext* c) {
524       ShapeHandle s;
525       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
526       TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
527       c->set_output(0, s);
528       return Status::OK();
529     });
530 
531 // --------------------------------------------------------------------------
532 REGISTER_OP("Conv3D")
533     .Input("input: T")
534     .Input("filter: T")
535     .Output("output: T")
536     .Attr("T: {half, bfloat16, float, double}")
537     .Attr("strides: list(int) >= 5")
538     .Attr(GetPaddingAttrString())
539     .Attr(GetConvnet3dDataFormatAttrString())
540     .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
541     .SetShapeFn(shape_inference::Conv3DShape);
542 
543 REGISTER_OP("Conv3DBackpropInput")
544     .Input("input: T")
545     .Input("filter: T")
546     .Input("out_backprop: T")
547     .Output("output: T")
548     .Attr("T: {half, float, double}")
549     .Attr("strides: list(int) >= 5")
550     .Attr(GetPaddingAttrString())
551     .Deprecated(10, "Use Conv3DBackpropInputV2")
552     .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
__anon3e672dd80c02(InferenceContext* c) 553     .SetShapeFn([](InferenceContext* c) {
554       return UnchangedShapeWithRank(c, 5);
555     });
556 
557 REGISTER_OP("Conv3DBackpropFilter")
558     .Input("input: T")
559     .Input("filter: T")
560     .Input("out_backprop: T")
561     .Output("output: T")
562     .Attr("T: {half, float, double}")
563     .Attr("strides: list(int) >= 5")
564     .Attr(GetPaddingAttrString())
565     .Deprecated(10, "Use Conv3DBackpropFilterV2")
566     .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
__anon3e672dd80d02(InferenceContext* c) 567     .SetShapeFn([](InferenceContext* c) {
568       ShapeHandle out;
569       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 5, &out));
570       c->set_output(0, out);
571       return Status::OK();
572     });
573 
574 REGISTER_OP("Conv3DBackpropInputV2")
575     .Input("input_sizes: Tshape")
576     .Input("filter: T")
577     .Input("out_backprop: T")
578     .Output("output: T")
579     .Attr("T: {half, bfloat16, float, double}")
580     .Attr("strides: list(int) >= 5")
581     .Attr(GetPaddingAttrString())
582     .Attr(GetConvnet3dDataFormatAttrString())
583     .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
584     .Attr("Tshape: {int32, int64} = DT_INT32")
__anon3e672dd80e02(InferenceContext* c) 585     .SetShapeFn([](InferenceContext* c) {
586       ShapeHandle s;
587       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
588       TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
589       c->set_output(0, s);
590       return Status::OK();
591     });
592 
593 REGISTER_OP("Conv3DBackpropFilterV2")
594     .Input("input: T")
595     .Input("filter_sizes: int32")
596     .Input("out_backprop: T")
597     .Output("output: T")
598     .Attr("T: {half, bfloat16, float, double}")
599     .Attr("strides: list(int) >= 5")
600     .Attr(GetPaddingAttrString())
601     .Attr(GetConvnet3dDataFormatAttrString())
602     .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
__anon3e672dd80f02(InferenceContext* c) 603     .SetShapeFn([](InferenceContext* c) {
604       ShapeHandle s;
605       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
606       TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
607       c->set_output(0, s);
608       return Status::OK();
609     });
610 
611 // --------------------------------------------------------------------------
612 
613 REGISTER_OP("AvgPool3D")
614     .Input("input: T")
615     .Output("output: T")
616     .Attr("ksize: list(int) >= 5")
617     .Attr("strides: list(int) >= 5")
618     .Attr(GetPaddingAttrString())
619     .Attr(GetConvnet3dDataFormatAttrString())
620     .Attr("T: {half, bfloat16, float, double}")
621     .SetShapeFn(shape_inference::Pool3DShape);
622 
623 REGISTER_OP("AvgPool3DGrad")
624     .Input("orig_input_shape: int32")
625     .Input("grad: T")
626     .Output("output: T")
627     .Attr("ksize: list(int) >= 5")
628     .Attr("strides: list(int) >= 5")
629     .Attr(GetPaddingAttrString())
630     .Attr(GetConvnet3dDataFormatAttrString())
631     .Attr("T: {half, bfloat16, float, double}")
__anon3e672dd81002(InferenceContext* c) 632     .SetShapeFn([](InferenceContext* c) {
633       ShapeHandle s;
634       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
635       TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
636       c->set_output(0, s);
637       return Status::OK();
638     });
639 
640 // --------------------------------------------------------------------------
641 
642 REGISTER_OP("MaxPool3D")
643     .Input("input: T")
644     .Output("output: T")
645     .Attr("ksize: list(int) >= 5")
646     .Attr("strides: list(int) >= 5")
647     .Attr(GetPaddingAttrString())
648     .Attr(GetConvnet3dDataFormatAttrString())
649     .Attr("T: {half, bfloat16, float}")
650     .SetShapeFn(shape_inference::Pool3DShape);
651 
652 REGISTER_OP("MaxPool3DGrad")
653     .Input("orig_input: TInput")
654     .Input("orig_output: TInput")
655     .Input("grad: T")
656     .Output("output: T")
657     .Attr("ksize: list(int) >= 5")
658     .Attr("strides: list(int) >= 5")
659     .Attr(GetPaddingAttrString())
660     .Attr(GetConvnet3dDataFormatAttrString())
661     .Attr("T: {half, bfloat16, float} = DT_FLOAT")
662     .Attr("TInput: {half, bfloat16, float} = DT_FLOAT")
__anon3e672dd81102(InferenceContext* c) 663     .SetShapeFn([](InferenceContext* c) {
664       return UnchangedShapeWithRank(c, 5);
665     });
666 
667 REGISTER_OP("MaxPool3DGradGrad")
668     .Input("orig_input: T")
669     .Input("orig_output: T")
670     .Input("grad: T")
671     .Output("output: T")
672     .Attr("ksize: list(int) >= 5 ")
673     .Attr("strides: list(int) >= 5")
674     .Attr(GetPaddingAttrString())
675     .Attr(GetConvnet3dDataFormatAttrString())
676     .Attr("T: realnumbertype")
__anon3e672dd81202(InferenceContext* c) 677     .SetShapeFn([](InferenceContext* c) {
678       TF_RETURN_IF_ERROR(shape_inference::Pool3DShape(c));
679       ShapeHandle unused;
680       // Validate 'orig_input' is the same shape as 'grad'
681       TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(2), &unused));
682       // Validate 'orig_output' is same shape as 'output'
683       TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->output(0), &unused));
684       return Status::OK();
685     });
686 
687 // --------------------------------------------------------------------------
688 
689 REGISTER_OP("L2Loss")
690     .Input("t: T")
691     .Output("output: T")
692     .Attr("T: {half, bfloat16, float, double}")
693     .SetShapeFn(shape_inference::ScalarShape);
694 
695 // --------------------------------------------------------------------------
696 
697 REGISTER_OP("LRN")
698     .Input("input: T")
699     .Output("output: T")
700     .Attr("depth_radius: int = 5")
701     .Attr("bias: float = 1.0")
702     .Attr("alpha: float = 1.0")
703     .Attr("beta: float = 0.5")
704     .Attr("T: {half, bfloat16, float} = DT_FLOAT")
__anon3e672dd81302(InferenceContext* c) 705     .SetShapeFn([](InferenceContext* c) {
706       return UnchangedShapeWithRank(c, 4);
707     });
708 
709 REGISTER_OP("LRNGrad")
710     .Input("input_grads: T")
711     .Input("input_image: T")
712     .Input("output_image: T")
713     .Output("output: T")
714     .Attr("depth_radius: int = 5")
715     .Attr("bias: float = 1.0")
716     .Attr("alpha: float = 1.0")
717     .Attr("beta: float = 0.5")
718     .Attr("T: {half, bfloat16, float} = DT_FLOAT")
__anon3e672dd81402(InferenceContext* c) 719     .SetShapeFn([](InferenceContext* c) {
720       ShapeHandle s;
721       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &s));  // input_grads
722       TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s));     // input_image
723       TF_RETURN_IF_ERROR(c->Merge(s, c->input(2), &s));     // output_image
724       c->set_output(0, s);
725       return Status::OK();
726     });
727 
728 // --------------------------------------------------------------------------
729 
730 REGISTER_OP("MaxPool")
731     .Attr(
732         "T: {half, bfloat16, float, double, int32, int64, uint8, int16, int8, "
733         "uint16, qint8} = DT_FLOAT")
734     .Attr("ksize: list(int) >= 4")
735     .Attr("strides: list(int) >= 4")
736     .Attr(GetPaddingAttrString())
737     .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
738     .Input("input: T")
739     .Output("output: T")
740     .SetShapeFn(shape_inference::MaxPoolShape);
741 
742 REGISTER_OP("MaxPoolV2")
743     .Attr(
744         "T: {half, bfloat16, float, double, int32, int64, uint8, int16, int8, "
745         "uint16, qint8} = DT_FLOAT")
746     .Attr(GetPaddingAttrString())
747     .Attr("data_format: {'NHWC', 'NCHW', 'NCHW_VECT_C'} = 'NHWC'")
748     .Input("input: T")
749     .Input("ksize: int32")
750     .Input("strides: int32")
751     .Output("output: T")
__anon3e672dd81502(InferenceContext* c) 752     .SetShapeFn([](InferenceContext* c) {
753       TF_RETURN_IF_ERROR(shape_inference::MaxPoolV2Shape(c, 3));
754       return Status::OK();
755     });
756 
757 REGISTER_OP("MaxPoolGrad")
758     .Attr("ksize: list(int) >= 4")
759     .Attr("strides: list(int) >= 4")
760     .Attr(GetPaddingAttrString())
761     .Attr(GetConvnetDataFormatAttrString())
762     .Input("orig_input: T")
763     .Input("orig_output: T")
764     .Input("grad: T")
765     .Output("output: T")
766     .Attr("T: realnumbertype = DT_FLOAT")
__anon3e672dd81602(InferenceContext* c) 767     .SetShapeFn([](InferenceContext* c) {
768       return UnchangedShapeWithRank(c, 4);
769     });
770 
771 REGISTER_OP("MaxPoolGradV2")
772     .Attr(GetPaddingAttrString())
773     .Attr(GetConvnetDataFormatAttrString())
774     .Input("orig_input: T")
775     .Input("orig_output: T")
776     .Input("grad: T")
777     .Input("ksize: int32")
778     .Input("strides: int32")
779     .Output("output: T")
780     .Attr("T: realnumbertype = DT_FLOAT")
__anon3e672dd81702(InferenceContext* c) 781     .SetShapeFn([](InferenceContext* c) {
782       return UnchangedShapeWithRank(c, 4);
783     });
784 
785 REGISTER_OP("MaxPoolGradGrad")
786     .Attr("ksize: list(int) >= 4")
787     .Attr("strides: list(int) >= 4")
788     .Attr(GetPaddingAttrString())
789     .Attr(GetConvnetDataFormatAttrString())
790     .Input("orig_input: T")
791     .Input("orig_output: T")
792     .Input("grad: T")
793     .Output("output: T")
794     .Attr("T: realnumbertype")
__anon3e672dd81802(InferenceContext* c) 795     .SetShapeFn([](InferenceContext* c) {
796       TF_RETURN_IF_ERROR(shape_inference::MaxPoolShape(c));
797       ShapeHandle unused;
798       // Validate 'orig_input' is the same shape as 'grad'
799       TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(2), &unused));
800       // Validate 'orig_output' is same shape as 'output'
801       TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->output(0), &unused));
802       return Status::OK();
803     });
804 
805 REGISTER_OP("MaxPoolGradGradV2")
806     .Attr(GetPaddingAttrString())
807     .Attr(GetConvnetDataFormatAttrString())
808     .Input("orig_input: T")
809     .Input("orig_output: T")
810     .Input("grad: T")
811     .Input("ksize: int32")
812     .Input("strides: int32")
813     .Output("output: T")
814     .Attr("T: realnumbertype")
__anon3e672dd81902(InferenceContext* c) 815     .SetShapeFn([](InferenceContext* c) {
816       TF_RETURN_IF_ERROR(shape_inference::MaxPoolV2Shape(c, 5));
817       ShapeHandle unused;
818       // Validate 'orig_input' is the same shape as 'grad'
819       TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(2), &unused));
820       // Validate 'orig_output' is same shape as 'output'
821       TF_RETURN_IF_ERROR(c->Merge(c->input(1), c->output(0), &unused));
822       return Status::OK();
823     });
824 
825 REGISTER_OP("MaxPoolWithArgmax")
826     .Attr("ksize: list(int) >= 4")
827     .Attr("strides: list(int) >= 4")
828     .Attr("Targmax: {int32, int64} = DT_INT64")
829     .Attr(GetPaddingAttrString())
830     .Attr("include_batch_in_index: bool = false")
831     .Input("input: T")
832     .Output("output: T")
833     .Output("argmax: Targmax")
834     .Attr("T: realnumbertype")
__anon3e672dd81a02(InferenceContext* c) 835     .SetShapeFn([](InferenceContext* c) {
836       TF_RETURN_IF_ERROR(shape_inference::MaxPoolShape(c));
837       c->set_output(1, c->output(0));
838       return Status::OK();
839     });
840 
841 REGISTER_OP("MaxPoolGradWithArgmax")
842     .Attr("ksize: list(int) >= 4")
843     .Attr("strides: list(int) >= 4")
844     .Attr(GetPaddingAttrString())
845     .Attr("include_batch_in_index: bool = false")
846     .Attr("Targmax: {int32, int64}")
847     .Input("input: T")
848     .Input("grad: T")
849     .Input("argmax: Targmax")
850     .Output("output: T")
851     .Attr("T: realnumbertype")
__anon3e672dd81b02(InferenceContext* c) 852     .SetShapeFn([](InferenceContext* c) {
853       return UnchangedShapeWithRank(c, 4);
854     });
855 
856 REGISTER_OP("MaxPoolGradGradWithArgmax")
857     .Attr("ksize: list(int) >= 4")
858     .Attr("strides: list(int) >= 4")
859     .Attr(GetPaddingAttrString())
860     .Attr("include_batch_in_index: bool = false")
861     .Attr("Targmax: {int32, int64}")
862     .Input("input: T")
863     .Input("grad: T")
864     .Input("argmax: Targmax")
865     .Output("output: T")
866     .Attr("T: realnumbertype")
__anon3e672dd81c02(InferenceContext* c) 867     .SetShapeFn([](InferenceContext* c) {
868       TF_RETURN_IF_ERROR(shape_inference::MaxPoolShape(c));
869       ShapeHandle unused;
870       // Validate 'orig_input' is the same shape as 'grad'
871       TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &unused));
872       // Validate 'argmax' is same shape as 'output'
873       TF_RETURN_IF_ERROR(c->Merge(c->input(2), c->output(0), &unused));
874       return Status::OK();
875     });
876 
877 // --------------------------------------------------------------------------
878 
879 REGISTER_OP("Dilation2D")
880     .Input("input: T")
881     .Input("filter: T")
882     .Output("output: T")
883     .Attr("T: realnumbertype")
884     .Attr("strides: list(int) >= 4")
885     .Attr("rates: list(int) >= 4")
886     .Attr(GetPaddingAttrString())
__anon3e672dd81d02(InferenceContext* c) 887     .SetShapeFn([](InferenceContext* c) {
888       ShapeHandle input_shape;
889       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
890       ShapeHandle filter_shape;
891       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 3, &filter_shape));
892 
893       std::vector<int32> strides;
894       TF_RETURN_IF_ERROR(c->GetAttr("strides", &strides));
895       if (strides.size() != 4) {
896         return errors::InvalidArgument(
897             "Dilation2D requires the stride attribute to contain 4 values, but "
898             "got: ",
899             strides.size());
900       }
901 
902       std::vector<int32> rates;
903       TF_RETURN_IF_ERROR(c->GetAttr("rates", &rates));
904       if (rates.size() != 4) {
905         return errors::InvalidArgument(
906             "Dilation2D requires the rates attribute to contain 4 values, but "
907             "got: ",
908             rates.size());
909       }
910 
911       int32 stride_rows = strides[1];
912       int32 stride_cols = strides[2];
913 
914       int32 rate_rows = rates[1];
915       int32 rate_cols = rates[2];
916 
917       DimensionHandle batch_size_dim = c->Dim(input_shape, 0);
918       DimensionHandle in_rows_dim = c->Dim(input_shape, 1);
919       DimensionHandle in_cols_dim = c->Dim(input_shape, 2);
920       DimensionHandle filter_rows_dim = c->Dim(filter_shape, 0);
921       DimensionHandle filter_cols_dim = c->Dim(filter_shape, 1);
922       DimensionHandle output_depth_dim = c->Dim(filter_shape, 2);
923 
924       if (!c->ValueKnown(in_rows_dim) || !c->ValueKnown(in_cols_dim) ||
925           !c->ValueKnown(filter_rows_dim) || !c->ValueKnown(filter_cols_dim)) {
926         ShapeHandle output_shape =
927             c->MakeShape({batch_size_dim, InferenceContext::kUnknownDim,
928                           InferenceContext::kUnknownDim, output_depth_dim});
929         c->set_output(0, output_shape);
930         return Status::OK();
931       }
932       DimensionHandle unused;
933       TF_RETURN_IF_ERROR(
934           c->Merge(c->Dim(input_shape, 3), output_depth_dim, &unused));
935 
936       auto in_rows = c->Value(in_rows_dim);
937       auto in_cols = c->Value(in_cols_dim);
938       auto filter_rows = c->Value(filter_rows_dim);
939       auto filter_cols = c->Value(filter_cols_dim);
940       auto filter_rows_eff = filter_rows + (filter_rows - 1) * (rate_rows - 1);
941       auto filter_cols_eff = filter_cols + (filter_cols - 1) * (rate_cols - 1);
942 
943       Padding padding;
944       TF_RETURN_IF_ERROR(c->GetAttr("padding", &padding));
945 
946       int64 output_rows, output_cols;
947       int64 padding_before, padding_after;
948       TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
949           in_rows, filter_rows_eff, stride_rows, padding, &output_rows,
950           &padding_before, &padding_after));
951       TF_RETURN_IF_ERROR(GetWindowedOutputSizeVerbose(
952           in_cols, filter_cols_eff, stride_cols, padding, &output_cols,
953           &padding_before, &padding_after));
954 
955       ShapeHandle output_shape = c->MakeShape(
956           {batch_size_dim, output_rows, output_cols, output_depth_dim});
957       c->set_output(0, output_shape);
958       return Status::OK();
959     });
960 
961 REGISTER_OP("Dilation2DBackpropInput")
962     .Input("input: T")
963     .Input("filter: T")
964     .Input("out_backprop: T")
965     .Output("in_backprop: T")
966     .Attr("T: realnumbertype")
967     .Attr("strides: list(int) >= 4")
968     .Attr("rates: list(int) >= 4")
969     .Attr(GetPaddingAttrString())
970     .SetShapeFn(shape_inference::UnchangedShape);
971 
972 REGISTER_OP("Dilation2DBackpropFilter")
973     .Input("input: T")
974     .Input("filter: T")
975     .Input("out_backprop: T")
976     .Output("filter_backprop: T")
977     .Attr("T: realnumbertype")
978     .Attr("strides: list(int) >= 4")
979     .Attr("rates: list(int) >= 4")
980     .Attr(GetPaddingAttrString())
__anon3e672dd81e02(InferenceContext* c) 981     .SetShapeFn([](InferenceContext* c) {
982       c->set_output(0, c->input(1));
983       return Status::OK();
984     });
985 
986 // --------------------------------------------------------------------------
987 
988 REGISTER_OP("Relu")
989     .Input("features: T")
990     .Output("activations: T")
991     .Attr("T: {realnumbertype, qint8}")
992     .SetShapeFn(shape_inference::UnchangedShape);
993 
994 REGISTER_OP("ReluGrad")
995     .Input("gradients: T")
996     .Input("features: T")
997     .Output("backprops: T")
998     .Attr("T: realnumbertype")
999     .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
1000 
1001 REGISTER_OP("Relu6")
1002     .Input("features: T")
1003     .Output("activations: T")
1004     .Attr("T: realnumbertype")
1005     .SetShapeFn(shape_inference::UnchangedShape);
1006 
1007 REGISTER_OP("Relu6Grad")
1008     .Input("gradients: T")
1009     .Input("features: T")
1010     .Output("backprops: T")
1011     .Attr("T: realnumbertype")
1012     .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
1013 
1014 REGISTER_OP("LeakyRelu")
1015     .Input("features: T")
1016     .Output("activations: T")
1017     .Attr("alpha: float = 0.2")
1018     .Attr("T: {half, bfloat16, float, double} = DT_FLOAT")
1019     .SetShapeFn(shape_inference::UnchangedShape);
1020 
1021 REGISTER_OP("LeakyReluGrad")
1022     .Input("gradients: T")
1023     .Input("features: T")
1024     .Output("backprops: T")
1025     .Attr("alpha: float = 0.2")
1026     .Attr("T: {half, bfloat16, float, double} = DT_FLOAT")
1027     .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
1028 
1029 REGISTER_OP("Elu")
1030     .Input("features: T")
1031     .Output("activations: T")
1032     .Attr("T: {half, bfloat16, float, double}")
1033     .SetShapeFn(shape_inference::UnchangedShape);
1034 
1035 REGISTER_OP("EluGrad")
1036     .Input("gradients: T")
1037     .Input("outputs: T")
1038     .Output("backprops: T")
1039     .Attr("T: {half, bfloat16, float, double}")
1040     .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
1041 
1042 REGISTER_OP("Selu")
1043     .Input("features: T")
1044     .Output("activations: T")
1045     .Attr("T: {half, bfloat16, float, double}")
1046     .SetShapeFn(shape_inference::UnchangedShape);
1047 
1048 REGISTER_OP("SeluGrad")
1049     .Input("gradients: T")
1050     .Input("outputs: T")
1051     .Output("backprops: T")
1052     .Attr("T: {half, bfloat16, float, double}")
1053     .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
1054 
1055 REGISTER_OP("Softplus")
1056     .Input("features: T")
1057     .Output("activations: T")
1058     .Attr("T: {half, bfloat16, float, double}")
1059     .SetShapeFn(shape_inference::UnchangedShape);
1060 
1061 REGISTER_OP("SoftplusGrad")
1062     .Input("gradients: T")
1063     .Input("features: T")
1064     .Output("backprops: T")
1065     .Attr("T: {half, bfloat16, float, double}")
1066     .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
1067 
1068 REGISTER_OP("Softsign")
1069     .Input("features: T")
1070     .Output("activations: T")
1071     .Attr("T: {half, bfloat16, float, double}")
1072     .SetShapeFn(shape_inference::UnchangedShape);
1073 
1074 REGISTER_OP("SoftsignGrad")
1075     .Input("gradients: T")
1076     .Input("features: T")
1077     .Output("backprops: T")
1078     .Attr("T: {half, bfloat16, float, double}")
1079     .SetShapeFn(shape_inference::MergeBothInputsShapeFn);
1080 
1081 // --------------------------------------------------------------------------
1082 
1083 REGISTER_OP("Softmax")
1084     .Input("logits: T")
1085     .Output("softmax: T")
1086     .Attr("T: {half, bfloat16, float, double}")
__anon3e672dd81f02(InferenceContext* c) 1087     .SetShapeFn([](InferenceContext* c) {
1088       return shape_inference::UnchangedShapeWithRankAtLeast(c, 1);
1089     });
1090 
1091 // --------------------------------------------------------------------------
1092 
1093 REGISTER_OP("LogSoftmax")
1094     .Input("logits: T")
1095     .Output("logsoftmax: T")
1096     .Attr("T: {half, bfloat16, float, double}")
__anon3e672dd82002(InferenceContext* c) 1097     .SetShapeFn([](InferenceContext* c) {
1098       return shape_inference::UnchangedShapeWithRankAtLeast(c, 1);
1099     });
1100 
1101 // --------------------------------------------------------------------------
1102 
1103 REGISTER_OP("SoftmaxCrossEntropyWithLogits")
1104     .Input("features: T")
1105     .Input("labels: T")
1106     .Output("loss: T")
1107     .Output("backprop: T")
1108     .Attr("T: {half, bfloat16, float, double}")
__anon3e672dd82102(InferenceContext* c) 1109     .SetShapeFn([](InferenceContext* c) {
1110       ShapeHandle input;
1111       if (c->WithRank(c->input(0), 2, &input) == Status::OK() &&
1112           c->Merge(input, c->input(1), &input) == Status::OK()) {
1113         DimensionHandle batch_size = c->Dim(input, 0);
1114         c->set_output(0, c->Vector(batch_size));
1115         c->set_output(1, input);
1116         return Status::OK();
1117       }
1118       TF_RETURN_IF_ERROR(BroadcastBinaryOpOutputShapeFn(c, 1));
1119 
1120       if (!c->RankKnown(c->output(1))) {
1121         return errors::InvalidArgument(
1122             "Shape must be broadcasted with rank 2, but is rank is unknown.");
1123       }
1124 
1125       if (c->Rank(c->output(1)) != 2) {
1126         return errors::InvalidArgument(
1127             "Shape must be broadcasted with rank 2, but is rank ",
1128             c->Rank(c->output(1)));
1129       }
1130       DimensionHandle batch_size = c->Dim(c->output(1), 0);
1131       c->set_output(0, c->Vector(batch_size));
1132       return Status::OK();
1133     });
1134 
1135 REGISTER_OP("SparseSoftmaxCrossEntropyWithLogits")
1136     .Input("features: T")
1137     .Input("labels: Tlabels")
1138     .Output("loss: T")
1139     .Output("backprop: T")
1140     .Attr("T: {half, bfloat16, float, double}")
1141     .Attr("Tlabels: {int32, int64} = DT_INT64")
__anon3e672dd82202(InferenceContext* c) 1142     .SetShapeFn([](InferenceContext* c) {
1143       ShapeHandle features;
1144       ShapeHandle labels;
1145       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &features));
1146       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &labels));
1147 
1148       DimensionHandle batch_size;
1149       TF_RETURN_IF_ERROR(
1150           c->Merge(c->Dim(features, 0), c->Dim(labels, 0), &batch_size));
1151       TF_RETURN_IF_ERROR(c->ReplaceDim(features, 0, batch_size, &features));
1152 
1153       c->set_output(0, c->Vector(batch_size));
1154       c->set_output(1, features);
1155       return Status::OK();
1156     });
1157 
1158 // --------------------------------------------------------------------------
1159 
1160 REGISTER_OP("InTopK")
1161     .Input("predictions: float")
1162     .Input("targets: T")
1163     .Output("precision: bool")
1164     .Attr("k: int")
1165     .Attr("T: {int32, int64} = DT_INT32")
__anon3e672dd82302(InferenceContext* c) 1166     .SetShapeFn([](InferenceContext* c) {
1167       ShapeHandle predictions;
1168       ShapeHandle targets;
1169       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &predictions));
1170       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &targets));
1171       DimensionHandle batch_size;
1172       TF_RETURN_IF_ERROR(
1173           c->Merge(c->Dim(predictions, 0), c->Dim(targets, 0), &batch_size));
1174       c->set_output(0, c->Vector(batch_size));
1175       return Status::OK();
1176     });
1177 
1178 // This is the same as `InTopK`, but takes `k` as in input rather than an attr.
1179 REGISTER_OP("InTopKV2")
1180     .Input("predictions: float")
1181     .Input("targets: T")
1182     .Input("k: T")
1183     .Output("precision: bool")
1184     .Attr("T: {int32, int64} = DT_INT32")
__anon3e672dd82402(InferenceContext* c) 1185     .SetShapeFn([](InferenceContext* c) {
1186       ShapeHandle predictions;
1187       ShapeHandle targets;
1188       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 2, &predictions));
1189       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 1, &targets));
1190       DimensionHandle batch_size;
1191       TF_RETURN_IF_ERROR(
1192           c->Merge(c->Dim(predictions, 0), c->Dim(targets, 0), &batch_size));
1193       c->set_output(0, c->Vector(batch_size));
1194       return Status::OK();
1195     });
1196 
1197 namespace {
1198 
TopKShapeFn(InferenceContext * c)1199 Status TopKShapeFn(InferenceContext* c) {
1200   ShapeHandle input;
1201   TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input));
1202 
1203   // Get the k value, either from input tensor or attribute.
1204   DimensionHandle k_dim;
1205   if (c->num_inputs() >= 2) {
1206     TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(1, &k_dim));
1207   } else {
1208     int32 k;
1209     TF_RETURN_IF_ERROR(c->GetAttr("k", &k));
1210     if (k < 0) {
1211       return errors::InvalidArgument("Need k >= 0, got ", k);
1212     }
1213     k_dim = c->MakeDim(k);
1214   }
1215 
1216   DimensionHandle last_dim = c->Dim(input, -1);
1217   if (c->ValueKnown(last_dim) && c->ValueKnown(k_dim) &&
1218       c->Value(last_dim) < c->Value(k_dim)) {
1219     return errors::InvalidArgument(
1220         "input must have last dimension >= k = ", c->Value(k_dim), " but is ",
1221         c->Value(last_dim));
1222   }
1223 
1224   // Replace last_dim with k_dim.
1225   ShapeHandle s;
1226   TF_RETURN_IF_ERROR(c->Subshape(input, 0, -1, &s));
1227   TF_RETURN_IF_ERROR(c->Concatenate(s, c->Vector(k_dim), &s));
1228   c->set_output(0, s);
1229   c->set_output(1, s);
1230   return Status::OK();
1231 }
1232 
1233 }  // namespace
1234 
1235 REGISTER_OP("TopK")
1236     .Input("input: T")
1237     .Output("values: T")
1238     .Output("indices: int32")
1239     .Attr("k: int >= 0")
1240     .Attr("sorted: bool = true")
1241     .Attr("T: realnumbertype")
1242     .Deprecated(7, "Use TopKV2 instead")
1243     .SetShapeFn(TopKShapeFn);
1244 
1245 // This is the same as `TopK`, but takes `k` as in input rather than an attr.
1246 REGISTER_OP("TopKV2")
1247     .Input("input: T")
1248     .Input("k: int32")
1249     .Output("values: T")
1250     .Output("indices: int32")
1251     .Attr("sorted: bool = true")
1252     .Attr("T: realnumbertype")
1253     .SetShapeFn(TopKShapeFn);
1254 
1255 // --------------------------------------------------------------------------
1256 
1257 REGISTER_OP("NthElement")
1258     .Input("input: T")
1259     .Input("n: int32")
1260     .Output("values: T")
1261     .Attr("reverse: bool = false")
1262     .Attr("T: realnumbertype")
__anon3e672dd82602(InferenceContext* c) 1263     .SetShapeFn([](InferenceContext* c) {
1264       ShapeHandle input;
1265       TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &input));
1266 
1267       // Get the n value from input tensor, and make sure which is a scalar.
1268       DimensionHandle n_dim;
1269       TF_RETURN_IF_ERROR(c->MakeDimForScalarInput(1, &n_dim));
1270 
1271       // The last dimension of input tensor must be greater than N.
1272       DimensionHandle last_dim = c->Dim(input, -1);
1273       if (c->ValueKnown(last_dim) && c->ValueKnown(n_dim) &&
1274           c->Value(last_dim) <= c->Value(n_dim)) {
1275         return errors::InvalidArgument(
1276             "Input must have last dimension > n = ", c->Value(n_dim),
1277             " but is ", c->Value(last_dim));
1278       }
1279 
1280       // Reduce last_dim for output tensor
1281       ShapeHandle s;
1282       TF_RETURN_IF_ERROR(c->Subshape(input, 0, -1, &s));
1283       c->set_output(0, s);
1284       return Status::OK();
1285     });
1286 
1287 // --------------------------------------------------------------------------
1288 
1289 REGISTER_OP("FractionalMaxPool")
1290     .Input("value: T")
1291     .Output("output: T")
1292     .Output("row_pooling_sequence: int64")
1293     .Output("col_pooling_sequence: int64")
1294     .Attr("pooling_ratio: list(float) >=4")
1295     .Attr("pseudo_random: bool = false")
1296     .Attr("overlapping: bool = false")
1297     .Attr("deterministic: bool = false")
1298     .Attr("seed: int = 0")
1299     .Attr("seed2: int = 0")
1300     .Attr("T: {float, double, int32, int64}")
1301     .SetShapeFn(FractionalPoolShapeFn);
1302 
1303 REGISTER_OP("FractionalMaxPoolGrad")
1304     .Input("orig_input: T")
1305     .Input("orig_output: T")
1306     .Input("out_backprop: T")
1307     .Input("row_pooling_sequence: int64")
1308     .Input("col_pooling_sequence: int64")
1309     .Output("output: T")
1310     .Attr("overlapping: bool = false")
1311     .Attr("T: {float, double, int32, int64}")
__anon3e672dd82702(InferenceContext* c) 1312     .SetShapeFn([](InferenceContext* c) {
1313       return shape_inference::UnchangedShapeWithRank(c, 4);
1314     });
1315 
1316 // --------------------------------------------------------------------------
1317 
1318 REGISTER_OP("FractionalAvgPool")
1319     .Input("value: T")
1320     .Output("output: T")
1321     .Output("row_pooling_sequence: int64")
1322     .Output("col_pooling_sequence: int64")
1323     .Attr("pooling_ratio: list(float) >=4")
1324     .Attr("pseudo_random: bool = false")
1325     .Attr("overlapping: bool = false")
1326     .Attr("deterministic: bool = false")
1327     .Attr("seed: int = 0")
1328     .Attr("seed2: int = 0")
1329     .Attr("T: {float, double, int32, int64}")
1330     .SetShapeFn(FractionalPoolShapeFn);
1331 
1332 REGISTER_OP("FractionalAvgPoolGrad")
1333     .Input("orig_input_tensor_shape: int64")
1334     .Input("out_backprop: T")
1335     .Input("row_pooling_sequence: int64")
1336     .Input("col_pooling_sequence: int64")
1337     .Output("output: T")
1338     .Attr("overlapping: bool = false")
1339     .Attr("T: {float, double, int32, int64}")
__anon3e672dd82802(InferenceContext* c) 1340     .SetShapeFn([](InferenceContext* c) {
1341       if (c->input_tensor(0) != nullptr) {
1342         ShapeHandle out;
1343         TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &out));
1344         c->set_output(0, out);
1345       } else {
1346         c->set_output(0, c->UnknownShapeOfRank(4));
1347       }
1348       return Status::OK();
1349     });
1350 
1351 REGISTER_OP("QuantizedAvgPool")
1352     .Input("input: T")
1353     .Input("min_input: float")
1354     .Input("max_input: float")
1355     .Output("output: T")
1356     .Output("min_output: float")
1357     .Output("max_output: float")
1358     .Attr("T: quantizedtype")
1359     .Attr("ksize: list(int)")
1360     .Attr("strides: list(int)")
1361     .Attr(GetPaddingAttrString())
__anon3e672dd82902(InferenceContext* c) 1362     .SetShapeFn([](InferenceContext* c) {
1363       TF_RETURN_IF_ERROR(shape_inference::AvgPoolShape(c));
1364       ShapeHandle unused;
1365       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1366       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1367       c->set_output(1, c->Scalar());
1368       c->set_output(2, c->Scalar());
1369       return Status::OK();
1370     });
1371 
1372 REGISTER_OP("QuantizedBiasAdd")
1373     .Input("input: T1")
1374     .Input("bias: T2")
1375     .Input("min_input: float")
1376     .Input("max_input: float")
1377     .Input("min_bias: float")
1378     .Input("max_bias: float")
1379     .Output("output: out_type")
1380     .Output("min_out: float")
1381     .Output("max_out: float")
1382     .Attr("T1: quantizedtype")
1383     .Attr("T2: quantizedtype")
1384     .Attr("out_type: quantizedtype")
__anon3e672dd82a02(InferenceContext* c) 1385     .SetShapeFn([](InferenceContext* c) {
1386       TF_RETURN_IF_ERROR(shape_inference::BiasAddShape(c));
1387       ShapeHandle unused;
1388       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1389       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1390       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
1391       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
1392       c->set_output(1, c->Scalar());
1393       c->set_output(2, c->Scalar());
1394       return Status::OK();
1395     });
1396 
1397 REGISTER_OP("QuantizedConv2D")
1398     .Input("input: Tinput")
1399     .Input("filter: Tfilter")
1400     .Input("min_input: float")
1401     .Input("max_input: float")
1402     .Input("min_filter: float")
1403     .Input("max_filter: float")
1404     .Output("output: out_type")
1405     .Output("min_output: float")
1406     .Output("max_output: float")
1407     .Attr("Tinput: quantizedtype")
1408     .Attr("Tfilter: quantizedtype")
1409     .Attr("out_type: quantizedtype = DT_QINT32")
1410     .Attr("strides: list(int)")
1411     .Attr(GetPaddingAttrString())
1412     .Attr("dilations: list(int) = [1, 1, 1, 1]")
__anon3e672dd82b02(InferenceContext* c) 1413     .SetShapeFn([](InferenceContext* c) {
1414       TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
1415       ShapeHandle unused;
1416       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1417       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1418       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
1419       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
1420       c->set_output(1, c->Scalar());
1421       c->set_output(2, c->Scalar());
1422       return Status::OK();
1423     });
1424 
1425 REGISTER_OP("QuantizedMaxPool")
1426     .Input("input: T")
1427     .Input("min_input: float")
1428     .Input("max_input: float")
1429     .Output("output: T")
1430     .Output("min_output: float")
1431     .Output("max_output: float")
1432     .Attr("T: quantizedtype")
1433     .Attr("ksize: list(int)")
1434     .Attr("strides: list(int)")
1435     .Attr(GetPaddingAttrString())
__anon3e672dd82c02(InferenceContext* c) 1436     .SetShapeFn([](InferenceContext* c) {
1437       TF_RETURN_IF_ERROR(shape_inference::MaxPoolShape(c));
1438       ShapeHandle unused;
1439       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1440       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1441       c->set_output(1, c->Scalar());
1442       c->set_output(2, c->Scalar());
1443       return Status::OK();
1444     });
1445 
1446 REGISTER_OP("QuantizedRelu")
1447     .Input("features: Tinput")
1448     .Input("min_features: float")
1449     .Input("max_features: float")
1450     .Output("activations: out_type")
1451     .Output("min_activations: float")
1452     .Output("max_activations: float")
1453     .Attr("Tinput: quantizedtype")
1454     .Attr("out_type: quantizedtype = DT_QUINT8")
__anon3e672dd82d02(InferenceContext* c) 1455     .SetShapeFn([](InferenceContext* c) {
1456       TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
1457       ShapeHandle unused;
1458       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1459       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1460       c->set_output(1, c->Scalar());
1461       c->set_output(2, c->Scalar());
1462       return Status::OK();
1463     });
1464 
1465 REGISTER_OP("QuantizedRelu6")
1466     .Input("features: Tinput")
1467     .Input("min_features: float")
1468     .Input("max_features: float")
1469     .Output("activations: out_type")
1470     .Output("min_activations: float")
1471     .Output("max_activations: float")
1472     .Attr("Tinput: quantizedtype")
1473     .Attr("out_type: quantizedtype = DT_QUINT8")
__anon3e672dd82e02(InferenceContext* c) 1474     .SetShapeFn([](InferenceContext* c) {
1475       TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
1476       ShapeHandle unused;
1477       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1478       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1479       c->set_output(1, c->Scalar());
1480       c->set_output(2, c->Scalar());
1481       return Status::OK();
1482     });
1483 
1484 REGISTER_OP("QuantizedReluX")
1485     .Input("features: Tinput")
1486     .Input("max_value: float")
1487     .Input("min_features: float")
1488     .Input("max_features: float")
1489     .Output("activations: out_type")
1490     .Output("min_activations: float")
1491     .Output("max_activations: float")
1492     .Attr("Tinput: quantizedtype")
1493     .Attr("out_type: quantizedtype = DT_QUINT8")
__anon3e672dd82f02(InferenceContext* c) 1494     .SetShapeFn([](InferenceContext* c) {
1495       TF_RETURN_IF_ERROR(shape_inference::UnchangedShape(c));
1496       ShapeHandle unused;
1497       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused));
1498       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
1499       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
1500       c->set_output(1, c->Scalar());
1501       c->set_output(2, c->Scalar());
1502       return Status::OK();
1503     });
1504 
1505 REGISTER_OP("QuantizedBatchNormWithGlobalNormalization")
1506     .Input("t: Tinput")
1507     .Input("t_min: float")
1508     .Input("t_max: float")
1509     .Input("m: Tinput")
1510     .Input("m_min: float")
1511     .Input("m_max: float")
1512     .Input("v: Tinput")
1513     .Input("v_min: float")
1514     .Input("v_max: float")
1515     .Input("beta: Tinput")
1516     .Input("beta_min: float")
1517     .Input("beta_max: float")
1518     .Input("gamma: Tinput")
1519     .Input("gamma_min: float")
1520     .Input("gamma_max: float")
1521     .Output("result: out_type")
1522     .Output("result_min: float")
1523     .Output("result_max: float")
1524     .Attr("Tinput: quantizedtype")
1525     .Attr("out_type: quantizedtype")
1526     .Attr("variance_epsilon: float")
1527     .Attr("scale_after_normalization: bool")
__anon3e672dd83002(InferenceContext* c) 1528     .SetShapeFn([](InferenceContext* c) {
1529       ShapeHandle input;
1530       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input));
1531 
1532       DimensionHandle last_dim = c->Dim(input, 3);
1533       for (int i = 1; i < 5; ++i) {  // covers m, v, beta, gamma
1534         ShapeHandle vec;
1535         TF_RETURN_IF_ERROR(c->WithRank(c->input(i * 3), 1, &vec));
1536         TF_RETURN_IF_ERROR(c->Merge(last_dim, c->Dim(vec, 0), &last_dim));
1537       }
1538 
1539       ShapeHandle out;
1540       TF_RETURN_IF_ERROR(c->ReplaceDim(input, 3, last_dim, &out));
1541       c->set_output(0, out);
1542       c->set_output(1, c->Scalar());
1543       c->set_output(2, c->Scalar());
1544 
1545       return Status::OK();
1546     });
1547 
1548 #ifdef INTEL_MKL
1549 REGISTER_OP("_MklDepthwiseConv2dNative")
1550     .Input("input: T")
1551     .Input("filter: T")
1552     .Input("mkl_input: uint8")
1553     .Input("mkl_filter: uint8")
1554     .Output("output: T")
1555     .Output("filter_output: T")
1556     .Output("mkl_output: uint8")
1557     .Output("mkl_filter_output: uint8")
1558     .Attr("T: {half, bfloat16, float, double}")
1559     .Attr("strides: list(int)")
1560     .Attr("is_filter_const: bool = false")
1561     .Attr(GetPaddingAttrString())
1562     .Attr(GetConvnetDataFormatAttrString())
1563     .Attr("dilations: list(int) = [1, 1, 1, 1]")
1564     .SetShapeFn(shape_inference::DepthwiseConv2DNativeShape);
1565 
1566 REGISTER_OP("_MklConv2D")
1567     .Input("input: T")
1568     .Input("filter: T")
1569     .Input("mkl_input: uint8")
1570     .Input("mkl_filter: uint8")
1571     .Output("output: T")
1572     .Output("filter_output: T")
1573     .Output("mkl_output: uint8")
1574     .Output("mkl_filter_output: uint8")
1575     .Attr("T: {half, float, double}")
1576     .Attr("strides: list(int)")
1577     .Attr("use_cudnn_on_gpu: bool = true")
1578     .Attr("is_filter_const: bool = false")
1579     .Attr(GetPaddingAttrString())
1580     .Attr(GetConvnetDataFormatAttrString())
1581     .Attr("dilations: list(int) = [1, 1, 1, 1]")
1582     .SetShapeFn(shape_inference::Conv2DShape)
1583     .Doc(R"doc(
1584 MKL version of Conv2D operator. Uses MKL DNN APIs to perform 2D convolution.
1585 
1586 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1587 expected to invoke these operators.
1588 )doc");
1589 
1590 REGISTER_OP("__MklDummyConv2DWithBias")
1591     .Input("input: T")
1592     .Input("filter: T")
1593     .Input("bias: T")
1594     .Output("output: T")
1595     .Attr("T: {half, float, double}")
1596     .Attr("strides: list(int)")
1597     .Attr("use_cudnn_on_gpu: bool = true")
1598     .Attr("is_filter_const: bool = false")
1599     .Attr(GetPaddingAttrString())
1600     .Attr(GetConvnetDataFormatAttrString())
1601     .Attr("dilations: list(int) = [1, 1, 1, 1]")
1602     .SetShapeFn(shape_inference::Conv2DShape)
1603     .Doc(R"doc(
1604 Dummy node that enables fusing Conv2D and BiasAdd operator for MKL. This node
1605 does not perform anything. It is just created as an intermediate output of
1606 merging Conv2D and BiasAdd.
1607 
1608 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1609 expected to invoke these operators.
1610 )doc");
1611 
1612 REGISTER_OP("_MklConv2DWithBias")
1613     .Input("input: T")
1614     .Input("filter: T")
1615     .Input("bias: T")
1616     .Input("mkl_input: uint8")
1617     .Input("mkl_filter: uint8")
1618     .Input("mkl_bias: uint8")
1619     .Output("output: T")
1620     .Output("filter_output: T")
1621     .Output("mkl_output: uint8")
1622     .Output("mkl_filter_output: uint8")
1623     .Attr("T: {half, float, double}")
1624     .Attr("strides: list(int)")
1625     .Attr("use_cudnn_on_gpu: bool = true")
1626     .Attr("is_filter_const: bool = false")
1627     .Attr(GetPaddingAttrString())
1628     .Attr(GetConvnetDataFormatAttrString())
1629     .Attr("dilations: list(int) = [1, 1, 1, 1]")
1630     .SetShapeFn(shape_inference::Conv2DShape)
1631     .Doc(R"doc(
1632 MKL version of Conv2D and BiasAdd operator. Uses MKL DNN APIs to perform
1633 2D convolution and add Bias to the output of convolution.
1634 
1635 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1636 expected to invoke these operators.
1637 )doc");
1638 
1639 REGISTER_OP("__MklDummyPadWithConv2D")
1640     .Input("input: T")
1641     .Input("filter: T")
1642     .Input("paddings: Tpaddings")
1643     .Output("output: T")
1644     .Attr("T: {half, float, double}")
1645     .Attr("strides: list(int)")
1646     .Attr("use_cudnn_on_gpu: bool = true")
1647     .Attr("is_filter_const: bool = false")
1648     .Attr(GetPaddingAttrString())
1649     .Attr(GetConvnetDataFormatAttrString())
1650     .Attr("dilations: list(int) = [1, 1, 1, 1]")
1651     .Attr("Tpaddings: {int32, int64} = DT_INT32")
1652     .SetShapeFn(shape_inference::Conv2DShape)
1653     .Doc(R"doc(
1654 Dummy node that enables fusing Pad and Conv2D operator for MKL. This node
1655 does not perform anything. It is just created as an intermediate output of
1656 merging Pad and Conv2D.
1657 
1658 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1659 expected to invoke these operators.
1660 )doc");
1661 
1662 REGISTER_OP("_MklPadWithConv2D")
1663     .Input("input: T")
1664     .Input("filter: T")
1665     .Input("paddings: Tpaddings")
1666     .Input("mkl_input: uint8")
1667     .Input("mkl_filter: uint8")
1668     .Input("mkl_paddings: uint8")
1669     .Output("output: T")
1670     .Output("filter_output: T")
1671     .Output("mkl_output: uint8")
1672     .Output("mkl_filter_output: uint8")
1673     .Attr("T: {half, float, double}")
1674     .Attr("strides: list(int)")
1675     .Attr("use_cudnn_on_gpu: bool = true")
1676     .Attr(GetPaddingAttrString())
1677     .Attr(GetConvnetDataFormatAttrString())
1678     .Attr("is_filter_const: bool = false")
1679     .Attr("dilations: list(int) = [1, 1, 1, 1]")
1680     .Attr("Tpaddings: {int32, int64} = DT_INT32")
1681     .SetShapeFn(shape_inference::Conv2DShape)
1682     .Doc(R"doc(
1683 MKL version of Pad and Conv2D operator. Uses MKL DNN APIs to perform
1684 Pad and 2D convolution to the output of convolution.
1685 
1686 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1687 expected to invoke these operators.
1688 )doc");
1689 
1690 REGISTER_OP("_MklConv2DBackpropFilter")
1691     .Input("input: T")
1692     .Input("filter_sizes: int32")
1693     .Input("out_backprop: T")
1694     .Input("mkl_input: uint8")
1695     .Input("mkl_filter_size: uint8")
1696     .Input("mkl_out_backprop: uint8")
1697     .Output("output: T")
1698     .Output("mkl_output: uint8")
1699     .Attr("T: {half, float, double}")
1700     .Attr("strides: list(int)")
1701     .Attr("use_cudnn_on_gpu: bool = true")
1702     .Attr(GetPaddingAttrString())
1703     .Attr(GetConvnetDataFormatAttrString())
1704     .Attr("dilations: list(int) = [1, 1, 1, 1]")
__anon3e672dd83102(InferenceContext* c) 1705     .SetShapeFn([](InferenceContext* c) {
1706       ShapeHandle s;
1707       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
1708       TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
1709       c->set_output(0, s);
1710       return Status::OK();
1711     })
1712     .Doc(R"doc(
1713 MKL version of Conv2DBackpropFilter. Uses MKL DNN APIs to compute the
1714 gradients of convolution with respect to the filter.
1715 
1716 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1717 expected to invoke these operators.
1718 )doc");
1719 
1720 REGISTER_OP("__MklDummyConv2DBackpropFilterWithBias")
1721     .Input("input: T")
1722     .Input("filter_sizes: int32")
1723     .Input("out_backprop: T")
1724     .Output("output: T")
1725     .Output("bias_grad: T")
1726     .Attr("T: {half, float, double}")
1727     .Attr("strides: list(int)")
1728     .Attr("use_cudnn_on_gpu: bool = true")
1729     .Attr(GetPaddingAttrString())
1730     .Attr(GetConvnetDataFormatAttrString())
1731     .Attr("dilations: list(int) = [1, 1, 1, 1]")
__anon3e672dd83202(InferenceContext* c) 1732     .SetShapeFn([](InferenceContext* c) {
1733       ShapeHandle input_shape;
1734       // Fetch the data_format attribute, which may not exist.
1735       string data_format;
1736       Status s = c->GetAttr("data_format", &data_format);
1737 
1738       if (s.ok() && data_format == "NCHW") {
1739         TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
1740         c->set_output(1, c->Vector(c->Dim(input_shape, -3)));
1741       } else {
1742         TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
1743         c->set_output(1, c->Vector(c->Dim(input_shape, -1)));
1744       }
1745       ShapeHandle sh;
1746       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &sh));
1747       TF_RETURN_IF_ERROR(c->WithRank(sh, 4, &sh));
1748       c->set_output(0, sh);
1749       return Status::OK();
1750     })
1751     .Doc(R"doc(
1752 Dummy node that enables fusing Conv2DBackpropFilter and BiasAddGrad operator
1753 for MKL. This node does not perform anything. It is just created as an
1754 intermediate output of merging Conv2DBackpropFilter and BiasAddGrad.
1755 
1756 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1757 expected to invoke these operators.
1758 )doc");
1759 
1760 REGISTER_OP("_MklConv2DBackpropFilterWithBias")
1761     .Input("input: T")
1762     .Input("filter_sizes: int32")
1763     .Input("out_backprop: T")
1764     .Input("mkl_input: uint8")
1765     .Input("mkl_filter_size: uint8")
1766     .Input("mkl_out_backprop: uint8")
1767     .Output("output: T")
1768     .Output("bias_grad: T")
1769     .Output("mkl_output: uint8")
1770     .Output("mkl_bias_grad: uint8")
1771     .Attr("T: {half, float, double}")
1772     .Attr("strides: list(int)")
1773     .Attr("use_cudnn_on_gpu: bool = true")
1774     .Attr(GetPaddingAttrString())
1775     .Attr(GetConvnetDataFormatAttrString())
1776     .Attr("dilations: list(int) = [1, 1, 1, 1]")
__anon3e672dd83302(InferenceContext* c) 1777     .SetShapeFn([](InferenceContext* c) {
1778       ShapeHandle input_shape;
1779       // Fetch the data_format attribute, which may not exist.
1780       string data_format;
1781       Status s = c->GetAttr("data_format", &data_format);
1782 
1783       if (s.ok() && data_format == "NCHW") {
1784         TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
1785         c->set_output(1, c->Vector(c->Dim(input_shape, -3)));
1786       } else {
1787         TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &input_shape));
1788         c->set_output(1, c->Vector(c->Dim(input_shape, -1)));
1789       }
1790       ShapeHandle sh;
1791       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &sh));
1792       TF_RETURN_IF_ERROR(c->WithRank(sh, 4, &sh));
1793       c->set_output(0, sh);
1794       return Status::OK();
1795     })
1796     .Doc(R"doc(
1797 MKL version of Conv2DBackpropFilterWithBias. Uses MKL DNN APIs to compute the
1798 gradients of convolution with respect to the filter.
1799 
1800 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1801 expected to invoke these operators.
1802 )doc");
1803 
1804 #ifdef INTEL_MKL_ML_ONLY
1805 REGISTER_OP("_MklConv2DWithBiasBackpropBias")
1806     .Input("out_backprop: T")
1807     .Input("mkl_out_backprop: uint8")
1808     .Output("output: T")
1809     .Output("mkl_output: uint8")
1810     .Attr("T: {half, float, double}")
1811     .Attr("strides: list(int)")
1812     .Attr(GetConvnetDataFormatAttrString())
1813     .Attr("dilations: list(int) = [1, 1, 1, 1]")
1814     .Doc(R"doc(
1815 MKL version of Conv2DBackpropBias. Uses MKL DNN APIs to compute the
1816 gradients of convolution with respect to the bias.
1817 
1818 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1819 expected to invoke these operators.
1820 )doc");
1821 #endif
1822 
1823 REGISTER_OP("_MklConv2DBackpropInput")
1824     .Input("input_sizes: int32")
1825     .Input("filter: T")
1826     .Input("out_backprop: T")
1827     .Input("mkl_input_sizes: uint8")
1828     .Input("mkl_filter: uint8")
1829     .Input("mkl_out_backprop: uint8")
1830     .Output("output: T")
1831     .Output("mkl_output: uint8")
1832     .Attr("T: {half, float, double}")
1833     .Attr("strides: list(int)")
1834     .Attr("use_cudnn_on_gpu: bool = true")
1835     .Attr(GetPaddingAttrString())
1836     .Attr(GetConvnetDataFormatAttrString())
1837     .Attr("dilations: list(int) = [1, 1, 1, 1]")
__anon3e672dd83402(InferenceContext* c) 1838     .SetShapeFn([](InferenceContext* c) {
1839       ShapeHandle s;
1840       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
1841       TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
1842       c->set_output(0, s);
1843       return Status::OK();
1844     })
1845     .Doc(R"doc(
1846 MKL version of Convolution2D backward input. Uses MKL DNN APIs to compute the
1847 gradients of convolution with respect to the input.
1848 
1849 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1850 expected to invoke these operators.
1851 )doc");
1852 
1853 REGISTER_OP("_MklConv3D")
1854     .Input("input: T")
1855     .Input("filter: T")
1856     .Input("mkl_input: uint8")
1857     .Input("mkl_filter: uint8")
1858     .Output("output: T")
1859     .Output("filter_output: T")
1860     .Output("mkl_output: uint8")
1861     .Output("mkl_filter_output: uint8")
1862     .Attr("T: {half, float, double}")
1863     .Attr("strides: list(int) >= 5")
1864     .Attr("is_filter_const: bool = false")
1865     .Attr(GetPaddingAttrString())
1866     .Attr(GetConvnet3dDataFormatAttrString())
1867     .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
1868     .SetShapeFn(shape_inference::Conv3DShape)
1869     .Doc(R"doc(
1870 MKL version of Conv3D operator. Uses MKL DNN APIs to perform 3D convolution.
1871 
1872 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1873 expected to invoke these operators.
1874 )doc");
1875 
1876 REGISTER_OP("_MklConv3DBackpropInputV2")
1877     .Input("input_sizes: Tshape")
1878     .Input("filter: T")
1879     .Input("out_backprop: T")
1880     .Input("mkl_input_sizes: uint8")
1881     .Input("mkl_filter: uint8")
1882     .Input("mkl_out_backprop: uint8")
1883     .Output("output: T")
1884     .Output("mkl_output: uint8")
1885     .Attr("T: {half, float, double}")
1886     .Attr("strides: list(int) >= 5")
1887     .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
1888     .Attr("Tshape: {int32, int64} = DT_INT32")
1889     .Attr(GetPaddingAttrString())
1890     .Attr(GetConvnet3dDataFormatAttrString())
__anon3e672dd83502(InferenceContext* c) 1891     .SetShapeFn([](InferenceContext* c) {
1892       ShapeHandle s;
1893       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
1894       TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
1895       c->set_output(0, s);
1896       return Status::OK();
1897     })
1898     .Doc(R"doc(
1899 MKL version of Convolution3D backward input. Uses MKL DNN APIs to compute the
1900 gradients of convolution with respect to the input.
1901 
1902 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1903 expected to invoke these operators.
1904 )doc");
1905 
1906 REGISTER_OP("_MklConv3DBackpropFilterV2")
1907     .Input("input: T")
1908     .Input("filter_sizes: int32")
1909     .Input("out_backprop: T")
1910     .Input("mkl_input: uint8")
1911     .Input("mkl_filter_size: uint8")
1912     .Input("mkl_out_backprop: uint8")
1913     .Output("output: T")
1914     .Output("mkl_output: uint8")
1915     .Attr("T: {half, float, double}")
1916     .Attr("strides: list(int)")
1917     .Attr(GetPaddingAttrString())
1918     .Attr(GetConvnet3dDataFormatAttrString())
1919     .Attr("dilations: list(int) = [1, 1, 1, 1, 1]")
__anon3e672dd83602(InferenceContext* c) 1920     .SetShapeFn([](InferenceContext* c) {
1921       ShapeHandle s;
1922       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
1923       TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
1924       c->set_output(0, s);
1925       return Status::OK();
1926     })
1927     .Doc(R"doc(
1928 MKL version of Conv3DBackpropFilter. Uses MKL DNN APIs to compute the
1929 gradients of convolution with respect to the filter.
1930 
1931 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1932 expected to invoke these operators.
1933 )doc");
1934 
1935 REGISTER_OP("_MklRelu")
1936     .Input("features: T")
1937     .Input("mkl_features: uint8")
1938     .Output("activations: T")
1939     .Output("mkl_activations: uint8")
1940     .Attr("T: realnumbertype")
1941     .SetShapeFn(shape_inference::UnchangedShape)
1942     .Doc(R"doc(
1943 MKL version of Relu operator. Uses MKL DNN APIs to implement Relu operator.
1944 
1945 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1946 expected to invoke these operators.
1947 )doc");
1948 
1949 REGISTER_OP("_MklReluGrad")
1950     .Input("gradients: T")
1951     .Input("features: T")
1952     .Input("mkl_gradients: uint8")
1953     .Input("mkl_features: uint8")
1954     .Output("backprops: T")
1955     .Output("mkl_backprops: uint8")
1956     .Attr("T: realnumbertype")
1957     .SetShapeFn(shape_inference::MergeBothInputsShapeFn)
1958     .Doc(R"doc(
1959 MKL version of ReluGrad operator. Uses MKL DNN APIs to compute rectified
1960 linear gradients for Relu operation.
1961 
1962 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1963 expected to invoke these operators.
1964 )doc");
1965 
1966 REGISTER_OP("_MklRelu6")
1967     .Input("features: T")
1968     .Input("mkl_features: uint8")
1969     .Output("activations: T")
1970     .Output("mkl_activations: uint8")
1971     .Attr("T: realnumbertype")
1972     .SetShapeFn(shape_inference::UnchangedShape)
1973     .Doc(R"doc(
1974 MKL version of Relu6 operator. Uses MKL DNN APIs to implement Relu6 operator.
1975 
1976 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1977 expected to invoke these operators.
1978 )doc");
1979 
1980 REGISTER_OP("_MklRelu6Grad")
1981     .Input("gradients: T")
1982     .Input("features: T")
1983     .Input("mkl_gradients: uint8")
1984     .Input("mkl_features: uint8")
1985     .Output("backprops: T")
1986     .Output("mkl_backprops: uint8")
1987     .Attr("T: realnumbertype")
1988     .SetShapeFn(shape_inference::MergeBothInputsShapeFn)
1989     .Doc(R"doc(
1990 MKL version of Relu6Grad operator. Uses MKL DNN APIs to compute rectified
1991 linear gradients for Relu6 operation.
1992 
1993 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
1994 expected to invoke these operators.
1995 )doc");
1996 
1997 REGISTER_OP("_MklLeakyRelu")
1998     .Input("features: T")
1999     .Input("mkl_features: uint8")
2000     .Output("activations: T")
2001     .Output("mkl_activations: uint8")
2002     .Attr("T: {half, float, double} = DT_FLOAT")
2003     .Attr("alpha: float = 0.2")
2004     .SetShapeFn(shape_inference::UnchangedShape)
2005     .Doc(R"doc(
2006 MKL version of LeakyRelu operator. Uses MKL DNN APIs to implement
2007 LeakyRelu operator.
2008 
2009 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2010 expected to invoke these operators.
2011 )doc");
2012 
2013 REGISTER_OP("_MklLeakyReluGrad")
2014     .Input("gradients: T")
2015     .Input("features: T")
2016     .Input("mkl_gradients: uint8")
2017     .Input("mkl_features: uint8")
2018     .Output("backprops: T")
2019     .Output("mkl_backprops: uint8")
2020     .Attr("T: {half, float, double} = DT_FLOAT")
2021     .Attr("alpha: float = 0.2")
2022     .SetShapeFn(shape_inference::MergeBothInputsShapeFn)
2023     .Doc(R"doc(
2024 MKL version of LeakyReluGrad operator. Uses MKL DNN APIs to compute rectified
2025 linear gradients for LeakyReluGrad operation.
2026 
2027 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2028 expected to invoke these operators.
2029 )doc");
2030 
2031 REGISTER_OP("_MklElu")
2032     .Input("features: T")
2033     .Input("mkl_features: uint8")
2034     .Output("activations: T")
2035     .Output("mkl_activations: uint8")
2036     .Attr("T: realnumbertype")
2037     .SetShapeFn(shape_inference::UnchangedShape)
2038     .Doc(R"doc(
2039 MKL version of Elu operator. Uses MKL DNN APIs to implement Elu operator.
2040 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2041 expected to invoke these operators.
2042 )doc");
2043 
2044 REGISTER_OP("_MklEluGrad")
2045     .Input("gradients: T")
2046     .Input("features: T")
2047     .Input("mkl_gradients: uint8")
2048     .Input("mkl_features: uint8")
2049     .Output("backprops: T")
2050     .Output("mkl_backprops: uint8")
2051     .Attr("T: realnumbertype")
2052     .SetShapeFn(shape_inference::MergeBothInputsShapeFn)
2053     .Doc(R"doc(
2054 MKL version of EluGrad operator. Uses MKL DNN APIs to compute Elu
2055 gradients for Elu operation.
2056 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2057 expected to invoke these operators.
2058 )doc");
2059 
2060 REGISTER_OP("_MklSoftmax")
2061     .Input("logits: T")
2062     .Input("mkl_logits: uint8")
2063     .Output("softmax: T")
2064     .Output("mkl_softmax: uint8")
2065     .Attr("T: {half, float, double}")
__anon3e672dd83702(InferenceContext* c) 2066     .SetShapeFn([](InferenceContext* c) {
2067       return shape_inference::UnchangedShapeWithRankAtLeast(c, 1);
2068     })
2069     .Doc(R"doc(
2070 MKL version of ReluGrad operator. Uses MKL DNN APIs to compute rectified
2071 linear gradients for Relu operation.
2072 )doc");
2073 
2074 REGISTER_OP("_MklTanh")
2075     .Input("features: T")
2076     .Input("mkl_features: uint8")
2077     .Output("activations: T")
2078     .Output("mkl_activations: uint8")
2079     .Attr("T: realnumbertype")
2080     .SetShapeFn(shape_inference::UnchangedShape)
2081     .Doc(R"doc(
2082 MKL version of Tanh operator. Uses MKL DNN APIs to implement Tanh operator.
2083 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2084 expected to invoke these operators.
2085 )doc");
2086 
2087 REGISTER_OP("_MklTanhGrad")
2088     .Input("gradients: T")
2089     .Input("features: T")
2090     .Input("mkl_gradients: uint8")
2091     .Input("mkl_features: uint8")
2092     .Output("backprops: T")
2093     .Output("mkl_backprops: uint8")
2094     .Attr("T: realnumbertype")
2095     .SetShapeFn(shape_inference::MergeBothInputsShapeFn)
2096     .Doc(R"doc(
2097 MKL version of TanhGrad operator. Uses MKL DNN APIs to compute tanh
2098 gradients for Tanh operation.
2099 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2100 expected to invoke these operators.
2101 )doc");
2102 
2103 REGISTER_OP("_MklMaxPool")
2104     .Attr("T: {float, half} = DT_FLOAT")
2105     .Attr("ksize: list(int) >= 4")
2106     .Attr("strides: list(int) >= 4")
2107     .Attr(GetPaddingAttrString())
2108     .Attr(GetConvnetDataFormatAttrString())
2109     .Attr("workspace_enabled: bool = false")
2110     .Input("input: T")
2111     .Input("mkl_input: uint8")
2112     .Output("output: T")
2113 #ifdef INTEL_MKL_ML_ONLY
2114     .Output("workspace: T")
2115 #else
2116     .Output("workspace: uint8")
2117 #endif
2118     .Output("mkl_output: uint8")
2119     .Output("mkl_workspace: uint8")
2120     .SetShapeFn(shape_inference::MaxPoolShape)
2121     .Doc(R"doc(
2122 MKL version of MaxPool operator. Uses MKL DNN APIs to perform max pooling
2123 on the input.
2124 
2125 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2126 expected to invoke these operators.
2127 )doc");
2128 
2129 REGISTER_OP("_MklMaxPoolGrad")
2130     .Attr("T: {float, half} = DT_FLOAT")
2131     .Attr("ksize: list(int) >= 4")
2132     .Attr("strides: list(int) >= 4")
2133     .Attr("workspace_enabled: bool = false")
2134     .Attr(GetPaddingAttrString())
2135     .Attr(GetConvnetDataFormatAttrString())
2136     .Input("orig_input: T")
2137     .Input("orig_output: T")
2138     .Input("grad: T")
2139 #ifdef INTEL_MKL_ML_ONLY
2140     .Input("workspace: T")
2141 #else
2142     .Input("workspace: uint8")
2143 #endif
2144     .Input("mkl_orig_input: uint8")
2145     .Input("mkl_orig_output: uint8")
2146     .Input("mkl_grad: uint8")
2147     .Input("mkl_workspace: uint8")
2148     .Output("output: T")
2149     .Output("mkl_output: uint8")
__anon3e672dd83802(InferenceContext* c) 2150     .SetShapeFn([](InferenceContext* c) {
2151       return UnchangedShapeWithRank(c, 4);
2152     })
2153     .Doc(R"doc(
2154 MKL version of MaxPoolGrad. Uses MKL DNN APIs to compute gradients of
2155 MaxPool operator.
2156 
2157 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2158 expected to invoke these operators.
2159 )doc");
2160 
2161 REGISTER_OP("_MklAvgPool")
2162     .Input("value: T")
2163     .Input("mkl_input: uint8")
2164     .Output("output: T")
2165     .Output("mkl_output: uint8")
2166     .Attr("ksize: list(int) >= 4")
2167     .Attr("strides: list(int) >= 4")
2168     .Attr(GetPaddingAttrString())
2169     .Attr(GetConvnetDataFormatAttrString())
2170     .Attr("T: {float, half, double}")
2171     .SetShapeFn(shape_inference::AvgPoolShape)
2172     .Doc(R"doc(
2173 MKL version of AvgPool operator. Uses MKL DNN APIs to perform average pooling
2174 on the input.
2175 
2176 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2177 expected to invoke these operators.
2178 )doc");
2179 
2180 REGISTER_OP("_MklAvgPoolGrad")
2181     .Input("orig_input_shape: int32")
2182     .Input("grad: T")
2183     .Input("mkl_orig_input: uint8")
2184     .Input("mkl_grad: uint8")
2185     .Output("output: T")
2186     .Output("mkl_output: uint8")
2187     .Attr("ksize: list(int) >= 4")
2188     .Attr("strides: list(int) >= 4")
2189     .Attr(GetPaddingAttrString())
2190     .Attr(GetConvnetDataFormatAttrString())
2191     .Attr("T: {float, half, double}")
__anon3e672dd83902(InferenceContext* c) 2192     .SetShapeFn([](InferenceContext* c) {
2193       ShapeHandle s;
2194       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
2195       TF_RETURN_IF_ERROR(c->WithRank(s, 4, &s));
2196       c->set_output(0, s);
2197       return Status::OK();
2198     })
2199     .Doc(R"doc(
2200 MKL version of AvgPoolGrad operator. Uses MKL DNN APIs to compute gradients
2201 of AvgPool function.
2202 
2203 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2204 expected to invoke these operators.
2205 )doc");
2206 
2207 REGISTER_OP("_MklAvgPool3D")
2208     .Input("value: T")
2209     .Input("mkl_input: uint8")
2210     .Output("output: T")
2211     .Output("mkl_output: uint8")
2212     .Attr("ksize: list(int) >= 5")
2213     .Attr("strides: list(int) >= 5")
2214     .Attr(GetPaddingAttrString())
2215     .Attr(GetConvnet3dDataFormatAttrString())
2216     .Attr("T: {float, half, double}")
2217     .SetShapeFn(shape_inference::Pool3DShape)
2218     .Doc(R"doc(
2219 MKL version of AvgPool3D operator. Uses MKL DNN APIs to perform average pooling
2220 on the input.
2221 
2222 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2223 expected to invoke these operators.
2224 )doc");
2225 
2226 REGISTER_OP("_MklAvgPool3DGrad")
2227     .Input("orig_input_shape: int32")
2228     .Input("grad: T")
2229     .Input("mkl_orig_input: uint8")
2230     .Input("mkl_grad: uint8")
2231     .Output("output: T")
2232     .Output("mkl_output: uint8")
2233     .Attr("ksize: list(int) >= 5")
2234     .Attr("strides: list(int) >= 5")
2235     .Attr(GetPaddingAttrString())
2236     .Attr(GetConvnet3dDataFormatAttrString())
2237     .Attr("T: {float, half, double}")
__anon3e672dd83a02(InferenceContext* c) 2238     .SetShapeFn([](InferenceContext* c) {
2239       ShapeHandle s;
2240       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
2241       TF_RETURN_IF_ERROR(c->WithRank(s, 5, &s));
2242       c->set_output(0, s);
2243       return Status::OK();
2244     })
2245     .Doc(R"doc(
2246 MKL version of AvgPool3DGrad operator. Uses MKL DNN APIs to compute gradients
2247 of AvgPool function.
2248 
2249 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2250 expected to invoke these operators.
2251 )doc");
2252 
2253 REGISTER_OP("_MklMaxPool3D")
2254     .Input("input: T")
2255     .Input("mkl_input: uint8")
2256     .Output("output: T")
2257     .Output("workspace: uint8")
2258     .Output("mkl_output: uint8")
2259     .Output("mkl_workspace: uint8")
2260     .Attr("ksize: list(int) >= 5")
2261     .Attr("strides: list(int) >= 5")
2262     .Attr(GetPaddingAttrString())
2263     .Attr(GetConvnet3dDataFormatAttrString())
2264     .Attr("T: {half, bfloat16, float}")
2265     .Attr("workspace_enabled: bool = false")
2266     .SetShapeFn(shape_inference::Pool3DShape)
2267     .Doc(R"doc(
2268 MKL version of MaxPool3D operator. Uses MKL DNN APIs to perform average pooling
2269 on the input.
2270 
2271 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2272 expected to invoke these operators.
2273 )doc");
2274 
2275 REGISTER_OP("_MklMaxPool3DGrad")
2276     .Input("orig_input: TInput")
2277     .Input("orig_output: TInput")
2278     .Input("grad: T")
2279     .Input("workspace: uint8")
2280     .Input("mkl_orig_input: uint8")
2281     .Input("mkl_orig_output: uint8")
2282     .Input("mkl_grad: uint8")
2283     .Input("mkl_workspace: uint8")
2284     .Output("output: T")
2285     .Output("mkl_output: uint8")
2286     .Attr("ksize: list(int) >= 5")
2287     .Attr("strides: list(int) >= 5")
2288     .Attr(GetPaddingAttrString())
2289     .Attr(GetConvnet3dDataFormatAttrString())
2290     .Attr("T: {half, bfloat16, float} = DT_FLOAT")
2291     .Attr("TInput: {half, bfloat16, float} = DT_FLOAT")
2292     .Attr("workspace_enabled: bool = false")
__anon3e672dd83b02(InferenceContext* c) 2293     .SetShapeFn([](InferenceContext* c) {
2294       return UnchangedShapeWithRank(c, 5);
2295     })
2296     .Doc(R"doc(
2297 MKL version of MklPool3DGrad operator. Uses MKL DNN APIs to compute gradients
2298 of MklPool function.
2299 
2300 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2301 expected to invoke these operators.
2302 )doc");
2303 
2304 REGISTER_OP("_MklLRN")
2305     .Input("input: T")
2306     .Input("mkl_input: uint8")
2307     .Output("output: T")
2308     .Output("workspace: uint8")
2309     .Output("mkl_output: uint8")
2310     .Output("mkl_workspace: uint8")
2311     .Attr("depth_radius: int = 5")
2312     .Attr("bias: float = 1.0")
2313     .Attr("alpha: float = 1.0")
2314     .Attr("beta: float = 0.5")
2315     .Attr("workspace_enabled: bool = false")
2316     .Attr("T: {float, half} = DT_FLOAT")
__anon3e672dd83c02(InferenceContext* c) 2317     .SetShapeFn([](InferenceContext* c) {
2318       return UnchangedShapeWithRank(c, 4);
2319     })
2320     .Doc(R"doc(
2321 MKL version of LRN operator. Uses MKL DNN APIs to perform local response
2322 normalization.
2323 
2324 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2325 expected to invoke these operators.
2326 )doc");
2327 
2328 REGISTER_OP("_MklLRNGrad")
2329     .Input("input_grads: T")
2330     .Input("input_image: T")
2331     .Input("output_image: T")
2332     .Input("workspace: uint8")
2333     .Input("mkl_input_grads: uint8")
2334     .Input("mkl_input_image: uint8")
2335     .Input("mkl_output_image: uint8")
2336     .Input("mkl_workspace: uint8")
2337     .Output("output: T")
2338     .Output("mkl_output: uint8")
2339     .Attr("depth_radius: int = 5")
2340     .Attr("bias: float = 1.0")
2341     .Attr("alpha: float = 1.0")
2342     .Attr("beta: float = 0.5")
2343     .Attr("workspace_enabled: bool = false")
2344     .Attr("T: {float, half} = DT_FLOAT")
__anon3e672dd83d02(InferenceContext* c) 2345     .SetShapeFn([](InferenceContext* c) {
2346       ShapeHandle s;
2347       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &s));  // input_grads
2348       TF_RETURN_IF_ERROR(c->Merge(s, c->input(1), &s));     // input_image
2349       TF_RETURN_IF_ERROR(c->Merge(s, c->input(2), &s));     // output_image
2350       c->set_output(0, s);
2351       return Status::OK();
2352     })
2353     .Doc(R"doc(
2354 MKL version of LRNGrad operator. Uses MKL DNN APIs to compute gradient for
2355 local response normalization.
2356 
2357 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2358 expected to invoke these operators.
2359 )doc");
2360 
2361 REGISTER_OP("_MklFusedBatchNorm")
2362     .Input("x: T")
2363     .Input("scale: T")
2364     .Input("offset: T")
2365     .Input("mean: T")
2366     .Input("variance: T")
2367     .Input("mkl_x: uint8")
2368     .Input("mkl_scale: uint8")
2369     .Input("mkl_offset: uint8")
2370     .Input("mkl_mean: uint8")
2371     .Input("mkl_variance: uint8")
2372     .Output("y: T")
2373     .Output("batch_mean: T")
2374     .Output("batch_variance: T")
2375     .Output("reserve_space_1: T")
2376     .Output("reserve_space_2: T")
2377     .Output("mkl_y: uint8")
2378     .Output("mkl_batch_mean: uint8")
2379     .Output("mkl_batch_variance: uint8")
2380     .Output("mkl_reserve_space_1: uint8")
2381     .Output("mkl_reserve_space_2: uint8")
2382     .Attr("T: numbertype")
2383     .Attr("epsilon: float = 0.0001")
2384     .Attr("data_format: string = 'NHWC'")
2385     .Attr("is_training: bool = true")
__anon3e672dd83e02(InferenceContext* c) 2386     .SetShapeFn([](InferenceContext* c) {
2387       ShapeHandle x;
2388       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &x));
2389 
2390       bool is_training;
2391       TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
2392       int number_inputs = (is_training) ? 3 : 5;
2393       string data_format;
2394       TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format));
2395       DimensionHandle channel_dim =
2396           (data_format == "NHWC") ? c->Dim(x, 3) : c->Dim(x, 1);
2397 
2398       // covers scale, offset, and if is_training is false, mean, variance
2399       for (int i = 1; i < number_inputs; ++i) {
2400         ShapeHandle vec;
2401         TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
2402         TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim));
2403       }
2404 
2405       ShapeHandle y;
2406       if (data_format == "NHWC") {
2407         TF_RETURN_IF_ERROR(c->ReplaceDim(x, 3, channel_dim, &y));
2408       } else {
2409         TF_RETURN_IF_ERROR(c->ReplaceDim(x, 1, channel_dim, &y));
2410       }
2411       c->set_output(0, y);
2412       ShapeHandle vector_shape = c->Vector(channel_dim);
2413       c->set_output(1, vector_shape);
2414       c->set_output(2, vector_shape);
2415       c->set_output(3, vector_shape);
2416       c->set_output(4, vector_shape);
2417       return Status::OK();
2418     })
2419     .Doc(R"doc(
2420 MKL version of FusedBatchNorm operator. Uses MKL DNN APIs to perform fused
2421 batch normalization.
2422 
2423 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2424 expected to invoke these operators.
2425 )doc");
2426 
2427 REGISTER_OP("_MklFusedBatchNormGrad")
2428     .Input("y_backprop: T")
2429     .Input("x: T")
2430     .Input("scale: T")
2431     .Input("reserve_space_1: T")
2432     .Input("reserve_space_2: T")
2433     .Input("mkl_y_backprop: uint8")
2434     .Input("mkl_x: uint8")
2435     .Input("mkl_scale: uint8")
2436     .Input("mkl_reserve_space_1: uint8")
2437     .Input("mkl_reserve_space_2: uint8")
2438     .Output("x_backprop: T")
2439     .Output("scale_backprop: T")
2440     .Output("offset_backprop: T")
2441     .Output("reserve_space_3: T")
2442     .Output("reserve_space_4: T")
2443     .Output("mkl_x_backprop: uint8")
2444     .Output("mkl_scale_backprop: uint8")
2445     .Output("mkl_offset_backprop: uint8")
2446     .Output("mkl_reserve_space_3: uint8")
2447     .Output("mkl_reserve_space_4: uint8")
2448     .Attr("T: numbertype")
2449     .Attr("epsilon: float = 0.0001")
2450     .Attr("data_format: string = 'NHWC'")
2451     .Attr("is_training: bool = true")
__anon3e672dd83f02(InferenceContext* c) 2452     .SetShapeFn([](InferenceContext* c) {
2453       ShapeHandle y_backprop;
2454       TF_RETURN_IF_ERROR(c->WithRank(c->input(0), 4, &y_backprop));
2455       ShapeHandle x;
2456       TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 4, &x));
2457 
2458       bool is_training;
2459       string data_format;
2460       TF_RETURN_IF_ERROR(c->GetAttr("is_training", &is_training));
2461       TF_RETURN_IF_ERROR(c->GetAttr("data_format", &data_format));
2462       DimensionHandle channel_dim = (data_format == "NHWC")
2463                                         ? c->Dim(y_backprop, 3)
2464                                         : c->Dim(y_backprop, 1);
2465       if (data_format == "NHWC") {
2466         TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(x, 3), &channel_dim));
2467       } else {
2468         TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(x, 1), &channel_dim));
2469       }
2470 
2471       // covers scale, mean (reserve_space_1), variance (reserve_space_2)
2472       for (int i = 2; i < 5; ++i) {
2473         ShapeHandle vec;
2474         TF_RETURN_IF_ERROR(c->WithRank(c->input(i), 1, &vec));
2475         TF_RETURN_IF_ERROR(c->Merge(channel_dim, c->Dim(vec, 0), &channel_dim));
2476       }
2477 
2478       ShapeHandle x_backprop;
2479       if (data_format == "NHWC") {
2480         TF_RETURN_IF_ERROR(
2481             c->ReplaceDim(y_backprop, 3, channel_dim, &x_backprop));
2482       } else {
2483         TF_RETURN_IF_ERROR(
2484             c->ReplaceDim(y_backprop, 1, channel_dim, &x_backprop));
2485       }
2486       c->set_output(0, x_backprop);
2487       c->set_output(1, c->Vector(channel_dim));
2488       c->set_output(2, c->Vector(channel_dim));
2489       // Set the correct shapes for reserve_spaces
2490       // so that gradients can be performed when
2491       // the op is in a symbolic condition.
2492       if (is_training) {
2493         c->set_output(3, c->Vector(0));
2494         c->set_output(4, c->Vector(0));
2495       } else {
2496         c->set_output(3, c->Vector(channel_dim));
2497         c->set_output(4, c->Vector(channel_dim));
2498       }
2499       return Status::OK();
2500     })
2501     .Doc(R"doc(
2502 MKL version of FusedBatchNormGrad operator. Uses MKL DNN APIs to compute
2503 gradients for fused batch normalization.
2504 
2505 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2506 expected to invoke these operators.
2507 )doc");
2508 
2509 REGISTER_OP("_MklToTf")
2510     .Input("input: T")
2511     .Input("mkl_input: uint8")
2512     .Output("output: T")
2513     .Attr("T: {half, float, double, qint8, quint8, qint32}")
2514     .Attr(GetConvnetDataFormat2D3DAttrString())
2515     .SetShapeFn(shape_inference::UnknownShape)
2516     .Doc(R"doc(
2517 MKL operator to convert a tensor from MKL layout to TensorFlow layout.
2518 
2519 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2520 expected to invoke these operators.
2521 )doc");
2522 
2523 REGISTER_OP("_MklInputConversion")
2524     .Input("input_0: T")
2525     .Input("input_1: T")
2526     .Input("mkl_input_0: uint8")
2527     .Input("mkl_input_1: uint8")
2528     .Output("output_0: T")
2529     .Output("output_1: T")
2530     .Output("mkl_output_0: uint8")
2531     .Output("mkl_output_1: uint8")
2532     // All datatypes supported by element-wise ops
2533     .Attr(
2534         "T: {half, float, double, uint8, int8, uint16, int16, int32, int64, "
2535         "complex64, complex128}")
2536     .Attr(GetConvnetDataFormat2D3DAttrString())
2537     .SetShapeFn(shape_inference::UnknownShape)
2538     .Doc(R"doc(
2539 MKL operator to process the inputs to an elementwise MKL op. Both inputs
2540 need to be either in TF or in MKL format. This op is added before every
2541 element-wise MKL op.
2542 
2543 NOTE Do not invoke this operator directly in Python. Graph rewrite pass is
2544 expected to invoke these operators.
2545 )doc");
2546 
2547 #endif  // INTEL_MKL
2548 REGISTER_OP("QuantizedConv2DAndRequantize")
2549     .Input("input: Tinput")
2550     .Input("filter: Tfilter")
2551     .Input("min_input: float")
2552     .Input("max_input: float")
2553     .Input("min_filter: float")
2554     .Input("max_filter: float")
2555     .Input("min_freezed_output: float")
2556     .Input("max_freezed_output: float")
2557     .Output("output: out_type")
2558     .Output("min_output: float")
2559     .Output("max_output: float")
2560     .Attr("Tinput: quantizedtype")
2561     .Attr("Tfilter: quantizedtype")
2562     .Attr("out_type: quantizedtype = DT_QINT8")
2563     .Attr("strides: list(int)")
2564     .Attr(GetPaddingAttrString())
2565     .Attr("dilations: list(int) = [1, 1, 1, 1]")
2566     .Attr("padding_list: list(int) = []")
__anon3e672dd84002(InferenceContext* c) 2567     .SetShapeFn([](InferenceContext* c) {
2568       TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
2569       ShapeHandle unused;
2570       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
2571       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2572       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
2573       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
2574       TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
2575       TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
2576       c->set_output(1, c->Scalar());
2577       c->set_output(2, c->Scalar());
2578       return Status::OK();
2579     });
2580 
2581 // Fusion of Quantized Conv2D and BiasAdd.
2582 REGISTER_OP("QuantizedConv2DWithBias")
2583     .Input("input: Tinput")
2584     .Input("filter: Tfilter")
2585     .Input("bias: float")
2586     .Input("min_input: float")
2587     .Input("max_input: float")
2588     .Input("min_filter: float")
2589     .Input("max_filter: float")
2590     .Output("output: out_type")
2591     .Output("min_output: float")
2592     .Output("max_output: float")
2593     .Attr("Tinput: quantizedtype")
2594     .Attr("Tfilter: quantizedtype")
2595     .Attr("out_type: quantizedtype = DT_QINT32")
2596     .Attr("strides: list(int)")
2597     .Attr(GetPaddingAttrString())
2598     .Attr("dilations: list(int) = [1, 1, 1, 1]")
2599     .Attr("padding_list: list(int) = []")
__anon3e672dd84102(InferenceContext* c) 2600     .SetShapeFn([](InferenceContext* c) {
2601       TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
2602       ShapeHandle unused;
2603       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
2604       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2605       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
2606       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
2607       TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
2608       c->set_output(1, c->Scalar());
2609       c->set_output(2, c->Scalar());
2610       return Status::OK();
2611     });
2612 
2613 REGISTER_OP("QuantizedConv2DWithBiasAndRequantize")
2614     .Input("input: Tinput")
2615     .Input("filter: Tfilter")
2616     .Input("bias: Tbias")
2617     .Input("min_input: float")
2618     .Input("max_input: float")
2619     .Input("min_filter: float")
2620     .Input("max_filter: float")
2621     .Input("min_freezed_output: float")
2622     .Input("max_freezed_output: float")
2623     .Output("output: out_type")
2624     .Output("min_output: float")
2625     .Output("max_output: float")
2626     .Attr("Tinput: quantizedtype")
2627     .Attr("Tfilter: quantizedtype")
2628     .Attr("Tbias: {float, qint32}")
2629     .Attr("out_type: quantizedtype = DT_QINT8")
2630     .Attr("strides: list(int)")
2631     .Attr(GetPaddingAttrString())
2632     .Attr("dilations: list(int) = [1, 1, 1, 1]")
2633     .Attr("padding_list: list(int) = []")
__anon3e672dd84202(InferenceContext* c) 2634     .SetShapeFn([](InferenceContext* c) {
2635       TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
2636       ShapeHandle unused;
2637       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
2638       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2639       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
2640       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
2641       TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
2642       TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
2643       TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused));
2644       c->set_output(1, c->Scalar());
2645       c->set_output(2, c->Scalar());
2646       return Status::OK();
2647     });
2648 
2649 // Fusion of Quantized Conv2D and Relu.
2650 REGISTER_OP("QuantizedConv2DAndRelu")
2651     .Input("input: Tinput")
2652     .Input("filter: Tfilter")
2653     .Input("min_input: float")
2654     .Input("max_input: float")
2655     .Input("min_filter: float")
2656     .Input("max_filter: float")
2657     .Output("output: out_type")
2658     .Output("min_output: float")
2659     .Output("max_output: float")
2660     .Attr("Tinput: quantizedtype")
2661     .Attr("Tfilter: quantizedtype")
2662     .Attr("out_type: quantizedtype = DT_QINT32")
2663     .Attr("strides: list(int)")
2664     .Attr(GetPaddingAttrString())
2665     .Attr("dilations: list(int) = [1, 1, 1, 1]")
2666     .Attr("padding_list: list(int) = []")
__anon3e672dd84302(InferenceContext* c) 2667     .SetShapeFn([](InferenceContext* c) {
2668       TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
2669       ShapeHandle unused;
2670       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
2671       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2672       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
2673       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
2674       c->set_output(1, c->Scalar());
2675       c->set_output(2, c->Scalar());
2676       return Status::OK();
2677     });
2678 
2679 REGISTER_OP("QuantizedConv2DAndReluAndRequantize")
2680     .Input("input: Tinput")
2681     .Input("filter: Tfilter")
2682     .Input("min_input: float")
2683     .Input("max_input: float")
2684     .Input("min_filter: float")
2685     .Input("max_filter: float")
2686     .Input("min_freezed_output: float")
2687     .Input("max_freezed_output: float")
2688     .Output("output: out_type")
2689     .Output("min_output: float")
2690     .Output("max_output: float")
2691     .Attr("Tinput: quantizedtype")
2692     .Attr("Tfilter: quantizedtype")
2693     .Attr("out_type: quantizedtype = DT_QUINT8")
2694     .Attr("strides: list(int)")
2695     .Attr(GetPaddingAttrString())
2696     .Attr("dilations: list(int) = [1, 1, 1, 1]")
2697     .Attr("padding_list: list(int) = []")
__anon3e672dd84402(InferenceContext* c) 2698     .SetShapeFn([](InferenceContext* c) {
2699       TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
2700       ShapeHandle unused;
2701       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &unused));
2702       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2703       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
2704       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
2705       TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
2706       TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
2707       c->set_output(1, c->Scalar());
2708       c->set_output(2, c->Scalar());
2709       return Status::OK();
2710     });
2711 
2712 // Fusion of Quantized Conv2D, BiasAdd and Relu.
2713 REGISTER_OP("QuantizedConv2DWithBiasAndRelu")
2714     .Input("input: Tinput")
2715     .Input("filter: Tfilter")
2716     .Input("bias: float")
2717     .Input("min_input: float")
2718     .Input("max_input: float")
2719     .Input("min_filter: float")
2720     .Input("max_filter: float")
2721     .Output("output: out_type")
2722     .Output("min_output: float")
2723     .Output("max_output: float")
2724     .Attr("Tinput: quantizedtype")
2725     .Attr("Tfilter: quantizedtype")
2726     .Attr("out_type: quantizedtype = DT_QINT32")
2727     .Attr("strides: list(int)")
2728     .Attr(GetPaddingAttrString())
2729     .Attr("dilations: list(int) = [1, 1, 1, 1]")
2730     .Attr("padding_list: list(int) = []")
__anon3e672dd84502(InferenceContext* c) 2731     .SetShapeFn([](InferenceContext* c) {
2732       TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
2733       ShapeHandle unused;
2734       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
2735       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2736       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
2737       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
2738       TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
2739       c->set_output(1, c->Scalar());
2740       c->set_output(2, c->Scalar());
2741       return Status::OK();
2742     });
2743 
2744 // Fusion of Quantized Conv2D, BiasAdd, Relu, and Requantize.
2745 REGISTER_OP("QuantizedConv2DWithBiasAndReluAndRequantize")
2746     .Input("input: Tinput")
2747     .Input("filter: Tfilter")
2748     .Input("bias: Tbias")
2749     .Input("min_input: float")
2750     .Input("max_input: float")
2751     .Input("min_filter: float")
2752     .Input("max_filter: float")
2753     .Input("min_freezed_output: float")
2754     .Input("max_freezed_output: float")
2755     .Output("output: out_type")
2756     .Output("min_output: float")
2757     .Output("max_output: float")
2758     .Attr("Tinput: quantizedtype")
2759     .Attr("Tfilter: quantizedtype")
2760     .Attr("Tbias: {float, qint32}")
2761     .Attr("out_type: quantizedtype = DT_QUINT8")
2762     .Attr("strides: list(int)")
2763     .Attr(GetPaddingAttrString())
2764     .Attr("dilations: list(int) = [1, 1, 1, 1]")
2765     .Attr("padding_list: list(int) = []")
__anon3e672dd84602(InferenceContext* c) 2766     .SetShapeFn([](InferenceContext* c) {
2767       TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
2768       ShapeHandle unused;
2769       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
2770       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2771       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
2772       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
2773       TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
2774       TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
2775       TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused));
2776       c->set_output(1, c->Scalar());
2777       c->set_output(2, c->Scalar());
2778       return Status::OK();
2779     });
2780 
2781 // Fusion of Quantized Conv2D, BiasAdd, Sum, and Relu.
2782 REGISTER_OP("QuantizedConv2DWithBiasSumAndRelu")
2783     .Input("input: Tinput")
2784     .Input("filter: Tfilter")
2785     .Input("bias: float")
2786     .Input("min_input: float")
2787     .Input("max_input: float")
2788     .Input("min_filter: float")
2789     .Input("max_filter: float")
2790     .Input("summand: float")
2791     .Output("output: out_type")
2792     .Output("min_output: float")
2793     .Output("max_output: float")
2794     .Attr("Tinput: quantizedtype")
2795     .Attr("Tfilter: quantizedtype")
2796     .Attr("out_type: quantizedtype = DT_QINT32")
2797     .Attr("strides: list(int)")
2798     .Attr(GetPaddingAttrString())
2799     .Attr("dilations: list(int) = [1, 1, 1, 1]")
2800     .Attr("padding_list: list(int) = []")
__anon3e672dd84702(InferenceContext* c) 2801     .SetShapeFn([](InferenceContext* c) {
2802       TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
2803       ShapeHandle unused;
2804       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
2805       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2806       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
2807       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
2808       TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
2809       c->set_output(1, c->Scalar());
2810       c->set_output(2, c->Scalar());
2811       return Status::OK();
2812     });
2813 
2814 REGISTER_OP("QuantizedConv2DWithBiasSumAndReluAndRequantize")
2815     .Input("input: Tinput")
2816     .Input("filter: Tfilter")
2817     .Input("bias: Tbias")
2818     .Input("min_input: float")
2819     .Input("max_input: float")
2820     .Input("min_filter: float")
2821     .Input("max_filter: float")
2822     .Input("min_freezed_output: float")
2823     .Input("max_freezed_output: float")
2824     .Input("summand: Tsummand")
2825     .Input("min_summand: float")
2826     .Input("max_summand: float")
2827     .Output("output: out_type")
2828     .Output("min_output: float")
2829     .Output("max_output: float")
2830     .Attr("Tinput: quantizedtype")
2831     .Attr("Tfilter: quantizedtype")
2832     .Attr("Tbias: {float, qint32}")
2833     .Attr("Tsummand: quantizedtype")
2834     .Attr("out_type: quantizedtype = DT_QUINT8")
2835     .Attr("strides: list(int)")
2836     .Attr(GetPaddingAttrString())
2837     .Attr("dilations: list(int) = [1, 1, 1, 1]")
2838     .Attr("padding_list: list(int) = []")
__anon3e672dd84802(InferenceContext* c) 2839     .SetShapeFn([](InferenceContext* c) {
2840       TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
2841       ShapeHandle unused;
2842       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
2843       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2844       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
2845       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
2846       TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
2847       TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
2848       TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused));
2849       c->set_output(1, c->Scalar());
2850       c->set_output(2, c->Scalar());
2851       return Status::OK();
2852     });
2853 
2854 REGISTER_OP("QuantizedConv2DWithBiasSignedSumAndReluAndRequantize")
2855     .Input("input: Tinput")
2856     .Input("filter: Tfilter")
2857     .Input("bias: Tbias")
2858     .Input("min_input: float")
2859     .Input("max_input: float")
2860     .Input("min_filter: float")
2861     .Input("max_filter: float")
2862     .Input("min_freezed_output: float")
2863     .Input("max_freezed_output: float")
2864     .Input("summand: Tsummand")
2865     .Input("min_summand: float")
2866     .Input("max_summand: float")
2867     .Output("output: out_type")
2868     .Output("min_output: float")
2869     .Output("max_output: float")
2870     .Attr("Tinput: quantizedtype")
2871     .Attr("Tfilter: quantizedtype")
2872     .Attr("Tbias: {float, qint32}")
2873     .Attr("Tsummand: quantizedtype")
2874     .Attr("out_type: quantizedtype = DT_QUINT8")
2875     .Attr("strides: list(int)")
2876     .Attr(GetPaddingAttrString())
2877     .Attr("dilations: list(int) = [1, 1, 1, 1]")
2878     .Attr("padding_list: list(int) = []")
__anon3e672dd84902(InferenceContext* c) 2879     .SetShapeFn([](InferenceContext* c) {
2880       TF_RETURN_IF_ERROR(shape_inference::Conv2DShape(c));
2881       ShapeHandle unused;
2882       TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &unused));
2883       TF_RETURN_IF_ERROR(c->WithRank(c->input(3), 0, &unused));
2884       TF_RETURN_IF_ERROR(c->WithRank(c->input(4), 0, &unused));
2885       TF_RETURN_IF_ERROR(c->WithRank(c->input(5), 0, &unused));
2886       TF_RETURN_IF_ERROR(c->WithRank(c->input(6), 0, &unused));
2887       TF_RETURN_IF_ERROR(c->WithRank(c->input(7), 0, &unused));
2888       TF_RETURN_IF_ERROR(c->WithRank(c->input(8), 0, &unused));
2889       c->set_output(1, c->Scalar());
2890       c->set_output(2, c->Scalar());
2891       return Status::OK();
2892     });
2893 
2894 }  // namespace tensorflow
2895