Skip to content

Commit b9d347c

Browse files
authored
Merge pull request #19 from lombokai/feature/downloader
Feature/downloader
2 parents b98c4eb + 9a1f0c5 commit b9d347c

File tree

4 files changed

+39
-3
lines changed

4 files changed

+39
-3
lines changed

predict.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
def predict_image(
1010
image_path: str = typer.Argument(help="image path", show_default=True),
11-
model_path: str = typer.Argument("checkpoint/checkpoint_notebook.pth", help="model path (pth)", show_default=True),
11+
model_path: str = typer.Argument(None, help="path to your model (pth)", show_default=True),
1212
device: str = typer.Argument("cpu", help="use cuda if your device has cuda", show_default=True)
1313
):
1414
predictor = ImageRecognition(model_path=model_path, device=device)

src/onepiece_classify/infer/recognition.py

+23-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import os
2+
import sys
13
from pathlib import Path
24
from typing import Dict, Optional, Tuple
35

@@ -7,13 +9,23 @@
79

810
from onepiece_classify.models import image_recog
911
from onepiece_classify.transforms import get_test_transforms
12+
from onepiece_classify.utils import downloader
1013

1114
from .base import BaseInference
1215

1316

1417
class ImageRecognition(BaseInference):
15-
def __init__(self, model_path: str, device: str):
16-
self.model_path = Path(model_path)
18+
def __init__(self, device: str, model_path=None):
19+
20+
path_to_save = str(self._get_cache_dir()) + "/model.pth"
21+
if model_path is None:
22+
downloader(path_to_save)
23+
self.model_path = path_to_save
24+
else:
25+
self.model_path = Path(model_path)
26+
if not self.model_path.exists():
27+
raise FileNotFoundError("Model does not exist, check your model location and read README for more information")
28+
1729
self.device = device
1830
self.class_dict = {
1931
0: "Ace",
@@ -44,6 +56,15 @@ def _build_model(self):
4456
model_backbone = image_recog(self.nclass)
4557
model_backbone.load_state_dict(state_dict)
4658
return model_backbone
59+
60+
def _get_cache_dir(self):
61+
if sys.platform.startswith("win"):
62+
cache_dir = Path(os.getenv("LOCALAPPDATA", Path.home() / "AppData" / "Local")) / "OnepieceClassifyCache"
63+
else:
64+
cache_dir = Path.home() / ".cache" / "OnepieceClassifyCache"
65+
66+
cache_dir.mkdir(parents=True, exist_ok=True)
67+
return cache_dir
4768

4869
def pre_process(
4970
self, image: Optional[str | np.ndarray | Image.Image]
+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .downloader import *
+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import gdown
2+
from pathlib import Path
3+
4+
5+
def downloader(output_file):
6+
file_id = "1M1-1Hs198XDD6Xx-kSWLThv1elZBzJ0j"
7+
prefix = 'https://drive.google.com/uc?/export=download&id='
8+
9+
url_download = prefix+file_id
10+
11+
if not Path(output_file).exists():
12+
print("Downloading...")
13+
gdown.download(url_download, output_file)
14+
print("Download Finish...")

0 commit comments

Comments
 (0)