Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
import string
import logging
from fastchat.model import get_conversation_template
from langchain.chains import LLMChain
from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder
from langchain_core.messages import SystemMessage
from langchain.chains.conversation.memory import ConversationBufferMemory

# Metadata used to store our results
STORE_FOLDER = ''
Expand Down Expand Up @@ -60,4 +64,16 @@ def conv_template(template_name, self_id=None, parent_id=None):
template.self_id = self_id
template.parent_id = parent_id

return template
return template

def ChatConversation(system_prompt):
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)

prompt = ChatPromptTemplate.from_messages([
SystemMessage(content=system_prompt),
MessagesPlaceholder(variable_name="chat_history"),
HumanMessagePromptTemplate.from_template("{human_input}")
])

return prompt

11 changes: 9 additions & 2 deletions conversers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@

import common
from language_models import GPT, PaLM, HuggingFace, APIModelLlama7B, APIModelVicuna13B, GeminiPro
from language_models import GPT, PaLM, HuggingFace, APIModelLlama7B, APIModelVicuna13B, GeminiPro, ChatGroqq
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from config import VICUNA_PATH, LLAMA_PATH, ATTACK_TEMP, TARGET_TEMP, ATTACK_TOP_P, TARGET_TOP_P, MAX_PARALLEL_STREAMS
Expand Down Expand Up @@ -88,11 +88,12 @@ def get_attack(self, convs_list, prompts_list):
for conv, prompt in zip(convs_list, prompts_list):
conv.append_message(conv.roles[0], prompt)
# Get prompts

if "gpt" in self.model_name:
full_prompts.append(conv.to_openai_api_messages())
else:
conv.append_message(conv.roles[1], init_message)
full_prompts.append(conv.get_prompt()[:-len(conv.sep2)])
full_prompts.append(conv.get_prompt())#[:-len(conv.sep2)])

