1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/op_types.h"
17
18 #include "tensorflow/core/framework/attr_value.pb.h"
19 #include "tensorflow/core/framework/op.h"
20 #include "tensorflow/core/framework/types.h"
21 #include "tensorflow/core/grappler/utils.h"
22 #include "tensorflow/core/lib/core/status.h"
23 #include "tensorflow/core/lib/gtl/flatset.h"
24 #include "tensorflow/core/lib/strings/str_util.h"
25 #include "tensorflow/core/platform/logging.h"
26
27 namespace tensorflow {
28 namespace grappler {
29
IsAdd(const NodeDef & node)30 bool IsAdd(const NodeDef& node) {
31 if (node.op() == "AddV2") {
32 return true;
33 }
34 if (node.op() == "Add") {
35 DataType type = node.attr().at("T").type();
36 return type != DT_STRING;
37 }
38 return false;
39 }
40
IsAddN(const NodeDef & node)41 bool IsAddN(const NodeDef& node) { return node.op() == "AddN"; }
42
IsAll(const NodeDef & node)43 bool IsAll(const NodeDef& node) { return node.op() == "All"; }
44
IsAngle(const NodeDef & node)45 bool IsAngle(const NodeDef& node) { return node.op() == "Angle"; }
46
IsAny(const NodeDef & node)47 bool IsAny(const NodeDef& node) { return node.op() == "Any"; }
48
IsAnyDiv(const NodeDef & node)49 bool IsAnyDiv(const NodeDef& node) {
50 return node.op() == "RealDiv" || node.op() == "Div" || node.op() == "Xdivy" ||
51 node.op() == "FloorDiv" || node.op() == "TruncateDiv";
52 }
53
IsAnyBatchMatMul(const NodeDef & node)54 bool IsAnyBatchMatMul(const NodeDef& node) {
55 return node.op() == "BatchMatMul" || node.op() == "BatchMatMulV2";
56 }
57
IsAnyMatMul(const NodeDef & node)58 bool IsAnyMatMul(const NodeDef& node) {
59 return node.op() == "MatMul" || node.op() == "SparseMatMul" ||
60 IsAnyBatchMatMul(node) || IsQuantizedMatMul(node);
61 }
62
IsAnyMax(const NodeDef & node)63 bool IsAnyMax(const NodeDef& node) {
64 const auto& op = node.op();
65 return op == "Max" || op == "SegmentMax" || op == "UnsortedSegmentMax";
66 }
67
IsAnyMaxPool(const NodeDef & node)68 bool IsAnyMaxPool(const NodeDef& node) {
69 const auto& op = node.op();
70 return op == "MaxPool" || op == "MaxPoolV2" || op == "MaxPool3D" ||
71 op == "MaxPoolWithArgmax" || op == "FractionalMaxPool";
72 }
73
IsAnyMin(const NodeDef & node)74 bool IsAnyMin(const NodeDef& node) {
75 const auto& op = node.op();
76 return op == "Min" || op == "SegmentMin" || op == "UnsortedSegmentMin";
77 }
78
IsAnySparseSegmentReduction(const NodeDef & node)79 bool IsAnySparseSegmentReduction(const NodeDef& node) {
80 const auto& op = node.op();
81 return op == "SparseSegmentSum" || op == "SparseSegmentSumWithNumSegments" ||
82 op == "SparseSegmentMean" ||
83 op == "SparseSegmentMeanWithNumSegments" ||
84 op == "SparseSegmentSqrtN" ||
85 op == "SparseSegmentSqrtNWithNumSegments";
86 }
87
IsApproximateEqual(const NodeDef & node)88 bool IsApproximateEqual(const NodeDef& node) {
89 return node.op() == "ApproximateEqual";
90 }
91
IsArg(const NodeDef & node)92 bool IsArg(const NodeDef& node) {
93 return node.op() == "_Arg" || node.op() == "_DeviceArg";
94 }
95
IsArgMax(const NodeDef & node)96 bool IsArgMax(const NodeDef& node) { return node.op() == "ArgMax"; }
97
IsArgMin(const NodeDef & node)98 bool IsArgMin(const NodeDef& node) { return node.op() == "ArgMin"; }
99
IsAvgPoolGrad(const NodeDef & node)100 bool IsAvgPoolGrad(const NodeDef& node) { return node.op() == "AvgPoolGrad"; }
101
IsAssign(const NodeDef & node)102 bool IsAssign(const NodeDef& node) {
103 return node.op() == "Assign" || node.op() == "AssignVariableOp";
104 }
105
IsAssert(const NodeDef & node)106 bool IsAssert(const NodeDef& node) { return node.op() == "Assert"; }
107
IsAtan2(const NodeDef & node)108 bool IsAtan2(const NodeDef& node) { return node.op() == "Atan2"; }
109
IsBetainc(const NodeDef & node)110 bool IsBetainc(const NodeDef& node) { return node.op() == "Betainc"; }
111
IsBiasAdd(const NodeDef & node)112 bool IsBiasAdd(const NodeDef& node) {
113 return node.op() == "BiasAdd" || node.op() == "BiasAddV1";
114 }
115
IsBiasAddV2(const NodeDef & node)116 bool IsBiasAddV2(const NodeDef& node) { return node.op() == "BiasAdd"; }
117
IsBiasAddGrad(const NodeDef & node)118 bool IsBiasAddGrad(const NodeDef& node) { return node.op() == "BiasAddGrad"; }
119
IsBitcast(const NodeDef & node)120 bool IsBitcast(const NodeDef& node) { return node.op() == "Bitcast"; }
121
IsBroadcastTo(const NodeDef & node)122 bool IsBroadcastTo(const NodeDef& node) { return node.op() == "BroadcastTo"; }
123
IsCast(const NodeDef & node)124 bool IsCast(const NodeDef& node) { return node.op() == "Cast"; }
125
IsCastLike(const NodeDef & node)126 bool IsCastLike(const NodeDef& node) {
127 static const gtl::FlatSet<string>* const kCastLikeOps =
128 CHECK_NOTNULL((new gtl::FlatSet<string>{
129 "Angle", "Bucketize", "Cast", "CompareAndBitpack", "Dequantize",
130 "HistogramFixedWidth", "Imag", "IsFinite", "IsInf", "IsNan",
131 "Quantize", "QuantizeDownAndShrinkRange", "QuantizeV2",
132 "QuantizedInstanceNorm", "QuantizedRelu", "QuantizedRelu6",
133 "QuantizedReluX", "Real", "Requantize"}));
134 return kCastLikeOps->count(node.op()) > 0;
135 }
136
IsCheckNumerics(const NodeDef & node)137 bool IsCheckNumerics(const NodeDef& node) {
138 return node.op() == "CheckNumerics";
139 }
140
IsCollective(const NodeDef & node)141 bool IsCollective(const NodeDef& node) {
142 return node.op() == "CollectiveReduce" ||
143 node.op() == "CollectiveBcastSend" ||
144 node.op() == "CollectiveBcastRecv";
145 }
146
IsComplex(const NodeDef & node)147 bool IsComplex(const NodeDef& node) { return node.op() == "Complex"; }
148
IsComplexAbs(const NodeDef & node)149 bool IsComplexAbs(const NodeDef& node) { return node.op() == "ComplexAbs"; }
150
IsConcat(const NodeDef & node)151 bool IsConcat(const NodeDef& node) {
152 return node.op() == "Concat" || node.op() == "ConcatV2";
153 }
154
IsConcatOffset(const NodeDef & node)155 bool IsConcatOffset(const NodeDef& node) { return node.op() == "ConcatOffset"; }
156
IsConstant(const NodeDef & node)157 bool IsConstant(const NodeDef& node) { return node.op() == "Const"; }
158
IsConj(const NodeDef & node)159 bool IsConj(const NodeDef& node) { return node.op() == "Conj"; }
160
IsConjugateTranspose(const NodeDef & node)161 bool IsConjugateTranspose(const NodeDef& node) {
162 return node.op() == "ConjugateTranspose";
163 }
164
IsControlFlow(const NodeDef & node)165 bool IsControlFlow(const NodeDef& node) {
166 // clang-format off
167 return node.op() == "ControlTrigger" ||
168 node.op() == "Enter" ||
169 node.op() == "Exit" ||
170 node.op() == "LoopCond" ||
171 node.op() == "Merge" ||
172 node.op() == "_XlaMerge" ||
173 node.op() == "NextIteration" ||
174 node.op() == "Switch" ||
175 node.op() == "_SwitchN";
176 // clang-format on
177 }
178
IsConv2D(const NodeDef & node)179 bool IsConv2D(const NodeDef& node) { return node.op() == "Conv2D"; }
180
IsConv2DBackpropFilter(const NodeDef & node)181 bool IsConv2DBackpropFilter(const NodeDef& node) {
182 return node.op() == "Conv2DBackpropFilter";
183 }
184
IsConv2DBackpropInput(const NodeDef & node)185 bool IsConv2DBackpropInput(const NodeDef& node) {
186 return node.op() == "Conv2DBackpropInput";
187 }
188
IsConv3D(const NodeDef & node)189 bool IsConv3D(const NodeDef& node) { return node.op() == "Conv3D"; }
190
IsConv3DBackpropFilterV2(const NodeDef & node)191 bool IsConv3DBackpropFilterV2(const NodeDef& node) {
192 return node.op() == "Conv3DBackpropFilterV2";
193 }
194
IsConv3DBackpropInputV2(const NodeDef & node)195 bool IsConv3DBackpropInputV2(const NodeDef& node) {
196 return node.op() == "Conv3DBackpropInputV2";
197 }
198
IsDepthwiseConv2dNative(const NodeDef & node)199 bool IsDepthwiseConv2dNative(const NodeDef& node) {
200 return node.op() == "DepthwiseConv2dNative";
201 }
202
IsDepthwiseConv2dNativeBackpropFilter(const NodeDef & node)203 bool IsDepthwiseConv2dNativeBackpropFilter(const NodeDef& node) {
204 return node.op() == "DepthwiseConv2dNativeBackpropFilter";
205 }
206
IsDepthwiseConv2dNativeBackpropInput(const NodeDef & node)207 bool IsDepthwiseConv2dNativeBackpropInput(const NodeDef& node) {
208 return node.op() == "DepthwiseConv2dNativeBackpropInput";
209 }
210
IsDequeueOp(const NodeDef & node)211 bool IsDequeueOp(const NodeDef& node) {
212 const auto& op = node.op();
213 return op == "QueueDequeueManyV2" || op == "QueueDequeueMany" ||
214 op == "QueueDequeueV2" || op == "QueueDequeue" ||
215 op == "QueueDequeueUpToV2" || op == "QueueDequeueUpTo";
216 }
217
IsDiv(const NodeDef & node)218 bool IsDiv(const NodeDef& node) { return node.op() == "Div"; }
219
IsDivNoNan(const NodeDef & node)220 bool IsDivNoNan(const NodeDef& node) { return node.op() == "DivNoNan"; }
221
222 // Returns true if node represents a unary elementwise function that is
223 // monotonic. If *is_non_decreasing is true, the function is non-decreasing,
224 // e.g. sqrt, exp. *is_non_decreasing is false, the function is non-increasing,
225 // e.g. inv.
IsElementWiseMonotonic(const NodeDef & node,bool * is_non_decreasing)226 bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing) {
227 static const gtl::FlatSet<string>* const kMonotonicNonDecreasingOps =
228 CHECK_NOTNULL((new gtl::FlatSet<string>{
229 "Acosh", "Asin", "Asinh", "Atan", "Atanh", "Ceil",
230 "Elu", "Erf", "Exp", "Expm1", "Floor", "Log",
231 "Log1p", "Relu", "Relu6", "Rint", "Selu", "Sigmoid",
232 "Sign", "Sinh", "Softsign", "Softplus", "Sqrt", "Tanh",
233 }));
234 static const gtl::FlatSet<string>* const kMonotonicNonIncreasingOps =
235 CHECK_NOTNULL((new gtl::FlatSet<string>{"Acos", "Erfc", "Neg", "Rsqrt"}));
236 if (kMonotonicNonDecreasingOps->count(node.op()) > 0) {
237 if (is_non_decreasing) {
238 *is_non_decreasing = true;
239 }
240 return true;
241 } else if (kMonotonicNonIncreasingOps->count(node.op()) > 0) {
242 if (is_non_decreasing) {
243 *is_non_decreasing = false;
244 }
245 return true;
246 }
247 return false;
248 }
249
IsElu(const NodeDef & node)250 bool IsElu(const NodeDef& node) { return node.op() == "Elu"; }
251
IsEluGrad(const NodeDef & node)252 bool IsEluGrad(const NodeDef& node) { return node.op() == "EluGrad"; }
253
IsQuantizationEmulation(const NodeDef & node)254 bool IsQuantizationEmulation(const NodeDef& node) {
255 const auto& op = node.op();
256 return absl::StartsWith(op, "QuantizeAndDequantize") ||
257 absl::StartsWith(op, "FakeQuantWithMinMax");
258 }
259
IsEnter(const NodeDef & node)260 bool IsEnter(const NodeDef& node) {
261 const auto& op = node.op();
262 return op == "Enter" || op == "RefEnter";
263 }
264
IsEqual(const NodeDef & node)265 bool IsEqual(const NodeDef& node) { return node.op() == "Equal"; }
266
IsExit(const NodeDef & node)267 bool IsExit(const NodeDef& node) {
268 const auto& op = node.op();
269 return op == "Exit" || op == "RefExit";
270 }
271
IsExp(const NodeDef & node)272 bool IsExp(const NodeDef& node) { return node.op() == "Exp"; }
273
IsFakeParam(const NodeDef & node)274 bool IsFakeParam(const NodeDef& node) { return node.op() == "FakeParam"; }
275
IsFill(const NodeDef & node)276 bool IsFill(const NodeDef& node) { return node.op() == "Fill"; }
277
IsFloorDiv(const NodeDef & node)278 bool IsFloorDiv(const NodeDef& node) { return node.op() == "FloorDiv"; }
279
IsFloorMod(const NodeDef & node)280 bool IsFloorMod(const NodeDef& node) { return node.op() == "FloorMod"; }
281
IsFusedBatchNorm(const NodeDef & node)282 bool IsFusedBatchNorm(const NodeDef& node) {
283 const auto& op = node.op();
284 return op == "FusedBatchNorm" || op == "FusedBatchNormV2" ||
285 op == "FusedBatchNormV3";
286 }
287
IsFusedBatchNormEx(const NodeDef & node)288 bool IsFusedBatchNormEx(const NodeDef& node) {
289 return node.op() == "_FusedBatchNormEx";
290 }
291
IsFusedBatchNormGrad(const NodeDef & node)292 bool IsFusedBatchNormGrad(const NodeDef& node) {
293 const auto& op = node.op();
294 return op == "FusedBatchNormGrad" || op == "FusedBatchNormGradV2" ||
295 op == "FusedBatchNormGradV3";
296 }
297
IsGather(const NodeDef & node)298 bool IsGather(const NodeDef& node) {
299 const auto& op = node.op();
300 return op == "Gather" || op == "GatherV2";
301 }
302
IsGreater(const NodeDef & node)303 bool IsGreater(const NodeDef& node) { return node.op() == "Greater"; }
304
IsGreaterEqual(const NodeDef & node)305 bool IsGreaterEqual(const NodeDef& node) { return node.op() == "GreaterEqual"; }
306
IsHostConstant(const NodeDef & node)307 bool IsHostConstant(const NodeDef& node) { return node.op() == "HostConst"; }
308
IsHistogramSummary(const NodeDef & node)309 bool IsHistogramSummary(const NodeDef& node) {
310 return node.op() == "HistogramSummary";
311 }
312
IsIdentity(const NodeDef & node)313 bool IsIdentity(const NodeDef& node) {
314 const auto& op = node.op();
315 return op == "Identity" || op == "RefIdentity";
316 }
317
IsIdentityN(const NodeDef & node)318 bool IsIdentityN(const NodeDef& node) {
319 const auto& op = node.op();
320 return op == "IdentityN";
321 }
322
IsIdentityNSingleInput(const NodeDef & node)323 bool IsIdentityNSingleInput(const NodeDef& node) {
324 return IsIdentityN(node) && node.attr().count("T") != 0 &&
325 node.attr().at("T").list().type_size() == 1;
326 }
327
IsIf(const NodeDef & node)328 bool IsIf(const NodeDef& node) {
329 const auto& op = node.op();
330 return op == "If" || op == "StatelessIf";
331 }
332
IsIgamma(const NodeDef & node)333 bool IsIgamma(const NodeDef& node) { return node.op() == "Igamma"; }
334
IsIgammac(const NodeDef & node)335 bool IsIgammac(const NodeDef& node) { return node.op() == "Igammac"; }
336
IsImag(const NodeDef & node)337 bool IsImag(const NodeDef& node) { return node.op() == "Imag"; }
338
IsImmutableConst(const NodeDef & node)339 bool IsImmutableConst(const NodeDef& node) {
340 return node.op() == "ImmutableConst";
341 }
342
IsInvGrad(const NodeDef & node)343 bool IsInvGrad(const NodeDef& node) { return node.op() == "InvGrad"; }
344
IsLeakyRelu(const NodeDef & node)345 bool IsLeakyRelu(const NodeDef& node) { return node.op() == "LeakyRelu"; }
346
IsLeakyReluGrad(const NodeDef & node)347 bool IsLeakyReluGrad(const NodeDef& node) {
348 return node.op() == "LeakyReluGrad";
349 }
350
IsLess(const NodeDef & node)351 bool IsLess(const NodeDef& node) { return node.op() == "Less"; }
352
IsLessEqual(const NodeDef & node)353 bool IsLessEqual(const NodeDef& node) { return node.op() == "LessEqual"; }
354
IsLog(const NodeDef & node)355 bool IsLog(const NodeDef& node) { return node.op() == "Log"; }
356
IsLogicalAnd(const NodeDef & node)357 bool IsLogicalAnd(const NodeDef& node) { return node.op() == "LogicalAnd"; }
358
IsLogicalNot(const NodeDef & node)359 bool IsLogicalNot(const NodeDef& node) { return node.op() == "LogicalNot"; }
360
IsLogicalOr(const NodeDef & node)361 bool IsLogicalOr(const NodeDef& node) { return node.op() == "LogicalOr"; }
362
IsLoopCond(const NodeDef & node)363 bool IsLoopCond(const NodeDef& node) { return node.op() == "LoopCond"; }
364
IsMatMul(const NodeDef & node)365 bool IsMatMul(const NodeDef& node) { return node.op() == "MatMul"; }
366
IsMax(const NodeDef & node)367 bool IsMax(const NodeDef& node) { return node.op() == "Max"; }
368
IsMaximum(const NodeDef & node)369 bool IsMaximum(const NodeDef& node) { return node.op() == "Maximum"; }
370
IsMaxPoolGrad(const NodeDef & node)371 bool IsMaxPoolGrad(const NodeDef& node) { return node.op() == "MaxPoolGrad"; }
372
IsMean(const NodeDef & node)373 bool IsMean(const NodeDef& node) { return node.op() == "Mean"; }
374
IsMerge(const NodeDef & node)375 bool IsMerge(const NodeDef& node) {
376 const auto& op = node.op();
377 return op == "Merge" || op == "RefMerge" || op == "_XlaMerge";
378 }
379
IsMin(const NodeDef & node)380 bool IsMin(const NodeDef& node) { return node.op() == "Min"; }
381
IsMinimum(const NodeDef & node)382 bool IsMinimum(const NodeDef& node) { return node.op() == "Minimum"; }
383
IsMirrorPad(const NodeDef & node)384 bool IsMirrorPad(const NodeDef& node) { return node.op() == "MirrorPad"; }
385
IsMirrorPadGrad(const NodeDef & node)386 bool IsMirrorPadGrad(const NodeDef& node) {
387 return node.op() == "MirrorPadGrad";
388 }
389
IsMod(const NodeDef & node)390 bool IsMod(const NodeDef& node) { return node.op() == "Mod"; }
391
IsMul(const NodeDef & node)392 bool IsMul(const NodeDef& node) { return node.op() == "Mul"; }
IsMulNoNan(const NodeDef & node)393 bool IsMulNoNan(const NodeDef& node) { return node.op() == "MulNoNan"; }
IsAnyMul(const NodeDef & node)394 bool IsAnyMul(const NodeDef& node) { return IsMul(node) || IsMulNoNan(node); }
395
IsNeg(const NodeDef & node)396 bool IsNeg(const NodeDef& node) { return node.op() == "Neg"; }
397
IsNoOp(const NodeDef & node)398 bool IsNoOp(const NodeDef& node) { return node.op() == "NoOp"; }
399
IsNotEqual(const NodeDef & node)400 bool IsNotEqual(const NodeDef& node) { return node.op() == "NotEqual"; }
401
IsNextIteration(const NodeDef & node)402 bool IsNextIteration(const NodeDef& node) {
403 const auto& op = node.op();
404 return op == "NextIteration" || op == "RefNextIteration";
405 }
406
IsOnesLike(const NodeDef & node)407 bool IsOnesLike(const NodeDef& node) { return node.op() == "OnesLike"; }
408
IsPack(const NodeDef & node)409 bool IsPack(const NodeDef& node) { return node.op() == "Pack"; }
410
IsPad(const NodeDef & node)411 bool IsPad(const NodeDef& node) {
412 const auto& op = node.op();
413 return op == "Pad" || op == "PadV2";
414 }
415
IsPartitionedCall(const NodeDef & node)416 bool IsPartitionedCall(const NodeDef& node) {
417 return node.op() == "PartitionedCall";
418 }
419
IsPlaceholder(const NodeDef & node)420 bool IsPlaceholder(const NodeDef& node) {
421 const auto& op = node.op();
422 return op == "Placeholder" || op == "PlaceholderV2" ||
423 op == "PlaceholderWithDefault";
424 }
425
IsPolygamma(const NodeDef & node)426 bool IsPolygamma(const NodeDef& node) { return node.op() == "Polygamma"; }
427
IsPow(const NodeDef & node)428 bool IsPow(const NodeDef& node) { return node.op() == "Pow"; }
429
IsPrint(const NodeDef & node)430 bool IsPrint(const NodeDef& node) {
431 return node.op() == "Print" || node.op() == "PrintV2";
432 }
433
IsProd(const NodeDef & node)434 bool IsProd(const NodeDef& node) { return node.op() == "Prod"; }
435
IsQuantizedMatMul(const NodeDef & node)436 bool IsQuantizedMatMul(const NodeDef& node) {
437 return node.op() == "QuantizedMatMul" || node.op() == "QuantizedMatMulV2";
438 }
439
IsQueue(const NodeDef & node)440 bool IsQueue(const NodeDef& node) {
441 return str_util::EndsWith(node.op(), "QueueV2");
442 }
443
IsRandomShuffle(const NodeDef & node)444 bool IsRandomShuffle(const NodeDef& node) {
445 return node.op() == "RandomShuffle";
446 }
447
IsRank(const NodeDef & node)448 bool IsRank(const NodeDef& node) { return node.op() == "Rank"; }
449
IsReadVariableOp(const NodeDef & node)450 bool IsReadVariableOp(const NodeDef& node) {
451 return node.op() == "ReadVariableOp";
452 }
453
IsReadVariablesOp(const NodeDef & node)454 bool IsReadVariablesOp(const NodeDef& node) {
455 return node.op() == "_ReadVariablesOp";
456 }
457
IsReal(const NodeDef & node)458 bool IsReal(const NodeDef& node) { return node.op() == "Real"; }
459
IsRealDiv(const NodeDef & node)460 bool IsRealDiv(const NodeDef& node) { return node.op() == "RealDiv"; }
461
IsReciprocalGrad(const NodeDef & node)462 bool IsReciprocalGrad(const NodeDef& node) {
463 return node.op() == "ReciprocalGrad";
464 }
465
IsRecv(const NodeDef & node)466 bool IsRecv(const NodeDef& node) {
467 return node.op() == "_Recv" || node.op() == "_HostRecv";
468 }
469
IsReduction(const NodeDef & node)470 bool IsReduction(const NodeDef& node) {
471 const auto& op = node.op();
472 return op == "Sum" || op == "Prod" || op == "Min" || op == "Max" ||
473 op == "Mean" || op == "Any" || op == "All";
474 }
475
IsRelu(const NodeDef & node)476 bool IsRelu(const NodeDef& node) { return node.op() == "Relu"; }
477
IsRelu6(const NodeDef & node)478 bool IsRelu6(const NodeDef& node) { return node.op() == "Relu6"; }
479
IsReluGrad(const NodeDef & node)480 bool IsReluGrad(const NodeDef& node) { return node.op() == "ReluGrad"; }
481
IsRelu6Grad(const NodeDef & node)482 bool IsRelu6Grad(const NodeDef& node) { return node.op() == "Relu6Grad"; }
483
IsReshape(const NodeDef & node)484 bool IsReshape(const NodeDef& node) { return (node.op() == "Reshape"); }
485
IsRestore(const NodeDef & node)486 bool IsRestore(const NodeDef& node) {
487 return (node.op() == "Restore" || node.op() == "RestoreV2" ||
488 node.op() == "RestoreSlice");
489 }
490
IsRetval(const NodeDef & node)491 bool IsRetval(const NodeDef& node) {
492 return node.op() == "_Retval" || node.op() == "_DeviceRetval";
493 }
494
IsReverse(const NodeDef & node)495 bool IsReverse(const NodeDef& node) {
496 return node.op() == "Reverse" || node.op() == "ReverseV2";
497 }
498
IsReverseV2(const NodeDef & node)499 bool IsReverseV2(const NodeDef& node) { return node.op() == "ReverseV2"; }
500
IsRsqrt(const NodeDef & node)501 bool IsRsqrt(const NodeDef& node) { return node.op() == "Rsqrt"; }
502
IsRsqrtGrad(const NodeDef & node)503 bool IsRsqrtGrad(const NodeDef& node) { return node.op() == "RsqrtGrad"; }
504
IsSelect(const NodeDef & node)505 bool IsSelect(const NodeDef& node) {
506 return node.op() == "Select" || node.op() == "SelectV2";
507 }
508
IsSeluGrad(const NodeDef & node)509 bool IsSeluGrad(const NodeDef& node) { return node.op() == "SeluGrad"; }
510
IsSend(const NodeDef & node)511 bool IsSend(const NodeDef& node) {
512 return node.op() == "_Send" || node.op() == "_HostSend";
513 }
514
IsShape(const NodeDef & node)515 bool IsShape(const NodeDef& node) { return node.op() == "Shape"; }
516
IsShapeN(const NodeDef & node)517 bool IsShapeN(const NodeDef& node) { return node.op() == "ShapeN"; }
518
IsShuffle(const NodeDef & node)519 bool IsShuffle(const NodeDef& node) { return node.op() == "Shuffle"; }
520
IsSigmoidGrad(const NodeDef & node)521 bool IsSigmoidGrad(const NodeDef& node) { return node.op() == "SigmoidGrad"; }
522
IsSize(const NodeDef & node)523 bool IsSize(const NodeDef& node) { return node.op() == "Size"; }
524
IsSlice(const NodeDef & node)525 bool IsSlice(const NodeDef& node) { return node.op() == "Slice"; }
526
IsSnapshot(const NodeDef & node)527 bool IsSnapshot(const NodeDef& node) { return node.op() == "Snapshot"; }
528
IsSoftmax(const NodeDef & node)529 bool IsSoftmax(const NodeDef& node) { return node.op() == "Softmax"; }
530
IsSoftplusGrad(const NodeDef & node)531 bool IsSoftplusGrad(const NodeDef& node) { return node.op() == "SoftplusGrad"; }
532
IsSoftsignGrad(const NodeDef & node)533 bool IsSoftsignGrad(const NodeDef& node) { return node.op() == "SoftsignGrad"; }
534
IsSplit(const NodeDef & node)535 bool IsSplit(const NodeDef& node) { return node.op() == "Split"; }
536
IsSplitV(const NodeDef & node)537 bool IsSplitV(const NodeDef& node) { return node.op() == "SplitV"; }
538
IsSqrt(const NodeDef & node)539 bool IsSqrt(const NodeDef& node) { return node.op() == "Sqrt"; }
540
IsSqrtGrad(const NodeDef & node)541 bool IsSqrtGrad(const NodeDef& node) { return node.op() == "SqrtGrad"; }
542
IsSquare(const NodeDef & node)543 bool IsSquare(const NodeDef& node) { return node.op() == "Square"; }
544
IsSquaredDifference(const NodeDef & node)545 bool IsSquaredDifference(const NodeDef& node) {
546 return node.op() == "SquaredDifference";
547 }
548
IsSqueeze(const NodeDef & node)549 bool IsSqueeze(const NodeDef& node) { return node.op() == "Squeeze"; }
550
IsStackOp(const NodeDef & node)551 bool IsStackOp(const NodeDef& node) {
552 return node.op() == "Stack" || node.op() == "StackV2";
553 }
IsStackCloseOp(const NodeDef & node)554 bool IsStackCloseOp(const NodeDef& node) {
555 return node.op() == "StackClose" || node.op() == "StackCloseV2";
556 }
IsStackPushOp(const NodeDef & node)557 bool IsStackPushOp(const NodeDef& node) {
558 return node.op() == "StackPush" || node.op() == "StackPushV2";
559 }
IsStackPopOp(const NodeDef & node)560 bool IsStackPopOp(const NodeDef& node) {
561 return node.op() == "StackPop" || node.op() == "StackPopV2";
562 }
563
IsStatefulPartitionedCall(const NodeDef & node)564 bool IsStatefulPartitionedCall(const NodeDef& node) {
565 return node.op() == "StatefulPartitionedCall";
566 }
567
IsStopGradient(const NodeDef & node)568 bool IsStopGradient(const NodeDef& node) {
569 const auto& op = node.op();
570 return op == "StopGradient" || op == "PreventGradient";
571 }
572
IsStridedSlice(const NodeDef & node)573 bool IsStridedSlice(const NodeDef& node) { return node.op() == "StridedSlice"; }
574
IsStridedSliceGrad(const NodeDef & node)575 bool IsStridedSliceGrad(const NodeDef& node) {
576 return node.op() == "StridedSliceGrad";
577 }
578
IsSub(const NodeDef & node)579 bool IsSub(const NodeDef& node) { return node.op() == "Sub"; }
580
IsSum(const NodeDef & node)581 bool IsSum(const NodeDef& node) { return node.op() == "Sum"; }
582
IsSwitch(const NodeDef & node)583 bool IsSwitch(const NodeDef& node) {
584 const auto& op = node.op();
585 return op == "_SwitchN" || op == "Switch" || op == "RefSwitch";
586 }
587
IsSymbolicGradient(const NodeDef & node)588 bool IsSymbolicGradient(const NodeDef& node) {
589 return node.op() == "SymbolicGradient";
590 }
591
IsTanh(const NodeDef & node)592 bool IsTanh(const NodeDef& node) { return node.op() == "Tanh"; }
593
IsTanhGrad(const NodeDef & node)594 bool IsTanhGrad(const NodeDef& node) { return node.op() == "TanhGrad"; }
595
IsTensorArray(const NodeDef & node)596 bool IsTensorArray(const NodeDef& node) {
597 static const gtl::FlatSet<string>* const kTensorArrayOps =
598 CHECK_NOTNULL((new gtl::FlatSet<string>{
599 "TensorArray",
600 "TensorArrayV2",
601 "TensorArrayV3",
602 "TensorArrayGrad",
603 "TensorArrayGradV2",
604 "TensorArrayGradV3",
605 "TensorArrayGradWithShape",
606 "TensorArrayWrite",
607 "TensorArrayWriteV2",
608 "TensorArrayWriteV3",
609 "TensorArrayRead",
610 "TensorArrayReadV2",
611 "TensorArrayReadV3",
612 "TensorArrayConcat",
613 "TensorArrayConcatV2",
614 "TensorArrayConcatV3",
615 "TensorArraySplit",
616 "TensorArraySplitV2",
617 "TensorArraySplitV3",
618 "TensorArraySize",
619 "TensorArraySizeV2",
620 "TensorArraySizeV3",
621 "TensorArrayClose",
622 "TensorArrayCloseV2",
623 "TensorArrayCloseV3",
624 }));
625 return kTensorArrayOps->count(node.op()) > 0;
626 }
627
IsTile(const NodeDef & node)628 bool IsTile(const NodeDef& node) { return node.op() == "Tile"; }
629
IsTranspose(const NodeDef & node)630 bool IsTranspose(const NodeDef& node) { return node.op() == "Transpose"; }
631
IsTruncateDiv(const NodeDef & node)632 bool IsTruncateDiv(const NodeDef& node) { return node.op() == "TruncateDiv"; }
633
IsTruncateMod(const NodeDef & node)634 bool IsTruncateMod(const NodeDef& node) { return node.op() == "TruncateMod"; }
635
IsUnique(const NodeDef & node)636 bool IsUnique(const NodeDef& node) {
637 const auto& op = node.op();
638 return op == "Unique" || op == "UniqueV2";
639 }
640
IsUnpack(const NodeDef & node)641 bool IsUnpack(const NodeDef& node) { return node.op() == "Unpack"; }
642
IsVariable(const NodeDef & node)643 bool IsVariable(const NodeDef& node) {
644 const auto& op = node.op();
645 return op == "Variable" || op == "VariableV2" || op == "AutoReloadVariable" ||
646 op == "VarHandleOp" || op == "ReadVariableOp" ||
647 op == "_VarHandlesOp" || op == "_ReadVariablesOp";
648 }
649
IsWhile(const NodeDef & node)650 bool IsWhile(const NodeDef& node) {
651 const auto& op = node.op();
652 return op == "While" || op == "StatelessWhile";
653 }
654
IsXdivy(const NodeDef & node)655 bool IsXdivy(const NodeDef& node) { return node.op() == "Xdivy"; }
656
IsZerosLike(const NodeDef & node)657 bool IsZerosLike(const NodeDef& node) { return node.op() == "ZerosLike"; }
658
IsZeta(const NodeDef & node)659 bool IsZeta(const NodeDef& node) { return node.op() == "Zeta"; }
660
661 namespace {
GetBoolAttr(const NodeDef & node,const string & name)662 bool GetBoolAttr(const NodeDef& node, const string& name) {
663 return node.attr().count(name) > 0 && node.attr().at(name).b();
664 }
665 } // namespace
666
IsPersistent(const NodeDef & node)667 bool IsPersistent(const NodeDef& node) {
668 return IsConstant(node) || IsVariable(node) || IsHostConstant(node);
669 }
670
HasRefInput(const NodeDef & node)671 bool HasRefInput(const NodeDef& node) {
672 const OpDef* op_def;
673 Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
674 if (!status.ok()) {
675 return false;
676 }
677 // Nodes such as Assign or AssignAdd modify one of their inputs.
678 for (const auto& input : op_def->input_arg()) {
679 if (input.is_ref()) {
680 return true;
681 }
682 }
683 return false;
684 }
685
IsDataset(const NodeDef & node)686 bool IsDataset(const NodeDef& node) {
687 const string& op = node.op();
688 // See `GetNodeClassForOp` in core/graph/graph.cc.
689 return op == "IteratorGetNext" || op == "IteratorGetNextSync" ||
690 op == "DatasetToSingleElement" || op == "ReduceDataset";
691 }
692
IsStateful(const NodeDef node,const OpRegistryInterface * op_registry)693 bool IsStateful(const NodeDef node, const OpRegistryInterface* op_registry) {
694 const OpDef* op_def = nullptr;
695 const string& op_name = node.op();
696 Status status = op_registry->LookUpOpDef(op_name, &op_def);
697 if (!status.ok()) {
698 LOG(WARNING) << "Failed to lookup OpDef for " << op_name
699 << ". Error: " << status.error_message();
700 return false;
701 }
702 return op_def->is_stateful();
703 }
704
IsStateful(const NodeDef node)705 bool IsStateful(const NodeDef node) {
706 return IsStateful(node, OpRegistry::Global());
707 }
708
IsFreeOfSideEffect(const NodeDef & node,const OpRegistryInterface * op_registry)709 bool IsFreeOfSideEffect(const NodeDef& node,
710 const OpRegistryInterface* op_registry) {
711 // Placeholders must be preserved to keep the graph feedable.
712 if (IsPlaceholder(node)) {
713 return false;
714 }
715 const OpDef* op_def = nullptr;
716 const string& op_name = node.op();
717 Status status = op_registry->LookUpOpDef(op_name, &op_def);
718 if (!status.ok()) {
719 return false;
720 }
721 if (op_def->is_stateful()) {
722 return false;
723 }
724 // Nodes such as Assign or AssignAdd modify one of their inputs.
725 for (const auto& input : op_def->input_arg()) {
726 if (input.is_ref()) {
727 return false;
728 }
729 }
730 // Queue ops modify the queue which is a side effect.
731 if (node.op().find("Queue") != string::npos) {
732 return false;
733 }
734 // Sending a tensor via a network is a side effect.
735 if (IsSend(node)) {
736 return false;
737 }
738 return !ModifiesInputsInPlace(node);
739 }
740
IsFreeOfSideEffect(const NodeDef & node)741 bool IsFreeOfSideEffect(const NodeDef& node) {
742 return IsFreeOfSideEffect(node, OpRegistry::Global());
743 }
744
ModifiesInputsInPlace(const NodeDef & node)745 bool ModifiesInputsInPlace(const NodeDef& node) {
746 // Some nodes do in-place updates on regular tensor inputs.
747 const string& op_name = node.op();
748
749 // Ops that modify resource variables effectively modify one of their inputs.
750 if (op_name == "AssignVariableOp" || op_name == "AssignAddVariableOp" ||
751 op_name == "AssignSubVariableOp" || op_name == "ResourceScatterUpdate" ||
752 op_name == "ResourceScatterAdd" || op_name == "ResourceScatterSub" ||
753 op_name == "ResourceScatterMul" || op_name == "ResourceScatterDiv" ||
754 op_name == "ResourceScatterMin" || op_name == "ResourceScatterMax") {
755 return false;
756 }
757
758 string lower_op_name = op_name;
759 std::transform(lower_op_name.begin(), lower_op_name.end(),
760 lower_op_name.begin(), ::tolower);
761 if (absl::StrContains(lower_op_name, "inplace")) {
762 return true;
763 }
764 return GetBoolAttr(node, "in_place") || GetBoolAttr(node, "inplace");
765 }
766
ModifiesFrameInfo(const NodeDef & node)767 bool ModifiesFrameInfo(const NodeDef& node) {
768 return IsEnter(node) || IsExit(node) || IsNextIteration(node);
769 }
770
771 #define OPDEF_PROPERTY_HELPER(PROPERTY_CAP, PROPERTY) \
772 bool Is##PROPERTY_CAP(const NodeDef& node) { \
773 if (node.op() == "Add") { \
774 /* Workaround for "Add" not being marked is_commutative and */ \
775 /* is_aggregate. (See cl/173915048). */ \
776 const auto type = GetDataTypeFromAttr(node, "T"); \
777 return type != DT_INVALID && type != DT_STRING; \
778 } \
779 const OpDef* op_def = nullptr; \
780 Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); \
781 return status.ok() && op_def->is_##PROPERTY(); \
782 }
783
OPDEF_PROPERTY_HELPER(Aggregate,aggregate)784 OPDEF_PROPERTY_HELPER(Aggregate, aggregate)
785 OPDEF_PROPERTY_HELPER(Commutative, commutative)
786
787 bool IsInvolution(const NodeDef& node) {
788 static const gtl::FlatSet<string>* const kInvolutionOps =
789 CHECK_NOTNULL((new gtl::FlatSet<string>{"Conj", "Reciprocal", "Invert",
790 "Neg", "LogicalNot"}));
791 return kInvolutionOps->count(node.op()) > 0;
792 }
793
IsValueAndOrderAndShapePreserving(const NodeDef & node)794 bool IsValueAndOrderAndShapePreserving(const NodeDef& node) {
795 if (NumNonControlInputs(node) == 1 && IsAggregate(node)) {
796 return true;
797 }
798 static const gtl::FlatSet<string>* const kValueAndOrderAndShapePreservingOps =
799 CHECK_NOTNULL((new const gtl::FlatSet<string>{
800 "CheckNumerics",
801 "DebugGradientIdentity",
802 "DeepCopy"
803 "Enter",
804 "Exit",
805 "PreventGradient",
806 "Print",
807 "Snapshot",
808 "StopGradient",
809 }));
810 return kValueAndOrderAndShapePreservingOps->count(node.op()) > 0 ||
811 IsIdentity(node);
812 }
813
IsValueAndOrderPreserving(const NodeDef & node)814 bool IsValueAndOrderPreserving(const NodeDef& node) {
815 if (NumNonControlInputs(node) == 1 && IsAggregate(node)) {
816 return true;
817 }
818 static const gtl::FlatSet<string>* const kValueAndOrderPreservingOps =
819 CHECK_NOTNULL((new const gtl::FlatSet<string>{
820 "ExpandDims",
821 "Reshape",
822 "Squeeze",
823 }));
824 return kValueAndOrderPreservingOps->count(node.op()) > 0 ||
825 IsValueAndOrderAndShapePreserving(node);
826 }
827
IsValuePreserving(const NodeDef & node)828 bool IsValuePreserving(const NodeDef& node) {
829 static const gtl::FlatSet<string>* const kValuePreservingOps =
830 CHECK_NOTNULL((new gtl::FlatSet<string>{
831 "InvertPermutation",
832 "Reverse",
833 "ReverseV2",
834 "Roll",
835 "Transpose",
836 "DepthToSpace",
837 "SpaceToDepth",
838 "BatchToSpace",
839 "BatchToSpaceND",
840 "SpaceToBatch",
841 "SpaceToBatchND",
842 }));
843 return IsValueAndOrderPreserving(node) ||
844 kValuePreservingOps->count(node.op()) > 0;
845 }
846
IsUnaryElementWise(const NodeDef & node)847 bool IsUnaryElementWise(const NodeDef& node) {
848 static const gtl::FlatSet<string>* const kElementWiseOps =
849 CHECK_NOTNULL((new gtl::FlatSet<string>{
850 "Abs",
851 "Acos",
852 "Acosh",
853 "Asin",
854 "Asinh",
855 "Atan",
856 "Atanh",
857 "Ceil",
858 "ComplexAbs",
859 "Conj",
860 "Cos",
861 "Cosh",
862 "Digamma",
863 "Elu"
864 "Erf",
865 "Erfc",
866 "Exp",
867 "Expm1",
868 "Floor",
869 "Inv",
870 "Invert",
871 "Isinf",
872 "Isnan",
873 "Isfinite",
874 "Lgamma",
875 "Log",
876 "Log1p",
877 "LogicalNot",
878 "Neg",
879 "Reciprocal",
880 "Relu",
881 "Relu6",
882 "Rint",
883 "Round",
884 "Selu",
885 "Rsqrt",
886 "Sigmoid",
887 "Sign",
888 "Sin",
889 "SinH",
890 "Softplus",
891 "Softsign",
892 "Sqrt",
893 "Square",
894 "Tan"
895 "Tanh",
896 }));
897 return kElementWiseOps->count(node.op()) > 0 ||
898 IsValueAndOrderAndShapePreserving(node);
899 }
900
HasOpDef(const NodeDef & node)901 bool HasOpDef(const NodeDef& node) {
902 const OpDef* op_def = nullptr;
903 return OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok();
904 }
905
IsIdempotent(const NodeDef & node)906 bool IsIdempotent(const NodeDef& node) {
907 return IsValueAndOrderAndShapePreserving(node) && IsFreeOfSideEffect(node) &&
908 !ModifiesFrameInfo(node);
909 }
910
NeverForwardsInputs(const NodeDef & node)911 bool NeverForwardsInputs(const NodeDef& node) {
912 static const gtl::FlatSet<string>* const kNonForwardingOps = CHECK_NOTNULL(
913 (new gtl::FlatSet<string>{"ArgMax",
914 "ArgMin",
915 "AudioSpectrogram",
916 "AvgPool",
917 "BatchMatMul",
918 "BatchMatMulV2",
919 "BatchNormWithGlobalNormalization",
920 "BatchToSpace",
921 "BatchToSpaceND",
922 "Bincount",
923 "BroadcastArgs",
924 "BroadcastGradientArgs",
925 "Bucketize",
926 "CTCBeamSearchDecoder",
927 "CTCGreedyDecoder",
928 "CTCLoss",
929 "CompareAndBitpack",
930 "ComplexAbs",
931 "Concat",
932 "ConcatOffset",
933 "ConcatV2",
934 "Conv2D",
935 "Copy",
936 "CopyHost",
937 "Cross",
938 "CudnnRNN",
939 "CudnnRNNBackprop",
940 "CudnnRNNBackpropV2",
941 "CudnnRNNBackpropV3",
942 "CudnnRNNCanonicalToParams",
943 "CudnnRNNCanonicalToParamsV2",
944 "CudnnRNNParamsSize",
945 "CudnnRNNParamsToCanonical",
946 "CudnnRNNParamsToCanonicalV2",
947 "CudnnRNNV2",
948 "CudnnRNNV3",
949 "CumProd",
950 "CumSum",
951 "DebugNanCount",
952 "DebugNumericSummary",
953 "DecodeProtoV2",
954 "DecodeWav",
955 "DeepCopy",
956 "DepthToSpace",
957 "Dequantize",
958 "Diag",
959 "DiagPart",
960 "EditDistance",
961 "Empty",
962 "EncodeProtoV2",
963 "EncodeWav",
964 "ExtractImagePatches",
965 "ExtractVolumePatches",
966 "Fill",
967 "Gather",
968 "GatherNd",
969 "GatherV2",
970 "HistogramFixedWidth",
971 "InvertPermutation",
972 "IsInf",
973 "IsNan",
974 "Isfinite",
975 "LinSpace",
976 "LowerBound",
977 "MatMul",
978 "MatrixDiag",
979 "MatrixDiagPart",
980 "MatrixDiagPartV2",
981 "MatrixDiagV2",
982 "Mfcc",
983 "Multinomial",
984 "OneHot",
985 "Pack",
986 "ParameterizedTruncatedNormal",
987 "PopulationCount",
988 "RandomGamma",
989 "RandomPoisson",
990 "RandomPoissonV2",
991 "RandomStandardNormal",
992 "RandomUniform",
993 "RandomUniformInt",
994 "Range",
995 "Rank",
996 "RequantizationRange",
997 "Requantize",
998 "ReverseSequence",
999 "Shape",
1000 "ShapeN",
1001 "Size",
1002 "SpaceToBatch",
1003 "SpaceToBatchND",
1004 "SpaceToDepth",
1005 "SparseMatMul",
1006 "Split",
1007 "SplitV",
1008 "TruncatedNormal",
1009 "Unique",
1010 "UniqueV2",
1011 "UniqueWithCounts",
1012 "UniqueWithCountsV2",
1013 "Unpack",
1014 "UnravelIndex",
1015 "UpperBound",
1016 "Where"}));
1017 const string& op_name = node.op();
1018 return kNonForwardingOps->count(op_name) > 0 ||
1019 absl::StrContains(op_name, "Segment") ||
1020 absl::StartsWith(op_name, "Quantize");
1021 }
1022
IsXlaLaunch(const NodeDef & node)1023 bool IsXlaLaunch(const NodeDef& node) { return node.op() == "XlaLaunch"; }
1024
1025 } // namespace grappler
1026 } // end namespace tensorflow
1027