1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_CONFIG_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_CONFIG_H_ 18 19 #include <string> 20 21 #include "absl/types/optional.h" 22 #include "tensorflow/compiler/xla/debug_options_flags.h" 23 #include "tensorflow/compiler/xla/service/computation_layout.h" 24 #include "tensorflow/compiler/xla/service/computation_placer.h" 25 #include "tensorflow/compiler/xla/types.h" 26 #include "tensorflow/compiler/xla/xla.pb.h" 27 #include "tensorflow/compiler/xla/xla_data.pb.h" 28 29 namespace xla { 30 31 enum class FusionConfigCollection { 32 kOff, // Do not collect configuration. 33 kPerEdge, // Collect per-edge configuration. 34 kPerNode, // Collect per-node configuration. 35 }; 36 37 // This class gathers all settings and values which affect the compiled 38 // executable outside of the HLO code itself. This include layouts of inputs and 39 // outputs to the module and settings such as HLO profiling. Together the 40 // HloModule and HloModuleConfig unambiguously determine a particular 41 // executable. 42 class HloModuleConfig { 43 public: 44 // Represents a pair of input and output of the entry computation that can be 45 // considered as the original and updated values of a variable maintained by 46 // the caller, and that can be transparently sharded by XLA as an internal 47 // optimization. If sharded, XLA will create separate sharding/unsharding 48 // programs, and the caller is responsible to call the XLA-generated 49 // sharding/unsharding programs before and after the sharded main program. 50 // 51 // If the variable is not updated and there is not a corresponding output, use 52 // {-1} as the output_shape_index. 53 // 54 // The sharding/unsharding programs will include all the input/output pairs in 55 // shardable_value_update_pairs() as a flat tuple in their inputs/outputs, 56 // sorted by (input_parameter_number, parameter_shape_index). 57 // 58 // A typical usage pattern is to shard the variables first, then repeatedly 59 // invoke the main program, and finally invoke the unsharding program before 60 // they are used in full-shape. 61 struct ShardableValueUpdatePair { 62 int64 input_parameter_number; 63 ShapeIndex parameter_shape_index; 64 ShapeIndex output_shape_index; 65 }; 66 67 // A configuration can be created either with, or without an entry 68 // ComputationLayout. The default ctor creates it without -- in this case 69 // accessing entry_computation_layout will CHECK-fail. The ctor accepting a 70 // ProgramShape creates a computation layout using this shape. 71 // The layouts in the ProgramShape will be reset to default unless 72 // ignore_layouts is set to false. HloModuleConfig()73 HloModuleConfig() { debug_options_ = DefaultDebugOptionsIgnoringFlags(); } 74 75 explicit HloModuleConfig(const ProgramShape& program_shape, 76 bool ignore_layouts = true); 77 78 explicit HloModuleConfig(ComputationLayout entry_computation_layout); 79 80 // Checks if this config has an entry computation layout already. has_entry_computation_layout()81 bool has_entry_computation_layout() const { 82 return entry_computation_layout_.has_value(); 83 } 84 85 // Sets the entry_computation_layout's parameter and result shapes for this 86 // config, according to the given program shape. The parameters and result 87 // are set to default layout. 88 void SetDefaultComputationLayout(const ProgramShape& program_shape); 89 90 // Same as above but if the given program contains layout for parameters or 91 // result, the entry_computation_layout's layout is updated accordingly. 92 void SetComputationLayoutIfExists(const ProgramShape& program_shape); 93 94 // Returns a constant reference to the layout of the entry computation. 95 // Assumes the layout was set. entry_computation_layout()96 const ComputationLayout& entry_computation_layout() const { 97 CHECK(entry_computation_layout_.has_value()); 98 return *entry_computation_layout_; 99 } 100 101 // Returns a mutable pointer to the layout of the entry computation. 102 // Assumes the layout was set. mutable_entry_computation_layout()103 ComputationLayout* mutable_entry_computation_layout() { 104 CHECK(entry_computation_layout_.has_value()); 105 return &(*entry_computation_layout_); 106 } 107 108 // Returns whether to enable HLO-level profiling. hlo_profiling_enabled()109 bool hlo_profiling_enabled() const { 110 return debug_options_.xla_hlo_profile(); 111 } 112 cpu_traceme_enabled()113 bool cpu_traceme_enabled() const { 114 return debug_options_.xla_cpu_enable_xprof_traceme(); 115 } 116 117 // Sets/returns the module seed set during execution. set_seed(uint64 seed)118 void set_seed(uint64 seed) { seed_ = seed; } seed()119 uint64 seed() const { return seed_; } 120 121 // Set the launch id of the program. Launch id identifies a set of programs 122 // that should be launched together. set_launch_id(uint64 launch_id)123 void set_launch_id(uint64 launch_id) { launch_id_ = launch_id; } 124 launch_id()125 int32 launch_id() const { return launch_id_; } 126 set_replica_count(int64_t replica_count)127 void set_replica_count(int64_t replica_count) { 128 replica_count_ = replica_count; 129 } replica_count()130 int64 replica_count() const { return replica_count_; } 131 set_num_partitions(int64_t num_partitions)132 void set_num_partitions(int64_t num_partitions) { 133 num_partitions_ = num_partitions; 134 } num_partitions()135 int64 num_partitions() const { return num_partitions_; } 136 param_requires_broadcast_via_collectives()137 const std::vector<bool> param_requires_broadcast_via_collectives() const { 138 return param_requires_broadcast_via_collectives_; 139 } set_param_requires_broadcast_via_collectives(const std::vector<bool> require_broadcast)140 void set_param_requires_broadcast_via_collectives( 141 const std::vector<bool> require_broadcast) { 142 param_requires_broadcast_via_collectives_ = std::move(require_broadcast); 143 } 144 set_use_spmd_partitioning(bool use_spmd_partitioning)145 void set_use_spmd_partitioning(bool use_spmd_partitioning) { 146 use_spmd_partitioning_ = use_spmd_partitioning; 147 } use_spmd_partitioning()148 bool use_spmd_partitioning() const { return use_spmd_partitioning_; } 149 150 // If enabled, deduplicate equivalent hlos into function calls to reduce code 151 // size. set_deduplicate_hlo(bool deduplicate_hlo)152 void set_deduplicate_hlo(bool deduplicate_hlo) { 153 deduplicate_hlo_ = deduplicate_hlo; 154 } 155 set_device_type(const string & device_type)156 void set_device_type(const string& device_type) { 157 device_type_ = device_type; 158 } 159 deduplicate_hlo()160 bool deduplicate_hlo() const { return deduplicate_hlo_; } 161 162 // Return a string which unambiguously represents all the fields of this data 163 // structure. Used for generating a cache key for storing the compiled 164 // executable. 165 string compilation_cache_key() const; 166 device_type()167 string device_type() const { return device_type_; } 168 debug_options()169 const DebugOptions& debug_options() const { return debug_options_; } 170 set_debug_options(const DebugOptions & debug_options)171 void set_debug_options(const DebugOptions& debug_options) { 172 debug_options_ = debug_options; 173 } 174 175 // Sets/returns the number of intra op threads for this module. set_intra_op_parallelism_threads(const int intra_op_parallelism_threads)176 void set_intra_op_parallelism_threads( 177 const int intra_op_parallelism_threads) { 178 intra_op_parallelism_threads_ = intra_op_parallelism_threads; 179 } intra_op_parallelism_threads()180 int64 intra_op_parallelism_threads() const { 181 return intra_op_parallelism_threads_; 182 } 183 184 // Checks if this config has a static device assignment. has_static_device_assignment()185 bool has_static_device_assignment() const { 186 return static_device_assignment_.has_value(); 187 } 188 189 // Getter and setter of the compile-time known device assignment. static_device_assignment()190 const DeviceAssignment& static_device_assignment() const { 191 CHECK(static_device_assignment_.has_value()); 192 return *static_device_assignment_; 193 } set_static_device_assignment(const DeviceAssignment & device_assignment)194 void set_static_device_assignment(const DeviceAssignment& device_assignment) { 195 static_device_assignment_ = device_assignment; 196 } 197 shardable_value_update_pairs()198 const std::vector<ShardableValueUpdatePair> shardable_value_update_pairs() 199 const { 200 return shardable_value_update_pairs_; 201 } set_shardable_value_update_pairs(std::vector<ShardableValueUpdatePair> pairs)202 void set_shardable_value_update_pairs( 203 std::vector<ShardableValueUpdatePair> pairs) { 204 shardable_value_update_pairs_ = std::move(pairs); 205 } 206 207 // Whether input and output buffers are aliased if the associated parameter is 208 // passed-through XLA modules without being changed. alias_passthrough_params()209 bool alias_passthrough_params() const { return alias_passthrough_params_; } set_alias_passthrough_params(bool alias_passthrough_params)210 void set_alias_passthrough_params(bool alias_passthrough_params) { 211 alias_passthrough_params_ = alias_passthrough_params; 212 } 213 content_aware_computation_sorting()214 bool content_aware_computation_sorting() const { 215 return content_aware_computation_sorting_; 216 } set_content_aware_computation_sorting(bool content_aware_computation_sorting)217 void set_content_aware_computation_sorting( 218 bool content_aware_computation_sorting) { 219 content_aware_computation_sorting_ = content_aware_computation_sorting; 220 } 221 fusion_config_collection()222 FusionConfigCollection fusion_config_collection() const { 223 return fusion_config_collection_; 224 } set_fusion_config_collection(FusionConfigCollection fusion_config_collection)225 void set_fusion_config_collection( 226 FusionConfigCollection fusion_config_collection) { 227 fusion_config_collection_ = fusion_config_collection; 228 } 229 fusion_config()230 const std::vector<std::vector<bool>>& fusion_config() const { 231 return fusion_config_; 232 } mutable_fusion_config()233 std::vector<std::vector<bool>>* mutable_fusion_config() { 234 return &fusion_config_; 235 } 236 dot_config()237 const std::vector<std::vector<int64>>& dot_config() const { 238 return dot_config_; 239 } 240 mutable_dot_config()241 std::vector<std::vector<int64>>* mutable_dot_config() { return &dot_config_; } 242 layout_config()243 const std::vector<std::vector<std::vector<int64>>>& layout_config() const { 244 return layout_config_; 245 } 246 mutable_layout_config()247 std::vector<std::vector<std::vector<int64>>>* mutable_layout_config() { 248 return &layout_config_; 249 } 250 phase_ordering_config()251 const std::vector<std::vector<bool>>& phase_ordering_config() const { 252 return phase_ordering_config_; 253 } 254 mutable_phase_ordering_config()255 std::vector<std::vector<bool>>* mutable_phase_ordering_config() { 256 return &phase_ordering_config_; 257 } 258 phase_index()259 const int phase_index() const { return phase_index_; } set_phase_index(const int phase_index)260 void set_phase_index(const int phase_index) { phase_index_ = phase_index; } 261 262 private: 263 // If you add new members, be sure to update compilation_cache_key. 264 265 absl::optional<ComputationLayout> entry_computation_layout_; 266 267 // Module/graph-level seed handle. 268 uint64 seed_ = 0; 269 270 // Program id that identifies a set of program to be launched together. 271 int32 launch_id_ = 0; 272 273 // The number of replicas (data parallelism) to compile this binary for. 274 int64 replica_count_ = 1; 275 276 // The number of partitions (model parallelism) to compile this binary for. 277 int64 num_partitions_ = 1; 278 279 // Whether to broadcast args across all replicas. One entry per arg. 280 std::vector<bool> param_requires_broadcast_via_collectives_; 281 282 // Whether to use SPMD (true) or MPMD (false) when num_partitions_ > 0 and XLA 283 // needs to partition the module. 284 bool use_spmd_partitioning_ = false; 285 286 // If enabled, deduplicate equivalent hlos into function calls to reduce code 287 // size. 288 bool deduplicate_hlo_ = false; 289 290 // The target maximum parallelism at which to partition HLOs for parallel 291 // execution on the CPU backend. 292 int64 intra_op_parallelism_threads_ = -1; 293 294 string device_type_; 295 296 DebugOptions debug_options_; 297 298 // Compile-time known device assignment. 299 absl::optional<DeviceAssignment> static_device_assignment_; 300 301 std::vector<ShardableValueUpdatePair> shardable_value_update_pairs_; 302 303 bool alias_passthrough_params_ = false; 304 305 bool content_aware_computation_sorting_ = true; 306 307 FusionConfigCollection fusion_config_collection_ = 308 FusionConfigCollection::kOff; 309 310 // TODO(b/155665133): Consolidate fusion, dot, and layout config into a proto 311 // similar to backend config. 312 313 // Custom fusion configuration, where fusion_config_[c][v] control if node v 314 // in computation c must be fused to all its consumers (true) or not (false). 315 std::vector<std::vector<bool>> fusion_config_; 316 317 // Custom dot canonicalization configuration, where dot_config_[v] control 318 // how to convert dot operation v (sorted topologically and by computation) to 319 // convolution. 320 std::vector<std::vector<int64>> dot_config_; 321 322 // Layout configuration, where layout_config_[v][i] controls the layout 323 // decision i of operation v. 324 std::vector<std::vector<std::vector<int64>>> layout_config_; 325 326 // Phase ordering configuration, where phase_ordering_config[v][i] controls 327 // whether a specific pass with index i (e.g. 0 = DCE, 1 = CSE, etc.) is 328 // inserted after pass v in pipeline. See tuning::PhaseOrderingConfig for 329 // details on what indices (i) correspond to which passes. 330 std::vector<std::vector<bool>> phase_ordering_config_; 331 // Index (v) corresponding to current passes being added for phase ordering. 332 // This is the variable that stores state to allow us to use the same 333 // config across functions during compilation. 334 int phase_index_ = 0; 335 }; 336 337 } // namespace xla 338 339 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_CONFIG_H_ 340