This is the repository for 2021 Spring HPML class project.
- Motivation:
- Accelerate clustering algorithms on large-scale datasets.
- Challenge:
- Dataset: large-scale data points
- Algorithm:
- Parallel computation of distance matrix (bottle neck).
- Parallel updates of parameters.
- Specifically designed distance kernel functions.
- System:
- PyTorch implementation VS Scikit Learn Library.
- CUDA extended PyTorch implementation.
- Contribution:
- PyTorch implementations of clustering algorithms (GPU/CPU).
- CUDA extensions of specifically designed distance kernel functions.
- CUDA extensions called from PyTorch runtime.
- Evaluated on large-scale clustering benchmarks.
- Datasets:
- MNIST
- Cifar-10
- Algorithms:
- K-means
- Sklearn
- Pytorch-cpu
- Pytorch-gpu-l2
- Pytorch-gpu-Minkowski
- GMM
- Sklearn
- Pytorch-cpu
- pytorch-gpu
- K-means
- Metric
- Total time
- Speedup
The scripts are separate implementions based on different frames.
_kmeans_skl.py
: Sklearn implementation of KMeans._kmeans.py
: Pytorch CPU & GPU implementation of KMeans_gmm_skl.py
: sklearn implementation of GMM._gaussian.py
: Pytorch CPU & GPU implementation of GMM/cuda/my_cuda_kernel.cu
: CUDA implementation of L2 & Minkowski kernels
Run and profile the scripts using:
- Sklearn: Modify script to test on specific dataset, and run with
time python _[kmeans/gmm]_skl.py
- PyTorch: Use
test.py
with options:--alg
:kmeans++
/gmm
--kernel
:l2
/m[N]
--kernel_cuda
: whether use customized cuda kernel--device
:cuda
/cpu
- KMeans
- Our PyTorch implementation on CPU is less efficient than sklearn.
- Speedup of GPU (PyTorch) on conventional kernel function (L2) is up to 5×.
- Speedup of GPU (PyTorch) on special kernels is up to 13×.
- Speedup of GPU (PyTorch) with CUDA kernel extension can be up to 16×.
- GMM
- GMM is relatively more computation intensive.
- Speedup of our PyTorch implementation on CPU can be up to 9×.
- This speedup can be the best utilisation of multi-thread computing.
- Speedup of GPU (PyTorch) can be up to 122×.
- Partial results:
- Accuracy: *