• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/lite/micro/examples/micro_speech/micro_features/no_micro_features_data.h"
17 #include "tensorflow/lite/micro/examples/micro_speech/micro_features/tiny_conv_micro_features_model_data.h"
18 #include "tensorflow/lite/micro/examples/micro_speech/micro_features/yes_micro_features_data.h"
19 #include "tensorflow/lite/micro/kernels/micro_ops.h"
20 #include "tensorflow/lite/micro/micro_error_reporter.h"
21 #include "tensorflow/lite/micro/micro_interpreter.h"
22 #include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
23 #include "tensorflow/lite/micro/testing/micro_test.h"
24 #include "tensorflow/lite/schema/schema_generated.h"
25 #include "tensorflow/lite/version.h"
26 
27 TF_LITE_MICRO_TESTS_BEGIN
28 
TF_LITE_MICRO_TEST(TestInvoke)29 TF_LITE_MICRO_TEST(TestInvoke) {
30   // Set up logging.
31   tflite::MicroErrorReporter micro_error_reporter;
32   tflite::ErrorReporter* error_reporter = &micro_error_reporter;
33 
34   // Map the model into a usable data structure. This doesn't involve any
35   // copying or parsing, it's a very lightweight operation.
36   const tflite::Model* model =
37       ::tflite::GetModel(g_tiny_conv_micro_features_model_data);
38   if (model->version() != TFLITE_SCHEMA_VERSION) {
39     error_reporter->Report(
40         "Model provided is schema version %d not equal "
41         "to supported version %d.\n",
42         model->version(), TFLITE_SCHEMA_VERSION);
43   }
44 
45   // Pull in only the operation implementations we need.
46   // This relies on a complete list of all the ops needed by this graph.
47   // An easier approach is to just use the AllOpsResolver, but this will
48   // incur some penalty in code space for op implementations that are not
49   // needed by this graph.
50   //
51   // tflite::ops::micro::AllOpsResolver resolver;
52   tflite::MicroOpResolver<3> micro_op_resolver;
53   micro_op_resolver.AddBuiltin(
54       tflite::BuiltinOperator_DEPTHWISE_CONV_2D,
55       tflite::ops::micro::Register_DEPTHWISE_CONV_2D());
56   micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_FULLY_CONNECTED,
57                                tflite::ops::micro::Register_FULLY_CONNECTED());
58   micro_op_resolver.AddBuiltin(tflite::BuiltinOperator_SOFTMAX,
59                                tflite::ops::micro::Register_SOFTMAX());
60 
61   // Create an area of memory to use for input, output, and intermediate arrays.
62   const int tensor_arena_size = 10 * 1024;
63   uint8_t tensor_arena[tensor_arena_size];
64 
65   // Build an interpreter to run the model with.
66   tflite::MicroInterpreter interpreter(model, micro_op_resolver, tensor_arena,
67                                        tensor_arena_size, error_reporter);
68   interpreter.AllocateTensors();
69 
70   // Get information about the memory area to use for the model's input.
71   TfLiteTensor* input = interpreter.input(0);
72 
73   // Make sure the input has the properties we expect.
74   TF_LITE_MICRO_EXPECT_NE(nullptr, input);
75   TF_LITE_MICRO_EXPECT_EQ(4, input->dims->size);
76   TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[0]);
77   TF_LITE_MICRO_EXPECT_EQ(49, input->dims->data[1]);
78   TF_LITE_MICRO_EXPECT_EQ(40, input->dims->data[2]);
79   TF_LITE_MICRO_EXPECT_EQ(1, input->dims->data[3]);
80   TF_LITE_MICRO_EXPECT_EQ(kTfLiteUInt8, input->type);
81 
82   // Copy a spectrogram created from a .wav audio file of someone saying "Yes",
83   // into the memory area used for the input.
84   const uint8_t* yes_features_data = g_yes_micro_f2e59fea_nohash_1_data;
85   for (int i = 0; i < input->bytes; ++i) {
86     input->data.uint8[i] = yes_features_data[i];
87   }
88 
89   // Run the model on this input and make sure it succeeds.
90   TfLiteStatus invoke_status = interpreter.Invoke();
91   if (invoke_status != kTfLiteOk) {
92     error_reporter->Report("Invoke failed\n");
93   }
94   TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, invoke_status);
95 
96   // Get the output from the model, and make sure it's the expected size and
97   // type.
98   TfLiteTensor* output = interpreter.output(0);
99   TF_LITE_MICRO_EXPECT_EQ(2, output->dims->size);
100   TF_LITE_MICRO_EXPECT_EQ(1, output->dims->data[0]);
101   TF_LITE_MICRO_EXPECT_EQ(4, output->dims->data[1]);
102   TF_LITE_MICRO_EXPECT_EQ(kTfLiteUInt8, output->type);
103 
104   // There are four possible classes in the output, each with a score.
105   const int kSilenceIndex = 0;
106   const int kUnknownIndex = 1;
107   const int kYesIndex = 2;
108   const int kNoIndex = 3;
109 
110   // Make sure that the expected "Yes" score is higher than the other classes.
111   uint8_t silence_score = output->data.uint8[kSilenceIndex];
112   uint8_t unknown_score = output->data.uint8[kUnknownIndex];
113   uint8_t yes_score = output->data.uint8[kYesIndex];
114   uint8_t no_score = output->data.uint8[kNoIndex];
115   TF_LITE_MICRO_EXPECT_GT(yes_score, silence_score);
116   TF_LITE_MICRO_EXPECT_GT(yes_score, unknown_score);
117   TF_LITE_MICRO_EXPECT_GT(yes_score, no_score);
118 
119   // Now test with a different input, from a recording of "No".
120   const uint8_t* no_features_data = g_no_micro_f9643d42_nohash_4_data;
121   for (int i = 0; i < input->bytes; ++i) {
122     input->data.uint8[i] = no_features_data[i];
123   }
124 
125   // Run the model on this "No" input.
126   invoke_status = interpreter.Invoke();
127   if (invoke_status != kTfLiteOk) {
128     error_reporter->Report("Invoke failed\n");
129   }
130   TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, invoke_status);
131 
132   // Get the output from the model, and make sure it's the expected size and
133   // type.
134   output = interpreter.output(0);
135   TF_LITE_MICRO_EXPECT_EQ(2, output->dims->size);
136   TF_LITE_MICRO_EXPECT_EQ(1, output->dims->data[0]);
137   TF_LITE_MICRO_EXPECT_EQ(4, output->dims->data[1]);
138   TF_LITE_MICRO_EXPECT_EQ(kTfLiteUInt8, output->type);
139 
140   // Make sure that the expected "No" score is higher than the other classes.
141   silence_score = output->data.uint8[kSilenceIndex];
142   unknown_score = output->data.uint8[kUnknownIndex];
143   yes_score = output->data.uint8[kYesIndex];
144   no_score = output->data.uint8[kNoIndex];
145   TF_LITE_MICRO_EXPECT_GT(no_score, silence_score);
146   TF_LITE_MICRO_EXPECT_GT(no_score, unknown_score);
147   TF_LITE_MICRO_EXPECT_GT(no_score, yes_score);
148 
149   error_reporter->Report("Ran successfully\n");
150 }
151 
152 TF_LITE_MICRO_TESTS_END
153