for _ in range(self.max_n_attack_attempts):
# Subset conversations based on indices to regenerate
Expand Down Expand Up @@ -222,6 +223,8 @@ def load_indiv_model(model_name):
lm = APIModelLlama7B(model_name)
elif model_name == 'vicuna-api-model':
lm = APIModelVicuna13B(model_name)
elif model_name == 'chatgroq':
lm = ChatGroqq(model_name)
else:
model = AutoModelForCausalLM.from_pretrained(
model_path,
Expand Down Expand Up @@ -265,6 +268,10 @@ def get_model_path_and_template(model_name):
"path": "gpt-3.5-turbo",
"template":"gpt-3.5-turbo"
},
"chatgroq": {
"path": "chatgroq",
"template":"chatgroq"
},
"vicuna":{
"path": VICUNA_PATH,
"template":"vicuna_v1.1"
Expand Down
Empty file added env
Empty file.
50 changes: 49 additions & 1 deletion evaluators.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@

from system_prompts import get_evaluator_system_prompt_for_judge, get_evaluator_system_prompt_for_on_topic

from language_models import GPT
from language_models import GPT, ChatGroqq

def load_evaluator(args):
if "gpt" in args.evaluator_model:
return GPTEvaluator(args)
elif "chatgroq" in args.evaluator_model:
return chatgroqEvaluator(args)
elif args.evaluator_model == "no-evaluator":
return NoEvaluator(args)
else:
Expand Down Expand Up @@ -122,3 +124,49 @@ def on_topic_score(self, attack_prompt_list, original_prompt):
class OpenSourceEvaluator(EvaluatorBase):
def __init__(self, evaluator_model, evaluator_tokenizer, args):
raise NotImplementedError

class chatgroqEvaluator(EvaluatorBase):
def __init__(self, args):
super(chatgroqEvaluator, self).__init__(args)
self.evaluator_model = ChatGroqq(model_name = self.evaluator_name)

def create_conv(self, full_prompt, system_prompt=None):
if system_prompt is None:
system_prompt = self.system_prompt

conv = get_conversation_template(self.evaluator_name)
conv.set_system_message(system_prompt)
conv.append_message(conv.roles[0], full_prompt)

return conv.to_openai_api_messages()

def judge_score(self, attack_prompt_list, target_response_list):
convs_list = [
self.create_conv(self.get_evaluator_prompt(prompt, response))
for prompt, response in zip(attack_prompt_list, target_response_list)
]

print(f'\tQuerying evaluator with {len(attack_prompt_list)} prompts (to evaluate judge scores)', flush=True)

raw_outputs = self.evaluator_model.batched_generate(convs_list,
max_n_tokens = self.max_n_tokens,
temperature = self.temperature)

outputs = [self.process_output_judge_score(raw_output) for raw_output in raw_outputs]
return outputs

def on_topic_score(self, attack_prompt_list, original_prompt):


convs_list = [
self.create_conv(self.get_evaluator_prompt_on_topic(prompt), system_prompt=self.system_prompt_on_topic)
for prompt in attack_prompt_list
]

print(f'\tQuerying evaluator with {len(attack_prompt_list)} prompts (to evaluate on-topic scores)', flush=True)

raw_outputs = self.evaluator_model.batched_generate(convs_list,
max_n_tokens = self.max_n_tokens,
temperature = self.temperature)
outputs = [self.process_output_on_topic_score(raw_output) for raw_output in raw_outputs]
return outputs
57 changes: 56 additions & 1 deletion language_models.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
import openai
import anthropic
import os
from dotenv import load_dotenv
import time
import torch
import gc
from typing import Dict, List
import google.generativeai as genai
import urllib3
from copy import deepcopy

from langchain_groq import ChatGroq
from config import LLAMA_API_LINK, VICUNA_API_LINK

load_dotenv(override=True)

class LanguageModel():
def __init__(self, model_name):
Expand Down Expand Up @@ -246,6 +248,58 @@ def batched_generate(self,
top_p: float = 1.0,):
return [self.generate(conv, max_n_tokens, temperature, top_p) for conv in convs_list]

class ChatGroqq():
API_RETRY_SLEEP = 10
API_ERROR_OUTPUT = "$ERROR$"
API_QUERY_SLEEP = 0.5
API_MAX_RETRY = 20
API_TIMEOUT = 20
API_KEY = os.getenv("GROQCLOUD_API_KEY")

def __init__(self, model_name) -> None:
self.model_name = model_name
self.model= ChatGroq(
groq_api_key=self.API_KEY,
model_name="mixtral-8x7b-32768"
)


def generate(self, conv: List[Dict],
max_n_tokens: int,
temperature: float,
top_p: float):
'''
Args:
conv: List of dictionaries, OpenAI API format
max_n_tokens: int, max number of tokens to generate
temperature: float, temperature for sampling
top_p: float, top p for sampling
Returns:
str: generated response
'''
output = self.API_ERROR_OUTPUT


for _ in range(self.API_MAX_RETRY):
try:

response = self.model.invoke(conv)
output = response.content
break
except Exception as e:
print(type(e), e)
time.sleep(self.API_RETRY_SLEEP)

time.sleep(self.API_QUERY_SLEEP)
return output

def batched_generate(self,
convs_list: List[List[Dict]],
max_n_tokens: int,
temperature: float,
top_p: float = 1.0,):
return [self.generate(conv, max_n_tokens, temperature, top_p) for conv in convs_list]

class PaLM():
API_RETRY_SLEEP = 10
API_ERROR_OUTPUT = "$ERROR$"
Expand Down Expand Up @@ -318,6 +372,7 @@ def __init__(self, model_name) -> None:
self.model_name = model_name
genai.configure(api_key=self.API_KEY)


def generate(self, conv: List,
max_n_tokens: int,
temperature: float,
Expand Down
45 changes: 33 additions & 12 deletions main_TAP.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,8 @@ def main(args):
evaluator_llm = load_evaluator(args)
print('Done loading evaluator!', flush=True)

logger = WandBLogger(args, system_prompt)
print('Done logging!', flush=True)
#logger = WandBLogger(args, system_prompt)
#print('Done logging!', flush=True)

# Initialize conversations
batchsize = args.n_streams
Expand Down Expand Up @@ -189,11 +189,26 @@ def main(args):
############################################################
target_response_list = target_llm.get_response(adv_prompt_list)
print("Finished getting target responses.")

# Get judge-scores (i.e., likelihood of jailbreak) from Evaluator
judge_scores = evaluator_llm.judge_score(adv_prompt_list, target_response_list)
print("Finished getting judge scores from evaluator.")


print("=========================")
print("on_topic_scores", on_topic_scores)
print("=========================")
print("judge_scores", judge_scores)
print("=========================")
print("adv_prompt_list", adv_prompt_list)
print("=========================")
print("improv_list", improv_list)
print("=========================")
print("convs_list", convs_list)
print("=========================")
print("target_response_list", target_response_list)
print("=========================")
print("extracted_attack_list", extracted_attack_list)
print("=========================")
############################################################
# PRUNE: PHASE 2
############################################################
Expand All @@ -216,21 +231,24 @@ def main(args):
attack_params=attack_params)

# WandB log values
'''
logger.log(iteration,
extracted_attack_list,
#extracted_attack_list,
target_response_list,
judge_scores,
on_topic_scores,
conv_ids=[c.self_id for c in convs_list],
parent_conv_ids=[c.parent_id for c in convs_list])
'''

# Truncate conversation to avoid context length issues
for conv in convs_list:
# Note that this does not delete the conv.role (i.e., the system prompt)
conv.messages = conv.messages[-2*(args.keep_last_n):]

# Early stopping criterion
if any([score == 10 for score in judge_scores]):

print("Found a jailbreak. Exiting.")
break

Expand All @@ -246,7 +264,7 @@ def main(args):
for target_response, score in zip(target_response_list, judge_scores)
]

