Skip to content

Commit

Permalink
Merge pull request #87 from noskill/ref
Browse files Browse the repository at this point in the history
refactor pipe weightsharing for quantised models
  • Loading branch information
Necr0x0Der authored Dec 4, 2024
2 parents a56ae2e + 11913a0 commit f112432
Show file tree
Hide file tree
Showing 10 changed files with 508 additions and 255 deletions.
9 changes: 9 additions & 0 deletions .github/workflows/python-app.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,18 @@ jobs:
METAFUSION_MODELS_DIR: models-full
run: |
cd tests && python pipe_test.py
- name: Test loader
env:
METAFUSION_MODELS_DIR: models-full
run: |
cd tests && python test_loader.py
- name: Test worker
env:
METAFUSION_MODELS_DIR: models-full
run: |
cd tests && python test_worker.py
- name: Test worker flux
env:
METAFUSION_MODELS_DIR: models-full
run: |
cd tests && python test_worker_flux.py
143 changes: 59 additions & 84 deletions multigen/loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Type, List
from typing import Type, List, Union, Optional, Any
from dataclasses import dataclass
import random
import copy as cp
from contextlib import nullcontext
Expand All @@ -10,44 +11,32 @@
import diffusers

from diffusers import DiffusionPipeline, StableDiffusionControlNetPipeline, StableDiffusionXLControlNetPipeline
from diffusers.utils import is_accelerate_available
if is_accelerate_available():
from accelerate import init_empty_weights
else:
init_empty_weights = nullcontext

from .util import get_model_size, awailable_ram, quantize, weightshare_copy


logger = logging.getLogger(__file__)


def weightshare_copy(pipe):
@dataclass(frozen=True)
class ModelDescriptor:
"""
Create a new pipe object then assign weights using load_state_dict from passed 'pipe'
Descriptor class for model identification that includes quantization information
"""
copy = pipe.__class__(**pipe.components)
ctx = init_empty_weights if is_accelerate_available() else nullcontext
with ctx():
for key, component in copy.components.items():
if getattr(copy, key) is None:
continue
if key in ('tokenizer', 'tokenizer_2', 'feature_extractor'):
setattr(copy, key, cp.deepcopy(getattr(copy, key)))
continue
cls = getattr(copy, key).__class__
if hasattr(cls, 'from_config'):
setattr(copy, key, cls.from_config(getattr(copy, key).config))
else:
setattr(copy, key, cls(getattr(copy, key).config))
# assign=True is needed since our copy is on "meta" device, i.g. weights are empty
for key, component in copy.components.items():
if key == 'tokenizer' or key == 'tokenizer_2':
continue
obj = getattr(copy, key)
if hasattr(obj, 'load_state_dict'):
obj.load_state_dict(getattr(pipe, key).state_dict(), assign=True)
# some buffers might not be transfered from pipe to copy
copy.to(pipe.device)
return copy
model_id: str
quantize_dtype: Optional[Any] = None

def __hash__(self):
return hash((self.model_id, str(self.quantize_dtype)))

def __eq__(self, other):
if isinstance(other, str):
return self.model_id == other

if not isinstance(other, ModelDescriptor):
return False
return (self.model_id == other.model_id and
self.quantize_dtype == other.quantize_dtype)


class Loader:
Expand All @@ -56,9 +45,8 @@ class for loading diffusion pipelines from files.
"""
def __init__(self):
self._lock = threading.RLock()
self._cpu_pipes = dict()
# idx -> list of (model_id, pipe) pairs
self._gpu_pipes = dict()
self._cpu_pipes = dict() # ModelDescriptor -> pipe
self._gpu_pipes = dict() # gpu idx -> list of (ModelDescriptor, pipe) pairs

def get_gpu(self, model_id) -> List[int]:
"""
Expand All @@ -73,24 +61,29 @@ def get_gpu(self, model_id) -> List[int]:
return result

