Searched refs:pmap_lib (Results 1 – 2 of 2) sorted by relevance
/external/tensorflow/tensorflow/compiler/xla/python/ |
D | pmap_lib.cc | 340 py::module pmap_lib = m.def_submodule("pmap_lib", "Jax C++ pmap library"); in BuildPmapSubmodule() local 342 py::class_<NoSharding> no_sharding(pmap_lib, "NoSharding"); in BuildPmapSubmodule() 350 py::class_<Chunked> chunked(pmap_lib, "Chunked"); in BuildPmapSubmodule() 365 py::class_<Unstacked> unstacked(pmap_lib, "Unstacked"); in BuildPmapSubmodule() 379 py::class_<ShardedAxis> sharded_axis(pmap_lib, "ShardedAxis"); in BuildPmapSubmodule() 390 py::class_<Replicated> replicated(pmap_lib, "Replicated"); in BuildPmapSubmodule() 401 py::class_<ShardingSpec> sharding_spec(pmap_lib, "ShardingSpec"); in BuildPmapSubmodule() 409 py::class_<ShardedDeviceArray> sda(pmap_lib, "ShardedDeviceArray"); in BuildPmapSubmodule() 417 py::class_<PmapFunction, std::unique_ptr<PmapFunction>> cfun(pmap_lib, in BuildPmapSubmodule() 422 pmap_lib.def( in BuildPmapSubmodule()
|
D | BUILD | 326 name = "pmap_lib", 327 srcs = ["pmap_lib.cc"], 328 hdrs = ["pmap_lib.h"], 510 ":pmap_lib",
|