@@ -75,6 +75,12 @@ public class Sam2 extends AbstractSamJ {
75
75
+ "from skimage import measure" + System .lineSeparator ()
76
76
+ "measure.label(np.ones((10, 10)), connectivity=1)" + System .lineSeparator ()
77
77
+ "import torch" + System .lineSeparator ()
78
+ + "device = 'cpu'" + System .lineSeparator ()
79
+ + ((!IS_APPLE_SILICON ) ? ""
80
+ : "from torch.backends import mps" + System .lineSeparator ()
81
+ + "if mps.is_built() and mps.is_available():" + System .lineSeparator ()
82
+ + " device = 'mps'" + System .lineSeparator ())
83
+ + "print(device)" + System .lineSeparator ()
78
84
+ "from scipy.ndimage import binary_fill_holes" + System .lineSeparator ()
79
85
+ "from scipy.ndimage import label" + System .lineSeparator ()
80
86
+ "import sys" + System .lineSeparator ()
@@ -83,7 +89,7 @@ public class Sam2 extends AbstractSamJ {
83
89
+ "from sam2.build_sam import build_sam2" + System .lineSeparator ()
84
90
+ "from sam2.sam2_image_predictor import SAM2ImagePredictor" + System .lineSeparator ()
85
91
+ "from sam2.utils.misc import variant_to_config_mapping" + System .lineSeparator ()
86
- + "model = build_sam2(variant_to_config_mapping['%s'],r'%s')" + System .lineSeparator ()
92
+ + "model = build_sam2(variant_to_config_mapping['%s'],r'%s').to(device) " + System .lineSeparator ()
87
93
+ "predictor = SAM2ImagePredictor(model)" + System .lineSeparator ()
88
94
+ "task.update('created predictor')" + System .lineSeparator ()
89
95
+ "encodings_map = {}" + System .lineSeparator ()
@@ -94,7 +100,8 @@ public class Sam2 extends AbstractSamJ {
94
100
+ "globals()['torch'] = torch" + System .lineSeparator ()
95
101
+ "globals()['label'] = label" + System .lineSeparator ()
96
102
+ "globals()['binary_fill_holes'] = binary_fill_holes" + System .lineSeparator ()
97
- + "globals()['predictor'] = predictor" + System .lineSeparator ();
103
+ + "globals()['predictor'] = predictor" + System .lineSeparator ()
104
+ + "globals()['device'] = device" + System .lineSeparator ();
98
105
/**
99
106
* String containing the Python imports code after it has been formated with the correct
100
107
* paths and names
0 commit comments