Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions async-openai/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ secrecy = { version = "0.10.3", features = ["serde"] }
bytes = "1.9.0"
eventsource-stream = "0.2.3"
tokio-tungstenite = { version = "0.26.1", optional = true, default-features = false }
byte_string = "1.0.0"

[dev-dependencies]
tokio-test = "0.4.4"
Expand Down
14 changes: 9 additions & 5 deletions async-openai/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
use std::pin::Pin;
use std::{any::type_name, bstr::ByteString, pin::Pin};

use bytes::Bytes;
use futures::{stream::StreamExt, Stream};
use reqwest::multipart::Form;
use reqwest_eventsource::{Event, EventSource, RequestBuilderExt};
use serde::{de::DeserializeOwned, Serialize};
use tracing::warn;

use crate::{
config::{Config, OpenAIConfig},
Expand Down Expand Up @@ -359,7 +360,7 @@ impl<C: Config> Client<C> {
// Deserialize response body from either error object or actual response object
if !status.is_success() {
let wrapped_error: WrappedError = serde_json::from_slice(bytes.as_ref())
.map_err(|e| map_deserialization_error(e, bytes.as_ref()))
.map_err(|e| map_deserialization_error(e, ByteString(bytes.to_vec())))
.map_err(backoff::Error::Permanent)?;

if status.as_u16() == 429
Expand Down Expand Up @@ -399,7 +400,7 @@ impl<C: Config> Client<C> {
let bytes = self.execute_raw(request_maker).await?;

let response: O = serde_json::from_slice(bytes.as_ref())
.map_err(|e| map_deserialization_error(e, bytes.as_ref()))?;
.map_err(|e| map_deserialization_error(e, ByteString(bytes.to_vec())))?;

Ok(response)
}
Expand Down Expand Up @@ -495,9 +496,12 @@ where
if message.data == "[DONE]" {
break;
}

let response = match serde_json::from_str::<O>(&message.data) {
Err(e) => Err(map_deserialization_error(e, message.data.as_bytes())),
Err(e) => {
warn!("Deserializing error {:?} to type {}", message.data, type_name::<O>());
Err(map_deserialization_error(e, ByteString(message.data.as_bytes().to_vec()))) // Convert String to Vec<u8>
},
Ok(output) => Ok(output),
};

Expand Down
16 changes: 8 additions & 8 deletions async-openai/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! Errors originating from API calls, parsing responses, and reading-or-writing to the file system.
use serde::{Deserialize, Serialize};
use std::bstr::ByteString;

#[derive(Debug, thiserror::Error)]
pub enum OpenAIError {
Expand All @@ -11,7 +12,9 @@ pub enum OpenAIError {
ApiError(ApiError),
/// Error when a response cannot be deserialized into a Rust type
#[error("failed to deserialize api response: {0}")]
JSONDeserialize(serde_json::Error),
// ByteString is used Because if you use Vec<u8> You get spammed with numbers in terminal. Huge loss.
// there is https://crates.io/crates/byte_string if you dont want unstable features.
JSONDeserialize(serde_json::Error, ByteString),
/// Error on the client side when saving file to file system
#[error("failed to save file: {0}")]
FileSaveError(String),
Expand Down Expand Up @@ -67,10 +70,7 @@ pub struct WrappedError {
pub error: ApiError,
}

pub(crate) fn map_deserialization_error(e: serde_json::Error, bytes: &[u8]) -> OpenAIError {
tracing::error!(
"failed deserialization of: {}",
String::from_utf8_lossy(bytes)
);
OpenAIError::JSONDeserialize(e)
}
pub(crate) fn map_deserialization_error(e: serde_json::Error, string: ByteString) -> OpenAIError {
tracing::error!("failed deserialization {:?}", e);
OpenAIError::JSONDeserialize(e, string)
}
2 changes: 2 additions & 0 deletions async-openai/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#![feature(bstr)]

//! Rust library for OpenAI
//!
//! ## Creating client
Expand Down
50 changes: 25 additions & 25 deletions async-openai/src/types/assistant_stream.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::pin::Pin;
use std::{bstr::ByteString, pin::Pin};

use futures::Stream;
use serde::Deserialize;
Expand Down Expand Up @@ -118,92 +118,92 @@ impl TryFrom<eventsource_stream::Event> for AssistantStreamEvent {
fn try_from(value: eventsource_stream::Event) -> Result<Self, Self::Error> {
match value.event.as_str() {
"thread.created" => serde_json::from_str::<ThreadObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec())))
.map(AssistantStreamEvent::ThreadCreated),
"thread.run.created" => serde_json::from_str::<RunObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec())))
.map(AssistantStreamEvent::ThreadRunCreated),
"thread.run.queued" => serde_json::from_str::<RunObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec())))
.map(AssistantStreamEvent::ThreadRunQueued),
"thread.run.in_progress" => serde_json::from_str::<RunObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec())))
.map(AssistantStreamEvent::ThreadRunInProgress),
"thread.run.requires_action" => serde_json::from_str::<RunObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec())))
.map(AssistantStreamEvent::ThreadRunRequiresAction),
"thread.run.completed" => serde_json::from_str::<RunObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec())))
.map(AssistantStreamEvent::ThreadRunCompleted),
"thread.run.incomplete" => serde_json::from_str::<RunObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec())))
.map(AssistantStreamEvent::ThreadRunIncomplete),
"thread.run.failed" => serde_json::from_str::<RunObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec())))
.map(AssistantStreamEvent::ThreadRunFailed),
"thread.run.cancelling" => serde_json::from_str::<RunObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec())))
.map(AssistantStreamEvent::ThreadRunCancelling),
"thread.run.cancelled" => serde_json::from_str::<RunObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec())))
.map(AssistantStreamEvent::ThreadRunCancelled),
"thread.run.expired" => serde_json::from_str::<RunObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec())))
.map(AssistantStreamEvent::ThreadRunExpired),
"thread.run.step.created" => serde_json::from_str::<RunStepObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec())))
.map(AssistantStreamEvent::ThreadRunStepCreated),
"thread.run.step.in_progress" => {
serde_json::from_str::<RunStepObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec())))
.map(AssistantStreamEvent::ThreadRunStepInProgress)
}
"thread.run.step.delta" => {
serde_json::from_str::<RunStepDeltaObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec())))
.map(AssistantStreamEvent::ThreadRunStepDelta)
}
"thread.run.step.completed" => {
serde_json::from_str::<RunStepObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec())))
.map(AssistantStreamEvent::ThreadRunStepCompleted)
}
"thread.run.step.failed" => serde_json::from_str::<RunStepObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec())))
.map(AssistantStreamEvent::ThreadRunStepFailed),
"thread.run.step.cancelled" => {
serde_json::from_str::<RunStepObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec())))
.map(AssistantStreamEvent::ThreadRunStepCancelled)
}
"thread.run.step.expired" => serde_json::from_str::<RunStepObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec())))
.map(AssistantStreamEvent::ThreadRunStepExpired),
"thread.message.created" => serde_json::from_str::<MessageObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec())))
.map(AssistantStreamEvent::ThreadMessageCreated),
"thread.message.in_progress" => {
serde_json::from_str::<MessageObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec())))
.map(AssistantStreamEvent::ThreadMessageInProgress)
}
"thread.message.delta" => {
serde_json::from_str::<MessageDeltaObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec())))
.map(AssistantStreamEvent::ThreadMessageDelta)
}
"thread.message.completed" => {
serde_json::from_str::<MessageObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec())))
.map(AssistantStreamEvent::ThreadMessageCompleted)
}
"thread.message.incomplete" => {
serde_json::from_str::<MessageObject>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec())))
.map(AssistantStreamEvent::ThreadMessageIncomplete)
}
"error" => serde_json::from_str::<ApiError>(value.data.as_str())
.map_err(|e| map_deserialization_error(e, value.data.as_bytes()))
.map_err(|e| map_deserialization_error(e, ByteString(value.data.as_bytes().to_vec())))
.map(AssistantStreamEvent::ErrorEvent),
"done" => Ok(AssistantStreamEvent::Done(value.data)),

Expand Down
1 change: 1 addition & 0 deletions async-openai/src/types/completion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ pub struct CreateCompletionRequest {
pub struct CreateCompletionResponse {
/// A unique identifier for the completion.
pub id: String,

pub choices: Vec<Choice>,
/// The Unix timestamp (in seconds) of when the completion was created.
pub created: u32,
Expand Down