1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_GRAPPLER_UTILS_GRAPPLER_TEST_H_ 17 #define TENSORFLOW_CORE_GRAPPLER_UTILS_GRAPPLER_TEST_H_ 18 19 #include <vector> 20 21 #include "absl/strings/string_view.h" 22 #include "tensorflow/cc/framework/scope.h" 23 #include "tensorflow/core/framework/attr_value.pb.h" 24 #include "tensorflow/core/framework/graph.pb.h" 25 #include "tensorflow/core/framework/types.h" 26 #include "tensorflow/core/grappler/grappler_item.h" 27 #include "tensorflow/core/grappler/utils.h" 28 #include "tensorflow/core/lib/random/random.h" 29 #include "tensorflow/core/platform/test.h" 30 #include "tensorflow/core/public/session_options.h" 31 32 namespace tensorflow { 33 namespace grappler { 34 35 class GrapplerTest : public ::testing::Test { 36 public: 37 GrapplerTest(); 38 39 protected: 40 void DisableAllOptimizers(); 41 void EnableAllOptimizers(); 42 43 std::vector<Tensor> EvaluateNodes( 44 const GraphDef& graph, const std::vector<string>& node_names) const; 45 46 std::vector<Tensor> EvaluateNodes( 47 const GraphDef& graph, const std::vector<string>& node_names, 48 const std::vector<std::pair<string, Tensor>>& inputs) const; 49 50 std::vector<Tensor> EvaluateFetchNodes(const GrapplerItem& item) const; 51 52 NodeDef* AddNode(const string& name, const string& op, 53 const std::vector<string>& inputs, 54 const std::vector<std::pair<string, AttrValue>>& attributes, 55 GraphDef* graph) const; 56 57 void DisableAllOptimizers(RewriterConfig* cfg); 58 59 // Checks if two graphs are equal. Both graphs must have the same set of nodes 60 // with the same inputs and attributes. Nodes can be in different order. 61 // 62 // NOTE: This function uses EXPECT/ASSERT macros to check node properties 63 // equality, and adds all failures to the current test. 64 void CompareGraphs(GraphDef want, GraphDef got) const; 65 66 // Checks if two nodes have the same name, op, inputs and attributes. 67 // 68 // NOTE: This function uses EXPECT/ASSERT macros to check node properties 69 // equality, and adds all failures to the current test. 70 void CompareNodes(const NodeDef& want, const NodeDef& got) const; 71 72 // Checks if two functions are equal. Both functions must have the same set of 73 // nodes with the same inputs and attributes. Nodes can be in different order. 74 // 75 // NOTE: This function uses EXPECT/ASSERT macros to check node properties 76 // equality, and adds all failures to the current test. 77 void CompareFunctions(FunctionDef want, FunctionDef got) const; 78 79 // Checks if node 'src' is directly connected to the input($position) of 80 // 'dst'. 81 bool IsNodesDirectlyConnected(const NodeMap& node_map, const string& src, 82 const string& dst, int position = 0); 83 84 // Counts nodes of the given op-type in a graph. 85 int CountOpNodes(const GraphDef& graph, const string& op); 86 87 // Get a random tensor with given shape. 88 template <DataType DTYPE> GenerateRandomTensor(const TensorShape & shape)89 Tensor GenerateRandomTensor(const TensorShape& shape) const { 90 typedef typename EnumToDataType<DTYPE>::Type T; 91 Tensor tensor(DTYPE, shape); 92 for (auto i = 0; i < tensor.NumElements(); i++) 93 tensor.flat<T>()(i) = i + random::New64() % 10; 94 return tensor; 95 } 96 97 // Creates a random tensor with given shape using `setRandom`. 98 template <DataType DTYPE> GenerateTensorWithSetRandom(const TensorShape & shape)99 Tensor GenerateTensorWithSetRandom(const TensorShape& shape) const { 100 typedef typename EnumToDataType<DTYPE>::Type T; 101 Tensor tensor(DTYPE, shape); 102 tensor.flat<T>().setRandom(); 103 return tensor; 104 } 105 106 // Get a constant tensor with given shape. 107 template <DataType DTYPE> GenerateConstantTensor(const TensorShape & shape,typename EnumToDataType<DTYPE>::Type value)108 Tensor GenerateConstantTensor( 109 const TensorShape& shape, 110 typename EnumToDataType<DTYPE>::Type value) const { 111 typedef typename EnumToDataType<DTYPE>::Type T; 112 Tensor tensor(DTYPE, shape); 113 for (auto i = 0; i < tensor.NumElements(); i++) tensor.flat<T>()(i) = value; 114 return tensor; 115 } 116 CreateScopeWithDevice(absl::string_view device)117 inline tensorflow::Scope CreateScopeWithDevice(absl::string_view device) { 118 return tensorflow::Scope::NewRootScope().WithDevice(string(device)); 119 } 120 121 private: 122 SessionOptions options_; 123 }; 124 125 } // end namespace grappler 126 } // end namespace tensorflow 127 128 #endif // TENSORFLOW_CORE_GRAPPLER_UTILS_GRAPPLER_TEST_H_ 129