Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stable Diffusion Rough Draft #12

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
265 changes: 265 additions & 0 deletions poet/architectures/diffusion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,265 @@
from poet.power_computation import *
from power_computation_transformer import *


def stablediffusion_256(training, iterations, text_encoding, noise=None, image=(3, 256, 256)):
layers = [InputLayer(*image)] if training else [InputLayer(*noise)]
if training:
vaeencodeblock(layers)
for _ in range(iterations):
unet(layers, text_encoding)
if not training:
vaedecodeblock(layers)
return layers


def vaeencodeblock(layers):
layers.append(Conv2dLayer(3, 3, (3,3), 1, (1, 1), layers[-1]))
layers.extend(vaeresnetshort(3, 32, 1, (1,1), layers[-1]))
layers.extend(vaeresnet(32, 32, 1, (1,1), layers[-1]))
layers.extend(vaeresnet(32, 32, 1, (1,1), layers[-1]))
layers.append(Conv2dLayer(32, 32, (3,3), 2, (0, 0), layers[-1]))
layers.extend(vaeresnetshort(32, 64, 1, (1,1), layers[-1]))
layers.extend(vaeresnet(64, 64, 1, (1,1), layers[-1]))
layers.extend(vaeresnet(64, 64, 1, (1,1), layers[-1]))
layers.append(Conv2dLayer(64, 64, (3,3), 2, (0, 0), layers[-1]))
layers.extend(vaeresnetshort(64, 128, 1, (1,1), layers[-1]))
layers.extend(vaeresnet(128, 128, 1, (1,1), layers[-1]))
layers.extend(vaeresnet(128, 128, 1, (1,1), layers[-1]))
layers.append(Conv2dLayer(128, 128, (3,3), 2, (0, 0), layers[-1]))
layers.extend(vaeresnet(128, 128, 1, (1,1), layers[-1]))
layers.extend(vaeresnet(128, 128, 1, (1,1), layers[-1]))
layers.extend(vaeresnet(128, 128, 1, (1,1), layers[-1]))
layers.append(Conv2dLayer(128, 128, (3,3), 2, (0, 0), layers[-1]))
layers.extend(vaeresnet(128, 128, 1, (1,1), layers[-1]))
attention_block(layers)
layers.extend(vaeresnet(128, 128, 1, (1,1), layers[-1]))
layers.append(BatchNorm2d(layers[-1]))
layers.append(Conv2dLayer(128, 128, (3,3), 1, (1, 1), layers[-1]))
return

def vaedecodeblock(layers):
layers.append(Conv2dLayer(128, 128, (3,3), 1, (1, 1), layers[-1]))
layers.extend(vaeresnet(128, 128, 1, (1,1), layers[-1]))
attention_block(layers)
layers.extend(vaeresnet(128, 128, 1, (1,1), layers[-1]))
layers.extend(vaeresnet(128, 128, 1, (1,1), layers[-1]))
layers.extend(vaeresnet(128, 128, 1, (1,1), layers[-1]))
layers.append(Conv2dLayer(128, 128, (3,3), 2, (0, 0), layers[-1]))
layers.extend(vaeresnetshort(128, 64, 1, (1,1), layers[-1]))
layers.extend(vaeresnet(64, 64, 1, (1,1), layers[-1]))
layers.extend(vaeresnet(64, 64, 1, (1,1), layers[-1]))
layers.append(Conv2dLayer(64, 64, (3,3), 2, (0, 0), layers[-1]))
layers.extend(vaeresnetshort(64, 32, 1, (1,1), layers[-1]))
layers.extend(vaeresnet(32, 32, 1, (1,1), layers[-1]))
layers.extend(vaeresnet(32, 32, 1, (1,1), layers[-1]))
layers.append(Conv2dLayer(32, 32, (3,3), 2, (0, 0), layers[-1]))
layers.extend(vaeresnetshort(32, 3, 1, (1,1), layers[-1]))
layers.extend(vaeresnet(3, 3, 1, (1,1), layers[-1]))
layers.extend(vaeresnet(3, 3, 1, (1,1), layers[-1]))
layers.append(Conv2dLayer(3, 3, (3,3), 2, (0, 0), layers[-1]))
layers.append(BatchNorm2d(layers[-1]))
layers.append(Conv2dLayer(3, 3, (3,3), 1, (1, 1), layers[-1]))
return

