- Rust CNNA high-performance Convolutional Neural Network (CNN) for garbage classification built from scratch using Rust and the tch crate (PyTorch bindings).
- This project classifies waste into 6 categories: cardboard, glass, metal, paper, plastic, and trash with 82.24% validation accuracy on unseen dataset or more depending on how much you train it.
- Model: Custom CNN with batch normalization, dropout, and data augmentation
- Classes: 6 garbage types (cardboard, glass, metal, paper, plastic, trash)
- Dataset Size: 13,901 images
- Performance: 82.24% validation accuracy
- Training Platform: Kaggle GPUs
π DatasetSource: Garbage Classification Dataset on Kaggle
- Total Images: 13,901
- Training Set: 11,120 images (80%)
- Validation Set: 2,781 images (20%)
- Image Size: 200x200 RGB
- Data Augmentation: Random horizontal/vertical flips, rotation, resizing
- Feature Extraction: 5 convolutional blocks with increasing channels (32β64β128β256β512)
- Pooling: Max Pooling and Adaptive Average Pooling to capture spatial features
- Classification: Multi-layer fully connected layers with dropout (0.15 and 0.075) for regularization
- Regularization: Batch normalization, gradient clipping to prevent exploding gradients
- Hybrid Learning Rate Scheduling: Combines warmup, cosine annealing, and step decay for adaptive learning rate tuning
- Early Stopping: Validation-based configurable patience to stop training when no improvement is observed
- Optimizer: Adam optimizer with weight decay and epsilon tuning for stable convergence
- Loss Function: Cross-entropy with optional label smoothing (configurable)
- Batch Size: 32 (configurable per training session)
- Device Support: Automatic CUDA GPU detection with fallback to CPU
- Gradient Clipping: Clips gradients with max norm to maintain training stability
- Multiple Model Saving Formats: PyTorch native, named tensor format, individual tensor files, and JSON metadata for universal compatibility
- Interactive Model Testing: Command-line interface for manual image testing during development
- Training Statistics Export: Epoch-wise metrics saved in CSV for analysis and visualization
- Comprehensive Metrics: Precision, recall, F1-score, and support per class
- Confusion Matrix: Detailed misclassification analysis and CSV export
- Loss Curves: Training and validation loss tracking included
- Export: CSV files and Python plotting scripts for detailed metric visualization
curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh
source ~/.cargo/envcurl -O https://download.pytorch.org/libtorch/cpu/libtorch-cxx11-abi-shared-with-deps-2.1.0%2Bcpu.zipFor GPU (Linux) with CUDA 12.1:
curl -O https://download.pytorch.org/libtorch/cu121/libtorch-cxx11-abi-shared-with-deps-2.1.0%2Bcu121.zip-unzip the zips
export LIBTORCH=/path/to/libtorch
export LD_LIBRARY_PATH=$LIBTORCH/lib:$LD_LIBRARY_PATH- Overall Accuracy: 83.35% (validation, best epoch: 52)
- Training Time: ~1 hour 3 minutes (57 epochs, early stopping, GPU)
- Final Training Loss: 0.0394
- Final Validation Loss: 0.7798
| Class | Precision | Recall | F1-Score | Support |
|---|---|---|---|---|
| Cardboard | 0.80 | 0.87 | 0.83 | 408 |
| Glass | 0.84 | 0.78 | 0.81 | 477 |
| Metal | 0.83 | 0.83 | 0.83 | 474 |
| Paper | 0.83 | 0.86 | 0.84 | 460 |
| Plastic | 0.85 | 0.84 | 0.85 | 508 |
| Trash | 0.83 | 0.81 | 0.82 | 454 |
Weighted Avg: P=0.83, R=0.83, F1=0.83 (all classes, support=2781)
- Confusion matrix and more are available in
confusion_matrix.csv,evaluation_results.csv. - For learning curves, see the included plot script output.
- Model formats: PyTorch
.pt, named tensors, individual tensors, JSON metadata.
- Challenge Areas: Glass classification (visually similar to other materials)
- Common Confusions: Glass β Paper/Plastic (expected due to visual similarity)
- Learning Curve: Smooth convergence without overfitting
π Kaggle IntegrationLive Demo: Garbage Classifier Rust on Kaggle
This project was developed and tested on Kaggle Notebooks using free GPU access, demonstrating the feasibility of Rust-based deep learning on cloud platforms.
Check out this link look for full example
# Clone repository
!git clone https://github.com/Not-Buddy/garbage_classification.git
%cd garbage_classification
# Install Rust + dependencies (handled in notebook)
# Run training
!cargo run --release -- 3## π Project Structure```
garbage_classification/
βββ src/
β βββ main.rs # Entry```int with CLI argument```rsing
β βββ menu.rs # Configuration```nagement +```taset loading
β βββ relu.rs # CNN```del architecture
β βββ train_model.rs # Training```op and optimization
β βββ training_validation.rs # Evaluation metrics an```lotting
β βββ visualize_data.rs # Data loading and batch```sualization
βββ Cargo.toml # Dependencies and project configuration
βββ README.md # This file
βββ plot_results.py # Generated Python plotting script
π§ Technical Implementation### Dataset Loading- Parallel Processing: Rayon for multi-threaded image loading
- Augmentation Pipeline: Random flips, rotation, resizing
- Memory Efficiency: Tensor operations in CHW format
- Progress Tracking: Real-time processing counter
CNN {
features: Conv2d(32) -> BN -> ReLU -> Conv2d(32) -> BN -> ReLU -> MaxPool
Conv2d(64) -> BN -> ReLU -> Conv2d(64) -> BN -> ReLU -> MaxPool
Conv2d(128) -> BN -> ReLU -> Conv2d(128) -> BN -> ReLU -> MaxPool,
gap: GlobalAveragePooling2d,
classifier: Linear(128) -> ReLU -> Dropout(0.15) -> Linear(6)
}- Hybrid Learning Rate Scheduling: Combines warmup, cosine annealing, and step decay for smooth and adaptive learning rate adjustment
- Early Stopping: Validation-based early stopping with configurable patience and minimum improvement thresholds to prevent overfitting
- Gradient Clipping: Stable training by clipping gradients with max norm to prevent exploding gradients
- Multiple Model Saving Formats: Saves models in PyTorch native, named tensor, individual tensor, and metadata JSON formats for maximum usability and interoperability
- Automatic Device Selection: GPU-enabled training with fallback to CPU for seamless cross-platform compatibility
- Progress Monitoring: Real-time logging of training/validation losses, accuracies, and learning rates with epoch-wise summaries
- Comprehensive Evaluation: Detailed per-class precision, recall, F1-score, confusion matrix generation, and CSV export for analysis
- Interactive Model Testing: Manual image classification interface for fast inference and validation during development
- Efficient Data Loading: Custom DataLoader with batch handling and image preprocessing aligned with training specs
- Training Statistics Export: Saves complete training metrics in CSV supporting downstream visualization and analysis
After training, the following files are automatically generated:
training_losses.csv- Loss data for plotting training progressconfusion_matrix.csv- Confusion matrix data for classification analysisevaluation_results.csv- Detailed predictions with probabilities for each validation sampletraining_stats.csv- Complete training metrics including losses and accuracies per epochplot_results.py- Python script for generating performance plots and visualizationsgarbage_classifier_100_epochs_13901_samples.pt- Saved PyTorch model weights in universal formatgarbage_classifier_100_epochs_13901_samples_named.pt- Named tensors format for Python compatibilitygarbage_classifier_100_epochs_13901_samples_tensors/- Directory containing individual tensor filesgarbage_classifier_100_epochs_13901_samples_metadata.json- JSON metadata file with training and model details
- Use
plot_results.pyto generate insightful visualizations of training progress, losses, and metrics:
- Attention Mechanisms: Focus on distinguishing features
- Ensemble Methods: Multiple model voting
- Advanced Augmentation: MixUp, CutMix, AutoAugment
- Early Stopping: Automatic best model selection
- Hyperparameter Tuning: Automated search
- Model Serving: REST API for inference
- Cross-Validation: More robust evaluation
- Performance: Near-C++ speed with memory safety
- Concurrency: Excellent parallel processing support
- Ecosystem: Growing ML/DL ecosystem with tch
- Deployment: Single binary, no runtime dependencies
- Reliability: Compile-time guarantees prevent common ML bugs
- Create your feature branch (
git checkout -b feature/amazing-feature) - Commit your changes (
git commit -m 'Add amazing feature') - Push to the branch (
git push origin feature/amazing-feature) - Open a Pull Request
π LicenseThis project is open source and available under the MIT License.
π Acknowledgments- tch crate for PyTorch bindings
- Kaggle for free GPU access
- Dataset contributors for the garbage classification dataset
- Rust community for excellent documentation and support
β Star this repository if you found it helpful!
For questions or collaboration opportunities, feel free to open an issue or reach out directly.

