• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/lite/tools/list_flex_ops.h"
16 
17 #include <fstream>
18 #include <sstream>
19 #include <string>
20 #include <vector>
21 
22 #include "flatbuffers/flexbuffers.h"  // from @flatbuffers
23 #include "json/json.h"
24 #include "tensorflow/core/framework/node_def.pb.h"
25 #include "tensorflow/core/framework/node_def_util.h"
26 #include "tensorflow/core/framework/op.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/platform/logging.h"
29 #include "tensorflow/core/util/device_name_utils.h"
30 #include "tensorflow/lite/schema/schema_utils.h"
31 #include "tensorflow/lite/util.h"
32 
33 namespace tflite {
34 namespace flex {
35 
OpListToJSONString(const OpKernelSet & flex_ops)36 std::string OpListToJSONString(const OpKernelSet& flex_ops) {
37   Json::Value result(Json::arrayValue);
38   for (const OpKernel& op : flex_ops) {
39     Json::Value op_kernel(Json::arrayValue);
40     op_kernel.append(Json::Value(op.op_name));
41     op_kernel.append(Json::Value(op.kernel_name));
42     result.append(op_kernel);
43   }
44   return Json::FastWriter().write(result);
45 }
46 
47 // Find the class name of the op kernel described in the node_def from the pool
48 // of registered ops. If no kernel class is found, return an empty string.
FindTensorflowKernelClass(tensorflow::NodeDef * node_def)49 string FindTensorflowKernelClass(tensorflow::NodeDef* node_def) {
50   if (!node_def || node_def->op().empty()) {
51     LOG(FATAL) << "Invalid NodeDef";
52   }
53 
54   const tensorflow::OpRegistrationData* op_reg_data;
55   auto status =
56       tensorflow::OpRegistry::Global()->LookUp(node_def->op(), &op_reg_data);
57   if (!status.ok()) {
58     LOG(FATAL) << "Op " << node_def->op() << " not found: " << status;
59   }
60   AddDefaultsToNodeDef(op_reg_data->op_def, node_def);
61 
62   tensorflow::DeviceNameUtils::ParsedName parsed_name;
63   if (!tensorflow::DeviceNameUtils::ParseFullName(node_def->device(),
64                                                   &parsed_name)) {
65     LOG(FATAL) << "Failed to parse device from node_def: "
66                << node_def->ShortDebugString();
67   }
68   string class_name;
69   if (!tensorflow::FindKernelDef(
70            tensorflow::DeviceType(parsed_name.type.c_str()), *node_def,
71            nullptr /* kernel_def */, &class_name)
72            .ok()) {
73     LOG(FATAL) << "Failed to find kernel class for op: " << node_def->op();
74   }
75   return class_name;
76 }
77 
AddFlexOpsFromModel(const tflite::Model * model,OpKernelSet * flex_ops)78 void AddFlexOpsFromModel(const tflite::Model* model, OpKernelSet* flex_ops) {
79   // Read flex ops.
80   auto* subgraphs = model->subgraphs();
81   if (!subgraphs) return;
82   for (int subgraph_index = 0; subgraph_index < subgraphs->size();
83        ++subgraph_index) {
84     const tflite::SubGraph* subgraph = subgraphs->Get(subgraph_index);
85     auto* operators = subgraph->operators();
86     auto* opcodes = model->operator_codes();
87     if (!operators || !opcodes) continue;
88     for (int i = 0; i < operators->size(); ++i) {
89       const tflite::Operator* op = operators->Get(i);
90       const tflite::OperatorCode* opcode = opcodes->Get(op->opcode_index());
91       if (tflite::GetBuiltinCode(opcode) != tflite::BuiltinOperator_CUSTOM ||
92           !tflite::IsFlexOp(opcode->custom_code()->c_str())) {
93         continue;
94       }
95 
96       // Remove the "Flex" prefix from op name.
97       std::string flex_op_name(opcode->custom_code()->c_str());
98       std::string tf_op_name =
99           flex_op_name.substr(strlen(tflite::kFlexCustomCodePrefix));
100 
101       // Read NodeDef and find the op kernel class.
102       if (op->custom_options_format() !=
103           tflite::CustomOptionsFormat_FLEXBUFFERS) {
104         LOG(FATAL) << "Invalid CustomOptionsFormat";
105       }
106       const flatbuffers::Vector<uint8_t>* custom_opt_bytes =
107           op->custom_options();
108       if (custom_opt_bytes && custom_opt_bytes->size()) {
109         // NOLINTNEXTLINE: It is common to use references with flatbuffer.
110         const flexbuffers::Vector& v =
111             flexbuffers::GetRoot(custom_opt_bytes->data(),
112                                  custom_opt_bytes->size())
113                 .AsVector();
114         std::string nodedef_str = v[1].AsString().str();
115         tensorflow::NodeDef nodedef;
116         if (nodedef_str.empty() || !nodedef.ParseFromString(nodedef_str)) {
117           LOG(FATAL) << "Failed to parse data into a valid NodeDef";
118         }
119         // Flex delegate only supports running flex ops with CPU.
120         *nodedef.mutable_device() = "/CPU:0";
121         std::string kernel_class = FindTensorflowKernelClass(&nodedef);
122         flex_ops->insert({tf_op_name, kernel_class});
123       }
124     }
125   }
126 }
127 }  // namespace flex
128 }  // namespace tflite
129