• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_OPS_TESTUTIL_H_
17 #define TENSORFLOW_CORE_KERNELS_OPS_TESTUTIL_H_
18 
19 #include <memory>
20 #include <vector>
21 
22 #include "tensorflow/core/common_runtime/device.h"
23 #include "tensorflow/core/common_runtime/device_factory.h"
24 #include "tensorflow/core/framework/allocator.h"
25 #include "tensorflow/core/framework/device_base.h"
26 #include "tensorflow/core/framework/graph.pb.h"
27 #include "tensorflow/core/framework/node_def.pb.h"
28 #include "tensorflow/core/framework/op_kernel.h"
29 #include "tensorflow/core/framework/resource_mgr.h"
30 #include "tensorflow/core/framework/tensor.h"
31 #include "tensorflow/core/framework/tensor_testutil.h"
32 #include "tensorflow/core/framework/types.h"
33 #include "tensorflow/core/framework/types.pb.h"
34 #include "tensorflow/core/lib/core/status.h"
35 #include "tensorflow/core/lib/core/status_test_util.h"
36 #include "tensorflow/core/lib/gtl/array_slice.h"
37 #include "tensorflow/core/lib/gtl/inlined_vector.h"
38 #include "tensorflow/core/lib/gtl/stl_util.h"
39 #include "tensorflow/core/platform/env.h"
40 #include "tensorflow/core/platform/logging.h"
41 #include "tensorflow/core/platform/macros.h"
42 #include "tensorflow/core/platform/mutex.h"
43 #include "tensorflow/core/platform/test.h"
44 #include "tensorflow/core/platform/types.h"
45 #include "tensorflow/core/public/session_options.h"
46 #include "tensorflow/core/public/version.h"
47 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
48 
49 namespace tensorflow {
50 namespace test {
51 
SetOutputAttrs(OpKernelContext::Params * params,std::vector<AllocatorAttributes> * attrs)52 inline void SetOutputAttrs(OpKernelContext::Params* params,
53                            std::vector<AllocatorAttributes>* attrs) {
54   attrs->clear();
55   for (int index = 0; index < params->op_kernel->num_outputs(); index++) {
56     AllocatorAttributes attr;
57     const bool on_host =
58         (params->op_kernel->output_memory_types()[index] == HOST_MEMORY);
59     attr.set_on_host(on_host);
60     attrs->push_back(attr);
61   }
62   params->output_attr_array = gtl::vector_as_array(attrs);
63 }
64 
65 }  // namespace test
66 
67 // Helpful functions to test operators.
68 //
69 // This class will eventually be replaced / heavily modified
70 // to use the BrainClient interface.
71 class OpsTestBase : public ::testing::Test {
72  public:
OpsTestBase()73   OpsTestBase()
74       : device_(DeviceFactory::NewDevice("CPU", {}, "/job:a/replica:0/task:0")),
75         device_type_(DEVICE_CPU) {
76     CHECK(device_.get()) << "Could not create CPU device";
77     allocator_ = device_->GetAllocator(AllocatorAttributes());
78   }
79 
~OpsTestBase()80   ~OpsTestBase() override {
81     gtl::STLDeleteElements(&tensors_);
82     gtl::STLDeleteElements(&managed_outputs_);
83     context_.reset(nullptr);
84     params_.reset(nullptr);
85   }
86 
87   // Allow kernel unit tests to run on GPU
88   void SetDevice(const DeviceType& device_type, std::unique_ptr<Device> device);
89 
set_node_def(const NodeDef & node_def)90   void set_node_def(const NodeDef& node_def) { node_def_.CopyFrom(node_def); }
91 
92   // Clients can manipulate the underlying NodeDef via this accessor.
node_def()93   NodeDef* node_def() { return &node_def_; }
94 
95   // Initializes an operator that takes in 'input_types' as input
96   // and output types as output.
97   //
98   // Returns the status of initialization.
InitOp()99   Status InitOp() { return InitOpWithGraphVersion(TF_GRAPH_DEF_VERSION); }
100 
101   // Only use this directly if you have a deprecated op that you need to test.
InitOpWithGraphVersion(int graph_def_version)102   Status InitOpWithGraphVersion(int graph_def_version) {
103     Status status;
104     kernel_ = CreateOpKernel(device_type_, device_.get(), allocator(),
105                              node_def_, graph_def_version, &status);
106     if (kernel_ != nullptr) input_types_ = kernel_->input_types();
107     return status;
108   }
109 
110   // Adds an input for every element described by the shape.
111   // 'input_mapping' maps an index (0...NumElements(shape)) to a
112   // value.
113   //
114   // TODO(vrv): Replace with something like a BrainClient Feed.
115   template <typename T>
AddInput(const TensorShape & shape,std::function<T (int)> input_mapping)116   void AddInput(const TensorShape& shape, std::function<T(int)> input_mapping) {
117     test::FillFn(AddInput(DataTypeToEnum<T>::v(), shape), input_mapping);
118   }
119 
120   // Like AddInput but takes in an explicit arrayslice of data.
121   template <typename T>
AddInputFromArray(const TensorShape & shape,const gtl::ArraySlice<T> & data)122   void AddInputFromArray(const TensorShape& shape,
123                          const gtl::ArraySlice<T>& data) {
124     test::FillValues<T>(AddInput(DataTypeToEnum<T>::v(), shape), data);
125   }
126 
127   // Convenience function to add an input and populate it with the elements from
128   // an initializer list converting the types as needed.
129   template <typename T, typename SrcType>
AddInputFromList(const TensorShape & shape,std::initializer_list<SrcType> data)130   void AddInputFromList(const TensorShape& shape,
131                         std::initializer_list<SrcType> data) {
132     test::FillValues<T>(AddInput(DataTypeToEnum<T>::v(), shape), data);
133   }
134 
135   // Adds a Resource type as input. If <container> is empty, uses the default
136   // container name.
137   template <typename T>
AddResourceInput(const string & container,const string & name,T * resource)138   void AddResourceInput(const string& container, const string& name,
139                         T* resource) {
140     CHECK_GT(input_types_.size(), inputs_.size())
141         << "Adding more inputs than types; perhaps you need to call MakeOp";
142     ResourceMgr* rm = device_->resource_manager();
143     std::string container_name =
144         container == "" ? rm->default_container() : container;
145     EXPECT_TRUE(rm->Create(container_name, name, resource).ok());
146     TypeIndex type_index = MakeTypeIndex<T>();
147     ResourceHandle handle;
148     handle.set_device(device_->name());
149     handle.set_container(container_name);
150     handle.set_name(name);
151     handle.set_hash_code(type_index.hash_code());
152     handle.set_maybe_type_name(type_index.name());
153     Tensor* input = new Tensor(allocator(), DT_RESOURCE, TensorShape({}));
154     input->scalar<ResourceHandle>()() = handle;
155     tensors_.push_back(input);
156     inputs_.push_back({nullptr, input});
157   }
158 
159   // Runs an operation producing 'num_outputs' outputs.
160   //
161   // Returns the context's status after running the operation.
RunOpKernel()162   Status RunOpKernel() {
163     // Make sure the old OpKernelContext is deleted before the Params
164     // it was using.
165     context_.reset(nullptr);
166 
167     params_.reset(new OpKernelContext::Params);
168     params_.get()->device = device_.get();
169     params_.get()->frame_iter = FrameAndIter(0, 0);
170     params_.get()->inputs = &inputs_;
171     params_.get()->op_kernel = kernel_.get();
172     step_container_.reset(new ScopedStepContainer(0, [](const string&) {}));
173     params_->step_container = step_container_.get();
174     std::vector<AllocatorAttributes> attrs;
175     test::SetOutputAttrs(params_.get(), &attrs);
176     checkpoint::TensorSliceReaderCacheWrapper slice_reader_cache_wrapper;
177     params_.get()->slice_reader_cache = &slice_reader_cache_wrapper;
178     params_.get()->resource_manager = device_.get()->resource_manager();
179 
180     context_.reset(new OpKernelContext(params_.get()));
181     device_->Compute(kernel_.get(), context_.get());
182     return context_->status();
183   }
184 
185   // Returns the tensor input for 'input_index'.
186   //
187   // REQUIRES: 0 <= input_index < context_->num_inputs()
GetInput(int input_index)188   const Tensor& GetInput(int input_index) const {
189     CHECK_LT(input_index, context_->num_inputs());
190     CHECK(!IsRefType(context_->input_dtype(input_index)));
191     return context_->input(input_index);
192   }
193 
mutable_input(int input_index)194   TensorValue mutable_input(int input_index) {
195     CHECK_LT(input_index, inputs_.size());
196     return inputs_[input_index];
197   }
198   // Returns the tensor output for 'output_index'.
199   //
200   // REQUIRES: 0 <= output_index < context_->num_outputs()
201   Tensor* GetOutput(int output_index);
202 
allocator()203   Allocator* allocator() { return allocator_; }
204 
output_types()205   const DataTypeVector& output_types() const { return kernel_->output_types(); }
206 
207  private:
AddInput(DataType dtype,const TensorShape & shape)208   Tensor* AddInput(DataType dtype, const TensorShape& shape) {
209     CHECK_GT(input_types_.size(), inputs_.size())
210         << "Adding more inputs than types; perhaps you need to call MakeOp";
211     bool is_ref = IsRefType(input_types_[inputs_.size()]);
212     Tensor* input = new Tensor(allocator(), dtype, shape);
213     tensors_.push_back(input);
214     if (is_ref) {
215       CHECK_EQ(RemoveRefType(input_types_[inputs_.size()]), dtype);
216       inputs_.push_back({&lock_for_refs_, input});
217     } else {
218       CHECK_EQ(input_types_[inputs_.size()], dtype);
219       inputs_.push_back({nullptr, input});
220     }
221     return input;
222   }
223 
224  protected:
225   std::unique_ptr<Device> device_;
226   // The device allocator, or the managed_allocator_ below if running on GPU.
227   Allocator* allocator_;
228 
229   std::unique_ptr<OpKernel> kernel_;
230   std::unique_ptr<ScopedStepContainer> step_container_;
231   NodeDef node_def_;
232   DataTypeVector input_types_;
233   DeviceType device_type_;
234 
235   mutex lock_for_refs_;  // Used as the Mutex for inputs added as refs
236 
237   gtl::InlinedVector<TensorValue, 4> inputs_;
238   // Owns Tensors.
239   std::vector<Tensor*> tensors_;
240   // Copies of the outputs in unified memory (host and device accessible).
241   std::vector<Tensor*> managed_outputs_;
242 
243   std::unique_ptr<OpKernelContext::Params> params_;
244   std::unique_ptr<OpKernelContext> context_;
245   // Unified memory allocator, only used when running on GPU.
246   std::unique_ptr<Allocator> managed_allocator_;
247 
248  private:
249   TF_DISALLOW_COPY_AND_ASSIGN(OpsTestBase);
250 };
251 
252 }  // namespace tensorflow
253 
254 #endif  // TENSORFLOW_CORE_KERNELS_OPS_TESTUTIL_H_
255