Home
last modified time | relevance | path

Searched refs:ParallelTensor (Results 1 – 7 of 7) sorted by relevance

/external/tensorflow/tensorflow/c/eager/parallel_device/
Dparallel_device.cc44 absl::variant<std::unique_ptr<ParallelTensor>, TensorHandlePtr>;
47 absl::variant<ParallelTensor*, TFE_TensorHandle*>;
89 if (absl::holds_alternative<ParallelTensor*>(inputs[i])) { in ExecuteWithSpecialOps()
104 result_content.push_back(ParallelTensor::FromTensorHandles( in ExecuteWithSpecialOps()
131 ParallelTensor* t = absl::get<ParallelTensor*>(inputs[0]); in ExecuteWithSpecialOps()
143 std::vector<ParallelTensor*> parallel_inputs; in ExecuteWithSpecialOps()
144 std::vector<std::unique_ptr<ParallelTensor>> implicitly_broadcast_tensors; in ExecuteWithSpecialOps()
162 std::unique_ptr<ParallelTensor> parallel_tensor( in ExecuteWithSpecialOps()
181 parallel_inputs.push_back(absl::get<ParallelTensor*>(input)); in ExecuteWithSpecialOps()
184 absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> in ExecuteWithSpecialOps()
[all …]
Dparallel_device_lib.h49 class ParallelTensor; variable
68 std::unique_ptr<ParallelTensor> CopyToParallelDevice(TFE_Context* context,
74 std::unique_ptr<ParallelTensor> ScalarsFromSequence(
79 std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context,
101 absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> Execute(
102 TFE_Context* context, const std::vector<ParallelTensor*>& inputs,
122 const std::vector<ParallelTensor*>& inputs,
139 absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> Join(
175 class ParallelTensor {
180 static std::unique_ptr<ParallelTensor> FromTensorHandles(
[all …]
Dparallel_device_lib.cc298 std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice( in CopyToParallelDevice()
308 return ParallelTensor::FromTensorHandles(*this, std::move(components), in CopyToParallelDevice()
312 std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs( in DeviceIDs()
322 absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
324 const std::vector<ParallelTensor*>& inputs, in Execute()
347 const std::vector<ParallelTensor*>& inputs, in StartExecute()
412 absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
416 absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> result; in Join()
457 std::vector<std::unique_ptr<ParallelTensor>> per_device_outputs; in Join()
466 per_device_outputs.push_back(ParallelTensor::FromTensorHandles( in Join()
[all …]
Dparallel_device_lib_test.cc64 parallel_device.Execute(context.get(), std::vector<ParallelTensor*>(), in TEST()
68 const std::vector<std::unique_ptr<ParallelTensor>>& handles = *outputs; in TEST()
69 std::vector<ParallelTensor*> handle_inputs; in TEST()
85 parallel_device.Execute(context.get(), std::vector<ParallelTensor*>(), in TEST()
120 parallel_device.StartExecute(context.get(), std::vector<ParallelTensor*>(), in TEST()
127 const std::vector<std::unique_ptr<ParallelTensor>>& handles = *outputs; in TEST()
188 std::unique_ptr<ParallelTensor> reduced_values = in TEST()
192 std::unique_ptr<ParallelTensor> run_collective = in TEST()
202 ParallelTensor* parallel_result = (*outputs)[0].get(); in TEST()
252 std::unique_ptr<ParallelTensor> unknown_length_vector = in TEST()
[all …]
/external/tensorflow/tensorflow/dtensor/cc/
Ddtensor_device_util.h160 parallel_device::ParallelTensor* DeviceIDs(TFE_Context* context,
200 mutable std::unique_ptr<parallel_device::ParallelTensor> device_ids_tensor_;
220 std::unique_ptr<parallel_device::ParallelTensor> tensor,
292 virtual parallel_device::ParallelTensor* tensor() const { in tensor()
328 TensorWithLayout(std::unique_ptr<parallel_device::ParallelTensor> tensor,
340 std::unique_ptr<parallel_device::ParallelTensor> tensor_;
428 std::unique_ptr<parallel_device::ParallelTensor> tensor, in ResourceHandleWithLayout()
452 std::unique_ptr<parallel_device::ParallelTensor> indices_tensor,
453 std::unique_ptr<parallel_device::ParallelTensor> values_tensor,
454 std::unique_ptr<parallel_device::ParallelTensor> shapes_tensor,
[all …]
Ddtensor_device_util.cc56 std::unique_ptr<parallel_device::ParallelTensor>
83 std::unique_ptr<parallel_device::ParallelTensor> parallel_tensor = in BroadcastTensorHandleToParallelTensor()
84 parallel_device::ParallelTensor::FromTensorHandles( in BroadcastTensorHandleToParallelTensor()
149 std::unique_ptr<parallel_device::ParallelTensor> parallel_tensor = in BroadcastResourceTensor()
307 std::unique_ptr<parallel_device::ParallelTensor> parallel_tensor = in Broadcast()
332 std::unique_ptr<parallel_device::ParallelTensor> tensor, in Wrap()
434 std::unique_ptr<parallel_device::ParallelTensor> indices_tensor, in Wrap()
435 std::unique_ptr<parallel_device::ParallelTensor> values_tensor, in Wrap()
436 std::unique_ptr<parallel_device::ParallelTensor> shapes_tensor, in Wrap()
993 StatusOr<std::vector<parallel_device::ParallelTensor*>> PrepareEmbeddingInputs( in PrepareEmbeddingInputs()
[all …]
Ddtensor_device.cc353 const std::vector<parallel_device::ParallelTensor*>& parallel_inputs,
429 parallel_device::ParallelTensor* MeshWithParallelDevice::DeviceIDs( in DeviceIDs()
883 std::unique_ptr<parallel_device::ParallelTensor> parallel_tensor = in Pack()
884 parallel_device::ParallelTensor::FromTensorHandles( in Pack()
1028 std::unique_ptr<parallel_device::ParallelTensor> parallel_indices_tensor = in SparsePack()
1029 parallel_device::ParallelTensor::FromTensorHandles( in SparsePack()
1033 std::unique_ptr<parallel_device::ParallelTensor> parallel_values_tensor = in SparsePack()
1034 parallel_device::ParallelTensor::FromTensorHandles( in SparsePack()
1038 std::unique_ptr<parallel_device::ParallelTensor> in SparsePack()
1040 parallel_device::ParallelTensor::FromTensorHandles( in SparsePack()
[all …]