• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/collective.h"
16 
17 #include "absl/strings/escaping.h"
18 #include "tensorflow/core/framework/op_kernel.h"
19 #include "tensorflow/core/lib/core/errors.h"
20 #include "tensorflow/core/lib/hash/hash.h"
21 #include "tensorflow/core/lib/strings/str_util.h"
22 #include "tensorflow/core/lib/strings/strcat.h"
23 
24 namespace tensorflow {
25 
26 namespace {
27 // A RegistrationInfo object stores a collective implementation registration
28 // details.  `factory` is used to create instances of the collective
29 // implementation.
30 struct RegistrationInfo {
31   // This constructor also creates, and stores in `param_resolver_instance`,
32   // what is effectively a static instance of the collective implementation.
33   // During param resolution of collective ops we return this static instance.
34   // The actual op execution gets a fresh instance using `factory`.
RegistrationInfotensorflow::__anon237cdfb40111::RegistrationInfo35   RegistrationInfo(const string& n, CollectiveRegistry::Factory f)
36       : name(n),
37         factory(std::move(f)),
38         param_resolver_instance(this->factory()) {}
39   string name;
40   CollectiveRegistry::Factory factory;
41   CollectiveImplementationInterface* param_resolver_instance;
42 };
43 
MutableCollectiveRegistry()44 std::vector<RegistrationInfo>* MutableCollectiveRegistry() {
45   static std::vector<RegistrationInfo>* registry =
46       new std::vector<RegistrationInfo>;
47   return registry;
48 }
49 }  // namespace
50 
ToString() const51 string CollGroupRuntimeDetails::ToString() const {
52   return strings::StrCat("CollGroupRuntimeDetails {communicator_key=",
53                          absl::CEscape(communicator_key), "}");
54 }
55 
ToString() const56 string CollGroupParams::ToString() const {
57   string v = strings::StrCat(
58       "CollGroupParams {group_key=", group_key, " group_size=", group_size,
59       " device_type=", device_type.type_string(), " num_tasks=", num_tasks,
60       " runtime_details=", runtime_details.ToString(), " devices {");
61   for (const auto& d : devices) {
62     strings::StrAppend(&v, d.name(), ",");
63   }
64   strings::StrAppend(&v, "} task_names={");
65   for (const auto& n : task_names) {
66     strings::StrAppend(&v, n, ", ");
67   }
68   strings::StrAppend(&v, "} num_devices_per_task={");
69   for (const auto& dpt : num_devices_per_task) {
70     strings::StrAppend(&v, dpt.first, ": ", dpt.second, ", ");
71   }
72   strings::StrAppend(&v, "}");
73   return v;
74 }
75 
operator =(const CollInstanceParams & other)76 CollInstanceParams& CollInstanceParams::operator=(
77     const CollInstanceParams& other) {
78   if (this != &other) {
79     instance_key = other.instance_key;
80     type = other.type;
81     data_type = other.data_type;
82     shape = other.shape;
83     impl_details.subdiv_offsets.assign(
84         other.impl_details.subdiv_offsets.begin(),
85         other.impl_details.subdiv_offsets.end());
86     impl_details.subdiv_permutations.clear();
87     for (auto p : other.impl_details.subdiv_permutations) {
88       impl_details.subdiv_permutations.push_back(
89           std::vector<int>(p.begin(), p.end()));
90     }
91     impl_details.subdiv_source_rank.assign(
92         other.impl_details.subdiv_source_rank.begin(),
93         other.impl_details.subdiv_source_rank.end());
94     impl_details.dependencies = other.impl_details.dependencies;
95     devices.assign(other.devices.begin(), other.devices.end());
96     permutation.assign(other.permutation.begin(), other.permutation.end());
97   }
98   return *this;
99 }
100 
ToString() const101 string CollInstanceParams::ToString() const {
102   string v =
103       strings::StrCat("CollInstanceParams { instance_key=", instance_key,
104                       " type=", type, " data_type=", DataTypeString(data_type),
105                       " shape=", shape.DebugString(), " devices {");
106   strings::StrAppend(&v, "}, collective_name=", impl_details.collective_name,
107                      ", subdiv_offsets={");
108   strings::StrAppend(&v, "}, subdiv_offsets={");
109   for (const auto& d : impl_details.subdiv_offsets) {
110     strings::StrAppend(&v, d, ",");
111   }
112   strings::StrAppend(&v, "}, subdiv_perms={");
113   for (const auto& p : impl_details.subdiv_permutations) {
114     strings::StrAppend(&v, "{");
115     for (const auto& i : p) {
116       strings::StrAppend(&v, i, ",");
117     }
118     strings::StrAppend(&v, "}");  // one subdiv
119   }
120   if (!impl_details.subdiv_source_rank.empty()) {
121     strings::StrAppend(&v, " subdiv_source_rank={");
122     for (const auto& r : impl_details.subdiv_source_rank) {
123       strings::StrAppend(&v, r, ",");
124     }
125     strings::StrAppend(&v, "}");
126   }  // all subdivs
127   if (type == PERMUTE_COLLECTIVE) {
128     strings::StrAppend(&v, "}, permute_devices {");
129     for (const auto& d : devices) {
130       strings::StrAppend(&v, d, ",");
131     }
132     strings::StrAppend(&v, "}, permute_permutation {");
133     for (const auto& p : permutation) {
134       strings::StrAppend(&v, p, ",");
135     }
136     strings::StrAppend(&v, "}");
137   }
138   return v;
139 }
140 
ToString() const141 string CollTaskParams::ToString() const {
142   string v = strings::StrCat("CollTaskParams {is_local={");
143   for (const auto& b : is_local) {
144     strings::StrAppend(&v, static_cast<int>(b), ",");
145   }
146   strings::StrAppend(&v, "}}");
147   return v;
148 }
149 
ToString() const150 string CollectiveParams::ToString() const {
151   string v = strings::StrCat("CollectiveParams ", name, " {", group.ToString());
152   strings::StrAppend(&v, " ", instance.ToString());
153   strings::StrAppend(&v, " ", task.ToString());
154   strings::StrAppend(&v, " default_rank=", default_rank,
155                      " is_source=", is_source, " source_rank=", source_rank,
156                      " subdiv_rank={");
157   for (const auto& r : subdiv_rank) {
158     strings::StrAppend(&v, r, ",");
159   }
160   strings::StrAppend(&v, "}}");
161   return v;
162 }
163 
CtxParams(OpKernelContext * ctx)164 /*static*/ OpKernelContext::Params* CollectiveExecutor::CtxParams(
165     OpKernelContext* ctx) {
166   return ctx->params_;
167 }
168 
CollectiveContext(CollectiveExecutor * col_exec,NcclCommunicatorInterface * nccl_communicator,const DeviceMgr * dev_mgr,OpKernelContext * ctx,OpKernelContext::Params * op_params,const CollectiveParams * col_params,const string & exec_key,int64_t step_id,const Tensor * input,Tensor * output)169 CollectiveContext::CollectiveContext(
170     CollectiveExecutor* col_exec, NcclCommunicatorInterface* nccl_communicator,
171     const DeviceMgr* dev_mgr, OpKernelContext* ctx,
172     OpKernelContext::Params* op_params, const CollectiveParams* col_params,
173     const string& exec_key, int64_t step_id, const Tensor* input,
174     Tensor* output)
175     : col_exec(col_exec),
176       nccl_communicator(nccl_communicator),
177       dev_mgr(dev_mgr),
178       op_ctx(ctx),
179       op_params(op_params),
180       col_params(col_params),
181       exec_key(exec_key),
182       step_id(step_id),
183       input(input),
184       output(output),
185       device(nullptr),
186       device_name(col_params->group.devices[col_params->default_rank].name()) {}
187 
188 /*static*/
189 int64_t CollectiveExecutor::kInvalidId = -1;
190 
191 /*static*/
Lookup(const string & collective_name,CollectiveImplementationInterface ** implementation)192 Status CollectiveRegistry::Lookup(
193     const string& collective_name,
194     CollectiveImplementationInterface** implementation) {
195   return LookupHelper(collective_name, implementation, false);
196 }
197 
198 /*static*/
LookupParamResolverInstance(const string & collective_name,CollectiveImplementationInterface ** implementation)199 Status CollectiveRegistry::LookupParamResolverInstance(
200     const string& collective_name,
201     CollectiveImplementationInterface** implementation) {
202   return LookupHelper(collective_name, implementation, true);
203 }
204 
205 /*static*/
GetAll(std::vector<CollectiveImplementationInterface * > * implementations)206 void CollectiveRegistry::GetAll(
207     std::vector<CollectiveImplementationInterface*>* implementations) {
208   std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry();
209   for (const RegistrationInfo& reg_info : *registry)
210     implementations->emplace_back(reg_info.factory());
211 }
212 
213 /*static*/
Register(const string & collective_name,Factory factory)214 Status CollectiveRegistry::Register(const string& collective_name,
215                                     Factory factory) {
216   std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry();
217   for (const RegistrationInfo& reg_info : *registry) {
218     if (reg_info.name == collective_name)
219       return errors::Internal("Already registered collective ",
220                               collective_name);
221   }
222   registry->emplace_back(collective_name, std::move(factory));
223   return Status::OK();
224 }
225 
226 /*static*/
LookupHelper(const string & collective_name,CollectiveImplementationInterface ** implementation,bool param_resolver)227 Status CollectiveRegistry::LookupHelper(
228     const string& collective_name,
229     CollectiveImplementationInterface** implementation, bool param_resolver) {
230   std::vector<RegistrationInfo>* registry = MutableCollectiveRegistry();
231   for (const RegistrationInfo& reg_info : *registry) {
232     if (reg_info.name == collective_name) {
233       if (param_resolver) {
234         *implementation = reg_info.param_resolver_instance;
235       } else {
236         *implementation = reg_info.factory();
237       }
238       return Status::OK();
239     }
240   }
241   return errors::Internal(
242       "CollectiveRegistry::Lookup did not find collective implementation ",
243       collective_name);
244 }
245 
246 }  // namespace tensorflow
247