1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/op_types.h"
17 #include "tensorflow/core/framework/attr_value.pb.h"
18 #include "tensorflow/core/framework/op.h"
19 #include "tensorflow/core/framework/types.h"
20 #include "tensorflow/core/grappler/utils.h"
21 #include "tensorflow/core/lib/core/status.h"
22 #include "tensorflow/core/lib/gtl/flatset.h"
23 #include "tensorflow/core/lib/strings/str_util.h"
24 #include "tensorflow/core/platform/logging.h"
25
26 namespace tensorflow {
27 namespace grappler {
28
IsAdd(const NodeDef & node)29 bool IsAdd(const NodeDef& node) {
30 if (node.op() == "AddV2" || node.op() == "Add") {
31 DataType type = node.attr().at("T").type();
32 return type != DT_STRING;
33 }
34 return false;
35 }
36
IsAddN(const NodeDef & node)37 bool IsAddN(const NodeDef& node) { return node.op() == "AddN"; }
38
IsAll(const NodeDef & node)39 bool IsAll(const NodeDef& node) { return node.op() == "All"; }
40
IsAngle(const NodeDef & node)41 bool IsAngle(const NodeDef& node) { return node.op() == "Angle"; }
42
IsAny(const NodeDef & node)43 bool IsAny(const NodeDef& node) { return node.op() == "Any"; }
44
IsAnyDiv(const NodeDef & node)45 bool IsAnyDiv(const NodeDef& node) {
46 return node.op() == "RealDiv" || node.op() == "Div" ||
47 node.op() == "FloorDiv" || node.op() == "TruncateDiv";
48 }
49
IsAnyMax(const NodeDef & node)50 bool IsAnyMax(const NodeDef& node) {
51 const auto& op = node.op();
52 return op == "Max" || op == "SegmentMax" || op == "UnsortedSegmentMax";
53 }
54
IsAnyMaxPool(const NodeDef & node)55 bool IsAnyMaxPool(const NodeDef& node) {
56 const auto& op = node.op();
57 return op == "MaxPool" || op == "MaxPoolV2" || op == "MaxPool3D" ||
58 op == "MaxPoolWithArgmax" || op == "FractionalMaxPool";
59 }
60
IsAnyMin(const NodeDef & node)61 bool IsAnyMin(const NodeDef& node) {
62 const auto& op = node.op();
63 return op == "Min" || op == "SegmentMin" || op == "UnsortedSegmentMin";
64 }
65
IsApproximateEqual(const NodeDef & node)66 bool IsApproximateEqual(const NodeDef& node) {
67 return node.op() == "ApproximateEqual";
68 }
69
IsArgMax(const NodeDef & node)70 bool IsArgMax(const NodeDef& node) { return node.op() == "ArgMax"; }
71
IsArgMin(const NodeDef & node)72 bool IsArgMin(const NodeDef& node) { return node.op() == "ArgMin"; }
73
IsAvgPoolGrad(const NodeDef & node)74 bool IsAvgPoolGrad(const NodeDef& node) { return node.op() == "AvgPoolGrad"; }
75
IsAssign(const NodeDef & node)76 bool IsAssign(const NodeDef& node) {
77 return node.op() == "Assign" || node.op() == "AssignVariableOp";
78 }
79
IsAssert(const NodeDef & node)80 bool IsAssert(const NodeDef& node) { return node.op() == "Assert"; }
81
IsAtan2(const NodeDef & node)82 bool IsAtan2(const NodeDef& node) { return node.op() == "Atan2"; }
83
IsBetainc(const NodeDef & node)84 bool IsBetainc(const NodeDef& node) { return node.op() == "Betainc"; }
85
IsBiasAdd(const NodeDef & node)86 bool IsBiasAdd(const NodeDef& node) {
87 return node.op() == "BiasAdd" || node.op() == "BiasAddV1";
88 }
89
IsBiasAddGrad(const NodeDef & node)90 bool IsBiasAddGrad(const NodeDef& node) { return node.op() == "BiasAddGrad"; }
91
IsBitcast(const NodeDef & node)92 bool IsBitcast(const NodeDef& node) { return node.op() == "Bitcast"; }
93
IsCast(const NodeDef & node)94 bool IsCast(const NodeDef& node) { return node.op() == "Cast"; }
95
IsCastLike(const NodeDef & node)96 bool IsCastLike(const NodeDef& node) {
97 static const gtl::FlatSet<string>* const kCastLikeOps =
98 CHECK_NOTNULL((new gtl::FlatSet<string>{
99 "Angle", "Bucketize", "Cast", "CompareAndBitpack", "Dequantize",
100 "HistogramFixedWidth", "Imag", "IsFinite", "IsInf", "IsNan",
101 "Quantize", "QuantizeDownAndShrinkRange", "QuantizeV2",
102 "QuantizedInstanceNorm", "QuantizedRelu", "QuantizedRelu6",
103 "QuantizedReluX", "Real", "Requantize"}));
104 return kCastLikeOps->count(node.op()) > 0;
105 }
106
IsCheckNumerics(const NodeDef & node)107 bool IsCheckNumerics(const NodeDef& node) {
108 return node.op() == "CheckNumerics";
109 }
110
IsCollective(const NodeDef & node)111 bool IsCollective(const NodeDef& node) {
112 return node.op() == "CollectiveReduce" ||
113 node.op() == "CollectiveBcastSend" ||
114 node.op() == "CollectiveBcastRecv";
115 }
116
IsComplex(const NodeDef & node)117 bool IsComplex(const NodeDef& node) { return node.op() == "Complex"; }
118
IsComplexAbs(const NodeDef & node)119 bool IsComplexAbs(const NodeDef& node) { return node.op() == "ComplexAbs"; }
120
IsConcat(const NodeDef & node)121 bool IsConcat(const NodeDef& node) {
122 return node.op() == "Concat" || node.op() == "ConcatV2";
123 }
124
IsConcatOffset(const NodeDef & node)125 bool IsConcatOffset(const NodeDef& node) { return node.op() == "ConcatOffset"; }
126
IsConstant(const NodeDef & node)127 bool IsConstant(const NodeDef& node) { return node.op() == "Const"; }
128
IsConj(const NodeDef & node)129 bool IsConj(const NodeDef& node) { return node.op() == "Conj"; }
130
IsConjugateTranspose(const NodeDef & node)131 bool IsConjugateTranspose(const NodeDef& node) {
132 return node.op() == "ConjugateTranspose";
133 }
134
IsControlFlow(const NodeDef & node)135 bool IsControlFlow(const NodeDef& node) {
136 // clang-format off
137 return node.op() == "ControlTrigger" ||
138 node.op() == "Enter" ||
139 node.op() == "Exit" ||
140 node.op() == "LoopCond" ||
141 node.op() == "Merge" ||
142 node.op() == "NextIteration" ||
143 node.op() == "Switch";
144 // clang-format on
145 }
146
IsConv2D(const NodeDef & node)147 bool IsConv2D(const NodeDef& node) { return node.op() == "Conv2D"; }
148
IsConv2DBackpropFilter(const NodeDef & node)149 bool IsConv2DBackpropFilter(const NodeDef& node) {
150 return node.op() == "Conv2DBackpropFilter";
151 }
152
IsConv2DBackpropInput(const NodeDef & node)153 bool IsConv2DBackpropInput(const NodeDef& node) {
154 return node.op() == "Conv2DBackpropInput";
155 }
156
IsConv3D(const NodeDef & node)157 bool IsConv3D(const NodeDef& node) { return node.op() == "Conv3D"; }
158
IsDepthwiseConv2dNative(const NodeDef & node)159 bool IsDepthwiseConv2dNative(const NodeDef& node) {
160 return node.op() == "DepthwiseConv2dNative";
161 }
162
IsDepthwiseConv2dNativeBackpropFilter(const NodeDef & node)163 bool IsDepthwiseConv2dNativeBackpropFilter(const NodeDef& node) {
164 return node.op() == "DepthwiseConv2dNativeBackpropFilter";
165 }
166
IsDepthwiseConv2dNativeBackpropInput(const NodeDef & node)167 bool IsDepthwiseConv2dNativeBackpropInput(const NodeDef& node) {
168 return node.op() == "DepthwiseConv2dNativeBackpropInput";
169 }
170
IsDequeueOp(const NodeDef & node)171 bool IsDequeueOp(const NodeDef& node) {
172 const auto& op = node.op();
173 return op == "QueueDequeueManyV2" || op == "QueueDequeueMany" ||
174 op == "QueueDequeueV2" || op == "QueueDequeue" ||
175 op == "QueueDequeueUpToV2" || op == "QueueDequeueUpTo";
176 }
177
IsDiv(const NodeDef & node)178 bool IsDiv(const NodeDef& node) { return node.op() == "Div"; }
179
180 // Returns true if node represents a unary elementwise function that is
181 // monotonic. If *is_non_decreasing is true, the function is non-decreasing,
182 // e.g. sqrt, exp. *is_non_decreasing is false, the function is non-increasing,
183 // e.g. inv.
IsElementWiseMonotonic(const NodeDef & node,bool * is_non_decreasing)184 bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing) {
185 static const gtl::FlatSet<string>* const kMonotonicNonDecreasingOps =
186 CHECK_NOTNULL((new gtl::FlatSet<string>{
187 "Acosh", "Asin", "Asinh", "Atan", "Atanh", "Ceil",
188 "Elu", "Erf", "Exp", "Expm1", "Floor", "Log",
189 "Log1p", "Relu", "Relu6", "Rint", "Selu", "Sigmoid",
190 "Sign", "Sinh", "Softsign", "Softplus", "Sqrt", "Tanh",
191 }));
192 static const gtl::FlatSet<string>* const kMonotonicNonIncreasingOps =
193 CHECK_NOTNULL((new gtl::FlatSet<string>{"Acos", "Erfc", "Neg", "Rsqrt"}));
194 if (kMonotonicNonDecreasingOps->count(node.op()) > 0) {
195 if (is_non_decreasing) {
196 *is_non_decreasing = true;
197 }
198 return true;
199 } else if (kMonotonicNonIncreasingOps->count(node.op()) > 0) {
200 if (is_non_decreasing) {
201 *is_non_decreasing = false;
202 }
203 return true;
204 }
205 return false;
206 }
207
IsEluGrad(const NodeDef & node)208 bool IsEluGrad(const NodeDef& node) { return node.op() == "EluGrad"; }
209
IsEnter(const NodeDef & node)210 bool IsEnter(const NodeDef& node) {
211 const auto& op = node.op();
212 return op == "Enter" || op == "RefEnter";
213 }
214
IsEqual(const NodeDef & node)215 bool IsEqual(const NodeDef& node) { return node.op() == "Equal"; }
216
IsExit(const NodeDef & node)217 bool IsExit(const NodeDef& node) {
218 const auto& op = node.op();
219 return op == "Exit" || op == "RefExit";
220 }
221
IsExp(const NodeDef & node)222 bool IsExp(const NodeDef& node) { return node.op() == "Exp"; }
223
IsFakeParam(const NodeDef & node)224 bool IsFakeParam(const NodeDef& node) { return node.op() == "FakeParam"; }
225
IsFill(const NodeDef & node)226 bool IsFill(const NodeDef& node) { return node.op() == "Fill"; }
227
IsFloorDiv(const NodeDef & node)228 bool IsFloorDiv(const NodeDef& node) { return node.op() == "FloorDiv"; }
229
IsFloorMod(const NodeDef & node)230 bool IsFloorMod(const NodeDef& node) { return node.op() == "FloorMod"; }
231
IsFusedBatchNorm(const NodeDef & node)232 bool IsFusedBatchNorm(const NodeDef& node) {
233 const auto& op = node.op();
234 return op == "FusedBatchNorm" || op == "FusedBatchNormV2";
235 }
236
IsFusedBatchNormGrad(const NodeDef & node)237 bool IsFusedBatchNormGrad(const NodeDef& node) {
238 const auto& op = node.op();
239 return op == "FusedBatchNormGrad" || op == "FusedBatchNormGradV2";
240 }
241
IsGreater(const NodeDef & node)242 bool IsGreater(const NodeDef& node) { return node.op() == "Greater"; }
243
IsGreaterEqual(const NodeDef & node)244 bool IsGreaterEqual(const NodeDef& node) { return node.op() == "GreaterEqual"; }
245
IsHostConstant(const NodeDef & node)246 bool IsHostConstant(const NodeDef& node) { return node.op() == "HostConst"; }
247
IsHistogramSummary(const NodeDef & node)248 bool IsHistogramSummary(const NodeDef& node) {
249 return node.op() == "HistogramSummary";
250 }
251
IsIdentity(const NodeDef & node)252 bool IsIdentity(const NodeDef& node) {
253 const auto& op = node.op();
254 return op == "Identity" || op == "RefIdentity";
255 }
256
IsIdentityN(const NodeDef & node)257 bool IsIdentityN(const NodeDef& node) {
258 const auto& op = node.op();
259 return op == "IdentityN";
260 }
261
IsIdentityNSingleInput(const NodeDef & node)262 bool IsIdentityNSingleInput(const NodeDef& node) {
263 return IsIdentityN(node) && node.attr().count("T") != 0 &&
264 node.attr().at("T").list().type_size() == 1;
265 }
266
IsIf(const NodeDef & node)267 bool IsIf(const NodeDef& node) {
268 const auto& op = node.op();
269 return op == "If" || op == "StatelessIf";
270 }
271
IsIgamma(const NodeDef & node)272 bool IsIgamma(const NodeDef& node) { return node.op() == "Igamma"; }
273
IsIgammac(const NodeDef & node)274 bool IsIgammac(const NodeDef& node) { return node.op() == "Igammac"; }
275
IsImag(const NodeDef & node)276 bool IsImag(const NodeDef& node) { return node.op() == "Imag"; }
277
IsImmutableConst(const NodeDef & node)278 bool IsImmutableConst(const NodeDef& node) {
279 return node.op() == "ImmutableConst";
280 }
281
IsInvGrad(const NodeDef & node)282 bool IsInvGrad(const NodeDef& node) { return node.op() == "InvGrad"; }
283
IsLess(const NodeDef & node)284 bool IsLess(const NodeDef& node) { return node.op() == "Less"; }
285
IsLessEqual(const NodeDef & node)286 bool IsLessEqual(const NodeDef& node) { return node.op() == "LessEqual"; }
287
IsLog(const NodeDef & node)288 bool IsLog(const NodeDef& node) { return node.op() == "Log"; }
289
IsLogicalAnd(const NodeDef & node)290 bool IsLogicalAnd(const NodeDef& node) { return node.op() == "LogicalAnd"; }
291
IsLogicalNot(const NodeDef & node)292 bool IsLogicalNot(const NodeDef& node) { return node.op() == "LogicalNot"; }
293
IsLogicalOr(const NodeDef & node)294 bool IsLogicalOr(const NodeDef& node) { return node.op() == "LogicalOr"; }
295
IsMatMul(const NodeDef & node)296 bool IsMatMul(const NodeDef& node) {
297 const auto& op = node.op();
298 return op == "MatMul" || op == "BatchMatMul" || op == "SparseMatMul" ||
299 IsQuantizedMatMul(node);
300 }
301
IsMax(const NodeDef & node)302 bool IsMax(const NodeDef& node) { return node.op() == "Max"; }
303
IsMaximum(const NodeDef & node)304 bool IsMaximum(const NodeDef& node) { return node.op() == "Maximum"; }
305
IsMaxPoolGrad(const NodeDef & node)306 bool IsMaxPoolGrad(const NodeDef& node) { return node.op() == "MaxPoolGrad"; }
307
IsMean(const NodeDef & node)308 bool IsMean(const NodeDef& node) { return node.op() == "Mean"; }
309
IsMerge(const NodeDef & node)310 bool IsMerge(const NodeDef& node) {
311 const auto& op = node.op();
312 return op == "Merge" || op == "RefMerge";
313 }
314
IsMin(const NodeDef & node)315 bool IsMin(const NodeDef& node) { return node.op() == "Min"; }
316
IsMinimum(const NodeDef & node)317 bool IsMinimum(const NodeDef& node) { return node.op() == "Minimum"; }
318
IsMirrorPad(const NodeDef & node)319 bool IsMirrorPad(const NodeDef& node) { return node.op() == "MirrorPad"; }
320
IsMirrorPadGrad(const NodeDef & node)321 bool IsMirrorPadGrad(const NodeDef& node) {
322 return node.op() == "MirrorPadGrad";
323 }
324
IsMod(const NodeDef & node)325 bool IsMod(const NodeDef& node) { return node.op() == "Mod"; }
326
IsMul(const NodeDef & node)327 bool IsMul(const NodeDef& node) { return node.op() == "Mul"; }
328
IsNeg(const NodeDef & node)329 bool IsNeg(const NodeDef& node) { return node.op() == "Neg"; }
330
IsNoOp(const NodeDef & node)331 bool IsNoOp(const NodeDef& node) { return node.op() == "NoOp"; }
332
IsNotEqual(const NodeDef & node)333 bool IsNotEqual(const NodeDef& node) { return node.op() == "NotEqual"; }
334
IsNextIteration(const NodeDef & node)335 bool IsNextIteration(const NodeDef& node) {
336 const auto& op = node.op();
337 return op == "NextIteration" || op == "RefNextIteration";
338 }
339
IsOnesLike(const NodeDef & node)340 bool IsOnesLike(const NodeDef& node) { return node.op() == "OnesLike"; }
341
IsPack(const NodeDef & node)342 bool IsPack(const NodeDef& node) { return node.op() == "Pack"; }
343
IsPad(const NodeDef & node)344 bool IsPad(const NodeDef& node) {
345 const auto& op = node.op();
346 return op == "Pad" || op == "PadV2";
347 }
348
IsPartitionedCall(const NodeDef & node)349 bool IsPartitionedCall(const NodeDef& node) {
350 return node.op() == "PartitionedCall";
351 }
352
IsPlaceholder(const NodeDef & node)353 bool IsPlaceholder(const NodeDef& node) {
354 const auto& op = node.op();
355 return op == "Placeholder" || op == "PlaceholderV2" ||
356 op == "PlaceholderWithDefault";
357 }
358
IsPolygamma(const NodeDef & node)359 bool IsPolygamma(const NodeDef& node) { return node.op() == "Polygamma"; }
360
IsPow(const NodeDef & node)361 bool IsPow(const NodeDef& node) { return node.op() == "Pow"; }
362
IsPrint(const NodeDef & node)363 bool IsPrint(const NodeDef& node) {
364 return node.op() == "Print" || node.op() == "PrintV2";
365 }
366
IsProd(const NodeDef & node)367 bool IsProd(const NodeDef& node) { return node.op() == "Prod"; }
368
IsQuantizedMatMul(const NodeDef & node)369 bool IsQuantizedMatMul(const NodeDef& node) {
370 return node.op() == "QuantizedMatMul" || node.op() == "QuantizedMatMulV2";
371 }
372
IsQueue(const NodeDef & node)373 bool IsQueue(const NodeDef& node) {
374 return str_util::EndsWith(node.op(), "QueueV2");
375 }
376
IsRandomShuffle(const NodeDef & node)377 bool IsRandomShuffle(const NodeDef& node) {
378 return node.op() == "RandomShuffle";
379 }
380
IsRank(const NodeDef & node)381 bool IsRank(const NodeDef& node) { return node.op() == "Rank"; }
382
IsReadVariableOp(const NodeDef & node)383 bool IsReadVariableOp(const NodeDef& node) {
384 return node.op() == "ReadVariableOp";
385 }
386
IsReal(const NodeDef & node)387 bool IsReal(const NodeDef& node) { return node.op() == "Real"; }
388
IsRealDiv(const NodeDef & node)389 bool IsRealDiv(const NodeDef& node) { return node.op() == "RealDiv"; }
390
IsReciprocalGrad(const NodeDef & node)391 bool IsReciprocalGrad(const NodeDef& node) {
392 return node.op() == "ReciprocalGrad";
393 }
394
IsRecv(const NodeDef & node)395 bool IsRecv(const NodeDef& node) {
396 return node.op() == "_Recv" || node.op() == "_HostRecv";
397 }
398
IsReduction(const NodeDef & node)399 bool IsReduction(const NodeDef& node) {
400 const auto& op = node.op();
401 return op == "Sum" || op == "Prod" || op == "Min" || op == "Max" ||
402 op == "Mean" || op == "Any" || op == "All";
403 }
404
IsRelu(const NodeDef & node)405 bool IsRelu(const NodeDef& node) { return node.op() == "Relu"; }
406
IsReluGrad(const NodeDef & node)407 bool IsReluGrad(const NodeDef& node) { return node.op() == "ReluGrad"; }
408
IsRelu6Grad(const NodeDef & node)409 bool IsRelu6Grad(const NodeDef& node) { return node.op() == "Relu6Grad"; }
410
IsReshape(const NodeDef & node)411 bool IsReshape(const NodeDef& node) { return (node.op() == "Reshape"); }
412
IsRestore(const NodeDef & node)413 bool IsRestore(const NodeDef& node) {
414 return (node.op() == "Restore" || node.op() == "RestoreV2" ||
415 node.op() == "RestoreSlice");
416 }
417
IsReverse(const NodeDef & node)418 bool IsReverse(const NodeDef& node) {
419 return node.op() == "Reverse" || node.op() == "ReverseV2";
420 }
421
IsReverseV2(const NodeDef & node)422 bool IsReverseV2(const NodeDef& node) { return node.op() == "ReverseV2"; }
423
IsRsqrt(const NodeDef & node)424 bool IsRsqrt(const NodeDef& node) { return node.op() == "Rsqrt"; }
425
IsRsqrtGrad(const NodeDef & node)426 bool IsRsqrtGrad(const NodeDef& node) { return node.op() == "RsqrtGrad"; }
427
IsSelect(const NodeDef & node)428 bool IsSelect(const NodeDef& node) { return node.op() == "Select"; }
429
IsSeluGrad(const NodeDef & node)430 bool IsSeluGrad(const NodeDef& node) { return node.op() == "SeluGrad"; }
431
IsSend(const NodeDef & node)432 bool IsSend(const NodeDef& node) {
433 return node.op() == "_Send" || node.op() == "_HostSend";
434 }
435
IsShape(const NodeDef & node)436 bool IsShape(const NodeDef& node) { return node.op() == "Shape"; }
437
IsShapeN(const NodeDef & node)438 bool IsShapeN(const NodeDef& node) { return node.op() == "ShapeN"; }
439
IsShuffle(const NodeDef & node)440 bool IsShuffle(const NodeDef& node) { return node.op() == "Shuffle"; }
441
IsSigmoidGrad(const NodeDef & node)442 bool IsSigmoidGrad(const NodeDef& node) { return node.op() == "SigmoidGrad"; }
443
IsSize(const NodeDef & node)444 bool IsSize(const NodeDef& node) { return node.op() == "Size"; }
445
IsSlice(const NodeDef & node)446 bool IsSlice(const NodeDef& node) { return node.op() == "Slice"; }
447
IsSnapshot(const NodeDef & node)448 bool IsSnapshot(const NodeDef& node) { return node.op() == "Snapshot"; }
449
IsSoftmax(const NodeDef & node)450 bool IsSoftmax(const NodeDef& node) { return node.op() == "Softmax"; }
451
IsSoftplusGrad(const NodeDef & node)452 bool IsSoftplusGrad(const NodeDef& node) { return node.op() == "SoftplusGrad"; }
453
IsSoftsignGrad(const NodeDef & node)454 bool IsSoftsignGrad(const NodeDef& node) { return node.op() == "SoftsignGrad"; }
455
IsSplit(const NodeDef & node)456 bool IsSplit(const NodeDef& node) { return node.op() == "Split"; }
457
IsSplitV(const NodeDef & node)458 bool IsSplitV(const NodeDef& node) { return node.op() == "SplitV"; }
459
IsSqrt(const NodeDef & node)460 bool IsSqrt(const NodeDef& node) { return node.op() == "Sqrt"; }
461
IsSqrtGrad(const NodeDef & node)462 bool IsSqrtGrad(const NodeDef& node) { return node.op() == "SqrtGrad"; }
463
IsSquare(const NodeDef & node)464 bool IsSquare(const NodeDef& node) { return node.op() == "Square"; }
465
IsSquaredDifference(const NodeDef & node)466 bool IsSquaredDifference(const NodeDef& node) {
467 return node.op() == "SquaredDifference";
468 }
469
IsSqueeze(const NodeDef & node)470 bool IsSqueeze(const NodeDef& node) { return node.op() == "Squeeze"; }
471
IsStackOp(const NodeDef & node)472 bool IsStackOp(const NodeDef& node) {
473 return node.op() == "Stack" || node.op() == "StackV2";
474 }
IsStackCloseOp(const NodeDef & node)475 bool IsStackCloseOp(const NodeDef& node) {
476 return node.op() == "StackClose" || node.op() == "StackCloseV2";
477 }
IsStackPushOp(const NodeDef & node)478 bool IsStackPushOp(const NodeDef& node) {
479 return node.op() == "StackPush" || node.op() == "StackPushV2";
480 }
IsStackPopOp(const NodeDef & node)481 bool IsStackPopOp(const NodeDef& node) {
482 return node.op() == "StackPop" || node.op() == "StackPopV2";
483 }
484
IsStatefulPartitionedCall(const NodeDef & node)485 bool IsStatefulPartitionedCall(const NodeDef& node) {
486 return node.op() == "StatefulPartitionedCall";
487 }
488
IsStopGradient(const NodeDef & node)489 bool IsStopGradient(const NodeDef& node) {
490 const auto& op = node.op();
491 return op == "StopGradient" || op == "PreventGradient";
492 }
493
IsStridedSlice(const NodeDef & node)494 bool IsStridedSlice(const NodeDef& node) { return node.op() == "StridedSlice"; }
495
IsStridedSliceGrad(const NodeDef & node)496 bool IsStridedSliceGrad(const NodeDef& node) {
497 return node.op() == "StridedSliceGrad";
498 }
499
IsSub(const NodeDef & node)500 bool IsSub(const NodeDef& node) { return node.op() == "Sub"; }
501
IsSum(const NodeDef & node)502 bool IsSum(const NodeDef& node) { return node.op() == "Sum"; }
503
IsSwitch(const NodeDef & node)504 bool IsSwitch(const NodeDef& node) {
505 const auto& op = node.op();
506 return op == "Switch" || op == "RefSwitch";
507 }
508
IsSymbolicGradient(const NodeDef & node)509 bool IsSymbolicGradient(const NodeDef& node) {
510 return node.op() == "SymbolicGradient";
511 }
512
IsTanhGrad(const NodeDef & node)513 bool IsTanhGrad(const NodeDef& node) { return node.op() == "TanhGrad"; }
514
IsTensorArray(const NodeDef & node)515 bool IsTensorArray(const NodeDef& node) {
516 static const gtl::FlatSet<string>* const kTensorArrayOps =
517 CHECK_NOTNULL((new gtl::FlatSet<string>{
518 "TensorArray",
519 "TensorArrayV2",
520 "TensorArrayV3",
521 "TensorArrayGrad",
522 "TensorArrayGradV2",
523 "TensorArrayGradV3",
524 "TensorArrayGradWithShape",
525 "TensorArrayWrite",
526 "TensorArrayWriteV2",
527 "TensorArrayWriteV3",
528 "TensorArrayRead",
529 "TensorArrayReadV2",
530 "TensorArrayReadV3",
531 "TensorArrayConcat",
532 "TensorArrayConcatV2",
533 "TensorArrayConcatV3",
534 "TensorArraySplit",
535 "TensorArraySplitV2",
536 "TensorArraySplitV3",
537 "TensorArraySize",
538 "TensorArraySizeV2",
539 "TensorArraySizeV3",
540 "TensorArrayClose",
541 "TensorArrayCloseV2",
542 "TensorArrayCloseV3",
543 }));
544 return kTensorArrayOps->count(node.op()) > 0;
545 }
546
IsTile(const NodeDef & node)547 bool IsTile(const NodeDef& node) { return node.op() == "Tile"; }
548
IsTranspose(const NodeDef & node)549 bool IsTranspose(const NodeDef& node) { return node.op() == "Transpose"; }
550
IsTruncateDiv(const NodeDef & node)551 bool IsTruncateDiv(const NodeDef& node) { return node.op() == "TruncateDiv"; }
552
IsTruncateMod(const NodeDef & node)553 bool IsTruncateMod(const NodeDef& node) { return node.op() == "TruncateMod"; }
554
IsUnpack(const NodeDef & node)555 bool IsUnpack(const NodeDef& node) { return node.op() == "Unpack"; }
556
IsVariable(const NodeDef & node)557 bool IsVariable(const NodeDef& node) {
558 const auto& op = node.op();
559 return op == "Variable" || op == "VariableV2" || op == "AutoReloadVariable" ||
560 op == "VarHandleOp" || op == "ReadVariableOp";
561 }
562
IsWhile(const NodeDef & node)563 bool IsWhile(const NodeDef& node) {
564 const auto& op = node.op();
565 return op == "While" || op == "StatelessWhile";
566 }
567
IsZerosLike(const NodeDef & node)568 bool IsZerosLike(const NodeDef& node) { return node.op() == "ZerosLike"; }
569
IsZeta(const NodeDef & node)570 bool IsZeta(const NodeDef& node) { return node.op() == "Zeta"; }
571
572 namespace {
GetBoolAttr(const NodeDef & node,const string & name)573 bool GetBoolAttr(const NodeDef& node, const string& name) {
574 return node.attr().count(name) > 0 && node.attr().at(name).b();
575 }
576 } // namespace
577
IsPersistent(const NodeDef & node)578 bool IsPersistent(const NodeDef& node) {
579 return IsConstant(node) || IsVariable(node) || IsHostConstant(node);
580 }
581
MaybeHasRefInput(const NodeDef & node)582 bool MaybeHasRefInput(const NodeDef& node) {
583 const OpDef* op_def;
584 Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
585 if (!status.ok()) {
586 return true;
587 }
588 // Nodes such as Assign or AssignAdd modify one of their inputs.
589 for (const auto& input : op_def->input_arg()) {
590 if (input.is_ref()) {
591 return true;
592 }
593 }
594 return false;
595 }
596
IsDataset(const NodeDef & node)597 bool IsDataset(const NodeDef& node) {
598 const string& op = node.op();
599 // See `GetNodeClassForOp` in core/graph/graph.cc.
600 return op == "IteratorGetNext" || op == "IteratorGetNextSync" ||
601 op == "DatasetToSingleElement" || op == "ReduceDataset";
602 }
603
IsStateful(const NodeDef node,const OpRegistryInterface * op_registry)604 bool IsStateful(const NodeDef node, const OpRegistryInterface* op_registry) {
605 const OpDef* op_def = nullptr;
606 const string& op_name = node.op();
607 Status status = op_registry->LookUpOpDef(op_name, &op_def);
608 if (!status.ok()) {
609 LOG(WARNING) << "Failed to lookup OpDef for " << op_name
610 << ". Error: " << status.error_message();
611 return false;
612 }
613 return op_def->is_stateful();
614 }
615
IsStateful(const NodeDef node)616 bool IsStateful(const NodeDef node) {
617 return IsStateful(node, OpRegistry::Global());
618 }
619
IsFreeOfSideEffect(const NodeDef & node,const OpRegistryInterface * op_registry)620 bool IsFreeOfSideEffect(const NodeDef& node,
621 const OpRegistryInterface* op_registry) {
622 // Placeholders must be preserved to keep the graph feedable.
623 if (IsPlaceholder(node)) {
624 return false;
625 }
626 const OpDef* op_def = nullptr;
627 const string& op_name = node.op();
628 Status status = op_registry->LookUpOpDef(op_name, &op_def);
629 if (!status.ok()) {
630 return false;
631 }
632 if (op_def->is_stateful()) {
633 return false;
634 }
635 // Nodes such as Assign or AssignAdd modify one of their inputs.
636 for (const auto& input : op_def->input_arg()) {
637 if (input.is_ref()) {
638 return false;
639 }
640 }
641 // Queue ops modify the queue which is a side effect.
642 if (node.op().find("Queue") != string::npos) {
643 return false;
644 }
645 // Sending a tensor via a network is a side effect.
646 if (IsSend(node)) {
647 return false;
648 }
649 return !ModifiesInputsInPlace(node);
650 }
651
IsFreeOfSideEffect(const NodeDef & node)652 bool IsFreeOfSideEffect(const NodeDef& node) {
653 return IsFreeOfSideEffect(node, OpRegistry::Global());
654 }
655
ModifiesInputsInPlace(const NodeDef & node)656 bool ModifiesInputsInPlace(const NodeDef& node) {
657 // Some nodes do in-place updates on regular tensor inputs.
658 string op_name = node.op();
659
660 // Ops that modify resource variables effectively modify one of their inputs.
661 if (op_name == "AssignVariableOp" || op_name == "AssignAddVariableOp" ||
662 op_name == "AssignSubVariableOp" || op_name == "ResourceScatterUpdate" ||
663 op_name == "ResourceScatterAdd" || op_name == "ResourceScatterSub" ||
664 op_name == "ResourceScatterMul" || op_name == "ResourceScatterDiv" ||
665 op_name == "ResourceScatterMin" || op_name == "ResourceScatterMax") {
666 return false;
667 }
668
669 std::transform(op_name.begin(), op_name.end(), op_name.begin(), ::tolower);
670 if (str_util::StrContains(op_name, "inplace")) {
671 return true;
672 }
673 return GetBoolAttr(node, "in_place") || GetBoolAttr(node, "inplace");
674 }
675
ModifiesFrameInfo(const NodeDef & node)676 bool ModifiesFrameInfo(const NodeDef& node) {
677 return IsEnter(node) || IsExit(node) || IsNextIteration(node);
678 }
679
680 #define OPDEF_PROPERTY_HELPER(PROPERTY_CAP, PROPERTY) \
681 bool Is##PROPERTY_CAP(const NodeDef& node) { \
682 if (node.op() == "Add") { \
683 /* Workaround for "Add" not being marked is_commutative and */ \
684 /* is_aggregate. (See cl/173915048). */ \
685 const auto type = GetDataTypeFromAttr(node, "T"); \
686 return type != DT_INVALID && type != DT_STRING; \
687 } \
688 const OpDef* op_def = nullptr; \
689 Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); \
690 return status.ok() && op_def->is_##PROPERTY(); \
691 }
692
OPDEF_PROPERTY_HELPER(Aggregate,aggregate)693 OPDEF_PROPERTY_HELPER(Aggregate, aggregate)
694 OPDEF_PROPERTY_HELPER(Commutative, commutative)
695
696 bool IsInvolution(const NodeDef& node) {
697 static const gtl::FlatSet<string>* const kInvolutionOps =
698 CHECK_NOTNULL((new gtl::FlatSet<string>{"Conj", "Reciprocal", "Invert",
699 "Neg", "LogicalNot"}));
700 return kInvolutionOps->count(node.op()) > 0;
701 }
702
IsValueAndOrderAndShapePreserving(const NodeDef & node)703 bool IsValueAndOrderAndShapePreserving(const NodeDef& node) {
704 if (NumNonControlInputs(node) == 1 && IsAggregate(node)) {
705 return true;
706 }
707 static const gtl::FlatSet<string>* const kValueAndOrderAndShapePreservingOps =
708 CHECK_NOTNULL((new const gtl::FlatSet<string>{
709 "CheckNumerics",
710 "DebugGradientIdentity",
711 "DeepCopy"
712 "Enter",
713 "Exit",
714 "PreventGradient",
715 "Print",
716 "Snapshot",
717 "StopGradient",
718 }));
719 return kValueAndOrderAndShapePreservingOps->count(node.op()) > 0 ||
720 IsIdentity(node);
721 }
722
IsValueAndOrderPreserving(const NodeDef & node)723 bool IsValueAndOrderPreserving(const NodeDef& node) {
724 if (NumNonControlInputs(node) == 1 && IsAggregate(node)) {
725 return true;
726 }
727 static const gtl::FlatSet<string>* const kValueAndOrderPreservingOps =
728 CHECK_NOTNULL((new const gtl::FlatSet<string>{
729 "ExpandDims",
730 "Reshape",
731 "Squeeze",
732 }));
733 return kValueAndOrderPreservingOps->count(node.op()) > 0 ||
734 IsValueAndOrderAndShapePreserving(node);
735 }
736
IsValuePreserving(const NodeDef & node)737 bool IsValuePreserving(const NodeDef& node) {
738 static const gtl::FlatSet<string>* const kValuePreservingOps =
739 CHECK_NOTNULL((new gtl::FlatSet<string>{
740 "InvertPermutation",
741 "Reverse",
742 "ReverseV2",
743 "Roll",
744 "Transpose",
745 "DepthToSpace",
746 "SpaceToDepth",
747 "BatchToSpace",
748 "BatchToSpaceND",
749 "SpaceToBatch",
750 "SpaceToBatchND",
751 }));
752 return IsValueAndOrderPreserving(node) ||
753 kValuePreservingOps->count(node.op()) > 0;
754 }
755
IsUnaryElementWise(const NodeDef & node)756 bool IsUnaryElementWise(const NodeDef& node) {
757 static const gtl::FlatSet<string>* const kElementWiseOps =
758 CHECK_NOTNULL((new gtl::FlatSet<string>{
759 "Abs",
760 "Acos",
761 "Acosh",
762 "Asin",
763 "Asinh",
764 "Atan",
765 "Atanh",
766 "Ceil",
767 "ComplexAbs",
768 "Conj",
769 "Cos",
770 "Cosh",
771 "Digamma",
772 "Elu"
773 "Erf",
774 "Erfc",
775 "Exp",
776 "Expm1",
777 "Floor",
778 "Inv",
779 "Invert",
780 "Isinf",
781 "Isnan",
782 "Isfinite",
783 "Lgamma",
784 "Log",
785 "Log1p",
786 "LogicalNot",
787 "Neg",
788 "Reciprocal",
789 "Relu",
790 "Relu6",
791 "Rint",
792 "Round",
793 "Selu",
794 "Rsqrt",
795 "Sigmoid",
796 "Sign",
797 "Sin",
798 "SinH",
799 "Softplus",
800 "Softsign",
801 "Sqrt",
802 "Square",
803 "Tan"
804 "Tanh",
805 }));
806 return kElementWiseOps->count(node.op()) > 0 ||
807 IsValueAndOrderAndShapePreserving(node);
808 }
809
HasOpDef(const NodeDef & node)810 bool HasOpDef(const NodeDef& node) {
811 const OpDef* op_def = nullptr;
812 return OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok();
813 }
814
IsIdempotent(const NodeDef & node)815 bool IsIdempotent(const NodeDef& node) {
816 return IsValueAndOrderAndShapePreserving(node) && IsFreeOfSideEffect(node) &&
817 !ModifiesFrameInfo(node);
818 }
819
NeverForwardsInputs(const NodeDef & node)820 bool NeverForwardsInputs(const NodeDef& node) {
821 static const gtl::FlatSet<string>* const kNonForwardingOps = CHECK_NOTNULL(
822 (new gtl::FlatSet<string>{"ArgMax",
823 "ArgMin",
824 "AudioSpectrogram",
825 "BatchMatMul",
826 "BatchToSpace",
827 "BatchToSpaceND",
828 "Bincount",
829 "BroadcastArgs",
830 "BroadcastGradientArgs",
831 "CTCBeamSearchDecoder",
832 "CTCGreedyDecoder",
833 "CTCLoss",
834 "ComplexAbs",
835 "Concat",
836 "ConcatOffset",
837 "ConcatV2",
838 "Copy",
839 "CopyHost",
840 "Cross",
841 "CudnnRNN",
842 "CudnnRNNBackprop",
843 "CudnnRNNBackpropV2",
844 "CudnnRNNBackpropV3",
845 "CudnnRNNCanonicalToParams",
846 "CudnnRNNParamsSize",
847 "CudnnRNNParamsToCanonical",
848 "CudnnRNNV2",
849 "CudnnRNNV3",
850 "CumSum",
851 "CumProd",
852 "DebugNanCount",
853 "DebugNumericSummary",
854 "DecodeProtoV2",
855 "DecodeWav",
856 "DeepCopy",
857 "DepthToSpace",
858 "Dequantize",
859 "Diag",
860 "DiagPart",
861 "EditDistance",
862 "Empty",
863 "EncodeProtoV2",
864 "EncodeWav",
865 "ExtractImagePatches",
866 "ExtractVolumePatches",
867 "Fill",
868 "Gather",
869 "GatherNd",
870 "GatherV2",
871 "HistogramFixedWidth",
872 "InvertPermutation",
873 "IsInf",
874 "IsNan",
875 "Isfinite",
876 "LinSpace",
877 "LowerBound",
878 "MatMul",
879 "MatrixDiag",
880 "MatrixDiagPart",
881 "Mfcc",
882 "OneHot",
883 "Pack",
884 "PopulationCount",
885 "Range",
886 "Rank",
887 "ReverseSequence",
888 "Shape",
889 "ShapeN",
890 "Size",
891 "SpaceToBatch",
892 "SpaceToBatchND",
893 "SpaceToDepth",
894 "SparseMatMul",
895 "Split",
896 "SplitV",
897 "Unique",
898 "UniqueV2",
899 "UniqueWithCounts",
900 "UniqueWithCountsV2",
901 "Unpack",
902 "UnravelIndex",
903 "UpperBound",
904 "Where",
905 "CompareAndBitpack",
906 "Requantize",
907 "RequantizationRange",
908 "Bucketize",
909 "AvgPool",
910 "BatchNormWithGlobalNormalization",
911 "FusedBatchNorm",
912 "FusedBatchNormV2",
913 "Conv2D",
914 "RandomUniform",
915 "RandomUniformInt",
916 "RandomStandardNormal",
917 "ParameterizedTruncatedNormal",
918 "TruncatedNormal",
919 "Multinomial",
920 "RandomGamma",
921 "RandomPoisson",
922 "RandomPoissonV2"}));
923 const string& op_name = node.op();
924 return kNonForwardingOps->count(op_name) > 0 ||
925 str_util::StrContains(op_name, "Segment") ||
926 str_util::StartsWith(op_name, "Quantize");
927 }
928
929 } // namespace grappler
930 } // end namespace tensorflow
931