From 976a1bc8f58f392a82fb135c5d306511faacdd44 Mon Sep 17 00:00:00 2001 From: "Adam H. Leventhal" Date: Thu, 21 Aug 2025 15:16:28 -0700 Subject: [PATCH 1/2] header marshaling should be case-insensitive and more robust --- dropshot/src/extractor/header.rs | 75 +++++++++++++- dropshot/src/from_map.rs | 173 ++++++++++++++++++++++++++----- 2 files changed, 220 insertions(+), 28 deletions(-) diff --git a/dropshot/src/extractor/header.rs b/dropshot/src/extractor/header.rs index a0cc40adb..f3294dc4c 100644 --- a/dropshot/src/extractor/header.rs +++ b/dropshot/src/extractor/header.rs @@ -9,7 +9,7 @@ use schemars::JsonSchema; use serde::de::DeserializeOwned; use crate::{ - from_map::from_map, ApiEndpointBodyContentType, + from_map::from_map_insensitive, ApiEndpointBodyContentType, ApiEndpointParameterLocation, HttpError, RequestContext, RequestInfo, ServerContext, }; @@ -17,11 +17,20 @@ use crate::{ use super::{metadata::get_metadata, ExtractorMetadata, SharedExtractor}; /// `Header` is an extractor used to deserialize an instance of -/// `HeaderType` from an HTTP request's header values. `PathType` may be any +/// `HeaderType` from an HTTP request's header values. `HeaderType` may be any /// structure that implements [serde::Deserialize] and [schemars::JsonSchema]. /// While headers are accessible through [RequestInfo::headers], using this /// extractor in an entrypoint causes header inputs to be documented in /// OpenAPI output. See the crate documentation for more information. +/// +/// Note that (unlike the [`Query`] and [`Path`] extractors) headers are case- +/// insensitive. You may rename fields with mixed casing (e.g. by using +/// #[serde(rename = "X-Header-Foo")]) and that casing will appear in the +/// OpenAPI document output. Case-insensitive name conflicts may lead to +/// unexpected behavior, and should be avoided. For example, only one of the +/// conflicting fields may be deserialized, and therefore deserialization may +/// fail if any conflicting field is required (i.e. not an `Option` type) +#[derive(Debug)] pub struct Header { inner: HeaderType, } @@ -50,8 +59,14 @@ where .map_err(|message: http::header::ToStrError| { HttpError::for_bad_request(None, message.to_string()) })?; - let x: HeaderType = from_map(&headers).unwrap(); - Ok(Header { inner: x }) + println!("headers: {headers:?}"); + let inner = from_map_insensitive(&headers).map_err(|message| { + HttpError::for_bad_request( + None, + format!("error processing headers: {message}"), + ) + })?; + Ok(Header { inner }) } #[async_trait] @@ -71,3 +86,55 @@ where get_metadata::(&ApiEndpointParameterLocation::Header) } } + +#[cfg(test)] +mod tests { + use schemars::JsonSchema; + use serde::Deserialize; + + use crate::{extractor::header::http_request_load_header, RequestInfo}; + + #[test] + fn test_header_parsing() { + #[allow(dead_code)] + #[derive(Debug, Deserialize, JsonSchema)] + pub struct TestHeaders { + header_a: String, + #[serde(rename = "X-Header-B")] + header_b: String, + } + + let addr = std::net::SocketAddr::new( + std::net::Ipv4Addr::LOCALHOST.into(), + 8080, + ); + + let request = + hyper::Request::builder().uri("http://localhost").body(()).unwrap(); + let info = RequestInfo::new(&request, addr); + + let parsed = http_request_load_header::(&info); + assert!(parsed.is_err()); + + let request = hyper::Request::builder() + .header("header_a", "header_a value") + .header("X-Header-B", "header_b value") + .uri("http://localhost") + .body(()) + .unwrap(); + println!("request: {request:?}"); + let info = RequestInfo::new(&request, addr); + + let parsed = http_request_load_header::(&info); + + match parsed { + Ok(headers) => { + assert_eq!(headers.inner.header_a, "header_a value"); + assert_eq!(headers.inner.header_b, "header_b value"); + } + Err(e) => { + panic!("unexpected error: {}", e); + } + } + } +} diff --git a/dropshot/src/from_map.rs b/dropshot/src/from_map.rs index 3ac83cf99..dc0a94fa8 100644 --- a/dropshot/src/from_map.rs +++ b/dropshot/src/from_map.rs @@ -1,4 +1,4 @@ -// Copyright 2020 Oxide Computer Company +// Copyright 2025 Oxide Computer Company use paste::paste; use serde::de::DeserializeSeed; @@ -13,9 +13,10 @@ use std::any::type_name; use std::collections::BTreeMap; use std::fmt::Debug; use std::fmt::Display; +use std::marker::PhantomData; /// Deserialize a BTreeMap into a type, invoking -/// String::parse() for all values according to the required type. MapValue may +/// FromStr::parse() for all values according to the required type. MapValue may /// be either a single String or a sequence of Strings. pub(crate) fn from_map<'a, T, Z>( map: &'a BTreeMap, @@ -28,6 +29,18 @@ where T::deserialize(&mut deserializer).map_err(|e| e.0) } +/// Similar to [`from_map`], but case-insensitive. +pub(crate) fn from_map_insensitive<'a, T, Z>( + map: &'a BTreeMap, +) -> Result +where + T: Deserialize<'a>, + Z: MapValue + Debug + Clone + 'static, +{ + let mut deserializer = MapDeserializer::from_map_insensitive(map); + T::deserialize(&mut deserializer).map_err(|e| e.0) +} + pub(crate) trait MapValue { fn as_value(&self) -> Result<&str, MapError>; fn as_seq(&self) -> Result>, MapError>; @@ -46,22 +59,68 @@ impl MapValue for String { } } +// To handle headers, which are case-insensitive, we use a trait to adapt +// deserialization of structs to be case-insensitive. + +trait Casing { + fn rename(fields: &'static [&'static str]) -> BTreeMap; +} + +enum CaseSensitive {} +enum CaseInsensitive {} + +impl Casing for CaseSensitive { + fn rename(_fields: &'static [&'static str]) -> BTreeMap { + Default::default() + } +} +impl Casing for CaseInsensitive { + fn rename(fields: &'static [&'static str]) -> BTreeMap { + fields + .iter() + .map(|field| (field.to_lowercase(), field.to_string())) + .collect() + } +} + /// Deserializer for BTreeMap that interprets the values. It has /// two modes: about to iterate over the map or about to process a single value. #[derive(Debug)] -enum MapDeserializer<'de, Z: MapValue + Debug + Clone + 'static> { - Map(&'de BTreeMap), +enum MapDeserializer< + 'de, + Z: MapValue + Debug + Clone + 'static, + CaseSensitivity, +> { + Map(&'de BTreeMap, PhantomData), Value(Z), } -impl<'de, Z> MapDeserializer<'de, Z> +impl<'de, Z> MapDeserializer<'de, Z, CaseSensitive> where Z: MapValue + Debug + Clone + 'static, { fn from_map(input: &'de BTreeMap) -> Self { - MapDeserializer::Map(input) + MapDeserializer::Map(input, PhantomData) + } + + fn from_value(input: Z) -> Self { + MapDeserializer::Value(input) } +} +impl<'de, Z> MapDeserializer<'de, Z, CaseInsensitive> +where + Z: MapValue + Debug + Clone + 'static, +{ + fn from_map_insensitive(input: &'de BTreeMap) -> Self { + MapDeserializer::Map(input, PhantomData) + } +} + +impl<'de, Z, CaseSensitivity> MapDeserializer<'de, Z, CaseSensitivity> +where + Z: MapValue + Debug + Clone + 'static, +{ /// Helper function to extract pattern match for Value. Fail if we're /// expecting a Map or return the result of the provided function. fn value(&self, deserialize: F) -> Result @@ -70,7 +129,7 @@ where { match self { MapDeserializer::Value(ref raw_value) => deserialize(raw_value), - MapDeserializer::Map(_) => Err(MapError( + MapDeserializer::Map(..) => Err(MapError( "must be applied to a flattened struct rather than a raw type" .to_string(), )), @@ -137,9 +196,11 @@ macro_rules! de_value { }; } -impl<'de, Z> Deserializer<'de> for &mut MapDeserializer<'de, Z> +impl<'de, Z, CaseSensitivity> Deserializer<'de> + for &mut MapDeserializer<'de, Z, CaseSensitivity> where Z: MapValue + Debug + Clone + 'static, + CaseSensitivity: Casing, { type Error = MapError; @@ -181,13 +242,19 @@ where fn deserialize_struct( self, _name: &'static str, - _fields: &'static [&'static str], + fields: &'static [&'static str], visitor: V, ) -> Result where V: Visitor<'de>, { - self.deserialize_map(visitor) + match self { + MapDeserializer::Map(map, _) => visitor + .visit_map(MapMapAccess::new::(map, fields)), + MapDeserializer::Value(_) => Err(MapError( + "deserialization container must be fully flattened".to_string(), + )), + } } // This will only be called when deserializing a structure that contains a // flattened structure. See `deserialize_any` below for details. @@ -196,14 +263,10 @@ where V: Visitor<'de>, { match self { - MapDeserializer::Map(map) => { - let xx = map.clone(); - let x = Box::new(xx.into_iter()); - let m = MapMapAccess:: { iter: x, value: None }; - visitor.visit_map(m) - } + MapDeserializer::Map(map, _) => visitor + .visit_map(MapMapAccess::new::(map, &[])), MapDeserializer::Value(_) => Err(MapError( - "destination struct must be fully flattened".to_string(), + "deserialization container must be fully flattened".to_string(), )), } } @@ -294,9 +357,11 @@ where } // Deserializer component for processing enums. -impl<'de, Z> EnumAccess<'de> for &mut MapDeserializer<'de, Z> +impl<'de, Z, CaseSensitivity> EnumAccess<'de> + for &mut MapDeserializer<'de, Z, CaseSensitivity> where Z: MapValue + Debug + Clone + 'static, + CaseSensitivity: Casing, { type Error = MapError; type Variant = Self; @@ -313,7 +378,8 @@ where } // Deserializer component for processing enum variants. -impl<'de, Z> VariantAccess<'de> for &mut MapDeserializer<'de, Z> +impl<'de, Z, CaseSensitivity> VariantAccess<'de> + for &mut MapDeserializer<'de, Z, CaseSensitivity> where Z: MapValue + Clone + Debug + 'static, { @@ -359,6 +425,25 @@ struct MapMapAccess { iter: Box>, /// Pending value in a key-value pair value: Option, + /// Field renaming + rename: BTreeMap, +} + +impl MapMapAccess +where + Z: MapValue + Debug + Clone + 'static, +{ + fn new( + map: &mut &BTreeMap, + fields: &'static [&'static str], + ) -> Self + where + Z: MapValue + Debug + Clone + 'static, + { + let iter = Box::new(map.clone().into_iter()); + let rename = CaseSensitivity::rename(fields); + Self { iter, value: None, rename } + } } impl<'de, Z> MapAccess<'de> for MapMapAccess @@ -375,11 +460,14 @@ where K: DeserializeSeed<'de>, { match self.iter.next() { - Some((key, value)) => { + Some((mut key, value)) => { + if let Some(rename) = self.rename.get(&key) { + key = rename.clone(); + } // Save the value for later. self.value.replace(value); // Create a Deserializer for that single value. - let mut deserializer = MapDeserializer::Value(key); + let mut deserializer = MapDeserializer::from_value(key); seed.deserialize(&mut deserializer).map(Some) } None => Ok(None), @@ -391,7 +479,7 @@ where { match self.value.take() { Some(value) => { - let mut deserializer = MapDeserializer::Value(value); + let mut deserializer = MapDeserializer::from_value(value); seed.deserialize(&mut deserializer) } // This means we were called without a corresponding call to @@ -420,7 +508,7 @@ where { match self.iter.next() { Some(value) => { - let mut deserializer = MapDeserializer::Value(value); + let mut deserializer = MapDeserializer::from_value(value); seed.deserialize(&mut deserializer).map(Some) } None => Ok(None), @@ -430,7 +518,10 @@ where #[cfg(test)] mod test { + use crate::from_map::from_map_insensitive; + use super::from_map; + use core::panic; use serde::Deserialize; use std::collections::BTreeMap; @@ -462,7 +553,10 @@ mod test { map.insert("b".to_string(), "B".to_string()); match from_map::(&map) { Err(s) => { - assert_eq!(s, "destination struct must be fully flattened") + assert_eq!( + s, + "deserialization container must be fully flattened" + ) } Ok(_) => panic!("unexpected success"), } @@ -571,4 +665,35 @@ mod test { Ok(_) => panic!("unexpected success"), } } + + #[test] + fn test_case_insensitive() { + #[derive(Deserialize, Debug)] + struct A { + #[serde(rename = "X-Header-A")] + a: String, + #[serde(rename = "WhYwOuLdAnYoNeEvErDoThIs")] + b: String, + } + + let map: BTreeMap = [ + ("x-header-a".to_string(), "x-header-a value".to_string()), + ("whywouldanyoneeverdothis".to_string(), "other value".to_string()), + ] + .into_iter() + .collect(); + + match from_map::(&map) { + Ok(_) => panic!("unexpected success"), + Err(_) => (), + } + + match from_map_insensitive::(&map) { + Ok(a) => { + assert_eq!(a.a, "x-header-a value"); + assert_eq!(a.b, "other value"); + } + Err(s) => panic!("error: {}", s), + } + } } From dca85d152210bb57b5da52fad3d75fc872fc79be Mon Sep 17 00:00:00 2001 From: "Adam H. Leventhal" Date: Mon, 25 Aug 2025 14:31:19 -0700 Subject: [PATCH 2/2] review feedback --- dropshot/src/extractor/header.rs | 31 ++++++++++++++++++++++------ dropshot/src/from_map.rs | 35 ++++++++++++++++++++------------ 2 files changed, 47 insertions(+), 19 deletions(-) diff --git a/dropshot/src/extractor/header.rs b/dropshot/src/extractor/header.rs index f3294dc4c..e8da8691b 100644 --- a/dropshot/src/extractor/header.rs +++ b/dropshot/src/extractor/header.rs @@ -26,10 +26,11 @@ use super::{metadata::get_metadata, ExtractorMetadata, SharedExtractor}; /// Note that (unlike the [`Query`] and [`Path`] extractors) headers are case- /// insensitive. You may rename fields with mixed casing (e.g. by using /// #[serde(rename = "X-Header-Foo")]) and that casing will appear in the -/// OpenAPI document output. Case-insensitive name conflicts may lead to -/// unexpected behavior, and should be avoided. For example, only one of the -/// conflicting fields may be deserialized, and therefore deserialization may -/// fail if any conflicting field is required (i.e. not an `Option` type) +/// OpenAPI document output. Name conflicts (including names differentiated by +/// casing since headers are case-insensitive) may lead to unexpected behavior, +/// and should be avoided. For example, only one of the conflicting fields may +/// be deserialized, and therefore deserialization may fail if any conflicting +/// field is required (i.e. not an `Option` type) #[derive(Debug)] pub struct Header { inner: HeaderType, @@ -59,7 +60,6 @@ where .map_err(|message: http::header::ToStrError| { HttpError::for_bad_request(None, message.to_string()) })?; - println!("headers: {headers:?}"); let inner = from_map_insensitive(&headers).map_err(|message| { HttpError::for_bad_request( None, @@ -122,7 +122,26 @@ mod tests { .uri("http://localhost") .body(()) .unwrap(); - println!("request: {request:?}"); + let info = RequestInfo::new(&request, addr); + + let parsed = http_request_load_header::(&info); + + match parsed { + Ok(headers) => { + assert_eq!(headers.inner.header_a, "header_a value"); + assert_eq!(headers.inner.header_b, "header_b value"); + } + Err(e) => { + panic!("unexpected error: {}", e); + } + } + + let request = hyper::Request::builder() + .header("header_a", "header_a value") + .header("X-hEaDEr-b", "header_b value") + .uri("http://localhost") + .body(()) + .unwrap(); let info = RequestInfo::new(&request, addr); let parsed = http_request_load_header::(&info); diff --git a/dropshot/src/from_map.rs b/dropshot/src/from_map.rs index dc0a94fa8..bf8638eae 100644 --- a/dropshot/src/from_map.rs +++ b/dropshot/src/from_map.rs @@ -63,23 +63,29 @@ impl MapValue for String { // deserialization of structs to be case-insensitive. trait Casing { - fn rename(fields: &'static [&'static str]) -> BTreeMap; + /// Return the remapping of field names (if such a remapping is applicable). + /// Keys are the lowercase header names; values are the original struct + /// field names. + fn rename( + fields: &'static [&'static str], + ) -> BTreeMap; } enum CaseSensitive {} enum CaseInsensitive {} impl Casing for CaseSensitive { - fn rename(_fields: &'static [&'static str]) -> BTreeMap { + fn rename( + _fields: &'static [&'static str], + ) -> BTreeMap { Default::default() } } impl Casing for CaseInsensitive { - fn rename(fields: &'static [&'static str]) -> BTreeMap { - fields - .iter() - .map(|field| (field.to_lowercase(), field.to_string())) - .collect() + fn rename( + fields: &'static [&'static str], + ) -> BTreeMap { + fields.into_iter().map(|field| (field.to_lowercase(), *field)).collect() } } @@ -426,7 +432,7 @@ struct MapMapAccess { /// Pending value in a key-value pair value: Option, /// Field renaming - rename: BTreeMap, + rename: BTreeMap, } impl MapMapAccess @@ -460,14 +466,17 @@ where K: DeserializeSeed<'de>, { match self.iter.next() { - Some((mut key, value)) => { - if let Some(rename) = self.rename.get(&key) { - key = rename.clone(); - } + Some((key, value)) => { + let key = if let Some(rename) = self.rename.get(&key) { + *rename + } else { + key.as_str() + }; // Save the value for later. self.value.replace(value); // Create a Deserializer for that single value. - let mut deserializer = MapDeserializer::from_value(key); + let mut deserializer = + MapDeserializer::from_value(key.to_string()); seed.deserialize(&mut deserializer).map(Some) } None => Ok(None),