1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/collective.h"
16
17 #include "absl/strings/escaping.h"
18 #include "tensorflow/core/framework/op_kernel.h"
19 #include "tensorflow/core/lib/core/errors.h"
20 #include "tensorflow/core/lib/hash/hash.h"
21 #include "tensorflow/core/lib/strings/str_util.h"
22 #include "tensorflow/core/lib/strings/strcat.h"
23
24 namespace tensorflow {
25
26 namespace {
27 // A RegistrationInfo object stores a collective implementation registration
28 // details. `factory` is used to create instances of the collective
29 // implementation.
30 struct RegistrationInfo {
31 // This constructor also creates, and stores in `param_resolver_instance`,
32 // what is effectively a static instance of the collective implementation.
33 // During param resolution of collective ops we return this static instance.
34 // The actual op execution gets a fresh instance using `factory`.
RegistrationInfotensorflow::__anon237cdfb40111::RegistrationInfo35 RegistrationInfo(const string& n, CollectiveRegistry::Factory f)
36 : name(n),
37 factory(std::move(f)),
38 param_resolver_instance(this->factory()) {}
39 string name;
40 CollectiveRegistry::Factory factory;
41 CollectiveImplementationInterface* param_resolver_instance;
42 };
43
MutableCollectiveRegistry()44 std::vector<RegistrationInfo>* MutableCollectiveRegistry() {
45 static std::vector<RegistrationInfo>* registry =
46 new std::vector<RegistrationInfo>;
47 return registry;
48 }
49 } // namespace
50
ToString() const51 string CollGroupRuntimeDetails::ToString() const {
52 return strings::StrCat("CollGroupRuntimeDetails {communicator_key=",
53 absl::CEscape(communicator_key), "}");
54 }
55
ToString() const56 string CollGroupParams::ToString() const {
57 string v = strings::StrCat(
58 "CollGroupParams {group_key=", group_key, " group_size=", group_size,
59 " device_type=", device_type.type_string(), " num_tasks=", num_tasks,
60 " runtime_details=", runtime_details.ToString(), " devices {");
61 for (const auto& d : devices) {
62 strings::StrAppend(&v, d.name(), ",");
63 }
64 strings::StrAppend(&v, "} task_names={");
65 for (const auto& n : task_names) {
66 strings::StrAppend(&v, n, ", ");
67 }
68 strings::StrAppend(&v, "} num_devices_per_task={");
69 for (const auto& dpt : num_devices_per_task) {
70 strings::StrAppend(&v, dpt.first, ": ", dpt.second, ", ");
71 }
72 strings::StrAppend(&v, "}");
73 return v;
74 }
75
operator =(const CollInstanceParams & other)76 CollInstanceParams& CollInstanceParams::operator=(
77 const CollInstanceParams& other) {
78 if (this != &other) {
79 instance_key = other.instance_key;
80 type = other.type;
81 data_type = other.data_type;
82 shape = other.shape;
83 impl_details.subdiv_offsets.assign(
84 other.impl_details.subdiv_offsets.begin(),
85 other.impl_details.subdiv_offsets.end());
86 impl_details.subdiv_permutations.clear();
87 for (auto p : other.impl_details.subdiv_permutations) {
88 impl_details.subdiv_permutations.push_back(
89 std::vector<int>(p.begin(), p.end()));
90 }
91 impl_details.subdiv_source_rank.assign(
92 other.impl_details.subdiv_source_rank.begin(),
93 other.impl_details.subdiv_source_rank.end());
94 impl_details.dependencies = other.impl_details.dependencies;
95 devices.assign(other.devices.begin(), other.devices.end());
96 permutation.assign(other.permutation.begin(), other.permutation.end());
97 }
98 return *this;
99 }
100
ToString() const101 string CollInstanceParams::ToString() const {
102 string v =
103 strings::StrCat("CollInstanceParams { instance_key=", instance_key,
104 " type=", type, " data_type=", DataTypeString(data_type),
105 " shape=", shape.DebugString(), " devices {");
106 strings::StrAppend(&v, "}, collective_name=", impl_details.collective_name,
107 ", subdiv_offsets={");
108 strings::StrAppend(&v, "}, subdiv_offsets={");
109 for (const auto& d : impl_details.subdiv_offsets) {
110 strings::StrAppend(&v, d, ",");
111 }
112 strings::StrAppend(&v, "}, subdiv_perms={");
113 for (const auto& p : impl_details.subdiv_permutations) {
114 strings::StrAppend(&v, "{");
115 for (const auto& i : p) {
116 strings::StrAppend(&v, i, ",");
117 }
118 strings::StrAppend(&v, "}"); // one subdiv
119 }
120 if (!impl_details.subdiv_source_rank.empty()) {
121 strings::StrAppend(&v, " subdiv_source_rank={");
122 for (const auto& r : impl_details.subdiv_source_rank) {
123 strings::StrAppend(&v, r, ",");
124 }
125 strings::StrAppend(&v, "}");
126 } // all subdivs
127 if (type == PERMUTE_COLLECTIVE) {
128 strings::StrAppend(&v, "}, permute_devices {");
129 for (const auto& d : devices) {
130 strings::StrAppend(&v, d, ",");
131 }
132 strings::StrAppend(&v, "}, permute_permutation {");
133 for (const auto& p : permutation) {
134 strings::StrAppend(&v, p, ",");
135 }
136 strings::StrAppend(&v, "}");
137 }
138 return v;
139 }
140
ToString() const141 string CollTaskParams::ToString() const {
142 string v = strings::StrCat("CollTaskParams {is_local={");
143 for (const auto& b : is_local) {
144 strings::StrAppend(&v, static_cast<int>(b), ",");
145 }
146 strings::StrAppend(&v, "}}");
147 return v;
148 }
149
ToString() const150 string CollectiveParams::ToString() const {
151 string v = strings::StrCat("CollectiveParams ", name, " {", group.ToString());
152 strings::StrAppend(&v, " ", instance.ToString());
153 strings::StrAppend(&v, " ", task.ToString());
154 strings::StrAppend(&v, " default_rank=", default_rank,
155 " is_source=", is_source, " source_rank=", source_rank,
156 " subdiv_rank={");
157 for (const auto& r : subdiv_rank) {
158 strings::StrAppend(&v, r, ",");
159 }
160 strings::StrAppend(&v, "}}");
161 return v;
162 }
163
CtxParams(OpKernelContext * ctx)164 /*static*/ OpKernelContext::Params* CollectiveExecutor::CtxParams(
165 OpKernelContext* ctx) {
166 return ctx->params_;
167 }
168
CollectiveContext(CollectiveExecutor * col_exec,NcclCommunicatorInterface * nccl_communicator,const DeviceMgr * dev_mgr,OpKernelContext * ctx,OpKernelContext::Params * op_params,const CollectiveParams * col_params,const string & exec_key,int64_t step_id,const Tensor * input,Tensor * output)169 CollectiveContext::CollectiveContext(
170 CollectiveExecutor* col_exec, NcclCommunicatorInterface* nccl_communicator,
171 const DeviceMgr* dev_mgr, OpKernelContext* ctx,
172 OpKernelContext::Params* op_params, const CollectiveParams* col_params,
173 const string& exec_key, int64_t step_id, const Tensor* input,
174 Tensor* output)
175 : col_exec(col_exec),
176 nccl_communicator(nccl_communicator),
177 dev_mgr(dev_mgr),
178 op_ctx(ctx),
179 op_params(op_params),
180 col_params(col_params),
181 exec_key(exec_key),
182 step_id(step_id),
183 input(input),
184 output(output),
185 device(nullptr),
186 device_name(col_params->group.devices[col_params->default_rank].name()) {}
187
188 /*static*/
189 int64_t CollectiveExecutor::kInvalidId = -1;
190
191 /*static*/
Lookup(const string & collective_name,CollectiveImplementationInterface ** implementation)192 Status CollectiveRegistry::Lookup(
193 const string& collective_name,
194 CollectiveImplementationInterface** implementation) {
195 return LookupHelper(collective_name, implementation, false);
196 }
197
198 /*static*/
LookupParamResolverInstance(const string & collective_name,CollectiveImplementationInterface ** implementation)199 Status CollectiveRegistry::LookupParamResolverInstance(
200 const string& collective_name,
201 CollectiveImplementationInterface** implementation) {
202 return LookupHelper(collective_name, implementation, true);
203 }
204
205 /*static*/
GetAll(std::vector<CollectiveImplementationInterface * > * implementations)206 void CollectiveRegistry::GetAll(
207 std::vector<CollectiveImplementationInterface*>* implementations) {
208 std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry();
209 for (const RegistrationInfo& reg_info : *registry)
210 implementations->emplace_back(reg_info.factory());
211 }
212
213 /*static*/
Register(const string & collective_name,Factory factory)214 Status CollectiveRegistry::Register(const string& collective_name,
215 Factory factory) {
216 std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry();
217 for (const RegistrationInfo& reg_info : *registry) {
218 if (reg_info.name == collective_name)
219 return errors::Internal("Already registered collective ",
220 collective_name);
221 }
222 registry->emplace_back(collective_name, std::move(factory));
223 return Status::OK();
224 }
225
226 /*static*/
LookupHelper(const string & collective_name,CollectiveImplementationInterface ** implementation,bool param_resolver)227 Status CollectiveRegistry::LookupHelper(
228 const string& collective_name,
229 CollectiveImplementationInterface** implementation, bool param_resolver) {
230 std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry();
231 for (const RegistrationInfo& reg_info : *registry) {
232 if (reg_info.name == collective_name) {
233 if (param_resolver) {
234 *implementation = reg_info.param_resolver_instance;
235 } else {
236 *implementation = reg_info.factory();
237 }
238 return Status::OK();
239 }
240 }
241 return errors::Internal(
242 "CollectiveRegistry::Lookup did not find collective implementation ",
243 collective_name);
244 }
245
246 } // namespace tensorflow
247