/external/tensorflow/tensorflow/compiler/xla/service/gpu/ |
D | conditional_thunk.cc | 65 int32 branch_index = -1; in ExecuteOnStream() local 72 stream.ThenMemcpy(&branch_index, branch_index_address, sizeof(int32)); in ExecuteOnStream() 82 branch_index = pred ? 0 : 1; in ExecuteOnStream() 85 if (branch_index < 0 || branch_index >= hlo_instruction()->branch_count()) { in ExecuteOnStream() 86 branch_index = hlo_instruction()->branch_count() - 1; in ExecuteOnStream() 92 TF_RETURN_IF_ERROR(branch_thunks_[branch_index]->ExecuteOnStream(params)); in ExecuteOnStream() 94 hlo_instruction()->branch_computation(branch_index)); in ExecuteOnStream()
|
/external/tensorflow/tensorflow/core/common_runtime/ |
D | lower_case_op_test.cc | 81 auto branch_index = in TEST() local 87 .Input(branch_index.node()) in TEST() 138 feeds.emplace(Output(branch_index.node()), Input::Initializer(-1)); in TEST() 147 feeds.emplace(Output(branch_index.node()), Input::Initializer(0)); in TEST() 156 feeds.emplace(Output(branch_index.node()), Input::Initializer(1)); in TEST() 165 feeds.emplace(Output(branch_index.node()), Input::Initializer(2)); in TEST() 174 feeds.emplace(Output(branch_index.node()), Input::Initializer(20)); in TEST() 228 auto branch_index = in TEST() local 239 .Input(branch_index.node()) in TEST() 277 feeds.emplace(Output(branch_index.node()), Input::Initializer(-5)); in TEST() [all …]
|
D | lower_case_op.cc | 122 Node* branch_index; in CreatePivotNodes() local 129 .Finalize(graph_, &branch_index)); in CreatePivotNodes() 130 control_predecessor_ = branch_index; in CreatePivotNodes() 136 .Input(branch_index, b) in CreatePivotNodes()
|
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/tests/graphdef2mlir/ |
D | switch_n.pbtxt | 6 # CHECK-SAME: loc("Case/branch_index/_3") 13 name: "Case/branch_index" 55 name: "Case/branch_index/_3" 57 input: "Case/branch_index" 58 input: "Case/branch_index" 76 input: "Case/branch_index" 101 input: "Case/branch_index/_3" 112 input: "Case/branch_index/_3:1" 123 input: "Case/branch_index/_3:2"
|
/external/tensorflow/tensorflow/compiler/tf2xla/ |
D | functionalize_cond.cc | 436 int branch_index = static_cast<int>(branch); in BuildArgumentNodes() local 442 .Finalize(bodies_[branch_index].get(), in BuildArgumentNodes() 443 &cond_arg_node.branch_copy[branch_index])); in BuildArgumentNodes() 448 int branch_index = e->src_output(); in BuildArgumentNodes() local 449 Node* src_copy = cond_arg_node.branch_copy[branch_index]; in BuildArgumentNodes() 450 Node* dst_copy = node_maps_[branch_index][e->dst()->id()]; in BuildArgumentNodes() 457 << " on branch " << Branch_Name(BranchType(branch_index)); in BuildArgumentNodes() 462 bodies_[branch_index]->AddEdge(src_copy, 0, dst_copy, dst_input); in BuildArgumentNodes() 559 int branch_index = static_cast<int>(branch); in ExtractBodies() local 560 auto output = bodies_[branch_index].get(); in ExtractBodies() [all …]
|
/external/tensorflow/tensorflow/python/ops/ |
D | cond_v2.py | 553 for branch_index, t in enumerate(branch_outs): 555 with grad_graphs[branch_index].as_default(): 557 forward_graphs[branch_index].inputs[output_idx]) 558 grad_graphs[branch_index].structured_outputs[output_idx] = zeros 943 def indexed_case(branch_index, branch_fns, name="indexed_case"): argument 945 if isinstance(branch_index, int): 946 raise TypeError("branch_index must not be a Python int", branch_index) 957 branch_index = ops.convert_to_tensor(branch_index, name="branch_index") 971 op_return_value=branch_index)) 975 branch_index, [all …]
|
D | control_flow_ops.py | 3200 branch_index): argument 3220 if not isinstance(branch_index, ops.Tensor): 3222 type(branch_index))) 3223 if not branch_index.dtype.is_integer: 3225 branch_index.dtype)) 3260 def _indexed_case_helper(branch_fns, default, branch_index, name): argument 3283 branch_fns, default, branch_index) 3284 with ops.name_scope(name, "case", [branch_index]): 3285 if context.executing_eagerly() and not hasattr(branch_index, "graph"): 3286 branch_index = array_ops.where( [all …]
|
D | control_flow_ops_test.py | 970 branch_index = array_ops.placeholder_with_default(bi, []) 971 case_out = control_flow_ops.switch_case(branch_index, branches) 982 branch_index = array_ops.placeholder_with_default(bi, []) 984 branch_index, branches, name=self.make_name()) 995 branch_index = array_ops.placeholder_with_default(bi, []) 997 branch_index, branches, default=make_func(6), name=self.make_name()) 1012 branch_index = array_ops.placeholder_with_default(bi, []) 1014 branch_index, branches, default=make_func(6), name=self.make_name()) 1034 branch_index = array_ops.placeholder_with_default(bi, []) 1038 case_out = control_flow_ops.switch_case(branch_index, branches) [all …]
|
/external/tensorflow/tensorflow/compiler/xla/service/ |
D | conditional_simplifier.cc | 77 int branch_index = 0; in TryRemoveConditional() local 79 branch_index = conditional->operand(0)->literal().Get<bool>({}) ? 0 : 1; in TryRemoveConditional() 81 branch_index = conditional->operand(0)->literal().Get<int32>({}); in TryRemoveConditional() 82 if (branch_index < 0 || branch_index >= conditional->branch_count()) { in TryRemoveConditional() 83 branch_index = conditional->branch_count() - 1; in TryRemoveConditional() 86 HloInstruction* call_op = create_call(branch_index); in TryRemoveConditional()
|
D | hlo_evaluator.cc | 1964 int branch_index; in HandleConditional() local 1966 branch_index = branch_index_literal.Get<bool>({}) ? 0 : 1; in HandleConditional() 1968 branch_index = branch_index_literal.Get<int32>({}); in HandleConditional() 1969 if (branch_index < 0 || branch_index >= conditional->branch_count()) { in HandleConditional() 1970 branch_index = conditional->branch_count() - 1; in HandleConditional() 1974 GetEvaluatedLiteralFor(conditional->operand(1 + branch_index)); in HandleConditional() 1981 *conditional->branch_computation(branch_index), in HandleConditional()
|
D | shape_inference.h | 217 const Shape& branch_index,
|
D | shape_inference.cc | 2673 const Shape& branch_index, in InferConditionalShape() argument 2676 if (!ShapeUtil::Compatible(branch_index, ShapeUtil::MakeShape(PRED, {})) && in InferConditionalShape() 2677 !ShapeUtil::Compatible(branch_index, ShapeUtil::MakeShape(S32, {}))) { in InferConditionalShape() 2679 ShapeUtil::HumanString(branch_index)); in InferConditionalShape() 2681 if (branch_index.element_type() == PRED) { in InferConditionalShape()
|
D | hlo_instruction.h | 868 const Shape& shape, HloInstruction* branch_index,
|
D | hlo_instruction.cc | 1057 const Shape& shape, HloInstruction* branch_index, in CreateConditional() argument 1062 instruction->AppendOperand(branch_index); in CreateConditional()
|
/external/tensorflow/tensorflow/core/kernels/data/experimental/ |
D | choose_fastest_branch_dataset_op.cc | 480 Status MakeCurrentIterator(IteratorContext* ctx, int64 branch_index, in MakeCurrentIterator() argument 483 DCHECK_GE(branch_index, 0); in MakeCurrentIterator() 484 DCHECK_LT(branch_index, histograms_.size()); in MakeCurrentIterator() 493 params.node_name = strings::StrCat(params.type_string, branch_index); in MakeCurrentIterator() 506 strings::StrCat(take_dataset_params.type_string, branch_index); in MakeCurrentIterator() 518 ctx, {*wrapper_dataset_tensor_}, branch_index, in MakeCurrentIterator() 519 *instantiated_captured_funcs_[branch_index], prefix(), in MakeCurrentIterator()
|
/external/tensorflow/tensorflow/core/api_def/base_api/ |
D | api_def_Case.pbtxt | 4 name: "branch_index" 28 switch (branch_index) {
|
/external/tensorflow/tensorflow/compiler/xla/tests/ |
D | conditional_test.cc | 197 XlaOp branch_index; in XLA_TEST_P() local 199 &builder, &branch_index); in XLA_TEST_P() 211 Conditional(branch_index, branches_p, operands); in XLA_TEST_P() 240 XlaOp branch_index; in XLA_TEST_P() local 242 &builder, &branch_index); in XLA_TEST_P() 264 Conditional(branch_index, branches_p, operands); in XLA_TEST_P() 415 XlaOp branch_index; in XLA_TEST_P() local 417 CreateR0Parameter<int32>(bi, 0, "pred", &builder, &branch_index); in XLA_TEST_P() 436 Conditional(branch_index, branches_p, in XLA_TEST_P()
|
/external/tensorflow/tensorflow/core/ops/compat/ops_history_v1/ |
D | Case.pbtxt | 4 name: "branch_index"
|
/external/tensorflow/tensorflow/core/kernels/ |
D | functional_ops.cc | 235 const Tensor& branch_index = ctx->input(0); in ComputeAsync() local 236 OP_REQUIRES_ASYNC(ctx, TensorShapeUtils::IsScalar(branch_index.shape()), in ComputeAsync() 239 int32 branch = branch_index.scalar<int32>()(); in ComputeAsync()
|
/external/tensorflow/tensorflow/compiler/xla/client/ |
D | xla_builder.h | 575 XlaOp Conditional(XlaOp branch_index, 995 XlaOp branch_index, 999 XlaOp branch_index, 1045 XlaOp branch_index, 1873 XlaOp Conditional(XlaOp branch_index,
|
D | xla_builder.cc | 1853 XlaOp branch_index, in Conditional() argument 1857 TF_ASSIGN_OR_RETURN(const xla::Shape* shape, GetShapePtr(branch_index)); in Conditional() 1864 return ConditionalImpl(branch_index, branch_computations, branch_operands); in Conditional() 1869 XlaOp branch_index, in ConditionalImpl() argument 1876 GetShapePtr(branch_index)); in ConditionalImpl() 1896 std::vector<XlaOp> operands(1, branch_index); in ConditionalImpl() 3477 XlaOp Conditional(const XlaOp branch_index, in Conditional() argument 3480 return branch_index.builder()->Conditional(branch_index, branch_computations, in Conditional()
|
/external/tensorflow/tensorflow/compiler/xla/service/cpu/ |
D | ir_emitter.cc | 2777 auto branch_index = conditional->operand(0); in HandleConditional() local 2779 TF_RET_CHECK(ShapeUtil::IsScalar(branch_index->shape()) && in HandleConditional() 2780 (branch_index->shape().element_type() == PRED || in HandleConditional() 2781 branch_index->shape().element_type() == S32)) in HandleConditional() 2783 << ShapeUtil::HumanString(branch_index->shape()); in HandleConditional() 2797 if (branch_index->shape().element_type() == PRED) { in HandleConditional() 2804 GetIrArrayFor(branch_index).GetBasePointer(), "load_predicate_value"); in HandleConditional() 2839 GetIrArrayFor(branch_index).GetBasePointer(), "load_branch_index_value"); in HandleConditional()
|
/external/tensorflow/tensorflow/compiler/xla/g3doc/ |
D | operation_semantics.md | 637 <b> `Conditional(branch_index, branch_computations, branch_operands)` </b> 641 | `branch_index` | `XlaOp` | Scalar of type `S32` | 648 Executes `branch_computations[branch_index]`, and returns the result. If 649 `branch_index` is an `S32` which is < 0 or >= N, then `branch_computations[N-1]` 657 the value of `branch_index`.
|
/external/tensorflow/tensorflow/tools/api/golden/v2/ |
D | tensorflow.pbtxt | 1053 …argspec: "args=[\'branch_index\', \'branch_fns\', \'default\', \'name\'], varargs=None, keywords=N…
|
/external/tensorflow/tensorflow/tools/api/golden/v1/ |
D | tensorflow.pbtxt | 2325 …argspec: "args=[\'branch_index\', \'branch_fns\', \'default\', \'name\'], varargs=None, keywords=N…
|