• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUSTOM_CALL_THUNK_H_
17 #define TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUSTOM_CALL_THUNK_H_
18 
19 #include "tensorflow/compiler/xla/service/gpu/buffer_allocations.h"
20 #include "tensorflow/compiler/xla/service/gpu/hlo_execution_profiler.h"
21 #include "tensorflow/compiler/xla/service/gpu/thunk.h"
22 
23 namespace xla {
24 namespace gpu {
25 
26 // Thunk to run a GPU custom call.
27 //
28 // This thunk's `ExecuteOnStream` implementation executes a host function
29 // `call_target` which is expected to enqueue operations onto the GPU.
30 //
31 // For information about the calling convention, see xla/g3doc/custom_call.md
32 //
33 // Note that not all kCustomCall HLOs in XLA:GPU end up being run by this thunk.
34 // XLA itself creates kCustomCall instructions when lowering kConvolution HLOs
35 // into calls to cudnn.  These internally-created custom-calls are run using
36 // ConvolutionThunk, not CustomCallThunk.  There's no ambiguity because they
37 // have special call target names (e.g. "__cudnn$convForward") that only the
38 // compiler is allowed to create.
39 class CustomCallThunk : public Thunk {
40  public:
41   CustomCallThunk(ThunkInfo thunk_info, void* call_target,
42                   std::vector<BufferAllocation::Slice> operands,
43                   std::vector<BufferAllocation::Slice> results,
44                   const std::string& opaque);
45 
46   Status ExecuteOnStream(const ExecuteParams& params) override;
47 
48  private:
49   void* call_target_;
50   const std::vector<BufferAllocation::Slice> operands_;
51   const std::vector<BufferAllocation::Slice> results_;
52   const std::string opaque_;
53 };
54 
55 }  // namespace gpu
56 }  // namespace xla
57 
58 #endif  // TENSORFLOW_COMPILER_XLA_SERVICE_GPU_CUSTOM_CALL_THUNK_H_
59