Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 83 additions & 24 deletions openclaw-rl/openclaw_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,17 +209,34 @@ def __init__(self, args, output_queue: queue.Queue, submission_enabled: threadin
self._session_effective: dict[str, int] = {} # session → count of samples with loss_mask=[1]

self._prm_enabled = getattr(args, "prm_enable", False)
self._prm_provider = str(getattr(args, "prm_provider", os.getenv("PRM_PROVIDER", "local"))).strip().lower()
self._prm_m = int(os.getenv("PRM_M", getattr(args, "prm_m", 3)))
self._prm_temperature = float(getattr(args, "prm_temperature", 0.6))
self._prm_max_tokens = int(getattr(args, "prm_max_new_tokens", 4096))
self._prm_api_base_url = (
str(getattr(args, "prm_api_base_url", None) or os.getenv("PRM_API_BASE_URL", "")).strip().rstrip("/")
)
self._prm_api_key = str(getattr(args, "prm_api_key", None) or os.getenv("PRM_API_KEY", "")).strip()
self._prm_api_model = str(getattr(args, "prm_api_model", None) or os.getenv("PRM_API_MODEL", "")).strip()
self._prm_api_timeout = float(getattr(args, "prm_api_timeout", 120.0))
prm_ip = getattr(args, "prm_router_ip", None)
prm_port = getattr(args, "prm_router_port", None)
self._prm_url = f"http://{prm_ip}:{prm_port}/generate" if prm_ip and prm_port else ""
self._prm_tokenizer = None
if self._prm_enabled:
prm_path = getattr(args, "prm_model_path", None) or args.hf_checkpoint
self._prm_tokenizer = load_tokenizer(prm_path, trust_remote_code=True)
logger.info("[OpenClaw] PRM enabled: url=%s m=%d", self._prm_url, self._prm_m)
if self._prm_provider == "api":
if not self._prm_api_base_url or not self._prm_api_model:
raise ValueError("PRM API mode requires prm_api_base_url and prm_api_model.")
logger.info(
"[OpenClaw] PRM enabled (api): base_url=%s model=%s m=%d",
self._prm_api_base_url,
self._prm_api_model,
self._prm_m,
)
else:
prm_path = getattr(args, "prm_model_path", None) or args.hf_checkpoint
self._prm_tokenizer = load_tokenizer(prm_path, trust_remote_code=True)
logger.info("[OpenClaw] PRM enabled (local): url=%s m=%d", self._prm_url, self._prm_m)

self._eval_scores: list[float] = []
self._eval_scores_lock = threading.Lock()
Expand Down Expand Up @@ -374,6 +391,8 @@ def purge_record_files(self):

# ---------------------------------------------------- PRM scoring
async def _query_prm_once(self, judge_prompt: str, vote_id: int) -> tuple[int | None, str]:
if self._prm_provider == "api":
return None, ""
if not self._prm_url:
return None, ""
payload = {
Expand Down Expand Up @@ -402,21 +421,52 @@ async def _query_prm_once(self, judge_prompt: str, vote_id: int) -> tuple[int |
logger.warning("[OpenClaw] PRM query failed (vote %d): %s", vote_id, e)
return None, ""

async def _query_prm_once_api(self, prm_messages: list[dict[str, str]], vote_id: int) -> tuple[int | None, str]:
endpoint = f"{self._prm_api_base_url}/chat/completions"
headers = {"Content-Type": "application/json"}
if self._prm_api_key:
headers["Authorization"] = f"Bearer {self._prm_api_key}"
payload = {
"model": self._prm_api_model,
"messages": prm_messages,
"temperature": self._prm_temperature,
"max_tokens": self._prm_max_tokens,
}
try:
async with httpx.AsyncClient(timeout=self._prm_api_timeout) as client:
resp = await client.post(endpoint, json=payload, headers=headers)
resp.raise_for_status()
data = resp.json()
choices = data.get("choices", []) if isinstance(data, dict) else []
msg = choices[0].get("message", {}) if choices else {}
raw = msg.get("content", "") if isinstance(msg, dict) else ""
if isinstance(raw, list):
raw = "".join(
item.get("text", "") for item in raw if isinstance(item, dict) and item.get("type") == "text"
)
raw = str(raw)
return _parse_prm_score(raw), raw
except Exception as e:
logger.warning("[OpenClaw] PRM API query failed (vote %d): %s", vote_id, e)
return None, ""

async def _prm_evaluate(self, session_id: str, turn_num: int,
response_text: str, next_state) -> dict:
ns_text = _flatten_message_content(next_state.get("content")) if next_state else ""
ns_role = next_state.get("role", "user") if next_state else "user"
msgs = _build_prm_judge_prompt(response_text, ns_text, ns_role)
if self._prm_tokenizer:
judge_prompt = self._prm_tokenizer.apply_chat_template(
msgs, tokenize=False, add_generation_prompt=True,
)
if self._prm_provider == "api":
results = await asyncio.gather(*[self._query_prm_once_api(msgs, i) for i in range(self._prm_m)])
else:
judge_prompt = "\n".join(m["content"] for m in msgs)

results = await asyncio.gather(
*[self._query_prm_once(judge_prompt, i) for i in range(self._prm_m)]
)
if self._prm_tokenizer:
judge_prompt = self._prm_tokenizer.apply_chat_template(
msgs, tokenize=False, add_generation_prompt=True,
)
else:
judge_prompt = "\n".join(m["content"] for m in msgs)
results = await asyncio.gather(
*[self._query_prm_once(judge_prompt, i) for i in range(self._prm_m)]
)
scores = [r[0] for r in results]
final = _majority_vote(scores)

Expand Down Expand Up @@ -697,22 +747,31 @@ def _wait_for_sglang_ready(self):
time.sleep(3)
logger.info("[OpenClaw] policy server ready")

if self._prm_enabled and self._prm_url:
prm_health = self._prm_url.rsplit("/", 1)[0] + "/health"
while True:
try:
r = httpx.get(prm_health, timeout=5)
if r.status_code == 200:
break
except Exception:
pass
time.sleep(3)
logger.info("[OpenClaw] PRM server ready")
if self._prm_enabled:
if self._prm_provider == "local" and self._prm_url:
prm_health = self._prm_url.rsplit("/", 1)[0] + "/health"
while True:
try:
r = httpx.get(prm_health, timeout=5)
if r.status_code == 200:
break
except Exception:
pass
time.sleep(3)
logger.info("[OpenClaw] PRM server ready")
elif self._prm_provider == "api":
logger.info("[OpenClaw] PRM API provider ready: %s", self._prm_api_base_url)

time.sleep(8)
prm_line = ""
if self._prm_enabled:
prm_line = f"\n PRM enabled: {self._prm_url} (m={self._prm_m})"
if self._prm_provider == "api":
prm_line = (
f"\n PRM enabled (api): {self._prm_api_base_url}"
f" model={self._prm_api_model} (m={self._prm_m})"
)
else:
prm_line = f"\n PRM enabled (local): {self._prm_url} (m={self._prm_m})"
banner = (
f"\n{'=' * 70}\n"
f" [OpenClaw] your model is fired up and ready to roll\n"
Expand Down
Loading