1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h"
17
18 #include <cassert>
19 #include "tensorflow/compiler/xla/cpu_function_runtime.h"
20
21 namespace tensorflow {
22
XlaCompiledCpuFunction(const StaticData & static_data,AllocMode alloc_mode)23 XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data,
24 AllocMode alloc_mode)
25 : raw_function_(static_data.raw_function_),
26 result_index_(static_data.result_index_),
27 buffer_table_(new void*[static_data.num_buffers_]),
28 buffer_infos_(static_data.buffer_infos_),
29 arg_index_table_(static_data.arg_index_table_),
30 num_args_(static_data.num_args_),
31 num_variables_(static_data.num_variables_),
32 arg_names_(static_data.arg_names_),
33 variable_names_(static_data.variable_names_),
34 result_names_(static_data.result_names_),
35 program_shape_(static_data.program_shape_),
36 hlo_profile_printer_data_(static_data.hlo_profile_printer_data_) {
37 bool allocate_entry_params =
38 alloc_mode == AllocMode::ARGS_VARIABLES_RESULTS_PROFILES_AND_TEMPS;
39 // Allocate arg and temp buffers.
40 alloc_buffer_table_ = xla::cpu_function_runtime::MallocContiguousBuffers(
41 static_data.buffer_infos_, static_data.num_buffers_,
42 /*allocate_entry_params=*/allocate_entry_params, buffer_table_,
43 /*annotate_initialized=*/true);
44 // If Hlo profiling is enabled the generated code expects an appropriately
45 // sized buffer to be passed in as the last argument. If Hlo profiling is
46 // disabled the last function argument is still present in the function
47 // signature, but it is ignored by the generated code and we pass in null for
48 // it.
49 if (hlo_profiling_enabled()) {
50 profile_counters_ = new int64[static_data.profile_counters_size_]();
51 }
52 }
53
Run()54 bool XlaCompiledCpuFunction::Run() {
55 raw_function_(buffer_table_[result_index_], &run_options_, nullptr,
56 buffer_table_, profile_counters_);
57 return true;
58 }
59
~XlaCompiledCpuFunction()60 XlaCompiledCpuFunction::~XlaCompiledCpuFunction() {
61 xla::cpu_function_runtime::FreeContiguous(alloc_buffer_table_);
62 delete[] buffer_table_;
63 delete[] profile_counters_;
64 }
65
66 namespace {
67
68 constexpr int kNotFound = -1;
69
70 // Linear search through `names` looking for a match with `name`. Returns -1 if
71 // the name isn't found, or is empty.
72 //
73 // REQUIRES: `names` is a nullptr-terminated array.
LookupNameIndex(const string & name,const char ** names)74 int LookupNameIndex(const string& name, const char** names) {
75 // Hitting this assert means that there is no name-to-index data available;
76 // for AOT try the setting the tfcompile --gen_name_to_index flag.
77 assert(names != nullptr);
78
79 if (name.empty()) {
80 return kNotFound;
81 }
82 for (int index = 0; names[index] != nullptr; ++index) {
83 if (name == names[index]) {
84 return index;
85 }
86 }
87 return kNotFound;
88 }
89
90 } // namespace
91
LookupArgIndex(const string & name) const92 int XlaCompiledCpuFunction::LookupArgIndex(const string& name) const {
93 return LookupNameIndex(name, arg_names_);
94 }
95
LookupVariableIndex(const string & name) const96 int XlaCompiledCpuFunction::LookupVariableIndex(const string& name) const {
97 int index = LookupNameIndex(name, variable_names_);
98 if (index == kNotFound) {
99 return kNotFound;
100 }
101 return num_args_ - num_variables_ + index;
102 }
103
LookupResultIndex(const string & name) const104 int XlaCompiledCpuFunction::LookupResultIndex(const string& name) const {
105 return LookupNameIndex(name, result_names_);
106 }
107
108 } // namespace tensorflow
109