• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/c/experimental/ops/math_ops.h"
16 
17 #include "tensorflow/c/eager/abstract_context.h"
18 #include "tensorflow/c/eager/abstract_tensor_handle.h"
19 #include "tensorflow/c/eager/tracing_utils.h"
20 #include "tensorflow/c/experimental/ops/array_ops.h"
21 #include "tensorflow/core/framework/types.h"
22 #include "tensorflow/core/platform/errors.h"
23 
24 using tensorflow::tracing::MaybeSetOpName;
25 
26 namespace tensorflow {
27 namespace ops {
28 
Mul(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs,const char * name)29 Status Mul(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
30            absl::Span<AbstractTensorHandle*> outputs, const char* name) {
31   AbstractOperationPtr mul_op(ctx->CreateOperation());
32   TF_RETURN_IF_ERROR(mul_op->Reset("Mul", /*raw_device_name=*/nullptr));
33   TF_RETURN_IF_ERROR(MaybeSetOpName(mul_op.get(), name));
34   TF_RETURN_IF_ERROR(mul_op->AddInput(inputs[0]));
35   TF_RETURN_IF_ERROR(mul_op->AddInput(inputs[1]));
36   int num_retvals = 1;
37   return mul_op->Execute(outputs, &num_retvals);
38 }
39 
Conj(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs,const char * name)40 Status Conj(AbstractContext* ctx,
41             absl::Span<AbstractTensorHandle* const> inputs,
42             absl::Span<AbstractTensorHandle*> outputs, const char* name) {
43   auto dtype = inputs[0]->DataType();
44   if (DataTypeIsFloating(BaseType(dtype)) ||
45       DataTypeIsInteger(BaseType(dtype))) {
46     TF_RETURN_IF_ERROR(Identity(ctx, inputs, outputs, name));
47   } else if (DataTypeIsComplex(BaseType(dtype)) ||
48              BaseType(dtype) == DT_VARIANT) {
49     AbstractOperationPtr conj_op(ctx->CreateOperation());
50     TF_RETURN_IF_ERROR(conj_op->Reset("Conj", /*raw_device_name=*/nullptr));
51     TF_RETURN_IF_ERROR(MaybeSetOpName(conj_op.get(), name));
52     TF_RETURN_IF_ERROR(conj_op->AddInput(inputs[0]));
53 
54     int num_retvals = 1;
55     TF_RETURN_IF_ERROR(conj_op->Execute(outputs, &num_retvals));
56   } else {
57     return errors::InvalidArgument(
58         "Expected numeric or variant tensor, got dtype ", dtype);
59   }
60   return Status::OK();
61 }
62 
Add(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs,const char * name)63 Status Add(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
64            absl::Span<AbstractTensorHandle*> outputs, const char* name) {
65   AbstractOperationPtr add_op(ctx->CreateOperation());
66   TF_RETURN_IF_ERROR(add_op->Reset("AddV2", /*raw_device_name=*/nullptr));
67   TF_RETURN_IF_ERROR(MaybeSetOpName(add_op.get(), name));
68   TF_RETURN_IF_ERROR(add_op->AddInput(inputs[0]));
69   TF_RETURN_IF_ERROR(add_op->AddInput(inputs[1]));
70 
71   int num_retvals = 1;
72   TF_RETURN_IF_ERROR(add_op->Execute(outputs, &num_retvals));
73   return Status::OK();
74 }
75 
Sub(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs,const char * name)76 Status Sub(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
77            absl::Span<AbstractTensorHandle*> outputs, const char* name) {
78   AbstractOperationPtr sub_op(ctx->CreateOperation());
79   TF_RETURN_IF_ERROR(sub_op->Reset("Sub", /*raw_device_name=*/nullptr));
80   TF_RETURN_IF_ERROR(MaybeSetOpName(sub_op.get(), name));
81   TF_RETURN_IF_ERROR(sub_op->AddInput(inputs[0]));
82   TF_RETURN_IF_ERROR(sub_op->AddInput(inputs[1]));
83 
84   int num_retvals = 1;
85   TF_RETURN_IF_ERROR(sub_op->Execute(outputs, &num_retvals));
86   return Status::OK();
87 }
88 
MatMul(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs,const char * name,bool transpose_a=false,bool transpose_b=false)89 Status MatMul(AbstractContext* ctx,
90               absl::Span<AbstractTensorHandle* const> inputs,
91               absl::Span<AbstractTensorHandle*> outputs, const char* name,
92               bool transpose_a = false, bool transpose_b = false) {
93   AbstractOperationPtr matmul_op(ctx->CreateOperation());
94   TF_RETURN_IF_ERROR(matmul_op->Reset("MatMul", /*raw_device_name=*/nullptr));
95   TF_RETURN_IF_ERROR(MaybeSetOpName(matmul_op.get(), name));
96   TF_RETURN_IF_ERROR(matmul_op->AddInput(inputs[0]));
97   TF_RETURN_IF_ERROR(matmul_op->AddInput(inputs[1]));
98 
99   TF_RETURN_IF_ERROR(matmul_op->SetAttrBool("transpose_a", transpose_a));
100   TF_RETURN_IF_ERROR(matmul_op->SetAttrBool("transpose_b", transpose_b));
101 
102   int num_retvals = 1;
103   TF_RETURN_IF_ERROR(matmul_op->Execute(outputs, &num_retvals));
104   return Status::OK();
105 }
106 
Neg(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs,const char * name)107 Status Neg(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
108            absl::Span<AbstractTensorHandle*> outputs, const char* name) {
109   AbstractOperationPtr neg_op(ctx->CreateOperation());
110   TF_RETURN_IF_ERROR(neg_op->Reset("Neg", /*raw_device_name=*/nullptr));
111   TF_RETURN_IF_ERROR(MaybeSetOpName(neg_op.get(), name));
112   TF_RETURN_IF_ERROR(neg_op->AddInput(inputs[0]));
113 
114   int num_retvals = 1;
115   return neg_op->Execute(outputs, &num_retvals);
116 }
117 
Sum(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs,const char * name)118 Status Sum(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
119            absl::Span<AbstractTensorHandle*> outputs, const char* name) {
120   AbstractOperationPtr sum_op(ctx->CreateOperation());
121   TF_RETURN_IF_ERROR(sum_op->Reset("Sum", /*raw_device_name=*/nullptr));
122   TF_RETURN_IF_ERROR(MaybeSetOpName(sum_op.get(), name));
123   TF_RETURN_IF_ERROR(sum_op->AddInput(inputs[0]));  // input_vals
124   TF_RETURN_IF_ERROR(sum_op->AddInput(inputs[1]));  // reduction_indices
125 
126   int num_retvals = 1;
127   TF_RETURN_IF_ERROR(sum_op->Execute(outputs, &num_retvals));
128   return Status::OK();
129 }
130 
Div(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs,const char * name)131 Status Div(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
132            absl::Span<AbstractTensorHandle*> outputs, const char* name) {
133   AbstractOperationPtr div_op(ctx->CreateOperation());
134   TF_RETURN_IF_ERROR(div_op->Reset("Div", /*raw_device_name=*/nullptr));
135   TF_RETURN_IF_ERROR(MaybeSetOpName(div_op.get(), name));
136   TF_RETURN_IF_ERROR(div_op->AddInput(inputs[0]));  // x
137   TF_RETURN_IF_ERROR(div_op->AddInput(inputs[1]));  // y
138 
139   int num_retvals = 1;
140   TF_RETURN_IF_ERROR(div_op->Execute(outputs, &num_retvals));  // z = x / y
141   return Status::OK();
142 }
143 
DivNoNan(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs,const char * name)144 Status DivNoNan(AbstractContext* ctx,
145                 absl::Span<AbstractTensorHandle* const> inputs,
146                 absl::Span<AbstractTensorHandle*> outputs, const char* name) {
147   AbstractOperationPtr div_op(ctx->CreateOperation());
148   TF_RETURN_IF_ERROR(div_op->Reset("DivNoNan", /*raw_device_name=*/nullptr));
149   TF_RETURN_IF_ERROR(MaybeSetOpName(div_op.get(), name));
150   TF_RETURN_IF_ERROR(div_op->AddInput(inputs[0]));  // x
151   TF_RETURN_IF_ERROR(div_op->AddInput(inputs[1]));  // y
152 
153   int num_retvals = 1;
154   TF_RETURN_IF_ERROR(div_op->Execute(
155       outputs, &num_retvals));  // z = x / y, (z_i = 0 if y_i = 0)
156   return Status::OK();
157 }
158 
Exp(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs,const char * name)159 Status Exp(AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
160            absl::Span<AbstractTensorHandle*> outputs, const char* name) {
161   AbstractOperationPtr exp_op(ctx->CreateOperation());
162   TF_RETURN_IF_ERROR(exp_op->Reset("Exp", /*raw_device_name=*/nullptr));
163   TF_RETURN_IF_ERROR(MaybeSetOpName(exp_op.get(), name));
164   TF_RETURN_IF_ERROR(exp_op->AddInput(inputs[0]));
165 
166   int num_retvals = 1;
167   return exp_op->Execute(outputs, &num_retvals);
168 }
169 
Sqrt(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs,const char * name)170 Status Sqrt(AbstractContext* ctx,
171             absl::Span<AbstractTensorHandle* const> inputs,
172             absl::Span<AbstractTensorHandle*> outputs, const char* name) {
173   AbstractOperationPtr sqrt_op(ctx->CreateOperation());
174   TF_RETURN_IF_ERROR(sqrt_op->Reset("Sqrt", /*raw_device_name=*/nullptr));
175   TF_RETURN_IF_ERROR(MaybeSetOpName(sqrt_op.get(), name));
176   TF_RETURN_IF_ERROR(sqrt_op->AddInput(inputs[0]));
177 
178   int num_retvals = 1;
179   Status s = sqrt_op->Execute(outputs, &num_retvals);
180   return s;
181 }
182 
SqrtGrad(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs,const char * name)183 Status SqrtGrad(AbstractContext* ctx,
184                 absl::Span<AbstractTensorHandle* const> inputs,
185                 absl::Span<AbstractTensorHandle*> outputs, const char* name) {
186   AbstractOperationPtr sqrt_grad_op(ctx->CreateOperation());
187   TF_RETURN_IF_ERROR(
188       sqrt_grad_op->Reset("SqrtGrad", /*raw_device_name=*/nullptr));
189   TF_RETURN_IF_ERROR(MaybeSetOpName(sqrt_grad_op.get(), name));
190   TF_RETURN_IF_ERROR(sqrt_grad_op->AddInput(inputs[0]));
191   TF_RETURN_IF_ERROR(sqrt_grad_op->AddInput(inputs[1]));
192 
193   int num_retvals = 1;
194   Status s = sqrt_grad_op->Execute(outputs, &num_retvals);
195   return s;
196 }
197 
Log1p(AbstractContext * ctx,absl::Span<AbstractTensorHandle * const> inputs,absl::Span<AbstractTensorHandle * > outputs,const char * name)198 Status Log1p(AbstractContext* ctx,
199              absl::Span<AbstractTensorHandle* const> inputs,
200              absl::Span<AbstractTensorHandle*> outputs, const char* name) {
201   AbstractOperationPtr log1p_op(ctx->CreateOperation());
202   TF_RETURN_IF_ERROR(log1p_op->Reset("Log1p", /*raw_device_name=*/nullptr));
203   TF_RETURN_IF_ERROR(MaybeSetOpName(log1p_op.get(), name));
204   TF_RETURN_IF_ERROR(log1p_op->AddInput(inputs[0]));
205 
206   int num_retvals = 1;
207   Status s = log1p_op->Execute(outputs, &num_retvals);
208   return s;
209 }
210 
211 }  // namespace ops
212 }  // namespace tensorflow
213