|
| 1 | +import os |
| 2 | +import sys |
1 | 3 | from pathlib import Path
|
2 | 4 | from typing import Dict, Optional, Tuple
|
3 | 5 |
|
|
7 | 9 |
|
8 | 10 | from onepiece_classify.models import image_recog
|
9 | 11 | from onepiece_classify.transforms import get_test_transforms
|
| 12 | +from onepiece_classify.utils import downloader |
10 | 13 |
|
11 | 14 | from .base import BaseInference
|
12 | 15 |
|
13 | 16 |
|
14 | 17 | class ImageRecognition(BaseInference):
|
15 |
| - def __init__(self, model_path: str, device: str): |
16 |
| - self.model_path = Path(model_path) |
| 18 | + def __init__(self, device: str, model_path=None): |
| 19 | + |
| 20 | + path_to_save = str(self._get_cache_dir()) + "/model.pth" |
| 21 | + if model_path is None: |
| 22 | + downloader(path_to_save) |
| 23 | + self.model_path = path_to_save |
| 24 | + else: |
| 25 | + self.model_path = Path(model_path) |
| 26 | + if not self.model_path.exists(): |
| 27 | + raise FileNotFoundError("Model does not exist, check your model location and read README for more information") |
| 28 | + |
17 | 29 | self.device = device
|
18 | 30 | self.class_dict = {
|
19 | 31 | 0: "Ace",
|
@@ -44,6 +56,15 @@ def _build_model(self):
|
44 | 56 | model_backbone = image_recog(self.nclass)
|
45 | 57 | model_backbone.load_state_dict(state_dict)
|
46 | 58 | return model_backbone
|
| 59 | + |
| 60 | + def _get_cache_dir(self): |
| 61 | + if sys.platform.startswith("win"): |
| 62 | + cache_dir = Path(os.getenv("LOCALAPPDATA", Path.home() / "AppData" / "Local")) / "OnepieceClassifyCache" |
| 63 | + else: |
| 64 | + cache_dir = Path.home() / ".cache" / "OnepieceClassifyCache" |
| 65 | + |
| 66 | + cache_dir.mkdir(parents=True, exist_ok=True) |
| 67 | + return cache_dir |
47 | 68 |
|
48 | 69 | def pre_process(
|
49 | 70 | self, image: Optional[str | np.ndarray | Image.Image]
|
|
0 commit comments