Open
Description
What happened?
I am trying to create a class that works with a custom local model client. Following response in a previous github repo I was able to create a class that works with a normal response but when using the Predict
module it is making calls to the default __call__
instead of the custom one.
Steps to reproduce
Here is the custom code
import requests
import os
import dspy
from dotenv import load_dotenv
from dspy import LM
# api_key = API_key
class AIMLAPI(LM):
def __init__(self, model, api_key, **kwargs):
super().__init__(model)
self.model = model
self.api_key = api_key
self.base_url = "http://some_ip_address/api/chat/completions"
self.kwargs.update(kwargs)
self.provider = "AIMLAPI"
def basic_request(self, prompt: str, **kwargs):
headers = {
'Content-Type': 'application/json',
'Authorization': f'Bearer {self.api_key}',
}
data = {
"model": self.model,
"messages": [{"role": "user", "content": prompt}],
**{k: v for k, v in kwargs.items() if k in self.kwargs}
}
response = requests.post(self.base_url, headers=headers, json=data)
response.raise_for_status()
response_json = response.json()
# Debugging 1: Print Json response from API: Good Response
print("API Response:", response_json)
self.history.append({
"prompt": prompt,
"response": response_json,
"kwargs": kwargs,
})
return response_json
def __call__(self, prompt, only_completed=True, return_sorted=False, **kwargs):
response = self.basic_request(prompt, **kwargs)
if 'choices' in response and len(response['choices']) > 0:
return [choice['message']['content'] for choice in response['choices']]
return "No valid response found"
# Load environment variables
load_dotenv()
# Set your API key and endpoint
# AIML_API_KEY = os.getenv("AIML_API_KEY")
# model = "meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo"
model = "llama3.3:70b"
llama_lm = AIMLAPI(
model=model,
api_key=API_KEY,
temperature=0.7,
max_tokens=100,
n=1
)
dspy.configure(lm=llama_lm, temperature=0.7, max_tokens=1000, n=1)
prompt = "I want to ask about the universe"
response = llama_lm(prompt)
print(f"Full Model Response: {response}") ## Working properly
# Process with Predict
processed_response = dspy.Predict("question -> answer", n=1)(prompt=prompt, lm=llama_lm)
print("Full Response from dspy.Predict:", processed_response)
For the second one prompt with predict I am getting the following error
xpected dict_keys(['question']) but got dict_keys(['prompt'])
File "/Users/paula/Projects/GraphRAG/GraphRAG_venv/lib/python3.10/site-packages/dspy/adapters/chat_adapter.py", line 49, in __call__
return super().__call__(lm, lm_kwargs, signature, demos, inputs)
File "/Users/paula/Projects/GraphRAG/GraphRAG_venv/lib/python3.10/site-packages/dspy/adapters/base.py", line 31, in __call__
inputs_ = self.format(signature, demos, inputs)
File "/Users/paula/Projects/GraphRAG/GraphRAG_venv/lib/python3.10/site-packages/dspy/utils/callback.py", line 266, in wrapper
return fn(instance, *args, **kwargs)
File "/Users/paula/Projects/GraphRAG/GraphRAG_venv/lib/python3.10/site-packages/dspy/adapters/chat_adapter.py", line 87, in format
messages.append(self.format_turn(signature, inputs, role="user"))
File "/Users/paula/Projects/GraphRAG/GraphRAG_venv/lib/python3.10/site-packages/dspy/adapters/chat_adapter.py", line 156, in format_turn
return format_turn(signature, values, role, incomplete, is_conversation_history)
File "/Users/paula/Projects/GraphRAG/GraphRAG_venv/lib/python3.10/site-packages/dspy/adapters/chat_adapter.py", line 217, in format_turn
raise ValueError(f"Expected {fields.keys()} but got {values.keys()}")
ValueError: Expected dict_keys(['question']) but got dict_keys(['prompt'])
During handling of the above exception, another exception occurred:
File "/Users/paula/Projects/GraphRAG/GraphRAG_venv/lib/python3.10/site-packages/dspy/adapters/json_adapter.py", line 214, in format_turn
raise ValueError(f"Expected {field_names} but got {values.keys()}")
File "/Users/paula/Projects/GraphRAG/GraphRAG_venv/lib/python3.10/site-packages/dspy/adapters/json_adapter.py", line 140, in format_turn
return format_turn(signature, values, role, incomplete, is_conversation_history)
File "/Users/paula/Projects/GraphRAG/GraphRAG_venv/lib/python3.10/site-packages/dspy/adapters/json_adapter.py", line 103, in format
messages.append(self.format_turn(signature, inputs, role="user"))
File "/Users/paula/Projects/GraphRAG/GraphRAG_venv/lib/python3.10/site-packages/dspy/utils/callback.py", line 266, in wrapper
return fn(instance, *args, **kwargs)
File "/Users/paula/Projects/GraphRAG/GraphRAG_venv/lib/python3.10/site-packages/dspy/adapters/json_adapter.py", line 43, in __call__
inputs = self.format(signature, demos, inputs)
File "/Users/paula/Projects/GraphRAG/GraphRAG_venv/lib/python3.10/site-packages/dspy/adapters/chat_adapter.py", line 55, in __call__
return JSONAdapter()(lm, lm_kwargs, signature, demos, inputs)
File "/Users/paula/Projects/GraphRAG/GraphRAG_venv/lib/python3.10/site-packages/dspy/predict/predict.py", line 100, in forward
completions = adapter(
File "/Users/paula/Projects/GraphRAG/GraphRAG_venv/lib/python3.10/site-packages/dspy/predict/predict.py", line 73, in __call__
return self.forward(**kwargs)
File "/Users/paula/Projects/GraphRAG/GraphRAG_venv/lib/python3.10/site-packages/dspy/utils/callback.py", line 266, in wrapper
return fn(instance, *args, **kwargs)
File "/Users/paula/Projects/GraphRAG/kg-gen/CustomLM_test.py", line 78, in <module>
processed_response = dspy.Predict("question -> answer", n=1)(prompt=prompt, lm=llama_lm)
File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/runpy.py", line 196, in _run_module_as_main (Current frame)
return _run_code(code, main_globals, None,
ValueError: Expected dict_keys(['question']) but got dict_keys(['prompt'])
"llama3.3:70b",
When I am changing the prompt
to question
as in processed_response = dspy.Predict("question -> answer", n=1)(question=prompt, lm=llama_lm)
I am getting the following error
Exception has occurred: TypeError (note: full exception trace is shown but execution is paused at: _run_module_as_main)
AIMLAPI.__call__() missing 1 required positional argument: 'prompt'
File "/Users/paula/Projects/GraphRAG/GraphRAG_venv/lib/python3.10/site-packages/dspy/adapters/chat_adapter.py", line 49, in __call__
return super().__call__(lm, lm_kwargs, signature, demos, inputs)
File "/Users/paula/Projects/GraphRAG/GraphRAG_venv/lib/python3.10/site-packages/dspy/adapters/base.py", line 34, in __call__
outputs = lm(**inputs_, **lm_kwargs)
TypeError: AIMLAPI.__call__() missing 1 required positional argument: 'prompt'
During handling of the above exception, another exception occurred:
File "/Users/paula/Projects/GraphRAG/GraphRAG_venv/lib/python3.10/site-packages/dspy/adapters/json_adapter.py", line 61, in __call__
outputs = lm(**inputs, **lm_kwargs)
File "/Users/paula/Projects/GraphRAG/GraphRAG_venv/lib/python3.10/site-packages/dspy/adapters/chat_adapter.py", line 55, in __call__
return JSONAdapter()(lm, lm_kwargs, signature, demos, inputs)
File "/Users/paula/Projects/GraphRAG/GraphRAG_venv/lib/python3.10/site-packages/dspy/predict/predict.py", line 100, in forward
completions = adapter(
File "/Users/paula/Projects/GraphRAG/GraphRAG_venv/lib/python3.10/site-packages/dspy/predict/predict.py", line 73, in __call__
return self.forward(**kwargs)
File "/Users/paula/Projects/GraphRAG/GraphRAG_venv/lib/python3.10/site-packages/dspy/utils/callback.py", line 266, in wrapper
return fn(instance, *args, **kwargs)
File "/Users/paula/Projects/GraphRAG/kg-gen/CustomLM_test.py", line 78, in <module>
processed_response = dspy.Predict("question -> answer", n=1)(question=prompt, lm=llama_lm)
File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/runpy.py", line 86, in _run_code
exec(code, run_globals)
File "/Library/Frameworks/Python.framework/Versions/3.10/lib/python3.10/runpy.py", line 196, in _run_module_as_main (Current frame)
return _run_code(code, main_globals, None,
TypeError: AIMLAPI.__call__() missing 1 required positional argument: 'prompt'
I am no python wizard, but would be helpful if the above example can be modified into a working one. Thanks!
DSPy version
2.6.16