Searched refs:ParallelTensor (Results 1 – 7 of 7) sorted by relevance
| /external/tensorflow/tensorflow/c/eager/parallel_device/ |
| D | parallel_device.cc | 44 absl::variant<std::unique_ptr<ParallelTensor>, TensorHandlePtr>; 47 absl::variant<ParallelTensor*, TFE_TensorHandle*>; 89 if (absl::holds_alternative<ParallelTensor*>(inputs[i])) { in ExecuteWithSpecialOps() 104 result_content.push_back(ParallelTensor::FromTensorHandles( in ExecuteWithSpecialOps() 131 ParallelTensor* t = absl::get<ParallelTensor*>(inputs[0]); in ExecuteWithSpecialOps() 143 std::vector<ParallelTensor*> parallel_inputs; in ExecuteWithSpecialOps() 144 std::vector<std::unique_ptr<ParallelTensor>> implicitly_broadcast_tensors; in ExecuteWithSpecialOps() 162 std::unique_ptr<ParallelTensor> parallel_tensor( in ExecuteWithSpecialOps() 181 parallel_inputs.push_back(absl::get<ParallelTensor*>(input)); in ExecuteWithSpecialOps() 184 absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> in ExecuteWithSpecialOps() [all …]
|
| D | parallel_device_lib.h | 49 class ParallelTensor; variable 68 std::unique_ptr<ParallelTensor> CopyToParallelDevice(TFE_Context* context, 74 std::unique_ptr<ParallelTensor> ScalarsFromSequence( 79 std::unique_ptr<ParallelTensor> DeviceIDs(TFE_Context* context, 101 absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> Execute( 102 TFE_Context* context, const std::vector<ParallelTensor*>& inputs, 122 const std::vector<ParallelTensor*>& inputs, 139 absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> Join( 175 class ParallelTensor { 180 static std::unique_ptr<ParallelTensor> FromTensorHandles( [all …]
|
| D | parallel_device_lib.cc | 298 std::unique_ptr<ParallelTensor> ParallelDevice::CopyToParallelDevice( in CopyToParallelDevice() 308 return ParallelTensor::FromTensorHandles(*this, std::move(components), in CopyToParallelDevice() 312 std::unique_ptr<ParallelTensor> ParallelDevice::DeviceIDs( in DeviceIDs() 322 absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> 324 const std::vector<ParallelTensor*>& inputs, in Execute() 347 const std::vector<ParallelTensor*>& inputs, in StartExecute() 412 absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> 416 absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> result; in Join() 457 std::vector<std::unique_ptr<ParallelTensor>> per_device_outputs; in Join() 466 per_device_outputs.push_back(ParallelTensor::FromTensorHandles( in Join() [all …]
|
| D | parallel_device_lib_test.cc | 64 parallel_device.Execute(context.get(), std::vector<ParallelTensor*>(), in TEST() 68 const std::vector<std::unique_ptr<ParallelTensor>>& handles = *outputs; in TEST() 69 std::vector<ParallelTensor*> handle_inputs; in TEST() 85 parallel_device.Execute(context.get(), std::vector<ParallelTensor*>(), in TEST() 120 parallel_device.StartExecute(context.get(), std::vector<ParallelTensor*>(), in TEST() 127 const std::vector<std::unique_ptr<ParallelTensor>>& handles = *outputs; in TEST() 188 std::unique_ptr<ParallelTensor> reduced_values = in TEST() 192 std::unique_ptr<ParallelTensor> run_collective = in TEST() 202 ParallelTensor* parallel_result = (*outputs)[0].get(); in TEST() 252 std::unique_ptr<ParallelTensor> unknown_length_vector = in TEST() [all …]
|
| /external/tensorflow/tensorflow/dtensor/cc/ |
| D | dtensor_device_util.h | 160 parallel_device::ParallelTensor* DeviceIDs(TFE_Context* context, 200 mutable std::unique_ptr<parallel_device::ParallelTensor> device_ids_tensor_; 220 std::unique_ptr<parallel_device::ParallelTensor> tensor, 292 virtual parallel_device::ParallelTensor* tensor() const { in tensor() 328 TensorWithLayout(std::unique_ptr<parallel_device::ParallelTensor> tensor, 340 std::unique_ptr<parallel_device::ParallelTensor> tensor_; 428 std::unique_ptr<parallel_device::ParallelTensor> tensor, in ResourceHandleWithLayout() 452 std::unique_ptr<parallel_device::ParallelTensor> indices_tensor, 453 std::unique_ptr<parallel_device::ParallelTensor> values_tensor, 454 std::unique_ptr<parallel_device::ParallelTensor> shapes_tensor, [all …]
|
| D | dtensor_device_util.cc | 56 std::unique_ptr<parallel_device::ParallelTensor> 83 std::unique_ptr<parallel_device::ParallelTensor> parallel_tensor = in BroadcastTensorHandleToParallelTensor() 84 parallel_device::ParallelTensor::FromTensorHandles( in BroadcastTensorHandleToParallelTensor() 149 std::unique_ptr<parallel_device::ParallelTensor> parallel_tensor = in BroadcastResourceTensor() 307 std::unique_ptr<parallel_device::ParallelTensor> parallel_tensor = in Broadcast() 332 std::unique_ptr<parallel_device::ParallelTensor> tensor, in Wrap() 434 std::unique_ptr<parallel_device::ParallelTensor> indices_tensor, in Wrap() 435 std::unique_ptr<parallel_device::ParallelTensor> values_tensor, in Wrap() 436 std::unique_ptr<parallel_device::ParallelTensor> shapes_tensor, in Wrap() 993 StatusOr<std::vector<parallel_device::ParallelTensor*>> PrepareEmbeddingInputs( in PrepareEmbeddingInputs() [all …]
|
| D | dtensor_device.cc | 353 const std::vector<parallel_device::ParallelTensor*>& parallel_inputs, 429 parallel_device::ParallelTensor* MeshWithParallelDevice::DeviceIDs( in DeviceIDs() 883 std::unique_ptr<parallel_device::ParallelTensor> parallel_tensor = in Pack() 884 parallel_device::ParallelTensor::FromTensorHandles( in Pack() 1028 std::unique_ptr<parallel_device::ParallelTensor> parallel_indices_tensor = in SparsePack() 1029 parallel_device::ParallelTensor::FromTensorHandles( in SparsePack() 1033 std::unique_ptr<parallel_device::ParallelTensor> parallel_values_tensor = in SparsePack() 1034 parallel_device::ParallelTensor::FromTensorHandles( in SparsePack() 1038 std::unique_ptr<parallel_device::ParallelTensor> in SparsePack() 1040 parallel_device::ParallelTensor::FromTensorHandles( in SparsePack() [all …]
|