AskOverflow.Dev

AskOverflow.Dev Logo AskOverflow.Dev Logo

AskOverflow.Dev Navigation

  • Início
  • system&network
  • Ubuntu
  • Unix
  • DBA
  • Computer
  • Coding
  • LangChain

Mobile menu

Close
  • Início
  • system&network
    • Recentes
    • Highest score
    • tags
  • Ubuntu
    • Recentes
    • Highest score
    • tags
  • Unix
    • Recentes
    • tags
  • DBA
    • Recentes
    • tags
  • Computer
    • Recentes
    • tags
  • Coding
    • Recentes
    • tags
Início / coding / Perguntas / 79583654
Accepted
Yousef
Yousef
Asked: 2025-04-21 02:46:03 +0800 CST2025-04-21 02:46:03 +0800 CST 2025-04-21 02:46:03 +0800 CST

Logits não mudam em uma reimplementação personalizada de um modelo CLIP [PyTorch]

  • 772

O problema

As pontuações de similaridade são quase as mesmas para textos que descrevem uma foto de um gato e de um cachorro (a foto é de um gato).

Cat similarity: tensor([[-3.5724]], grad_fn=<MulBackward0>)
Dog similarity: tensor([[-3.4155]], grad_fn=<MulBackward0>)

O código para o modelo CLIP

O código é baseado no ponto de verificação de openai/clip-vit-base-patch32 . A função encode_text recebe uma entrada bruta e a transforma em embeddings, que posteriormente são inseridos no método forward. Tenho certeza de que os nomes e tamanhos das camadas estão corretos, pois o ponto de verificação se ajusta ao modelo sem erros devido a camadas ausentes ou inesperadas.

class CLIP(nn.Module):
    def __init__(self, project_dim: int = 768, embed_dim: int = 512):
        super(CLIP, self).__init__()

        self.vision_model = ImageEncoder(project_dim = project_dim)
        self.text_model = TextEncoder(embed_dim = embed_dim)
        self.tokenizer = TorchTokenizer()
        
        self.logit_scale = nn.Parameter(torch.ones([]) * 0.7) 
        self.visual_projection = nn.Linear(project_dim, embed_dim, bias = False)
        self.text_projection = nn.Linear(embed_dim, embed_dim, bias = False)

        self.vision_model.eval()
        self.text_model.eval()

    def forward(self, image: torch.Tensor, text_embed: torch.Tensor) -> torch.Tensor:

        " Compute the relationship between image and text  "

        # get fixed size to comply with the checkpoint position_embeddings nn.Embedding(50, embed_dim)
        image = Resize(size=(224, 224))(image)

        image_features = self.vision_model(image)

        # projections
        text_features = self.text_projection(text_embed)
        image_features = self.visual_projection(image_features)
        
        # normalization
        text_features = F.normalize(text_features, dim = -1)
        image_features = F.normalize(image_features, dim = -1)

        logits = self.logit_scale.exp() * (image_features @ text_features.t())

        return logits
    
    def encode_text(self, input_ids, attention_mask = None):
        """ Tokenize (if needed) and encode texts, returning embeddings and mask. Function for ConditionalPromptNorm """

        # tokenize strings if raw text passed
        if attention_mask is None:
            input_ids, attention_mask = self.tokenizer.tokenize(input_ids)
        
        # ensure batch dim
        if input_ids.dim() == 1:
            input_ids = input_ids.unsqueeze(0)

        with torch.no_grad():
            text_emb = self.text_model(input_ids.long(), attention_mask)

        return text_emb

O código para o codificador de texto

Verifiquei se a obtenção do token EOS funciona corretamente. Além disso, os tipos de camadas, como nn.Embedding e nn.Parameter, estão corretos para cada camada, pois entraria em conflito com o ponto de verificação se não fossem do mesmo tipo.

