• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1"""
2/* Copyright (c) 2022 Amazon
3   Written by Jan Buethe */
4/*
5   Redistribution and use in source and binary forms, with or without
6   modification, are permitted provided that the following conditions
7   are met:
8
9   - Redistributions of source code must retain the above copyright
10   notice, this list of conditions and the following disclaimer.
11
12   - Redistributions in binary form must reproduce the above copyright
13   notice, this list of conditions and the following disclaimer in the
14   documentation and/or other materials provided with the distribution.
15
16   THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
17   ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
18   LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
19   A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER
20   OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21   EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22   PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23   PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
24   LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
25   NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
26   SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27*/
28"""
29
30import os
31os.environ['CUDA_VISIBLE_DEVICES'] = ""
32
33import argparse
34
35
36
37parser = argparse.ArgumentParser()
38
39parser.add_argument('exchange_folder', type=str, help='exchange folder path')
40parser.add_argument('output', type=str, help='path to output model checkpoint')
41
42model_group = parser.add_argument_group(title="model parameters")
43model_group.add_argument('--num-features', type=int, help="number of features, default: 20", default=20)
44model_group.add_argument('--latent-dim', type=int, help="number of symbols produces by encoder, default: 80", default=80)
45model_group.add_argument('--cond-size', type=int, help="first conditioning size, default: 256", default=256)
46model_group.add_argument('--cond-size2', type=int, help="second conditioning size, default: 256", default=256)
47model_group.add_argument('--state-dim', type=int, help="dimensionality of transfered state, default: 24", default=24)
48model_group.add_argument('--quant-levels', type=int, help="number of quantization levels, default: 40", default=40)
49
50args = parser.parse_args()
51
52import torch
53from rdovae import RDOVAE
54from wexchange.torch import load_torch_weights
55
56exchange_name_to_name = {
57    'encoder_stack_layer1_dense'    : 'core_encoder.module.dense_1',
58    'encoder_stack_layer3_dense'    : 'core_encoder.module.dense_2',
59    'encoder_stack_layer5_dense'    : 'core_encoder.module.dense_3',
60    'encoder_stack_layer7_dense'    : 'core_encoder.module.dense_4',
61    'encoder_stack_layer8_dense'    : 'core_encoder.module.dense_5',
62    'encoder_state_layer1_dense'    : 'core_encoder.module.state_dense_1',
63    'encoder_state_layer2_dense'    : 'core_encoder.module.state_dense_2',
64    'encoder_stack_layer2_gru'      : 'core_encoder.module.gru_1',
65    'encoder_stack_layer4_gru'      : 'core_encoder.module.gru_2',
66    'encoder_stack_layer6_gru'      : 'core_encoder.module.gru_3',
67    'encoder_stack_layer9_conv'     : 'core_encoder.module.conv1',
68    'statistical_model_embedding'   : 'statistical_model.quant_embedding',
69    'decoder_state1_dense'          : 'core_decoder.module.gru_1_init',
70    'decoder_state2_dense'          : 'core_decoder.module.gru_2_init',
71    'decoder_state3_dense'          : 'core_decoder.module.gru_3_init',
72    'decoder_stack_layer1_dense'    : 'core_decoder.module.dense_1',
73    'decoder_stack_layer3_dense'    : 'core_decoder.module.dense_2',
74    'decoder_stack_layer5_dense'    : 'core_decoder.module.dense_3',
75    'decoder_stack_layer7_dense'    : 'core_decoder.module.dense_4',
76    'decoder_stack_layer8_dense'    : 'core_decoder.module.dense_5',
77    'decoder_stack_layer9_dense'    : 'core_decoder.module.output',
78    'decoder_stack_layer2_gru'      : 'core_decoder.module.gru_1',
79    'decoder_stack_layer4_gru'      : 'core_decoder.module.gru_2',
80    'decoder_stack_layer6_gru'      : 'core_decoder.module.gru_3'
81}
82
83if __name__ == "__main__":
84    checkpoint = dict()
85
86    # parameters
87    num_features    = args.num_features
88    latent_dim      = args.latent_dim
89    quant_levels    = args.quant_levels
90    cond_size       = args.cond_size
91    cond_size2      = args.cond_size2
92    state_dim       = args.state_dim
93
94
95    # model
96    checkpoint['model_args']    = (num_features, latent_dim, quant_levels, cond_size, cond_size2)
97    checkpoint['model_kwargs']  = {'state_dim': state_dim}
98    model = RDOVAE(*checkpoint['model_args'], **checkpoint['model_kwargs'])
99
100    dense_layer_names = [
101        'encoder_stack_layer1_dense',
102        'encoder_stack_layer3_dense',
103        'encoder_stack_layer5_dense',
104        'encoder_stack_layer7_dense',
105        'encoder_stack_layer8_dense',
106        'encoder_state_layer1_dense',
107        'encoder_state_layer2_dense',
108        'decoder_state1_dense',
109        'decoder_state2_dense',
110        'decoder_state3_dense',
111        'decoder_stack_layer1_dense',
112        'decoder_stack_layer3_dense',
113        'decoder_stack_layer5_dense',
114        'decoder_stack_layer7_dense',
115        'decoder_stack_layer8_dense',
116        'decoder_stack_layer9_dense'
117    ]
118
119    gru_layer_names = [
120        'encoder_stack_layer2_gru',
121        'encoder_stack_layer4_gru',
122        'encoder_stack_layer6_gru',
123        'decoder_stack_layer2_gru',
124        'decoder_stack_layer4_gru',
125        'decoder_stack_layer6_gru'
126    ]
127
128    conv1d_layer_names = [
129        'encoder_stack_layer9_conv'
130    ]
131
132    embedding_layer_names = [
133        'statistical_model_embedding'
134    ]
135
136    for name in dense_layer_names + gru_layer_names + conv1d_layer_names + embedding_layer_names:
137        print(f"loading weights for layer {exchange_name_to_name[name]}")
138        layer = model.get_submodule(exchange_name_to_name[name])
139        load_torch_weights(os.path.join(args.exchange_folder, name), layer)
140
141    checkpoint['state_dict'] = model.state_dict()
142
143    torch.save(checkpoint, args.output)