Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 77 additions & 48 deletions async-openai/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,13 @@ use std::pin::Pin;

use bytes::Bytes;
use futures::{stream::StreamExt, Stream};
use reqwest::multipart::Form;
use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
use reqwest::{multipart::Form, Response};
use reqwest_eventsource::{Error as EventSourceError, Event, EventSource, RequestBuilderExt};
use serde::{de::DeserializeOwned, Serialize};

use crate::{
config::{Config, OpenAIConfig},
error::{map_deserialization_error, ApiError, OpenAIError, WrappedError},
error::{map_deserialization_error, ApiError, OpenAIError, StreamError, WrappedError},
file::Files,
image::Images,
moderation::Moderations,
Expand Down Expand Up @@ -335,52 +335,34 @@ impl<C: Config> Client<C> {
.map_err(backoff::Error::Permanent)?;

let status = response.status();
let bytes = response
.bytes()
.await
.map_err(OpenAIError::Reqwest)
.map_err(backoff::Error::Permanent)?;

if status.is_server_error() {
// OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them.
let message: String = String::from_utf8_lossy(&bytes).into_owned();
tracing::warn!("Server error: {status} - {message}");
return Err(backoff::Error::Transient {
err: OpenAIError::ApiError(ApiError {
message,
r#type: None,
param: None,
code: None,
}),
retry_after: None,
});
}

// Deserialize response body from either error object or actual response object
if !status.is_success() {
let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
.map_err(|e| map_deserialization_error(e, bytes.as_ref()))
.map_err(backoff::Error::Permanent)?;

if status.as_u16() == 429
// API returns 429 also when:
// "You exceeded your current quota, please check your plan and billing details."
&& wrapped_error.error.r#type != Some("insufficient_quota".to_string())
{
// Rate limited retry...
tracing::warn!("Rate limited: {}", wrapped_error.error.message);
return Err(backoff::Error::Transient {
err: OpenAIError::ApiError(wrapped_error.error),
retry_after: None,
});
} else {
return Err(backoff::Error::Permanent(OpenAIError::ApiError(
wrapped_error.error,
)));
match read_response(response).await {
Ok(bytes) => Ok(bytes),
Err(e) => {
match e {
OpenAIError::ApiError(api_error) => {
if status.is_server_error() {
Err(backoff::Error::Transient {
err: OpenAIError::ApiError(api_error),
retry_after: None,
})
} else if status.as_u16() == 429
&& api_error.r#type != Some("insufficient_quota".to_string())
{
// Rate limited retry...
tracing::warn!("Rate limited: {}", api_error.message);
Err(backoff::Error::Transient {
err: OpenAIError::ApiError(api_error),
retry_after: None,
})
} else {
Err(backoff::Error::Permanent(OpenAIError::ApiError(api_error)))
}
}
_ => Err(backoff::Error::Permanent(e)),
}
}
}

Ok(bytes)
})
.await
}
Expand Down Expand Up @@ -471,6 +453,53 @@ impl<C: Config> Client<C> {
}
}

async fn read_response(response: Response) -> Result<Bytes, OpenAIError> {
let status = response.status();
let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?;

if status.is_server_error() {
// OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them.
let message: String = String::from_utf8_lossy(&bytes).into_owned();
tracing::warn!("Server error: {status} - {message}");
return Err(OpenAIError::ApiError(ApiError {
message,
r#type: None,
param: None,
code: None,
}));
}

// Deserialize response body from either error object or actual response object
if !status.is_success() {
let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
.map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;

return Err(OpenAIError::ApiError(wrapped_error.error));
}

Ok(bytes)
}

