1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUSTOM_CALL_THUNK_H_ 17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUSTOM_CALL_THUNK_H_ 18 19 #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h" 20 #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h" 21 #include "tensorflow/compiler/xla/service/gpu/thunk.h" 22 23 namespace xla { 24 namespace gpu { 25 26 // Thunk to run a GPU custom call. 27 // 28 // This thunk's `ExecuteOnStream` implementation executes a host function 29 // `call_target` which is expected to enqueue operations onto the GPU. 30 // 31 // For information about the calling convention, see xla/g3doc/custom_call.md 32 // 33 // Note that not all kCustomCall HLOs in XLA:GPU end up being run by this thunk. 34 // XLA itself creates kCustomCall instructions when lowering kConvolution HLOs 35 // into calls to cudnn. These internally-created custom-calls are run using 36 // ConvolutionThunk, not CustomCallThunk. There's no ambiguity because they 37 // have special call target names (e.g. "__cudnn$convForward") that only the 38 // compiler is allowed to create. 39 class CustomCallThunk : public Thunk { 40 public: 41 CustomCallThunk(ThunkInfo thunk_info, void* call_target, 42 std::vector<BufferAllocation::Slice> operands, 43 std::vector<BufferAllocation::Slice> results, 44 const std::string& opaque); 45 46 Status ExecuteOnStream(const ExecuteParams& params) override; 47 48 private: 49 void* call_target_; 50 const std::vector<BufferAllocation::Slice> operands_; 51 const std::vector<BufferAllocation::Slice> results_; 52 const std::string opaque_; 53 }; 54 55 } // namespace gpu 56 } // namespace xla 57 58 #endif // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUSTOM_CALL_THUNK_H_ 59