def unet(layers, attention_input):
#to add: proper timestep embedding structure
layers.append(LinearLayer(128, 512, None))
layers.append(SigmoidLayer(layers[-1]))
layers.append(LinearLayer(512, 512, layers[-1]))
#Downsampling
timestepembed = layers[-1]
layers.extend(resnetblock(128, 128, 1, (1,1), layers[-4], timestepembed))
layers.extend(resnetblock(128, 128, 1, (1,1), layers[-1], timestepembed))
layers.append(Conv2dLayer(128, 128, (3,3), 2, (0, 0), layers[-1]))
layers.extend(resnetblock(128, 128, 1, (1,1), layers[-1], timestepembed))
layers.extend(resnetblock(128, 128, 1, (1,1), layers[-1], timestepembed))
layers.append(skip128 := Conv2dLayer(128, 128, (3,3), 2, (0, 0), layers[-1]))
layers.extend(resnetshortblock(128, 256, 1, (1,1), layers[-1], timestepembed))
layers.extend(resnetblock(256, 256, 1, (1,1), layers[-1], timestepembed))
layers.append(Conv2dLayer(256, 256, (3,3), 2, (0, 0), layers[-1]))
layers.extend(resnetblock(256, 256, 1, (1,1), layers[-1], timestepembed))
layers.extend(resnetblock(256, 256, 1, (1,1), layers[-1], timestepembed))
layers.append(Conv2dLayer(256, 256, (3,3), 2, (0, 0), layers[-1]))
layers.extend(resnetblock(256, 256, 1, (1,1), layers[-1], timestepembed))
layers.extend(resnetblock(256, 256, 1, (1,1), layers[-1], timestepembed))
layers.append(skip256 := Conv2dLayer(256, 256, (3,3), 2, (0, 0), layers[-1]))
attention_block(layers)
layers.extend(resnetshortblock(256, 512, 1, (1,1), layers[-1], timestepembed))
attention_block(layers)
layers.extend(resnetblock(512, 512, 1, (1,1), layers[-1], timestepembed))
layers.append(Conv2dLayer(512, 512, (3,3), 2, (0, 0), layers[-1]))
layers.extend(resnetblock(512, 512, 1, (1,1), layers[-1], timestepembed))
layers.extend(skip512 := resnetblock(512, 512, 1, (1,1), layers[-1], timestepembed))

#Middle Segment
layers.extend(resnetblock(512, 512, 1, (1,1), layers[-1], timestepembed))
attention_block(layers)
layers.extend(resnetblock(512, 512, 1, (1,1), layers[-1], timestepembed))

#Upsampling
layers.append(CropConcatLayer(layers[-1], skip512))
layers.extend(resnetshortblock(1024, 512, 1, (1,1), layers[-1], timestepembed))
layers.append(CropConcatLayer(layers[-1], skip512))
layers.extend(resnetshortblock(1024, 512, 1, (1,1), layers[-1], timestepembed))
layers.append(CropConcatLayer(layers[-1], skip512))
layers.extend(resnetshortblock(1024, 512, 1, (1,1), layers[-1], timestepembed))
layers.append(Conv2dLayer(512, 512, (3,3), 1, (1, 1), layers[-1]))
attention_block(layers)
layers.append(CropConcatLayer(layers[-1], skip512))
layers.extend(resnetshortblock(1024, 512, 1, (1,1), layers[-1], timestepembed))
attention_block(layers)
layers.append(CropConcatLayer(layers[-1], skip512))
layers.extend(resnetshortblock(1024, 512, 1, (1,1), layers[-1], timestepembed))
attention_block(layers)
layers.append(CropConcatLayer(layers[-1], skip256))
layers.extend(resnetshortblock(768, 512, 1, (1,1), layers[-1], timestepembed))
layers.append(Conv2dLayer(512, 512, (3,3), 1, (1, 1), layers[-1]))
layers.append(CropConcatLayer(layers[-1], skip256))
layers.extend(resnetshortblock(768, 256, 1, (1,1), layers[-1], timestepembed))
layers.append(CropConcatLayer(layers[-1], skip256))
layers.extend(resnetshortblock(512, 256, 1, (1,1), layers[-1], timestepembed))
layers.append(CropConcatLayer(layers[-1], skip256))
layers.extend(resnetshortblock(512, 256, 1, (1,1), layers[-1], timestepembed))
layers.append(Conv2dLayer(256, 256, (3,3), 1, (1, 1), layers[-1]))
layers.append(CropConcatLayer(layers[-1], skip256))
layers.extend(resnetshortblock(512, 256, 1, (1,1), layers[-1], timestepembed))
layers.append(CropConcatLayer(layers[-1], skip256))
layers.extend(resnetshortblock(512, 256, 1, (1,1), layers[-1], timestepembed))
layers.append(CropConcatLayer(layers[-1], skip128))
layers.extend(resnetshortblock(384, 256, 1, (1,1), layers[-1], timestepembed))
layers.append(Conv2dLayer(256, 256, (3,3), 1, (1, 1), layers[-1]))
layers.append(CropConcatLayer(layers[-1], skip128))
layers.extend(resnetshortblock(384, 128, 1, (1,1), layers[-1], timestepembed))
layers.append(CropConcatLayer(layers[-1], skip128))
layers.extend(resnetshortblock(256, 128, 1, (1,1), layers[-1], timestepembed))
layers.append(CropConcatLayer(layers[-1], skip128))
layers.extend(resnetshortblock(256, 128, 1, (1,1), layers[-1], timestepembed))
layers.append(Conv2dLayer(128, 128, (3,3), 1, (1, 1), layers[-1]))
layers.append(CropConcatLayer(layers[-1], skip128))
layers.extend(resnetshortblock(256, 128, 1, (1,1), layers[-1], timestepembed))
layers.append(CropConcatLayer(layers[-1], skip128))
layers.extend(resnetshortblock(256, 128, 1, (1,1), layers[-1], timestepembed))
layers.append(CropConcatLayer(layers[-1], skip128))
layers.extend(resnetshortblock(256, 128, 1, (1,1), layers[-1], timestepembed))
layers.append(BatchNorm2d(layers[-1]))
layers.append(SigmoidLayer(layers[-1]))
return

