/external/tensorflow/tensorflow/compiler/xla/service/gpu/ |
D | conditional_thunk.cc | 51 int32_t branch_index = -1; in ExecuteOnStream() local 58 stream.ThenMemcpy(&branch_index, branch_index_address, sizeof(int32_t)); in ExecuteOnStream() 68 branch_index = pred ? 0 : 1; in ExecuteOnStream() 71 if (branch_index < 0 || branch_index >= config_.branch_count) { in ExecuteOnStream() 72 branch_index = config_.branch_count - 1; in ExecuteOnStream() 78 config_.branch_thunks[branch_index]->ExecuteOnStream(params)); in ExecuteOnStream()
|
/external/tensorflow/tensorflow/core/common_runtime/ |
D | lower_case_op_test.cc | 79 auto branch_index = in TEST() local 85 .Input(branch_index.node()) in TEST() 136 feeds.emplace(Output(branch_index.node()), Input::Initializer(-1)); in TEST() 145 feeds.emplace(Output(branch_index.node()), Input::Initializer(0)); in TEST() 154 feeds.emplace(Output(branch_index.node()), Input::Initializer(1)); in TEST() 163 feeds.emplace(Output(branch_index.node()), Input::Initializer(2)); in TEST() 172 feeds.emplace(Output(branch_index.node()), Input::Initializer(20)); in TEST() 226 auto branch_index = in TEST() local 237 .Input(branch_index.node()) in TEST() 275 feeds.emplace(Output(branch_index.node()), Input::Initializer(-5)); in TEST() [all …]
|
D | lower_case_op.cc | 121 Node* branch_index; in CreatePivotNodes() local 128 .Finalize(graph_, &branch_index)); in CreatePivotNodes() 129 control_predecessor_ = branch_index; in CreatePivotNodes() 135 .Input(branch_index, b) in CreatePivotNodes()
|
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/ |
D | switch_n.pbtxt | 6 # CHECK-SAME: loc(fused["_SwitchN:", "Case/branch_index/_3"]) 13 name: "Case/branch_index" 55 name: "Case/branch_index/_3" 57 input: "Case/branch_index" 58 input: "Case/branch_index" 76 input: "Case/branch_index" 101 input: "Case/branch_index/_3" 112 input: "Case/branch_index/_3:1" 123 input: "Case/branch_index/_3:2"
|
/external/tensorflow/tensorflow/core/ir/importexport/tests/roundtrip/ |
D | switch_n.pbtxt | 2 name: "Case/branch_index" 44 name: "Case/branch_index/_3" 46 input: "Case/branch_index" 47 input: "Case/branch_index" 65 input: "Case/branch_index" 90 input: "Case/branch_index/_3" 101 input: "Case/branch_index/_3:1" 112 input: "Case/branch_index/_3:2"
|
/external/tensorflow/tensorflow/compiler/tests/ |
D | case_test.py | 32 def switch_case_test(branch_index): argument 44 branch_index, branch_fns={ 59 branch_index = array_ops.constant(0) 72 branch_index, branch_fns={
|
/external/tensorflow/tensorflow/core/transforms/cf_sink/tests/ |
D | sink_region_op.mlir | 35 %branch_index: tensor<i32> {tfg.name = "branch_index"}, 41 %Case, %ctl_Case = StatelessCaseRegion %branch_index { 54 %Case_0, %ctl_Case_0 = StatelessCaseRegion %branch_index { 63 yield(%branch_index) : tensor<i32>
|
/external/tensorflow/tensorflow/compiler/tf2xla/kernels/ |
D | case_op.cc | 60 int32_t branch_index = branch_index_literal.Get<int32>({}); in GetPrunedBranchesAndIndex() local 61 if (branch_index < 0 || branch_index >= unpruned_branches_.size()) { in GetPrunedBranchesAndIndex() 62 branch_index = unpruned_branches_.size() - 1; in GetPrunedBranchesAndIndex() 65 std::vector<NameAttrList> pruned_branch = {unpruned_branches_[branch_index]}; in GetPrunedBranchesAndIndex() 88 xla::XlaOp branch_index; in Compile() local 89 std::tie(branches, branch_index) = GetPrunedBranchesAndIndex(ctx); in Compile() 299 ctx->builder(), branch_index, absl::MakeSpan(result_computations), in Compile()
|
/external/tensorflow/tensorflow/compiler/tf2xla/ |
D | functionalize_cond.cc | 442 int branch_index = static_cast<int>(branch); in BuildArgumentNodes() local 448 .Finalize(bodies_[branch_index].get(), in BuildArgumentNodes() 449 &cond_arg_node.branch_copy[branch_index])); in BuildArgumentNodes() 454 int branch_index = e->src_output(); in BuildArgumentNodes() local 455 Node* src_copy = cond_arg_node.branch_copy[branch_index]; in BuildArgumentNodes() 456 Node* dst_copy = node_maps_[branch_index][e->dst()->id()]; in BuildArgumentNodes() 463 << " on branch " << Branch_Name(BranchType(branch_index)); in BuildArgumentNodes() 468 bodies_[branch_index]->AddEdge(src_copy, 0, dst_copy, dst_input); in BuildArgumentNodes() 565 int branch_index = static_cast<int>(branch); in ExtractBodies() local 566 auto output = bodies_[branch_index].get(); in ExtractBodies() [all …]
|
/external/tensorflow/tensorflow/core/kernels/data/experimental/ |
D | choose_fastest_branch_dataset_op.cc | 498 Status MakeCurrentIterator(IteratorContext* ctx, int64_t branch_index, in MakeCurrentIterator() argument 501 DCHECK_GE(branch_index, 0); in MakeCurrentIterator() 502 DCHECK_LT(branch_index, histograms_.size()); in MakeCurrentIterator() 511 params.node_name = strings::StrCat(params.type_string, branch_index); in MakeCurrentIterator() 524 strings::StrCat(take_dataset_params.type_string, branch_index); in MakeCurrentIterator() 537 ctx, this, {*wrapper_dataset_tensor_}, branch_index, in MakeCurrentIterator() 538 *instantiated_captured_funcs_[branch_index], prefix(), in MakeCurrentIterator() 543 ctx, this, {*wrapper_dataset_tensor_}, branch_index, in MakeCurrentIterator() 544 *instantiated_captured_funcs_[branch_index], prefix(), in MakeCurrentIterator()
|
/external/tensorflow/tensorflow/python/ops/ |
D | cond_v2.py | 605 for branch_index, t in enumerate(branch_outs): 607 with grad_graphs[branch_index].as_default(): 609 forward_graphs[branch_index].inputs[output_idx]) 610 grad_graphs[branch_index].structured_outputs[output_idx] = zeros 1024 def indexed_case(branch_index, argument 1029 if isinstance(branch_index, int): 1030 raise TypeError("branch_index must not be a Python int", branch_index) 1041 branch_index = ops.convert_to_tensor(branch_index, name="branch_index") 1055 op_return_value=branch_index)) 1059 branch_index, [all …]
|
D | control_flow_ops.py | 3246 branch_index): argument 3266 if not isinstance(branch_index, ops.Tensor): 3268 type(branch_index))) 3269 if not branch_index.dtype.is_integer: 3271 branch_index.dtype)) 3309 branch_index, argument 3335 branch_fns, default, branch_index) 3336 with ops.name_scope(name, "case", [branch_index]): 3337 if context.executing_eagerly() and not hasattr(branch_index, "graph"): 3338 branch_index = array_ops.where( [all …]
|
D | control_flow_ops_test.py | 1014 branch_index = array_ops.placeholder_with_default(bi, []) 1015 case_out = control_flow_ops.switch_case(branch_index, branches) 1026 branch_index = array_ops.placeholder_with_default(bi, []) 1028 branch_index, branches, name=self.make_name()) 1039 branch_index = array_ops.placeholder_with_default(bi, []) 1041 branch_index, branches, default=make_func(6), name=self.make_name()) 1056 branch_index = array_ops.placeholder_with_default(bi, []) 1058 branch_index, branches, default=make_func(6), name=self.make_name()) 1087 branch_index = array_ops.placeholder_with_default(bi, []) 1091 case_out = control_flow_ops.switch_case(branch_index, branches) [all …]
|
/external/tensorflow/tensorflow/compiler/xla/service/ |
D | conditional_simplifier.cc | 475 int branch_index = 0; in TryRemoveConditional() local 477 branch_index = conditional->operand(0)->literal().Get<bool>({}) ? 0 : 1; in TryRemoveConditional() 479 branch_index = conditional->operand(0)->literal().Get<int32_t>({}); in TryRemoveConditional() 480 if (branch_index < 0 || branch_index >= conditional->branch_count()) { in TryRemoveConditional() 481 branch_index = conditional->branch_count() - 1; in TryRemoveConditional() 484 HloInstruction* call_op = create_call(branch_index); in TryRemoveConditional()
|
D | dynamic_dimension_inference.cc | 1441 for (int64_t branch_index = 0; branch_index < hlo->branch_count(); in HandleConditional() local 1442 ++branch_index) { in HandleConditional() 1449 const int64_t operand_index = branch_index + 1; in HandleConditional() 1481 HloComputation* branch_computation = hlo->branch_computation(branch_index); in HandleConditional() 1545 for (int64_t branch_index = 0; branch_index < hlo->branch_count(); in HandleConditional() local 1546 ++branch_index) { in HandleConditional() 1558 new_branch_computations[branch_index]->root_instruction(), in HandleConditional() 1564 new_branch_computations[branch_index]->AddInstruction( in HandleConditional() 1578 new_branch_computations[branch_index]->root_instruction(), in HandleConditional() 1580 new_branch_computations[branch_index]->set_root_instruction( in HandleConditional()
|
/external/tensorflow/tensorflow/core/api_def/base_api/ |
D | api_def_Case.pbtxt | 4 name: "branch_index" 28 switch (branch_index) {
|
D | api_def_StatelessCase.pbtxt | 5 name: "branch_index" 29 switch (branch_index) {
|
/external/tensorflow/tensorflow/compiler/xla/tests/ |
D | conditional_test.cc | 199 XlaOp branch_index; in XLA_TEST_P() local 201 bi, 0, "branch_index_arg", &builder, &branch_index); in XLA_TEST_P() 213 Conditional(branch_index, branches_p, operands); in XLA_TEST_P() 242 XlaOp branch_index; in XLA_TEST_P() local 244 bi, 0, "branch_index_arg", &builder, &branch_index); in XLA_TEST_P() 266 Conditional(branch_index, branches_p, operands); in XLA_TEST_P() 416 XlaOp branch_index; in XLA_TEST_P() local 418 CreateR0Parameter<int32_t>(bi, 0, "pred", &builder, &branch_index); in XLA_TEST_P() 437 Conditional(branch_index, branches_p, in XLA_TEST_P()
|
/external/tensorflow/tensorflow/compiler/xla/client/lib/ |
D | dynamic_shaped_ops.cc | 167 XlaBuilder* builder, XlaOp branch_index, in DynamicConditional() argument 185 return xla::Conditional(branch_index, branch_computations, in DynamicConditional() 227 return xla::Conditional(branch_index, rewritten_computation_ptrs, in DynamicConditional()
|
D | dynamic_shaped_ops.h | 39 XlaBuilder* builder, XlaOp branch_index,
|
/external/tensorflow/tensorflow/core/ops/compat/ops_history_v2/ |
D | StatelessCase.pbtxt | 4 name: "branch_index"
|
D | Case.pbtxt | 4 name: "branch_index"
|
/external/tensorflow/tensorflow/core/ops/compat/ops_history_v1/ |
D | Case.pbtxt | 4 name: "branch_index"
|
/external/tensorflow/tensorflow/compiler/xla/client/ |
D | value_inference.cc | 577 int64_t branch_index = 0; in AnalyzeConstantValueFallback() local 580 branch_index = 0; in AnalyzeConstantValueFallback() 582 branch_index = 1; in AnalyzeConstantValueFallback() 585 branch_index = operands[0].GetIntegralAsS64({}).value(); in AnalyzeConstantValueFallback() 587 const int64_t branch_dynamism_index = 2 + branch_index; in AnalyzeConstantValueFallback() 1275 int64_t branch_index = 0; in AnalyzeIsDynamic() local 1278 branch_index = 0; in AnalyzeIsDynamic() 1280 branch_index = 1; in AnalyzeIsDynamic() 1283 branch_index = operands[0].GetIntegralAsS64({}).value(); in AnalyzeIsDynamic() 1285 const int64_t branch_dynamism_index = 2 + 2 * branch_index + 1; in AnalyzeIsDynamic()
|
/external/tensorflow/tensorflow/dtensor/mlir/expansions/ |
D | save_restore_spmd_expander.cc | 369 mlir::Value branch_index = DeviceIdToLocalBranchIndex( in ConditionalSave() local 376 /*branch_index=*/branch_index, in ConditionalSave() 642 mlir::Value branch_index = DeviceIdToLocalBranchIndex( in ExpandRestoreV2OpHelper() local 648 /*branch_index=*/branch_index, in ExpandRestoreV2OpHelper()
|