• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/micro/examples/micro_speech/micro_features/model.h"
17 #include "tensorflow/lite/micro/examples/micro_speech/micro_features/no_micro_features_data.h"
18 #include "tensorflow/lite/micro/examples/micro_speech/micro_features/yes_micro_features_data.h"
19 #include "tensorflow/lite/micro/micro_error_reporter.h"
20 #include "tensorflow/lite/micro/micro_interpreter.h"
21 #include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
22 #include "tensorflow/lite/micro/testing/micro_test.h"
23 #include "tensorflow/lite/schema/schema_generated.h"
24 
25 TF_LITE_MICRO_TESTS_BEGIN
26 
TF_LITE_MICRO_TEST(TestInvoke)27 TF_LITE_MICRO_TEST(TestInvoke) {
28   // Set up logging.
29   tflite::MicroErrorReporter micro_error_reporter;
30 
31   // Map the model into a usable data structure. This doesn't involve any
32   // copying or parsing, it's a very lightweight operation.
33   const tflite::Model* model = ::tflite::GetModel(g_model);
34   if (model->version() != TFLITE_SCHEMA_VERSION) {
35     TF_LITE_REPORT_ERROR(&micro_error_reporter,
36                          "Model provided is schema version %d not equal "
37                          "to supported version %d.\n",
38                          model->version(), TFLITE_SCHEMA_VERSION);
39   }
40 
41   // Pull in only the operation implementations we need.
42   // This relies on a complete list of all the ops needed by this graph.
43   // An easier approach is to just use the AllOpsResolver, but this will
44   // incur some penalty in code space for op implementations that are not
45   // needed by this graph.
46   //
47   // tflite::AllOpsResolver resolver;
48   tflite::MicroMutableOpResolver<4> micro_op_resolver;
49   micro_op_resolver.AddDepthwiseConv2D();
50   micro_op_resolver.AddFullyConnected();
51   micro_op_resolver.AddReshape();
52   micro_op_resolver.AddSoftmax();
53 
54   // Create an area of memory to use for input, output, and intermediate arrays.
55   const int tensor_arena_size = 10 * 1024;
56   uint8_t tensor_arena[tensor_arena_size];
57 
58   // Build an interpreter to run the model with.
59   tflite::MicroInterpreter interpreter(model, micro_op_resolver, tensor_arena,
60                                        tensor_arena_size,
61                                        &micro_error_reporter);
62   interpreter.AllocateTensors();
63 
64   // Get information about the memory area to use for the model's input.
65   TfLiteTensor* input = interpreter.input(0);
66 
67   // Make sure the input has the properties we expect.
68   TF_LITE_MICRO_EXPECT_NE(nullptr, input);
69   TF_LITE_MICRO_EXPECT_EQ(2, input->dims->size);
70   TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[0]);
71   TF_LITE_MICRO_EXPECT_EQ(1960, input->dims->data[1]);
72   TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt8, input->type);
73 
74   // Copy a spectrogram created from a .wav audio file of someone saying "Yes",
75   // into the memory area used for the input.
76   const int8_t* yes_features_data = g_yes_micro_f2e59fea_nohash_1_data;
77   for (size_t i = 0; i < input->bytes; ++i) {
78     input->data.int8[i] = yes_features_data[i];
79   }
80 
81   // Run the model on this input and make sure it succeeds.
82   TfLiteStatus invoke_status = interpreter.Invoke();
83   if (invoke_status != kTfLiteOk) {
84     TF_LITE_REPORT_ERROR(&micro_error_reporter, "Invoke failed\n");
85   }
86   TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, invoke_status);
87 
88   // Get the output from the model, and make sure it's the expected size and
89   // type.
90   TfLiteTensor* output = interpreter.output(0);
91   TF_LITE_MICRO_EXPECT_EQ(2, output->dims->size);
92   TF_LITE_MICRO_EXPECT_EQ(1, output->dims->data[0]);
93   TF_LITE_MICRO_EXPECT_EQ(4, output->dims->data[1]);
94   TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt8, output->type);
95 
96   // There are four possible classes in the output, each with a score.
97   const int kSilenceIndex = 0;
98   const int kUnknownIndex = 1;
99   const int kYesIndex = 2;
100   const int kNoIndex = 3;
101 
102   // Make sure that the expected "Yes" score is higher than the other classes.
103   uint8_t silence_score = output->data.uint8[kSilenceIndex] + 128;
104   uint8_t unknown_score = output->data.uint8[kUnknownIndex] + 128;
105   uint8_t yes_score = output->data.int8[kYesIndex] + 128;
106   uint8_t no_score = output->data.int8[kNoIndex] + 128;
107   TF_LITE_MICRO_EXPECT_GT(yes_score, silence_score);
108   TF_LITE_MICRO_EXPECT_GT(yes_score, unknown_score);
109   TF_LITE_MICRO_EXPECT_GT(yes_score, no_score);
110 
111   // Now test with a different input, from a recording of "No".
112   const int8_t* no_features_data = g_no_micro_f9643d42_nohash_4_data;
113   for (size_t i = 0; i < input->bytes; ++i) {
114     input->data.int8[i] = no_features_data[i];
115   }
116 
117   // Run the model on this "No" input.
118   invoke_status = interpreter.Invoke();
119   if (invoke_status != kTfLiteOk) {
120     TF_LITE_REPORT_ERROR(&micro_error_reporter, "Invoke failed\n");
121   }
122   TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, invoke_status);
123 
124   // Get the output from the model, and make sure it's the expected size and
125   // type.
126   output = interpreter.output(0);
127   TF_LITE_MICRO_EXPECT_EQ(2, output->dims->size);
128   TF_LITE_MICRO_EXPECT_EQ(1, output->dims->data[0]);
129   TF_LITE_MICRO_EXPECT_EQ(4, output->dims->data[1]);
130   TF_LITE_MICRO_EXPECT_EQ(kTfLiteInt8, output->type);
131 
132   // Make sure that the expected "No" score is higher than the other classes.
133   silence_score = output->data.int8[kSilenceIndex] + 128;
134   unknown_score = output->data.int8[kUnknownIndex] + 128;
135   yes_score = output->data.int8[kYesIndex] + 128;
136   no_score = output->data.int8[kNoIndex] + 128;
137   TF_LITE_MICRO_EXPECT_GT(no_score, silence_score);
138   TF_LITE_MICRO_EXPECT_GT(no_score, unknown_score);
139   TF_LITE_MICRO_EXPECT_GT(no_score, yes_score);
140 
141   TF_LITE_REPORT_ERROR(&micro_error_reporter, "Ran successfully\n");
142 }
143 
144 TF_LITE_MICRO_TESTS_END
145