1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/local_service.h"
17
18 #include <string>
19 #include <utility>
20 #include <vector>
21
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/strings/str_format.h"
25 #include "tensorflow/compiler/xla/client/executable_build_options.h"
26 #include "tensorflow/compiler/xla/client/xla_computation.h"
27 #include "tensorflow/compiler/xla/execution_options_util.h"
28 #include "tensorflow/compiler/xla/service/backend.h"
29 #include "tensorflow/compiler/xla/service/computation_layout.h"
30 #include "tensorflow/compiler/xla/service/executable.h"
31 #include "tensorflow/compiler/xla/service/hlo_computation.h"
32 #include "tensorflow/compiler/xla/service/hlo_execution_profile.h"
33 #include "tensorflow/compiler/xla/service/hlo_module.h"
34 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
35 #include "tensorflow/compiler/xla/service/platform_util.h"
36 #include "tensorflow/compiler/xla/shape_layout.h"
37 #include "tensorflow/compiler/xla/shape_util.h"
38 #include "tensorflow/compiler/xla/status_macros.h"
39 #include "tensorflow/compiler/xla/types.h"
40 #include "tensorflow/compiler/xla/util.h"
41 #include "tensorflow/core/lib/gtl/cleanup.h"
42 #include "tensorflow/core/platform/logging.h"
43 #include "tensorflow/core/platform/stream_executor_no_cuda.h"
44
45 namespace xla {
46
NewService(const ServiceOptions & options)47 /* static */ StatusOr<std::unique_ptr<LocalService>> LocalService::NewService(
48 const ServiceOptions& options) {
49 se::Platform* platform = options.platform();
50 if (platform == nullptr) {
51 TF_ASSIGN_OR_RETURN(platform, PlatformUtil::GetDefaultPlatform());
52 }
53
54 BackendOptions backend_options;
55 backend_options.set_platform(platform)
56 .set_intra_op_parallelism_threads(options.intra_op_parallelism_threads())
57 .set_allowed_devices(options.allowed_devices());
58
59 TF_ASSIGN_OR_RETURN(std::unique_ptr<Backend> backend,
60 Backend::CreateBackend(backend_options));
61
62 std::unique_ptr<LocalService> service(
63 new LocalService(options, std::move(backend)));
64 return std::move(service);
65 }
66
LocalService(const ServiceOptions & options,std::unique_ptr<Backend> execute_backend)67 LocalService::LocalService(const ServiceOptions& options,
68 std::unique_ptr<Backend> execute_backend)
69 : Service(options, std::move(execute_backend)) {}
70
71 namespace {
72
73 // Retrieves the parameter metadata for the given computation and parameter
74 // number.
75 //
76 // If the parameter number is invalid for this computation, nullopt is
77 // returned. When the return value has_value(), nullptr will never be
78 // the held value.
ParameterMetadata(const XlaComputation & computation,int parameter_number)79 absl::optional<const OpMetadata*> ParameterMetadata(
80 const XlaComputation& computation, int parameter_number) {
81 for (const HloComputationProto& comp : computation.proto().computations()) {
82 if (comp.id() == computation.proto().entry_computation_id()) {
83 for (const HloInstructionProto& instr : comp.instructions()) {
84 if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter) &&
85 instr.parameter_number() == parameter_number) {
86 if (!instr.has_metadata()) {
87 return absl::nullopt;
88 }
89 return &instr.metadata();
90 }
91 }
92 }
93 }
94 return absl::nullopt;
95 }
96
CreateExecutionOptions(const ExecutableBuildOptions & build_options,const ProgramShape * program_shape)97 ExecutionOptions CreateExecutionOptions(
98 const ExecutableBuildOptions& build_options,
99 const ProgramShape* program_shape) {
100 ExecutionOptions execution_options = CreateDefaultExecutionOptions();
101 if (build_options.has_debug_options()) {
102 *execution_options.mutable_debug_options() = build_options.debug_options();
103 }
104 if (build_options.result_layout() != nullptr) {
105 *execution_options.mutable_shape_with_output_layout() =
106 build_options.result_layout()->ToProto();
107 } else {
108 Shape result_shape(program_shape->result());
109 LayoutUtil::SetToDefaultLayout(&result_shape);
110 *execution_options.mutable_shape_with_output_layout() =
111 result_shape.ToProto();
112 }
113 execution_options.set_num_replicas(build_options.num_replicas());
114 return execution_options;
115 }
116
117 } // namespace
118
CompileExecutable(const XlaComputation & computation,const absl::Span<const Shape * const> argument_layouts,const ExecutableBuildOptions & build_options)119 StatusOr<std::unique_ptr<Executable>> LocalService::CompileExecutable(
120 const XlaComputation& computation,
121 const absl::Span<const Shape* const> argument_layouts,
122 const ExecutableBuildOptions& build_options) {
123 const HloModuleProto& proto = computation.proto();
124 TF_RET_CHECK(proto.has_host_program_shape());
125 ProgramShape program_shape(proto.host_program_shape());
126
127 // Validate incoming layouts.
128 if (argument_layouts.size() != program_shape.parameters_size()) {
129 return InvalidArgument(
130 "Invalid number of arguments for computation: expected %d, got %u.",
131 program_shape.parameters_size(), argument_layouts.size());
132 }
133
134 for (int i = 0; i < argument_layouts.size(); ++i) {
135 const Shape& argument_shape = *argument_layouts[i];
136 TF_RETURN_IF_ERROR(
137 ShapeUtil::ValidateShapeWithOptionalLayout(argument_shape));
138 if (!ShapeUtil::Compatible(argument_shape, program_shape.parameters(i))) {
139 absl::optional<const OpMetadata*> metadata =
140 ParameterMetadata(computation, /*parameter_number=*/i);
141 auto metadata_string = [&metadata]() -> string {
142 if (!metadata.has_value()) {
143 return "";
144 }
145 CHECK(metadata.value() != nullptr);
146 const OpMetadata& m = *metadata.value();
147 if (!m.source_file().empty()) {
148 return absl::StrFormat(" (%s:%d)", m.source_file(), m.source_line());
149 }
150 return "";
151 };
152 return InvalidArgument(
153 "Invalid argument shape for argument %d%s, expected %s, got %s.", i,
154 metadata_string(),
155 ShapeUtil::HumanString(program_shape.parameters(i)),
156 ShapeUtil::HumanString(argument_shape));
157 }
158 }
159 if (build_options.result_layout() != nullptr) {
160 TF_RETURN_IF_ERROR(ValidateResultShape(*build_options.result_layout(),
161 program_shape.result()));
162 }
163
164 ExecutionOptions execution_options =
165 CreateExecutionOptions(build_options, &program_shape);
166
167 TF_ASSIGN_OR_RETURN(
168 std::unique_ptr<HloModuleConfig> module_config,
169 CreateModuleConfig(program_shape, argument_layouts, &execution_options));
170
171 VLOG(3) << "Computation Layout: "
172 << module_config->entry_computation_layout().ToString();
173
174 TF_ASSIGN_OR_RETURN(
175 se::StreamExecutor * executor,
176 execute_backend_->stream_executor(build_options.device_ordinal()));
177
178 return BuildExecutable(proto, std::move(module_config),
179 execute_backend_.get(), executor,
180 build_options.device_allocator());
181 }
182
ReplicaNumberToDeviceOrdinal(int replica_number)183 StatusOr<int> LocalService::ReplicaNumberToDeviceOrdinal(int replica_number) {
184 return backend().computation_placer()->DeviceId(
185 replica_number, /*computation=*/0, options_.number_of_replicas(),
186 /*computation_count=*/1);
187 }
188
GlobalDataToShapedBuffer(const GlobalDataHandle & data,int replica_number)189 StatusOr<const ShapedBuffer*> LocalService::GlobalDataToShapedBuffer(
190 const GlobalDataHandle& data, int replica_number) {
191 TF_ASSIGN_OR_RETURN(auto buffers, allocation_tracker_.Resolve(data));
192 if (replica_number >= buffers.size()) {
193 return InvalidArgument(
194 "replica_number %d out of range; must be less than num_replicas = %u.",
195 replica_number, buffers.size());
196 }
197 return buffers[replica_number];
198 }
199
RegisterReplicatedBuffers(std::vector<ScopedShapedBuffer> replicated_buffers,const string & tag)200 StatusOr<GlobalDataHandle> LocalService::RegisterReplicatedBuffers(
201 std::vector<ScopedShapedBuffer> replicated_buffers, const string& tag) {
202 return allocation_tracker_.RegisterReplicatedBuffers(
203 std::move(replicated_buffers), tag);
204 }
205
206 } // namespace xla
207