1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_TESTLIB_H_
17 #define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_TESTLIB_H_
18
19 #include <array>
20
21 #include "tensorflow/c/c_api.h"
22 #include "tensorflow/c/c_api_experimental.h"
23 #include "tensorflow/c/eager/c_api.h"
24 #include "tensorflow/c/eager/c_api_experimental.h"
25 #include "tensorflow/c/eager/parallel_device/parallel_device.h"
26 #include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
27 #include "tensorflow/core/platform/test.h"
28
29 namespace tensorflow {
30 namespace parallel_device {
31
32 // A helper for performing common operations on variables. A much more
33 // restricted stand-in for tf.Variable in Python.
34 class Variable {
35 public:
36 // Construct a Variable from a resource-dtype TFE_TensorHandle and an
37 // indication of the dtype of the variable's value.
38 //
39 // Note that creating this resource-dtype handle can fail, so `Create` is a
40 // separate static method which returns a status.
Variable(TFE_TensorHandle * handle,TF_DataType type)41 Variable(TFE_TensorHandle* handle, TF_DataType type)
42 : handle_(handle), type_(type) {}
43
44 // Helper for constructing a resource handle and wrapping it in a `Variable`
45 // object.
46 static Variable* Create(TFE_Context* context, TF_DataType type,
47 const int64_t* dims, const int num_dims,
48 const char* device, TF_Status* status);
49 // Dereferences the backing buffer for the variable. Note that since this can
50 // fail (it runs operations), it must be called explicitly and the resulting
51 // `status` checked.
52 void Destroy(TFE_Context* context, TF_Status* status);
53
54 // Reads from the variable.
55 TensorHandlePtr Read(TFE_Context* context, TF_Status* status);
56 // Assigns a new value to the variable.
57 void Assign(TFE_Context* context, TFE_TensorHandle* value, TF_Status* status);
58 // Adds `value` to the existing value of the variable.
59 void AssignAdd(TFE_Context* context, TFE_TensorHandle* value,
60 TF_Status* status);
61
62 private:
63 // Helper for running any single-argument assignment ops (Assign, AssignAdd,
64 // AssignSub, ...).
65 void GeneralAssignment(const char* op_name, TFE_Context* context,
66 TFE_TensorHandle* value, TF_Status* status);
67
68 // The a handle for the resource-dtype tensor pointing to the variable's
69 // buffer.
70 TFE_TensorHandle* handle_;
71 // The dtype of the variable's buffer (input dtype for assignments, output
72 // dtype of read operations).
73 TF_DataType type_;
74 };
75
76 // Creates a TFE_TensorHandle with value `v`.
77 TensorHandlePtr FloatTensorHandle(float v, TF_Status* status);
78
79 // Creates a rank-one TFE_TensorHandle with value `v`.
80 TensorHandlePtr VectorFloatTensorHandle(const std::vector<float>& v,
81 TF_Status* status);
82
83 // Helper to un-pack `num_replicas` TFE_TensorHandles from one parallel handle.
84 template <std::size_t num_replicas>
85 void ExtractPerDeviceValues(
86 TFE_Context* context, TFE_TensorHandle* input,
87 std::array<TensorHandlePtr, num_replicas>* components, TF_Status* status);
88
89 // Helper to pack `num_replicas` TFE_TensorHandles into one parallel handle.
90 template <std::size_t num_replicas>
91 TensorHandlePtr CreatePerDeviceValues(
92 TFE_Context* context,
93 const std::array<TFE_TensorHandle*, num_replicas>& components,
94 const char* device, TF_Status* status);
95
96 TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first,
97 TFE_TensorHandle* second, TF_Status* status);
98
99 // Assert that `handle` is equal to `expected_value`.
100 template <typename value_type>
101 void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value);
102
103 template <std::size_t num_devices>
104 void RegisterParallelDevice(
105 TFE_Context* context, const char* device_name,
106 const std::array<const char*, num_devices>& underlying_devices,
107 TF_Status* status);
108
109 // Create and modify a variable placed on a parallel device which composes
110 // `first_device` and `second_device`.
111 void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
112 const char* second_device);
113
114 // Implementations of templated functions ******************************
115
116 template <std::size_t num_replicas>
CreatePerDeviceValues(TFE_Context * context,const std::array<TFE_TensorHandle *,num_replicas> & components,const char * device,TF_Status * status)117 TensorHandlePtr CreatePerDeviceValues(
118 TFE_Context* context,
119 const std::array<TFE_TensorHandle*, num_replicas>& components,
120 const char* device, TF_Status* status) {
121 std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
122 TFE_NewOp(context, "TPUReplicatedInput", status), TFE_DeleteOp);
123 if (TF_GetCode(status) != TF_OK) return nullptr;
124 TFE_OpSetAttrInt(op.get(), "N", num_replicas);
125 for (int i = 0; i < num_replicas; ++i) {
126 TFE_OpAddInput(op.get(), components[i], status);
127 if (TF_GetCode(status) != TF_OK) return nullptr;
128 }
129 TFE_OpSetDevice(op.get(), device, status);
130 if (TF_GetCode(status) != TF_OK) return nullptr;
131
132 TFE_TensorHandle* result_handle;
133 int num_retvals = 1;
134 TFE_Execute(op.get(), &result_handle, &num_retvals, status);
135 if (TF_GetCode(status) != TF_OK) return nullptr;
136 return TensorHandlePtr(result_handle);
137 }
138
139 template <typename value_type>
ExpectScalarEq(TFE_TensorHandle * handle,value_type expected_value)140 void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value) {
141 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
142 TF_NewStatus(), TF_DeleteStatus);
143 std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> actual_value(
144 TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor);
145 ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
146 ASSERT_EQ(TF_TensorType(actual_value.get()),
147 static_cast<TF_DataType>(DataTypeToEnum<value_type>().value));
148 EXPECT_EQ(expected_value,
149 *static_cast<value_type*>(TF_TensorData(actual_value.get())));
150 }
151
152 template <std::size_t num_devices>
RegisterParallelDevice(TFE_Context * context,const char * device_name,const std::array<const char *,num_devices> & underlying_devices,TF_Status * status)153 void RegisterParallelDevice(
154 TFE_Context* context, const char* device_name,
155 const std::array<const char*, num_devices>& underlying_devices,
156 TF_Status* status) {
157 TFE_CustomDevice device;
158 void* device_info;
159 tensorflow::parallel_device::AllocateParallelDevice(
160 device_name, underlying_devices.data(), underlying_devices.size(),
161 &device, &device_info);
162 TFE_RegisterCustomDevice(context, device, device_name, device_info, status);
163 }
164
165 } // namespace parallel_device
166 } // namespace tensorflow
167
168 #endif // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_TESTLIB_H_
169