Modified model from Semi-Supervised Semantic Segmentation Using Unreliable Pseudo Labels, CVPR 2022.
Refer to U2PL GitHub for the official U2PL model code.
After cloning the repository,
git clone https://github.com/ponoma1202/U2PL_copy.git && cd U2PL
conda create -n u2pl python=3.8.16
pip install -r requirements.txt
conda install pytorch torchvision pytorch-cuda=11.7 -c pytorch -c nvidia
Used pytorch version 2.0.1 and torchvision version 0.15.2.
The U2PL model is trained on both Citiscapes and PASCAL VOC 2012 datasets.
For Cityscapes
-
Download leftImg8bit_trainvaltest.zip
-
Download gtFine.zip from Google Drive
-
Unzip
gtFine
andleftImg8bit_trainvaltest
into a new folder namedcitiscapes
. -
Move
citiscapes
folder intodata
folder.
Note: both gtFine
and leftImg8bit_trainvaltest
contain:
train
test
val
For ResNet
-
Download the dataset from Kaggle
-
Download and unzip SegmentationClassAug.zip
-
Unzip the
archive.zip
file intodata
from Kaggle download. Unzippedarchive
file should contain theVOC2012
folder. Delete the extraarchive/VOC2012/VOC2012
folder. -
Move the unzipped
SegmentationClassAug/SegmentationClassAug
folder into the VOC2012 folder.
Move VOC2012
into the U2PL
project folder. The path should be U2PL/data/VOC2012
. File directory should look like this:
data/VOC2012
Annotations
ImageSets
JPEGImages
SegmentationClass
SegmentationClassAug
SegmentationObject
The data
folder should now contain at least one of the following folders in addition to the splits
folder:
cityscapes
VOC2012
-
Download resnet101.pth file
-
Replace
/path/to/resnet101.pth
at the top of theu2pl/models/resnet.py
file under themodel_urls = {"resnet101": "/path/to/resnet101.pth"}
variable with file path ofresnet101.pth
.
-
Replace the relative paths for
data_root
anddata_list
values in the configuration files with explicit paths. -
Edit the TODO's in each sbatch shell script.
The U2PL model used a batch size of 16, however, the replicated U2PL model could only use a batch size of 14 before running out of 80 GB of memory on an A100 GPU.
U2PL
├───data
│ └───splits**
│ ├───cityscapes
│ └───pascal
├───experiments
│ ├───cityscapes
│ │ └───744
│ │ ├───ours
│ │ └───suponly
│ └───pascal
│ └───1464
│ ├───ours**
│ └───suponly
├───pytorch_utils**
├───u2pl
├───dataset**
├───models
└───utils
data/splits
contains all labeled.txt and unlabeled.txt splits.experiments/pascal/1464/ours/config.yaml
contains config file for semi-supervised model using the PASCAL VOC dataset. Follow similar structure to access config files for Cityscapes.pytorch_utils/lr_scheduler
contains learning rate scheduler with early stoppingpytroch_utils/metadata.py
is a tracker for metadata such as training accuracy, learning rate, loss, etcu2pl/dataset/pascal_voc.py
is the DataSet class for the PASCAL VOC dataset. Similar structure for Cityscapes dataset.
- config: specify file path for configuration file (.yaml)
- seed: set to 2 in original U2PL model
- output_dirpath: specify file path for output directory for plots of tracked parameters and copy of the dictionary of the trained model
The original U2PL model uses intersection over union (IoU) as its benchmark for accuracy. In the modified U2PL model, IoU remains
the benchmark for accuracy, however, other accuracy metrics are also tracked. All plots for accuracy metrics, along with a csv file
of all tracked metrics are generated in the folder specified by output_dirpath
argument.
- IoU: intersection over union. Tracked only for validation.
- accuracy: number of pixels classified correctly / total number of pixels in image. Tracked for both training and validation
- ARI: adjusted random score
Each sbatch file is located in the main U2PL
project directory. Edit TODO's before running.
Use the infer.py file with the following arguments:
- config: specify file path for configuration file (.yaml) used during training
- model_path: path to the model-state-dict.pt file located in the
output_dirpath
folder (3rd input argument for model training) - save_folder: path to folder into which inferencing images will be saved
To compare to the original U2PL model results, download the model checkpoints from the U2PL GitHub README file.
The pytorch_utils
folder containing: plateau_scheduler
, used for stopping the training of the model early if learning rate plateaus,
and training_stats
, used to track metadata, were taken from Michael Majursky's https://github.com/usnistgov/semantic-segmentation-unet/tree/pytorch.