• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_
17 #define TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_
18 
19 #include <string>
20 
21 #include "tensorflow/compiler/xla/types.h"
22 
23 // These classes are forward declared so that ExecutableRunOptions can be linked
24 // into an XLA-compiled binary without having to link all of the pointed-to
25 // objects (e.g., for an ahead-of-time compiled CPU binary, the gpu tools don't
26 // need to be linked).
27 namespace stream_executor {
28 class Stream;
29 class Platform;
30 class DeviceMemoryAllocator;
31 }  // namespace stream_executor
32 
33 namespace Eigen {
34 struct ThreadPoolDevice;
35 }  // namespace Eigen
36 
37 namespace xla {
38 
39 class DeviceAssignment;
40 class ExecutionProfile;
41 
42 // A unique identifier for a particular "logical execution" of an XLA model.
43 //
44 // A logical execution might encompass multiple executions of one or more
45 // HloModules.  Runs that are part of the same logical execution can
46 // communicate via collective ops (e.g. kAllToAll), whereas runs that are part
47 // of different logical executions are isolated.
48 class RunId {
49  public:
50   // Creates a new, unique RunId.
51   RunId();
52 
53   RunId(const RunId&) = default;
54   RunId& operator=(const RunId&) = default;
55   friend bool operator==(const RunId& a, const RunId& b);
56   std::string ToString() const;
57 
58   template <typename H>
AbslHashValue(H h,const RunId & id)59   friend H AbslHashValue(H h, const RunId& id) {
60     return H::combine(std::move(h), id.data_);
61   }
62 
63  private:
64   int64 data_;
65 };
66 
67 // Class containing options for running a LocalExecutable.
68 class ExecutableRunOptions {
69  public:
70   // Specifies the allocator to use during execution.
71   ExecutableRunOptions& set_allocator(
72       stream_executor::DeviceMemoryAllocator* allocator);
73   stream_executor::DeviceMemoryAllocator* allocator() const;
74 
75   // If set, this is the device to run the computation on. Valid device_ordinal
76   // values are: 0 to # of devices - 1. These values are identical to the device
77   // ordinal values used by StreamExecutor. The device must be of the same type
78   // as the executable was compiled for. A value of -1 indicates this option has
79   // not been set.
80   ExecutableRunOptions& set_device_ordinal(int device_ordinal);
81   int device_ordinal() const;
82 
83   // If set, this is the stream to run the computation on. The platform of the
84   // stream must match the platform the executable was built for.  A value of
85   // nullptr indicates the option has not been set.
86   ExecutableRunOptions& set_stream(stream_executor::Stream* stream);
87   stream_executor::Stream* stream() const;
88 
89   // If set, this is the stream to perform any pre-computation transfers on.
90   // The platform of the stream must match the platform the executable was
91   // built for.  A value of nullptr indicates the option has not been set.
92   ExecutableRunOptions& set_host_to_device_stream(
93       stream_executor::Stream* stream);
94   stream_executor::Stream* host_to_device_stream() const;
95 
96   // Sets the thread pool device on which to run Eigen subcomputations.
97   //
98   // This field must be set for XLA:CPU models that call Eigen routines, but may
99   // be null otherwise.  Routines that use this field should always CHECK (or
100   // TF_RET_CHECK) that it's not null before dereferencing it, so that users get
101   // a clean crash rather than a segfault.
102   //
103   // Does not take ownership.
104   ExecutableRunOptions& set_intra_op_thread_pool(
105       const Eigen::ThreadPoolDevice* intra_op_thread_pool);
106   const Eigen::ThreadPoolDevice* intra_op_thread_pool() const;
107 
108   // If set, profiling information is written to 'profile'.
109   ExecutionProfile* execution_profile() const;
110   ExecutableRunOptions& set_execution_profile(ExecutionProfile* profile);
111 
112   ExecutableRunOptions& set_device_assignment(
113       const DeviceAssignment* device_assignment);
114   const DeviceAssignment* device_assignment() const;
115 
116   ExecutableRunOptions& set_rng_seed(int rng_seed);
117   int rng_seed() const;
118 
119   ExecutableRunOptions& set_run_id(RunId id);
120   RunId run_id() const;
121 
122  private:
123   stream_executor::DeviceMemoryAllocator* allocator_ = nullptr;
124   int device_ordinal_ = -1;
125   const DeviceAssignment* device_assignment_ = nullptr;
126   stream_executor::Stream* stream_ = nullptr;
127   const Eigen::ThreadPoolDevice* intra_op_thread_pool_ = nullptr;
128   ExecutionProfile* execution_profile_ = nullptr;
129   int rng_seed_ = 0;
130   stream_executor::Stream* host_to_device_stream_ = nullptr;
131   RunId run_id_;
132 };
133 
134 }  // namespace xla
135 
136 #endif  // TENSORFLOW_COMPILER_XLA_EXECUTABLE_RUN_OPTIONS_H_
137