Skip to content

Commit 40e2a09

Browse files
authored
Propagates llmclient changes to pqa (#839)
1 parent 3dee112 commit 40e2a09

20 files changed

+3184
-3144
lines changed

.mailmap

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
2-
James Braza <[email protected]> <jamesbraza@gmail.com>
2+
Anush008 <[email protected]> Anush <anushshetty90@gmail.com>
33
44
5+
6+
Harry Vu <[email protected]> harryvu-futurehouse
7+
8+
Mayk Caldas <[email protected]> maykcaldas
9+
510
Michael Skarlinski <[email protected]> mskarlin <[email protected]>
611
Odhran O'Donoghue <[email protected]> odhran-o-d <[email protected]>
712
813
9-
10-
Mayk Caldas <[email protected]> maykcaldas <[email protected]>
11-
12-
Harry Vu <[email protected]> harryvu-futurehouse

.pre-commit-config.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,9 @@ repos:
8484
- aiohttp>=3.10.6 # Match pyproject.toml
8585
- PyMuPDF>=1.24.12
8686
- anyio
87-
- fh-llm-client[deepseek]<0.1.0 # Match pyproject.toml
87+
- fhlmi>=0.0.1
8888
- fhaviary[llm]>=0.18.2 # Match pyproject.toml
89-
- ldp>=0.20 # Match pyproject.toml
89+
- ldp>=0.25.0 # Match pyproject.toml
9090
- html2text
9191
- httpx
9292
- pybtex

paperqa/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import warnings
22

3-
from llmclient import (
3+
from lmi import (
44
EmbeddingModel,
55
HybridEmbeddingModel,
66
LiteLLMEmbeddingModel,

paperqa/agents/env.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
ToolResponseMessage,
1414
)
1515
from aviary.utils import MultipleChoiceQuestion
16-
from llmclient import EmbeddingModel, LiteLLMModel
16+
from lmi import EmbeddingModel, LiteLLMModel
1717

1818
from paperqa.docs import Docs
1919
from paperqa.settings import Settings

paperqa/agents/helpers.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@
55
from datetime import datetime
66
from typing import cast
77

8-
from llmclient import LiteLLMModel, LLMModel
8+
from aviary.core import Message
9+
from lmi import LiteLLMModel, LLMModel
910
from rich.table import Table
1011

1112
from paperqa.docs import Docs
@@ -60,12 +61,13 @@ async def litellm_get_search_query(
6061
)
6162
else:
6263
model = llm
63-
result = await model.run_prompt(
64-
prompt=search_prompt,
65-
data={"question": question, "count": count},
66-
system_prompt=None,
64+
messages = [
65+
Message(content=search_prompt.format(question=question, count=count)),
66+
]
67+
result = await model.call_single(
68+
messages=messages,
6769
)
68-
search_query = result.text
70+
search_query = cast(str, result.text)
6971
queries = [s for s in search_query.split("\n") if len(s) > 3] # noqa: PLR2004
7072
# remove "2.", "3.", etc. -- https://regex101.com/r/W2f7F1/1
7173
queries = [re.sub(r"^\d+\.\s*", "", q) for q in queries]

paperqa/agents/models.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
import time
66
from contextlib import asynccontextmanager
77
from enum import StrEnum
8-
from typing import Any, ClassVar, Protocol
8+
from typing import Any, ClassVar, Protocol, cast
99
from uuid import UUID, uuid4
1010

11-
from llmclient import LiteLLMModel, LLMModel
11+
from aviary.core import Message
12+
from lmi import LiteLLMModel, LLMModel
1213
from pydantic import (
1314
BaseModel,
1415
ConfigDict,
@@ -79,12 +80,20 @@ async def get_summary(self, llm_model: LLMModel | str = "gpt-4o") -> str:
7980
model = (
8081
LiteLLMModel(name=llm_model) if isinstance(llm_model, str) else llm_model
8182
)
82-
result = await model.run_prompt(
83-
prompt="{question}\n\n{answer}",
84-
data={"question": self.session.question, "answer": self.session.answer},
85-
system_prompt=sys_prompt,
83+
prompt_template = "{question}\n\n{answer}"
84+
messages = [
85+
Message(role="system", content=sys_prompt),
86+
Message(
87+
role="user",
88+
content=prompt_template.format(
89+
question=self.session.question, answer=self.session.answer
90+
),
91+
),
92+
]
93+
result = await model.call_single(
94+
messages=messages,
8695
)
87-
return result.text.strip()
96+
return cast(str, result.text).strip()
8897

8998

9099
class TimerData(BaseModel):

paperqa/agents/tools.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from typing import ClassVar, Self, cast
1212

1313
from aviary.core import ToolRequestMessage
14-
from llmclient import Embeddable, EmbeddingModel, LiteLLMModel
14+
from lmi import Embeddable, EmbeddingModel, LiteLLMModel
1515
from pydantic import BaseModel, ConfigDict, Field, computed_field
1616

1717
from paperqa.docs import Docs

paperqa/clients/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
from typing import Any, cast
77

88
import aiohttp
9+
from lmi.utils import gather_with_concurrency
910
from pydantic import BaseModel, ConfigDict
1011

1112
from paperqa.types import Doc, DocDetails
12-
from paperqa.utils import gather_with_concurrency
1313

1414
from .client_models import MetadataPostProcessor, MetadataProvider
1515
from .crossref import CrossrefProvider

paperqa/core.py

Lines changed: 28 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,11 @@
33
import json
44
import re
55
from collections.abc import Callable, Sequence
6-
from typing import Any
6+
from typing import Any, cast
7+
8+
from aviary.core import Message
9+
from lmi import LLMModel
710

8-
from paperqa.llms import PromptRunner
911
from paperqa.types import Context, LLMResult, Text
1012
from paperqa.utils import extract_score, strip_citations
1113

@@ -102,7 +104,8 @@ def fraction_replacer(match: re.Match) -> str:
102104
async def map_fxn_summary(
103105
text: Text,
104106
question: str,
105-
prompt_runner: PromptRunner | None,
107+
summary_llm_model: LLMModel | None,
108+
prompt_templates: tuple[str, str] | None,
106109
extra_prompt_data: dict[str, str] | None = None,
107110
parser: Callable[[str], dict[str, Any]] | None = None,
108111
callbacks: Sequence[Callable[[str], None]] | None = None,
@@ -115,12 +118,14 @@ async def map_fxn_summary(
115118
116119
Args:
117120
text: The text to parse.
118-
question: The question to use for the chain.
119-
prompt_runner: The prompt runner to call - should have question, citation,
120-
summary_length, and text fields.
121-
extra_prompt_data: Optional extra kwargs to pass to the prompt runner's data.
122-
parser: The parser to use for parsing - return empty dict on Failure to fallback to text parsing.
123-
callbacks: LLM callbacks to execute in the prompt runner.
121+
question: The question to use for summarization.
122+
summary_llm_model: The LLM model to use for generating summaries.
123+
prompt_templates: Optional two-elements tuple containing templates for the user and system prompts.
124+
prompt_templates = (user_prompt_template, system_prompt_template)
125+
extra_prompt_data: Optional extra data to pass to the prompt template.
126+
parser: Optional parser function to parse LLM output into structured data.
127+
Should return dict with at least 'summary' field.
128+
callbacks: Optional sequence of callback functions to execute during LLM calls.
124129
125130
Returns:
126131
The context object and LLMResult to get info about the LLM execution.
@@ -131,14 +136,21 @@ async def map_fxn_summary(
131136
citation = text.name + ": " + text.doc.formatted_citation
132137
success = False
133138

134-
if prompt_runner:
135-
llm_result = await prompt_runner(
136-
{"question": question, "citation": citation, "text": text.text}
137-
| (extra_prompt_data or {}),
138-
callbacks,
139-
"evidence:" + text.name,
139+
if summary_llm_model and prompt_templates:
140+
data = {"question": question, "citation": citation, "text": text.text} | (
141+
extra_prompt_data or {}
142+
)
143+
message_prompt, system_prompt = prompt_templates
144+
messages = [
145+
Message(role="system", content=system_prompt.format(**data)),
146+
Message(role="user", content=message_prompt.format(**data)),
147+
]
148+
llm_result = await summary_llm_model.call_single(
149+
messages=messages,
150+
callbacks=callbacks,
151+
name="evidence:" + text.name,
140152
)
141-
context = llm_result.text
153+
context = cast(str, llm_result.text)
142154
result_data = parser(context) if parser else {}
143155
success = bool(result_data)
144156
if success:

0 commit comments

Comments
 (0)