Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 82 additions & 17 deletions beets/metadata_plugins.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,21 @@
import abc
import re
import warnings
from typing import TYPE_CHECKING, Generic, Literal, Sequence, TypedDict, TypeVar
from typing import (
TYPE_CHECKING,
Callable,
Generic,
Iterator,
Literal,
Sequence,
TypedDict,
TypeVar,
)

import unidecode
from typing_extensions import NotRequired
from typing_extensions import NotRequired, ParamSpec

from beets import logging
from beets.util import cached_classproperty
from beets.util.id_extractors import extract_release_id

Expand All @@ -25,8 +35,13 @@

from confuse import ConfigView

from .autotag import Distance
from .autotag.hooks import AlbumInfo, Item, TrackInfo
from .autotag.hooks import AlbumInfo, Distance, Item, TrackInfo

P = ParamSpec("P")
R = TypeVar("R")

# Global logger.
log = logging.getLogger("beets")


def find_metadata_source_plugins() -> list[MetadataSourcePlugin]:
Expand Down Expand Up @@ -56,17 +71,17 @@ def find_metadata_source_plugins() -> list[MetadataSourcePlugin]:


@notify_info_yielded("albuminfo_received")
def candidates(*args, **kwargs) -> Iterable[AlbumInfo]:
def candidates(*args, **kwargs) -> Iterator[AlbumInfo]:
"""Return matching album candidates from all metadata source plugins."""
for plugin in find_metadata_source_plugins():
yield from plugin.candidates(*args, **kwargs)
yield from _safe_yield_from(plugin.candidates, *args, **kwargs)


@notify_info_yielded("trackinfo_received")
def item_candidates(*args, **kwargs) -> Iterable[TrackInfo]:
"""Return matching track candidates fromm all metadata source plugins."""
"""Return matching track candidates from all metadata source plugins."""
for plugin in find_metadata_source_plugins():
yield from plugin.item_candidates(*args, **kwargs)
yield from _safe_yield_from(plugin.item_candidates, *args, **kwargs)


def album_for_id(_id: str) -> AlbumInfo | None:
Expand All @@ -75,7 +90,7 @@ def album_for_id(_id: str) -> AlbumInfo | None:
A single ID can yield just a single album, so we return the first match.
"""
for plugin in find_metadata_source_plugins():
if info := plugin.album_for_id(album_id=_id):
if info := _safe_call(plugin.album_for_id, _id):
send("albuminfo_received", info=info)
return info

Expand All @@ -88,7 +103,7 @@ def track_for_id(_id: str) -> TrackInfo | None:
A single ID can yield just a single track, so we return the first match.
"""
for plugin in find_metadata_source_plugins():
if info := plugin.track_for_id(_id):
if info := _safe_call(plugin.track_for_id, _id):
send("trackinfo_received", info=info)
return info

Expand All @@ -105,7 +120,8 @@ def track_distance(item: Item, info: TrackInfo) -> Distance:

dist = Distance()
for plugin in find_metadata_source_plugins():
dist.update(plugin.track_distance(item, info))
if distance := _safe_call(plugin.track_distance, item, info):
dist.update(distance)
return dist


