• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_KERNELS_TRANSPOSE_OP_H_
17 #define TENSORFLOW_KERNELS_TRANSPOSE_OP_H_
18 
19 #include "tensorflow/core/framework/op_kernel.h"
20 #include "tensorflow/core/framework/tensor.h"
21 
22 namespace tensorflow {
23 
24 class TransposeOp : public OpKernel {
25  public:
TransposeOp(OpKernelConstruction * ctx)26   explicit TransposeOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
27 
28   void Compute(OpKernelContext* ctx) override;
29 
30  protected:
31   virtual Status DoTranspose(OpKernelContext* ctx, const Tensor& in,
32                              gtl::ArraySlice<int32> perm, Tensor* out) = 0;
IsConjugate()33   virtual bool IsConjugate() const { return false; }
34 };
35 
36 class TransposeCpuOp : public TransposeOp {
37  public:
TransposeCpuOp(OpKernelConstruction * ctx)38   explicit TransposeCpuOp(OpKernelConstruction* ctx) : TransposeOp(ctx) {}
39 
40  protected:
41   Status DoTranspose(OpKernelContext* ctx, const Tensor& in,
42                      gtl::ArraySlice<int32> perm, Tensor* out) override;
43 };
44 
45 #if defined(INTEL_MKL)
46 class MklTransposeCpuOp : public TransposeOp {
47  public:
MklTransposeCpuOp(OpKernelConstruction * ctx)48   explicit MklTransposeCpuOp(OpKernelConstruction* ctx) : TransposeOp(ctx) {}
49 
50  protected:
51   Status DoTranspose(OpKernelContext* ctx, const Tensor& in,
52                      gtl::ArraySlice<int32> perm, Tensor* out) override;
53 };
54 #endif  // INTEL_MKL
55 
56 class TransposeGpuOp : public TransposeOp {
57  public:
TransposeGpuOp(OpKernelConstruction * ctx)58   explicit TransposeGpuOp(OpKernelConstruction* ctx) : TransposeOp(ctx) {}
59 
60  protected:
61   Status DoTranspose(OpKernelContext* ctx, const Tensor& in,
62                      gtl::ArraySlice<int32> perm, Tensor* out) override;
63 };
64 
65 #ifdef TENSORFLOW_USE_SYCL
66 class TransposeSyclOp : public TransposeOp {
67  public:
TransposeSyclOp(OpKernelConstruction * ctx)68   explicit TransposeSyclOp(OpKernelConstruction* ctx) : TransposeOp(ctx) {}
69 
70  protected:
71   Status DoTranspose(OpKernelContext* ctx, const Tensor& in,
72                      gtl::ArraySlice<int32> perm, Tensor* out) override;
73 };
74 #endif  // TENSORFLOW_USE_SYCL
75 
76 // Conjugating transpose ops.
77 class ConjugateTransposeCpuOp : public TransposeOp {
78  public:
ConjugateTransposeCpuOp(OpKernelConstruction * ctx)79   explicit ConjugateTransposeCpuOp(OpKernelConstruction* ctx)
80       : TransposeOp(ctx) {}
81 
82  protected:
83   Status DoTranspose(OpKernelContext* ctx, const Tensor& in,
84                      gtl::ArraySlice<int32> perm, Tensor* out) override;
IsConjugate()85   bool IsConjugate() const override { return true; }
86 };
87 
88 #if defined(INTEL_MKL)
89 class MklConjugateTransposeCpuOp : public TransposeOp {
90  public:
MklConjugateTransposeCpuOp(OpKernelConstruction * ctx)91   explicit MklConjugateTransposeCpuOp(OpKernelConstruction* ctx)
92       : TransposeOp(ctx) {}
93 
94  protected:
95   Status DoTranspose(OpKernelContext* ctx, const Tensor& in,
96                      gtl::ArraySlice<int32> perm, Tensor* out) override;
IsConjugate()97   bool IsConjugate() const override { return true; }
98 };
99 #endif  // INTEL_MKL
100 
101 class ConjugateTransposeGpuOp : public TransposeOp {
102  public:
ConjugateTransposeGpuOp(OpKernelConstruction * ctx)103   explicit ConjugateTransposeGpuOp(OpKernelConstruction* ctx)
104       : TransposeOp(ctx) {}
105 
106  protected:
107   Status DoTranspose(OpKernelContext* ctx, const Tensor& in,
108                      gtl::ArraySlice<int32> perm, Tensor* out) override;
IsConjugate()109   bool IsConjugate() const override { return true; }
110 };
111 
112 #ifdef TENSORFLOW_USE_SYCL
113 class ConjugateTransposeSyclOp : public TransposeOp {
114  public:
ConjugateTransposeSyclOp(OpKernelConstruction * ctx)115   explicit ConjugateTransposeSyclOp(OpKernelConstruction* ctx)
116       : TransposeOp(ctx) {}
117 
118  protected:
119   Status DoTranspose(OpKernelContext* ctx, const Tensor& in,
120                      gtl::ArraySlice<int32> perm, Tensor* out) override;
IsConjugate()121   bool IsConjugate() const override { return true; }
122 };
123 #endif  // TENSORFLOW_USE_SYCL
124 
125 }  // namespace tensorflow
126 
127 #endif  // TENSORFLOW_KERNELS_TRANSPOSE_OP_H_
128