Home
last modified time | relevance | path

Searched refs:static_argnums (Results 1 – 3 of 3) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/python/
Djax_jit.cc185 absl::Span<int const> static_argnums, in ParseArguments() argument
187 if (static_argnums.size() > args.size()) { in ParseArguments()
192 static_argnums.size()); in ParseArguments()
194 args.size() - static_argnums.size()); in ParseArguments()
198 if (std::find(static_argnums.begin(), static_argnums.end(), i) == in ParseArguments()
199 static_argnums.end()) { in ParseArguments()
846 std::vector<int> static_argnums);
920 std::vector<int> static_argnums) in CompiledFunction() argument
923 static_argnums_(std::move(static_argnums)), in CompiledFunction()
1224 std::vector<int> static_argnums) -> std::unique_ptr<CompiledFunction> { in BuildJaxjitSubmodule() argument
[all …]
Dpmap_lib.cc135 py::function get_jax_enable_x64, std::vector<int> static_argnums) in PmapFunction() argument
138 static_argnums_(std::move(static_argnums)), in PmapFunction()
426 std::vector<int> static_argnums) -> std::unique_ptr<PmapFunction> { in BuildPmapSubmodule() argument
429 std::move(get_jax_enable_x64), std::move(static_argnums)); in BuildPmapSubmodule()
Djax_jit.h138 absl::Span<int const> static_argnums,