1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <algorithm>
17 #include <chrono> // NOLINT(build/c++11)
18 #include <iostream>
19 #include <string>
20
21 #include "absl/time/time.h"
22 #include "tensorflow/lite/delegates/gpu/cl/environment.h"
23 #include "tensorflow/lite/delegates/gpu/cl/inference_context.h"
24 #include "tensorflow/lite/delegates/gpu/common/model.h"
25 #include "tensorflow/lite/delegates/gpu/common/model_builder.h"
26 #include "tensorflow/lite/delegates/gpu/common/status.h"
27 #include "tensorflow/lite/kernels/register.h"
28
29 namespace tflite {
30 namespace gpu {
31 namespace cl {
32
RunModelSample(const std::string & model_name)33 absl::Status RunModelSample(const std::string& model_name) {
34 auto flatbuffer = tflite::FlatBufferModel::BuildFromFile(model_name.c_str());
35 GraphFloat32 graph_cl;
36 ops::builtin::BuiltinOpResolver op_resolver;
37 RETURN_IF_ERROR(BuildFromFlatBuffer(*flatbuffer, op_resolver, &graph_cl,
38 /*allow_quant_ops*/ true));
39
40 Environment env;
41 RETURN_IF_ERROR(CreateEnvironment(&env));
42
43 InferenceContext::CreateInferenceInfo create_info;
44 create_info.precision = env.IsSupported(CalculationsPrecision::F16)
45 ? CalculationsPrecision::F16
46 : CalculationsPrecision::F32;
47 create_info.storage_type = GetFastestStorageType(env.device().GetInfo());
48 create_info.hints.Add(ModelHints::kAllowSpecialKernels);
49 std::cout << "Precision: " << ToString(create_info.precision) << std::endl;
50 std::cout << "Storage type: " << ToString(create_info.storage_type)
51 << std::endl;
52 InferenceContext context;
53 RETURN_IF_ERROR(
54 context.InitFromGraphWithTransforms(create_info, &graph_cl, &env));
55
56 auto* queue = env.profiling_queue();
57 ProfilingInfo profiling_info;
58 RETURN_IF_ERROR(context.Profile(queue, &profiling_info));
59 std::cout << profiling_info.GetDetailedReport() << std::endl;
60 uint64_t mem_bytes = context.GetSizeOfMemoryAllocatedForIntermediateTensors();
61 std::cout << "Memory for intermediate tensors - "
62 << mem_bytes / 1024.0 / 1024.0 << " MB" << std::endl;
63
64 const int num_runs_per_sec = std::max(
65 1, static_cast<int>(1000.0f / absl::ToDoubleMilliseconds(
66 profiling_info.GetTotalTime())));
67
68 const int kNumRuns = 10;
69 for (int i = 0; i < kNumRuns; ++i) {
70 const auto start = std::chrono::high_resolution_clock::now();
71 for (int k = 0; k < num_runs_per_sec; ++k) {
72 RETURN_IF_ERROR(context.AddToQueue(env.queue()));
73 }
74 RETURN_IF_ERROR(env.queue()->WaitForCompletion());
75 const auto end = std::chrono::high_resolution_clock::now();
76 const double total_time_ms = (end - start).count() * 1e-6f;
77 const double average_inference_time = total_time_ms / num_runs_per_sec;
78 std::cout << "Total time - " << average_inference_time << "ms" << std::endl;
79 }
80
81 return absl::OkStatus();
82 }
83
84 } // namespace cl
85 } // namespace gpu
86 } // namespace tflite
87
main(int argc,char ** argv)88 int main(int argc, char** argv) {
89 if (argc <= 1) {
90 std::cerr << "Expected model path as second argument.";
91 return -1;
92 }
93
94 auto load_status = tflite::gpu::cl::LoadOpenCL();
95 if (!load_status.ok()) {
96 std::cerr << load_status.message();
97 return -1;
98 }
99
100 auto run_status = tflite::gpu::cl::RunModelSample(argv[1]);
101 if (!run_status.ok()) {
102 std::cerr << run_status.message();
103 return -1;
104 }
105
106 return EXIT_SUCCESS;
107 }
108