• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_PROFILER_INTERNAL_GPU_CUPTI_INTERFACE_H_
17 #define TENSORFLOW_CORE_PROFILER_INTERNAL_GPU_CUPTI_INTERFACE_H_
18 
19 #include <stddef.h>
20 #include <stdint.h>
21 
22 #include "third_party/gpus/cuda/extras/CUPTI/include/cupti.h"
23 #include "third_party/gpus/cuda/include/cuda.h"
24 #include "tensorflow/core/platform/macros.h"
25 #include "tensorflow/core/platform/types.h"
26 
27 namespace tensorflow {
28 namespace profiler {
29 
30 // Provides a wrapper interface to every single CUPTI API function. This class
31 // is needed to create an easy mock object for CUPTI API calls. All member
32 // functions are defined in the following order: activity related APIs, callback
33 // related APIs, Event APIs, and metric APIs. Within each category, we follow
34 // the order in the original CUPTI documentation.
35 class CuptiInterface {
36  public:
CuptiInterface()37   CuptiInterface() {}
38 
~CuptiInterface()39   virtual ~CuptiInterface() {}
40 
41   // CUPTI activity API
42   virtual CUptiResult ActivityDisable(CUpti_ActivityKind kind) = 0;
43 
44   virtual CUptiResult ActivityEnable(CUpti_ActivityKind kind) = 0;
45 
46   virtual CUptiResult ActivityFlushAll(uint32_t flag) = 0;
47 
48   virtual CUptiResult ActivityGetNextRecord(uint8_t* buffer,
49                                             size_t valid_buffer_size_bytes,
50                                             CUpti_Activity** record) = 0;
51 
52   virtual CUptiResult ActivityGetNumDroppedRecords(CUcontext context,
53                                                    uint32_t stream_id,
54                                                    size_t* dropped) = 0;
55 
56   virtual CUptiResult ActivityConfigureUnifiedMemoryCounter(
57       CUpti_ActivityUnifiedMemoryCounterConfig* config, uint32_t count) = 0;
58 
59   virtual CUptiResult ActivityRegisterCallbacks(
60       CUpti_BuffersCallbackRequestFunc func_buffer_requested,
61       CUpti_BuffersCallbackCompleteFunc func_buffer_completed) = 0;
62 
63   virtual CUptiResult GetDeviceId(CUcontext context, uint32* deviceId) = 0;
64 
65   virtual CUptiResult GetTimestamp(uint64_t* timestamp) = 0;
66 
67   virtual CUptiResult Finalize() = 0;
68 
69   // CUPTI callback API
70   virtual CUptiResult EnableCallback(uint32_t enable,
71                                      CUpti_SubscriberHandle subscriber,
72                                      CUpti_CallbackDomain domain,
73                                      CUpti_CallbackId cbid) = 0;
74 
75   virtual CUptiResult EnableDomain(uint32_t enable,
76                                    CUpti_SubscriberHandle subscriber,
77                                    CUpti_CallbackDomain domain) = 0;
78 
79   virtual CUptiResult Subscribe(CUpti_SubscriberHandle* subscriber,
80                                 CUpti_CallbackFunc callback,
81                                 void* userdata) = 0;
82 
83   virtual CUptiResult Unsubscribe(CUpti_SubscriberHandle subscriber) = 0;
84 
85   // CUPTI event API
86   virtual CUptiResult DeviceEnumEventDomains(
87       CUdevice device, size_t* array_size_bytes,
88       CUpti_EventDomainID* domain_array) = 0;
89 
90   virtual CUptiResult DeviceGetEventDomainAttribute(
91       CUdevice device, CUpti_EventDomainID event_domain,
92       CUpti_EventDomainAttribute attrib, size_t* value_size, void* value) = 0;
93 
94   virtual CUptiResult DisableKernelReplayMode(CUcontext context) = 0;
95 
96   virtual CUptiResult EnableKernelReplayMode(CUcontext context) = 0;
97 
98   virtual CUptiResult DeviceGetNumEventDomains(CUdevice device,
99                                                uint32_t* num_domains) = 0;
100 
101   virtual CUptiResult EventDomainEnumEvents(CUpti_EventDomainID event_domain,
102                                             size_t* array_size_bytes,
103                                             CUpti_EventID* event_array) = 0;
104 
105   virtual CUptiResult EventDomainGetNumEvents(CUpti_EventDomainID event_domain,
106                                               uint32_t* num_events) = 0;
107 
108   virtual CUptiResult EventGetAttribute(CUpti_EventID event,
109                                         CUpti_EventAttribute attrib,
110                                         size_t* value_size, void* value) = 0;
111 
112   virtual CUptiResult EventGetIdFromName(CUdevice device,
113                                          const char* event_name,
114                                          CUpti_EventID* event) = 0;
115 
116   virtual CUptiResult EventGroupDisable(CUpti_EventGroup event_group) = 0;
117 
118   virtual CUptiResult EventGroupEnable(CUpti_EventGroup event_group) = 0;
119 
120   virtual CUptiResult EventGroupGetAttribute(CUpti_EventGroup event_group,
121                                              CUpti_EventGroupAttribute attrib,
122                                              size_t* value_size,
123                                              void* value) = 0;
124 
125   virtual CUptiResult EventGroupReadEvent(CUpti_EventGroup event_group,
126                                           CUpti_ReadEventFlags flags,
127                                           CUpti_EventID event,
128                                           size_t* event_value_buffer_size_bytes,
129                                           uint64_t* eventValueBuffer) = 0;
130 
131   virtual CUptiResult EventGroupSetAttribute(CUpti_EventGroup event_group,
132                                              CUpti_EventGroupAttribute attrib,
133                                              size_t value_size,
134                                              void* value) = 0;
135 
136   virtual CUptiResult EventGroupSetsCreate(
137       CUcontext context, size_t event_id_array_size_bytes,
138       CUpti_EventID* event_id_array,
139       CUpti_EventGroupSets** event_group_passes) = 0;
140 
141   virtual CUptiResult EventGroupSetsDestroy(
142       CUpti_EventGroupSets* event_group_sets) = 0;
143 
144   // CUPTI metric API
145   virtual CUptiResult DeviceEnumMetrics(CUdevice device, size_t* arraySizeBytes,
146                                         CUpti_MetricID* metricArray) = 0;
147 
148   virtual CUptiResult DeviceGetNumMetrics(CUdevice device,
149                                           uint32_t* num_metrics) = 0;
150 
151   virtual CUptiResult MetricGetIdFromName(CUdevice device,
152                                           const char* metric_name,
153                                           CUpti_MetricID* metric) = 0;
154 
155   virtual CUptiResult MetricGetNumEvents(CUpti_MetricID metric,
156                                          uint32_t* num_events) = 0;
157 
158   virtual CUptiResult MetricEnumEvents(CUpti_MetricID metric,
159                                        size_t* event_id_array_size_bytes,
160                                        CUpti_EventID* event_id_array) = 0;
161 
162   virtual CUptiResult MetricGetAttribute(CUpti_MetricID metric,
163                                          CUpti_MetricAttribute attrib,
164                                          size_t* value_size, void* value) = 0;
165 
166   virtual CUptiResult MetricGetValue(CUdevice device, CUpti_MetricID metric,
167                                      size_t event_id_array_size_bytes,
168                                      CUpti_EventID* event_id_array,
169                                      size_t event_value_array_size_bytes,
170                                      uint64_t* event_value_array,
171                                      uint64_t time_duration,
172                                      CUpti_MetricValue* metric_value) = 0;
173 
174   virtual CUptiResult GetResultString(CUptiResult result, const char** str) = 0;
175 
176   virtual CUptiResult GetContextId(CUcontext context, uint32_t* context_id) = 0;
177 
178   virtual CUptiResult GetStreamIdEx(CUcontext context, CUstream stream,
179                                     uint8_t per_thread_stream,
180                                     uint32_t* stream_id) = 0;
181 
182   // Interface maintenance functions. Not directly related to CUPTI, but
183   // required for implementing an error resilient layer over CUPTI API.
184 
185   // Performance any clean up work that is required each time profile session
186   // is done. Therefore this can be called multiple times during process life
187   // time.
188   virtual void CleanUp() = 0;
189 
190   // Whether CUPTI API is currently disabled due to unrecoverable errors.
191   // All subsequent calls will fail immediately without forwarding calls to
192   // CUPTI library.
193   virtual bool Disabled() const = 0;
194 
195  private:
196   TF_DISALLOW_COPY_AND_ASSIGN(CuptiInterface);
197 };
198 
199 CuptiInterface* GetCuptiInterface();
200 
201 }  // namespace profiler
202 }  // namespace tensorflow
203 
204 #endif  // TENSORFLOW_CORE_PROFILER_INTERNAL_GPU_CUPTI_INTERFACE_H_
205