1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/delegates/flex/allowlisted_flex_ops.h"
16
17 #include <set>
18
19 #include "tensorflow/core/framework/op.h"
20 #include "tensorflow/lite/delegates/flex/allowlisted_flex_ops_internal.h"
21
22 namespace tflite {
23 namespace flex {
24
GetFlexAllowlist()25 const std::set<std::string>& GetFlexAllowlist() {
26 // LINT.IfChange
27 static const std::set<std::string>* allowlisted_flex_ops =
28 new std::set<std::string>({
29 // go/keep-sorted start
30 "Abort",
31 "Abs",
32 "Add",
33 "AddN",
34 "AddV2",
35 "AdjustContrast",
36 "AdjustContrastv2",
37 "AdjustHue",
38 "AdjustSaturation",
39 "All",
40 "Angle",
41 "Any",
42 "ApplyAdaMax",
43 "ApplyAdadelta",
44 "ApplyAdagrad",
45 "ApplyAdagradDA",
46 "ApplyAdagradV2",
47 "ApplyAdam",
48 "ApplyAddSign",
49 "ApplyCenteredRMSProp",
50 "ApplyFtrl",
51 "ApplyFtrlV2",
52 "ApplyGradientDescent",
53 "ApplyMomentum",
54 "ApplyPowerSign",
55 "ApplyProximalAdagrad",
56 "ApplyProximalGradientDescent",
57 "ApplyRMSProp",
58 "ApproximateEqual",
59 "ArgMax",
60 "ArgMin",
61 "Assert",
62 "Assign",
63 "AssignAdd",
64 "AssignAddVariableOp",
65 "AssignSub",
66 "AssignSubVariableOp",
67 "AssignVariableOp",
68 "Atan",
69 "Atan2",
70 "AudioSpectrogram",
71 "AvgPool",
72 "AvgPool3D",
73 "AvgPool3DGrad",
74 "AvgPoolGrad",
75 "BatchCholesky",
76 "BatchMatMul",
77 "BatchMatMulV2",
78 "BatchMatrixDeterminant",
79 "BatchMatrixDiag",
80 "BatchMatrixDiagPart",
81 "BatchMatrixInverse",
82 "BatchMatrixSetDiag",
83 "BatchMatrixTriangularSolve",
84 "BatchNormWithGlobalNormalization",
85 "BatchNormWithGlobalNormalizationGrad",
86 "BatchToSpace",
87 "BatchToSpaceND",
88 "BiasAdd",
89 "BiasAddGrad",
90 "BiasAddV1",
91 "Bincount",
92 "Bitcast",
93 "BitwiseAnd",
94 "BitwiseOr",
95 "BitwiseXor",
96 "BoostedTreesBucketize",
97 "BoostedTreesCreateQuantileStreamResource",
98 "BoostedTreesFlushQuantileSummaries",
99 "BoostedTreesMakeQuantileSummaries",
100 "BoostedTreesQuantileStreamResourceAddSummaries",
101 "BoostedTreesQuantileStreamResourceDeserialize",
102 "BoostedTreesQuantileStreamResourceFlush",
103 "BoostedTreesQuantileStreamResourceGetBucketBoundaries",
104 "BoostedTreesQuantileStreamResourceHandleOp",
105 "BroadcastArgs",
106 "BroadcastGradientArgs",
107 "BroadcastTo",
108 "Bucketize",
109 "CTCBeamSearchDecoder",
110 "CTCGreedyDecoder",
111 "Cast",
112 "Ceil",
113 "CheckNumerics",
114 "CheckNumericsV2",
115 "Cholesky",
116 "CombinedNonMaxSuppression",
117 "Complex",
118 "ComplexAbs",
119 "Concat",
120 "ConcatOffset",
121 "ConcatV2",
122 "Conj",
123 "ConjugateTranspose",
124 "Const",
125 "ControlTrigger",
126 "Conv2D",
127 "Conv2DBackpropFilter",
128 "Conv2DBackpropInput",
129 "Conv3D",
130 "Conv3DBackpropFilter",
131 "Conv3DBackpropFilterV2",
132 "Conv3DBackpropInput",
133 "Conv3DBackpropInputV2",
134 "Cos",
135 "Cosh",
136 "CropAndResize",
137 "CropAndResizeGradBoxes",
138 "CropAndResizeGradImage",
139 "Cumprod",
140 "Cumsum",
141 "CumulativeLogsumexp",
142 "DataFormatDimMap",
143 "DataFormatVecPermute",
144 "DebugGradientIdentity",
145 "DebugGradientRefIdentity",
146 "DecodeAndCropJpeg",
147 "DecodeBase64",
148 "DecodeBmp",
149 "DecodeGif",
150 "DecodeImage",
151 "DecodeJpeg",
152 "DecodePng",
153 "DecodeRaw",
154 "DecodeWav",
155 "DeepCopy",
156 "DeleteSessionTensor",
157 "DenseBincount",
158 "DenseToDenseSetOperation",
159 "DenseToSparseSetOperation",
160 "DepthToSpace",
161 "DepthwiseConv2dNative",
162 "DepthwiseConv2dNativeBackpropFilter",
163 "DepthwiseConv2dNativeBackpropInput",
164 "Dequantize",
165 "DestroyResourceOp",
166 "DestroyTemporaryVariable",
167 "Diag",
168 "DiagPart",
169 "Dilation2D",
170 "Dilation2DBackpropFilter",
171 "Dilation2DBackpropInput",
172 "Div",
173 "DivNoNan",
174 "DynamicPartition",
175 "DynamicStitch",
176 "Einsum",
177 "Elu",
178 "EluGrad",
179 "Empty",
180 "EmptyTensorList",
181 "EmptyTensorMap",
182 "EncodeBase64",
183 "EncodeJpeg",
184 "EncodeJpegVariableQuality",
185 "EncodePng",
186 "EncodeWav",
187 "EnsureShape",
188 "Enter",
189 "Equal",
190 "Erf",
191 "Exit",
192 "Exp",
193 "ExpandDims",
194 "ExtractImagePatches",
195 "FFT",
196 "FFT2D",
197 "FFT3D",
198 "FIFOQueue",
199 "FIFOQueueV2",
200 "FakeQuantWithMinMaxArgs",
201 "FakeQuantWithMinMaxArgsGradient",
202 "FakeQuantWithMinMaxVars",
203 "FakeQuantWithMinMaxVarsGradient",
204 "FakeQuantWithMinMaxVarsPerChannel",
205 "FakeQuantWithMinMaxVarsPerChannelGradient",
206 "FakeQueue",
207 "Fill",
208 "Fingerprint",
209 "Floor",
210 "FloorDiv",
211 "FloorMod",
212 "FusedBatchNorm",
213 "FusedBatchNormGrad",
214 "FusedBatchNormGradV2",
215 "FusedBatchNormGradV3",
216 "FusedBatchNormV2",
217 "FusedBatchNormV3",
218 "FusedPadConv2D",
219 "FusedResizeAndPadConv2D",
220 "Gather",
221 "GatherNd",
222 "GatherV2",
223 "GetSessionHandle",
224 "GetSessionHandleV2",
225 "GetSessionTensor",
226 "Greater",
227 "GreaterEqual",
228 "HistogramSummary",
229 "IFFT",
230 "IFFT2D",
231 "IFFT3D",
232 "IRFFT",
233 "IRFFT2D",
234 "IRFFT3D",
235 "Identity",
236 "IdentityN",
237 "Imag",
238 "ImageProjectiveTransformV2",
239 "ImageProjectiveTransformV3",
240 "ImmutableConst",
241 "InTopK",
242 "InTopKV2",
243 "InplaceAdd",
244 "InplaceSub",
245 "InplaceUpdate",
246 "Inv",
247 "InvGrad",
248 "Invert",
249 "InvertPermutation",
250 "IsBoostedTreesQuantileStreamResourceInitialized",
251 "IsFinite",
252 "IsNan",
253 "IsVariableInitialized",
254 "LRN",
255 "LeakyRelu",
256 "LeakyReluGrad",
257 "LeftShift",
258 "Less",
259 "LessEqual",
260 "LinSpace",
261 "ListDiff",
262 "Log",
263 "LogMatrixDeterminant",
264 "LogSoftmax",
265 "LogicalAnd",
266 "LogicalNot",
267 "LogicalOr",
268 "LoopCond",
269 "MatMul",
270 "MatrixDeterminant",
271 "MatrixDiag",
272 "MatrixDiagPart",
273 "MatrixDiagPartV2",
274 "MatrixDiagPartV3",
275 "MatrixDiagV2",
276 "MatrixDiagV3",
277 "MatrixInverse",
278 "MatrixSetDiag",
279 "MatrixSetDiagV2",
280 "MatrixSetDiagV3",
281 "MatrixTriangularSolve",
282 "Max",
283 "MaxPool",
284 "MaxPool3D",
285 "MaxPool3DGrad",
286 "MaxPool3DGradGrad",
287 "MaxPoolGrad",
288 "MaxPoolGradGrad",
289 "MaxPoolGradGradV2",
290 "MaxPoolGradV2",
291 "MaxPoolGradWithArgmax",
292 "MaxPoolV2",
293 "MaxPoolWithArgmax",
294 "Maximum",
295 "Mean",
296 "Merge",
297 "MergeSummary",
298 "MergeV2Checkpoints",
299 "Mfcc",
300 "Min",
301 "Minimum",
302 "MirrorPad",
303 "MirrorPadGrad",
304 "Mul",
305 "MulNoNan",
306 "Multinomial",
307 "Neg",
308 "NextIteration",
309 "NoOp",
310 "NonMaxSuppression",
311 "NonMaxSuppressionV2",
312 "NonMaxSuppressionV3",
313 "NonMaxSuppressionV4",
314 "NonMaxSuppressionV5",
315 "NonMaxSuppressionWithOverlaps",
316 "NotEqual",
317 "OneHot",
318 "OnesLike",
319 "Pack",
320 "Pad",
321 "PadV2",
322 "PaddingFIFOQueue",
323 "PaddingFIFOQueueV2",
324 "ParallelConcat",
325 "ParallelDynamicStitch",
326 "ParseExample",
327 "ParseExampleV2",
328 "ParseSequenceExample",
329 "ParseSequenceExampleV2",
330 "ParseSingleExample",
331 "ParseSingleSequenceExample",
332 "Placeholder",
333 "PlaceholderV2",
334 "PlaceholderWithDefault",
335 "PopulationCount",
336 "Pow",
337 "PreventGradient",
338 "Print",
339 "PrintV2",
340 "Prod",
341 "QuantizeDownAndShrinkRange",
342 "QuantizeV2",
343 "QuantizedAdd",
344 "QuantizedAvgPool",
345 "QuantizedBatchNormWithGlobalNormalization",
346 "QuantizedBiasAdd",
347 "QuantizedConcat",
348 "QuantizedConv2D",
349 "QuantizedInstanceNorm",
350 "QuantizedMatMul",
351 "QuantizedMaxPool",
352 "QuantizedMul",
353 "QuantizedRelu",
354 "QuantizedRelu6",
355 "QuantizedReshape",
356 "QuantizedResizeBilinear",
357 "QueueClose",
358 "QueueCloseV2",
359 "QueueDequeue",
360 "QueueDequeueMany",
361 "QueueDequeueManyV2",
362 "QueueDequeueUpTo",
363 "QueueDequeueUpToV2",
364 "QueueDequeueV2",
365 "QueueEnqueue",
366 "QueueEnqueueMany",
367 "QueueEnqueueManyV2",
368 "QueueEnqueueV2",
369 "QueueIsClosed",
370 "QueueIsClosedV2",
371 "QueueSize",
372 "QueueSizeV2",
373 "RFFT",
374 "RFFT2D",
375 "RFFT3D",
376 "RaggedBincount",
377 "RaggedGather",
378 "RaggedRange",
379 "RaggedTensorToSparse",
380 "RaggedTensorToTensor",
381 "RandomGamma",
382 "RandomPoisson",
383 "RandomPoissonV2",
384 "RandomStandardNormal",
385 "RandomUniform",
386 "RandomUniformInt",
387 "Range",
388 "Rank",
389 "ReadVariableOp",
390 "Real",
391 "RealDiv",
392 "Reciprocal",
393 "ReciprocalGrad",
394 "Recv",
395 "ReduceJoin",
396 "RefEnter",
397 "RefExit",
398 "RefIdentity",
399 "RefMerge",
400 "RefNextIteration",
401 "RefSelect",
402 "RefSwitch",
403 "RegexFullMatch",
404 "RegexReplace",
405 "Relu",
406 "Relu6",
407 "Relu6Grad",
408 "ReluGrad",
409 "RemoteCall",
410 "RequantizationRange",
411 "Requantize",
412 "Reshape",
413 "ResizeBicubic",
414 "ResizeBicubicGrad",
415 "ResizeBilinear",
416 "ResizeBilinearGrad",
417 "ResizeNearestNeighbor",
418 "ResizeNearestNeighborGrad",
419 "ResourceApplyAdaMax",
420 "ResourceApplyAdadelta",
421 "ResourceApplyAdagrad",
422 "ResourceApplyAdagradDA",
423 "ResourceApplyAdagradV2",
424 "ResourceApplyAdam",
425 "ResourceApplyAdamWithAmsgrad",
426 "ResourceApplyAddSign",
427 "ResourceApplyCenteredRMSProp",
428 "ResourceApplyFtrl",
429 "ResourceApplyFtrlV2",
430 "ResourceApplyGradientDescent",
431 "ResourceApplyKerasMomentum",
432 "ResourceApplyMomentum",
433 "ResourceApplyPowerSign",
434 "ResourceApplyProximalAdagrad",
435 "ResourceApplyProximalGradientDescent",
436 "ResourceApplyRMSProp",
437 "ResourceGather",
438 "ResourceGatherNd",
439 "ResourceScatterAdd",
440 "ResourceScatterDiv",
441 "ResourceScatterMax",
442 "ResourceScatterMin",
443 "ResourceScatterMul",
444 "ResourceScatterNdAdd",
445 "ResourceScatterNdMax",
446 "ResourceScatterNdMin",
447 "ResourceScatterNdSub",
448 "ResourceScatterNdUpdate",
449 "ResourceScatterSub",
450 "ResourceScatterUpdate",
451 "ResourceSparseApplyAdadelta",
452 "ResourceSparseApplyAdagrad",
453 "ResourceSparseApplyAdagradDA",
454 "ResourceSparseApplyAdagradV2",
455 "ResourceSparseApplyCenteredRMSProp",
456 "ResourceSparseApplyFtrl",
457 "ResourceSparseApplyFtrlV2",
458 "ResourceSparseApplyKerasMomentum",
459 "ResourceSparseApplyMomentum",
460 "ResourceSparseApplyProximalAdagrad",
461 "ResourceSparseApplyProximalGradientDescent",
462 "ResourceSparseApplyRMSProp",
463 "ResourceStridedSliceAssign",
464 "Restore",
465 "RestoreSlice",
466 "RestoreV2",
467 "Reverse",
468 "ReverseSequence",
469 "ReverseV2",
470 "RightShift",
471 "Roll",
472 "Round",
473 "Rsqrt",
474 "RsqrtGrad",
475 "SampleDistortedBoundingBox",
476 "SampleDistortedBoundingBoxV2",
477 "Save",
478 "SaveSlices",
479 "SaveV2",
480 "ScalarSummary",
481 "ScatterNd",
482 "ScatterNdAdd",
483 "ScatterNdMax",
484 "ScatterNdMin",
485 "ScatterNdNonAliasingAdd",
486 "ScatterNdSub",
487 "ScatterNdUpdate",
488 "SegmentMax",
489 "SegmentMean",
490 "SegmentMin",
491 "SegmentProd",
492 "SegmentSum",
493 "Select",
494 "SelectV2",
495 "Selu",
496 "SeluGrad",
497 "Send",
498 "Shape",
499 "ShapeN",
500 "ShardedFilename",
501 "ShardedFilespec",
502 "Sigmoid",
503 "SigmoidGrad",
504 "Sign",
505 "Sin",
506 "Sinh",
507 "Size",
508 "Slice",
509 "Softmax",
510 "SoftmaxCrossEntropyWithLogits",
511 "Softplus",
512 "SoftplusGrad",
513 "Softsign",
514 "SoftsignGrad",
515 "SpaceToBatch",
516 "SpaceToBatchND",
517 "SpaceToDepth",
518 "SparseApplyAdadelta",
519 "SparseApplyAdagrad",
520 "SparseApplyAdagradDA",
521 "SparseApplyAdagradV2",
522 "SparseApplyCenteredRMSProp",
523 "SparseApplyFtrl",
524 "SparseApplyFtrlV2",
525 "SparseApplyMomentum",
526 "SparseApplyProximalAdagrad",
527 "SparseApplyProximalGradientDescent",
528 "SparseApplyRMSProp",
529 "SparseBincount",
530 "SparseCross",
531 "SparseCrossHashed",
532 "SparseCrossV2",
533 "SparseFillEmptyRows",
534 "SparseFillEmptyRowsGrad",
535 "SparseReshape",
536 "SparseSegmentMean",
537 "SparseSegmentMeanGrad",
538 "SparseSegmentMeanWithNumSegments",
539 "SparseSegmentSqrtN",
540 "SparseSegmentSqrtNGrad",
541 "SparseSegmentSqrtNWithNumSegments",
542 "SparseSegmentSum",
543 "SparseSegmentSumWithNumSegments",
544 "SparseToDense",
545 "SparseToSparseSetOperation",
546 "Split",
547 "SplitV",
548 "Sqrt",
549 "SqrtGrad",
550 "Square",
551 "SquaredDifference",
552 "Squeeze",
553 "Stack",
554 "StackClose",
555 "StackCloseV2",
556 "StackPop",
557 "StackPopV2",
558 "StackPush",
559 "StackPushV2",
560 "StackV2",
561 "StatelessMultinomial",
562 "StatelessRandomGammaV2",
563 "StatelessRandomGetKeyCounterAlg",
564 "StatelessRandomNormal",
565 "StatelessRandomNormalV2",
566 "StatelessRandomPoisson",
567 "StatelessRandomUniform",
568 "StatelessRandomUniformFullInt",
569 "StatelessRandomUniformFullIntV2",
570 "StatelessRandomUniformInt",
571 "StatelessRandomUniformIntV2",
572 "StatelessRandomUniformV2",
573 "StatelessSampleDistortedBoundingBox",
574 "StatelessTruncatedNormal",
575 "StatelessTruncatedNormalV2",
576 "StaticRegexFullMatch",
577 "StaticRegexReplace",
578 "StopGradient",
579 "StridedSlice",
580 "StridedSliceAssign",
581 "StridedSliceGrad",
582 "StringFormat",
583 "StringJoin",
584 "StringLength",
585 "StringLower",
586 "StringSplit",
587 "StringSplitV2",
588 "StringStrip",
589 "StringToHashBucket",
590 "StringToHashBucketFast",
591 "StringToHashBucketStrong",
592 "StringToNumber",
593 "Sub",
594 "Substr",
595 "Sum",
596 "Switch",
597 "SymbolicGradient",
598 "Tan",
599 "Tanh",
600 "TanhGrad",
601 "TemporaryVariable",
602 "TensorArray",
603 "TensorArrayClose",
604 "TensorArrayCloseV2",
605 "TensorArrayCloseV3",
606 "TensorArrayConcat",
607 "TensorArrayConcatV2",
608 "TensorArrayConcatV3",
609 "TensorArrayGather",
610 "TensorArrayGatherV2",
611 "TensorArrayGatherV3",
612 "TensorArrayGrad",
613 "TensorArrayGradV2",
614 "TensorArrayGradV3",
615 "TensorArrayGradWithShape",
616 "TensorArrayPack",
617 "TensorArrayRead",
618 "TensorArrayReadV2",
619 "TensorArrayReadV3",
620 "TensorArrayScatter",
621 "TensorArrayScatterV2",
622 "TensorArrayScatterV3",
623 "TensorArraySize",
624 "TensorArraySizeV2",
625 "TensorArraySizeV3",
626 "TensorArraySplit",
627 "TensorArraySplitV2",
628 "TensorArraySplitV3",
629 "TensorArrayUnpack",
630 "TensorArrayV2",
631 "TensorArrayV3",
632 "TensorArrayWrite",
633 "TensorArrayWriteV2",
634 "TensorArrayWriteV3",
635 "TensorListConcat",
636 "TensorListConcatLists",
637 "TensorListConcatV2",
638 "TensorListElementShape",
639 "TensorListFromTensor",
640 "TensorListGather",
641 "TensorListGetItem",
642 "TensorListLength",
643 "TensorListPopBack",
644 "TensorListPushBack",
645 "TensorListPushBackBatch",
646 "TensorListReserve",
647 "TensorListResize",
648 "TensorListScatter",
649 "TensorListScatterIntoExistingList",
650 "TensorListScatterV2",
651 "TensorListSetItem",
652 "TensorListSplit",
653 "TensorListStack",
654 "TensorMapErase",
655 "TensorMapHasKey",
656 "TensorMapInsert",
657 "TensorMapLookup",
658 "TensorMapSize",
659 "TensorMapStackKeys",
660 "TensorScatterAdd",
661 "TensorScatterMax",
662 "TensorScatterMin",
663 "TensorScatterSub",
664 "TensorScatterUpdate",
665 "TensorStridedSliceUpdate",
666 "Tile",
667 "TileGrad",
668 "Timestamp",
669 "TopK",
670 "TopKV2",
671 "Transpose",
672 "TruncateDiv",
673 "TruncatedNormal",
674 "UnicodeDecode",
675 "UnicodeDecodeWithOffsets",
676 "UnicodeEncode",
677 "UnicodeTranscode",
678 "Unique",
679 "UniqueV2",
680 "UniqueWithCounts",
681 "UniqueWithCountsV2",
682 "Unpack",
683 "UnsortedSegmentMax",
684 "UnsortedSegmentMin",
685 "UnsortedSegmentProd",
686 "UnsortedSegmentSum",
687 "UnwrapDatasetVariant",
688 "VarHandleOp",
689 "VarIsInitializedOp",
690 "Variable",
691 "VariableShape",
692 "VariableV2",
693 "Where",
694 "WrapDatasetVariant",
695 "Xdivy",
696 "Xlog1py",
697 "Xlogy",
698 "ZerosLike",
699 "_Arg",
700 "_ArrayToList",
701 "_DeviceArg",
702 "_DeviceRetval",
703 "_FusedConv2D",
704 "_HostCast",
705 "_HostRecv",
706 "_HostSend",
707 "_ListToArray",
708 "_ParallelConcatStart",
709 "_ParallelConcatUpdate",
710 "_ReadVariablesOp",
711 "_Recv",
712 "_Retval",
713 "_Send",
714 "_SwitchN",
715 "_VarHandlesOp",
716 // go/keep-sorted end
717 });
718 // LINT.ThenChange(//tensorflow/lite/g3doc/guide/op_select_allowlist.md)
719
720 return *allowlisted_flex_ops;
721 // Prevent lint error about this function being too long. This function
722 // is a set of ops, and making it shorter won't help readbility.
723 // NOLINTNEXTLINE
724 }
725
GetTFTextFlexAllowlist()726 const std::set<std::string>& GetTFTextFlexAllowlist() {
727 // LINT.IfChange
728 static const std::set<std::string>* tftext_flex_ops =
729 new std::set<std::string>({
730 "CaseFoldUTF8",
731 "ConstrainedSequence",
732 "MaxSpanningTree",
733 "NormalizeUTF8",
734 "NormalizeUTF8WithOffsetsMap",
735 "RegexSplitWithOffsets",
736 "RougeL",
737 "SentenceFragments",
738 "SentencepieceOp",
739 "SentencepieceTokenizeOp",
740 "SentencepieceTokenizeWithOffsetsOp",
741 "SentencepieceDetokenizeOp",
742 "SentencepieceVocabSizeOp",
743 "SplitMergeTokenizeWithOffsets",
744 "UnicodeScriptTokenizeWithOffsets",
745 "WhitespaceTokenizeWithOffsets",
746 "WordpieceTokenizeWithOffsets",
747 });
748 // LINT.ThenChange(//tensorflow/lite/g3doc/guide/op_select_allowlist.md)
749
750 return *tftext_flex_ops;
751 }
752
753 // Allow the tf.text ops if they are registered in the global op registry.
IsAllowedTFTextOpForFlex(const std::string & op_name)754 bool IsAllowedTFTextOpForFlex(const std::string& op_name) {
755 if (GetTFTextFlexAllowlist().count(op_name) == 0) return false;
756 return tensorflow::OpRegistry::Global()->LookUp(op_name) != nullptr;
757 }
758
GetSentencePieceFlexAllowlist()759 const std::set<std::string>& GetSentencePieceFlexAllowlist() {
760 // LINT.IfChange
761 static const std::set<std::string>* sentencepiece_flex_ops =
762 new std::set<std::string>({
763 "SentencepieceGetPieceSize",
764 "SentencepiecePieceToId",
765 "SentencepieceIdToPiece",
766 "SentencepieceEncodeDense",
767 "SentencepieceEncodeSparse",
768 "SentencepieceDecode",
769 });
770 // LINT.ThenChange(//tensorflow/lite/g3doc/guide/op_select_allowlist.md)
771
772 return *sentencepiece_flex_ops;
773 }
774
775 // Allow the sentencepiece ops if they are registered in the global op registry.
IsAllowedSentencePieceOpForFlex(const std::string & op_name)776 bool IsAllowedSentencePieceOpForFlex(const std::string& op_name) {
777 if (GetSentencePieceFlexAllowlist().count(op_name) == 0) return false;
778 return tensorflow::OpRegistry::Global()->LookUp(op_name) != nullptr;
779 }
780
IsAllowlistedFlexOp(const std::string & tensorflow_op_name)781 bool IsAllowlistedFlexOp(const std::string& tensorflow_op_name) {
782 if (GetFlexAllowlist().count(tensorflow_op_name) != 0) return true;
783
784 // Check if the op is an allowlisted tf.text or sentencepiece op.
785 return IsAllowedTFTextOpForFlex(tensorflow_op_name) ||
786 IsAllowedSentencePieceOpForFlex(tensorflow_op_name);
787 }
788
789 } // namespace flex
790 } // namespace tflite
791