Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion crates/chat-cli/src/cli/chat/server_messenger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ use tokio::sync::mpsc::{
channel,
};

use super::tools::custom_tool::TransportType;
use crate::mcp_client::messenger::{
Messenger,
MessengerError,
Expand Down Expand Up @@ -50,6 +51,7 @@ pub enum UpdateEventMessage {
},
InitStart {
server_name: String,
transport_type: TransportType,
},
Deinit {
server_name: String,
Expand All @@ -70,9 +72,10 @@ impl ServerMessengerBuilder {
(rx, this)
}

pub fn build_with_name(&self, server_name: String) -> ServerMessenger {
pub fn build(&self, server_name: String, transport_type: TransportType) -> ServerMessenger {
ServerMessenger {
server_name,
transport_type,
update_event_sender: self.update_event_sender.clone(),
}
}
Expand All @@ -81,6 +84,7 @@ impl ServerMessengerBuilder {
#[derive(Clone, Debug)]
pub struct ServerMessenger {
pub server_name: String,
pub transport_type: TransportType,
pub update_event_sender: Sender<UpdateEventMessage>,
}

Expand Down Expand Up @@ -166,6 +170,7 @@ impl Messenger for ServerMessenger {
.update_event_sender
.send(UpdateEventMessage::InitStart {
server_name: self.server_name.clone(),
transport_type: self.transport_type,
})
.await
.map_err(|e| MessengerError::Custom(e.to_string()))?)
Expand Down
53 changes: 37 additions & 16 deletions crates/chat-cli/src/cli/chat/tool_manager.rs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,10 @@ use crate::cli::chat::server_messenger::{
ServerMessengerBuilder,
UpdateEventMessage,
};
use crate::cli::chat::tools::custom_tool::CustomTool;
use crate::cli::chat::tools::custom_tool::{
CustomTool,
TransportType,
};
use crate::cli::chat::tools::execute::ExecuteCommand;
use crate::cli::chat::tools::fs_read::FsRead;
use crate::cli::chat::tools::fs_write::FsWrite;
Expand Down Expand Up @@ -181,7 +184,7 @@ pub struct ToolManagerBuilder {
has_new_stuff: Arc<AtomicBool>,
mcp_load_record: Arc<Mutex<HashMap<String, Vec<LoadingRecord>>>>,
new_tool_specs: NewToolSpecs,
pending_clients: Option<Arc<RwLock<HashSet<String>>>>,
pending_clients: Option<Arc<RwLock<HashMap<String, TransportType>>>>,
is_first_launch: bool,
agent: Option<Arc<Mutex<Agent>>>,
}
Expand Down Expand Up @@ -322,8 +325,12 @@ impl ToolManagerBuilder {
let new_tool_specs = self.new_tool_specs;
let has_new_stuff = self.has_new_stuff;
let pending = self.pending_clients.unwrap_or(Arc::new(RwLock::new({
let mut pending = HashSet::<String>::new();
pending.extend(pre_initialized.iter().map(|(name, _)| name.clone()));
let mut pending = HashMap::<String, TransportType>::new();
pending.extend(
pre_initialized
.iter()
.map(|(name, config)| (name.clone(), config.r#type)),
);
pending
})));
let notify = Arc::new(Notify::new());
Expand Down Expand Up @@ -388,18 +395,20 @@ impl ToolManagerBuilder {
let pre_initialized = enabled_servers
.into_iter()
.map(|(server_name, server_config)| {
let transport_type = server_config.r#type;
(
server_name.clone(),
McpClientService::new(
server_name.clone(),
server_config,
messenger_builder.build_with_name(server_name),
messenger_builder.build(server_name, transport_type),
),
)
})
.collect::<Vec<_>>();

for (mut name, mcp_client) in pre_initialized {
let transport_type = mcp_client.config.r#type;
let init_res = mcp_client.init(os).await;
match init_res {
Ok(mut running_service) => {
Expand All @@ -417,6 +426,7 @@ impl ToolManagerBuilder {
&os.database,
conversation_id.clone(),
name.clone(),
transport_type,
Some(e.to_string()),
0,
Some("".to_string()),
Expand All @@ -426,7 +436,7 @@ impl ToolManagerBuilder {
.await
.ok();

let temp_messenger = messenger_builder.build_with_name(name);
let temp_messenger = messenger_builder.build(name, transport_type);
let _ = temp_messenger
.send_tools_list_result(Err(ServiceError::UnexpectedResponse), None)
.await;
Expand Down Expand Up @@ -554,7 +564,7 @@ pub struct ToolManager {
pub clients: HashMap<String, InitializedMcpClient>,

/// A list of client names that are still in the process of being initialized
pub pending_clients: Arc<RwLock<HashSet<String>>>,
pub pending_clients: Arc<RwLock<HashMap<String, TransportType>>>,

/// Flag indicating whether new tool specifications have been added since the last update.
/// When set to true, it signals that the tool manager needs to refresh its internal state
Expand Down Expand Up @@ -789,7 +799,7 @@ impl ToolManager {
tokio::select! {
_ = timeout_fut => {
if let Some(tx) = tx {
let still_loading = self.pending_clients.read().await.iter().cloned().collect::<Vec<_>>();
let still_loading = self.pending_clients.read().await.keys().cloned().collect::<Vec<_>>();
let _ = tx.send(LoadingMsg::Terminate { still_loading }).await;
if let Some(task) = loading_display_task {
let _ = tokio::time::timeout(
Expand All @@ -810,14 +820,14 @@ impl ToolManager {
},
_ = server_loading_fut => {
if let Some(tx) = tx {
let still_loading = self.pending_clients.read().await.iter().cloned().collect::<Vec<_>>();
let still_loading = self.pending_clients.read().await.keys().cloned().collect::<Vec<_>>();
let _ = tx.send(LoadingMsg::Terminate { still_loading }).await;
}
}
_ = ctrl_c() => {
if self.is_interactive {
if let Some(tx) = tx {
let still_loading = self.pending_clients.read().await.iter().cloned().collect::<Vec<_>>();
let still_loading = self.pending_clients.read().await.keys().cloned().collect::<Vec<_>>();
let _ = tx.send(LoadingMsg::Terminate { still_loading }).await;
}
} else {
Expand Down Expand Up @@ -1105,7 +1115,7 @@ impl ToolManager {
}

pub async fn pending_clients(&self) -> Vec<String> {
self.pending_clients.read().await.iter().cloned().collect::<Vec<_>>()
self.pending_clients.read().await.keys().cloned().collect::<Vec<_>>()
}
}

Expand Down Expand Up @@ -1257,7 +1267,7 @@ fn spawn_orchestrator_task(
mut msg_rx: tokio::sync::mpsc::Receiver<UpdateEventMessage>,
mut prompt_list_receiver: tokio::sync::broadcast::Receiver<PromptQuery>,
mut prompt_list_sender: tokio::sync::broadcast::Sender<PromptQueryResult>,
pending: Arc<RwLock<HashSet<String>>>,
pending: Arc<RwLock<HashMap<String, TransportType>>>,
agent: Arc<Mutex<Agent>>,
database: Database,
regex: Regex,
Expand Down Expand Up @@ -1344,7 +1354,7 @@ fn spawn_orchestrator_task(
msg: UpdateEventMessage,
loading_servers: &mut HashMap<String, Instant>,
record_temp_buf: &mut Vec<u8>,
pending: &Arc<RwLock<HashSet<String>>>,
pending: &Arc<RwLock<HashMap<String, TransportType>>>,
agent: &Arc<Mutex<Agent>>,
database: &Database,
conv_id: &str,
Expand Down Expand Up @@ -1376,7 +1386,12 @@ fn spawn_orchestrator_task(
let time_taken = (std::time::Instant::now() - init_time).as_secs_f64().abs();
format!("{:.2}", time_taken)
});
pending.write().await.remove(&server_name);
// We will never get a None. But even if we do we should not fatal here.
let transport_type = pending
.write()
.await
.remove(&server_name)
.unwrap_or(TransportType::Stdio);

let result_tools = match &result {
Ok(tools_result) => {
Expand Down Expand Up @@ -1463,6 +1478,7 @@ fn spawn_orchestrator_task(
database,
conv_id,
&server_name,
transport_type,
&mut specs,
&mut sanitized_mapping,
&alias_list,
Expand Down Expand Up @@ -1666,8 +1682,11 @@ fn spawn_orchestrator_task(
}
}
},
UpdateEventMessage::InitStart { server_name, .. } => {
pending.write().await.insert(server_name.clone());
UpdateEventMessage::InitStart {
server_name,
transport_type,
} => {
pending.write().await.insert(server_name.clone(), transport_type);
loading_servers.insert(server_name, std::time::Instant::now());
},
UpdateEventMessage::Deinit { server_name, .. } => {
Expand Down Expand Up @@ -1724,6 +1743,7 @@ async fn process_tool_specs(
database: &Database,
conversation_id: &str,
server_name: &str,
transport_type: TransportType,
specs: &mut Vec<ToolSpec>,
tn_map: &mut HashMap<ModelToolName, ToolInfo>,
alias_list: &HashMap<HostToolName, ModelToolName>,
Expand Down Expand Up @@ -1798,6 +1818,7 @@ async fn process_tool_specs(
database,
conversation_id,
server_name.to_string(),
transport_type,
None,
number_of_tools,
all_tool_names,
Expand Down
11 changes: 10 additions & 1 deletion crates/chat-cli/src/cli/chat/tools/custom_tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use crate::os::Os;
use crate::util::MCP_SERVER_TOOL_DELIMITER;
use crate::util::pattern_matching::matches_any_pattern;

#[derive(Clone, Serialize, Deserialize, Debug, Eq, PartialEq, JsonSchema)]
#[derive(Copy, Clone, Serialize, Deserialize, Debug, Eq, PartialEq, JsonSchema)]
#[serde(rename_all = "camelCase")]
pub enum TransportType {
/// Standard input/output transport (default)
Expand All @@ -39,6 +39,15 @@ pub enum TransportType {
Http,
}

impl std::fmt::Display for TransportType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TransportType::Stdio => write!(f, "Stdio"),
TransportType::Http => write!(f, "Http"),
}
}
}

impl Default for TransportType {
fn default() -> Self {
Self::Stdio
Expand Down
4 changes: 4 additions & 0 deletions crates/chat-cli/src/telemetry/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use strum::{

use super::definitions::metrics::CodewhispererterminalRecordUserTurnCompletion;
use super::definitions::types::CodewhispererterminalChatConversationType;
use crate::cli::chat::tools::custom_tool::TransportType;
use crate::telemetry::definitions::IntoMetricDatum;
use crate::telemetry::definitions::metrics::{
AmazonqDidSelectProfile,
Expand Down Expand Up @@ -385,6 +386,7 @@ impl Event {
EventType::McpServerInit {
conversation_id,
server_name,
transport_type,
init_failure_reason,
number_of_tools,
all_tool_names,
Expand All @@ -397,6 +399,7 @@ impl Event {
value: None,
amazonq_conversation_id: Some(conversation_id.into()),
codewhispererterminal_mcp_server_name: Some(server_name.into()),
codewhispererterminal_transport_type: Some(transport_type.to_string().into()),
codewhispererterminal_mcp_server_init_failure_reason: init_failure_reason
.map(CodewhispererterminalMcpServerInitFailureReason),
codewhispererterminal_tools_per_mcp_server: Some(CodewhispererterminalToolsPerMcpServer(
Expand Down Expand Up @@ -665,6 +668,7 @@ pub enum EventType {
McpServerInit {
conversation_id: String,
server_name: String,
transport_type: TransportType,
init_failure_reason: Option<String>,
number_of_tools: usize,
all_tool_names: Option<String>,
Expand Down
3 changes: 3 additions & 0 deletions crates/chat-cli/src/telemetry/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ use crate::api_client::{
use crate::auth::builder_id::get_start_url_and_region;
use crate::aws_common::app_name;
use crate::cli::RootSubcommand;
use crate::cli::chat::tools::custom_tool::TransportType;
use crate::database::settings::Setting;
use crate::database::{
Database,
Expand Down Expand Up @@ -382,6 +383,7 @@ impl TelemetryThread {
database: &Database,
conversation_id: String,
server_name: String,
transport_type: TransportType,
init_failure_reason: Option<String>,
number_of_tools: usize,
all_tool_names: Option<String>,
Expand All @@ -391,6 +393,7 @@ impl TelemetryThread {
let mut telemetry_event = Event::new(crate::telemetry::EventType::McpServerInit {
conversation_id,
server_name,
transport_type,
init_failure_reason,
number_of_tools,
all_tool_names,
Expand Down
6 changes: 6 additions & 0 deletions crates/chat-cli/telemetry_definitions.json
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,11 @@
"type": "string",
"description": "Name of the MCP server"
},
{
"name": "codewhispererterminal_transportType",
"type": "string",
"description": "Transport type used by the MCP server"
},
{
"name": "codewhispererterminal_mcpServerAllToolNames",
"type": "string",
Expand Down Expand Up @@ -464,6 +469,7 @@
{ "type": "credentialStartUrl" },
{ "type": "amazonqConversationId" },
{ "type": "codewhispererterminal_mcpServerName" },
{ "type": "codewhispererterminal_transportType" },
{
"type": "codewhispererterminal_mcpServerInitFailureReason",
"required": false
Expand Down
Loading