AskOverflow.Dev

AskOverflow.Dev Logo AskOverflow.Dev Logo

AskOverflow.Dev Navigation

  • Início
  • system&network
  • Ubuntu
  • Unix
  • DBA
  • Computer
  • Coding
  • LangChain

Mobile menu

Close
  • Início
  • system&network
    • Recentes
    • Highest score
    • tags
  • Ubuntu
    • Recentes
    • Highest score
    • tags
  • Unix
    • Recentes
    • tags
  • DBA
    • Recentes
    • tags
  • Computer
    • Recentes
    • tags
  • Coding
    • Recentes
    • tags
Início / coding / Perguntas / 78427847
Accepted
Qazi Fahim Farhan
Qazi Fahim Farhan
Asked: 2024-05-04 13:09:47 +0800 CST2024-05-04 13:09:47 +0800 CST 2024-05-04 13:09:47 +0800 CST

Como tornar a CNN invariante à posição de um padrão na sequência de DNA?

  • 772

Estou tentando fazer a classificação binária encontrando um padrão (digamos "CTCATGTCA") na sequência de DNA usando a CNN. Eu escrevi um modelo em pytorch. Quando o padrão está no centro da sequência, o modelo o detecta. Mas se o padrão estiver em locais aleatórios, o modelo não está funcionando. Como tornar a CNN invariante à posição do padrão?

Este é o meu código:

import logging

import torch
import torch.nn as nn
import torch.optim as optim
import pandas as pd
from sklearn import metrics
from skorch import NeuralNetClassifier
from skorch.callbacks import EpochScoring
from torch.utils.data import DataLoader, Dataset
import numpy as np

import constants

timber = logging.getLogger()
logging.basicConfig(level=logging.INFO)  # change to level=logging.DEBUG to print more logs...


# utils

def one_hot_e(dna_seq: str) -> np.ndarray:
  mydict = {'A': np.asarray([1.0, 0.0, 0.0, 0.0]), 'C': np.asarray([0.0, 1.0, 0.0, 0.0]),
            'G': np.asarray([0.0, 0.0, 1.0, 0.0]), 'T': np.asarray([0.0, 0.0, 0.0, 1.0]),
            'N': np.asarray([0.0, 0.0, 0.0, 0.0]), 'H': np.asarray([0.0, 0.0, 0.0, 0.0]),
            'a': np.asarray([1.0, 0.0, 0.0, 0.0]), 'c': np.asarray([0.0, 1.0, 0.0, 0.0]),
            'g': np.asarray([0.0, 0.0, 1.0, 0.0]), 't': np.asarray([0.0, 0.0, 0.0, 1.0]),
            'n': np.asarray([0.0, 0.0, 0.0, 0.0]), '-': np.asarray([0.0, 0.0, 0.0, 0.0])}

  size_of_a_seq: int = len(dna_seq)

  # forward = np.zeros(shape=(size_of_a_seq, 4))

  forward_list: list = [mydict[dna_seq[i]] for i in range(0, size_of_a_seq)]
  encoded = np.asarray(forward_list)
  return encoded


def one_hot_e_column(column: pd.Series) -> np.ndarray:
  tmp_list: list = [one_hot_e(seq) for seq in column]
  encoded_column = np.asarray(tmp_list)
  return encoded_column


def reverse_dna_seq(dna_seq: str) -> str:
  # m_reversed = ""
  # for i in range(0, len(dna_seq)):
  #     m_reversed = dna_seq[i] + m_reversed
  # return m_reversed
  return dna_seq[::-1]


def complement_dna_seq(dna_seq: str) -> str:
  comp_map = {"A": "T", "C": "G", "T": "A", "G": "C",
              "a": "t", "c": "g", "t": "a", "g": "c",
              "N": "N", "H": "H", "-": "-",
              "n": "n", "h": "h"
              }

  comp_dna_seq_list: list = [comp_map[nucleotide] for nucleotide in dna_seq]
  comp_dna_seq: str = "".join(comp_dna_seq_list)
  return comp_dna_seq


def reverse_complement_dna_seq(dna_seq: str) -> str:
  return reverse_dna_seq(complement_dna_seq(dna_seq))


