-
Notifications
You must be signed in to change notification settings - Fork 21
Open
Description
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers_cfg.grammar_utils import IncrementalGrammarConstraint
from transformers_cfg.generation.logits_process import GrammarConstrainedLogitsProcessor
if __name__ == "__main__":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model_id = "mistralai/Mistral-7B-v0.1"
# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token
model = AutoModelForCausalLM.from_pretrained(model_id).to(
device
) # Load model to defined device
model.generation_config.pad_token_id = model.generation_config.eos_token_id
model = AutoModelForCausalLM.from_pretrained(model_id).to(
device
) # Load model to defined device
model.generation_config.pad_token_id = model.generation_config.eos_token_id
grammar_str = """
# Grammar for subset of JSON
# String doesn't support unicode and escape yet
# If you don't need to generate unicode and escape, you can use this grammar
# We are working to support unicode and escape
root ::= object
object ::= "{" ws ( string ":" ws value ("," ws string ":" ws value)* )? "}"
value ::= object | array | string | number | ("true" | "false" | "null") ws
array ::= "[" ws ( value ("," ws value)* )? "]" ws
string ::= "\"" [ \t!#-\[\]-~]* "\"" ws
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? ws
ws ::= ([ \t\n] ws)?
"""
grammar = IncrementalGrammarConstraint(grammar_str, "root", tokenizer)
grammar_processor = GrammarConstrainedLogitsProcessor(grammar)
# Generate
prefix1 = "This is a valid json string for http request:"
prefix2 = "This is a valid json string for shopping cart:"
input_ids = tokenizer(
[prefix1, prefix2], add_special_tokens=False, return_tensors="pt", padding=True
)["input_ids"].to(
device
) # Move input_ids to the same device as model
output = model.generate(
input_ids,
do_sample=False,
max_new_tokens=60,
logits_processor=[grammar_processor],
repetition_penalty=1.1,
num_return_sequences=1,
)
# decode output
generations = tokenizer.batch_decode(output, skip_special_tokens=True)
print(generations)
"""
'This is a valid json string for http request:{ "request": { "method": "GET", "headers": [], "content": "Content","type": "application" }}
'This is a valid json string for shopping cart:This is a valid json string for shopping cart:{ "name": "MyCart", "price": 0, "value": 1 }
"""Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels