Skip to content

Latest commit

 

History

History
926 lines (699 loc) · 25.4 KB

File metadata and controls

926 lines (699 loc) · 25.4 KB

API Reference

Table of Contents


Type Aliases

Messages

Messages = str | list[ChatMessage]

The primary message type. Either a plain string (completion mode) or a list of chat messages (chat mode).

ChatMessage

ChatMessage = ChatCompletionMessageParam  # from openai.types.chat

OpenAI's chat message type with role, content, and optional tool_calls / tool_call_id fields.

Info

Info = dict[str, Any]

Arbitrary metadata dictionary from dataset rows.

SamplingArgs

SamplingArgs = dict[str, Any]

Generation parameters passed to the inference server (e.g., temperature, top_p, max_tokens).

RewardFunc

IndividualRewardFunc = Callable[..., float | Awaitable[float]]
GroupRewardFunc = Callable[..., list[float] | Awaitable[list[float]]]
RewardFunc = IndividualRewardFunc | GroupRewardFunc

Individual reward functions operate on single rollouts. Group reward functions operate on all rollouts for an example together (useful for relative scoring).

ClientType

ClientType = Literal[
    "openai_completions",
    "openai_chat_completions",
    "openai_chat_completions_token",
    "anthropic_messages",
]

Selects which Client implementation to use. Set via ClientConfig.client_type.


Data Types

State

class State(dict):
    INPUT_FIELDS = ["prompt", "answer", "task", "info", "example_id"]

A dict subclass that tracks rollout information. Accessing keys in INPUT_FIELDS automatically forwards to the nested input object.

Fields set during initialization:

Field Type Description
input RolloutInput Nested input data
client Client Client instance
model str Model name
sampling_args SamplingArgs | None Generation parameters
is_completed bool Whether rollout has ended
is_truncated bool Whether generation was truncated
tool_defs list[Tool] | None Available tool definitions
trajectory list[TrajectoryStep] Multi-turn trajectory
trajectory_id str UUID for this rollout
timing RolloutTiming Timing information

Fields set after scoring:

Field Type Description
completion Messages | None Final completion
reward float | None Final reward
advantage float | None Advantage over group mean
metrics dict[str, float] | None Per-function metrics
stop_condition str | None Name of triggered stop condition
error Error | None Error if rollout failed

RolloutInput

class RolloutInput(TypedDict):
    prompt: Messages        # Required
    example_id: int         # Required
    task: str               # Required
    answer: str             # Optional
    info: Info              # Optional

RolloutOutput

class RolloutOutput(dict):
    # Required fields
    example_id: int
    task: str
    prompt: Messages | None
    completion: Messages | None
    reward: float
    timing: RolloutTiming
    is_completed: bool
    is_truncated: bool
    metrics: dict[str, float]
    # Optional fields
    answer: str
    info: Info
    error: str | None
    stop_condition: str | None
    trajectory: list[TrajectoryStep]
    tool_defs: list[Tool] | None

Serialized output from a rollout. This is a dict subclass that provides typed access to known fields while supporting arbitrary additional fields from state_columns. All values must be JSON-serializable. Used in GenerateOutputs and for saving results to disk.

TrajectoryStep

class TrajectoryStep(TypedDict):
    prompt: Messages
    completion: Messages
    response: Response
    tokens: TrajectoryStepTokens | None
    reward: float | None
    advantage: float | None
    is_truncated: bool
    trajectory_id: str
    extras: dict[str, Any]

A single turn in a multi-turn rollout.

TrajectoryStepTokens

class TrajectoryStepTokens(TypedDict):
    prompt_ids: list[int]
    prompt_mask: list[int]
    completion_ids: list[int]
    completion_mask: list[int]
    completion_logprobs: list[float]
    overlong_prompt: bool
    is_truncated: bool
    routed_experts: list[list[list[int]]] | None  # [seq_len, layers, topk] to enable router replay

Token-level data for training.

RolloutTiming

class RolloutTiming(TypedDict, total=False):
    start_time: float
    generation_ms: float
    scoring_ms: float
    total_ms: float

GenerateOutputs

class GenerateOutputs(TypedDict):
    outputs: list[RolloutOutput]
    metadata: GenerateMetadata

Output from Environment.generate(). Contains a list of RolloutOutput objects (one per rollout) and generation metadata. Each RolloutOutput is a serialized, JSON-compatible dict containing the rollout's prompt, completion, answer, reward, metrics, timing, and other per-rollout data.

GenerateMetadata

class VersionInfo(TypedDict):
    vf_version: str
    vf_commit: str | None
    env_version: str | None
    env_commit: str | None

