Skip to content

Supportiveness to Lora-tuned model #110

@JTWang2000

Description

@JTWang2000

May I ask does this library support Lora-tuned model?

I followed examples/mnist/compute_influences_manual.py to set up the influence on a lora-tuned model. However, I came across

Traceback (most recent call last):
  File "/proj/myenv/lib/python3.10/runpy.py", line 196, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/proj/myenv/lib/python3.10/runpy.py", line 86, in _run_code
    exec(code, run_globals)
  File "/proj/get_logra_influence.py", line 216, in <module>
    loss = model(**tmp).loss
  File "/proj/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/proj/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/proj/myenv/lib/python3.10/site-packages/peft/peft_model.py", line 1644, in forward
    return self.base_model(
  File "/proj/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/proj/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/proj/myenv/lib/python3.10/site-packages/peft/tuners/tuners_utils.py", line 197, in forward
    return self.model.forward(*args, **kwargs)
  File "/proj/myenv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1189, in forward
    outputs = self.model(
  File "/proj/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/proj/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/proj/myenv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 1000, in forward
    layer_outputs = decoder_layer(
  File "/proj/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/proj/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/proj/myenv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 729, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/proj/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/proj/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/proj/myenv/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 612, in forward
    query_states = self.q_proj(hidden_states)
  File "/proj/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/proj/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/proj/myenv/lib/python3.10/site-packages/peft/tuners/lora/layer.py", line 572, in forward
    result = self.base_layer(x, *args, **kwargs)
  File "/proj/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/proj/myenv/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1581, in _call_impl
    hook_result = hook(self, args, result)
  File "/proj/myenv/lib/python3.10/site-packages/logix/logging/logger.py", line 221, in _grad_hook_fn
    tensor_hook = outputs.register_hook(_grad_backward_hook_fn)
  File "/proj/myenv/lib/python3.10/site-packages/torch/_tensor.py", line 532, in register_hook

Specifically, the code snippet looks like

base_model = AutoModelForCausalLM.from_pretrained(config.base_model_name_or_path, device_map='auto')
model = PeftModel.from_pretrained(base_model, model_name_or_path, device_map='auto')
logix = LogIX(project="test", config="./config.yaml")
scheduler = LogIXScheduler(logix, lora="none",  save='grad')

# Gradient & Hessian logging
logix.watch(model)
id_gen = DataIDGenerator()

for epoch in scheduler:
    for batch in tqdm(train_loader):
        # Start
        inputs, targets, attention_mask = batch["input_ids"].to(DEVICE), batch['labels'].to(DEVICE), batch['attention_mask'].to(DEVICE)
        logix.start(data_id=id_gen(inputs))
        tmp={}
        tmp["input_ids"] = inputs
        tmp["attention_mask"] = attention_mask
        tmp["labels"] = targets
        model.zero_grad()
        loss = model(**tmp).loss
        loss.backward()

        # End
        logix.end()
    logix.finalize()

# Influence Analysis
log_loader = logix.build_log_dataloader(
    batch_size=64, num_workers=0, flatten=args.flatten
)
merged_test_logs = []
logix.setup({"grad": ["log"]})
logix.eval()
for batch in test_loader:
    # Start
    inputs, targets, attention_mask = batch["input_ids"].to(DEVICE), batch['labels'].to(DEVICE), batch[
        'attention_mask'].to(DEVICE)
    logix.start(data_id=id_gen(inputs))
    model.zero_grad()
    tmp = {}
    tmp["input_ids"] = inputs
    tmp["attention_mask"] = attention_mask
    tmp["labels"] = targets
    loss = model(**tmp).loss
    loss.backward()

    test_log = logix.get_log()
    merged_test_logs.append(copy.deepcopy(test_log))
    # End
    # logix.end()
logix.finalize()
merged_test_log = merge_logs(merged_test_logs)
result = logix.influence.compute_influence_all(merged_test_log, log_loader, damping=args.damping)

My Logix version is 0.1.1

May I ask if the current version supports lora-tuned model? If it does support, if my code snippet has anything wrong? Thanks.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions