1# Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14"""Defines all the new composite ops used in the mnist example.""" 15 16# pylint: disable=g-direct-tensorflow-import 17# pylint: disable=missing-function-docstring 18 19import os 20import sys 21from absl import app 22 23import tensorflow as tf 24 25from tensorflow.compiler.mlir.tfr.python import composite 26from tensorflow.compiler.mlir.tfr.python.op_reg_gen import gen_register_op 27from tensorflow.compiler.mlir.tfr.python.tfr_gen import tfr_gen_from_module 28from tensorflow.python.ops import gen_math_ops 29from tensorflow.python.ops import gen_nn_ops 30from tensorflow.python.ops import math_ops 31from tensorflow.python.platform import flags 32 33Composite = composite.Composite 34FLAGS = flags.FLAGS 35 36flags.DEFINE_string( 37 'output', None, 38 'Path to write the genereated register op file and MLIR file.') 39 40flags.DEFINE_bool('gen_register_op', True, 41 'Generate register op cc file or tfr mlir file.') 42 43 44@Composite( 45 'NewConv2D', 46 inputs=['input_: T', 'filter_: T', 'bias: T'], 47 attrs=[ 48 'stride_w: int', 'stride_h: int', 'dilation_w: int', 'dilation_h: int', 49 'padding: {"SAME", "VALID"}', 'act: {"", "RELU", "RELU6", "TANH"} = ""' 50 ], 51 derived_attrs=['T: {float, int8}'], 52 outputs=['o: T']) 53def _composite_conv_add_relu(input_, filter_, bias, stride_w, stride_h, 54 dilation_w, dilation_h, padding, act): 55 res = tf.raw_ops.Conv2D( 56 input=input_, 57 filter=filter_, 58 strides=[1, stride_w, stride_h, 1], 59 dilations=[1, dilation_w, dilation_h, 1], 60 padding=padding) 61 res = tf.raw_ops.Add(x=res, y=bias) 62 if act == 'RELU': 63 return tf.raw_ops.Relu(features=res) 64 elif act == 'RELU6': 65 return tf.raw_ops.Relu6(features=res) 66 elif act == 'TANH': 67 return tf.raw_ops.Tanh(x=res) 68 else: 69 return res 70 71 72@tf.RegisterGradient('NewConv2D') 73def _conv_add_relu_grad(op, grad): 74 act = op.get_attr('act') 75 y = op.outputs[0] 76 if act == 'RELU': 77 grad = gen_nn_ops.relu_grad(grad, y) 78 elif act == 'RELU6': 79 grad = gen_nn_ops.relu6_grad(grad, y) 80 elif act == 'TANH': 81 y = math_ops.conj(y) 82 grad = gen_math_ops.tanh_grad(y, grad) 83 84 broadcast_shape = tf.shape(y) 85 input_value_shape = tf.shape(op.inputs[2]) 86 _, reduction_axes = tf.raw_ops.BroadcastGradientArgs( 87 s0=broadcast_shape, s1=input_value_shape) 88 updates_grad_reshaped = tf.reduce_sum( 89 grad, axis=reduction_axes, keepdims=True) 90 bias_grad = tf.reshape(updates_grad_reshaped, input_value_shape) 91 92 dilations = [1, op.get_attr('dilation_w'), op.get_attr('dilation_h'), 1] 93 strides = [1, op.get_attr('stride_w'), op.get_attr('stride_h'), 1] 94 padding = op.get_attr('padding') 95 shape_0, shape_1 = tf.shape_n([op.inputs[0], op.inputs[1]]) 96 return [ 97 tf.compat.v1.nn.conv2d_backprop_input( 98 shape_0, 99 op.inputs[1], 100 grad, 101 strides=strides, 102 padding=padding, 103 dilations=dilations, 104 data_format='NHWC'), 105 tf.compat.v1.nn.conv2d_backprop_filter( 106 op.inputs[0], 107 shape_1, 108 grad, 109 strides=strides, 110 padding=padding, 111 dilations=dilations, 112 data_format='NHWC'), bias_grad 113 ] 114 115 116@Composite( 117 'NewFullyConnected', 118 inputs=['input_: T', 'filter_: T', 'bias: T'], 119 attrs=['act: {"", "RELU", "RELU6", "TANH"} = ""'], 120 derived_attrs=['T: {float, int8}'], 121 outputs=['o: T']) 122def _composite_fully_connected(input_, filter_, bias, act): 123 res = tf.raw_ops.MatMul( 124 a=input_, b=filter_, transpose_a=False, transpose_b=True) 125 res = tf.raw_ops.Add(x=res, y=bias) 126 if act == 'RELU': 127 return tf.raw_ops.Relu(features=res) 128 elif act == 'RELU6': 129 return tf.raw_ops.Relu6(features=res) 130 elif act == 'TANH': 131 return tf.raw_ops.Tanh(x=res) 132 else: 133 return res 134 135 136@tf.RegisterGradient('NewFullyConnected') 137def _fully_connected_grad(op, grad): 138 act = op.get_attr('act') 139 y = op.outputs[0] 140 if act == 'RELU': 141 grad = gen_nn_ops.relu_grad(grad, y) 142 elif act == 'RELU6': 143 grad = gen_nn_ops.relu6_grad(grad, y) 144 elif act == 'TANH': 145 y = math_ops.conj(y) 146 grad = gen_math_ops.tanh_grad(y, grad) 147 148 broadcast_shape = tf.shape(y) 149 input_value_shape = tf.shape(op.inputs[2]) 150 _, reduction_axes = tf.raw_ops.BroadcastGradientArgs( 151 s0=broadcast_shape, s1=input_value_shape) 152 updates_grad_reshaped = tf.reduce_sum( 153 grad, axis=reduction_axes, keepdims=True) 154 bias_grad = tf.reshape(updates_grad_reshaped, input_value_shape) 155 156 a = math_ops.conj(op.inputs[0]) 157 b = math_ops.conj(op.inputs[1]) 158 grad_a = gen_math_ops.mat_mul(grad, b) 159 grad_b = gen_math_ops.mat_mul(grad, a, transpose_a=True) 160 return [grad_a, grad_b, bias_grad] 161 162 163@Composite( 164 'NewMaxPool', 165 inputs=['input_: T'], 166 attrs=[ 167 'stride_w: int', 'stride_h: int', 'filter_width: int', 168 'filter_height: int', 'padding: {"SAME", "VALID"}' 169 ], 170 derived_attrs=['T: {float, int8}'], 171 outputs=['o: T']) 172def _composite_max_pool(input_, stride_w, stride_h, filter_width, filter_height, 173 padding): 174 ksize = [1, filter_width, filter_height, 1] 175 strides = [1, stride_w, stride_h, 1] 176 return tf.raw_ops.MaxPool( 177 input=input_, ksize=ksize, strides=strides, padding=padding) 178 179 180@tf.RegisterGradient('NewMaxPool') 181def _max_pool_grad(op, grad): 182 filter_width = op.get_attr('filter_width') 183 filter_height = op.get_attr('filter_height') 184 stride_w = op.get_attr('stride_w') 185 stride_h = op.get_attr('stride_h') 186 padding = op.get_attr('padding') 187 return tf.raw_ops.MaxPoolGrad( 188 orig_input=op.inputs[0], 189 orig_output=op.outputs[0], 190 grad=grad, 191 ksize=[1, filter_width, filter_height, 1], 192 strides=[1, stride_w, stride_h, 1], 193 padding=padding, 194 data_format='NHWC') 195 196 197def main(_): 198 if FLAGS.gen_register_op: 199 assert FLAGS.output.endswith('.cc') 200 generated_code = gen_register_op(sys.modules[__name__], '_composite_') 201 else: 202 assert FLAGS.output.endswith('.mlir') 203 generated_code = tfr_gen_from_module(sys.modules[__name__], '_composite_',) 204 205 dirname = os.path.dirname(FLAGS.output) 206 if not os.path.exists(dirname): 207 os.makedirs(dirname) 208 with open(FLAGS.output, 'w') as f: 209 f.write(generated_code) 210 211 212if __name__ == '__main__': 213 app.run(main=main) 214