class TextEncoder(nn.Module):
    def __init__(self, embed_dim: int = 512):
        super(TextEncoder, self).__init__()

        vocab_size = 49408

        self.embeddings = nn.Module()
        self.embeddings.token_embedding = nn.Embedding(vocab_size, embed_dim)
        # tokenizer's context_length must be set to 77 tokens
        self.embeddings.position_embedding = nn.Embedding(77, embed_dim) # 77 = context length

        self.encoder = Encoder(embed_size = embed_dim)

        self.final_layer_norm = nn.LayerNorm(embed_dim)

    def forward(self, text: torch.Tensor, attention_mask: torch.Tensor):

        x = self.embeddings.token_embedding(text.long())

        #                       seq_length
        positions = torch.arange(x.size(1))
        pos_embed = self.embeddings.position_embedding(positions)

        x += pos_embed.to(x.dtype).to(x.device)

        # obtain text embeddings
        x = x.permute(1, 0, 2)
        x = self.encoder(x, attention_mask)
        x = x.permute(1, 0, 2)

        # ensure batch dim
        if x.dim() == 2: x = x.unsqueeze(0)
        if attention_mask.dim() == 1: attention_mask = attention_mask.unsqueeze(0)

        # for each batch, get the last token (eos)
        x = x[torch.arange(x.size(0)), text.argmax(dim = -1)]

        return self.final_layer_norm(x)

A classe de atenção é de https://github.com/openai/CLIP/blob/main/clip/model.py#L58 com uma pequena modificação para permitir atenção própria e combinada (x e x[:1]).

O Codificador

Verifiquei se o código do tokenizador funciona corretamente. O MLP é o mesmo do código original do CLIP. Duas camadas lineares com proporção de 4 e um GELU no meio.

class EncoderLayer(nn.Module):
    def __init__(self, embed_size: int = 768, ratio: int = 4, num_heads: int = 8):
        super().__init__()

        self.layer_norm1 = nn.LayerNorm(embed_size)
        self.layer_norm2 = nn.LayerNorm(embed_size)
        self.mlp = MLP(embed_size = embed_size, ratio = ratio)

        self.self_attn = AttentionPool2d(num_heads = num_heads, embed_dim = embed_size)

    def forward(self, x: torch.Tensor, src_pad_key = None):

        x = self.layer_norm1(x)
        
        if src_pad_key is not None: attn_out = self.self_attn(x, src_pad_key = src_pad_key, use_self_attention = True)
        else: attn_out = self.self_attn(x)

        # normalize and apply residual connections
        x += attn_out
        x = self.layer_norm2(x)
        x += self.mlp(x)

        return x

class Encoder(nn.Module):
    def __init__(self, embed_size = 768):
        super().__init__()

        self.layers = nn.ModuleList([EncoderLayer(embed_size = embed_size) for _ in range(12)])

    def forward(self, x: torch.Tensor, attention_mask = None):

        if attention_mask is not None:
            src_key_mask = attention_mask == 0
            if src_key_mask.dim() == 1: src_key_mask = src_key_mask.unsqueeze(0)

            for layer in self.layers:
                x = layer(x, src_key_mask)
        
        else:
            for layer in self.layers:
                x = layer(x)

        return x
python
  • 1 1 respostas
  • 48 Views

1 respostas

  • Voted
  1. Best Answer
    Yousef
    2025-04-22T15:12:44+08:002025-04-22T15:12:44+08:00

    O problema estava no EncoderLayer, onde os cálculos residuais foram feitos incorretamente. A maneira correta de calcular:

        def forward(self, x: torch.Tensor, src_pad_key = None):
            
            residual = x
            x = self.layer_norm1(x)
            
            if src_pad_key is not None: x = self.self_attn(x, src_pad_key = src_pad_key, use_self_attention = True)
            else: x = self.self_attn(x)
    
            # normalize and apply residual connections
            x += residual
    
            residual = x
            x = self.layer_norm2(x)
            x = self.mlp(x)
            x += residual
    
            return x
    

    Outra mudança foi que devemos sempre usar a autoatenção (em vez da atenção combinada), caso contrário os cálculos não funcionarão com o codificador de imagem. [consulta = x]

    Os resultados são assim:

    Cat similarity: tensor([[25.4132]], grad_fn=<MulBackward0>)
    Dog similarity: tensor([[21.8544]], grad_fn=<MulBackward0>)
    cosine cat/dog: 0.8438754677772522
    
    • 1

