diff --git a/dropshot/src/extractor/header.rs b/dropshot/src/extractor/header.rs index a0cc40ad..e8da8691 100644 --- a/dropshot/src/extractor/header.rs +++ b/dropshot/src/extractor/header.rs @@ -9,7 +9,7 @@ use schemars::JsonSchema; use serde::de::DeserializeOwned; use crate::{ - from_map::from_map, ApiEndpointBodyContentType, + from_map::from_map_insensitive, ApiEndpointBodyContentType, ApiEndpointParameterLocation, HttpError, RequestContext, RequestInfo, ServerContext, }; @@ -17,11 +17,21 @@ use crate::{ use super::{metadata::get_metadata, ExtractorMetadata, SharedExtractor}; /// `Header` is an extractor used to deserialize an instance of -/// `HeaderType` from an HTTP request's header values. `PathType` may be any +/// `HeaderType` from an HTTP request's header values. `HeaderType` may be any /// structure that implements [serde::Deserialize] and [schemars::JsonSchema]. /// While headers are accessible through [RequestInfo::headers], using this /// extractor in an entrypoint causes header inputs to be documented in /// OpenAPI output. See the crate documentation for more information. +/// +/// Note that (unlike the [`Query`] and [`Path`] extractors) headers are case- +/// insensitive. You may rename fields with mixed casing (e.g. by using +/// #[serde(rename = "X-Header-Foo")]) and that casing will appear in the +/// OpenAPI document output. Name conflicts (including names differentiated by +/// casing since headers are case-insensitive) may lead to unexpected behavior, +/// and should be avoided. For example, only one of the conflicting fields may +/// be deserialized, and therefore deserialization may fail if any conflicting +/// field is required (i.e. not an `Option` type) +#[derive(Debug)] pub struct Header { inner: HeaderType, } @@ -50,8 +60,13 @@ where .map_err(|message: http::header::ToStrError| { HttpError::for_bad_request(None, message.to_string()) })?; - let x: HeaderType = from_map(&headers).unwrap(); - Ok(Header { inner: x }) + let inner = from_map_insensitive(&headers).map_err(|message| { + HttpError::for_bad_request( + None, + format!("error processing headers: {message}"), + ) + })?; + Ok(Header { inner }) } #[async_trait] @@ -71,3 +86,74 @@ where get_metadata::(&ApiEndpointParameterLocation::Header) } } + +#[cfg(test)] +mod tests { + use schemars::JsonSchema; + use serde::Deserialize; + + use crate::{extractor::header::http_request_load_header, RequestInfo}; + + #[test] + fn test_header_parsing() { + #[allow(dead_code)] + #[derive(Debug, Deserialize, JsonSchema)] + pub struct TestHeaders { + header_a: String, + #[serde(rename = "X-Header-B")] + header_b: String, + } + + let addr = std::net::SocketAddr::new( + std::net::Ipv4Addr::LOCALHOST.into(), + 8080, + ); + + let request = + hyper::Request::builder().uri("http://localhost").body(()).unwrap(); + let info = RequestInfo::new(&request, addr); + + let parsed = http_request_load_header::(&info); + assert!(parsed.is_err()); + + let request = hyper::Request::builder() + .header("header_a", "header_a value") + .header("X-Header-B", "header_b value") + .uri("http://localhost") + .body(()) + .unwrap(); + let info = RequestInfo::new(&request, addr); + + let parsed = http_request_load_header::(&info); + + match parsed { + Ok(headers) => { + assert_eq!(headers.inner.header_a, "header_a value"); + assert_eq!(headers.inner.header_b, "header_b value"); + } + Err(e) => { + panic!("unexpected error: {}", e); + } + } + + let request = hyper::Request::builder() + .header("header_a", "header_a value") + .header("X-hEaDEr-b", "header_b value") + .uri("http://localhost") + .body(()) + .unwrap(); + let info = RequestInfo::new(&request, addr); + + let parsed = http_request_load_header::(&info); + + match parsed { + Ok(headers) => { + assert_eq!(headers.inner.header_a, "header_a value"); + assert_eq!(headers.inner.header_b, "header_b value"); + } + Err(e) => { + panic!("unexpected error: {}", e); + } + } + } +} diff --git a/dropshot/src/from_map.rs b/dropshot/src/from_map.rs index 3ac83cf9..bf8638ea 100644 --- a/dropshot/src/from_map.rs +++ b/dropshot/src/from_map.rs @@ -1,4 +1,4 @@ -// Copyright 2020 Oxide Computer Company +// Copyright 2025 Oxide Computer Company use paste::paste; use serde::de::DeserializeSeed; @@ -13,9 +13,10 @@ use std::any::type_name; use std::collections::BTreeMap; use std::fmt::Debug; use std::fmt::Display; +use std::marker::PhantomData; /// Deserialize a BTreeMap into a type, invoking -/// String::parse() for all values according to the required type. MapValue may +/// FromStr::parse() for all values according to the required type. MapValue may /// be either a single String or a sequence of Strings. pub(crate) fn from_map<'a, T, Z>( map: &'a BTreeMap, @@ -28,6 +29,18 @@ where T::deserialize(&mut deserializer).map_err(|e| e.0) } +/// Similar to [`from_map`], but case-insensitive. +pub(crate) fn from_map_insensitive<'a, T, Z>( + map: &'a BTreeMap, +) -> Result +where + T: Deserialize<'a>, + Z: MapValue + Debug + Clone + 'static, +{ + let mut deserializer = MapDeserializer::from_map_insensitive(map); + T::deserialize(&mut deserializer).map_err(|e| e.0) +} + pub(crate) trait MapValue { fn as_value(&self) -> Result<&str, MapError>; fn as_seq(&self) -> Result>, MapError>; @@ -46,22 +59,74 @@ impl MapValue for String { } } +// To handle headers, which are case-insensitive, we use a trait to adapt +// deserialization of structs to be case-insensitive. + +trait Casing { + /// Return the remapping of field names (if such a remapping is applicable). + /// Keys are the lowercase header names; values are the original struct + /// field names. + fn rename( + fields: &'static [&'static str], + ) -> BTreeMap; +} + +enum CaseSensitive {} +enum CaseInsensitive {} + +impl Casing for CaseSensitive { + fn rename( + _fields: &'static [&'static str], + ) -> BTreeMap { + Default::default() + } +} +impl Casing for CaseInsensitive { + fn rename( + fields: &'static [&'static str], + ) -> BTreeMap { + fields.into_iter().map(|field| (field.to_lowercase(), *field)).collect() + } +} + /// Deserializer for BTreeMap that interprets the values. It has /// two modes: about to iterate over the map or about to process a single value. #[derive(Debug)] -enum MapDeserializer<'de, Z: MapValue + Debug + Clone + 'static> { - Map(&'de BTreeMap), +enum MapDeserializer< + 'de, + Z: MapValue + Debug + Clone + 'static, + CaseSensitivity, +> { + Map(&'de BTreeMap, PhantomData), Value(Z), } -impl<'de, Z> MapDeserializer<'de, Z> +impl<'de, Z> MapDeserializer<'de, Z, CaseSensitive> where Z: MapValue + Debug + Clone + 'static, { fn from_map(input: &'de BTreeMap) -> Self { - MapDeserializer::Map(input) + MapDeserializer::Map(input, PhantomData) + } + + fn from_value(input: Z) -> Self { + MapDeserializer::Value(input) + } +} + +impl<'de, Z> MapDeserializer<'de, Z, CaseInsensitive> +where + Z: MapValue + Debug + Clone + 'static, +{ + fn from_map_insensitive(input: &'de BTreeMap) -> Self { + MapDeserializer::Map(input, PhantomData) } +} +impl<'de, Z, CaseSensitivity> MapDeserializer<'de, Z, CaseSensitivity> +where + Z: MapValue + Debug + Clone + 'static, +{ /// Helper function to extract pattern match for Value. Fail if we're /// expecting a Map or return the result of the provided function. fn value(&self, deserialize: F) -> Result @@ -70,7 +135,7 @@ where { match self { MapDeserializer::Value(ref raw_value) => deserialize(raw_value), - MapDeserializer::Map(_) => Err(MapError( + MapDeserializer::Map(..) => Err(MapError( "must be applied to a flattened struct rather than a raw type" .to_string(), )), @@ -137,9 +202,11 @@ macro_rules! de_value { }; } -impl<'de, Z> Deserializer<'de> for &mut MapDeserializer<'de, Z> +impl<'de, Z, CaseSensitivity> Deserializer<'de> + for &mut MapDeserializer<'de, Z, CaseSensitivity> where Z: MapValue + Debug + Clone + 'static, + CaseSensitivity: Casing, { type Error = MapError; @@ -181,13 +248,19 @@ where fn deserialize_struct( self, _name: &'static str, - _fields: &'static [&'static str], + fields: &'static [&'static str], visitor: V, ) -> Result where V: Visitor<'de>, { - self.deserialize_map(visitor) + match self { + MapDeserializer::Map(map, _) => visitor + .visit_map(MapMapAccess::new::(map, fields)), + MapDeserializer::Value(_) => Err(MapError( + "deserialization container must be fully flattened".to_string(), + )), + } } // This will only be called when deserializing a structure that contains a // flattened structure. See `deserialize_any` below for details. @@ -196,14 +269,10 @@ where V: Visitor<'de>, { match self { - MapDeserializer::Map(map) => { - let xx = map.clone(); - let x = Box::new(xx.into_iter()); - let m = MapMapAccess:: { iter: x, value: None }; - visitor.visit_map(m) - } + MapDeserializer::Map(map, _) => visitor + .visit_map(MapMapAccess::new::(map, &[])), MapDeserializer::Value(_) => Err(MapError( - "destination struct must be fully flattened".to_string(), + "deserialization container must be fully flattened".to_string(), )), } } @@ -294,9 +363,11 @@ where } // Deserializer component for processing enums. -impl<'de, Z> EnumAccess<'de> for &mut MapDeserializer<'de, Z> +impl<'de, Z, CaseSensitivity> EnumAccess<'de> + for &mut MapDeserializer<'de, Z, CaseSensitivity> where Z: MapValue + Debug + Clone + 'static, + CaseSensitivity: Casing, { type Error = MapError; type Variant = Self; @@ -313,7 +384,8 @@ where } // Deserializer component for processing enum variants. -impl<'de, Z> VariantAccess<'de> for &mut MapDeserializer<'de, Z> +impl<'de, Z, CaseSensitivity> VariantAccess<'de> + for &mut MapDeserializer<'de, Z, CaseSensitivity> where Z: MapValue + Clone + Debug + 'static, { @@ -359,6 +431,25 @@ struct MapMapAccess { iter: Box>, /// Pending value in a key-value pair value: Option, + /// Field renaming + rename: BTreeMap, +} + +impl MapMapAccess +where + Z: MapValue + Debug + Clone + 'static, +{ + fn new( + map: &mut &BTreeMap, + fields: &'static [&'static str], + ) -> Self + where + Z: MapValue + Debug + Clone + 'static, + { + let iter = Box::new(map.clone().into_iter()); + let rename = CaseSensitivity::rename(fields); + Self { iter, value: None, rename } + } } impl<'de, Z> MapAccess<'de> for MapMapAccess @@ -376,10 +467,16 @@ where { match self.iter.next() { Some((key, value)) => { + let key = if let Some(rename) = self.rename.get(&key) { + *rename + } else { + key.as_str() + }; // Save the value for later. self.value.replace(value); // Create a Deserializer for that single value. - let mut deserializer = MapDeserializer::Value(key); + let mut deserializer = + MapDeserializer::from_value(key.to_string()); seed.deserialize(&mut deserializer).map(Some) } None => Ok(None), @@ -391,7 +488,7 @@ where { match self.value.take() { Some(value) => { - let mut deserializer = MapDeserializer::Value(value); + let mut deserializer = MapDeserializer::from_value(value); seed.deserialize(&mut deserializer) } // This means we were called without a corresponding call to @@ -420,7 +517,7 @@ where { match self.iter.next() { Some(value) => { - let mut deserializer = MapDeserializer::Value(value); + let mut deserializer = MapDeserializer::from_value(value); seed.deserialize(&mut deserializer).map(Some) } None => Ok(None), @@ -430,7 +527,10 @@ where #[cfg(test)] mod test { + use crate::from_map::from_map_insensitive; + use super::from_map; + use core::panic; use serde::Deserialize; use std::collections::BTreeMap; @@ -462,7 +562,10 @@ mod test { map.insert("b".to_string(), "B".to_string()); match from_map::(&map) { Err(s) => { - assert_eq!(s, "destination struct must be fully flattened") + assert_eq!( + s, + "deserialization container must be fully flattened" + ) } Ok(_) => panic!("unexpected success"), } @@ -571,4 +674,35 @@ mod test { Ok(_) => panic!("unexpected success"), } } + + #[test] + fn test_case_insensitive() { + #[derive(Deserialize, Debug)] + struct A { + #[serde(rename = "X-Header-A")] + a: String, + #[serde(rename = "WhYwOuLdAnYoNeEvErDoThIs")] + b: String, + } + + let map: BTreeMap = [ + ("x-header-a".to_string(), "x-header-a value".to_string()), + ("whywouldanyoneeverdothis".to_string(), "other value".to_string()), + ] + .into_iter() + .collect(); + + match from_map::(&map) { + Ok(_) => panic!("unexpected success"), + Err(_) => (), + } + + match from_map_insensitive::(&map) { + Ok(a) => { + assert_eq!(a.a, "x-header-a value"); + assert_eq!(a.b, "other value"); + } + Err(s) => panic!("error: {}", s), + } + } }