1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/micro/examples/micro_speech/micro_features/model.h"
17 #include "tensorflow/lite/micro/examples/micro_speech/micro_features/no_micro_features_data.h"
18 #include "tensorflow/lite/micro/examples/micro_speech/micro_features/yes_micro_features_data.h"
19 #include "tensorflow/lite/micro/micro_error_reporter.h"
20 #include "tensorflow/lite/micro/micro_interpreter.h"
21 #include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
22 #include "tensorflow/lite/micro/testing/micro_test.h"
23 #include "tensorflow/lite/schema/schema_generated.h"
24
25 TF_LITE_MICRO_TESTS_BEGIN
26
TF_LITE_MICRO_TEST(TestInvoke)27 TF_LITE_MICRO_TEST(TestInvoke) {
28 // Set up logging.
29 tflite::MicroErrorReporter micro_error_reporter;
30
31 // Map the model into a usable data structure. This doesn't involve any
32 // copying or parsing, it's a very lightweight operation.
33 const tflite::Model* model = ::tflite::GetModel(g_model);
34 if (model->version() != TFLITE_SCHEMA_VERSION) {
35 TF_LITE_REPORT_ERROR(µ_error_reporter,
36 "Model provided is schema version %d not equal "
37 "to supported version %d.\n",
38 model->version(), TFLITE_SCHEMA_VERSION);
39 }
40
41 // Pull in only the operation implementations we need.
42 // This relies on a complete list of all the ops needed by this graph.
43 // An easier approach is to just use the AllOpsResolver, but this will
44 // incur some penalty in code space for op implementations that are not
45 // needed by this graph.
46 //
47 // tflite::AllOpsResolver resolver;
48 tflite::MicroMutableOpResolver<4> micro_op_resolver;
49 micro_op_resolver.AddDepthwiseConv2D();
50 micro_op_resolver.AddFullyConnected();
51 micro_op_resolver.AddReshape();
52 micro_op_resolver.AddSoftmax();
53
54 // Create an area of memory to use for input, output, and intermediate arrays.
55 const int tensor_arena_size = 10 * 1024;
56 uint8_t tensor_arena[tensor_arena_size];
57
58 // Build an interpreter to run the model with.
59 tflite::MicroInterpreter interpreter(model, micro_op_resolver, tensor_arena,
60 tensor_arena_size,
61 µ_error_reporter);
62 interpreter.AllocateTensors();
63
64 // Get information about the memory area to use for the model's input.
65 TfLiteTensor* input = interpreter.input(0);
66
67 // Make sure the input has the properties we expect.
68 TF_LITE_MICRO_EXPECT_NE(nullptr, input);
69 TF_LITE_MICRO_EXPECT_EQ(2, input->dims->size);
70 TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[0]);
71 TF_LITE_MICRO_EXPECT_EQ(1960, input->dims->data[1]);
72 TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt8, input->type);
73
74 // Copy a spectrogram created from a .wav audio file of someone saying "Yes",
75 // into the memory area used for the input.
76 const int8_t* yes_features_data = g_yes_micro_f2e59fea_nohash_1_data;
77 for (size_t i = 0; i < input->bytes; ++i) {
78 input->data.int8[i] = yes_features_data[i];
79 }
80
81 // Run the model on this input and make sure it succeeds.
82 TfLiteStatus invoke_status = interpreter.Invoke();
83 if (invoke_status != kTfLiteOk) {
84 TF_LITE_REPORT_ERROR(µ_error_reporter, "Invoke failed\n");
85 }
86 TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, invoke_status);
87
88 // Get the output from the model, and make sure it's the expected size and
89 // type.
90 TfLiteTensor* output = interpreter.output(0);
91 TF_LITE_MICRO_EXPECT_EQ(2, output->dims->size);
92 TF_LITE_MICRO_EXPECT_EQ(1, output->dims->data[0]);
93 TF_LITE_MICRO_EXPECT_EQ(4, output->dims->data[1]);
94 TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt8, output->type);
95
96 // There are four possible classes in the output, each with a score.
97 const int kSilenceIndex = 0;
98 const int kUnknownIndex = 1;
99 const int kYesIndex = 2;
100 const int kNoIndex = 3;
101
102 // Make sure that the expected "Yes" score is higher than the other classes.
103 uint8_t silence_score = output->data.uint8[kSilenceIndex] + 128;
104 uint8_t unknown_score = output->data.uint8[kUnknownIndex] + 128;
105 uint8_t yes_score = output->data.int8[kYesIndex] + 128;
106 uint8_t no_score = output->data.int8[kNoIndex] + 128;
107 TF_LITE_MICRO_EXPECT_GT(yes_score, silence_score);
108 TF_LITE_MICRO_EXPECT_GT(yes_score, unknown_score);
109 TF_LITE_MICRO_EXPECT_GT(yes_score, no_score);
110
111 // Now test with a different input, from a recording of "No".
112 const int8_t* no_features_data = g_no_micro_f9643d42_nohash_4_data;
113 for (size_t i = 0; i < input->bytes; ++i) {
114 input->data.int8[i] = no_features_data[i];
115 }
116
117 // Run the model on this "No" input.
118 invoke_status = interpreter.Invoke();
119 if (invoke_status != kTfLiteOk) {
120 TF_LITE_REPORT_ERROR(µ_error_reporter, "Invoke failed\n");
121 }
122 TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, invoke_status);
123
124 // Get the output from the model, and make sure it's the expected size and
125 // type.
126 output = interpreter.output(0);
127 TF_LITE_MICRO_EXPECT_EQ(2, output->dims->size);
128 TF_LITE_MICRO_EXPECT_EQ(1, output->dims->data[0]);
129 TF_LITE_MICRO_EXPECT_EQ(4, output->dims->data[1]);
130 TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt8, output->type);
131
132 // Make sure that the expected "No" score is higher than the other classes.
133 silence_score = output->data.int8[kSilenceIndex] + 128;
134 unknown_score = output->data.int8[kUnknownIndex] + 128;
135 yes_score = output->data.int8[kYesIndex] + 128;
136 no_score = output->data.int8[kNoIndex] + 128;
137 TF_LITE_MICRO_EXPECT_GT(no_score, silence_score);
138 TF_LITE_MICRO_EXPECT_GT(no_score, unknown_score);
139 TF_LITE_MICRO_EXPECT_GT(no_score, yes_score);
140
141 TF_LITE_REPORT_ERROR(µ_error_reporter, "Ran successfully\n");
142 }
143
144 TF_LITE_MICRO_TESTS_END
145