如何将 Kaggle 上的 TensorFlow 1“BigGAN”模型转换为 TensorFlow Lite 格式?
https://www.tensorflow.org/hub/tutorials/biggan_ Generation_with_tf_hub?hl=ja
我正在尝试实现以下源代码来定义 TensorFlow.Keras 模型并将其转换为 TensorFlow Lite 格式,但是当我尝试使用 hub.KerasLayer 创建的 Layer 在 Function API 中定义模型时,出现以下异常:
而且我也不知道如何指定 model.build() 的参数。
此模型需要多个输入,因此我无法使用该解决方案通过以下 URL 中显示的 Sequential API 构建模型。 如何保存输入形状为 (1, None, None, 3) 且 None 固定为 256 的模型?
有没有什么好的解决办法?
无法运行的源代码
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
from tensorflow.compat.v1 import keras
from tensorflow.compat.v1.keras import layers
import tensorflow_hub as hub
input_truncation = keras.Input(shape=(), name='truncation')
input_y = keras.Input(shape=(1000, ), name='y')
input_z = keras.Input(shape=(128, ), name='z')
hub_layer = hub.KerasLayer(
"https://www.kaggle.com/models/deepmind/biggan/TensorFlow1/128/2",
trainable=False,
signature="default",
signature_outputs_as_dict=True,
input_shape = [[], [1000], [128]],
output_shape = [128, 128, 3],
)
### -->> TypeError occur HERE
output = hub_layer([input_truncation, input_y, input_z])
### <<-- TypeError occur HERE
model = tf.keras.models.Model(inputs=[input_truncation, input_y, input_z], outputs=[output])
### -->> How to build this model
model.build([1], [1, 1000], [1, 128])
### <<-- How to build this model
model.summary()
model.save("biggan-128")
converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
发生的异常
TypeError
in user code:
File "/home/shino/anaconda3/envs/movenet/lib/python3.10/site-packages/tensorflow_hub/keras_layer.py", line 242, in call *
result = f()
TypeError: pruned(truncation, y, z) takes 0 positional arguments, got 1.
File "/tmp/__autograph_generated_fileedgegq8b.py", line 74, in tf__call
ag__.if_stmt(ag__.not_(ag__.ld(self)._has_training_argument), if_body_3, else_body_3, get_state_3, set_state_3, ('result', 'training'), 1)
File "/tmp/__autograph_generated_fileedgegq8b.py", line 37, in if_body_3
result = ag__.converted_call(ag__.ld(f), (), None, fscope)
TypeError: pruned(truncation, y, z) takes 0 positional arguments, got 1.
During handling of the above exception, another exception occurred:
File "/tmp/__autograph_generated_fileedgegq8b.py", line 37, in if_body_3
result = ag__.converted_call(ag__.ld(f), (), None, fscope)
File "/tmp/__autograph_generated_fileedgegq8b.py", line 74, in tf__call
ag__.if_stmt(ag__.not_(ag__.ld(self)._has_training_argument), if_body_3, else_body_3, get_state_3, set_state_3, ('result', 'training'), 1)
File "/home/shino/sandbox/python/biggan/biggan_export.py", line 21, in <module>
output = hub_layer([input_truncation, input_y, input_z])
TypeError: in user code:
File "/home/shino/anaconda3/envs/movenet/lib/python3.10/site-packages/tensorflow_hub/keras_layer.py", line 242, in call *
result = f()
TypeError: pruned(truncation, y, z) takes 0 positional arguments, got 1.
要求.txt
absl-py==2.2.1
astunparse==1.6.3
cachetools==5.5.2
certifi==2025.1.31
charset-normalizer==3.4.1
coloredlogs==15.0.1
flatbuffers==1.12
gast==0.4.0
google-auth==2.38.0
google-auth-oauthlib==0.4.6
google-pasta==0.2.0
grpcio==1.71.0
h5py==3.13.0
humanfriendly==10.0
idna==3.10
keras==2.9.0
Keras-Preprocessing==1.1.2
libclang==18.1.1
Markdown==3.7
MarkupSafe==3.0.2
mpmath==1.3.0
numpy==1.26.4
oauthlib==3.2.2
onnx==1.14.1
onnx-graphsurgeon==0.5.7
onnx2tf==1.26.9
onnxruntime==1.21.0
opt_einsum==3.4.0
packaging==24.2
protobuf==3.20.3
psutil==7.0.0
pyasn1==0.6.1
pyasn1_modules==0.4.2
requests==2.32.3
requests-oauthlib==2.0.0
rsa==4.9
six==1.17.0
sng4onnx==1.0.4
sympy==1.13.3
tensorboard==2.9.0
tensorboard-data-server==0.6.1
tensorboard-plugin-wit==1.8.1
tensorflow==2.9.0
tensorflow-estimator==2.9.0
tensorflow-hub==0.16.1
tensorflow-io-gcs-filesystem==0.37.1
tensorflow-neuron==1.0
termcolor==3.0.0
tf-keras==2.14.1
tf2onnx==1.13.0
typing_extensions==4.13.0
urllib3==2.3.0
Werkzeug==3.1.3
wrapt==1.17.2