• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_SUPPORT_H_
16 #define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_SUPPORT_H_
17 
18 #include <string>
19 #include <vector>
20 
21 #include "absl/strings/string_view.h"
22 #include "absl/types/optional.h"
23 #include "absl/types/span.h"
24 #include "absl/types/variant.h"
25 #include "tensorflow/cc/framework/ops.h"
26 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
27 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
28 #include "tensorflow/compiler/xla/service/hlo_module_group.h"
29 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
30 #include "tensorflow/compiler/xla/shape.h"
31 #include "tensorflow/compiler/xla/shape_tree.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/framework/attr_value.pb.h"
35 #include "tensorflow/core/framework/function.h"
36 #include "tensorflow/core/framework/op_kernel.h"
37 #include "tensorflow/core/framework/tensor.pb.h"
38 #include "tensorflow/core/framework/tensor_shape.h"
39 #include "tensorflow/core/framework/types.pb.h"
40 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
41 #include "tensorflow/core/tpu/kernels/tpu_compile.pb.h"
42 
43 namespace tensorflow {
44 namespace tpu {
45 
46 namespace se = ::stream_executor;
47 
48 // List of parameters for lowering Mlir to HLO IR.
49 struct MlirToHloArgs {
50   const std::string& mlir_module;
51 };
52 
53 // Variant of guaranteed constant tensors types.
54 using GuaranteedConsts = absl::variant<absl::Span<const TensorProto* const>,
55                                        const OpInputList* const>;
56 
57 // List of parameters for lowering function library definition to HLO IR.
58 struct FunctionToHloArgs {
59   const NameAttrList* const function;
60   const FunctionLibraryDefinition* const flib_def;
61   int graph_def_version;
62   GuaranteedConsts guaranteed_constants;
63 };
64 
65 // Persistent cache for compiled TPU program and the related compiler metadata
66 // intended for TPU inference.
67 // TODO(henrytan): there is an opportunity to consolidate the interface with the
68 // `TpuCompilationCacheInterface` once `TpuPersistentCompilationCache` is
69 // converted into a ref count based class.
70 class TpuPersistentCompilationCacheInterface {
71  public:
72   virtual ~TpuPersistentCompilationCacheInterface() = default;
73 
74   // Returns the location where cache entries are stored.
75   virtual std::string cache_location() const = 0;
76 };
77 
78 // Describes the position of an argument or return value after the computation
79 // has been partitioned into cores.
80 struct ShardingAndIndex {
81   // Sharding across cores.
82   ::xla::OpSharding sharding;
83   // Argument/return value number. If sharding is single-core, `indices` has a
84   // single element; otherwise, it has num_cores elements.
85   std::vector<int> indices;
86 };
87 
88 // TODO(b/158279168): Dedup with internal version.
89 // Return the per-device shape for a `shape` with a given `sharding`.
90 xla::Shape GetPerDeviceShape(const xla::Shape& shape,
91                              const xla::HloSharding& sharding,
92                              int64 device);
93 
94 stream_executor::port::StatusOr<std::unique_ptr<xla::HloModuleConfig>>
95 CreateModuleConfig(
96     const xla::ProgramShape& program_shape,
97     absl::Span<const xla::Shape> argument_shapes,
98     absl::optional<const xla::Shape> result_layout,
99     absl::optional<const xla::DeviceAssignment> device_assignment,
100     int replica_count, int num_partitions,
101     const xla::DebugOptions* debug_options, const int* seed,
102     const int* launch_id, const bool* alias_passthrough_params,
103     const xla::FusionConfigCollection* fusion_config_collection,
104     const std::vector<std::vector<bool>>* fusion_config);
105 
106 stream_executor::port::StatusOr<std::unique_ptr<xla::HloModuleConfig>>
107 CreateModuleConfig(
108     const xla::ProgramShape& program_shape,
109     absl::Span<const xla::Shape> argument_shapes,
110     absl::optional<const xla::Shape> result_layout,
111     absl::optional<const xla::DeviceAssignment> device_assignment,
112     int replica_count,
113     int num_partitions, const xla::DebugOptions* debug_options);
114 
115 xla::ShapeTree<xla::HloSharding> GetSubtree(
116     const xla::ShapeTree<xla::HloSharding>& tuple_shape_tree,
117     int element_index);
118 
119 xla::Shape GetPerDeviceShape(const xla::Shape& shape,
120                              const xla::HloSharding& sharding,
121                              int64 device);
122 
123 Status AddVariableUpdatesToCores(
124     const TPUCompileMetadataProto& metadata,
125     const XlaCompiler::CompilationResult& compilation_result,
126     const std::vector<ShardingAndIndex>& arg_core_mapping,
127     std::vector<bool>* may_modify_variables,
128     std::vector<std::vector<xla::Shape>>* per_core_output_shapes,
129     std::vector<std::vector<std::pair<int, bool>>>* per_core_variable_indices);
130 
131 se::port::Status ComputeOutputShapesForEachCore(
132     const tpu::TPUCompileMetadataProto& metadata,
133     const XlaCompiler::CompilationResult& compilation_result,
134     std::vector<std::vector<xla::Shape>>* per_core_output_shapes);
135 
136 se::port::Status CreateHloModules(
137     const TPUCompileMetadataProto& metadata,
138     const XlaCompiler::CompilationResult& compilation_result,
139     const absl::optional<xla::DeviceAssignment>& device_assignment,
140     std::vector<std::unique_ptr<xla::HloModule>>* hlo_modules);
141 
142 se::port::StatusOr<TpuCompilationRequestProto> CreateTpuCompilationRequest(
143     const absl::variant<MlirToHloArgs, FunctionToHloArgs>& computation,
144     const TPUCompileMetadataProto& metadata,
145     const std::vector<TensorShape>& arg_shapes);
146 
147 se::port::Status CompileOpMetadataFromContext(OpKernelConstruction* ctx,
148                                               TPUCompileMetadataProto* metadata,
149                                               NameAttrList* function_name,
150                                               std::string* mlir_module);
151 
152 // Computes shapes for each argument. Uses both the static shape from the
153 // metadata, and the dynamic shapes where the static shape is not
154 // defined. There must be one dynamic_shape for each argument with a
155 // partially defined shape, in index order.
156 Status ComputeArgumentShapes(const TPUCompileMetadataProto& metadata,
157                              const std::vector<TensorShape>& dynamic_shapes,
158                              std::vector<TensorShape>* arg_shapes);
159 }  // namespace tpu
160 }  // namespace tensorflow
161 
162 #endif  // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_SUPPORT_H_
163