Home
last modified time | relevance | path

Searched refs:pmap_lib (Results 1 – 2 of 2) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/python/
Dpmap_lib.cc340 py::module pmap_lib = m.def_submodule("pmap_lib", "Jax C++ pmap library"); in BuildPmapSubmodule() local
342 py::class_<NoSharding> no_sharding(pmap_lib, "NoSharding"); in BuildPmapSubmodule()
350 py::class_<Chunked> chunked(pmap_lib, "Chunked"); in BuildPmapSubmodule()
365 py::class_<Unstacked> unstacked(pmap_lib, "Unstacked"); in BuildPmapSubmodule()
379 py::class_<ShardedAxis> sharded_axis(pmap_lib, "ShardedAxis"); in BuildPmapSubmodule()
390 py::class_<Replicated> replicated(pmap_lib, "Replicated"); in BuildPmapSubmodule()
401 py::class_<ShardingSpec> sharding_spec(pmap_lib, "ShardingSpec"); in BuildPmapSubmodule()
409 py::class_<ShardedDeviceArray> sda(pmap_lib, "ShardedDeviceArray"); in BuildPmapSubmodule()
417 py::class_<PmapFunction, std::unique_ptr<PmapFunction>> cfun(pmap_lib, in BuildPmapSubmodule()
422 pmap_lib.def( in BuildPmapSubmodule()
DBUILD326 name = "pmap_lib",
327 srcs = ["pmap_lib.cc"],
328 hdrs = ["pmap_lib.h"],
510 ":pmap_lib",