1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_TOCO_TFLITE_EXPORT_H_
16 #define TENSORFLOW_LITE_TOCO_TFLITE_EXPORT_H_
17
18 #include "tensorflow/lite/toco/model.h"
19 #include "tensorflow/lite/toco/tflite/operator.h"
20 #include "tensorflow/lite/util.h"
21
22 namespace toco {
23
24 namespace tflite {
25
26 enum class QuantizedBufferType { NONE, INT8, FLOAT16 };
27
28 // The parameters for exporting a TFLite model.
29 struct ExportParams {
30 bool allow_custom_ops = false;
31 bool allow_dynamic_tensors = true;
32 bool enable_select_tf_ops = false;
33 QuantizedBufferType quantize_weights = QuantizedBufferType::NONE;
34 };
35
36 // Transform the given tf.mini model into a TF Lite flatbuffer and deposit the
37 // result in the given string.
38 tensorflow::Status Export(const Model& model, std::string* output_file_contents,
39 const ExportParams& params);
40
41 // Export API with custom TFLite operator mapping.
42 tensorflow::Status Export(
43 const Model& model, std::string* output_file_contents,
44 const ExportParams& params,
45 const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type);
46
47 // This is for backward-compatibility.
48 // TODO(ycling): Remove the deprecated entry functions.
Export(const Model & model,bool allow_custom_ops,bool quantize_weights,std::string * output_file_contents)49 inline void Export(const Model& model, bool allow_custom_ops,
50 bool quantize_weights, std::string* output_file_contents) {
51 ExportParams params;
52 params.allow_custom_ops = allow_custom_ops;
53 params.quantize_weights =
54 quantize_weights ? QuantizedBufferType::INT8 : QuantizedBufferType::NONE;
55 auto status = Export(model, output_file_contents, params);
56 if (!status.ok()) LOG(QFATAL) << status.error_message();
57 }
58
59 // This is for backward-compatibility.
60 // TODO(ycling): Remove the deprecated entry functions.
Export(const Model & model,bool allow_custom_ops,bool quantize_weights,std::string * output_file_contents,const std::map<OperatorType,std::unique_ptr<BaseOperator>> & ops_by_type)61 inline void Export(
62 const Model& model, bool allow_custom_ops, bool quantize_weights,
63 std::string* output_file_contents,
64 const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) {
65 ExportParams params;
66 params.allow_custom_ops = allow_custom_ops;
67 params.quantize_weights =
68 quantize_weights ? QuantizedBufferType::INT8 : QuantizedBufferType::NONE;
69 auto status = Export(model, output_file_contents, params, ops_by_type);
70 if (!status.ok()) LOG(QFATAL) << status.error_message();
71 }
72
73 // This is for backward-compatibility.
74 // TODO(ycling): Remove the deprecated entry functions.
Export(const Model & model,std::string * output_file_contents)75 inline void Export(const Model& model, std::string* output_file_contents) {
76 ExportParams params;
77 params.allow_custom_ops = true;
78 auto status = Export(model, output_file_contents, params);
79 if (!status.ok()) LOG(QFATAL) << status.error_message();
80 }
81
82 namespace details {
83
84 // A map from tensor name to its final position in the TF Lite buffer.
85 using TensorsMap = std::unordered_map<std::string, int>;
86
87 // A key to identify an operator.
88 // Only when `type` is `kUnsupported`, `custom_code` is filled to
89 // identify which operation is used.
90 class OperatorKey {
91 public:
OperatorKey()92 OperatorKey() {}
93
94 // Construct OperatorKey by Toco op.
95 OperatorKey(
96 const ::toco::OperatorSignature& op_signature,
97 const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
98 bool enable_select_tf_ops);
99
100 // Construct OperatorKey by type, custom code and version.
101 // Note that this construct doesn't set the additional information including
102 // `is_custom_op`, `is_flex_op`, `is_unsupported_flex_op`.
OperatorKey(::tflite::BuiltinOperator type,const std::string & custom_code,int version)103 OperatorKey(::tflite::BuiltinOperator type, const std::string& custom_code,
104 int version)
105 : type_(type), custom_code_(custom_code), version_(version) {}
106
107 // Only `type`, `custom_code` and `version` is used to compute hash and
108 // identity.
type()109 ::tflite::BuiltinOperator type() const { return type_; }
custom_code()110 const std::string& custom_code() const { return custom_code_; }
version()111 int version() const { return version_; }
112
113 // The attributes below are not used to compute hash and identity.
114 //
115 // Return true if the op is a custom op. Note it will return false for Flex
116 // ops.
is_custom_op()117 bool is_custom_op() const { return is_custom_op_; }
118 // Return true if the op is a Flex op.
is_flex_op()119 bool is_flex_op() const { return is_flex_op_; }
120 // Return true if the op is a Flex op but it's knwon that the op is not
121 // supported by Flex runtime.
is_unsupported_flex_op()122 bool is_unsupported_flex_op() const { return is_unsupported_flex_op_; }
123 // Return the original TensorFlow op name for a Flex op.
flex_tensorflow_op()124 const std::string& flex_tensorflow_op() const { return flex_tensorflow_op_; }
125
126 bool operator<(const OperatorKey& other) const {
127 if (type_ < other.type_)
128 return true;
129 else if (type_ > other.type_)
130 return false;
131 else if (custom_code_ < other.custom_code_)
132 return true;
133 else if (custom_code_ > other.custom_code_)
134 return false;
135 else
136 return version_ < other.version_;
137 }
138
139 bool operator==(const OperatorKey& other) const {
140 return type_ == other.type_ && custom_code_ == other.custom_code_ &&
141 version_ == other.version_;
142 }
143
144 struct Hash {
operatorHash145 size_t operator()(const OperatorKey& key) const {
146 return ::tflite::CombineHashes(
147 {std::hash<size_t>()(static_cast<size_t>(key.type())),
148 std::hash<std::string>()(key.custom_code()),
149 std::hash<int>()(key.version())});
150 }
151 };
152
153 private:
154 ::tflite::BuiltinOperator type_ = ::tflite::BuiltinOperator_CUSTOM;
155 std::string custom_code_;
156 int version_ = 1;
157
158 bool is_custom_op_ = false;
159 bool is_flex_op_ = false;
160 bool is_unsupported_flex_op_ = false;
161 // The original TensorFlow op name for the flex op. Filled only when
162 // `is_flex_op` is true.
163 std::string flex_tensorflow_op_;
164 };
165
166 // A map from OperatorKey to its final position in the TF Lite buffer.
167 using OperatorsMap = std::unordered_map<OperatorKey, int, OperatorKey::Hash>;
168
169 void LoadTensorsMap(const Model& model, TensorsMap* tensors_map);
170 void LoadOperatorsMap(
171 const Model& model, OperatorsMap* operators_map,
172 const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
173 bool enable_select_tf_ops);
174
175 } // namespace details
176 } // namespace tflite
177 } // namespace toco
178
179 #endif // TENSORFLOW_LITE_TOCO_TFLITE_EXPORT_H_
180