1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/c/eager/abstract_context.h"
16 #include "tensorflow/c/eager/abstract_function.h"
17 #include "tensorflow/c/eager/abstract_tensor_handle.h"
18 #include "tensorflow/c/eager/graph_function.h"
19 #include "tensorflow/c/eager/unified_api_testutil.h"
20 #include "tensorflow/c/experimental/ops/resource_variable_ops.h"
21 #include "tensorflow/c/tf_status_helper.h"
22 #include "tensorflow/cc/experimental/libtf/function.h"
23 #include "tensorflow/cc/experimental/libtf/object.h"
24 #include "tensorflow/cc/experimental/libtf/value.h"
25 #include "tensorflow/core/framework/function.h"
26 #include "tensorflow/core/framework/tensor_shape.h"
27 #include "tensorflow/core/framework/types.pb.h"
28 #include "tensorflow/core/lib/core/status_test_util.h"
29 #include "tensorflow/core/platform/errors.h"
30 #include "tensorflow/core/platform/statusor.h"
31 #include "tensorflow/core/platform/test.h"
32
33 namespace tf {
34 namespace libtf {
35 using tensorflow::AbstractContext;
36 using tensorflow::AbstractContextPtr;
37 using tensorflow::AbstractFunctionPtr;
38 using tensorflow::AbstractTensorHandle;
39 using tensorflow::DT_FLOAT;
40 using tensorflow::PartialTensorShape;
41 using tensorflow::Status;
42 using tensorflow::TF_StatusPtr;
43
44 class VariableTest
45 : public ::testing::TestWithParam<std::tuple<const char*, bool>> {
46 public:
47 template <class T, TF_DataType datatype>
CreateScalarTensor(T val)48 impl::TaggedValueTensor CreateScalarTensor(T val) {
49 AbstractTensorHandle* raw = nullptr;
50 Status s = TestScalarTensorHandle<T, datatype>(ctx_.get(), val, &raw);
51 CHECK_EQ(tensorflow::errors::OK, s.code()) << s.error_message();
52 return impl::TaggedValueTensor(raw, /*add_ref=*/false);
53 }
54
UseTfrt()55 bool UseTfrt() { return std::get<1>(GetParam()); }
56
57 AbstractContextPtr ctx_;
58
59 protected:
SetUp()60 void SetUp() override {
61 // Set the tracing impl, GraphDef vs MLIR.
62 TF_StatusPtr status(TF_NewStatus());
63 TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
64 Status s = tensorflow::StatusFromTF_Status(status.get());
65 CHECK_EQ(tensorflow::errors::OK, s.code()) << s.error_message();
66
67 // Set the runtime impl, Core RT vs TFRT.
68 AbstractContext* ctx_raw = nullptr;
69 s = BuildImmediateExecutionContext(UseTfrt(), &ctx_raw);
70 CHECK_EQ(tensorflow::errors::OK, s.code()) << s.error_message();
71 ctx_.reset(ctx_raw);
72 }
73 };
74
75 template <typename T>
ExpectEquals(AbstractTensorHandle * t,T expected)76 void ExpectEquals(AbstractTensorHandle* t, T expected) {
77 TF_Tensor* result_t;
78 Status s = tensorflow::GetValue(t, &result_t);
79 ASSERT_TRUE(s.ok()) << s.error_message();
80 auto value = static_cast<T*>(TF_TensorData(result_t));
81 EXPECT_EQ(*value, expected);
82 TF_DeleteTensor(result_t);
83 }
84
TEST_P(VariableTest,CreateAssignReadDestroy)85 TEST_P(VariableTest, CreateAssignReadDestroy) {
86 // Create uninitialized variable.
87 tensorflow::AbstractTensorHandlePtr var;
88 {
89 AbstractTensorHandle* var_ptr = nullptr;
90 PartialTensorShape scalar_shape;
91 TF_EXPECT_OK(
92 PartialTensorShape::MakePartialShape<int32>({}, 0, &scalar_shape));
93 TF_EXPECT_OK(tensorflow::ops::VarHandleOp(ctx_.get(), &var_ptr, DT_FLOAT,
94 scalar_shape));
95 var.reset(var_ptr);
96 }
97 // Assign a value.
98 auto x = CreateScalarTensor<float, TF_FLOAT>(2.0f);
99 TF_EXPECT_OK(
100 tensorflow::ops::AssignVariableOp(ctx_.get(), var.get(), x.get()));
101 // Read variable.
102 tensorflow::AbstractTensorHandlePtr value;
103 {
104 AbstractTensorHandle* value_ptr = nullptr;
105 TF_EXPECT_OK(tensorflow::ops::ReadVariableOp(ctx_.get(), var.get(),
106 &value_ptr, DT_FLOAT));
107 value.reset(value_ptr);
108 }
109 ExpectEquals(value.get(), 2.0f);
110 // Destroy variable.
111 TF_EXPECT_OK(tensorflow::ops::DestroyResourceOp(ctx_.get(), var.get()));
112 }
113
114 INSTANTIATE_TEST_SUITE_P(TF2CAPI, VariableTest,
115 ::testing::Combine(::testing::Values("graphdef",
116 "mlir"),
117 ::testing::Values(false, true)));
118
119 } // namespace libtf
120 } // namespace tf
121