Searched refs:xla_sharding (Results 1 – 14 of 14) sorted by relevance
/external/tensorflow/tensorflow/python/training/ |
D | slot_creator_test.py | 23 from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding 184 v = xla_sharding.mesh_split( 189 xla_sharding.get_tensor_sharding(v), 190 xla_sharding.get_tensor_sharding(slot)) 196 v = xla_sharding.mesh_split( 202 xla_sharding.get_tensor_sharding(v), 203 xla_sharding.get_tensor_sharding(slot))
|
D | slot_creator.py | 42 from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding 112 slot = xla_sharding.copy_sharding(primary, slot, use_sharding_op=False)
|
D | adam_test.py | 23 from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding 394 xla_sharding.mesh_split( 410 self.assertIsNone(xla_sharding.get_tensor_sharding(beta1_power)) 411 self.assertIsNone(xla_sharding.get_tensor_sharding(beta2_power)) 415 self.assertIsNotNone(xla_sharding.get_tensor_sharding(v)) 418 self.assertIsNotNone(xla_sharding.get_tensor_sharding(slot))
|
D | moving_averages_test.py | 23 from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding 501 self.assertIsNone(xla_sharding.get_tensor_sharding(v)) 502 v = xla_sharding.mesh_split(v, np.array([0, 1]), [0], use_sharding_op=False) 503 self.assertIsNotNone(xla_sharding.get_tensor_sharding(v)) 508 xla_sharding.get_tensor_sharding(v), 509 xla_sharding.get_tensor_sharding(avg))
|
D | BUILD | 503 "//tensorflow/compiler/xla/experimental/xla_sharding",
|
/external/tensorflow/tensorflow/compiler/xla/experimental/xla_sharding/ |
D | BUILD | 10 name = "xla_sharding", 11 srcs = ["xla_sharding.py"],
|
/external/tensorflow/tensorflow/compiler/tf2xla/ |
D | sharding_util_test.cc | 67 AttrValue xla_sharding; in TEST_P() local 68 xla_sharding.set_s(""); in TEST_P() 74 {{"_XlaSharding", xla_sharding}, {"index", index}, {"T", type}}); in TEST_P()
|
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/ |
D | tpu_reorder_replicate_and_partitioned_inputs.cc | 48 llvm::Optional<::llvm::StringRef> xla_sharding = in ReorderReplicateAndPartitionedInputs() local 68 if (xla_sharding != op_xla_sharding) in ReorderReplicateAndPartitionedInputs()
|
/external/tensorflow/tensorflow/python/tpu/ |
D | tpu_feed.py | 28 from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding 99 return xla_sharding.replicate(tensor, assign_tuple_sharding=True) 101 return xla_sharding.assign_device(tensor, 0, assign_tuple_sharding=True) 104 return xla_sharding.tile(
|
D | BUILD | 235 "//tensorflow/compiler/xla/experimental/xla_sharding", 311 "//tensorflow/compiler/xla/experimental/xla_sharding",
|
/external/tensorflow/tensorflow/python/distribute/ |
D | tpu_strategy.py | 31 from tensorflow.compiler.xla.experimental.xla_sharding import xla_sharding 457 return xla_sharding.assign_device( 548 return xla_sharding.tile(tensor, tile_assignment, use_sharding_op=True) 596 return xla_sharding.replicate(tensor, use_sharding_op=True)
|
D | BUILD | 592 "//tensorflow/compiler/xla/experimental/xla_sharding",
|
/external/tensorflow/tensorflow/compiler/tf2xla/python/ |
D | xla.py | 442 sharding = gen_xla_ops.xla_sharding 448 grad_sharding = gen_xla_ops.xla_sharding(grad, sharding=sharding_attr)
|
/external/tensorflow/tensorflow/compiler/mlir/xla/tests/ |
D | legalize-tf.mlir | 4811 // CHECK-LABEL: xla_sharding 4812 func @xla_sharding(%arg0: tensor<4x16xf32>) -> tensor<4x16xf32> {
|