1""" 2/* Copyright (c) 2022 Amazon 3 Written by Jan Buethe */ 4/* 5 Redistribution and use in source and binary forms, with or without 6 modification, are permitted provided that the following conditions 7 are met: 8 9 - Redistributions of source code must retain the above copyright 10 notice, this list of conditions and the following disclaimer. 11 12 - Redistributions in binary form must reproduce the above copyright 13 notice, this list of conditions and the following disclaimer in the 14 documentation and/or other materials provided with the distribution. 15 16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS 17 ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT 18 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR 19 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER 20 OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, 21 EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, 22 PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR 23 PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF 24 LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING 25 NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 26 SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 27*/ 28""" 29 30import os 31import argparse 32import sys 33 34sys.path.append(os.path.join(os.path.dirname(__file__), '../weight-exchange')) 35 36parser = argparse.ArgumentParser() 37 38parser.add_argument('checkpoint', type=str, help='rdovae model checkpoint') 39parser.add_argument('output_dir', type=str, help='output folder') 40parser.add_argument('--format', choices=['C', 'numpy'], help='output format, default: C', default='C') 41 42args = parser.parse_args() 43 44import torch 45import numpy as np 46 47from rdovae import RDOVAE 48from wexchange.torch import dump_torch_weights 49from wexchange.c_export import CWriter, print_vector 50 51def print_xml(xmlout, val, param, anchor, name): 52 xmlout.write( 53f""" 54 <table anchor="{anchor}_{name}"> 55 <name>{param} values for {name}</name> 56 <thead> 57 <tr><th>k</th><th>Q0</th><th>Q1</th><th>Q2</th><th>Q3</th><th>Q4</th><th>Q5</th><th>Q6</th><th>Q7</th><th>Q8</th><th>Q9</th><th>Q10</th><th>Q11</th><th>Q12</th><th>Q13</th><th>Q14</th><th>Q15</th></tr> 58 </thead> 59 <tbody> 60""") 61 for k in range(val.shape[1]): 62 xmlout.write(f" <tr><th>{k}</th>") 63 for j in range(val.shape[0]): 64 xmlout.write(f"<th>{val[j][k]}</th>") 65 xmlout.write("</tr>\n") 66 xmlout.write( 67f""" 68 </tbody> 69 </table> 70""") 71def dump_statistical_model(writer, w, name, xmlout): 72 levels = w.shape[0] 73 74 print("printing statistical model") 75 quant_scales = torch.nn.functional.softplus(w[:, 0, :]).numpy() 76 dead_zone = 0.05 * torch.nn.functional.softplus(w[:, 1, :]).numpy() 77 r = torch.sigmoid(w[:, 5 , :]).numpy() 78 p0 = torch.sigmoid(w[:, 4 , :]).numpy() 79 p0 = 1 - r ** (0.5 + 0.5 * p0) 80 81 scales_norm = 255./256./(1e-15+np.max(quant_scales,axis=0)) 82 quant_scales = quant_scales*scales_norm 83 quant_scales_q8 = np.round(quant_scales * 2**8).astype(np.uint16) 84 dead_zone_q8 = np.clip(np.round(dead_zone * 2**8), 0, 255).astype(np.uint16) 85 r_q8 = np.clip(np.round(r * 2**8), 0, 255).astype(np.uint8) 86 p0_q8 = np.clip(np.round(p0 * 2**8), 0, 255).astype(np.uint16) 87 88 mask = (np.max(r_q8,axis=0) > 0) * (np.min(p0_q8,axis=0) < 255) 89 quant_scales_q8 = quant_scales_q8[:, mask] 90 dead_zone_q8 = dead_zone_q8[:, mask] 91 r_q8 = r_q8[:, mask] 92 p0_q8 = p0_q8[:, mask] 93 N = r_q8.shape[-1] 94 95 print_vector(writer.source, quant_scales_q8, f'dred_{name}_quant_scales_q8', dtype='opus_uint8', static=False) 96 print_vector(writer.source, dead_zone_q8, f'dred_{name}_dead_zone_q8', dtype='opus_uint8', static=False) 97 print_vector(writer.source, r_q8, f'dred_{name}_r_q8', dtype='opus_uint8', static=False) 98 print_vector(writer.source, p0_q8, f'dred_{name}_p0_q8', dtype='opus_uint8', static=False) 99 100 print_xml(xmlout, quant_scales_q8, "Scale", "scale", name) 101 print_xml(xmlout, dead_zone_q8, "Dead zone", "deadzone", name) 102 print_xml(xmlout, r_q8, "Decay (r)", "decay", name) 103 print_xml(xmlout, p0_q8, "P(0)", "p0", name) 104 105 writer.header.write( 106f""" 107extern const opus_uint8 dred_{name}_quant_scales_q8[{levels * N}]; 108extern const opus_uint8 dred_{name}_dead_zone_q8[{levels * N}]; 109extern const opus_uint8 dred_{name}_r_q8[{levels * N}]; 110extern const opus_uint8 dred_{name}_p0_q8[{levels * N}]; 111 112""" 113 ) 114 return N, mask, torch.tensor(scales_norm[mask]) 115 116 117def c_export(args, model): 118 119 message = f"Auto generated from checkpoint {os.path.basename(args.checkpoint)}" 120 121 enc_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_enc_data"), message=message, model_struct_name='RDOVAEEnc') 122 dec_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_dec_data"), message=message, model_struct_name='RDOVAEDec') 123 stats_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_stats_data"), message=message, enable_binary_blob=False) 124 constants_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_constants"), message=message, header_only=True, enable_binary_blob=False) 125 xmlout = open("stats.xml", "w") 126 127 # some custom includes 128 for writer in [enc_writer, dec_writer]: 129 writer.header.write( 130f""" 131#include "opus_types.h" 132 133#include "dred_rdovae.h" 134 135#include "dred_rdovae_constants.h" 136 137""" 138 ) 139 140 stats_writer.header.write( 141f""" 142#include "opus_types.h" 143 144#include "dred_rdovae_constants.h" 145 146""" 147 ) 148 149 latent_out = model.get_submodule('core_encoder.module.z_dense') 150 state_out = model.get_submodule('core_encoder.module.state_dense_2') 151 orig_latent_dim = latent_out.weight.shape[0] 152 orig_state_dim = state_out.weight.shape[0] 153 # statistical model 154 qembedding = model.statistical_model.quant_embedding.weight.detach() 155 levels = qembedding.shape[0] 156 qembedding = torch.reshape(qembedding, (levels, 6, -1)) 157 158 latent_dim, latent_mask, latent_scale = dump_statistical_model(stats_writer, qembedding[:, :, :orig_latent_dim], 'latent', xmlout) 159 state_dim, state_mask, state_scale = dump_statistical_model(stats_writer, qembedding[:, :, orig_latent_dim:], 'state', xmlout) 160 161 padded_latent_dim = (latent_dim+7)//8*8 162 latent_pad = padded_latent_dim - latent_dim; 163 w = latent_out.weight[latent_mask,:] 164 w = w/latent_scale[:, None] 165 w = torch.cat([w, torch.zeros(latent_pad, w.shape[1])], dim=0) 166 b = latent_out.bias[latent_mask] 167 b = b/latent_scale 168 b = torch.cat([b, torch.zeros(latent_pad)], dim=0) 169 latent_out.weight = torch.nn.Parameter(w) 170 latent_out.bias = torch.nn.Parameter(b) 171 172 padded_state_dim = (state_dim+7)//8*8 173 state_pad = padded_state_dim - state_dim; 174 w = state_out.weight[state_mask,:] 175 w = w/state_scale[:, None] 176 w = torch.cat([w, torch.zeros(state_pad, w.shape[1])], dim=0) 177 b = state_out.bias[state_mask] 178 b = b/state_scale 179 b = torch.cat([b, torch.zeros(state_pad)], dim=0) 180 state_out.weight = torch.nn.Parameter(w) 181 state_out.bias = torch.nn.Parameter(b) 182 183 latent_in = model.get_submodule('core_decoder.module.dense_1') 184 state_in = model.get_submodule('core_decoder.module.hidden_init') 185 latent_in.weight = torch.nn.Parameter(latent_in.weight[:,latent_mask]*latent_scale) 186 state_in.weight = torch.nn.Parameter(state_in.weight[:,state_mask]*state_scale) 187 188 # encoder 189 encoder_dense_layers = [ 190 ('core_encoder.module.dense_1' , 'enc_dense1', 'TANH', False,), 191 ('core_encoder.module.z_dense' , 'enc_zdense', 'LINEAR', True,), 192 ('core_encoder.module.state_dense_1' , 'gdense1' , 'TANH', True,), 193 ('core_encoder.module.state_dense_2' , 'gdense2' , 'TANH', True) 194 ] 195 196 for name, export_name, _, quantize in encoder_dense_layers: 197 layer = model.get_submodule(name) 198 dump_torch_weights(enc_writer, layer, name=export_name, verbose=True, quantize=quantize, scale=None) 199 200 201 encoder_gru_layers = [ 202 ('core_encoder.module.gru1' , 'enc_gru1', 'TANH', True), 203 ('core_encoder.module.gru2' , 'enc_gru2', 'TANH', True), 204 ('core_encoder.module.gru3' , 'enc_gru3', 'TANH', True), 205 ('core_encoder.module.gru4' , 'enc_gru4', 'TANH', True), 206 ('core_encoder.module.gru5' , 'enc_gru5', 'TANH', True), 207 ] 208 209 enc_max_rnn_units = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, verbose=True, input_sparse=True, quantize=quantize, scale=None, recurrent_scale=None) 210 for name, export_name, _, quantize in encoder_gru_layers]) 211 212 213 encoder_conv_layers = [ 214 ('core_encoder.module.conv1.conv' , 'enc_conv1', 'TANH', True), 215 ('core_encoder.module.conv2.conv' , 'enc_conv2', 'TANH', True), 216 ('core_encoder.module.conv3.conv' , 'enc_conv3', 'TANH', True), 217 ('core_encoder.module.conv4.conv' , 'enc_conv4', 'TANH', True), 218 ('core_encoder.module.conv5.conv' , 'enc_conv5', 'TANH', True), 219 ] 220 221 enc_max_conv_inputs = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, verbose=True, quantize=quantize, scale=None) for name, export_name, _, quantize in encoder_conv_layers]) 222 223 224 del enc_writer 225 226 # decoder 227 decoder_dense_layers = [ 228 ('core_decoder.module.dense_1' , 'dec_dense1', 'TANH', False), 229 ('core_decoder.module.glu1.gate' , 'dec_glu1', 'TANH', True), 230 ('core_decoder.module.glu2.gate' , 'dec_glu2', 'TANH', True), 231 ('core_decoder.module.glu3.gate' , 'dec_glu3', 'TANH', True), 232 ('core_decoder.module.glu4.gate' , 'dec_glu4', 'TANH', True), 233 ('core_decoder.module.glu5.gate' , 'dec_glu5', 'TANH', True), 234 ('core_decoder.module.output' , 'dec_output', 'LINEAR', True), 235 ('core_decoder.module.hidden_init' , 'dec_hidden_init', 'TANH', False), 236 ('core_decoder.module.gru_init' , 'dec_gru_init','TANH', True), 237 ] 238 239 for name, export_name, _, quantize in decoder_dense_layers: 240 layer = model.get_submodule(name) 241 dump_torch_weights(dec_writer, layer, name=export_name, verbose=True, quantize=quantize, scale=None) 242 243 244 decoder_gru_layers = [ 245 ('core_decoder.module.gru1' , 'dec_gru1', 'TANH', True), 246 ('core_decoder.module.gru2' , 'dec_gru2', 'TANH', True), 247 ('core_decoder.module.gru3' , 'dec_gru3', 'TANH', True), 248 ('core_decoder.module.gru4' , 'dec_gru4', 'TANH', True), 249 ('core_decoder.module.gru5' , 'dec_gru5', 'TANH', True), 250 ] 251 252 dec_max_rnn_units = max([dump_torch_weights(dec_writer, model.get_submodule(name), export_name, verbose=True, input_sparse=True, quantize=quantize, scale=None, recurrent_scale=None) 253 for name, export_name, _, quantize in decoder_gru_layers]) 254 255 decoder_conv_layers = [ 256 ('core_decoder.module.conv1.conv' , 'dec_conv1', 'TANH', True), 257 ('core_decoder.module.conv2.conv' , 'dec_conv2', 'TANH', True), 258 ('core_decoder.module.conv3.conv' , 'dec_conv3', 'TANH', True), 259 ('core_decoder.module.conv4.conv' , 'dec_conv4', 'TANH', True), 260 ('core_decoder.module.conv5.conv' , 'dec_conv5', 'TANH', True), 261 ] 262 263 dec_max_conv_inputs = max([dump_torch_weights(dec_writer, model.get_submodule(name), export_name, verbose=True, quantize=quantize, scale=None) for name, export_name, _, quantize in decoder_conv_layers]) 264 265 del dec_writer 266 267 del stats_writer 268 269 # constants 270 constants_writer.header.write( 271f""" 272#define DRED_NUM_FEATURES {model.feature_dim} 273 274#define DRED_LATENT_DIM {latent_dim} 275 276#define DRED_STATE_DIM {state_dim} 277 278#define DRED_PADDED_LATENT_DIM {padded_latent_dim} 279 280#define DRED_PADDED_STATE_DIM {padded_state_dim} 281 282#define DRED_NUM_QUANTIZATION_LEVELS {model.quant_levels} 283 284#define DRED_MAX_RNN_NEURONS {max(enc_max_rnn_units, dec_max_rnn_units)} 285 286#define DRED_MAX_CONV_INPUTS {max(enc_max_conv_inputs, dec_max_conv_inputs)} 287 288#define DRED_ENC_MAX_RNN_NEURONS {enc_max_conv_inputs} 289 290#define DRED_ENC_MAX_CONV_INPUTS {enc_max_conv_inputs} 291 292#define DRED_DEC_MAX_RNN_NEURONS {dec_max_rnn_units} 293 294""" 295 ) 296 297 del constants_writer 298 299 300def numpy_export(args, model): 301 302 exchange_name_to_name = { 303 'encoder_stack_layer1_dense' : 'core_encoder.module.dense_1', 304 'encoder_stack_layer3_dense' : 'core_encoder.module.dense_2', 305 'encoder_stack_layer5_dense' : 'core_encoder.module.dense_3', 306 'encoder_stack_layer7_dense' : 'core_encoder.module.dense_4', 307 'encoder_stack_layer8_dense' : 'core_encoder.module.dense_5', 308 'encoder_state_layer1_dense' : 'core_encoder.module.state_dense_1', 309 'encoder_state_layer2_dense' : 'core_encoder.module.state_dense_2', 310 'encoder_stack_layer2_gru' : 'core_encoder.module.gru_1', 311 'encoder_stack_layer4_gru' : 'core_encoder.module.gru_2', 312 'encoder_stack_layer6_gru' : 'core_encoder.module.gru_3', 313 'encoder_stack_layer9_conv' : 'core_encoder.module.conv1', 314 'statistical_model_embedding' : 'statistical_model.quant_embedding', 315 'decoder_state1_dense' : 'core_decoder.module.gru_1_init', 316 'decoder_state2_dense' : 'core_decoder.module.gru_2_init', 317 'decoder_state3_dense' : 'core_decoder.module.gru_3_init', 318 'decoder_stack_layer1_dense' : 'core_decoder.module.dense_1', 319 'decoder_stack_layer3_dense' : 'core_decoder.module.dense_2', 320 'decoder_stack_layer5_dense' : 'core_decoder.module.dense_3', 321 'decoder_stack_layer7_dense' : 'core_decoder.module.dense_4', 322 'decoder_stack_layer8_dense' : 'core_decoder.module.dense_5', 323 'decoder_stack_layer9_dense' : 'core_decoder.module.output', 324 'decoder_stack_layer2_gru' : 'core_decoder.module.gru_1', 325 'decoder_stack_layer4_gru' : 'core_decoder.module.gru_2', 326 'decoder_stack_layer6_gru' : 'core_decoder.module.gru_3' 327 } 328 329 name_to_exchange_name = {value : key for key, value in exchange_name_to_name.items()} 330 331 for name, exchange_name in name_to_exchange_name.items(): 332 print(f"printing layer {name}...") 333 dump_torch_weights(os.path.join(args.output_dir, exchange_name), model.get_submodule(name)) 334 335 336if __name__ == "__main__": 337 338 339 os.makedirs(args.output_dir, exist_ok=True) 340 341 342 # load model from checkpoint 343 checkpoint = torch.load(args.checkpoint, map_location='cpu') 344 model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs']) 345 missing_keys, unmatched_keys = model.load_state_dict(checkpoint['state_dict'], strict=False) 346 def _remove_weight_norm(m): 347 try: 348 torch.nn.utils.remove_weight_norm(m) 349 except ValueError: # this module didn't have weight norm 350 return 351 model.apply(_remove_weight_norm) 352 353 354 if len(missing_keys) > 0: 355 raise ValueError(f"error: missing keys in state dict") 356 357 if len(unmatched_keys) > 0: 358 print(f"warning: the following keys were unmatched {unmatched_keys}") 359 360 if args.format == 'C': 361 c_export(args, model) 362 elif args.format == 'numpy': 363 numpy_export(args, model) 364 else: 365 raise ValueError(f'error: unknown export format {args.format}') 366