You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I have been playing with an adaptation of jailbreaking models using REFT but when I chat with it it tends to start aggressively repeating in its response and doesn't terminate before hitting max tokens. Only very rarely does it properly use eos tokens. I added deepseek support, but that shouldn't be the issue as it was relatively trivial.
Here's my training:
import json
import random
import torch
import transformers
import pyreft
from pyvene.models.deepseek.modelings_intervenable_deepseek import create_deepseek
device = "cuda" if torch.cuda.is_available() else "cpu"
config, tokenizer, model = create_deepseek()
tokenizer.pad_token = tokenizer.eos_token
layers = [7]
rank = 4
share_weights = False
positions = "f1+l1"
if "+" in positions and not share_weights:
layers = layers*2
first_n, last_n = pyreft.parse_positions(positions)
reft_config = pyreft.ReftConfig(representations=[{
"layer": l, "component": "block_output", "low_rank_dimension": rank, "intervention": pyreft.LoreftIntervention(embed_dim=2048, low_rank_dimension=rank, dtype=torch.bfloat16)
} for l in layers])
reft_model = pyreft.get_reft_model(model, reft_config)
reft_model.set_device(device)
reft_model.print_trainable_parameters()
with open("harmful_behaviors.json", "r") as f:
jailbreak_data = json.load(f)
jailbreak_indices = random.sample(range(len(jailbreak_data)), 250)
jailbreak_data = [jailbreak_data[i] for i in jailbreak_indices]
jailbreak_data_module = pyreft.make_multiple_position_supervised_data_module(
tokenizer,
model,
[
tokenizer.apply_chat_template([
{"role": "user", "content": row["prompt"]}
], tokenize=False, add_generation_prompt=True)
for row in jailbreak_data
],
[
row["answer"] for row in jailbreak_data
],
positions=positions,
num_interventions=len(reft_config.representations),
share_weights=share_weights,
nonstop=True
)
jailbreak_training_args = transformers.TrainingArguments(
num_train_epochs=2,
output_dir="./tmp",
per_device_train_batch_size=2,
learning_rate=5e-5,
report_to=[],
logging_steps=20
)
jailbreak_trainer = pyreft.ReftTrainerForCausalLM(
model=reft_model,
tokenizer=tokenizer,
args=jailbreak_training_args,
**jailbreak_data_module
)
jailbreak_trainer.train()
And here's my inference
import json
import torch
import transformers
import pyreft
import sys
from pyvene.models.deepseek.modelings_intervenable_deepseek import create_deepseek
device = "cuda" if torch.cuda.is_available() else "cpu"
config, tokenizer, model = create_deepseek()
tokenizer.pad_token = tokenizer.eos_token
layers = [7]
rank = 4
share_weights = False
positions = "f1+l1"
if "+" in positions and not share_weights:
layers = layers*2
first_n, last_n = pyreft.parse_positions(positions)
reft_config = pyreft.ReftConfig(representations=[{
"layer": l, "component": "block_output", "low_rank_dimension": rank, "intervention": pyreft.LoreftIntervention(embed_dim=2048, low_rank_dimension=rank, dtype=torch.bfloat16)
} for l in layers])
reft_model = pyreft.ReftModel.load(
"./coder-v0", model
)
reft_model.set_device(device)
def get_model_response(prompt):
# tokenize and prepare the input
print(prompt)
text = tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
tokenize=False, add_generation_prompt=True)
print(text)
inputs = tokenizer(text, return_tensors="pt", add_special_tokens=False).to(device)
unit_locations = torch.IntTensor([pyreft.get_intervention_locations(
last_position=inputs["input_ids"].shape[-1],
first_n=first_n,
last_n=last_n,
pad_mode="last",
num_interventions=len(reft_config.representations),
share_weights=share_weights
)]).permute(1, 0, 2).tolist()
_, reft_response = reft_model.generate(
inputs, unit_locations={"sources->base": (None, unit_locations)},
intervene_on_prompt=True, max_new_tokens=512, do_sample=True,
eos_token_id=tokenizer.eos_token_id
)
return tokenizer.decode(reft_response[0], skip_special_tokens=False)
def main():
print("Enter your message (or 'quit' to exit):")
while True:
try:
user_input = input("> ")
if user_input.lower() == 'quit':
break
response = get_model_response(user_input)
print("\nModel response:")
print(response)
print("\nEnter your message (or 'quit' to exit):")
except KeyboardInterrupt:
print("\nExiting...")
break
except Exception as e:
print(f"\nError: {e}")
print("\nEnter your message (or 'quit' to exit):")
if __name__ == "__main__":
main()
The text was updated successfully, but these errors were encountered:
@mfirth-truffle Hey, sorry for this BIG delay.. Have you solved this problem at all? If not, here are some late pointers:
EOS issue: right, the training above actually explicitly remove the EOS token for any input output pairs because of:
jailbreak_data_module = pyreft.make_multiple_position_supervised_data_module(
tokenizer,
model,
[
tokenizer.apply_chat_template([
{"role": "user", "content": row["prompt"]}
], tokenize=False, add_generation_prompt=True)
for row in jailbreak_data
],
[
row["answer"] for row in jailbreak_data
],
positions=positions,
num_interventions=len(reft_config.representations),
share_weights=share_weights,
nonstop=True <============================= this means no EOS token right-padding for any example.
)
repeating tokens: interesting. the hyperparameters you are using are bit out of the range we explored for cases like jailbreaking or model adaptation:
jailbreak_training_args = transformers.TrainingArguments(
num_train_epochs=2, <======= to small? try much larger
output_dir="./tmp",
per_device_train_batch_size=2,
learning_rate=5e-5, <========= lr can also be too small?
report_to=[],
logging_steps=20
)
what does your training loss look like at the end?
Thanks! Hope these are still helpful.
frankaging
changed the title
Model responses repeat a lot
[P2] Model responses repeat a lot
Dec 19, 2024
I have been playing with an adaptation of jailbreaking models using REFT but when I chat with it it tends to start aggressively repeating in its response and doesn't terminate before hitting max tokens. Only very rarely does it properly use eos tokens. I added deepseek support, but that shouldn't be the issue as it was relatively trivial.
Here's my training:
And here's my inference
The text was updated successfully, but these errors were encountered: