From 94ca2fad18b9c5eb4ee0c1f095b799fd969e6025 Mon Sep 17 00:00:00 2001 From: pipiPdesu Date: Thu, 19 Sep 2024 21:12:46 +0800 Subject: [PATCH] fix: optim_string_condense --- nanogcg/gcg.py | 10 +++++++--- nanogcg/utils.py | 2 +- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/nanogcg/gcg.py b/nanogcg/gcg.py index 1125769..3478345 100644 --- a/nanogcg/gcg.py +++ b/nanogcg/gcg.py @@ -322,10 +322,14 @@ def init_buffer(self) -> AttackBuffer: if isinstance(config.optim_str_init, str): init_optim_ids = tokenizer(config.optim_str_init, add_special_tokens=False, return_tensors="pt")["input_ids"].to(model.device) + optim_ids_len = init_optim_ids.shape[1] + logger.info(f"Optimization string length is {optim_ids_len}") if config.buffer_size > 1: - init_buffer_ids = tokenizer(INIT_CHARS, add_special_tokens=False, return_tensors="pt")["input_ids"].squeeze().to(model.device) - init_indices = torch.randint(0, init_buffer_ids.shape[0], (config.buffer_size - 1, init_optim_ids.shape[1])) - init_buffer_ids = torch.cat([init_optim_ids, init_buffer_ids[init_indices]], dim=0) + init_indices = torch.randint(0, len(INIT_CHARS), ((config.buffer_size - 1)*optim_ids_len, )) + init_str = [INIT_CHARS[i] for i in init_indices] + init_str = [" ".join(init_str[i:i+optim_ids_len]) for i in range(0, len(init_str), optim_ids_len)] + init_buffer_ids = tokenizer(init_str, add_special_tokens=False, return_tensors="pt")["input_ids"].to(model.device) + init_buffer_ids = torch.cat([init_buffer_ids, init_optim_ids], dim=0) else: init_buffer_ids = init_optim_ids diff --git a/nanogcg/utils.py b/nanogcg/utils.py index 3ce6d7c..9354f68 100644 --- a/nanogcg/utils.py +++ b/nanogcg/utils.py @@ -5,7 +5,7 @@ from torch import Tensor INIT_CHARS = [ - ".", ",", "!", "?", ";", ":", "(", ")", "[", "]", "{", "}", + ",", ";", ":", "(", ")", "[", "]", "{", "}", "@", "#", "$", "%", "&", "*", "w", "x", "y", "z", ]