class GenerateMetadata(TypedDict):
    env_id: str
    env_args: dict
    model: str
    base_url: str
    num_examples: int
    rollouts_per_example: int
    sampling_args: SamplingArgs
    date: str
    time_ms: float
    avg_reward: float
    avg_metrics: dict[str, float]
    version_info: VersionInfo
    state_columns: list[str]
    path_to_save: Path
    tools: list[Tool] | None

base_url is always serialized as a string. For multi-endpoint runs (e.g., using ClientConfig.endpoint_configs), it is stored as a comma-separated list of URLs.

version_info captures the verifiers framework version/commit and the environment package version/commit at generation time. Populated automatically by GenerateOutputsBuilder.

RolloutScore / RolloutScores

class RolloutScore(TypedDict):
    reward: float
    metrics: dict[str, float]

class RolloutScores(TypedDict):
    reward: list[float]
    metrics: dict[str, list[float]]

Classes

Environment Classes

Environment

class Environment(ABC):
    def __init__(
        self,
        dataset: Dataset | None = None,
        eval_dataset: Dataset | None = None,
        system_prompt: str | None = None,
        few_shot: list[ChatMessage] | None = None,
        parser: Parser | None = None,
        rubric: Rubric | None = None,
        sampling_args: SamplingArgs | None = None,
        message_type: MessageType = "chat",
        max_workers: int = 512,
        env_id: str | None = None,
        env_args: dict | None = None,
        max_seq_len: int | None = None,
        **kwargs,
    ): ...

Abstract base class for all environments.

Generation methods:

Method Returns Description
generate(inputs, client, model, ...) GenerateOutputs Run rollouts asynchronously. client accepts Client | ClientConfig.
generate_sync(inputs, client, ...) GenerateOutputs Synchronous wrapper
evaluate(client, model, ...) GenerateOutputs Evaluate on eval_dataset
evaluate_sync(client, model, ...) GenerateOutputs Synchronous evaluation

Dataset methods:

Method Returns Description
get_dataset(n=-1, seed=None) Dataset Get training dataset (optionally first n, shuffled)
get_eval_dataset(n=-1, seed=None) Dataset Get evaluation dataset
make_dataset(...) Dataset Static method to create dataset from inputs

Rollout methods (used internally or by subclasses):

Method Returns Description
rollout(input, client, model, sampling_args) State Abstract: run single rollout
init_state(input, client, model, sampling_args) State Create initial state from input
get_model_response(state, prompt, ...) Response Get model response for prompt
is_completed(state) bool Check all stop conditions
run_rollout(sem, input, client, model, sampling_args) State Run rollout with semaphore
run_group(group_inputs, client, model, ...) list[State] Generate and score one group

Configuration methods:

Method Description
set_kwargs(**kwargs) Set attributes using setter methods when available
add_rubric(rubric) Add or merge rubric
set_max_seq_len(max_seq_len) Set maximum sequence length
set_score_rollouts(bool) Enable/disable scoring

SingleTurnEnv

Single-response Q&A tasks. Inherits from Environment.

MultiTurnEnv

class MultiTurnEnv(Environment):
    def __init__(self, max_turns: int = -1, **kwargs): ...

Multi-turn interactions. Subclasses must implement env_response.

Abstract method:

async def env_response(self, messages: Messages, state: State, **kwargs) -> Messages:
    """Generate environment feedback after model turn."""

Built-in stop conditions: has_error, prompt_too_long, max_turns_reached, has_final_env_response

Hooks:

Method Description
setup_state(state) Initialize per-rollout state
get_prompt_messages(state) Customize prompt construction
render_completion(state) Customize completion rendering
add_trajectory_step(state, step) Customize trajectory handling

ToolEnv

class ToolEnv(MultiTurnEnv):
    def __init__(
        self,
        tools: list[Callable] | None = None,
        max_turns: int = 10,
        error_formatter: Callable[[Exception], str] = lambda e: f"{e}",
        stop_errors: list[type[Exception]] | None = None,
        **kwargs,
    ): ...

Tool calling with stateless Python functions. Automatically converts functions to OpenAI tool format.

Built-in stop condition: no_tools_called (ends when model responds without tool calls)

Methods:

Method Description
add_tool(tool) Add a tool at runtime
remove_tool(tool) Remove a tool at runtime
call_tool(name, args, id) Override to customize tool execution

StatefulToolEnv

Tools requiring per-rollout state. Override setup_state and update_tool_args to inject state.

SandboxEnv

