• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <algorithm>
17 #include <chrono>  // NOLINT(build/c++11)
18 #include <iostream>
19 #include <string>
20 
21 #include "absl/time/time.h"
22 #include "tensorflow/lite/delegates/gpu/cl/environment.h"
23 #include "tensorflow/lite/delegates/gpu/cl/inference_context.h"
24 #include "tensorflow/lite/delegates/gpu/common/model.h"
25 #include "tensorflow/lite/delegates/gpu/common/model_builder.h"
26 #include "tensorflow/lite/delegates/gpu/common/status.h"
27 #include "tensorflow/lite/kernels/register.h"
28 
29 namespace tflite {
30 namespace gpu {
31 namespace cl {
32 
RunModelSample(const std::string & model_name)33 absl::Status RunModelSample(const std::string& model_name) {
34   auto flatbuffer = tflite::FlatBufferModel::BuildFromFile(model_name.c_str());
35   GraphFloat32 graph_cl;
36   ops::builtin::BuiltinOpResolver op_resolver;
37   RETURN_IF_ERROR(BuildFromFlatBuffer(*flatbuffer, op_resolver, &graph_cl,
38                                       /*allow_quant_ops*/ true));
39 
40   Environment env;
41   RETURN_IF_ERROR(CreateEnvironment(&env));
42 
43   InferenceContext::CreateInferenceInfo create_info;
44   create_info.precision = env.IsSupported(CalculationsPrecision::F16)
45                               ? CalculationsPrecision::F16
46                               : CalculationsPrecision::F32;
47   create_info.storage_type = GetFastestStorageType(env.device().GetInfo());
48   create_info.hints.Add(ModelHints::kAllowSpecialKernels);
49   std::cout << "Precision: " << ToString(create_info.precision) << std::endl;
50   std::cout << "Storage type: " << ToString(create_info.storage_type)
51             << std::endl;
52   InferenceContext context;
53   RETURN_IF_ERROR(
54       context.InitFromGraphWithTransforms(create_info, &graph_cl, &env));
55 
56   auto* queue = env.profiling_queue();
57   ProfilingInfo profiling_info;
58   RETURN_IF_ERROR(context.Profile(queue, &profiling_info));
59   std::cout << profiling_info.GetDetailedReport() << std::endl;
60   uint64_t mem_bytes = context.GetSizeOfMemoryAllocatedForIntermediateTensors();
61   std::cout << "Memory for intermediate tensors - "
62             << mem_bytes / 1024.0 / 1024.0 << " MB" << std::endl;
63 
64   const int num_runs_per_sec = std::max(
65       1, static_cast<int>(1000.0f / absl::ToDoubleMilliseconds(
66                                         profiling_info.GetTotalTime())));
67 
68   const int kNumRuns = 10;
69   for (int i = 0; i < kNumRuns; ++i) {
70     const auto start = std::chrono::high_resolution_clock::now();
71     for (int k = 0; k < num_runs_per_sec; ++k) {
72       RETURN_IF_ERROR(context.AddToQueue(env.queue()));
73     }
74     RETURN_IF_ERROR(env.queue()->WaitForCompletion());
75     const auto end = std::chrono::high_resolution_clock::now();
76     const double total_time_ms = (end - start).count() * 1e-6f;
77     const double average_inference_time = total_time_ms / num_runs_per_sec;
78     std::cout << "Total time - " << average_inference_time << "ms" << std::endl;
79   }
80 
81   return absl::OkStatus();
82 }
83 
84 }  // namespace cl
85 }  // namespace gpu
86 }  // namespace tflite
87 
main(int argc,char ** argv)88 int main(int argc, char** argv) {
89   if (argc <= 1) {
90     std::cerr << "Expected model path as second argument.";
91     return -1;
92   }
93 
94   auto load_status = tflite::gpu::cl::LoadOpenCL();
95   if (!load_status.ok()) {
96     std::cerr << load_status.message();
97     return -1;
98   }
99 
100   auto run_status = tflite::gpu::cl::RunModelSample(argv[1]);
101   if (!run_status.ok()) {
102     std::cerr << run_status.message();
103     return -1;
104   }
105 
106   return EXIT_SUCCESS;
107 }
108