• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/get_compiler_ir.h"
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_format.h"
25 #include "tensorflow/compiler/jit/compilability_check_util.h"
26 #include "tensorflow/compiler/jit/defs.h"
27 #include "tensorflow/compiler/jit/flags.h"
28 #include "tensorflow/compiler/jit/xla_launch_util.h"
29 #include "tensorflow/compiler/jit/xla_platform_info.h"
30 #include "tensorflow/compiler/tf2xla/const_analysis.h"
31 #include "tensorflow/compiler/xla/client/executable_build_options.h"
32 #include "tensorflow/compiler/xla/client/local_client.h"
33 #include "tensorflow/compiler/xla/service/hlo_graph_dumper.h"
34 #include "tensorflow/core/common_runtime/eager/tensor_handle.h"
35 #include "tensorflow/core/common_runtime/function.h"
36 #include "tensorflow/core/framework/function.h"
37 #include "tensorflow/core/lib/core/status.h"
38 #include "tensorflow/core/platform/statusor.h"
39 #include "tensorflow/core/util/ptr_util.h"
40 
41 namespace tensorflow {
42 
BuildExecutable(xla::LocalClient * local_client,const XlaCompiler::CompilationResult & result,const XlaCompiler::Options & options,const bool xla_dump_hlo=false)43 static StatusOr<std::unique_ptr<xla::LocalExecutable>> BuildExecutable(
44     xla::LocalClient* local_client,
45     const XlaCompiler::CompilationResult& result,
46     const XlaCompiler::Options& options, const bool xla_dump_hlo = false) {
47   std::vector<const xla::Shape*> argument_layouts(
48       result.xla_input_shapes.size());
49   for (int i = 0, end = result.xla_input_shapes.size(); i < end; ++i) {
50     argument_layouts[i] = &result.xla_input_shapes[i];
51   }
52   xla::ExecutableBuildOptions build_options;
53   if (result.collective_reduce_info) {
54     build_options.set_num_replicas(result.collective_reduce_info->group_size);
55   }
56   build_options.set_device_ordinal(
57       options.device_ordinal != -1 ? options.device_ordinal
58                                    : local_client->default_device_ordinal());
59   build_options.set_result_layout(result.xla_output_shape);
60   build_options.set_device_allocator(options.device_allocator.get());
61   build_options.set_alias_passthrough_params(options.alias_passthrough_params);
62   build_options.mutable_debug_options()->set_xla_detailed_logging_and_dumping(
63       options.detailed_logging);
64   // If any of the xla_dump_hlo_* flags is set, hlo_proto will be dumped in
65   // executable. The hlo_proto contains HLO modules and buffer assignment.
66   build_options.mutable_debug_options()->set_xla_dump_hlo_as_text(xla_dump_hlo);
67   TF_ASSIGN_OR_RETURN(
68       std::vector<std::unique_ptr<xla::LocalExecutable>> executables,
69       local_client->Compile(*result.computation, argument_layouts,
70                             build_options));
71   TF_RET_CHECK(executables.size() == 1);
72   return std::move(executables[0]);
73 }
74 
GetCompilerIr(IrExportStage stage,ProcessFunctionLibraryRuntime * pflr,absl::string_view func_name,Device * dev,EagerContext * context,absl::Span<const TensorHandle * const> inputs_handles)75 StatusOr<std::string> GetCompilerIr(
76     IrExportStage stage, ProcessFunctionLibraryRuntime* pflr,
77     absl::string_view func_name, Device* dev, EagerContext* context,
78     absl::Span<const TensorHandle* const> inputs_handles) {
79   NameAttrList function;
80   function.set_name(std::string{func_name});
81 
82   FunctionLibraryRuntime* flr = pflr->GetFLR(dev->name());
83   ResourceMgr* rmgr = dev->resource_manager();
84 
85   const FunctionBody* fbody = nullptr;
86   std::vector<int> constant_arg_indices;
87   std::vector<int> resource_arg_indices;
88   TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
89       flr, function, &fbody, &constant_arg_indices, &resource_arg_indices));
90 
91   MemoryTypeVector input_memory_types =
92       GetInputMemoryTypes(fbody, constant_arg_indices, resource_arg_indices);
93   MemoryTypeVector output_memory_types = GetOutputMemoryTypes(fbody);
94 
95   std::deque<Tensor> inputs_storage;
96   std::vector<const Tensor*> inputs;
97   inputs.reserve(inputs_handles.size());
98   for (int i = 0; i < inputs_handles.size(); i++) {
99     const TensorHandle* th = inputs_handles[i];
100     const Tensor* t;
101     // Handle owns the tensor.
102     TF_RETURN_IF_ERROR(th->Tensor(&t));
103     if (absl::c_binary_search(constant_arg_indices, i)) {
104       // Need to make sure it's on the host.
105       inputs_storage.emplace_back(t->dtype(), t->shape());
106       TF_RETURN_IF_ERROR(
107           th->CopyToDevice(*context, /*d=*/nullptr, &inputs_storage.back()));
108       inputs.push_back(&inputs_storage.back());
109     } else {
110       inputs.push_back(t);
111     }
112   }
113 
114   std::vector<VariableInfo> variable_infos;
115   TF_RETURN_IF_ERROR(GetVariableInfosFromInputs(
116       rmgr, dev, inputs, resource_arg_indices, &variable_infos));
117   TF_RETURN_IF_ERROR(LockVariables(absl::MakeSpan(variable_infos)));
118 
119   XlaPlatformInfo platform_info = XlaPlatformInfoFromDevice(dev);
120 
121   XlaCompilationCache* cache;
122   TF_RETURN_IF_ERROR(rmgr->LookupOrCreate<XlaCompilationCache>(
123       rmgr->default_container(), "xla_cache", &cache,
124       [&](XlaCompilationCache** cache_write_into) {
125         return BuildXlaCompilationCache(dev, flr, platform_info,
126                                         cache_write_into);
127       }));
128   core::ScopedUnref cache_ref(cache);
129 
130   se::Stream* stream = nullptr;
131   if (const DeviceBase::GpuDeviceInfo* gpu_device_info =
132           dev->tensorflow_gpu_device_info()) {
133     stream = gpu_device_info->stream;
134   }
135 
136   XlaCompiler::Options options =
137       GenerateCompilerOptions(*cache, *flr, dev, stream, platform_info,
138                               /*has_ref_vars=*/false);
139 
140   XlaCompiler::CompileOptions compile_options;
141   compile_options.always_return_tuple = false;
142   compile_options.alias_resource_update = true;
143 
144   XlaCompiler compiler(options);
145 
146   StatusOr<std::vector<XlaCompiler::Argument>> args =
147       XlaComputationLaunchContext::BuildXlaCompilerArguments(
148           constant_arg_indices, inputs, variable_infos, dev);
149   TF_RETURN_IF_ERROR(args.status());
150 
151   xla::LocalClient* local_client = cache->client();
152   XlaCompiler::CompilationResult result;
153   TF_RETURN_IF_ERROR(
154       compiler.CompileFunction(compile_options, function, *args, &result));
155 
156   switch (stage) {
157     case IrExportStage::HLO:
158     case IrExportStage::HLO_SERIALIZED: {
159       TF_ASSIGN_OR_RETURN(xla::ProgramShape program_shape,
160                           result.computation->GetProgramShape());
161       xla::HloModuleConfig config(program_shape);
162       TF_ASSIGN_OR_RETURN(
163           std::unique_ptr<xla::HloModule> new_module,
164           xla::HloModule::CreateFromProto(result.computation->proto(), config));
165 
166       if (stage == IrExportStage::HLO_SERIALIZED) {
167         return new_module->ToProto().SerializeAsString();
168       } else {
169         return new_module->ToString();
170       }
171     }
172     case IrExportStage::OPTIMIZED_HLO:
173     case IrExportStage::OPTIMIZED_HLO_SERIALIZED: {
174       TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::LocalExecutable> executable,
175                           BuildExecutable(local_client, result, options));
176       xla::Executable* new_executable = executable->executable();
177       if (stage == IrExportStage::OPTIMIZED_HLO_SERIALIZED) {
178         return new_executable->module().ToProto().SerializeAsString();
179       } else {
180         return new_executable->module().ToString();
181       }
182     }
183     case IrExportStage::OPTIMIZED_HLO_PROTO_SERIALIZED: {
184       TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::LocalExecutable> executable,
185                           BuildExecutable(local_client, result, options,
186                                           /*xla_dump_hlo=*/true));
187       return executable->executable()->hlo_proto()->SerializeAsString();
188     }
189     case IrExportStage::OPTIMIZED_HLO_DOT: {
190       TF_ASSIGN_OR_RETURN(std::unique_ptr<xla::LocalExecutable> executable,
191                           BuildExecutable(local_client, result, options));
192       StatusOr<std::string> graph = xla::RenderGraph(
193           *executable->executable()->module().entry_computation(),
194           "Visualization",
195           /*debug_options=*/{}, xla::RenderedGraphFormat::kDot,
196           /*hlo_execution_profile=*/nullptr,
197           /*hlo_render_options=*/{});
198       TF_RETURN_IF_ERROR(graph.status());
199       return *graph;
200     }
201   }
202 }
203 
204 }  // namespace tensorflow
205