Skip to content

Commit eb378cf

Browse files
committed
Address reviewer feedback previously not addressed (model-name/target-name/self.llm)
Signed-off-by: s-nrajpal <66713174+Nakul-Rajpal@users.noreply.github.com>
1 parent 74eb9ad commit eb378cf

File tree

2 files changed

+17
-15
lines changed

2 files changed

+17
-15
lines changed

garak/generators/llm.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
99
API keys and provider configuration are managed by ``llm`` itself
1010
(e.g. ``llm keys set openai``). Pass the ``llm`` model id or alias
11-
as ``--model_name``:
11+
as ``--target_name``:
1212
1313
.. code-block:: bash
1414
@@ -102,11 +102,11 @@ def _call_model(
102102
text_prompt = prompt.last_message("user").text
103103

104104
prompt_kwargs = {}
105-
if self.temperature is not None:
105+
if self.temperature:
106106
prompt_kwargs["temperature"] = self.temperature
107-
if self.max_tokens is not None:
107+
if self.max_tokens:
108108
prompt_kwargs["max_tokens"] = self.max_tokens
109-
if self.top_p is not None:
109+
if self.top_p:
110110
prompt_kwargs["top_p"] = self.top_p
111111
if self.stop:
112112
prompt_kwargs["stop"] = self.stop

tests/generators/test_llm.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,10 @@ def fake_llm(monkeypatch):
4444

4545

4646
def test_instantiation_resolves_model(cfg, fake_llm):
47-
gen = LLMGenerator(name="my-alias", config_root=cfg)
48-
assert gen.name == "my-alias"
49-
assert hasattr(gen, "target")
47+
test_name = "my-alias"
48+
gen = LLMGenerator(name=test_name, config_root=cfg)
49+
assert gen.name == test_name
50+
assert isinstance(gen.target, FakeModel)
5051

5152

5253
def test_generate_returns_message(cfg, fake_llm):
@@ -61,19 +62,20 @@ def test_generate_returns_message(cfg, fake_llm):
6162

6263
def test_param_passthrough(cfg, fake_llm):
6364
gen = LLMGenerator(name="alias", config_root=cfg)
64-
gen.temperature = 0.2
65-
gen.max_tokens = 64
66-
gen.top_p = 0.9
67-
gen.stop = ["\n\n"]
65+
temperature, max_tokens, top_p, stop = 0.2, 64, 0.9, ["\n\n"]
66+
gen.temperature = temperature
67+
gen.max_tokens = max_tokens
68+
gen.top_p = top_p
69+
gen.stop = stop
6870

6971
conv = Conversation([Turn("user", Message(text="hello"))])
7072
gen._call_model(conv)
7173

7274
_, kwargs = fake_llm.calls[0]
73-
assert kwargs["temperature"] == 0.2
74-
assert kwargs["max_tokens"] == 64
75-
assert kwargs["top_p"] == 0.9
76-
assert kwargs["stop"] == ["\n\n"]
75+
assert kwargs["temperature"] == temperature
76+
assert kwargs["max_tokens"] == max_tokens
77+
assert kwargs["top_p"] == top_p
78+
assert kwargs["stop"] == stop
7779

7880

7981
def test_handles_llm_exception(cfg, monkeypatch):

0 commit comments

Comments
 (0)