1 /* Copyright 2016 Google Inc. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_FUZZING_FUZZ_SESSION_H_ 17 #define TENSORFLOW_CORE_KERNELS_FUZZING_FUZZ_SESSION_H_ 18 19 #include "tensorflow/cc/framework/scope.h" 20 #include "tensorflow/core/graph/graph.h" 21 #include "tensorflow/core/public/session.h" 22 23 // Standard invoking function macro to dispatch to a fuzzer class. 24 #ifndef PLATFORM_WINDOWS 25 #define STANDARD_TF_FUZZ_FUNCTION(FuzzerClass) \ 26 extern "C" int LLVMFuzzerTestOneInput(const uint8_t* data, size_t size) { \ 27 static FuzzerClass* fuzzer = new FuzzerClass(); \ 28 return fuzzer->Fuzz(data, size); \ 29 } 30 #else 31 // We don't compile this for Windows, MSVC doesn't like it as pywrap in Windows 32 // links all the code into one big object file and there are conflicting 33 // function names. 34 #define STANDARD_TF_FUZZ_FUNCTION(FuzzerClass) 35 #endif 36 37 // Standard builder for hooking one placeholder to one op. 38 #define SINGLE_INPUT_OP_BUILDER(dtype, opName) \ 39 void BuildGraph(const Scope& scope) override { \ 40 auto op_node = \ 41 tensorflow::ops::Placeholder(scope.WithOpName("input"), dtype); \ 42 (void)tensorflow::ops::opName(scope.WithOpName("output"), op_node); \ 43 } 44 45 namespace tensorflow { 46 namespace fuzzing { 47 48 // Create a TensorFlow session using a specific GraphDef created 49 // by BuildGraph(), and make it available for fuzzing. 50 // Users must override BuildGraph and FuzzImpl to specify 51 // (1) which operations are being fuzzed; and 52 // (2) How to translate the uint8_t* buffer from the fuzzer 53 // to a Tensor or Tensors that are semantically appropriate 54 // for the op under test. 55 // For the simple cases of testing a single op that takes a single 56 // input Tensor, use the SINGLE_INPUT_OP_BUILDER(dtype, opName) macro in place 57 // of defining BuildGraphDef. 58 // 59 // Typical use: 60 // class FooFuzzer : public FuzzSession { 61 // SINGLE_INPUT_OP_BUILDER(DT_INT8, Identity); 62 // void FuzzImpl(const uint8_t* data, size_t size) { 63 // ... convert data and size to a Tensor, pass it to: 64 // RunInputs({{"input", input_tensor}}); 65 // 66 class FuzzSession { 67 public: FuzzSession()68 FuzzSession() : initialized_(false) {} ~FuzzSession()69 virtual ~FuzzSession() {} 70 71 // Constructs a Graph using the supplied Scope. 72 // By convention, the graph should have inputs named "input1", ... 73 // "inputN", and one output node, named "output". 74 // Users of FuzzSession should override this method to create their graph. 75 virtual void BuildGraph(const Scope& scope) = 0; 76 77 // Implements the logic that converts an opaque byte buffer 78 // from the fuzzer to Tensor inputs to the graph. Users must override. 79 virtual void FuzzImpl(const uint8_t* data, size_t size) = 0; 80 81 // Initializes the FuzzSession. Not safe for multithreading. 82 // Separate init function because the call to virtual BuildGraphDef 83 // can't be put into the constructor. InitIfNeeded()84 Status InitIfNeeded() { 85 if (initialized_) { 86 return Status::OK(); 87 } 88 initialized_ = true; 89 90 Scope root = Scope::DisabledShapeInferenceScope().ExitOnError(); 91 SessionOptions options; 92 session_ = std::unique_ptr<Session>(NewSession(options)); 93 94 BuildGraph(root); 95 96 GraphDef graph_def; 97 TF_CHECK_OK(root.ToGraphDef(&graph_def)); 98 99 Status status = session_->Create(graph_def); 100 if (!status.ok()) { 101 // This is FATAL, because this code is designed to fuzz an op 102 // within a session. Failure to create the session means we 103 // can't send any data to the op. 104 LOG(FATAL) << "Could not create session: " << status.error_message(); 105 } 106 return status; 107 } 108 109 // Runs the TF session by pulling on the "output" node, attaching 110 // the supplied input_tensor to the input node(s), and discarding 111 // any returned output. 112 // Note: We are ignoring Status from Run here since fuzzers don't need to 113 // check it (as that will slow them down and printing/logging is useless). RunInputs(const std::vector<std::pair<string,Tensor>> & inputs)114 void RunInputs(const std::vector<std::pair<string, Tensor> >& inputs) { 115 RunInputsWithStatus(inputs).IgnoreError(); 116 } 117 118 // Same as RunInputs but don't ignore status RunInputsWithStatus(const std::vector<std::pair<string,Tensor>> & inputs)119 Status RunInputsWithStatus( 120 const std::vector<std::pair<string, Tensor> >& inputs) { 121 return session_->Run(inputs, {}, {"output"}, nullptr); 122 } 123 124 // Dispatches to FuzzImpl; small amount of sugar to keep the code 125 // of the per-op fuzzers tiny. Fuzz(const uint8_t * data,size_t size)126 int Fuzz(const uint8_t* data, size_t size) { 127 Status status = InitIfNeeded(); 128 TF_CHECK_OK(status) << "Fuzzer graph initialization failed: " 129 << status.error_message(); 130 // No return value from fuzzing: Success is defined as "did not 131 // crash". The actual application results are irrelevant. 132 FuzzImpl(data, size); 133 return 0; 134 } 135 136 private: 137 bool initialized_; 138 std::unique_ptr<Session> session_; 139 }; 140 141 // A specialized fuzz implementation for ops that take 142 // a single string. Caller must still define the op 143 // to plumb by overriding BuildGraph or using 144 // a plumbing macro. 145 class FuzzStringInputOp : public FuzzSession { FuzzImpl(const uint8_t * data,size_t size)146 void FuzzImpl(const uint8_t* data, size_t size) final { 147 Tensor input_tensor(tensorflow::DT_STRING, TensorShape({})); 148 input_tensor.scalar<tstring>()() = 149 string(reinterpret_cast<const char*>(data), size); 150 RunInputs({{"input", input_tensor}}); 151 } 152 }; 153 154 } // end namespace fuzzing 155 } // end namespace tensorflow 156 157 #endif // TENSORFLOW_CORE_KERNELS_FUZZING_FUZZ_SESSION_H_ 158