1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/c/eager/c_api_unified_experimental.h"
16 #include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
17 #include "tensorflow/c/eager/unified_api_testutil.h"
18 #include "tensorflow/c/tf_status_helper.h"
19 #include "tensorflow/core/framework/tensor_shape.h"
20 #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
21 #include "tensorflow/core/platform/errors.h"
22 #include "tensorflow/core/platform/test.h"
23
24 namespace tensorflow {
25 namespace {
26 class UnifiedAPI
27 : public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
28 protected:
SetUp()29 void SetUp() override {
30 TF_StatusPtr status(TF_NewStatus());
31 TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
32 Status s = StatusFromTF_Status(status.get());
33 CHECK_EQ(errors::OK, s.code()) << s.error_message();
34 }
35
36 public:
UseMlir() const37 bool UseMlir() const { return strcmp(std::get<0>(GetParam()), "mlir") == 0; }
UseFunction() const38 bool UseFunction() const { return std::get<2>(GetParam()); }
39 };
40
41 // Checks that inputs[0] is a scalar.
TestScalarShape(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs)42 Status TestScalarShape(AbstractContext* ctx,
43 absl::Span<AbstractTensorHandle* const> inputs,
44 absl::Span<AbstractTensorHandle*> outputs) {
45 PartialTensorShape shape;
46 TF_RETURN_IF_ERROR(inputs[0]->Shape(&shape));
47 if (shape.dims() != 0) {
48 return errors::InvalidArgument(
49 "Tensor expected to have scalar shape found rank: ", shape.dims());
50 }
51 return Status::OK();
52 }
53
TEST_P(UnifiedAPI,TestTensorShapeScalar)54 TEST_P(UnifiedAPI, TestTensorShapeScalar) {
55 if (UseFunction() && UseMlir()) {
56 // TODO(b/173074167): Remove this.
57 GTEST_SKIP() << "MlirTensor::Shape is not implemented yet.";
58 }
59 AbstractContextPtr ctx;
60 {
61 AbstractContext* ctx_raw = nullptr;
62 Status s =
63 BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
64 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
65 ctx.reset(ctx_raw);
66 }
67
68 AbstractTensorHandlePtr x;
69 {
70 AbstractTensorHandle* x_raw = nullptr;
71 Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
72 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
73 x.reset(x_raw);
74 }
75
76 Status s = RunModel(TestScalarShape, ctx.get(),
77 /*inputs=*/{x.get()},
78 /*outputs=*/{},
79 /*use_function=*/UseFunction());
80 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
81 }
82
83 // Checks that inputs[0] is a matrix with shape 2x4.
TestTensorShape2x4(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs)84 Status TestTensorShape2x4(AbstractContext* ctx,
85 absl::Span<AbstractTensorHandle* const> inputs,
86 absl::Span<AbstractTensorHandle*> outputs) {
87 PartialTensorShape shape;
88 TF_RETURN_IF_ERROR(inputs[0]->Shape(&shape));
89 if (shape.dims() != 2) {
90 return errors::InvalidArgument(
91 "Tensor expected to have rank 2 found rank: ", shape.dims());
92 }
93 int64 dim_sizes[] = {2, 4};
94 for (int i = 0; i < shape.dims(); i++) {
95 if (shape.dim_size(i) != dim_sizes[i]) {
96 return errors::InvalidArgument("Dim ", i, " expected to be of size ",
97 dim_sizes[i],
98 " found: ", shape.dim_size(i));
99 }
100 }
101 return Status::OK();
102 }
103
TEST_P(UnifiedAPI,TestTensorShape2x4)104 TEST_P(UnifiedAPI, TestTensorShape2x4) {
105 if (UseFunction() && UseMlir()) {
106 // TODO(b/173074167): Remove this.
107 GTEST_SKIP() << "MlirTensor::Shape is not implemented yet.";
108 }
109 AbstractContextPtr ctx;
110 {
111 AbstractContext* ctx_raw = nullptr;
112 Status s =
113 BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
114 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
115 ctx.reset(ctx_raw);
116 }
117
118 AbstractTensorHandlePtr x;
119 {
120 AbstractTensorHandle* x_raw = nullptr;
121 float data[] = {0., 0., 0., 0., 0., 0., 0., 0};
122 int64_t dim_sizes[] = {2, 4};
123 Status s =
124 TestTensorHandleWithDimsFloat(ctx.get(), data, dim_sizes, 2, &x_raw);
125 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
126 x.reset(x_raw);
127 }
128
129 Status s = RunModel(TestTensorShape2x4, ctx.get(),
130 /*inputs=*/{x.get()},
131 /*outputs=*/{},
132 /*use_function=*/UseFunction());
133 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
134 }
135
TEST_P(UnifiedAPI,TestUnknownShapeTracing)136 TEST_P(UnifiedAPI, TestUnknownShapeTracing) {
137 if (!UseFunction()) {
138 GTEST_SKIP() << "Tracing only test.";
139 }
140 if (UseMlir()) {
141 // TODO(b/173074167): Remove this.
142 GTEST_SKIP() << "MlirTensor::Shape is not implemented yet.";
143 }
144 AbstractContextPtr ctx(BuildFunction("test_fn"));
145 AbstractTensorHandlePtr x;
146 {
147 tracing::TracingTensorHandle* x_raw = nullptr;
148 PartialTensorShape shape;
149 Status s = dyn_cast<tracing::TracingContext>(ctx.get())->AddParameter(
150 DT_FLOAT, shape, &x_raw);
151 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
152 x.reset(x_raw);
153 }
154
155 PartialTensorShape shape;
156 Status s = x->Shape(&shape);
157 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
158 ASSERT_TRUE(shape.unknown_rank());
159 }
160
TEST_P(UnifiedAPI,TestPartialShapeTracing)161 TEST_P(UnifiedAPI, TestPartialShapeTracing) {
162 if (!UseFunction()) {
163 GTEST_SKIP() << "Tracing only test.";
164 }
165 if (UseMlir()) {
166 GTEST_SKIP() << "MlirTensor::Shape is not implemented yet.";
167 }
168 AbstractContextPtr ctx(BuildFunction("test_fn"));
169 AbstractTensorHandlePtr x;
170 {
171 tracing::TracingTensorHandle* x_raw = nullptr;
172 PartialTensorShape shape;
173 int64 dim_sizes[] = {2, -1};
174 Status s = PartialTensorShape::MakePartialShape(dim_sizes, 2, &shape);
175 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
176 s = dyn_cast<tracing::TracingContext>(ctx.get())->AddParameter(
177 DT_FLOAT, shape, &x_raw);
178 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
179 x.reset(x_raw);
180 }
181
182 PartialTensorShape shape;
183 Status s = x->Shape(&shape);
184 ASSERT_EQ(errors::OK, s.code()) << s.error_message();
185 ASSERT_FALSE(shape.unknown_rank());
186
187 ASSERT_EQ(2, shape.dim_size(0));
188 ASSERT_EQ(-1, shape.dim_size(1));
189 }
190
191 #ifdef PLATFORM_GOOGLE
192 INSTANTIATE_TEST_SUITE_P(
193 UnifiedCppAPI, UnifiedAPI,
194 ::testing::Combine(::testing::Values("graphdef", "mlir"),
195 /*tfrt*/ ::testing::Values(true, false),
196 /*use_function*/ ::testing::Values(true, false)));
197 #else
198 INSTANTIATE_TEST_SUITE_P(
199 UnifiedCppAPI, UnifiedAPI,
200 ::testing::Combine(::testing::Values("graphdef", "mlir"),
201 /*tfrt*/ ::testing::Values(false),
202 /*use_function*/ ::testing::Values(true, false)));
203 #endif
204 } // namespace
205 } // namespace tensorflow
206