1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/local_service.h"
17
18 #include <string>
19 #include <utility>
20 #include <vector>
21
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_format.h"
25 #include "tensorflow/compiler/xla/client/executable_build_options.h"
26 #include "tensorflow/compiler/xla/client/xla_computation.h"
27 #include "tensorflow/compiler/xla/execution_options_util.h"
28 #include "tensorflow/compiler/xla/service/backend.h"
29 #include "tensorflow/compiler/xla/service/computation_layout.h"
30 #include "tensorflow/compiler/xla/service/executable.h"
31 #include "tensorflow/compiler/xla/service/hlo_computation.h"
32 #include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
33 #include "tensorflow/compiler/xla/service/hlo_module.h"
34 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
35 #include "tensorflow/compiler/xla/service/hlo_module_util.h"
36 #include "tensorflow/compiler/xla/service/platform_util.h"
37 #include "tensorflow/compiler/xla/shape_layout.h"
38 #include "tensorflow/compiler/xla/shape_util.h"
39 #include "tensorflow/compiler/xla/status_macros.h"
40 #include "tensorflow/compiler/xla/types.h"
41 #include "tensorflow/compiler/xla/util.h"
42 #include "tensorflow/core/lib/gtl/cleanup.h"
43 #include "tensorflow/core/platform/logging.h"
44 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
45
46 namespace xla {
47
NewService(const ServiceOptions & options)48 /* static */ StatusOr<std::unique_ptr<LocalService>> LocalService::NewService(
49 const ServiceOptions& options) {
50 se::Platform* platform = options.platform();
51 if (platform == nullptr) {
52 TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
53 }
54
55 BackendOptions backend_options;
56 backend_options.set_platform(platform)
57 .set_intra_op_parallelism_threads(options.intra_op_parallelism_threads())
58 .set_allowed_devices(options.allowed_devices());
59
60 TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> backend,
61 Backend::CreateBackend(backend_options));
62
63 std::unique_ptr<LocalService> service(
64 new LocalService(options, std::move(backend)));
65 return std::move(service);
66 }
67
LocalService(const ServiceOptions & options,std::unique_ptr<Backend> execute_backend)68 LocalService::LocalService(const ServiceOptions& options,
69 std::unique_ptr<Backend> execute_backend)
70 : Service(options, std::move(execute_backend)) {}
71
72 namespace {
73
74 // Retrieves the parameter metadata for the given computation and parameter
75 // number.
76 //
77 // If the parameter number is invalid for this computation, nullopt is
78 // returned. When the return value has_value(), nullptr will never be
79 // the held value.
ParameterMetadata(const XlaComputation & computation,int parameter_number)80 absl::optional<const OpMetadata*> ParameterMetadata(
81 const XlaComputation& computation, int parameter_number) {
82 for (const HloComputationProto& comp : computation.proto().computations()) {
83 if (comp.id() == computation.proto().entry_computation_id()) {
84 for (const HloInstructionProto& instr : comp.instructions()) {
85 if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter) &&
86 instr.parameter_number() == parameter_number) {
87 if (!instr.has_metadata()) {
88 return absl::nullopt;
89 }
90 return &instr.metadata();
91 }
92 }
93 }
94 }
95 return absl::nullopt;
96 }
97
98 } // namespace
99
100 StatusOr<std::vector<std::unique_ptr<Executable>>>
CompileExecutables(const XlaComputation & computation,const absl::Span<const Shape * const> argument_layouts,const ExecutableBuildOptions & build_options)101 LocalService::CompileExecutables(
102 const XlaComputation& computation,
103 const absl::Span<const Shape* const> argument_layouts,
104 const ExecutableBuildOptions& build_options) {
105 const HloModuleProto& proto = computation.proto();
106 TF_RET_CHECK(proto.has_host_program_shape());
107 ProgramShape program_shape(proto.host_program_shape());
108
109 // Validate incoming layouts.
110 if (argument_layouts.size() != program_shape.parameters_size()) {
111 return InvalidArgument(
112 "Invalid number of arguments for computation: expected %d, got %u.",
113 program_shape.parameters_size(), argument_layouts.size());
114 }
115
116 for (int i = 0; i < argument_layouts.size(); ++i) {
117 const Shape& argument_shape = *argument_layouts[i];
118 TF_RETURN_IF_ERROR(
119 ShapeUtil::ValidateShapeWithOptionalLayout(argument_shape));
120 if (!ShapeUtil::Compatible(argument_shape, program_shape.parameters(i))) {
121 absl::optional<const OpMetadata*> metadata =
122 ParameterMetadata(computation, /*parameter_number=*/i);
123 auto metadata_string = [&metadata]() -> string {
124 if (!metadata.has_value()) {
125 return "";
126 }
127 CHECK(metadata.value() != nullptr);
128 const OpMetadata& m = *metadata.value();
129 if (!m.source_file().empty()) {
130 return absl::StrFormat(" (%s:%d)", m.source_file(), m.source_line());
131 }
132 return "";
133 };
134 return InvalidArgument(
135 "Invalid argument shape for argument %d%s, expected %s, got %s.", i,
136 metadata_string(),
137 ShapeUtil::HumanString(program_shape.parameters(i)),
138 ShapeUtil::HumanString(argument_shape));
139 }
140 }
141 if (build_options.result_layout() != nullptr) {
142 TF_RETURN_IF_ERROR(ValidateResultShape(*build_options.result_layout(),
143 program_shape.result()));
144 }
145
146 ExecutionOptions execution_options =
147 CreateExecutionOptions(build_options, &program_shape);
148
149 TF_ASSIGN_OR_RETURN(
150 std::unique_ptr<HloModuleConfig> module_config,
151 CreateModuleConfig(program_shape, argument_layouts, &execution_options));
152
153 VLOG(3) << "Computation Layout: "
154 << module_config->entry_computation_layout().ToString();
155
156 TF_ASSIGN_OR_RETURN(
157 se::StreamExecutor * executor,
158 execute_backend_->stream_executor(build_options.device_ordinal()));
159
160 // TODO(cjfj): Investigate why there are a couple of test failures when the
161 // single partition computations are built using `BuildExecutables`, fix it,
162 // and remove this special case (provided the performance if similar).
163 if (build_options.num_partitions() == 1) {
164 TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
165 BuildExecutable(proto, std::move(module_config),
166 execute_backend_.get(), executor,
167 {build_options.device_allocator(),
168 build_options.compile_thread_pool()},
169 build_options.run_backend_only()));
170 std::vector<std::unique_ptr<Executable>> executables;
171 executables.push_back(std::move(executable));
172 return executables;
173 } else {
174 std::vector<std::unique_ptr<HloModuleConfig>> module_configs;
175 module_configs.push_back(std::move(module_config));
176 // BuildExecutables uses the executors length to determine the number of
177 // cores per module, but otherwise only uses the first executor.
178 std::vector<se::StreamExecutor*> executors(build_options.num_partitions(),
179 executor);
180
181 return BuildExecutables(
182 /*module_protos=*/{&proto}, std::move(module_configs),
183 execute_backend_.get(), {executors},
184 Compiler::CompileOptions{build_options.device_allocator(),
185 build_options.compile_thread_pool()},
186 build_options.run_backend_only());
187 }
188 }
189
ReplicaNumberToDeviceOrdinal(int replica_number)190 StatusOr<int> LocalService::ReplicaNumberToDeviceOrdinal(int replica_number) {
191 return backend().computation_placer()->DeviceId(
192 replica_number, /*computation=*/0, options_.number_of_replicas(),
193 /*computation_count=*/1);
194 }
195
GlobalDataToShapedBuffer(const GlobalDataHandle & data,int replica_number)196 StatusOr<const ShapedBuffer*> LocalService::GlobalDataToShapedBuffer(
197 const GlobalDataHandle& data, int replica_number) {
198 TF_ASSIGN_OR_RETURN(auto buffers, allocation_tracker_.Resolve(data));
199 if (replica_number >= buffers.size()) {
200 return InvalidArgument(
201 "replica_number %d out of range; must be less than num_replicas = %u.",
202 replica_number, buffers.size());
203 }
204 return buffers[replica_number];
205 }
206
RegisterReplicatedBuffers(std::vector<ScopedShapedBuffer> replicated_buffers,const string & tag)207 StatusOr<GlobalDataHandle> LocalService::RegisterReplicatedBuffers(
208 std::vector<ScopedShapedBuffer> replicated_buffers, const string& tag) {
209 return allocation_tracker_.RegisterReplicatedBuffers(
210 std::move(replicated_buffers), tag);
211 }
212
213 } // namespace xla
214