Descrição
Estou trabalhando com LightningDataModule
e queria garantir que um método ( _after_init
) seja executado apenas uma vez após a inicialização completa , independentemente da subclassificação. Para isso, implementei uma metaclasse personalizada ( _InitMeta
) que substitui __call__
para invocar _after_init
após a instância ser totalmente criada.
Ao criar uma instância da subclasse final, encontro um KeyError: 'self' dentro de save_hyperparameters().
Criei um exemplo mínimo do código abaixo para ilustrar o problema:
Trecho de código
from typing import Any
from lightning import LightningDataModule
class _InitMeta(type):
def __call__(
cls: Any, *args: Any, **kwargs: Any
) -> Any:
instance = super().__call__(*args, **kwargs) # Create the instance
if hasattr(instance, "_after_init"):
instance._after_init(**kwargs) # Call the method if defined
return instance
class A(LightningDataModule, metaclass=_InitMeta):
def __init__(self, *args, **kwargs):
self.save_hyperparameters()
self.a = 1
self.b = 2
super().__init__(*args, **kwargs)
def print_ab(self, **kwargs: Any):
print("in print ab")
if kwargs.get("flag", False):
print("flag is set to False")
print("some other logic")
else:
print(self.a, self.b)
def _after_init(self, **kwargs):
"""Called only once after full initialization."""
self.print_ab(**kwargs)
class B(A):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.a += 1
self.b += 2
class C(B):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.a += 1
self.b += 2
if __name__ == "__main__":
print("Creating C instance:")
c = C() # Should print 3, 6 only once
print("\nCreating B instance:")
b = B() # Should print 2, 4 only once
print("\nCreating A instance:")
a = A() # Should print 1, 2 only once
Saída de erro
Creating C instance:
Traceback (most recent call last):
File "G:\github-aditya0by0\python-chebai\test.py", line 48, in <module>
c = C() # Should print 3, 6 only once
File "G:\github-aditya0by0\python-chebai\test.py", line 10, in __call__
instance = super().__call__(*args, **kwargs) # Create the instance
File "G:\github-aditya0by0\python-chebai\test.py", line 41, in __init__
super().__init__(**kwargs)
File "G:\github-aditya0by0\python-chebai\test.py", line 34, in __init__
super().__init__(**kwargs)
File "G:\github-aditya0by0\python-chebai\test.py", line 18, in __init__
self.save_hyperparameters()
File "G:\anaconda3\envs\env_chebai\lib\site-packages\lightning\pytorch\core\mixins\hparams_mixin.py", line 112, in save_hyperparameters
save_hyperparameters(self, *args, ignore=ignore, frame=frame)
File "G:\anaconda3\envs\env_chebai\lib\site-packages\lightning\pytorch\utilities\parsing.py", line 165, in save_hyperparameters
for local_args in collect_init_args(frame, [], classes=(HyperparametersMixin,)):
File "G:\anaconda3\envs\env_chebai\lib\site-packages\lightning\pytorch\utilities\parsing.py", line 135, in collect_init_args
return collect_init_args(frame.f_back, path_args, inside=True, classes=classes)
File "G:\anaconda3\envs\env_chebai\lib\site-packages\lightning\pytorch\utilities\parsing.py", line 135, in collect_init_args
return collect_init_args(frame.f_back, path_args, inside=True, classes=classes)
File "G:\anaconda3\envs\env_chebai\lib\site-packages\lightning\pytorch\utilities\parsing.py", line 135, in collect_init_args
return collect_init_args(frame.f_back, path_args, inside=True, classes=classes)
File "G:\anaconda3\envs\env_chebai\lib\site-packages\lightning\pytorch\utilities\parsing.py", line 131, in collect_init_args
local_self, local_args = _get_init_args(frame)
File "G:\anaconda3\envs\env_chebai\lib\site-packages\lightning\pytorch\utilities\parsing.py", line 97, in _get_init_args
local_args = {k: local_vars[k] for k in init_parameters}
File "G:\anaconda3\envs\env_chebai\lib\site-packages\lightning\pytorch\utilities\parsing.py", line 97, in <dictcomp>
local_args = {k: local_vars[k] for k in init_parameters}
KeyError: 'self'
Ambiente
- PyTorch Relâmpago: 2.1.2
- Python: 3.10.14
- Tocha: 2.5.1
Além disso, deixe-me apontar um problema semelhante, mas com origem de erro diferente: https://github.com/Lightning-AI/pytorch-lightning/issues/18405
Embora seu código esteja correto, o problema parece ser mais um bug na função save_hyperparameters(), na biblioteca lightning.
Ele usa introspecção Python para tentar voltar para a
__init__
função, onde foi chamado, e coletar os parâmetros. Fazer uso de uma metaclasse personalizada__call__
aparentemente faz com que ele atinja um caso extremo para o qual eles não estavam preparados - então eles tentam buscar umaself
variável local em um quadro onde ela não existe - provavelmente está inspecionando o quadro da__Call__
própria metaclasse, mas comparando-a com os nomes dos parâmetros para a classe__init__
.Ao analisar o código na biblioteca que aciona o erro ( https://github.com/Lightning-AI/pytorch-lightning/blob/master/src/lightning/pytorch/utilities/parsing.py ), eu diria que eles realmente tentam cobrir todos os casos - e também o arquivo no github é uma versão mais recente (no seu exemplo, o erro é gerado na linha 97, enquanto esse código está na linha 102 na fonte atual) - então eles ficariam felizes em corrigir esse caso extremo extra.
Por favor, informe-os sobre esse exemplo como um bug. Enquanto isso, para que seu trabalho possa continuar, você deve tentar encontrar uma solução alternativa para a
save hyperparameters
chamada.solução alternativa
Então - ao olhar o código deles, mas sem ter tempo ou vontade de depurar todos os casos de inspeção que eles têm lá, eu encontrei esse caminho para você: a função ofensiva procura pela existência de uma
__class__
variável no escopo, caso contrário, ela apenas passa. Qualquer método que use a versão sem parâmetros desuper
(ou mesmo a versão parametrizada), terá__class__
injetado como uma variável não local - esse é o mecanismo usado porsuper()
. Como os problemas estão ocorrendo, provavelmente, na sua metaclasse__call__
, é possível "esconder" o__Class__
não local usando a forma parametrizada de super, além de definir e desconfigurar uma variável local chamada__class__
que deve fazer a função com erro pular o quadro para sua metaclasse.Estes são os testes que fiz aqui:
Como você pode ver, a
__class__
variável está ausente dasB.__init__
localvars no momento da introspecção. Não é assim com C`s.Dito isso, tente alterar seu
__call__
código para:(veja que super() agora usa parâmetros explícitos)
Diga-me sua quilometragem - se não funcionar, talvez eu possa dar uma olhada nisso novamente esta tarde.