axolotl: RuntimeError: stack expects each tensor to be equal size, but got...

Please check that this issue hasn’t been reported before.

  • I searched previous Bug Reports didn’t find any similar reports.

Expected Behavior

Training should continue after evaluation + logging to wandb

Current behaviour

I’m running Windows 11 in WSL2 and I get the following traceback:

[2023-10-18 03:25:44,364] [INFO] [axolotl.monkeypatch.mistral._prepare_decoder_attention_mask:113] [PID:2711681] [RANK:0] skipping sliding window mask, not broadcastable with attention mask
  0%|                                                                                                                         | 0/184 [00:14<?, ?it/s]
Traceback (most recent call last):
  File "/home/richard/miniforge3/envs/axolotl/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/home/richard/miniforge3/envs/axolotl/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/home/richard/github/axolotl/src/axolotl/cli/train.py", line 51, in <module>
    fire.Fire(do_cli)
  File "/home/richard/miniforge3/envs/axolotl/lib/python3.10/site-packages/fire/core.py", line 141, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
  File "/home/richard/miniforge3/envs/axolotl/lib/python3.10/site-packages/fire/core.py", line 475, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
  File "/home/richard/miniforge3/envs/axolotl/lib/python3.10/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
  File "/home/richard/github/axolotl/src/axolotl/cli/train.py", line 47, in do_cli
    train(cfg=parsed_cfg, cli_args=parsed_cli_args, dataset_meta=dataset_meta)
  File "/home/richard/github/axolotl/src/axolotl/train.py", line 118, in train
    trainer.train(resume_from_checkpoint=resume_from_checkpoint)
  File "/home/richard/miniforge3/envs/axolotl/lib/python3.10/site-packages/transformers/trainer.py", line 1591, in train
    return inner_training_loop(
  File "/home/richard/miniforge3/envs/axolotl/lib/python3.10/site-packages/transformers/trainer.py", line 1984, in _inner_training_loop
    self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
  File "/home/richard/miniforge3/envs/axolotl/lib/python3.10/site-packages/transformers/trainer.py", line 2328, in _maybe_log_save_evaluate
    metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
  File "/home/richard/miniforge3/envs/axolotl/lib/python3.10/site-packages/transformers/trainer.py", line 3094, in evaluate
    self.control = self.callback_handler.on_evaluate(self.args, self.state, self.control, output.metrics)
  File "/home/richard/miniforge3/envs/axolotl/lib/python3.10/site-packages/transformers/trainer_callback.py", line 388, in on_evaluate
    return self.call_event("on_evaluate", args, state, control, metrics=metrics)
  File "/home/richard/miniforge3/envs/axolotl/lib/python3.10/site-packages/transformers/trainer_callback.py", line 406, in call_event
    result = getattr(callback, event)(
  File "/home/richard/github/axolotl/src/axolotl/utils/callbacks.py", line 512, in on_evaluate
    log_table_from_dataloader("Eval", eval_dataloader)
  File "/home/richard/github/axolotl/src/axolotl/utils/callbacks.py", line 472, in log_table_from_dataloader
    predictions = trainer.model.generate(
  File "/home/richard/miniforge3/envs/axolotl/lib/python3.10/site-packages/peft/peft_model.py", line 1022, in generate
    outputs = self.base_model.generate(**kwargs)
  File "/home/richard/miniforge3/envs/axolotl/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/richard/miniforge3/envs/axolotl/lib/python3.10/site-packages/transformers/generation/utils.py", line 1606, in generate
    return self.greedy_search(
  File "/home/richard/miniforge3/envs/axolotl/lib/python3.10/site-packages/transformers/generation/utils.py", line 2454, in greedy_search
    outputs = self(
  File "/home/richard/miniforge3/envs/axolotl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/richard/miniforge3/envs/axolotl/lib/python3.10/site-packages/accelerate/hooks.py", line 164, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/richard/miniforge3/envs/axolotl/lib/python3.10/site-packages/transformers/models/mistral/modeling_mistral.py", line 1045, in forward
    outputs = self.model(
  File "/home/richard/miniforge3/envs/axolotl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/richard/miniforge3/envs/axolotl/lib/python3.10/site-packages/accelerate/hooks.py", line 164, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/richard/github/axolotl/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py", line 536, in mistral_model_forward
    layer_outputs = decoder_layer(
  File "/home/richard/miniforge3/envs/axolotl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/richard/miniforge3/envs/axolotl/lib/python3.10/site-packages/accelerate/hooks.py", line 164, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/richard/github/axolotl/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py", line 614, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/home/richard/miniforge3/envs/axolotl/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/richard/miniforge3/envs/axolotl/lib/python3.10/site-packages/accelerate/hooks.py", line 164, in new_forward
    output = module._old_forward(*args, **kwargs)
  File "/home/richard/github/axolotl/src/axolotl/monkeypatch/mistral_attn_hijack_flash.py", line 206, in flashattn_forward
    qkv = torch.stack(
RuntimeError: stack expects each tensor to be equal size, but got [1, 32, 1, 128] at entry 0 and [1, 32, 4096, 128] at entry 1

Steps to reproduce

Enable wandb logging and val_set_size? I am unsure, but the most recent commit was supposed to fix this error?

Config yaml

base_model: Open-Orca/Mistral-7B-OpenOrca
base_model_config: Open-Orca/Mistral-7B-OpenOrca
model_type: MistralForCausalLM
tokenizer_type: LlamaTokenizer
is_mistral_derived_model: true

load_in_8bit: false
load_in_4bit: true
strict: false

datasets:
  - path: json
    data_files: data/sharegpt/oa-train-32k-axolotl-sharegpt.json
    type: sharegpt
dataset_prepared_path: data/sharegpt/last_run_prepared
val_set_size: 0.01
output_dir: ./oa-mistral-openorca-ckpt

torch_compile: true

sequence_len: 32768
sample_packing: true
pad_to_sequence_len: true
group_by_length: true

adapter: qlora

lora_r: 16
lora_alpha: 32
lora_dropout: 0.05
lora_target_linear: true
lora_fan_in_fan_out:
lora_target_modules:
  - gate_proj
  - down_proj
  - up_proj
  - q_proj
  - v_proj
  - k_proj
  - o_proj

wandb_project: OA-Mistral-7B-OpenOrca
wandb_entity:
wandb_watch:
wandb_run_id:
wandb_log_model:

gradient_accumulation_steps: 32
micro_batch_size: 1
num_epochs: 1
optimizer: adamw_bnb_8bit
lr_scheduler: cosine
learning_rate: 0.00005

# Augmentation techniques
# NEFT https://arxiv.org/abs/2310.05914, set this to a number (paper default is 5) to add noise to embeddings
# currently only supported on Llama and Mistral
noisy_embedding_alpha: 5

train_on_inputs: false
group_by_length: true
bf16: true
fp16: false
tf32: false

gradient_checkpointing: true
early_stopping_patience:
resume_from_checkpoint: 
local_rank:
logging_steps: 1
xformers_attention:
flash_attention: true

save_steps: 20
eval_batch_size: 1
warmup_steps: 10
eval_steps: 20
eval_table_size: 5
eval_table_max_new_tokens: 128
debug:
deepspeed:
weight_decay: 0.0
fsdp:
fsdp_config:
special_tokens:
  bos_token: "<s>"
  eos_token: "</s>"
  unk_token: "<unk>"

Possible solution

No response

Which Operating Systems are you using?

  • Linux
  • macOS
  • Windows

Python Version

3.10.12

axolotl branch-commit

main/a045db0

Acknowledgements

  • My issue title is concise, descriptive, and in title casing.
  • I have searched the existing issues to make sure this bug has not been reported yet.
  • I am using the latest version of axolotl.
  • I have provided enough information for the maintainers to reproduce and diagnose the issue.

About this issue

  • Original URL
  • State: open
  • Created 8 months ago
  • Comments: 18

Most upvoted comments

@dachenlian also don’t set group_by_length: true when using sample packing

@winglian May I ask why?

they are incompatible features and using both I expect might lead to unexpected behavior

@NanoCode012 I think we may have to add a validation that if sample_packing is true and the eval table is enabled, that particular configuration is likely invalid.