Home
last modified time | relevance | path

Searched refs:replicate_op (Results 1 – 7 of 7) sorted by relevance

/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/
Dhoist_replicate_invariant_resource_writes.cc59 tf_device::ReplicateOp replicate_op, in MoveTailWritesAfterReplicate() argument
61 const auto num_replicas = replicate_op.n(); in MoveTailWritesAfterReplicate()
63 replicate_op.getRegion().front().getTerminator()); in MoveTailWritesAfterReplicate()
68 auto new_result_types = llvm::to_vector<4>(replicate_op->getResultTypes()); in MoveTailWritesAfterReplicate()
75 OpBuilder builder(replicate_op); in MoveTailWritesAfterReplicate()
78 replicate_op->getLoc(), new_result_types, replicate_op->getOperands(), in MoveTailWritesAfterReplicate()
79 replicate_op->getAttrs()); in MoveTailWritesAfterReplicate()
82 new_replicate_op.getRegion().takeBody(replicate_op.getRegion()); in MoveTailWritesAfterReplicate()
85 int old_num_results = replicate_op->getNumResults(); in MoveTailWritesAfterReplicate()
86 replicate_op->replaceAllUsesWith( in MoveTailWritesAfterReplicate()
[all …]
Dreplicate_invariant_op_hoisting.cc46 void MakeShapeOpInvariant(tf_device::ReplicateOp replicate_op, int num_replicas, in MakeShapeOpInvariant() argument
54 shape_op.setOperand(replicate_op.GetReplicaOperandForBlockArgument( in MakeShapeOpInvariant()
79 replicate_op.GetReplicaOperandForBlockArgument(block_arg, in MakeShapeOpInvariant()
124 void HoistReplicateInvariantOps(tf_device::ReplicateOp replicate_op) { in HoistReplicateInvariantOps() argument
125 const int num_replicas = replicate_op.n(); in HoistReplicateInvariantOps()
126 Block* replicate_block = &replicate_op.GetBody(); in HoistReplicateInvariantOps()
128 replicate_op.walk([&](TF::ShapeOp shape_op) { in HoistReplicateInvariantOps()
129 MakeShapeOpInvariant(replicate_op, num_replicas, replicate_block, shape_op); in HoistReplicateInvariantOps()
132 Region* replicate_region = &replicate_op.body(); in HoistReplicateInvariantOps()
133 Optional<DictionaryAttr> virtual_device_list = replicate_op.devices(); in HoistReplicateInvariantOps()
[all …]
Dreplicate_to_island.cc142 tf_executor::IslandOp island_op, tf_device::ReplicateOp replicate_op, in ExpandReplicateIntoReplicas() argument
145 auto devices = replicate_op.devices(); in ExpandReplicateIntoReplicas()
148 Operation& terminator = replicate_op.GetBody().back(); in ExpandReplicateIntoReplicas()
168 for (auto& block_arg : replicate_op.GetBody().getArguments()) in ExpandReplicateIntoReplicas()
170 replicate_op.GetReplicaOperandForBlockArgument(block_arg, i)); in ExpandReplicateIntoReplicas()
173 replicate_op.body().cloneInto(&replica.body(), mapping); in ExpandReplicateIntoReplicas()
175 if (failed(UpdateRegionReplicateVariantOps(builder, replicate_op.getLoc(), in ExpandReplicateIntoReplicas()
239 tf_device::ReplicateOp replicate_op) { in CreateIslandsFromReplicate() argument
241 const int num_replicas = replicate_op.n(); in CreateIslandsFromReplicate()
246 replicate_op, num_replicas, replicas))) in CreateIslandsFromReplicate()
[all …]
Dtpu_colocate_composite_resource_ops.cc95 tf_device::ReplicateOp replicate_op, OpBuilder* builder) { in ColocateCompositeResourceOpsInReplicate() argument
96 auto devices = replicate_op.devices(); in ColocateCompositeResourceOpsInReplicate()
102 GetResourceOpsUsingCompositeArgsInReplicate(replicate_op); in ColocateCompositeResourceOpsInReplicate()
Dtpu_cluster_formation.cc462 auto replicate_op = builder.create<tf_device::ReplicateOp>( in ReplicateCluster() local
468 replicate_op->setAttr(kMirroredVariableIndicesAttr, in ReplicateCluster()
476 std::next(replicate_op.result_begin(), idx * num_replicas), in ReplicateCluster()
477 std::next(replicate_op.result_begin(), (idx + 1) * num_replicas)); in ReplicateCluster()
502 replicate_op.GetBody().getArguments())) { in ReplicateCluster()
519 builder.setInsertionPointToEnd(&replicate_op.GetBody()); in ReplicateCluster()
520 auto return_op = builder.create<tf_device::ReturnOp>(replicate_op.getLoc(), in ReplicateCluster()
Dtpu_reorder_replicate_and_partitioned_inputs.cc91 auto replicate_op = builder.create<TF::TPUReplicatedInputOp>( in ReorderReplicateAndPartitionedInputs() local
94 operands_per_core.push_back(replicate_op); in ReorderReplicateAndPartitionedInputs()
Dtpu_variable_runtime_reformatting.cc479 while_op.body().walk([&](tf_device::ReplicateOp replicate_op) { in runOnOperation() argument
481 replicate = replicate_op; in runOnOperation()