1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_JIT_COMPILED_CPU_FUNCTION_H_ 17 #define TENSORFLOW_COMPILER_TF2XLA_XLA_JIT_COMPILED_CPU_FUNCTION_H_ 18 19 #include <memory> 20 #include <vector> 21 22 #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" 23 #include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" 24 #include "tensorflow/compiler/xla/client/local_client.h" 25 #include "tensorflow/compiler/xla/cpu_function_runtime.h" 26 #include "tensorflow/compiler/xla/statusor.h" 27 #include "tensorflow/core/framework/graph.pb.h" 28 #include "tensorflow/core/platform/types.h" 29 30 namespace tensorflow { 31 32 // Represents the result of JIT compilation by XLA down to a function. This 33 // class holds the state necessary to create XlaCompiledCpuFunction instances, 34 // which are used to actually invoke the compiled computation. 35 // 36 // XlaJitCompiledCpuFunction must outlive the XlaCompiledCpuFunctions that are 37 // created from it. It holds state shared by all of the functions, including the 38 // JIT-compiled function itself, along with buffer sizes and other metadata 39 // necessary for execution. 40 class XlaJitCompiledCpuFunction { 41 public: 42 // Compile a tensorflow::GraphDef into an XlaJitCompiledCpuFunction. The given 43 // `config` specifies the portion of the graph to compile, via feeds and 44 // fetches. Each feed is a positional input argument for the compiled 45 // function, while each fetch is a positional output argument. 46 static StatusOr<std::unique_ptr<XlaJitCompiledCpuFunction>> Compile( 47 const GraphDef& graph_def, const tf2xla::Config& config, 48 const xla::ExecutableBuildOptions& build_options); 49 50 XlaJitCompiledCpuFunction(const XlaJitCompiledCpuFunction&) = delete; 51 XlaJitCompiledCpuFunction& operator=(const XlaJitCompiledCpuFunction&) = 52 delete; 53 54 // Returns static data used to create an XlaCompiledCpuFunction instance, 55 // which represents the JIT-compiled function. The static data is unchanging 56 // across each instance. StaticData()57 const XlaCompiledCpuFunction::StaticData& StaticData() const { 58 return static_data_; 59 } 60 61 private: XlaJitCompiledCpuFunction()62 XlaJitCompiledCpuFunction() {} 63 64 // The executable holds the underlying function. 65 std::unique_ptr<xla::LocalExecutable> executable_; 66 67 // The static data is backed by the rest of the state in this class. 68 XlaCompiledCpuFunction::StaticData static_data_; 69 70 // The backing array for buffer infos. 71 std::vector<xla::cpu_function_runtime::BufferInfo> buffer_infos_; 72 73 // The backing array for the arg index table. 74 std::vector<int32> arg_index_table_; 75 76 // The backing arrays of arg and result names. We hold the actual strings in 77 // nonempty_*_names_, and hold arrays of pointers in *_names_ for the static 78 // data to refer to. 79 std::vector<string> nonempty_arg_names_; 80 std::vector<string> nonempty_variable_names_; 81 std::vector<string> nonempty_result_names_; 82 std::vector<const char*> arg_names_; 83 std::vector<const char*> variable_names_; 84 std::vector<const char*> result_names_; 85 86 // The backing data for the program shape. The proto form of program shape is 87 // used because the program shape is serialized and embedded in the object 88 // file. 89 std::unique_ptr<const xla::ProgramShapeProto> program_shape_; 90 }; 91 92 } // namespace tensorflow 93 94 #endif // TENSORFLOW_COMPILER_TF2XLA_XLA_JIT_COMPILED_CPU_FUNCTION_H_ 95