• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include <vector>
16 
17 #include "tensorflow/c/c_api_macros.h"
18 #include "tensorflow/c/c_api_macros_internal.h"
19 #include "tensorflow/c/experimental/pluggable_profiler/pluggable_profiler_internal.h"
20 #include "tensorflow/c/tf_status_helper.h"
21 #include "tensorflow/core/common_runtime/device/device_utils.h"
22 #include "tensorflow/core/platform/errors.h"
23 #include "tensorflow/core/platform/logging.h"
24 #include "tensorflow/core/platform/mutex.h"
25 #include "tensorflow/core/platform/status.h"
26 #include "tensorflow/core/profiler/lib/profiler_factory.h"
27 #include "tensorflow/core/profiler/lib/profiler_interface.h"
28 #include "tensorflow/core/profiler/profiler_options.pb.h"
29 
30 namespace tensorflow {
31 namespace profiler {
32 
33 namespace {
34 
ValidateTPProfilerRegistrationParams(const TF_ProfilerRegistrationParams & params)35 Status ValidateTPProfilerRegistrationParams(
36     const TF_ProfilerRegistrationParams& params) {
37   TF_VALIDATE_STRUCT_SIZE(TF_ProfilerRegistrationParams, params,
38                           TF_PROFILER_REGISTRATION_PARAMS_STRUCT_SIZE);
39   TF_VALIDATE_NOT_NULL(TF_ProfilerRegistrationParams, params, destroy_profiler);
40   TF_VALIDATE_NOT_NULL(TF_ProfilerRegistrationParams, params,
41                        destroy_profiler_fns);
42   return OkStatus();
43 }
44 
ValidateTPProfiler(const TP_Profiler & profiler)45 Status ValidateTPProfiler(const TP_Profiler& profiler) {
46   TF_VALIDATE_STRUCT_SIZE(TP_Profiler, profiler, TP_PROFILER_STRUCT_SIZE);
47   TF_VALIDATE_NOT_NULL(TP_Profiler, profiler, device_type);
48   TF_RETURN_IF_ERROR(
49       tensorflow::device_utils::ValidateDeviceType(profiler.device_type));
50   return OkStatus();
51 }
52 
ValidateTPProfilerFns(const TP_ProfilerFns & profiler_fns)53 Status ValidateTPProfilerFns(const TP_ProfilerFns& profiler_fns) {
54   TF_VALIDATE_STRUCT_SIZE(TP_ProfilerFns, profiler_fns,
55                           TF_PROFILER_FNS_STRUCT_SIZE);
56   TF_VALIDATE_NOT_NULL(TP_ProfilerFns, profiler_fns, start);
57   TF_VALIDATE_NOT_NULL(TP_ProfilerFns, profiler_fns, stop);
58   TF_VALIDATE_NOT_NULL(TP_ProfilerFns, profiler_fns, collect_data_xspace);
59   return OkStatus();
60 }
61 
62 class PluggableProfiler : public tensorflow::profiler::ProfilerInterface {
63  public:
64   // The caller must have validated profiler_fns and profiler.
65   static std::unique_ptr<tensorflow::profiler::ProfilerInterface>
CreatePluggableProfiler(const ProfileOptions & options,TP_Profiler profiler,TP_ProfilerFns profiler_fns)66   CreatePluggableProfiler(const ProfileOptions& options, TP_Profiler profiler,
67                           TP_ProfilerFns profiler_fns) {
68     if (options.device_tracer_level() == 0) {
69       return nullptr;
70     }
71     if (options.device_type() != ProfileOptions::PLUGGABLE_DEVICE &&
72         options.device_type() != ProfileOptions::UNSPECIFIED) {
73       return nullptr;
74     }
75     return absl::WrapUnique(new PluggableProfiler(profiler_fns, profiler));
76   }
77 
Start()78   Status Start() override {
79     tensorflow::TF_StatusPtr status(TF_NewStatus());
80     profiler_fns_.start(&profiler_, status.get());
81     return tensorflow::StatusFromTF_Status(status.get());
82   }
83 
Stop()84   Status Stop() override {
85     tensorflow::TF_StatusPtr status(TF_NewStatus());
86     profiler_fns_.stop(&profiler_, status.get());
87     return tensorflow::StatusFromTF_Status(status.get());
88   }
89 
CollectData(XSpace * space)90   Status CollectData(XSpace* space) override {
91     tensorflow::TF_StatusPtr status(TF_NewStatus());
92     // Get size of buffer required for Plugin to serialize XSpace into it.
93     size_t size_in_bytes;
94     profiler_fns_.collect_data_xspace(&profiler_, /*buffer=*/nullptr,
95                                       &size_in_bytes, status.get());
96 
97     if (size_in_bytes <= 0)
98       return tensorflow::StatusFromTF_Status(status.get());
99 
100     // Prepare an appropriately sized buffer.
101     std::vector<uint8_t> buffer(size_in_bytes);
102     profiler_fns_.collect_data_xspace(&profiler_, buffer.data(), &size_in_bytes,
103                                       status.get());
104     // Deserialize XSpace from the buffer and return it.
105     XSpace plugin_space;
106     plugin_space.ParseFromArray(buffer.data(), buffer.size());
107     for (XPlane& plugin_plane : *plugin_space.mutable_planes()) {
108       XPlane* plane = space->add_planes();
109       plane->Swap(&plugin_plane);
110     }
111     return tensorflow::StatusFromTF_Status(status.get());
112   }
113 
114  private:
PluggableProfiler(TP_ProfilerFns profiler_fns,TP_Profiler profiler)115   PluggableProfiler(TP_ProfilerFns profiler_fns, TP_Profiler profiler)
116       : profiler_fns_(profiler_fns), profiler_(profiler) {}
117   TP_ProfilerFns profiler_fns_;
118   TP_Profiler profiler_;
119 };
120 
121 class PluggableProfilerFactory {
122  public:
PluggableProfilerFactory(TP_Profiler profiler,void (* destroy_profiler)(TP_Profiler *),TP_ProfilerFns profiler_fns,void (* destroy_profiler_fns)(TP_ProfilerFns *))123   PluggableProfilerFactory(TP_Profiler profiler,
124                            void (*destroy_profiler)(TP_Profiler*),
125                            TP_ProfilerFns profiler_fns,
126                            void (*destroy_profiler_fns)(TP_ProfilerFns*))
127       : profiler_(std::move(profiler)),
128         destroy_profiler_(destroy_profiler),
129         profiler_fns_(std::move(profiler_fns)),
130         destroy_profiler_fns_(destroy_profiler_fns) {}
131 
~PluggableProfilerFactory()132   ~PluggableProfilerFactory() {
133     destroy_profiler_(&profiler_);
134     destroy_profiler_fns_(&profiler_fns_);
135   }
136 
137   std::unique_ptr<tensorflow::profiler::ProfilerInterface>
CreatePluggableProfiler(const ProfileOptions & options)138   CreatePluggableProfiler(const ProfileOptions& options) {
139     return PluggableProfiler::CreatePluggableProfiler(options, profiler_,
140                                                       profiler_fns_);
141   }
142 
143  private:
144   TP_Profiler profiler_{TP_PROFILER_STRUCT_SIZE};
145   void (*destroy_profiler_)(TP_Profiler*);
146   TP_ProfilerFns profiler_fns_{TP_PROFILER_FNS_STRUCT_SIZE};
147   void (*destroy_profiler_fns_)(TP_ProfilerFns*);
148 };
149 
150 }  // namespace
151 
InitPluginProfiler(TFInitProfilerFn init_fn)152 Status InitPluginProfiler(TFInitProfilerFn init_fn) {
153   TF_ProfilerRegistrationParams params{
154       TF_PROFILER_REGISTRATION_PARAMS_STRUCT_SIZE};
155   TP_Profiler profiler{TP_PROFILER_STRUCT_SIZE};
156   TP_ProfilerFns profiler_fns{TP_PROFILER_FNS_STRUCT_SIZE};
157   params.major_version = TP_MAJOR;
158   params.minor_version = TP_MINOR;
159   params.patch_version = TP_PATCH;
160   params.profiler = &profiler;
161   params.profiler_fns = &profiler_fns;
162   tensorflow::TF_StatusPtr status(TF_NewStatus());
163   init_fn(&params, status.get());
164   TF_RETURN_IF_ERROR(tensorflow::StatusFromTF_Status(status.get()));
165   TF_RETURN_IF_ERROR(ValidateTPProfilerRegistrationParams(params));
166   TF_RETURN_IF_ERROR(ValidateTPProfiler(profiler));
167   TF_RETURN_IF_ERROR(ValidateTPProfilerFns(profiler_fns));
168 
169   PluggableProfilerFactory factory(std::move(profiler), params.destroy_profiler,
170                                    std::move(profiler_fns),
171                                    params.destroy_profiler_fns);
172   std::function<std::unique_ptr<ProfilerInterface>(const ProfileOptions&)>
173       create_func = [factory = std::move(factory)](
174                         const ProfileOptions& options) mutable {
175         return factory.CreatePluggableProfiler(options);
176       };
177 
178   tensorflow::profiler::RegisterProfilerFactory(std::move(create_func));
179   return OkStatus();
180 }
181 
182 }  // namespace profiler
183 }  // namespace tensorflow
184