O problema
As pontuações de similaridade são quase as mesmas para textos que descrevem uma foto de um gato e de um cachorro (a foto é de um gato).
Cat similarity: tensor([[-3.5724]], grad_fn=<MulBackward0>)
Dog similarity: tensor([[-3.4155]], grad_fn=<MulBackward0>)
O código para o modelo CLIP
O código é baseado no ponto de verificação de openai/clip-vit-base-patch32 . A função encode_text recebe uma entrada bruta e a transforma em embeddings, que posteriormente são inseridos no método forward. Tenho certeza de que os nomes e tamanhos das camadas estão corretos, pois o ponto de verificação se ajusta ao modelo sem erros devido a camadas ausentes ou inesperadas.
class CLIP(nn.Module):
def __init__(self, project_dim: int = 768, embed_dim: int = 512):
super(CLIP, self).__init__()
self.vision_model = ImageEncoder(project_dim = project_dim)
self.text_model = TextEncoder(embed_dim = embed_dim)
self.tokenizer = TorchTokenizer()
self.logit_scale = nn.Parameter(torch.ones([]) * 0.7)
self.visual_projection = nn.Linear(project_dim, embed_dim, bias = False)
self.text_projection = nn.Linear(embed_dim, embed_dim, bias = False)
self.vision_model.eval()
self.text_model.eval()
def forward(self, image: torch.Tensor, text_embed: torch.Tensor) -> torch.Tensor:
" Compute the relationship between image and text "
# get fixed size to comply with the checkpoint position_embeddings nn.Embedding(50, embed_dim)
image = Resize(size=(224, 224))(image)
image_features = self.vision_model(image)
# projections
text_features = self.text_projection(text_embed)
image_features = self.visual_projection(image_features)
# normalization
text_features = F.normalize(text_features, dim = -1)
image_features = F.normalize(image_features, dim = -1)
logits = self.logit_scale.exp() * (image_features @ text_features.t())
return logits
def encode_text(self, input_ids, attention_mask = None):
""" Tokenize (if needed) and encode texts, returning embeddings and mask. Function for ConditionalPromptNorm """
# tokenize strings if raw text passed
if attention_mask is None:
input_ids, attention_mask = self.tokenizer.tokenize(input_ids)
# ensure batch dim
if input_ids.dim() == 1:
input_ids = input_ids.unsqueeze(0)
with torch.no_grad():
text_emb = self.text_model(input_ids.long(), attention_mask)
return text_emb
O código para o codificador de texto
Verifiquei se a obtenção do token EOS funciona corretamente. Além disso, os tipos de camadas, como nn.Embedding e nn.Parameter, estão corretos para cada camada, pois entraria em conflito com o ponto de verificação se não fossem do mesmo tipo.
class TextEncoder(nn.Module):
def __init__(self, embed_dim: int = 512):
super(TextEncoder, self).__init__()
vocab_size = 49408
self.embeddings = nn.Module()
self.embeddings.token_embedding = nn.Embedding(vocab_size, embed_dim)
# tokenizer's context_length must be set to 77 tokens
self.embeddings.position_embedding = nn.Embedding(77, embed_dim) # 77 = context length
self.encoder = Encoder(embed_size = embed_dim)
self.final_layer_norm = nn.LayerNorm(embed_dim)
def forward(self, text: torch.Tensor, attention_mask: torch.Tensor):
x = self.embeddings.token_embedding(text.long())
# seq_length
positions = torch.arange(x.size(1))
pos_embed = self.embeddings.position_embedding(positions)
x += pos_embed.to(x.dtype).to(x.device)
# obtain text embeddings
x = x.permute(1, 0, 2)
x = self.encoder(x, attention_mask)
x = x.permute(1, 0, 2)
# ensure batch dim
if x.dim() == 2: x = x.unsqueeze(0)
if attention_mask.dim() == 1: attention_mask = attention_mask.unsqueeze(0)
# for each batch, get the last token (eos)
x = x[torch.arange(x.size(0)), text.argmax(dim = -1)]
return self.final_layer_norm(x)
A classe de atenção é de https://github.com/openai/CLIP/blob/main/clip/model.py#L58 com uma pequena modificação para permitir atenção própria e combinada (x e x[:1]).
O Codificador
Verifiquei se o código do tokenizador funciona corretamente. O MLP é o mesmo do código original do CLIP. Duas camadas lineares com proporção de 4 e um GELU no meio.
class EncoderLayer(nn.Module):
def __init__(self, embed_size: int = 768, ratio: int = 4, num_heads: int = 8):
super().__init__()
self.layer_norm1 = nn.LayerNorm(embed_size)
self.layer_norm2 = nn.LayerNorm(embed_size)
self.mlp = MLP(embed_size = embed_size, ratio = ratio)
self.self_attn = AttentionPool2d(num_heads = num_heads, embed_dim = embed_size)
def forward(self, x: torch.Tensor, src_pad_key = None):
x = self.layer_norm1(x)
if src_pad_key is not None: attn_out = self.self_attn(x, src_pad_key = src_pad_key, use_self_attention = True)
else: attn_out = self.self_attn(x)
# normalize and apply residual connections
x += attn_out
x = self.layer_norm2(x)
x += self.mlp(x)
return x
class Encoder(nn.Module):
def __init__(self, embed_size = 768):
super().__init__()
self.layers = nn.ModuleList([EncoderLayer(embed_size = embed_size) for _ in range(12)])
def forward(self, x: torch.Tensor, attention_mask = None):
if attention_mask is not None:
src_key_mask = attention_mask == 0
if src_key_mask.dim() == 1: src_key_mask = src_key_mask.unsqueeze(0)
for layer in self.layers:
x = layer(x, src_key_mask)
else:
for layer in self.layers:
x = layer(x)
return x
O problema estava no EncoderLayer, onde os cálculos residuais foram feitos incorretamente. A maneira correta de calcular:
Outra mudança foi que devemos sempre usar a autoatenção (em vez da atenção combinada), caso contrário os cálculos não funcionarão com o codificador de imagem. [consulta = x]
Os resultados são assim: