meshgpt-pytorch: MeshTransformer.generate does not work with a prompt if kv cache is enabled
I wanted to experiment how the MeshTransformer is able to complete a mesh by giving the initial codes, but there is a problem where I think the prompt codes are not correctly given down the line. Here is a small debug code:
vertices = torch.randn(2, 100, 3)
faces = torch.randint(0, 100, (2, 100, 3))
# gpt = # Load MeshTransformer from checkpoint
codes = gpt.autoencoder.tokenize(vertices=vertices, faces=faces)
generated = gpt.generate(prompt=codes)
It gives the following error:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
Cell In[421], [line 1](vscode-notebook-cell:?execution_count=421&line=1)
----> [1](vscode-notebook-cell:?execution_count=421&line=1) generated = gpt.generate(prompt=codes)
File [c:\Users\Farid\anaconda3\envs\meshgpt\lib\site-packages\x_transformers\autoregressive_wrapper.py:27](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/autoregressive_wrapper.py:27), in eval_decorator.<locals>.inner(self, *args, **kwargs)
[25](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/autoregressive_wrapper.py:25) was_training = self.training
[26](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/autoregressive_wrapper.py:26) self.eval()
---> [27](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/autoregressive_wrapper.py:27) out = fn(self, *args, **kwargs)
[28](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/autoregressive_wrapper.py:28) self.train(was_training)
[29](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/autoregressive_wrapper.py:29) return out
File [c:\Users\Farid\anaconda3\envs\meshgpt\lib\site-packages\torch\utils\_contextlib.py:115](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/utils/_contextlib.py:115), in context_decorator.<locals>.decorate_context(*args, **kwargs)
[112](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/utils/_contextlib.py:112) @functools.wraps(func)
[113](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/utils/_contextlib.py:113) def decorate_context(*args, **kwargs):
[114](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/utils/_contextlib.py:114) with ctx_factory():
--> [115](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/utils/_contextlib.py:115) return func(*args, **kwargs)
File <@beartype(meshgpt_pytorch.meshgpt_pytorch.MeshTransformer.generate) at 0x2239b61c9d0>:170, in generate(__beartype_func, __beartype_conf, __beartype_get_violation, __beartype_object_2352326455808, __beartype_object_2350005736832, __beartype_object_2349955153872, __beartype_object_140723080033008, __beartype_getrandbits, *args, **kwargs)
File [c:\Users\Farid\anaconda3\envs\meshgpt\lib\site-packages\meshgpt_pytorch\meshgpt_pytorch.py:1238](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1238), in MeshTransformer.generate(self, prompt, batch_size, filter_logits_fn, filter_kwargs, temperature, return_codes, texts, text_embeds, cond_scale, cache_kv, face_coords_to_file)
[1233](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1233) for i in tqdm(range(curr_length, self.max_seq_len)):
[1234](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1234) # v1([q1] [q2] [q1] [q2] [q1] [q2]) v2([eos| q1] [q2] [q1] [q2] [q1] [q2]) -> 0 1 2 3 4 5 6 7 8 9 10 11 12 -> v1(F F F F F F) v2(T F F F F F) v3(T F F F F F)
[1236](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1236) can_eos = i != 0 and divisible_by(i, self.num_quantizers * 3) # only allow for eos to be decoded at the end of each face, defined as 3 vertices with D residual VQ codes
-> [1238](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1238) output = self.forward_on_codes(
[1239](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1239) codes,
[1240](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1240) text_embeds = text_embeds,
[1241](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1241) return_loss = False,
[1242](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1242) return_cache = cache_kv,
[1243](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1243) append_eos = False,
[1244](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1244) cond_scale = cond_scale,
[1245](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1245) cfg_routed_kwargs = dict(
[1246](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1246) cache = cache
[1247](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1247) )
[1248](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1248) )
[1250](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1250) if cache_kv:
[1251](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1251) logits, cache = output
File [c:\Users\Farid\anaconda3\envs\meshgpt\lib\site-packages\classifier_free_guidance_pytorch\classifier_free_guidance_pytorch.py:152](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py:152), in classifier_free_guidance.<locals>.inner(self, cond_scale, rescale_phi, cfg_routed_kwargs, *args, **kwargs)
[148](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py:148) null_fn_kwargs = {k: v[1] for k, v in cfg_routed_kwargs.items()}
[150](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py:150) # non-null forward
--> [152](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py:152) outputs = fn_maybe_with_text(self, *args, **fn_kwargs, **kwargs_without_cond_dropout)
[154](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py:154) if cond_scale == 1:
[155](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py:155) return outputs
File [c:\Users\Farid\anaconda3\envs\meshgpt\lib\site-packages\classifier_free_guidance_pytorch\classifier_free_guidance_pytorch.py:130](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py:130), in classifier_free_guidance.<locals>.inner.<locals>.fn_maybe_with_text(self, *args, **kwargs)
[127](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py:127) if 'raw_text_cond' in fn_params:
[128](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py:128) kwargs.update(raw_text_cond = raw_text_cond)
--> [130](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/classifier_free_guidance_pytorch/classifier_free_guidance_pytorch.py:130) return fn(self, *args, **kwargs)
File [c:\Users\Farid\anaconda3\envs\meshgpt\lib\site-packages\meshgpt_pytorch\meshgpt_pytorch.py:1514](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1514), in MeshTransformer.forward_on_codes(self, codes, return_loss, return_cache, append_eos, cache, texts, text_embeds, cond_drop_prob)
[1511](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1511) if one_face:
[1512](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1512) fine_vertex_codes = fine_vertex_codes[:, :(curr_vertex_pos + 1)]
-> [1514](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1514) attended_vertex_codes, fine_cache = self.fine_decoder(
[1515](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1515) fine_vertex_codes,
[1516](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1516) cache = fine_cache,
[1517](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1517) return_hiddens = True
[1518](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1518) )
[1520](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1520) if not should_cache_fine:
[1521](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/meshgpt_pytorch/meshgpt_pytorch.py:1521) fine_cache = None
File [c:\Users\Farid\anaconda3\envs\meshgpt\lib\site-packages\torch\nn\modules\module.py:1518](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1518), in Module._wrapped_call_impl(self, *args, **kwargs)
[1516](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1516) return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
[1517](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1517) else:
-> [1518](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1518) return self._call_impl(*args, **kwargs)
File [c:\Users\Farid\anaconda3\envs\meshgpt\lib\site-packages\torch\nn\modules\module.py:1527](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1527), in Module._call_impl(self, *args, **kwargs)
[1522](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1522) # If we don't have any hooks, we want to skip the rest of the logic in
[1523](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1523) # this function, and just call forward.
[1524](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1524) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
[1525](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1525) or _global_backward_pre_hooks or _global_backward_hooks
[1526](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1526) or _global_forward_hooks or _global_forward_pre_hooks):
-> [1527](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1527) return forward_call(*args, **kwargs)
[1529](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1529) try:
[1530](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1530) result = None
File [c:\Users\Farid\anaconda3\envs\meshgpt\lib\site-packages\x_transformers\x_transformers.py:1299](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/x_transformers.py:1299), in AttentionLayers.forward(self, x, context, mask, context_mask, attn_mask, self_attn_kv_mask, mems, seq_start_pos, cache, cache_age, return_hiddens, rotary_pos_emb)
[1296](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/x_transformers.py:1296) x = pre_norm(x)
[1298](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/x_transformers.py:1298) if layer_type == 'a':
-> [1299](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/x_transformers.py:1299) out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, return_intermediates = True)
[1300](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/x_transformers.py:1300) elif layer_type == 'c':
[1301](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/x_transformers.py:1301) out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), return_intermediates = True)
File [c:\Users\Farid\anaconda3\envs\meshgpt\lib\site-packages\torch\nn\modules\module.py:1518](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1518), in Module._wrapped_call_impl(self, *args, **kwargs)
[1516](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1516) return self._compiled_call_impl(*args, **kwargs) # type: ignore[misc]
[1517](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1517) else:
-> [1518](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1518) return self._call_impl(*args, **kwargs)
File [c:\Users\Farid\anaconda3\envs\meshgpt\lib\site-packages\torch\nn\modules\module.py:1527](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1527), in Module._call_impl(self, *args, **kwargs)
[1522](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1522) # If we don't have any hooks, we want to skip the rest of the logic in
[1523](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1523) # this function, and just call forward.
[1524](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1524) if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
[1525](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1525) or _global_backward_pre_hooks or _global_backward_hooks
[1526](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1526) or _global_forward_hooks or _global_forward_pre_hooks):
-> [1527](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1527) return forward_call(*args, **kwargs)
[1529](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1529) try:
[1530](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/torch/nn/modules/module.py:1530) result = None
File [c:\Users\Farid\anaconda3\envs\meshgpt\lib\site-packages\x_transformers\x_transformers.py:832](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/x_transformers.py:832), in Attention.forward(self, x, context, mask, context_mask, attn_mask, rel_pos, rotary_pos_emb, prev_attn, mem, return_intermediates, cache)
[829](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/x_transformers.py:829) mk, k = unpack(k, mem_packed_shape, 'b h * d')
[830](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/x_transformers.py:830) mv, v = unpack(v, mem_packed_shape, 'b h * d')
--> [832](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/x_transformers.py:832) k = torch.cat((ck, k), dim = -2)
[833](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/x_transformers.py:833) v = torch.cat((cv, v), dim = -2)
[835](file:///C:/Users/Farid/anaconda3/envs/meshgpt/lib/site-packages/x_transformers/x_transformers.py:835) if exists(mem):
RuntimeError: Sizes of tensors must match except in dimension 2. Expected size 202 but got size 2 for tensor number 1 in the list.
If I disable however the kv cache with generated = gpt.generate(prompt=codes, cache_kv=False)
, it works (albeit being slow).
With the cache, in x_transformers > Attention > forward, ck.shape=[202,16,6,64]
and k.shape=[2, 16, 1, 64]
causing the shape mismatch error (same shapes for cv and v after)
About this issue
- Original URL
- State: closed
- Created 6 months ago
- Reactions: 1
- Comments: 15 (14 by maintainers)
Commits related to this issue
- handle hierarchical kv cache correctly when prompted with mesh codes, addressing https://github.com/lucidrains/meshgpt-pytorch/issues/48 — committed to lucidrains/meshgpt-pytorch by lucidrains 6 months ago
- handle hierarchical kv cache correctly when prompted with mesh codes, addressing https://github.com/lucidrains/meshgpt-pytorch/issues/48 — committed to lucidrains/meshgpt-pytorch by lucidrains 6 months ago
Thanks for looking into this! 😄
Some results on the train set with a gpt trained up to a loss of 3.293. The first row is the ground truth The second row shows the first 100 faces of the ground truth given to the gpt The last row is the gpt output
@lucidrains I do 🚀
Kind of, I work at a company atm, but I’ll soon start a PhD
What’s the name of the company? I didn’t hear something about that anyway
It works perfectly now, thank you for fixing it so fast 🙏
@Kurokabe hey Farid, after a few hours of debugging (hierarchical transformers are confusing), i think i finally figured out the issue
do you want to try 0.5.7 and see if it fixes your original script?
I agree, it’s much better to give it a little push in the right direction. I’ve managed to create multi able different objects using the text generation but it needs such a low loss rate. I haven’t tried prompting it but I’m guessing makes it so it can generate meshes with relative high training loss.
I wonder how much the cross attention with text helps it at the start, it might be interesting testing the impact of the prompt tokens when using text as well. If the text + prompt give much better results vs just text, it might be worth for @lucidrains to revisit to see if there’s any way to increase the impact by the text during the mesh generation.
@Kurokabe not bad! thank you Farid! will let you know once i get this fixed tonight 🚀