class SandboxEnv(StatefulToolEnv):
    def __init__(
        self,
        sandbox_name: str = "sandbox-env",
        docker_image: str = "python:3.11-slim",
        start_command: str = "tail -f /dev/null",
        cpu_cores: int = 1,
        memory_gb: int = 2,
        disk_size_gb: int = 5,
        gpu_count: int = 0,
        timeout_minutes: int = 60,
        timeout_per_command_seconds: int = 30,
        environment_vars: dict[str, str] | None = None,
        team_id: str | None = None,
        advanced_configs: AdvancedConfigs | None = None,
        labels: list[str] | None = None,
        **kwargs,
    ): ...

Sandboxed container execution using prime sandboxes.

Key parameters:

Parameter Type Description
sandbox_name str Name prefix for sandbox instances
docker_image str Docker image to use for the sandbox
cpu_cores int Number of CPU cores
memory_gb int Memory allocation in GB
disk_size_gb int Disk size in GB
gpu_count int Number of GPUs
timeout_minutes int Sandbox timeout in minutes
timeout_per_command_seconds int Per-command execution timeout
environment_vars dict[str, str] | None Environment variables to set in sandbox
labels list[str] | None Labels for sandbox categorization and filtering

PythonEnv

Persistent Python REPL in sandbox. Extends SandboxEnv.

OpenEnvEnv

class OpenEnvEnv(MultiTurnEnv):
    def __init__(
        self,
        openenv_project: str | Path,
        num_train_examples: int = 100,
        num_eval_examples: int = 50,
        seed: int = 0,
        prompt_renderer: Callable[..., ChatMessages] | None = None,
        max_turns: int = -1,
        rubric: Rubric | None = None,
        **kwargs,
    ): ...

OpenEnv integration that runs OpenEnv projects in Prime Sandboxes using a prebuilt image manifest (.build.json), supports both gym and MCP contracts, and requires a prompt_renderer to convert observations into chat messages.

EnvGroup

env_group = vf.EnvGroup(
    envs=[env1, env2, env3],
    names=["math", "code", "qa"]  # optional
)

Combines multiple environments for mixed-task training.


Parser Classes

Parser

class Parser:
    def __init__(self, extract_fn: Callable[[str], str] = lambda x: x): ...
    
    def parse(self, text: str) -> Any: ...
    def parse_answer(self, completion: Messages) -> str | None: ...
    def get_format_reward_func(self) -> Callable: ...

Base parser. Default behavior returns text as-is.

XMLParser

class XMLParser(Parser):
    def __init__(
        self,
        fields: list[str | tuple[str, ...]],
        answer_field: str = "answer",
        extract_fn: Callable[[str], str] = lambda x: x,
    ): ...

Extracts structured fields from XML-tagged output.

parser = vf.XMLParser(fields=["reasoning", "answer"])
# Parses: <reasoning>...</reasoning><answer>...</answer>

# With alternatives:
parser = vf.XMLParser(fields=["reasoning", ("code", "answer")])
# Accepts either <code> or <answer> for second field

Methods:

Method Returns Description
parse(text) SimpleNamespace Parse XML into object with field attributes
parse_answer(completion) str | None Extract answer field from completion
get_format_str() str Get format description string
get_fields() list[str] Get canonical field names
format(**kwargs) str Format kwargs into XML string

ThinkParser

class ThinkParser(Parser):
    def __init__(self, extract_fn: Callable[[str], str] = lambda x: x): ...

Extracts content after </think> tag. For models that always include <think> tags but don't parse them automatically.

MaybeThinkParser

Handles optional <think> tags (for models that may or may not think).


Rubric Classes

Rubric

class Rubric:
    def __init__(
        self,
        funcs: list[RewardFunc] | None = None,
        weights: list[float] | None = None,
        parser: Parser | None = None,
    ): ...

Combines multiple reward functions with weights. Default weight is 1.0. Functions with weight=0.0 are tracked as metrics only.

Methods:

Method Description
add_reward_func(func, weight=1.0) Add a reward function
add_metric(func, weight=0.0) Add a metric (no reward contribution)
add_class_object(name, obj) Add object accessible in reward functions

Reward function signature:

def my_reward(
    completion: Messages,
    answer: str = "",
    prompt: Messages | None = None,
    state: State | None = None,
    parser: Parser | None = None,  # if rubric has parser
    task: str = "",
    info: Info | None = None,
    **kwargs
) -> float:
    ...

Group reward function signature:

def my_group_reward(
    completions: list[Messages],
    answers: list[str],
    states: list[State],
    # ... plural versions of individual args
    **kwargs
) -> list[float]:
    ...

