Skip to content

hamigualisingl/Styled_vae_vq

Repository files navigation

Styled_vae_vq

Author: lidehu [email protected] 在soul app 实习期间的工作

动机:

  • 目标:需要根据聊天上下文,自发的生成图片,比如:a:我去过黄山,秋天去的,可漂亮了.b:巧了,我冬天去的,我还滑雪了呢,(发送对应生成的照片),我之前在长白山滑过雪(发送生成的照片,此时注意人物画像要一致),比黄山还要好玩些呢.a:哇,冬天的黄山真好看,还有云海哎(根据生成的图片做理解).所以我选取了自回归式生成与理解方案,完全的端到端多模态大模型(图像视频理解和生成都做的).
  • 注意:此方案是用来总结图像,对于生成而言,可以理解为它提供了很详细的'图像语言'描述.很多时候,人类是很难按照某种固定规律还不能太过琐碎方式去描述图像的,而且还容易根据简单的文字描述去生成图像描述,详情见重建部分说明.
  • 理解:目前视觉encoder输出token语义弱(cls除外),规则弱,llm注意力重点偏移,出现幻觉;token数量多,多图和视频理解时候增加计算负担(目前减少token的做法都是间接还原图像,bad case很多,图像信息损失多); 通俗来说,说话没有重点,没有规则逻辑,增加理解难度.(经过实验:titok也未观察出规律)
  • 生成:Vq-vae(内容的投影,语义弱,不好学习,类似于把man分词成m,a,n的感觉):自回归按展开顺序预测下一个token,缺乏整体布局观和纠错机制(画错一俩错能纠正回来),而且token语义弱,隐状态对应太多输出,预测难度较大.不论是理解还是生成,哪有像素级别生成的(vq-vae可以理解成像素级别表征),都是素材级别.目前语义信息低,可能无法利用大语言模型自身token与视觉token呼应,统一任务下,可能会出现不同模态能力抢夺问题(参数有限,能力有上限)
  • 解决方案:提出图像语言(语义,规则):Styled_vae_vq,将图像按固定规则离散成36个(根据压缩比确定)从高级到低级的属性token,编码器提供易对齐(高级到低级属性)的有固定规则的序列,提供更优的输入输出映射关系.使用时候和其他视觉编码器拼接.
  • 与titok出发点不同,所以网络结构是不一样的,最初的解码器直接是扩大俩倍的stylegan2,所以这份工作起名style_vae_vq,与titok效果也不同.此份工作在5月底展开,万事靠自己摸索,加上还有其他任务,所以进展缓慢.

下一步改进:

  • 图像复原(目前方案)加图像文字描述还原(用0.5b大小的),像素还原保证了信息的提取质量,确保图像信息都提取出来了(主loss),文字描述的还原,打通图文模态(辅助loss,俩loss量纲不一致,估计要调).我心目中的图像预训练范式(高质量图文数据收集中,连续版本,只为理解而生,敬请期待)

关键部分代码:

  • 编码器部分:
    for r in self.resblocks_exper:
        x = checkpoint(r, x, attn_mask)#最后6层使用了专家,每层37个专家,一个属性一个专家
    mu=self.progject_mean(x[257:]) #降维度到64/128俩个版本,制造信息瓶颈,第二阶段需要量化的值也是这个,计算余弦距离
    ################################################################################
    mu_flattened = mu.view(-1, self.emb_dim)
    similarity = cosine_similarity(mu_flattened.float().unsqueeze(1), self.V2.float(), dim=2)+1
    #其中 self.V2 = nn.Parameter(scale * torch.randn(self.emb_dim,self.emb_dim))#可以看做一个连接层
    similarity= similarity/torch.sum(similarity, dim=-1, keepdim=True)#线性加权,未使用softmax暂时没想好温控策略,这边也可以减少误差的影响,但是操作不当会增加误差
    weighted_sum = self.ln_cosin(torch.matmul(similarity, self.V2))
    output = weighted_sum.view(mu.shape)
    ###############这边是为了减轻第二阶段量化影响的, 保证值连续含义是连续的,这样存在误差也无妨,本来就是总结图像,不指望他还原,目前追求还原质量是担心链路太长,每个环节都差一些,不好找原因.这边图像也会被表示成一个向量36*128/64,用作其他任务.
    
  • 解码器部分-条件(编码器输出)添加方式如下(主要是通过条件利用方式来迫使编码器将图像按固定规则编码成成36个从高级到低级的属性token):
    for index, r in enumerate(self.condtionTransformer):
        x = checkpoint(r, x,conditon[index*2:(index+1)*2], index)#每层添加俩个条件
    
    norm_hidden_states = self.norm1(x,self.ln_1(condation[0]))#自适应归一化
    x=torch.cat([torch.zeros(2, *x.shape[1:], device=x.device, dtype=x.dtype),x], dim=0)##条件先norm,然后参与交互
    norm_hidden_states=torch.cat([self.ln_11(condation[0:1]),self.ln_22(condation[1:2]),norm_hidden_states], dim=0)#条件先交互后norm,
    x = x + self.attn1(norm_hidden_states, norm_hidden_states, norm_hidden_states, need_weights=False, attn_mask=attn_mask)[0]#
    x = x + self.mlp(self.norm2(x,self.ln_2(condation[1])))
    

