Skip to content

Commit 74f8bb0

Browse files
committed
Merge branch 'main' of github.com:google/lmeval
2 parents 4640d4e + d32fe5d commit 74f8bb0

File tree

10 files changed

+290
-18
lines changed

10 files changed

+290
-18
lines changed

lmeval/models/mock_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def generate_text(
6060
prompt: str,
6161
medias: List[Media] | Media = [],
6262
temperature: float | None = 0.0,
63+
max_tokens: int = 4096,
6364
completions: int = 1) -> LMAnswer:
6465
# print(f"generate_text: {prompt}")
6566
id = "mock"
@@ -85,7 +86,7 @@ def generate_text(
8586

8687
def batch_generate_text(
8788
self, prompts: list[str], medias: list[list[Media] | Media] = [],
88-
temperature: float | None = 0.0,
89+
temperature: float | None = 0.0, max_tokens:int = 4096,
8990
completions: int = 1) -> Generator[Tuple[int, LMAnswer], None, None]:
9091
log.info(f"mock-batch_generate_text: {len(prompts)} prompts")
9192
for i, prompt in enumerate(prompts):

lmeval/prompts/multi_choices_prompts.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,17 +39,21 @@
3939

4040

4141
class MultiChoicesMultiAnswersPrompt(Prompt):
42+
use_original_letters: bool = False
43+
4244
def __init__(self,
4345
template: str = MULTI_ANSWER_TEMPLATE,
4446
name: str = "Multi Choices Multi Answer Picker",
4547
description: str = "Ask the model to return the letters associated with potentially multiple correct answers",
4648
task_type = TaskType.multiple_choices_multiple_answers,
4749
url: str = '',
48-
version: str = '1.0'):
50+
version: str = '1.0',
51+
use_original_letters: bool = False):
4952

5053
super().__init__(name=name, description=description,
5154
task_type=task_type, template=template, url=url,
5255
version=version)
56+
self.use_original_letters = use_original_letters
5357

5458
def render(self, question: Question, task: Task) -> str:
5559
"Render prompt for a given question and task"
@@ -69,15 +73,18 @@ def render(self, question: Question, task: Task) -> str:
6973
question.letter_mapping = question.prompt_cache[version]['letter_mapping']
7074
else:
7175
possible_answers = [question.answer] + question.additional_answers + question.choices
72-
random.shuffle(possible_answers)
76+
if self.use_original_letters:
77+
assert len(possible_answers) == len(question.original_letters), f"Original letters {question.original_letters} should match the number of possible answers {possible_answers}"
78+
else:
79+
random.shuffle(possible_answers)
7380

7481
# Construct the list of possible answers
7582
choices_list = []
7683
letters_list = []
7784
letter_mapping = {}
7885
correct_letters = []
7986
for idx, answer in enumerate(possible_answers):
80-
letter = ascii_uppercase[idx]
87+
letter = question.original_letters[idx] if self.use_original_letters else ascii_uppercase[idx]
8188
# don't put space between letter and answer it decrease accuracy...
8289
choices_list.append(f"{letter}:{answer}")
8390
letters_list.append(letter)
@@ -88,6 +95,10 @@ def render(self, question: Question, task: Task) -> str:
8895
correct_letters.append(letter)
8996
if answer in question.additional_answers:
9097
correct_letters.append(letter)
98+
if self.use_original_letters:
99+
choices_list.sort()
100+
letters_list.sort()
101+
correct_letters.sort()
91102

92103
question.answer_letter = ', '.join(correct_letters)
93104

@@ -132,18 +143,21 @@ def render(self, question: Question, task: Task) -> str:
132143

133144

134145
class MultiChoicesPrompt(Prompt):
146+
use_original_letters: bool = False
135147

136148
def __init__(self,
137149
template: str = TEMPLATE,
138150
name: str = "Multi Choices Picker",
139151
description: str = "Ask the model to return the letter associated with the correct answer",
140152
task_type = TaskType.multiple_choices,
141153
url: str = '',
142-
version: str = '1.0'):
154+
version: str = '1.0',
155+
use_original_letters: bool = False):
143156

144157
super().__init__(name=name, description=description,
145158
task_type=task_type, template=template, url=url,
146159
version=version)
160+
self.use_original_letters = use_original_letters
147161

148162
def render(self, question: Question, task: Task) -> str:
149163
"Render prompt for a given question and task"
@@ -162,20 +176,26 @@ def render(self, question: Question, task: Task) -> str:
162176
question.letter_mapping = question.prompt_cache[version]['letter_mapping']
163177
else:
164178
possible_answers = [question.answer] + question.choices
165-
random.shuffle(possible_answers)
179+
if self.use_original_letters:
180+
assert len(possible_answers) == len(question.original_letters), f"Original letters {question.original_letters} should match the number of possible answers {possible_answers}"
181+
else:
182+
random.shuffle(possible_answers)
166183

