Acho esse comportamento bastante contraintuitivo, embora eu suponha que haja uma razão para isso - o numba converte automaticamente meus tipos inteiros numpy diretamente em um int python:
import numba as nb
import numpy as np
print(f"Numba version: {nb.__version__}") # 0.59.0
print(f"NumPy version: {np.__version__}") # 1.23.5
# Explicitly define the signature
sig = nb.uint32(nb.uint32, nb.uint32)
@nb.njit(sig, cache=False)
def test_fn(a, b):
return a * b
res = test_fn(2, 10)
print(f"Result value: {res}") # returns 20
print(f"Result type: {type(res)}") # returns <class 'int'>
Este é um problema porque estou usando o retorno como uma entrada em outra função njit, então recebo um aviso de conversão (e também faço conversões desnecessárias entre as funções njit)
Existe alguma maneira de forçar o numba a me dar np.uint32
um resultado?
--- EDITAR ---
Este é o melhor que consegui fazer sozinho, mas me recuso a acreditar que esta seja a melhor implementação que existe:
# we manually define a return record and pass it as a parameter
res_type = np.dtype([('res', np.uint32)])
sig = nb.void(nb.uint32, nb.uint32, nb.from_dtype(res_type))
@nb.njit(sig, cache=False)
def test_fn(a:np.uint32, b:np.uint32, res: res_type):
res['res'] = a * b
# Call with Python ints (Numba should coerce based on signature)
res = np.recarray(1, dtype=res_type)[0]
res_py_in = test_fn(2, 10, res)
print(f"\nCalled with Python ints:")
print(f"Result value: {res['res']}") # 20
print(f"Result type: {type(res['res'])}") # <class 'numpy.uint32'>