• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_
17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_
18 
19 #include <cassert>
20 #include <string>
21 
22 #include "tensorflow/compiler/xla/cpu_function_runtime.h"
23 #include "tensorflow/compiler/xla/executable_run_options.h"
24 #include "tensorflow/core/platform/types.h"
25 
26 // Forward-declare, rather than include, to reduce code size for users that
27 // never use this functionality.
28 namespace xla {
29 class ProgramShapeProto;
30 class HloProfilePrinterData;
31 }
32 
33 namespace tensorflow {
34 
35 // Represents a function compiled by XLA, produced via either JIT or AOT.
36 //
37 // The Run method invokes the actual computation, with inputs read from arg
38 // buffers, and outputs written to result buffers. Each Run call may also use a
39 // set of temporary buffers for the computation.
40 //
41 // By default each instance of this class manages its own arg, result and temp
42 // buffers. The AllocMode constructor parameter may be used to modify the buffer
43 // allocation strategy.
44 //
45 // Under the default allocation strategy, this class is thread-compatible:
46 // o Calls to non-const methods require exclusive access to the object.
47 // o Concurrent calls to const methods are OK, if those calls are made while it
48 //   is guaranteed that no thread may call a non-const method.
49 class XlaCompiledCpuFunction {
50  public:
51   // Type of the raw function, produced by either JIT or AOT.
52   using RawFunction = void (*)(void* result,
53                                const xla::ExecutableRunOptions* run_options,
54                                const void** args, void** temps,
55                                int64* profile_counters);
56 
57   // StaticData represents the state necessary to run an XLA-compiled
58   // function. For JIT this is backed by data in XlaJitCompiledCpuFunction; for
59   // AOT this is backed by data compiled into the object file.
60   //
61   // The contents of StaticData are XLA-internal implementation details and
62   // should not be relied on by clients (and therefore are private).
63   class StaticData {
64    private:
65     // The raw function to call.
66     RawFunction raw_function_;
67 
68     // Contains information about the buffers used by the XLA computation.
69     const xla::cpu_function_runtime::BufferInfo* buffer_infos_ = nullptr;
70     size_t num_buffers_ = 0;
71 
72     // Entry parameter i is described by
73     // buffer_infos[arg_index_table[i]].
74     const int32* arg_index_table_ = nullptr;
75 
76     // There are num_args entry parameters.
77     int64 num_args_ = 0;
78 
79     // There are num_variables variables.
80     int64 num_variables_ = 0;
81 
82     // The 0-based index of the result tuple, in the temp buffers.
83     size_t result_index_ = 0;
84 
85     // [Optional] Arrays of arg and result names. These are arrays of C-style
86     // strings, where the array is terminated by nullptr.
87     const char** arg_names_ = nullptr;
88     const char** variable_names_ = nullptr;
89     const char** result_names_ = nullptr;
90 
91     // [Optional] Arg and result shapes.
92     const xla::ProgramShapeProto* program_shape_ = nullptr;
93 
94     // [Optional] Profile printer data.  Null if profiling is disabled.
95     const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr;
96 
97     // [Optional] The number of profile counters expected in the profile counter
98     // buffer by the generated code and hlo_profile_printer.  0 if profiling is
99     // disabled.  This information is already present in
100     // hlo_profile_printer_data but xla::HloProfilePrinterData is forward
101     // declared so we don't have access to that information here.
102     int64 profile_counters_size_ = 0;
103 
104     // Only XlaCompiledCpuFunction is allowed to read and write the above
105     // fields.
106     friend class XlaCompiledCpuFunction;
107   };
108 
109   // AllocMode controls the buffer allocation mode.
110   enum class AllocMode {
111     // Allocate all buffers - args, results, profile and temps.
112     ARGS_VARIABLES_RESULTS_PROFILES_AND_TEMPS,
113 
114     // Only allocate result, profile and temp buffers.
115     // Use set_arg_data to set argument buffers before Run is called.
116     RESULTS_PROFILES_AND_TEMPS_ONLY,
117   };
118 
119   explicit XlaCompiledCpuFunction(
120       const StaticData& static_data,
121       AllocMode alloc_mode =
122           AllocMode::ARGS_VARIABLES_RESULTS_PROFILES_AND_TEMPS);
123   virtual ~XlaCompiledCpuFunction();
124 
125   XlaCompiledCpuFunction(const XlaCompiledCpuFunction&) = delete;
126   XlaCompiledCpuFunction& operator=(const XlaCompiledCpuFunction&) = delete;
127 
128   // Sets the intra-op thread pool used to run individual ops concurrently.
set_thread_pool(const Eigen::ThreadPoolDevice * pool)129   void set_thread_pool(const Eigen::ThreadPoolDevice* pool) {
130     run_options_.set_intra_op_thread_pool(pool);
131   }
132 
133   // Runs the computation, with inputs read from arg buffers, and outputs
134   // written to result buffers. Returns true on success and false on failure.
135   bool Run();
136 
137   // Returns the error message from the previous failed Run call.
138   //
139   // TODO(fschneider): For now this always returns an empty string because there
140   // is no support for error reporting in XLA. Remove this once all callers are
141   // updated.
error_msg()142   string error_msg() const { return {}; }
143 
144   // ------------------------------
145   // Arg methods for managing input buffers. Buffers are in row-major order.
146 
147   // Returns the buffer for the positional argument at the given `index`.
arg_data(size_t index)148   void* arg_data(size_t index) {
149     return buffer_table_[arg_index_table_[index]];
150   }
arg_data(size_t index)151   const void* arg_data(size_t index) const {
152     return buffer_table_[arg_index_table_[index]];
153   }
154 
num_args()155   int num_args() const { return num_args_; }
156 
num_variables()157   int num_variables() const { return num_variables_; }
158 
159   // Returns the size of entry parameter `idx`.
160   //
161   // There is a static version of this method on tfcompile generated subclasses
162   // of XlaCompiledCpuFunction, but try to prefer this when possible since it
163   // works both for XlaJitCompiledCpuFunction and AOT compiled subclasses.
arg_size(int idx)164   int arg_size(int idx) const {
165     assert(idx < num_args());
166     return buffer_infos_[arg_index_table_[idx]].size();
167   }
168 
169   // Sets the buffer for the positional argument at the given `index` to `data`.
170   // Must be called before Run to have an effect. May be called under any
171   // AllocMode; if the AllocMode is RESULTS_AND_TEMPS_ONLY, this method must be
172   // called for each positional argument, in order to set the argument buffers.
173   //
174   // Allocated memory must be aligned to the size specified by
175   // xla::cpu_function_runtime::kMinAlign. If possible, use the functions in
176   // tensorflow/compiler/tf2xla/cpu_function_runtime.h to ensure correct
177   // alignment.
178   //
179   // Aliasing of argument and result buffers is not allowed, and results in
180   // undefined behavior.
set_arg_data(size_t index,const void * data)181   void set_arg_data(size_t index, const void* data) {
182     assert((arg_size(index) < xla::cpu_function_runtime::kMinAlign ||
183             (uintptr_t)data % xla::cpu_function_runtime::kMinAlign == 0) &&
184            "Underaligned pointer!");
185     // The const_cast is safe because the generated code does not write to arg
186     // buffers.
187     //
188     // buffer_table_ contains pointers to buffers that _will_ be written to by
189     // generated code so it would be misleading to make buffer_table_ a `const
190     // void**`.
191     buffer_table_[arg_index_table_[index]] = const_cast<void*>(data);
192   }
193 
194   // ------------------------------
195   // Result methods for managing output buffers. Buffers are in row-major order.
196   // Must only be called after a successful Run call. Unlike the arg methods,
197   // there is no set_resultN_data method. The result buffers are managed
198   // internally, and may change after each call to Run.
199 
200   // Returns the underlying array of result buffers, where results()[I] is the
201   // buffer for the positional result at index I.
results()202   void** results() { return static_cast<void**>(buffer_table_[result_index_]); }
results()203   const void* const* results() const {
204     return static_cast<const void* const*>(buffer_table_[result_index_]);
205   }
206 
207   // Profile counters for this XLA computation.
208   //
209   // When Hlo profiling is enabled (`hlo_profiling_enabled()` return true in
210   // this case) these counters are non-null and are automatically populated by
211   // `Run`.  The counters can then be pretty-printed using
212   // `hlo_profile_printer()`.
213   //
214   // When Hlo profiling is disabled, this accessor returns null.
profile_counters()215   const int64* profile_counters() const { return profile_counters_; }
216 
217   // Returns the buffer for the positional result at the given `index`.
result_data(size_t index)218   void* result_data(size_t index) { return results()[index]; }
result_data(size_t index)219   const void* result_data(size_t index) const { return results()[index]; }
220 
221   // ------------------------------
222   // Methods for extracting optional metadata.
223 
224   // Returns true iff data is available for the Lookup{Arg,Variable,Result}Index
225   // methods. E.g. the data might not be compiled into the binary for AOT.
HasNameIndices()226   bool HasNameIndices() const {
227     return arg_names_ != nullptr && variable_names_ != nullptr &&
228            result_names_ != nullptr;
229   }
230 
231   // Returns the 0-based index for the argument with the given `name`.
232   // Returns -1 if the name wasn't found, or data isn't available.
233   //
234   // The index remains constant for every instance of XlaCompiledCpuFunction
235   // generated from the same static data, and might not be cheap to determine.
236   // Recommended usage is to capture this in a variable for re-use.
237   int LookupArgIndex(const string& name) const;
238 
239   // Returns the 0-based index for the variable with the given `name`.
240   // Returns -1 if the name wasn't found, or data isn't available.
241   //
242   // The index remains constant for every instance of XlaCompiledCpuFunction
243   // generated from the same static data, and might not be cheap to determine.
244   // Recommended usage is to capture this in a variable for re-use.
245   int LookupVariableIndex(const string& name) const;
246 
247   // Returns the 0-based index for the result with the given `name`.
248   // Returns -1 if the name wasn't found, or data isn't available.
249   //
250   // The index remains constant for every instance of XlaCompiledCpuFunction
251   // generated from the same static data, and might not be cheap to determine.
252   // Recommended usage is to capture this in a variable for re-use.
253   int LookupResultIndex(const string& name) const;
254 
255   // Returns the shape of the args and results. May return nullptr if the
256   // program shape isn't available.
ProgramShape()257   const xla::ProgramShapeProto* ProgramShape() const { return program_shape_; }
258 
hlo_profiling_enabled()259   bool hlo_profiling_enabled() const {
260     return hlo_profile_printer_data_ != nullptr;
261   }
hlo_profile_printer_data()262   const xla::HloProfilePrinterData& hlo_profile_printer_data() const {
263     assert(hlo_profiling_enabled());
264     return *hlo_profile_printer_data_;
265   }
266 
267  protected:
268   // ---------------------------------------------------------------------------
269   // Accessors for reading from and writing to instances of `StaticData`.
270   //
271   // Classes generated by tfcompile can call these because the generated classes
272   // inherit from `XlaCompiledCpuFunction`.  `XlaJitCompiledCpuFunction` can
273   // call these because it is explicitly added as a friend.
274 
set_static_data_raw_function(StaticData * static_data,RawFunction raw_function)275   static void set_static_data_raw_function(StaticData* static_data,
276                                            RawFunction raw_function) {
277     static_data->raw_function_ = raw_function;
278   }
279 
set_static_data_buffer_infos(StaticData * static_data,const xla::cpu_function_runtime::BufferInfo * buffer_infos)280   static void set_static_data_buffer_infos(
281       StaticData* static_data,
282       const xla::cpu_function_runtime::BufferInfo* buffer_infos) {
283     static_data->buffer_infos_ = buffer_infos;
284   }
285 
set_static_data_num_buffers(StaticData * static_data,size_t num_buffers)286   static void set_static_data_num_buffers(StaticData* static_data,
287                                           size_t num_buffers) {
288     static_data->num_buffers_ = num_buffers;
289   }
290 
set_static_data_arg_index_table(StaticData * static_data,const int32 * arg_index_table)291   static void set_static_data_arg_index_table(StaticData* static_data,
292                                               const int32* arg_index_table) {
293     static_data->arg_index_table_ = arg_index_table;
294   }
295 
set_static_data_num_args(StaticData * static_data,int64 num_args)296   static void set_static_data_num_args(StaticData* static_data,
297                                        int64 num_args) {
298     static_data->num_args_ = num_args;
299   }
300 
set_static_data_num_variables(StaticData * static_data,int64 num_variables)301   static void set_static_data_num_variables(StaticData* static_data,
302                                             int64 num_variables) {
303     static_data->num_variables_ = num_variables;
304   }
305 
set_static_data_result_index(StaticData * static_data,size_t result_index)306   static void set_static_data_result_index(StaticData* static_data,
307                                            size_t result_index) {
308     static_data->result_index_ = result_index;
309   }
310 
set_static_data_arg_names(StaticData * static_data,const char ** arg_names)311   static void set_static_data_arg_names(StaticData* static_data,
312                                         const char** arg_names) {
313     static_data->arg_names_ = arg_names;
314   }
315 
set_static_data_variable_names(StaticData * static_data,const char ** variable_names)316   static void set_static_data_variable_names(StaticData* static_data,
317                                              const char** variable_names) {
318     static_data->variable_names_ = variable_names;
319   }
320 
set_static_data_result_names(StaticData * static_data,const char ** result_names)321   static void set_static_data_result_names(StaticData* static_data,
322                                            const char** result_names) {
323     static_data->result_names_ = result_names;
324   }
325 
set_static_data_program_shape(StaticData * static_data,const xla::ProgramShapeProto * program_shape)326   static void set_static_data_program_shape(
327       StaticData* static_data, const xla::ProgramShapeProto* program_shape) {
328     static_data->program_shape_ = program_shape;
329   }
330 
set_static_data_hlo_profile_printer_data(StaticData * static_data,const xla::HloProfilePrinterData * hlo_profile_printer_data)331   static void set_static_data_hlo_profile_printer_data(
332       StaticData* static_data,
333       const xla::HloProfilePrinterData* hlo_profile_printer_data) {
334     static_data->hlo_profile_printer_data_ = hlo_profile_printer_data;
335   }
336 
337   static const xla::HloProfilePrinterData*
get_static_data_hlo_profile_printer_data(StaticData * static_data)338   get_static_data_hlo_profile_printer_data(StaticData* static_data) {
339     return static_data->hlo_profile_printer_data_;
340   }
341 
set_static_data_profile_counters_size(StaticData * static_data,int64 profile_counters_size)342   static void set_static_data_profile_counters_size(
343       StaticData* static_data, int64 profile_counters_size) {
344     static_data->profile_counters_size_ = profile_counters_size;
345   }
346 
347  private:
348   const RawFunction raw_function_;
349   const size_t result_index_;
350 
351   // Array containing pointers to argument and temp buffers (slots corresponding
352   // to constant and on-stack buffers are null).
353   void** const buffer_table_;
354 
355   // Describes the buffers used by the XLA computation.
356   const xla::cpu_function_runtime::BufferInfo* const buffer_infos_;
357 
358   // Argument i needs to be placed in buffer_table_[arg_index_to_temp_index_[i]]
359   // for XLA generated code to be able to find it.
360   const int32* const arg_index_table_;
361 
362   // The number of incoming arguments.
363   const int32 num_args_;
364 
365   // The number of incoming variables.
366   const int32 num_variables_;
367 
368   // Backing memory for buffer_table_ and args_, the latter depending on
369   // AllocMode.
370   void* alloc_buffer_table_ = nullptr;
371 
372   // Backing memory for profiling counters.
373   int64* profile_counters_ = nullptr;
374 
375   // Options and context passed to the compiled function.
376   xla::ExecutableRunOptions run_options_;
377 
378   // Optional metadata.
379   const char** arg_names_ = nullptr;
380   const char** variable_names_ = nullptr;
381   const char** result_names_ = nullptr;
382   const xla::ProgramShapeProto* program_shape_ = nullptr;
383   const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr;
384 
385   // Add `XlaJitCompiledCpuFunction` as a friend so that it can access the
386   // `set_static_data_*` static methods above.
387   friend class XlaJitCompiledCpuFunction;
388 };
389 
390 }  // namespace tensorflow
391 
392 #endif  // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_
393