• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Integrate audio classifiers
2
3Audio classification is a common use case of Machine Learning to classify the
4sound types. For example, it can identify the bird species by their songs.
5
6The Task Library `AudioClassifier` API can be used to deploy your custom audio
7classifiers or pretrained ones into your mobile app.
8
9## Key features of the AudioClassifier API
10
11*   Input audio processing, e.g. converting PCM 16 bit encoding to PCM
12    Float encoding and the manipulation of the audio ring buffer.
13
14*   Label map locale.
15
16*   Supporting Multi-head classification model.
17
18*   Supporting both single-label and multi-label classification.
19
20*   Score threshold to filter results.
21
22*   Top-k classification results.
23
24*   Label allowlist and denylist.
25
26## Supported audio classifier models
27
28The following models are guaranteed to be compatible with the `AudioClassifier`
29API.
30
31*   Models created by
32    [TensorFlow Lite Model Maker for Audio Classification](https://www.tensorflow.org/lite/api_docs/python/tflite_model_maker/audio_classifier).
33
34*   The
35    [pretrained audio event classification models on TensorFlow Hub](https://tfhub.dev/google/lite-model/yamnet/classification/tflite/1).
36
37*   Custom models that meet the
38    [model compatibility requirements](#model-compatibility-requirements).
39
40## Run inference in Java
41
42See the
43[Audio Classification reference app](https://github.com/tensorflow/examples/tree/master/lite/examples/sound_classification/android)
44for an example using `AudioClassifier` in an Android app.
45
46### Step 1: Import Gradle dependency and other settings
47
48Copy the `.tflite` model file to the assets directory of the Android module
49where the model will be run. Specify that the file should not be compressed, and
50add the TensorFlow Lite library to the module’s `build.gradle` file:
51
52```java
53android {
54    // Other settings
55
56    // Specify that the tflite file should not be compressed when building the APK package.
57    aaptOptions {
58        noCompress "tflite"
59    }
60}
61
62dependencies {
63    // Other dependencies
64
65    // Import the Audio Task Library dependency (NNAPI is included)
66    implementation 'org.tensorflow:tensorflow-lite-task-audio:0.4.0'
67    // Import the GPU delegate plugin Library for GPU inference
68    implementation 'org.tensorflow:tensorflow-lite-gpu-delegate-plugin:0.4.0'
69}
70```
71
72Note: starting from version 4.1 of the Android Gradle plugin, .tflite will be
73added to the noCompress list by default and the above aaptOptions is not needed
74anymore.
75
76### Step 2: Using the model
77
78```java
79// Initialization
80AudioClassifierOptions options =
81    AudioClassifierOptions.builder()
82        .setBaseOptions(BaseOptions.builder().useGpu().build())
83        .setMaxResults(1)
84        .build();
85AudioClassifier classifier =
86    AudioClassifier.createFromFileAndOptions(context, modelFile, options);
87
88// Start recording
89AudioRecord record = classifier.createAudioRecord();
90record.startRecording();
91
92// Load latest audio samples
93TensorAudio audioTensor = classifier.createInputTensorAudio();
94audioTensor.load(record);
95
96// Run inference
97List<Classifications> results = audioClassifier.classify(audioTensor);
98```
99
100See the
101[source code and javadoc](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/java/src/java/org/tensorflow/lite/task/audio/classifier/AudioClassifier.java)
102for more options to configure `AudioClassifier`.
103
104## Run inference in Python
105
106### Step 1: Install the pip package
107
108```
109pip install tflite-support
110```
111
112Note: Task Library's Audio APIs rely on [PortAudio](http://www.portaudio.com/docs/v19-doxydocs/index.html)
113to record audio from the device's microphone. If you intend to use Task
114Library's [AudioRecord](/lite/api_docs/python/tflite_support/task/audio/AudioRecord)
115for audio recording, you need to install PortAudio on your system.
116
117* Linux: Run `sudo apt-get update && apt-get install libportaudio2`
118* Mac and Windows: PortAudio is installed automatically when installing the
119`tflite-support` pip package.
120
121### Step 2: Using the model
122
123```python
124# Imports
125from tflite_support.task import audio
126from tflite_support.task import core
127from tflite_support.task import processor
128
129# Initialization
130base_options = core.BaseOptions(file_name=model_path)
131classification_options = processor.ClassificationOptions(max_results=2)
132options = audio.AudioClassifierOptions(base_options=base_options, classification_options=classification_options)
133classifier = audio.AudioClassifier.create_from_options(options)
134
135# Alternatively, you can create an audio classifier in the following manner:
136# classifier = audio.AudioClassifier.create_from_file(model_path)
137
138# Run inference
139audio_file = audio.TensorAudio.create_from_wav_file(audio_path, classifier.required_input_buffer_size)
140audio_result = classifier.classify(audio_file)
141```
142
143See the
144[source code](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/python/task/audio/audio_classifier.py)
145for more options to configure `AudioClassifier`.
146
147## Run inference in C++
148
149```c++
150// Initialization
151AudioClassifierOptions options;
152options.mutable_base_options()->mutable_model_file()->set_file_name(model_path);
153std::unique_ptr<AudioClassifier> audio_classifier = AudioClassifier::CreateFromOptions(options).value();
154
155// Create input audio buffer from your `audio_data` and `audio_format`.
156// See more information here: tensorflow_lite_support/cc/task/audio/core/audio_buffer.h
157int input_size = audio_classifier->GetRequiredInputBufferSize();
158const std::unique_ptr<AudioBuffer> audio_buffer =
159    AudioBuffer::Create(audio_data, input_size, audio_format).value();
160
161// Run inference
162const ClassificationResult result = audio_classifier->Classify(*audio_buffer).value();
163```
164
165See the
166[source code](https://github.com/tensorflow/tflite-support/blob/master/tensorflow_lite_support/cc/task/audio/audio_classifier.h)
167for more options to configure `AudioClassifier`.
168
169## Model compatibility requirements
170
171The `AudioClassifier` API expects a TFLite model with mandatory
172[TFLite Model Metadata](../../models/convert/metadata.md). See examples of
173creating metadata for audio classifiers using the
174[TensorFlow Lite Metadata Writer API](../../models/convert/metadata_writer_tutorial.ipynb#audio_classifiers).
175
176The compatible audio classifier models should meet the following requirements:
177
178*   Input audio tensor (kTfLiteFloat32)
179
180    -   audio clip of size `[batch x samples]`.
181    -   batch inference is not supported (`batch` is required to be 1).
182    -   for multi-channel models, the channels need to be interleaved.
183
184*   Output score tensor (kTfLiteFloat32)
185
186    -   `[1 x N]` array with `N` represents the class number.
187    -   optional (but recommended) label map(s) as AssociatedFile-s with type
188        TENSOR_AXIS_LABELS, containing one label per line. The first such
189        AssociatedFile (if any) is used to fill the `label` field (named as
190        `class_name` in C++) of the results. The `display_name` field is filled
191        from the AssociatedFile (if any) whose locale matches the
192        `display_names_locale` field of the `AudioClassifierOptions` used at
193        creation time ("en" by default, i.e. English). If none of these are
194        available, only the `index` field of the results will be filled.
195