1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_CONFIG_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_CONFIG_H_ 18 19 #include <string> 20 21 #include "absl/types/optional.h" 22 #include "tensorflow/compiler/xla/service/computation_layout.h" 23 #include "tensorflow/compiler/xla/service/computation_placer.h" 24 #include "tensorflow/compiler/xla/types.h" 25 #include "tensorflow/compiler/xla/xla.pb.h" 26 #include "tensorflow/compiler/xla/xla_data.pb.h" 27 28 namespace xla { 29 30 enum class FusionConfigCollection { 31 kOff, // Do not collect configuration. 32 kPerEdge, // Collect per-edge configuration. 33 kPerNode, // Collect per-node configuration. 34 }; 35 36 // This class gathers all settings and values which affect the compiled 37 // executable outside of the HLO code itself. This include layouts of inputs and 38 // outputs to the module and settings such as HLO profiling. Together the 39 // HloModule and HloModuleConfig unambiguously determine a particular 40 // executable. 41 class HloModuleConfig { 42 public: 43 // Represents a pair of input and output of the entry computation that can be 44 // considered as the original and updated values of a variable maintained by 45 // the caller, and that can be transparently sharded by XLA as an internal 46 // optimization. If sharded, XLA will create separate sharding/unsharding 47 // programs, and the caller is responsible to call the XLA-generated 48 // sharding/unsharding programs before and after the sharded main program. 49 // 50 // If the variable is not updated and there is not a corresponding output, use 51 // {-1} as the output_shape_index. 52 // 53 // The sharding/unsharding programs will include all the input/output pairs in 54 // shardable_value_update_pairs() as a flat tuple in their inputs/outputs, 55 // sorted by (input_parameter_number, parameter_shape_index). 56 // 57 // A typical usage pattern is to shard the variables first, then repeatedly 58 // invoke the main program, and finally invoke the unsharding program before 59 // they are used in full-shape. 60 struct ShardableValueUpdatePair { 61 int64 input_parameter_number; 62 ShapeIndex parameter_shape_index; 63 ShapeIndex output_shape_index; 64 }; 65 66 // A configuration can be created either with, or without an entry 67 // ComputationLayout. The default ctor creates it without -- in this case 68 // accessing entry_computation_layout will CHECK-fail. The ctor accepting a 69 // ProgramShape creates a computation layout using this shape. 70 // The layouts in the ProgramShape will be reset to default unless 71 // ignore_layouts is set to false. 72 HloModuleConfig() = default; 73 74 explicit HloModuleConfig(const ProgramShape& program_shape, 75 bool ignore_layouts = true); 76 77 explicit HloModuleConfig(ComputationLayout entry_computation_layout); 78 79 // Checks if this config has an entry computation layout already. has_entry_computation_layout()80 bool has_entry_computation_layout() const { 81 return entry_computation_layout_.has_value(); 82 } 83 84 // Sets the entry_computation_layout's parameter and result shapes for this 85 // config, according to the given program shape. The parameters and result 86 // are set to default layout. 87 void SetDefaultComputationLayout(const ProgramShape& program_shape); 88 89 // Same as above but if the given program contains layout for parameters or 90 // result, the entry_computation_layout's layout is updated accordingly. 91 void SetComputationLayoutIfExists(const ProgramShape& program_shape); 92 93 // Returns a constant reference to the layout of the entry computation. 94 // Assumes the layout was set. entry_computation_layout()95 const ComputationLayout& entry_computation_layout() const { 96 CHECK(entry_computation_layout_.has_value()); 97 return *entry_computation_layout_; 98 } 99 100 // Returns a mutable pointer to the layout of the entry computation. 101 // Assumes the layout was set. mutable_entry_computation_layout()102 ComputationLayout* mutable_entry_computation_layout() { 103 CHECK(entry_computation_layout_.has_value()); 104 return &(*entry_computation_layout_); 105 } 106 107 // Returns whether to enable HLO-level profiling. hlo_profiling_enabled()108 bool hlo_profiling_enabled() const { 109 return debug_options_.xla_hlo_profile(); 110 } 111 cpu_traceme_enabled()112 bool cpu_traceme_enabled() const { 113 return debug_options_.xla_cpu_enable_xprof_traceme(); 114 } 115 116 // Sets/returns the module seed set during execution. set_seed(uint64 seed)117 void set_seed(uint64 seed) { seed_ = seed; } seed()118 uint64 seed() const { return seed_; } 119 120 // Set the launch id of the program. Launch id identifies a set of programs 121 // that should be launched together. set_launch_id(uint64 launch_id)122 void set_launch_id(uint64 launch_id) { launch_id_ = launch_id; } 123 launch_id()124 int32 launch_id() const { return launch_id_; } 125 set_replica_count(int64 replica_count)126 void set_replica_count(int64 replica_count) { 127 replica_count_ = replica_count; 128 } replica_count()129 int64 replica_count() const { return replica_count_; } 130 set_num_partitions(int64 num_partitions)131 void set_num_partitions(int64 num_partitions) { 132 num_partitions_ = num_partitions; 133 } num_partitions()134 int64 num_partitions() const { return num_partitions_; } 135 set_broadcast_replicated_params(bool broadcast_replicated_params)136 void set_broadcast_replicated_params(bool broadcast_replicated_params) { 137 broadcast_replicated_params_ = broadcast_replicated_params; 138 } broadcast_replicated_params()139 bool broadcast_replicated_params() const { 140 return broadcast_replicated_params_; 141 } 142 set_use_spmd_partitioning(bool use_spmd_partitioning)143 void set_use_spmd_partitioning(bool use_spmd_partitioning) { 144 use_spmd_partitioning_ = use_spmd_partitioning; 145 } use_spmd_partitioning()146 bool use_spmd_partitioning() const { return use_spmd_partitioning_; } 147 148 // If enabled, deduplicate equivalent hlos into function calls to reduce code 149 // size. set_deduplicate_hlo(bool deduplicate_hlo)150 void set_deduplicate_hlo(bool deduplicate_hlo) { 151 deduplicate_hlo_ = deduplicate_hlo; 152 } deduplicate_hlo()153 bool deduplicate_hlo() const { return deduplicate_hlo_; } 154 155 // Return a string which unambiguously represents all the fields of this data 156 // structure. Used for generating a cache key for storing the compiled 157 // executable. 158 string compilation_cache_key() const; 159 debug_options()160 const DebugOptions& debug_options() const { return debug_options_; } 161 set_debug_options(const DebugOptions & debug_options)162 void set_debug_options(const DebugOptions& debug_options) { 163 debug_options_ = debug_options; 164 } 165 166 // Sets/returns the number of intra op threads for this module. set_intra_op_parallelism_threads(const int intra_op_parallelism_threads)167 void set_intra_op_parallelism_threads( 168 const int intra_op_parallelism_threads) { 169 intra_op_parallelism_threads_ = intra_op_parallelism_threads; 170 } intra_op_parallelism_threads()171 int64 intra_op_parallelism_threads() const { 172 return intra_op_parallelism_threads_; 173 } 174 175 // Checks if this config has a static device assignment. has_static_device_assignment()176 bool has_static_device_assignment() const { 177 return static_device_assignment_.has_value(); 178 } 179 180 // Getter and setter of the compile-time known device assignment. static_device_assignment()181 const DeviceAssignment& static_device_assignment() const { 182 CHECK(static_device_assignment_.has_value()); 183 return *static_device_assignment_; 184 } set_static_device_assignment(const DeviceAssignment & device_assignment)185 void set_static_device_assignment(const DeviceAssignment& device_assignment) { 186 static_device_assignment_ = device_assignment; 187 } 188 shardable_value_update_pairs()189 const std::vector<ShardableValueUpdatePair> shardable_value_update_pairs() 190 const { 191 return shardable_value_update_pairs_; 192 } set_shardable_value_update_pairs(std::vector<ShardableValueUpdatePair> pairs)193 void set_shardable_value_update_pairs( 194 std::vector<ShardableValueUpdatePair> pairs) { 195 shardable_value_update_pairs_ = std::move(pairs); 196 } 197 198 // Whether input and output buffers are aliased if the associated parameter is 199 // passed-through XLA modules without being changed. alias_passthrough_params()200 bool alias_passthrough_params() const { return alias_passthrough_params_; } set_alias_passthrough_params(bool alias_passthrough_params)201 void set_alias_passthrough_params(bool alias_passthrough_params) { 202 alias_passthrough_params_ = alias_passthrough_params; 203 } 204 content_aware_computation_sorting()205 bool content_aware_computation_sorting() const { 206 return content_aware_computation_sorting_; 207 } set_content_aware_computation_sorting(bool content_aware_computation_sorting)208 void set_content_aware_computation_sorting( 209 bool content_aware_computation_sorting) { 210 content_aware_computation_sorting_ = content_aware_computation_sorting; 211 } 212 fusion_config_collection()213 FusionConfigCollection fusion_config_collection() const { 214 return fusion_config_collection_; 215 } set_fusion_config_collection(FusionConfigCollection fusion_config_collection)216 void set_fusion_config_collection( 217 FusionConfigCollection fusion_config_collection) { 218 fusion_config_collection_ = fusion_config_collection; 219 } 220 fusion_config()221 const std::vector<std::vector<bool>>& fusion_config() const { 222 return fusion_config_; 223 } mutable_fusion_config()224 std::vector<std::vector<bool>>* mutable_fusion_config() { 225 return &fusion_config_; 226 } 227 dot_config()228 const std::vector<std::vector<int64>>& dot_config() const { 229 return dot_config_; 230 } 231 mutable_dot_config()232 std::vector<std::vector<int64>>* mutable_dot_config() { return &dot_config_; } 233 layout_config()234 const std::vector<std::vector<std::vector<int64>>>& layout_config() const { 235 return layout_config_; 236 } 237 mutable_layout_config()238 std::vector<std::vector<std::vector<int64>>>* mutable_layout_config() { 239 return &layout_config_; 240 } 241 242 private: 243 // If you add new members, be sure to update compilation_cache_key. 244 245 absl::optional<ComputationLayout> entry_computation_layout_; 246 247 // Module/graph-level seed handle. 248 uint64 seed_ = 0; 249 250 // Program id that identifies a set of program to be launched together. 251 int32 launch_id_ = 0; 252 253 // The number of replicas (data parallelism) to compile this binary for. 254 int64 replica_count_ = 1; 255 256 // The number of partitions (model parallelism) to compile this binary for. 257 int64 num_partitions_ = 1; 258 259 // Whether to use XLA collectives to broadcast params to all replicas. 260 bool broadcast_replicated_params_ = false; 261 262 // Whether to use SPMD (true) or MPMD (false) when num_partitions_ > 0 and XLA 263 // needs to partition the module. 264 bool use_spmd_partitioning_ = false; 265 266 // If enabled, deduplicate equivalent hlos into function calls to reduce code 267 // size. 268 bool deduplicate_hlo_ = false; 269 270 // The target maximum parallelism at which to partition HLOs for parallel 271 // execution on the CPU backend. 272 int64 intra_op_parallelism_threads_ = -1; 273 274 DebugOptions debug_options_; 275 276 // Compile-time known device assignment. 277 absl::optional<DeviceAssignment> static_device_assignment_; 278 279 std::vector<ShardableValueUpdatePair> shardable_value_update_pairs_; 280 281 bool alias_passthrough_params_ = false; 282 283 bool content_aware_computation_sorting_ = false; 284 285 FusionConfigCollection fusion_config_collection_ = 286 FusionConfigCollection::kOff; 287 288 // TODO(b/155665133): Consolidate fusion, dot, and layout config into a proto 289 // similar to backend config. 290 291 // Custom fusion configuration, where fusion_config_[c][v] control if node v 292 // in computation c must be fused to all its consumers (true) or not (false). 293 std::vector<std::vector<bool>> fusion_config_; 294 295 // Custom dot canonicalization configuration, where dot_config_[v] control 296 // how to convert dot operation v (sorted topologically and by computation) to 297 // convolution. 298 std::vector<std::vector<int64>> dot_config_; 299 300 // Layout configuration, where layout_config_[v][i] controls the layout 301 // decision i of operation v. 302 std::vector<std::vector<std::vector<int64>>> layout_config_; 303 }; 304 305 } // namespace xla 306 307 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_CONFIG_H_ 308