JudgeRubric

LLM-as-judge evaluation.

MathRubric

Math-specific evaluation using math-verify.

RubricGroup

Combines rubrics for EnvGroup.


Client Classes

Client

class Client(ABC, Generic[ClientT, MessagesT, ResponseT, ToolT]):
    def __init__(self, client_or_config: ClientT | ClientConfig) -> None: ...

    @property
    def client(self) -> ClientT: ...

    async def get_response(
        self,
        prompt: Messages,
        model: str,
        sampling_args: SamplingArgs,
        tools: list[Tool] | None = None,
        **kwargs,
    ) -> Response: ...

    async def close(self) -> None: ...

Abstract base class for all model clients. Wraps a provider-specific SDK client and translates between provider-agnostic vf types (Messages, Tool, Response) and provider-native formats. The client property exposes the underlying SDK client (e.g., AsyncOpenAI, AsyncAnthropic).

get_response() is the main public method — it converts the prompt and tools to the native format, calls the provider API, validates the response, and converts it back to a vf.Response. Errors are wrapped in vf.ModelError unless they are already vf.Error or authentication errors.

Abstract methods (for subclass implementors):

Method Description
setup_client(config) Create the native SDK client from ClientConfig
to_native_prompt(messages) Convert Messages → native prompt format + extra kwargs
to_native_tool(tool) Convert Tool → native tool format
get_native_response(prompt, model, ...) Call the provider API
raise_from_native_response(response) Raise ModelError for invalid responses
from_native_response(response) Convert native response → vf.Response
close() Close the underlying SDK client

Built-in Client Implementations

Class client_type SDK Client Description
OpenAIChatCompletionsClient "openai_chat_completions" AsyncOpenAI Chat Completions API (default)
OpenAICompletionsClient "openai_completions" AsyncOpenAI Legacy Completions API
OpenAIChatCompletionsTokenClient "openai_chat_completions_token" AsyncOpenAI Custom vLLM token route
AnthropicMessagesClient "anthropic_messages" AsyncAnthropic Anthropic Messages API

All built-in clients are available as vf.OpenAIChatCompletionsClient, vf.AnthropicMessagesClient, etc.

Response

class Response(BaseModel):
    id: str
    created: int
    model: str
    usage: Usage | None
    message: ResponseMessage

class ResponseMessage(BaseModel):
    content: str | None
    reasoning_content: str | None
    finish_reason: Literal["stop", "length", "tool_calls"] | None
    is_truncated: bool | None
    tokens: ResponseTokens | None
    tool_calls: list[ToolCall] | None

Provider-agnostic model response. All Client implementations return Response from get_response().

Tool

class Tool(BaseModel):
    name: str
    description: str
    parameters: dict[str, object]
    strict: bool | None = None

Provider-agnostic tool definition. Environments define tools using this type; each Client converts them to its native format via to_native_tool().


Configuration Types

ClientConfig

class ClientConfig(BaseModel):
    client_idx: int = 0
    client_type: ClientType = "openai_chat_completions"
    api_key_var: str = "PRIME_API_KEY"
    api_base_url: str = "https://api.pinference.ai/api/v1"
    endpoint_configs: list[EndpointClientConfig] = []
    timeout: float = 3600.0
    max_connections: int = 28000
    max_keepalive_connections: int = 28000
    max_retries: int = 10
    extra_headers: dict[str, str] = {}

client_type selects which Client implementation to instantiate (see Client Classes). Use endpoint_configs for multi-endpoint round-robin. In grouped scoring mode, groups are distributed round-robin across endpoint configs.

When api_key_var is "PRIME_API_KEY" (the default), credentials are loaded with the following precedence:

  • API key: PRIME_API_KEY env var > ~/.prime/config.json > "EMPTY"
  • Team ID: PRIME_TEAM_ID env var > ~/.prime/config.json > not set

This allows seamless use after running prime login.

EndpointClientConfig

class EndpointClientConfig(BaseModel):
    client_idx: int = 0
    api_key_var: str = "PRIME_API_KEY"
    api_base_url: str = "https://api.pinference.ai/api/v1"
    timeout: float = 3600.0
    max_connections: int = 28000
    max_keepalive_connections: int = 28000
    max_retries: int = 10
    extra_headers: dict[str, str] = {}

Leaf endpoint configuration used inside ClientConfig.endpoint_configs. Has the same fields as ClientConfig except endpoint_configs itself, preventing recursive nesting.

EvalConfig