logger.finish()
#logger.finish()


if __name__ == '__main__':
Expand All @@ -256,15 +274,16 @@ def main(args):
########### Attack model parameters ##########
parser.add_argument(
"--attack-model",
default = "vicuna",
default = "chatgroq",
help = "Name of attacking model.",
choices=["vicuna",
"vicuna-api-model",
"gpt-3.5-turbo",
"gpt-4",
"gpt-4-turbo",
"gpt-4-1106-preview", # This is same as gpt-4-turbo
'llama-2-api-model']
'llama-2-api-model',
"chatgroq"]
)
parser.add_argument(
"--attack-max-n-tokens",
Expand All @@ -283,7 +302,7 @@ def main(args):
########### Target model parameters ##########
parser.add_argument(
"--target-model",
default = "vicuna",
default = "chatgroq",
help = "Name of target model.",
choices=["llama-2",
'llama-2-api-model',
Expand All @@ -295,6 +314,7 @@ def main(args):
'gpt-4-1106-preview', # This is same as gpt-4-turbo
"palm-2",
"gemini-pro",
"chatgroq",
]
)
parser.add_argument(
Expand All @@ -308,12 +328,13 @@ def main(args):
############ Evaluator model parameters ##########
parser.add_argument(
"--evaluator-model",
default="gpt-3.5-turbo",
default="chatgroq",
help="Name of evaluator model.",
choices=["gpt-3.5-turbo",
"gpt-4",
"gpt-4-turbo",
"gpt-4-1106-preview",
"gpt-4-1106-preview",
"chatgroq",
"no-evaluator"]
)
parser.add_argument(
Expand Down Expand Up @@ -409,4 +430,4 @@ def main(args):

args = parser.parse_args()

main(args)
main(args)
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ wandb==0.16.1
accelerate==0.25.0
pyarrow
fastparquet
langchain