Skip to content

Commit 4335317

Browse files
authored
Merge pull request #10 from pbower/shape
feat: Add Shape support
2 parents 29ee97d + bc5a9d2 commit 4335317

25 files changed

+518
-63
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,5 @@ target
1919
# and can be added to the global gitignore or merged into this file. For a more nuclear
2020
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
2121
#.idea/
22+
23+
TODO.md

src/enums/shape_dim.rs

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
//! # ShapeDim Enum Module
2+
//!
3+
//! Companion to [crate::traits::shape::Shape];
4+
//!
5+
//! Contains all supported `Shape` variants.
6+
7+
use crate::traits::shape::Shape;
8+
9+
/// Recursively-describable dimensional rank for any `Value`.
10+
#[derive(Clone, PartialEq, Eq)]
11+
pub enum ShapeDim {
12+
/// Rank-0 - must always be `1`
13+
Rank0(usize),
14+
15+
/// Array row count
16+
Rank1(usize),
17+
18+
/// Relational table with row/column counts.
19+
Rank2 { rows: usize, cols: usize },
20+
21+
/// 3d object
22+
Rank3 { x: usize, y: usize, z: usize },
23+
24+
/// 4d Object
25+
Rank4 {
26+
a: usize,
27+
b: usize,
28+
c: usize,
29+
d: usize,
30+
},
31+
32+
/// N-dimensional tensor.
33+
RankN(Vec<usize>),
34+
35+
/// Dictionary shape
36+
Dictionary {
37+
// Number of keys
38+
n_keys: usize,
39+
// Number of values for each key
40+
n_values: Vec<usize>,
41+
},
42+
43+
/// Heterogeneous ordered collection.
44+
/// Covers lists, tuples, cubes (with varying row-counts) and user-custom chunked types.
45+
///
46+
/// The order is significant; for tuples it is the fixed arity order.
47+
Collection(Vec<ShapeDim>),
48+
49+
/// Shape could not be determined.
50+
Unknown,
51+
}
52+
53+
54+
/// Implement `Shape` for `ShapeDim` so recursive calls like `item.shape_3d()`
55+
/// compile when iterating `Collection(Vec<ShapeDim>)`.
56+
impl Shape for ShapeDim {
57+
fn shape(&self) -> ShapeDim {
58+
self.clone()
59+
}
60+
}

src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ pub mod enums {
100100
pub mod temporal_array;
101101
}
102102
pub mod operators;
103+
pub mod shape_dim;
103104
}
104105

105106
/// Contains SIMD-accelerated kernels for the 'essentials' that are highly coupled to this crate
@@ -179,6 +180,7 @@ pub mod traits {
179180
pub mod print;
180181
pub mod type_unions;
181182
pub mod custom_value;
183+
pub mod shape;
182184
}
183185

184186
pub mod aliases;

src/structs/bitmask.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ use std::fmt::{Debug, Display, Formatter, Result as FmtResult};
1818
use std::ops::{BitAnd, BitOr, Deref, DerefMut, Index, Not};
1919

2020
use crate::structs::vec64::Vec64;
21+
use crate::traits::shape::Shape;
22+
use crate::enums::shape_dim::ShapeDim;
2123
use crate::{BitmaskV, Buffer, Length, Offset};
2224

2325
/// TODO: Move bitmask kernels here
@@ -731,6 +733,11 @@ impl Display for Bitmask {
731733
}
732734
}
733735

736+
impl Shape for Bitmask {
737+
fn shape(&self) -> ShapeDim {
738+
ShapeDim::Rank1(self.len())
739+
}
740+
}
734741

735742
#[cfg(test)]
736743
mod tests {

src/structs/chunked/super_array.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ use std::iter::FromIterator;
1313
#[cfg(feature = "views")]
1414
use std::sync::Arc;
1515

16+
use crate::traits::shape::Shape;
17+
use crate::enums::shape_dim::ShapeDim;
1618
#[cfg(feature = "views")]
1719
use crate::ArrayV;
1820
#[cfg(feature = "views")]
@@ -719,6 +721,12 @@ impl From<FieldArray> for SuperArray {
719721
}
720722
}
721723

