Home
last modified time | relevance | path

Searched refs:while_op (Results 1 – 25 of 50) sorted by relevance

12

/external/tensorflow/tensorflow/compiler/mlir/lite/transforms/
Dwhile_loop_outline.cc51 void OutlineWhile(WhileOp while_op);
65 bool IsAlreadyOutlined(WhileOp while_op) { in IsAlreadyOutlined() argument
73 return just_call(while_op.body()) && just_call(while_op.cond()); in IsAlreadyOutlined()
99 void WhileOutlinePass::OutlineWhile(WhileOp while_op) { in OutlineWhile() argument
107 auto num_loop_carried = while_op.cond().getNumArguments(); in OutlineWhile()
109 while_op.getOperands().drop_front(num_loop_carried); in OutlineWhile()
114 llvm::SmallVector<Region*, 2> regions{&while_op.cond(), &while_op.body()}; in OutlineWhile()
146 if (extra_operands.empty() && IsAlreadyOutlined(while_op)) return; in OutlineWhile()
150 types.reserve(extra_operands.size() + while_op.getNumOperands()); in OutlineWhile()
151 for (Type type : while_op.cond().getArgumentTypes()) types.push_back(type); in OutlineWhile()
[all …]
Dlegalize_tf_while.cc49 void RunOnWhile(TF::WhileOp while_op) { in RunOnWhile() argument
50 Operation* op = while_op.getOperation(); in RunOnWhile()
54 while_op.is_stateless()); in RunOnWhile()
56 auto create_region_with_call = [&while_op](FuncOp func, Region& region) { in RunOnWhile()
62 auto call = builder.create<CallOp>(while_op.getLoc(), func, new_operands); in RunOnWhile()
63 builder.create<YieldOp>(while_op.getLoc(), call.getResults()); in RunOnWhile()
67 create_region_with_call(while_op.cond_function(), new_op.cond()); in RunOnWhile()
68 create_region_with_call(while_op.body_function(), new_op.body()); in RunOnWhile()
76 func.getBody().walk([](TF::WhileOp while_op) { RunOnWhile(while_op); }); in RunOnFunction() argument
/external/tensorflow/tensorflow/compiler/xla/service/
Dwhile_loop_simplifier.cc44 HloInstruction* while_op, absl::flat_hash_set<int64>& used_tuple_indices) { in RemoveDeadTupleIndices() argument
50 HloModule* module = while_op->GetModule(); in RemoveDeadTupleIndices()
51 HloComputation* computation = while_op->parent(); in RemoveDeadTupleIndices()
52 HloInstruction* while_init = while_op->mutable_operand(0); in RemoveDeadTupleIndices()
53 HloComputation* while_cond = while_op->while_condition(); in RemoveDeadTupleIndices()
54 HloComputation* while_body = while_op->while_body(); in RemoveDeadTupleIndices()
201 TF_RETURN_IF_ERROR(computation->ReplaceInstruction(while_op, new_tuple)); in RemoveDeadTupleIndices()
213 static StatusOr<bool> TryRemoveDeadWhileParams(HloInstruction* while_op) { in TryRemoveDeadWhileParams() argument
214 CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); in TryRemoveDeadWhileParams()
219 if (!while_op->parent()->IsSafelyRemovable(while_op)) { in TryRemoveDeadWhileParams()
[all …]
Dwhile_loop_analysis_test.cc55 HloInstruction* while_op = module->entry_computation()->root_instruction(); in TEST_F() local
56 EXPECT_EQ(*ComputeWhileLoopTripCountUpperBound(while_op), 1); in TEST_F()
86 HloInstruction* while_op = module->entry_computation()->root_instruction(); in TEST_F() local
87 EXPECT_EQ(ComputeWhileLoopTripCountUpperBound(while_op), absl::nullopt); in TEST_F()
119 HloInstruction* while_op = module->entry_computation()->root_instruction(); in TEST_F() local
120 EXPECT_EQ(*ComputeWhileLoopTripCountUpperBound(while_op), 42); in TEST_F()
154 HloInstruction* while_op = module->entry_computation()->root_instruction(); in TEST_F() local
156 GetAuxiliaryLoopInductionVars(while_op); in TEST_F()
193 HloInstruction* while_op = module->entry_computation()->root_instruction(); in TEST_F() local
195 GetAuxiliaryLoopInductionVars(while_op); in TEST_F()
[all …]
Dwhile_loop_analysis.cc97 const HloInstruction* while_op) { in GetAuxiliaryLoopInductionVars() argument
99 CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); in GetAuxiliaryLoopInductionVars()
100 auto* while_body = while_op->while_body(); in GetAuxiliaryLoopInductionVars()
102 VLOG(2) << "Aux Induction Variables for loop:" << while_op->ToShortString(); in GetAuxiliaryLoopInductionVars()
241 optional<int64> GetLoopInductionVarTupleIdx(const HloInstruction* while_op) { in GetLoopInductionVarTupleIdx() argument
242 CHECK_EQ(while_op->opcode(), HloOpcode::kWhile); in GetLoopInductionVarTupleIdx()
244 << while_op->ToShortString(); in GetLoopInductionVarTupleIdx()
252 auto* while_cond = while_op->while_condition(); in GetLoopInductionVarTupleIdx()
270 auto* while_body = while_op->while_body(); in GetLoopInductionVarTupleIdx()
298 auto* while_init = while_op->operand(0); in GetLoopInductionVarTupleIdx()
[all …]
Dwhile_loop_analysis.h32 HloInstruction *while_op, int64 max_brute_force_iters = 128);
37 HloInstruction *while_op);
43 const HloInstruction *while_op);
47 const HloInstruction *while_op);
Dwhile_loop_simplifier_test.cc168 auto* while_op = computation->root_instruction(); in TEST_F() local
169 ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); in TEST_F()
170 auto* true_op = while_op->while_body()->AddInstruction( in TEST_F()
173 while_op->while_body()->root_instruction())); in TEST_F()
185 auto* while_op = computation->root_instruction(); in TEST_F() local
186 ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); in TEST_F()
187 auto* while_body = while_op->while_body(); in TEST_F()
201 auto* while_op = computation->root_instruction(); in TEST_F() local
202 ASSERT_EQ(while_op->opcode(), HloOpcode::kWhile); in TEST_F()
203 auto* while_body = while_op->while_body(); in TEST_F()
[all …]
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/
Dfunctional_control_flow_to_regions.cc116 LogicalResult ConvertWhileOp(WhileOp while_op) { in ConvertWhileOp() argument
117 auto while_region = OpBuilder(while_op).create<TF::WhileRegionOp>( in ConvertWhileOp()
118 while_op.getLoc(), while_op.getResultTypes(), while_op.input(), in ConvertWhileOp()
119 while_op.parallel_iterations(), while_op.is_stateless(), in ConvertWhileOp()
120 while_op.shape_invariant()); in ConvertWhileOp()
121 CopyDeviceAndUnderscoredAttributes(while_op, while_region); in ConvertWhileOp()
124 CreateCall(while_op, while_op.cond_function(), in ConvertWhileOp()
125 /*caller_region=*/while_region.cond(), while_op.input(), in ConvertWhileOp()
131 CreateCall(while_op, while_op.body_function(), in ConvertWhileOp()
132 /*caller_region=*/while_region.body(), while_op.input(), in ConvertWhileOp()
[all …]
Dstack_ops_decomposition.cc160 TF::WhileOp while_op, ModuleOp module, in HandleWhileOp() argument
164 auto body = while_op.body_function(); in HandleWhileOp()
167 auto it = data_var_to_size_var.find(while_op.getOperand(index)); in HandleWhileOp()
188 auto cond = while_op.cond_function(); in HandleWhileOp()
197 auto new_while_operands = llvm::to_vector<8>(while_op.getOperands()); in HandleWhileOp()
198 OpBuilder builder(while_op); in HandleWhileOp()
199 assert(while_op.getNumOperands() == while_op.getNumResults()); in HandleWhileOp()
200 for (int64_t i = 0; i < while_op.getNumResults(); ++i) { in HandleWhileOp()
201 auto it = data_var_to_size_var.find(while_op.getOperand(i)); in HandleWhileOp()
206 builder.create<TF::WhileOp>(while_op.getLoc(), body.getType().getInputs(), in HandleWhileOp()
[all …]
Dtpu_variable_runtime_reformatting.cc146 TF::WhileRegionOp while_op, tf_device::ReplicateOp replicate, in AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping() argument
149 Region& body = while_op.body(); in AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping()
150 Region& cond = while_op.cond(); in AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping()
382 void HandleReplicateOp(TF::WhileRegionOp while_op, in HandleReplicateOp() argument
416 while_op, replicate, execute, compile_launch); in HandleReplicateOp()
440 builder.setInsertionPoint(while_op); in HandleReplicateOp()
446 CreateStateVars(devices, while_op.getLoc(), key_type, &builder); in HandleReplicateOp()
484 builder.setInsertionPointAfter(while_op); in HandleReplicateOp()
490 while_op.getLoc(), in HandleReplicateOp()
494 while_op.getLoc(), num_replicas, devices, unformat_replicate_operands, in HandleReplicateOp()
[all …]
Dtensor_list_ops_decomposition.cc164 TF::WhileOp while_op, ModuleOp module, in HandleWhileOp() argument
169 auto body = while_op.body_function(); in HandleWhileOp()
172 auto it = buffer_to_size->find(while_op.getOperand(index)); in HandleWhileOp()
177 return (*buffer_to_size)[while_op.getOperand(index)].fixed; in HandleWhileOp()
179 OpBuilder builder(while_op); in HandleWhileOp()
190 auto cond = while_op.cond_function(); in HandleWhileOp()
203 auto new_while_operands = llvm::to_vector<8>(while_op.getOperands()); in HandleWhileOp()
204 for (int64_t i = 0; i < while_op.getNumResults(); ++i) { in HandleWhileOp()
205 auto it = buffer_to_size->find(while_op.getOperand(i)); in HandleWhileOp()
210 builder.create<TF::WhileOp>(while_op.getLoc(), body.getType().getInputs(), in HandleWhileOp()
[all …]
Dtpu_extract_outside_compilation.cc373 if (auto while_op = llvm::dyn_cast<TF::WhileRegionOp>(op)) { in DecomposeControlFlow() local
375 OpBuilder builder(while_op); in DecomposeControlFlow()
376 auto host_while = CloneEmptyWhile(while_op.is_stateless(), in DecomposeControlFlow()
377 while_op.parallel_iterations(), in DecomposeControlFlow()
378 while_op.getLoc(), builder); in DecomposeControlFlow()
385 auto condition = while_op.cond().front().getTerminator()->getOperand(0); in DecomposeControlFlow()
386 builder.setInsertionPoint(while_op.cond().front().getTerminator()); in DecomposeControlFlow()
387 builder.create<TF::XlaSendToHostOp>(while_op.getLoc(), condition, in DecomposeControlFlow()
391 builder, while_op.getLoc(), TypeRange{condition.getType()}, in DecomposeControlFlow()
393 builder.create<TF::YieldOp>(while_op.getLoc(), in DecomposeControlFlow()
[all …]
Dtensor_array_ops_decomposition.cc492 } else if (auto while_op = llvm::dyn_cast<TF::WhileOp>(op)) { in AccessedGradients() local
494 {while_op.body_function(), while_op.cond_function()}, module)) in AccessedGradients()
496 insert(while_op.getOperand(entry.getFirst()), source, func_block); in AccessedGradients()
556 LogicalResult HandleWhileOp(TF::WhileOp while_op, ModuleOp module, in HandleWhileOp() argument
560 auto body = while_op.body_function(); in HandleWhileOp()
561 auto cond = while_op.cond_function(); in HandleWhileOp()
564 auto it = stats->find(while_op.getOperand(index)); in HandleWhileOp()
569 auto it = stats->find(while_op.getOperand(index)); in HandleWhileOp()
588 for (int64_t i = 0; i < while_op.getNumResults(); ++i) { in HandleWhileOp()
593 return while_op.emitOpError( in HandleWhileOp()
[all …]
Doptimize_global_tensors.cc116 if (auto while_op = dyn_cast<TF::WhileOp>(op)) { in AnalyzeFunc() local
118 {while_op.cond_function(), while_op.body_function()}) { in AnalyzeFunc()
119 PropagatePotentiallyWrittenUpFromCallee(callee, while_op.input()); in AnalyzeFunc()
Dshape_inference.cc1073 if (auto while_op = dyn_cast<WhileOp>(op)) in InferShapeForSingleOperation() local
1074 return InferShapeForWhile(while_op, in InferShapeForSingleOperation()
1075 while_op.body_function().getType().getResults()); in InferShapeForSingleOperation()
1335 } else if (auto while_op = dyn_cast<TF::WhileOp>(op)) { in PropagateShapeIntoAttachedFunctions() local
1340 if (while_op.shape_invariant()) { in PropagateShapeIntoAttachedFunctions()
1342 while_op.input().getTypes(), while_op.output().getTypes(), in PropagateShapeIntoAttachedFunctions()
1343 while_op.body_function().getType().getInputs()); in PropagateShapeIntoAttachedFunctions()
1346 {while_op.cond_function(), while_op.body_function()}, max_iteration); in PropagateShapeIntoAttachedFunctions()
1349 module, while_op.input().getTypes(), in PropagateShapeIntoAttachedFunctions()
1350 {while_op.cond_function(), while_op.body_function()}, max_iteration); in PropagateShapeIntoAttachedFunctions()
[all …]
Dresource_device_inference.cc283 if (auto while_op = dyn_cast<WhileOp>(op)) { in runOnOperation() local
285 while_op, while_op.getOperands(), in runOnOperation()
286 {while_op.body_function(), while_op.cond_function()}, in runOnOperation()
Dresource_op_lifting.cc662 if (auto while_op = dyn_cast<TF::WhileRegionOp>(op)) in ReplaceOpWithNewOp() local
663 return HoistResourcesOutOfWhileRegion(while_op); in ReplaceOpWithNewOp()
897 LogicalResult HandleWhileLoop(TF::WhileOp while_op, FuncOp body, FuncOp cond) { in HandleWhileLoop() argument
927 OpBuilder builder(while_op); in HandleWhileLoop()
931 while_op.getLoc(), body.getType().getResults(), in HandleWhileLoop()
932 FilterRange<Value, OperandRange>(while_op.getOperands(), in HandleWhileLoop()
934 while_op.getAttrs()); in HandleWhileLoop()
951 while_op.getResult(i).replaceAllUsesWith( in HandleWhileLoop()
955 while_op.erase(); in HandleWhileLoop()
1237 if (auto while_op = llvm::dyn_cast<TF::WhileOp>(&op)) { in HoistForControlFlow() local
[all …]
/external/tensorflow/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/
Dlegalize_control_flow.cc108 LogicalResult LowerWhileOp(mlir::mhlo::WhileOp while_op) { in LowerWhileOp() argument
116 auto* op_inst = while_op.getOperation(); in LowerWhileOp()
117 mlir::OpBuilder builder(while_op); in LowerWhileOp()
118 auto loc = while_op.getLoc(); in LowerWhileOp()
130 while_op.cond().cloneInto(orig_block->getParent(), in LowerWhileOp()
132 while_op.body().cloneInto(orig_block->getParent(), in LowerWhileOp()
136 auto* cond_block = mapper.lookup(&while_op.cond().front()); in LowerWhileOp()
137 auto* body_block = mapper.lookup(&while_op.body().front()); in LowerWhileOp()
143 builder.create<mlir::BranchOp>(loc, cond_block, while_op.getOperand()); in LowerWhileOp()
162 for (auto& block : while_op.cond()) { in LowerWhileOp()
[all …]
Dsink_constants_to_control_flow.cc47 if (auto while_op = llvm::dyn_cast<WhileOp>(op)) { in runOnFunction() local
48 SinkToRegion(&while_op.body()); in runOnFunction()
49 SinkToRegion(&while_op.cond()); in runOnFunction()
/external/tensorflow/tensorflow/python/ops/
Dwhile_v2.py317 while_op = op.outputs[0].op
318 cond_graph = _get_graph(while_op, "cond", "_cond_graph")
319 body_graph = _get_graph(while_op, "body", "_body_graph")
326 num_original_outputs = while_op.get_attr("_num_original_outputs")
328 num_original_outputs = len(while_op.outputs)
330 num_intermediates = len(while_op.outputs) - num_original_outputs
336 while_op.inputs[:num_original_outputs],
337 while_op.outputs[:num_original_outputs])
372 while_op._set_func_attr("cond", util.create_new_tf_function(cond_graph))
373 while_op._set_func_attr("body", util.create_new_tf_function(body_graph))
[all …]
/external/tensorflow/tensorflow/compiler/mlir/xla/transforms/
Dlegalize_tf_control_flow.cc162 auto while_op = builder.create<mhlo::WhileOp>( in LowerWhile() local
167 ImportXlaRegion(op.body_function(), &while_op.body(), loc); in LowerWhile()
168 ImportXlaRegion(op.cond_function(), &while_op.cond(), loc, in LowerWhile()
172 Detuple(while_op.getResult(), op.getResults(), &builder); in LowerWhile()
333 auto while_op = builder.create<mhlo::WhileOp>( in LowerWhileRegion() local
338 Region& cond = while_op.cond(); in LowerWhileRegion()
351 Region& body = while_op.body(); in LowerWhileRegion()
364 Detuple(while_op.getResult(), op.getResults(), &builder); in LowerWhileRegion()
371 if (auto while_op = dyn_cast<TF::WhileOp>(op)) { in runOnOperation() local
372 LowerWhile(while_op); in runOnOperation()
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/analysis/
Dresource_alias_analysis.cc316 } else if (auto while_op = dyn_cast<WhileOp>(op)) { in ResourceAliasAnalysisInfo() local
317 AnalyzeWhileLoop(while_op, backtrack_analysis.GetAnalysisForFunc( in ResourceAliasAnalysisInfo()
318 while_op.body_function())); in ResourceAliasAnalysisInfo()
399 Operation* while_op, const BacktrackAnalysisInfo& body_info) { in AnalyzeWhileLoop() argument
404 while_op->getNumResults()); in AnalyzeWhileLoop()
406 for (auto result : filter_resources(while_op->getResults())) { in AnalyzeWhileLoop()
411 PropagateInputToOutput(while_op->getOperand(passthru_index), result); in AnalyzeWhileLoop()
428 for (auto result : filter_resources(while_op->getResults())) { in AnalyzeWhileLoop()
435 PropagateInputToOutput(while_op->getResult(passthru_index), result) || in AnalyzeWhileLoop()
/external/tensorflow/tensorflow/python/kernel_tests/
Dwhile_v2_test.py859 def _assertNotAccumulated(self, while_op, index): argument
861 body_graph = while_v2._get_graph(while_op, "body", "_body_graph")
875 while_op = r.op.inputs[0].op
876 self._assertNotAccumulated(while_op, 0)
893 while_op = r.op.inputs[0].op
896 index = GetInputIndex(while_op, v)
897 self._assertNotAccumulated(while_op, index)
968 while_op = op
970 body_graph = while_v2._get_graph(while_op, "body", "_body_graph")
971 x_input_index = [i for i, inp in enumerate(while_op.inputs) if inp == x][0]
[all …]
/external/tensorflow/tensorflow/core/common_runtime/
Dlower_while_op.cc65 static Status Run(Node* while_op, const NameAttrList& cond_fn, in Run() argument
68 LowerWhileHelper helper(while_op, cond_fn, body_fn, parallel_iterations, in Run()
77 LowerWhileHelper(Node* while_op, const NameAttrList& cond_fn,
177 LowerWhileHelper::LowerWhileHelper(Node* while_op, const NameAttrList& cond_fn, in LowerWhileHelper() argument
181 : while_op_(while_op), in LowerWhileHelper()
183 name_(while_op->name()), in LowerWhileHelper()
/external/tensorflow/tensorflow/compiler/jit/
Drearrange_function_argument_pass_test.cc104 auto while_op = in TEST() local
107 auto ret2 = ops::_Retval(s.WithOpName("ret2"), while_op.output[0], 2); in TEST()
108 auto ret3 = ops::_Retval(s.WithOpName("ret3"), while_op.output[1], 3); in TEST()
217 auto while_op = ops::While(s.WithOpName("while"), in TEST() local

12