Home
last modified time | relevance | path

Searched refs:branch_index (Results 1 – 25 of 30) sorted by relevance

12

/external/tensorflow/tensorflow/compiler/xla/service/gpu/
Dconditional_thunk.cc65 int32 branch_index = -1; in ExecuteOnStream() local
72 stream.ThenMemcpy(&branch_index, branch_index_address, sizeof(int32)); in ExecuteOnStream()
82 branch_index = pred ? 0 : 1; in ExecuteOnStream()
85 if (branch_index < 0 || branch_index >= hlo_instruction()->branch_count()) { in ExecuteOnStream()
86 branch_index = hlo_instruction()->branch_count() - 1; in ExecuteOnStream()
92 TF_RETURN_IF_ERROR(branch_thunks_[branch_index]->ExecuteOnStream(params)); in ExecuteOnStream()
94 hlo_instruction()->branch_computation(branch_index)); in ExecuteOnStream()
/external/tensorflow/tensorflow/core/common_runtime/
Dlower_case_op_test.cc81 auto branch_index = in TEST() local
87 .Input(branch_index.node()) in TEST()
138 feeds.emplace(Output(branch_index.node()), Input::Initializer(-1)); in TEST()
147 feeds.emplace(Output(branch_index.node()), Input::Initializer(0)); in TEST()
156 feeds.emplace(Output(branch_index.node()), Input::Initializer(1)); in TEST()
165 feeds.emplace(Output(branch_index.node()), Input::Initializer(2)); in TEST()
174 feeds.emplace(Output(branch_index.node()), Input::Initializer(20)); in TEST()
228 auto branch_index = in TEST() local
239 .Input(branch_index.node()) in TEST()
277 feeds.emplace(Output(branch_index.node()), Input::Initializer(-5)); in TEST()
[all …]
Dlower_case_op.cc122 Node* branch_index; in CreatePivotNodes() local
129 .Finalize(graph_, &branch_index)); in CreatePivotNodes()
130 control_predecessor_ = branch_index; in CreatePivotNodes()
136 .Input(branch_index, b) in CreatePivotNodes()
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/
Dswitch_n.pbtxt6 # CHECK-SAME: loc("Case/branch_index/_3")
13 name: "Case/branch_index"
55 name: "Case/branch_index/_3"
57 input: "Case/branch_index"
58 input: "Case/branch_index"
76 input: "Case/branch_index"
101 input: "Case/branch_index/_3"
112 input: "Case/branch_index/_3:1"
123 input: "Case/branch_index/_3:2"
/external/tensorflow/tensorflow/compiler/tf2xla/
Dfunctionalize_cond.cc436 int branch_index = static_cast<int>(branch); in BuildArgumentNodes() local
442 .Finalize(bodies_[branch_index].get(), in BuildArgumentNodes()
443 &cond_arg_node.branch_copy[branch_index])); in BuildArgumentNodes()
448 int branch_index = e->src_output(); in BuildArgumentNodes() local
449 Node* src_copy = cond_arg_node.branch_copy[branch_index]; in BuildArgumentNodes()
450 Node* dst_copy = node_maps_[branch_index][e->dst()->id()]; in BuildArgumentNodes()
457 << " on branch " << Branch_Name(BranchType(branch_index)); in BuildArgumentNodes()
462 bodies_[branch_index]->AddEdge(src_copy, 0, dst_copy, dst_input); in BuildArgumentNodes()
559 int branch_index = static_cast<int>(branch); in ExtractBodies() local
560 auto output = bodies_[branch_index].get(); in ExtractBodies()
[all …]
/external/tensorflow/tensorflow/python/ops/
Dcond_v2.py553 for branch_index, t in enumerate(branch_outs):
555 with grad_graphs[branch_index].as_default():
557 forward_graphs[branch_index].inputs[output_idx])
558 grad_graphs[branch_index].structured_outputs[output_idx] = zeros
943 def indexed_case(branch_index, branch_fns, name="indexed_case"): argument
945 if isinstance(branch_index, int):
946 raise TypeError("branch_index must not be a Python int", branch_index)
957 branch_index = ops.convert_to_tensor(branch_index, name="branch_index")
971 op_return_value=branch_index))
975 branch_index,
[all …]
Dcontrol_flow_ops.py3200 branch_index): argument
3220 if not isinstance(branch_index, ops.Tensor):
3222 type(branch_index)))
3223 if not branch_index.dtype.is_integer:
3225 branch_index.dtype))
3260 def _indexed_case_helper(branch_fns, default, branch_index, name): argument
3283 branch_fns, default, branch_index)
3284 with ops.name_scope(name, "case", [branch_index]):
3285 if context.executing_eagerly() and not hasattr(branch_index, "graph"):
3286 branch_index = array_ops.where(
[all …]
Dcontrol_flow_ops_test.py970 branch_index = array_ops.placeholder_with_default(bi, [])
971 case_out = control_flow_ops.switch_case(branch_index, branches)
982 branch_index = array_ops.placeholder_with_default(bi, [])
984 branch_index, branches, name=self.make_name())
995 branch_index = array_ops.placeholder_with_default(bi, [])
997 branch_index, branches, default=make_func(6), name=self.make_name())
1012 branch_index = array_ops.placeholder_with_default(bi, [])
1014 branch_index, branches, default=make_func(6), name=self.make_name())
1034 branch_index = array_ops.placeholder_with_default(bi, [])
1038 case_out = control_flow_ops.switch_case(branch_index, branches)
[all …]
/external/tensorflow/tensorflow/compiler/xla/service/
Dconditional_simplifier.cc77 int branch_index = 0; in TryRemoveConditional() local
79 branch_index = conditional->operand(0)->literal().Get<bool>({}) ? 0 : 1; in TryRemoveConditional()
81 branch_index = conditional->operand(0)->literal().Get<int32>({}); in TryRemoveConditional()
82 if (branch_index < 0 || branch_index >= conditional->branch_count()) { in TryRemoveConditional()
83 branch_index = conditional->branch_count() - 1; in TryRemoveConditional()
86 HloInstruction* call_op = create_call(branch_index); in TryRemoveConditional()
Dhlo_evaluator.cc1964 int branch_index; in HandleConditional() local
1966 branch_index = branch_index_literal.Get<bool>({}) ? 0 : 1; in HandleConditional()
1968 branch_index = branch_index_literal.Get<int32>({}); in HandleConditional()
1969 if (branch_index < 0 || branch_index >= conditional->branch_count()) { in HandleConditional()
1970 branch_index = conditional->branch_count() - 1; in HandleConditional()
1974 GetEvaluatedLiteralFor(conditional->operand(1 + branch_index)); in HandleConditional()
1981 *conditional->branch_computation(branch_index), in HandleConditional()
Dshape_inference.h217 const Shape& branch_index,
Dshape_inference.cc2673 const Shape& branch_index, in InferConditionalShape() argument
2676 if (!ShapeUtil::Compatible(branch_index, ShapeUtil::MakeShape(PRED, {})) && in InferConditionalShape()
2677 !ShapeUtil::Compatible(branch_index, ShapeUtil::MakeShape(S32, {}))) { in InferConditionalShape()
2679 ShapeUtil::HumanString(branch_index)); in InferConditionalShape()
2681 if (branch_index.element_type() == PRED) { in InferConditionalShape()
Dhlo_instruction.h868 const Shape& shape, HloInstruction* branch_index,
Dhlo_instruction.cc1057 const Shape& shape, HloInstruction* branch_index, in CreateConditional() argument
1062 instruction->AppendOperand(branch_index); in CreateConditional()
/external/tensorflow/tensorflow/core/kernels/data/experimental/
Dchoose_fastest_branch_dataset_op.cc480 Status MakeCurrentIterator(IteratorContext* ctx, int64 branch_index, in MakeCurrentIterator() argument
483 DCHECK_GE(branch_index, 0); in MakeCurrentIterator()
484 DCHECK_LT(branch_index, histograms_.size()); in MakeCurrentIterator()
493 params.node_name = strings::StrCat(params.type_string, branch_index); in MakeCurrentIterator()
506 strings::StrCat(take_dataset_params.type_string, branch_index); in MakeCurrentIterator()
518 ctx, {*wrapper_dataset_tensor_}, branch_index, in MakeCurrentIterator()
519 *instantiated_captured_funcs_[branch_index], prefix(), in MakeCurrentIterator()
/external/tensorflow/tensorflow/core/api_def/base_api/
Dapi_def_Case.pbtxt4 name: "branch_index"
28 switch (branch_index) {
/external/tensorflow/tensorflow/compiler/xla/tests/
Dconditional_test.cc197 XlaOp branch_index; in XLA_TEST_P() local
199 &builder, &branch_index); in XLA_TEST_P()
211 Conditional(branch_index, branches_p, operands); in XLA_TEST_P()
240 XlaOp branch_index; in XLA_TEST_P() local
242 &builder, &branch_index); in XLA_TEST_P()
264 Conditional(branch_index, branches_p, operands); in XLA_TEST_P()
415 XlaOp branch_index; in XLA_TEST_P() local
417 CreateR0Parameter<int32>(bi, 0, "pred", &builder, &branch_index); in XLA_TEST_P()
436 Conditional(branch_index, branches_p, in XLA_TEST_P()
/external/tensorflow/tensorflow/core/ops/compat/ops_history_v1/
DCase.pbtxt4 name: "branch_index"
/external/tensorflow/tensorflow/core/kernels/
Dfunctional_ops.cc235 const Tensor& branch_index = ctx->input(0); in ComputeAsync() local
236 OP_REQUIRES_ASYNC(ctx, TensorShapeUtils::IsScalar(branch_index.shape()), in ComputeAsync()
239 int32 branch = branch_index.scalar<int32>()(); in ComputeAsync()
/external/tensorflow/tensorflow/compiler/xla/client/
Dxla_builder.h575 XlaOp Conditional(XlaOp branch_index,
995 XlaOp branch_index,
999 XlaOp branch_index,
1045 XlaOp branch_index,
1873 XlaOp Conditional(XlaOp branch_index,
Dxla_builder.cc1853 XlaOp branch_index, in Conditional() argument
1857 TF_ASSIGN_OR_RETURN(const xla::Shape* shape, GetShapePtr(branch_index)); in Conditional()
1864 return ConditionalImpl(branch_index, branch_computations, branch_operands); in Conditional()
1869 XlaOp branch_index, in ConditionalImpl() argument
1876 GetShapePtr(branch_index)); in ConditionalImpl()
1896 std::vector<XlaOp> operands(1, branch_index); in ConditionalImpl()
3477 XlaOp Conditional(const XlaOp branch_index, in Conditional() argument
3480 return branch_index.builder()->Conditional(branch_index, branch_computations, in Conditional()
/external/tensorflow/tensorflow/compiler/xla/service/cpu/
Dir_emitter.cc2777 auto branch_index = conditional->operand(0); in HandleConditional() local
2779 TF_RET_CHECK(ShapeUtil::IsScalar(branch_index->shape()) && in HandleConditional()
2780 (branch_index->shape().element_type() == PRED || in HandleConditional()
2781 branch_index->shape().element_type() == S32)) in HandleConditional()
2783 << ShapeUtil::HumanString(branch_index->shape()); in HandleConditional()
2797 if (branch_index->shape().element_type() == PRED) { in HandleConditional()
2804 GetIrArrayFor(branch_index).GetBasePointer(), "load_predicate_value"); in HandleConditional()
2839 GetIrArrayFor(branch_index).GetBasePointer(), "load_branch_index_value"); in HandleConditional()
/external/tensorflow/tensorflow/compiler/xla/g3doc/
Doperation_semantics.md637 <b> `Conditional(branch_index, branch_computations, branch_operands)` </b>
641 | `branch_index` | `XlaOp` | Scalar of type `S32` |
648 Executes `branch_computations[branch_index]`, and returns the result. If
649 `branch_index` is an `S32` which is < 0 or >= N, then `branch_computations[N-1]`
657 the value of `branch_index`.
/external/tensorflow/tensorflow/tools/api/golden/v2/
Dtensorflow.pbtxt1053 …argspec: "args=[\'branch_index\', \'branch_fns\', \'default\', \'name\'], varargs=None, keywords=N…
/external/tensorflow/tensorflow/tools/api/golden/v1/
Dtensorflow.pbtxt2325 …argspec: "args=[\'branch_index\', \'branch_fns\', \'default\', \'name\'], varargs=None, keywords=N…

12