• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright (c) Qualcomm Innovation Center, Inc.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/examples/qualcomm/qaihub_scripts/stable_diffusion/runner/runner.h>
10 #include <executorch/runtime/platform/runtime.h>
11 #include <gflags/gflags.h>
12 
13 DEFINE_string(
14     text_encoder_path,
15     "qaihub_stable_diffusion_text_encoder.pte",
16     "Text Encoder Model serialized in flatbuffer format.");
17 DEFINE_string(
18     unet_path,
19     "qaihub_stable_diffusion_unet.pte",
20     "Unet Model serialized in flatbuffer format.");
21 DEFINE_string(
22     vae_path,
23     "qaihub_stable_diffusion_vae.pte",
24     "Vae Model serialized in flatbuffer format.");
25 DEFINE_string(
26     output_folder_path,
27     "outputs",
28     "Executorch inference data output path.");
29 DEFINE_string(
30     input_list_path,
31     "input_list.txt",
32     "Input list storing time embedding.");
33 DEFINE_string(
34     vocab_json,
35     "vocab.json",
36     "Json path to retrieve a list of vocabs.");
37 DEFINE_string(
38     prompt,
39     "a photo of an astronaut riding a horse on mars",
40     "User input prompt");
41 DEFINE_int32(num_time_steps, 20, "Number of time steps.");
42 DEFINE_double(guidance_scale, 7.5, "Guidance Scale");
43 
44 DEFINE_double(text_encoder_output_scale, 0.0, "Text encoder output scale");
45 DEFINE_int32(text_encoder_output_offset, 0, "Text encoder output offset");
46 DEFINE_double(unet_input_latent_scale, 0.0, "Unet input latent scale");
47 DEFINE_int32(unet_input_latent_offset, 0, "Unet input latent offset");
48 DEFINE_double(unet_input_text_emb_scale, 0.0, "Unet input text emb scale");
49 DEFINE_int32(unet_input_text_emb_offset, 0, "Unet input text emb offset");
50 DEFINE_double(unet_output_scale, 0.0, "Unet output scale");
51 DEFINE_int32(unet_output_offset, 0, "Unet output offset");
52 DEFINE_double(vae_input_scale, 0.0, "Vae input scale");
53 DEFINE_int32(vae_input_offset, 0, "Vae input offset");
54 DEFINE_double(vae_output_scale, 0.0, "Vae output scale");
55 DEFINE_int32(vae_output_offset, 0, "Vae output offset");
56 DEFINE_bool(
57     fix_latents,
58     false,
59     "Enable this option to fix the latents in the unet diffuse step.");
60 
usage_message()61 void usage_message() {
62   std::string usage_message =
63       "This is a sample executor runner capable of executing stable diffusion models."
64       "Users will need binary .pte program files for text_encoder, unet, and vae. Below are the options to retrieve required .pte program files:\n"
65       "For further information on how to generate the .pte program files and example command to execute this runner, please refer to qaihub_stable_diffsion.py.";
66   gflags::SetUsageMessage(usage_message);
67 }
68 
69 using executorch::runtime::Error;
70 
main(int argc,char ** argv)71 int main(int argc, char** argv) {
72   executorch::runtime::runtime_init();
73   usage_message();
74   gflags::ParseCommandLineFlags(&argc, &argv, true);
75   bool is_default =
76       gflags::GetCommandLineFlagInfoOrDie("text_encoder_output_scale")
77           .is_default ||
78       gflags::GetCommandLineFlagInfoOrDie("text_encoder_output_offset")
79           .is_default ||
80       gflags::GetCommandLineFlagInfoOrDie("unet_input_latent_scale")
81           .is_default ||
82       gflags::GetCommandLineFlagInfoOrDie("unet_input_latent_offset")
83           .is_default ||
84       gflags::GetCommandLineFlagInfoOrDie("unet_input_text_emb_scale")
85           .is_default ||
86       gflags::GetCommandLineFlagInfoOrDie("unet_input_text_emb_offset")
87           .is_default ||
88       gflags::GetCommandLineFlagInfoOrDie("unet_output_scale").is_default ||
89       gflags::GetCommandLineFlagInfoOrDie("unet_output_offset").is_default ||
90       gflags::GetCommandLineFlagInfoOrDie("vae_input_scale").is_default ||
91       gflags::GetCommandLineFlagInfoOrDie("vae_input_offset").is_default ||
92       gflags::GetCommandLineFlagInfoOrDie("vae_output_scale").is_default ||
93       gflags::GetCommandLineFlagInfoOrDie("vae_output_offset").is_default;
94 
95   ET_CHECK_MSG(
96       !is_default,
97       "Please provide scale and offset for unet latent input, unet output, and vae input/output."
98       "Please refer to qaihub_stable_diffusion.py if you are unsure how to retrieve these values.");
99 
100   ET_LOG(Info, "Stable Diffusion runner started");
101   std::vector<std::string> models_path = {
102       FLAGS_text_encoder_path, FLAGS_unet_path, FLAGS_vae_path};
103 
104   // Create stable_diffusion_runner
105   example::Runner runner(
106       models_path,
107       FLAGS_num_time_steps,
108       FLAGS_guidance_scale,
109       FLAGS_text_encoder_output_scale,
110       FLAGS_text_encoder_output_offset,
111       FLAGS_unet_input_latent_scale,
112       FLAGS_unet_input_latent_offset,
113       FLAGS_unet_input_text_emb_scale,
114       FLAGS_unet_input_text_emb_offset,
115       FLAGS_unet_output_scale,
116       FLAGS_unet_output_offset,
117       FLAGS_vae_input_scale,
118       FLAGS_vae_input_offset,
119       FLAGS_vae_output_scale,
120       FLAGS_vae_output_offset,
121       FLAGS_output_folder_path,
122       FLAGS_fix_latents);
123 
124   ET_CHECK_MSG(
125       runner.init_tokenizer(FLAGS_vocab_json) == Error::Ok,
126       "Runner failed to init tokenizer");
127 
128   ET_CHECK_MSG(runner.load() == Error::Ok, "Runner failed to load method");
129 
130   ET_CHECK_MSG(
131       runner.parse_input_list(FLAGS_input_list_path) == Error::Ok,
132       "Failed to parse time embedding input list");
133   ET_CHECK_MSG(
134       runner.generate(FLAGS_prompt) == Error::Ok, "Runner failed to generate");
135 
136   ET_CHECK_MSG(
137       runner.print_performance() == Error::Ok,
138       "Runner failed to print performance");
139 
140   return 0;
141 }
142