diff --git a/async-openai/Cargo.toml b/async-openai/Cargo.toml index 99d8da8b..6b42a0ae 100644 --- a/async-openai/Cargo.toml +++ b/async-openai/Cargo.toml @@ -40,7 +40,7 @@ reqwest = { version = "0.12.12", features = [ reqwest-eventsource = "0.6.0" serde = { version = "1.0.217", features = ["derive", "rc"] } serde_json = "1.0.135" -thiserror = "2.0.11" +thiserror = "2.*" tokio = { version = "1.43.0", features = ["fs", "macros"] } tokio-stream = "0.1.17" tokio-util = { version = "0.7.13", features = ["codec", "io-util"] } @@ -50,6 +50,7 @@ secrecy = { version = "0.10.3", features = ["serde"] } bytes = "1.9.0" eventsource-stream = "0.2.3" tokio-tungstenite = { version = "0.26.1", optional = true, default-features = false } +byte_string = "1.0.0" [dev-dependencies] tokio-test = "0.4.4" diff --git a/async-openai/src/client.rs b/async-openai/src/client.rs index fe2ed232..8405896f 100644 --- a/async-openai/src/client.rs +++ b/async-openai/src/client.rs @@ -1,10 +1,11 @@ -use std::pin::Pin; +use std::{any::type_name, bstr::ByteString, pin::Pin}; use bytes::Bytes; use futures::{stream::StreamExt, Stream}; use reqwest::multipart::Form; use reqwest_eventsource::{Event, EventSource, RequestBuilderExt}; use serde::{de::DeserializeOwned, Serialize}; +use tracing::warn; use crate::{ config::{Config, OpenAIConfig}, @@ -359,7 +360,7 @@ impl Client { // Deserialize response body from either error object or actual response object if !status.is_success() { let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref()) - .map_err(|e| map_deserialization_error(e, bytes.as_ref())) + .map_err(|e| map_deserialization_error(e, ByteString(bytes.to_vec()))) .map_err(backoff::Error::Permanent)?; if status.as_u16() == 429 @@ -399,7 +400,7 @@ impl Client { let bytes = self.execute_raw(request_maker).await?; let response: O = serde_json::from_slice(bytes.as_ref()) - .map_err(|e| map_deserialization_error(e, bytes.as_ref()))?; + .map_err(|e| map_deserialization_error(e, ByteString(bytes.to_vec())))?; Ok(response) } @@ -495,9 +496,12 @@ where if message.data == "[DONE]" { break; } - + let response = match serde_json::from_str::(&message.data) { - Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())), + Err(e) => { + warn!("Deserializing error {:?} to type {}", message.data, type_name::()); + Err(map_deserialization_error(e, ByteString(message.data.as_bytes().to_vec()))) // Convert String to Vec + }, Ok(output) => Ok(output), }; diff --git a/async-openai/src/config.rs b/async-openai/src/config.rs index 82ab043c..ce0085ab 100644 --- a/async-openai/src/config.rs +++ b/async-openai/src/config.rs @@ -1,7 +1,7 @@ //! Client configurations: [OpenAIConfig] for OpenAI, [AzureConfig] for Azure OpenAI Service. use reqwest::header::{HeaderMap, AUTHORIZATION}; -use secrecy::{ExposeSecret, SecretString}; -use serde::Deserialize; +use secrecy::{ExposeSecret, SecretString, SerializableSecret}; +use serde::{Deserialize, Serialize}; /// Default v1 API base url pub const OPENAI_API_BASE: &str = "https://api.openai.com/v1"; @@ -61,6 +61,36 @@ pub struct OpenAIConfig { project_id: String, } +#[derive(Clone, Debug, Serialize, Deserialize, Hash)] +pub struct OpenAIConfigSerde { + api_base: String, + api_key: String, + org_id: String, + project_id: String, +} + +impl From for OpenAIConfigSerde { + fn from(config: OpenAIConfig) -> Self { + Self { + api_base: config.api_base, + api_key: config.api_key.expose_secret().to_string(), + org_id: config.org_id, + project_id: config.project_id, + } + } +} + +impl From for OpenAIConfig { + fn from(config: OpenAIConfigSerde) -> Self { + Self { + api_base: config.api_base, + api_key: SecretString::from(config.api_key), + org_id: config.org_id, + project_id: config.project_id, + } + } +} + impl Default for OpenAIConfig { fn default() -> Self { Self { diff --git a/async-openai/src/error.rs b/async-openai/src/error.rs index a1139c9f..051325d3 100644 --- a/async-openai/src/error.rs +++ b/async-openai/src/error.rs @@ -1,5 +1,6 @@ //! Errors originating from API calls, parsing responses, and reading-or-writing to the file system. use serde::{Deserialize, Serialize}; +use std::bstr::ByteString; #[derive(Debug, thiserror::Error)] pub enum OpenAIError { @@ -11,7 +12,9 @@ pub enum OpenAIError { ApiError(ApiError), /// Error when a response cannot be deserialized into a Rust type #[error("failed to deserialize api response: {0}")] - JSONDeserialize(serde_json::Error), + // ByteString is used Because if you use Vec You get spammed with numbers in terminal. Huge loss. + // there is https://crates.io/crates/byte_string if you dont want unstable features. + JSONDeserialize(serde_json::Error, ByteString), /// Error on the client side when saving file to file system #[error("failed to save file: {0}")] FileSaveError(String), @@ -67,10 +70,7 @@ pub struct WrappedError { pub error: ApiError, } -pub(crate) fn map_deserialization_error(e: serde_json::Error, bytes: &[u8]) -> OpenAIError { - tracing::error!( - "failed deserialization of: {}", - String::from_utf8_lossy(bytes) - ); - OpenAIError::JSONDeserialize(e) -} +pub(crate) fn map_deserialization_error(e: serde_json::Error, string: ByteString) -> OpenAIError { + tracing::error!("failed deserialization {:?}", e); + OpenAIError::JSONDeserialize(e, string) +} \ No newline at end of file diff --git a/async-openai/src/lib.rs b/async-openai/src/lib.rs index c94bc495..b9c31c07 100644 --- a/async-openai/src/lib.rs +++ b/async-openai/src/lib.rs @@ -1,3 +1,5 @@ +#![feature(bstr)] + //! Rust library for OpenAI //! //! ## Creating client diff --git a/async-openai/src/types/assistant_stream.rs b/async-openai/src/types/assistant_stream.rs index fca835cf..2268d608 100644 --- a/async-openai/src/types/assistant_stream.rs +++ b/async-openai/src/types/assistant_stream.rs @@ -1,4 +1,4 @@ -use std::pin::Pin; +use std::{bstr::ByteString, pin::Pin}; use futures::Stream; use serde::Deserialize; @@ -118,92 +118,92 @@ impl TryFrom for AssistantStreamEvent { fn try_from(value: eventsource_stream::Event) -> Result { match value.event.as_str() { "thread.created" => serde_json::from_str::(value.data.as_str()) - .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec()))) .map(AssistantStreamEvent::ThreadCreated), "thread.run.created" => serde_json::from_str::(value.data.as_str()) - .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec()))) .map(AssistantStreamEvent::ThreadRunCreated), "thread.run.queued" => serde_json::from_str::(value.data.as_str()) - .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec()))) .map(AssistantStreamEvent::ThreadRunQueued), "thread.run.in_progress" => serde_json::from_str::(value.data.as_str()) - .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec()))) .map(AssistantStreamEvent::ThreadRunInProgress), "thread.run.requires_action" => serde_json::from_str::(value.data.as_str()) - .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec()))) .map(AssistantStreamEvent::ThreadRunRequiresAction), "thread.run.completed" => serde_json::from_str::(value.data.as_str()) - .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec()))) .map(AssistantStreamEvent::ThreadRunCompleted), "thread.run.incomplete" => serde_json::from_str::(value.data.as_str()) - .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec()))) .map(AssistantStreamEvent::ThreadRunIncomplete), "thread.run.failed" => serde_json::from_str::(value.data.as_str()) - .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec()))) .map(AssistantStreamEvent::ThreadRunFailed), "thread.run.cancelling" => serde_json::from_str::(value.data.as_str()) - .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec()))) .map(AssistantStreamEvent::ThreadRunCancelling), "thread.run.cancelled" => serde_json::from_str::(value.data.as_str()) - .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec()))) .map(AssistantStreamEvent::ThreadRunCancelled), "thread.run.expired" => serde_json::from_str::(value.data.as_str()) - .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec()))) .map(AssistantStreamEvent::ThreadRunExpired), "thread.run.step.created" => serde_json::from_str::(value.data.as_str()) - .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec()))) .map(AssistantStreamEvent::ThreadRunStepCreated), "thread.run.step.in_progress" => { serde_json::from_str::(value.data.as_str()) - .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec()))) .map(AssistantStreamEvent::ThreadRunStepInProgress) } "thread.run.step.delta" => { serde_json::from_str::(value.data.as_str()) - .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec()))) .map(AssistantStreamEvent::ThreadRunStepDelta) } "thread.run.step.completed" => { serde_json::from_str::(value.data.as_str()) - .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec()))) .map(AssistantStreamEvent::ThreadRunStepCompleted) } "thread.run.step.failed" => serde_json::from_str::(value.data.as_str()) - .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec()))) .map(AssistantStreamEvent::ThreadRunStepFailed), "thread.run.step.cancelled" => { serde_json::from_str::(value.data.as_str()) - .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec()))) .map(AssistantStreamEvent::ThreadRunStepCancelled) } "thread.run.step.expired" => serde_json::from_str::(value.data.as_str()) - .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec()))) .map(AssistantStreamEvent::ThreadRunStepExpired), "thread.message.created" => serde_json::from_str::(value.data.as_str()) - .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec()))) .map(AssistantStreamEvent::ThreadMessageCreated), "thread.message.in_progress" => { serde_json::from_str::(value.data.as_str()) - .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec()))) .map(AssistantStreamEvent::ThreadMessageInProgress) } "thread.message.delta" => { serde_json::from_str::(value.data.as_str()) - .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec()))) .map(AssistantStreamEvent::ThreadMessageDelta) } "thread.message.completed" => { serde_json::from_str::(value.data.as_str()) - .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec()))) .map(AssistantStreamEvent::ThreadMessageCompleted) } "thread.message.incomplete" => { serde_json::from_str::(value.data.as_str()) - .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec()))) .map(AssistantStreamEvent::ThreadMessageIncomplete) } "error" => serde_json::from_str::(value.data.as_str()) - .map_err(|e| map_deserialization_error(e, value.data.as_bytes())) + .map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec()))) .map(AssistantStreamEvent::ErrorEvent), "done" => Ok(AssistantStreamEvent::Done(value.data)), diff --git a/async-openai/src/types/completion.rs b/async-openai/src/types/completion.rs index 3c15dd6b..6c263f3d 100644 --- a/async-openai/src/types/completion.rs +++ b/async-openai/src/types/completion.rs @@ -119,6 +119,7 @@ pub struct CreateCompletionRequest { pub struct CreateCompletionResponse { /// A unique identifier for the completion. pub id: String, + pub choices: Vec, /// The Unix timestamp (in seconds) of when the completion was created. pub created: u32,