Searched refs:optimizer_weight_shard_size (Results 1 – 8 of 8) sorted by relevance
549 def set_optimizer_weight_shard_size(self, optimizer_weight_shard_size): argument558 …if not isinstance(optimizer_weight_shard_size, int) or isinstance(optimizer_weight_shard_size, boo…560 but got {type(optimizer_weight_shard_size)}.")561 if optimizer_weight_shard_size <= 1:565 self._context_handle.set_optimizer_weight_shard_size(optimizer_weight_shard_size)685 communi_parallel_mode=str, optimizer_weight_shard_size=int,
436 …int64_t optimizer_weight_shard_size = ParallelContext::GetInstance()->optimizer_weight_shard_size(… in GenerateOptShardSliceShape() local437 if (optimizer_weight_shard_size != -1) { in GenerateOptShardSliceShape()438 repeated_num = optimizer_weight_shard_size; in GenerateOptShardSliceShape()
102 void set_optimizer_weight_shard_size(int64_t optimizer_weight_shard_size);103 int64_t optimizer_weight_shard_size() const { return optimizer_weight_shard_size_; } in optimizer_weight_shard_size() function
145 void ParallelContext::set_optimizer_weight_shard_size(int64_t optimizer_weight_shard_size) { in set_optimizer_weight_shard_size() argument146 optimizer_weight_shard_size_ = optimizer_weight_shard_size; in set_optimizer_weight_shard_size()
565 …int64_t optimizer_weight_shard_size = ParallelContext::GetInstance()->optimizer_weight_shard_size(… in CreateGroupForOptShard() local566 if (optimizer_weight_shard_size != -1) { in CreateGroupForOptShard()570 if (repeated_size % optimizer_weight_shard_size != 0) { in CreateGroupForOptShard()571 …S_LOG(WARNING) << "Parallel optimizer: optimizer_weight_shard_size " << optimizer_weight_shard_size in CreateGroupForOptShard()575 repeated_size = repeated_size / optimizer_weight_shard_size; in CreateGroupForOptShard()579 group_devices.begin() + index / optimizer_weight_shard_size * optimizer_weight_shard_size, in CreateGroupForOptShard()580 … group_devices.begin() + (index / optimizer_weight_shard_size + 1) * optimizer_weight_shard_size); in CreateGroupForOptShard()609 tensor_layout->set_opt_weight_shard_size(LongToInt(optimizer_weight_shard_size)); in CreateGroupForOptShard()
100 context.set_auto_parallel_context(optimizer_weight_shard_size=2)
137 context.set_auto_parallel_context(optimizer_weight_shard_size=2)
193 .def("get_optimizer_weight_shard_size", &ParallelContext::optimizer_weight_shard_size, in PYBIND11_MODULE()