1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/get_compiler_ir.h"
17
18 #include <memory>
19 #include <string>
20 #include <utility>
21
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_format.h"
25 #include "tensorflow/compiler/jit/compilability_check_util.h"
26 #include "tensorflow/compiler/jit/defs.h"
27 #include "tensorflow/compiler/jit/flags.h"
28 #include "tensorflow/compiler/jit/xla_launch_util.h"
29 #include "tensorflow/compiler/jit/xla_platform_info.h"
30 #include "tensorflow/compiler/tf2xla/const_analysis.h"
31 #include "tensorflow/compiler/xla/client/executable_build_options.h"
32 #include "tensorflow/compiler/xla/client/local_client.h"
33 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
34 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
35 #include "tensorflow/core/common_runtime/function.h"
36 #include "tensorflow/core/framework/function.h"
37 #include "tensorflow/core/lib/core/status.h"
38 #include "tensorflow/core/platform/statusor.h"
39 #include "tensorflow/core/util/ptr_util.h"
40
41 namespace tensorflow {
42
BuildExecutable(xla::LocalClient * local_client,const XlaCompiler::CompilationResult & result,const XlaCompiler::Options & options,const bool xla_dump_hlo=false)43 static StatusOr<std::unique_ptr<xla::LocalExecutable>> BuildExecutable(
44 xla::LocalClient* local_client,
45 const XlaCompiler::CompilationResult& result,
46 const XlaCompiler::Options& options, const bool xla_dump_hlo = false) {
47 std::vector<const xla::Shape*> argument_layouts(
48 result.xla_input_shapes.size());
49 for (int i = 0, end = result.xla_input_shapes.size(); i < end; ++i) {
50 argument_layouts[i] = &result.xla_input_shapes[i];
51 }
52 xla::ExecutableBuildOptions build_options;
53 if (result.collective_reduce_info) {
54 build_options.set_num_replicas(result.collective_reduce_info->group_size);
55 }
56 build_options.set_device_ordinal(
57 options.device_ordinal != -1 ? options.device_ordinal
58 : local_client->default_device_ordinal());
59 build_options.set_result_layout(result.xla_output_shape);
60 build_options.set_device_allocator(options.device_allocator.get());
61 build_options.set_alias_passthrough_params(options.alias_passthrough_params);
62 build_options.mutable_debug_options()->set_xla_detailed_logging_and_dumping(
63 options.detailed_logging);
64 // If any of the xla_dump_hlo_* flags is set, hlo_proto will be dumped in
65 // executable. The hlo_proto contains HLO modules and buffer assignment.
66 build_options.mutable_debug_options()->set_xla_dump_hlo_as_text(xla_dump_hlo);
67 TF_ASSIGN_OR_RETURN(
68 std::vector<std::unique_ptr<xla::LocalExecutable>> executables,
69 local_client->Compile(*result.computation, argument_layouts,
70 build_options));
71 TF_RET_CHECK(executables.size() == 1);
72 return std::move(executables[0]);
73 }
74
GetCompilerIr(IrExportStage stage,ProcessFunctionLibraryRuntime * pflr,absl::string_view func_name,Device * dev,EagerContext * context,absl::Span<const TensorHandle * const> inputs_handles)75 StatusOr<std::string> GetCompilerIr(
76 IrExportStage stage, ProcessFunctionLibraryRuntime* pflr,
77 absl::string_view func_name, Device* dev, EagerContext* context,
78 absl::Span<const TensorHandle* const> inputs_handles) {
79 NameAttrList function;
80 function.set_name(std::string{func_name});
81
82 FunctionLibraryRuntime* flr = pflr->GetFLR(dev->name());
83 ResourceMgr* rmgr = dev->resource_manager();
84
85 const FunctionBody* fbody = nullptr;
86 std::vector<int> constant_arg_indices;
87 std::vector<int> resource_arg_indices;
88 TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
89 flr, function, &fbody, &constant_arg_indices, &resource_arg_indices));
90
91 MemoryTypeVector input_memory_types =
92 GetInputMemoryTypes(fbody, constant_arg_indices, resource_arg_indices);
93 MemoryTypeVector output_memory_types = GetOutputMemoryTypes(fbody);
94
95 std::deque<Tensor> inputs_storage;
96 std::vector<const Tensor*> inputs;
97 inputs.reserve(inputs_handles.size());
98 for (int i = 0; i < inputs_handles.size(); i++) {
99 const TensorHandle* th = inputs_handles[i];
100 const Tensor* t;
101 // Handle owns the tensor.
102 TF_RETURN_IF_ERROR(th->Tensor(&t));
103 if (absl::c_binary_search(constant_arg_indices, i)) {
104 // Need to make sure it's on the host.
105 inputs_storage.emplace_back(t->dtype(), t->shape());
106 TF_RETURN_IF_ERROR(
107 th->CopyToDevice(*context, /*d=*/nullptr, &inputs_storage.back()));
108 inputs.push_back(&inputs_storage.back());
109 } else {
110 inputs.push_back(t);
111 }
112 }
113
114 std::vector<VariableInfo> variable_infos;
115 TF_RETURN_IF_ERROR(GetVariableInfosFromInputs(
116 rmgr, dev, inputs, resource_arg_indices, &variable_infos));
117 TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
118
119 XlaPlatformInfo platform_info = XlaPlatformInfoFromDevice(dev);
120
121 XlaCompilationCache* cache;
122 TF_RETURN_IF_ERROR(rmgr->LookupOrCreate<XlaCompilationCache>(
123 rmgr->default_container(), "xla_cache", &cache,
124 [&](XlaCompilationCache** cache_write_into) {
125 return BuildXlaCompilationCache(dev, flr, platform_info,
126 cache_write_into);
127 }));
128 core::ScopedUnref cache_ref(cache);
129
130 se::Stream* stream = nullptr;
131 if (const DeviceBase::GpuDeviceInfo* gpu_device_info =
132 dev->tensorflow_gpu_device_info()) {
133 stream = gpu_device_info->stream;
134 }
135
136 XlaCompiler::Options options =
137 GenerateCompilerOptions(*cache, *flr, dev, stream, platform_info,
138 /*has_ref_vars=*/false);
139
140 XlaCompiler::CompileOptions compile_options;
141 compile_options.always_return_tuple = false;
142 compile_options.alias_resource_update = true;
143
144 XlaCompiler compiler(options);
145
146 StatusOr<std::vector<XlaCompiler::Argument>> args =
147 XlaComputationLaunchContext::BuildXlaCompilerArguments(
148 constant_arg_indices, inputs, variable_infos, dev);
149 TF_RETURN_IF_ERROR(args.status());
150
151 xla::LocalClient* local_client = cache->client();
152 XlaCompiler::CompilationResult result;
153 TF_RETURN_IF_ERROR(
154 compiler.CompileFunction(compile_options, function, *args, &result));
155
156 switch (stage) {
157 case IrExportStage::HLO:
158 case IrExportStage::HLO_SERIALIZED: {
159 TF_ASSIGN_OR_RETURN(xla::ProgramShape program_shape,
160 result.computation->GetProgramShape());
161 xla::HloModuleConfig config(program_shape);
162 TF_ASSIGN_OR_RETURN(
163 std::unique_ptr<xla::HloModule> new_module,
164 xla::HloModule::CreateFromProto(result.computation->proto(), config));
165
166 if (stage == IrExportStage::HLO_SERIALIZED) {
167 return new_module->ToProto().SerializeAsString();
168 } else {
169 return new_module->ToString();
170 }
171 }
172 case IrExportStage::OPTIMIZED_HLO:
173 case IrExportStage::OPTIMIZED_HLO_SERIALIZED: {
174 TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::LocalExecutable> executable,
175 BuildExecutable(local_client, result, options));
176 xla::Executable* new_executable = executable->executable();
177 if (stage == IrExportStage::OPTIMIZED_HLO_SERIALIZED) {
178 return new_executable->module().ToProto().SerializeAsString();
179 } else {
180 return new_executable->module().ToString();
181 }
182 }
183 case IrExportStage::OPTIMIZED_HLO_PROTO_SERIALIZED: {
184 TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::LocalExecutable> executable,
185 BuildExecutable(local_client, result, options,
186 /*xla_dump_hlo=*/true));
187 return executable->executable()->hlo_proto()->SerializeAsString();
188 }
189 case IrExportStage::OPTIMIZED_HLO_DOT: {
190 TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::LocalExecutable> executable,
191 BuildExecutable(local_client, result, options));
192 StatusOr<std::string> graph = xla::RenderGraph(
193 *executable->executable()->module().entry_computation(),
194 "Visualization",
195 /*debug_options=*/{}, xla::RenderedGraphFormat::kDot,
196 /*hlo_execution_profile=*/nullptr,
197 /*hlo_render_options=*/{});
198 TF_RETURN_IF_ERROR(graph.status());
199 return *graph;
200 }
201 }
202 }
203
204 } // namespace tensorflow
205