• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1#!/usr/bin/env python
2"""Extracts trainable parameters from Caffe models and stores them in numpy arrays.
3Usage
4    python caffe_data_extractor -m path_to_caffe_model_file -n path_to_caffe_netlist
5
6Saves each variable to a {variable_name}.npy binary file.
7
8Tested with Caffe 1.0 on Python 2.7
9"""
10import argparse
11import caffe
12import os
13import numpy as np
14
15
16if __name__ == "__main__":
17    # Parse arguments
18    parser = argparse.ArgumentParser('Extract Caffe net parameters')
19    parser.add_argument('-m', dest='modelFile', type=str, required=True, help='Path to Caffe model file')
20    parser.add_argument('-n', dest='netFile', type=str, required=True, help='Path to Caffe netlist')
21    args = parser.parse_args()
22
23    # Create Caffe Net
24    net = caffe.Net(args.netFile, 1, weights=args.modelFile)
25
26    # Read and dump blobs
27    for name, blobs in net.params.iteritems():
28        print('Name: {0}, Blobs: {1}'.format(name, len(blobs)))
29        for i in range(len(blobs)):
30            # Weights
31            if i == 0:
32                outname = name + "_w"
33            # Bias
34            elif i == 1:
35                outname = name + "_b"
36            else:
37                continue
38
39            varname = outname
40            if os.path.sep in varname:
41                varname = varname.replace(os.path.sep, '_')
42                print("Renaming variable {0} to {1}".format(outname, varname))
43            print("Saving variable {0} with shape {1} ...".format(varname, blobs[i].data.shape))
44            # Dump as binary
45            np.save(varname, blobs[i].data)
46