Home
last modified time | relevance | path

Searched refs:col_params (Results 1 – 16 of 16) sorted by relevance

/external/tensorflow/tensorflow/core/common_runtime/
Dhierarchical_tree_broadcaster.cc77 CollectiveParams* col_params) { in InitializeCollectiveParams() argument
78 CHECK_EQ(col_params->instance.type, BROADCAST_COLLECTIVE); in InitializeCollectiveParams()
79 CHECK_EQ(col_params->instance.impl_details.collective_name, in InitializeCollectiveParams()
82 col_params->instance.device_names[col_params->default_rank]; in InitializeCollectiveParams()
87 << str_util::Join(col_params->instance.task_names, ", "); in InitializeCollectiveParams()
89 const string* prior_task_name = &col_params->instance.task_names[0]; in InitializeCollectiveParams()
91 for (int di = 1; di < col_params->group.group_size; ++di) { in InitializeCollectiveParams()
92 if (col_params->instance.task_names[di] != *prior_task_name) { in InitializeCollectiveParams()
95 prior_task_name = &col_params->instance.task_names[di]; in InitializeCollectiveParams()
101 CHECK_EQ(col_params->group.num_tasks, dev_per_task.size()); in InitializeCollectiveParams()
[all …]
Dring_alg.cc108 Status GenerateSubdivsInCollectiveParams(CollectiveParams* col_params) { in GenerateSubdivsInCollectiveParams() argument
109 if (col_params->instance.shape.num_elements() == 0) { in GenerateSubdivsInCollectiveParams()
113 col_params->group.group_size / col_params->group.num_tasks; in GenerateSubdivsInCollectiveParams()
118 col_params->instance.impl_details.collective_name); in GenerateSubdivsInCollectiveParams()
125 const size_t tensor_size = col_params->instance.shape.num_elements() * in GenerateSubdivsInCollectiveParams()
126 DataTypeSize(col_params->instance.data_type); in GenerateSubdivsInCollectiveParams()
130 int num_chunks = col_params->group.group_size * num_subdivs; in GenerateSubdivsInCollectiveParams()
137 col_params->instance.impl_details.collective_name); in GenerateSubdivsInCollectiveParams()
142 col_params->instance.impl_details.subdiv_offsets.reserve(num_subdivs); in GenerateSubdivsInCollectiveParams()
146 col_params->instance.impl_details.subdiv_offsets.push_back(subdiv_offset); in GenerateSubdivsInCollectiveParams()
[all …]
Dbase_collective_executor.cc219 const CollectiveParams& col_params, in ExecuteAsync() argument
240 const Tensor* input = (col_params.instance.type == REDUCTION_COLLECTIVE || in ExecuteAsync()
241 col_params.instance.type == GATHER_COLLECTIVE || in ExecuteAsync()
242 (col_params.instance.type == BROADCAST_COLLECTIVE && in ExecuteAsync()
243 col_params.is_source)) in ExecuteAsync()
247 Status status = CreateCollective(col_params, &col_impl); in ExecuteAsync()
254 new CollectiveContext(this, dev_mgr_, ctx, CtxParams(ctx), col_params, in ExecuteAsync()
284 const CollectiveParams& col_params, in CreateCollective() argument
288 switch (col_params.instance.data_type) { in CreateCollective()
290 if (col_params.group.device_type == DEVICE_GPU) { in CreateCollective()
[all …]
Dcollective_util.cc54 string SubdivPermDebugString(const CollectiveParams& col_params) { in SubdivPermDebugString() argument
56 col_params.instance.impl_details.subdiv_permutations; in SubdivPermDebugString()
63 CHECK_GT(col_params.instance.device_names.size(), idx); in SubdivPermDebugString()
64 strings::StrAppend(&buf, col_params.instance.device_names[idx], "\n"); in SubdivPermDebugString()
68 for (auto o : col_params.instance.impl_details.subdiv_offsets) in SubdivPermDebugString()
71 for (auto d : col_params.subdiv_rank) strings::StrAppend(&buf, d, " "); in SubdivPermDebugString()
72 if (col_params.instance.type == BROADCAST_COLLECTIVE) { in SubdivPermDebugString()
74 for (auto src : col_params.instance.impl_details.subdiv_source_rank) in SubdivPermDebugString()
Dring_gatherer.cc43 Status RingGatherer::InitializeCollectiveParams(CollectiveParams* col_params) { in InitializeCollectiveParams() argument
44 DCHECK_EQ(col_params->instance.type, GATHER_COLLECTIVE); in InitializeCollectiveParams()
45 DCHECK_EQ(col_params->instance.impl_details.collective_name, "RingGather"); in InitializeCollectiveParams()
49 if (!col_params->instance.impl_details.subdiv_offsets.empty() && in InitializeCollectiveParams()
50 (col_params->instance.impl_details.subdiv_offsets.size() > 1 || in InitializeCollectiveParams()
51 col_params->instance.impl_details.subdiv_offsets[0] != 0)) { in InitializeCollectiveParams()
55 if (col_params->instance.impl_details.subdiv_offsets.empty()) { in InitializeCollectiveParams()
56 col_params->instance.impl_details.subdiv_offsets.push_back(0); in InitializeCollectiveParams()
58 return RingAlg::InitializeCollectiveParams(col_params); in InitializeCollectiveParams()
Dbase_collective_executor.h110 void ExecuteAsync(OpKernelContext* ctx, const CollectiveParams& col_params,
148 void WaitForDependencies(const CollectiveParams& col_params) override;
152 void Launched(const CollectiveParams& col_params) override;
166 Status CreateCollective(const CollectiveParams& col_params,
169 bool CheckDependencies(const CollectiveParams& col_params)
Dring_reducer.cc46 Status RingReducer::InitializeCollectiveParams(CollectiveParams* col_params) { in InitializeCollectiveParams() argument
48 CHECK_EQ(col_params->instance.type, REDUCTION_COLLECTIVE); in InitializeCollectiveParams()
49 CHECK_EQ(col_params->instance.impl_details.collective_name, "RingReduce"); in InitializeCollectiveParams()
50 return RingAlg::InitializeCollectiveParams(col_params); in InitializeCollectiveParams()
Dring_gatherer.h36 Status InitializeCollectiveParams(CollectiveParams* col_params) override;
Dring_reducer.h42 Status InitializeCollectiveParams(CollectiveParams* col_params) override;
Dcollective_util.h33 string SubdivPermDebugString(const CollectiveParams& col_params);
Dhierarchical_tree_broadcaster.h38 Status InitializeCollectiveParams(CollectiveParams* col_params) override;
Dring_alg.h38 Status InitializeCollectiveParams(CollectiveParams* col_params) override;
/external/tensorflow/tensorflow/core/kernels/
Dcollective_nccl_reducer.cc31 Status NcclReducer::InitializeCollectiveParams(CollectiveParams* col_params) { in InitializeCollectiveParams() argument
32 if (col_params->instance.type != REDUCTION_COLLECTIVE || in InitializeCollectiveParams()
33 col_params->instance.impl_details.collective_name != "NcclReduce") { in InitializeCollectiveParams()
35 col_params->instance.type, " expected ", in InitializeCollectiveParams()
37 col_params->instance.impl_details.collective_name, in InitializeCollectiveParams()
46 col_params_ = &col_ctx->col_params; in InitializeCollectiveContext()
53 CollectiveParams* col_params) { in InitializeInstanceBeforeGroupDiscovery() argument
54 if (col_params->default_rank == 0 && col_params->group.num_tasks > 1) { in InitializeInstanceBeforeGroupDiscovery()
55 col_params->instance.communicator_key = in InitializeInstanceBeforeGroupDiscovery()
Dcollective_nccl_reducer.h29 Status InitializeCollectiveParams(CollectiveParams* col_params) override;
36 CollectiveParams* col_params) override;
/external/tensorflow/tensorflow/core/framework/
Dcollective.h262 const CollectiveParams& col_params, in ExecuteAsync() argument
288 virtual void WaitForDependencies(const CollectiveParams& col_params) {} in WaitForDependencies() argument
292 virtual void Launched(const CollectiveParams& col_params) {} in Launched() argument
343 const CollectiveParams& col_params, const string& exec_key,
352 const CollectiveParams& col_params; variable
378 virtual Status InitializeCollectiveParams(CollectiveParams* col_params) = 0;
392 CollectiveParams* col_params) = 0;
Dcollective.cc160 const CollectiveParams& col_params, in CollectiveContext() argument
167 col_params(col_params), in CollectiveContext()
173 device_name(col_params.instance.device_names[col_params.default_rank]) {} in CollectiveContext()