• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // Parses tflite example input data.
16 // Format is ASCII
17 // TODO(aselle): Switch to protobuf, but the android team requested a simple
18 // ASCII file.
19 #include "tensorflow/lite/testing/parse_testdata.h"
20 
21 #include <cinttypes>
22 #include <cmath>
23 #include <cstdint>
24 #include <cstdio>
25 #include <fstream>
26 #include <iostream>
27 #include <streambuf>
28 
29 #include "tensorflow/lite/error_reporter.h"
30 #include "tensorflow/lite/testing/message.h"
31 #include "tensorflow/lite/testing/split.h"
32 
33 namespace tflite {
34 namespace testing {
35 namespace {
36 
37 // Fatal error if parse error occurs
38 #define PARSE_CHECK_EQ(filename, current_line, x, y)                         \
39   if ((x) != (y)) {                                                          \
40     fprintf(stderr, "Parse Error @ %s:%d\n  File %s\n  Line %d, %s != %s\n", \
41             __FILE__, __LINE__, filename, current_line + 1, #x, #y);         \
42     return kTfLiteError;                                                     \
43   }
44 
45 // Breakup a "," delimited line into a std::vector<std::string>.
46 // This is extremely inefficient, and just used for testing code.
47 // TODO(aselle): replace with absl when we use it.
ParseLine(const std::string & line)48 std::vector<std::string> ParseLine(const std::string& line) {
49   size_t pos = 0;
50   std::vector<std::string> elements;
51   while (true) {
52     size_t end = line.find(',', pos);
53     if (end == std::string::npos) {
54       elements.push_back(line.substr(pos));
55       break;
56     } else {
57       elements.push_back(line.substr(pos, end - pos));
58     }
59     pos = end + 1;
60   }
61   return elements;
62 }
63 
64 }  // namespace
65 
66 // Given a `filename`, produce a vector of Examples corresponding
67 // to test cases that can be applied to a tflite model.
ParseExamples(const char * filename,std::vector<Example> * examples)68 TfLiteStatus ParseExamples(const char* filename,
69                            std::vector<Example>* examples) {
70   std::ifstream fp(filename);
71   if (!fp.good()) {
72     fprintf(stderr, "Could not read '%s'\n", filename);
73     return kTfLiteError;
74   }
75   std::string str((std::istreambuf_iterator<char>(fp)),
76                   std::istreambuf_iterator<char>());
77   size_t pos = 0;
78 
79   // \n and , delimit parse a file.
80   std::vector<std::vector<std::string>> csv;
81   while (true) {
82     size_t end = str.find('\n', pos);
83 
84     if (end == std::string::npos) {
85       csv.emplace_back(ParseLine(str.substr(pos)));
86       break;
87     }
88     csv.emplace_back(ParseLine(str.substr(pos, end - pos)));
89     pos = end + 1;
90   }
91 
92   int current_line = 0;
93   PARSE_CHECK_EQ(filename, current_line, csv[0][0], "test_cases");
94   int example_count = std::stoi(csv[0][1]);
95   current_line++;
96 
97   auto parse_tensor = [&filename, &current_line,
98                        &csv](FloatTensor* tensor_ptr) {
99     PARSE_CHECK_EQ(filename, current_line, csv[current_line][0], "dtype");
100     current_line++;
101     // parse shape
102     PARSE_CHECK_EQ(filename, current_line, csv[current_line][0], "shape");
103     size_t elements = 1;
104     FloatTensor& tensor = *tensor_ptr;
105 
106     for (size_t i = 1; i < csv[current_line].size(); i++) {
107       const auto& shape_part_to_parse = csv[current_line][i];
108       if (shape_part_to_parse.empty()) {
109         // Case of a 0-dimensional shape
110         break;
111       }
112       int shape_part = std::stoi(shape_part_to_parse);
113       elements *= shape_part;
114       tensor.shape.push_back(shape_part);
115     }
116     current_line++;
117     // parse data
118     PARSE_CHECK_EQ(filename, current_line, csv[current_line].size() - 1,
119                    elements);
120     for (size_t i = 1; i < csv[current_line].size(); i++) {
121       tensor.flat_data.push_back(std::stof(csv[current_line][i]));
122     }
123     current_line++;
124 
125     return kTfLiteOk;
126   };
127 
128   for (int example_idx = 0; example_idx < example_count; example_idx++) {
129     Example example;
130     PARSE_CHECK_EQ(filename, current_line, csv[current_line][0], "inputs");
131     int inputs = std::stoi(csv[current_line][1]);
132     current_line++;
133     // parse dtype
134     for (int input_index = 0; input_index < inputs; input_index++) {
135       example.inputs.push_back(FloatTensor());
136       TF_LITE_ENSURE_STATUS(parse_tensor(&example.inputs.back()));
137     }
138 
139     PARSE_CHECK_EQ(filename, current_line, csv[current_line][0], "outputs");
140     int outputs = std::stoi(csv[current_line][1]);
141     current_line++;
142     for (int input_index = 0; input_index < outputs; input_index++) {
143       example.outputs.push_back(FloatTensor());
144       TF_LITE_ENSURE_STATUS(parse_tensor(&example.outputs.back()));
145     }
146     examples->emplace_back(example);
147   }
148   return kTfLiteOk;
149 }
150 
FeedExample(tflite::Interpreter * interpreter,const Example & example)151 TfLiteStatus FeedExample(tflite::Interpreter* interpreter,
152                          const Example& example) {
153   // Resize inputs to match example & allocate.
154   for (size_t i = 0; i < interpreter->inputs().size(); i++) {
155     int input_index = interpreter->inputs()[i];
156 
157     TF_LITE_ENSURE_STATUS(
158         interpreter->ResizeInputTensor(input_index, example.inputs[i].shape));
159   }
160   TF_LITE_ENSURE_STATUS(interpreter->AllocateTensors());
161   // Copy data into tensors.
162   for (size_t i = 0; i < interpreter->inputs().size(); i++) {
163     int input_index = interpreter->inputs()[i];
164     if (float* data = interpreter->typed_tensor<float>(input_index)) {
165       for (size_t idx = 0; idx < example.inputs[i].flat_data.size(); idx++) {
166         data[idx] = example.inputs[i].flat_data[idx];
167       }
168     } else if (int32_t* data =
169                    interpreter->typed_tensor<int32_t>(input_index)) {
170       for (size_t idx = 0; idx < example.inputs[i].flat_data.size(); idx++) {
171         data[idx] = example.inputs[i].flat_data[idx];
172       }
173     } else if (int64_t* data =
174                    interpreter->typed_tensor<int64_t>(input_index)) {
175       for (size_t idx = 0; idx < example.inputs[i].flat_data.size(); idx++) {
176         data[idx] = example.inputs[i].flat_data[idx];
177       }
178     } else {
179       fprintf(stderr, "input[%zu] was not float or int data\n", i);
180       return kTfLiteError;
181     }
182   }
183   return kTfLiteOk;
184 }
185 
CheckOutputs(tflite::Interpreter * interpreter,const Example & example)186 TfLiteStatus CheckOutputs(tflite::Interpreter* interpreter,
187                           const Example& example) {
188   constexpr double kRelativeThreshold = 1e-2f;
189   constexpr double kAbsoluteThreshold = 1e-4f;
190 
191   ErrorReporter* context = DefaultErrorReporter();
192   int model_outputs = interpreter->outputs().size();
193   TF_LITE_ENSURE_EQ(context, model_outputs, example.outputs.size());
194   for (size_t i = 0; i < interpreter->outputs().size(); i++) {
195     bool tensors_differ = false;
196     int output_index = interpreter->outputs()[i];
197     if (const float* data = interpreter->typed_tensor<float>(output_index)) {
198       for (size_t idx = 0; idx < example.outputs[i].flat_data.size(); idx++) {
199         float computed = data[idx];
200         float reference = example.outputs[0].flat_data[idx];
201         float diff = std::abs(computed - reference);
202         // For very small numbers, try absolute error, otherwise go with
203         // relative.
204         bool local_tensors_differ =
205             std::abs(reference) < kRelativeThreshold
206                 ? diff > kAbsoluteThreshold
207                 : diff > kRelativeThreshold * std::abs(reference);
208         if (local_tensors_differ) {
209           fprintf(stdout, "output[%zu][%zu] did not match %f vs reference %f\n",
210                   i, idx, data[idx], reference);
211           tensors_differ = local_tensors_differ;
212         }
213       }
214     } else if (const int32_t* data =
215                    interpreter->typed_tensor<int32_t>(output_index)) {
216       for (size_t idx = 0; idx < example.outputs[i].flat_data.size(); idx++) {
217         int32_t computed = data[idx];
218         int32_t reference = example.outputs[0].flat_data[idx];
219         if (std::abs(computed - reference) > 0) {
220           fprintf(stderr, "output[%zu][%zu] did not match %d vs reference %d\n",
221                   i, idx, computed, reference);
222           tensors_differ = true;
223         }
224       }
225     } else if (const int64_t* data =
226                    interpreter->typed_tensor<int64_t>(output_index)) {
227       for (size_t idx = 0; idx < example.outputs[i].flat_data.size(); idx++) {
228         int64_t computed = data[idx];
229         int64_t reference = example.outputs[0].flat_data[idx];
230         if (std::abs(computed - reference) > 0) {
231           fprintf(stderr,
232                   "output[%zu][%zu] did not match %" PRId64
233                   " vs reference %" PRId64 "\n",
234                   i, idx, computed, reference);
235           tensors_differ = true;
236         }
237       }
238     } else {
239       fprintf(stderr, "output[%zu] was not float or int data\n", i);
240       return kTfLiteError;
241     }
242     fprintf(stderr, "\n");
243     if (tensors_differ) return kTfLiteError;
244   }
245   return kTfLiteOk;
246 }
247 
248 // Process an 'invoke' message, triggering execution of the test runner, as
249 // well as verification of outputs. An 'invoke' message looks like:
250 //   invoke {
251 //     id: xyz
252 //     input: 1,2,1,1,1,2,3,4
253 //     output: 4,5,6
254 //   }
255 class Invoke : public Message {
256  public:
Invoke(TestRunner * test_runner)257   explicit Invoke(TestRunner* test_runner) : test_runner_(test_runner) {
258     expected_inputs_ = test_runner->GetInputs();
259     expected_outputs_ = test_runner->GetOutputs();
260   }
261 
SetField(const std::string & name,const std::string & value)262   void SetField(const std::string& name, const std::string& value) override {
263     if (name == "id") {
264       test_runner_->SetInvocationId(value);
265     } else if (name == "input") {
266       if (parsed_input_count_ >= expected_inputs_.size()) {
267         return test_runner_->Invalidate("Too many inputs");
268       }
269       test_runner_->SetInput(expected_inputs_[parsed_input_count_], value);
270       ++parsed_input_count_;
271     } else if (name == "output") {
272       if (parsed_output_count_ >= expected_outputs_.size()) {
273         return test_runner_->Invalidate("Too many outputs");
274       }
275       test_runner_->SetExpectation(expected_outputs_[parsed_output_count_],
276                                    value);
277       ++parsed_output_count_;
278     } else if (name == "output_shape") {
279       if (parsed_output_shape_count_ >= expected_outputs_.size()) {
280         return test_runner_->Invalidate("Too many output shapes");
281       }
282       test_runner_->SetShapeExpectation(
283           expected_outputs_[parsed_output_shape_count_], value);
284       ++parsed_output_shape_count_;
285     }
286   }
Finish()287   void Finish() override {
288     test_runner_->Invoke();
289     test_runner_->CheckResults();
290   }
291 
292  private:
293   std::vector<int> expected_inputs_;
294   std::vector<int> expected_outputs_;
295 
296   int parsed_input_count_ = 0;
297   int parsed_output_count_ = 0;
298   int parsed_output_shape_count_ = 0;
299 
300   TestRunner* test_runner_;
301 };
302 
303 // Process an 'reshape' message, triggering resizing of the input tensors via
304 // the test runner. A 'reshape' message looks like:
305 //   reshape {
306 //     input: 1,2,1,1,1,2,3,4
307 //   }
308 class Reshape : public Message {
309  public:
Reshape(TestRunner * test_runner)310   explicit Reshape(TestRunner* test_runner) : test_runner_(test_runner) {
311     expected_inputs_ = test_runner->GetInputs();
312   }
313 
SetField(const std::string & name,const std::string & value)314   void SetField(const std::string& name, const std::string& value) override {
315     if (name == "input") {
316       if (expected_inputs_.empty()) {
317         return test_runner_->Invalidate("Too many inputs to reshape");
318       }
319       test_runner_->ReshapeTensor(*expected_inputs_.begin(), value);
320       expected_inputs_.erase(expected_inputs_.begin());
321     }
322   }
323 
324  private:
325   std::vector<int> expected_inputs_;
326   TestRunner* test_runner_;
327 };
328 
329 // This is the top-level message in a test file.
330 class TestData : public Message {
331  public:
TestData(TestRunner * test_runner)332   explicit TestData(TestRunner* test_runner)
333       : test_runner_(test_runner), num_invocations_(0), max_invocations_(-1) {}
SetMaxInvocations(int max)334   void SetMaxInvocations(int max) { max_invocations_ = max; }
SetField(const std::string & name,const std::string & value)335   void SetField(const std::string& name, const std::string& value) override {
336     if (name == "load_model") {
337       test_runner_->LoadModel(value);
338     } else if (name == "init_state") {
339       test_runner_->AllocateTensors();
340       for (int id : Split<int>(value, ",")) {
341         test_runner_->ResetTensor(id);
342       }
343     }
344   }
AddChild(const std::string & s)345   Message* AddChild(const std::string& s) override {
346     if (s == "invoke") {
347       test_runner_->AllocateTensors();
348       if (max_invocations_ == -1 || num_invocations_ < max_invocations_) {
349         ++num_invocations_;
350         return Store(new Invoke(test_runner_));
351       } else {
352         return nullptr;
353       }
354     } else if (s == "reshape") {
355       return Store(new Reshape(test_runner_));
356     }
357     return nullptr;
358   }
359 
360  private:
361   TestRunner* test_runner_;
362   int num_invocations_;
363   int max_invocations_;
364 };
365 
ParseAndRunTests(std::istream * input,TestRunner * test_runner,int max_invocations)366 bool ParseAndRunTests(std::istream* input, TestRunner* test_runner,
367                       int max_invocations) {
368   TestData test_data(test_runner);
369   test_data.SetMaxInvocations(max_invocations);
370   Message::Read(input, &test_data);
371   return test_runner->IsValid() && test_runner->GetOverallSuccess();
372 }
373 
374 }  // namespace testing
375 }  // namespace tflite
376