1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_CONFIG_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_CONFIG_H_ 18 19 #include <string> 20 21 #include "absl/types/optional.h" 22 #include "tensorflow/compiler/xla/service/computation_layout.h" 23 #include "tensorflow/compiler/xla/service/computation_placer.h" 24 #include "tensorflow/compiler/xla/types.h" 25 #include "tensorflow/compiler/xla/xla.pb.h" 26 #include "tensorflow/compiler/xla/xla_data.pb.h" 27 28 namespace xla { 29 30 // This class gathers all settings and values which affect the compiled 31 // executable outside of the HLO code itself. This include layouts of inputs and 32 // outputs to the module and settings such as HLO profiling. Together the 33 // HloModule and HloModuleConfig unambiguously determine a particular 34 // executable. 35 class HloModuleConfig { 36 public: 37 // A configuration can be created either with, or without an entry 38 // ComputationLayout. The default ctor creates it without -- in this case 39 // accessing entry_computation_layout will CHECK-fail. The ctor accepting a 40 // ProgramShape creates a computation layout using this shape. 41 // The layouts in the ProgramShape will be reset to default unless 42 // ignore_layouts is set to false. 43 HloModuleConfig() = default; 44 45 explicit HloModuleConfig(const ProgramShape& program_shape, 46 bool ignore_layouts = true); 47 48 // Checks if this config has an entry computation layout already. has_entry_computation_layout()49 bool has_entry_computation_layout() const { 50 return entry_computation_layout_.has_value(); 51 } 52 53 // Sets the entry computation layout for this config. If the entry computation 54 // layout already exists, it is silently replaced. 55 void SetDefaultComputationLayout(const ProgramShape& program_shape); 56 57 // Returns a constant reference to the layout of the entry computation. 58 // Assumes the layout was set. entry_computation_layout()59 const ComputationLayout& entry_computation_layout() const { 60 CHECK(entry_computation_layout_.has_value()); 61 return *entry_computation_layout_; 62 } 63 64 // Returns a mutable pointer to the layout of the entry computation. 65 // Assumes the layout was set. mutable_entry_computation_layout()66 ComputationLayout* mutable_entry_computation_layout() { 67 CHECK(entry_computation_layout_.has_value()); 68 return &(*entry_computation_layout_); 69 } 70 71 // Returns whether to enable HLO-level profiling. hlo_profiling_enabled()72 bool hlo_profiling_enabled() const { 73 return debug_options_.xla_hlo_profile(); 74 } 75 76 // Sets/returns the module seed set during execution. set_seed(uint64 seed)77 void set_seed(uint64 seed) { seed_ = seed; } seed()78 uint64 seed() const { return seed_; } 79 set_replica_count(int64 replica_count)80 void set_replica_count(int64 replica_count) { 81 replica_count_ = replica_count; 82 } replica_count()83 int64 replica_count() const { return replica_count_; } 84 85 // Return a string which unambiguously represents all the fields of this data 86 // structure. Used for generating a cache key for storing the compiled 87 // executable. 88 string compilation_cache_key() const; 89 debug_options()90 const DebugOptions& debug_options() const { return debug_options_; } 91 set_debug_options(const DebugOptions & debug_options)92 void set_debug_options(const DebugOptions& debug_options) { 93 debug_options_ = debug_options; 94 } 95 96 // Sets/returns the number of intra op threads for this module. set_intra_op_parallelism_threads(const int intra_op_parallelism_threads)97 void set_intra_op_parallelism_threads( 98 const int intra_op_parallelism_threads) { 99 intra_op_parallelism_threads_ = intra_op_parallelism_threads; 100 } intra_op_parallelism_threads()101 int64 intra_op_parallelism_threads() const { 102 return intra_op_parallelism_threads_; 103 } 104 105 // Checks if this config has a static device assignment. has_static_device_assignment()106 bool has_static_device_assignment() const { 107 return static_device_assignment_.has_value(); 108 } 109 110 // Getter and setter of the compile-time known device assignment. static_device_assignment()111 const DeviceAssignment& static_device_assignment() const { 112 CHECK(static_device_assignment_.has_value()); 113 return *static_device_assignment_; 114 } set_static_device_assignment(const DeviceAssignment & device_assignment)115 void set_static_device_assignment(const DeviceAssignment& device_assignment) { 116 static_device_assignment_ = device_assignment; 117 } 118 119 private: 120 // If you add new members, be sure to update compilation_cache_key. 121 122 absl::optional<ComputationLayout> entry_computation_layout_; 123 124 // Module/graph-level seed handle. 125 uint64 seed_ = 0; 126 127 // The number of replicas to compile this binary for. 128 int64 replica_count_ = 1; 129 130 // The target maximum parallelism at which to partition HLOs for parallel 131 // execution on the CPU backend. 132 int64 intra_op_parallelism_threads_ = -1; 133 134 DebugOptions debug_options_; 135 136 // Compile-time known device assignment. 137 absl::optional<DeviceAssignment> static_device_assignment_; 138 }; 139 140 } // namespace xla 141 142 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_HLO_MODULE_CONFIG_H_ 143