Home
last modified time | relevance | path

Searched refs:shape_info (Results 1 – 23 of 23) sorted by relevance

/external/tensorflow/tensorflow/compiler/jit/
Dshape_inference_test.cc51 GraphShapeInfo shape_info; in TEST() local
53 /*fnlib_def=*/nullptr, &shape_info)); in TEST()
61 TF_EXPECT_OK(ShapeAnnotationsMatch(*graph, shape_info, expected)); in TEST()
115 GraphShapeInfo shape_info; in TEST() local
117 &shape_info)); in TEST()
122 TF_EXPECT_OK(ShapeAnnotationsMatch(graph, shape_info, expected)); in TEST()
161 GraphShapeInfo shape_info; in TEST() local
163 &shape_info)); in TEST()
164 auto iter = shape_info.find("sink"); in TEST()
165 EXPECT_NE(iter, shape_info.end()); in TEST()
Dtest_util.cc24 const Graph& graph, const GraphShapeInfo& shape_info, in ShapeAnnotationsMatch() argument
27 auto sit = shape_info.find(node->name()); in ShapeAnnotationsMatch()
28 TF_RET_CHECK(sit != shape_info.end()) in ShapeAnnotationsMatch()
Dshape_inference.cc210 GraphShapeInfo* shape_info) { in StoreOutputShapes() argument
215 auto& outputs = (*shape_info)[node->name()]; in StoreOutputShapes()
248 GraphShapeInfo* shape_info) { in InferShapes() argument
270 return StoreOutputShapes(*graph, shape_refiner, shape_info); in InferShapes()
Dshape_inference.h45 GraphShapeInfo* shape_info);
Dtest_util.h41 const Graph& graph, const GraphShapeInfo& shape_info,
Dencapsulate_util.cc326 GraphShapeInfo shape_info; in PerformStaticShapeInferenceBeforeEncapsulation() local
328 InferShapes(g, arg_shapes, /*fnlib_def=*/nullptr, &shape_info)); in PerformStaticShapeInferenceBeforeEncapsulation()
332 for (auto iter : shape_info) { in PerformStaticShapeInferenceBeforeEncapsulation()
/external/tensorflow/tensorflow/python/kernel_tests/linalg/
Dlinear_operator_block_diag_test.py76 shape_info = linear_operator_test_util.OperatorShapesInfo
78 shape_info((0, 0)),
79 shape_info((1, 1)),
80 shape_info((1, 3, 3)),
81 shape_info((5, 5), blocks=[(2, 2), (3, 3)]),
82 shape_info((3, 7, 7), blocks=[(1, 2, 2), (3, 2, 2), (1, 3, 3)]),
83 shape_info((2, 1, 5, 5), blocks=[(2, 1, 2, 2), (1, 3, 3)]),
91 self, shape_info, dtype, use_placeholder, argument
93 shape = list(shape_info.shape)
95 shape_info.__dict__["blocks"] if "blocks" in shape_info.__dict__
[all …]
Dlinear_operator_block_lower_triangular_test.py84 shape_info = linear_operator_test_util.OperatorShapesInfo
86 shape_info((0, 0)),
87 shape_info((1, 1)),
88 shape_info((1, 3, 3)),
89 shape_info((5, 5), blocks=[[(2, 2)], [(3, 2), (3, 3)]]),
90 shape_info((3, 7, 7),
93 shape_info((2, 4, 6, 6),
98 self, shape_info, dtype, use_placeholder, argument
102 shape_info.__dict__["blocks"] if "blocks" in shape_info.__dict__
103 else [[list(shape_info.shape)]])
[all …]
Dlinear_operator_householder_test.py41 shape_info = linear_operator_test_util.OperatorShapesInfo
43 shape_info((1, 1)),
44 shape_info((1, 3, 3)),
45 shape_info((3, 4, 4)),
46 shape_info((2, 1, 4, 4))]
Dlinear_operator_permutation_test.py43 shape_info = linear_operator_test_util.OperatorShapesInfo
45 shape_info((1, 1)),
46 shape_info((1, 3, 3)),
47 shape_info((3, 4, 4)),
48 shape_info((2, 1, 4, 4))]
Dlinear_operator_tridiag_test.py94 shape_info = linear_operator_test_util.OperatorShapesInfo
97 shape_info((3, 3)),
98 shape_info((1, 6, 6)),
99 shape_info((3, 4, 4)),
100 shape_info((2, 1, 3, 3))
Dlinear_operator_circulant_test.py118 shape_info, argument
122 shape = shape_info.shape
180 shape_info, argument
184 shape = shape_info.shape
265 shape_info, argument
270 shape = shape_info.shape
418 shape_info = linear_operator_test_util.OperatorShapesInfo
421 shape_info((0, 0)),
422 shape_info((1, 1)),
423 shape_info((1, 6, 6)),
[all …]
Dlinear_operator_toeplitz_test.py74 shape_info = linear_operator_test_util.OperatorShapesInfo
77 shape_info((1, 1)),
78 shape_info((1, 6, 6)),
79 shape_info((3, 4, 4)),
80 shape_info((2, 1, 3, 3))
Dlinear_operator_low_rank_update_test.py53 shape_info = linear_operator_test_util.OperatorShapesInfo
59 shape_info((0, 0)),
60 shape_info((1, 1)),
61 shape_info((1, 3, 3)),
62 shape_info((3, 4, 4)),
63 shape_info((2, 1, 4, 4))]
74 def operator_and_matrix(self, shape_info, dtype, use_placeholder, argument
77 shape = list(shape_info.shape)
Dlinear_operator_kronecker_test.py95 shape_info = linear_operator_test_util.OperatorShapesInfo
97 shape_info((1, 1), factors=[(1, 1), (1, 1)]),
98 shape_info((8, 8), factors=[(2, 2), (2, 2), (2, 2)]),
99 shape_info((12, 12), factors=[(2, 2), (3, 3), (2, 2)]),
100 shape_info((1, 3, 3), factors=[(1, 1), (1, 3, 3)]),
101 shape_info((3, 6, 6), factors=[(3, 1, 1), (1, 2, 2), (1, 3, 3)]),
/external/tensorflow/tensorflow/python/ops/ragged/
Dragged_to_tensor_op_test.py660 for shape_info in ['known', 'unknown_dims', 'unknown_rank']:
662 rt_val = self.wrap_in_placeholder(rt_val, shape_info)
663 default_val = self.wrap_in_placeholder(default_value, shape_info)
664 shape_val = self.wrap_in_placeholder(shape, shape_info)
690 def wrap_in_placeholder(self, arg, shape_info): argument
708 if shape_info == 'known':
712 self.wrap_in_placeholder(arg.flat_values, shape_info))
717 if shape_info == 'unknown_rank':
719 if shape_info == 'unknown_dims':
721 raise AssertionError('Unexpected shape_info %r' % shape_info)
/external/tensorflow/tensorflow/compiler/mlir/xla/experimental/conv_emitter/
Dconv_emitter.cc70 ShapeInfo shape_info; in GetShapeInfo() local
85 shape_info.nchw_dimensions.push_back(shape.dimensions(dim)); in GetShapeInfo()
89 shape_info.physical_dimensions.push_back(shape.dimensions(dim)); in GetShapeInfo()
99 shape_info.affine_map = mlir::AffineMap::get( in GetShapeInfo()
103 shape_info.element_type = [&] { in GetShapeInfo()
115 return shape_info; in GetShapeInfo()
/external/tensorflow/tensorflow/core/tpu/kernels/
Dtpu_compile_op_common.cc458 FunctionLibraryRuntime* flr, GraphShapeInfo* shape_info) { in RunShapeInferenceOnComputation() argument
484 shape_info); in RunShapeInferenceOnComputation()
513 GraphShapeInfo shape_info; in OptimizeGraph() local
515 metadata, arg_shapes, graph->get(), flr, &shape_info)); in OptimizeGraph()
519 ConvertGraphShapeInfoToShapeMap(**graph, shape_info, &shape_map); in OptimizeGraph()
526 GraphShapeInfo shape_info; in OptimizeGraph() local
528 metadata, arg_shapes, graph->get(), flr, &shape_info)); in OptimizeGraph()
530 ConvertGraphShapeInfoToShapeMap(**graph, shape_info, &shape_map); in OptimizeGraph()
Dtpu_compile_op_common.h112 FunctionLibraryRuntime* flr, GraphShapeInfo* shape_info);
/external/tensorflow/tensorflow/core/tpu/graph_rewrite/
Ddistributed_tpu_rewrite_pass.h294 const GraphShapeInfo& shape_info, const Node& node,
557 const GraphShapeInfo& shape_info,
Ddistributed_tpu_rewrite_pass.cc1679 static Status GetEdgeShape(const GraphShapeInfo& shape_info, const Edge& edge, in GetEdgeShape() argument
1681 auto it = shape_info.find(edge.src()->name()); in GetEdgeShape()
1682 if (it == shape_info.end()) { in GetEdgeShape()
1693 const GraphShapeInfo& shape_info, const Node& node, in GetArgAndRetvalShapes() argument
1712 TF_RETURN_IF_ERROR(GetEdgeShape(shape_info, *input_edges[edge_pos], &info)); in GetArgAndRetvalShapes()
1755 TF_RETURN_IF_ERROR(GetEdgeShape(shape_info, *input_edges[edge_pos], &info)); in GetArgAndRetvalShapes()
1767 TF_RETURN_IF_ERROR(GetEdgeShape(shape_info, *input_edges[edge_pos], &info)); in GetArgAndRetvalShapes()
1787 TF_RETURN_IF_ERROR(GetEdgeShape(shape_info, *input_edges[edge_pos], &info)); in GetArgAndRetvalShapes()
1796 auto it = shape_info.find(node.name()); in GetArgAndRetvalShapes()
1798 if (it != shape_info.end()) { in GetArgAndRetvalShapes()
[all …]
/external/tensorflow/tensorflow/python/ops/linalg/
Dlinear_operator_test_util.py728 for dtype, use_placeholder, shape_info in itertools.product(
734 shape_info.shape, dtype, use_placeholder)])
747 shape_info,
759 use_placeholder, shape_info, dtype)))
/external/tensorflow/tensorflow/compiler/tf2xla/
Dxla_compiler.cc641 GraphShapeInfo shape_info; in GetGraph() local
643 flib_runtime_->GetFunctionLibraryDefinition(), &shape_info) in GetGraph()
647 for (const auto& node_shape_info : shape_info) { in GetGraph()
664 GraphShapeInfo shape_info; in GetGraph() local
666 flib_runtime_->GetFunctionLibraryDefinition(), &shape_info) in GetGraph()
670 for (const auto& node_shape_info : shape_info) { in GetGraph()