• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/c/eager/c_api_unified_experimental.h"
16 #include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
17 #include "tensorflow/c/eager/unified_api_testutil.h"
18 #include "tensorflow/c/tf_status_helper.h"
19 #include "tensorflow/core/framework/tensor_shape.h"
20 #include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
21 #include "tensorflow/core/platform/errors.h"
22 #include "tensorflow/core/platform/test.h"
23 
24 namespace tensorflow {
25 namespace {
26 class UnifiedAPI
27     : public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
28  protected:
SetUp()29   void SetUp() override {
30     TF_StatusPtr status(TF_NewStatus());
31     TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
32     Status s = StatusFromTF_Status(status.get());
33     CHECK_EQ(errors::OK, s.code()) << s.error_message();
34   }
35 
36  public:
UseMlir() const37   bool UseMlir() const { return strcmp(std::get<0>(GetParam()), "mlir") == 0; }
UseFunction() const38   bool UseFunction() const { return std::get<2>(GetParam()); }
39 };
40 
41 // Checks that inputs[0] is a scalar.
TestScalarShape(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs)42 Status TestScalarShape(AbstractContext* ctx,
43                        absl::Span<AbstractTensorHandle* const> inputs,
44                        absl::Span<AbstractTensorHandle*> outputs) {
45   PartialTensorShape shape;
46   TF_RETURN_IF_ERROR(inputs[0]->Shape(&shape));
47   if (shape.dims() != 0) {
48     return errors::InvalidArgument(
49         "Tensor expected to have scalar shape found rank: ", shape.dims());
50   }
51   return Status::OK();
52 }
53 
TEST_P(UnifiedAPI,TestTensorShapeScalar)54 TEST_P(UnifiedAPI, TestTensorShapeScalar) {
55   if (UseFunction() && UseMlir()) {
56     // TODO(b/173074167): Remove this.
57     GTEST_SKIP() << "MlirTensor::Shape is not implemented yet.";
58   }
59   AbstractContextPtr ctx;
60   {
61     AbstractContext* ctx_raw = nullptr;
62     Status s =
63         BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
64     ASSERT_EQ(errors::OK, s.code()) << s.error_message();
65     ctx.reset(ctx_raw);
66   }
67 
68   AbstractTensorHandlePtr x;
69   {
70     AbstractTensorHandle* x_raw = nullptr;
71     Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
72     ASSERT_EQ(errors::OK, s.code()) << s.error_message();
73     x.reset(x_raw);
74   }
75 
76   Status s = RunModel(TestScalarShape, ctx.get(),
77                       /*inputs=*/{x.get()},
78                       /*outputs=*/{},
79                       /*use_function=*/UseFunction());
80   ASSERT_EQ(errors::OK, s.code()) << s.error_message();
81 }
82 
83 // Checks that inputs[0] is a matrix with shape 2x4.
TestTensorShape2x4(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs)84 Status TestTensorShape2x4(AbstractContext* ctx,
85                           absl::Span<AbstractTensorHandle* const> inputs,
86                           absl::Span<AbstractTensorHandle*> outputs) {
87   PartialTensorShape shape;
88   TF_RETURN_IF_ERROR(inputs[0]->Shape(&shape));
89   if (shape.dims() != 2) {
90     return errors::InvalidArgument(
91         "Tensor expected to have rank 2 found rank: ", shape.dims());
92   }
93   int64 dim_sizes[] = {2, 4};
94   for (int i = 0; i < shape.dims(); i++) {
95     if (shape.dim_size(i) != dim_sizes[i]) {
96       return errors::InvalidArgument("Dim ", i, " expected to be of size ",
97                                      dim_sizes[i],
98                                      " found: ", shape.dim_size(i));
99     }
100   }
101   return Status::OK();
102 }
103 
TEST_P(UnifiedAPI,TestTensorShape2x4)104 TEST_P(UnifiedAPI, TestTensorShape2x4) {
105   if (UseFunction() && UseMlir()) {
106     // TODO(b/173074167): Remove this.
107     GTEST_SKIP() << "MlirTensor::Shape is not implemented yet.";
108   }
109   AbstractContextPtr ctx;
110   {
111     AbstractContext* ctx_raw = nullptr;
112     Status s =
113         BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
114     ASSERT_EQ(errors::OK, s.code()) << s.error_message();
115     ctx.reset(ctx_raw);
116   }
117 
118   AbstractTensorHandlePtr x;
119   {
120     AbstractTensorHandle* x_raw = nullptr;
121     float data[] = {0., 0., 0., 0., 0., 0., 0., 0};
122     int64_t dim_sizes[] = {2, 4};
123     Status s =
124         TestTensorHandleWithDimsFloat(ctx.get(), data, dim_sizes, 2, &x_raw);
125     ASSERT_EQ(errors::OK, s.code()) << s.error_message();
126     x.reset(x_raw);
127   }
128 
129   Status s = RunModel(TestTensorShape2x4, ctx.get(),
130                       /*inputs=*/{x.get()},
131                       /*outputs=*/{},
132                       /*use_function=*/UseFunction());
133   ASSERT_EQ(errors::OK, s.code()) << s.error_message();
134 }
135 
TEST_P(UnifiedAPI,TestUnknownShapeTracing)136 TEST_P(UnifiedAPI, TestUnknownShapeTracing) {
137   if (!UseFunction()) {
138     GTEST_SKIP() << "Tracing only test.";
139   }
140   if (UseMlir()) {
141     // TODO(b/173074167): Remove this.
142     GTEST_SKIP() << "MlirTensor::Shape is not implemented yet.";
143   }
144   AbstractContextPtr ctx(BuildFunction("test_fn"));
145   AbstractTensorHandlePtr x;
146   {
147     tracing::TracingTensorHandle* x_raw = nullptr;
148     PartialTensorShape shape;
149     Status s = dyn_cast<tracing::TracingContext>(ctx.get())->AddParameter(
150         DT_FLOAT, shape, &x_raw);
151     ASSERT_EQ(errors::OK, s.code()) << s.error_message();
152     x.reset(x_raw);
153   }
154 
155   PartialTensorShape shape;
156   Status s = x->Shape(&shape);
157   ASSERT_EQ(errors::OK, s.code()) << s.error_message();
158   ASSERT_TRUE(shape.unknown_rank());
159 }
160 
TEST_P(UnifiedAPI,TestPartialShapeTracing)161 TEST_P(UnifiedAPI, TestPartialShapeTracing) {
162   if (!UseFunction()) {
163     GTEST_SKIP() << "Tracing only test.";
164   }
165   if (UseMlir()) {
166     GTEST_SKIP() << "MlirTensor::Shape is not implemented yet.";
167   }
168   AbstractContextPtr ctx(BuildFunction("test_fn"));
169   AbstractTensorHandlePtr x;
170   {
171     tracing::TracingTensorHandle* x_raw = nullptr;
172     PartialTensorShape shape;
173     int64 dim_sizes[] = {2, -1};
174     Status s = PartialTensorShape::MakePartialShape(dim_sizes, 2, &shape);
175     ASSERT_EQ(errors::OK, s.code()) << s.error_message();
176     s = dyn_cast<tracing::TracingContext>(ctx.get())->AddParameter(
177         DT_FLOAT, shape, &x_raw);
178     ASSERT_EQ(errors::OK, s.code()) << s.error_message();
179     x.reset(x_raw);
180   }
181 
182   PartialTensorShape shape;
183   Status s = x->Shape(&shape);
184   ASSERT_EQ(errors::OK, s.code()) << s.error_message();
185   ASSERT_FALSE(shape.unknown_rank());
186 
187   ASSERT_EQ(2, shape.dim_size(0));
188   ASSERT_EQ(-1, shape.dim_size(1));
189 }
190 
191 #ifdef PLATFORM_GOOGLE
192 INSTANTIATE_TEST_SUITE_P(
193     UnifiedCppAPI, UnifiedAPI,
194     ::testing::Combine(::testing::Values("graphdef", "mlir"),
195                        /*tfrt*/ ::testing::Values(true, false),
196                        /*use_function*/ ::testing::Values(true, false)));
197 #else
198 INSTANTIATE_TEST_SUITE_P(
199     UnifiedCppAPI, UnifiedAPI,
200     ::testing::Combine(::testing::Values("graphdef", "mlir"),
201                        /*tfrt*/ ::testing::Values(false),
202                        /*use_function*/ ::testing::Values(true, false)));
203 #endif
204 }  // namespace
205 }  // namespace tensorflow
206