训练流程

  • 由于任务比较困难,采取俩阶段训练策略,先获取合理的特征,然后量化,合理的特征会增强量化后序列间的逻辑.避免vqvae量化序列跳转困难(fsq更加突出,这个任务下,loss会出现nan),词表崩塌等问题-我们是要通过编码器得到一个合理的序列和特征,不能是依靠训一个强大的解码器拟合回去. 其实目前探索出直接训codebook,且维持大词表下高利用率的方案,但是机器数量不支持这样的大规模实验,只在128张卡训练了8000步,loss一切正常.

  • 阶段一:编码器输出连续值,但是添加容错机制.目前没找到合适的vqgan作为代理模型,后面不直接回归像素,而是预测vqgan编出来的序列,增加容错机会.

  • 阶段二:通过k-means聚类,得到词表(n,128/64).

  • 阶段三:训练llm,预期方案(最近忙着秋招,之后会单起一个项目介绍):使用俩个预测头和俩套词表.

  • 其中生成部分:styled_vae_vq(离散)部分串行预测,后面的vqvae并行预测.

  • 理解部分:使用连续值/连续

  • Environment installation

    pip install -r requirments.txt
    
  • Pretrained Model Weight

    数据集: YFCC15M(随机挑选7.3M)+3.3M(混杂数据集)(最初实验在330w数据规模训练,在YFCC15M测试,重建效果挺不错,所以最终版本只在1060w数据规模训练).超参数:lr 1e-4;12epoch(资源限制);bs 16*64;optim.AdamW,lr_scheduler "cosine".有多个版本,解码器有俩个版本(768/512),编码器有俩个版本(输出128/64).Google Drive(版本:768,128) (768,64通过网盘分享的文件:model_11.pt 链接: https://pan.baidu.com/s/152D-JIHdavImpTLElc0XuQ 提取码: qimi --来自百度网盘超级会员v1的分享)

  • Training

    Start training by run

    bash Styled_vae_vq/run.sh 64 1e-4 /mnt/data/user/lidehu/vae/ALIP/out_put_stage1_6expert_std_noise_1_pect_1  1024 200#注意数据集路径更换!注意最新版本是from modelV2,需要做相应修改
    
  • Use

    python  how_to_use.py#注意图片路径和模型路径更换!注意最新版本是from modelV2,需要做相应修改
    

Acknowledgement

感谢soul app自然语言算法组.

重建效果

  • 注意!没有添加对抗损失.因为图像的特性,模型要把位置对应的vqvae方案(改变其中一个token,解码后对应位置内容也会改变)作为最后一公里,提供了几乎一对一的映射关系(任务难度非常小)后由普通vqvae去还原.
  • 分析:非自然图像,属性提取困难,还原有难度-河马(sd生成的图像),力扣界面.

图像示例,64维度重建效果类似

原始图像 重建图像 原始图像 重建图像
原始图像 1 重建图像 1 原始图像 2 重建图像 2
原始图像 3 重建图像 3 原始图像 4 重建图像 4
原始图像 7 重建图像 7 原始图像 8 重建图像 8
原始图像 9 重建图像 9 原始图像 10 重建图像 10
原始图像 11 重建图像 11 原始图像 12 重建图像 12

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published