1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/micro/examples/micro_speech/micro_features/no_micro_features_data.h"
17 #include "tensorflow/lite/micro/examples/micro_speech/micro_features/tiny_conv_micro_features_model_data.h"
18 #include "tensorflow/lite/micro/examples/micro_speech/micro_features/yes_micro_features_data.h"
19 #include "tensorflow/lite/micro/kernels/micro_ops.h"
20 #include "tensorflow/lite/micro/micro_error_reporter.h"
21 #include "tensorflow/lite/micro/micro_interpreter.h"
22 #include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
23 #include "tensorflow/lite/micro/testing/micro_test.h"
24 #include "tensorflow/lite/schema/schema_generated.h"
25 #include "tensorflow/lite/version.h"
26
27 TF_LITE_MICRO_TESTS_BEGIN
28
TF_LITE_MICRO_TEST(TestInvoke)29 TF_LITE_MICRO_TEST(TestInvoke) {
30 // Set up logging.
31 tflite::MicroErrorReporter micro_error_reporter;
32 tflite::ErrorReporter* error_reporter = µ_error_reporter;
33
34 // Map the model into a usable data structure. This doesn't involve any
35 // copying or parsing, it's a very lightweight operation.
36 const tflite::Model* model =
37 ::tflite::GetModel(g_tiny_conv_micro_features_model_data);
38 if (model->version() != TFLITE_SCHEMA_VERSION) {
39 error_reporter->Report(
40 "Model provided is schema version %d not equal "
41 "to supported version %d.\n",
42 model->version(), TFLITE_SCHEMA_VERSION);
43 }
44
45 // Pull in only the operation implementations we need.
46 // This relies on a complete list of all the ops needed by this graph.
47 // An easier approach is to just use the AllOpsResolver, but this will
48 // incur some penalty in code space for op implementations that are not
49 // needed by this graph.
50 //
51 // tflite::ops::micro::AllOpsResolver resolver;
52 tflite::MicroOpResolver<3> micro_op_resolver;
53 micro_op_resolver.AddBuiltin(
54 tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
55 tflite::ops::micro::Register_DEPTHWISE_CONV_2D());
56 micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_FULLY_CONNECTED,
57 tflite::ops::micro::Register_FULLY_CONNECTED());
58 micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_SOFTMAX,
59 tflite::ops::micro::Register_SOFTMAX());
60
61 // Create an area of memory to use for input, output, and intermediate arrays.
62 const int tensor_arena_size = 10 * 1024;
63 uint8_t tensor_arena[tensor_arena_size];
64
65 // Build an interpreter to run the model with.
66 tflite::MicroInterpreter interpreter(model, micro_op_resolver, tensor_arena,
67 tensor_arena_size, error_reporter);
68 interpreter.AllocateTensors();
69
70 // Get information about the memory area to use for the model's input.
71 TfLiteTensor* input = interpreter.input(0);
72
73 // Make sure the input has the properties we expect.
74 TF_LITE_MICRO_EXPECT_NE(nullptr, input);
75 TF_LITE_MICRO_EXPECT_EQ(4, input->dims->size);
76 TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[0]);
77 TF_LITE_MICRO_EXPECT_EQ(49, input->dims->data[1]);
78 TF_LITE_MICRO_EXPECT_EQ(40, input->dims->data[2]);
79 TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[3]);
80 TF_LITE_MICRO_EXPECT_EQ(kTfLiteUInt8, input->type);
81
82 // Copy a spectrogram created from a .wav audio file of someone saying "Yes",
83 // into the memory area used for the input.
84 const uint8_t* yes_features_data = g_yes_micro_f2e59fea_nohash_1_data;
85 for (int i = 0; i < input->bytes; ++i) {
86 input->data.uint8[i] = yes_features_data[i];
87 }
88
89 // Run the model on this input and make sure it succeeds.
90 TfLiteStatus invoke_status = interpreter.Invoke();
91 if (invoke_status != kTfLiteOk) {
92 error_reporter->Report("Invoke failed\n");
93 }
94 TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, invoke_status);
95
96 // Get the output from the model, and make sure it's the expected size and
97 // type.
98 TfLiteTensor* output = interpreter.output(0);
99 TF_LITE_MICRO_EXPECT_EQ(2, output->dims->size);
100 TF_LITE_MICRO_EXPECT_EQ(1, output->dims->data[0]);
101 TF_LITE_MICRO_EXPECT_EQ(4, output->dims->data[1]);
102 TF_LITE_MICRO_EXPECT_EQ(kTfLiteUInt8, output->type);
103
104 // There are four possible classes in the output, each with a score.
105 const int kSilenceIndex = 0;
106 const int kUnknownIndex = 1;
107 const int kYesIndex = 2;
108 const int kNoIndex = 3;
109
110 // Make sure that the expected "Yes" score is higher than the other classes.
111 uint8_t silence_score = output->data.uint8[kSilenceIndex];
112 uint8_t unknown_score = output->data.uint8[kUnknownIndex];
113 uint8_t yes_score = output->data.uint8[kYesIndex];
114 uint8_t no_score = output->data.uint8[kNoIndex];
115 TF_LITE_MICRO_EXPECT_GT(yes_score, silence_score);
116 TF_LITE_MICRO_EXPECT_GT(yes_score, unknown_score);
117 TF_LITE_MICRO_EXPECT_GT(yes_score, no_score);
118
119 // Now test with a different input, from a recording of "No".
120 const uint8_t* no_features_data = g_no_micro_f9643d42_nohash_4_data;
121 for (int i = 0; i < input->bytes; ++i) {
122 input->data.uint8[i] = no_features_data[i];
123 }
124
125 // Run the model on this "No" input.
126 invoke_status = interpreter.Invoke();
127 if (invoke_status != kTfLiteOk) {
128 error_reporter->Report("Invoke failed\n");
129 }
130 TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, invoke_status);
131
132 // Get the output from the model, and make sure it's the expected size and
133 // type.
134 output = interpreter.output(0);
135 TF_LITE_MICRO_EXPECT_EQ(2, output->dims->size);
136 TF_LITE_MICRO_EXPECT_EQ(1, output->dims->data[0]);
137 TF_LITE_MICRO_EXPECT_EQ(4, output->dims->data[1]);
138 TF_LITE_MICRO_EXPECT_EQ(kTfLiteUInt8, output->type);
139
140 // Make sure that the expected "No" score is higher than the other classes.
141 silence_score = output->data.uint8[kSilenceIndex];
142 unknown_score = output->data.uint8[kUnknownIndex];
143 yes_score = output->data.uint8[kYesIndex];
144 no_score = output->data.uint8[kNoIndex];
145 TF_LITE_MICRO_EXPECT_GT(no_score, silence_score);
146 TF_LITE_MICRO_EXPECT_GT(no_score, unknown_score);
147 TF_LITE_MICRO_EXPECT_GT(no_score, yes_score);
148
149 error_reporter->Report("Ran successfully\n");
150 }
151
152 TF_LITE_MICRO_TESTS_END
153