Searched defs:get_jax_enable_x64 (Results 1 – 2 of 2) sorted by relevance
135 py::function get_jax_enable_x64, std::vector<int> static_argnums) in PmapFunction()426 std::vector<int> static_argnums) -> std::unique_ptr<PmapFunction> { in BuildPmapSubmodule()
918 py::function get_jax_enable_x64, in CompiledFunction()1224 std::vector<int> static_argnums) -> std::unique_ptr<CompiledFunction> { in BuildJaxjitSubmodule()