Skip to content

Commit 43d4467

Browse files
committed
start using apple silicon when available
1 parent 8852ace commit 43d4467

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

src/main/java/ai/nets/samj/models/AbstractSamJ.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ public abstract class AbstractSamJ implements AutoCloseable {
7979
protected static String UPDATE_ID_N_CONTOURS = "PROMPT_NUMBER_" + UUID.randomUUID().toString();
8080

8181
protected static String UPDATE_ID_CONTOUR = "FOUND_CONTOUR_" + UUID.randomUUID().toString();
82+
83+
protected static final boolean IS_APPLE_SILICON = PlatformDetection.isMacOS()
84+
&& PlatformDetection.getArch().equals(PlatformDetection.ARCH_ARM64);
8285

8386
public interface BatchCallback {
8487

src/main/java/ai/nets/samj/models/Sam2.java

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,12 @@ public class Sam2 extends AbstractSamJ {
7575
+ "from skimage import measure" + System.lineSeparator()
7676
+ "measure.label(np.ones((10, 10)), connectivity=1)" + System.lineSeparator()
7777
+ "import torch" + System.lineSeparator()
78+
+ "device = 'cpu'" + System.lineSeparator()
79+
+ ((!IS_APPLE_SILICON) ? ""
80+
: "from torch.backends import mps" + System.lineSeparator()
81+
+ "if mps.is_built() and mps.is_available():" + System.lineSeparator()
82+
+ " device = 'mps'" + System.lineSeparator())
83+
+ "print(device)" + System.lineSeparator()
7884
+ "from scipy.ndimage import binary_fill_holes" + System.lineSeparator()
7985
+ "from scipy.ndimage import label" + System.lineSeparator()
8086
+ "import sys" + System.lineSeparator()
@@ -83,7 +89,7 @@ public class Sam2 extends AbstractSamJ {
8389
+ "from sam2.build_sam import build_sam2" + System.lineSeparator()
8490
+ "from sam2.sam2_image_predictor import SAM2ImagePredictor" + System.lineSeparator()
8591
+ "from sam2.utils.misc import variant_to_config_mapping" + System.lineSeparator()
86-
+ "model = build_sam2(variant_to_config_mapping['%s'],r'%s')" + System.lineSeparator()
92+
+ "model = build_sam2(variant_to_config_mapping['%s'],r'%s').to(device)" + System.lineSeparator()
8793
+ "predictor = SAM2ImagePredictor(model)" + System.lineSeparator()
8894
+ "task.update('created predictor')" + System.lineSeparator()
8995
+ "encodings_map = {}" + System.lineSeparator()
@@ -94,7 +100,8 @@ public class Sam2 extends AbstractSamJ {
94100
+ "globals()['torch'] = torch" + System.lineSeparator()
95101
+ "globals()['label'] = label" + System.lineSeparator()
96102
+ "globals()['binary_fill_holes'] = binary_fill_holes" + System.lineSeparator()
97-
+ "globals()['predictor'] = predictor" + System.lineSeparator();
103+
+ "globals()['predictor'] = predictor" + System.lineSeparator()
104+
+ "globals()['device'] = device" + System.lineSeparator();
98105
/**
99106
* String containing the Python imports code after it has been formated with the correct
100107
* paths and names

0 commit comments

Comments
 (0)