1# Custom operators 2 3TensorFlow Lite currently supports a subset of TensorFlow operators. It supports 4the use of user-provided implementations (as known as custom implementations) if 5the model contains an operator that is not supported. Providing custom kernels 6is also a way of evaluating a series of TensorFlow operations as a single fused 7TensorFlow Lite operations. 8 9Using custom operators consists of three steps. 10 11* Making sure the TensorFlow Graph Def or SavedModel refers to the correctly 12 named TensorFlow Lite operator. 13 14* Registering a custom kernel with TensorFlow Lite so that the runtime knows 15 how to map your operator and parameters in your graph to executable C/C++ 16 code. 17 18* Testing and profiling your operator correctness and performance, 19 respectively. If you wish to test just your custom operator it is best to 20 create a model with just your custom operator and using the 21 [benchmark_model](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/tools/benchmark/benchmark_model_test.cc) 22 proggram 23 24Below we describe a complete example of defining Sin and some links to existing 25conversion process involving custom operators. 26 27## Making a custom operator for Sin 28 29Let’s walk through this an example of supporting a TensorFlow operator that 30TensorFlow Lite does not have. Assume we are using the `Sin` operator and that 31we are building a very simple model for a function `y = sin(x + offset)`, where 32`offset` is trainable. 33 34### Generating the model from TensorFlow 35 36The code to train the TensorFlow model will be something like: 37 38```python 39offset = tf.get_variable("offset", [1,], tf.float32) 40x = tf.placeholder(tf.float32, shape=(None,)) 41y = tf.sin(x + offset) 42y_ = tf.placeholder(tf.float32, shape=(None,)) 43loss = tf.reduce_sum(tf.square(y - y_)) 44optimizer = tf.train.GradientDescentOptimizer(0.001) 45train = optimizer.minimize(loss) 46``` 47 48If you convert this model to Tensorflow Lite format using the TensorFlow Lite 49Optimizing Converter with `--allow_custom_ops` argument, and run it with the 50default interpreter, the interpreter will raise the following error messages: 51 52``` 53Didn't find custom op for name 'Sin' 54Registration failed. 55``` 56 57### Defining the kernel in the TensorFlow Lite runtime 58 59All we need to do to use the op in TensorFlow Lite is define two functions 60(`Prepare` and `Eval`), and construct a `TfLiteRegistration`. This code would 61look something like this: 62 63```cpp 64TfLiteStatus SinPrepare(TfLiteContext* context, TfLiteNode* node) { 65 using namespace tflite; 66 TF_LITE_ENSURE_EQ(context, NumInputs(node), 1); 67 TF_LITE_ENSURE_EQ(context, NumOutputs(node), 1); 68 69 const TfLiteTensor* input = GetInput(context, node, 0); 70 TfLiteTensor* output = GetOutput(context, node, 0); 71 72 int num_dims = NumDimensions(input); 73 74 TfLiteIntArray* output_size = TfLiteIntArrayCreate(num_dims); 75 for (int i=0; i<num_dims; ++i) { 76 output_size->data[i] = input->dims->data[i]; 77 } 78 79 return context->ResizeTensor(context, output, output_size); 80} 81 82TfLiteStatus SinEval(TfLiteContext* context, TfLiteNode* node) { 83 using namespace tflite; 84 const TfLiteTensor* input = GetInput(context, node,0); 85 TfLiteTensor* output = GetOutput(context, node,0); 86 87 float* input_data = input->data.f; 88 float* output_data = output->data.f; 89 90 size_t count = 1; 91 int num_dims = NumDimensions(input); 92 for (int i = 0; i < num_dims; ++i) { 93 count *= input->dims->data[i]; 94 } 95 96 for (size_t i=0; i<count; ++i) { 97 output_data[i] = sin(input_data[i]); 98 } 99 return kTfLiteOk; 100} 101 102TfLiteRegistration* Register_SIN() { 103 static TfLiteRegistration r = {nullptr, nullptr, SinPrepare, SinEval}; 104 return &r; 105} 106``` 107 108When initializing the `OpResolver`, add the custom op into the resolver, this 109will register the operator with Tensorflow Lite so that TensorFlow Lite can use 110the new implementation. Note that the last two arguments in TfLiteRegistration 111correspond to the `SinPrepare` and `SinEval()` functions you defined for the 112custom op. If you used two functions to initialize variables used in the op and 113free up space: `Init()` and `Free()`, then they would be added to the first two 114arguments of TfLiteRegistration; they are set to nullptr in this example. 115 116```cpp 117tflite::ops::builtin::BuiltinOpResolver builtins; 118builtins.AddCustom("Sin", Register_SIN()); 119``` 120 121If you want to make your custom operators in Java, you would currently need to 122build your own custom JNI layer and compile your own AAR 123[in this jni code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/java/src/main/native/builtin_ops_jni.cc). 124Similarly, if you wish to make these operators available in Python you can place 125your registrations in the 126[Python wrapper code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.cc). 127 128Note that a similar process as above can be followed for supporting for a set of 129operations instead of a single operator. Just add as many `AddCustom` operators 130as you need. In addition, `BuiltinOpResolver` also allows you to override 131implementations of builtins by using the `AddBuiltin`. 132 133## Best Practices 134 135### Writing TensorFlow Lite kernels best practices 136 1371. Optimize memory allocations and de-allocations cautiously. It is more 138 efficient to allocate memory in Prepare() instead of Invoke(), and allocate 139 memory before a loop instead of in every iteration. Use temporary tensors 140 data rather than mallocing yourself (see item 2). Use pointers/references 141 instead of copying as much as possible. 142 1432. If a data structure will persist during the entire operation, we advise 144 pre-allocating the memory using temporary tensors. You may need to use 145 OpData struct to reference the tensor indices in other functions. See 146 example in the 147 [kernel for convolution](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/conv.cc). 148 A sample code snippet is below 149 150 ``` 151 auto* op_data = reinterpret_cast<OpData*>(node->user_data); 152 TfLiteIntArrayFree(node->temporaries); 153 node->temporaries = TfLiteIntArrayCreate(1); 154 node->temporaries->data[0] = op_data->temp_tensor_index; 155 TfLiteTensor* temp_tensor = &context->tensors[op_data->temp_tensor_index]; 156 temp_tensor->type = kTfLiteFloat32; 157 temp_tensor->allocation_type = kTfLiteArenaRw; 158 ``` 159 1603. If it doesn't cost too much wasted memory, prefer using a static fixed size 161 array (or in Resize() pre-allocated std::vector) rather than using a 162 dynamically allocating std::vector every iteration of execution. 163 1644. Avoid instantiating standard library container templates that don't already 165 exist, because they affect binary size. For example, if you need a std::map 166 in your operation that doesn't exist in other kernels, using a std::vector 167 with direct indexing mapping could work while keeping the binary size small. 168 See what other kernels use to gain insight (or ask). 169 1705. Check the pointer to the memory returned by malloc. If this pointer is 171 nullptr, no operations should be performed using that pointer. If you 172 malloc() in a function and have an error exit, deallocate memory before you 173 exit. 174 1756. Use TF_LITE_ENSURE(context, condition) to check for a specific condition. 176 Your code must not leave memory hanging when TF_LITE_ENSURE is done, i.e., 177 these should be done before any resources are allocated that will leak. 178 179### Conversion best practices 180 181The example above was easy to convert since it was a builtin operator in 182TensorFlow. If you are defining a new operator that fuses many operators or you 183have complicated shapes or types, you might need to provide more information and 184use graph transformations to rewrite an existing graph to use your operator 185instead of the builtin TensorFlow one. 186 187#### Converting TensorFlow models to convert graphs 188 189In TensorFlow you can use the `tf.lite.OpHint` class to encapsulate groups of 190operators when you create a TensorFlow graph. This allows you then to extract a 191graph def that has references to those operators. This is currently experimental 192and should only be used by advanced users. There is a full example of how to use 193this in the 194[OpHint code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/python/op_hint.py). 195 196In addition, you can also use a manual graph substitution approach to rewrite 197Tensorflow graphs. There is an example of how this is done in single shot object 198based detection models 199[export script](https://github.com/tensorflow/models/blob/master/research/object_detection/export_tflite_ssd_graph.py). 200 201### TF Graph Attributes 202 203When `tflite_convert` converts a TensorFlow graph into TFLite format, it makes 204some assumption about custom operations that might not be correct. In this case, 205the generated graph may not execute. 206 207It is possible to add additional information about your custom op output to TF 208graph before it is converted. The following attributes are supported: 209 210- **_output_quantized** a boolean attribute, true if the operation outputs are 211 quantized 212- **_output_types** a list of types for output tensors 213- **_output_shapes** a list of shapes for output tensors 214 215#### Setting the Attributes 216 217This is an example how the attributes can be set: 218 219```python 220frozen_graph_def = tf.graph_util.convert_variables_to_constants(...) 221for node in frozen_graph_def.node: 222 if node.op == 'sin': 223 node.attr['_output_types'].list.type.extend([ 224 types_pb2.DT_FLOAT, 225 ]) 226 node.attr['_output_shapes'].list.shape.extend([ 227 tf.TensorShape([10]), 228 ]) 229 node.attr['_output_quantized'].b = False 230tflite_model = tf.lite.toco_convert( 231 frozen_graph_def,...) 232``` 233 234**Note:** After the attributes are set, the graph can not be executed by 235Tensorflow, therefore it should be done just before the conversion. 236