• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_KERNELS_OPS_TESTUTIL_H_
17 #define TENSORFLOW_KERNELS_OPS_TESTUTIL_H_
18 
19 #include <memory>
20 #include <vector>
21 
22 #include "tensorflow/core/common_runtime/device.h"
23 #include "tensorflow/core/common_runtime/device_factory.h"
24 #include "tensorflow/core/framework/allocator.h"
25 #include "tensorflow/core/framework/device_base.h"
26 #include "tensorflow/core/framework/graph.pb.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/resource_mgr.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/framework/tensor_testutil.h"
32 #include "tensorflow/core/framework/types.h"
33 #include "tensorflow/core/framework/types.pb.h"
34 #include "tensorflow/core/lib/core/status.h"
35 #include "tensorflow/core/lib/core/status_test_util.h"
36 #include "tensorflow/core/lib/gtl/array_slice.h"
37 #include "tensorflow/core/lib/gtl/inlined_vector.h"
38 #include "tensorflow/core/lib/gtl/stl_util.h"
39 #include "tensorflow/core/platform/env.h"
40 #include "tensorflow/core/platform/logging.h"
41 #include "tensorflow/core/platform/macros.h"
42 #include "tensorflow/core/platform/mutex.h"
43 #include "tensorflow/core/platform/test.h"
44 #include "tensorflow/core/platform/types.h"
45 #include "tensorflow/core/public/session_options.h"
46 #include "tensorflow/core/public/version.h"
47 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
48 
49 namespace tensorflow {
50 namespace test {
51 
SetOutputAttrs(OpKernelContext::Params * params,std::vector<AllocatorAttributes> * attrs)52 inline void SetOutputAttrs(OpKernelContext::Params* params,
53                            std::vector<AllocatorAttributes>* attrs) {
54   attrs->clear();
55   for (int index = 0; index < params->op_kernel->num_outputs(); index++) {
56     AllocatorAttributes attr;
57     const bool on_host =
58         (params->op_kernel->output_memory_types()[index] == HOST_MEMORY);
59     attr.set_on_host(on_host);
60     attrs->push_back(attr);
61   }
62   params->output_attr_array = gtl::vector_as_array(attrs);
63 }
64 
65 }  // namespace test
66 
67 // Helpful functions to test operators.
68 //
69 // This class will eventually be replaced / heavily modified
70 // to use the BrainClient interface.
71 class OpsTestBase : public ::testing::Test {
72  public:
OpsTestBase()73   OpsTestBase()
74       : device_(DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")),
75         device_type_(DEVICE_CPU) {
76     CHECK(device_.get()) << "Could not create CPU device";
77     allocator_ = device_->GetAllocator(AllocatorAttributes());
78   }
79 
~OpsTestBase()80   ~OpsTestBase() override {
81     gtl::STLDeleteElements(&tensors_);
82     gtl::STLDeleteElements(&managed_outputs_);
83     context_.reset(nullptr);
84     params_.reset(nullptr);
85   }
86 
87   // Allow kernel unit tests to run on GPU
88   void SetDevice(const DeviceType& device_type, std::unique_ptr<Device> device);
89 
set_node_def(const NodeDef & node_def)90   void set_node_def(const NodeDef& node_def) { node_def_.CopyFrom(node_def); }
91 
92   // Clients can manipulate the underlying NodeDef via this accessor.
node_def()93   NodeDef* node_def() { return &node_def_; }
94 
95   // Initializes an operator that takes in 'input_types' as input
96   // and output types as output.
97   //
98   // Returns the status of initialization.
InitOp()99   Status InitOp() { return InitOpWithGraphVersion(TF_GRAPH_DEF_VERSION); }
100 
101   // Only use this directly if you have a deprecated op that you need to test.
InitOpWithGraphVersion(int graph_def_version)102   Status InitOpWithGraphVersion(int graph_def_version) {
103     Status status;
104     kernel_ = CreateOpKernel(device_type_, device_.get(), allocator(),
105                              node_def_, graph_def_version, &status);
106     if (kernel_ != nullptr) input_types_ = kernel_->input_types();
107     return status;
108   }
109 
110   // Adds an input for every element described by the shape.
111   // 'input_mapping' maps an index (0...NumElements(shape)) to a
112   // value.
113   //
114   // TODO(vrv): Replace with something like a BrainClient Feed.
115   template <typename T>
AddInput(const TensorShape & shape,std::function<T (int)> input_mapping)116   void AddInput(const TensorShape& shape, std::function<T(int)> input_mapping) {
117     test::FillFn(AddInput(DataTypeToEnum<T>::v(), shape), input_mapping);
118   }
119 
120   // Like AddInput but takes in an explicit arrayslice of data.
121   template <typename T>
AddInputFromArray(const TensorShape & shape,const gtl::ArraySlice<T> & data)122   void AddInputFromArray(const TensorShape& shape,
123                          const gtl::ArraySlice<T>& data) {
124     test::FillValues<T>(AddInput(DataTypeToEnum<T>::v(), shape), data);
125   }
126 
127   // Convenience function to add an input and populate it with the elements from
128   // an initializer list converting the types as needed.
129   template <typename T, typename SrcType>
AddInputFromList(const TensorShape & shape,std::initializer_list<SrcType> data)130   void AddInputFromList(const TensorShape& shape,
131                         std::initializer_list<SrcType> data) {
132     test::FillValues<T>(AddInput(DataTypeToEnum<T>::v(), shape), data);
133   }
134 
135   // Adds a Resource type as input. If <container> is empty, uses the default
136   // container name.
137   template <typename T>
AddResourceInput(const string & container,const string & name,T * resource)138   void AddResourceInput(const string& container, const string& name,
139                         T* resource) {
140     CHECK_GT(input_types_.size(), inputs_.size())
141         << "Adding more inputs than types; perhaps you need to call MakeOp";
142     ResourceMgr* rm = device_->resource_manager();
143     EXPECT_TRUE(
144         rm->Create(container == "" ? rm->default_container() : container, name,
145                    resource)
146             .ok());
147     TypeIndex type_index = MakeTypeIndex<T>();
148     ResourceHandle handle;
149     handle.set_device(device_->name());
150     handle.set_container(container);
151     handle.set_name(name);
152     handle.set_hash_code(type_index.hash_code());
153     handle.set_maybe_type_name(type_index.name());
154     Tensor* input = new Tensor(allocator(), DT_RESOURCE, TensorShape({}));
155     input->scalar<ResourceHandle>()() = handle;
156     tensors_.push_back(input);
157     inputs_.push_back({nullptr, input});
158   }
159 
160   // Runs an operation producing 'num_outputs' outputs.
161   //
162   // Returns the context's status after running the operation.
RunOpKernel()163   Status RunOpKernel() {
164     // Make sure the old OpKernelContext is deleted before the Params
165     // it was using.
166     context_.reset(nullptr);
167 
168     params_.reset(new OpKernelContext::Params);
169     params_.get()->device = device_.get();
170     params_.get()->frame_iter = FrameAndIter(0, 0);
171     params_.get()->inputs = &inputs_;
172     params_.get()->op_kernel = kernel_.get();
173     step_container_.reset(new ScopedStepContainer(0, [](const string&) {}));
174     params_->step_container = step_container_.get();
175     std::vector<AllocatorAttributes> attrs;
176     test::SetOutputAttrs(params_.get(), &attrs);
177     checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_wrapper;
178     params_.get()->slice_reader_cache = &slice_reader_cache_wrapper;
179     params_.get()->resource_manager = device_.get()->resource_manager();
180 
181     context_.reset(new OpKernelContext(params_.get()));
182     device_->Compute(kernel_.get(), context_.get());
183     return context_->status();
184   }
185 
186   // Returns the tensor input for 'input_index'.
187   //
188   // REQUIRES: 0 <= input_index < context_->num_inputs()
GetInput(int input_index)189   const Tensor& GetInput(int input_index) const {
190     CHECK_LT(input_index, context_->num_inputs());
191     CHECK(!IsRefType(context_->input_dtype(input_index)));
192     return context_->input(input_index);
193   }
194 
mutable_input(int input_index)195   TensorValue mutable_input(int input_index) {
196     CHECK_LT(input_index, inputs_.size());
197     return inputs_[input_index];
198   }
199   // Returns the tensor output for 'output_index'.
200   //
201   // REQUIRES: 0 <= output_index < context_->num_outputs()
202   Tensor* GetOutput(int output_index);
203 
allocator()204   Allocator* allocator() { return allocator_; }
205 
output_types()206   const DataTypeVector& output_types() const { return kernel_->output_types(); }
207 
208  private:
AddInput(DataType dtype,const TensorShape & shape)209   Tensor* AddInput(DataType dtype, const TensorShape& shape) {
210     CHECK_GT(input_types_.size(), inputs_.size())
211         << "Adding more inputs than types; perhaps you need to call MakeOp";
212     bool is_ref = IsRefType(input_types_[inputs_.size()]);
213     Tensor* input = new Tensor(allocator(), dtype, shape);
214     tensors_.push_back(input);
215     if (is_ref) {
216       CHECK_EQ(RemoveRefType(input_types_[inputs_.size()]), dtype);
217       inputs_.push_back({&lock_for_refs_, input});
218     } else {
219       CHECK_EQ(input_types_[inputs_.size()], dtype);
220       inputs_.push_back({nullptr, input});
221     }
222     return input;
223   }
224 
225  protected:
226   std::unique_ptr<Device> device_;
227   // The device allocator, or the managed_allocator_ below if running on GPU.
228   Allocator* allocator_;
229 
230   std::unique_ptr<OpKernel> kernel_;
231   std::unique_ptr<ScopedStepContainer> step_container_;
232   NodeDef node_def_;
233   DataTypeVector input_types_;
234   DeviceType device_type_;
235 
236   mutex lock_for_refs_;  // Used as the Mutex for inputs added as refs
237 
238   gtl::InlinedVector<TensorValue, 4> inputs_;
239   // Owns Tensors.
240   std::vector<Tensor*> tensors_;
241   // Copies of the outputs in unified memory (host and device accessible).
242   std::vector<Tensor*> managed_outputs_;
243 
244   std::unique_ptr<OpKernelContext::Params> params_;
245   std::unique_ptr<OpKernelContext> context_;
246   // Unified memory allocator, only used when running on GPU.
247   std::unique_ptr<Allocator> managed_allocator_;
248 
249  private:
250   TF_DISALLOW_COPY_AND_ASSIGN(OpsTestBase);
251 };
252 
253 }  // namespace tensorflow
254 
255 #endif  // TENSORFLOW_KERNELS_OPS_TESTUTIL_H_
256