diff --git a/README.md b/README.md index 0fe607e2..c1ba396f 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,7 @@ Inference of Stable Diffusion and Flux in pure C/C++ - Plain C/C++ implementation based on [ggml](https://github.com/ggerganov/ggml), working in the same way as [llama.cpp](https://github.com/ggerganov/llama.cpp) - Super lightweight and without external dependencies -- SD1.x, SD2.x, SDXL and SD3 support +- SD1.x, SD2.x, SDXL and [SD3/SD3.5](./docs/sd3.md) support - !!!The VAE in SDXL encounters NaN issues under FP16, but unfortunately, the ggml_conv_2d only operates under FP16. Hence, a parameter is needed to specify the VAE that has fixed the FP16 NaN issue. You can find it here: [SDXL VAE FP16 Fix](https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl_vae.safetensors). - [Flux-dev/Flux-schnell Support](./docs/flux.md) @@ -197,23 +197,24 @@ usage: ./bin/sd [arguments] arguments: -h, --help show this help message and exit -M, --mode [MODEL] run mode (txt2img or img2img or convert, default: txt2img) - -t, --threads N number of threads to use during computation (default: -1). + -t, --threads N number of threads to use during computation (default: -1) If threads <= 0, then threads will be set to the number of CPU physical cores -m, --model [MODEL] path to full model --diffusion-model path to the standalone diffusion model --clip_l path to the clip-l text encoder - --t5xxl path to the the t5xxl text encoder. + --clip_g path to the clip-l text encoder + --t5xxl path to the the t5xxl text encoder --vae [VAE] path to vae --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality) --control-net [CONTROL_PATH] path to control net model - --embd-dir [EMBEDDING_PATH] path to embeddings. - --stacked-id-embd-dir [DIR] path to PHOTOMAKER stacked id embeddings. - --input-id-images-dir [DIR] path to PHOTOMAKER input id images dir. + --embd-dir [EMBEDDING_PATH] path to embeddings + --stacked-id-embd-dir [DIR] path to PHOTOMAKER stacked id embeddings + --input-id-images-dir [DIR] path to PHOTOMAKER input id images dir --normalize-input normalize PHOTOMAKER input id images - --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now. + --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now --upscale-repeats Run the ESRGAN upscaler this many times (default 1) --type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_k, q3_k, q4_k) - If not specified, the default is the type of the weight file. + If not specified, the default is the type of the weight file --lora-model-dir [DIR] lora model directory -i, --init-img [IMAGE] path to the input image, required by img2img --control-image [IMAGE] path to image condition, control net @@ -232,13 +233,13 @@ arguments: --steps STEPS number of sample steps (default: 20) --rng {std_default, cuda} RNG (default: cuda) -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0) - -b, --batch-count COUNT number of images to generate. + -b, --batch-count COUNT number of images to generate --schedule {discrete, karras, exponential, ays, gits} Denoiser sigma schedule (default: discrete) --clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1) <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x --vae-tiling process vae in tiles to reduce memory usage --vae-on-cpu keep vae in cpu (for low vram) - --clip-on-cpu keep clip in cpu (for low vram). + --clip-on-cpu keep clip in cpu (for low vram) --control-net-cpu keep controlnet in cpu (for low vram) --canny apply canny preprocessor (edge detection) --color Colors the logging tags according to level @@ -253,6 +254,7 @@ arguments: # ./bin/sd -m ../models/sd_xl_base_1.0.safetensors --vae ../models/sdxl_vae-fp16-fix.safetensors -H 1024 -W 1024 -p "a lovely cat" -v # ./bin/sd -m ../models/sd3_medium_incl_clips_t5xxlfp16.safetensors -H 1024 -W 1024 -p 'a lovely cat holding a sign says \"Stable Diffusion CPP\"' --cfg-scale 4.5 --sampling-method euler -v # ./bin/sd --diffusion-model ../models/flux1-dev-q3_k.gguf --vae ../models/ae.sft --clip_l ../models/clip_l.safetensors --t5xxl ../models/t5xxl_fp16.safetensors -p "a lovely cat holding a sign says 'flux.cpp'" --cfg-scale 1.0 --sampling-method euler -v +# ./bin/sd -m ..\models\sd3.5_large.safetensors --clip_l ..\models\clip_l.safetensors --clip_g ..\models\clip_g.safetensors --t5xxl ..\models\t5xxl_fp16.safetensors -H 1024 -W 1024 -p 'a lovely cat holding a sign says \"Stable diffusion 3.5 Large\"' --cfg-scale 4.5 --sampling-method euler -v ``` Using formats of different precisions will yield results of varying quality. diff --git a/assets/sd3.5_large.png b/assets/sd3.5_large.png new file mode 100644 index 00000000..b76b1322 Binary files /dev/null and b/assets/sd3.5_large.png differ diff --git a/conditioner.hpp b/conditioner.hpp index 43d0a6d5..ac2ab7eb 100644 --- a/conditioner.hpp +++ b/conditioner.hpp @@ -1001,8 +1001,8 @@ struct FluxCLIPEmbedder : public Conditioner { } void get_param_tensors(std::map& tensors) { - clip_l->get_param_tensors(tensors, "text_encoders.clip_l.text_model"); - t5->get_param_tensors(tensors, "text_encoders.t5xxl"); + clip_l->get_param_tensors(tensors, "text_encoders.clip_l.transformer.text_model"); + t5->get_param_tensors(tensors, "text_encoders.t5xxl.transformer"); } void alloc_params_buffer() { diff --git a/denoiser.hpp b/denoiser.hpp index 287b1093..975699d2 100644 --- a/denoiser.hpp +++ b/denoiser.hpp @@ -49,7 +49,7 @@ struct ExponentialSchedule : SigmaSchedule { // Calculate step size float log_sigma_min = std::log(sigma_min); float log_sigma_max = std::log(sigma_max); - float step = (log_sigma_max - log_sigma_min) / (n - 1); + float step = (log_sigma_max - log_sigma_min) / (n - 1); // Fill sigmas with exponential values for (uint32_t i = 0; i < n; ++i) { @@ -205,7 +205,7 @@ struct AYSSchedule : SigmaSchedule { /* * GITS Scheduler: https://github.com/zju-pi/diff-sampler/tree/main/gits-main -*/ + */ struct GITSSchedule : SigmaSchedule { std::vector get_sigmas(uint32_t n, float sigma_min, float sigma_max, t_to_sigma_t t_to_sigma) { if (sigma_max <= 0.0f) { @@ -221,7 +221,7 @@ struct GITSSchedule : SigmaSchedule { // Calculate the index based on the coefficient int index = static_cast((coeff - 0.80f) / 0.05f); // Ensure the index is within bounds - index = std::max(0, std::min(index, static_cast(GITS_NOISE.size() - 1))); + index = std::max(0, std::min(index, static_cast(GITS_NOISE.size() - 1))); const std::vector>& selected_noise = *GITS_NOISE[index]; if (n <= 20) { @@ -823,24 +823,24 @@ static void sample_k_diffusion(sample_method_t method, } break; case IPNDM: // iPNDM sampler from https://github.com/zju-pi/diff-sampler/tree/main/diff-solvers-main { - int max_order = 4; + int max_order = 4; ggml_tensor* x_next = x; std::vector buffer_model; for (int i = 0; i < steps; i++) { - float sigma = sigmas[i]; + float sigma = sigmas[i]; float sigma_next = sigmas[i + 1]; ggml_tensor* x_cur = x_next; - float* vec_x_cur = (float*)x_cur->data; - float* vec_x_next = (float*)x_next->data; + float* vec_x_cur = (float*)x_cur->data; + float* vec_x_next = (float*)x_next->data; // Denoising step ggml_tensor* denoised = model(x_cur, sigma, i + 1); - float* vec_denoised = (float*)denoised->data; + float* vec_denoised = (float*)denoised->data; // d_cur = (x_cur - denoised) / sigma struct ggml_tensor* d_cur = ggml_dup_tensor(work_ctx, x_cur); - float* vec_d_cur = (float*)d_cur->data; + float* vec_d_cur = (float*)d_cur->data; for (int j = 0; j < ggml_nelements(d_cur); j++) { vec_d_cur[j] = (vec_x_cur[j] - vec_denoised[j]) / sigma; @@ -857,34 +857,31 @@ static void sample_k_diffusion(sample_method_t method, break; case 2: // Use one history point - { - float* vec_d_prev1 = (float*)buffer_model.back()->data; - for (int j = 0; j < ggml_nelements(x_next); j++) { - vec_x_next[j] = vec_x_cur[j] + (sigma_next - sigma) * (3 * vec_d_cur[j] - vec_d_prev1[j]) / 2; - } + { + float* vec_d_prev1 = (float*)buffer_model.back()->data; + for (int j = 0; j < ggml_nelements(x_next); j++) { + vec_x_next[j] = vec_x_cur[j] + (sigma_next - sigma) * (3 * vec_d_cur[j] - vec_d_prev1[j]) / 2; } - break; + } break; case 3: // Use two history points - { - float* vec_d_prev1 = (float*)buffer_model.back()->data; - float* vec_d_prev2 = (float*)buffer_model[buffer_model.size() - 2]->data; - for (int j = 0; j < ggml_nelements(x_next); j++) { - vec_x_next[j] = vec_x_cur[j] + (sigma_next - sigma) * (23 * vec_d_cur[j] - 16 * vec_d_prev1[j] + 5 * vec_d_prev2[j]) / 12; - } + { + float* vec_d_prev1 = (float*)buffer_model.back()->data; + float* vec_d_prev2 = (float*)buffer_model[buffer_model.size() - 2]->data; + for (int j = 0; j < ggml_nelements(x_next); j++) { + vec_x_next[j] = vec_x_cur[j] + (sigma_next - sigma) * (23 * vec_d_cur[j] - 16 * vec_d_prev1[j] + 5 * vec_d_prev2[j]) / 12; } - break; + } break; case 4: // Use three history points - { - float* vec_d_prev1 = (float*)buffer_model.back()->data; - float* vec_d_prev2 = (float*)buffer_model[buffer_model.size() - 2]->data; - float* vec_d_prev3 = (float*)buffer_model[buffer_model.size() - 3]->data; - for (int j = 0; j < ggml_nelements(x_next); j++) { - vec_x_next[j] = vec_x_cur[j] + (sigma_next - sigma) * (55 * vec_d_cur[j] - 59 * vec_d_prev1[j] + 37 * vec_d_prev2[j] - 9 * vec_d_prev3[j]) / 24; - } + { + float* vec_d_prev1 = (float*)buffer_model.back()->data; + float* vec_d_prev2 = (float*)buffer_model[buffer_model.size() - 2]->data; + float* vec_d_prev3 = (float*)buffer_model[buffer_model.size() - 3]->data; + for (int j = 0; j < ggml_nelements(x_next); j++) { + vec_x_next[j] = vec_x_cur[j] + (sigma_next - sigma) * (55 * vec_d_cur[j] - 59 * vec_d_prev1[j] + 37 * vec_d_prev2[j] - 9 * vec_d_prev3[j]) / 24; } - break; + } break; } // Manage buffer_model @@ -906,27 +903,27 @@ static void sample_k_diffusion(sample_method_t method, ggml_tensor* x_next = x; for (int i = 0; i < steps; i++) { - float sigma = sigmas[i]; + float sigma = sigmas[i]; float t_next = sigmas[i + 1]; // Denoising step - ggml_tensor* denoised = model(x, sigma, i + 1); - float* vec_denoised = (float*)denoised->data; + ggml_tensor* denoised = model(x, sigma, i + 1); + float* vec_denoised = (float*)denoised->data; struct ggml_tensor* d_cur = ggml_dup_tensor(work_ctx, x); - float* vec_d_cur = (float*)d_cur->data; - float* vec_x = (float*)x->data; + float* vec_d_cur = (float*)d_cur->data; + float* vec_x = (float*)x->data; // d_cur = (x - denoised) / sigma for (int j = 0; j < ggml_nelements(d_cur); j++) { vec_d_cur[j] = (vec_x[j] - vec_denoised[j]) / sigma; } - int order = std::min(max_order, i + 1); - float h_n = t_next - sigma; + int order = std::min(max_order, i + 1); + float h_n = t_next - sigma; float h_n_1 = (i > 0) ? (sigma - sigmas[i - 1]) : h_n; switch (order) { - case 1: // First Euler step + case 1: // First Euler step for (int j = 0; j < ggml_nelements(x_next); j++) { vec_x[j] += vec_d_cur[j] * h_n; } @@ -941,7 +938,7 @@ static void sample_k_diffusion(sample_method_t method, } case 3: { - float h_n_2 = (i > 1) ? (sigmas[i - 1] - sigmas[i - 2]) : h_n_1; + float h_n_2 = (i > 1) ? (sigmas[i - 1] - sigmas[i - 2]) : h_n_1; float* vec_d_prev1 = (float*)buffer_model.back()->data; float* vec_d_prev2 = (buffer_model.size() > 1) ? (float*)buffer_model[buffer_model.size() - 2]->data : vec_d_prev1; for (int j = 0; j < ggml_nelements(x_next); j++) { @@ -951,8 +948,8 @@ static void sample_k_diffusion(sample_method_t method, } case 4: { - float h_n_2 = (i > 1) ? (sigmas[i - 1] - sigmas[i - 2]) : h_n_1; - float h_n_3 = (i > 2) ? (sigmas[i - 2] - sigmas[i - 3]) : h_n_2; + float h_n_2 = (i > 1) ? (sigmas[i - 1] - sigmas[i - 2]) : h_n_1; + float h_n_3 = (i > 2) ? (sigmas[i - 2] - sigmas[i - 3]) : h_n_2; float* vec_d_prev1 = (float*)buffer_model.back()->data; float* vec_d_prev2 = (buffer_model.size() > 1) ? (float*)buffer_model[buffer_model.size() - 2]->data : vec_d_prev1; float* vec_d_prev3 = (buffer_model.size() > 2) ? (float*)buffer_model[buffer_model.size() - 3]->data : vec_d_prev2; diff --git a/docs/sd3.md b/docs/sd3.md new file mode 100644 index 00000000..777511d4 --- /dev/null +++ b/docs/sd3.md @@ -0,0 +1,20 @@ +# How to Use + +## Download weights + +- Download sd3.5_large from https://huggingface.co/stabilityai/stable-diffusion-3.5-large/blob/main/sd3.5_large.safetensors +- Download clip_g from https://huggingface.co/Comfy-Org/stable-diffusion-3.5-fp8/blob/main/text_encoders/clip_g.safetensors +- Download clip_l from https://huggingface.co/Comfy-Org/stable-diffusion-3.5-fp8/blob/main/text_encoders/clip_l.safetensors +- Download t5xxl from https://huggingface.co/Comfy-Org/stable-diffusion-3.5-fp8/blob/main/text_encoders/t5xxl_fp16.safetensors + + +## Run + +### SD3.5 Large +For example: + +``` +.\bin\Release\sd.exe -m ..\models\sd3.5_large.safetensors --clip_l ..\models\clip_l.safetensors --clip_g ..\models\clip_g.safetensors --t5xxl ..\models\t5xxl_fp16.safetensors -H 1024 -W 1024 -p 'a lovely cat holding a sign says \"Stable diffusion 3.5 Large\"' --cfg-scale 4.5 --sampling-method euler -v +``` + +![](../assets/sd3.5_large.png) \ No newline at end of file diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index ceae27b8..f1bdc698 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -69,9 +69,9 @@ enum SDMode { struct SDParams { int n_threads = -1; SDMode mode = TXT2IMG; - std::string model_path; std::string clip_l_path; + std::string clip_g_path; std::string t5xxl_path; std::string diffusion_model_path; std::string vae_path; @@ -128,6 +128,7 @@ void print_params(SDParams params) { printf(" model_path: %s\n", params.model_path.c_str()); printf(" wtype: %s\n", params.wtype < SD_TYPE_COUNT ? sd_type_name(params.wtype) : "unspecified"); printf(" clip_l_path: %s\n", params.clip_l_path.c_str()); + printf(" clip_g_path: %s\n", params.clip_g_path.c_str()); printf(" t5xxl_path: %s\n", params.t5xxl_path.c_str()); printf(" diffusion_model_path: %s\n", params.diffusion_model_path.c_str()); printf(" vae_path: %s\n", params.vae_path.c_str()); @@ -171,23 +172,24 @@ void print_usage(int argc, const char* argv[]) { printf("arguments:\n"); printf(" -h, --help show this help message and exit\n"); printf(" -M, --mode [MODEL] run mode (txt2img or img2img or convert, default: txt2img)\n"); - printf(" -t, --threads N number of threads to use during computation (default: -1).\n"); + printf(" -t, --threads N number of threads to use during computation (default: -1)\n"); printf(" If threads <= 0, then threads will be set to the number of CPU physical cores\n"); printf(" -m, --model [MODEL] path to full model\n"); printf(" --diffusion-model path to the standalone diffusion model\n"); printf(" --clip_l path to the clip-l text encoder\n"); - printf(" --t5xxl path to the the t5xxl text encoder.\n"); + printf(" --clip_g path to the clip-l text encoder\n"); + printf(" --t5xxl path to the the t5xxl text encoder\n"); printf(" --vae [VAE] path to vae\n"); printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n"); printf(" --control-net [CONTROL_PATH] path to control net model\n"); - printf(" --embd-dir [EMBEDDING_PATH] path to embeddings.\n"); - printf(" --stacked-id-embd-dir [DIR] path to PHOTOMAKER stacked id embeddings.\n"); - printf(" --input-id-images-dir [DIR] path to PHOTOMAKER input id images dir.\n"); + printf(" --embd-dir [EMBEDDING_PATH] path to embeddings\n"); + printf(" --stacked-id-embd-dir [DIR] path to PHOTOMAKER stacked id embeddings\n"); + printf(" --input-id-images-dir [DIR] path to PHOTOMAKER input id images dir\n"); printf(" --normalize-input normalize PHOTOMAKER input id images\n"); - printf(" --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now.\n"); + printf(" --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now\n"); printf(" --upscale-repeats Run the ESRGAN upscaler this many times (default 1)\n"); printf(" --type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0, q2_k, q3_k, q4_k)\n"); - printf(" If not specified, the default is the type of the weight file.\n"); + printf(" If not specified, the default is the type of the weight file\n"); printf(" --lora-model-dir [DIR] lora model directory\n"); printf(" -i, --init-img [IMAGE] path to the input image, required by img2img\n"); printf(" --control-image [IMAGE] path to image condition, control net\n"); @@ -206,13 +208,13 @@ void print_usage(int argc, const char* argv[]) { printf(" --steps STEPS number of sample steps (default: 20)\n"); printf(" --rng {std_default, cuda} RNG (default: cuda)\n"); printf(" -s SEED, --seed SEED RNG seed (default: 42, use random seed for < 0)\n"); - printf(" -b, --batch-count COUNT number of images to generate.\n"); + printf(" -b, --batch-count COUNT number of images to generate\n"); printf(" --schedule {discrete, karras, exponential, ays, gits} Denoiser sigma schedule (default: discrete)\n"); printf(" --clip-skip N ignore last layers of CLIP network; 1 ignores none, 2 ignores one layer (default: -1)\n"); printf(" <= 0 represents unspecified, will be 1 for SD1.x, 2 for SD2.x\n"); printf(" --vae-tiling process vae in tiles to reduce memory usage\n"); printf(" --vae-on-cpu keep vae in cpu (for low vram)\n"); - printf(" --clip-on-cpu keep clip in cpu (for low vram).\n"); + printf(" --clip-on-cpu keep clip in cpu (for low vram)\n"); printf(" --control-net-cpu keep controlnet in cpu (for low vram)\n"); printf(" --canny apply canny preprocessor (edge detection)\n"); printf(" --color Colors the logging tags according to level\n"); @@ -262,6 +264,12 @@ void parse_args(int argc, const char** argv, SDParams& params) { break; } params.clip_l_path = argv[i]; + } else if (arg == "--clip_g") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.clip_g_path = argv[i]; } else if (arg == "--t5xxl") { if (++i >= argc) { invalid_arg = true; @@ -765,6 +773,7 @@ int main(int argc, const char* argv[]) { sd_ctx_t* sd_ctx = new_sd_ctx(params.model_path.c_str(), params.clip_l_path.c_str(), + params.clip_g_path.c_str(), params.t5xxl_path.c_str(), params.diffusion_model_path.c_str(), params.vae_path.c_str(), diff --git a/ggml_extend.hpp b/ggml_extend.hpp index 810f2b9e..e50137d5 100644 --- a/ggml_extend.hpp +++ b/ggml_extend.hpp @@ -368,8 +368,8 @@ __STATIC_INLINE__ void ggml_merge_tensor_2d(struct ggml_tensor* input, int64_t height = input->ne[1]; int64_t channels = input->ne[2]; - int64_t img_width = output->ne[0]; - int64_t img_height = output->ne[1]; + int64_t img_width = output->ne[0]; + int64_t img_height = output->ne[1]; GGML_ASSERT(input->type == GGML_TYPE_F32 && output->type == GGML_TYPE_F32); for (int iy = 0; iy < height; iy++) { @@ -380,7 +380,7 @@ __STATIC_INLINE__ void ggml_merge_tensor_2d(struct ggml_tensor* input, float old_value = ggml_tensor_get_f32(output, x + ix, y + iy, k); const float x_f_0 = (x > 0) ? ix / float(overlap) : 1; - const float x_f_1 = (x < (img_width - width)) ? (width - ix) / float(overlap) : 1 ; + const float x_f_1 = (x < (img_width - width)) ? (width - ix) / float(overlap) : 1; const float y_f_0 = (y > 0) ? iy / float(overlap) : 1; const float y_f_1 = (y < (img_height - height)) ? (height - iy) / float(overlap) : 1; @@ -390,8 +390,7 @@ __STATIC_INLINE__ void ggml_merge_tensor_2d(struct ggml_tensor* input, ggml_tensor_set_f32( output, old_value + new_value * ggml_smootherstep_f32(y_f) * ggml_smootherstep_f32(x_f), - x + ix, y + iy, k - ); + x + ix, y + iy, k); } else { ggml_tensor_set_f32(output, new_value, x + ix, y + iy, k); } diff --git a/mmdit.hpp b/mmdit.hpp index 6f3a8a06..3a278dac 100644 --- a/mmdit.hpp +++ b/mmdit.hpp @@ -142,29 +142,77 @@ struct VectorEmbedder : public GGMLBlock { } }; +class RMSNorm : public UnaryBlock { +protected: + int64_t hidden_size; + float eps; + + void init_params(struct ggml_context* ctx, ggml_type wtype) { + params["weight"] = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, hidden_size); + } + +public: + RMSNorm(int64_t hidden_size, + float eps = 1e-06f) + : hidden_size(hidden_size), + eps(eps) {} + + struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_tensor* x) { + struct ggml_tensor* w = params["weight"]; + x = ggml_rms_norm(ctx, x, eps); + x = ggml_mul(ctx, x, w); + return x; + } +}; + class SelfAttention : public GGMLBlock { public: int64_t num_heads; bool pre_only; + std::string qk_norm; public: SelfAttention(int64_t dim, - int64_t num_heads = 8, - bool qkv_bias = false, - bool pre_only = false) - : num_heads(num_heads), pre_only(pre_only) { - // qk_norm is always None - blocks["qkv"] = std::shared_ptr(new Linear(dim, dim * 3, qkv_bias)); + int64_t num_heads = 8, + std::string qk_norm = "", + bool qkv_bias = false, + bool pre_only = false) + : num_heads(num_heads), pre_only(pre_only), qk_norm(qk_norm) { + int64_t d_head = dim / num_heads; + blocks["qkv"] = std::shared_ptr(new Linear(dim, dim * 3, qkv_bias)); if (!pre_only) { blocks["proj"] = std::shared_ptr(new Linear(dim, dim)); } + if (qk_norm == "rms") { + blocks["ln_q"] = std::shared_ptr(new RMSNorm(d_head, 1.0e-6)); + blocks["ln_k"] = std::shared_ptr(new RMSNorm(d_head, 1.0e-6)); + } else if (qk_norm == "ln") { + blocks["ln_q"] = std::shared_ptr(new LayerNorm(d_head, 1.0e-6)); + blocks["ln_k"] = std::shared_ptr(new LayerNorm(d_head, 1.0e-6)); + } } std::vector pre_attention(struct ggml_context* ctx, struct ggml_tensor* x) { auto qkv_proj = std::dynamic_pointer_cast(blocks["qkv"]); - auto qkv = qkv_proj->forward(ctx, x); - return split_qkv(ctx, qkv); + auto qkv = qkv_proj->forward(ctx, x); + auto qkv_vec = split_qkv(ctx, qkv); + int64_t head_dim = qkv_vec[0]->ne[0] / num_heads; + auto q = ggml_reshape_4d(ctx, qkv_vec[0], head_dim, num_heads, qkv_vec[0]->ne[1], qkv_vec[0]->ne[2]); // [N, n_token, n_head, d_head] + auto k = ggml_reshape_4d(ctx, qkv_vec[1], head_dim, num_heads, qkv_vec[1]->ne[1], qkv_vec[1]->ne[2]); // [N, n_token, n_head, d_head] + auto v = qkv_vec[2]; // [N, n_token, n_head*d_head] + + if (qk_norm == "rms" || qk_norm == "ln") { + auto ln_q = std::dynamic_pointer_cast(blocks["ln_q"]); + auto ln_k = std::dynamic_pointer_cast(blocks["ln_k"]); + q = ln_q->forward(ctx, q); + k = ln_k->forward(ctx, k); + } + + q = ggml_reshape_3d(ctx, q, q->ne[0] * q->ne[1], q->ne[2], q->ne[3]); // [N, n_token, n_head*d_head] + k = ggml_reshape_3d(ctx, k, k->ne[0] * k->ne[1], k->ne[2], k->ne[3]); // [N, n_token, n_head*d_head] + + return {q, k, v}; } struct ggml_tensor* post_attention(struct ggml_context* ctx, struct ggml_tensor* x) { @@ -208,16 +256,16 @@ struct DismantledBlock : public GGMLBlock { public: DismantledBlock(int64_t hidden_size, int64_t num_heads, - float mlp_ratio = 4.0, - bool qkv_bias = false, - bool pre_only = false) + float mlp_ratio = 4.0, + std::string qk_norm = "", + bool qkv_bias = false, + bool pre_only = false) : num_heads(num_heads), pre_only(pre_only) { // rmsnorm is always Flase // scale_mod_only is always Flase // swiglu is always Flase - // qk_norm is always Flase blocks["norm1"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-06f, false)); - blocks["attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qkv_bias, pre_only)); + blocks["attn"] = std::shared_ptr(new SelfAttention(hidden_size, num_heads, qk_norm, qkv_bias, pre_only)); if (!pre_only) { blocks["norm2"] = std::shared_ptr(new LayerNorm(hidden_size, 1e-06f, false)); @@ -396,12 +444,12 @@ struct JointBlock : public GGMLBlock { public: JointBlock(int64_t hidden_size, int64_t num_heads, - float mlp_ratio = 4.0, - bool qkv_bias = false, - bool pre_only = false) { - // qk_norm is always Flase - blocks["context_block"] = std::shared_ptr(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qkv_bias, pre_only)); - blocks["x_block"] = std::shared_ptr(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qkv_bias, false)); + float mlp_ratio = 4.0, + std::string qk_norm = "", + bool qkv_bias = false, + bool pre_only = false) { + blocks["context_block"] = std::shared_ptr(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, pre_only)); + blocks["x_block"] = std::shared_ptr(new DismantledBlock(hidden_size, num_heads, mlp_ratio, qk_norm, qkv_bias, false)); } std::pair forward(struct ggml_context* ctx, @@ -455,18 +503,20 @@ struct FinalLayer : public GGMLBlock { struct MMDiT : public GGMLBlock { // Diffusion model with a Transformer backbone. protected: - SDVersion version = VERSION_SD3_2B; - int64_t input_size = -1; - int64_t patch_size = 2; - int64_t in_channels = 16; - int64_t depth = 24; - float mlp_ratio = 4.0f; - int64_t adm_in_channels = 2048; - int64_t out_channels = 16; - int64_t pos_embed_max_size = 192; - int64_t num_patchs = 36864; // 192 * 192 - int64_t context_size = 4096; + SDVersion version = VERSION_SD3_2B; + int64_t input_size = -1; + int64_t patch_size = 2; + int64_t in_channels = 16; + int64_t depth = 24; + float mlp_ratio = 4.0f; + int64_t adm_in_channels = 2048; + int64_t out_channels = 16; + int64_t pos_embed_max_size = 192; + int64_t num_patchs = 36864; // 192 * 192 + int64_t context_size = 4096; + int64_t context_embedder_out_dim = 1536; int64_t hidden_size; + std::string qk_norm; void init_params(struct ggml_context* ctx, ggml_type wtype) { params["pos_embed"] = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hidden_size, num_patchs, 1); @@ -481,23 +531,36 @@ struct MMDiT : public GGMLBlock { // rmsnorm is alwalys False // scale_mod_only is alwalys False // swiglu is alwalys False - // qk_norm is always None // qkv_bias is always True // context_processor_layers is always None // pos_embed_scaling_factor is not used // pos_embed_offset is not used // context_embedder_config is always {'target': 'torch.nn.Linear', 'params': {'in_features': 4096, 'out_features': 1536}} if (version == VERSION_SD3_2B) { - input_size = -1; - patch_size = 2; - in_channels = 16; - depth = 24; - mlp_ratio = 4.0f; - adm_in_channels = 2048; - out_channels = 16; - pos_embed_max_size = 192; - num_patchs = 36864; // 192 * 192 - context_size = 4096; + input_size = -1; + patch_size = 2; + in_channels = 16; + depth = 24; + mlp_ratio = 4.0f; + adm_in_channels = 2048; + out_channels = 16; + pos_embed_max_size = 192; + num_patchs = 36864; // 192 * 192 + context_size = 4096; + context_embedder_out_dim = 1536; + } else if (version == VERSION_SD3_5_8B) { + input_size = -1; + patch_size = 2; + in_channels = 16; + depth = 38; + mlp_ratio = 4.0f; + adm_in_channels = 2048; + out_channels = 16; + pos_embed_max_size = 192; + num_patchs = 36864; // 192 * 192 + context_size = 4096; + context_embedder_out_dim = 2432; + qk_norm = "rms"; } int64_t default_out_channels = in_channels; hidden_size = 64 * depth; @@ -510,12 +573,13 @@ struct MMDiT : public GGMLBlock { blocks["y_embedder"] = std::shared_ptr(new VectorEmbedder(adm_in_channels, hidden_size)); } - blocks["context_embedder"] = std::shared_ptr(new Linear(4096, 1536, true, true)); + blocks["context_embedder"] = std::shared_ptr(new Linear(4096, context_embedder_out_dim, true, true)); for (int i = 0; i < depth; i++) { blocks["joint_blocks." + std::to_string(i)] = std::shared_ptr(new JointBlock(hidden_size, num_heads, mlp_ratio, + qk_norm, true, i == depth - 1)); } diff --git a/model.cpp b/model.cpp index b74a735f..26451cdc 100644 --- a/model.cpp +++ b/model.cpp @@ -430,6 +430,14 @@ std::string convert_tensor_name(std::string name) { if (starts_with(name, "diffusion_model")) { name = "model." + name; } + // size_t pos = name.find("lora_A"); + // if (pos != std::string::npos) { + // name.replace(pos, strlen("lora_A"), "lora_up"); + // } + // pos = name.find("lora_B"); + // if (pos != std::string::npos) { + // name.replace(pos, strlen("lora_B"), "lora_down"); + // } std::string new_name = name; if (starts_with(name, "cond_stage_model.") || starts_with(name, "conditioner.embedders.") || starts_with(name, "text_encoders.") || ends_with(name, ".vision_model.visual_projection.weight")) { new_name = convert_open_clip_to_hf_clip(name); @@ -466,6 +474,9 @@ std::string convert_tensor_name(std::string name) { if (pos != std::string::npos) { new_name.replace(pos, strlen(".processor"), ""); } + // if (starts_with(new_name, "transformer.transformer_blocks") || starts_with(new_name, "transformer.single_transformer_blocks")) { + // new_name = "model.diffusion_model." + new_name; + // } pos = new_name.rfind("lora"); if (pos != std::string::npos) { std::string name_without_network_parts = new_name.substr(0, pos - 1); @@ -1354,6 +1365,7 @@ bool ModelLoader::init_from_ckpt_file(const std::string& file_path, const std::s SDVersion ModelLoader::get_sd_version() { TensorStorage token_embedding_weight; bool is_flux = false; + bool is_sd3 = false; for (auto& tensor_storage : tensor_storages) { if (tensor_storage.name.find("model.diffusion_model.guidance_in.in_layer.weight") != std::string::npos) { return VERSION_FLUX_DEV; @@ -1361,8 +1373,11 @@ SDVersion ModelLoader::get_sd_version() { if (tensor_storage.name.find("model.diffusion_model.double_blocks.") != std::string::npos) { is_flux = true; } + if (tensor_storage.name.find("joint_blocks.37.x_block.attn.ln_q.weight") != std::string::npos) { + return VERSION_SD3_5_8B; + } if (tensor_storage.name.find("model.diffusion_model.joint_blocks.23.") != std::string::npos) { - return VERSION_SD3_2B; + is_sd3 = true; } if (tensor_storage.name.find("conditioner.embedders.1") != std::string::npos) { return VERSION_SDXL; @@ -1387,6 +1402,9 @@ SDVersion ModelLoader::get_sd_version() { if (is_flux) { return VERSION_FLUX_SCHNELL; } + if (is_sd3) { + return VERSION_SD3_2B; + } if (token_embedding_weight.ne[0] == 768) { return VERSION_SD1; } else if (token_embedding_weight.ne[0] == 1024) { diff --git a/model.h b/model.h index 33f3fbcd..4efbdf81 100644 --- a/model.h +++ b/model.h @@ -25,6 +25,7 @@ enum SDVersion { VERSION_SD3_2B, VERSION_FLUX_DEV, VERSION_FLUX_SCHNELL, + VERSION_SD3_5_8B, VERSION_COUNT, }; diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 07b59bb8..4d28a147 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -31,7 +31,8 @@ const char* model_version_to_str[] = { "SVD", "SD3 2B", "Flux Dev", - "Flux Schnell"}; + "Flux Schnell", + "SD3.5 8B"}; const char* sampling_methods_str[] = { "Euler A", @@ -139,6 +140,7 @@ class StableDiffusionGGML { bool load_from_file(const std::string& model_path, const std::string& clip_l_path, + const std::string& clip_g_path, const std::string& t5xxl_path, const std::string& diffusion_model_path, const std::string& vae_path, @@ -167,7 +169,7 @@ class StableDiffusionGGML { for (int device = 0; device < ggml_backend_vk_get_device_count(); ++device) { backend = ggml_backend_vk_init(device); } - if(!backend) { + if (!backend) { LOG_WARN("Failed to initialize Vulkan backend"); } #endif @@ -181,7 +183,7 @@ class StableDiffusionGGML { backend = ggml_backend_cpu_init(); } #ifdef SD_USE_FLASH_ATTENTION -#if defined(SD_USE_CUBLAS) || defined(SD_USE_METAL) || defined (SD_USE_SYCL) || defined(SD_USE_VULKAN) +#if defined(SD_USE_CUBLAS) || defined(SD_USE_METAL) || defined(SD_USE_SYCL) || defined(SD_USE_VULKAN) LOG_WARN("Flash Attention not supported with GPU Backend"); #else LOG_INFO("Flash Attention enabled"); @@ -200,14 +202,21 @@ class StableDiffusionGGML { if (clip_l_path.size() > 0) { LOG_INFO("loading clip_l from '%s'", clip_l_path.c_str()); - if (!model_loader.init_from_file(clip_l_path, "text_encoders.clip_l.")) { + if (!model_loader.init_from_file(clip_l_path, "text_encoders.clip_l.transformer.")) { LOG_WARN("loading clip_l from '%s' failed", clip_l_path.c_str()); } } + if (clip_g_path.size() > 0) { + LOG_INFO("loading clip_g from '%s'", clip_g_path.c_str()); + if (!model_loader.init_from_file(clip_g_path, "text_encoders.clip_g.transformer.")) { + LOG_WARN("loading clip_g from '%s' failed", clip_g_path.c_str()); + } + } + if (t5xxl_path.size() > 0) { LOG_INFO("loading t5xxl from '%s'", t5xxl_path.c_str()); - if (!model_loader.init_from_file(t5xxl_path, "text_encoders.t5xxl.")) { + if (!model_loader.init_from_file(t5xxl_path, "text_encoders.t5xxl.transformer.")) { LOG_WARN("loading t5xxl from '%s' failed", t5xxl_path.c_str()); } } @@ -279,7 +288,7 @@ class StableDiffusionGGML { "try specifying SDXL VAE FP16 Fix with the --vae parameter. " "You can find it here: https://huggingface.co/madebyollin/sdxl-vae-fp16-fix/blob/main/sdxl_vae.safetensors"); } - } else if (version == VERSION_SD3_2B) { + } else if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) { scale_factor = 1.5305f; } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { scale_factor = 0.3611; @@ -302,7 +311,7 @@ class StableDiffusionGGML { } else { clip_backend = backend; bool use_t5xxl = false; - if (version == VERSION_SD3_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { + if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { use_t5xxl = true; } if (!ggml_backend_is_cpu(backend) && use_t5xxl && conditioner_wtype != GGML_TYPE_F32) { @@ -313,7 +322,7 @@ class StableDiffusionGGML { LOG_INFO("CLIP: Using CPU backend"); clip_backend = ggml_backend_cpu_init(); } - if (version == VERSION_SD3_2B) { + if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) { cond_stage_model = std::make_shared(clip_backend, conditioner_wtype); diffusion_model = std::make_shared(backend, diffusion_model_wtype, version); } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { @@ -511,7 +520,7 @@ class StableDiffusionGGML { is_using_v_parameterization = true; } - if (version == VERSION_SD3_2B) { + if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) { LOG_INFO("running in FLOW mode"); denoiser = std::make_shared(); } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { @@ -939,7 +948,7 @@ class StableDiffusionGGML { if (use_tiny_autoencoder) { C = 4; } else { - if (version == VERSION_SD3_2B) { + if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B) { C = 32; } else if (version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { C = 32; @@ -1008,6 +1017,7 @@ struct sd_ctx_t { sd_ctx_t* new_sd_ctx(const char* model_path_c_str, const char* clip_l_path_c_str, + const char* clip_g_path_c_str, const char* t5xxl_path_c_str, const char* diffusion_model_path_c_str, const char* vae_path_c_str, @@ -1032,6 +1042,7 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str, } std::string model_path(model_path_c_str); std::string clip_l_path(clip_l_path_c_str); + std::string clip_g_path(clip_g_path_c_str); std::string t5xxl_path(t5xxl_path_c_str); std::string diffusion_model_path(diffusion_model_path_c_str); std::string vae_path(vae_path_c_str); @@ -1052,6 +1063,7 @@ sd_ctx_t* new_sd_ctx(const char* model_path_c_str, if (!sd_ctx->sd->load_from_file(model_path, clip_l_path, + clip_g_path, t5xxl_path_c_str, diffusion_model_path, vae_path, @@ -1269,7 +1281,7 @@ sd_image_t* generate_image(sd_ctx_t* sd_ctx, // Sample std::vector final_latents; // collect latents to decode int C = 4; - if (sd_ctx->sd->version == VERSION_SD3_2B) { + if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) { C = 16; } else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { C = 16; @@ -1382,7 +1394,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, struct ggml_init_params params; params.mem_size = static_cast(10 * 1024 * 1024); // 10 MB - if (sd_ctx->sd->version == VERSION_SD3_2B) { + if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) { params.mem_size *= 3; } if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { @@ -1408,7 +1420,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, std::vector sigmas = sd_ctx->sd->denoiser->get_sigmas(sample_steps); int C = 4; - if (sd_ctx->sd->version == VERSION_SD3_2B) { + if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) { C = 16; } else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { C = 16; @@ -1416,7 +1428,7 @@ sd_image_t* txt2img(sd_ctx_t* sd_ctx, int W = width / 8; int H = height / 8; ggml_tensor* init_latent = ggml_new_tensor_4d(work_ctx, GGML_TYPE_F32, W, H, C, 1); - if (sd_ctx->sd->version == VERSION_SD3_2B) { + if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) { ggml_set_f32(init_latent, 0.0609f); } else if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { ggml_set_f32(init_latent, 0.1159f); @@ -1477,7 +1489,7 @@ sd_image_t* img2img(sd_ctx_t* sd_ctx, struct ggml_init_params params; params.mem_size = static_cast(10 * 1024 * 1024); // 10 MB - if (sd_ctx->sd->version == VERSION_SD3_2B) { + if (sd_ctx->sd->version == VERSION_SD3_2B || sd_ctx->sd->version == VERSION_SD3_5_8B) { params.mem_size *= 2; } if (sd_ctx->sd->version == VERSION_FLUX_DEV || sd_ctx->sd->version == VERSION_FLUX_SCHNELL) { diff --git a/stable-diffusion.h b/stable-diffusion.h index 0d4cc1fd..812e8fc9 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -124,6 +124,7 @@ typedef struct sd_ctx_t sd_ctx_t; SD_API sd_ctx_t* new_sd_ctx(const char* model_path, const char* clip_l_path, + const char* clip_g_path, const char* t5xxl_path, const char* diffusion_model_path, const char* vae_path, diff --git a/vae.hpp b/vae.hpp index 85319fde..42b694cd 100644 --- a/vae.hpp +++ b/vae.hpp @@ -457,7 +457,7 @@ class AutoencodingEngine : public GGMLBlock { bool use_video_decoder = false, SDVersion version = VERSION_SD1) : decode_only(decode_only), use_video_decoder(use_video_decoder) { - if (version == VERSION_SD3_2B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { + if (version == VERSION_SD3_2B || version == VERSION_SD3_5_8B || version == VERSION_FLUX_DEV || version == VERSION_FLUX_SCHNELL) { dd_config.z_channels = 16; use_quant = false; }