1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
16
17 #include "tensorflow/compiler/xla/debug_options_flags.h"
18 #include "tensorflow/compiler/xla/service/computation_layout.h"
19 #include "tensorflow/compiler/xla/service/computation_placer.h"
20 #include "tensorflow/compiler/xla/service/dump.h"
21 #include "tensorflow/compiler/xla/xla_data.pb.h"
22 #include "tensorflow/core/platform/errors.h"
23 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
24 #include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
25 #include "tensorflow/stream_executor/tpu/proto_helper.h"
26
27 namespace tensorflow {
28 namespace tpu {
29 using ::stream_executor::port::Status;
30 using ::stream_executor::port::StatusOr;
31 using ::xla::ComputationLayout;
32 using ::xla::DebugOptions;
33 using ::xla::DeviceAssignment;
34 using ::xla::HloModuleConfig;
35 using ::xla::HloSharding;
36 using ::xla::InvalidArgument;
37 using ::xla::ProgramShape;
38 using ::xla::Shape;
39 using ::xla::ShapeTree;
40 using ::xla::ShapeUtil;
41
ValidateResultShape(const Shape & client_shape,const Shape & result_shape)42 Status ValidateResultShape(const Shape& client_shape,
43 const Shape& result_shape) {
44 TF_RETURN_IF_ERROR(
45 xla::ShapeUtil::ValidateShapeWithOptionalLayout(client_shape));
46 if (!xla::ShapeUtil::Compatible(client_shape, result_shape)) {
47 return InvalidArgument(
48 "Shape used to set computation result layout %s is not compatible "
49 "with result shape %s",
50 xla::ShapeUtil::HumanStringWithLayout(client_shape),
51 xla::ShapeUtil::HumanString(result_shape));
52 }
53 return Status::OK();
54 }
55
CreateModuleConfig(const ProgramShape & program_shape,absl::Span<const Shape> argument_shapes,absl::optional<const Shape> result_layout,absl::optional<const DeviceAssignment> device_assignment,int replica_count,int num_partitions,const DebugOptions * debug_options,const int * seed,const int * launch_id,const bool * alias_passthrough_params,const xla::FusionConfigCollection * fusion_config_collection,const std::vector<std::vector<bool>> * fusion_config)56 StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
57 const ProgramShape& program_shape, absl::Span<const Shape> argument_shapes,
58 absl::optional<const Shape> result_layout,
59 absl::optional<const DeviceAssignment> device_assignment, int replica_count,
60 int num_partitions, const DebugOptions* debug_options, const int* seed,
61 const int* launch_id, const bool* alias_passthrough_params,
62 const xla::FusionConfigCollection* fusion_config_collection,
63 const std::vector<std::vector<bool>>* fusion_config) {
64 auto config = absl::make_unique<HloModuleConfig>(program_shape);
65 ComputationLayout* computation_layout =
66 config->mutable_entry_computation_layout();
67 if (program_shape.parameters_size() != argument_shapes.size()) {
68 return InvalidArgument("computation takes %d parameters, but %u given",
69 program_shape.parameters_size(),
70 argument_shapes.size());
71 }
72 for (int i = 0; i < argument_shapes.size(); ++i) {
73 // Verify that shape of arguments matches the shape of the arguments in the
74 // ProgramShape.
75 if (!ShapeUtil::Compatible(argument_shapes[i],
76 program_shape.parameters(i))) {
77 return InvalidArgument(
78 "Argument does not match shape of computation parameter %d: want "
79 "%s, got %s",
80 i, ShapeUtil::HumanString(program_shape.parameters(i)),
81 ShapeUtil::HumanString(argument_shapes[i]));
82 }
83 TF_RETURN_IF_ERROR(
84 computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
85 argument_shapes[i]));
86 }
87
88 if (result_layout.has_value()) {
89 TF_RETURN_IF_ERROR(
90 ValidateResultShape(result_layout.value(), program_shape.result()));
91 TF_RETURN_IF_ERROR(
92 computation_layout->mutable_result_layout()->CopyLayoutFromShape(
93 result_layout.value()));
94 } else {
95 // If the result layout is not set, then choose the default.
96 computation_layout->mutable_result_layout()->SetToDefaultLayout();
97 }
98
99 config->set_replica_count(replica_count);
100 config->set_num_partitions(num_partitions);
101 if (seed != nullptr) {
102 config->set_seed(*seed);
103 }
104 if (launch_id != nullptr) {
105 config->set_launch_id(*launch_id);
106 }
107 if (debug_options != nullptr) {
108 config->set_debug_options(*debug_options);
109 } else {
110 config->set_debug_options(xla::GetDebugOptionsFromFlags());
111 }
112
113 // TODO(henrytan): set intra_op_parallelism_threads.
114 // Reference:
115 // tensorflow/compiler/xla/service/service.cc?l=324.
116
117 if (device_assignment.has_value()) {
118 config->set_static_device_assignment(device_assignment.value());
119 }
120
121 if (alias_passthrough_params != nullptr) {
122 config->set_alias_passthrough_params(*alias_passthrough_params);
123 }
124
125 if (fusion_config_collection != nullptr && fusion_config != nullptr &&
126 *fusion_config_collection != xla::FusionConfigCollection::kOff) {
127 config->set_fusion_config_collection(*fusion_config_collection);
128 *config->mutable_fusion_config() = *fusion_config;
129 }
130
131 return std::move(config);
132 }
133
CreateModuleConfig(const xla::ProgramShape & program_shape,absl::Span<const Shape> argument_shapes,absl::optional<const Shape> result_layout,absl::optional<const DeviceAssignment> device_assignment,int replica_count,int num_partitions,const DebugOptions * debug_options)134 StatusOr<std::unique_ptr<xla::HloModuleConfig>> CreateModuleConfig(
135 const xla::ProgramShape& program_shape,
136 absl::Span<const Shape> argument_shapes,
137 absl::optional<const Shape> result_layout,
138 absl::optional<const DeviceAssignment> device_assignment, int replica_count,
139 int num_partitions, const DebugOptions* debug_options) {
140 return CreateModuleConfig(program_shape, argument_shapes, result_layout,
141 device_assignment, replica_count, num_partitions,
142 debug_options, /*seed=*/nullptr,
143 /*launch_id=*/nullptr,
144 /*alias_passthrough_params=*/nullptr,
145 /*fusion_config_collection=*/nullptr,
146 /*fusion_config=*/nullptr);
147 }
148
GetSubtree(const ShapeTree<HloSharding> & tuple_shape_tree,int element_index)149 ShapeTree<HloSharding> GetSubtree(
150 const ShapeTree<HloSharding>& tuple_shape_tree, int element_index) {
151 ShapeTree<HloSharding> element_shape_tree(
152 xla::ShapeUtil::GetTupleElementShape(tuple_shape_tree.shape(),
153 element_index),
154 HloSharding::Replicate());
155
156 xla::ShapeIndex src_index;
157 src_index.push_back(element_index);
158 element_shape_tree.CopySubtreeFrom(tuple_shape_tree, src_index, {});
159 return element_shape_tree;
160 }
161
GetPerDeviceShape(const Shape & shape,const HloSharding & sharding,int64 device)162 Shape GetPerDeviceShape(const Shape& shape, const HloSharding& sharding,
163 int64 device) {
164 if (shape.IsTuple()) {
165 ShapeTree<HloSharding> tuple_shape_tree = sharding.GetAsShapeTree(shape);
166 std::vector<Shape> arg_shapes;
167 for (int64 i = 0; i < xla::ShapeUtil::TupleElementCount(shape); ++i) {
168 Shape element_shape = xla::ShapeUtil::GetTupleElementShape(shape, i);
169 HloSharding element_sharding = tuple_shape_tree.element({i});
170 if (element_shape.IsTuple()) {
171 element_sharding = HloSharding::Tuple(GetSubtree(tuple_shape_tree, i));
172 }
173 if (element_sharding.UsesDevice(device)) {
174 arg_shapes.push_back(
175 GetPerDeviceShape(element_shape, element_sharding, device));
176 }
177 }
178 return xla::ShapeUtil::MakeTupleShape(arg_shapes);
179 }
180
181 if (sharding.IsTileMaximal()) {
182 return shape;
183 }
184
185 std::vector<int64> dimensions;
186 std::vector<int64> offset = sharding.TileOffsetForDevice(shape, device);
187 std::vector<int64> limit = sharding.TileLimitForDevice(shape, device);
188 dimensions.resize(limit.size());
189 for (int64 i = 0; i < limit.size(); ++i) {
190 dimensions[i] = limit[i] - offset[i];
191 }
192 if (shape.has_layout()) {
193 return xla::ShapeUtil::MakeShapeWithLayout(shape.element_type(), dimensions,
194 shape.layout().minor_to_major());
195 }
196 return xla::ShapeUtil::MakeShape(shape.element_type(), dimensions);
197 }
198
AddVariableUpdatesToCores(const TPUCompileMetadataProto & metadata,const XlaCompiler::CompilationResult & compilation_result,const std::vector<ShardingAndIndex> & arg_core_mapping,std::vector<bool> * may_modify_variables,std::vector<std::vector<xla::Shape>> * per_core_output_shapes,std::vector<std::vector<std::pair<int,bool>>> * per_core_variable_indices)199 Status AddVariableUpdatesToCores(
200 const TPUCompileMetadataProto& metadata,
201 const XlaCompiler::CompilationResult& compilation_result,
202 const std::vector<ShardingAndIndex>& arg_core_mapping,
203 std::vector<bool>* may_modify_variables,
204 std::vector<std::vector<xla::Shape>>* per_core_output_shapes,
205 std::vector<std::vector<std::pair<int, bool>>>* per_core_variable_indices) {
206 // Add all variables to the corresponding core.
207 may_modify_variables->resize(metadata.num_cores_per_replica(), false);
208 int resource_update_pos = 0;
209 for (int i = 0; i < metadata.args_size(); ++i) {
210 const tpu::TPUCompileMetadataProto::Arg& proto_arg = metadata.args(i);
211 if (proto_arg.kind() == tpu::TPUCompileMetadataProto::Arg::VARIABLE) {
212 const auto& sharding = proto_arg.sharding();
213 bool updated = false;
214 if (resource_update_pos < compilation_result.resource_updates.size()) {
215 const XlaCompiler::ResourceUpdate& update =
216 compilation_result.resource_updates[resource_update_pos];
217 if (update.input_index == i) {
218 updated = true;
219 int pos = compilation_result.outputs.size() + resource_update_pos;
220 xla::Shape shape = xla::ShapeUtil::GetTupleElementShape(
221 compilation_result.xla_output_shape, pos);
222 auto add_to_core = [&](int64 core, const xla::Shape& per_core_shape) {
223 (*per_core_output_shapes)[core].push_back(per_core_shape);
224 (*may_modify_variables)[core] =
225 (*may_modify_variables)[core] || update.modified;
226 };
227 if (sharding.type() == xla::OpSharding::MAXIMAL) {
228 add_to_core(sharding.tile_assignment_devices(0), shape);
229 } else if (sharding.type() == xla::OpSharding::OTHER) {
230 auto sharding_or =
231 xla::HloSharding::FromProto(proto_arg.sharding());
232 TF_RET_CHECK(sharding_or.ok());
233 for (int64 core : proto_arg.sharding().tile_assignment_devices()) {
234 xla::Shape per_core_shape =
235 GetPerDeviceShape(shape, sharding_or.ValueOrDie(), core);
236 add_to_core(core, per_core_shape);
237 }
238 } else {
239 TF_RET_CHECK(sharding.type() == xla::OpSharding::REPLICATED);
240 for (int64 core = 0; core < metadata.num_cores_per_replica();
241 ++core) {
242 add_to_core(core, shape);
243 }
244 }
245 ++resource_update_pos;
246 }
247 }
248 if (sharding.type() == xla::OpSharding::MAXIMAL) {
249 (*per_core_variable_indices)[sharding.tile_assignment_devices(0)]
250 .push_back(
251 std::pair<int, bool>(arg_core_mapping[i].indices[0], updated));
252 } else if (sharding.type() == xla::OpSharding::OTHER) {
253 for (int core : sharding.tile_assignment_devices()) {
254 (*per_core_variable_indices)[core].push_back(
255 std::pair<int, bool>(arg_core_mapping[i].indices[core], updated));
256 }
257 } else {
258 TF_RET_CHECK(sharding.type() == xla::OpSharding::REPLICATED);
259 for (int64 core = 0; core < metadata.num_cores_per_replica(); ++core) {
260 (*per_core_variable_indices)[core].push_back(
261 std::pair<int, bool>(arg_core_mapping[i].indices[core], updated));
262 }
263 }
264 }
265 }
266 return Status::OK();
267 }
268
ComputeOutputShapesForEachCore(const tpu::TPUCompileMetadataProto & metadata,const XlaCompiler::CompilationResult & compilation_result,std::vector<std::vector<xla::Shape>> * per_core_output_shapes)269 Status ComputeOutputShapesForEachCore(
270 const tpu::TPUCompileMetadataProto& metadata,
271 const XlaCompiler::CompilationResult& compilation_result,
272 std::vector<std::vector<xla::Shape>>* per_core_output_shapes) {
273 for (int i = 0; i < metadata.retvals_size(); ++i) {
274 const tpu::TPUCompileMetadataProto::Retval& retval = metadata.retvals(i);
275 TF_RET_CHECK(!compilation_result.outputs[i].is_constant)
276 << "TPU compilation output " << i
277 << " has a compile-time constant value. "
278 "This should never happen.";
279
280 xla::Shape shape = xla::ShapeUtil::GetTupleElementShape(
281 compilation_result.xla_output_shape, i);
282 auto add_shape_to_core = [&](int core, xla::Shape per_core_shape) {
283 (*per_core_output_shapes)[core].push_back(std::move(per_core_shape));
284 };
285 if (retval.sharding().type() == xla::OpSharding::MAXIMAL) {
286 add_shape_to_core(retval.sharding().tile_assignment_devices(0),
287 std::move(shape));
288 } else if (retval.sharding().type() == xla::OpSharding::OTHER) {
289 auto sharding_or = xla::HloSharding::FromProto(retval.sharding());
290 TF_RET_CHECK(sharding_or.ok());
291 for (int64 core : retval.sharding().tile_assignment_devices()) {
292 xla::Shape per_core_shape =
293 GetPerDeviceShape(shape, sharding_or.ValueOrDie(), core);
294 add_shape_to_core(core, std::move(per_core_shape));
295 }
296 } else {
297 TF_RET_CHECK(retval.sharding().type() == xla::OpSharding::REPLICATED)
298 << "Not all of the constant tensors were consumed.";
299 for (int core = 0; core < per_core_output_shapes->size(); ++core) {
300 add_shape_to_core(core, shape);
301 }
302 }
303 }
304 return Status::OK();
305 }
306
CreateHloModules(const TPUCompileMetadataProto & metadata,const tensorflow::XlaCompiler::CompilationResult & compilation_result,const absl::optional<xla::DeviceAssignment> & device_assignment,std::vector<std::unique_ptr<xla::HloModule>> * hlo_modules)307 Status CreateHloModules(
308 const TPUCompileMetadataProto& metadata,
309 const tensorflow::XlaCompiler::CompilationResult& compilation_result,
310 const absl::optional<xla::DeviceAssignment>& device_assignment,
311 std::vector<std::unique_ptr<xla::HloModule>>* hlo_modules) {
312 TF_RET_CHECK(
313 compilation_result.computation->proto().has_host_program_shape());
314
315 auto debug_options = xla::DebugOptions();
316 debug_options.set_xla_step_marker_location(metadata.step_marker_location());
317 TF_ASSIGN_OR_RETURN(
318 std::unique_ptr<xla::HloModuleConfig> module_config,
319 CreateModuleConfig(
320 xla::ProgramShape(
321 compilation_result.computation->proto().host_program_shape()),
322 compilation_result.xla_input_shapes,
323 compilation_result.xla_output_shape, device_assignment,
324 metadata.num_replicas(), metadata.num_cores_per_replica(),
325 &debug_options));
326
327 TF_ASSIGN_OR_RETURN(
328 std::unique_ptr<xla::HloModule> hlo_module,
329 xla::HloModule::CreateFromProto(compilation_result.computation->proto(),
330 *module_config));
331 DumpHloModuleIfEnabled(*hlo_module, "before_optimizations");
332 hlo_modules->push_back(std::move(hlo_module));
333
334 return Status::OK();
335 }
336
CreateTpuCompilationRequest(const absl::variant<MlirToHloArgs,FunctionToHloArgs> & computation,const TPUCompileMetadataProto & metadata,const std::vector<TensorShape> & arg_shapes)337 StatusOr<TpuCompilationRequestProto> CreateTpuCompilationRequest(
338 const absl::variant<MlirToHloArgs, FunctionToHloArgs>& computation,
339 const TPUCompileMetadataProto& metadata,
340 const std::vector<TensorShape>& arg_shapes) {
341 VLOG(1) << "CreateTpuCompilationRequest.";
342 TpuCompilationRequestProto compilation_request;
343 bool use_mlir = computation.index() == 0;
344 compilation_request.set_use_mlir(use_mlir);
345 if (use_mlir) {
346 VLOG(1) << "Serializing MlirModule";
347 const MlirToHloArgs& mlir_computation = absl::get<0>(computation);
348 *compilation_request.mutable_mlir_module() = mlir_computation.mlir_module;
349 } else {
350 VLOG(1) << "Serializing FunctionDefinitionLibrary";
351 const FunctionToHloArgs& function_computation = absl::get<1>(computation);
352 *compilation_request.mutable_fdef_lib() =
353 function_computation.flib_def->ToProto();
354 compilation_request.set_graph_def_version(
355 function_computation.graph_def_version);
356 *compilation_request.mutable_function() = *function_computation.function;
357 // TODO(b/160937500): serializing and copying large guaranteed_constants can
358 // be a perf hit. There is a future work to refactor the compilation layer
359 // to avoid passing guaranteed_constants over C_API.
360 if (function_computation.guaranteed_constants.index() == 0) {
361 absl::Span<const TensorProto* const> guaranteed_constants =
362 absl::get<0>(function_computation.guaranteed_constants);
363 for (const TensorProto* constant : guaranteed_constants) {
364 *compilation_request.add_guaranteed_constants() = *constant;
365 }
366 } else {
367 CHECK_EQ(function_computation.guaranteed_constants.index(), 1);
368 const OpInputList& guaranteed_constants =
369 *absl::get<1>(function_computation.guaranteed_constants);
370 for (const Tensor& constant : guaranteed_constants) {
371 constant.AsProtoTensorContent(
372 compilation_request.add_guaranteed_constants());
373 }
374 }
375 }
376
377 for (const TensorShape& shape : arg_shapes) {
378 shape.AsProto(compilation_request.add_arg_shapes());
379 }
380
381 *(compilation_request.mutable_metadata()) = metadata;
382
383 VLOG(1) << "TpuCompilationRequest:\n" << compilation_request.DebugString();
384 return compilation_request;
385 }
386
CompileOpMetadataFromContext(OpKernelConstruction * ctx,TPUCompileMetadataProto * metadata,NameAttrList * function_name,std::string * mlir_module)387 Status CompileOpMetadataFromContext(OpKernelConstruction* ctx,
388 TPUCompileMetadataProto* metadata,
389 NameAttrList* function_name,
390 std::string* mlir_module) {
391 CHECK_NE(metadata, nullptr);
392
393 int num_computations;
394 TF_RETURN_IF_ERROR(ctx->GetAttr("num_computations", &num_computations));
395
396 std::string metadata_string;
397 TF_RETURN_IF_ERROR(ctx->GetAttr("metadata", &metadata_string));
398 if (!metadata->ParsePartialFromString(metadata_string)) {
399 return errors::InvalidArgument("Unable to parse TPUCompileMetadataProto");
400 }
401
402 if (function_name != nullptr) {
403 TF_RETURN_IF_ERROR(ctx->GetAttr("function", function_name));
404 }
405
406 if (mlir_module != nullptr) {
407 TF_RETURN_IF_ERROR(ctx->GetAttr("mlir_module", mlir_module));
408 }
409
410 if (num_computations != metadata->num_cores_per_replica()) {
411 return errors::InvalidArgument(
412 "num_computations must be equal to "
413 "num_cores_per_replica in the 'metadata' "
414 "attribute (",
415 num_computations, " vs ", metadata->num_cores_per_replica(), ")");
416 }
417
418 if (metadata->has_device_assignment()) {
419 StatusOr<std::unique_ptr<DeviceAssignment>> device_assignment_or_error =
420 DeviceAssignment::Deserialize(metadata->device_assignment());
421 TF_RETURN_IF_ERROR(device_assignment_or_error.status());
422 const DeviceAssignment& device_assignment =
423 *device_assignment_or_error.ValueOrDie();
424 const int num_replicas = metadata->num_replicas();
425 if (device_assignment.replica_count() != num_replicas) {
426 return errors::InvalidArgument(
427 "Device assignment replica_count != num_replicas; ",
428 device_assignment.replica_count(), " vs ", num_replicas);
429 }
430 if (device_assignment.computation_count() !=
431 metadata->num_cores_per_replica()) {
432 return errors::InvalidArgument(
433 "Device assignment computation_count != num_cores_per_replica; ",
434 device_assignment.computation_count(), " vs ",
435 metadata->num_cores_per_replica());
436 }
437 }
438 return Status::OK();
439 }
440
ComputeArgumentShapes(const tpu::TPUCompileMetadataProto & metadata,const std::vector<TensorShape> & dynamic_shapes,std::vector<TensorShape> * arg_shapes)441 Status ComputeArgumentShapes(const tpu::TPUCompileMetadataProto& metadata,
442 const std::vector<TensorShape>& dynamic_shapes,
443 std::vector<TensorShape>* arg_shapes) {
444 arg_shapes->resize(metadata.args_size());
445 int dynamic_shape_pos = 0;
446 for (int i = 0; i < metadata.args_size(); ++i) {
447 const tpu::TPUCompileMetadataProto::Arg& arg = metadata.args(i);
448 // The XLA compiler determines the shape of each constant by inspecting the
449 // value of its corresponding host-memory tensor. As a result, we don't need
450 // to give the compiler graph-inferred shapes for constant arguments.
451 if (arg.kind() == tpu::TPUCompileMetadataProto::Arg::GUARANTEED_CONSTANT) {
452 continue;
453 }
454 TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(arg.shape()));
455 PartialTensorShape static_shape(arg.shape());
456
457 TensorShape& shape = (*arg_shapes)[i];
458 if (static_shape.IsFullyDefined()) {
459 TF_RET_CHECK(static_shape.AsTensorShape(&shape));
460 } else {
461 TF_RET_CHECK(dynamic_shape_pos < dynamic_shapes.size())
462 << "Too few dynamic shapes";
463 shape = dynamic_shapes[dynamic_shape_pos++];
464 if (!static_shape.IsCompatibleWith(shape)) {
465 return errors::InvalidArgument(
466 "Mismatch between static and dynamic shape for argument. Static "
467 "shape: ",
468 static_shape.DebugString(),
469 "; dynamic shape: ", shape.DebugString());
470 }
471 }
472 }
473 // Checks we consumed all of the dynamic shapes.
474 TF_RET_CHECK(dynamic_shape_pos == dynamic_shapes.size())
475 << "Too many dynamic shapes";
476 return Status::OK();
477 }
478 } // namespace tpu
479 } // namespace tensorflow
480