1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <vector>
16
17 #include "tensorflow/c/c_api_macros.h"
18 #include "tensorflow/c/c_api_macros_internal.h"
19 #include "tensorflow/c/experimental/pluggable_profiler/pluggable_profiler_internal.h"
20 #include "tensorflow/c/tf_status_helper.h"
21 #include "tensorflow/core/common_runtime/device/device_utils.h"
22 #include "tensorflow/core/platform/errors.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/platform/mutex.h"
25 #include "tensorflow/core/platform/status.h"
26 #include "tensorflow/core/profiler/lib/profiler_factory.h"
27 #include "tensorflow/core/profiler/lib/profiler_interface.h"
28 #include "tensorflow/core/profiler/profiler_options.pb.h"
29
30 namespace tensorflow {
31 namespace profiler {
32
33 namespace {
34
ValidateTPProfilerRegistrationParams(const TF_ProfilerRegistrationParams & params)35 Status ValidateTPProfilerRegistrationParams(
36 const TF_ProfilerRegistrationParams& params) {
37 TF_VALIDATE_STRUCT_SIZE(TF_ProfilerRegistrationParams, params,
38 TF_PROFILER_REGISTRATION_PARAMS_STRUCT_SIZE);
39 TF_VALIDATE_NOT_NULL(TF_ProfilerRegistrationParams, params, destroy_profiler);
40 TF_VALIDATE_NOT_NULL(TF_ProfilerRegistrationParams, params,
41 destroy_profiler_fns);
42 return OkStatus();
43 }
44
ValidateTPProfiler(const TP_Profiler & profiler)45 Status ValidateTPProfiler(const TP_Profiler& profiler) {
46 TF_VALIDATE_STRUCT_SIZE(TP_Profiler, profiler, TP_PROFILER_STRUCT_SIZE);
47 TF_VALIDATE_NOT_NULL(TP_Profiler, profiler, device_type);
48 TF_RETURN_IF_ERROR(
49 tensorflow::device_utils::ValidateDeviceType(profiler.device_type));
50 return OkStatus();
51 }
52
ValidateTPProfilerFns(const TP_ProfilerFns & profiler_fns)53 Status ValidateTPProfilerFns(const TP_ProfilerFns& profiler_fns) {
54 TF_VALIDATE_STRUCT_SIZE(TP_ProfilerFns, profiler_fns,
55 TF_PROFILER_FNS_STRUCT_SIZE);
56 TF_VALIDATE_NOT_NULL(TP_ProfilerFns, profiler_fns, start);
57 TF_VALIDATE_NOT_NULL(TP_ProfilerFns, profiler_fns, stop);
58 TF_VALIDATE_NOT_NULL(TP_ProfilerFns, profiler_fns, collect_data_xspace);
59 return OkStatus();
60 }
61
62 class PluggableProfiler : public tensorflow::profiler::ProfilerInterface {
63 public:
64 // The caller must have validated profiler_fns and profiler.
65 static std::unique_ptr<tensorflow::profiler::ProfilerInterface>
CreatePluggableProfiler(const ProfileOptions & options,TP_Profiler profiler,TP_ProfilerFns profiler_fns)66 CreatePluggableProfiler(const ProfileOptions& options, TP_Profiler profiler,
67 TP_ProfilerFns profiler_fns) {
68 if (options.device_tracer_level() == 0) {
69 return nullptr;
70 }
71 if (options.device_type() != ProfileOptions::PLUGGABLE_DEVICE &&
72 options.device_type() != ProfileOptions::UNSPECIFIED) {
73 return nullptr;
74 }
75 return absl::WrapUnique(new PluggableProfiler(profiler_fns, profiler));
76 }
77
Start()78 Status Start() override {
79 tensorflow::TF_StatusPtr status(TF_NewStatus());
80 profiler_fns_.start(&profiler_, status.get());
81 return tensorflow::StatusFromTF_Status(status.get());
82 }
83
Stop()84 Status Stop() override {
85 tensorflow::TF_StatusPtr status(TF_NewStatus());
86 profiler_fns_.stop(&profiler_, status.get());
87 return tensorflow::StatusFromTF_Status(status.get());
88 }
89
CollectData(XSpace * space)90 Status CollectData(XSpace* space) override {
91 tensorflow::TF_StatusPtr status(TF_NewStatus());
92 // Get size of buffer required for Plugin to serialize XSpace into it.
93 size_t size_in_bytes;
94 profiler_fns_.collect_data_xspace(&profiler_, /*buffer=*/nullptr,
95 &size_in_bytes, status.get());
96
97 if (size_in_bytes <= 0)
98 return tensorflow::StatusFromTF_Status(status.get());
99
100 // Prepare an appropriately sized buffer.
101 std::vector<uint8_t> buffer(size_in_bytes);
102 profiler_fns_.collect_data_xspace(&profiler_, buffer.data(), &size_in_bytes,
103 status.get());
104 // Deserialize XSpace from the buffer and return it.
105 XSpace plugin_space;
106 plugin_space.ParseFromArray(buffer.data(), buffer.size());
107 for (XPlane& plugin_plane : *plugin_space.mutable_planes()) {
108 XPlane* plane = space->add_planes();
109 plane->Swap(&plugin_plane);
110 }
111 return tensorflow::StatusFromTF_Status(status.get());
112 }
113
114 private:
PluggableProfiler(TP_ProfilerFns profiler_fns,TP_Profiler profiler)115 PluggableProfiler(TP_ProfilerFns profiler_fns, TP_Profiler profiler)
116 : profiler_fns_(profiler_fns), profiler_(profiler) {}
117 TP_ProfilerFns profiler_fns_;
118 TP_Profiler profiler_;
119 };
120
121 class PluggableProfilerFactory {
122 public:
PluggableProfilerFactory(TP_Profiler profiler,void (* destroy_profiler)(TP_Profiler *),TP_ProfilerFns profiler_fns,void (* destroy_profiler_fns)(TP_ProfilerFns *))123 PluggableProfilerFactory(TP_Profiler profiler,
124 void (*destroy_profiler)(TP_Profiler*),
125 TP_ProfilerFns profiler_fns,
126 void (*destroy_profiler_fns)(TP_ProfilerFns*))
127 : profiler_(std::move(profiler)),
128 destroy_profiler_(destroy_profiler),
129 profiler_fns_(std::move(profiler_fns)),
130 destroy_profiler_fns_(destroy_profiler_fns) {}
131
~PluggableProfilerFactory()132 ~PluggableProfilerFactory() {
133 destroy_profiler_(&profiler_);
134 destroy_profiler_fns_(&profiler_fns_);
135 }
136
137 std::unique_ptr<tensorflow::profiler::ProfilerInterface>
CreatePluggableProfiler(const ProfileOptions & options)138 CreatePluggableProfiler(const ProfileOptions& options) {
139 return PluggableProfiler::CreatePluggableProfiler(options, profiler_,
140 profiler_fns_);
141 }
142
143 private:
144 TP_Profiler profiler_{TP_PROFILER_STRUCT_SIZE};
145 void (*destroy_profiler_)(TP_Profiler*);
146 TP_ProfilerFns profiler_fns_{TP_PROFILER_FNS_STRUCT_SIZE};
147 void (*destroy_profiler_fns_)(TP_ProfilerFns*);
148 };
149
150 } // namespace
151
InitPluginProfiler(TFInitProfilerFn init_fn)152 Status InitPluginProfiler(TFInitProfilerFn init_fn) {
153 TF_ProfilerRegistrationParams params{
154 TF_PROFILER_REGISTRATION_PARAMS_STRUCT_SIZE};
155 TP_Profiler profiler{TP_PROFILER_STRUCT_SIZE};
156 TP_ProfilerFns profiler_fns{TP_PROFILER_FNS_STRUCT_SIZE};
157 params.major_version = TP_MAJOR;
158 params.minor_version = TP_MINOR;
159 params.patch_version = TP_PATCH;
160 params.profiler = &profiler;
161 params.profiler_fns = &profiler_fns;
162 tensorflow::TF_StatusPtr status(TF_NewStatus());
163 init_fn(¶ms, status.get());
164 TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(status.get()));
165 TF_RETURN_IF_ERROR(ValidateTPProfilerRegistrationParams(params));
166 TF_RETURN_IF_ERROR(ValidateTPProfiler(profiler));
167 TF_RETURN_IF_ERROR(ValidateTPProfilerFns(profiler_fns));
168
169 PluggableProfilerFactory factory(std::move(profiler), params.destroy_profiler,
170 std::move(profiler_fns),
171 params.destroy_profiler_fns);
172 std::function<std::unique_ptr<ProfilerInterface>(const ProfileOptions&)>
173 create_func = [factory = std::move(factory)](
174 const ProfileOptions& options) mutable {
175 return factory.CreatePluggableProfiler(options);
176 };
177
178 tensorflow::profiler::RegisterProfilerFactory(std::move(create_func));
179 return OkStatus();
180 }
181
182 } // namespace profiler
183 } // namespace tensorflow
184