1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_OP_BUILDER_H_ 16 #define TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_OP_BUILDER_H_ 17 18 #include <string> 19 20 #include "mlmodel/format/Model.pb.h" 21 #include "mlmodel/format/NeuralNetwork.pb.h" 22 #include "tensorflow/lite/c/common.h" 23 24 namespace tflite { 25 namespace delegates { 26 namespace coreml { 27 class OpBuilder; 28 29 // A class represents an ID in the coreML graph. 30 // A node is represented by a pair (node_id, and output_index) 31 // API is experimental and subject to change. 32 class TensorID { 33 public: TensorID()34 TensorID() {} TensorID(int node,int output_id)35 TensorID(int node, int output_id) : node_(node), output_id_(output_id) {} 36 37 std::string ToString() const; 38 39 int NodeID() const; 40 41 int OutputID() const; 42 43 private: 44 int node_ = -1; 45 int output_id_ = -1; 46 }; 47 48 // Builder for the whole graph. 49 // All op builders should be added using AddBuilder 50 // and then BuildModel should be called to return the CoreML generated. 51 // 52 // API is experimental and subject to change. 53 class GraphBuilder { 54 public: GraphBuilder(int coreml_version)55 explicit GraphBuilder(int coreml_version) : coreml_version_(coreml_version) {} 56 57 // Returns pointer to the created builder. Ownership still belongs 58 // to the GraphBuilder. 59 OpBuilder* AddBuilder(int builtin_code, const TfLiteNode* node); 60 61 // Returns pointer to the created builder with op builder function provided. 62 OpBuilder* AddBuilder(const std::function<OpBuilder*(GraphBuilder*)>& builder, 63 const TfLiteNode* node); 64 65 // Builds Model instance and returns it. 66 CoreML::Specification::Model* BuildModel(); 67 68 // Returns string representing tensor 'tensor_id' in coreML. 69 // tensor_id should have been added before calling this method. 70 std::string GetTensorName(int tensor_id); 71 72 // Returns Core ML Tensor ID for TFL 'tensor_id'. 73 // tensor_id should have been added before calling this method. 74 const TensorID GetTensorID(int tensor_id); 75 76 void AddTensorWithID(int tf_tensor_id, const TensorID& tensor_id); 77 78 // Return true if this tensor was added before to the graph. 79 bool HasTensor(int tflite_tensor_index); 80 // Return if this tensor is used in the graph (not as data). 81 // This information is used to mark constant tensors that are used as input. 82 bool IsTensorUsed(int tflite_tensor_index); 83 84 const int coreml_version_; 85 86 private: 87 std::vector<std::unique_ptr<OpBuilder>> builders_; 88 // Index in the vector is the tflite_tensor_index, the value 89 // is the ID in the coreml graph. 90 std::vector<TensorID> tensors_; 91 std::vector<bool> used_tensor_; 92 }; 93 94 // Interface for all op layers 95 // API is experimental and subject to change. 96 class OpBuilder { 97 public: OpBuilder(GraphBuilder * graph_builder)98 explicit OpBuilder(GraphBuilder* graph_builder) 99 : graph_builder_(graph_builder) {} ~OpBuilder()100 virtual ~OpBuilder() {} 101 102 // Returns the Layer this builder responsible for. 103 // Ownership is transferred to caller. 104 virtual CoreML::Specification::NeuralNetworkLayer* Build(); 105 106 // Associates TfLite input tensors to Core ML layer's inputs and properties. 107 // Verification for input constraints should happen here. 108 virtual TfLiteStatus RegisterInputs(const TfLiteIntArray* inputs, 109 TfLiteContext* context) = 0; 110 111 // Associates TFLite output tensor with the node's output. If the OpBuilder 112 // has subgraphs, The final output of that subgraph should be associated with 113 // the output tensor. 114 virtual TfLiteStatus RegisterOutputs(const TfLiteIntArray* outputs, 115 TfLiteContext* context) = 0; 116 117 // Adds additional required OpBuilders, and populate builder_output_ with 118 // Actual output that corresponds to output tensor of TFL Node. 119 // Clients need to override this in cases where the nodes can be used for 120 // composing other ops. For example, Relu6 in TfLite can be converted to 121 // Relu -> Threshold -> Neg. 122 // TODO(b/147211734): have this called automatically when necessary. 123 virtual TfLiteStatus PopulateSubgraph(TfLiteContext* context); 124 125 virtual const std::string& DebugName() = 0; 126 127 void SetBuiltinData(void* builtin_data); 128 129 void SetNodeID(int id); 130 131 void SetTfLiteNode(const TfLiteNode* node); 132 133 int GetID() const; 134 135 // Adds input with tensor name. 136 void AddInput(const std::string& input_name); 137 138 // Adds input with CoreML tensor ID. 139 void AddInput(const TensorID& input_id); 140 141 // Adds input with TF Lite tensor ID. 142 // TODO(taeheej): cleanup AddInput use cases and used tensor tracking. 143 void AddInput(int tf_input_id); 144 145 // Simply adds new output to the underlying layer. 146 TensorID AddOutput(); 147 148 // Should set builder_output_ (if unset) and return it as the output of 149 // this node. To be used by clients that needs the output of the node. 150 virtual TensorID GetOutput(TfLiteContext* context); 151 152 protected: 153 // Sets layer's name. 154 void SetDebugName(const char* layer_name, int id); 155 156 GraphBuilder* graph_builder_ = nullptr; 157 // Data needed by this node. 158 void* builtin_data_ = nullptr; 159 int node_id_ = -1; 160 int num_outputs_ = 0; 161 const TfLiteNode* tflite_node_ = nullptr; 162 TensorID builder_output_; 163 std::string debug_name_; 164 std::unique_ptr<CoreML::Specification::NeuralNetworkLayer> layer_; 165 }; 166 167 } // namespace coreml 168 } // namespace delegates 169 } // namespace tflite 170 171 #endif // TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_OP_BUILDER_H_ 172