1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_SUPPORT_H_ 16 #define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_SUPPORT_H_ 17 18 #include <string> 19 #include <vector> 20 21 #include "absl/strings/string_view.h" 22 #include "absl/types/optional.h" 23 #include "absl/types/span.h" 24 #include "absl/types/variant.h" 25 #include "tensorflow/cc/framework/ops.h" 26 #include "tensorflow/compiler/tf2xla/xla_compiler.h" 27 #include "tensorflow/compiler/xla/service/hlo_module_config.h" 28 #include "tensorflow/compiler/xla/service/hlo_module_group.h" 29 #include "tensorflow/compiler/xla/service/hlo_sharding.h" 30 #include "tensorflow/compiler/xla/shape.h" 31 #include "tensorflow/compiler/xla/shape_tree.h" 32 #include "tensorflow/compiler/xla/statusor.h" 33 #include "tensorflow/compiler/xla/xla_data.pb.h" 34 #include "tensorflow/core/framework/attr_value.pb.h" 35 #include "tensorflow/core/framework/function.h" 36 #include "tensorflow/core/framework/op_kernel.h" 37 #include "tensorflow/core/framework/tensor.pb.h" 38 #include "tensorflow/core/framework/tensor_shape.h" 39 #include "tensorflow/core/framework/types.pb.h" 40 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" 41 #include "tensorflow/core/tpu/kernels/tpu_compile.pb.h" 42 43 namespace tensorflow { 44 namespace tpu { 45 46 namespace se = ::stream_executor; 47 48 // List of parameters for lowering Mlir to HLO IR. 49 struct MlirToHloArgs { 50 const std::string& mlir_module; 51 }; 52 53 // Variant of guaranteed constant tensors types. 54 using GuaranteedConsts = absl::variant<absl::Span<const TensorProto* const>, 55 const OpInputList* const>; 56 57 // List of parameters for lowering function library definition to HLO IR. 58 struct FunctionToHloArgs { 59 const NameAttrList* const function; 60 const FunctionLibraryDefinition* const flib_def; 61 int graph_def_version; 62 GuaranteedConsts guaranteed_constants; 63 }; 64 65 // Persistent cache for compiled TPU program and the related compiler metadata 66 // intended for TPU inference. 67 // TODO(henrytan): there is an opportunity to consolidate the interface with the 68 // `TpuCompilationCacheInterface` once `TpuPersistentCompilationCache` is 69 // converted into a ref count based class. 70 class TpuPersistentCompilationCacheInterface { 71 public: 72 virtual ~TpuPersistentCompilationCacheInterface() = default; 73 74 // Returns the location where cache entries are stored. 75 virtual std::string cache_location() const = 0; 76 }; 77 78 // Describes the position of an argument or return value after the computation 79 // has been partitioned into cores. 80 struct ShardingAndIndex { 81 // Sharding across cores. 82 ::xla::OpSharding sharding; 83 // Argument/return value number. If sharding is single-core, `indices` has a 84 // single element; otherwise, it has num_cores elements. 85 std::vector<int> indices; 86 }; 87 88 // TODO(b/158279168): Dedup with internal version. 89 // Return the per-device shape for a `shape` with a given `sharding`. 90 xla::Shape GetPerDeviceShape(const xla::Shape& shape, 91 const xla::HloSharding& sharding, 92 int64 device); 93 94 stream_executor::port::StatusOr<std::unique_ptr<xla::HloModuleConfig>> 95 CreateModuleConfig( 96 const xla::ProgramShape& program_shape, 97 absl::Span<const xla::Shape> argument_shapes, 98 absl::optional<const xla::Shape> result_layout, 99 absl::optional<const xla::DeviceAssignment> device_assignment, 100 int replica_count, int num_partitions, 101 const xla::DebugOptions* debug_options, const int* seed, 102 const int* launch_id, const bool* alias_passthrough_params, 103 const xla::FusionConfigCollection* fusion_config_collection, 104 const std::vector<std::vector<bool>>* fusion_config); 105 106 stream_executor::port::StatusOr<std::unique_ptr<xla::HloModuleConfig>> 107 CreateModuleConfig( 108 const xla::ProgramShape& program_shape, 109 absl::Span<const xla::Shape> argument_shapes, 110 absl::optional<const xla::Shape> result_layout, 111 absl::optional<const xla::DeviceAssignment> device_assignment, 112 int replica_count, 113 int num_partitions, const xla::DebugOptions* debug_options); 114 115 xla::ShapeTree<xla::HloSharding> GetSubtree( 116 const xla::ShapeTree<xla::HloSharding>& tuple_shape_tree, 117 int element_index); 118 119 xla::Shape GetPerDeviceShape(const xla::Shape& shape, 120 const xla::HloSharding& sharding, 121 int64 device); 122 123 Status AddVariableUpdatesToCores( 124 const TPUCompileMetadataProto& metadata, 125 const XlaCompiler::CompilationResult& compilation_result, 126 const std::vector<ShardingAndIndex>& arg_core_mapping, 127 std::vector<bool>* may_modify_variables, 128 std::vector<std::vector<xla::Shape>>* per_core_output_shapes, 129 std::vector<std::vector<std::pair<int, bool>>>* per_core_variable_indices); 130 131 se::port::Status ComputeOutputShapesForEachCore( 132 const tpu::TPUCompileMetadataProto& metadata, 133 const XlaCompiler::CompilationResult& compilation_result, 134 std::vector<std::vector<xla::Shape>>* per_core_output_shapes); 135 136 se::port::Status CreateHloModules( 137 const TPUCompileMetadataProto& metadata, 138 const XlaCompiler::CompilationResult& compilation_result, 139 const absl::optional<xla::DeviceAssignment>& device_assignment, 140 std::vector<std::unique_ptr<xla::HloModule>>* hlo_modules); 141 142 se::port::StatusOr<TpuCompilationRequestProto> CreateTpuCompilationRequest( 143 const absl::variant<MlirToHloArgs, FunctionToHloArgs>& computation, 144 const TPUCompileMetadataProto& metadata, 145 const std::vector<TensorShape>& arg_shapes); 146 147 se::port::Status CompileOpMetadataFromContext(OpKernelConstruction* ctx, 148 TPUCompileMetadataProto* metadata, 149 NameAttrList* function_name, 150 std::string* mlir_module); 151 152 // Computes shapes for each argument. Uses both the static shape from the 153 // metadata, and the dynamic shapes where the static shape is not 154 // defined. There must be one dynamic_shape for each argument with a 155 // partially defined shape, in index order. 156 Status ComputeArgumentShapes(const TPUCompileMetadataProto& metadata, 157 const std::vector<TensorShape>& dynamic_shapes, 158 std::vector<TensorShape>* arg_shapes); 159 } // namespace tpu 160 } // namespace tensorflow 161 162 #endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_SUPPORT_H_ 163