-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathtest_guardrails.py
More file actions
182 lines (145 loc) · 5.73 KB
/
test_guardrails.py
File metadata and controls
182 lines (145 loc) · 5.73 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
#!/usr/bin/env python3
"""
Test script for guardrails functionality in PraisonAI Agents.
"""
import os
import sys
import logging
from praisonaiagents import Agent, Task, TaskOutput
from praisonaiagents.guardrails import GuardrailResult, LLMGuardrail
from typing import Tuple, Any
def test_function_guardrail():
"""Test function-based guardrail."""
print("Testing function-based guardrail...")
def validate_output(task_output: TaskOutput) -> Tuple[bool, Any]:
"""Simple validation function."""
if "error" in task_output.raw.lower():
return False, "Output contains errors"
if len(task_output.raw) < 10:
return False, "Output is too short"
return True, task_output
# Create agent and task with guardrail
agent = Agent(
name="Test Agent",
role="Tester",
goal="Test guardrails",
backstory="I am testing the guardrail functionality"
)
task = Task(
description="Write a simple hello message",
expected_output="A friendly greeting message",
agent=agent,
guardrail=validate_output,
max_retries=2
)
# Test with good output
good_output = TaskOutput(
description="Test task",
raw="Hello! This is a friendly greeting message from the agent.",
agent="Test Agent"
)
result = task._process_guardrail(good_output)
assert result.success, f"Good output should pass: {result.error}"
print("✓ Good output passed guardrail")
# Test with bad output
bad_output = TaskOutput(
description="Test task",
raw="Error occurred",
agent="Test Agent"
)
result = task._process_guardrail(bad_output)
assert not result.success, "Bad output should fail guardrail"
print("✓ Bad output failed guardrail as expected")
print("Function-based guardrail test passed!\n")
def test_string_guardrail():
"""Test string-based LLM guardrail."""
print("Testing string-based LLM guardrail...")
# Mock LLM for testing
class MockLLM:
def chat(self, prompt, **kwargs):
# Extract the actual output to validate from the prompt
# The LLMGuardrail sends a structured prompt with "Output to Validate:" section
if "Output to Validate:" in prompt:
# Split by "Output to Validate:" and get the content after it
parts = prompt.split("Output to Validate:")
if len(parts) > 1:
output_content = parts[1].strip()
# Check only the output content, not the validation criteria
if "error" in output_content.lower():
return "FAIL: The output contains error messages"
return "PASS"
# Fallback: if no "Output to Validate:" section, check the prompt directly
# This should rarely happen with proper LLMGuardrail usage
if "error" in prompt.lower() and "check if" not in prompt.lower():
return "FAIL: The output contains error messages"
return "PASS"
# Create agent with mock LLM
agent = Agent(
name="Test Agent",
role="Tester",
goal="Test guardrails",
backstory="I am testing the guardrail functionality"
)
agent.llm = MockLLM()
task = Task(
description="Write a simple hello message",
expected_output="A friendly greeting message",
agent=agent,
guardrail="Check if the output is professional and does not contain errors",
max_retries=2
)
# Test with good output
good_output = TaskOutput(
description="Test task",
raw="Hello! This is a professional greeting message.",
agent="Test Agent"
)
result = task._process_guardrail(good_output)
assert result.success, f"Good output should pass: {result.error}"
print("✓ Good output passed LLM guardrail")
# Test with bad output
bad_output = TaskOutput(
description="Test task",
raw="There was an error in the system",
agent="Test Agent"
)
result = task._process_guardrail(bad_output)
assert not result.success, "Bad output should fail LLM guardrail"
print("✓ Bad output failed LLM guardrail as expected")
print("String-based LLM guardrail test passed!\n")
def test_guardrail_result():
"""Test GuardrailResult helper methods."""
print("Testing GuardrailResult...")
# Test success case
success_result = GuardrailResult.from_tuple((True, "Modified output"))
assert success_result.success
assert success_result.result == "Modified output"
assert success_result.error == ""
print("✓ Success result created correctly")
# Test failure case
failure_result = GuardrailResult.from_tuple((False, "Validation failed"))
assert not failure_result.success
assert failure_result.result is None
assert failure_result.error == "Validation failed"
print("✓ Failure result created correctly")
print("GuardrailResult test passed!\n")
def main():
"""Run all tests."""
print("Running PraisonAI Agents Guardrails Tests...\n")
try:
test_guardrail_result()
test_function_guardrail()
test_string_guardrail()
print("🎉 All guardrail tests passed!")
print("\nGuardrails implementation is working correctly!")
except Exception as e:
print(f"❌ Test failed: {e}")
import traceback
traceback.print_exc()
return False
return True
if __name__ == "__main__":
# Set up basic logging
logging.basicConfig(level=logging.WARNING)
success = main()
sys.exit(0 if success else 1)