• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_OPS_TESTUTIL_H_
17 #define TENSORFLOW_CORE_KERNELS_OPS_TESTUTIL_H_
18 
19 #include <functional>
20 #include <initializer_list>
21 #include <memory>
22 #include <string>
23 #include <vector>
24 
25 #include "tensorflow/core/common_runtime/device.h"
26 #include "tensorflow/core/common_runtime/device_factory.h"
27 #include "tensorflow/core/common_runtime/device_mgr.h"
28 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
29 #include "tensorflow/core/framework/allocator.h"
30 #include "tensorflow/core/framework/device_base.h"
31 #include "tensorflow/core/framework/function.h"
32 #include "tensorflow/core/framework/graph.pb.h"
33 #include "tensorflow/core/framework/node_def.pb.h"
34 #include "tensorflow/core/framework/op_kernel.h"
35 #include "tensorflow/core/framework/resource_mgr.h"
36 #include "tensorflow/core/framework/tensor.h"
37 #include "tensorflow/core/framework/tensor_shape.h"
38 #include "tensorflow/core/framework/tensor_testutil.h"
39 #include "tensorflow/core/framework/type_index.h"
40 #include "tensorflow/core/framework/types.h"
41 #include "tensorflow/core/framework/types.pb.h"
42 #include "tensorflow/core/lib/core/status.h"
43 #include "tensorflow/core/lib/core/status_test_util.h"
44 #include "tensorflow/core/lib/gtl/array_slice.h"
45 #include "tensorflow/core/lib/gtl/inlined_vector.h"
46 #include "tensorflow/core/platform/env.h"
47 #include "tensorflow/core/platform/logging.h"
48 #include "tensorflow/core/platform/macros.h"
49 #include "tensorflow/core/platform/mutex.h"
50 #include "tensorflow/core/platform/status.h"
51 #include "tensorflow/core/platform/test.h"
52 #include "tensorflow/core/platform/threadpool.h"
53 #include "tensorflow/core/platform/types.h"
54 #include "tensorflow/core/public/session_options.h"
55 #include "tensorflow/core/public/version.h"
56 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
57 
58 namespace tensorflow {
59 namespace test {
60 
61 void SetOutputAttrs(OpKernelContext::Params* params,
62                     std::vector<AllocatorAttributes>* attrs);
63 
64 }  // namespace test
65 
66 // Helpful functions to test operators.
67 //
68 // This class will eventually be replaced / heavily modified
69 // to use the BrainClient interface.
70 class OpsTestBase : public ::testing::Test {
71  public:
72   OpsTestBase();
73 
74   ~OpsTestBase() override;
75 
76   // Allow kernel unit tests to run on GPU
77   void SetDevice(const DeviceType& device_type, std::unique_ptr<Device> device);
78 
79   void set_node_def(const NodeDef& node_def);
80 
81   // Clients can manipulate the underlying NodeDef via this accessor.
82   NodeDef* node_def();
83 
84   // Initializes an operator that takes in 'input_types' as input
85   // and output types as output.
86   //
87   // Returns the status of initialization.
88   Status InitOp();
89 
90   // Only use this directly if you have a deprecated op that you need to test.
91   Status InitOpWithGraphVersion(int graph_def_version);
92 
93   // Adds an input for every element described by the shape.
94   // 'input_mapping' maps an index (0...NumElements(shape)) to a
95   // value.
96   //
97   // TODO(vrv): Replace with something like a BrainClient Feed.
98   template <typename T>
AddInput(const TensorShape & shape,std::function<T (int)> input_mapping)99   void AddInput(const TensorShape& shape, std::function<T(int)> input_mapping) {
100     test::FillFn(AddInput(DataTypeToEnum<T>::v(), shape), input_mapping);
101   }
102 
103   // Like AddInput but takes in an explicit arrayslice of data.
104   template <typename T>
AddInputFromArray(const TensorShape & shape,const gtl::ArraySlice<T> & data)105   void AddInputFromArray(const TensorShape& shape,
106                          const gtl::ArraySlice<T>& data) {
107     test::FillValues<T>(AddInput(DataTypeToEnum<T>::v(), shape), data);
108   }
109 
110   // Convenience function to add an input and populate it with the elements from
111   // an initializer list converting the types as needed.
112   template <typename T, typename SrcType>
AddInputFromList(const TensorShape & shape,std::initializer_list<SrcType> data)113   void AddInputFromList(const TensorShape& shape,
114                         std::initializer_list<SrcType> data) {
115     test::FillValues<T>(AddInput(DataTypeToEnum<T>::v(), shape), data);
116   }
117 
118   // Adds a Resource type as input. If <container> is empty, uses the default
119   // container name.
120   template <typename T>
AddResourceInput(const string & container,const string & name,T * resource)121   void AddResourceInput(const string& container, const string& name,
122                         T* resource) {
123     CHECK_GT(input_types_.size(), inputs_.size())
124         << "Adding more inputs than types; perhaps you need to call MakeOp";
125     ResourceMgr* rm = device_->resource_manager();
126     std::string container_name =
127         container.empty() ? rm->default_container() : container;
128     EXPECT_TRUE(rm->Create(container_name, name, resource).ok());
129     AddResourceInputInternal(container_name, name, TypeIndex::Make<T>());
130   }
131 
132   // Runs an operation producing 'num_outputs' outputs.
133   //
134   // Returns the context's status after running the operation.
135   Status RunOpKernel();
136 
137   // Returns the tensor input for 'input_index'.
138   //
139   // REQUIRES: 0 <= input_index < context_->num_inputs()
140   const Tensor& GetInput(int input_index) const;
141 
142   TensorValue mutable_input(int input_index);
143 
144   // Returns the tensor output for 'output_index'.
145   //
146   // REQUIRES: 0 <= output_index < context_->num_outputs()
147   Tensor* GetOutput(int output_index);
148 
149   Allocator* allocator();
150 
151   const DataTypeVector& output_types() const;
152 
153  protected:
154   Tensor* AddInput(DataType dtype, const TensorShape& shape);
155   void AddResourceInputInternal(const std::string& container_name,
156                                 const std::string& name,
157                                 const TypeIndex& type_index);
158 
159   // device_mgr_ owns device_.
160   std::unique_ptr<DeviceMgr> device_mgr_;
161   Device* device_;
162 
163   // The device allocator, or the managed_allocator_ below if running on GPU.
164   Allocator* allocator_;
165 
166   std::unique_ptr<OpKernel> kernel_;
167   std::unique_ptr<ScopedStepContainer> step_container_;
168   NodeDef node_def_;
169   DataTypeVector input_types_;
170   DeviceType device_type_;
171 
172   mutex lock_for_refs_;  // Used as the Mutex for inputs added as refs
173 
174   gtl::InlinedVector<TensorValue, 4> inputs_;
175   // Owns Tensors.
176   std::vector<Tensor*> tensors_;
177   // Copies of the outputs in unified memory (host and device accessible).
178   std::vector<Tensor*> managed_outputs_;
179 
180   std::unique_ptr<OpKernelContext::Params> params_;
181   std::unique_ptr<OpKernelContext> context_;
182   // Unified memory allocator, only used when running on GPU.
183   std::unique_ptr<Allocator> managed_allocator_;
184 
185   std::unique_ptr<FunctionLibraryDefinition> flib_def_;
186   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
187   std::unique_ptr<thread::ThreadPool> thread_pool_;
188 
189  private:
190   TF_DISALLOW_COPY_AND_ASSIGN(OpsTestBase);
191 };
192 
193 }  // namespace tensorflow
194 
195 #endif  // TENSORFLOW_CORE_KERNELS_OPS_TESTUTIL_H_
196