Home
last modified time | relevance | path

Searched refs:body_fn (Results 1 – 22 of 22) sorted by relevance

/external/tensorflow/tensorflow/cc/framework/
Dwhile_gradients.cc75 BodyGraphBuilderFn body_fn = [](const Scope& scope, in AddForwardLoopCounter() local
85 TF_RETURN_IF_ERROR(BuildWhileLoop(scope, {zero}, cond_fn, body_fn, in AddForwardLoopCounter()
114 BodyGraphBuilderFn body_fn = [](const Scope& scope, in AddBackPropLoopCounter() local
125 scope, {loop_count}, cond_fn, body_fn, frame_name, &outputs, in AddBackPropLoopCounter()
160 BodyGraphBuilderFn body_fn = [while_ctx](const Scope& scope, in AddWhileGradientLoop() local
171 TF_RETURN_IF_ERROR(BuildWhileLoop(scope, grad_inputs, cond_fn, body_fn, in AddWhileGradientLoop()
/external/tensorflow/tensorflow/compiler/jit/
Drearrange_function_argument_pass_test.cc101 NameAttrList cond_fn, body_fn; in TEST() local
103 body_fn.set_name("f2"); in TEST()
106 std::initializer_list<Input>{arg0, arg1}, cond_fn, body_fn); in TEST()
214 NameAttrList cond_fn, body_fn; in TEST() local
216 body_fn.set_name("f2"); in TEST()
219 cond_fn, body_fn); in TEST()
Dextract_outside_compilation_pass_test.cc620 NameAttrList body_fn; in TEST_F() local
621 body_fn.set_name("body_fn"); in TEST_F()
624 cond_fn, body_fn); in TEST_F()
/external/tensorflow/tensorflow/core/common_runtime/
Dlower_while_op.cc66 const NameAttrList& body_fn, int parallel_iterations, in Run() argument
69 LowerWhileHelper helper(while_op, cond_fn, body_fn, parallel_iterations, in Run()
79 const NameAttrList& body_fn, int parallel_iterations,
181 const NameAttrList& body_fn, in LowerWhileHelper() argument
194 body_call_builder_(NewName("body"), body_fn.name(), flib_def, in LowerWhileHelper()
202 for (const auto& i : body_fn.attr()) { in LowerWhileHelper()
/external/tensorflow/tensorflow/compiler/tf2xla/
Dtf2xla_util_test.cc359 NameAttrList cond_fn, body_fn; in TEST() local
361 body_fn.set_name("body"); in TEST()
364 std::initializer_list<Input>{pred, input}, cond_fn, body_fn); in TEST()
393 NameAttrList cond_fn, body_fn; in TEST() local
395 body_fn.set_name("body"); in TEST()
398 std::initializer_list<Input>{pred, input}, cond_fn, body_fn); in TEST()
409 TF_ASSERT_OK(GetNodeAttr(while_node->def(), "body", &body_fn)); in TEST()
410 const FunctionDef* rewritten_body_fn = fld.Find(body_fn.name()); in TEST()
Dfunctionalize_control_flow_test.cc388 NameAttrList cond_fn, body_fn; in TEST() local
389 TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); in TEST()
397 std::initializer_list<Input>{source}, cond_fn, body_fn); in TEST()
440 InstantiateFunctionForTest(body_fn.name(), library, &result)); in TEST()
538 NameAttrList cond_fn, body_fn; in TEST() local
539 TF_ASSERT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); in TEST()
547 std::initializer_list<Input>{source}, cond_fn, body_fn); in TEST()
576 InstantiateFunctionForTest(body_fn.name(), library, &result)); in TEST()
656 NameAttrList cond_fn, body_fn; in TEST() local
657 TF_EXPECT_OK(FindWhileCondAndBody(graph_def, &cond_fn, &body_fn)); in TEST()
[all …]
Dxla_compiler_test.cc1681 NameAttrList cond_fn, body_fn; in TEST_F() local
1683 body_fn.set_name("body"); in TEST_F()
1685 ops::While(scope, std::initializer_list<Input>{arg}, cond_fn, body_fn); in TEST_F()
1762 NameAttrList cond_fn, body_fn; in TEST_F() local
1764 body_fn.set_name("body"); in TEST_F()
1766 scope, std::initializer_list<Input>{arg0, arg1, arg2}, cond_fn, body_fn); in TEST_F()
/external/tensorflow/tensorflow/lite/testing/op_tests/
Dwhile_loop.py63 def body_fn(counter, value, increment_value): function
74 cond_fn, body_fn, loop_vars=[1, increment_value, increment_value])
/external/tensorflow/tensorflow/lite/experimental/mlir/testing/op_tests/
Dwhile_loop.py67 def body_fn(counter, value, increment_value): function
78 cond_fn, body_fn, loop_vars=[1, increment_value, increment_value])
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/
Dfunctional_control_flow_to_cfg.cc197 auto body_fn = op.body_function(); in LowerWhileOp() local
231 for (Type type : body_fn.getFunctionType().getInputs()) { in LowerWhileOp()
262 Operation* body_call_op = CallFn(loc, get_body_arg, body_fn, &builder); in LowerWhileOp()
/external/tensorflow/tensorflow/python/keras/
Dconstraints.py293 def body_fn(i, array): function
301 body_fn,
/external/tensorflow/tensorflow/compiler/xla/service/
Dcholesky_expander.cc76 auto body_fn = [&](XlaOp i, absl::Span<const XlaOp> loop_vars, in CholeskyUnblocked() local
119 n, S32, body_fn, in CholeskyUnblocked()
Dqr_expander.cc327 auto body_fn = [&](XlaOp j, absl::Span<const XlaOp> values, in CompactWYRepresentation() local
356 ForEachIndex(n, S32, body_fn, {t, vtv}, "wy", builder)); in CompactWYRepresentation()
/external/tensorflow/tensorflow/python/autograph/operators/
Dcontrol_flow_test.py574 def _basic_loop(self, init_value, body_fn): argument
577 s = body_fn(i, s)
1000 def _basic_loop(self, init_value, body_fn): argument
1004 s = body_fn(i, s)
1244 def _basic_cond(self, body_fn, else_fn): argument
1247 x = body_fn()
/external/tensorflow/tensorflow/compiler/mlir/tfrt/transforms/
Dtf_to_tfrt.cc1201 mlir::FlatSymbolRefAttr body_fn = op.bodyAttr(); in matchAndRewrite() local
1246 op, body_fn, pred_fn, while_arg_result_types, rewriter); in matchAndRewrite()
1287 mlir::TF::WhileOp op, mlir::FlatSymbolRefAttr body_fn,
1380 if (auto body_fn = symbol_table_.lookup<mlir::func::FuncOp>(body_fn_name)) { in GetWhileBodyFunction() local
1381 return body_fn; in GetWhileBodyFunction()
1397 auto body_fn = in GetWhileBodyFunction() local
1414 body_fn->setAttr("tfrt.cost_threshold", rewriter.getI64IntegerAttr(1)); in GetWhileBodyFunction()
1417 auto *block = body_fn.addEntryBlock(); in GetWhileBodyFunction()
1447 symbol_table_.insert(body_fn); in GetWhileBodyFunction()
1449 return body_fn; in GetWhileBodyFunction()
/external/tensorflow/tensorflow/python/kernel_tests/control_flow/
Dwhile_v2_test.py111 def body_fn(v): # pylint: disable=invalid-name function
124 lambda v: v < 8., body_fn, [x], return_same_structure=False)
338 def body_fn(i): # pylint: disable=invalid-name function
342 loop = while_loop_v2(lambda i: i < 1, body_fn, [0])
Dcontrol_flow_ops_py_test.py1566 def body_fn(i): function
1571 body=body_fn, loop_vars=[1])
1583 def body_fn(unused_i): function
1588 body=body_fn, loop_vars=[0])
/external/tensorflow/tensorflow/python/distribute/
Dmirrored_strategy_test.py460 def body_fn(i): function
463 return control_flow_ops.while_loop_v2(lambda i: i < 2, body_fn, [0])
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/ir/
Dtf_ops_n_z.cc3117 auto body_fn = in verifySymbolUses() local
3122 if (!body_fn) { in verifySymbolUses()
3127 auto body_fn_type = body_fn.getFunctionType(); in verifySymbolUses()
/external/tensorflow/tensorflow/c/
Dc_api.cc2038 tensorflow::ops::BodyGraphBuilderFn body_fn = in TF_FinishWhileHelper() local
2062 body_fn, params->name, &loop_outputs); in TF_FinishWhileHelper()
/external/tensorflow/tensorflow/python/ops/parallel_for/
Dcontrol_flow_ops_test.py1857 def body_fn(t, state, ta): function
1880 _, state, ta = control_flow_ops.while_loop(condition_fn, body_fn, [
/external/tensorflow/tensorflow/compiler/mlir/xla/transforms/
Dlegalize_tf.cc520 WhileBodyFnType body_fn, ArrayRef<Value> init_values, in CreateWhile32() argument
577 body_fn(loc, block->getArgument(0), in CreateWhile32()
6912 auto body_fn = [&](Location loc, Value j, ArrayRef<Value> old_values, in ComputeWYRepresentation() local
6987 CreateWhile32(loc, n - 1, body_fn, {w, vs, taus}, &while_output, rewriter); in ComputeWYRepresentation()