724+
impl Shape for SuperArray {
725+
fn shape(&self) -> ShapeDim {
726+
ShapeDim::Rank1(self.len())
727+
}
728+
}
729+
722730
impl Display for SuperArray {
723731
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
724732
writeln!(

src/structs/chunked/super_table.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ use std::sync::Arc;
2727
use crate::structs::field::Field;
2828
use crate::structs::field_array::FieldArray;
2929
use crate::structs::table::Table;
30+
use crate::traits::shape::Shape;
31+
use crate::enums::shape_dim::ShapeDim;
3032
#[cfg(feature = "views")]
3133
use crate::{SuperTableV, TableV};
3234

@@ -154,6 +156,12 @@ impl SuperTable {
154156
pub fn n_cols(&self) -> usize {
155157
self.schema.len()
156158
}
159+
160+
#[inline]
161+
pub fn n_rows(&self) -> usize {
162+
self.n_rows
163+
}
164+
157165
#[inline]
158166
pub fn n_batches(&self) -> usize {
159167
self.batches.len()
@@ -223,6 +231,7 @@ impl SuperTable {
223231
name
224232
}
225233
}
234+
226235
}
227236

228237
impl Default for SuperTable {
@@ -238,6 +247,12 @@ impl FromIterator<Table> for SuperTable {
238247
}
239248
}
240249

250+
impl Shape for SuperTable {
251+
fn shape(&self) -> ShapeDim {
252+
ShapeDim::Rank2 { rows: self.n_rows(), cols: self.n_cols() }
253+
}
254+
}
255+
241256
impl Display for SuperTable {
242257
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
243258
writeln!(

src/structs/cube.rs

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ use super::field_array::FieldArray;
3232
#[cfg(feature = "views")]
3333
use crate::aliases::CubeV;
3434
use crate::ffi::arrow_dtype::ArrowType;
35+
use crate::traits::shape::Shape;
36+
use crate::enums::shape_dim::ShapeDim;
3537
use crate::{Field, Table};
3638
#[cfg(feature = "views")]
3739
use crate::TableV;
@@ -64,6 +66,7 @@ pub struct Cube {
6466
pub tables: Vec<Table>,
6567
/// Number of rows in each table
6668
pub n_rows: Vec<usize>,
69+
6770
/// Cube name
6871
pub name: String,
6972
// Third-dimensional index column names
@@ -160,6 +163,11 @@ impl Cube {
160163
self.tables.len()
161164
}
162165

166+
/// Returns the number of rows
167+
pub fn n_rows(&self) -> Vec<usize> {
168+
self.n_rows.clone()
169+
}
170+
163171
/// Returns the number of columns.
164172
pub fn n_cols(&self) -> usize {
165173
self.tables[0].n_cols()
@@ -439,6 +447,13 @@ impl IntoIterator for Cube {
439447
}
440448
}
441449

450+
impl Shape for Cube {
451+
fn shape(&self) -> ShapeDim {
452+
ShapeDim::Collection(self.tables.iter().map(|t| t.shape()).collect())
453+
}
454+
}
455+
456+
442457
#[cfg(test)]
443458
mod tests {
444459
use super::*;

src/structs/field_array.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ use polars::series::Series;
2121
#[cfg(feature = "views")]
2222
use crate::aliases::FieldAVT;
2323
use crate::ffi::arrow_dtype::ArrowType;
24+
use crate::traits::shape::Shape;
25+
use crate::enums::shape_dim::ShapeDim;
2426
use crate::{Array, Field};
2527

2628

@@ -245,6 +247,12 @@ impl Display for FieldArray {
245247
}
246248
}
247249

250+
impl Shape for FieldArray {
251+
fn shape(&self) -> ShapeDim {
252+
ShapeDim::Rank1(self.len())
253+
}
254+
}
255+
248256
#[cfg(test)]
249257
mod tests {
250258
use super::*;

0 commit comments

Comments
 (0)