1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // A simple logging device to test custom device registration.
17 #include <memory>
18
19 #include "tensorflow/c/c_api.h"
20 #include "tensorflow/c/eager/c_api.h"
21 #include "tensorflow/c/eager/c_api_experimental.h"
22 #include "tensorflow/c/eager/c_api_test_util.h"
23 #include "tensorflow/c/tf_status.h"
24 #include "tensorflow/core/lib/gtl/cleanup.h"
25 #include "tensorflow/core/platform/test.h"
26
27 namespace {
28
29 struct LoggingDevice {
30 tensorflow::string device_name;
31 tensorflow::string underlying_device;
32 // Set to true whenever a TensorHandle is copied onto the device
33 bool* arrived_flag;
34 // Set to true whenever an operation is executed
35 bool* executed_flag;
36 // If true, only explicit op placements are accepted. If false, uses
37 // type-based dispatch.
38 bool strict_scope_placement;
39 };
40
41 struct LoggedTensor {
42 TFE_TensorHandle* tensor;
43 LoggedTensor() = delete;
LoggedTensor__anonb3a4e9dc0111::LoggedTensor44 explicit LoggedTensor(TFE_TensorHandle* tensor) : tensor(tensor) {}
~LoggedTensor__anonb3a4e9dc0111::LoggedTensor45 ~LoggedTensor() { TFE_DeleteTensorHandle(tensor); }
46 };
47
LoggedTensorDeallocator(void * data,size_t len,void * arg)48 void LoggedTensorDeallocator(void* data, size_t len, void* arg) {
49 delete reinterpret_cast<LoggedTensor*>(data);
50 }
51
MakeLoggedTensorHandle(TFE_Context * context,const tensorflow::string & logging_device_name,std::unique_ptr<LoggedTensor> t,TF_Status * status)52 TFE_TensorHandle* MakeLoggedTensorHandle(
53 TFE_Context* context, const tensorflow::string& logging_device_name,
54 std::unique_ptr<LoggedTensor> t, TF_Status* status) {
55 std::vector<int64_t> shape(TFE_TensorHandleNumDims(t->tensor, status));
56 if (TF_GetCode(status) != TF_OK) return nullptr;
57 for (int i = 0; i < shape.size(); ++i) {
58 shape[i] = TFE_TensorHandleDim(t->tensor, i, status);
59 if (TF_GetCode(status) != TF_OK) return nullptr;
60 }
61 auto dtype = TFE_TensorHandleDataType(t->tensor);
62 return TFE_NewTensorHandleFromDeviceMemory(
63 context, logging_device_name.c_str(), dtype, shape.data(), shape.size(),
64 t.release(), 1, &LoggedTensorDeallocator, nullptr, status);
65 }
66
CopyToLoggingDevice(TFE_Context * context,TFE_TensorHandle * tensor,TF_Status * status,void * device_info)67 TFE_TensorHandle* CopyToLoggingDevice(TFE_Context* context,
68 TFE_TensorHandle* tensor,
69 TF_Status* status, void* device_info) {
70 LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
71 TFE_TensorHandle* t = TFE_TensorHandleCopyToDevice(
72 tensor, context, dev->underlying_device.c_str(), status);
73 if (TF_GetCode(status) != TF_OK) return nullptr;
74 auto dst = std::make_unique<LoggedTensor>(t);
75 *(dev->arrived_flag) = true;
76 return MakeLoggedTensorHandle(context, dev->device_name, std::move(dst),
77 status);
78 }
79
CopyTensorFromLoggingDevice(TFE_Context * context,TFE_TensorHandle * tensor,const char * target_device_name,TF_Status * status,void * device_info)80 TFE_TensorHandle* CopyTensorFromLoggingDevice(TFE_Context* context,
81 TFE_TensorHandle* tensor,
82 const char* target_device_name,
83 TF_Status* status,
84 void* device_info) {
85 TF_SetStatus(status, TF_INTERNAL,
86 "Trying to copy a tensor out of a logging device.");
87 return nullptr;
88 }
89
LoggingDeviceExecute(const TFE_Op * original_op,int * num_outputs,TFE_TensorHandle ** outputs,TF_Status * s,void * device_info)90 void LoggingDeviceExecute(const TFE_Op* original_op, int* num_outputs,
91 TFE_TensorHandle** outputs, TF_Status* s,
92 void* device_info) {
93 const char* requested_placement = TFE_OpGetDevice(original_op, s);
94 if (TF_GetCode(s) != TF_OK) return;
95
96 LoggingDevice* dev = reinterpret_cast<LoggingDevice*>(device_info);
97 if (dev->strict_scope_placement && *requested_placement == '\0') {
98 TF_SetStatus(s, TF_INTERNAL,
99 "Ops must be placed on the device explicitly, or their inputs "
100 "first copied to other devices.");
101 return;
102 }
103 TFE_Context* context = TFE_OpGetContext(original_op, s);
104 if (TF_GetCode(s) != TF_OK) return;
105 const char* operation_name = TFE_OpGetName(original_op, s);
106 if (TF_GetCode(s) != TF_OK) return;
107 const TFE_OpAttrs* attributes = TFE_OpGetAttrs(original_op);
108
109 TFE_Op* op(TFE_NewOp(context, operation_name, s));
110 if (TF_GetCode(s) != TF_OK) return;
111 TFE_OpAddAttrs(op, attributes);
112 TFE_OpSetDevice(op, dev->underlying_device.c_str(), s);
113 if (TF_GetCode(s) != TF_OK) return;
114 int num_inputs = TFE_OpGetFlatInputCount(original_op, s);
115 if (TF_GetCode(s) != TF_OK) return;
116 for (int j = 0; j < num_inputs; ++j) {
117 TFE_TensorHandle* input = TFE_OpGetFlatInput(original_op, j, s);
118 if (TF_GetCode(s) != TF_OK) return;
119 const char* input_device = TFE_TensorHandleDeviceName(input, s);
120 if (TF_GetCode(s) != TF_OK) return;
121 if (dev->device_name == input_device) {
122 LoggedTensor* t = reinterpret_cast<LoggedTensor*>(
123 TFE_TensorHandleDevicePointer(input, s));
124 if (TF_GetCode(s) != TF_OK) return;
125 TFE_OpAddInput(op, t->tensor, s);
126 } else {
127 TFE_OpAddInput(op, input, s);
128 }
129 if (TF_GetCode(s) != TF_OK) return;
130 }
131 std::vector<TFE_TensorHandle*> op_outputs(*num_outputs);
132 TFE_Execute(op, op_outputs.data(), num_outputs, s);
133 TFE_DeleteOp(op);
134 if (TF_GetCode(s) != TF_OK) return;
135 std::vector<TFE_TensorHandle*> unwrapped_outputs;
136 unwrapped_outputs.reserve(op_outputs.size());
137 for (auto* handle : op_outputs) {
138 unwrapped_outputs.push_back(handle);
139 }
140 for (int i = 0; i < *num_outputs; ++i) {
141 auto logged_tensor = std::make_unique<LoggedTensor>(unwrapped_outputs[i]);
142 outputs[i] = MakeLoggedTensorHandle(context, dev->device_name,
143 std::move(logged_tensor), s);
144 }
145 *(dev->executed_flag) = true;
146 }
147
DeleteLoggingDevice(void * device_info)148 void DeleteLoggingDevice(void* device_info) {
149 delete reinterpret_cast<LoggingDevice*>(device_info);
150 }
151
152 } // namespace
153
RegisterLoggingDevice(TFE_Context * context,const char * name,bool strict_scope_placement,bool * arrived_flag,bool * executed_flag,TF_Status * status)154 void RegisterLoggingDevice(TFE_Context* context, const char* name,
155 bool strict_scope_placement, bool* arrived_flag,
156 bool* executed_flag, TF_Status* status) {
157 TFE_CustomDevice custom_device;
158 custom_device.copy_tensor_to_device = &CopyToLoggingDevice;
159 custom_device.copy_tensor_from_device = &CopyTensorFromLoggingDevice;
160 custom_device.delete_device = &DeleteLoggingDevice;
161 custom_device.execute = &LoggingDeviceExecute;
162 LoggingDevice* device = new LoggingDevice;
163 device->arrived_flag = arrived_flag;
164 device->executed_flag = executed_flag;
165 device->device_name = name;
166 device->underlying_device = "/job:localhost/replica:0/task:0/device:CPU:0";
167 device->strict_scope_placement = strict_scope_placement;
168 TFE_RegisterCustomDevice(context, custom_device, name, device, status);
169 }
170
UnpackTensorHandle(TFE_TensorHandle * logged_tensor_handle,TF_Status * status)171 TFE_TensorHandle* UnpackTensorHandle(TFE_TensorHandle* logged_tensor_handle,
172 TF_Status* status) {
173 return reinterpret_cast<LoggedTensor*>(
174 TFE_TensorHandleDevicePointer(logged_tensor_handle, status))
175 ->tensor;
176 }
177
AllocateLoggingDevice(const char * name,bool * arrived_flag,bool * executed_flag,TFE_CustomDevice ** device,void ** device_info)178 void AllocateLoggingDevice(const char* name, bool* arrived_flag,
179 bool* executed_flag, TFE_CustomDevice** device,
180 void** device_info) {
181 TFE_CustomDevice* custom_device = new TFE_CustomDevice;
182 custom_device->copy_tensor_to_device = &CopyToLoggingDevice;
183 custom_device->copy_tensor_from_device = &CopyTensorFromLoggingDevice;
184 custom_device->delete_device = &DeleteLoggingDevice;
185 custom_device->execute = &LoggingDeviceExecute;
186 *device = custom_device;
187 LoggingDevice* logging_device = new LoggingDevice;
188 logging_device->arrived_flag = arrived_flag;
189 logging_device->executed_flag = executed_flag;
190 logging_device->device_name = name;
191 logging_device->underlying_device =
192 "/job:localhost/replica:0/task:0/device:CPU:0";
193 logging_device->strict_scope_placement = true;
194 *device_info = reinterpret_cast<void*>(logging_device);
195 }
196