def reverse_complement_dna_seqs(column: pd.Series) -> pd.Series:
  tmp_list: list = [reverse_complement_dna_seq(seq) for seq in column]
  rc_column = pd.Series(tmp_list)
  return rc_column


class CNN1D(nn.Module):
  def __init__(self,
               in_channel_num_of_nucleotides=4,
               kernel_size_k_mer_motif=4,
               dnn_size=256,
               num_filters=1,
               lstm_hidden_size=128,
               *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.conv1d = nn.Conv1d(in_channels=in_channel_num_of_nucleotides, out_channels=num_filters,
                            kernel_size=kernel_size_k_mer_motif, stride=2)
    self.activation = nn.ReLU()
    self.pooling = nn.MaxPool1d(kernel_size=kernel_size_k_mer_motif, stride=2)

    self.flatten = nn.Flatten()
    # linear layer

    self.dnn2 = nn.Linear(in_features=14 * num_filters, out_features=dnn_size)
    self.act2 = nn.Sigmoid()
    self.dropout2 = nn.Dropout(p=0.2)

    self.out = nn.Linear(in_features=dnn_size, out_features=1)
    self.out_act = nn.Sigmoid()

    pass

  def forward(self, x):
    timber.debug(constants.magenta + f"h0: {x}")
    h = self.conv1d(x)
    timber.debug(constants.green + f"h1: {h}")
    h = self.activation(h)
    timber.debug(constants.magenta + f"h2: {h}")
    h = self.pooling(h)
    timber.debug(constants.blue + f"h3: {h}")
    timber.debug(constants.cyan + f"h4: {h}")

    h = self.flatten(h)
    timber.debug(constants.magenta + f"h5: {h},\n shape {h.shape}, size {h.size}")
    h = self.dnn2(h)
    timber.debug(constants.green + f"h6: {h}")

    h = self.act2(h)
    timber.debug(constants.blue + f"h7: {h}")

    h = self.dropout2(h)
    timber.debug(constants.cyan + f"h8: {h}")

    h = self.out(h)
    timber.debug(constants.magenta + f"h9: {h}")

    h = self.out_act(h)
    timber.debug(constants.green + f"h10: {h}")
    # h = (h > 0.5).float()  # <---- should this go here?
    # timber.debug(constants.green + f"h11: {h}")

    return h


class CustomDataset(Dataset):
  def __init__(self, dataframe):
    self.x = dataframe["Sequence"]
    self.y = dataframe["class"]

  def __len__(self):
    return len(self.y)

  def preprocessing(self, x1, y1) -> (torch.Tensor, torch.Tensor, torch.Tensor):
    forward_col = x1

    backward_col = reverse_complement_dna_seqs(forward_col)

    forward_one_hot_e_col: np.ndarray = one_hot_e_column(forward_col)
    backward_one_hot_e_col: np.ndarray = one_hot_e_column(backward_col)

    tr_xf_tensor = torch.Tensor(forward_one_hot_e_col).permute(1, 2, 0)
    tr_xb_tensor = torch.Tensor(backward_one_hot_e_col).permute(1, 2, 0)
    # timber.debug(f"y1 {y1}")
    tr_y1 = np.array([y1])  # <--- need to put it inside brackets

    return tr_xf_tensor, tr_xb_tensor, tr_y1

  def __getitem__(self, idx):
    m_seq = self.x.iloc[idx]
    labels = self.y.iloc[idx]
    xf, xb, y = self.preprocessing(m_seq, labels)
    timber.debug(f"xf -> {xf.shape}, xb -> {xb.shape}, y -> {y}")
    return xf, xb, y


def test_dataloader():
  df = pd.read_csv("todo.csv")
  X = df["Sequence"]
  y = df["class"]

  ds = CustomDataset(df)
  loader = DataLoader(ds, shuffle=True, batch_size=16)

  train_loader = loader

  for data in train_loader:
    timber.debug(data)
    # xf, xb, y = data[0], data[1], data[2]
    # timber.debug(f"xf -> {xf.shape}, xb -> {xb.shape}, y -> {y.shape}")
  pass


def get_callbacks() -> list:
  # metric.auc ( uses trapezoidal rule) gave an error: x is neither increasing, nor decreasing. so I had to remove it
  return [
    ("tr_acc", EpochScoring(
      metrics.accuracy_score,
      lower_is_better=False,
      on_train=True,
      name="train_acc",
    )),

    ("tr_recall", EpochScoring(
      metrics.recall_score,
      lower_is_better=False,
      on_train=True,
      name="train_recall",
    )),
    ("tr_precision", EpochScoring(
      metrics.precision_score,
      lower_is_better=False,
      on_train=True,
      name="train_precision",
    )),
    ("tr_roc_auc", EpochScoring(
      metrics.roc_auc_score,
      lower_is_better=False,
      on_train=False,
      name="tr_auc"
    )),
    ("tr_f1", EpochScoring(
      metrics.f1_score,
      lower_is_better=False,
      on_train=False,
      name="tr_f1"
    )),
    # ("valid_acc1", EpochScoring(
    #   metrics.accuracy_score,
    #   lower_is_better=False,
    #   on_train=False,
    #   name="valid_acc1",
    # )),
    ("valid_recall", EpochScoring(
      metrics.recall_score,
      lower_is_better=False,
      on_train=False,
      name="valid_recall",
    )),
    ("valid_precision", EpochScoring(
      metrics.precision_score,
      lower_is_better=False,
      on_train=False,
      name="valid_precision",
    )),
    ("valid_roc_auc", EpochScoring(
      metrics.roc_auc_score,
      lower_is_better=False,
      on_train=False,
      name="valid_auc"
    )),
    ("valid_f1", EpochScoring(
      metrics.f1_score,
      lower_is_better=False,
      on_train=False,
      name="valid_f1"
    ))
  ]


def start():

  # df = pd.read_csv("data64.csv")  # use this line
  df = pd.read_csv("data64random.csv")
  X = df["Sequence"]
  y = df["class"]

  npa = np.array([y.values])

  torch_tensor = torch.tensor(npa)  # [0, 1, 1, 0, ... ... ] a simple list
  print(f"torch_tensor: {torch_tensor}")
  # need to transpose it!

  yt = torch.transpose(torch_tensor, 0, 1)

  ds = CustomDataset(df)
  loader = DataLoader(ds, shuffle=True)

  # train_loader = loader
  # test_loader = loader  # todo: load another dataset later

  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  model = CNN1D().to(device)
  m_criterion = nn.BCEWithLogitsLoss
  # optimizer = optim.Adam(model.parameters(), lr=0.001)
  m_optimizer = optim.Adam

  net = NeuralNetClassifier(
    model,
    max_epochs=200,
    criterion=m_criterion,
    optimizer=m_optimizer,
    lr=0.01,
    # decay=0.01,
    # momentum=0.9,

    device=device,
    classes=["no_mqtl", "yes_mqtl"],
    verbose=True,
    callbacks=get_callbacks()
  )

  ohe_c = one_hot_e_column(X)
  print(f"ohe_c shape {ohe_c.shape}")
  ohe_c = torch.Tensor(ohe_c)
  ohe_c = ohe_c.permute(0, 2, 1)
  ohe_c = ohe_c.to(device)
  print(f"ohe_c shape {ohe_c.shape}")

  net.fit(X=ohe_c, y=yt)
  y_proba = net.predict_proba(ohe_c)
  # timber.info(f"y_proba = {y_proba}")
  pass


if __name__ == '__main__':
  start()
  # test_dataloader()
  pass

E você pode encontrar os 2 conjuntos de dados

  1. dna64random.csv (o modelo não funciona com isso)
  2. dna64.csv (o modelo funciona com ele)

Você pode baixar tudo rapidamente usando este link principal

scikit-learn
  • 1 1 respostas
  • 38 Views

1 respostas

  • Voted
  1. Best Answer
    MuhammedYunus
    2024-05-05T06:14:59+08:002024-05-05T06:14:59+08:00

    O modelo abaixo usa Conv1dcamadas e atinge uma precisão de validação de 99% a 100%. Não foram necessários muitos ajustes para chegar lá.

    Treinei em 60% dos dados, sendo 30% usados ​​para validação. O modelo é relativamente pequeno, com cerca de 700 parâmetros.

    O modelo possui um amplo campo receptivo na entrada e corresponde ao comprimento do padrão. O desempenho do modelo foi bastante sensível a esta configuração inicial. As camadas conv subsequentes dobram o tamanho do recurso para 8 e cortam ligeiramente a sequência. Uma camada final de maxpool divide pela metade o comprimento restante da sequência, antes de ser nivelada e mapeada para um escalar por meio de uma camada densa.

    Descobri que NAdam funcionava melhor que Adam e não explorei mais os otimizadores. Tamanhos de lote de 8 e abaixo funcionaram bem.

    ...
    [epoch  13] trn loss: 0.011 [acc: 100.000%] | val loss: 0.011 [acc: 100.000%]
    [epoch  14] trn loss: 0.014 [acc: 100.000%] | val loss: 0.009 [acc:  99.667%]
    [epoch  15] trn loss: 0.008 [acc: 100.000%] | val loss: 0.006 [acc: 100.000%]
    

    insira a descrição da imagem aqui

    Tentei camadas conv em vez de células recorrentes, pois elas tendem a ser mais rápidas e têm bom desempenho. Como o desempenho foi bom, não procurei LSTMs. Seu modelo parece bem grande - tente diminuí-lo e veja se isso melhora as coisas.


    Código para preparar os dados, definir e treinar o modelo e imprimir/plotar os resultados.

    import numpy as np
    from matplotlib import pyplot as plt
    import pandas as pd
    
    import torch
    from torch.utils.data import DataLoader
    from torch import nn
    
    np.random.seed(0)
    
    def one_hot_e(dna_seq: str) -> np.ndarray:
      mydict = {'A': np.asarray([1.0, 0.0, 0.0, 0.0]), 'C': np.asarray([0.0, 1.0, 0.0, 0.0]),
                'G': np.asarray([0.0, 0.0, 1.0, 0.0]), 'T': np.asarray([0.0, 0.0, 0.0, 1.0]),
                'N': np.asarray([0.0, 0.0, 0.0, 0.0]), 'H': np.asarray([0.0, 0.0, 0.0, 0.0]),
                'a': np.asarray([1.0, 0.0, 0.0, 0.0]), 'c': np.asarray([0.0, 1.0, 0.0, 0.0]),
                'g': np.asarray([0.0, 0.0, 1.0, 0.0]), 't': np.asarray([0.0, 0.0, 0.0, 1.0]),
                'n': np.asarray([0.0, 0.0, 0.0, 0.0]), '-': np.asarray([0.0, 0.0, 0.0, 0.0])}
    
      size_of_a_seq: int = len(dna_seq)
    
      # forward = np.zeros(shape=(size_of_a_seq, 4))
    
      forward_list: list = [mydict[dna_seq[i]] for i in range(0, size_of_a_seq)]
      encoded = np.asarray(forward_list)
      return encoded
    
    #
    #Load and prepare data    "CTCATGTCA"
    #
    df = pd.read_csv('data64random.csv')
    
    #To numpy arrays, and encode X
    X = np.stack([one_hot_e(row) for row in df.Sequence], axis=0)
    y = df['class'].values
    
    #Shuffle and split
    train_size = int(0.6 * len(X))
    val_size = int(0.3 * len(X))
    
    shuffle_ixs = np.random.permutation(len(X))
    X, y = [arr[shuffle_ixs] for arr in [X, y]]
    
    X_train, y_train = [arr[:train_size] for arr in [X, y]]
    X_val, y_val = [arr[train_size:train_size + val_size] for arr in [X, y]]
    
    #As tensors. Useful for passing directly to model as a single large batch.
    X_train_t, y_train_t = [torch.tensor(arr).float() for arr in [X_train, y_train]]
    X_val_t, y_val_t = [torch.tensor(arr).float() for arr in [X_val, y_val]]
    
    
    #
    #Define the model
    #
    
    #Lambda layer useful for simple manipulations
    class LambdaLayer(nn.Module):
        def __init__(self, func):
            super().__init__()
            self.func = func
        
        def forward(self, x):
            return self.func(x)
    
    #batch, 64, chan
    
    #The model
    seq_len = X[0].shape[0]    #64 characters long
    n_features = X[0].shape[1] #4-dim onehot encoding
    
    torch.manual_seed(0)
    model = nn.Sequential(
        #> (batch, seq_len, channels)
        
        LambdaLayer(lambda x: x.swapdims(1, 2)),
        #> (batch, channels, seq_len)
        
        #Initial wide receptive field (and it matches length of the pattern)
        nn.Conv1d(in_channels=n_features, out_channels=4, kernel_size=9, padding='same'),
        nn.ReLU(),
        nn.BatchNorm1d(num_features=4),
        
        #Conv block 1 doubles features
        nn.Conv1d(in_channels=4, out_channels=8, kernel_size=3),
        nn.ReLU(),
        nn.BatchNorm1d(num_features=8),
        
        #Conv block 2, then maxpool
        nn.Conv1d(in_channels=8, out_channels=8, kernel_size=3),
        nn.ReLU(),
        nn.BatchNorm1d(num_features=8),
        
        #Output layer: flatten, linear
        nn.MaxPool1d(kernel_size=2, stride=2), #batch, feat, seq
        nn.Flatten(start_dim=1), #batch, feat*seq
        nn.Linear(8 * 30, 1),
    )
    print(
        'Model size is',
        sum([p.numel() for p in model.parameters() if p.requires_grad]),
        'trainable parameters'
    )
    
    #Train loader for batchifying train data
    train_loader = DataLoader(list(zip(X_train_t, y_train_t)), shuffle=True, batch_size=8)
    
    optimiser = torch.optim.NAdam(model.parameters())
    loss_fn = nn.BCEWithLogitsLoss()
    
    from collections import defaultdict
    metrics_dict = defaultdict(list)
    
    for epoch in range(n_epochs := 15):
        model.train()
        cum_loss = 0
        
        for X_minibatch, y_minibatch in train_loader:
            logits = model(X_minibatch).ravel()
            loss = loss_fn(logits, y_minibatch)
            
            optimiser.zero_grad()
            loss.backward()
            optimiser.step()
            
            cum_loss += loss.item() * len(X_minibatch)
        #/end of epoch
        
        time_to_print = (epoch == 0) or ((epoch + 1) % 1) == 0
        if not time_to_print:
            continue
        
        model.eval()
        
        with torch.no_grad():
            val_logits = model(X_val_t).ravel()
            trn_logits = model(X_train_t).ravel()
        
        val_loss = loss_fn(val_logits, y_val_t).item()
        trn_loss = cum_loss / len(X_train_t)
        
        val_acc = ((nn.Sigmoid()(val_logits) > 0.5).int() == y_val_t.int()).float().mean().item()
        trn_acc = ((nn.Sigmoid()(trn_logits) > 0.5).int() == y_train_t.int()).float().mean().item()
        print(
            f'[epoch {epoch + 1:>3d}]',
            f'trn loss: {trn_loss:>5.3f} [acc: {trn_acc:>8.3%}] |',
            f'val loss: {val_loss:>5.3f} [acc: {val_acc:>8.3%}]'
        )
        
        #Record metrics
        metrics_dict['epoch'].append(epoch + 1)
        metrics_dict['trn_loss'].append(trn_loss)
        metrics_dict['val_loss'].append(val_loss)
        metrics_dict['trn_acc'].append(trn_acc)
        metrics_dict['val_acc'].append(val_acc)
    
    #View training curves
    metrics_df = pd.DataFrame(metrics_dict).set_index('epoch')
    ax = metrics_df.plot(
        use_index=True, y=['trn_loss', 'val_loss'], ylabel='loss',
        figsize=(8, 4), legend=False, linewidth=3, marker='s', markersize=7
    )
    
    metrics_df.mul(100).plot(
        use_index=True, y=['trn_acc', 'val_acc'], ylabel='acc', 
        linestyle='--', marker='o', ax=ax.twinx(), legend=False
    )
    ax.figure.legend(ncol=2)
    ax.set_title('training curves')
    
    • 1

relate perguntas

  • Como obter os nomes dos recursos de um OneHotEncoder incorporado em um ColumnTransformer?

Sidebar

Stats

  • Perguntas 205573
  • respostas 270741
  • best respostas 135370
  • utilizador 68524
  • Highest score
  • respostas
  • Marko Smith

    Vue 3: Erro na criação "Identificador esperado, mas encontrado 'import'" [duplicado]

    • 1 respostas
  • Marko Smith

    Por que esse código Java simples e pequeno roda 30x mais rápido em todas as JVMs Graal, mas não em nenhuma JVM Oracle?

    • 1 respostas
  • Marko Smith

    Qual é o propósito de `enum class` com um tipo subjacente especificado, mas sem enumeradores?

    • 1 respostas
  • Marko Smith

    Como faço para corrigir um erro MODULE_NOT_FOUND para um módulo que não importei manualmente?

    • 6 respostas
  • Marko Smith

    `(expression, lvalue) = rvalue` é uma atribuição válida em C ou C++? Por que alguns compiladores aceitam/rejeitam isso?

    • 3 respostas
  • Marko Smith

    Quando devo usar um std::inplace_vector em vez de um std::vector?

    • 3 respostas
  • Marko Smith

    Um programa vazio que não faz nada em C++ precisa de um heap de 204 KB, mas não em C

    • 1 respostas
  • Marko Smith

    PowerBI atualmente quebrado com BigQuery: problema de driver Simba com atualização do Windows

    • 2 respostas
  • Marko Smith

    AdMob: MobileAds.initialize() - "java.lang.Integer não pode ser convertido em java.lang.String" para alguns dispositivos

    • 1 respostas
  • Marko Smith

    Estou tentando fazer o jogo pacman usando apenas o módulo Turtle Random e Math

    • 1 respostas
  • Martin Hope
    Aleksandr Dubinsky Por que a correspondência de padrões com o switch no InetAddress falha com 'não cobre todos os valores de entrada possíveis'? 2024-12-23 06:56:21 +0800 CST
  • Martin Hope
    Phillip Borge Por que esse código Java simples e pequeno roda 30x mais rápido em todas as JVMs Graal, mas não em nenhuma JVM Oracle? 2024-12-12 20:46:46 +0800 CST
  • Martin Hope
    Oodini Qual é o propósito de `enum class` com um tipo subjacente especificado, mas sem enumeradores? 2024-12-12 06:27:11 +0800 CST
  • Martin Hope
    sleeptightAnsiC `(expression, lvalue) = rvalue` é uma atribuição válida em C ou C++? Por que alguns compiladores aceitam/rejeitam isso? 2024-11-09 07:18:53 +0800 CST
  • Martin Hope
    The Mad Gamer Quando devo usar um std::inplace_vector em vez de um std::vector? 2024-10-29 23:01:00 +0800 CST
  • Martin Hope
    Chad Feller O ponto e vírgula agora é opcional em condicionais bash com [[ .. ]] na versão 5.2? 2024-10-21 05:50:33 +0800 CST
  • Martin Hope
    Wrench Por que um traço duplo (--) faz com que esta cláusula MariaDB seja avaliada como verdadeira? 2024-05-05 13:37:20 +0800 CST
  • Martin Hope
    Waket Zheng Por que `dict(id=1, **{'id': 2})` às vezes gera `KeyError: 'id'` em vez de um TypeError? 2024-05-04 14:19:19 +0800 CST
  • Martin Hope
    user924 AdMob: MobileAds.initialize() - "java.lang.Integer não pode ser convertido em java.lang.String" para alguns dispositivos 2024-03-20 03:12:31 +0800 CST
  • Martin Hope
    MarkB Por que o GCC gera código que executa condicionalmente uma implementação SIMD? 2024-02-17 06:17:14 +0800 CST

Hot tag

python javascript c++ c# java typescript sql reactjs html

Explore

  • Início
  • Perguntas
    • Recentes
    • Highest score
  • tag
  • help

Footer

AskOverflow.Dev

About Us

  • About Us
  • Contact Us

Legal Stuff

  • Privacy Policy

Language

  • Pt
  • Server
  • Unix

© 2023 AskOverflow.DEV All Rights Reserve