Home
last modified time | relevance | path

Searched refs:xla_sharding (Results 1 – 14 of 14) sorted by relevance

/external/tensorflow/tensorflow/python/training/
Dslot_creator_test.py23 from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
184 v = xla_sharding.mesh_split(
189 xla_sharding.get_tensor_sharding(v),
190 xla_sharding.get_tensor_sharding(slot))
196 v = xla_sharding.mesh_split(
202 xla_sharding.get_tensor_sharding(v),
203 xla_sharding.get_tensor_sharding(slot))
Dslot_creator.py42 from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
112 slot = xla_sharding.copy_sharding(primary, slot, use_sharding_op=False)
Dadam_test.py23 from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
394 xla_sharding.mesh_split(
410 self.assertIsNone(xla_sharding.get_tensor_sharding(beta1_power))
411 self.assertIsNone(xla_sharding.get_tensor_sharding(beta2_power))
415 self.assertIsNotNone(xla_sharding.get_tensor_sharding(v))
418 self.assertIsNotNone(xla_sharding.get_tensor_sharding(slot))
Dmoving_averages_test.py23 from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
501 self.assertIsNone(xla_sharding.get_tensor_sharding(v))
502 v = xla_sharding.mesh_split(v, np.array([0, 1]), [0], use_sharding_op=False)
503 self.assertIsNotNone(xla_sharding.get_tensor_sharding(v))
508 xla_sharding.get_tensor_sharding(v),
509 xla_sharding.get_tensor_sharding(avg))
DBUILD503 "//tensorflow/compiler/xla/experimental/xla_sharding",
/external/tensorflow/tensorflow/compiler/xla/experimental/xla_sharding/
DBUILD10 name = "xla_sharding",
11 srcs = ["xla_sharding.py"],
/external/tensorflow/tensorflow/compiler/tf2xla/
Dsharding_util_test.cc67 AttrValue xla_sharding; in TEST_P() local
68 xla_sharding.set_s(""); in TEST_P()
74 {{"_XlaSharding", xla_sharding}, {"index", index}, {"T", type}}); in TEST_P()
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/
Dtpu_reorder_replicate_and_partitioned_inputs.cc48 llvm::Optional<::llvm::StringRef> xla_sharding = in ReorderReplicateAndPartitionedInputs() local
68 if (xla_sharding != op_xla_sharding) in ReorderReplicateAndPartitionedInputs()
/external/tensorflow/tensorflow/python/tpu/
Dtpu_feed.py28 from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
99 return xla_sharding.replicate(tensor, assign_tuple_sharding=True)
101 return xla_sharding.assign_device(tensor, 0, assign_tuple_sharding=True)
104 return xla_sharding.tile(
DBUILD235 "//tensorflow/compiler/xla/experimental/xla_sharding",
311 "//tensorflow/compiler/xla/experimental/xla_sharding",
/external/tensorflow/tensorflow/python/distribute/
Dtpu_strategy.py31 from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding
457 return xla_sharding.assign_device(
548 return xla_sharding.tile(tensor, tile_assignment, use_sharding_op=True)
596 return xla_sharding.replicate(tensor, use_sharding_op=True)
DBUILD592 "//tensorflow/compiler/xla/experimental/xla_sharding",
/external/tensorflow/tensorflow/compiler/tf2xla/python/
Dxla.py442 sharding = gen_xla_ops.xla_sharding
448 grad_sharding = gen_xla_ops.xla_sharding(grad, sharding=sharding_attr)
/external/tensorflow/tensorflow/compiler/mlir/xla/tests/
Dlegalize-tf.mlir4811 // CHECK-LABEL: xla_sharding
4812 func @xla_sharding(%arg0: tensor<4x16xf32>) -> tensor<4x16xf32> {