167184
# Construct the list of possible answers
168185
choices_list = []
169186
letters_list = []
170187
letter_mapping = {}
171188
for idx, answer in enumerate(possible_answers):
172-
letter = ascii_uppercase[idx]
189+
letter = question.original_letters[idx] if self.use_original_letters else ascii_uppercase[idx]
173190
# don't put space between letter and answer it decrease accuracy...
174191
choices_list.append(f"{letter}:{answer}")
175192
letters_list.append(letter)
176193
letter_mapping[letter] = answer
177194
if answer == question.answer:
178195
question.answer_letter = letter
196+
if self.use_original_letters:
197+
choices_list.sort()
198+
letters_list.sort()
179199

180200
# flatten
181201
multi_choices = "\n".join(choices_list)

lmeval/prompts/multi_choices_prompts_test.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,50 @@ def test_multi_choices_multi_answers():
5757
for letter, answer in question.letter_mapping.items():
5858
assert f"{letter}:{answer}" in rendered_prompt
5959

60+
def test_multi_choices_multi_answers_original_letters():
61+
prompt = MultiChoicesMultiAnswersPrompt(use_original_letters=True)
62+
question_text = "What is true about Paris"
63+
question = Question(id=1,
64+
question=question_text,
65+
answer="It is the capital of France",
66+
additional_answers=["The Louvre is there",
67+
"The effeil tower is there"],
68+
choices=["It is the capital of Portugal",
69+
"It is the capital of Germany",
70+
"The Guggenheim museum is there",
71+
"THe MoMa is there"],
72+
original_letters=['G', 'F', 'E', 'D', 'C', 'B', 'A'])
73+
74+
task = Task(name="Paris Info", type=TaskType.multiple_choices_multiple_answers,
75+
scorer=get_scorer(ScorerType.contains_answer_letters_insensitive))
76+
rendered_prompt = prompt.render(question, task)
77+
print(prompt.template)
78+
print(rendered_prompt)
79+
80+
assert question_text in rendered_prompt
81+
for choice in question.choices:
82+
assert choice in rendered_prompt
83+
for answer in question.answer:
84+
assert answer in rendered_prompt
85+
for c in ['A', 'B', 'C', 'D', 'E', 'F', 'G']:
86+
assert f"\n{c}:" in rendered_prompt
87+
88+
# check that the answer letter is tied to the correct answer
89+
90+
assert f"{answer}" in rendered_prompt
91+
92+
# check that the additional answers are in the prompt
93+
for additional_answer in question.additional_answers:
94+
assert additional_answer in rendered_prompt
6095

96+
# check the mapping from letter to answer exist
97+
for letter, answer in question.letter_mapping.items():
98+
assert f"{letter}:{answer}" in rendered_prompt
99+
# test the original order is preserved
100+
for idx, answer in enumerate(
101+
[question.answer] + question.additional_answers + question.choices
102+
):
103+
assert f"{question.original_letters[idx]}:{answer}"in rendered_prompt
61104

62105
def test_multi_choices():
63106
prompt = MultiChoicesPrompt()
@@ -83,6 +126,34 @@ def test_multi_choices():
83126
for letter, answer in question.letter_mapping.items():
84127
assert f"{letter}:{answer}" in rendered_prompt
85128

129+
def test_multi_choices_original_letters():
130+
prompt = MultiChoicesPrompt(use_original_letters=True)
131+
question_text = "What is the capital of France?"
132+
question = Question(id=1, question=question_text, answer="Paris",
133+
choices=["London", "Berlin", "Madrid"],
134+
original_letters=["D", "B", "A", "C"])
135+
task = Task(name="City capital", type=TaskType.multiple_choices,
136+
scorer=get_scorer(ScorerType.contain_text_insensitive))
137+
rendered_prompt = prompt.render(question, task)
138+
print(prompt.template)
139+
print(rendered_prompt)
140+
141+
142+
assert question_text in rendered_prompt
143+
for choice in question.choices:
144+
assert choice in rendered_prompt
145+
assert question.answer in rendered_prompt
146+
for c in ['A', 'B', 'C', 'D']:
147+
assert f"\n{c}:" in rendered_prompt
148+
149+
# check that the answer letter is tied to the correct answer
150+
assert f"{question.answer_letter}:{question.answer}" in rendered_prompt
151+
for letter, answer in question.letter_mapping.items():
152+
assert f"{letter}:{answer}" in rendered_prompt
153+
# test the original order is preserved
154+
for idx, answer in enumerate([question.answer] + question.choices):
155+
assert f"{question.original_letters[idx]}:{answer}"in rendered_prompt
156+
86157
def test_repeated_used_multi_choices():
87158
prompt = MultiChoicesPrompt()
88159
question_text = "What is the capital of France?"
@@ -122,4 +193,4 @@ def test_answer_in_choice_fail():
122193
task = Task(name="City capital", type=TaskType.multiple_choices,
123194
scorer=get_scorer(ScorerType.contain_text_insensitive))
124195
with pytest.raises(AssertionError):
125-
prompt.render(question, task)
196+
prompt.render(question, task)

