README.md
1# TFLite Task library - C++
2
3A flexible and ready-to-use library for common machine learning model types,
4such as classification and detection.
5
6## Text Task Libraries
7
8### QuestionAnswerer
9
10`QuestionAnswerer` API is able to load
11[Mobile BERT](https://tfhub.dev/tensorflow/mobilebert/1) or
12[AlBert](https://tfhub.dev/tensorflow/albert_lite_base/1) TFLite models and
13answer question based on context.
14
15Use the C++ API to answer questions as follows:
16
17```cc
18using tflite::task::text::qa::BertQuestionAnswerer;
19using tflite::task::text::qa::QaAnswer;
20// Create API handler with Mobile Bert model.
21auto qa_client = BertQuestionAnswerer::CreateBertQuestionAnswererFromFile("/path/to/mobileBertModel", "/path/to/vocab");
22// Or create API handler with Albert model.
23// auto qa_client = BertQuestionAnswerer::CreateAlbertQuestionAnswererFromFile("/path/to/alBertModel", "/path/to/sentencePieceModel");
24
25
26std::string context =
27 "Nikola Tesla (Serbian Cyrillic: Никола Тесла; 10 "
28 "July 1856 – 7 January 1943) was a Serbian American inventor, electrical "
29 "engineer, mechanical engineer, physicist, and futurist best known for his "
30 "contributions to the design of the modern alternating current (AC) "
31 "electricity supply system.";
32std::string question = "When was Nikola Tesla born?";
33// Run inference with `context` and a given `question` to the context, and get top-k
34// answers ranked by logits.
35const std::vector<QaAnswer> answers = qa_client->Answer(context, question);
36// Access QaAnswer results.
37for (const QaAnswer& item : answers) {
38 std::cout << absl::StrFormat("Text: %s logit=%f start=%d end=%d", item.text,
39 item.pos.logit, item.pos.start, item.pos.end)
40 << std::endl;
41}
42// Output:
43// Text: 10 July 1856 logit=16.8527 start=17 end=19
44// ... (and more)
45//
46// So the top-1 answer is: "10 July 1856".
47```
48
49In the above code, `item.text` is the text content of an answer. We use a span
50with closed interval `[item.pos.start, item.pos.end]` to denote predicted tokens
51in the answer, and `item.pos.logit` is the sum of span logits to represent the
52confidence score.
53
54### NLClassifier
55
56`NLClassifier` API is able to load any TFLite models for natural language
57classaification task such as language detection or sentiment detection.
58
59The API expects a TFLite model with the following input/output tensor:
60Input tensor0:
61 (kTfLiteString) - input of the model, accepts a string.
62Output tensor0:
63 (kTfLiteUInt8/kTfLiteInt8/kTfLiteInt16/kTfLiteFloat32/kTfLiteFloat64)
64 - output scores for each class, if type is one of the Int types,
65 dequantize it to double
66Output tensor1: optional
67 (kTfLiteString)
68 - output classname for each class, should be of the same length with
69 scores. If this tensor is not present, the API uses score indices as
70 classnames.
71By default the API tries to find the input/output tensors with default
72configurations in NLClassifierOptions, with tensor name prioritized over
73tensor index. The option is configurable for different TFLite models.
74
75Use the C++ API to perform language ID classification as follows:
76
77```cc
78using tflite::task::text::nlclassifier::NLClassifier;
79using tflite::task::core::Category;
80auto classifier = NLClassifier::CreateFromFileAndOptions("/path/to/model");
81// Or create a customized NLClassifierOptions
82// NLClassifierOptions options =
83// {
84// .output_score_tensor_name = myOutputScoreTensorName,
85// .output_label_tensor_name = myOutputLabelTensorName,
86// }
87// auto classifier = NLClassifier::CreateFromFileAndOptions("/path/to/model", options);
88std::string context = "What language is this?";
89std::vector<Category> categories = classifier->Classify(context);
90// Access category results.
91for (const Categoryr& category : categories) {
92 std::cout << absl::StrFormat("Language: %s Probability: %f", category.class_name, category_.score)
93 << std::endl;
94}
95// Output:
96// Language: en Probability=0.9
97// ... (and more)
98//
99// So the top-1 answer is 'en'.
100```
101
102## Vision Task Libraries
103
104### Image Classifier
105
106`ImageClassifier` accepts any TFLite image classification model (with optional,
107but strongly recommended, TFLite Model Metadata) that conforms to the following
108spec:
109
110Input tensor (type: `kTfLiteUInt8` / `kTfLiteFloat32`):
111
112 - image input of size `[batch x height x width x channels]`.
113 - batch inference is not supported (`batch` is required to be 1).
114 - only RGB inputs are supported (`channels` is required to be 3).
115 - if type is `kTfLiteFloat32`, `NormalizationOptions` are required to be
116 attached to the metadata for input normalization.
117
118At least one output tensor (type: `kTfLiteUInt8` / `kTfLiteFloat32`) with:
119
120 - `N` classes and either 2 or 4 dimensions, i.e. `[1 x N]` or
121 `[1 x 1 x 1 x N]`
122 - optional (but recommended) label map(s) as AssociatedFile-s with type
123 TENSOR_AXIS_LABELS, containing one label per line. The first such
124 AssociatedFile (if any) is used to fill the `class_name` field of the
125 results. The `display_name` field is filled from the AssociatedFile (if
126 any) whose locale matches the `display_names_locale` field of the
127 `ImageClassifierOptions` used at creation time ("en" by default, i.e.
128 English). If none of these are available, only the `index` field of the
129 results will be filled.
130
131An example of such model can be found at:
132https://tfhub.dev/bohemian-visual-recognition-alliance/lite-model/models/mushroom-identification_v1/1
133
134Example usage:
135
136```cc
137// More options are available (e.g. max number of results to return). At the
138// very least, the model must be specified:
139ImageClassifierOptions options;
140options.mutable_model_file_with_metadata()->set_file_name(
141 "/path/to/model.tflite");
142
143// Create an ImageClassifier instance from the options.
144StatusOr<std::unique_ptr<ImageClassifier>> image_classifier_or =
145 ImageClassifier::CreateFromOptions(options);
146// Check if an error occurred.
147if (!image_classifier_or.ok()) {
148 std::cerr << "An error occurred during ImageClassifier creation: "
149 << image_classifier_or.status().message();
150 return;
151}
152std::unique_ptr<ImageClassifier> image_classifier =
153 std::move(image_classifier_or.value());
154
155// Prepare FrameBuffer input from e.g. image RGBA data, width and height:
156std::unique_ptr<FrameBuffer> frame_buffer =
157 CreateFromRgbaRawBuffer(image_rgba_data, {image_width, image_height});
158
159// Run inference:
160StatusOr<ClassificationResult> result_or =
161 image_classifier->Classify(*frame_buffer);
162// Check if an error occurred.
163if (!result_or.ok()) {
164 std::cerr << "An error occurred during classification: "
165 << result_or.status().message();
166 return;
167}
168ClassificationResult result = result_or.value();
169
170// Example value for 'result':
171//
172// classifications {
173// classes { index: 934 score: 0.95 class_name: "cat" }
174// classes { index: 948 score: 0.007 class_name: "dog" }
175// classes { index: 927 score: 0.003 class_name: "fox" }
176// head_index: 0
177// }
178```
179
180A CLI demo tool is also available [here][1] for easily trying out this API.
181
182### Object Detector
183
184`ObjectDetector` accepts any object detection TFLite model (with mandatory
185TFLite Model Metadata) that conforms to the following spec (e.g. Single Shot
186Detectors):
187
188Input tensor (type: `kTfLiteUInt8` / `kTfLiteFloat32`):
189
190 - image input of size `[batch x height x width x channels]`.
191 - batch inference is not supported (`batch` is required to be 1).
192 - only RGB inputs are supported (`channels` is required to be 3).
193 - if type is kTfLiteFloat32, `NormalizationOptions` are required to be
194 attached to the metadata for input normalization.
195
196Output tensors must be the 4 outputs (type: `kTfLiteFloat32`) of a
197[`DetectionPostProcess`][2] op, i.e:
198
199* Locations:
200
201 - of size `[num_results x 4]`, the inner array
202 representing bounding boxes in the form [top, left, right, bottom].
203 - BoundingBoxProperties are required to be attached to the metadata
204 and must specify type=BOUNDARIES and coordinate_type=RATIO.
205
206* Classes:
207
208 - of size `[num_results]`, each value representing the
209 integer index of a class.
210 - optional (but recommended) label map(s) can be attached as
211 AssociatedFile-s with type TENSOR_VALUE_LABELS, containing one label per
212 line. The first such AssociatedFile (if any) is used to fill the
213 `class_name` field of the results. The `display_name` field is filled
214 from the AssociatedFile (if any) whose locale matches the
215 `display_names_locale` field of the `ObjectDetectorOptions` used at
216 creation time ("en" by default, i.e. English). If none of these are
217 available, only the `index` field of the results will be filled.
218
219* Scores:
220
221 - of size `[num_results]`, each value representing the score
222 of the detected object.
223
224* Number of results:
225
226 - integer `num_results` as a tensor of size `[1]`
227
228An example of such model can be found at:
229https://tfhub.dev/google/lite-model/object_detection/mobile_object_localizer_v1/1/metadata/1
230
231Example usage:
232
233```cc
234// More options are available (e.g. max number of results to return). At the
235// very least, the model must be specified:
236ObjectDetectorOptions options;
237options.mutable_model_file_with_metadata()->set_file_name(
238 "/path/to/model.tflite");
239
240// Create an ObjectDetector instance from the options.
241StatusOr<std::unique_ptr<ObjectDetector>> object_detector_or =
242 ObjectDetector::CreateFromOptions(options);
243// Check if an error occurred.
244if (!object_detector_or.ok()) {
245 std::cerr << "An error occurred during ObjectDetector creation: "
246 << object_detector_or.status().message();
247 return;
248}
249std::unique_ptr<ObjectDetector> object_detector =
250 std::move(object_detector_or.value());
251
252// Prepare FrameBuffer input from e.g. image RGBA data, width and height:
253std::unique_ptr<FrameBuffer> frame_buffer =
254 CreateFromRgbaRawBuffer(image_rgba_data, {image_width, image_height});
255
256// Run inference:
257StatusOr<DetectionResult> result_or = object_detector->Detect(*frame_buffer);
258// Check if an error occurred.
259if (!result_or.ok()) {
260 std::cerr << "An error occurred during detection: "
261 << result_or.status().message();
262 return;
263}
264DetectionResult result = result_or.value();
265
266// Example value for 'result':
267//
268// detections {
269// bounding_box {
270// origin_x: 54
271// origin_y: 398
272// width: 393
273// height: 196
274// }
275// classes { index: 16 score: 0.65 class_name: "cat" }
276// }
277// detections {
278// bounding_box {
279// origin_x: 602
280// origin_y: 157
281// width: 394
282// height: 447
283// }
284// classes { index: 17 score: 0.45 class_name: "dog" }
285// }
286```
287
288A CLI demo tool is available [here][3] for easily trying out this API.
289
290### Image Segmenter
291
292`ImageSegmenter` accepts any TFLite model (with optional, but strongly
293recommended, TFLite Model Metadata) that conforms to the following spec:
294
295Input tensor (type: `kTfLiteUInt8` / `kTfLiteFloat32`):
296
297 - image input of size `[batch x height x width x channels]`.
298 - batch inference is not supported (`batch` is required to be 1).
299 - only RGB inputs are supported (`channels` is required to be 3).
300 - if type is kTfLiteFloat32, `NormalizationOptions` are required to be
301 attached to the metadata for input normalization.
302
303Output tensor (type: `kTfLiteUInt8` / `kTfLiteFloat32`):
304
305 - tensor of size `[batch x mask_height x mask_width x num_classes]`, where
306 `batch` is required to be 1, `mask_width` and `mask_height` are the
307 dimensions of the segmentation masks produced by the model, and
308 `num_classes` is the number of classes supported by the model.
309 - optional (but recommended) label map(s) can be attached as
310 AssociatedFile-s with type TENSOR_AXIS_LABELS, containing one label per
311 line. The first such AssociatedFile (if any) is used to fill the
312 `class_name` field of the results. The `display_name` field is filled
313 from the AssociatedFile (if any) whose locale matches the
314 `display_names_locale` field of the `ImageSegmenterOptions` used at
315 creation time ("en" by default, i.e. English). If none of these are
316 available, only the `index` field of the results will be filled.
317
318An example of such model can be found at:
319https://tfhub.dev/tensorflow/lite-model/deeplabv3/1/metadata/1
320
321Example usage:
322
323```cc
324// More options are available to select between return a single category mask
325// or multiple confidence masks during post-processing.
326ImageSegmenterOptions options;
327options.mutable_model_file_with_metadata()->set_file_name(
328 "/path/to/model.tflite");
329
330// Create an ImageSegmenter instance from the options.
331StatusOr<std::unique_ptr<ImageSegmenter>> image_segmenter_or =
332 ImageSegmenter::CreateFromOptions(options);
333// Check if an error occurred.
334if (!image_segmenter_or.ok()) {
335 std::cerr << "An error occurred during ImageSegmenter creation: "
336 << image_segmenter_or.status().message();
337 return;
338}
339std::unique_ptr<ImageSegmenter> immage_segmenter =
340 std::move(image_segmenter_or.value());
341
342// Prepare FrameBuffer input from e.g. image RGBA data, width and height:
343std::unique_ptr<FrameBuffer> frame_buffer =
344 CreateFromRgbaRawBuffer(image_rgba_data, {image_width, image_height});
345
346// Run inference:
347StatusOr<SegmentationResult> result_or =
348 immage_segmenter->Segment(*frame_buffer);
349// Check if an error occurred.
350if (!result_or.ok()) {
351 std::cerr << "An error occurred during segmentation: "
352 << result_or.status().message();
353 return;
354}
355SegmentationResult result = result_or.value();
356
357// Example value for 'result':
358//
359// segmentation {
360// width: 257
361// height: 257
362// category_mask: "\x00\x01..."
363// colored_labels { r: 0 g: 0 b: 0 class_name: "background" }
364// colored_labels { r: 128 g: 0 b: 0 class_name: "aeroplane" }
365// ...
366// colored_labels { r: 128 g: 192 b: 0 class_name: "train" }
367// colored_labels { r: 0 g: 64 b: 128 class_name: "tv" }
368// }
369//
370// Where 'category_mask' is a byte buffer of size 'width' x 'height', with the
371// value of each pixel representing the class this pixel belongs to (e.g. '\x00'
372// means "background", '\x01' means "aeroplane", etc).
373// 'colored_labels' provides the label for each possible value, as well as
374// suggested RGB components to optionally transform the result into a more
375// human-friendly colored image.
376//
377```
378
379A CLI demo tool is available [here][4] for easily trying out this API.
380
381[1]: https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/examples/task/vision/desktop/image_classifier_demo.cc
382[2]: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/lite/kernels/detection_postprocess.cc
383[3]: https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/examples/task/vision/desktop/object_detector_demo.cc
384[4]: https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/examples/task/vision/desktop/image_segmenter_demo.cc
385