• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_GRAPPLER_TEST_H_
17 #define TENSORFLOW_CORE_GRAPPLER_UTILS_GRAPPLER_TEST_H_
18 
19 #include <vector>
20 
21 #include "absl/strings/string_view.h"
22 #include "tensorflow/cc/framework/scope.h"
23 #include "tensorflow/core/framework/attr_value.pb.h"
24 #include "tensorflow/core/framework/graph.pb.h"
25 #include "tensorflow/core/framework/types.h"
26 #include "tensorflow/core/grappler/grappler_item.h"
27 #include "tensorflow/core/grappler/utils.h"
28 #include "tensorflow/core/lib/random/random.h"
29 #include "tensorflow/core/platform/test.h"
30 #include "tensorflow/core/public/session_options.h"
31 
32 namespace tensorflow {
33 namespace grappler {
34 
35 class GrapplerTest : public ::testing::Test {
36  public:
37   GrapplerTest();
38 
39  protected:
40   void DisableAllOptimizers();
41   void EnableAllOptimizers();
42 
43   std::vector<Tensor> EvaluateNodes(
44       const GraphDef& graph, const std::vector<string>& node_names) const;
45 
46   std::vector<Tensor> EvaluateNodes(
47       const GraphDef& graph, const std::vector<string>& node_names,
48       const std::vector<std::pair<string, Tensor>>& inputs) const;
49 
50   std::vector<Tensor> EvaluateFetchNodes(const GrapplerItem& item) const;
51 
52   NodeDef* AddNode(const string& name, const string& op,
53                    const std::vector<string>& inputs,
54                    const std::vector<std::pair<string, AttrValue>>& attributes,
55                    GraphDef* graph) const;
56 
57   void DisableAllOptimizers(RewriterConfig* cfg);
58 
59   // Checks if two graphs are equal. Both graphs must have the same set of nodes
60   // with the same inputs and attributes. Nodes can be in different order.
61   //
62   // NOTE: This function uses EXPECT/ASSERT macros to check node properties
63   // equality, and adds all failures to the current test.
64   void CompareGraphs(GraphDef want, GraphDef got) const;
65 
66   // Checks if two nodes have the same name, op, inputs and attributes.
67   //
68   // NOTE: This function uses EXPECT/ASSERT macros to check node properties
69   // equality, and adds all failures to the current test.
70   void CompareNodes(const NodeDef& want, const NodeDef& got) const;
71 
72   // Checks if two functions are equal. Both functions must have the same set of
73   // nodes with the same inputs and attributes. Nodes can be in different order.
74   //
75   // NOTE: This function uses EXPECT/ASSERT macros to check node properties
76   // equality, and adds all failures to the current test.
77   void CompareFunctions(FunctionDef want, FunctionDef got) const;
78 
79   // Checks if node 'src' is directly connected to the input($position) of
80   // 'dst'.
81   bool IsNodesDirectlyConnected(const NodeMap& node_map, const string& src,
82                                 const string& dst, int position = 0);
83 
84   // Counts nodes of the given op-type in a graph.
85   int CountOpNodes(const GraphDef& graph, const string& op);
86 
87   // Get a random tensor with given shape.
88   template <DataType DTYPE>
GenerateRandomTensor(const TensorShape & shape)89   Tensor GenerateRandomTensor(const TensorShape& shape) const {
90     typedef typename EnumToDataType<DTYPE>::Type T;
91     Tensor tensor(DTYPE, shape);
92     for (auto i = 0; i < tensor.NumElements(); i++)
93       tensor.flat<T>()(i) = i + random::New64() % 10;
94     return tensor;
95   }
96 
97   // Creates a random tensor with given shape using `setRandom`.
98   template <DataType DTYPE>
GenerateTensorWithSetRandom(const TensorShape & shape)99   Tensor GenerateTensorWithSetRandom(const TensorShape& shape) const {
100     typedef typename EnumToDataType<DTYPE>::Type T;
101     Tensor tensor(DTYPE, shape);
102     tensor.flat<T>().setRandom();
103     return tensor;
104   }
105 
106   // Get a constant tensor with given shape.
107   template <DataType DTYPE>
GenerateConstantTensor(const TensorShape & shape,typename EnumToDataType<DTYPE>::Type value)108   Tensor GenerateConstantTensor(
109       const TensorShape& shape,
110       typename EnumToDataType<DTYPE>::Type value) const {
111     typedef typename EnumToDataType<DTYPE>::Type T;
112     Tensor tensor(DTYPE, shape);
113     for (auto i = 0; i < tensor.NumElements(); i++) tensor.flat<T>()(i) = value;
114     return tensor;
115   }
116 
CreateScopeWithDevice(absl::string_view device)117   inline tensorflow::Scope CreateScopeWithDevice(absl::string_view device) {
118     return tensorflow::Scope::NewRootScope().WithDevice(string(device));
119   }
120 
121  private:
122   SessionOptions options_;
123 };
124 
125 }  // end namespace grappler
126 }  // end namespace tensorflow
127 
128 #endif  // TENSORFLOW_CORE_GRAPPLER_UTILS_GRAPPLER_TEST_H_
129