1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/collective.h"
16
17 #include "absl/strings/escaping.h"
18 #include "tensorflow/core/framework/op_kernel.h"
19 #include "tensorflow/core/lib/core/errors.h"
20 #include "tensorflow/core/lib/hash/hash.h"
21 #include "tensorflow/core/lib/strings/str_util.h"
22 #include "tensorflow/core/lib/strings/strcat.h"
23
24 namespace tensorflow {
25
26 namespace {
27 // A RegistrationInfo object stores a collective implementation registration
28 // details. `factory` is used to create instances of the collective
29 // implementation.
30 struct RegistrationInfo {
31 // This constructor also creates, and stores in `param_resolver_instance`,
32 // what is effectively a static instance of the collective implementation.
33 // During param resolution of collective ops we return this static instance.
34 // The actual op execution gets a fresh instance using `factory`.
RegistrationInfotensorflow::__anonbe41636d0111::RegistrationInfo35 RegistrationInfo(const string& n, CollectiveRegistry::Factory f)
36 : name(n),
37 factory(std::move(f)),
38 param_resolver_instance(this->factory()) {}
39 string name;
40 CollectiveRegistry::Factory factory;
41 CollectiveImplementationInterface* param_resolver_instance;
42 };
43
MutableCollectiveRegistry()44 std::vector<RegistrationInfo>* MutableCollectiveRegistry() {
45 static std::vector<RegistrationInfo>* registry =
46 new std::vector<RegistrationInfo>;
47 return registry;
48 }
49 } // namespace
50
ToString() const51 string CollGroupRuntimeDetails::ToString() const {
52 return strings::StrCat("CollGroupRuntimeDetails {communicator_key=",
53 absl::CEscape(communicator_key), "}");
54 }
55
ToString() const56 string CollGroupParams::ToString() const {
57 string v = strings::StrCat(
58 "CollGroupParams {group_key=", group_key, " group_size=", group_size,
59 " device_type=", device_type.type_string(), " num_tasks=", num_tasks,
60 " runtime_details=", runtime_details.ToString(), " devices {");
61 for (const auto& m : members) {
62 strings::StrAppend(&v, m.device.name(), ",");
63 }
64 strings::StrAppend(&v, "} num_devices_per_task={");
65 for (const auto& dpt : num_devices_per_task) {
66 strings::StrAppend(&v, dpt.first, ": ", dpt.second, ", ");
67 }
68 strings::StrAppend(&v, "}");
69 return v;
70 }
71
operator =(const CollInstanceParams & other)72 CollInstanceParams& CollInstanceParams::operator=(
73 const CollInstanceParams& other) {
74 if (this != &other) {
75 instance_key = other.instance_key;
76 type = other.type;
77 data_type = other.data_type;
78 shape = other.shape;
79 impl_details.subdiv_offsets.assign(
80 other.impl_details.subdiv_offsets.begin(),
81 other.impl_details.subdiv_offsets.end());
82 impl_details.subdiv_permutations.clear();
83 for (auto p : other.impl_details.subdiv_permutations) {
84 impl_details.subdiv_permutations.push_back(
85 std::vector<int>(p.begin(), p.end()));
86 }
87 impl_details.subdiv_source_rank.assign(
88 other.impl_details.subdiv_source_rank.begin(),
89 other.impl_details.subdiv_source_rank.end());
90 impl_details.dependencies = other.impl_details.dependencies;
91 devices.assign(other.devices.begin(), other.devices.end());
92 permutation.assign(other.permutation.begin(), other.permutation.end());
93 }
94 return *this;
95 }
96
ToString() const97 string CollInstanceParams::ToString() const {
98 string v =
99 strings::StrCat("CollInstanceParams { instance_key=", instance_key,
100 " type=", type, " data_type=", DataTypeString(data_type),
101 " shape=", shape.DebugString(), " devices {");
102 strings::StrAppend(&v, "}, collective_name=", impl_details.collective_name,
103 ", subdiv_offsets={");
104 strings::StrAppend(&v, "}, subdiv_offsets={");
105 for (const auto& d : impl_details.subdiv_offsets) {
106 strings::StrAppend(&v, d, ",");
107 }
108 strings::StrAppend(&v, "}, subdiv_perms={");
109 for (const auto& p : impl_details.subdiv_permutations) {
110 strings::StrAppend(&v, "{");
111 for (const auto& i : p) {
112 strings::StrAppend(&v, i, ",");
113 }
114 strings::StrAppend(&v, "}"); // one subdiv
115 }
116 if (!impl_details.subdiv_source_rank.empty()) {
117 strings::StrAppend(&v, " subdiv_source_rank={");
118 for (const auto& r : impl_details.subdiv_source_rank) {
119 strings::StrAppend(&v, r, ",");
120 }
121 strings::StrAppend(&v, "}");
122 } // all subdivs
123 if (type == PERMUTE_COLLECTIVE) {
124 strings::StrAppend(&v, "}, permute_devices {");
125 for (const auto& d : devices) {
126 strings::StrAppend(&v, d, ",");
127 }
128 strings::StrAppend(&v, "}, permute_permutation {");
129 for (const auto& p : permutation) {
130 strings::StrAppend(&v, p, ",");
131 }
132 strings::StrAppend(&v, "}");
133 }
134 return v;
135 }
136
ToString() const137 string CollectiveParams::ToString() const {
138 string v = strings::StrCat("CollectiveParams ", name, " {", group.ToString());
139 strings::StrAppend(&v, " ", instance.ToString());
140 strings::StrAppend(&v, " default_rank=", default_rank,
141 " is_source=", is_source, " source_rank=", source_rank,
142 " subdiv_rank={");
143 for (const auto& r : subdiv_rank) {
144 strings::StrAppend(&v, r, ",");
145 }
146 strings::StrAppend(&v, "}}");
147 return v;
148 }
149
CtxParams(OpKernelContext * ctx)150 /*static*/ OpKernelContext::Params* CollectiveExecutor::CtxParams(
151 OpKernelContext* ctx) {
152 return ctx->params_;
153 }
154
CollectiveContext(CollectiveExecutor * col_exec,NcclCommunicatorInterface * nccl_communicator,const DeviceMgr * dev_mgr,OpKernelContext * ctx,OpKernelContext::Params * op_params,const CollectiveParams * col_params,const string & exec_key,int64_t step_id,const Tensor * input,Tensor * output)155 CollectiveContext::CollectiveContext(
156 CollectiveExecutor* col_exec, NcclCommunicatorInterface* nccl_communicator,
157 const DeviceMgr* dev_mgr, OpKernelContext* ctx,
158 OpKernelContext::Params* op_params, const CollectiveParams* col_params,
159 const string& exec_key, int64_t step_id, const Tensor* input,
160 Tensor* output)
161 : col_exec(col_exec),
162 nccl_communicator(nccl_communicator),
163 dev_mgr(dev_mgr),
164 op_ctx(ctx),
165 op_params(op_params),
166 col_params(col_params, /*add_ref=*/true),
167 exec_key(exec_key),
168 step_id(step_id),
169 input(input),
170 output(output),
171 device(nullptr),
172 device_name(
173 col_params->group.members[col_params->default_rank].device.name()) {}
174
175 /*static*/
176 int64_t CollectiveExecutor::kInvalidId = -1;
177
178 /*static*/
Lookup(const string & collective_name,CollectiveImplementationInterface ** implementation)179 Status CollectiveRegistry::Lookup(
180 const string& collective_name,
181 CollectiveImplementationInterface** implementation) {
182 return LookupHelper(collective_name, implementation, false);
183 }
184
185 /*static*/
LookupParamResolverInstance(const string & collective_name,CollectiveImplementationInterface ** implementation)186 Status CollectiveRegistry::LookupParamResolverInstance(
187 const string& collective_name,
188 CollectiveImplementationInterface** implementation) {
189 return LookupHelper(collective_name, implementation, true);
190 }
191
192 /*static*/
GetAll(std::vector<CollectiveImplementationInterface * > * implementations)193 void CollectiveRegistry::GetAll(
194 std::vector<CollectiveImplementationInterface*>* implementations) {
195 std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry();
196 for (const RegistrationInfo& reg_info : *registry)
197 implementations->emplace_back(reg_info.factory());
198 }
199
200 /*static*/
Register(const string & collective_name,Factory factory)201 Status CollectiveRegistry::Register(const string& collective_name,
202 Factory factory) {
203 std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry();
204 for (const RegistrationInfo& reg_info : *registry) {
205 if (reg_info.name == collective_name)
206 return errors::Internal("Already registered collective ",
207 collective_name);
208 }
209 registry->emplace_back(collective_name, std::move(factory));
210 return OkStatus();
211 }
212
213 /*static*/
LookupHelper(const string & collective_name,CollectiveImplementationInterface ** implementation,bool param_resolver)214 Status CollectiveRegistry::LookupHelper(
215 const string& collective_name,
216 CollectiveImplementationInterface** implementation, bool param_resolver) {
217 std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry();
218 for (const RegistrationInfo& reg_info : *registry) {
219 if (reg_info.name == collective_name) {
220 if (param_resolver) {
221 *implementation = reg_info.param_resolver_instance;
222 } else {
223 *implementation = reg_info.factory();
224 }
225 return OkStatus();
226 }
227 }
228 return errors::Internal(
229 "CollectiveRegistry::Lookup did not find collective implementation ",
230 collective_name);
231 }
232
233 } // namespace tensorflow
234