• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Integrate BERT natural language classifier
2
3The Task Library `BertNLClassifier` API is very similar to the `NLClassifier`
4that classifies input text into different categories, except that this API is
5specially tailored for Bert related models that require Wordpiece and
6Sentencepiece tokenizations outside the TFLite model.
7
8## Key features of the BertNLClassifier API
9
10*   Takes a single string as input, performs classification with the string and
11    outputs <Label, Score> pairs as classification results.
12
13*   Performs out-of-graph
14    [Wordpiece](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/cc/text/tokenizers/bert_tokenizer.h)
15    or
16    [Sentencepiece](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/cc/text/tokenizers/sentencepiece_tokenizer.h)
17    tokenizations on input text.
18
19## Supported BertNLClassifier models
20
21The following models are compatible with the `BertNLClassifier` API.
22
23*   Bert Models created by
24    [TensorFlow Lite Model Maker for text Classfication](https://www.tensorflow.org/lite/models/modify/model_maker/text_classification).
25
26*   Custom models that meet the
27    [model compatibility requirements](#model-compatibility-requirements).
28
29## Run inference in Java
30
31### Step 1: Import Gradle dependency and other settings
32
33Copy the `.tflite` model file to the assets directory of the Android module
34where the model will be run. Specify that the file should not be compressed, and
35add the TensorFlow Lite library to the module’s `build.gradle` file:
36
37```java
38android {
39    // Other settings
40
41    // Specify tflite file should not be compressed for the app apk
42    aaptOptions {
43        noCompress "tflite"
44    }
45
46}
47
48dependencies {
49    // Other dependencies
50
51    // Import the Task Text Library dependency (NNAPI is included)
52    implementation 'org.tensorflow:tensorflow-lite-task-text:0.3.0'
53}
54```
55
56Note: starting from version 4.1 of the Android Gradle plugin, .tflite will be
57added to the noCompress list by default and the aaptOptions above is not needed
58anymore.
59
60### Step 2: Run inference using the API
61
62```java
63// Initialization
64BertNLClassifierOptions options =
65    BertNLClassifierOptions.builder()
66        .setBaseOptions(BaseOptions.builder().setNumThreads(4).build())
67        .build();
68BertNLClassifier classifier =
69    BertNLClassifier.createFromFileAndOptions(context, modelFile, options);
70
71// Run inference
72List<Category> results = classifier.classify(input);
73```
74
75See the
76[source code](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/text/nlclassifier/BertNLClassifier.java)
77for more details.
78
79## Run inference in Swift
80
81### Step 1: Import CocoaPods
82
83Add the TensorFlowLiteTaskText pod in Podfile
84
85```
86target 'MySwiftAppWithTaskAPI' do
87  use_frameworks!
88  pod 'TensorFlowLiteTaskText', '~> 0.2.0'
89end
90```
91
92### Step 2: Run inference using the API
93
94```swift
95// Initialization
96let bertNLClassifier = TFLBertNLClassifier.bertNLClassifier(
97      modelPath: bertModelPath)
98
99// Run inference
100let categories = bertNLClassifier.classify(text: input)
101```
102
103See the
104[source code](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/ios/task/text/nlclassifier/Sources/TFLBertNLClassifier.h)
105for more details.
106
107## Run inference in C++
108
109```c++
110// Initialization
111BertNLClassifierOptions options;
112options.mutable_base_options()->mutable_model_file()->set_file_name(model_path);
113std::unique_ptr<BertNLClassifier> classifier = BertNLClassifier::CreateFromOptions(options).value();
114
115// Run inference with your input, `input_text`.
116std::vector<core::Category> categories = classifier->Classify(input_text);
117```
118
119See the
120[source code](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/cc/task/text/bert_nl_classifier.h)
121for more details.
122
123## Example results
124
125Here is an example of the classification results of movie reviews using the
126[MobileBert](https://www.tensorflow.org/lite/models/modify/model_maker/text_classification)
127model from Model Maker.
128
129Input: "it's a charming and often affecting journey"
130
131Output:
132
133```
134category[0]: 'negative' : '0.00006'
135category[1]: 'positive' : '0.99994'
136```
137
138Try out the simple
139[CLI demo tool for BertNLClassifier](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/examples/task/text/desktop/README.md#bertnlclassifier)
140with your own model and test data.
141
142## Model compatibility requirements
143
144The `BetNLClassifier` API expects a TFLite model with mandatory
145[TFLite Model Metadata](../../models/convert/metadata.md).
146
147The Metadata should meet the following requirements:
148
149*   input_process_units for Wordpiece/Sentencepiece Tokenizer
150
151*   3 input tensors with names "ids", "mask" and "segment_ids" for the output of
152    the tokenizer
153
154*   1 output tensor of type float32, with a optionally attached label file. If a
155    label file is attached, the file should be a plain text file with one label
156    per line and the number of labels should match the number of categories as
157    the model outputs.
158