• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14"""Defines all the new composite ops used in the mnist example."""
15
16# pylint: disable=g-direct-tensorflow-import
17# pylint: disable=missing-function-docstring
18
19import os
20import sys
21from absl import app
22
23import tensorflow as tf
24
25from tensorflow.compiler.mlir.tfr.python import composite
26from tensorflow.compiler.mlir.tfr.python.op_reg_gen import gen_register_op
27from tensorflow.compiler.mlir.tfr.python.tfr_gen import tfr_gen_from_module
28from tensorflow.python.ops import gen_math_ops
29from tensorflow.python.ops import gen_nn_ops
30from tensorflow.python.ops import math_ops
31from tensorflow.python.platform import flags
32
33Composite = composite.Composite
34FLAGS = flags.FLAGS
35
36flags.DEFINE_string(
37    'output', None,
38    'Path to write the genereated register op file and MLIR file.')
39
40flags.DEFINE_bool('gen_register_op', True,
41                  'Generate register op cc file or tfr mlir file.')
42
43
44@Composite(
45    'NewConv2D',
46    inputs=['input_: T', 'filter_: T', 'bias: T'],
47    attrs=[
48        'stride_w: int', 'stride_h: int', 'dilation_w: int', 'dilation_h: int',
49        'padding: {"SAME", "VALID"}', 'act: {"", "RELU", "RELU6", "TANH"} = ""'
50    ],
51    derived_attrs=['T: {float, int8}'],
52    outputs=['o: T'])
53def _composite_conv_add_relu(input_, filter_, bias, stride_w, stride_h,
54                             dilation_w, dilation_h, padding, act):
55  res = tf.raw_ops.Conv2D(
56      input=input_,
57      filter=filter_,
58      strides=[1, stride_w, stride_h, 1],
59      dilations=[1, dilation_w, dilation_h, 1],
60      padding=padding)
61  res = tf.raw_ops.Add(x=res, y=bias)
62  if act == 'RELU':
63    return tf.raw_ops.Relu(features=res)
64  elif act == 'RELU6':
65    return tf.raw_ops.Relu6(features=res)
66  elif act == 'TANH':
67    return tf.raw_ops.Tanh(x=res)
68  else:
69    return res
70
71
72@tf.RegisterGradient('NewConv2D')
73def _conv_add_relu_grad(op, grad):
74  act = op.get_attr('act')
75  y = op.outputs[0]
76  if act == 'RELU':
77    grad = gen_nn_ops.relu_grad(grad, y)
78  elif act == 'RELU6':
79    grad = gen_nn_ops.relu6_grad(grad, y)
80  elif act == 'TANH':
81    y = math_ops.conj(y)
82    grad = gen_math_ops.tanh_grad(y, grad)
83
84  broadcast_shape = tf.shape(y)
85  input_value_shape = tf.shape(op.inputs[2])
86  _, reduction_axes = tf.raw_ops.BroadcastGradientArgs(
87      s0=broadcast_shape, s1=input_value_shape)
88  updates_grad_reshaped = tf.reduce_sum(
89      grad, axis=reduction_axes, keepdims=True)
90  bias_grad = tf.reshape(updates_grad_reshaped, input_value_shape)
91
92  dilations = [1, op.get_attr('dilation_w'), op.get_attr('dilation_h'), 1]
93  strides = [1, op.get_attr('stride_w'), op.get_attr('stride_h'), 1]
94  padding = op.get_attr('padding')
95  shape_0, shape_1 = tf.shape_n([op.inputs[0], op.inputs[1]])
96  return [
97      tf.compat.v1.nn.conv2d_backprop_input(
98          shape_0,
99          op.inputs[1],
100          grad,
101          strides=strides,
102          padding=padding,
103          dilations=dilations,
104          data_format='NHWC'),
105      tf.compat.v1.nn.conv2d_backprop_filter(
106          op.inputs[0],
107          shape_1,
108          grad,
109          strides=strides,
110          padding=padding,
111          dilations=dilations,
112          data_format='NHWC'), bias_grad
113  ]
114
115
116@Composite(
117    'NewFullyConnected',
118    inputs=['input_: T', 'filter_: T', 'bias: T'],
119    attrs=['act: {"", "RELU", "RELU6", "TANH"} = ""'],
120    derived_attrs=['T: {float, int8}'],
121    outputs=['o: T'])
122def _composite_fully_connected(input_, filter_, bias, act):
123  res = tf.raw_ops.MatMul(
124      a=input_, b=filter_, transpose_a=False, transpose_b=True)
125  res = tf.raw_ops.Add(x=res, y=bias)
126  if act == 'RELU':
127    return tf.raw_ops.Relu(features=res)
128  elif act == 'RELU6':
129    return tf.raw_ops.Relu6(features=res)
130  elif act == 'TANH':
131    return tf.raw_ops.Tanh(x=res)
132  else:
133    return res
134
135
136@tf.RegisterGradient('NewFullyConnected')
137def _fully_connected_grad(op, grad):
138  act = op.get_attr('act')
139  y = op.outputs[0]
140  if act == 'RELU':
141    grad = gen_nn_ops.relu_grad(grad, y)
142  elif act == 'RELU6':
143    grad = gen_nn_ops.relu6_grad(grad, y)
144  elif act == 'TANH':
145    y = math_ops.conj(y)
146    grad = gen_math_ops.tanh_grad(y, grad)
147
148  broadcast_shape = tf.shape(y)
149  input_value_shape = tf.shape(op.inputs[2])
150  _, reduction_axes = tf.raw_ops.BroadcastGradientArgs(
151      s0=broadcast_shape, s1=input_value_shape)
152  updates_grad_reshaped = tf.reduce_sum(
153      grad, axis=reduction_axes, keepdims=True)
154  bias_grad = tf.reshape(updates_grad_reshaped, input_value_shape)
155
156  a = math_ops.conj(op.inputs[0])
157  b = math_ops.conj(op.inputs[1])
158  grad_a = gen_math_ops.mat_mul(grad, b)
159  grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True)
160  return [grad_a, grad_b, bias_grad]
161
162
163@Composite(
164    'NewMaxPool',
165    inputs=['input_: T'],
166    attrs=[
167        'stride_w: int', 'stride_h: int', 'filter_width: int',
168        'filter_height: int', 'padding: {"SAME", "VALID"}'
169    ],
170    derived_attrs=['T: {float, int8}'],
171    outputs=['o: T'])
172def _composite_max_pool(input_, stride_w, stride_h, filter_width, filter_height,
173                        padding):
174  ksize = [1, filter_width, filter_height, 1]
175  strides = [1, stride_w, stride_h, 1]
176  return tf.raw_ops.MaxPool(
177      input=input_, ksize=ksize, strides=strides, padding=padding)
178
179
180@tf.RegisterGradient('NewMaxPool')
181def _max_pool_grad(op, grad):
182  filter_width = op.get_attr('filter_width')
183  filter_height = op.get_attr('filter_height')
184  stride_w = op.get_attr('stride_w')
185  stride_h = op.get_attr('stride_h')
186  padding = op.get_attr('padding')
187  return tf.raw_ops.MaxPoolGrad(
188      orig_input=op.inputs[0],
189      orig_output=op.outputs[0],
190      grad=grad,
191      ksize=[1, filter_width, filter_height, 1],
192      strides=[1, stride_w, stride_h, 1],
193      padding=padding,
194      data_format='NHWC')
195
196
197def main(_):
198  if FLAGS.gen_register_op:
199    assert FLAGS.output.endswith('.cc')
200    generated_code = gen_register_op(sys.modules[__name__], '_composite_')
201  else:
202    assert FLAGS.output.endswith('.mlir')
203    generated_code = tfr_gen_from_module(sys.modules[__name__], '_composite_',)
204
205  dirname = os.path.dirname(FLAGS.output)
206  if not os.path.exists(dirname):
207    os.makedirs(dirname)
208  with open(FLAGS.output, 'w') as f:
209    f.write(generated_code)
210
211
212if __name__ == '__main__':
213  app.run(main=main)
214