Emu3-practice

最近北京智源放了一个模型叫做emu3,是一个多模态生成模型,恰巧机器学习课程需要有个project,然后就打算微调一下,体验一下AR model 的魅力。

就causal transformer本身其实并不复杂,但是比较复杂或者我不太熟悉的应该是tokenizer,特别是视觉tokenizer是如何组织的,生成过程中如何保证图像规格,各类特殊token的作用,AR模型的CFG是如何在代码层面实现的等问题,所以这篇博客就是带着问题来寻找答案,并记录下来。

emu3的vision tokenizer

vqvae

本质上就是一个vqvae,但是要使得vae能够编解码视频,同时对视频在时间维度上的压缩做到4。对于原始的2Dvae的改进点在于,将res block中的2D conv换成causal 3Dconv,但是这些3Dconv并不同时压缩时间维度,只在2D维度完成压缩之后再单独进行两次的时间维度压缩,解码则是先扩展时间维度,然后再上采样空间维度。对于图片的处理,是将图片在时间维度重复四次作为视频进行压缩或者逆压缩, 重构后只取第一帧。

image token和text token的联合处理

这里我想弄明白的是,vision token和text token是如何share code book的。在emu3中,主要通过Emu3Processor这个类来实现,初始化传入图片预处理器,vqvae以及text tokenizer。

1
processor = Emu3Processor(image_processor, image_tokenizer, tokenizer)

其中最重要的函数是:

1
2
3
4
5
6
7
8
9
10
11
12
13
visual_template=("<|visual token {token_id:0>6d}|>", r"<\|visual token (\d+)\|>")
def to_imgstr(self, image_tokens):
image_tokens = image_tokens.cpu().numpy().tolist()
image_token_str = [
[
self.visual_template[0].format(token_id=token_id)
for token_id in token_row
]
for token_row in image_tokens
]
image_row_str = ["".join(token_row) for token_row in image_token_str]
imgstr = self.tokenizer.eol_token.join(image_row_str)
return imgstr

以上函数说明白了两个问题:

  1. 视觉token到文本token的转化。视觉token会被转化为’<|visual token 000000|>’ 这样的文本token,这在文本tokenizer中相当于第151854个token,依次增长,也就是前151853个是普通本文token,后面的依次是视觉token。完成了token的映射。
  2. 视觉token的编排。可见,图像的编码是通过换行来进行的,没有对图像token做2D的位置编码,这和文章表述一致。

如何保证图像生成规格

这里我想解决的问题是,AR模型生成token是估计token的概率分布,这意味着对于一个64X64的图片生成过程,模型可能在生成过程中第一行生成了64个visual token,第二行生成了67个visual token,这会导致图片并不规格。

对此,emu3的解决方案是约束当前token生成的范围,例如到了第65个token,则严格限制只能生成换行符号,具体的代码实现是通过传递一个PrefixConstrainedLogitsProcessor,并对这个类传递一个判断生成状态的函数。

1
2
3
4
5
constrained_fn = processor.build_prefix_constrained_fn(h, w)
PrefixConstrainedLogitsProcessor(
constrained_fn ,
num_beams=1,
)

最核心的判断逻辑在:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class Emu3PrefixConstrainedLogitsHelper:

def __init__(
self,
height,
width,
img_token,
eoi_token,
eos_token,
eol_token,
eof_token,
pad_token,
visual_tokens,
):
self.height = height
self.width = width
self.img_token = img_token
self.eoi_token = eoi_token
self.eos_token = eos_token
self.eol_token = eol_token
self.eof_token = eof_token
self.pad_token = pad_token
self.visual_tokens = visual_tokens

self.offset_cache = {}

def __call__(self, batch_id, input_ids):
if batch_id not in self.offset_cache:
position = torch.nonzero(input_ids == self.img_token, as_tuple=True)[0][0]
self.offset_cache[batch_id] = position

height = self.height[batch_id] if self.height.shape[0] > 1 else self.height[0]
width = self.width[batch_id] if self.width.shape[0] > 1 else self.width[0]

offset = input_ids.shape[0] - self.offset_cache[batch_id]
height = height.to(offset.device)
width = width.to(offset.device)

if offset % (width + 1) == 0:
return (self.eol_token, )
elif offset == (width + 1) * height + 1:
return (self.eof_token, )
elif offset == (width + 1) * height + 2:
return (self.eoi_token, )
elif offset == (width + 1) * height + 3:
return (self.eos_token, )
elif offset > (width + 1) * height + 3:
return (self.pad_token, )
else:
return self.visual_tokens

可见的,这个prefixconstrain导致了当前的生成策略只能生成一张图片。

