Home
last modified time | relevance | path

Searched refs:source_target_pairs (Results 1 – 25 of 42) sorted by relevance

12

/external/tensorflow/tensorflow/python/tpu/ops/
Dtpu_ops.py113 def collective_permute(x, source_target_pairs, name=None): argument
136 return gen_tpu_ops.collective_permute(x, source_target_pairs, name=name)
144 source_target_pairs = op.inputs[1][:, ::-1]
145 return [gen_tpu_ops.collective_permute(grad, source_target_pairs), None]
/external/tensorflow/tensorflow/core/tpu/kernels/
Dcross_replica_ops.cc114 std::vector<std::pair<int64, int64>> source_target_pairs(num_pairs); in Compile() local
116 source_target_pairs[i] = {source_target_literal.Get<int64>({i, 0}), in Compile()
120 xla::CollectivePermute(ctx->Input(0), source_target_pairs)); in Compile()
/external/tensorflow/tensorflow/core/api_def/base_api/
Dapi_def_CollectivePermute.pbtxt12 name: "source_target_pairs"
34 source_target_pairs=`[[0,1],[1,2],[2,3],[3,0]]` gets the outputs:
/external/tensorflow/tensorflow/compiler/xla/service/gpu/
Dcollective_permute_thunk.h32 std::vector<std::pair<int64, int64>> source_target_pairs,
Dcollective_permute_thunk.cc223 std::vector<std::pair<int64, int64>> source_target_pairs, in CollectivePermuteThunk() argument
226 source_target_pairs_(std::move(source_target_pairs)), in CollectivePermuteThunk()
/external/tensorflow/tensorflow/core/ops/compat/ops_history_v2/
DCollectivePermute.pbtxt8 name: "source_target_pairs"
/external/tensorflow/tensorflow/core/ops/compat/ops_history_v1/
DCollectivePermute.pbtxt8 name: "source_target_pairs"
/external/tensorflow/tensorflow/compiler/xla/service/cpu/
Dcpu_runtime.h183 void* output_buffer, const void* source_target_pairs,
Dcpu_runtime.cc722 void* output_buffer, const void* source_target_pairs, in __xla_cpu_runtime_CollectivePermute() argument
726 static_cast<const char*>(source_target_pairs), source_target_pairs_size); in __xla_cpu_runtime_CollectivePermute()
Dir_emitter.cc1266 std::string source_target_pairs = absl::StrJoin( in HandleCollectivePermute() local
1267 instr->source_target_pairs(), ",", absl::PairFormatter("=")); in HandleCollectivePermute()
1269 b_.CreateGlobalStringPtr(source_target_pairs); in HandleCollectivePermute()
1295 /*source_target_pairs_size=*/b_.getInt32(source_target_pairs.size())}, in HandleCollectivePermute()
/external/tensorflow/tensorflow/compiler/mlir/xla/
Dhlo_function_importer.h71 source_target_pairs,
Dhlo_function_importer.cc315 instruction->source_target_pairs(), builder_)); in ImportInstructionImpl()
939 source_target_pairs, in ConvertSourceTargetPairs() argument
941 std::vector<int64_t> attr(source_target_pairs.size() * 2); in ConvertSourceTargetPairs()
942 for (auto p : llvm::enumerate(source_target_pairs)) { in ConvertSourceTargetPairs()
Dmlir_hlo_to_hlo.cc172 llvm::Optional<mlir::DenseIntElementsAttr> source_target_pairs) { in Convert_source_target_pairs() argument
173 return xla::ConvertNx2Attribute(source_target_pairs).ValueOrDie(); in Convert_source_target_pairs()
/external/tensorflow/tensorflow/compiler/xla/service/
Dhlo_instruction.cc478 std::vector<std::pair<int64, int64>> source_target_pairs( in CreateFromProto() local
484 for (int i = 0; i < source_target_pairs.size(); i++) { in CreateFromProto()
485 source_target_pairs[i].first = proto.source_target_pairs(i).source(); in CreateFromProto()
486 source_target_pairs[i].second = proto.source_target_pairs(i).target(); in CreateFromProto()
491 source_target_pairs, channel_id); in CreateFromProto()
494 shape, operands(0), source_target_pairs, channel_id); in CreateFromProto()
1075 const std::vector<std::pair<int64, int64>>& source_target_pairs, in CreateCollectivePermute() argument
1078 HloOpcode::kCollectivePermute, shape, operand, source_target_pairs, in CreateCollectivePermute()
1085 const std::vector<std::pair<int64, int64>>& source_target_pairs, in CreateCollectivePermuteStart() argument
1088 HloOpcode::kCollectivePermuteStart, shape, operand, source_target_pairs, in CreateCollectivePermuteStart()
[all …]
Dhlo_instructions.cc770 const std::vector<std::pair<int64, int64>>& source_target_pairs, in HloCollectivePermuteInstruction() argument
773 source_target_pairs_(source_target_pairs) { in HloCollectivePermuteInstruction()
779 for (const auto& pair : source_target_pairs()) { in ToProto()
793 for (const auto& pair : source_target_pairs()) { in ExtraAttributesToStringImpl()
811 absl::c_equal(source_target_pairs(), in IdenticalSlowPathIgnoringChannelIdValues()
812 casted_other.source_target_pairs(), in IdenticalSlowPathIgnoringChannelIdValues()
822 opcode(), shape, new_operands[0], source_target_pairs(), channel_id()); in CloneWithNewOperandsImpl()
Dhlo_instruction.h698 const std::vector<std::pair<int64, int64>>& source_target_pairs,
705 const std::vector<std::pair<int64, int64>>& source_target_pairs,
1816 const std::vector<std::pair<int64, int64>>& source_target_pairs() const;
Dhlo_instructions.h504 const std::vector<std::pair<int64, int64>>& source_target_pairs,
507 const std::vector<std::pair<int64, int64>>& source_target_pairs() const { in source_target_pairs() function
Dhlo_parser_test.cc1665 ROOT root = f32[128,32]{0,1} collective-permute(input), source_target_pairs={{0,1},{1,2},{2,3}} in CreateTestCases()
1678 …, f32[128,32]{0,1}, u32[], u32[]) collective-permute-start(input), source_target_pairs={{0,1},{1,2… in CreateTestCases()
/external/tensorflow/tensorflow/compiler/xla/service/spmd/
Dspmd_partitioner_util.cc821 std::vector<std::pair<int64, int64>> source_target_pairs; in ExchangeHalo() local
827 source_target_pairs.emplace_back( in ExchangeHalo()
846 b, source_halo_slice, source_target_pairs, (*next_channel_id)++); in ExchangeHalo()
855 std::vector<std::pair<int64, int64>> source_target_pairs; in ExchangeHalo() local
861 source_target_pairs.emplace_back( in ExchangeHalo()
879 b, source_halo_slice, source_target_pairs, (*next_channel_id)++); in ExchangeHalo()
/external/tensorflow/tensorflow/compiler/mlir/hlo/tests/
Dops.mlir318 source_target_pairs = dense<[[0, 1], [0, 2], [2, 3]]> : tensor<3x2xi64>
328 source_target_pairs = dense<[[0, 1], [1, 2], [2, 1]]> : tensor<3x2xi64>
336 // expected-error@+1 {{expect source_target_pairs attribute to be of rank 2, but got rank 1}}
338 source_target_pairs = dense<[0, 1]> : tensor<2xi64>
346 // expected-error@+1 {{expect source_target_pairs attribute of shape (N, 2), but got (2, 3)}}
348 source_target_pairs = dense<[[0, 1, 2], [3, 4, 5]]> : tensor<2x3xi64>
Dlhlo_ops.mlir779 source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>
783 source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>,
/external/tensorflow/tensorflow/compiler/mlir/xla/transforms/
Dlegalize_tf_patterns.td281 def : Pat<(TF_CollectivePermuteOp $input, (ConstantLikeMatcher ElementsAttr:$source_target_pairs)),
283 (CastElementsToI64Elements $source_target_pairs))>;
/external/tensorflow/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/
Dhlo_text_to_lhlo_no_opt.hlotxt635 // CHECK-SAME{LITERAL}: source_target_pairs = dense<[[0, 1], [0, 2], [1, 0]]> : tensor<3x2xi64>
640 source_target_pairs={{0,1}, {0,2}, {1,0}}, channel_id=2
/external/tensorflow/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/
Dhlo_ops.cc483 auto type = op.source_target_pairs().getType().dyn_cast<RankedTensorType>(); in Verify()
495 for (auto i = op.source_target_pairs().begin(), in Verify()
496 e = op.source_target_pairs().end(); in Verify()
/external/tensorflow/tensorflow/compiler/xla/client/
Dxla_builder.h758 const std::vector<std::pair<int64, int64>>& source_target_pairs);
1317 const std::vector<std::pair<int64, int64>>& source_target_pairs);
2243 const std::vector<std::pair<int64, int64>>& source_target_pairs);

12