TypeError:Only arrays of numeric types are supported by JAX

问题遇到的现象和发生背景

生成3d图像模型

问题相关代码,请勿粘贴截图

if use_gpu:

Monitor memory usage

!nvidia-smi

Compilation is slow on Colab GPU

contextmanager = chex.fake_pmap
else:

No-op context manager on TPU

contextmanager = contextlib.suppress

model = DreamField(config)
losses = defaultdict(list)
images = []
log_interval = 2

with contextmanager():
for step, (loss, image, origin, lr) in enumerate(model.run_train(
experiment_dir=experiment_dir, work_unit_dir=work_unit_dir, rng=rng,
yield_results=True)):

UnfilteredStackTrace Traceback (most recent call last)
in
18 experiment_dir=experiment_dir, work_unit_dir=work_unit_dir, rng=rng,

19 yield_results=True)):

运行结果及报错内容

TypeError:Value '<jaxlib.tpu_client_extension.PyTpuBuffer object at 0x7f7cd7afc890>' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.

我的解答思路和尝试过的方法

要降低jaxlib版本吗,求解答!

我想要达到的结果

求解答!

数组类型不只有数值型,而JAX 仅支持数值类型的数组。