Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/dkackman/chiaSWARM
Browse files Browse the repository at this point in the history
  • Loading branch information
dkackman committed Jan 10, 2024
2 parents 69e279a + 4d79814 commit 587df1e
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 7 deletions.
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ RUN python -m pip install --upgrade pip
RUN python -m pip install wheel setuptools

RUN pip install torch torchvision torchaudio
RUN pip install diffusers[torch] transformers accelerate scipy ftfy safetensors moviepy opencv-python sentencepiece
RUN pip install diffusers[torch] transformers accelerate scipy ftfy safetensors moviepy opencv-python sentencepiece peft
RUN pip install aiohttp concurrent-log-handler pydub controlnet_aux qrcode matplotlib PyWavelets
RUN pip install --no-deps invisible-watermark
RUN pip install git+https://github.com/suno-ai/bark.git@main
Expand Down
2 changes: 1 addition & 1 deletion Install.ps1
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ if ($LASTEXITCODE -ne 0) { exit $LASTEXITCODE }
python -m pip install torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu118
if ($LASTEXITCODE -ne 0) { exit $LASTEXITCODE }

python -m pip install diffusers[torch] transformers accelerate scipy ftfy safetensors moviepy opencv-python sentencepiece
python -m pip install diffusers[torch] transformers accelerate scipy ftfy safetensors moviepy opencv-python sentencepiece peft
if ($LASTEXITCODE -ne 0) { exit $LASTEXITCODE }

python -m pip install aiohttp concurrent-log-handler pydub controlnet_aux qrcode matplotlib PyWavelets
Expand Down
2 changes: 1 addition & 1 deletion install.sh
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,7 @@ python -m pip install --upgrade pip
python -m pip install wheel setuptools

pip install torch torchvision torchaudio
pip install diffusers[torch] transformers accelerate scipy ftfy safetensors moviepy opencv-python sentencepiece
pip install diffusers[torch] transformers accelerate scipy ftfy safetensors moviepy opencv-python sentencepiece peft
pip install aiohttp concurrent-log-handler pydub controlnet_aux qrcode matplotlib PyWavelets
pip install --no-deps invisible-watermark
pip install git+https://github.com/suno-ai/bark.git@main
Expand Down
9 changes: 5 additions & 4 deletions swarm/diffusion/diffusion_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,16 @@ def diffusion_callback(device_identifier, model_name, **kwargs):
kwargs.pop("content_type", "image/jpeg"),
)

torch_dtype = torch.bfloat16 if kwargs.pop("use_bfloat16", False) else torch.float16
load_pipeline_args = {}
load_pipeline_args["revision"] = kwargs.pop("revision", "main")
load_pipeline_args["variant"] = kwargs.pop("variant", None)
load_pipeline_args["torch_dtype"] = torch.float16
load_pipeline_args["torch_dtype"] = torch_dtype
load_pipeline_args["use_safe_tensors"] = kwargs.pop("use_safe_tensors", None)

if "vae" in kwargs:
load_pipeline_args["vae"] = AutoencoderKL.from_pretrained(
kwargs.pop("vae"), torch_dtype=torch.float16
kwargs.pop("vae"), torch_dtype=torch_dtype
).to(device_identifier)

# if there is a controlnet load and configure it
Expand All @@ -52,7 +53,7 @@ def diffusion_callback(device_identifier, model_name, **kwargs):
load_pipeline_args["controlnet"] = controlnet_model_type.from_pretrained(
kwargs.pop("controlnet_model_name"),
revision=kwargs.pop("controlnet_revision", "main"),
torch_dtype=torch.float16,
torch_dtype=torch_dtype,
).to(device_identifier)

if kwargs.pop("save_preprocessed_input", False):
Expand All @@ -78,7 +79,7 @@ def diffusion_callback(device_identifier, model_name, **kwargs):
model_name,
controlnet=load_pipeline_args["controlnet"],
vae=load_pipeline_args.get("vae", None),
torch_dtype=torch.float16,
torch_dtype=torch_dtype,
).to(device_identifier)
# take out the original control_image
control_image = kwargs.pop("control_image", None)
Expand Down

0 comments on commit 587df1e

Please sign in to comment.