• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/c/eager/abstract_context.h"
16 #include "tensorflow/c/eager/abstract_function.h"
17 #include "tensorflow/c/eager/abstract_tensor_handle.h"
18 #include "tensorflow/c/eager/graph_function.h"
19 #include "tensorflow/c/eager/unified_api_testutil.h"
20 #include "tensorflow/c/experimental/ops/resource_variable_ops.h"
21 #include "tensorflow/c/tf_status_helper.h"
22 #include "tensorflow/cc/experimental/libtf/function.h"
23 #include "tensorflow/cc/experimental/libtf/object.h"
24 #include "tensorflow/cc/experimental/libtf/value.h"
25 #include "tensorflow/core/framework/function.h"
26 #include "tensorflow/core/framework/tensor_shape.h"
27 #include "tensorflow/core/framework/types.pb.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29 #include "tensorflow/core/platform/errors.h"
30 #include "tensorflow/core/platform/statusor.h"
31 #include "tensorflow/core/platform/test.h"
32 
33 namespace tf {
34 namespace libtf {
35 using tensorflow::AbstractContext;
36 using tensorflow::AbstractContextPtr;
37 using tensorflow::AbstractFunctionPtr;
38 using tensorflow::AbstractTensorHandle;
39 using tensorflow::DT_FLOAT;
40 using tensorflow::PartialTensorShape;
41 using tensorflow::Status;
42 using tensorflow::TF_StatusPtr;
43 
44 class VariableTest
45     : public ::testing::TestWithParam<std::tuple<const char*, bool>> {
46  public:
47   template <class T, TF_DataType datatype>
CreateScalarTensor(T val)48   impl::TaggedValueTensor CreateScalarTensor(T val) {
49     AbstractTensorHandle* raw = nullptr;
50     Status s = TestScalarTensorHandle<T, datatype>(ctx_.get(), val, &raw);
51     CHECK_EQ(tensorflow::errors::OK, s.code()) << s.error_message();
52     return impl::TaggedValueTensor(raw, /*add_ref=*/false);
53   }
54 
UseTfrt()55   bool UseTfrt() { return std::get<1>(GetParam()); }
56 
57   AbstractContextPtr ctx_;
58 
59  protected:
SetUp()60   void SetUp() override {
61     // Set the tracing impl, GraphDef vs MLIR.
62     TF_StatusPtr status(TF_NewStatus());
63     TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
64     Status s = tensorflow::StatusFromTF_Status(status.get());
65     CHECK_EQ(tensorflow::errors::OK, s.code()) << s.error_message();
66 
67     // Set the runtime impl, Core RT vs TFRT.
68     AbstractContext* ctx_raw = nullptr;
69     s = BuildImmediateExecutionContext(UseTfrt(), &ctx_raw);
70     CHECK_EQ(tensorflow::errors::OK, s.code()) << s.error_message();
71     ctx_.reset(ctx_raw);
72   }
73 };
74 
75 template <typename T>
ExpectEquals(AbstractTensorHandle * t,T expected)76 void ExpectEquals(AbstractTensorHandle* t, T expected) {
77   TF_Tensor* result_t;
78   Status s = tensorflow::GetValue(t, &result_t);
79   ASSERT_TRUE(s.ok()) << s.error_message();
80   auto value = static_cast<T*>(TF_TensorData(result_t));
81   EXPECT_EQ(*value, expected);
82   TF_DeleteTensor(result_t);
83 }
84 
TEST_P(VariableTest,CreateAssignReadDestroy)85 TEST_P(VariableTest, CreateAssignReadDestroy) {
86   // Create uninitialized variable.
87   tensorflow::AbstractTensorHandlePtr var;
88   {
89     AbstractTensorHandle* var_ptr = nullptr;
90     PartialTensorShape scalar_shape;
91     TF_EXPECT_OK(
92         PartialTensorShape::MakePartialShape<int32>({}, 0, &scalar_shape));
93     TF_EXPECT_OK(tensorflow::ops::VarHandleOp(ctx_.get(), &var_ptr, DT_FLOAT,
94                                               scalar_shape));
95     var.reset(var_ptr);
96   }
97   // Assign a value.
98   auto x = CreateScalarTensor<float, TF_FLOAT>(2.0f);
99   TF_EXPECT_OK(
100       tensorflow::ops::AssignVariableOp(ctx_.get(), var.get(), x.get()));
101   // Read variable.
102   tensorflow::AbstractTensorHandlePtr value;
103   {
104     AbstractTensorHandle* value_ptr = nullptr;
105     TF_EXPECT_OK(tensorflow::ops::ReadVariableOp(ctx_.get(), var.get(),
106                                                  &value_ptr, DT_FLOAT));
107     value.reset(value_ptr);
108   }
109   ExpectEquals(value.get(), 2.0f);
110   // Destroy variable.
111   TF_EXPECT_OK(tensorflow::ops::DestroyResourceOp(ctx_.get(), var.get()));
112 }
113 
114 INSTANTIATE_TEST_SUITE_P(TF2CAPI, VariableTest,
115                          ::testing::Combine(::testing::Values("graphdef",
116                                                               "mlir"),
117                                             ::testing::Values(false, true)));
118 
119 }  // namespace libtf
120 }  // namespace tf
121