lmeval/question.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ class Question(CustomModel):
5353
multi_choices: str = Field(default="")
5454
letter_mapping: dict = Field(default_factory=dict,
5555
description="Keep track of which letter is associated with which answer")
56-
56+
original_letters: List[str] = Field(default_factory=list,
57+
description="Keep track of the original letters for: [anwer] + additional_answers + choices")
5758

5859
# cache template rendering keyed by prompt version to ensure consistency
5960
# accross model evaluations.

lmeval/scorers/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
from .scorer import Scorer
16+
from .llm_rater import LLMRater
1617
from .loader import get_scorer, list_scorers
1718
from .dummy_scorer import Always0Scorer, Always1Scorer
1819
from .boolean_answer import BooleanAnswerScorer
@@ -38,4 +39,5 @@
3839
"ContainAnswerLetterInsensitive",
3940
"ContainAnswerLettersInsensitive",
4041
"PuntDetector",
41-
]
42+
"LLMRater",
43+
]

lmeval/scorers/llm_rater.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import json
16+
from string import Template
17+
from pydantic import Field
18+
from typing_extensions import override
19+
20+
from ..enums import Modality, ScorerType
21+
from ..logger import log
22+
from ..models import LMAnswer
23+
from ..question import Question
24+
from .scorer import Scorer
25+
26+
DEFAULT_RATER_TEMPLATE = Template('''
27+
You are an impartial evaluator whose job is to determine if two sets of answer to a question are equivalent.
28+
The question is this:
29+
30+
<question>
31+
$question
32+
</question>
33+
34+
Here are two sets of answers:
35+
36+
<answer1>
37+
$expected
38+
</answer1>
39+
40+
<answer2>
41+
$actual
42+
</answer2>
43+
44+
Rate on the scale from 0.0 to 1.0 how similar answer1 is to answer2. Here 0.0 means they are completely different
45+
and 1.0 means they are semantically equivalent. Here are some rubrics to help you:
46+
47+
1. Using the question as the context, list all the relevant facts from answer1 and compare them with the facts
48+
presented in answer2 to see if they are the equivalent.
49+
2. Do both answers come to the same conclusion?
50+
3. Do not consider stylistic differences such as the tone, the writing presentation (for instance bullet points vs paragraph).
51+
52+
Write your rating and reasoning for the rateing in json format
53+
like this:
54+
55+
{
56+
"score": the rating score between 0 and 1,
57+
"reasoning": explain how you arrived at this rating
58+
}
59+
60+
''')
61+
62+
63+
def _parse_response_as_json(val:str):
64+
jline = val.split('\n')
65+
start = 1 if jline[0].startswith("```") else 0
66+
end = -1 if jline[-1].startswith("```") else len(jline)
67+
j = '\n'.join(jline[start:end])
68+
return json.loads(j)
69+
70+
71+
class LLMRater(Scorer):
72+
"""A scorer using a LLM to rate the similiarity between the expected and actual answers.
73+
"""
74+
class Config:
75+
arbitrary_types_allowed = True # to enable Template as an attribute
76+
name: str = ScorerType.llm_rater.name
77+
description: str = 'Calling a model to rate the answer on the scale from 0 to 1'
78+
type: ScorerType = ScorerType.llm_rater
79+
modality: Modality = Modality.text # assume text for now
80+
# The template is expect to have 3 parmeters: $question, $expectd, $actual. $question
81+
# is the question asked, expected is the right answer and actual is the received answer.
82+
# The prompt shoudl return JSON with a field "rating"
83+
rater_prompt_template: Template = DEFAULT_RATER_TEMPLATE
84+
temperature: float = Field(default=0.0)
85+
max_tokens: int = Field(default=4096)
86+
87+
@override
88+
def _score(self,
89+
model_answer: LMAnswer,
90+
question: Question,
91+
task,
92+
debug: bool = False) -> float:
93+
# if model for the class is set, use it, else use the model from the answer
94+
model = self.model if self.model else model_answer.model
95+
assert model # must have a model
96+
prompt = self.rater_prompt_template.safe_substitute(
97+
question=question.question,
98+
expected=question.answer,
99+
actual=model_answer.answer)
100+
ans = model.generate_text(prompt=prompt,
101+
temperature=self.temperature,
102+
max_tokens=self.max_tokens)
103+
if ans.iserror:
104+
log.error('Rater failed with error %s', ans.error_reason)
105+
return -1.0
106+
if ans.ispunting:
107+
log.error('Rater punted')
108+
return 0.0
109+
try:
110+
jans = _parse_response_as_json(ans.answer)
111+
score = jans.get('score', None)
112+
assert score is not None
113+
return score
114+
except Exception as e: # pylint: disable=broad-except
115+
log.error('Rater json parsing failed: ans = %s, exception = %s',
116+
ans.answer, e)
117+
return -1.0

0 commit comments

Comments
 (0)