• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/collective.h"
16 
17 #include "absl/strings/escaping.h"
18 #include "tensorflow/core/framework/op_kernel.h"
19 #include "tensorflow/core/lib/core/errors.h"
20 #include "tensorflow/core/lib/hash/hash.h"
21 #include "tensorflow/core/lib/strings/str_util.h"
22 #include "tensorflow/core/lib/strings/strcat.h"
23 
24 namespace tensorflow {
25 
26 namespace {
27 // A RegistrationInfo object stores a collective implementation registration
28 // details.  `factory` is used to create instances of the collective
29 // implementation.
30 struct RegistrationInfo {
31   // This constructor also creates, and stores in `param_resolver_instance`,
32   // what is effectively a static instance of the collective implementation.
33   // During param resolution of collective ops we return this static instance.
34   // The actual op execution gets a fresh instance using `factory`.
RegistrationInfotensorflow::__anonbe41636d0111::RegistrationInfo35   RegistrationInfo(const string& n, CollectiveRegistry::Factory f)
36       : name(n),
37         factory(std::move(f)),
38         param_resolver_instance(this->factory()) {}
39   string name;
40   CollectiveRegistry::Factory factory;
41   CollectiveImplementationInterface* param_resolver_instance;
42 };
43 
MutableCollectiveRegistry()44 std::vector<RegistrationInfo>* MutableCollectiveRegistry() {
45   static std::vector<RegistrationInfo>* registry =
46       new std::vector<RegistrationInfo>;
47   return registry;
48 }
49 }  // namespace
50 
ToString() const51 string CollGroupRuntimeDetails::ToString() const {
52   return strings::StrCat("CollGroupRuntimeDetails {communicator_key=",
53                          absl::CEscape(communicator_key), "}");
54 }
55 
ToString() const56 string CollGroupParams::ToString() const {
57   string v = strings::StrCat(
58       "CollGroupParams {group_key=", group_key, " group_size=", group_size,
59       " device_type=", device_type.type_string(), " num_tasks=", num_tasks,
60       " runtime_details=", runtime_details.ToString(), " devices {");
61   for (const auto& m : members) {
62     strings::StrAppend(&v, m.device.name(), ",");
63   }
64   strings::StrAppend(&v, "} num_devices_per_task={");
65   for (const auto& dpt : num_devices_per_task) {
66     strings::StrAppend(&v, dpt.first, ": ", dpt.second, ", ");
67   }
68   strings::StrAppend(&v, "}");
69   return v;
70 }
71 
operator =(const CollInstanceParams & other)72 CollInstanceParams& CollInstanceParams::operator=(
73     const CollInstanceParams& other) {
74   if (this != &other) {
75     instance_key = other.instance_key;
76     type = other.type;
77     data_type = other.data_type;
78     shape = other.shape;
79     impl_details.subdiv_offsets.assign(
80         other.impl_details.subdiv_offsets.begin(),
81         other.impl_details.subdiv_offsets.end());
82     impl_details.subdiv_permutations.clear();
83     for (auto p : other.impl_details.subdiv_permutations) {
84       impl_details.subdiv_permutations.push_back(
85           std::vector<int>(p.begin(), p.end()));
86     }
87     impl_details.subdiv_source_rank.assign(
88         other.impl_details.subdiv_source_rank.begin(),
89         other.impl_details.subdiv_source_rank.end());
90     impl_details.dependencies = other.impl_details.dependencies;
91     devices.assign(other.devices.begin(), other.devices.end());
92     permutation.assign(other.permutation.begin(), other.permutation.end());
93   }
94   return *this;
95 }
96 
ToString() const97 string CollInstanceParams::ToString() const {
98   string v =
99       strings::StrCat("CollInstanceParams { instance_key=", instance_key,
100                       " type=", type, " data_type=", DataTypeString(data_type),
101                       " shape=", shape.DebugString(), " devices {");
102   strings::StrAppend(&v, "}, collective_name=", impl_details.collective_name,
103                      ", subdiv_offsets={");
104   strings::StrAppend(&v, "}, subdiv_offsets={");
105   for (const auto& d : impl_details.subdiv_offsets) {
106     strings::StrAppend(&v, d, ",");
107   }
108   strings::StrAppend(&v, "}, subdiv_perms={");
109   for (const auto& p : impl_details.subdiv_permutations) {
110     strings::StrAppend(&v, "{");
111     for (const auto& i : p) {
112       strings::StrAppend(&v, i, ",");
113     }
114     strings::StrAppend(&v, "}");  // one subdiv
115   }
116   if (!impl_details.subdiv_source_rank.empty()) {
117     strings::StrAppend(&v, " subdiv_source_rank={");
118     for (const auto& r : impl_details.subdiv_source_rank) {
119       strings::StrAppend(&v, r, ",");
120     }
121     strings::StrAppend(&v, "}");
122   }  // all subdivs
123   if (type == PERMUTE_COLLECTIVE) {
124     strings::StrAppend(&v, "}, permute_devices {");
125     for (const auto& d : devices) {
126       strings::StrAppend(&v, d, ",");
127     }
128     strings::StrAppend(&v, "}, permute_permutation {");
129     for (const auto& p : permutation) {
130       strings::StrAppend(&v, p, ",");
131     }
132     strings::StrAppend(&v, "}");
133   }
134   return v;
135 }
136 
ToString() const137 string CollectiveParams::ToString() const {
138   string v = strings::StrCat("CollectiveParams ", name, " {", group.ToString());
139   strings::StrAppend(&v, " ", instance.ToString());
140   strings::StrAppend(&v, " default_rank=", default_rank,
141                      " is_source=", is_source, " source_rank=", source_rank,
142                      " subdiv_rank={");
143   for (const auto& r : subdiv_rank) {
144     strings::StrAppend(&v, r, ",");
145   }
146   strings::StrAppend(&v, "}}");
147   return v;
148 }
149 
CtxParams(OpKernelContext * ctx)150 /*static*/ OpKernelContext::Params* CollectiveExecutor::CtxParams(
151     OpKernelContext* ctx) {
152   return ctx->params_;
153 }
154 
CollectiveContext(CollectiveExecutor * col_exec,NcclCommunicatorInterface * nccl_communicator,const DeviceMgr * dev_mgr,OpKernelContext * ctx,OpKernelContext::Params * op_params,const CollectiveParams * col_params,const string & exec_key,int64_t step_id,const Tensor * input,Tensor * output)155 CollectiveContext::CollectiveContext(
156     CollectiveExecutor* col_exec, NcclCommunicatorInterface* nccl_communicator,
157     const DeviceMgr* dev_mgr, OpKernelContext* ctx,
158     OpKernelContext::Params* op_params, const CollectiveParams* col_params,
159     const string& exec_key, int64_t step_id, const Tensor* input,
160     Tensor* output)
161     : col_exec(col_exec),
162       nccl_communicator(nccl_communicator),
163       dev_mgr(dev_mgr),
164       op_ctx(ctx),
165       op_params(op_params),
166       col_params(col_params, /*add_ref=*/true),
167       exec_key(exec_key),
168       step_id(step_id),
169       input(input),
170       output(output),
171       device(nullptr),
172       device_name(
173           col_params->group.members[col_params->default_rank].device.name()) {}
174 
175 /*static*/
176 int64_t CollectiveExecutor::kInvalidId = -1;
177 
178 /*static*/
Lookup(const string & collective_name,CollectiveImplementationInterface ** implementation)179 Status CollectiveRegistry::Lookup(
180     const string& collective_name,
181     CollectiveImplementationInterface** implementation) {
182   return LookupHelper(collective_name, implementation, false);
183 }
184 
185 /*static*/
LookupParamResolverInstance(const string & collective_name,CollectiveImplementationInterface ** implementation)186 Status CollectiveRegistry::LookupParamResolverInstance(
187     const string& collective_name,
188     CollectiveImplementationInterface** implementation) {
189   return LookupHelper(collective_name, implementation, true);
190 }
191 
192 /*static*/
GetAll(std::vector<CollectiveImplementationInterface * > * implementations)193 void CollectiveRegistry::GetAll(
194     std::vector<CollectiveImplementationInterface*>* implementations) {
195   std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry();
196   for (const RegistrationInfo& reg_info : *registry)
197     implementations->emplace_back(reg_info.factory());
198 }
199 
200 /*static*/
Register(const string & collective_name,Factory factory)201 Status CollectiveRegistry::Register(const string& collective_name,
202                                     Factory factory) {
203   std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry();
204   for (const RegistrationInfo& reg_info : *registry) {
205     if (reg_info.name == collective_name)
206       return errors::Internal("Already registered collective ",
207                               collective_name);
208   }
209   registry->emplace_back(collective_name, std::move(factory));
210   return OkStatus();
211 }
212 
213 /*static*/
LookupHelper(const string & collective_name,CollectiveImplementationInterface ** implementation,bool param_resolver)214 Status CollectiveRegistry::LookupHelper(
215     const string& collective_name,
216     CollectiveImplementationInterface** implementation, bool param_resolver) {
217   std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry();
218   for (const RegistrationInfo& reg_info : *registry) {
219     if (reg_info.name == collective_name) {
220       if (param_resolver) {
221         *implementation = reg_info.param_resolver_instance;
222       } else {
223         *implementation = reg_info.factory();
224       }
225       return OkStatus();
226     }
227   }
228   return errors::Internal(
229       "CollectiveRegistry::Lookup did not find collective implementation ",
230       collective_name);
231 }
232 
233 }  // namespace tensorflow
234