• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""
2/* Copyright (c) 2022 Amazon
3   Written by Jan Buethe */
4/*
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*/
28"""
29
30import os
31import argparse
32import sys
33
34sys.path.append(os.path.join(os.path.dirname(__file__), '../weight-exchange'))
35
36parser = argparse.ArgumentParser()
37
38parser.add_argument('checkpoint', type=str, help='rdovae model checkpoint')
39parser.add_argument('output_dir', type=str, help='output folder')
40parser.add_argument('--format', choices=['C', 'numpy'], help='output format, default: C', default='C')
41
42args = parser.parse_args()
43
44import torch
45import numpy as np
46
47from rdovae import RDOVAE
48from wexchange.torch import dump_torch_weights
49from wexchange.c_export import CWriter, print_vector
50
51def print_xml(xmlout, val, param, anchor, name):
52    xmlout.write(
53f"""
54            <table anchor="{anchor}_{name}">
55                <name>{param} values for {name}</name>
56                <thead>
57                    <tr><th>k</th><th>Q0</th><th>Q1</th><th>Q2</th><th>Q3</th><th>Q4</th><th>Q5</th><th>Q6</th><th>Q7</th><th>Q8</th><th>Q9</th><th>Q10</th><th>Q11</th><th>Q12</th><th>Q13</th><th>Q14</th><th>Q15</th></tr>
58                </thead>
59                <tbody>
60""")
61    for k in range(val.shape[1]):
62        xmlout.write(f"        <tr><th>{k}</th>")
63        for j in range(val.shape[0]):
64            xmlout.write(f"<th>{val[j][k]}</th>")
65        xmlout.write("</tr>\n")
66    xmlout.write(
67f"""
68                </tbody>
69            </table>
70""")
71def dump_statistical_model(writer, w, name, xmlout):
72    levels = w.shape[0]
73
74    print("printing statistical model")
75    quant_scales    = torch.nn.functional.softplus(w[:, 0, :]).numpy()
76    dead_zone       = 0.05 * torch.nn.functional.softplus(w[:, 1, :]).numpy()
77    r               = torch.sigmoid(w[:, 5 , :]).numpy()
78    p0              = torch.sigmoid(w[:, 4 , :]).numpy()
79    p0              = 1 - r ** (0.5 + 0.5 * p0)
80
81    scales_norm = 255./256./(1e-15+np.max(quant_scales,axis=0))
82    quant_scales = quant_scales*scales_norm
83    quant_scales_q8 = np.round(quant_scales * 2**8).astype(np.uint16)
84    dead_zone_q8   = np.clip(np.round(dead_zone * 2**8), 0, 255).astype(np.uint16)
85    r_q8           = np.clip(np.round(r * 2**8), 0, 255).astype(np.uint8)
86    p0_q8          = np.clip(np.round(p0 * 2**8), 0, 255).astype(np.uint16)
87
88    mask = (np.max(r_q8,axis=0) > 0) * (np.min(p0_q8,axis=0) < 255)
89    quant_scales_q8 = quant_scales_q8[:, mask]
90    dead_zone_q8 = dead_zone_q8[:, mask]
91    r_q8 = r_q8[:, mask]
92    p0_q8 = p0_q8[:, mask]
93    N = r_q8.shape[-1]
94
95    print_vector(writer.source, quant_scales_q8, f'dred_{name}_quant_scales_q8', dtype='opus_uint8', static=False)
96    print_vector(writer.source, dead_zone_q8, f'dred_{name}_dead_zone_q8', dtype='opus_uint8', static=False)
97    print_vector(writer.source, r_q8, f'dred_{name}_r_q8', dtype='opus_uint8', static=False)
98    print_vector(writer.source, p0_q8, f'dred_{name}_p0_q8', dtype='opus_uint8', static=False)
99
100    print_xml(xmlout, quant_scales_q8, "Scale", "scale", name)
101    print_xml(xmlout, dead_zone_q8, "Dead zone", "deadzone", name)
102    print_xml(xmlout, r_q8, "Decay (r)", "decay", name)
103    print_xml(xmlout, p0_q8, "P(0)", "p0", name)
104
105    writer.header.write(
106f"""
107extern const opus_uint8 dred_{name}_quant_scales_q8[{levels * N}];
108extern const opus_uint8 dred_{name}_dead_zone_q8[{levels * N}];
109extern const opus_uint8 dred_{name}_r_q8[{levels * N}];
110extern const opus_uint8 dred_{name}_p0_q8[{levels * N}];
111
112"""
113    )
114    return N, mask, torch.tensor(scales_norm[mask])
115
116
117def c_export(args, model):
118
119    message = f"Auto generated from checkpoint {os.path.basename(args.checkpoint)}"
120
121    enc_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_enc_data"), message=message, model_struct_name='RDOVAEEnc')
122    dec_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_dec_data"), message=message, model_struct_name='RDOVAEDec')
123    stats_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_stats_data"), message=message, enable_binary_blob=False)
124    constants_writer = CWriter(os.path.join(args.output_dir, "dred_rdovae_constants"), message=message, header_only=True, enable_binary_blob=False)
125    xmlout = open("stats.xml", "w")
126
127    # some custom includes
128    for writer in [enc_writer, dec_writer]:
129        writer.header.write(
130f"""
131#include "opus_types.h"
132
133#include "dred_rdovae.h"
134
135#include "dred_rdovae_constants.h"
136
137"""
138        )
139
140    stats_writer.header.write(
141f"""
142#include "opus_types.h"
143
144#include "dred_rdovae_constants.h"
145
146"""
147        )
148
149    latent_out = model.get_submodule('core_encoder.module.z_dense')
150    state_out = model.get_submodule('core_encoder.module.state_dense_2')
151    orig_latent_dim = latent_out.weight.shape[0]
152    orig_state_dim = state_out.weight.shape[0]
153    # statistical model
154    qembedding = model.statistical_model.quant_embedding.weight.detach()
155    levels = qembedding.shape[0]
156    qembedding = torch.reshape(qembedding, (levels, 6, -1))
157
158    latent_dim, latent_mask, latent_scale = dump_statistical_model(stats_writer, qembedding[:, :, :orig_latent_dim], 'latent', xmlout)
159    state_dim, state_mask, state_scale = dump_statistical_model(stats_writer, qembedding[:, :, orig_latent_dim:], 'state', xmlout)
160
161    padded_latent_dim = (latent_dim+7)//8*8
162    latent_pad = padded_latent_dim - latent_dim;
163    w = latent_out.weight[latent_mask,:]
164    w = w/latent_scale[:, None]
165    w = torch.cat([w, torch.zeros(latent_pad, w.shape[1])], dim=0)
166    b = latent_out.bias[latent_mask]
167    b = b/latent_scale
168    b = torch.cat([b, torch.zeros(latent_pad)], dim=0)
169    latent_out.weight = torch.nn.Parameter(w)
170    latent_out.bias = torch.nn.Parameter(b)
171
172    padded_state_dim = (state_dim+7)//8*8
173    state_pad = padded_state_dim - state_dim;
174    w = state_out.weight[state_mask,:]
175    w = w/state_scale[:, None]
176    w = torch.cat([w, torch.zeros(state_pad, w.shape[1])], dim=0)
177    b = state_out.bias[state_mask]
178    b = b/state_scale
179    b = torch.cat([b, torch.zeros(state_pad)], dim=0)
180    state_out.weight = torch.nn.Parameter(w)
181    state_out.bias = torch.nn.Parameter(b)
182
183    latent_in = model.get_submodule('core_decoder.module.dense_1')
184    state_in = model.get_submodule('core_decoder.module.hidden_init')
185    latent_in.weight = torch.nn.Parameter(latent_in.weight[:,latent_mask]*latent_scale)
186    state_in.weight = torch.nn.Parameter(state_in.weight[:,state_mask]*state_scale)
187
188    # encoder
189    encoder_dense_layers = [
190        ('core_encoder.module.dense_1'       , 'enc_dense1',   'TANH', False,),
191        ('core_encoder.module.z_dense'       , 'enc_zdense',   'LINEAR', True,),
192        ('core_encoder.module.state_dense_1' , 'gdense1'    ,   'TANH', True,),
193        ('core_encoder.module.state_dense_2' , 'gdense2'    ,   'TANH', True)
194    ]
195
196    for name, export_name, _, quantize in encoder_dense_layers:
197        layer = model.get_submodule(name)
198        dump_torch_weights(enc_writer, layer, name=export_name, verbose=True, quantize=quantize, scale=None)
199
200
201    encoder_gru_layers = [
202        ('core_encoder.module.gru1'       , 'enc_gru1',   'TANH', True),
203        ('core_encoder.module.gru2'       , 'enc_gru2',   'TANH', True),
204        ('core_encoder.module.gru3'       , 'enc_gru3',   'TANH', True),
205        ('core_encoder.module.gru4'       , 'enc_gru4',   'TANH', True),
206        ('core_encoder.module.gru5'       , 'enc_gru5',   'TANH', True),
207    ]
208
209    enc_max_rnn_units = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, verbose=True, input_sparse=True, quantize=quantize, scale=None, recurrent_scale=None)
210                             for name, export_name, _, quantize in encoder_gru_layers])
211
212
213    encoder_conv_layers = [
214        ('core_encoder.module.conv1.conv'       , 'enc_conv1',   'TANH', True),
215        ('core_encoder.module.conv2.conv'       , 'enc_conv2',   'TANH', True),
216        ('core_encoder.module.conv3.conv'       , 'enc_conv3',   'TANH', True),
217        ('core_encoder.module.conv4.conv'       , 'enc_conv4',   'TANH', True),
218        ('core_encoder.module.conv5.conv'       , 'enc_conv5',   'TANH', True),
219    ]
220
221    enc_max_conv_inputs = max([dump_torch_weights(enc_writer, model.get_submodule(name), export_name, verbose=True, quantize=quantize, scale=None) for name, export_name, _, quantize in encoder_conv_layers])
222
223
224    del enc_writer
225
226    # decoder
227    decoder_dense_layers = [
228        ('core_decoder.module.dense_1'      , 'dec_dense1',  'TANH', False),
229        ('core_decoder.module.glu1.gate'    , 'dec_glu1',    'TANH', True),
230        ('core_decoder.module.glu2.gate'    , 'dec_glu2',    'TANH', True),
231        ('core_decoder.module.glu3.gate'    , 'dec_glu3',    'TANH', True),
232        ('core_decoder.module.glu4.gate'    , 'dec_glu4',    'TANH', True),
233        ('core_decoder.module.glu5.gate'    , 'dec_glu5',    'TANH', True),
234        ('core_decoder.module.output'       , 'dec_output',  'LINEAR', True),
235        ('core_decoder.module.hidden_init'  , 'dec_hidden_init',        'TANH', False),
236        ('core_decoder.module.gru_init'     , 'dec_gru_init','TANH', True),
237    ]
238
239    for name, export_name, _, quantize in decoder_dense_layers:
240        layer = model.get_submodule(name)
241        dump_torch_weights(dec_writer, layer, name=export_name, verbose=True, quantize=quantize, scale=None)
242
243
244    decoder_gru_layers = [
245        ('core_decoder.module.gru1'         , 'dec_gru1',    'TANH', True),
246        ('core_decoder.module.gru2'         , 'dec_gru2',    'TANH', True),
247        ('core_decoder.module.gru3'         , 'dec_gru3',    'TANH', True),
248        ('core_decoder.module.gru4'         , 'dec_gru4',    'TANH', True),
249        ('core_decoder.module.gru5'         , 'dec_gru5',    'TANH', True),
250    ]
251
252    dec_max_rnn_units = max([dump_torch_weights(dec_writer, model.get_submodule(name), export_name, verbose=True, input_sparse=True, quantize=quantize, scale=None, recurrent_scale=None)
253                             for name, export_name, _, quantize in decoder_gru_layers])
254
255    decoder_conv_layers = [
256        ('core_decoder.module.conv1.conv'       , 'dec_conv1',   'TANH', True),
257        ('core_decoder.module.conv2.conv'       , 'dec_conv2',   'TANH', True),
258        ('core_decoder.module.conv3.conv'       , 'dec_conv3',   'TANH', True),
259        ('core_decoder.module.conv4.conv'       , 'dec_conv4',   'TANH', True),
260        ('core_decoder.module.conv5.conv'       , 'dec_conv5',   'TANH', True),
261    ]
262
263    dec_max_conv_inputs = max([dump_torch_weights(dec_writer, model.get_submodule(name), export_name, verbose=True, quantize=quantize, scale=None) for name, export_name, _, quantize in decoder_conv_layers])
264
265    del dec_writer
266
267    del stats_writer
268
269    # constants
270    constants_writer.header.write(
271f"""
272#define DRED_NUM_FEATURES {model.feature_dim}
273
274#define DRED_LATENT_DIM {latent_dim}
275
276#define DRED_STATE_DIM {state_dim}
277
278#define DRED_PADDED_LATENT_DIM {padded_latent_dim}
279
280#define DRED_PADDED_STATE_DIM {padded_state_dim}
281
282#define DRED_NUM_QUANTIZATION_LEVELS {model.quant_levels}
283
284#define DRED_MAX_RNN_NEURONS {max(enc_max_rnn_units, dec_max_rnn_units)}
285
286#define DRED_MAX_CONV_INPUTS {max(enc_max_conv_inputs, dec_max_conv_inputs)}
287
288#define DRED_ENC_MAX_RNN_NEURONS {enc_max_conv_inputs}
289
290#define DRED_ENC_MAX_CONV_INPUTS {enc_max_conv_inputs}
291
292#define DRED_DEC_MAX_RNN_NEURONS {dec_max_rnn_units}
293
294"""
295    )
296
297    del constants_writer
298
299
300def numpy_export(args, model):
301
302    exchange_name_to_name = {
303        'encoder_stack_layer1_dense'    : 'core_encoder.module.dense_1',
304        'encoder_stack_layer3_dense'    : 'core_encoder.module.dense_2',
305        'encoder_stack_layer5_dense'    : 'core_encoder.module.dense_3',
306        'encoder_stack_layer7_dense'    : 'core_encoder.module.dense_4',
307        'encoder_stack_layer8_dense'    : 'core_encoder.module.dense_5',
308        'encoder_state_layer1_dense'    : 'core_encoder.module.state_dense_1',
309        'encoder_state_layer2_dense'    : 'core_encoder.module.state_dense_2',
310        'encoder_stack_layer2_gru'      : 'core_encoder.module.gru_1',
311        'encoder_stack_layer4_gru'      : 'core_encoder.module.gru_2',
312        'encoder_stack_layer6_gru'      : 'core_encoder.module.gru_3',
313        'encoder_stack_layer9_conv'     : 'core_encoder.module.conv1',
314        'statistical_model_embedding'   : 'statistical_model.quant_embedding',
315        'decoder_state1_dense'          : 'core_decoder.module.gru_1_init',
316        'decoder_state2_dense'          : 'core_decoder.module.gru_2_init',
317        'decoder_state3_dense'          : 'core_decoder.module.gru_3_init',
318        'decoder_stack_layer1_dense'    : 'core_decoder.module.dense_1',
319        'decoder_stack_layer3_dense'    : 'core_decoder.module.dense_2',
320        'decoder_stack_layer5_dense'    : 'core_decoder.module.dense_3',
321        'decoder_stack_layer7_dense'    : 'core_decoder.module.dense_4',
322        'decoder_stack_layer8_dense'    : 'core_decoder.module.dense_5',
323        'decoder_stack_layer9_dense'    : 'core_decoder.module.output',
324        'decoder_stack_layer2_gru'      : 'core_decoder.module.gru_1',
325        'decoder_stack_layer4_gru'      : 'core_decoder.module.gru_2',
326        'decoder_stack_layer6_gru'      : 'core_decoder.module.gru_3'
327    }
328
329    name_to_exchange_name = {value : key for key, value in exchange_name_to_name.items()}
330
331    for name, exchange_name in name_to_exchange_name.items():
332        print(f"printing layer {name}...")
333        dump_torch_weights(os.path.join(args.output_dir, exchange_name), model.get_submodule(name))
334
335
336if __name__ == "__main__":
337
338
339    os.makedirs(args.output_dir, exist_ok=True)
340
341
342    # load model from checkpoint
343    checkpoint = torch.load(args.checkpoint, map_location='cpu')
344    model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs'])
345    missing_keys, unmatched_keys = model.load_state_dict(checkpoint['state_dict'], strict=False)
346    def _remove_weight_norm(m):
347        try:
348            torch.nn.utils.remove_weight_norm(m)
349        except ValueError:  # this module didn't have weight norm
350            return
351    model.apply(_remove_weight_norm)
352
353
354    if len(missing_keys) > 0:
355        raise ValueError(f"error: missing keys in state dict")
356
357    if len(unmatched_keys) > 0:
358        print(f"warning: the following keys were unmatched {unmatched_keys}")
359
360    if args.format == 'C':
361        c_export(args, model)
362    elif args.format == 'numpy':
363        numpy_export(args, model)
364    else:
365        raise ValueError(f'error: unknown export format {args.format}')
366