最近北京智源放了一个模型叫做emu3,是一个多模态生成模型,恰巧机器学习课程需要有个project,然后就打算微调一下,体验一下AR model 的魅力。
就causal transformer本身其实并不复杂,但是比较复杂或者我不太熟悉的应该是tokenizer,特别是视觉tokenizer是如何组织的,生成过程中如何保证图像规格,各类特殊token的作用,AR模型的CFG是如何在代码层面实现的等问题,所以这篇博客就是带着问题来寻找答案,并记录下来。
emu3的vision tokenizer
vqvae
本质上就是一个vqvae,但是要使得vae能够编解码视频,同时对视频在时间维度上的压缩做到4。对于原始的2Dvae的改进点在于,将res block中的2D conv换成causal 3Dconv,但是这些3Dconv并不同时压缩时间维度,只在2D维度完成压缩之后再单独进行两次的时间维度压缩,解码则是先扩展时间维度,然后再上采样空间维度。对于图片的处理,是将图片在时间维度重复四次作为视频进行压缩或者逆压缩, 重构后只取第一帧。
image token和text token的联合处理
这里我想弄明白的是,vision token和text token是如何share code book的。在emu3中,主要通过Emu3Processor这个类来实现,初始化传入图片预处理器,vqvae以及text tokenizer。
1
| processor = Emu3Processor(image_processor, image_tokenizer, tokenizer)
|
其中最重要的函数是:
1 2 3 4 5 6 7 8 9 10 11 12 13
| visual_template=("<|visual token {token_id:0>6d}|>", r"<\|visual token (\d+)\|>") def to_imgstr(self, image_tokens): image_tokens = image_tokens.cpu().numpy().tolist() image_token_str = [ [ self.visual_template[0].format(token_id=token_id) for token_id in token_row ] for token_row in image_tokens ] image_row_str = ["".join(token_row) for token_row in image_token_str] imgstr = self.tokenizer.eol_token.join(image_row_str) return imgstr
|
以上函数说明白了两个问题:
- 视觉token到文本token的转化。视觉token会被转化为’<|visual token 000000|>’ 这样的文本token,这在文本tokenizer中相当于第151854个token,依次增长,也就是前151853个是普通本文token,后面的依次是视觉token。完成了token的映射。
- 视觉token的编排。可见,图像的编码是通过换行来进行的,没有对图像token做2D的位置编码,这和文章表述一致。
如何保证图像生成规格
这里我想解决的问题是,AR模型生成token是估计token的概率分布,这意味着对于一个64X64的图片生成过程,模型可能在生成过程中第一行生成了64个visual token,第二行生成了67个visual token,这会导致图片并不规格。
对此,emu3的解决方案是约束当前token生成的范围,例如到了第65个token,则严格限制只能生成换行符号,具体的代码实现是通过传递一个PrefixConstrainedLogitsProcessor,并对这个类传递一个判断生成状态的函数。
1 2 3 4 5
| constrained_fn = processor.build_prefix_constrained_fn(h, w) PrefixConstrainedLogitsProcessor( constrained_fn , num_beams=1, )
|
最核心的判断逻辑在:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50
| class Emu3PrefixConstrainedLogitsHelper:
def __init__( self, height, width, img_token, eoi_token, eos_token, eol_token, eof_token, pad_token, visual_tokens, ): self.height = height self.width = width self.img_token = img_token self.eoi_token = eoi_token self.eos_token = eos_token self.eol_token = eol_token self.eof_token = eof_token self.pad_token = pad_token self.visual_tokens = visual_tokens
self.offset_cache = {}
def __call__(self, batch_id, input_ids): if batch_id not in self.offset_cache: position = torch.nonzero(input_ids == self.img_token, as_tuple=True)[0][0] self.offset_cache[batch_id] = position
height = self.height[batch_id] if self.height.shape[0] > 1 else self.height[0] width = self.width[batch_id] if self.width.shape[0] > 1 else self.width[0]
offset = input_ids.shape[0] - self.offset_cache[batch_id] height = height.to(offset.device) width = width.to(offset.device)
if offset % (width + 1) == 0: return (self.eol_token, ) elif offset == (width + 1) * height + 1: return (self.eof_token, ) elif offset == (width + 1) * height + 2: return (self.eoi_token, ) elif offset == (width + 1) * height + 3: return (self.eos_token, ) elif offset > (width + 1) * height + 3: return (self.pad_token, ) else: return self.visual_tokens
|
可见的,这个prefixconstrain导致了当前的生成策略只能生成一张图片。
根据文章和初步的代码判断,最终的token排列情况应当如下:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
| [BOS] {caption text} [SOV{boi}] "{H}*{W}" [SOT{img_token}] [vs] [vs] [vs] [vs] [vs] [vs] [EOL{eol}] [vs] [vs] [vs] [vs] [vs] [vs] [EOL{eol}] [vs] [vs] [vs] [vs] [vs] [vs] [EOL{eol}] [vs] [vs] [vs] [vs] [vs] [vs] [EOL{eol}] [vs] [vs] [vs] [vs] [vs] [vs] [EOL{eol}] [vs] [vs] [vs] [vs] [vs] [vs] [EOL{eol}] [EOF] [vs] [vs] [vs] [vs] [vs] [vs] [EOL{eol}] [vs] [vs] [vs] [vs] [vs] [vs] [EOL{eol}] [vs] [vs] [vs] [vs] [vs] [vs] [EOL{eol}] [vs] [vs] [vs] [vs] [vs] [vs] [EOL{eol}] [vs] [vs] [vs] [vs] [vs] [vs] [EOL{eol}] [vs] [vs] [vs] [vs] [vs] [vs] [EOL{eol}] [EOF] [EOV{eoi}] [EOS]
|
为了能够让模型实现多帧的生成,可以简单对constrain做简单的修改,并增加新的帧数量的超参。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
| class Emu3PrefixConstrainedLogitsHelper:
def __init__( self, height, width, img_token, eoi_token, eos_token, eol_token, eof_token, pad_token, visual_tokens, num_frame=1, ): self.height = height self.width = width self.img_token = img_token self.eoi_token = eoi_token self.eos_token = eos_token self.eol_token = eol_token self.eof_token = eof_token self.pad_token = pad_token self.visual_tokens = visual_tokens self.num_frame = num_frame
self.offset_cache = {} self.frame_index_cache = {}
def __call__(self, batch_id, input_ids): if batch_id not in self.offset_cache: position = torch.nonzero(input_ids == self.img_token, as_tuple=True)[0][0] self.offset_cache[batch_id] = position self.frame_index_cache[batch_id] = 0
height = self.height[batch_id] if self.height.shape[0] > 1 else self.height[0] width = self.width[batch_id] if self.width.shape[0] > 1 else self.width[0]
offset = input_ids.shape[0] - self.offset_cache[batch_id] height = height.to(offset.device) width = width.to(offset.device)
if (offset - self.frame_index_cache[batch_id]) % (width + 1) == 0: return (self.eol_token, ) elif offset % ((width + 1) * height + 1) == 0: self.frame_index_cache[batch_id] += 1 return (self.eof_token, ) elif offset == ((width + 1) * height + 1) * self.num_frame + 1: return (self.eoi_token, ) elif offset == ((width + 1) * height + 1) * self.num_frame + 2: return (self.eos_token, ) elif offset > ((width + 1) * height + 1) * self.num_frame + 3: return (self.pad_token, ) else: return self.visual_tokens
|
如何进行CFG
在扩散模型中,classifier free guidance(CFG)就是对于模型的条件和无条件预测进行线性加权外推,使得模型朝着条件方向进行前进,其中包含一个超参guidance scale,具体公式如下:
1
| new_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond)
|
这在扩散模型中能够显著提高生成质量,在AR模型用于视觉生成也被证明能够提高生成质量。在AR模型中,比较重要的是在哪个空间进行CFG,是在logits(未归一化),还是在归一化(softmax)之后的空间,还是在对数归一化(log_softmax)空间。
在emu3的代码中,实现CFG的代码位于UnbatchedClassifierFreeGuidanceLogitsProcessor这个类中。核心代码如下:
1 2 3 4 5 6 7 8 9 10
| def __call__(self, input_ids, scores): scores = torch.nn.functional.log_softmax(scores, dim=-1) if self.guidance_scale == 1: return scores
logits = self.get_unconditional_logits(input_ids)
unconditional_logits = torch.nn.functional.log_softmax(logits[:, -1], dim=-1) scores_processed = self.guidance_scale * (scores - unconditional_logits) + unconditional_logits return scores_processed
|
可见,CFG在对数归一化空间进行。