diff --git a/async-openai/src/client.rs b/async-openai/src/client.rs index fe2ed232..877d9c26 100644 --- a/async-openai/src/client.rs +++ b/async-openai/src/client.rs @@ -2,13 +2,13 @@ use std::pin::Pin; use bytes::Bytes; use futures::{stream::StreamExt, Stream}; -use reqwest::multipart::Form; -use reqwest_eventsource::{Event, EventSource, RequestBuilderExt}; +use reqwest::{multipart::Form, Response}; +use reqwest_eventsource::{Error as EventSourceError, Event, EventSource, RequestBuilderExt}; use serde::{de::DeserializeOwned, Serialize}; use crate::{ config::{Config, OpenAIConfig}, - error::{map_deserialization_error, ApiError, OpenAIError, WrappedError}, + error::{map_deserialization_error, ApiError, OpenAIError, StreamError, WrappedError}, file::Files, image::Images, moderation::Moderations, @@ -335,52 +335,34 @@ impl Client { .map_err(backoff::Error::Permanent)?; let status = response.status(); - let bytes = response - .bytes() - .await - .map_err(OpenAIError::Reqwest) - .map_err(backoff::Error::Permanent)?; - if status.is_server_error() { - // OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them. - let message: String = String::from_utf8_lossy(&bytes).into_owned(); - tracing::warn!("Server error: {status} - {message}"); - return Err(backoff::Error::Transient { - err: OpenAIError::ApiError(ApiError { - message, - r#type: None, - param: None, - code: None, - }), - retry_after: None, - }); - } - - // Deserialize response body from either error object or actual response object - if !status.is_success() { - let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref()) - .map_err(|e| map_deserialization_error(e, bytes.as_ref())) - .map_err(backoff::Error::Permanent)?; - - if status.as_u16() == 429 - // API returns 429 also when: - // "You exceeded your current quota, please check your plan and billing details." - && wrapped_error.error.r#type != Some("insufficient_quota".to_string()) - { - // Rate limited retry... - tracing::warn!("Rate limited: {}", wrapped_error.error.message); - return Err(backoff::Error::Transient { - err: OpenAIError::ApiError(wrapped_error.error), - retry_after: None, - }); - } else { - return Err(backoff::Error::Permanent(OpenAIError::ApiError( - wrapped_error.error, - ))); + match read_response(response).await { + Ok(bytes) => Ok(bytes), + Err(e) => { + match e { + OpenAIError::ApiError(api_error) => { + if status.is_server_error() { + Err(backoff::Error::Transient { + err: OpenAIError::ApiError(api_error), + retry_after: None, + }) + } else if status.as_u16() == 429 + && api_error.r#type != Some("insufficient_quota".to_string()) + { + // Rate limited retry... + tracing::warn!("Rate limited: {}", api_error.message); + Err(backoff::Error::Transient { + err: OpenAIError::ApiError(api_error), + retry_after: None, + }) + } else { + Err(backoff::Error::Permanent(OpenAIError::ApiError(api_error))) + } + } + _ => Err(backoff::Error::Permanent(e)), + } } } - - Ok(bytes) }) .await } @@ -471,6 +453,53 @@ impl Client { } } +async fn read_response(response: Response) -> Result { + let status = response.status(); + let bytes = response.bytes().await.map_err(OpenAIError::Reqwest)?; + + if status.is_server_error() { + // OpenAI does not guarantee server errors are returned as JSON so we cannot deserialize them. + let message: String = String::from_utf8_lossy(&bytes).into_owned(); + tracing::warn!("Server error: {status} - {message}"); + return Err(OpenAIError::ApiError(ApiError { + message, + r#type: None, + param: None, + code: None, + })); + } + + // Deserialize response body from either error object or actual response object + if !status.is_success() { + let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref()) + .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?; + + return Err(OpenAIError::ApiError(wrapped_error.error)); + } + + Ok(bytes) +} + +async fn map_stream_error(value: EventSourceError) -> OpenAIError { + match value { + EventSourceError::Parser(e) => OpenAIError::StreamError(StreamError::Parser(e.to_string())), + EventSourceError::InvalidContentType(e, response) => { + OpenAIError::StreamError(StreamError::InvalidContentType(e, response)) + } + EventSourceError::InvalidLastEventId(e) => { + OpenAIError::StreamError(StreamError::InvalidLastEventId(e)) + } + EventSourceError::StreamEnded => OpenAIError::StreamError(StreamError::StreamEnded), + EventSourceError::Utf8(e) => OpenAIError::StreamError(StreamError::Utf8(e)), + EventSourceError::Transport(error) => OpenAIError::Reqwest(error), + EventSourceError::InvalidStatusCode(_status_code, response) => { + read_response(response).await.expect_err( + "Unreachable because read_response returns err when status_code is invalid", + ) + } + } +} + /// Request which responds with SSE. /// [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format) pub(crate) async fn stream( @@ -485,7 +514,7 @@ where while let Some(ev) = event_source.next().await { match ev { Err(e) => { - if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) { + if let Err(_e) = tx.send(Err(map_stream_error(e).await)) { // rx dropped break; } @@ -530,7 +559,7 @@ where while let Some(ev) = event_source.next().await { match ev { Err(e) => { - if let Err(_e) = tx.send(Err(OpenAIError::StreamError(e.to_string()))) { + if let Err(_e) = tx.send(Err(map_stream_error(e).await)) { // rx dropped break; } diff --git a/async-openai/src/error.rs b/async-openai/src/error.rs index a1139c9f..46b2fc1b 100644 --- a/async-openai/src/error.rs +++ b/async-openai/src/error.rs @@ -1,6 +1,10 @@ //! Errors originating from API calls, parsing responses, and reading-or-writing to the file system. +use std::string::FromUtf8Error; + +use reqwest::{header::HeaderValue, Response}; use serde::{Deserialize, Serialize}; + #[derive(Debug, thiserror::Error)] pub enum OpenAIError { /// Underlying error from reqwest library after an API call was made @@ -20,13 +24,35 @@ pub enum OpenAIError { FileReadError(String), /// Error on SSE streaming #[error("stream failed: {0}")] - StreamError(String), + StreamError(StreamError), /// Error from client side validation /// or when builder fails to build request before making API call #[error("invalid args: {0}")] InvalidArgument(String), } +#[derive(Debug, thiserror::Error)] +pub enum StreamError { + /// Source stream is not valid UTF8 + #[error(transparent)] + Utf8(FromUtf8Error), + /// Source stream is not a valid EventStream + #[error("Source stream is not a valid event stream: {0}")] + Parser(String), + /// The `Content-Type` returned by the server is invalid + #[error("Invalid content type for event stream: {0:?}")] + InvalidContentType(HeaderValue, Response), + /// The `Last-Event-ID` cannot be formed into a Header to be submitted to the server + #[error("Invalid `Last-Event-ID` for event stream: {0}")] + InvalidLastEventId(String), + /// The server sent an unrecognized event type + #[error("Unrecognized event type: {0}")] + UnrecognizedEventType(String), + /// The stream ended + #[error("Stream ended")] + StreamEnded, +} + /// OpenAI API returns error object on failure #[derive(Debug, Serialize, Deserialize, Clone)] pub struct ApiError { diff --git a/async-openai/src/types/assistant_stream.rs b/async-openai/src/types/assistant_stream.rs index 755a322d..d7bec396 100644 --- a/async-openai/src/types/assistant_stream.rs +++ b/async-openai/src/types/assistant_stream.rs @@ -3,7 +3,7 @@ use std::pin::Pin; use futures::Stream; use serde::Deserialize; -use crate::error::{map_deserialization_error, ApiError, OpenAIError}; +use crate::error::{map_deserialization_error, ApiError, OpenAIError, StreamError}; use super::{ MessageDeltaObject, MessageObject, RunObject, RunStepDeltaObject, RunStepObject, ThreadObject, @@ -208,7 +208,7 @@ impl TryFrom for AssistantStreamEvent { "done" => Ok(AssistantStreamEvent::Done(value.data)), _ => Err(OpenAIError::StreamError( - "Unrecognized event: {value:?#}".into(), + StreamError::UnrecognizedEventType(value.event), )), } }