Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
215 changes: 99 additions & 116 deletions R/provider-openai.R
Original file line number Diff line number Diff line change
Expand Up @@ -49,28 +49,17 @@ chat_openai <- function(
api_key = openai_key(),
model = NULL,
params = NULL,
seed = lifecycle::deprecated(),
api_args = list(),
echo = c("none", "output", "all")
) {
model <- set_default(model, "gpt-4.1")
echo <- check_echo(echo)

params <- params %||% params()
if (lifecycle::is_present(seed) && !is.null(seed)) {
lifecycle::deprecate_warn(
when = "0.2.0",
what = "chat_openai(seed)",
with = "chat_openai(params)"
)
params$seed <- seed
}

provider <- ProviderOpenAI(
name = "OpenAI",
base_url = base_url,
model = model,
params = params,
params = params %||% params(),
extra_args = api_args,
api_key = api_key
)
Expand All @@ -84,7 +73,6 @@ chat_openai_test <- function(
echo = "none"
) {
params <- params %||% params()
params$seed <- params$seed %||% 1014
params$temperature <- params$temperature %||% 0

chat_openai(
Expand Down Expand Up @@ -145,44 +133,52 @@ method(base_request_error, ProviderOpenAI) <- function(provider, req) {
# Chat endpoint ----------------------------------------------------------------

method(chat_path, ProviderOpenAI) <- function(provider) {
"/chat/completions"
"/responses"
}

# https://platform.openai.com/docs/api-reference/chat/create
# https://platform.openai.com/docs/api-reference/responses
method(chat_body, ProviderOpenAI) <- function(
provider,
stream = TRUE,
turns = list(),
tools = list(),
type = NULL
) {
messages <- compact(unlist(as_json(provider, turns), recursive = FALSE))
input <- compact(unlist(as_json(provider, turns), recursive = FALSE))
tools <- as_json(provider, unname(tools))

if (!is.null(type)) {
response_format <- list(
type = "json_schema",
json_schema = list(
# https://platform.openai.com/docs/api-reference/responses/create#responses-create-text
text <- list(
format = list(
type = "json_schema",
name = "structured_data",
schema = as_json(provider, type),
strict = TRUE
)
)
} else {
response_format <- NULL
text <- NULL
}

# https://platform.openai.com/docs/api-reference/responses/create#responses-create-include
params <- chat_params(provider, provider@params)
params$seed <- params$seed %||% provider@seed
if (isTRUE(params$log_probs)) {
include <- list("message.output_text.logprobs")
} else {
include <- NULL
}
params$log_probs <- NULL

compact(list2(
messages = messages,
input = input,
include = include,
model = provider@model,
!!!params,
stream = stream,
stream_options = if (stream) list(include_usage = TRUE),
tools = tools,
response_format = response_format
text = text,
store = FALSE
))
}

Expand All @@ -194,11 +190,8 @@ method(chat_params, ProviderOpenAI) <- function(provider, params) {
temperature = "temperature",
top_p = "top_p",
frequency_penalty = "frequency_penalty",
presence_penalty = "presence_penalty",
seed = "seed",
max_tokens = "max_completion_tokens",
logprobs = "log_probs",
stop = "stop_sequences"
max_tokens = "max_output_tokens",
log_probs = "log_probs"
)
)
}
Expand All @@ -213,67 +206,53 @@ method(stream_parse, ProviderOpenAI) <- function(provider, event) {
jsonlite::parse_json(event$data)
}
method(stream_text, ProviderOpenAI) <- function(provider, event) {
if (length(event$choices) == 0) {
NULL
} else {
event$choices[[1]]$delta[["content"]]
# https://platform.openai.com/docs/api-reference/responses-streaming/response/output_text/delta
if (event$type == "response.output_text.delta") {
event$delta
}
}
method(stream_merge_chunks, ProviderOpenAI) <- function(
provider,
result,
chunk
) {
if (is.null(result)) {
chunk
} else {
merge_dicts(result, chunk)
# https://platform.openai.com/docs/api-reference/responses-streaming/response/completed
if (chunk$type == "response.completed") {
chunk$response
}
}

method(value_turn, ProviderOpenAI) <- function(
provider,
result,
has_type = FALSE
) {
if (has_name(result$choices[[1]], "delta")) {
# streaming
message <- result$choices[[1]]$delta
} else {
message <- result$choices[[1]]$message
}

if (has_type) {
if (is_string(message$content)) {
json <- jsonlite::parse_json(message$content[[1]])
contents <- lapply(result$output, function(output) {
if (output$type == "message") {
if (has_type) {
ContentJson(jsonlite::parse_json(output$content[[1]]$text))
} else {
ContentText(output$content[[1]]$text)
}
} else if (output$type == "function_call") {
arguments <- jsonlite::parse_json(output$arguments)
ContentToolRequest(output$id, output$name, arguments)
} else {
json <- message$content
}
content <- list(ContentJson(json))
} else {
content <- lapply(message$content, as_content)
}
if (has_name(message, "tool_calls")) {
calls <- lapply(message$tool_calls, function(call) {
name <- call$`function`$name
# TODO: record parsing error
args <- tryCatch(
jsonlite::parse_json(call$`function`$arguments),
error = function(cnd) list()
browser()
cli::cli_abort(
"Unknown content type {.str {content$type}}.",
.internal = TRUE
)
ContentToolRequest(name = name, arguments = args, id = call$id)
})
content <- c(content, calls)
}
}
})

# cached_tokens <- result$usage$input_token_details$cached_tokens
tokens <- tokens_log(
provider,
input = result$usage$prompt_tokens,
output = result$usage$completion_tokens
)
assistant_turn(
content,
json = result,
tokens = tokens
input = result$usage$input_tokens,
output = result$usage$output_tokens
)
assistant_turn(contents = contents, json = result, tokens = tokens)
}

# ellmer -> OpenAI --------------------------------------------------------------
Expand All @@ -284,61 +263,56 @@ method(as_json, list(ProviderOpenAI, Turn)) <- function(provider, x) {
list(role = "system", content = x@contents[[1]]@text)
)
} else if (x@role == "user") {
# Each tool result needs to go in its own message with role "tool"
is_tool <- map_lgl(x@contents, S7_inherits, ContentToolResult)
content <- as_json(provider, x@contents[!is_tool])
if (length(content) > 0) {
user <- list(list(role = "user", content = content))
} else {
user <- list()
}

tools <- lapply(x@contents[is_tool], function(tool) {
list(
role = "tool",
content = tool_string(tool),
tool_call_id = tool@request@id
)
lapply(x@contents, function(x) {
if (S7_inherits(x, ContentText)) {
list(role = "user", content = x@text)
} else {
as_json(provider, x)
}
})

c(user, tools)
} else if (x@role == "assistant") {
# Tool requests come out of content and go into own argument
is_tool <- map_lgl(x@contents, is_tool_request)
content <- as_json(provider, x@contents[!is_tool])
tool_calls <- as_json(provider, x@contents[is_tool])

list(
compact(list(
role = "assistant",
content = content,
tool_calls = tool_calls
))
)
as_json(provider, x@contents)
} else {
cli::cli_abort("Unknown role {x@role}", .internal = TRUE)
}
}

method(as_json, list(ProviderOpenAI, ContentText)) <- function(provider, x) {
list(type = "text", text = x@text)
# OpenAI uses a different format dependening on whether the text is provided
# by the user or generated by the assistant. Since ellmer content types don't
# distinguish, this method generates the assistant content and we special
# case the
list(
role = "assistant",
content = x@text
)
}

method(as_json, list(ProviderOpenAI, ContentImageRemote)) <- function(
provider,
x
) {
list(type = "image_url", image_url = list(url = x@url))
list(
type = "message",
role = "user",
content = list(
list(type = "input_image", image_url = x@url)
)
)
}

method(as_json, list(ProviderOpenAI, ContentImageInline)) <- function(
provider,
x
) {
list(
type = "image_url",
image_url = list(
url = paste0("data:", x@type, ";base64,", x@data)
type = "message",
role = "user",
content = list(
list(
type = "input_image",
image_url = paste0("data:", x@type, ";base64,", x@data)
)
)
)
}
Expand All @@ -347,23 +321,32 @@ method(as_json, list(ProviderOpenAI, ContentToolRequest)) <- function(
provider,
x
) {
json_args <- jsonlite::toJSON(x@arguments)
list(
id = x@id,
`function` = list(name = x@name, arguments = json_args),
type = "function"
type = "function_call",
call_id = x@id,
name = x@name,
arguments = jsonlite::toJSON(x@arguments)
)
}

method(as_json, list(ProviderOpenAI, ContentToolResult)) <- function(
provider,
x
) {
list(
type = "function_call_output",
call_id = x@request@id,
output = tool_string(x)
)
}

method(as_json, list(ProviderOpenAI, ToolDef)) <- function(provider, x) {
list(
type = "function",
"function" = compact(list(
name = x@name,
description = x@description,
strict = TRUE,
parameters = as_json(provider, x@arguments)
))
name = x@name,
description = x@description,
strict = TRUE,
parameters = as_json(provider, x@arguments)
)
}

Expand Down Expand Up @@ -422,7 +405,7 @@ method(batch_submit, ProviderOpenAI) <- function(
list(
custom_id = paste0("chat-", i),
method = "POST",
url = "/v1/chat/completions",
url = "/v1/responses",
body = body
)
})
Expand Down
2 changes: 2 additions & 0 deletions R/provider.R
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ stream_parse <- new_generic(
S7_dispatch()
}
)

# Extract text that should be printed to the console
stream_text <- new_generic(
"stream_text",
"provider",
Expand Down
9 changes: 4 additions & 5 deletions man/chat_openai.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading