Home
last modified time | relevance | path

Searched refs:arg_shape (Results 1 – 13 of 13) sorted by relevance

/external/tensorflow/tensorflow/compiler/jit/
Dshape_inference.cc83 const InferredShape& arg_shape = it->second; in PropagateShapes() local
87 if (arg_shape.handle_type != DT_INVALID) { in PropagateShapes()
90 arg_shape.handle_shape, &handle)); in PropagateShapes()
95 {handle, arg_shape.handle_type}}); in PropagateShapes()
100 context->MakeShapeFromPartialTensorShape(arg_shape.shape, &handle)); in PropagateShapes()
/external/tensorflow/tensorflow/compiler/mlir/tfrt/python_tests/
Dpython_test_attrs.cc88 const auto& arg_shape = arg_type.getShape(); in verifyRegionArgAttribute() local
90 if (!arg_type.isDynamicDim(i) && arg_shape[i] != attr_shape[i]) { in verifyRegionArgAttribute()
/external/tensorflow/tensorflow/compiler/xla/service/
Dshape_inference.cc372 const Shape* arg_shape = nullptr; in InferConcatOpShape() local
376 if (!arg_shape) { in InferConcatOpShape()
377 arg_shape = shape; in InferConcatOpShape()
378 element_type = arg_shape->element_type(); in InferConcatOpShape()
381 if (arg_shape->rank() != shape->rank()) { in InferConcatOpShape()
385 arg_shape->rank(), ShapeUtil::HumanString(*arg_shape), shape->rank(), in InferConcatOpShape()
388 if (!ShapeUtil::SameElementTypeIgnoringFpPrecision(*arg_shape, *shape)) { in InferConcatOpShape()
391 PrimitiveType_Name(arg_shape->element_type()), in InferConcatOpShape()
394 for (int64_t dimension_number = 0; dimension_number < arg_shape->rank(); in InferConcatOpShape()
396 if (arg_shape->dimensions(dimension_number) != in InferConcatOpShape()
[all …]
Dshape_inference_test.cc1263 Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); in TEST_F() local
1265 {&arg_shape, &f32_}, in TEST_F()
1274 Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); in TEST_F() local
1276 ShapeInference::InferReduceShape({&arg_shape, &f32_}, in TEST_F()
1285 Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); in TEST_F() local
1287 ShapeInference::InferReduceShape({&arg_shape, &f32_}, in TEST_F()
1296 Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); in TEST_F() local
1298 {&arg_shape, &f32_}, in TEST_F()
1663 Shape arg_shape = ShapeUtil::MakeShape(F32, {5, 3}); in TEST_F() local
1666 arg_shape, val_shape, /*dimension=*/0); in TEST_F()
[all …]
Dhlo_evaluator.cc786 const auto& arg_shape = arg_literals[i]->shape(); in Evaluate() local
788 arg_shape)) { in Evaluate()
793 ShapeUtil::HumanStringWithLayout(arg_shape)); in Evaluate()
3671 const Shape& arg_shape = input_args[0]->shape(); in GenerateReduceOutputElement() local
3672 absl::Span<const int64_t> arg_dimensions = arg_shape.dimensions(); in GenerateReduceOutputElement()
3692 arg_shape, base, arg_dim_counts, arg_dim_steps, reduction_step)); in GenerateReduceOutputElement()
3700 arg_shape, base, arg_dim_counts, arg_dim_steps, in GenerateReduceOutputElement()
3740 const Shape& arg_shape = input_args[0]->shape(); in HandleReduce() local
3747 absl::Span<const int64_t> arg_dimensions = arg_shape.dimensions(); in HandleReduce()
/external/tensorflow/tensorflow/compiler/mlir/tensorflow/utils/
Dcompile_mlir_util.cc88 TensorShape arg_shape; in GetTensorShapeFromXlaArgument() local
90 XLAShapeToTensorShape(std::get<xla::Shape>(arg.shape), &arg_shape)); in GetTensorShapeFromXlaArgument()
91 return arg_shape; in GetTensorShapeFromXlaArgument()
681 for (const auto& arg_shape : arg_shapes) in CompileSerializedMlirToXlaHlo() local
682 tensor_or_resource_shapes.push_back({arg_shape}); in CompileSerializedMlirToXlaHlo()
715 TF_ASSIGN_OR_RETURN(TensorShape arg_shape, in RewriteWithArgs()
717 auto resource_shape = arg_shape.dim_sizes(); in RewriteWithArgs()
772 TF_ASSIGN_OR_RETURN(TensorShape arg_shape, in CompileGraphSetup()
774 arg_shapes.push_back({arg_shape, in CompileGraphSetup()
/external/tensorflow/tensorflow/core/transforms/constant_folding/tests/
Dshape_materialization.mlir21 // CHECK: Const [%ArgWithShape.ctl] name("arg_shape") {{.*}} -> (tensor<2xi32>)
22 …%ArgShape, %ctl_10 = Shape(%ArgWithShape) name("arg_shape") {T = i32, out_type = i32} : (tensor<2x…
/external/tensorflow/tensorflow/dtensor/mlir/
Dspmd_expansion.cc224 llvm::ArrayRef<int64_t> arg_shape = ranked_type.getShape(); in UpdateFunctionArgsUsingLayout() local
226 arg_layout->LocalShapeFromGlobalShape(arg_shape); in UpdateFunctionArgsUsingLayout()
/external/tensorflow/tensorflow/compiler/xla/client/
Dclient.cc234 for (const auto& arg_shape : argument_shapes) { in Compile() local
235 *request.add_input_shape_with_layout() = arg_shape.ToProto(); in Compile()
/external/tensorflow/tensorflow/core/tpu/graph_rewrite/
Ddistributed_tpu_rewrite_pass.cc1981 InferredShape& arg_shape = in GetArgAndRetvalShapes() local
1985 arg_shape.shape = TensorShape(); // Variables are always scalars. in GetArgAndRetvalShapes()
1986 arg_shape.handle_shape = info->handle_shape; in GetArgAndRetvalShapes()
1987 arg_shape.handle_type = info->handle_type; in GetArgAndRetvalShapes()
1988 TF_RET_CHECK(arg_shape.handle_type != DT_INVALID) in GetArgAndRetvalShapes()
2648 const InferredShape& arg_shape = arg_shapes[i]; in BuildCompileNode() local
2651 TF_RET_CHECK(arg_shape.handle_type != DT_INVALID) << i; in BuildCompileNode()
2652 arg->set_dtype(arg_shape.handle_type); in BuildCompileNode()
2653 arg_shape.handle_shape.AsProto(arg->mutable_shape()); in BuildCompileNode()
2658 arg_shape.shape.AsProto(arg->mutable_shape()); in BuildCompileNode()
[all …]
/external/tensorflow/tensorflow/compiler/mlir/xla/
Dmlir_hlo_to_hlo.cc2328 xla::Shape& arg_shape = arg_shapes->back(); in SetEntryTupleShapesAndLeafReplication() local
2330 options_.layout_preference_fn ? options_.layout_preference_fn(arg_shape) in SetEntryTupleShapesAndLeafReplication()
2338 arg_shape, /*use_fast_memory=*/false, in SetEntryTupleShapesAndLeafReplication()
2340 : arg_shape; in SetEntryTupleShapesAndLeafReplication()
2345 arg_shape = std::move(arg_shape_status.ValueOrDie()); in SetEntryTupleShapesAndLeafReplication()
2348 for (int i = 0, e = xla::ShapeUtil::GetLeafCount(arg_shape); i < e; ++i) in SetEntryTupleShapesAndLeafReplication()
/external/tensorflow/tensorflow/compiler/mlir/tools/kernel_gen/tests/
Dbuffer_reuse.mlir448 %arg_shape : memref<?xindex>,
468 %result = memref.reshape %flat_result(%arg_shape)
/external/tensorflow/tensorflow/compiler/tf2xla/
Dxla_compiler_test.cc1700 xla::Shape arg_shape = xla::ShapeUtil::MakeTupleShape(shapes); in TEST_F() local
1701 args[0].shape = arg_shape; in TEST_F()