1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_FRAMEWORK_FUNCTION_TESTLIB_H_ 17 #define TENSORFLOW_CORE_FRAMEWORK_FUNCTION_TESTLIB_H_ 18 19 #include <string> 20 21 #include "tensorflow/core/framework/attr_value_util.h" 22 #include "tensorflow/core/framework/function.h" 23 #include "tensorflow/core/framework/function.pb.h" 24 #include "tensorflow/core/framework/graph.pb.h" 25 #include "tensorflow/core/framework/node_def.pb.h" 26 #include "tensorflow/core/lib/gtl/array_slice.h" 27 #include "tensorflow/core/platform/types.h" 28 29 namespace tensorflow { 30 namespace test { 31 namespace function { 32 33 // A helper class to make AttrSlice from initializer lists 34 class Attrs { 35 public: Attrs(const std::initializer_list<std::pair<string,FunctionDefHelper::AttrValueWrapper>> & attrs)36 Attrs(const std::initializer_list< // NOLINT(runtime/explicit) 37 std::pair<string, FunctionDefHelper::AttrValueWrapper>>& attrs) { 38 for (const auto& aval : attrs) { 39 map_.insert({aval.first, aval.second.proto}); 40 } 41 } 42 Attrs(const std::vector<std::pair<string,FunctionDefHelper::AttrValueWrapper>> & attrs)43 Attrs( 44 const std::vector<std::pair<string, FunctionDefHelper::AttrValueWrapper>>& 45 attrs) { 46 for (const auto& aval : attrs) { 47 map_.insert({aval.first, aval.second.proto}); 48 } 49 } 50 AttrSlice()51 operator AttrSlice() { return AttrSlice(&map_); } // NOLINT(runtime/explicit) 52 53 private: 54 AttrValueMap map_; 55 }; 56 57 // Helper to construct a NodeDef. 58 NodeDef NDef( 59 StringPiece name, StringPiece op, gtl::ArraySlice<string> inputs, 60 gtl::ArraySlice<std::pair<string, FunctionDefHelper::AttrValueWrapper>> 61 attrs = {}, 62 const string& device = ""); 63 64 // Helper to construct a GraphDef proto. 65 GraphDef GDef(gtl::ArraySlice<NodeDef> nodes, 66 gtl::ArraySlice<FunctionDef> funcs = {}); 67 68 // For testing convenience, we provide a few simple functions that can 69 // be easily executed and tested. 70 71 // x: T -> x * 2. 72 FunctionDef XTimesTwo(); 73 74 // x: T -> cpu(x * 2) + cpu(x * 3). 75 FunctionDef TwoDeviceTimesFive(); 76 77 // x: T -> cpu(x * 2), gpu(x * 3). 78 FunctionDef TwoDeviceMult(); 79 80 // cpu(x): T, gpu(y): T -> cpu(x * 2), gpu(y * 3). 81 FunctionDef TwoDeviceInputOutput(); 82 83 // Function taking a list of Tensors as input. 84 FunctionDef FuncWithListInput(); 85 86 // Function returning a list of Tensors as output. 87 FunctionDef FuncWithListOutput(); 88 89 // x: T -> x + x. 90 FunctionDef XAddX(); 91 92 // x: T, y: T -> x + y. 93 FunctionDef XAddY(); 94 95 // x: T -> x * 2, where x is int32. 96 FunctionDef XTimesTwoInt32(); 97 98 // x: T -> (x * 2) * 2. 99 FunctionDef XTimesFour(); 100 101 // x: T -> ((x * 2) * 2) * 2. 102 FunctionDef XTimes16(); 103 104 // w: T, x: T, b: T -> MatMul(w, x) + b 105 FunctionDef WXPlusB(); 106 107 // x: T -> x: T, T is a type which we automatically converts to a bool. 108 FunctionDef NonZero(); 109 110 // x: T -> bool. 111 FunctionDef IsZero(); 112 113 // x: T -> int64 114 FunctionDef RandomUniform(); 115 116 // x: T, y:T -> y: T, x: T 117 FunctionDef Swap(); 118 119 // x: T, y: T -> y: T, x: T, the body has no nodes. 120 FunctionDef EmptyBodySwap(); 121 122 // x: float, y: resource -> y: resource, 2*x: float. 123 FunctionDef ResourceOutput(); 124 125 // x: resource -> x: resource 126 FunctionDef ResourceIdentity(); 127 128 // x: resource -> y: float. 129 FunctionDef ReadResourceVariable(); 130 131 // Contains malformed control flow which can't be run by the executor. 132 FunctionDef InvalidControlFlow(); 133 134 // x: T -> x <= N. 135 FunctionDef LessThanOrEqualToN(int64 N); 136 137 // x: T, y: T -> x + 1, x * y 138 FunctionDef XPlusOneXTimesY(); 139 140 // x: T, y: T -> x <= N 141 FunctionDef XYXLessThanOrEqualToN(int64 N); 142 143 // x: T -> bool 144 FunctionDef RandomUniformLess(); 145 146 // start: int64, stop: int64, step: int64 -> y: RangeDatasetOp::Dataset 147 FunctionDef MakeRangeDataset(); 148 149 // input_dataset: variant, batch_size: int64, drop_remainder: bool 150 // -> y: BatchDatasetV2::Dataset 151 FunctionDef MakeBatchDataset(); 152 153 // input_dataset: variant, other_arguments: Targuments, f: func, 154 // Targuments: list(type), output_types: list(type), output_shapes: list(shape), 155 // use_inter_op_parallelism: bool, preserve_cardinality: bool 156 // -> y: MapDatasetOp::Dataset 157 FunctionDef MakeMapDataset(bool has_other_args); 158 159 // input_dataset: variant, count: int64 -> y: TakeDataset::Dataset 160 FunctionDef MakeTakeDataset(); 161 162 // x: T -> y: TensorSliceDatasetOp::Dataset 163 FunctionDef MakeTensorSliceDataset(); 164 165 // x: T -> y: T, idx: out_idx 166 FunctionDef Unique(); 167 168 void FunctionTestSchedClosure(std::function<void()> fn); 169 170 } // end namespace function 171 } // end namespace test 172 } // end namespace tensorflow 173 174 #endif // TENSORFLOW_CORE_FRAMEWORK_FUNCTION_TESTLIB_H_ 175