class EvalConfig(BaseModel):
    env_id: str
    env_args: dict
    env_dir_path: str
    endpoint_id: str | None = None
    model: str
    client_config: ClientConfig
    sampling_args: SamplingArgs
    num_examples: int
    rollouts_per_example: int
    max_concurrent: int
    independent_scoring: bool = False
    extra_env_kwargs: dict = {}
    max_retries: int = 0
    verbose: bool = False
    state_columns: list[str] | None = None
    save_results: bool = False
    resume_path: Path | None = None
    save_to_hf_hub: bool = False
    hf_hub_dataset_name: str | None = None

Endpoint

Endpoint = TypedDict("Endpoint", {"key": str, "url": str, "model": str})
Endpoints = dict[str, list[Endpoint]]

Endpoints maps an endpoint id to one or more endpoint variants. A single variant is represented as a one-item list.


Prime CLI Plugin

Verifiers exposes a plugin contract consumed by prime for command execution.

PRIME_PLUGIN_API_VERSION

PRIME_PLUGIN_API_VERSION = 1

API version for compatibility checks between prime and verifiers.

PrimeCLIPlugin

@dataclass(frozen=True)
class PrimeCLIPlugin:
    api_version: int = PRIME_PLUGIN_API_VERSION
    eval_module: str = "verifiers.cli.commands.eval"
    gepa_module: str = "verifiers.cli.commands.gepa"
    install_module: str = "verifiers.cli.commands.install"
    init_module: str = "verifiers.cli.commands.init"
    setup_module: str = "verifiers.cli.commands.setup"
    build_module: str = "verifiers.cli.commands.build"

    def build_module_command(
        self, module_name: str, args: Sequence[str] | None = None
    ) -> list[str]:
        ...

build_module_command returns a subprocess command list for python -m <module> ....

get_plugin

def get_plugin() -> PrimeCLIPlugin:
    ...

Returns the plugin instance consumed by prime.


Decorators

@vf.stop

@vf.stop
async def my_condition(self, state: State) -> bool:
    """Return True to end the rollout."""
    ...

@vf.stop(priority=10)  # Higher priority runs first
async def early_check(self, state: State) -> bool:
    ...

Mark a method as a stop condition. All stop conditions are checked by is_completed().

@vf.cleanup

@vf.cleanup
async def my_cleanup(self, state: State) -> None:
    """Called after each rollout completes."""
    ...

@vf.cleanup(priority=10)
async def early_cleanup(self, state: State) -> None:
    ...

Mark a method as a rollout cleanup handler. Cleanup methods should be idempotent—safe to call multiple times—and handle errors gracefully to ensure cleanup completes even when resources are in unexpected states.

@vf.teardown

@vf.teardown
async def my_teardown(self) -> None:
    """Called when environment is destroyed."""
    ...

@vf.teardown(priority=10)
async def early_teardown(self) -> None:
    ...

Mark a method as an environment teardown handler.


Utility Functions

Data Utilities

vf.load_example_dataset(name: str) -> Dataset

Load a built-in example dataset.

vf.extract_boxed_answer(text: str) -> str | None

Extract answer from LaTeX \boxed{} format.

vf.extract_hash_answer(text: str) -> str | None

Extract answer after #### marker (GSM8K format).

Environment Utilities

vf.load_environment(env_id: str, **kwargs) -> Environment

Load an environment by ID (e.g., "primeintellect/gsm8k").

Configuration Utilities

vf.ensure_keys(keys: list[str]) -> None

Validate that required environment variables are set. Raises MissingKeyError (a ValueError subclass) with a clear message listing all missing keys and instructions for setting them.

class MissingKeyError(ValueError):
    keys: list[str]  # list of missing key names

Example:

def load_environment(api_key_var: str = "OPENAI_API_KEY") -> vf.Environment:
    vf.ensure_keys([api_key_var])
    # now safe to use os.environ[api_key_var]
    ...

Logging Utilities

vf.print_prompt_completions_sample(outputs: GenerateOutputs, n: int = 3)

Pretty-print sample rollouts.

vf.setup_logging(level: str = "INFO")

Configure verifiers logging. Set VF_LOG_LEVEL env var to change default.

vf.log_level(level: str | int)

Context manager to temporarily set the verifiers logger to a new log level. Useful for temporarily adjusting verbosity during specific operations.

with vf.log_level("DEBUG"):
    # verifiers logs at DEBUG level here
    ...
# reverts to previous level
vf.quiet_verifiers()

Context manager to temporarily silence verifiers logging by setting WARNING level. Shorthand for vf.log_level("WARNING").

with vf.quiet_verifiers():
    # verifiers logging is quieted here
    outputs = env.generate(...)
# logging restored