1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // A simple logging device to test custom device registration.
17 #include <memory>
18
19 #include "tensorflow/c/c_api.h"
20 #include "tensorflow/c/eager/c_api.h"
21 #include "tensorflow/c/eager/c_api_experimental.h"
22 #include "tensorflow/c/eager/c_api_test_util.h"
23 #include "tensorflow/c/tf_status.h"
24 #include "tensorflow/core/lib/gtl/cleanup.h"
25 #include "tensorflow/core/platform/test.h"
26
27 namespace {
28
29 struct LoggingDevice {
30 tensorflow::string device_name;
31 tensorflow::string underlying_device;
32 // Set to true whenever a TensorHandle is copied onto the device
33 bool* arrived_flag;
34 // Set to true whenever an operation is executed
35 bool* executed_flag;
36 // If true, only explicit op placements are accepted. If false, uses
37 // type-based dispatch.
38 bool strict_scope_placement;
39 };
40
41 struct LoggedTensor {
42 TFE_TensorHandle* tensor;
43 LoggedTensor() = delete;
LoggedTensor__anon239fa2ab0111::LoggedTensor44 explicit LoggedTensor(TFE_TensorHandle* tensor) : tensor(tensor) {}
~LoggedTensor__anon239fa2ab0111::LoggedTensor45 ~LoggedTensor() { TFE_DeleteTensorHandle(tensor); }
46 };
47
LoggedTensorDim(void * data,int dim_index,TF_Status * status)48 int64_t LoggedTensorDim(void* data, int dim_index, TF_Status* status) {
49 return TFE_TensorHandleDim(reinterpret_cast<LoggedTensor*>(data)->tensor,
50 dim_index, status);
51 }
52
LoggedTensorNumDims(void * data,TF_Status * status)53 int LoggedTensorNumDims(void* data, TF_Status* status) {
54 return TFE_TensorHandleNumDims(reinterpret_cast<LoggedTensor*>(data)->tensor,
55 status);
56 }
57
LoggedTensorDeallocator(void * data)58 void LoggedTensorDeallocator(void* data) {
59 delete reinterpret_cast<LoggedTensor*>(data);
60 }
61
MakeLoggedTensorHandle(TFE_Context * context,const tensorflow::string & logging_device_name,std::unique_ptr<LoggedTensor> t,TF_Status * status)62 TFE_TensorHandle* MakeLoggedTensorHandle(
63 TFE_Context* context, const tensorflow::string& logging_device_name,
64 std::unique_ptr<LoggedTensor> t, TF_Status* status) {
65 auto dtype = TFE_TensorHandleDataType(t->tensor);
66 TFE_CustomDeviceTensorHandleMethods handle_methods;
67 handle_methods.num_dims = &LoggedTensorNumDims;
68 handle_methods.dim = &LoggedTensorDim;
69 handle_methods.deallocator = &LoggedTensorDeallocator;
70 return TFE_NewCustomDeviceTensorHandle(context, logging_device_name.c_str(),
71 dtype, t.release(), handle_methods,
72 status);
73 }
74
CopyToLoggingDevice(TFE_Context * context,TFE_TensorHandle * tensor,TF_Status * status,void * device_info)75 TFE_TensorHandle* CopyToLoggingDevice(TFE_Context* context,
76 TFE_TensorHandle* tensor,
77 TF_Status* status, void* device_info) {
78 LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
79 TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
80 tensor, context, dev->underlying_device.c_str(), status);
81 if (TF_GetCode(status) != TF_OK) return nullptr;
82 auto dst = std::make_unique<LoggedTensor>(t);
83 *(dev->arrived_flag) = true;
84 return MakeLoggedTensorHandle(context, dev->device_name, std::move(dst),
85 status);
86 }
87
CopyTensorFromLoggingDevice(TFE_Context * context,TFE_TensorHandle * tensor,const char * target_device_name,TF_Status * status,void * device_info)88 TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context,
89 TFE_TensorHandle* tensor,
90 const char* target_device_name,
91 TF_Status* status,
92 void* device_info) {
93 TF_SetStatus(status, TF_INTERNAL,
94 "Trying to copy a tensor out of a logging device.");
95 return nullptr;
96 }
97
LoggingDeviceExecute(const TFE_Op * original_op,int * num_outputs,TFE_TensorHandle ** outputs,TF_Status * s,void * device_info)98 void LoggingDeviceExecute(const TFE_Op* original_op, int* num_outputs,
99 TFE_TensorHandle** outputs, TF_Status* s,
100 void* device_info) {
101 const char* requested_placement = TFE_OpGetDevice(original_op, s);
102 if (TF_GetCode(s) != TF_OK) return;
103
104 LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
105 if (dev->strict_scope_placement && *requested_placement == '\0') {
106 TF_SetStatus(s, TF_INTERNAL,
107 "Ops must be placed on the device explicitly, or their inputs "
108 "first copied to other devices.");
109 return;
110 }
111 TFE_Context* context = TFE_OpGetContext(original_op, s);
112 if (TF_GetCode(s) != TF_OK) return;
113 const char* operation_name = TFE_OpGetName(original_op, s);
114 if (TF_GetCode(s) != TF_OK) return;
115 const TFE_OpAttrs* attributes = TFE_OpGetAttrs(original_op);
116
117 TFE_Op* op(TFE_NewOp(context, operation_name, s));
118 if (TF_GetCode(s) != TF_OK) return;
119 TFE_OpAddAttrs(op, attributes);
120 TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
121 if (TF_GetCode(s) != TF_OK) return;
122 int num_inputs = TFE_OpGetFlatInputCount(original_op, s);
123 if (TF_GetCode(s) != TF_OK) return;
124 for (int j = 0; j < num_inputs; ++j) {
125 TFE_TensorHandle* input = TFE_OpGetFlatInput(original_op, j, s);
126 if (TF_GetCode(s) != TF_OK) return;
127 const char* input_device = TFE_TensorHandleDeviceName(input, s);
128 if (TF_GetCode(s) != TF_OK) return;
129 if (dev->device_name == input_device) {
130 LoggedTensor* t = reinterpret_cast<LoggedTensor*>(
131 TFE_TensorHandleDevicePointer(input, s));
132 if (TF_GetCode(s) != TF_OK) return;
133 TFE_OpAddInput(op, t->tensor, s);
134 } else {
135 TFE_OpAddInput(op, input, s);
136 }
137 if (TF_GetCode(s) != TF_OK) return;
138 }
139 std::vector<TFE_TensorHandle*> op_outputs(*num_outputs);
140 TFE_Execute(op, op_outputs.data(), num_outputs, s);
141 TFE_DeleteOp(op);
142 if (TF_GetCode(s) != TF_OK) return;
143 std::vector<TFE_TensorHandle*> unwrapped_outputs;
144 unwrapped_outputs.reserve(op_outputs.size());
145 for (auto* handle : op_outputs) {
146 unwrapped_outputs.push_back(handle);
147 }
148 for (int i = 0; i < *num_outputs; ++i) {
149 auto logged_tensor = std::make_unique<LoggedTensor>(unwrapped_outputs[i]);
150 outputs[i] = MakeLoggedTensorHandle(context, dev->device_name,
151 std::move(logged_tensor), s);
152 }
153 *(dev->executed_flag) = true;
154 }
155
DeleteLoggingDevice(void * device_info)156 void DeleteLoggingDevice(void* device_info) {
157 delete reinterpret_cast<LoggingDevice*>(device_info);
158 }
159
160 } // namespace
161
RegisterLoggingDevice(TFE_Context * context,const char * name,bool strict_scope_placement,bool * arrived_flag,bool * executed_flag,TF_Status * status)162 void RegisterLoggingDevice(TFE_Context* context, const char* name,
163 bool strict_scope_placement, bool* arrived_flag,
164 bool* executed_flag, TF_Status* status) {
165 TFE_CustomDevice custom_device;
166 custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
167 custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
168 custom_device.delete_device = &DeleteLoggingDevice;
169 custom_device.execute = &LoggingDeviceExecute;
170 LoggingDevice* device = new LoggingDevice;
171 device->arrived_flag = arrived_flag;
172 device->executed_flag = executed_flag;
173 device->device_name = name;
174 device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
175 device->strict_scope_placement = strict_scope_placement;
176 TFE_RegisterCustomDevice(context, custom_device, name, device, status);
177 }
178
UnpackTensorHandle(TFE_TensorHandle * logged_tensor_handle,TF_Status * status)179 TFE_TensorHandle* UnpackTensorHandle(TFE_TensorHandle* logged_tensor_handle,
180 TF_Status* status) {
181 return reinterpret_cast<LoggedTensor*>(
182 TFE_TensorHandleDevicePointer(logged_tensor_handle, status))
183 ->tensor;
184 }
185
AllocateLoggingDevice(const char * name,bool * arrived_flag,bool * executed_flag,TFE_CustomDevice ** device,void ** device_info)186 void AllocateLoggingDevice(const char* name, bool* arrived_flag,
187 bool* executed_flag, TFE_CustomDevice** device,
188 void** device_info) {
189 TFE_CustomDevice* custom_device = new TFE_CustomDevice;
190 custom_device->copy_tensor_to_device = &CopyToLoggingDevice;
191 custom_device->copy_tensor_from_device = &CopyTensorFromLoggingDevice;
192 custom_device->delete_device = &DeleteLoggingDevice;
193 custom_device->execute = &LoggingDeviceExecute;
194 *device = custom_device;
195 LoggingDevice* logging_device = new LoggingDevice;
196 logging_device->arrived_flag = arrived_flag;
197 logging_device->executed_flag = executed_flag;
198 logging_device->device_name = name;
199 logging_device->underlying_device =
200 "/job:localhost/replica:0/task:0/device:CPU:0";
201 logging_device->strict_scope_placement = true;
202 *device_info = reinterpret_cast<void*>(logging_device);
203 }
204