1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/list_flex_ops.h"
16
17 #include <fstream>
18 #include <sstream>
19 #include <string>
20 #include <vector>
21
22 #include "flatbuffers/flexbuffers.h" // from @flatbuffers
23 #include "json/json.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/framework/node_def_util.h"
26 #include "tensorflow/core/framework/op.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/platform/logging.h"
29 #include "tensorflow/core/util/device_name_utils.h"
30 #include "tensorflow/lite/schema/schema_utils.h"
31 #include "tensorflow/lite/util.h"
32
33 namespace tflite {
34 namespace flex {
35
OpListToJSONString(const OpKernelSet & flex_ops)36 std::string OpListToJSONString(const OpKernelSet& flex_ops) {
37 Json::Value result(Json::arrayValue);
38 for (const OpKernel& op : flex_ops) {
39 Json::Value op_kernel(Json::arrayValue);
40 op_kernel.append(Json::Value(op.op_name));
41 op_kernel.append(Json::Value(op.kernel_name));
42 result.append(op_kernel);
43 }
44 return Json::FastWriter().write(result);
45 }
46
47 // Find the class name of the op kernel described in the node_def from the pool
48 // of registered ops. If no kernel class is found, return an empty string.
FindTensorflowKernelClass(tensorflow::NodeDef * node_def)49 string FindTensorflowKernelClass(tensorflow::NodeDef* node_def) {
50 if (!node_def || node_def->op().empty()) {
51 LOG(FATAL) << "Invalid NodeDef";
52 }
53
54 const tensorflow::OpRegistrationData* op_reg_data;
55 auto status =
56 tensorflow::OpRegistry::Global()->LookUp(node_def->op(), &op_reg_data);
57 if (!status.ok()) {
58 LOG(FATAL) << "Op " << node_def->op() << " not found: " << status;
59 }
60 AddDefaultsToNodeDef(op_reg_data->op_def, node_def);
61
62 tensorflow::DeviceNameUtils::ParsedName parsed_name;
63 if (!tensorflow::DeviceNameUtils::ParseFullName(node_def->device(),
64 &parsed_name)) {
65 LOG(FATAL) << "Failed to parse device from node_def: "
66 << node_def->ShortDebugString();
67 }
68 string class_name;
69 if (!tensorflow::FindKernelDef(
70 tensorflow::DeviceType(parsed_name.type.c_str()), *node_def,
71 nullptr /* kernel_def */, &class_name)
72 .ok()) {
73 LOG(FATAL) << "Failed to find kernel class for op: " << node_def->op();
74 }
75 return class_name;
76 }
77
AddFlexOpsFromModel(const tflite::Model * model,OpKernelSet * flex_ops)78 void AddFlexOpsFromModel(const tflite::Model* model, OpKernelSet* flex_ops) {
79 // Read flex ops.
80 auto* subgraphs = model->subgraphs();
81 if (!subgraphs) return;
82 for (int subgraph_index = 0; subgraph_index < subgraphs->size();
83 ++subgraph_index) {
84 const tflite::SubGraph* subgraph = subgraphs->Get(subgraph_index);
85 auto* operators = subgraph->operators();
86 auto* opcodes = model->operator_codes();
87 if (!operators || !opcodes) continue;
88 for (int i = 0; i < operators->size(); ++i) {
89 const tflite::Operator* op = operators->Get(i);
90 const tflite::OperatorCode* opcode = opcodes->Get(op->opcode_index());
91 if (tflite::GetBuiltinCode(opcode) != tflite::BuiltinOperator_CUSTOM ||
92 !tflite::IsFlexOp(opcode->custom_code()->c_str())) {
93 continue;
94 }
95
96 // Remove the "Flex" prefix from op name.
97 std::string flex_op_name(opcode->custom_code()->c_str());
98 std::string tf_op_name =
99 flex_op_name.substr(strlen(tflite::kFlexCustomCodePrefix));
100
101 // Read NodeDef and find the op kernel class.
102 if (op->custom_options_format() !=
103 tflite::CustomOptionsFormat_FLEXBUFFERS) {
104 LOG(FATAL) << "Invalid CustomOptionsFormat";
105 }
106 const flatbuffers::Vector<uint8_t>* custom_opt_bytes =
107 op->custom_options();
108 if (custom_opt_bytes && custom_opt_bytes->size()) {
109 // NOLINTNEXTLINE: It is common to use references with flatbuffer.
110 const flexbuffers::Vector& v =
111 flexbuffers::GetRoot(custom_opt_bytes->data(),
112 custom_opt_bytes->size())
113 .AsVector();
114 std::string nodedef_str = v[1].AsString().str();
115 tensorflow::NodeDef nodedef;
116 if (nodedef_str.empty() || !nodedef.ParseFromString(nodedef_str)) {
117 LOG(FATAL) << "Failed to parse data into a valid NodeDef";
118 }
119 // Flex delegate only supports running flex ops with CPU.
120 *nodedef.mutable_device() = "/CPU:0";
121 std::string kernel_class = FindTensorflowKernelClass(&nodedef);
122 flex_ops->insert({tf_op_name, kernel_class});
123 }
124 }
125 }
126 }
127 } // namespace flex
128 } // namespace tflite
129