Skip to content

no_std support, infallible get/set/mask field operations #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,22 +11,22 @@ For example:
```rust
// Create a bitset of width 8, with three fields a, b and c.
let mut x = BitSet::<8>::from(0b1_0101_111);
// ^ ^ ^
// c b a
// ^ ^ ^
// c b a

// Extract individual fields. Note that the extracted field is statically
// typed according to width.
let a: BitSet<3> = x.get_field::<3, 0>().unwrap();
let b: BitSet<4> = x.get_field::<4, 3>().unwrap();
let c: BitSet<1> = x.get_field::<1, 7>().unwrap();
let a: BitSet<3> = x.get_field::<3, 0>();
let b: BitSet<4> = x.get_field::<4, 3>();
let c: BitSet<1> = x.get_field::<1, 7>();
assert_eq!(u8::from(a), 0b111);
assert_eq!(u8::from(b), 0b0101);
assert_eq!(u8::from(c), 0b1);

// Now set a feild. Note that setting a field requires a bitset with the
// correct width.
let b = BitSet::<4>::from_int(0b1010).unwrap();
x.set_field::<4, 3>(b).unwrap();
let b = bitset!(4, 0b1010);
x.set_field::<4, 3>(b);
assert_eq!(u8::from(a), 0b111);
assert_eq!(u8::from(b), 0b1010);
assert_eq!(u8::from(c), 0b1);
Expand Down
121 changes: 71 additions & 50 deletions lib/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
// Copyright Oxide Computer Company 2025

#![no_std]
#![feature(generic_const_exprs)]
#![allow(incomplete_features)]
#![allow(clippy::unusual_byte_groupings)]

extern crate alloc;

use alloc::vec::Vec;
use core::error::Error;
use core::fmt::Display;
use seq_macro::seq;
use std::error::Error;
use std::fmt::Display;

#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub struct BitSet<const BITS: usize>(pub [u8; ((BITS - 1) >> 3) + 1])
Expand All @@ -22,26 +25,26 @@ where
}
}

impl<const BITS: usize> std::fmt::LowerHex for BitSet<BITS>
impl<const BITS: usize> core::fmt::LowerHex for BitSet<BITS>
where
[(); ((BITS - 1) >> 3) + 1]:,
{
fn fmt(
&self,
f: &mut std::fmt::Formatter<'_>,
) -> Result<(), std::fmt::Error> {
f: &mut core::fmt::Formatter<'_>,
) -> Result<(), core::fmt::Error> {
write!(f, "{:x?}", self.0)
}
}

impl<const BITS: usize> std::fmt::UpperHex for BitSet<BITS>
impl<const BITS: usize> core::fmt::UpperHex for BitSet<BITS>
where
[(); ((BITS - 1) >> 3) + 1]:,
{
fn fmt(
&self,
f: &mut std::fmt::Formatter<'_>,
) -> Result<(), std::fmt::Error> {
f: &mut core::fmt::Formatter<'_>,
) -> Result<(), core::fmt::Error> {
write!(f, "{:X?}", self.0)
}
}
Expand Down Expand Up @@ -91,41 +94,60 @@ where
pub const ZERO: Self = Self([0; ((BITS - 1) >> 3) + 1]);
pub const WIDTH: usize = BITS;

pub fn max() -> Result<Self, OutOfBounds> {
pub fn max() -> Self {
let mut s = Self([0xff; ((BITS - 1) >> 3) + 1]);
s.mask()?;
Ok(s)
s.mask();
s
}

fn mask(&mut self) -> Result<(), OutOfBounds> {
fn mask(&mut self) {
let mask = ((1 << (BITS % 8)) - 1) as u8;
let pos = BITS >> 3;
if pos > self.0.len() {
return Err(OutOfBounds {});
}
if mask == 0 {
return Ok(());
return;
}

// Index safety:
//
// The array len is ((BITS-1) >> 3) + 1. Can we assume that
// pos = BITS >> 3 is always less than ((BITS-1) >> 3) + 1?
//
// The following analysis shows we can.
//
// Start with what we want to show.
//
// b >> 3 < ((b-1) >> 3) + 1
//
// Treat b>>3 as b/8 so we can use normal math. The result holds for >>
// as x>>3 <= x/3 because x>>3 is x/8 with truncation.
//
// b/8 < ((b-1)/8) + 1
//
// Now simplify.
//
// b/8 - (b-1)/8 < 1
// b - (b-1) < 8
// b - b + 1 < 8
// 1 < 8 ✓
//
self.0[pos] &= mask;

Ok(())
}

pub fn get_field<const FBITS: usize, const OFFSET: usize>(
&self,
) -> Result<BitSet<FBITS>, OutOfBounds>
) -> BitSet<FBITS>
where
[(); ((FBITS - 1) >> 3) + 1]:,
[(); BITS - FBITS]:,
[(); BITS - (OFFSET + FBITS)]:,
{
assert!(FBITS <= BITS);
let sub =
&self.0[(OFFSET >> 3)..(OFFSET >> 3) + (((FBITS - 1) >> 3) + 1)];
let mut result = BitSet::<FBITS>(sub.try_into().unwrap());
result = result.shr(OFFSET % 8);
result.mask()?;
Ok(result)
result.mask();
result
}

pub fn extend_right<const XBITS: usize>(&self) -> BitSet<{ BITS + XBITS }>
Expand All @@ -152,13 +174,13 @@ where
pub fn set_field<const FBITS: usize, const OFFSET: usize>(
&mut self,
value: BitSet<FBITS>,
) -> Result<(), OutOfBounds>
where
) where
[(); ((FBITS - 1) >> 3) + 1]:,
[(); (((FBITS + OFFSET) - 1) >> 3) + 1]:,
[(); BITS - FBITS]:,
[(); BITS - (OFFSET + FBITS)]:,
{
let m = BitSet::<FBITS>::max().unwrap();
let m = BitSet::<FBITS>::max();
let mut mask = BitSet::<BITS>::default();
for (i, x) in m.0.iter().enumerate() {
mask.0[i] = *x;
Expand All @@ -173,8 +195,6 @@ where
}
v = v.shl(OFFSET);
*self = self.or(v);

Ok(())
}

#[allow(clippy::should_implement_trait)]
Expand Down Expand Up @@ -240,7 +260,7 @@ where
pub struct OutOfBounds {}
impl Error for OutOfBounds {}
impl Display for OutOfBounds {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
write!(f, "out of bounds")
}
}
Expand Down Expand Up @@ -321,11 +341,11 @@ macro_rules! large_int_small_bitset {

macro_rules! display {
($width:expr) => {
impl std::fmt::Display for BitSet<$width> {
impl core::fmt::Display for BitSet<$width> {
fn fmt(
&self,
f: &mut std::fmt::Formatter<'_>,
) -> Result<(), std::fmt::Error> {
f: &mut core::fmt::Formatter<'_>,
) -> Result<(), core::fmt::Error> {
let i = u64::from(*self);
write!(f, "{i}/0x{i:x}/0b{i:b}")
}
Expand Down Expand Up @@ -372,6 +392,7 @@ seq!(N in 1..=64 { display!(N); });
#[cfg(test)]
mod test {
use super::*;
use alloc::vec;
use bitset_macro::bitset;

#[test]
Expand Down Expand Up @@ -529,52 +550,52 @@ mod test {
fn test_max() {
seq!(I in 1..=7 {{
let a = (1u8 << (I)) - 1;
let b = BitSet::<I>::max().unwrap();
let b = BitSet::<I>::max();
assert_eq!(a, u8::from(b), "max count={}", I);
}});
let a = u8::MAX;
let b = BitSet::<8>::max().unwrap();
let b = BitSet::<8>::max();
assert_eq!(a, u8::from(b), "max count={}", 8);

seq!(I in 9..=15 {{
let a = (1u16 << (I)) - 1;
let b = BitSet::<I>::max().unwrap();
let b = BitSet::<I>::max();
assert_eq!(a, u16::from(b), "max count={}", I);
}});
let a = u16::MAX;
let b = BitSet::<16>::max().unwrap();
let b = BitSet::<16>::max();
assert_eq!(a, u16::from(b), "max count={}", 16);

seq!(I in 17..=31 {{
let a = (1u32 << (I)) - 1;
let b = BitSet::<I>::max().unwrap();
let b = BitSet::<I>::max();
assert_eq!(a, u32::from(b), "max count={}", I);
}});
let a = u32::MAX;
let b = BitSet::<32>::max().unwrap();
let b = BitSet::<32>::max();
assert_eq!(a, u32::from(b), "max count={}", 32);

seq!(I in 33..=63 {{
let a = (1u64 << (I)) - 1;
let b = BitSet::<I>::max().unwrap();
let b = BitSet::<I>::max();
assert_eq!(a, u64::from(b), "max count={}", I);
}});
let a = u64::MAX;
let b = BitSet::<64>::max().unwrap();
let b = BitSet::<64>::max();
assert_eq!(a, u64::from(b), "max count={}", 64);
}

#[test]
fn test_extend_right() {
let x = BitSet::<47>::max().unwrap();
let x = BitSet::<47>::max();
let y = x.extend_right::<5>();
let z = BitSet::<52>::try_from((1u64 << 47) - 1).unwrap();
assert_eq!(y, z);
}

#[test]
fn test_extend_left() {
let x = BitSet::<47>::max().unwrap();
let x = BitSet::<47>::max();
let y = x.extend_left::<5>();
let z = BitSet::<52>::try_from(((1u64 << 47) - 1) >> 5).unwrap();
assert_eq!(y, z);
Expand Down Expand Up @@ -632,7 +653,7 @@ mod test {
let y = BitSet::<8>::from(b);

let mut z = x;
z.set_field::<8, 0>(y).unwrap();
z.set_field::<8, 0>(y);
let expected =
0b1111111100000000111111110000000011111111000000001111111110101010u64;

Expand All @@ -645,7 +666,7 @@ mod test {
);

let mut z = x;
z.set_field::<8, 48>(y).unwrap();
z.set_field::<8, 48>(y);
let expected =
0b1111111110101010111111110000000011111111000000001111111100000000u64;

Expand All @@ -667,17 +688,17 @@ mod test {

// Extract individual fields. Note that the extracted field is statically
// typed according to width.
let a: BitSet<3> = x.get_field::<3, 0>().unwrap();
let b: BitSet<4> = x.get_field::<4, 3>().unwrap();
let c: BitSet<1> = x.get_field::<1, 7>().unwrap();
let a: BitSet<3> = x.get_field::<3, 0>();
let b: BitSet<4> = x.get_field::<4, 3>();
let c: BitSet<1> = x.get_field::<1, 7>();
assert_eq!(u8::from(a), 0b111);
assert_eq!(u8::from(b), 0b0101);
assert_eq!(u8::from(c), 0b1);

// Now set a feild. Note that setting a field requires a bitset with the
// Now set a field. Note that setting a field requires a bitset with the
// correct width.
let b = bitset!(4, 0b1010);
x.set_field::<4, 3>(b).unwrap();
x.set_field::<4, 3>(b);
assert_eq!(u8::from(a), 0b111);
assert_eq!(u8::from(b), 0b1010);
assert_eq!(u8::from(c), 0b1);
Expand Down Expand Up @@ -783,8 +804,8 @@ pub mod legacy {
let x = bitset!(16, 0xabcd);
assert_eq!(x.to_int(), 0xabcd);

let x0 = x.get_field::<8, 0>().unwrap();
let x1 = x.get_field::<2, 8>().unwrap();
let x0 = x.get_field::<8, 0>();
let x1 = x.get_field::<2, 8>();

assert_eq!(x0.to_int(), 0xcd);
assert_eq!(x1.to_int(), 0x3);
Expand Down