• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_LITE_TOCO_TFLITE_OPERATOR_H_
16 #define TENSORFLOW_LITE_TOCO_TFLITE_OPERATOR_H_
17 
18 #include "flatbuffers/flatbuffers.h"
19 #include "flatbuffers/flexbuffers.h"
20 #include "tensorflow/lite/schema/schema_generated.h"
21 #include "tensorflow/lite/toco/model.h"
22 #include "tensorflow/lite/tools/versioning/op_version.h"
23 
24 namespace toco {
25 
26 namespace tflite {
27 
28 class BaseOperator;
29 
30 // Return a map contained all know TF Lite Operators, keyed by their names.
31 // TODO(ycling): The pattern to propagate parameters (e.g. enable_select_tf_ops)
32 // is ugly here. Consider refactoring.
33 std::map<string, std::unique_ptr<BaseOperator>> BuildOperatorByNameMap(
34     bool enable_select_tf_ops = false);
35 
36 // Return a map contained all know TF Lite Operators, keyed by the type of
37 // their tf.mini counterparts.
38 std::map<OperatorType, std::unique_ptr<BaseOperator>> BuildOperatorByTypeMap(
39     bool enable_select_tf_ops = false);
40 
41 // Write the custom option FlexBuffer with a serialized TensorFlow NodeDef
42 // for a Flex op.
43 std::unique_ptr<flexbuffers::Builder> WriteFlexOpOptions(
44     const string& tensorflow_node_def);
45 
46 // These are the flatbuffer types for custom and builtin options.
47 using CustomOptions = flatbuffers::Vector<uint8_t>;
48 using BuiltinOptions = void;
49 
50 // A simple wrapper around the flatbuffer objects used to describe options that
51 // configure operators.
52 struct Options {
53   // Build custom options.
CustomOptions54   static Options Custom(flatbuffers::Offset<CustomOptions> offset) {
55     return {::tflite::BuiltinOptions_NONE, 0, offset};
56   }
57 
58   // Build builtin options of the given type.
BuiltinOptions59   static Options Builtin(::tflite::BuiltinOptions type,
60                          flatbuffers::Offset<BuiltinOptions> offset) {
61     return {type, offset, 0};
62   }
63 
64   ::tflite::BuiltinOptions type;
65   flatbuffers::Offset<BuiltinOptions> builtin;
66   flatbuffers::Offset<CustomOptions> custom;
67 };
68 
69 // A BaseOperator encapsulates the relationship between operators in tf.mini
70 // and TF lite, and provides methods for converting between those two formats.
71 class BaseOperator {
72  public:
73   // Build an operator with the given TF Lite name and tf.mini type.
BaseOperator(const string & name,OperatorType type)74   BaseOperator(const string& name, OperatorType type)
75       : name_(name), type_(type) {}
76   virtual ~BaseOperator() = default;
77 
name()78   string name() const { return name_; }
type()79   OperatorType type() const { return type_; }
80 
81   // Given a tf.mini operator, create the corresponding flatbuffer options and
82   // return their offsets.
83   virtual Options Serialize(const Operator& op,
84                             flatbuffers::FlatBufferBuilder* builder) const = 0;
85 
86   // Read TF Lite options and create the appropriate tf.mini operator.
87   virtual std::unique_ptr<Operator> Deserialize(
88       const BuiltinOptions* builtin_options,
89       const CustomOptions* custom_options) const = 0;
90 
91   // Get the op version using the OperatorSignature.
92   // The function needs to be overridden to return the op version based on the
93   // parameters. Note:
94   // * The first version for each op should be 1 (to be consistent with the
95   //   default value in Flatbuffer. `return 1;` is okay for newly implemented
96   //   ops.
97   // * When multiple versions are defined for an op, this function could be
98   //   overridden. (See example in `operator_test.cc` and
99   //   'tools/versioning/op_version.cc`)
100   virtual int GetVersion(const OperatorSignature& op_signature) const = 0;
101 
102   // Given a Toco `Operator`, return a list of booleans indicating the op
103   // mutates which input variables.
104   // * If the op mutates any input variables, it should return a list of bool
105   //   with the same length as inputs.
106   // * Otherwise, it will return an empty list.
GetMutatingInputVariables(const Operator & op)107   virtual std::vector<bool> GetMutatingInputVariables(
108       const Operator& op) const {
109     // Most ops don't have variable tensors. This function can be overridden.
110     return std::vector<bool>();
111   }
112 
113  private:
114   string name_;
115   OperatorType type_;
116 };
117 
118 // Helper function to create ::tflite::OpSignature from the given
119 // ::tflite::BuiltinOperator and OperatorSignature.
120 ::tflite::OpSignature GetVersioningOpSig(const ::tflite::BuiltinOperator op,
121                                          const OperatorSignature& op_signature);
122 
123 // Helper function to determine if a unsupported TensorFlow op should be
124 // exported as an Flex op or a regular custom op.
125 bool ShouldExportAsFlexOp(bool enable_select_tf_ops,
126                           const string& tensorflow_op_name);
127 
128 }  // namespace tflite
129 
130 }  // namespace toco
131 
132 #endif  // TENSORFLOW_LITE_TOCO_TFLITE_OPERATOR_H_
133