Home
last modified time | relevance | path

Searched refs:index_shape (Results 1 – 19 of 19) sorted by relevance

/external/tensorflow/tensorflow/compiler/xla/client/lib/
Dslicing.cc161 TF_ASSIGN_OR_RETURN(Shape index_shape, builder->GetShape(index)); in TorchGather()
163 if (ShapeUtil::ElementHasBitWidth(index_shape, 64) && in TorchGather()
166 index_shape.set_element_type(U32); in TorchGather()
168 if (index_shape.rank() == 1) { in TorchGather()
175 for (int64 i = 0; i < index_shape.rank(); ++i) { in TorchGather()
187 sizes.push_back(index_shape.dimensions(i)); in TorchGather()
191 Iota(builder, ShapeUtil::MakeShape(index_shape.element_type(), sizes), in TorchGather()
203 ShapeUtil::AppendMajorDimension(1, &index_shape); in TorchGather()
209 to_concat.push_back(Reshape(index, index_shape.dimensions())); in TorchGather()
211 to_concat.push_back(Iota(builder, index_shape, i)); in TorchGather()
[all …]
/external/tensorflow/tensorflow/core/tpu/kernels/xla/
Dget_item_op.cc37 const TensorShape& index_shape = ctx->InputShape(1); in Compile() local
41 OP_REQUIRES(ctx, index_shape.dims() == 1 && index_shape.dim_size(0) == 1, in Compile()
/external/tensorflow/tensorflow/lite/kernels/
Dembedding_lookup_test.cc41 BaseEmbeddingLookupOpModel(std::initializer_list<int> index_shape, in BaseEmbeddingLookupOpModel() argument
49 BuildInterpreter({index_shape, weight_shape}); in BaseEmbeddingLookupOpModel()
90 HybridEmbeddingLookupOpModel(std::initializer_list<int> index_shape, in HybridEmbeddingLookupOpModel() argument
93 : BaseEmbeddingLookupOpModel(index_shape, weight_shape, type) {} in HybridEmbeddingLookupOpModel()
/external/tensorflow/tensorflow/compiler/tf2xla/kernels/
Ddynamic_slice_ops.cc51 const TensorShape index_shape = ctx->InputShape("indices"); in Compile() local
56 TensorShapeUtils::IsVector(index_shape) && in Compile()
57 index_shape.num_elements() == rank, in Compile()
Dsplit_op.cc106 const TensorShape index_shape = ctx->InputShape(2); in Compile() local
108 OP_REQUIRES(ctx, index_shape.num_elements() == 1, in Compile()
/external/tensorflow/tensorflow/compiler/xla/service/
Dgather_expander.cc114 const Shape& index_shape = index_vector->shape(); in ExpandIndexVectorIntoOperandSpace() local
120 LiteralUtil::CreateFromDimensions(index_shape.element_type(), {0}))); in ExpandIndexVectorIntoOperandSpace()
125 LiteralUtil::CreateFromDimensions(index_shape.element_type(), {1}))); in ExpandIndexVectorIntoOperandSpace()
Dscatter_expander.cc136 const Shape& index_shape = index_vector->shape(); in ExpandIndexVectorIntoOperandSpace() local
141 LiteralUtil::CreateFromDimensions(index_shape.element_type(), {0}))); in ExpandIndexVectorIntoOperandSpace()
146 LiteralUtil::CreateFromDimensions(index_shape.element_type(), {1}))); in ExpandIndexVectorIntoOperandSpace()
Dshape_inference.cc2539 for (const Shape& index_shape : start_index_shapes) { in InferDynamicSliceShape() local
2540 if (!ShapeUtil::Compatible(first_index_shape, index_shape)) { in InferDynamicSliceShape()
2545 ShapeUtil::HumanString(index_shape)); in InferDynamicSliceShape()
2646 for (const Shape& index_shape : start_index_shapes) { in InferDynamicUpdateSliceShape() local
2647 if (!ShapeUtil::Compatible(first_index_shape, index_shape)) { in InferDynamicUpdateSliceShape()
2652 ShapeUtil::HumanString(index_shape)); in InferDynamicUpdateSliceShape()
Dhlo_alias_analysis_test.cc1068 Shape index_shape = ShapeUtil::MakeShape(S32, {}); in TEST_F() local
1075 HloInstruction::CreateParameter(2, index_shape, "param2")); in TEST_F()
Dalgebraic_simplifier.cc2530 const Shape& index_shape = gather->operand(1)->shape(); in HandleGather() local
2534 index_shape.rank() && in HandleGather()
2549 index_shape.element_type()); in HandleGather()
/external/tensorflow/tensorflow/core/kernels/batching_util/
Dbatch_resource_base.cc676 TensorShape index_shape({0, 3}); in ProcessBatch() local
680 task.context->allocate_output(num_input_edges, index_shape, &output), in ProcessBatch()
708 const TensorShape index_shape({batch.num_tasks(), 3}); in EmitIndexTensor() local
711 context->allocate_output(output_index, index_shape, &index)); in EmitIndexTensor()
/external/tensorflow/tensorflow/lite/kernels/parse_example/
Dparse_example.cc495 TfLiteIntArray* index_shape = TfLiteIntArrayCreate(2); in FastParseExampleLite() local
496 index_shape->data[0] = total_num_features; in FastParseExampleLite()
497 index_shape->data[1] = 2; in FastParseExampleLite()
498 context->ResizeTensor(context, indices, index_shape); in FastParseExampleLite()
/external/tensorflow/tensorflow/compiler/xla/service/spmd/
Dspmd_partitioner_util.cc1021 auto index_shape = ShapeUtil::ChangeElementType(valid_slice->shape(), S32); in ExchangeHaloAndGetValidData() local
1022 auto iota = b->AddInstruction(HloInstruction::CreateIota(index_shape, dim)); in ExchangeHaloAndGetValidData()
1025 index_shape, offset_on_padded_shape, {})); in ExchangeHaloAndGetValidData()
1027 HloInstruction::CreateBinary(index_shape, HloOpcode::kAdd, iota, in ExchangeHaloAndGetValidData()
1029 auto mask_shape = ShapeUtil::ChangeElementType(index_shape, PRED); in ExchangeHaloAndGetValidData()
1034 index_shape, in ExchangeHaloAndGetValidData()
1046 index_shape, in ExchangeHaloAndGetValidData()
Dspmd_partitioner.cc489 auto index_shape = ShapeUtil::ChangeElementType(shape, S32); in PadWithValue() local
490 auto mask_shape = ShapeUtil::ChangeElementType(index_shape, PRED); in PadWithValue()
494 state_.b->AddInstruction(HloInstruction::CreateIota(index_shape, dim)); in PadWithValue()
496 HloInstruction::CreateBroadcast(index_shape, start_index, {})); in PadWithValue()
499 index_shape, HloOpcode::kAdd, iota, broadcast_start_index)); in PadWithValue()
505 index_shape.dimensions(dim) * sharding.tile_assignment().dim(dim) - in PadWithValue()
511 HloInstruction::CreateBroadcast(index_shape, limit, {})); in PadWithValue()
/external/tensorflow/tensorflow/compiler/tf2xla/
Dxla_compiler_test.cc1618 xla::Shape index_shape; in TEST_F() local
1619 TF_ASSERT_OK(TensorShapeToXLAShape(DT_INT32, TensorShape{}, &index_shape)); in TEST_F()
1620 std::vector<xla::Shape> shapes{tensor_list_element_shape, index_shape}; in TEST_F()
/external/tensorflow/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/
Dhlo_ops.cc3001 auto index_shape = index_type.getShape().vec(); in fold() local
3002 index_shape.push_back(1); in fold()
3004 RankedTensorType::get(index_shape, index_type.getElementType()); in fold()
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/ir/
Dtf_ops_a_m.cc1979 ArrayRef<int64_t> index_shape = index_ty.getShape(); in Verify() local
1980 if (failed(mlir::verifyCompatibleShape(index_shape, in Verify()
/external/tensorflow/tensorflow/lite/delegates/nnapi/
Dnnapi_delegate_test.cc4752 BaseEmbeddingLookupOpModel(std::initializer_list<int> index_shape, in BaseEmbeddingLookupOpModel() argument
4759 BuildInterpreterWithNNAPI({index_shape, weight_shape}); in BaseEmbeddingLookupOpModel()
/external/tensorflow/tensorflow/python/ops/parallel_for/
Dpfor.py3931 index_shape = array_ops.shape(index)
3936 [index_shape, array_ops.shape(values)[1:]], axis=0)