Home
last modified time | relevance | path

Searched refs:cross_replica_sum (Results 1 – 7 of 7) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/tests/
Dall_reduce_test.cc124 …%convert.11 = u8[] convert(f32[] %p0.1), metadata={op_type="xla::cross_replica_sum" source_file="a… in XLA_TEST_F()
125 …dynamic-update-slice.10, u8[] %convert.11), metadata={op_type="xla::cross_replica_sum" source_file… in XLA_TEST_F()
126 …ement((u8[8]{0}, u8[]) %tuple.12), index=0, metadata={op_type="xla::cross_replica_sum" source_file… in XLA_TEST_F()
127 …ement((u8[8]{0}, u8[]) %tuple.12), index=1, metadata={op_type="xla::cross_replica_sum" source_file… in XLA_TEST_F()
128 …n_layout=true, to_apply=%AddComputation.15, metadata={op_type="xla::cross_replica_sum" source_file… in XLA_TEST_F()
129 …((u8[8]{0}, u8[]) %all-reduce.19), index=1, metadata={op_type="xla::cross_replica_sum" source_file… in XLA_TEST_F()
130 … f32[] convert(u8[] %get-tuple-element.21), metadata={op_type="xla::cross_replica_sum" source_file… in XLA_TEST_F()
131 …((u8[8]{0}, u8[]) %all-reduce.19), index=0, metadata={op_type="xla::cross_replica_sum" source_file… in XLA_TEST_F()
166 …%convert.11 = s32[] convert(f32[] %p0.1), metadata={op_type="xla::cross_replica_sum" source_file="… in XLA_TEST_F()
167 …ynamic-update-slice.10, s32[] %convert.11), metadata={op_type="xla::cross_replica_sum" source_file… in XLA_TEST_F()
[all …]
/external/tensorflow/tensorflow/compiler/xla/service/
Dar_crs_combiner.h97 ArCrsPair(HloInstruction* all_reduce, HloInstruction* cross_replica_sum, in ArCrsPair()
99 : ar(all_reduce), crs(cross_replica_sum), distance(dist) {} in ArCrsPair()
/external/tensorflow/tensorflow/python/tpu/ops/
Dtpu_ops.py94 def cross_replica_sum(x, group_assignment=None, name=None): function
110 return gen_tpu_ops.cross_replica_sum(x, group_assignment, name=name)
152 return [gen_tpu_ops.cross_replica_sum(grad, op.inputs[1]), None]
/external/tensorflow/tensorflow/tools/api/golden/v1/
Dtensorflow.tpu.pbtxt32 name: "cross_replica_sum"
/external/tensorflow/tensorflow/python/tpu/
Dtpu_optimizer.py190 summed_grads_and_vars.append((tpu_ops.cross_replica_sum(
/external/tensorflow/tensorflow/python/distribute/
Dtpu_strategy.py1225 return tpu_ops.cross_replica_sum(value)
/external/tensorflow/tensorflow/compiler/mlir/xla/tests/
Dlegalize-tf.mlir3864 // CHECK-LABEL: @cross_replica_sum
3865 func @cross_replica_sum(%input: tensor<10xf32>) -> tensor<10xf32> {