根据文章和初步的代码判断,最终的token排列情况应当如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
[BOS] 
{caption text} // caption text tokens
[SOV{boi}] "{H}*{W}" // h * w are meta info tokens
[SOT{img_token}]
[vs] [vs] [vs] [vs] [vs] [vs] [EOL{eol}] // end of line
[vs] [vs] [vs] [vs] [vs] [vs] [EOL{eol}] // vs are vision tokens
[vs] [vs] [vs] [vs] [vs] [vs] [EOL{eol}]
[vs] [vs] [vs] [vs] [vs] [vs] [EOL{eol}]
[vs] [vs] [vs] [vs] [vs] [vs] [EOL{eol}]
[vs] [vs] [vs] [vs] [vs] [vs] [EOL{eol}]
[EOF] // end of frame
[vs] [vs] [vs] [vs] [vs] [vs] [EOL{eol}] // end of line
[vs] [vs] [vs] [vs] [vs] [vs] [EOL{eol}] // vs are vision tokens
[vs] [vs] [vs] [vs] [vs] [vs] [EOL{eol}]
[vs] [vs] [vs] [vs] [vs] [vs] [EOL{eol}]
[vs] [vs] [vs] [vs] [vs] [vs] [EOL{eol}]
[vs] [vs] [vs] [vs] [vs] [vs] [EOL{eol}]
[EOF] // end of frame
[EOV{eoi}]
[EOS]

为了能够让模型实现多帧的生成,可以简单对constrain做简单的修改,并增加新的帧数量的超参。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
class Emu3PrefixConstrainedLogitsHelper:

def __init__(
self,
height,
width,
img_token,
eoi_token,
eos_token,
eol_token,
eof_token,
pad_token,
visual_tokens,
num_frame=1,
):
self.height = height
self.width = width
self.img_token = img_token
self.eoi_token = eoi_token
self.eos_token = eos_token
self.eol_token = eol_token
self.eof_token = eof_token
self.pad_token = pad_token
self.visual_tokens = visual_tokens
self.num_frame = num_frame

self.offset_cache = {}
self.frame_index_cache = {}

def __call__(self, batch_id, input_ids):
if batch_id not in self.offset_cache:
position = torch.nonzero(input_ids == self.img_token, as_tuple=True)[0][0]
self.offset_cache[batch_id] = position
self.frame_index_cache[batch_id] = 0

height = self.height[batch_id] if self.height.shape[0] > 1 else self.height[0]
width = self.width[batch_id] if self.width.shape[0] > 1 else self.width[0]

offset = input_ids.shape[0] - self.offset_cache[batch_id]
height = height.to(offset.device)
width = width.to(offset.device)

if (offset - self.frame_index_cache[batch_id]) % (width + 1) == 0:
return (self.eol_token, )
elif offset % ((width + 1) * height + 1) == 0:
self.frame_index_cache[batch_id] += 1
return (self.eof_token, )
elif offset == ((width + 1) * height + 1) * self.num_frame + 1:
return (self.eoi_token, )
elif offset == ((width + 1) * height + 1) * self.num_frame + 2:
return (self.eos_token, )
elif offset > ((width + 1) * height + 1) * self.num_frame + 3:
return (self.pad_token, )
else:
return self.visual_tokens

如何进行CFG

在扩散模型中,classifier free guidance(CFG)就是对于模型的条件和无条件预测进行线性加权外推,使得模型朝着条件方向进行前进,其中包含一个超参guidance scale,具体公式如下:

1
new_pred = pred_uncond + guidance_scale * (pred_cond - pred_uncond)

这在扩散模型中能够显著提高生成质量,在AR模型用于视觉生成也被证明能够提高生成质量。在AR模型中,比较重要的是在哪个空间进行CFG,是在logits(未归一化),还是在归一化(softmax)之后的空间,还是在对数归一化(log_softmax)空间。

在emu3的代码中,实现CFG的代码位于UnbatchedClassifierFreeGuidanceLogitsProcessor这个类中。核心代码如下:

1
2
3
4
5
6
7
8
9
10
def __call__(self, input_ids, scores):
scores = torch.nn.functional.log_softmax(scores, dim=-1)
if self.guidance_scale == 1:
return scores

logits = self.get_unconditional_logits(input_ids)

unconditional_logits = torch.nn.functional.log_softmax(logits[:, -1], dim=-1)
scores_processed = self.guidance_scale * (scores - unconditional_logits) + unconditional_logits
return scores_processed

可见,CFG在对数归一化空间进行。

Author

Chendong Xiang

Posted on

2024-11-28

Updated on

2025-03-12

Licensed under

Comments