def attention_block(layers, SEQ_LEN=512, HIDDEN_DIM=768, I=64, HEADS=12):
input_layer = layers[-1]
layers.append(QueryKeyValueMatrix(SEQ_LEN, HIDDEN_DIM, I, HEADS, layers[-1])) # Calculate Query
layers.append(QKTMatrix(SEQ_LEN=SEQ_LEN, HIDDEN_DIM=I, I=SEQ_LEN, ATTN_HEADS=HEADS, input=layers[-1])) # QK^T
layers.append(QKTVMatrix(SEQ_LEN, SEQ_LEN, I, HEADS, layers[-1])) # QK^TV
layers.append(LinearLayer(I * HEADS, HIDDEN_DIM, layers[-1]))
# Residual
layers.append(SkipAddLayer(input_layer, layers[-1]))
# FFNs
layers.append(LinearLayerReLU(HIDDEN_DIM, HIDDEN_DIM * 4, layers[-1]))
layers.append(LinearLayer(HIDDEN_DIM * 4, HIDDEN_DIM, layers[-1]))
layers.append(SkipAddLayer(layers[-4], layers[-1]))
return

#Resnet with shorting built in
def resnetshortblock(in_planes, planes, stride, padding, x, time_emb):
kernel_size = (3, 3)
bn1 = BatchNorm2d(x)
conv1 = Conv2dLayer(in_planes, planes, kernel_size, stride, padding, bn1)
skipconv = Conv2dLayer(in_planes, planes, (1, 1), stride, padding, x)
linear = LinearLayer(in_planes, in_planes//2, time_emb)
skip1 = SkipAddLayer(conv1, linear)
bn2 = BatchNorm2d(skip1)
dropout = DropoutLayer(bn2)
conv2 = Conv2dLayer(planes, planes, kernel_size, 1, padding, dropout)
skip2 = SkipAddLayer(skipconv, conv2)
silu = SigmoidLayer(skip2)
return [bn1, conv1, skipconv, linear, skip1, bn2, dropout, conv2, skip2, silu]


def resnetblock(in_planes, planes, stride, padding, x, time_emb):
kernel_size = (3, 3)
bn1 = BatchNorm2d(x)
conv1 = Conv2dLayer(in_planes, planes, kernel_size, stride, padding, bn1)
linear = LinearLayer(in_planes, in_planes//2, time_emb)
skip1 = SkipAddLayer(conv1, linear)
bn2 = BatchNorm2d(skip1)
dropout = DropoutLayer(bn2)
conv2 = Conv2dLayer(planes, planes, kernel_size, 1, padding, dropout)
silu = SigmoidLayer(conv2)
return [bn1, conv1, linear, skip1, bn2, dropout, conv2, silu]

def vaeresnet(in_planes, planes, stride, padding, x):
kernel_size = (3, 3)
bn1 = BatchNorm2d(x)
conv1 = Conv2dLayer(in_planes, planes, kernel_size, stride, padding, bn1)
bn2 = BatchNorm2d(conv1)
dropout = DropoutLayer(bn2)
conv2 = Conv2dLayer(planes, planes, kernel_size, 1, padding, dropout)
silu = SigmoidLayer(conv2)
return [bn1, conv1, bn2, dropout, conv2, silu]

def vaeresnetshort(in_planes, planes, stride, padding, x):
kernel_size = (3, 3)
bn1 = BatchNorm2d(x)
conv1 = Conv2dLayer(in_planes, planes, kernel_size, stride, padding, bn1)
skipconv = Conv2dLayer(in_planes, planes, (1, 1), stride, padding, x)
bn2 = BatchNorm2d(conv1)
dropout = DropoutLayer(bn2)
conv2 = Conv2dLayer(planes, planes, kernel_size, 1, padding, dropout)
skip2 = SkipAddLayer(skipconv, conv2)
silu = SigmoidLayer(skip2)
return [bn1, conv1, skipconv, bn2, dropout, conv2, skip2, silu]
'''
#Block for general the general UNet node operation of two 3x3 Conv2d interleaved with two ReLUs
def unet_block(in_planes, planes, x):
kernel_size = (3, 3)
conv1 = Conv2dLayer(in_planes, planes, kernel_size, 1, (1,1), x)
relu1 = ReLULayer(conv1)
conv2 = Conv2dLayer(planes, planes, kernel_size, 1, (1,1), relu1)
relu2 = ReLULayer(conv2)
return [conv1, relu1, conv2, relu2]


def unet_256(batch_size, input_shape=(3, 256, 256)):
layers = [InputLayer((batch_size, *input_shape))]

#Encoding layers all the way to the bottom, alternating between UNet blocks and MaxPools
#Layer 1 1x572x572 -> 64x284x284
layers.extend(block_64 := unet_block(1, 64, layers[-1]))
layers.append(MaxPool2d((2,2), 2, layers[-1]))
#Layer 2 64x284x284 -> 128x140x140
layers.extend(block_128 := unet_block(64, 128, layers[-1]))
layers.append(MaxPool2d((2,2), 2, layers[-1]))
#Layer 3 128x140x140 -> 256x68x68
layers.extend(block_256 := unet_block(128, 256, layers[-1]))
layers.append(MaxPool2d((2,2), 2, layers[-1]))
#Layer 4 256x68x68 -> 512x32x32
layers.extend(block_512 := unet_block(256, 512, layers[-1]))
layers.append(MaxPool2d((2,2), 2, layers[-1]))

#Bottom Layer 512x32x32->1024x28x28
layers.extend(unet_block(512, 1024, layers[-1]))

#Encoding layers back to the top, cycling between upconv, concatenation, and UNet blocks
#Layer 4 1024x28x28 -> 512x56x56 + 512x56x56 -> 512x52x52
layers.append(Conv2dLayer(1024, 512, (2,2), 2, (1,1), layers[-1]))
layers.append(CropConcatLayer(block_512, layers[-1]))
layers.extend(unet_block(1024, 512, layers[-1]))
#Layer 3 512x52x52 -> 256x104x104 + 256x104x104 -> 256x100x100
layers.append(Conv2dLayer(512, 256, (2,2), 2, (1,1), layers[-1]))
layers.append(CropConcatLayer(block_256, layers[-1]))
layers.extend(unet_block(512, 256, layers[-1]))
#Layer 2 256x100x100 -> 128x200x200 + 128x200x200 -> 128x196x196
layers.append(Conv2dLayer(256, 128, (2,2), 2, (1,1), layers[-1]))
layers.append(CropConcatLayer(block_128, layers[-1]))
layers.extend(unet_block(256, 128, layers[-1]))
#Layer 1 128x196x196 -> 64x392x392 + 64x392x392 -> 64x388x388
layers.append(Conv2dLayer(128, 64, (2,2), 2, (1,1), layers[-1]))
layers.append(CropConcatLayer(block_64, layers[-1]))
layers.extend(unet_block(128, 64, layers[-1]))

#Final 1x1 conv
layers.append(Conv2dLayer(64, 2, (1,1), 1, (1,1), layers[-1]))

return layers
'''
7 changes: 7 additions & 0 deletions poet/power_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,13 @@ def find_outshape(self, inputs):
assert inputs[_index].out_shape == inputs[_index - 1].out_shape
return inputs[0].out_shape

class CropConcatLayer(DNNLayer):
def __init__(self, copycrop, input):
super().__init__(out_shape=self.find_outshape(input), depends_on=[input, copycrop])

def find_outshape(self, input):
B, C, H, W = input.out_shape
return (B, 2*C, H, W)

def get_net_costs(net, device):
compute_energy_list, compute_runtime_list, ram_list, param_ram_list, pagein_cost, pageout_cost = [[] for _ in range(6)]
Expand Down