AnglE: RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

I am trying to train my model using LLAMA-v2-nli. I was able to do so with the bert-nli model but when I try to run with LLAMA I get the following error:

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

`from angle_emb import AnglE, AngleDataTokenizer angle = AnglE.from_pretrained(‘NousResearch/Llama-2-7b-hf’, pretrained_lora_path=‘SeanLee97/angle-llama-7b-nli-v2’).cuda() train_ds = ds[‘train’].shuffle().map(AngleDataTokenizer(angle.tokenizer, angle.max_length), num_proc=8) valid_ds = ds[‘valid’].map(AngleDataTokenizer(angle.tokenizer, angle.max_length), num_proc=8) test_ds = ds[‘test’].map(AngleDataTokenizer(angle.tokenizer, angle.max_length), num_proc=8)

angle.fit( train_ds=train_ds, valid_ds=test_ds, output_dir=‘ckpts/sts-b’, batch_size=16, epochs=5, learning_rate=2e-5, save_steps=100, eval_steps=1000, warmup_steps=0, gradient_accumulation_steps=1, loss_kwargs={ ‘w1’: 1.0, ‘w2’: 1.0, ‘w3’: 1.0, ‘cosine_tau’: 20, ‘ibn_tau’: 20, ‘angle_tau’: 1.0 }, fp16=True, logging_steps=100 )`

I used the same code for bert (loaded the bert model instead) and it works no issues

About this issue

  • Original URL
  • State: closed
  • Created 6 months ago
  • Comments: 28 (14 by maintainers)

Most upvoted comments

hello, I meet an error “RuntimeError: No GPU found. A GPU is needed for quantization.”.Could you tell me your requirements like python-version, pytorch-version, cuda-version and so on?

hi @rickeyhhh, may I know your GPU version?

below is my environment:

GPU: 3090 Ti CUDA: 12.2

python libraries:

bitsandbytes                  0.41.1
torch                         2.0.1
torchvision                   0.15.2a0
transformers                  4.34.0

Thanks for your reply,I have already solved this problem by matching the torch and CUDA versions.

And I met another problem called : ZeroDivisionError: integer division or modulo by zero I solved this problem by using lower version of bitsandbytes(0.39.0).