1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/tpu/kernels/tpu_reshard_variables_op.h"
17
18 #include "tensorflow/compiler/jit/xla_device.h"
19 #include "tensorflow/compiler/jit/xla_launch_util.h"
20 #include "tensorflow/compiler/jit/xla_tensor.h"
21 #include "tensorflow/compiler/tf2xla/shape_util.h"
22 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
23 #include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
24 #include "tensorflow/core/framework/allocator.h"
25 #include "tensorflow/core/framework/node_def_util.h"
26 #include "tensorflow/core/framework/op.h"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/resource_mgr.h"
29 #include "tensorflow/core/framework/resource_var.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/framework/types.h"
32 #include "tensorflow/core/platform/casts.h"
33 #include "tensorflow/core/profiler/lib/traceme.h"
34 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
35 #include "tensorflow/core/tpu/kernels/tpu_op_consts.h"
36 #include "tensorflow/core/tpu/kernels/tpu_program_group.h"
37 #include "tensorflow/core/tpu/kernels/tpu_reshard_variables_op_util.h"
38 #include "tensorflow/core/tpu/tpu_configuration.h"
39 #include "tensorflow/core/tpu/tpu_defs.h"
40 #include "tensorflow/core/tpu/tpu_execute.h"
41 #include "tensorflow/core/util/stream_executor_util.h"
42 #include "tensorflow/stream_executor/device_memory_allocator.h"
43 #include "tensorflow/stream_executor/tpu/tpu_executor.h"
44 #include "tensorflow/stream_executor/tpu/tpu_executor_interface.h"
45 #include "tensorflow/stream_executor/tpu/tpu_node_context.h"
46
47 namespace tensorflow {
48
49 namespace reshard_util = ::tensorflow::tpu::reshard_variables;
50
TPUReshardVariablesOpKernel(OpKernelConstruction * context)51 TPUReshardVariablesOpKernel::TPUReshardVariablesOpKernel(
52 OpKernelConstruction* context)
53 : AsyncOpKernel(context, /* is_deferred = */ true) {
54 OP_REQUIRES_OK(context, context->GetAttr("N", &num_vars_));
55 }
56
ComputeAsync(OpKernelContext * context,DoneCallback done)57 void TPUReshardVariablesOpKernel::ComputeAsync(OpKernelContext* context,
58 DoneCallback done) {
59 // If TPU launches are asynchronous, then perform the launch on this thread
60 // to avoid a thread hop, which has an observable latency cost.
61 OP_REQUIRES_OK_ASYNC(context, DoWork(context), done);
62 done();
63 }
64
DoWork(OpKernelContext * context)65 Status TPUReshardVariablesOpKernel::DoWork(OpKernelContext* context) {
66 VLOG(1) << "Cloud TPU: TPUReshardVariablesOpKernel::DoWork";
67 TF_RET_CHECK(context->input_dtype(num_vars_) == DT_STRING);
68 const Tensor* new_format_key;
69 TF_RETURN_IF_ERROR(context->input("new_format_key", &new_format_key));
70 TF_RETURN_IF_ERROR(reshard_util::CheckIsValidKey(*new_format_key));
71
72 TF_RET_CHECK(context->input_dtype(num_vars_ + 1) == DT_RESOURCE);
73 const ResourceHandle& handle = HandleFromInput(context, num_vars_ + 1);
74 core::RefCountPtr<Var> format_state_var;
75 TF_RETURN_IF_ERROR(LookupOrCreateResource<Var>(
76 context, handle, &format_state_var, [new_format_key](Var** ptr) {
77 *ptr = new Var(new_format_key->dtype());
78 return Status::OK();
79 }));
80 mutex_lock ml(*format_state_var->mu());
81 const bool initialized = format_state_var->is_initialized;
82 if (initialized) {
83 TF_RETURN_IF_ERROR(
84 reshard_util::CheckIsValidKey(*format_state_var->tensor()));
85 }
86
87 const bool state_is_default =
88 !initialized || reshard_util::IsDefaultKey(*format_state_var->tensor());
89 const bool new_format_is_default =
90 reshard_util::IsDefaultKey(*new_format_key);
91
92 if ((state_is_default && new_format_is_default) ||
93 (initialized && format_state_var->tensor()->vec<tstring>()(2) ==
94 new_format_key->vec<tstring>()(2))) {
95 VLOG(1) << "Sharding unchanged, nothing to do.";
96 return Status::OK();
97 }
98
99 if (!state_is_default) {
100 // Convert the current format to default (unsharded).
101 VLOG(1) << "Unsharding with key: "
102 << format_state_var->tensor()->vec<tstring>()(2);
103 TF_RETURN_IF_ERROR(
104 DoTpuExecute(context, *format_state_var->tensor(),
105 tpu::CompilationCacheFetchTarget::UNSHARDING));
106 }
107
108 if (!new_format_is_default) {
109 // Convert the new format.
110 VLOG(1) << "Sharding with key: " << new_format_key->vec<tstring>()(2);
111 TF_RETURN_IF_ERROR(DoTpuExecute(
112 context, *new_format_key, tpu::CompilationCacheFetchTarget::SHARDING));
113 }
114
115 // Change the state.
116 *format_state_var->tensor() = *new_format_key;
117 format_state_var->is_initialized = true;
118 return Status::OK();
119 }
120
DoTpuExecute(OpKernelContext * context,const Tensor & format_key,tpu::CompilationCacheFetchTarget fetch_target)121 Status TPUReshardVariablesOpKernel::DoTpuExecute(
122 OpKernelContext* context, const Tensor& format_key,
123 tpu::CompilationCacheFetchTarget fetch_target) {
124 const XlaDevice::Metadata* metadata;
125 TF_RETURN_IF_ERROR(XlaDevice::GetMetadata(context, &metadata));
126 const int device_ordinal = metadata->device_ordinal();
127
128 // We are guaranteed that the underlying object won't be deleted out from
129 // under us
130 TF_ASSIGN_OR_RETURN(std::unique_ptr<tpu::TpuNodeContext> node_interfaces,
131 tpu::TpuNodeContext::Create(device_ordinal));
132
133 profiler::TraceMe trace_me(
134 [device_ordinal] {
135 return profiler::TraceMeEncode("TPUReshardVariablesOpKernel",
136 {{"device_ordinal", device_ordinal}});
137 },
138 /*level=*/2);
139 profiler::TraceMe trace_me_init("TPUReshardVariablesOpKernel::Init",
140 /*level=*/2);
141
142 string rendezvous_key_base;
143 std::unique_ptr<tpu::CompilationCacheEntryRef> entry_ref;
144 TF_RETURN_IF_ERROR(reshard_util::GetComputationCacheEntry(
145 format_key, &rendezvous_key_base, &entry_ref, fetch_target));
146 tpu::TpuCompilationCacheEntry entry = entry_ref->get();
147 if (entry.tpu_program_group() == nullptr) {
148 VLOG(2) << "Sharding/unsharding program does not exist, so this is default "
149 "sharding.";
150 return Status::OK();
151 }
152
153 const tpu::TpuProgramGroupInterface* tpu_program_group =
154 entry.tpu_program_group();
155 const int core_index = entry.core_index();
156 const TPUExecutableInfoProto& executable_info_proto =
157 tpu_program_group->executable_info(core_index);
158 const TPUExecutableInfoProto* executable = &executable_info_proto;
159
160 xla::Backend* const backend = node_interfaces->backend();
161 xla::TransferManager* const transfer_manager = backend->transfer_manager();
162
163 CHECK(context->op_device_context());
164 se::Stream* stream = context->op_device_context()->stream();
165
166 TF_RET_CHECK(executable->input_shapes_size() == 1);
167 xla::Shape host_shape(executable->input_shapes(0));
168 std::vector<VariableInfo> variables;
169 for (int i = 0; i < num_vars_; ++i) {
170 TF_RET_CHECK(context->input_dtype(i) == DT_RESOURCE);
171 const ResourceHandle& handle = HandleFromInput(context, i);
172 Var* variable;
173 TF_RETURN_IF_ERROR(LookupResource(context, handle, &variable));
174 variables.push_back(VariableInfo(i, handle.name(), variable));
175 }
176
177 // Block for previous TPUExecute ops so that the memory used for them could be
178 // freed.
179 TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
180 // Lock variables to prevent concurrent access.
181 TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variables)));
182
183 // Build input buffers.
184 TF_ASSIGN_OR_RETURN(auto input_buffers, reshard_util::BuildInputBuffers(
185 context, variables, host_shape,
186 backend, device_ordinal, stream));
187 xla::ShapedBuffer shaped_buffer(std::move(host_shape), input_buffers.shape(),
188 device_ordinal);
189 shaped_buffer.set_buffers(input_buffers.Map<se::DeviceMemoryBase>(
190 [](xla::MaybeOwningDeviceMemory* buffer) {
191 CHECK(buffer);
192 return buffer->AsDeviceMemoryBase();
193 }));
194
195 // Write input root tuple.
196 TF_ASSIGN_OR_RETURN(auto transfer_stream_ptr,
197 backend->BorrowStream(device_ordinal));
198 if (transfer_manager->CanShapedBufferBeAccessedNow(stream->parent(),
199 shaped_buffer)) {
200 TF_RETURN_IF_ERROR(transfer_manager->WriteRootTupleIndexTable(
201 transfer_stream_ptr.get(), shaped_buffer));
202 stream->ThenWaitFor(transfer_stream_ptr.get());
203 } else {
204 TF_RETURN_IF_ERROR(
205 transfer_manager->WriteRootTupleIndexTable(stream, shaped_buffer));
206 }
207 VLOG(4) << "Input buffers: " << shaped_buffer.ToString();
208
209 TF_RET_CHECK(!executable->has_session_module())
210 << "session module not supported in sharding/unsharding program.";
211
212 auto definition_event = std::make_shared<se::Event>(stream->parent());
213 TF_RET_CHECK(definition_event->Init())
214 << "TPU definition event initialization failed";
215
216 trace_me_init.Stop();
217
218 // Execute the program.
219 std::unique_ptr<xla::DeviceAssignment> device_assignment;
220 if (executable->has_device_assignment()) {
221 TF_ASSIGN_OR_RETURN(
222 device_assignment,
223 xla::DeviceAssignment::Deserialize(executable->device_assignment()));
224 }
225 std::vector<xla::ExecutionInput> input;
226 input.emplace_back(xla::ExecutionInput(std::move(input_buffers),
227 shaped_buffer.on_host_shape()));
228
229 const TPUHostTransferInfoProto& host_transfer_info =
230 tpu_program_group->host_transfer_info(core_index);
231
232 TF_ASSIGN_OR_RETURN(
233 xla::ExecutionOutput output,
234 TPUExecute(*executable, host_transfer_info,
235 *tpu_program_group->hlo_metadatas()[core_index],
236 std::move(input), rendezvous_key_base, GetXLARandomSeed(),
237 node_interfaces.get(), device_assignment.get(),
238 context->cancellation_manager(), context, stream,
239 transfer_stream_ptr.get(),
240 tpu_program_group->tpu_program(core_index)));
241
242 stream->ThenRecordEvent(definition_event.get());
243
244 // Assign the new buffers to the variables.
245 xla::ScopedShapedBuffer result = output.ConsumeResult();
246
247 // Only perform compaction when sharding.
248 // NOTE: Compaction is not supported on some TPUs, see b/168322060 for details
249 if (node_interfaces->CompactionSupported(device_ordinal) &&
250 fetch_target == tpu::CompilationCacheFetchTarget::SHARDING) {
251 // Block until program execution is done so that input, output, and program
252 // cache memory can be actually released.
253 TF_RETURN_IF_ERROR(transfer_stream_ptr->BlockHostUntilDone());
254 TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
255 {
256 // Explicitly release any RAII objects owning on-device allocations.
257 auto unused = output.ConsumeToBeReleased();
258 }
259 // Release variables holding inputs.
260 for (int i = 0; i < variables.size(); ++i) {
261 *variables[i].var()->tensor() =
262 Tensor(variables[i].var()->tensor()->dtype());
263 }
264 // Flush on-device program memory cache.
265 TF_RETURN_IF_ERROR(
266 reshard_util::FlushProgramMemory(backend->platform(), device_ordinal));
267 TF_RETURN_IF_ERROR(reshard_util::PerformCompaction(stream));
268 }
269 return reshard_util::UpdateOutputVariables(
270 context, std::move(result), executable->output_tensor_shapes(), backend,
271 stream, device_ordinal, variables, definition_event);
272 }
273
274 TPUReshardVariablesOpKernel::~TPUReshardVariablesOpKernel() = default;
275
276 #if !defined(PLATFORM_GOOGLE)
277 REGISTER_KERNEL_BUILDER(Name("TPUReshardVariables")
278 .Device(DEVICE_TPU_NODE)
279 .HostMemory("format_state_var")
280 .HostMemory("new_format_key"),
281 TPUReshardVariablesOpKernel);
282 #endif
283
284 } // namespace tensorflow
285