• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_CONFIG_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_CONFIG_H_
18 
19 #include <string>
20 
21 #include "absl/types/optional.h"
22 #include "tensorflow/compiler/xla/service/computation_layout.h"
23 #include "tensorflow/compiler/xla/service/computation_placer.h"
24 #include "tensorflow/compiler/xla/types.h"
25 #include "tensorflow/compiler/xla/xla.pb.h"
26 #include "tensorflow/compiler/xla/xla_data.pb.h"
27 
28 namespace xla {
29 
30 enum class FusionConfigCollection {
31   kOff,      // Do not collect configuration.
32   kPerEdge,  // Collect per-edge configuration.
33   kPerNode,  // Collect per-node configuration.
34 };
35 
36 // This class gathers all settings and values which affect the compiled
37 // executable outside of the HLO code itself. This include layouts of inputs and
38 // outputs to the module and settings such as HLO profiling. Together the
39 // HloModule and HloModuleConfig unambiguously determine a particular
40 // executable.
41 class HloModuleConfig {
42  public:
43   // Represents a pair of input and output of the entry computation that can be
44   // considered as the original and updated values of a variable maintained by
45   // the caller, and that can be transparently sharded by XLA as an internal
46   // optimization. If sharded, XLA will create separate sharding/unsharding
47   // programs, and the caller is responsible to call the XLA-generated
48   // sharding/unsharding programs before and after the sharded main program.
49   //
50   // If the variable is not updated and there is not a corresponding output, use
51   // {-1} as the output_shape_index.
52   //
53   // The sharding/unsharding programs will include all the input/output pairs in
54   // shardable_value_update_pairs() as a flat tuple in their inputs/outputs,
55   // sorted by (input_parameter_number, parameter_shape_index).
56   //
57   // A typical usage pattern is to shard the variables first, then repeatedly
58   // invoke the main program, and finally invoke the unsharding program before
59   // they are used in full-shape.
60   struct ShardableValueUpdatePair {
61     int64 input_parameter_number;
62     ShapeIndex parameter_shape_index;
63     ShapeIndex output_shape_index;
64   };
65 
66   // A configuration can be created either with, or without an entry
67   // ComputationLayout. The default ctor creates it without -- in this case
68   // accessing entry_computation_layout will CHECK-fail. The ctor accepting a
69   // ProgramShape creates a computation layout using this shape.
70   // The layouts in the ProgramShape will be reset to default unless
71   // ignore_layouts is set to false.
72   HloModuleConfig() = default;
73 
74   explicit HloModuleConfig(const ProgramShape& program_shape,
75                            bool ignore_layouts = true);
76 
77   explicit HloModuleConfig(ComputationLayout entry_computation_layout);
78 
79   // Checks if this config has an entry computation layout already.
has_entry_computation_layout()80   bool has_entry_computation_layout() const {
81     return entry_computation_layout_.has_value();
82   }
83 
84   // Sets the entry_computation_layout's parameter and result shapes for this
85   // config, according to the given program shape. The parameters and result
86   // are set to default layout.
87   void SetDefaultComputationLayout(const ProgramShape& program_shape);
88 
89   // Same as above but if the given program contains layout for parameters or
90   // result, the entry_computation_layout's layout is updated accordingly.
91   void SetComputationLayoutIfExists(const ProgramShape& program_shape);
92 
93   // Returns a constant reference to the layout of the entry computation.
94   // Assumes the layout was set.
entry_computation_layout()95   const ComputationLayout& entry_computation_layout() const {
96     CHECK(entry_computation_layout_.has_value());
97     return *entry_computation_layout_;
98   }
99 
100   // Returns a mutable pointer to the layout of the entry computation.
101   // Assumes the layout was set.
mutable_entry_computation_layout()102   ComputationLayout* mutable_entry_computation_layout() {
103     CHECK(entry_computation_layout_.has_value());
104     return &(*entry_computation_layout_);
105   }
106 
107   // Returns whether to enable HLO-level profiling.
hlo_profiling_enabled()108   bool hlo_profiling_enabled() const {
109     return debug_options_.xla_hlo_profile();
110   }
111 
cpu_traceme_enabled()112   bool cpu_traceme_enabled() const {
113     return debug_options_.xla_cpu_enable_xprof_traceme();
114   }
115 
116   // Sets/returns the module seed set during execution.
set_seed(uint64 seed)117   void set_seed(uint64 seed) { seed_ = seed; }
seed()118   uint64 seed() const { return seed_; }
119 
120   // Set the launch id of the program. Launch id identifies a set of programs
121   // that should be launched together.
set_launch_id(uint64 launch_id)122   void set_launch_id(uint64 launch_id) { launch_id_ = launch_id; }
123 
launch_id()124   int32 launch_id() const { return launch_id_; }
125 
set_replica_count(int64 replica_count)126   void set_replica_count(int64 replica_count) {
127     replica_count_ = replica_count;
128   }
replica_count()129   int64 replica_count() const { return replica_count_; }
130 
set_num_partitions(int64 num_partitions)131   void set_num_partitions(int64 num_partitions) {
132     num_partitions_ = num_partitions;
133   }
num_partitions()134   int64 num_partitions() const { return num_partitions_; }
135 
set_broadcast_replicated_params(bool broadcast_replicated_params)136   void set_broadcast_replicated_params(bool broadcast_replicated_params) {
137     broadcast_replicated_params_ = broadcast_replicated_params;
138   }
broadcast_replicated_params()139   bool broadcast_replicated_params() const {
140     return broadcast_replicated_params_;
141   }
142 
set_use_spmd_partitioning(bool use_spmd_partitioning)143   void set_use_spmd_partitioning(bool use_spmd_partitioning) {
144     use_spmd_partitioning_ = use_spmd_partitioning;
145   }
use_spmd_partitioning()146   bool use_spmd_partitioning() const { return use_spmd_partitioning_; }
147 
148   // If enabled, deduplicate equivalent hlos into function calls to reduce code
149   // size.
set_deduplicate_hlo(bool deduplicate_hlo)150   void set_deduplicate_hlo(bool deduplicate_hlo) {
151     deduplicate_hlo_ = deduplicate_hlo;
152   }
deduplicate_hlo()153   bool deduplicate_hlo() const { return deduplicate_hlo_; }
154 
155   // Return a string which unambiguously represents all the fields of this data
156   // structure. Used for generating a cache key for storing the compiled
157   // executable.
158   string compilation_cache_key() const;
159 
debug_options()160   const DebugOptions& debug_options() const { return debug_options_; }
161 
set_debug_options(const DebugOptions & debug_options)162   void set_debug_options(const DebugOptions& debug_options) {
163     debug_options_ = debug_options;
164   }
165 
166   // Sets/returns the number of intra op threads for this module.
set_intra_op_parallelism_threads(const int intra_op_parallelism_threads)167   void set_intra_op_parallelism_threads(
168       const int intra_op_parallelism_threads) {
169     intra_op_parallelism_threads_ = intra_op_parallelism_threads;
170   }
intra_op_parallelism_threads()171   int64 intra_op_parallelism_threads() const {
172     return intra_op_parallelism_threads_;
173   }
174 
175   // Checks if this config has a static device assignment.
has_static_device_assignment()176   bool has_static_device_assignment() const {
177     return static_device_assignment_.has_value();
178   }
179 
180   // Getter and setter of the compile-time known device assignment.
static_device_assignment()181   const DeviceAssignment& static_device_assignment() const {
182     CHECK(static_device_assignment_.has_value());
183     return *static_device_assignment_;
184   }
set_static_device_assignment(const DeviceAssignment & device_assignment)185   void set_static_device_assignment(const DeviceAssignment& device_assignment) {
186     static_device_assignment_ = device_assignment;
187   }
188 
shardable_value_update_pairs()189   const std::vector<ShardableValueUpdatePair> shardable_value_update_pairs()
190       const {
191     return shardable_value_update_pairs_;
192   }
set_shardable_value_update_pairs(std::vector<ShardableValueUpdatePair> pairs)193   void set_shardable_value_update_pairs(
194       std::vector<ShardableValueUpdatePair> pairs) {
195     shardable_value_update_pairs_ = std::move(pairs);
196   }
197 
198   // Whether input and output buffers are aliased if the associated parameter is
199   // passed-through XLA modules without being changed.
alias_passthrough_params()200   bool alias_passthrough_params() const { return alias_passthrough_params_; }
set_alias_passthrough_params(bool alias_passthrough_params)201   void set_alias_passthrough_params(bool alias_passthrough_params) {
202     alias_passthrough_params_ = alias_passthrough_params;
203   }
204 
content_aware_computation_sorting()205   bool content_aware_computation_sorting() const {
206     return content_aware_computation_sorting_;
207   }
set_content_aware_computation_sorting(bool content_aware_computation_sorting)208   void set_content_aware_computation_sorting(
209       bool content_aware_computation_sorting) {
210     content_aware_computation_sorting_ = content_aware_computation_sorting;
211   }
212 
fusion_config_collection()213   FusionConfigCollection fusion_config_collection() const {
214     return fusion_config_collection_;
215   }
set_fusion_config_collection(FusionConfigCollection fusion_config_collection)216   void set_fusion_config_collection(
217       FusionConfigCollection fusion_config_collection) {
218     fusion_config_collection_ = fusion_config_collection;
219   }
220 
fusion_config()221   const std::vector<std::vector<bool>>& fusion_config() const {
222     return fusion_config_;
223   }
mutable_fusion_config()224   std::vector<std::vector<bool>>* mutable_fusion_config() {
225     return &fusion_config_;
226   }
227 
dot_config()228   const std::vector<std::vector<int64>>& dot_config() const {
229     return dot_config_;
230   }
231 
mutable_dot_config()232   std::vector<std::vector<int64>>* mutable_dot_config() { return &dot_config_; }
233 
layout_config()234   const std::vector<std::vector<std::vector<int64>>>& layout_config() const {
235     return layout_config_;
236   }
237 
mutable_layout_config()238   std::vector<std::vector<std::vector<int64>>>* mutable_layout_config() {
239     return &layout_config_;
240   }
241 
242  private:
243   // If you add new members, be sure to update compilation_cache_key.
244 
245   absl::optional<ComputationLayout> entry_computation_layout_;
246 
247   // Module/graph-level seed handle.
248   uint64 seed_ = 0;
249 
250   // Program id that identifies a set of program to be launched together.
251   int32 launch_id_ = 0;
252 
253   // The number of replicas (data parallelism) to compile this binary for.
254   int64 replica_count_ = 1;
255 
256   // The number of partitions (model parallelism) to compile this binary for.
257   int64 num_partitions_ = 1;
258 
259   // Whether to use XLA collectives to broadcast params to all replicas.
260   bool broadcast_replicated_params_ = false;
261 
262   // Whether to use SPMD (true) or MPMD (false) when num_partitions_ > 0 and XLA
263   // needs to partition the module.
264   bool use_spmd_partitioning_ = false;
265 
266   // If enabled, deduplicate equivalent hlos into function calls to reduce code
267   // size.
268   bool deduplicate_hlo_ = false;
269 
270   // The target maximum parallelism at which to partition HLOs for parallel
271   // execution on the CPU backend.
272   int64 intra_op_parallelism_threads_ = -1;
273 
274   DebugOptions debug_options_;
275 
276   // Compile-time known device assignment.
277   absl::optional<DeviceAssignment> static_device_assignment_;
278 
279   std::vector<ShardableValueUpdatePair> shardable_value_update_pairs_;
280 
281   bool alias_passthrough_params_ = false;
282 
283   bool content_aware_computation_sorting_ = false;
284 
285   FusionConfigCollection fusion_config_collection_ =
286       FusionConfigCollection::kOff;
287 
288   // TODO(b/155665133): Consolidate fusion, dot, and layout config into a proto
289   // similar to backend config.
290 
291   // Custom fusion configuration, where fusion_config_[c][v] control if node v
292   // in computation c must be fused to all its consumers (true) or not (false).
293   std::vector<std::vector<bool>> fusion_config_;
294 
295   // Custom dot canonicalization configuration, where dot_config_[v] control
296   // how to convert dot operation v (sorted topologically and by computation) to
297   // convolution.
298   std::vector<std::vector<int64>> dot_config_;
299 
300   // Layout configuration, where layout_config_[v][i] controls the layout
301   // decision i of operation v.
302   std::vector<std::vector<std::vector<int64>>> layout_config_;
303 };
304 
305 }  // namespace xla
306 
307 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_CONFIG_H_
308