Home
last modified time | relevance | path

Searched refs:PjRtClient (Results 1 – 22 of 22) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/python/
Dpy_client.h91 explicit PyClient(std::unique_ptr<PjRtClient> pjrt_client);
92 explicit PyClient(std::shared_ptr<PjRtClient> pjrt_client);
94 PjRtClient* pjrt_client() const { return pjrt_client_.get(); } in pjrt_client()
95 std::shared_ptr<PjRtClient> shared_pjrt_client() { return pjrt_client_; } in shared_pjrt_client()
130 PjRtClient::HostBufferSemantics host_buffer_semantics);
133 PjRtClient::HostBufferSemantics host_buffer_semantics);
144 std::shared_ptr<PjRtClient> pjrt_client_;
Doutfeed_receiver_test.cc35 PjRtClient* client) { in CompileAndExecute()
78 StatusOr<std::unique_ptr<PjRtClient>> GetCpuClientWithNonLocalDevice() { in GetCpuClientWithNonLocalDevice()
100 return std::unique_ptr<PjRtClient>(std::make_unique<PjRtStreamExecutorClient>( in GetCpuClientWithNonLocalDevice()
108 TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<PjRtClient> cpu_client, in TEST()
110 std::vector<PjRtClient*> clients{cpu_client.get()}; in TEST()
141 TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<PjRtClient> cpu_client, in TEST()
143 std::vector<PjRtClient*> clients{cpu_client.get()}; in TEST()
186 TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<PjRtClient> cpu_client, in TEST()
188 std::vector<PjRtClient*> clients{cpu_client.get()}; in TEST()
229 TF_ASSERT_OK_AND_ASSIGN(std::shared_ptr<PjRtClient> cpu_client, in TEST()
[all …]
Dxla.cc196 py::enum_<PjRtClient::HostBufferSemantics>(m, "HostBufferSemantics") in PYBIND11_MODULE()
198 PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall) in PYBIND11_MODULE()
200 PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes) in PYBIND11_MODULE()
201 .value("ZERO_COPY", PjRtClient::HostBufferSemantics::kZeroCopy); in PYBIND11_MODULE()
225 PjRtClient::HostBufferSemantics::kZeroCopy) in PYBIND11_MODULE()
233 TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtClient> client, in PYBIND11_MODULE()
239 TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtClient> client, in PYBIND11_MODULE()
249 std::unique_ptr<PjRtClient> client, in PYBIND11_MODULE()
260 TF_ASSIGN_OR_RETURN(std::shared_ptr<PjRtClient> client, in PYBIND11_MODULE()
Dpy_client.cc33 PyClient::PyClient(std::unique_ptr<PjRtClient> pjrt_client) in PyClient()
35 PyClient::PyClient(std::shared_ptr<PjRtClient> pjrt_client) in PyClient()
104 PjRtClient::HostBufferSemantics host_buffer_semantics) { in PjRtBufferFromPyval()
128 PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall) { in PjRtBufferFromPyval()
148 PjRtClient::HostBufferSemantics host_buffer_semantics) { in BufferFromPyval()
Doutfeed_receiver.h47 OutfeedReceiver(Callback callback, absl::Span<PjRtClient* const> clients,
Doutfeed_receiver.cc155 absl::Span<PjRtClient* const> clients,
229 OutfeedReceiver::Callback callback, absl::Span<PjRtClient* const> clients, in OutfeedReceiverImpl()
457 absl::Span<PjRtClient* const> clients, in OutfeedReceiver()
Djax_jit.cc359 const py::handle& scalar, xla::PjRtClient* client, in ConvertToScalarBuffer()
369 xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall, nullptr, in ConvertToScalarBuffer()
625 xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall, in HandleComplex()
634 xla::PjRtClient::HostBufferSemantics::kImmutableOnlyDuringCall, in HandleComplex()
670 xla::PjRtClient::HostBufferSemantics::kZeroCopy)); in HandleBufferFromPyval()
694 xla::PjRtClient::HostBufferSemantics::kZeroCopy)); in HandleUint64()
705 xla::PjRtClient::HostBufferSemantics::kZeroCopy)); in HandleUint64()
734 xla::PjRtClient::HostBufferSemantics::kZeroCopy)); in HandleNdarray()
Doutfeed_receiver_py.cc55 std::vector<PjRtClient*> client_ptrs(clients_.size()); in OutfeedReceiverForPython()
Ddlpack.cc225 StatusOr<PjRtDevice*> DeviceForDLContext(const PjRtClient& client, in DeviceForDLContext()
/external/tensorflow/tensorflow/compiler/xla/pjrt/
Dgpu_multistream_test.cc31 std::unique_ptr<PjRtClient> client, in TEST()
76 PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes, in TEST()
82 PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes, in TEST()
88 PjRtClient::HostBufferSemantics::kImmutableUntilTransferCompletes, in TEST()
Dpjrt_client.h58 class PjRtClient; variable
65 virtual PjRtClient* client() const = 0;
139 class PjRtClient {
141 virtual ~PjRtClient() = default;
280 virtual PjRtClient* client() const = 0;
423 virtual PjRtClient* client() const = 0;
Dinterpreter_device.cc32 StatusOr<std::unique_ptr<PjRtClient>> GetInterpreterClient() { in GetInterpreterClient()
54 return std::unique_ptr<PjRtClient>(std::make_unique<PjRtStreamExecutorClient>( in GetInterpreterClient()
Dcpu_device.cc32 StatusOr<std::unique_ptr<PjRtClient>> GetCpuClient(bool asynchronous) { in GetCpuClient()
60 return std::unique_ptr<PjRtClient>(std::make_unique<PjRtStreamExecutorClient>( in GetCpuClient()
Dpjrt_stream_executor_client.h71 void SetClient(PjRtClient* client) { in SetClient()
84 PjRtClient* client() const override { return client_; } in client()
118 PjRtClient* client_ = nullptr;
121 class PjRtStreamExecutorClient : public PjRtClient {
461 PjRtClient* client, PjRtDevice* device);
545 friend class PjRtClient;
Dinterpreter_device.h32 StatusOr<std::unique_ptr<PjRtClient>> GetInterpreterClient();
Dcpu_device.h31 StatusOr<std::unique_ptr<PjRtClient>> GetCpuClient(bool asynchronous);
Dgpu_device.h57 StatusOr<std::unique_ptr<PjRtClient>> GetGpuClient(
Dtpu_client.h54 StatusOr<std::shared_ptr<PjRtClient>> GetTpuClient(
Dtpu_client.cc175 StatusOr<std::shared_ptr<PjRtClient>> GetTpuClient( in GetTpuClient()
219 return std::shared_ptr<PjRtClient>( in GetTpuClient()
Dgpu_device.cc313 StatusOr<std::unique_ptr<PjRtClient>> GetGpuClient( in GetGpuClient()
336 return std::unique_ptr<PjRtClient>(std::make_unique<GpuClient>( in GetGpuClient()
Dpjrt_stream_executor_client.cc366 bool is_uninitialized_create, PjRtClient* client, in AllocateDestinationBuffer()
999 PjRtClient* client, PjRtDevice* device) in PjRtStreamExecutorBuffer()
1481 PjRtClient* client, LocalDeviceState* local_device, in MakeTupleHelper()
1552 std::shared_ptr<BufferSequencingEvent> definition_event, PjRtClient* client, in OutputBufferHelper()
/external/tensorflow/tensorflow/compiler/xla/python/tpu_driver/client/
Dtpu_client.h56 PjRtClient* client() const override { return nullptr; } in client()