Descrição
Estou trabalhando com LightningDataModule
e queria garantir que um método ( _after_init
) seja executado apenas uma vez após a inicialização completa , independentemente da subclassificação. Para isso, implementei uma metaclasse personalizada ( _InitMeta
) que substitui __call__
para invocar _after_init
após a instância ser totalmente criada.
Ao criar uma instância da subclasse final, encontro um KeyError: 'self' dentro de save_hyperparameters().
Criei um exemplo mínimo do código abaixo para ilustrar o problema:
Trecho de código
from typing import Any
from lightning import LightningDataModule
class _InitMeta(type):
def __call__(
cls: Any, *args: Any, **kwargs: Any
) -> Any:
instance = super().__call__(*args, **kwargs) # Create the instance
if hasattr(instance, "_after_init"):
instance._after_init(**kwargs) # Call the method if defined
return instance
class A(LightningDataModule, metaclass=_InitMeta):
def __init__(self, *args, **kwargs):
self.save_hyperparameters()
self.a = 1
self.b = 2
super().__init__(*args, **kwargs)
def print_ab(self, **kwargs: Any):
print("in print ab")
if kwargs.get("flag", False):
print("flag is set to False")
print("some other logic")
else:
print(self.a, self.b)
def _after_init(self, **kwargs):
"""Called only once after full initialization."""
self.print_ab(**kwargs)
class B(A):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.a += 1
self.b += 2
class C(B):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.a += 1
self.b += 2
if __name__ == "__main__":
print("Creating C instance:")
c = C() # Should print 3, 6 only once
print("\nCreating B instance:")
b = B() # Should print 2, 4 only once
print("\nCreating A instance:")
a = A() # Should print 1, 2 only once
Saída de erro
Creating C instance:
Traceback (most recent call last):
File "G:\github-aditya0by0\python-chebai\test.py", line 48, in <module>
c = C() # Should print 3, 6 only once
File "G:\github-aditya0by0\python-chebai\test.py", line 10, in __call__
instance = super().__call__(*args, **kwargs) # Create the instance
File "G:\github-aditya0by0\python-chebai\test.py", line 41, in __init__
super().__init__(**kwargs)
File "G:\github-aditya0by0\python-chebai\test.py", line 34, in __init__
super().__init__(**kwargs)
File "G:\github-aditya0by0\python-chebai\test.py", line 18, in __init__
self.save_hyperparameters()
File "G:\anaconda3\envs\env_chebai\lib\site-packages\lightning\pytorch\core\mixins\hparams_mixin.py", line 112, in save_hyperparameters
save_hyperparameters(self, *args, ignore=ignore, frame=frame)
File "G:\anaconda3\envs\env_chebai\lib\site-packages\lightning\pytorch\utilities\parsing.py", line 165, in save_hyperparameters
for local_args in collect_init_args(frame, [], classes=(HyperparametersMixin,)):
File "G:\anaconda3\envs\env_chebai\lib\site-packages\lightning\pytorch\utilities\parsing.py", line 135, in collect_init_args
return collect_init_args(frame.f_back, path_args, inside=True, classes=classes)
File "G:\anaconda3\envs\env_chebai\lib\site-packages\lightning\pytorch\utilities\parsing.py", line 135, in collect_init_args
return collect_init_args(frame.f_back, path_args, inside=True, classes=classes)
File "G:\anaconda3\envs\env_chebai\lib\site-packages\lightning\pytorch\utilities\parsing.py", line 135, in collect_init_args
return collect_init_args(frame.f_back, path_args, inside=True, classes=classes)
File "G:\anaconda3\envs\env_chebai\lib\site-packages\lightning\pytorch\utilities\parsing.py", line 131, in collect_init_args
local_self, local_args = _get_init_args(frame)
File "G:\anaconda3\envs\env_chebai\lib\site-packages\lightning\pytorch\utilities\parsing.py", line 97, in _get_init_args
local_args = {k: local_vars[k] for k in init_parameters}
File "G:\anaconda3\envs\env_chebai\lib\site-packages\lightning\pytorch\utilities\parsing.py", line 97, in <dictcomp>
local_args = {k: local_vars[k] for k in init_parameters}
KeyError: 'self'
Ambiente
- PyTorch Relâmpago: 2.1.2
- Python: 3.10.14
- Tocha: 2.5.1
Além disso, deixe-me apontar um problema semelhante, mas com origem de erro diferente: https://github.com/Lightning-AI/pytorch-lightning/issues/18405