1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/collective.h"
16
17 #include "absl/strings/escaping.h"
18 #include "tensorflow/core/framework/op_kernel.h"
19 #include "tensorflow/core/lib/core/errors.h"
20 #include "tensorflow/core/lib/hash/hash.h"
21 #include "tensorflow/core/lib/strings/str_util.h"
22 #include "tensorflow/core/lib/strings/strcat.h"
23
24 namespace tensorflow {
25
26 namespace {
27 // A RegistrationInfo object stores a collective implementation registration
28 // details. `factory` is used to create instances of the collective
29 // implementation.
30 struct RegistrationInfo {
31 // This constructor also creates, and stores in `param_resolver_instance`,
32 // what is effectively a static instance of the collective implementation.
33 // During param resolution of collective ops we return this static instance.
34 // The actual op execution gets a fresh instance using `factory`.
RegistrationInfotensorflow::__anond8d2d08f0111::RegistrationInfo35 RegistrationInfo(const string& n, CollectiveRegistry::Factory f)
36 : name(n),
37 factory(std::move(f)),
38 param_resolver_instance(this->factory()) {}
39 string name;
40 CollectiveRegistry::Factory factory;
41 CollectiveImplementationInterface* param_resolver_instance;
42 };
43
MutableCollectiveRegistry()44 std::vector<RegistrationInfo>* MutableCollectiveRegistry() {
45 static std::vector<RegistrationInfo>* registry =
46 new std::vector<RegistrationInfo>;
47 return registry;
48 }
49 } // namespace
50
ToString() const51 string CollGroupRuntimeDetails::ToString() const {
52 return strings::StrCat("CollGroupRuntimeDetails {communicator_key=",
53 absl::CEscape(communicator_key), "}");
54 }
55
ToString() const56 string CollGroupParams::ToString() const {
57 string v = strings::StrCat(
58 "CollGroupParams {group_key=", group_key, " group_size=", group_size,
59 " device_type=", device_type.type_string(), " num_tasks=", num_tasks,
60 " runtime_details=", runtime_details.ToString(), " devices {");
61 for (const auto& d : device_names) {
62 strings::StrAppend(&v, d, ",");
63 }
64 strings::StrAppend(&v, "} task_names={");
65 for (const auto& n : task_names) {
66 strings::StrAppend(&v, n, ", ");
67 }
68 strings::StrAppend(&v, "} num_devices_per_task={");
69 for (const auto& dpt : num_devices_per_task) {
70 strings::StrAppend(&v, dpt.first, ": ", dpt.second, ", ");
71 }
72 strings::StrAppend(&v, "}");
73 return v;
74 }
75
operator =(const CollInstanceParams & other)76 CollInstanceParams& CollInstanceParams::operator=(
77 const CollInstanceParams& other) {
78 if (this != &other) {
79 instance_key = other.instance_key;
80 type = other.type;
81 data_type = other.data_type;
82 shape = other.shape;
83 impl_details.subdiv_offsets.assign(
84 other.impl_details.subdiv_offsets.begin(),
85 other.impl_details.subdiv_offsets.end());
86 impl_details.subdiv_permutations.clear();
87 for (auto p : other.impl_details.subdiv_permutations) {
88 impl_details.subdiv_permutations.push_back(
89 std::vector<int>(p.begin(), p.end()));
90 }
91 impl_details.subdiv_source_rank.assign(
92 other.impl_details.subdiv_source_rank.begin(),
93 other.impl_details.subdiv_source_rank.end());
94 impl_details.dependencies = other.impl_details.dependencies;
95 devices.assign(other.devices.begin(), other.devices.end());
96 permutation.assign(other.permutation.begin(), other.permutation.end());
97 }
98 return *this;
99 }
100
ToString() const101 string CollInstanceParams::ToString() const {
102 string v =
103 strings::StrCat("CollInstanceParams { instance_key=", instance_key,
104 " type=", type, " data_type=", DataTypeString(data_type),
105 " shape=", shape.DebugString(), " devices {");
106 strings::StrAppend(&v, "}, collective_name=", impl_details.collective_name,
107 ", subdiv_offsets={");
108 strings::StrAppend(&v, "}, subdiv_offsets={");
109 for (const auto& d : impl_details.subdiv_offsets) {
110 strings::StrAppend(&v, d, ",");
111 }
112 strings::StrAppend(&v, "}, subdiv_perms={");
113 for (const auto& p : impl_details.subdiv_permutations) {
114 strings::StrAppend(&v, "{");
115 for (const auto& i : p) {
116 strings::StrAppend(&v, i, ",");
117 }
118 strings::StrAppend(&v, "}"); // one subdiv
119 }
120 if (!impl_details.subdiv_source_rank.empty()) {
121 strings::StrAppend(&v, " subdiv_source_rank={");
122 for (const auto& r : impl_details.subdiv_source_rank) {
123 strings::StrAppend(&v, r, ",");
124 }
125 strings::StrAppend(&v, "}");
126 } // all subdivs
127 if (type == PERMUTE_COLLECTIVE) {
128 strings::StrAppend(&v, "}, permute_devices {");
129 for (const auto& d : devices) {
130 strings::StrAppend(&v, d, ",");
131 }
132 strings::StrAppend(&v, "}, permute_permutation {");
133 for (const auto& p : permutation) {
134 strings::StrAppend(&v, p, ",");
135 }
136 strings::StrAppend(&v, "}");
137 }
138 return v;
139 }
140
ToString() const141 string CollTaskParams::ToString() const {
142 string v = strings::StrCat("CollTaskParams {is_local={");
143 for (const auto& b : is_local) {
144 strings::StrAppend(&v, static_cast<int>(b), ",");
145 }
146 strings::StrAppend(&v, "}}");
147 return v;
148 }
149
ToString() const150 string CollectiveParams::ToString() const {
151 string v = strings::StrCat("CollectiveParams ", name, " {", group.ToString());
152 strings::StrAppend(&v, " ", instance.ToString());
153 strings::StrAppend(&v, " ", task.ToString());
154 strings::StrAppend(&v, " default_rank=", default_rank,
155 " is_source=", is_source, " source_rank=", source_rank,
156 " subdiv_rank={");
157 for (const auto& r : subdiv_rank) {
158 strings::StrAppend(&v, r, ",");
159 }
160 strings::StrAppend(&v, "}}");
161 return v;
162 }
163
CtxParams(OpKernelContext * ctx)164 /*static*/ OpKernelContext::Params* CollectiveExecutor::CtxParams(
165 OpKernelContext* ctx) {
166 return ctx->params_;
167 }
168
CollectiveContext(CollectiveExecutor * col_exec,NcclCommunicatorInterface * nccl_communicator,const DeviceMgr * dev_mgr,OpKernelContext * ctx,OpKernelContext::Params * op_params,const CollectiveParams * col_params,const string & exec_key,int64 step_id,const Tensor * input,Tensor * output)169 CollectiveContext::CollectiveContext(
170 CollectiveExecutor* col_exec, NcclCommunicatorInterface* nccl_communicator,
171 const DeviceMgr* dev_mgr, OpKernelContext* ctx,
172 OpKernelContext::Params* op_params, const CollectiveParams* col_params,
173 const string& exec_key, int64 step_id, const Tensor* input, Tensor* output)
174 : col_exec(col_exec),
175 nccl_communicator(nccl_communicator),
176 dev_mgr(dev_mgr),
177 op_ctx(ctx),
178 op_params(op_params),
179 col_params(col_params),
180 exec_key(exec_key),
181 step_id(step_id),
182 input(input),
183 output(output),
184 device(nullptr),
185 device_name(col_params->group.device_names[col_params->default_rank]) {}
186
187 /*static*/
188 int64 CollectiveExecutor::kInvalidId = -1;
189
190 /*static*/
Lookup(const string & collective_name,CollectiveImplementationInterface ** implementation)191 Status CollectiveRegistry::Lookup(
192 const string& collective_name,
193 CollectiveImplementationInterface** implementation) {
194 return LookupHelper(collective_name, implementation, false);
195 }
196
197 /*static*/
LookupParamResolverInstance(const string & collective_name,CollectiveImplementationInterface ** implementation)198 Status CollectiveRegistry::LookupParamResolverInstance(
199 const string& collective_name,
200 CollectiveImplementationInterface** implementation) {
201 return LookupHelper(collective_name, implementation, true);
202 }
203
204 /*static*/
GetAll(std::vector<CollectiveImplementationInterface * > * implementations)205 void CollectiveRegistry::GetAll(
206 std::vector<CollectiveImplementationInterface*>* implementations) {
207 std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry();
208 for (const RegistrationInfo& reg_info : *registry)
209 implementations->emplace_back(reg_info.factory());
210 }
211
212 /*static*/
Register(const string & collective_name,Factory factory)213 Status CollectiveRegistry::Register(const string& collective_name,
214 Factory factory) {
215 std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry();
216 for (const RegistrationInfo& reg_info : *registry) {
217 if (reg_info.name == collective_name)
218 return errors::Internal("Already registered collective ",
219 collective_name);
220 }
221 registry->emplace_back(collective_name, std::move(factory));
222 return Status::OK();
223 }
224
225 /*static*/
LookupHelper(const string & collective_name,CollectiveImplementationInterface ** implementation,bool param_resolver)226 Status CollectiveRegistry::LookupHelper(
227 const string& collective_name,
228 CollectiveImplementationInterface** implementation, bool param_resolver) {
229 std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry();
230 for (const RegistrationInfo& reg_info : *registry) {
231 if (reg_info.name == collective_name) {
232 if (param_resolver) {
233 *implementation = reg_info.param_resolver_instance;
234 } else {
235 *implementation = reg_info.factory();
236 }
237 return Status::OK();
238 }
239 }
240 return errors::Internal(
241 "CollectiveRegistry::Lookup did not find collective implementation ",
242 collective_name);
243 }
244
245 } // namespace tensorflow
246