• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This program runs the tflite model specified in --model with random inputs.
17 // For string type, the input is filled with a fixed string.
18 
19 #include <string>
20 
21 #include <glog/logging.h>
22 #include "tensorflow/core/platform/init_main.h"
23 #include "tensorflow/lite/c/common.h"
24 #include "tensorflow/lite/interpreter.h"
25 #include "tensorflow/lite/kernels/register.h"
26 #include "tensorflow/lite/model.h"
27 #include "tensorflow/lite/model_builder.h"
28 #include "tensorflow/lite/string_util.h"
29 #include "tensorflow/lite/tools/command_line_flags.h"
30 
FillRandomString(tflite::DynamicBuffer * buffer,const TfLiteIntArray * dim_array,const std::function<std::string ()> & random_func)31 void FillRandomString(tflite::DynamicBuffer* buffer,
32                       const TfLiteIntArray* dim_array,
33                       const std::function<std::string()>& random_func) {
34   int num_elements = 1;
35   for (size_t i = 0; i < dim_array->size; i++) {
36     num_elements *= dim_array->data[i];
37   }
38   for (int i = 0; i < num_elements; ++i) {
39     auto str = random_func();
40     buffer->AddString(str.data(), str.length());
41   }
42 }
43 
RunWithRandomInputs(const std::string & filename)44 void RunWithRandomInputs(const std::string& filename) {
45   std::unique_ptr<tflite::FlatBufferModel> model =
46       tflite::FlatBufferModel::BuildFromFile(filename.c_str());
47 
48   // Build the interpreter
49   tflite::ops::builtin::BuiltinOpResolver resolver;
50   std::unique_ptr<tflite::Interpreter> interpreter;
51   if (tflite::InterpreterBuilder(*model, resolver)(&interpreter) != kTfLiteOk) {
52     LOG(FATAL) << "Could not initialize interpreter for TFLite model.";
53   }
54 
55   // Resize input tensors, if desired.
56   if (interpreter->AllocateTensors() != kTfLiteOk) {
57     LOG(FATAL) << "Could not allocate tensor.";
58   }
59 
60   // Fill the random data.
61   std::vector<std::vector<uint8_t>> sample;
62   for (int tensor_idx : interpreter->inputs()) {
63     auto tensor = interpreter->tensor(tensor_idx);
64     if (tensor->type == kTfLiteString) {
65       tflite::DynamicBuffer buffer;
66       FillRandomString(&buffer, tensor->dims, []() {
67         return "we're have some friends over saturday to hang out in the "
68                "yard";
69       });
70       buffer.WriteToTensor(tensor, /*new_shape=*/nullptr);
71     } else {
72       std::vector<uint8_t> data(tensor->bytes);
73       for (auto it = data.begin(); it != data.end(); ++it) {
74         *it = random();
75       }
76       sample.push_back(data);
77       tensor->data.raw = reinterpret_cast<char*>(sample.rbegin()->data());
78     }
79   }
80 
81   // Running inference.
82   if (interpreter->Invoke() != kTfLiteOk) {
83     LOG(FATAL) << "Failed to run the model.";
84   }
85 
86   // Get the output.
87   for (int tensor_idx : interpreter->outputs()) {
88     auto tensor = interpreter->tensor(tensor_idx);
89     LOG(INFO) << "Output type: " << TfLiteTypeGetName(tensor->type);
90   }
91 }
92 
main(int argc,char ** argv)93 int main(int argc, char** argv) {
94   // Parse flags to get the filename.
95   std::string filename;
96   std::vector<tflite::Flag> flag_list{tflite::Flag::CreateFlag(
97       "model", &filename, "The tflite model to run sample inference.",
98       tflite::Flag::kRequired)};
99   tflite::Flags::Parse(&argc, const_cast<const char**>(argv), flag_list);
100   tensorflow::port::InitMain(argv[0], &argc, &argv);
101 
102   // Run the model with random inputs.
103   RunWithRandomInputs(filename);
104   return 0;
105 }
106