/external/tensorflow/tensorflow/compiler/mlir/lite/transforms/ |
D | while_loop_outline.cc | 51 void OutlineWhile(WhileOp while_op); 65 bool IsAlreadyOutlined(WhileOp while_op) { in IsAlreadyOutlined() argument 73 return just_call(while_op.body()) && just_call(while_op.cond()); in IsAlreadyOutlined() 99 void WhileOutlinePass::OutlineWhile(WhileOp while_op) { in OutlineWhile() argument 107 auto num_loop_carried = while_op.cond().getNumArguments(); in OutlineWhile() 109 while_op.getOperands().drop_front(num_loop_carried); in OutlineWhile() 114 llvm::SmallVector<Region*, 2> regions{&while_op.cond(), &while_op.body()}; in OutlineWhile() 146 if (extra_operands.empty() && IsAlreadyOutlined(while_op)) return; in OutlineWhile() 150 types.reserve(extra_operands.size() + while_op.getNumOperands()); in OutlineWhile() 151 for (Type type : while_op.cond().getArgumentTypes()) types.push_back(type); in OutlineWhile() [all …]
|
D | legalize_tf_while.cc | 49 void RunOnWhile(TF::WhileOp while_op) { in RunOnWhile() argument 50 Operation* op = while_op.getOperation(); in RunOnWhile() 54 while_op.is_stateless()); in RunOnWhile() 56 auto create_region_with_call = [&while_op](FuncOp func, Region& region) { in RunOnWhile() 62 auto call = builder.create<CallOp>(while_op.getLoc(), func, new_operands); in RunOnWhile() 63 builder.create<YieldOp>(while_op.getLoc(), call.getResults()); in RunOnWhile() 67 create_region_with_call(while_op.cond_function(), new_op.cond()); in RunOnWhile() 68 create_region_with_call(while_op.body_function(), new_op.body()); in RunOnWhile() 76 func.getBody().walk([](TF::WhileOp while_op) { RunOnWhile(while_op); }); in RunOnFunction() argument
|
/external/tensorflow/tensorflow/compiler/xla/service/ |
D | while_loop_simplifier.cc | 44 HloInstruction* while_op, absl::flat_hash_set<int64>& used_tuple_indices) { in RemoveDeadTupleIndices() argument 50 HloModule* module = while_op->GetModule(); in RemoveDeadTupleIndices() 51 HloComputation* computation = while_op->parent(); in RemoveDeadTupleIndices() 52 HloInstruction* while_init = while_op->mutable_operand(0); in RemoveDeadTupleIndices() 53 HloComputation* while_cond = while_op->while_condition(); in RemoveDeadTupleIndices() 54 HloComputation* while_body = while_op->while_body(); in RemoveDeadTupleIndices() 201 TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, new_tuple)); in RemoveDeadTupleIndices() 213 static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) { in TryRemoveDeadWhileParams() argument 214 CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); in TryRemoveDeadWhileParams() 219 if (!while_op->parent()->IsSafelyRemovable(while_op)) { in TryRemoveDeadWhileParams() [all …]
|
D | while_loop_analysis_test.cc | 55 HloInstruction* while_op = module->entry_computation()->root_instruction(); in TEST_F() local 56 EXPECT_EQ(*ComputeWhileLoopTripCountUpperBound(while_op), 1); in TEST_F() 86 HloInstruction* while_op = module->entry_computation()->root_instruction(); in TEST_F() local 87 EXPECT_EQ(ComputeWhileLoopTripCountUpperBound(while_op), absl::nullopt); in TEST_F() 119 HloInstruction* while_op = module->entry_computation()->root_instruction(); in TEST_F() local 120 EXPECT_EQ(*ComputeWhileLoopTripCountUpperBound(while_op), 42); in TEST_F() 154 HloInstruction* while_op = module->entry_computation()->root_instruction(); in TEST_F() local 156 GetAuxiliaryLoopInductionVars(while_op); in TEST_F() 193 HloInstruction* while_op = module->entry_computation()->root_instruction(); in TEST_F() local 195 GetAuxiliaryLoopInductionVars(while_op); in TEST_F() [all …]
|
D | while_loop_analysis.cc | 97 const HloInstruction* while_op) { in GetAuxiliaryLoopInductionVars() argument 99 CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); in GetAuxiliaryLoopInductionVars() 100 auto* while_body = while_op->while_body(); in GetAuxiliaryLoopInductionVars() 102 VLOG(2) << "Aux Induction Variables for loop:" << while_op->ToShortString(); in GetAuxiliaryLoopInductionVars() 241 optional<int64> GetLoopInductionVarTupleIdx(const HloInstruction* while_op) { in GetLoopInductionVarTupleIdx() argument 242 CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); in GetLoopInductionVarTupleIdx() 244 << while_op->ToShortString(); in GetLoopInductionVarTupleIdx() 252 auto* while_cond = while_op->while_condition(); in GetLoopInductionVarTupleIdx() 270 auto* while_body = while_op->while_body(); in GetLoopInductionVarTupleIdx() 298 auto* while_init = while_op->operand(0); in GetLoopInductionVarTupleIdx() [all …]
|
D | while_loop_analysis.h | 32 HloInstruction *while_op, int64 max_brute_force_iters = 128); 37 HloInstruction *while_op); 43 const HloInstruction *while_op); 47 const HloInstruction *while_op);
|
D | while_loop_simplifier_test.cc | 168 auto* while_op = computation->root_instruction(); in TEST_F() local 169 ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); in TEST_F() 170 auto* true_op = while_op->while_body()->AddInstruction( in TEST_F() 173 while_op->while_body()->root_instruction())); in TEST_F() 185 auto* while_op = computation->root_instruction(); in TEST_F() local 186 ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); in TEST_F() 187 auto* while_body = while_op->while_body(); in TEST_F() 201 auto* while_op = computation->root_instruction(); in TEST_F() local 202 ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); in TEST_F() 203 auto* while_body = while_op->while_body(); in TEST_F() [all …]
|
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/ |
D | functional_control_flow_to_regions.cc | 116 LogicalResult ConvertWhileOp(WhileOp while_op) { in ConvertWhileOp() argument 117 auto while_region = OpBuilder(while_op).create<TF::WhileRegionOp>( in ConvertWhileOp() 118 while_op.getLoc(), while_op.getResultTypes(), while_op.input(), in ConvertWhileOp() 119 while_op.parallel_iterations(), while_op.is_stateless(), in ConvertWhileOp() 120 while_op.shape_invariant()); in ConvertWhileOp() 121 CopyDeviceAndUnderscoredAttributes(while_op, while_region); in ConvertWhileOp() 124 CreateCall(while_op, while_op.cond_function(), in ConvertWhileOp() 125 /*caller_region=*/while_region.cond(), while_op.input(), in ConvertWhileOp() 131 CreateCall(while_op, while_op.body_function(), in ConvertWhileOp() 132 /*caller_region=*/while_region.body(), while_op.input(), in ConvertWhileOp() [all …]
|
D | stack_ops_decomposition.cc | 160 TF::WhileOp while_op, ModuleOp module, in HandleWhileOp() argument 164 auto body = while_op.body_function(); in HandleWhileOp() 167 auto it = data_var_to_size_var.find(while_op.getOperand(index)); in HandleWhileOp() 188 auto cond = while_op.cond_function(); in HandleWhileOp() 197 auto new_while_operands = llvm::to_vector<8>(while_op.getOperands()); in HandleWhileOp() 198 OpBuilder builder(while_op); in HandleWhileOp() 199 assert(while_op.getNumOperands() == while_op.getNumResults()); in HandleWhileOp() 200 for (int64_t i = 0; i < while_op.getNumResults(); ++i) { in HandleWhileOp() 201 auto it = data_var_to_size_var.find(while_op.getOperand(i)); in HandleWhileOp() 206 builder.create<TF::WhileOp>(while_op.getLoc(), body.getType().getInputs(), in HandleWhileOp() [all …]
|
D | tpu_variable_runtime_reformatting.cc | 146 TF::WhileRegionOp while_op, tf_device::ReplicateOp replicate, in AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping() argument 149 Region& body = while_op.body(); in AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping() 150 Region& cond = while_op.cond(); in AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping() 382 void HandleReplicateOp(TF::WhileRegionOp while_op, in HandleReplicateOp() argument 416 while_op, replicate, execute, compile_launch); in HandleReplicateOp() 440 builder.setInsertionPoint(while_op); in HandleReplicateOp() 446 CreateStateVars(devices, while_op.getLoc(), key_type, &builder); in HandleReplicateOp() 484 builder.setInsertionPointAfter(while_op); in HandleReplicateOp() 490 while_op.getLoc(), in HandleReplicateOp() 494 while_op.getLoc(), num_replicas, devices, unformat_replicate_operands, in HandleReplicateOp() [all …]
|
D | tensor_list_ops_decomposition.cc | 164 TF::WhileOp while_op, ModuleOp module, in HandleWhileOp() argument 169 auto body = while_op.body_function(); in HandleWhileOp() 172 auto it = buffer_to_size->find(while_op.getOperand(index)); in HandleWhileOp() 177 return (*buffer_to_size)[while_op.getOperand(index)].fixed; in HandleWhileOp() 179 OpBuilder builder(while_op); in HandleWhileOp() 190 auto cond = while_op.cond_function(); in HandleWhileOp() 203 auto new_while_operands = llvm::to_vector<8>(while_op.getOperands()); in HandleWhileOp() 204 for (int64_t i = 0; i < while_op.getNumResults(); ++i) { in HandleWhileOp() 205 auto it = buffer_to_size->find(while_op.getOperand(i)); in HandleWhileOp() 210 builder.create<TF::WhileOp>(while_op.getLoc(), body.getType().getInputs(), in HandleWhileOp() [all …]
|
D | tpu_extract_outside_compilation.cc | 373 if (auto while_op = llvm::dyn_cast<TF::WhileRegionOp>(op)) { in DecomposeControlFlow() local 375 OpBuilder builder(while_op); in DecomposeControlFlow() 376 auto host_while = CloneEmptyWhile(while_op.is_stateless(), in DecomposeControlFlow() 377 while_op.parallel_iterations(), in DecomposeControlFlow() 378 while_op.getLoc(), builder); in DecomposeControlFlow() 385 auto condition = while_op.cond().front().getTerminator()->getOperand(0); in DecomposeControlFlow() 386 builder.setInsertionPoint(while_op.cond().front().getTerminator()); in DecomposeControlFlow() 387 builder.create<TF::XlaSendToHostOp>(while_op.getLoc(), condition, in DecomposeControlFlow() 391 builder, while_op.getLoc(), TypeRange{condition.getType()}, in DecomposeControlFlow() 393 builder.create<TF::YieldOp>(while_op.getLoc(), in DecomposeControlFlow() [all …]
|
D | tensor_array_ops_decomposition.cc | 492 } else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(op)) { in AccessedGradients() local 494 {while_op.body_function(), while_op.cond_function()}, module)) in AccessedGradients() 496 insert(while_op.getOperand(entry.getFirst()), source, func_block); in AccessedGradients() 556 LogicalResult HandleWhileOp(TF::WhileOp while_op, ModuleOp module, in HandleWhileOp() argument 560 auto body = while_op.body_function(); in HandleWhileOp() 561 auto cond = while_op.cond_function(); in HandleWhileOp() 564 auto it = stats->find(while_op.getOperand(index)); in HandleWhileOp() 569 auto it = stats->find(while_op.getOperand(index)); in HandleWhileOp() 588 for (int64_t i = 0; i < while_op.getNumResults(); ++i) { in HandleWhileOp() 593 return while_op.emitOpError( in HandleWhileOp() [all …]
|
D | optimize_global_tensors.cc | 116 if (auto while_op = dyn_cast<TF::WhileOp>(op)) { in AnalyzeFunc() local 118 {while_op.cond_function(), while_op.body_function()}) { in AnalyzeFunc() 119 PropagatePotentiallyWrittenUpFromCallee(callee, while_op.input()); in AnalyzeFunc()
|
D | shape_inference.cc | 1073 if (auto while_op = dyn_cast<WhileOp>(op)) in InferShapeForSingleOperation() local 1074 return InferShapeForWhile(while_op, in InferShapeForSingleOperation() 1075 while_op.body_function().getType().getResults()); in InferShapeForSingleOperation() 1335 } else if (auto while_op = dyn_cast<TF::WhileOp>(op)) { in PropagateShapeIntoAttachedFunctions() local 1340 if (while_op.shape_invariant()) { in PropagateShapeIntoAttachedFunctions() 1342 while_op.input().getTypes(), while_op.output().getTypes(), in PropagateShapeIntoAttachedFunctions() 1343 while_op.body_function().getType().getInputs()); in PropagateShapeIntoAttachedFunctions() 1346 {while_op.cond_function(), while_op.body_function()}, max_iteration); in PropagateShapeIntoAttachedFunctions() 1349 module, while_op.input().getTypes(), in PropagateShapeIntoAttachedFunctions() 1350 {while_op.cond_function(), while_op.body_function()}, max_iteration); in PropagateShapeIntoAttachedFunctions() [all …]
|
D | resource_device_inference.cc | 283 if (auto while_op = dyn_cast<WhileOp>(op)) { in runOnOperation() local 285 while_op, while_op.getOperands(), in runOnOperation() 286 {while_op.body_function(), while_op.cond_function()}, in runOnOperation()
|
D | resource_op_lifting.cc | 662 if (auto while_op = dyn_cast<TF::WhileRegionOp>(op)) in ReplaceOpWithNewOp() local 663 return HoistResourcesOutOfWhileRegion(while_op); in ReplaceOpWithNewOp() 897 LogicalResult HandleWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) { in HandleWhileLoop() argument 927 OpBuilder builder(while_op); in HandleWhileLoop() 931 while_op.getLoc(), body.getType().getResults(), in HandleWhileLoop() 932 FilterRange<Value, OperandRange>(while_op.getOperands(), in HandleWhileLoop() 934 while_op.getAttrs()); in HandleWhileLoop() 951 while_op.getResult(i).replaceAllUsesWith( in HandleWhileLoop() 955 while_op.erase(); in HandleWhileLoop() 1237 if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) { in HoistForControlFlow() local [all …]
|
/external/tensorflow/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/ |
D | legalize_control_flow.cc | 108 LogicalResult LowerWhileOp(mlir::mhlo::WhileOp while_op) { in LowerWhileOp() argument 116 auto* op_inst = while_op.getOperation(); in LowerWhileOp() 117 mlir::OpBuilder builder(while_op); in LowerWhileOp() 118 auto loc = while_op.getLoc(); in LowerWhileOp() 130 while_op.cond().cloneInto(orig_block->getParent(), in LowerWhileOp() 132 while_op.body().cloneInto(orig_block->getParent(), in LowerWhileOp() 136 auto* cond_block = mapper.lookup(&while_op.cond().front()); in LowerWhileOp() 137 auto* body_block = mapper.lookup(&while_op.body().front()); in LowerWhileOp() 143 builder.create<mlir::BranchOp>(loc, cond_block, while_op.getOperand()); in LowerWhileOp() 162 for (auto& block : while_op.cond()) { in LowerWhileOp() [all …]
|
D | sink_constants_to_control_flow.cc | 47 if (auto while_op = llvm::dyn_cast<WhileOp>(op)) { in runOnFunction() local 48 SinkToRegion(&while_op.body()); in runOnFunction() 49 SinkToRegion(&while_op.cond()); in runOnFunction()
|
/external/tensorflow/tensorflow/python/ops/ |
D | while_v2.py | 317 while_op = op.outputs[0].op 318 cond_graph = _get_graph(while_op, "cond", "_cond_graph") 319 body_graph = _get_graph(while_op, "body", "_body_graph") 326 num_original_outputs = while_op.get_attr("_num_original_outputs") 328 num_original_outputs = len(while_op.outputs) 330 num_intermediates = len(while_op.outputs) - num_original_outputs 336 while_op.inputs[:num_original_outputs], 337 while_op.outputs[:num_original_outputs]) 372 while_op._set_func_attr("cond", util.create_new_tf_function(cond_graph)) 373 while_op._set_func_attr("body", util.create_new_tf_function(body_graph)) [all …]
|
/external/tensorflow/tensorflow/compiler/mlir/xla/transforms/ |
D | legalize_tf_control_flow.cc | 162 auto while_op = builder.create<mhlo::WhileOp>( in LowerWhile() local 167 ImportXlaRegion(op.body_function(), &while_op.body(), loc); in LowerWhile() 168 ImportXlaRegion(op.cond_function(), &while_op.cond(), loc, in LowerWhile() 172 Detuple(while_op.getResult(), op.getResults(), &builder); in LowerWhile() 333 auto while_op = builder.create<mhlo::WhileOp>( in LowerWhileRegion() local 338 Region& cond = while_op.cond(); in LowerWhileRegion() 351 Region& body = while_op.body(); in LowerWhileRegion() 364 Detuple(while_op.getResult(), op.getResults(), &builder); in LowerWhileRegion() 371 if (auto while_op = dyn_cast<TF::WhileOp>(op)) { in runOnOperation() local 372 LowerWhile(while_op); in runOnOperation()
|
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/analysis/ |
D | resource_alias_analysis.cc | 316 } else if (auto while_op = dyn_cast<WhileOp>(op)) { in ResourceAliasAnalysisInfo() local 317 AnalyzeWhileLoop(while_op, backtrack_analysis.GetAnalysisForFunc( in ResourceAliasAnalysisInfo() 318 while_op.body_function())); in ResourceAliasAnalysisInfo() 399 Operation* while_op, const BacktrackAnalysisInfo& body_info) { in AnalyzeWhileLoop() argument 404 while_op->getNumResults()); in AnalyzeWhileLoop() 406 for (auto result : filter_resources(while_op->getResults())) { in AnalyzeWhileLoop() 411 PropagateInputToOutput(while_op->getOperand(passthru_index), result); in AnalyzeWhileLoop() 428 for (auto result : filter_resources(while_op->getResults())) { in AnalyzeWhileLoop() 435 PropagateInputToOutput(while_op->getResult(passthru_index), result) || in AnalyzeWhileLoop()
|
/external/tensorflow/tensorflow/python/kernel_tests/ |
D | while_v2_test.py | 859 def _assertNotAccumulated(self, while_op, index): argument 861 body_graph = while_v2._get_graph(while_op, "body", "_body_graph") 875 while_op = r.op.inputs[0].op 876 self._assertNotAccumulated(while_op, 0) 893 while_op = r.op.inputs[0].op 896 index = GetInputIndex(while_op, v) 897 self._assertNotAccumulated(while_op, index) 968 while_op = op 970 body_graph = while_v2._get_graph(while_op, "body", "_body_graph") 971 x_input_index = [i for i, inp in enumerate(while_op.inputs) if inp == x][0] [all …]
|
/external/tensorflow/tensorflow/core/common_runtime/ |
D | lower_while_op.cc | 65 static Status Run(Node* while_op, const NameAttrList& cond_fn, in Run() argument 68 LowerWhileHelper helper(while_op, cond_fn, body_fn, parallel_iterations, in Run() 77 LowerWhileHelper(Node* while_op, const NameAttrList& cond_fn, 177 LowerWhileHelper::LowerWhileHelper(Node* while_op, const NameAttrList& cond_fn, in LowerWhileHelper() argument 181 : while_op_(while_op), in LowerWhileHelper() 183 name_(while_op->name()), in LowerWhileHelper()
|
/external/tensorflow/tensorflow/compiler/jit/ |
D | rearrange_function_argument_pass_test.cc | 104 auto while_op = in TEST() local 107 auto ret2 = ops::_Retval(s.WithOpName("ret2"), while_op.output[0], 2); in TEST() 108 auto ret3 = ops::_Retval(s.WithOpName("ret3"), while_op.output[1], 3); in TEST() 217 auto while_op = ops::While(s.WithOpName("while"), in TEST() local
|