def load_pipeline(self, cls: Type[DiffusionPipeline], path, torch_dtype=torch.bfloat16,
device=None, offload_device=None, **additional_args):
device=None, offload_device=None, quantize_dtype=None, **additional_args):
with self._lock:
logger.debug(f'looking for pipeline {cls} from {path} on {device}')
result = None
descriptor = ModelDescriptor(path, quantize_dtype)
found_quantized = False
if device is None:
device = torch.device('cpu', 0)
if device.type == 'cuda':
idx = device.index
gpu_pipes = self._gpu_pipes.get(idx, [])
for (key, value) in gpu_pipes:
if key == path:
if key == descriptor:
logger.debug(f'found pipe in gpu cache {key}')
result = self.from_pipe(cls, value, additional_args)
logger.debug(f'created pipe from gpu cache {key} on {device}')
return result
for (key, pipe) in self._cpu_pipes.items():
if key == path:
if key == descriptor:
found_quantized = True
logger.debug(f'found pipe in cpu cache {key} {pipe.device}')
if device.type == 'cuda':
pipe = cp.deepcopy(pipe)
result = self.from_pipe(cls, pipe, additional_args)
break
if result is None:
Expand All @@ -106,16 +99,26 @@ def load_pipeline(self, cls: Type[DiffusionPipeline], path, torch_dtype=torch.bf
logger.debug("prepare pipe before returning from loader")
logger.debug(f"{path} on {result.device} {result.dtype}")

# Add quantization if specified
if (not found_quantized) and quantize_dtype is not None:
logger.debug(f'Quantizing pipeline to {quantize_dtype}')
quantize(result, dtype=quantize_dtype)

if result.device != device:
logger.debug(f"move pipe to {device}")
result = result.to(dtype=torch_dtype, device=device)
if result.dtype != torch_dtype:
result = result.to(dtype=torch_dtype)

self.cache_pipeline(result, path)
logger.debug(f'result device before weightshare_copy {result.device}')
result = weightshare_copy(result)
logger.debug(f'result device after weightshare_copy {result.device}')
assert result.device.type == device.type
if device.type == 'cuda':
assert result.device.index == device.index
logger.debug(f'returning {type(result)} from {path} on {result.device}')
logger.debug(f'returning {type(result)} from {path} \
on {result.device} scheduler {id(result.scheduler)}')
return result

def from_pipe(self, cls, pipe, additional_args):
Expand All @@ -131,86 +134,58 @@ def from_pipe(self, cls, pipe, additional_args):
components.pop('controlnet')
return cls(**components, **additional_args)

def cache_pipeline(self, pipe: DiffusionPipeline, model_id):
def cache_pipeline(self, pipe: DiffusionPipeline, descriptor: ModelDescriptor):
logger.debug(f'caching pipeline {descriptor} {pipe}')
with self._lock:
device = pipe.device
if model_id not in self._cpu_pipes:
if descriptor not in self._cpu_pipes:
# deepcopy is needed since Module.to is an inplace operation
size = get_model_size(pipe)
ram = awailable_ram()
logger.debug(f'{model_id} has size {size}, ram {ram}')
logger.debug(f'{descriptor} has size {size}, ram {ram}')
if ram < size * 2.5 and self._cpu_pipes:
key_to_delete = random.choice(list(self._cpu_pipes.keys()))
self._cpu_pipes.pop(key_to_delete)
item = pipe
if pipe.device.type == 'cuda':
item = cp.deepcopy(pipe).to('cpu')
self._cpu_pipes[model_id] = item
logger.debug(f'storing {model_id} on cpu')
device = pipe.device
logger.debug("deepcopy pipe from gpu to save it in cpu cache")
item = cp.deepcopy(pipe.to('cpu'))
pipe.to(device)
self._cpu_pipes[descriptor] = item
logger.debug(f'storing {descriptor} on cpu')
assert pipe.device == device
if pipe.device.type == 'cuda':
self._store_gpu_pipe(pipe, model_id)
logger.debug(f'storing {model_id} on {pipe.device}')
self._store_gpu_pipe(pipe, descriptor)
logger.debug(f'storing {descriptor} on {pipe.device}')

def clear_cache(self, device):
logger.debug(f'clear_cache pipelines from {device}')
with self._lock:
if device.type == 'cuda':
self._gpu_pipes[device.index] = []

def _store_gpu_pipe(self, pipe, model_id):
def _store_gpu_pipe(self, pipe, descriptor: ModelDescriptor):
idx = pipe.device.index
assert idx is not None
# for now just clear all other pipelines
self._gpu_pipes[idx] = [(model_id, pipe)]
self._gpu_pipes[idx] = [(descriptor, pipe)]

def remove_pipeline(self, model_id):
self._cpu_pipes.pop(model_id)

def get_pipeline(self, model_id, device=None):
def get_pipeline(self, descriptor: Union[ModelDescriptor, str], device=None):
with self._lock:
if device is None:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu', 0)
if device.type == 'cuda':
idx = device.index
gpu_pipes = self._gpu_pipes.get(idx, ())
for (key, value) in gpu_pipes:
if key == model_id:
if key == descriptor:
return value
for (key, pipe) in self._cpu_pipes.items():
if key == model_id:
if key == descriptor:
return pipe

return None


def count_params(model):
total_size = sum(param.numel() for param in model.parameters())
mul = 2
if model.dtype in (torch.float16, torch.bfloat16):
mul = 2
elif model.dtype == torch.float32:
mul = 4
return total_size * mul


def get_size(obj):
return sys.getsizeof(obj)


def get_model_size(pipeline):
total_size = 0
for name, component in pipeline.components.items():
if isinstance(component, torch.nn.Module):
total_size += count_params(component)
elif hasattr(component, 'tokenizer'):
total_size += count_params(component.tokenizer)
else:
total_size += get_size(component)
return total_size / (1024 * 1024)


def awailable_ram():
mem = psutil.virtual_memory()
available_ram = mem.available
return available_ram / (1024 * 1024)
19 changes: 11 additions & 8 deletions multigen/pipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,10 @@ def _get_model_type(self):
return ModelType.SD
elif module.startswith('diffusers.pipelines.flux.pipeline_flux'):
return ModelType.FLUX
elif 'masked_stable_diffusion_xl_img2img' in module:
return ModelType.SDXL
else:
raise RuntimeError("unsuported model type {self.pipe.__class__}")
raise RuntimeError(f"unsuported model type {self.pipe.__class__}")

def _initialize_pipe(self, device, offload_device):
# sometimes text encoder is on a different device
Expand Down Expand Up @@ -744,7 +746,8 @@ def __init__(self, model_id, pipe: Optional[StableDiffusionControlNetPipeline] =
else:
raise RuntimeError(f"Unexpected model type {type(self.pipe)}")
self.model_type = t_model_type
logging.debug(f"from_pipe source dtype {self.pipe.dtype}")
device = self.pipe.device
logging.debug(f"from_pipe source dtype {self.pipe.dtype} {device}")
cnets = self._load_cnets(cnets, cnet_ids, args.get('offload_device', None), self.pipe.dtype)
prev_dtype = self.pipe.dtype
if self.model_type == ModelType.SDXL:
Expand All @@ -754,11 +757,11 @@ def __init__(self, model_id, pipe: Optional[StableDiffusionControlNetPipeline] =
else:
self.pipe = self._class.from_pipe(self.pipe, controlnet=cnets)
logging.debug(f"after from_pipe result dtype {self.pipe.dtype}")
for cnet in cnets:
cnet.to(prev_dtype)
logging.debug(f'moving cnet {id(cnet)} to self.pipe.dtype {prev_dtype}')
if 'offload_device' not in args:
cnet.to(self.pipe.device)
for cnet in cnets:
cnet.to(prev_dtype)
logging.debug(f'moving cnet {id(cnet)} to self.pipe.dtype {prev_dtype}')
if 'offload_device' not in args:
cnet.to(device)
else:
# don't load anything, just reuse pipe
super().__init__(model_id=model_id, pipe=pipe, **args)
Expand Down Expand Up @@ -1052,7 +1055,7 @@ def __init__(self, model_id, pipe: Optional[StableDiffusionControlNetPipeline] =
"""
dtype = torch.float32
if torch.cuda.is_available():
dtype = torch.float16
dtype = torch.bfloat16
dtype = args.get('torch_type', dtype)
cnet = ControlNetModel.from_pretrained(
Cond2ImPipe.cpath+Cond2ImPipe.cmodels["inpaint"], torch_dtype=dtype)
Expand Down
Loading

0 comments on commit f112432

Please sign in to comment.