1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_ 17 #define TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_ 18 19 #include <string> 20 21 #include "tensorflow/compiler/xla/types.h" 22 23 // These classes are forward declared so that ExecutableRunOptions can be linked 24 // into an XLA-compiled binary without having to link all of the pointed-to 25 // objects (e.g., for an ahead-of-time compiled CPU binary, the gpu tools don't 26 // need to be linked). 27 namespace stream_executor { 28 class Stream; 29 class Platform; 30 class DeviceMemoryAllocator; 31 } // namespace stream_executor 32 33 namespace Eigen { 34 struct ThreadPoolDevice; 35 } // namespace Eigen 36 37 namespace xla { 38 39 class DeviceAssignment; 40 class ExecutionProfile; 41 42 // A unique identifier for a particular "logical execution" of an XLA model. 43 // 44 // A logical execution might encompass multiple executions of one or more 45 // HloModules. Runs that are part of the same logical execution can 46 // communicate via collective ops (e.g. kAllToAll), whereas runs that are part 47 // of different logical executions are isolated. 48 class RunId { 49 public: 50 // Creates a new, unique RunId. 51 RunId(); 52 53 RunId(const RunId&) = default; 54 RunId& operator=(const RunId&) = default; 55 friend bool operator==(const RunId& a, const RunId& b); 56 std::string ToString() const; 57 58 template <typename H> AbslHashValue(H h,const RunId & id)59 friend H AbslHashValue(H h, const RunId& id) { 60 return H::combine(std::move(h), id.data_); 61 } 62 63 private: 64 int64 data_; 65 }; 66 67 // Class containing options for running a LocalExecutable. 68 class ExecutableRunOptions { 69 public: 70 // Specifies the allocator to use during execution. 71 ExecutableRunOptions& set_allocator( 72 stream_executor::DeviceMemoryAllocator* allocator); 73 stream_executor::DeviceMemoryAllocator* allocator() const; 74 75 // If set, this is the device to run the computation on. Valid device_ordinal 76 // values are: 0 to # of devices - 1. These values are identical to the device 77 // ordinal values used by StreamExecutor. The device must be of the same type 78 // as the executable was compiled for. A value of -1 indicates this option has 79 // not been set. 80 ExecutableRunOptions& set_device_ordinal(int device_ordinal); 81 int device_ordinal() const; 82 83 // If set, this is the stream to run the computation on. The platform of the 84 // stream must match the platform the executable was built for. A value of 85 // nullptr indicates the option has not been set. 86 ExecutableRunOptions& set_stream(stream_executor::Stream* stream); 87 stream_executor::Stream* stream() const; 88 89 // If set, this is the stream to perform any pre-computation transfers on. 90 // The platform of the stream must match the platform the executable was 91 // built for. A value of nullptr indicates the option has not been set. 92 ExecutableRunOptions& set_host_to_device_stream( 93 stream_executor::Stream* stream); 94 stream_executor::Stream* host_to_device_stream() const; 95 96 // Sets the thread pool device on which to run Eigen subcomputations. 97 // 98 // This field must be set for XLA:CPU models that call Eigen routines, but may 99 // be null otherwise. Routines that use this field should always CHECK (or 100 // TF_RET_CHECK) that it's not null before dereferencing it, so that users get 101 // a clean crash rather than a segfault. 102 // 103 // Does not take ownership. 104 ExecutableRunOptions& set_intra_op_thread_pool( 105 const Eigen::ThreadPoolDevice* intra_op_thread_pool); 106 const Eigen::ThreadPoolDevice* intra_op_thread_pool() const; 107 108 // If set, profiling information is written to 'profile'. 109 ExecutionProfile* execution_profile() const; 110 ExecutableRunOptions& set_execution_profile(ExecutionProfile* profile); 111 112 ExecutableRunOptions& set_device_assignment( 113 const DeviceAssignment* device_assignment); 114 const DeviceAssignment* device_assignment() const; 115 116 ExecutableRunOptions& set_rng_seed(int rng_seed); 117 int rng_seed() const; 118 119 ExecutableRunOptions& set_run_id(RunId id); 120 RunId run_id() const; 121 122 private: 123 stream_executor::DeviceMemoryAllocator* allocator_ = nullptr; 124 int device_ordinal_ = -1; 125 const DeviceAssignment* device_assignment_ = nullptr; 126 stream_executor::Stream* stream_ = nullptr; 127 const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr; 128 ExecutionProfile* execution_profile_ = nullptr; 129 int rng_seed_ = 0; 130 stream_executor::Stream* host_to_device_stream_ = nullptr; 131 RunId run_id_; 132 }; 133 134 } // namespace xla 135 136 #endif // TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_ 137