Expand All @@ -119,10 +135,59 @@ def album_distance(

dist = Distance()
for plugin in find_metadata_source_plugins():
dist.update(plugin.album_distance(items, album_info, mapping))
if distance := _safe_call(
plugin.album_distance, items, album_info, mapping
):
dist.update(distance)
return dist


def _safe_call(
func: Callable[P, R], *arg: P.args, **kwargs: P.kwargs
) -> R | None:
"""Helper function to safely call plugin functions.

Wraps the function call in a try/except block and logs any exceptions
that occur.
"""

try:
return func(*arg, **kwargs)
except Exception as e:
log.error(
"Error in '{}': {}",
_class_name_from_method(func),
e,
)
log.debug("Exception details:", exc_info=True)

return None


def _safe_yield_from(
func: Callable[P, Iterable[R]], *arg: P.args, **kwargs: P.kwargs
) -> Iterable[R]:
"""Helper function to safely yield from plugin functions."""
try:
yield from func(*arg, **kwargs)
except Exception as e:
log.error(
"Error in '{}': {}",
_class_name_from_method(func),
e,
)
log.debug("Exception details:", exc_info=True)


def _class_name_from_method(func: Callable[P, R]) -> str:
"""Helper function to get the class name from a method."""
return (
func.__qualname__.split(".")[0]
if "." in func.__qualname__
else "Unknown"
)


def _get_distance(
config: ConfigView, data_source: str, info: AlbumInfo | TrackInfo
) -> Distance:
Expand Down Expand Up @@ -202,7 +267,7 @@ def item_candidates(
"""
raise NotImplementedError

def albums_for_ids(self, ids: Sequence[str]) -> Iterable[AlbumInfo | None]:
def albums_for_ids(self, ids: Iterable[str]) -> Iterable[AlbumInfo | None]:
"""Batch lookup of album metadata for a list of album IDs.

Given a list of album identifiers, yields corresponding AlbumInfo objects.
Expand All @@ -213,7 +278,7 @@ def albums_for_ids(self, ids: Sequence[str]) -> Iterable[AlbumInfo | None]:

return (self.album_for_id(id) for id in ids)

def tracks_for_ids(self, ids: Sequence[str]) -> Iterable[TrackInfo | None]:
def tracks_for_ids(self, ids: Iterable[str]) -> Iterable[TrackInfo | None]:
"""Batch lookup of track metadata for a list of track IDs.

Given a list of track identifiers, yields corresponding TrackInfo objects.
Expand Down Expand Up @@ -324,11 +389,11 @@ class SearchFilter(TypedDict):
album: NotRequired[str]


R = TypeVar("R", bound=IDResponse)
Res = TypeVar("Res", bound=IDResponse)


class SearchApiMetadataSourcePlugin(
Generic[R], MetadataSourcePlugin, metaclass=abc.ABCMeta
Generic[Res], MetadataSourcePlugin, metaclass=abc.ABCMeta
):
"""Helper class to implement a metadata source plugin with an API.

Expand All @@ -353,7 +418,7 @@ def _search_api(
query_type: Literal["album", "track"],
filters: SearchFilter,
query_string: str = "",
) -> Sequence[R]:
) -> Sequence[Res]:
"""Perform a search on the API.

:param query_type: The type of query to perform.
Expand Down
2 changes: 2 additions & 0 deletions docs/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ Bug fixes:
the config option ``deezer.search_query_ascii: yes``. :bug:`5860`
- Fixed regression with :doc:`/plugins/listenbrainz` where the plugin could not
be loaded :bug:`5975`
- Errors in metadata plugins during candidates lookup will now be logged but
won't crash beets anymore. :bug:`5903`, :bug:`4789`

For packagers:

Expand Down
70 changes: 70 additions & 0 deletions test/test_metadata_plugins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from typing import Iterable

import pytest
from confuse import AttrDict

from beets import metadata_plugins
from beets.test.helper import PluginMixin


class ErrorMetadataMockPlugin(metadata_plugins.MetadataSourcePlugin):
"""A metadata source plugin that raises errors in all its methods."""

def candidates(self, *args, **kwargs):
raise ValueError("Mocked error")

def item_candidates(self, *args, **kwargs):
raise ValueError("Mocked error")

def album_for_id(self, *args, **kwargs):
raise ValueError("Mocked error")

def track_for_id(self, *args, **kwargs):
raise ValueError("Mocked error")

def track_distance(self, *args, **kwargs):
raise ValueError("Mocked error")

def album_distance(self, *args, **kwargs):
raise ValueError("Mocked error")


class TestMetadataPluginsException(PluginMixin):
"""Check that errors during the metadata plugins do not crash beets.
They should be logged as errors instead.
"""

@pytest.fixture(autouse=True)
def setup(self):
self.register_plugin(ErrorMetadataMockPlugin)
yield
self.unload_plugins()

@pytest.mark.parametrize(
"method_name,args",
[
("candidates", ()),
("item_candidates", ()),
("album_for_id", ("some_id",)),
("track_for_id", ("some_id",)),
("track_distance", (None, AttrDict({"data_source": "mock"}))),
("album_distance", (None, AttrDict({"data_source": "mock"}), None)),
],
)
def test_error_handling_candidates(
self,
caplog,
method_name,
args,
):
with caplog.at_level("ERROR"):
# Call the method to trigger the error
ret = getattr(metadata_plugins, method_name)(*args)
if isinstance(ret, Iterable):
list(ret)

# Check that an error was logged
assert len(caplog.records) == 1
logs = [record.getMessage() for record in caplog.records]
assert logs == ["Error in 'ErrorMetadataMockPlugin': Mocked error"]
caplog.clear()
Loading