• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_OPS_TESTUTIL_H_
17 #define TENSORFLOW_CORE_KERNELS_OPS_TESTUTIL_H_
18 
19 #include <functional>
20 #include <initializer_list>
21 #include <memory>
22 #include <string>
23 #include <vector>
24 
25 #include "tensorflow/core/common_runtime/device.h"
26 #include "tensorflow/core/common_runtime/device_factory.h"
27 #include "tensorflow/core/common_runtime/device_mgr.h"
28 #include "tensorflow/core/common_runtime/process_function_library_runtime.h"
29 #include "tensorflow/core/framework/allocator.h"
30 #include "tensorflow/core/framework/device_base.h"
31 #include "tensorflow/core/framework/function.h"
32 #include "tensorflow/core/framework/graph.pb.h"
33 #include "tensorflow/core/framework/node_def.pb.h"
34 #include "tensorflow/core/framework/op_kernel.h"
35 #include "tensorflow/core/framework/resource_mgr.h"
36 #include "tensorflow/core/framework/tensor.h"
37 #include "tensorflow/core/framework/tensor_shape.h"
38 #include "tensorflow/core/framework/tensor_testutil.h"
39 #include "tensorflow/core/framework/type_index.h"
40 #include "tensorflow/core/framework/types.h"
41 #include "tensorflow/core/framework/types.pb.h"
42 #include "tensorflow/core/lib/core/status.h"
43 #include "tensorflow/core/lib/core/status_test_util.h"
44 #include "tensorflow/core/lib/gtl/array_slice.h"
45 #include "tensorflow/core/lib/gtl/inlined_vector.h"
46 #include "tensorflow/core/platform/env.h"
47 #include "tensorflow/core/platform/logging.h"
48 #include "tensorflow/core/platform/macros.h"
49 #include "tensorflow/core/platform/mutex.h"
50 #include "tensorflow/core/platform/status.h"
51 #include "tensorflow/core/platform/test.h"
52 #include "tensorflow/core/platform/threadpool.h"
53 #include "tensorflow/core/platform/types.h"
54 #include "tensorflow/core/public/session_options.h"
55 #include "tensorflow/core/public/version.h"
56 #include "tensorflow/core/util/tensor_slice_reader_cache.h"
57 
58 namespace tensorflow {
59 namespace test {
60 
61 void SetOutputAttrs(OpKernelContext::Params* params,
62                     std::vector<AllocatorAttributes>* attrs);
63 
64 }  // namespace test
65 
66 // Helpful functions to test operators.
67 //
68 // This class will eventually be replaced / heavily modified
69 // to use the BrainClient interface.
70 class OpsTestBase : public ::testing::Test {
71  public:
72   OpsTestBase();
73 
74   ~OpsTestBase() override;
75 
76   // Allow kernel unit tests to run on GPU
77   void SetDevice(const DeviceType& device_type, std::unique_ptr<Device> device);
78 
79   void set_node_def(const NodeDef& node_def);
80 
81   // Clients can manipulate the underlying NodeDef via this accessor.
82   NodeDef* node_def();
83 
84   // Initializes an operator that takes in 'input_types' as input
85   // and output types as output.
86   //
87   // Returns the status of initialization.
88   Status InitOp();
89 
90   // Only use this directly if you have a deprecated op that you need to test.
91   Status InitOpWithGraphVersion(int graph_def_version);
92 
93   // Adds an input for every element described by the shape.
94   // 'input_mapping' maps an index (0...NumElements(shape)) to a
95   // value.
96   //
97   // TODO(vrv): Replace with something like a BrainClient Feed.
98   template <typename T>
AddInput(const TensorShape & shape,std::function<T (int)> input_mapping)99   void AddInput(const TensorShape& shape, std::function<T(int)> input_mapping) {
100     test::FillFn(AddInput(DataTypeToEnum<T>::v(), shape), input_mapping);
101   }
102 
103   // Like AddInput but takes in an explicit arrayslice of data.
104   template <typename T>
AddInputFromArray(const TensorShape & shape,const gtl::ArraySlice<T> data)105   void AddInputFromArray(const TensorShape& shape,
106                          const gtl::ArraySlice<T> data) {
107     test::FillValues<T>(AddInput(DataTypeToEnum<T>::v(), shape), data);
108   }
109 
110   // Convenience function to add an input and populate it with the elements from
111   // an initializer list converting the types as needed.
112   template <typename T, typename SrcType>
AddInputFromList(const TensorShape & shape,std::initializer_list<SrcType> data)113   void AddInputFromList(const TensorShape& shape,
114                         std::initializer_list<SrcType> data) {
115     test::FillValues<T>(AddInput(DataTypeToEnum<T>::v(), shape), data);
116   }
117 
118   // Adds a Resource type as input. If <container> is empty, uses the default
119   // container name.
120   template <typename T>
AddResourceInput(const string & container,const string & name,T * resource)121   void AddResourceInput(const string& container, const string& name,
122                         T* resource) {
123     CHECK_GT(input_types_.size(), inputs_.size())
124         << "Adding more inputs than types; perhaps you need to call MakeOp";
125     ResourceMgr* rm = device_->resource_manager();
126     std::string container_name =
127         container.empty() ? rm->default_container() : container;
128     EXPECT_TRUE(rm->Create(container_name, name, resource).ok());
129     AddResourceInputInternal(container_name, name, TypeIndex::Make<T>());
130   }
131 
132   // Runs an operation producing 'num_outputs' outputs.
133   //
134   // Returns the context's status after running the operation.
135   Status RunOpKernel();
136 
137   // Returns the tensor input for 'input_index'.
138   //
139   // REQUIRES: 0 <= input_index < context_->num_inputs()
140   const Tensor& GetInput(int input_index) const;
141 
142   TensorValue mutable_input(int input_index);
143 
144   // Returns the tensor output for 'output_index'.
145   //
146   // REQUIRES: 0 <= output_index < context_->num_outputs()
147   Tensor* GetOutput(int output_index);
148 
149   Allocator* allocator();
150 
151   OpKernel* op_kernel();
152 
153   const DataTypeVector& output_types() const;
154 
155  protected:
156   Tensor* AddInput(DataType dtype, const TensorShape& shape);
157   void AddResourceInputInternal(const std::string& container_name,
158                                 const std::string& name,
159                                 const TypeIndex& type_index);
160 
161   // device_mgr_ owns device_.
162   std::unique_ptr<DeviceMgr> device_mgr_;
163   Device* device_;
164 
165   // The device allocator, or the managed_allocator_ below if running on GPU.
166   Allocator* allocator_;
167 
168   std::unique_ptr<OpKernel> kernel_;
169   std::unique_ptr<ScopedStepContainer> step_container_;
170   NodeDef node_def_;
171   DataTypeVector input_types_;
172   DeviceType device_type_;
173 
174   mutex lock_for_refs_;  // Used as the Mutex for inputs added as refs
175 
176   gtl::InlinedVector<TensorValue, 4> inputs_;
177   // Owns Tensors.
178   std::vector<Tensor*> tensors_;
179   // Copies of the outputs in unified memory (host and device accessible).
180   std::vector<Tensor*> managed_outputs_;
181 
182   std::unique_ptr<OpKernelContext::Params> params_;
183   std::unique_ptr<OpKernelContext> context_;
184   // Unified memory allocator, only used when running on GPU.
185   std::unique_ptr<Allocator> managed_allocator_;
186 
187   std::unique_ptr<FunctionLibraryDefinition> flib_def_;
188   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr_;
189   std::unique_ptr<thread::ThreadPool> thread_pool_;
190 
191  private:
192   TF_DISALLOW_COPY_AND_ASSIGN(OpsTestBase);
193 };
194 
195 }  // namespace tensorflow
196 
197 #endif  // TENSORFLOW_CORE_KERNELS_OPS_TESTUTIL_H_
198