• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_TOCO_TFLITE_EXPORT_H_
16 #define TENSORFLOW_LITE_TOCO_TFLITE_EXPORT_H_
17 
18 #include "tensorflow/lite/toco/model.h"
19 #include "tensorflow/lite/toco/tflite/operator.h"
20 #include "tensorflow/lite/util.h"
21 
22 namespace toco {
23 
24 namespace tflite {
25 
26 enum class QuantizedBufferType { NONE, INT8, FLOAT16 };
27 
28 // The parameters for exporting a TFLite model.
29 struct ExportParams {
30   bool allow_custom_ops = false;
31   bool allow_dynamic_tensors = true;
32   bool enable_select_tf_ops = false;
33   QuantizedBufferType quantize_weights = QuantizedBufferType::NONE;
34 };
35 
36 // Transform the given tf.mini model into a TF Lite flatbuffer and deposit the
37 // result in the given string.
38 tensorflow::Status Export(const Model& model, std::string* output_file_contents,
39                           const ExportParams& params);
40 
41 // Export API with custom TFLite operator mapping.
42 tensorflow::Status Export(
43     const Model& model, std::string* output_file_contents,
44     const ExportParams& params,
45     const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type);
46 
47 // This is for backward-compatibility.
48 // TODO(ycling): Remove the deprecated entry functions.
Export(const Model & model,bool allow_custom_ops,bool quantize_weights,std::string * output_file_contents)49 inline void Export(const Model& model, bool allow_custom_ops,
50                    bool quantize_weights, std::string* output_file_contents) {
51   ExportParams params;
52   params.allow_custom_ops = allow_custom_ops;
53   params.quantize_weights =
54       quantize_weights ? QuantizedBufferType::INT8 : QuantizedBufferType::NONE;
55   auto status = Export(model, output_file_contents, params);
56   if (!status.ok()) LOG(QFATAL) << status.error_message();
57 }
58 
59 // This is for backward-compatibility.
60 // TODO(ycling): Remove the deprecated entry functions.
Export(const Model & model,bool allow_custom_ops,bool quantize_weights,std::string * output_file_contents,const std::map<OperatorType,std::unique_ptr<BaseOperator>> & ops_by_type)61 inline void Export(
62     const Model& model, bool allow_custom_ops, bool quantize_weights,
63     std::string* output_file_contents,
64     const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type) {
65   ExportParams params;
66   params.allow_custom_ops = allow_custom_ops;
67   params.quantize_weights =
68       quantize_weights ? QuantizedBufferType::INT8 : QuantizedBufferType::NONE;
69   auto status = Export(model, output_file_contents, params, ops_by_type);
70   if (!status.ok()) LOG(QFATAL) << status.error_message();
71 }
72 
73 // This is for backward-compatibility.
74 // TODO(ycling): Remove the deprecated entry functions.
Export(const Model & model,std::string * output_file_contents)75 inline void Export(const Model& model, std::string* output_file_contents) {
76   ExportParams params;
77   params.allow_custom_ops = true;
78   auto status = Export(model, output_file_contents, params);
79   if (!status.ok()) LOG(QFATAL) << status.error_message();
80 }
81 
82 namespace details {
83 
84 // A map from tensor name to its final position in the TF Lite buffer.
85 using TensorsMap = std::unordered_map<std::string, int>;
86 
87 // A key to identify an operator.
88 // Only when `type` is `kUnsupported`, `custom_code` is filled to
89 // identify which operation is used.
90 class OperatorKey {
91  public:
OperatorKey()92   OperatorKey() {}
93 
94   // Construct OperatorKey by Toco op.
95   OperatorKey(
96       const ::toco::OperatorSignature& op_signature,
97       const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
98       bool enable_select_tf_ops);
99 
100   // Construct OperatorKey by type, custom code and version.
101   // Note that this construct doesn't set the additional information including
102   // `is_custom_op`, `is_flex_op`, `is_unsupported_flex_op`.
OperatorKey(::tflite::BuiltinOperator type,const std::string & custom_code,int version)103   OperatorKey(::tflite::BuiltinOperator type, const std::string& custom_code,
104               int version)
105       : type_(type), custom_code_(custom_code), version_(version) {}
106 
107   // Only `type`, `custom_code` and `version` is used to compute hash and
108   // identity.
type()109   ::tflite::BuiltinOperator type() const { return type_; }
custom_code()110   const std::string& custom_code() const { return custom_code_; }
version()111   int version() const { return version_; }
112 
113   // The attributes below are not used to compute hash and identity.
114   //
115   // Return true if the op is a custom op. Note it will return false for Flex
116   // ops.
is_custom_op()117   bool is_custom_op() const { return is_custom_op_; }
118   // Return true if the op is a Flex op.
is_flex_op()119   bool is_flex_op() const { return is_flex_op_; }
120   // Return true if the op is a Flex op but it's knwon that the op is not
121   // supported by Flex runtime.
is_unsupported_flex_op()122   bool is_unsupported_flex_op() const { return is_unsupported_flex_op_; }
123   // Return the original TensorFlow op name for a Flex op.
flex_tensorflow_op()124   const std::string& flex_tensorflow_op() const { return flex_tensorflow_op_; }
125 
126   bool operator<(const OperatorKey& other) const {
127     if (type_ < other.type_)
128       return true;
129     else if (type_ > other.type_)
130       return false;
131     else if (custom_code_ < other.custom_code_)
132       return true;
133     else if (custom_code_ > other.custom_code_)
134       return false;
135     else
136       return version_ < other.version_;
137   }
138 
139   bool operator==(const OperatorKey& other) const {
140     return type_ == other.type_ && custom_code_ == other.custom_code_ &&
141            version_ == other.version_;
142   }
143 
144   struct Hash {
operatorHash145     size_t operator()(const OperatorKey& key) const {
146       return ::tflite::CombineHashes(
147           {std::hash<size_t>()(static_cast<size_t>(key.type())),
148            std::hash<std::string>()(key.custom_code()),
149            std::hash<int>()(key.version())});
150     }
151   };
152 
153  private:
154   ::tflite::BuiltinOperator type_ = ::tflite::BuiltinOperator_CUSTOM;
155   std::string custom_code_;
156   int version_ = 1;
157 
158   bool is_custom_op_ = false;
159   bool is_flex_op_ = false;
160   bool is_unsupported_flex_op_ = false;
161   // The original TensorFlow op name for the flex op. Filled only when
162   // `is_flex_op` is true.
163   std::string flex_tensorflow_op_;
164 };
165 
166 // A map from OperatorKey to its final position in the TF Lite buffer.
167 using OperatorsMap = std::unordered_map<OperatorKey, int, OperatorKey::Hash>;
168 
169 void LoadTensorsMap(const Model& model, TensorsMap* tensors_map);
170 void LoadOperatorsMap(
171     const Model& model, OperatorsMap* operators_map,
172     const std::map<OperatorType, std::unique_ptr<BaseOperator>>& ops_by_type,
173     bool enable_select_tf_ops);
174 
175 }  // namespace details
176 }  // namespace tflite
177 }  // namespace toco
178 
179 #endif  // TENSORFLOW_LITE_TOCO_TFLITE_EXPORT_H_
180