• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_CONFIG_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_CONFIG_H_
18 
19 #include <string>
20 
21 #include "absl/types/optional.h"
22 #include "tensorflow/compiler/xla/debug_options_flags.h"
23 #include "tensorflow/compiler/xla/service/computation_layout.h"
24 #include "tensorflow/compiler/xla/service/computation_placer.h"
25 #include "tensorflow/compiler/xla/types.h"
26 #include "tensorflow/compiler/xla/xla.pb.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 
29 namespace xla {
30 
31 enum class FusionConfigCollection {
32   kOff,      // Do not collect configuration.
33   kPerEdge,  // Collect per-edge configuration.
34   kPerNode,  // Collect per-node configuration.
35 };
36 
37 // This class gathers all settings and values which affect the compiled
38 // executable outside of the HLO code itself. This include layouts of inputs and
39 // outputs to the module and settings such as HLO profiling. Together the
40 // HloModule and HloModuleConfig unambiguously determine a particular
41 // executable.
42 class HloModuleConfig {
43  public:
44   // Represents a pair of input and output of the entry computation that can be
45   // considered as the original and updated values of a variable maintained by
46   // the caller, and that can be transparently sharded by XLA as an internal
47   // optimization. If sharded, XLA will create separate sharding/unsharding
48   // programs, and the caller is responsible to call the XLA-generated
49   // sharding/unsharding programs before and after the sharded main program.
50   //
51   // If the variable is not updated and there is not a corresponding output, use
52   // {-1} as the output_shape_index.
53   //
54   // The sharding/unsharding programs will include all the input/output pairs in
55   // shardable_value_update_pairs() as a flat tuple in their inputs/outputs,
56   // sorted by (input_parameter_number, parameter_shape_index).
57   //
58   // A typical usage pattern is to shard the variables first, then repeatedly
59   // invoke the main program, and finally invoke the unsharding program before
60   // they are used in full-shape.
61   struct ShardableValueUpdatePair {
62     int64 input_parameter_number;
63     ShapeIndex parameter_shape_index;
64     ShapeIndex output_shape_index;
65   };
66 
67   // A configuration can be created either with, or without an entry
68   // ComputationLayout. The default ctor creates it without -- in this case
69   // accessing entry_computation_layout will CHECK-fail. The ctor accepting a
70   // ProgramShape creates a computation layout using this shape.
71   // The layouts in the ProgramShape will be reset to default unless
72   // ignore_layouts is set to false.
HloModuleConfig()73   HloModuleConfig() { debug_options_ = DefaultDebugOptionsIgnoringFlags(); }
74 
75   explicit HloModuleConfig(const ProgramShape& program_shape,
76                            bool ignore_layouts = true);
77 
78   explicit HloModuleConfig(ComputationLayout entry_computation_layout);
79 
80   // Checks if this config has an entry computation layout already.
has_entry_computation_layout()81   bool has_entry_computation_layout() const {
82     return entry_computation_layout_.has_value();
83   }
84 
85   // Sets the entry_computation_layout's parameter and result shapes for this
86   // config, according to the given program shape. The parameters and result
87   // are set to default layout.
88   void SetDefaultComputationLayout(const ProgramShape& program_shape);
89 
90   // Same as above but if the given program contains layout for parameters or
91   // result, the entry_computation_layout's layout is updated accordingly.
92   void SetComputationLayoutIfExists(const ProgramShape& program_shape);
93 
94   // Returns a constant reference to the layout of the entry computation.
95   // Assumes the layout was set.
entry_computation_layout()96   const ComputationLayout& entry_computation_layout() const {
97     CHECK(entry_computation_layout_.has_value());
98     return *entry_computation_layout_;
99   }
100 
101   // Returns a mutable pointer to the layout of the entry computation.
102   // Assumes the layout was set.
mutable_entry_computation_layout()103   ComputationLayout* mutable_entry_computation_layout() {
104     CHECK(entry_computation_layout_.has_value());
105     return &(*entry_computation_layout_);
106   }
107 
108   // Returns whether to enable HLO-level profiling.
hlo_profiling_enabled()109   bool hlo_profiling_enabled() const {
110     return debug_options_.xla_hlo_profile();
111   }
112 
cpu_traceme_enabled()113   bool cpu_traceme_enabled() const {
114     return debug_options_.xla_cpu_enable_xprof_traceme();
115   }
116 
117   // Sets/returns the module seed set during execution.
set_seed(uint64 seed)118   void set_seed(uint64 seed) { seed_ = seed; }
seed()119   uint64 seed() const { return seed_; }
120 
121   // Set the launch id of the program. Launch id identifies a set of programs
122   // that should be launched together.
set_launch_id(uint64 launch_id)123   void set_launch_id(uint64 launch_id) { launch_id_ = launch_id; }
124 
launch_id()125   int32 launch_id() const { return launch_id_; }
126 
set_replica_count(int64_t replica_count)127   void set_replica_count(int64_t replica_count) {
128     replica_count_ = replica_count;
129   }
replica_count()130   int64 replica_count() const { return replica_count_; }
131 
set_num_partitions(int64_t num_partitions)132   void set_num_partitions(int64_t num_partitions) {
133     num_partitions_ = num_partitions;
134   }
num_partitions()135   int64 num_partitions() const { return num_partitions_; }
136 
param_requires_broadcast_via_collectives()137   const std::vector<bool> param_requires_broadcast_via_collectives() const {
138     return param_requires_broadcast_via_collectives_;
139   }
set_param_requires_broadcast_via_collectives(const std::vector<bool> require_broadcast)140   void set_param_requires_broadcast_via_collectives(
141       const std::vector<bool> require_broadcast) {
142     param_requires_broadcast_via_collectives_ = std::move(require_broadcast);
143   }
144 
set_use_spmd_partitioning(bool use_spmd_partitioning)145   void set_use_spmd_partitioning(bool use_spmd_partitioning) {
146     use_spmd_partitioning_ = use_spmd_partitioning;
147   }
use_spmd_partitioning()148   bool use_spmd_partitioning() const { return use_spmd_partitioning_; }
149 
150   // If enabled, deduplicate equivalent hlos into function calls to reduce code
151   // size.
set_deduplicate_hlo(bool deduplicate_hlo)152   void set_deduplicate_hlo(bool deduplicate_hlo) {
153     deduplicate_hlo_ = deduplicate_hlo;
154   }
155 
set_device_type(const string & device_type)156   void set_device_type(const string& device_type) {
157     device_type_ = device_type;
158   }
159 
deduplicate_hlo()160   bool deduplicate_hlo() const { return deduplicate_hlo_; }
161 
162   // Return a string which unambiguously represents all the fields of this data
163   // structure. Used for generating a cache key for storing the compiled
164   // executable.
165   string compilation_cache_key() const;
166 
device_type()167   string device_type() const { return device_type_; }
168 
debug_options()169   const DebugOptions& debug_options() const { return debug_options_; }
170 
set_debug_options(const DebugOptions & debug_options)171   void set_debug_options(const DebugOptions& debug_options) {
172     debug_options_ = debug_options;
173   }
174 
175   // Sets/returns the number of intra op threads for this module.
set_intra_op_parallelism_threads(const int intra_op_parallelism_threads)176   void set_intra_op_parallelism_threads(
177       const int intra_op_parallelism_threads) {
178     intra_op_parallelism_threads_ = intra_op_parallelism_threads;
179   }
intra_op_parallelism_threads()180   int64 intra_op_parallelism_threads() const {
181     return intra_op_parallelism_threads_;
182   }
183 
184   // Checks if this config has a static device assignment.
has_static_device_assignment()185   bool has_static_device_assignment() const {
186     return static_device_assignment_.has_value();
187   }
188 
189   // Getter and setter of the compile-time known device assignment.
static_device_assignment()190   const DeviceAssignment& static_device_assignment() const {
191     CHECK(static_device_assignment_.has_value());
192     return *static_device_assignment_;
193   }
set_static_device_assignment(const DeviceAssignment & device_assignment)194   void set_static_device_assignment(const DeviceAssignment& device_assignment) {
195     static_device_assignment_ = device_assignment;
196   }
197 
shardable_value_update_pairs()198   const std::vector<ShardableValueUpdatePair> shardable_value_update_pairs()
199       const {
200     return shardable_value_update_pairs_;
201   }
set_shardable_value_update_pairs(std::vector<ShardableValueUpdatePair> pairs)202   void set_shardable_value_update_pairs(
203       std::vector<ShardableValueUpdatePair> pairs) {
204     shardable_value_update_pairs_ = std::move(pairs);
205   }
206 
207   // Whether input and output buffers are aliased if the associated parameter is
208   // passed-through XLA modules without being changed.
alias_passthrough_params()209   bool alias_passthrough_params() const { return alias_passthrough_params_; }
set_alias_passthrough_params(bool alias_passthrough_params)210   void set_alias_passthrough_params(bool alias_passthrough_params) {
211     alias_passthrough_params_ = alias_passthrough_params;
212   }
213 
content_aware_computation_sorting()214   bool content_aware_computation_sorting() const {
215     return content_aware_computation_sorting_;
216   }
set_content_aware_computation_sorting(bool content_aware_computation_sorting)217   void set_content_aware_computation_sorting(
218       bool content_aware_computation_sorting) {
219     content_aware_computation_sorting_ = content_aware_computation_sorting;
220   }
221 
fusion_config_collection()222   FusionConfigCollection fusion_config_collection() const {
223     return fusion_config_collection_;
224   }
set_fusion_config_collection(FusionConfigCollection fusion_config_collection)225   void set_fusion_config_collection(
226       FusionConfigCollection fusion_config_collection) {
227     fusion_config_collection_ = fusion_config_collection;
228   }
229 
fusion_config()230   const std::vector<std::vector<bool>>& fusion_config() const {
231     return fusion_config_;
232   }
mutable_fusion_config()233   std::vector<std::vector<bool>>* mutable_fusion_config() {
234     return &fusion_config_;
235   }
236 
dot_config()237   const std::vector<std::vector<int64>>& dot_config() const {
238     return dot_config_;
239   }
240 
mutable_dot_config()241   std::vector<std::vector<int64>>* mutable_dot_config() { return &dot_config_; }
242 
layout_config()243   const std::vector<std::vector<std::vector<int64>>>& layout_config() const {
244     return layout_config_;
245   }
246 
mutable_layout_config()247   std::vector<std::vector<std::vector<int64>>>* mutable_layout_config() {
248     return &layout_config_;
249   }
250 
phase_ordering_config()251   const std::vector<std::vector<bool>>& phase_ordering_config() const {
252     return phase_ordering_config_;
253   }
254 
mutable_phase_ordering_config()255   std::vector<std::vector<bool>>* mutable_phase_ordering_config() {
256     return &phase_ordering_config_;
257   }
258 
phase_index()259   const int phase_index() const { return phase_index_; }
set_phase_index(const int phase_index)260   void set_phase_index(const int phase_index) { phase_index_ = phase_index; }
261 
262  private:
263   // If you add new members, be sure to update compilation_cache_key.
264 
265   absl::optional<ComputationLayout> entry_computation_layout_;
266 
267   // Module/graph-level seed handle.
268   uint64 seed_ = 0;
269 
270   // Program id that identifies a set of program to be launched together.
271   int32 launch_id_ = 0;
272 
273   // The number of replicas (data parallelism) to compile this binary for.
274   int64 replica_count_ = 1;
275 
276   // The number of partitions (model parallelism) to compile this binary for.
277   int64 num_partitions_ = 1;
278 
279   // Whether to broadcast args across all replicas. One entry per arg.
280   std::vector<bool> param_requires_broadcast_via_collectives_;
281 
282   // Whether to use SPMD (true) or MPMD (false) when num_partitions_ > 0 and XLA
283   // needs to partition the module.
284   bool use_spmd_partitioning_ = false;
285 
286   // If enabled, deduplicate equivalent hlos into function calls to reduce code
287   // size.
288   bool deduplicate_hlo_ = false;
289 
290   // The target maximum parallelism at which to partition HLOs for parallel
291   // execution on the CPU backend.
292   int64 intra_op_parallelism_threads_ = -1;
293 
294   string device_type_;
295 
296   DebugOptions debug_options_;
297 
298   // Compile-time known device assignment.
299   absl::optional<DeviceAssignment> static_device_assignment_;
300 
301   std::vector<ShardableValueUpdatePair> shardable_value_update_pairs_;
302 
303   bool alias_passthrough_params_ = false;
304 
305   bool content_aware_computation_sorting_ = true;
306 
307   FusionConfigCollection fusion_config_collection_ =
308       FusionConfigCollection::kOff;
309 
310   // TODO(b/155665133): Consolidate fusion, dot, and layout config into a proto
311   // similar to backend config.
312 
313   // Custom fusion configuration, where fusion_config_[c][v] control if node v
314   // in computation c must be fused to all its consumers (true) or not (false).
315   std::vector<std::vector<bool>> fusion_config_;
316 
317   // Custom dot canonicalization configuration, where dot_config_[v] control
318   // how to convert dot operation v (sorted topologically and by computation) to
319   // convolution.
320   std::vector<std::vector<int64>> dot_config_;
321 
322   // Layout configuration, where layout_config_[v][i] controls the layout
323   // decision i of operation v.
324   std::vector<std::vector<std::vector<int64>>> layout_config_;
325 
326   // Phase ordering configuration, where phase_ordering_config[v][i] controls
327   // whether a specific pass with index i (e.g. 0 = DCE, 1 = CSE, etc.) is
328   // inserted after pass v in pipeline. See tuning::PhaseOrderingConfig for
329   // details on what indices (i) correspond to which passes.
330   std::vector<std::vector<bool>> phase_ordering_config_;
331   // Index (v) corresponding to current passes being added for phase ordering.
332   // This is the variable that stores state to allow us to use the same
333   // config across functions during compilation.
334   int phase_index_ = 0;
335 };
336 
337 }  // namespace xla
338 
339 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_CONFIG_H_
340