Skip to content

Commit 149c8f1

Browse files
committed
Compatibility with Completion task
1 parent b31a5c5 commit 149c8f1

File tree

2 files changed

+20
-17
lines changed

2 files changed

+20
-17
lines changed

lmeval/models/litellm.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,6 @@ def complete(
136136
temperature: float = 0.0,
137137
completions: int = 1,
138138
max_tokens: int = 4096,
139-
return_first: bool = True,
140139
**generation_kwargs,
141140
) -> LMAnswer | list[LMAnswer]:
142141
# FIXME: finish multi-completion support
@@ -155,11 +154,8 @@ def complete(
155154
resp = None
156155
print("Can't get response from model:", traceback.format_exc())
157156

158-
answers = self._make_answer(resp)
159-
if return_first:
160-
return answers[0]
161-
else:
162-
return answers
157+
answer = self._make_answer(resp)
158+
return answer
163159

164160
def _make_grouped_answer(self, answers: list[LMAnswer]) -> LMAnswer:
165161
is_error = any([a.iserror for a in answers])
@@ -183,12 +179,13 @@ def _make_grouped_answer(self, answers: list[LMAnswer]) -> LMAnswer:
183179

184180
def multi_complete(self, grouped_question: GroupedQuestion, temperature: float = 0.0,
185181
completions: int = 1, max_tokens: int = 4096, **generation_kwargs) -> LMAnswer:
186-
n_completions = grouped_question.metadata['n_completions']
182+
n_completions = grouped_question.metadata.get('n_completions', 1)
183+
temperature = grouped_question.metadata.get('temperature', None)
187184
grouped_answers = []
188185

189186
for question in grouped_question.question_set:
190-
answers = self.complete(question.messages, temperature, n_completions, 100, return_first=False, **generation_kwargs)
191-
grouped_answers += answers
187+
answer = self.complete(question.messages, temperature, n_completions, max_tokens, **generation_kwargs)
188+
grouped_answers.append(answer)
192189

193190

194191
return self._make_grouped_answer(grouped_answers)
@@ -236,22 +233,28 @@ def _make_answer(self,
236233
prompt: str = "") -> LMAnswer:
237234
iserror = False
238235
error_reason = ""
239-
236+
raw_response = ""
240237
cost = total_tokens = prompt_tokens = completion_tokens = 0
241238
total_time = 0
242239
model_name = self.runtime_vars['litellm_version_string']
240+
response_id = ""
243241

244242
if isinstance(resp, ModelResponse):
245243
response = resp
244+
response_id = resp.id
246245

247246
log.debug("response: %s", response)
248247
try:
249248
answer_contents = [c.message.content for c in response.choices]
250249
tool_calls = [c.message.tool_calls for c in response.choices]
250+
251251
if all(r is None and tc is None for r, tc in zip(answer_contents, tool_calls)):
252252
raise ValueError("No response from model")
253-
if raw_response is None and tool_calls is not None:
254-
raw_response = ""
253+
254+
for answer in answer_contents:
255+
if answer is not None:
256+
raw_response = answer
257+
break
255258

256259
except Exception as e:
257260
try:
@@ -284,7 +287,7 @@ def _make_answer(self,
284287
else:
285288
iserror = True
286289
error_reason = f'{resp}'
287-
raw_responses = []
290+
288291
elif isinstance(resp, Exception):
289292
iserror = True
290293
error_reason = repr(resp)
@@ -295,7 +298,7 @@ def _make_answer(self,
295298
iserror = True
296299
error_reason = "Not implemented"
297300

298-
answers = [self._build_answer(text=r,
301+
answer = self._build_answer(text=raw_response,
299302
generation_time=total_time,
300303
iserror=iserror,
301304
error_reason=error_reason,
@@ -306,7 +309,8 @@ def _make_answer(self,
306309
isunsafe=self.isunsafe,
307310
prompt=prompt,
308311
id=response_id)
309-
answer.raw_response = response.model_dump()
312+
if isinstance(resp, ModelResponse):
313+
answer.raw_response = resp.model_dump()
310314
return answer
311315

312316
def _batch_completion(self,
@@ -368,7 +372,6 @@ def _completion(self,
368372
messages = self._replace_system_messages(messages)
369373
messages = self._merge_messages_by_role(messages)
370374

371-
print("Temperature: ", temperature)
372375
resp = completion(model=model,
373376
messages=messages,
374377
temperature=temperature,

lmeval/models/lmmodel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def batch_execute(self, tasks: list["EvalTask"], temperature: float = 0.0,
133133
""" Execute a batch of prompts in parallel."""
134134
for i, etask in enumerate(tasks):
135135
if etask.task.type == TaskType.completion.value:
136-
yield i, self.complete(etask.messages, temperature, completions, tools=tools)
136+
yield i, self.complete(etask.messages, temperature, completions, tools=etask.question.tools)
137137
elif etask.task.type == TaskType.grouped_completion.value:
138138
yield i, self.multi_complete(etask.question, temperature=temperature, completions=10)
139139
else:

0 commit comments

Comments
 (0)