Skip to content

removed hard coded class count #15

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion training/example_train_script.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,5 @@ python main_2pt5d.py --max_epochs 100 --val_every 1 --optim_lr 0.000005 \
--logdir finetune_ckpt_example --point_prompt --label_prompt --distributed --seed 12346 \
--iterative_training_warm_up_epoch 50 --reuse_img_embedding \
--label_prompt_warm_up_epoch 25 \
--checkpoint ./runs/9s_2dembed_model.pt
--checkpoint ./runs/9s_2dembed_model.pt \
--num_classes 105
8 changes: 7 additions & 1 deletion training/main_2pt5d.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@
parser.add_argument("--skip_bk", action="store_true", help="skip background (0) during training")
parser.add_argument("--patch_embed_3d", action="store_true", help="using 3d patch embedding layer")

parser.add_argument("--num_classes", default=105, type=int, help="number of output classes")


def start_tb(log_dir):
cmd = ["tensorboard", "--logdir", log_dir]
Expand All @@ -123,6 +125,10 @@ def main():
args = parser.parse_args()
args.amp = not args.noamp
args.logdir = "./runs/" + args.logdir

if args.num_classes == 0:
warnings.warn("consider setting the correct number of classes")

# start_tb(args.logdir)
if args.seed > -1:
set_determinism(seed=args.seed)
Expand Down Expand Up @@ -162,7 +168,7 @@ def main_worker(gpu, args):

dice_loss = DiceCELoss(sigmoid=True)

post_label = AsDiscrete(to_onehot=105)
post_label = AsDiscrete(to_onehot=args.num_classes)
post_pred = Compose([Activations(sigmoid=True), AsDiscrete(threshold=0.5)])
dice_acc = DiceMetric(include_background=False, reduction=MetricReduction.MEAN, get_not_nans=True)

Expand Down
15 changes: 9 additions & 6 deletions training/trainer_2pt5d.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def prepare_sam_training_input(inputs, labels, args, model):
unique_labels = unique_labels[: args.num_prompt]

# add 4 background labels to every batch
background_labels = list(set([i for i in range(1, 105)]) - set(unique_labels.cpu().numpy()))
background_labels = list(set([i for i in range(1, args.num_classes)]) - set(unique_labels.cpu().numpy()))
random.shuffle(background_labels)
unique_labels = torch.cat([unique_labels, torch.tensor(background_labels[:4]).cuda(args.rank)])

Expand Down Expand Up @@ -375,7 +375,7 @@ def train_epoch_iterative(model, loader, optimizer, scaler, epoch, loss_func, ar


def prepare_sam_test_input(inputs, labels, args, previous_pred=None):
unique_labels = torch.tensor([i for i in range(1, 105)]).cuda(args.rank)
unique_labels = torch.tensor([i for i in range(1, args.num_classes)]).cuda(args.rank)

# preprocess make the size of lable same as high_res_logit
batch_labels = torch.stack([labels == unique_labels[i] for i in range(len(unique_labels))], dim=0).float()
Expand All @@ -400,7 +400,7 @@ def prepare_sam_test_input(inputs, labels, args, previous_pred=None):

def prepare_sam_val_input_cp_only(inputs, labels, args):
# Don't exclude background in val but will ignore it in metric calculation
unique_labels = torch.tensor([i for i in range(1, 105)]).cuda(args.rank)
unique_labels = torch.tensor([i for i in range(1, args.num_classes)]).cuda(args.rank)

# preprocess make the size of lable same as high_res_logit
batch_labels = torch.stack([labels == unique_labels[i] for i in range(len(unique_labels))], dim=0).float()
Expand Down Expand Up @@ -457,15 +457,18 @@ def val_epoch(model, loader, epoch, acc_func, args, iterative=False, post_label=
y_pred = torch.stack(post_pred(decollate_batch(logit)), 0)

# TODO: we compute metric for each prompt for simplicity in validation.
acc_batch = compute_dice(y_pred=y_pred, y=target)
acc_batch = compute_dice(y_pred=y_pred[None,], y=target[None,])
acc_sum, not_nans = (
torch.nansum(acc_batch).item(),
104 - torch.sum(torch.isnan(acc_batch).float()).item(),
(args.num_classes - 1) - torch.sum(torch.isnan(acc_batch).float()).item(),
)
acc_sum_total += acc_sum
not_nans_total += not_nans

acc, not_nans = acc_sum_total / not_nans_total, not_nans_total
if not_nans_total > 0:
acc, not_nans = acc_sum_total / not_nans_total, not_nans_total
else:
acc, not_nans = 0, 0
f_name = batch_data["image"].meta["filename_or_obj"]
print(f"Rank: {args.rank}, Case: {f_name}, Acc: {acc:.4f}, N_prompts: {int(not_nans)} ")

Expand Down