From ab8874f5922a9a568cb1f4ca9fe8bd155449ba1c Mon Sep 17 00:00:00 2001 From: Claudio Stanzione Date: Mon, 23 Sep 2024 17:12:06 +0200 Subject: [PATCH 1/2] added ChatGroq model --- conversers.py | 11 +++++++-- env | 0 evaluators.py | 50 +++++++++++++++++++++++++++++++++++++++- language_models.py | 57 +++++++++++++++++++++++++++++++++++++++++++++- main_TAP.py | 44 ++++++++++++++++++++++++++--------- 5 files changed, 147 insertions(+), 15 deletions(-) create mode 100644 env diff --git a/conversers.py b/conversers.py index c90a6c4..4fb0ab6 100644 --- a/conversers.py +++ b/conversers.py @@ -1,6 +1,6 @@ import common -from language_models import GPT, PaLM, HuggingFace, APIModelLlama7B, APIModelVicuna13B, GeminiPro +from language_models import GPT, PaLM, HuggingFace, APIModelLlama7B, APIModelVicuna13B, GeminiPro, ChatGroqq import torch from transformers import AutoModelForCausalLM, AutoTokenizer from config import VICUNA_PATH, LLAMA_PATH, ATTACK_TEMP, TARGET_TEMP, ATTACK_TOP_P, TARGET_TOP_P, MAX_PARALLEL_STREAMS @@ -88,11 +88,12 @@ def get_attack(self, convs_list, prompts_list): for conv, prompt in zip(convs_list, prompts_list): conv.append_message(conv.roles[0], prompt) # Get prompts + if "gpt" in self.model_name: full_prompts.append(conv.to_openai_api_messages()) else: conv.append_message(conv.roles[1], init_message) - full_prompts.append(conv.get_prompt()[:-len(conv.sep2)]) + full_prompts.append(conv.get_prompt())#[:-len(conv.sep2)]) for _ in range(self.max_n_attack_attempts): # Subset conversations based on indices to regenerate @@ -222,6 +223,8 @@ def load_indiv_model(model_name): lm = APIModelLlama7B(model_name) elif model_name == 'vicuna-api-model': lm = APIModelVicuna13B(model_name) + elif model_name == 'chatgroq': + lm = ChatGroqq(model_name) else: model = AutoModelForCausalLM.from_pretrained( model_path, @@ -265,6 +268,10 @@ def get_model_path_and_template(model_name): "path": "gpt-3.5-turbo", "template":"gpt-3.5-turbo" }, + "chatgroq": { + "path": "chatgroq", + "template":"chatgroq" + }, "vicuna":{ "path": VICUNA_PATH, "template":"vicuna_v1.1" diff --git a/env b/env new file mode 100644 index 0000000..e69de29 diff --git a/evaluators.py b/evaluators.py index 4e9cb52..0a5bb44 100644 --- a/evaluators.py +++ b/evaluators.py @@ -6,11 +6,13 @@ from system_prompts import get_evaluator_system_prompt_for_judge, get_evaluator_system_prompt_for_on_topic -from language_models import GPT +from language_models import GPT, ChatGroqq def load_evaluator(args): if "gpt" in args.evaluator_model: return GPTEvaluator(args) + elif "chatgroq" in args.evaluator_model: + return chatgroqEvaluator(args) elif args.evaluator_model == "no-evaluator": return NoEvaluator(args) else: @@ -122,3 +124,49 @@ def on_topic_score(self, attack_prompt_list, original_prompt): class OpenSourceEvaluator(EvaluatorBase): def __init__(self, evaluator_model, evaluator_tokenizer, args): raise NotImplementedError + +class chatgroqEvaluator(EvaluatorBase): + def __init__(self, args): + super(chatgroqEvaluator, self).__init__(args) + self.evaluator_model = ChatGroqq(model_name = self.evaluator_name) + + def create_conv(self, full_prompt, system_prompt=None): + if system_prompt is None: + system_prompt = self.system_prompt + + conv = get_conversation_template(self.evaluator_name) + conv.set_system_message(system_prompt) + conv.append_message(conv.roles[0], full_prompt) + + return conv.to_openai_api_messages() + + def judge_score(self, attack_prompt_list, target_response_list): + convs_list = [ + self.create_conv(self.get_evaluator_prompt(prompt, response)) + for prompt, response in zip(attack_prompt_list, target_response_list) + ] + + print(f'\tQuerying evaluator with {len(attack_prompt_list)} prompts (to evaluate judge scores)', flush=True) + + raw_outputs = self.evaluator_model.batched_generate(convs_list, + max_n_tokens = self.max_n_tokens, + temperature = self.temperature) + + outputs = [self.process_output_judge_score(raw_output) for raw_output in raw_outputs] + return outputs + + def on_topic_score(self, attack_prompt_list, original_prompt): + + + convs_list = [ + self.create_conv(self.get_evaluator_prompt_on_topic(prompt), system_prompt=self.system_prompt_on_topic) + for prompt in attack_prompt_list + ] + + print(f'\tQuerying evaluator with {len(attack_prompt_list)} prompts (to evaluate on-topic scores)', flush=True) + + raw_outputs = self.evaluator_model.batched_generate(convs_list, + max_n_tokens = self.max_n_tokens, + temperature = self.temperature) + outputs = [self.process_output_on_topic_score(raw_output) for raw_output in raw_outputs] + return outputs \ No newline at end of file diff --git a/language_models.py b/language_models.py index f640021..306747e 100644 --- a/language_models.py +++ b/language_models.py @@ -1,6 +1,7 @@ import openai import anthropic import os +from dotenv import load_dotenv import time import torch import gc @@ -8,9 +9,10 @@ import google.generativeai as genai import urllib3 from copy import deepcopy - +from langchain_groq import ChatGroq from config import LLAMA_API_LINK, VICUNA_API_LINK +load_dotenv(override=True) class LanguageModel(): def __init__(self, model_name): @@ -246,6 +248,58 @@ def batched_generate(self, top_p: float = 1.0,): return [self.generate(conv, max_n_tokens, temperature, top_p) for conv in convs_list] +class ChatGroqq(): + API_RETRY_SLEEP = 10 + API_ERROR_OUTPUT = "$ERROR$" + API_QUERY_SLEEP = 0.5 + API_MAX_RETRY = 20 + API_TIMEOUT = 20 + API_KEY = os.getenv("GROQCLOUD_API_KEY") + + def __init__(self, model_name) -> None: + self.model_name = model_name + self.model= ChatGroq( + groq_api_key=self.API_KEY, + model_name="mixtral-8x7b-32768" + ) + + + def generate(self, conv: List[Dict], + max_n_tokens: int, + temperature: float, + top_p: float): + ''' + Args: + conv: List of dictionaries, OpenAI API format + max_n_tokens: int, max number of tokens to generate + temperature: float, temperature for sampling + top_p: float, top p for sampling + Returns: + str: generated response + ''' + output = self.API_ERROR_OUTPUT + + + for _ in range(self.API_MAX_RETRY): + try: + + response = self.model.invoke(conv) + output = response.content + break + except Exception as e: + print(type(e), e) + time.sleep(self.API_RETRY_SLEEP) + + time.sleep(self.API_QUERY_SLEEP) + return output + + def batched_generate(self, + convs_list: List[List[Dict]], + max_n_tokens: int, + temperature: float, + top_p: float = 1.0,): + return [self.generate(conv, max_n_tokens, temperature, top_p) for conv in convs_list] + class PaLM(): API_RETRY_SLEEP = 10 API_ERROR_OUTPUT = "$ERROR$" @@ -318,6 +372,7 @@ def __init__(self, model_name) -> None: self.model_name = model_name genai.configure(api_key=self.API_KEY) + def generate(self, conv: List, max_n_tokens: int, temperature: float, diff --git a/main_TAP.py b/main_TAP.py index 00be593..68429a4 100644 --- a/main_TAP.py +++ b/main_TAP.py @@ -108,8 +108,8 @@ def main(args): evaluator_llm = load_evaluator(args) print('Done loading evaluator!', flush=True) - logger = WandBLogger(args, system_prompt) - print('Done logging!', flush=True) + #logger = WandBLogger(args, system_prompt) + #print('Done logging!', flush=True) # Initialize conversations batchsize = args.n_streams @@ -194,6 +194,22 @@ def main(args): judge_scores = evaluator_llm.judge_score(adv_prompt_list, target_response_list) print("Finished getting judge scores from evaluator.") + + print("=========================") + print("on_topic_scores", on_topic_scores) + print("=========================") + print("judge_scores", judge_scores) + print("=========================") + print("adv_prompt_list", adv_prompt_list) + print("=========================") + print("improv_list", improv_list) + print("=========================") + print("convs_list", convs_list) + print("=========================") + print("target_response_list", target_response_list) + print("=========================") + print("extracted_attack_list", extracted_attack_list) + print("=========================") ############################################################ # PRUNE: PHASE 2 ############################################################ @@ -216,21 +232,24 @@ def main(args): attack_params=attack_params) # WandB log values + ''' logger.log(iteration, - extracted_attack_list, + #extracted_attack_list, target_response_list, judge_scores, on_topic_scores, conv_ids=[c.self_id for c in convs_list], parent_conv_ids=[c.parent_id for c in convs_list]) + ''' # Truncate conversation to avoid context length issues for conv in convs_list: # Note that this does not delete the conv.role (i.e., the system prompt) conv.messages = conv.messages[-2*(args.keep_last_n):] - + # Early stopping criterion if any([score == 10 for score in judge_scores]): + print("Found a jailbreak. Exiting.") break @@ -246,7 +265,7 @@ def main(args): for target_response, score in zip(target_response_list, judge_scores) ] - logger.finish() + #logger.finish() if __name__ == '__main__': @@ -256,7 +275,7 @@ def main(args): ########### Attack model parameters ########## parser.add_argument( "--attack-model", - default = "vicuna", + default = "chatgroq", help = "Name of attacking model.", choices=["vicuna", "vicuna-api-model", @@ -264,7 +283,8 @@ def main(args): "gpt-4", "gpt-4-turbo", "gpt-4-1106-preview", # This is same as gpt-4-turbo - 'llama-2-api-model'] + 'llama-2-api-model', + "chatgroq"] ) parser.add_argument( "--attack-max-n-tokens", @@ -283,7 +303,7 @@ def main(args): ########### Target model parameters ########## parser.add_argument( "--target-model", - default = "vicuna", + default = "chatgroq", help = "Name of target model.", choices=["llama-2", 'llama-2-api-model', @@ -295,6 +315,7 @@ def main(args): 'gpt-4-1106-preview', # This is same as gpt-4-turbo "palm-2", "gemini-pro", + "chatgroq", ] ) parser.add_argument( @@ -308,12 +329,13 @@ def main(args): ############ Evaluator model parameters ########## parser.add_argument( "--evaluator-model", - default="gpt-3.5-turbo", + default="chatgroq", help="Name of evaluator model.", choices=["gpt-3.5-turbo", "gpt-4", "gpt-4-turbo", - "gpt-4-1106-preview", + "gpt-4-1106-preview", + "chatgroq", "no-evaluator"] ) parser.add_argument( @@ -409,4 +431,4 @@ def main(args): args = parser.parse_args() - main(args) + main(args) \ No newline at end of file From 4f461e77573f11a4347a82c71ddcda370db11b90 Mon Sep 17 00:00:00 2001 From: Claudio Stanzione Date: Tue, 24 Sep 2024 14:59:54 +0200 Subject: [PATCH 2/2] update --- common.py | 18 +++++++++++++++++- main_TAP.py | 1 - requirements.txt | 1 + 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/common.py b/common.py index 05652f8..0559c91 100644 --- a/common.py +++ b/common.py @@ -3,6 +3,10 @@ import string import logging from fastchat.model import get_conversation_template +from langchain.chains import LLMChain +from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate, MessagesPlaceholder +from langchain_core.messages import SystemMessage +from langchain.chains.conversation.memory import ConversationBufferMemory # Metadata used to store our results STORE_FOLDER = '' @@ -60,4 +64,16 @@ def conv_template(template_name, self_id=None, parent_id=None): template.self_id = self_id template.parent_id = parent_id - return template \ No newline at end of file + return template + +def ChatConversation(system_prompt): + memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True) + + prompt = ChatPromptTemplate.from_messages([ + SystemMessage(content=system_prompt), + MessagesPlaceholder(variable_name="chat_history"), + HumanMessagePromptTemplate.from_template("{human_input}") + ]) + + return prompt + diff --git a/main_TAP.py b/main_TAP.py index 68429a4..2a69aa4 100644 --- a/main_TAP.py +++ b/main_TAP.py @@ -189,7 +189,6 @@ def main(args): ############################################################ target_response_list = target_llm.get_response(adv_prompt_list) print("Finished getting target responses.") - # Get judge-scores (i.e., likelihood of jailbreak) from Evaluator judge_scores = evaluator_llm.judge_score(adv_prompt_list, target_response_list) print("Finished getting judge scores from evaluator.") diff --git a/requirements.txt b/requirements.txt index 0a4f2bc..1e14729 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,3 +13,4 @@ wandb==0.16.1 accelerate==0.25.0 pyarrow fastparquet +langchain \ No newline at end of file