• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_TESTLIB_H_
17 #define TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_TESTLIB_H_
18 
19 #include <array>
20 
21 #include "tensorflow/c/c_api.h"
22 #include "tensorflow/c/c_api_experimental.h"
23 #include "tensorflow/c/eager/c_api.h"
24 #include "tensorflow/c/eager/c_api_experimental.h"
25 #include "tensorflow/c/eager/parallel_device/parallel_device.h"
26 #include "tensorflow/c/eager/parallel_device/parallel_device_lib.h"
27 #include "tensorflow/core/platform/test.h"
28 
29 namespace tensorflow {
30 namespace parallel_device {
31 
32 // A helper for performing common operations on variables. A much more
33 // restricted stand-in for tf.Variable in Python.
34 class Variable {
35  public:
36   // Construct a Variable from a resource-dtype TFE_TensorHandle and an
37   // indication of the dtype of the variable's value.
38   //
39   // Note that creating this resource-dtype handle can fail, so `Create` is a
40   // separate static method which returns a status.
Variable(TFE_TensorHandle * handle,TF_DataType type)41   Variable(TFE_TensorHandle* handle, TF_DataType type)
42       : handle_(handle), type_(type) {}
43 
44   // Helper for constructing a resource handle and wrapping it in a `Variable`
45   // object.
46   static Variable* Create(TFE_Context* context, TF_DataType type,
47                           const int64_t* dims, const int num_dims,
48                           const char* device, TF_Status* status);
49   // Dereferences the backing buffer for the variable. Note that since this can
50   // fail (it runs operations), it must be called explicitly and the resulting
51   // `status` checked.
52   void Destroy(TFE_Context* context, TF_Status* status);
53 
54   // Reads from the variable.
55   TensorHandlePtr Read(TFE_Context* context, TF_Status* status);
56   // Assigns a new value to the variable.
57   void Assign(TFE_Context* context, TFE_TensorHandle* value, TF_Status* status);
58   // Adds `value` to the existing value of the variable.
59   void AssignAdd(TFE_Context* context, TFE_TensorHandle* value,
60                  TF_Status* status);
61 
62  private:
63   // Helper for running any single-argument assignment ops (Assign, AssignAdd,
64   // AssignSub, ...).
65   void GeneralAssignment(const char* op_name, TFE_Context* context,
66                          TFE_TensorHandle* value, TF_Status* status);
67 
68   // The a handle for the resource-dtype tensor pointing to the variable's
69   // buffer.
70   TFE_TensorHandle* handle_;
71   // The dtype of the variable's buffer (input dtype for assignments, output
72   // dtype of read operations).
73   TF_DataType type_;
74 };
75 
76 // Creates a TFE_TensorHandle with value `v`.
77 TensorHandlePtr FloatTensorHandle(float v, TF_Status* status);
78 
79 // Creates a rank-one TFE_TensorHandle with value `v`.
80 TensorHandlePtr VectorFloatTensorHandle(const std::vector<float>& v,
81                                         TF_Status* status);
82 
83 // Helper to un-pack `num_replicas` TFE_TensorHandles from one parallel handle.
84 template <std::size_t num_replicas>
85 void ExtractPerDeviceValues(
86     TFE_Context* context, TFE_TensorHandle* input,
87     std::array<TensorHandlePtr, num_replicas>* components, TF_Status* status);
88 
89 // Helper to pack `num_replicas` TFE_TensorHandles into one parallel handle.
90 template <std::size_t num_replicas>
91 TensorHandlePtr CreatePerDeviceValues(
92     TFE_Context* context,
93     const std::array<TFE_TensorHandle*, num_replicas>& components,
94     const char* device, TF_Status* status);
95 
96 TensorHandlePtr Multiply(TFE_Context* context, TFE_TensorHandle* first,
97                          TFE_TensorHandle* second, TF_Status* status);
98 
99 // Assert that `handle` is equal to `expected_value`.
100 template <typename value_type>
101 void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value);
102 
103 template <std::size_t num_devices>
104 void RegisterParallelDevice(
105     TFE_Context* context, const char* device_name,
106     const std::array<const char*, num_devices>& underlying_devices,
107     TF_Status* status);
108 
109 // Create and modify a variable placed on a parallel device which composes
110 // `first_device` and `second_device`.
111 void BasicTestsForTwoDevices(TFE_Context* context, const char* first_device,
112                              const char* second_device);
113 
114 // Implementations of templated functions ******************************
115 
116 template <std::size_t num_replicas>
CreatePerDeviceValues(TFE_Context * context,const std::array<TFE_TensorHandle *,num_replicas> & components,const char * device,TF_Status * status)117 TensorHandlePtr CreatePerDeviceValues(
118     TFE_Context* context,
119     const std::array<TFE_TensorHandle*, num_replicas>& components,
120     const char* device, TF_Status* status) {
121   std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> op(
122       TFE_NewOp(context, "TPUReplicatedInput", status), TFE_DeleteOp);
123   if (TF_GetCode(status) != TF_OK) return nullptr;
124   TFE_OpSetAttrInt(op.get(), "N", num_replicas);
125   for (int i = 0; i < num_replicas; ++i) {
126     TFE_OpAddInput(op.get(), components[i], status);
127     if (TF_GetCode(status) != TF_OK) return nullptr;
128   }
129   TFE_OpSetDevice(op.get(), device, status);
130   if (TF_GetCode(status) != TF_OK) return nullptr;
131 
132   TFE_TensorHandle* result_handle;
133   int num_retvals = 1;
134   TFE_Execute(op.get(), &result_handle, &num_retvals, status);
135   if (TF_GetCode(status) != TF_OK) return nullptr;
136   return TensorHandlePtr(result_handle);
137 }
138 
139 template <typename value_type>
ExpectScalarEq(TFE_TensorHandle * handle,value_type expected_value)140 void ExpectScalarEq(TFE_TensorHandle* handle, value_type expected_value) {
141   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
142       TF_NewStatus(), TF_DeleteStatus);
143   std::unique_ptr<TF_Tensor, decltype(&TF_DeleteTensor)> actual_value(
144       TFE_TensorHandleResolve(handle, status.get()), TF_DeleteTensor);
145   ASSERT_EQ(TF_GetCode(status.get()), TF_OK) << TF_Message(status.get());
146   ASSERT_EQ(TF_TensorType(actual_value.get()),
147             static_cast<TF_DataType>(DataTypeToEnum<value_type>().value));
148   EXPECT_EQ(expected_value,
149             *static_cast<value_type*>(TF_TensorData(actual_value.get())));
150 }
151 
152 template <std::size_t num_devices>
RegisterParallelDevice(TFE_Context * context,const char * device_name,const std::array<const char *,num_devices> & underlying_devices,TF_Status * status)153 void RegisterParallelDevice(
154     TFE_Context* context, const char* device_name,
155     const std::array<const char*, num_devices>& underlying_devices,
156     TF_Status* status) {
157   TFE_CustomDevice device;
158   void* device_info;
159   tensorflow::parallel_device::AllocateParallelDevice(
160       device_name, underlying_devices.data(), underlying_devices.size(),
161       &device, &device_info);
162   TFE_RegisterCustomDevice(context, device, device_name, device_info, status);
163 }
164 
165 }  // namespace parallel_device
166 }  // namespace tensorflow
167 
168 #endif  // TENSORFLOW_C_EAGER_PARALLEL_DEVICE_PARALLEL_DEVICE_TESTLIB_H_
169