relate perguntas

  • Como divido o loop for em 3 quadros de dados individuais?

  • Como verificar se todas as colunas flutuantes em um Pandas DataFrame são aproximadamente iguais ou próximas

  • Como funciona o "load_dataset", já que não está detectando arquivos de exemplo?

  • Por que a comparação de string pandas.eval() retorna False

  • Python tkinter/ ttkboostrap dateentry não funciona quando no estado somente leitura

Sidebar

Stats

  • Perguntas 205573
  • respostas 270741
  • best respostas 135370
  • utilizador 68524
  • Highest score
  • respostas
  • Marko Smith

    Reformatar números, inserindo separadores em posições fixas

    • 6 respostas
  • Marko Smith

    Por que os conceitos do C++20 causam erros de restrição cíclica, enquanto o SFINAE antigo não?

    • 2 respostas
  • Marko Smith

    Problema com extensão desinstalada automaticamente do VScode (tema Material)

    • 2 respostas
  • Marko Smith

    Vue 3: Erro na criação "Identificador esperado, mas encontrado 'import'" [duplicado]

    • 1 respostas
  • Marko Smith

    Qual é o propósito de `enum class` com um tipo subjacente especificado, mas sem enumeradores?

    • 1 respostas
  • Marko Smith

    Como faço para corrigir um erro MODULE_NOT_FOUND para um módulo que não importei manualmente?

    • 6 respostas
  • Marko Smith

    `(expression, lvalue) = rvalue` é uma atribuição válida em C ou C++? Por que alguns compiladores aceitam/rejeitam isso?

    • 3 respostas
  • Marko Smith

    Um programa vazio que não faz nada em C++ precisa de um heap de 204 KB, mas não em C

    • 1 respostas
  • Marko Smith

    PowerBI atualmente quebrado com BigQuery: problema de driver Simba com atualização do Windows

    • 2 respostas
  • Marko Smith

    AdMob: MobileAds.initialize() - "java.lang.Integer não pode ser convertido em java.lang.String" para alguns dispositivos

    • 1 respostas
  • Martin Hope
    Fantastic Mr Fox Somente o tipo copiável não é aceito na implementação std::vector do MSVC 2025-04-23 06:40:49 +0800 CST
  • Martin Hope
    Howard Hinnant Encontre o próximo dia da semana usando o cronógrafo 2025-04-21 08:30:25 +0800 CST
  • Martin Hope
    Fedor O inicializador de membro do construtor pode incluir a inicialização de outro membro? 2025-04-15 01:01:44 +0800 CST
  • Martin Hope
    Petr Filipský Por que os conceitos do C++20 causam erros de restrição cíclica, enquanto o SFINAE antigo não? 2025-03-23 21:39:40 +0800 CST
  • Martin Hope
    Catskul O C++20 mudou para permitir a conversão de `type(&)[N]` de matriz de limites conhecidos para `type(&)[]` de matriz de limites desconhecidos? 2025-03-04 06:57:53 +0800 CST
  • Martin Hope
    Stefan Pochmann Como/por que {2,3,10} e {x,3,10} com x=2 são ordenados de forma diferente? 2025-01-13 23:24:07 +0800 CST
  • Martin Hope
    Chad Feller O ponto e vírgula agora é opcional em condicionais bash com [[ .. ]] na versão 5.2? 2024-10-21 05:50:33 +0800 CST
  • Martin Hope
    Wrench Por que um traço duplo (--) faz com que esta cláusula MariaDB seja avaliada como verdadeira? 2024-05-05 13:37:20 +0800 CST
  • Martin Hope
    Waket Zheng Por que `dict(id=1, **{'id': 2})` às vezes gera `KeyError: 'id'` em vez de um TypeError? 2024-05-04 14:19:19 +0800 CST
  • Martin Hope
    user924 AdMob: MobileAds.initialize() - "java.lang.Integer não pode ser convertido em java.lang.String" para alguns dispositivos 2024-03-20 03:12:31 +0800 CST

Hot tag

python javascript c++ c# java typescript sql reactjs html

Explore

  • Início
  • Perguntas
    • Recentes
    • Highest score
  • tag
  • help

Footer

AskOverflow.Dev

About Us

  • About Us
  • Contact Us

Legal Stuff

  • Privacy Policy

Language

  • Pt
  • Server
  • Unix

© 2023 AskOverflow.DEV All Rights Reserve