1 /*
2 * Copyright (c) Qualcomm Innovation Center, Inc.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/examples/qualcomm/qaihub_scripts/stable_diffusion/runner/runner.h>
10 #include <executorch/runtime/platform/runtime.h>
11 #include <gflags/gflags.h>
12
13 DEFINE_string(
14 text_encoder_path,
15 "qaihub_stable_diffusion_text_encoder.pte",
16 "Text Encoder Model serialized in flatbuffer format.");
17 DEFINE_string(
18 unet_path,
19 "qaihub_stable_diffusion_unet.pte",
20 "Unet Model serialized in flatbuffer format.");
21 DEFINE_string(
22 vae_path,
23 "qaihub_stable_diffusion_vae.pte",
24 "Vae Model serialized in flatbuffer format.");
25 DEFINE_string(
26 output_folder_path,
27 "outputs",
28 "Executorch inference data output path.");
29 DEFINE_string(
30 input_list_path,
31 "input_list.txt",
32 "Input list storing time embedding.");
33 DEFINE_string(
34 vocab_json,
35 "vocab.json",
36 "Json path to retrieve a list of vocabs.");
37 DEFINE_string(
38 prompt,
39 "a photo of an astronaut riding a horse on mars",
40 "User input prompt");
41 DEFINE_int32(num_time_steps, 20, "Number of time steps.");
42 DEFINE_double(guidance_scale, 7.5, "Guidance Scale");
43
44 DEFINE_double(text_encoder_output_scale, 0.0, "Text encoder output scale");
45 DEFINE_int32(text_encoder_output_offset, 0, "Text encoder output offset");
46 DEFINE_double(unet_input_latent_scale, 0.0, "Unet input latent scale");
47 DEFINE_int32(unet_input_latent_offset, 0, "Unet input latent offset");
48 DEFINE_double(unet_input_text_emb_scale, 0.0, "Unet input text emb scale");
49 DEFINE_int32(unet_input_text_emb_offset, 0, "Unet input text emb offset");
50 DEFINE_double(unet_output_scale, 0.0, "Unet output scale");
51 DEFINE_int32(unet_output_offset, 0, "Unet output offset");
52 DEFINE_double(vae_input_scale, 0.0, "Vae input scale");
53 DEFINE_int32(vae_input_offset, 0, "Vae input offset");
54 DEFINE_double(vae_output_scale, 0.0, "Vae output scale");
55 DEFINE_int32(vae_output_offset, 0, "Vae output offset");
56 DEFINE_bool(
57 fix_latents,
58 false,
59 "Enable this option to fix the latents in the unet diffuse step.");
60
usage_message()61 void usage_message() {
62 std::string usage_message =
63 "This is a sample executor runner capable of executing stable diffusion models."
64 "Users will need binary .pte program files for text_encoder, unet, and vae. Below are the options to retrieve required .pte program files:\n"
65 "For further information on how to generate the .pte program files and example command to execute this runner, please refer to qaihub_stable_diffsion.py.";
66 gflags::SetUsageMessage(usage_message);
67 }
68
69 using executorch::runtime::Error;
70
main(int argc,char ** argv)71 int main(int argc, char** argv) {
72 executorch::runtime::runtime_init();
73 usage_message();
74 gflags::ParseCommandLineFlags(&argc, &argv, true);
75 bool is_default =
76 gflags::GetCommandLineFlagInfoOrDie("text_encoder_output_scale")
77 .is_default ||
78 gflags::GetCommandLineFlagInfoOrDie("text_encoder_output_offset")
79 .is_default ||
80 gflags::GetCommandLineFlagInfoOrDie("unet_input_latent_scale")
81 .is_default ||
82 gflags::GetCommandLineFlagInfoOrDie("unet_input_latent_offset")
83 .is_default ||
84 gflags::GetCommandLineFlagInfoOrDie("unet_input_text_emb_scale")
85 .is_default ||
86 gflags::GetCommandLineFlagInfoOrDie("unet_input_text_emb_offset")
87 .is_default ||
88 gflags::GetCommandLineFlagInfoOrDie("unet_output_scale").is_default ||
89 gflags::GetCommandLineFlagInfoOrDie("unet_output_offset").is_default ||
90 gflags::GetCommandLineFlagInfoOrDie("vae_input_scale").is_default ||
91 gflags::GetCommandLineFlagInfoOrDie("vae_input_offset").is_default ||
92 gflags::GetCommandLineFlagInfoOrDie("vae_output_scale").is_default ||
93 gflags::GetCommandLineFlagInfoOrDie("vae_output_offset").is_default;
94
95 ET_CHECK_MSG(
96 !is_default,
97 "Please provide scale and offset for unet latent input, unet output, and vae input/output."
98 "Please refer to qaihub_stable_diffusion.py if you are unsure how to retrieve these values.");
99
100 ET_LOG(Info, "Stable Diffusion runner started");
101 std::vector<std::string> models_path = {
102 FLAGS_text_encoder_path, FLAGS_unet_path, FLAGS_vae_path};
103
104 // Create stable_diffusion_runner
105 example::Runner runner(
106 models_path,
107 FLAGS_num_time_steps,
108 FLAGS_guidance_scale,
109 FLAGS_text_encoder_output_scale,
110 FLAGS_text_encoder_output_offset,
111 FLAGS_unet_input_latent_scale,
112 FLAGS_unet_input_latent_offset,
113 FLAGS_unet_input_text_emb_scale,
114 FLAGS_unet_input_text_emb_offset,
115 FLAGS_unet_output_scale,
116 FLAGS_unet_output_offset,
117 FLAGS_vae_input_scale,
118 FLAGS_vae_input_offset,
119 FLAGS_vae_output_scale,
120 FLAGS_vae_output_offset,
121 FLAGS_output_folder_path,
122 FLAGS_fix_latents);
123
124 ET_CHECK_MSG(
125 runner.init_tokenizer(FLAGS_vocab_json) == Error::Ok,
126 "Runner failed to init tokenizer");
127
128 ET_CHECK_MSG(runner.load() == Error::Ok, "Runner failed to load method");
129
130 ET_CHECK_MSG(
131 runner.parse_input_list(FLAGS_input_list_path) == Error::Ok,
132 "Failed to parse time embedding input list");
133 ET_CHECK_MSG(
134 runner.generate(FLAGS_prompt) == Error::Ok, "Runner failed to generate");
135
136 ET_CHECK_MSG(
137 runner.print_performance() == Error::Ok,
138 "Runner failed to print performance");
139
140 return 0;
141 }
142