Home
last modified time | relevance | path

Searched refs:gather_op (Results 1 – 9 of 9) sorted by relevance

/external/tensorflow/tensorflow/lite/toco/graph_transformations/
Dunpartition_embedding_lookup.cc132 for (auto* gather_op : gather_ops) { in Run() local
133 auto* op = GetOpWithOutput(*model, gather_op->inputs[1]); in Run()
134 CHECK(op) << "Source of " << gather_op->inputs[1] << " not found"; in Run()
139 LogName(*op), LogName(*gather_op)); in Run()
150 LogName(*op), LogName(*gather_op)); in Run()
178 for (const auto& gather_op : gather_ops) { in Run() local
179 gather_params_concat_op->inputs.push_back(gather_op->inputs[0]); in Run()
238 for (auto* gather_op : gather_ops) { in Run() local
239 DeleteOpAndArrays(model, gather_op); in Run()
Dresolve_gather_attributes.cc31 auto* gather_op = model->operators[op_index].get(); in Run() local
32 if (gather_op->type != OperatorType::kGather) in Run()
34 auto* op = static_cast<GatherOperator*>(gather_op); in Run()
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/transforms/
Dlegalize_hlo.cc1013 mhlo::GatherOp gather_op, ArrayRef<Value> args, in matchAndRewrite() argument
1015 Value operand = gather_op.operand(); in matchAndRewrite()
1016 Value start_indices = gather_op.start_indices(); in matchAndRewrite()
1021 ShapedType result_type = gather_op.getResult().getType().cast<ShapedType>(); in matchAndRewrite()
1029 gather_op.dimension_numbers().index_vector_dim().getInt(); in matchAndRewrite()
1030 if (failed(NormalizeIndexVector(gather_op, start_indices, in matchAndRewrite()
1038 auto start_index_map = gather_op.dimension_numbers().start_index_map(); in matchAndRewrite()
1040 gather_op.dimension_numbers().collapsed_slice_dims(); in matchAndRewrite()
1047 gather_op, "unsupported start index map and/or collapsed slice dims"); in matchAndRewrite()
1052 auto slice_sizes = gather_op.slice_sizes(); in matchAndRewrite()
[all …]
/external/tensorflow/tensorflow/python/kernel_tests/
Dgather_nd_op_test.py398 gather_op = array_ops.gather_nd(t_params, t_indices)
401 self.evaluate(gather_op)
404 self.evaluate(gather_op)
/external/tensorflow/tensorflow/lite/toco/
Dexport_tensorflow.cc1233 tensorflow::NodeDef* gather_op = tensorflow_graph->add_node(); in ConvertGatherOperator() local
1234 gather_op->set_op("GatherV2"); in ConvertGatherOperator()
1235 gather_op->set_name(src_op.outputs[0]); in ConvertGatherOperator()
1236 *gather_op->add_input() = src_op.inputs[0]; in ConvertGatherOperator()
1237 *gather_op->add_input() = src_op.inputs[1]; in ConvertGatherOperator()
1242 *gather_op->add_input() = src_op.inputs[2]; in ConvertGatherOperator()
1247 AvailableArrayName(model, gather_op->name() + "/axis"); in ConvertGatherOperator()
1250 *gather_op->add_input() = gather_axis; in ConvertGatherOperator()
1253 (*gather_op->mutable_attr())["Tindices"].set_type(DT_INT32); in ConvertGatherOperator()
1254 (*gather_op->mutable_attr())["Taxis"].set_type(DT_INT32); in ConvertGatherOperator()
[all …]
/external/tensorflow/tensorflow/compiler/tf2xla/kernels/
DBUILD51 "gather_op.cc",
/external/tensorflow/tensorflow/core/kernels/
Dreduction_gpu_kernels.cu.h959 GatherOp gather_op(extent_x, extent_y, extent_z, false);
963 gatherIterType gather_iter(counting_iter, gather_op);
DBUILD967 ":gather_op",
1081 name = "gather_op",
1082 prefix = "gather_op",
1945 ":gather_op",
5813 "gather_op.cc",
/external/llvm-project/mlir/test/Conversion/VectorToLLVM/
Dvector-to-llvm.mlir1092 func @gather_op(%arg0: memref<?xf32>, %arg1: vector<3xi32>, %arg2: vector<3xi1>, %arg3: vector<3xf3…
1097 // CHECK-LABEL: func @gather_op