• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h"
16 
17 #include "tensorflow/compiler/xla/debug_options_flags.h"
18 #include "tensorflow/compiler/xla/service/computation_layout.h"
19 #include "tensorflow/compiler/xla/service/computation_placer.h"
20 #include "tensorflow/compiler/xla/service/dump.h"
21 #include "tensorflow/compiler/xla/xla_data.pb.h"
22 #include "tensorflow/core/platform/errors.h"
23 #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
24 #include "tensorflow/core/tpu/kernels/tpu_executable_info.pb.h"
25 #include "tensorflow/stream_executor/tpu/proto_helper.h"
26 
27 namespace tensorflow {
28 namespace tpu {
29 using ::stream_executor::port::Status;
30 using ::stream_executor::port::StatusOr;
31 using ::xla::ComputationLayout;
32 using ::xla::DebugOptions;
33 using ::xla::DeviceAssignment;
34 using ::xla::HloModuleConfig;
35 using ::xla::HloSharding;
36 using ::xla::InvalidArgument;
37 using ::xla::ProgramShape;
38 using ::xla::Shape;
39 using ::xla::ShapeTree;
40 using ::xla::ShapeUtil;
41 
ValidateResultShape(const Shape & client_shape,const Shape & result_shape)42 Status ValidateResultShape(const Shape& client_shape,
43                            const Shape& result_shape) {
44   TF_RETURN_IF_ERROR(
45       xla::ShapeUtil::ValidateShapeWithOptionalLayout(client_shape));
46   if (!xla::ShapeUtil::Compatible(client_shape, result_shape)) {
47     return InvalidArgument(
48         "Shape used to set computation result layout %s is not compatible "
49         "with result shape %s",
50         xla::ShapeUtil::HumanStringWithLayout(client_shape),
51         xla::ShapeUtil::HumanString(result_shape));
52   }
53   return Status::OK();
54 }
55 
CreateModuleConfig(const ProgramShape & program_shape,absl::Span<const Shape> argument_shapes,absl::optional<const Shape> result_layout,absl::optional<const DeviceAssignment> device_assignment,int replica_count,int num_partitions,const DebugOptions * debug_options,const int * seed,const int * launch_id,const bool * alias_passthrough_params,const xla::FusionConfigCollection * fusion_config_collection,const std::vector<std::vector<bool>> * fusion_config)56 StatusOr<std::unique_ptr<HloModuleConfig>> CreateModuleConfig(
57     const ProgramShape& program_shape, absl::Span<const Shape> argument_shapes,
58     absl::optional<const Shape> result_layout,
59     absl::optional<const DeviceAssignment> device_assignment, int replica_count,
60     int num_partitions, const DebugOptions* debug_options, const int* seed,
61     const int* launch_id, const bool* alias_passthrough_params,
62     const xla::FusionConfigCollection* fusion_config_collection,
63     const std::vector<std::vector<bool>>* fusion_config) {
64   auto config = absl::make_unique<HloModuleConfig>(program_shape);
65   ComputationLayout* computation_layout =
66       config->mutable_entry_computation_layout();
67   if (program_shape.parameters_size() != argument_shapes.size()) {
68     return InvalidArgument("computation takes %d parameters, but %u given",
69                            program_shape.parameters_size(),
70                            argument_shapes.size());
71   }
72   for (int i = 0; i < argument_shapes.size(); ++i) {
73     // Verify that shape of arguments matches the shape of the arguments in the
74     // ProgramShape.
75     if (!ShapeUtil::Compatible(argument_shapes[i],
76                                program_shape.parameters(i))) {
77       return InvalidArgument(
78           "Argument does not match shape of computation parameter %d: want "
79           "%s, got %s",
80           i, ShapeUtil::HumanString(program_shape.parameters(i)),
81           ShapeUtil::HumanString(argument_shapes[i]));
82     }
83     TF_RETURN_IF_ERROR(
84         computation_layout->mutable_parameter_layout(i)->CopyLayoutFromShape(
85             argument_shapes[i]));
86   }
87 
88   if (result_layout.has_value()) {
89     TF_RETURN_IF_ERROR(
90         ValidateResultShape(result_layout.value(), program_shape.result()));
91     TF_RETURN_IF_ERROR(
92         computation_layout->mutable_result_layout()->CopyLayoutFromShape(
93             result_layout.value()));
94   } else {
95     // If the result layout is not set, then choose the default.
96     computation_layout->mutable_result_layout()->SetToDefaultLayout();
97   }
98 
99   config->set_replica_count(replica_count);
100   config->set_num_partitions(num_partitions);
101   if (seed != nullptr) {
102     config->set_seed(*seed);
103   }
104   if (launch_id != nullptr) {
105     config->set_launch_id(*launch_id);
106   }
107   if (debug_options != nullptr) {
108     config->set_debug_options(*debug_options);
109   } else {
110     config->set_debug_options(xla::GetDebugOptionsFromFlags());
111   }
112 
113   // TODO(henrytan): set intra_op_parallelism_threads.
114   // Reference:
115   // tensorflow/compiler/xla/service/service.cc?l=324.
116 
117   if (device_assignment.has_value()) {
118     config->set_static_device_assignment(device_assignment.value());
119   }
120 
121   if (alias_passthrough_params != nullptr) {
122     config->set_alias_passthrough_params(*alias_passthrough_params);
123   }
124 
125   if (fusion_config_collection != nullptr && fusion_config != nullptr &&
126       *fusion_config_collection != xla::FusionConfigCollection::kOff) {
127     config->set_fusion_config_collection(*fusion_config_collection);
128     *config->mutable_fusion_config() = *fusion_config;
129   }
130 
131   return std::move(config);
132 }
133 
CreateModuleConfig(const xla::ProgramShape & program_shape,absl::Span<const Shape> argument_shapes,absl::optional<const Shape> result_layout,absl::optional<const DeviceAssignment> device_assignment,int replica_count,int num_partitions,const DebugOptions * debug_options)134 StatusOr<std::unique_ptr<xla::HloModuleConfig>> CreateModuleConfig(
135     const xla::ProgramShape& program_shape,
136     absl::Span<const Shape> argument_shapes,
137     absl::optional<const Shape> result_layout,
138     absl::optional<const DeviceAssignment> device_assignment, int replica_count,
139     int num_partitions, const DebugOptions* debug_options) {
140   return CreateModuleConfig(program_shape, argument_shapes, result_layout,
141                             device_assignment, replica_count, num_partitions,
142                             debug_options, /*seed=*/nullptr,
143                             /*launch_id=*/nullptr,
144                             /*alias_passthrough_params=*/nullptr,
145                             /*fusion_config_collection=*/nullptr,
146                             /*fusion_config=*/nullptr);
147 }
148 
GetSubtree(const ShapeTree<HloSharding> & tuple_shape_tree,int element_index)149 ShapeTree<HloSharding> GetSubtree(
150     const ShapeTree<HloSharding>& tuple_shape_tree, int element_index) {
151   ShapeTree<HloSharding> element_shape_tree(
152       xla::ShapeUtil::GetTupleElementShape(tuple_shape_tree.shape(),
153                                            element_index),
154       HloSharding::Replicate());
155 
156   xla::ShapeIndex src_index;
157   src_index.push_back(element_index);
158   element_shape_tree.CopySubtreeFrom(tuple_shape_tree, src_index, {});
159   return element_shape_tree;
160 }
161 
GetPerDeviceShape(const Shape & shape,const HloSharding & sharding,int64 device)162 Shape GetPerDeviceShape(const Shape& shape, const HloSharding& sharding,
163                         int64 device) {
164   if (shape.IsTuple()) {
165     ShapeTree<HloSharding> tuple_shape_tree = sharding.GetAsShapeTree(shape);
166     std::vector<Shape> arg_shapes;
167     for (int64 i = 0; i < xla::ShapeUtil::TupleElementCount(shape); ++i) {
168       Shape element_shape = xla::ShapeUtil::GetTupleElementShape(shape, i);
169       HloSharding element_sharding = tuple_shape_tree.element({i});
170       if (element_shape.IsTuple()) {
171         element_sharding = HloSharding::Tuple(GetSubtree(tuple_shape_tree, i));
172       }
173       if (element_sharding.UsesDevice(device)) {
174         arg_shapes.push_back(
175             GetPerDeviceShape(element_shape, element_sharding, device));
176       }
177     }
178     return xla::ShapeUtil::MakeTupleShape(arg_shapes);
179   }
180 
181   if (sharding.IsTileMaximal()) {
182     return shape;
183   }
184 
185   std::vector<int64> dimensions;
186   std::vector<int64> offset = sharding.TileOffsetForDevice(shape, device);
187   std::vector<int64> limit = sharding.TileLimitForDevice(shape, device);
188   dimensions.resize(limit.size());
189   for (int64 i = 0; i < limit.size(); ++i) {
190     dimensions[i] = limit[i] - offset[i];
191   }
192   if (shape.has_layout()) {
193     return xla::ShapeUtil::MakeShapeWithLayout(shape.element_type(), dimensions,
194                                                shape.layout().minor_to_major());
195   }
196   return xla::ShapeUtil::MakeShape(shape.element_type(), dimensions);
197 }
198 
AddVariableUpdatesToCores(const TPUCompileMetadataProto & metadata,const XlaCompiler::CompilationResult & compilation_result,const std::vector<ShardingAndIndex> & arg_core_mapping,std::vector<bool> * may_modify_variables,std::vector<std::vector<xla::Shape>> * per_core_output_shapes,std::vector<std::vector<std::pair<int,bool>>> * per_core_variable_indices)199 Status AddVariableUpdatesToCores(
200     const TPUCompileMetadataProto& metadata,
201     const XlaCompiler::CompilationResult& compilation_result,
202     const std::vector<ShardingAndIndex>& arg_core_mapping,
203     std::vector<bool>* may_modify_variables,
204     std::vector<std::vector<xla::Shape>>* per_core_output_shapes,
205     std::vector<std::vector<std::pair<int, bool>>>* per_core_variable_indices) {
206   // Add all variables to the corresponding core.
207   may_modify_variables->resize(metadata.num_cores_per_replica(), false);
208   int resource_update_pos = 0;
209   for (int i = 0; i < metadata.args_size(); ++i) {
210     const tpu::TPUCompileMetadataProto::Arg& proto_arg = metadata.args(i);
211     if (proto_arg.kind() == tpu::TPUCompileMetadataProto::Arg::VARIABLE) {
212       const auto& sharding = proto_arg.sharding();
213       bool updated = false;
214       if (resource_update_pos < compilation_result.resource_updates.size()) {
215         const XlaCompiler::ResourceUpdate& update =
216             compilation_result.resource_updates[resource_update_pos];
217         if (update.input_index == i) {
218           updated = true;
219           int pos = compilation_result.outputs.size() + resource_update_pos;
220           xla::Shape shape = xla::ShapeUtil::GetTupleElementShape(
221               compilation_result.xla_output_shape, pos);
222           auto add_to_core = [&](int64 core, const xla::Shape& per_core_shape) {
223             (*per_core_output_shapes)[core].push_back(per_core_shape);
224             (*may_modify_variables)[core] =
225                 (*may_modify_variables)[core] || update.modified;
226           };
227           if (sharding.type() == xla::OpSharding::MAXIMAL) {
228             add_to_core(sharding.tile_assignment_devices(0), shape);
229           } else if (sharding.type() == xla::OpSharding::OTHER) {
230             auto sharding_or =
231                 xla::HloSharding::FromProto(proto_arg.sharding());
232             TF_RET_CHECK(sharding_or.ok());
233             for (int64 core : proto_arg.sharding().tile_assignment_devices()) {
234               xla::Shape per_core_shape =
235                   GetPerDeviceShape(shape, sharding_or.ValueOrDie(), core);
236               add_to_core(core, per_core_shape);
237             }
238           } else {
239             TF_RET_CHECK(sharding.type() == xla::OpSharding::REPLICATED);
240             for (int64 core = 0; core < metadata.num_cores_per_replica();
241                  ++core) {
242               add_to_core(core, shape);
243             }
244           }
245           ++resource_update_pos;
246         }
247       }
248       if (sharding.type() == xla::OpSharding::MAXIMAL) {
249         (*per_core_variable_indices)[sharding.tile_assignment_devices(0)]
250             .push_back(
251                 std::pair<int, bool>(arg_core_mapping[i].indices[0], updated));
252       } else if (sharding.type() == xla::OpSharding::OTHER) {
253         for (int core : sharding.tile_assignment_devices()) {
254           (*per_core_variable_indices)[core].push_back(
255               std::pair<int, bool>(arg_core_mapping[i].indices[core], updated));
256         }
257       } else {
258         TF_RET_CHECK(sharding.type() == xla::OpSharding::REPLICATED);
259         for (int64 core = 0; core < metadata.num_cores_per_replica(); ++core) {
260           (*per_core_variable_indices)[core].push_back(
261               std::pair<int, bool>(arg_core_mapping[i].indices[core], updated));
262         }
263       }
264     }
265   }
266   return Status::OK();
267 }
268 
ComputeOutputShapesForEachCore(const tpu::TPUCompileMetadataProto & metadata,const XlaCompiler::CompilationResult & compilation_result,std::vector<std::vector<xla::Shape>> * per_core_output_shapes)269 Status ComputeOutputShapesForEachCore(
270     const tpu::TPUCompileMetadataProto& metadata,
271     const XlaCompiler::CompilationResult& compilation_result,
272     std::vector<std::vector<xla::Shape>>* per_core_output_shapes) {
273   for (int i = 0; i < metadata.retvals_size(); ++i) {
274     const tpu::TPUCompileMetadataProto::Retval& retval = metadata.retvals(i);
275     TF_RET_CHECK(!compilation_result.outputs[i].is_constant)
276         << "TPU compilation output " << i
277         << " has a compile-time constant value. "
278            "This should never happen.";
279 
280     xla::Shape shape = xla::ShapeUtil::GetTupleElementShape(
281         compilation_result.xla_output_shape, i);
282     auto add_shape_to_core = [&](int core, xla::Shape per_core_shape) {
283       (*per_core_output_shapes)[core].push_back(std::move(per_core_shape));
284     };
285     if (retval.sharding().type() == xla::OpSharding::MAXIMAL) {
286       add_shape_to_core(retval.sharding().tile_assignment_devices(0),
287                         std::move(shape));
288     } else if (retval.sharding().type() == xla::OpSharding::OTHER) {
289       auto sharding_or = xla::HloSharding::FromProto(retval.sharding());
290       TF_RET_CHECK(sharding_or.ok());
291       for (int64 core : retval.sharding().tile_assignment_devices()) {
292         xla::Shape per_core_shape =
293             GetPerDeviceShape(shape, sharding_or.ValueOrDie(), core);
294         add_shape_to_core(core, std::move(per_core_shape));
295       }
296     } else {
297       TF_RET_CHECK(retval.sharding().type() == xla::OpSharding::REPLICATED)
298           << "Not all of the constant tensors were consumed.";
299       for (int core = 0; core < per_core_output_shapes->size(); ++core) {
300         add_shape_to_core(core, shape);
301       }
302     }
303   }
304   return Status::OK();
305 }
306 
CreateHloModules(const TPUCompileMetadataProto & metadata,const tensorflow::XlaCompiler::CompilationResult & compilation_result,const absl::optional<xla::DeviceAssignment> & device_assignment,std::vector<std::unique_ptr<xla::HloModule>> * hlo_modules)307 Status CreateHloModules(
308     const TPUCompileMetadataProto& metadata,
309     const tensorflow::XlaCompiler::CompilationResult& compilation_result,
310     const absl::optional<xla::DeviceAssignment>& device_assignment,
311     std::vector<std::unique_ptr<xla::HloModule>>* hlo_modules) {
312   TF_RET_CHECK(
313       compilation_result.computation->proto().has_host_program_shape());
314 
315   auto debug_options = xla::DebugOptions();
316   debug_options.set_xla_step_marker_location(metadata.step_marker_location());
317   TF_ASSIGN_OR_RETURN(
318       std::unique_ptr<xla::HloModuleConfig> module_config,
319       CreateModuleConfig(
320           xla::ProgramShape(
321               compilation_result.computation->proto().host_program_shape()),
322           compilation_result.xla_input_shapes,
323           compilation_result.xla_output_shape, device_assignment,
324           metadata.num_replicas(), metadata.num_cores_per_replica(),
325           &debug_options));
326 
327   TF_ASSIGN_OR_RETURN(
328       std::unique_ptr<xla::HloModule> hlo_module,
329       xla::HloModule::CreateFromProto(compilation_result.computation->proto(),
330                                       *module_config));
331   DumpHloModuleIfEnabled(*hlo_module, "before_optimizations");
332   hlo_modules->push_back(std::move(hlo_module));
333 
334   return Status::OK();
335 }
336 
CreateTpuCompilationRequest(const absl::variant<MlirToHloArgs,FunctionToHloArgs> & computation,const TPUCompileMetadataProto & metadata,const std::vector<TensorShape> & arg_shapes)337 StatusOr<TpuCompilationRequestProto> CreateTpuCompilationRequest(
338     const absl::variant<MlirToHloArgs, FunctionToHloArgs>& computation,
339     const TPUCompileMetadataProto& metadata,
340     const std::vector<TensorShape>& arg_shapes) {
341   VLOG(1) << "CreateTpuCompilationRequest.";
342   TpuCompilationRequestProto compilation_request;
343   bool use_mlir = computation.index() == 0;
344   compilation_request.set_use_mlir(use_mlir);
345   if (use_mlir) {
346     VLOG(1) << "Serializing MlirModule";
347     const MlirToHloArgs& mlir_computation = absl::get<0>(computation);
348     *compilation_request.mutable_mlir_module() = mlir_computation.mlir_module;
349   } else {
350     VLOG(1) << "Serializing FunctionDefinitionLibrary";
351     const FunctionToHloArgs& function_computation = absl::get<1>(computation);
352     *compilation_request.mutable_fdef_lib() =
353         function_computation.flib_def->ToProto();
354     compilation_request.set_graph_def_version(
355         function_computation.graph_def_version);
356     *compilation_request.mutable_function() = *function_computation.function;
357     // TODO(b/160937500): serializing and copying large guaranteed_constants can
358     // be a perf hit. There is a future work to refactor the compilation layer
359     // to avoid passing guaranteed_constants over C_API.
360     if (function_computation.guaranteed_constants.index() == 0) {
361       absl::Span<const TensorProto* const> guaranteed_constants =
362           absl::get<0>(function_computation.guaranteed_constants);
363       for (const TensorProto* constant : guaranteed_constants) {
364         *compilation_request.add_guaranteed_constants() = *constant;
365       }
366     } else {
367       CHECK_EQ(function_computation.guaranteed_constants.index(), 1);
368       const OpInputList& guaranteed_constants =
369           *absl::get<1>(function_computation.guaranteed_constants);
370       for (const Tensor& constant : guaranteed_constants) {
371         constant.AsProtoTensorContent(
372             compilation_request.add_guaranteed_constants());
373       }
374     }
375   }
376 
377   for (const TensorShape& shape : arg_shapes) {
378     shape.AsProto(compilation_request.add_arg_shapes());
379   }
380 
381   *(compilation_request.mutable_metadata()) = metadata;
382 
383   VLOG(1) << "TpuCompilationRequest:\n" << compilation_request.DebugString();
384   return compilation_request;
385 }
386 
CompileOpMetadataFromContext(OpKernelConstruction * ctx,TPUCompileMetadataProto * metadata,NameAttrList * function_name,std::string * mlir_module)387 Status CompileOpMetadataFromContext(OpKernelConstruction* ctx,
388                                     TPUCompileMetadataProto* metadata,
389                                     NameAttrList* function_name,
390                                     std::string* mlir_module) {
391   CHECK_NE(metadata, nullptr);
392 
393   int num_computations;
394   TF_RETURN_IF_ERROR(ctx->GetAttr("num_computations", &num_computations));
395 
396   std::string metadata_string;
397   TF_RETURN_IF_ERROR(ctx->GetAttr("metadata", &metadata_string));
398   if (!metadata->ParsePartialFromString(metadata_string)) {
399     return errors::InvalidArgument("Unable to parse TPUCompileMetadataProto");
400   }
401 
402   if (function_name != nullptr) {
403     TF_RETURN_IF_ERROR(ctx->GetAttr("function", function_name));
404   }
405 
406   if (mlir_module != nullptr) {
407     TF_RETURN_IF_ERROR(ctx->GetAttr("mlir_module", mlir_module));
408   }
409 
410   if (num_computations != metadata->num_cores_per_replica()) {
411     return errors::InvalidArgument(
412         "num_computations must be equal to "
413         "num_cores_per_replica in the 'metadata' "
414         "attribute (",
415         num_computations, " vs ", metadata->num_cores_per_replica(), ")");
416   }
417 
418   if (metadata->has_device_assignment()) {
419     StatusOr<std::unique_ptr<DeviceAssignment>> device_assignment_or_error =
420         DeviceAssignment::Deserialize(metadata->device_assignment());
421     TF_RETURN_IF_ERROR(device_assignment_or_error.status());
422     const DeviceAssignment& device_assignment =
423         *device_assignment_or_error.ValueOrDie();
424     const int num_replicas = metadata->num_replicas();
425     if (device_assignment.replica_count() != num_replicas) {
426       return errors::InvalidArgument(
427           "Device assignment replica_count != num_replicas; ",
428           device_assignment.replica_count(), " vs ", num_replicas);
429     }
430     if (device_assignment.computation_count() !=
431         metadata->num_cores_per_replica()) {
432       return errors::InvalidArgument(
433           "Device assignment computation_count != num_cores_per_replica; ",
434           device_assignment.computation_count(), " vs ",
435           metadata->num_cores_per_replica());
436     }
437   }
438   return Status::OK();
439 }
440 
ComputeArgumentShapes(const tpu::TPUCompileMetadataProto & metadata,const std::vector<TensorShape> & dynamic_shapes,std::vector<TensorShape> * arg_shapes)441 Status ComputeArgumentShapes(const tpu::TPUCompileMetadataProto& metadata,
442                              const std::vector<TensorShape>& dynamic_shapes,
443                              std::vector<TensorShape>* arg_shapes) {
444   arg_shapes->resize(metadata.args_size());
445   int dynamic_shape_pos = 0;
446   for (int i = 0; i < metadata.args_size(); ++i) {
447     const tpu::TPUCompileMetadataProto::Arg& arg = metadata.args(i);
448     // The XLA compiler determines the shape of each constant by inspecting the
449     // value of its corresponding host-memory tensor. As a result, we don't need
450     // to give the compiler graph-inferred shapes for constant arguments.
451     if (arg.kind() == tpu::TPUCompileMetadataProto::Arg::GUARANTEED_CONSTANT) {
452       continue;
453     }
454     TF_RETURN_IF_ERROR(PartialTensorShape::IsValidShape(arg.shape()));
455     PartialTensorShape static_shape(arg.shape());
456 
457     TensorShape& shape = (*arg_shapes)[i];
458     if (static_shape.IsFullyDefined()) {
459       TF_RET_CHECK(static_shape.AsTensorShape(&shape));
460     } else {
461       TF_RET_CHECK(dynamic_shape_pos < dynamic_shapes.size())
462           << "Too few dynamic shapes";
463       shape = dynamic_shapes[dynamic_shape_pos++];
464       if (!static_shape.IsCompatibleWith(shape)) {
465         return errors::InvalidArgument(
466             "Mismatch between static and dynamic shape for argument. Static "
467             "shape: ",
468             static_shape.DebugString(),
469             "; dynamic shape: ", shape.DebugString());
470       }
471     }
472   }
473   // Checks we consumed all of the dynamic shapes.
474   TF_RET_CHECK(dynamic_shape_pos == dynamic_shapes.size())
475       << "Too many dynamic shapes";
476   return Status::OK();
477 }
478 }  // namespace tpu
479 }  // namespace tensorflow
480