Home
last modified time | relevance | path

Searched refs:branch_index (Results 1 – 25 of 57) sorted by relevance

123

/external/tensorflow/tensorflow/compiler/xla/service/gpu/
Dconditional_thunk.cc51 int32_t branch_index = -1; in ExecuteOnStream() local
58 stream.ThenMemcpy(&branch_index, branch_index_address, sizeof(int32_t)); in ExecuteOnStream()
68 branch_index = pred ? 0 : 1; in ExecuteOnStream()
71 if (branch_index < 0 || branch_index >= config_.branch_count) { in ExecuteOnStream()
72 branch_index = config_.branch_count - 1; in ExecuteOnStream()
78 config_.branch_thunks[branch_index]->ExecuteOnStream(params)); in ExecuteOnStream()
/external/tensorflow/tensorflow/core/common_runtime/
Dlower_case_op_test.cc79 auto branch_index = in TEST() local
85 .Input(branch_index.node()) in TEST()
136 feeds.emplace(Output(branch_index.node()), Input::Initializer(-1)); in TEST()
145 feeds.emplace(Output(branch_index.node()), Input::Initializer(0)); in TEST()
154 feeds.emplace(Output(branch_index.node()), Input::Initializer(1)); in TEST()
163 feeds.emplace(Output(branch_index.node()), Input::Initializer(2)); in TEST()
172 feeds.emplace(Output(branch_index.node()), Input::Initializer(20)); in TEST()
226 auto branch_index = in TEST() local
237 .Input(branch_index.node()) in TEST()
275 feeds.emplace(Output(branch_index.node()), Input::Initializer(-5)); in TEST()
[all …]
Dlower_case_op.cc121 Node* branch_index; in CreatePivotNodes() local
128 .Finalize(graph_, &branch_index)); in CreatePivotNodes()
129 control_predecessor_ = branch_index; in CreatePivotNodes()
135 .Input(branch_index, b) in CreatePivotNodes()
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/
Dswitch_n.pbtxt6 # CHECK-SAME: loc(fused["_SwitchN:", "Case/branch_index/_3"])
13 name: "Case/branch_index"
55 name: "Case/branch_index/_3"
57 input: "Case/branch_index"
58 input: "Case/branch_index"
76 input: "Case/branch_index"
101 input: "Case/branch_index/_3"
112 input: "Case/branch_index/_3:1"
123 input: "Case/branch_index/_3:2"
/external/tensorflow/tensorflow/core/ir/importexport/tests/roundtrip/
Dswitch_n.pbtxt2 name: "Case/branch_index"
44 name: "Case/branch_index/_3"
46 input: "Case/branch_index"
47 input: "Case/branch_index"
65 input: "Case/branch_index"
90 input: "Case/branch_index/_3"
101 input: "Case/branch_index/_3:1"
112 input: "Case/branch_index/_3:2"
/external/tensorflow/tensorflow/compiler/tests/
Dcase_test.py32 def switch_case_test(branch_index): argument
44 branch_index, branch_fns={
59 branch_index = array_ops.constant(0)
72 branch_index, branch_fns={
/external/tensorflow/tensorflow/core/transforms/cf_sink/tests/
Dsink_region_op.mlir35 %branch_index: tensor<i32> {tfg.name = "branch_index"},
41 %Case, %ctl_Case = StatelessCaseRegion %branch_index {
54 %Case_0, %ctl_Case_0 = StatelessCaseRegion %branch_index {
63 yield(%branch_index) : tensor<i32>
/external/tensorflow/tensorflow/compiler/tf2xla/kernels/
Dcase_op.cc60 int32_t branch_index = branch_index_literal.Get<int32>({}); in GetPrunedBranchesAndIndex() local
61 if (branch_index < 0 || branch_index >= unpruned_branches_.size()) { in GetPrunedBranchesAndIndex()
62 branch_index = unpruned_branches_.size() - 1; in GetPrunedBranchesAndIndex()
65 std::vector<NameAttrList> pruned_branch = {unpruned_branches_[branch_index]}; in GetPrunedBranchesAndIndex()
88 xla::XlaOp branch_index; in Compile() local
89 std::tie(branches, branch_index) = GetPrunedBranchesAndIndex(ctx); in Compile()
299 ctx->builder(), branch_index, absl::MakeSpan(result_computations), in Compile()
/external/tensorflow/tensorflow/compiler/tf2xla/
Dfunctionalize_cond.cc442 int branch_index = static_cast<int>(branch); in BuildArgumentNodes() local
448 .Finalize(bodies_[branch_index].get(), in BuildArgumentNodes()
449 &cond_arg_node.branch_copy[branch_index])); in BuildArgumentNodes()
454 int branch_index = e->src_output(); in BuildArgumentNodes() local
455 Node* src_copy = cond_arg_node.branch_copy[branch_index]; in BuildArgumentNodes()
456 Node* dst_copy = node_maps_[branch_index][e->dst()->id()]; in BuildArgumentNodes()
463 << " on branch " << Branch_Name(BranchType(branch_index)); in BuildArgumentNodes()
468 bodies_[branch_index]->AddEdge(src_copy, 0, dst_copy, dst_input); in BuildArgumentNodes()
565 int branch_index = static_cast<int>(branch); in ExtractBodies() local
566 auto output = bodies_[branch_index].get(); in ExtractBodies()
[all …]
/external/tensorflow/tensorflow/core/kernels/data/experimental/
Dchoose_fastest_branch_dataset_op.cc498 Status MakeCurrentIterator(IteratorContext* ctx, int64_t branch_index, in MakeCurrentIterator() argument
501 DCHECK_GE(branch_index, 0); in MakeCurrentIterator()
502 DCHECK_LT(branch_index, histograms_.size()); in MakeCurrentIterator()
511 params.node_name = strings::StrCat(params.type_string, branch_index); in MakeCurrentIterator()
524 strings::StrCat(take_dataset_params.type_string, branch_index); in MakeCurrentIterator()
537 ctx, this, {*wrapper_dataset_tensor_}, branch_index, in MakeCurrentIterator()
538 *instantiated_captured_funcs_[branch_index], prefix(), in MakeCurrentIterator()
543 ctx, this, {*wrapper_dataset_tensor_}, branch_index, in MakeCurrentIterator()
544 *instantiated_captured_funcs_[branch_index], prefix(), in MakeCurrentIterator()
/external/tensorflow/tensorflow/python/ops/
Dcond_v2.py605 for branch_index, t in enumerate(branch_outs):
607 with grad_graphs[branch_index].as_default():
609 forward_graphs[branch_index].inputs[output_idx])
610 grad_graphs[branch_index].structured_outputs[output_idx] = zeros
1024 def indexed_case(branch_index, argument
1029 if isinstance(branch_index, int):
1030 raise TypeError("branch_index must not be a Python int", branch_index)
1041 branch_index = ops.convert_to_tensor(branch_index, name="branch_index")
1055 op_return_value=branch_index))
1059 branch_index,
[all …]
Dcontrol_flow_ops.py3246 branch_index): argument
3266 if not isinstance(branch_index, ops.Tensor):
3268 type(branch_index)))
3269 if not branch_index.dtype.is_integer:
3271 branch_index.dtype))
3309 branch_index, argument
3335 branch_fns, default, branch_index)
3336 with ops.name_scope(name, "case", [branch_index]):
3337 if context.executing_eagerly() and not hasattr(branch_index, "graph"):
3338 branch_index = array_ops.where(
[all …]
Dcontrol_flow_ops_test.py1014 branch_index = array_ops.placeholder_with_default(bi, [])
1015 case_out = control_flow_ops.switch_case(branch_index, branches)
1026 branch_index = array_ops.placeholder_with_default(bi, [])
1028 branch_index, branches, name=self.make_name())
1039 branch_index = array_ops.placeholder_with_default(bi, [])
1041 branch_index, branches, default=make_func(6), name=self.make_name())
1056 branch_index = array_ops.placeholder_with_default(bi, [])
1058 branch_index, branches, default=make_func(6), name=self.make_name())
1087 branch_index = array_ops.placeholder_with_default(bi, [])
1091 case_out = control_flow_ops.switch_case(branch_index, branches)
[all …]
/external/tensorflow/tensorflow/compiler/xla/service/
Dconditional_simplifier.cc475 int branch_index = 0; in TryRemoveConditional() local
477 branch_index = conditional->operand(0)->literal().Get<bool>({}) ? 0 : 1; in TryRemoveConditional()
479 branch_index = conditional->operand(0)->literal().Get<int32_t>({}); in TryRemoveConditional()
480 if (branch_index < 0 || branch_index >= conditional->branch_count()) { in TryRemoveConditional()
481 branch_index = conditional->branch_count() - 1; in TryRemoveConditional()
484 HloInstruction* call_op = create_call(branch_index); in TryRemoveConditional()
Ddynamic_dimension_inference.cc1441 for (int64_t branch_index = 0; branch_index < hlo->branch_count(); in HandleConditional() local
1442 ++branch_index) { in HandleConditional()
1449 const int64_t operand_index = branch_index + 1; in HandleConditional()
1481 HloComputation* branch_computation = hlo->branch_computation(branch_index); in HandleConditional()
1545 for (int64_t branch_index = 0; branch_index < hlo->branch_count(); in HandleConditional() local
1546 ++branch_index) { in HandleConditional()
1558 new_branch_computations[branch_index]->root_instruction(), in HandleConditional()
1564 new_branch_computations[branch_index]->AddInstruction( in HandleConditional()
1578 new_branch_computations[branch_index]->root_instruction(), in HandleConditional()
1580 new_branch_computations[branch_index]->set_root_instruction( in HandleConditional()
/external/tensorflow/tensorflow/core/api_def/base_api/
Dapi_def_Case.pbtxt4 name: "branch_index"
28 switch (branch_index) {
Dapi_def_StatelessCase.pbtxt5 name: "branch_index"
29 switch (branch_index) {
/external/tensorflow/tensorflow/compiler/xla/tests/
Dconditional_test.cc199 XlaOp branch_index; in XLA_TEST_P() local
201 bi, 0, "branch_index_arg", &builder, &branch_index); in XLA_TEST_P()
213 Conditional(branch_index, branches_p, operands); in XLA_TEST_P()
242 XlaOp branch_index; in XLA_TEST_P() local
244 bi, 0, "branch_index_arg", &builder, &branch_index); in XLA_TEST_P()
266 Conditional(branch_index, branches_p, operands); in XLA_TEST_P()
416 XlaOp branch_index; in XLA_TEST_P() local
418 CreateR0Parameter<int32_t>(bi, 0, "pred", &builder, &branch_index); in XLA_TEST_P()
437 Conditional(branch_index, branches_p, in XLA_TEST_P()
/external/tensorflow/tensorflow/compiler/xla/client/lib/
Ddynamic_shaped_ops.cc167 XlaBuilder* builder, XlaOp branch_index, in DynamicConditional() argument
185 return xla::Conditional(branch_index, branch_computations, in DynamicConditional()
227 return xla::Conditional(branch_index, rewritten_computation_ptrs, in DynamicConditional()
Ddynamic_shaped_ops.h39 XlaBuilder* builder, XlaOp branch_index,
/external/tensorflow/tensorflow/core/ops/compat/ops_history_v2/
DStatelessCase.pbtxt4 name: "branch_index"
DCase.pbtxt4 name: "branch_index"
/external/tensorflow/tensorflow/core/ops/compat/ops_history_v1/
DCase.pbtxt4 name: "branch_index"
/external/tensorflow/tensorflow/compiler/xla/client/
Dvalue_inference.cc577 int64_t branch_index = 0; in AnalyzeConstantValueFallback() local
580 branch_index = 0; in AnalyzeConstantValueFallback()
582 branch_index = 1; in AnalyzeConstantValueFallback()
585 branch_index = operands[0].GetIntegralAsS64({}).value(); in AnalyzeConstantValueFallback()
587 const int64_t branch_dynamism_index = 2 + branch_index; in AnalyzeConstantValueFallback()
1275 int64_t branch_index = 0; in AnalyzeIsDynamic() local
1278 branch_index = 0; in AnalyzeIsDynamic()
1280 branch_index = 1; in AnalyzeIsDynamic()
1283 branch_index = operands[0].GetIntegralAsS64({}).value(); in AnalyzeIsDynamic()
1285 const int64_t branch_dynamism_index = 2 + 2 * branch_index + 1; in AnalyzeIsDynamic()
/external/tensorflow/tensorflow/dtensor/mlir/expansions/
Dsave_restore_spmd_expander.cc369 mlir::Value branch_index = DeviceIdToLocalBranchIndex( in ConditionalSave() local
376 /*branch_index=*/branch_index, in ConditionalSave()
642 mlir::Value branch_index = DeviceIdToLocalBranchIndex( in ExpandRestoreV2OpHelper() local
648 /*branch_index=*/branch_index, in ExpandRestoreV2OpHelper()

123