async fn map_stream_error(value: EventSourceError) -> OpenAIError {
match value {
EventSourceError::Parser(e) => OpenAIError::StreamError(StreamError::Parser(e.to_string())),
EventSourceError::InvalidContentType(e, response) => {
OpenAIError::StreamError(StreamError::InvalidContentType(e, response))
}
EventSourceError::InvalidLastEventId(e) => {
OpenAIError::StreamError(StreamError::InvalidLastEventId(e))
}
EventSourceError::StreamEnded => OpenAIError::StreamError(StreamError::StreamEnded),
EventSourceError::Utf8(e) => OpenAIError::StreamError(StreamError::Utf8(e)),
EventSourceError::Transport(error) => OpenAIError::Reqwest(error),
EventSourceError::InvalidStatusCode(_status_code, response) => {
read_response(response).await.expect_err(
"Unreachable because read_response returns err when status_code is invalid",
)
}
}
}

/// Request which responds with SSE.
/// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format)
pub(crate) async fn stream<O>(
Expand All @@ -485,7 +514,7 @@ where
while let Some(ev) = event_source.next().await {
match ev {
Err(e) => {
if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) {
if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
// rx dropped
break;
}
Expand Down Expand Up @@ -530,7 +559,7 @@ where
while let Some(ev) = event_source.next().await {
match ev {
Err(e) => {
if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) {
if let Err(_e) = tx.send(Err(map_stream_error(e).await)) {
// rx dropped
break;
}
Expand Down
28 changes: 27 additions & 1 deletion async-openai/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
//! Errors originating from API calls, parsing responses, and reading-or-writing to the file system.
use std::string::FromUtf8Error;

use reqwest::{header::HeaderValue, Response};
use serde::{Deserialize, Serialize};


#[derive(Debug, thiserror::Error)]
pub enum OpenAIError {
/// Underlying error from reqwest library after an API call was made
Expand All @@ -20,13 +24,35 @@ pub enum OpenAIError {
FileReadError(String),
/// Error on SSE streaming
#[error("stream failed: {0}")]
StreamError(String),
StreamError(StreamError),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Since you always map the inner eventsource error anyway, It would simplify the code quite a bit I think to wrap StreamError(EventSourceError) instead with #[from].

If I'm not mistaken, that would remove map_stream_error instead as we can use Into, remove the intermediate error, and avoid reading out the bytes of the response early.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't because we need to determine what OpenAIError is wrapped inside the EventSourceError::InvalidStatusCode and as a bonus we also map the EventSourceError::Transport to OpenAIError::Reqwest.

/// Error from client side validation
/// or when builder fails to build request before making API call
#[error("invalid args: {0}")]
InvalidArgument(String),
}

#[derive(Debug, thiserror::Error)]
pub enum StreamError {
/// Source stream is not valid UTF8
#[error(transparent)]
Utf8(FromUtf8Error),
/// Source stream is not a valid EventStream
#[error("Source stream is not a valid event stream: {0}")]
Parser(String),
/// The `Content-Type` returned by the server is invalid
#[error("Invalid content type for event stream: {0:?}")]
InvalidContentType(HeaderValue, Response),
/// The `Last-Event-ID` cannot be formed into a Header to be submitted to the server
#[error("Invalid `Last-Event-ID` for event stream: {0}")]
InvalidLastEventId(String),
/// The server sent an unrecognized event type
#[error("Unrecognized event type: {0}")]
UnrecognizedEventType(String),
/// The stream ended
#[error("Stream ended")]
StreamEnded,
}

/// OpenAI API returns error object on failure
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct ApiError {
Expand Down
4 changes: 2 additions & 2 deletions async-openai/src/types/assistant_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::pin::Pin;
use futures::Stream;
use serde::Deserialize;

use crate::error::{map_deserialization_error, ApiError, OpenAIError};
use crate::error::{map_deserialization_error, ApiError, OpenAIError, StreamError};

use super::{
MessageDeltaObject, MessageObject, RunObject, RunStepDeltaObject, RunStepObject, ThreadObject,
Expand Down Expand Up @@ -208,7 +208,7 @@ impl TryFrom<eventsource_stream::Event> for AssistantStreamEvent {
"done" => Ok(AssistantStreamEvent::Done(value.data)),

_ => Err(OpenAIError::StreamError(
"Unrecognized event: {value:?#}".into(),
StreamError::UnrecognizedEventType(value.event),
)),
}
}
Expand Down