生成3d图像模型
if use_gpu:
!nvidia-smi
contextmanager = chex.fake_pmap
else:
contextmanager = contextlib.suppress
model = DreamField(config)
losses = defaultdict(list)
images = []
log_interval = 2
with contextmanager():
for step, (loss, image, origin, lr) in enumerate(model.run_train(
experiment_dir=experiment_dir, work_unit_dir=work_unit_dir, rng=rng,
yield_results=True)):
UnfilteredStackTrace Traceback (most recent call last)
in
18 experiment_dir=experiment_dir, work_unit_dir=work_unit_dir, rng=rng,
19 yield_results=True)):
TypeError:Value '<jaxlib.tpu_client_extension.PyTpuBuffer object at 0x7f7cd7afc890>' with dtype object is not a valid JAX array type. Only arrays of numeric types are supported by JAX.
要降低jaxlib版本吗,求解答!
求解答!
数组类型不只有数值型,而JAX 仅支持数值类型的数组。