Skip to content

Commit 3b6903a

Browse files
committed
Create a new transform function to overlay teams
1 parent 97ad27e commit 3b6903a

File tree

3 files changed

+187
-3
lines changed

3 files changed

+187
-3
lines changed

kloppy/domain/services/transformers/dataset.py

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737
to_coordinate_system: Optional[CoordinateSystem] = None,
3838
to_pitch_dimensions: Optional[PitchDimensions] = None,
3939
to_orientation: Optional[Orientation] = None,
40+
overlay_teams: bool = False,
4041
):
4142
if (
4243
from_pitch_dimensions
@@ -73,9 +74,13 @@ def __init__(
7374
"You must specify the source CoordinateSystem when specifying the target CoordinateSystem"
7475
)
7576
self._to_pitch_dimensions = to_coordinate_system.pitch_dimensions
77+
else:
78+
self._to_pitch_dimensions = self._from_pitch_dimensions
7679

7780
self._from_orientation = from_orientation
7881
self._to_orientation = to_orientation
82+
83+
self._overlay_teams = overlay_teams
7984
if (
8085
from_orientation
8186
and not to_orientation
@@ -185,6 +190,9 @@ def transform_frame(self, frame: Frame) -> Frame:
185190
elif self._needs_pitch_dimensions_change:
186191
frame = self.__change_frame_dimensions(frame)
187192

193+
elif self._overlay_teams:
194+
frame = self.transform_frame_overlay_teams(frame)
195+
188196
# Flip frame based on orientation
189197
if self._needs_orientation_change:
190198
if self.__needs_flip(
@@ -308,6 +316,55 @@ def __flip_frame(self, frame: Frame):
308316
statistics=frame.statistics,
309317
)
310318

319+
def _get_overlay_players_coordinates(
320+
self,
321+
player_data: PlayerData,
322+
player_team: Team,
323+
ball_owning_team: Team,
324+
attacking_direction: AttackingDirection,
325+
):
326+
if attacking_direction == AttackingDirection.RTL:
327+
if player_team != ball_owning_team:
328+
player_data.coordinates = self.flip_point(
329+
player_data.coordinates
330+
)
331+
else:
332+
if player_team == ball_owning_team:
333+
player_data.coordinates = self.flip_point(
334+
player_data.coordinates
335+
)
336+
337+
return player_data
338+
339+
def transform_frame_overlay_teams(self, frame: Frame):
340+
players_data = {
341+
player: self._get_overlay_players_coordinates(
342+
player_data,
343+
player.team,
344+
frame.ball_owning_team,
345+
frame.attacking_direction,
346+
)
347+
for player, player_data in frame.players_data.items()
348+
}
349+
350+
ball_coordinates = frame.ball_coordinates
351+
if frame.attacking_direction != AttackingDirection.RTL:
352+
ball_coordinates = self.flip_point(ball_coordinates)
353+
354+
return Frame(
355+
# doesn't change
356+
timestamp=frame.timestamp,
357+
frame_id=frame.frame_id,
358+
ball_owning_team=frame.ball_owning_team,
359+
ball_state=frame.ball_state,
360+
period=frame.period,
361+
other_data=frame.other_data,
362+
statistics=frame.statistics,
363+
# changes
364+
ball_coordinates=ball_coordinates,
365+
players_data=players_data,
366+
)
367+
311368
def transform_event(self, event: Event) -> Event:
312369
# Change coordinate system
313370
if self._needs_coordinate_system_change:
@@ -375,11 +432,13 @@ def transform_dataset(
375432
to_pitch_dimensions: Optional[PitchDimensions] = None,
376433
to_orientation: Optional[Orientation] = None,
377434
to_coordinate_system: Optional[CoordinateSystem] = None,
435+
overlay_teams: bool = False,
378436
) -> Dataset:
379437
if (
380438
to_pitch_dimensions is None
381439
and to_orientation is None
382440
and to_coordinate_system is None
441+
and overlay_teams is False
383442
):
384443
return dataset
385444

@@ -391,8 +450,20 @@ def transform_dataset(
391450
"Cannot transform to BALL_OWNING_TEAM orientation when "
392451
"dataset doesn't contain ball owning team data"
393452
)
394-
395-
if to_pitch_dimensions is not None:
453+
if overlay_teams:
454+
transformer = cls(
455+
from_pitch_dimensions=dataset.metadata.pitch_dimensions,
456+
from_orientation=dataset.metadata.orientation,
457+
to_orientation=to_orientation,
458+
to_pitch_dimensions=to_pitch_dimensions,
459+
overlay_teams=overlay_teams,
460+
)
461+
metadata = replace(
462+
dataset.metadata,
463+
pitch_dimensions=to_pitch_dimensions,
464+
orientation=to_orientation,
465+
)
466+
elif to_pitch_dimensions is not None:
396467
# Transform the pitch dimensions and optionally the orientation
397468
transformer = cls(
398469
from_pitch_dimensions=dataset.metadata.pitch_dimensions,
@@ -418,7 +489,6 @@ def transform_dataset(
418489
dataset.metadata,
419490
coordinate_system=to_coordinate_system,
420491
pitch_dimensions=to_coordinate_system.pitch_dimensions,
421-
orientation=to_orientation,
422492
)
423493

424494
else:

kloppy/helpers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def transform(
1818
to_coordinate_system: Optional[
1919
Union[CoordinateSystem, Provider, str]
2020
] = None,
21+
overlay_teams: bool = False,
2122
) -> Dataset:
2223
# convert raw orientation to object
2324
if to_orientation is not None and isinstance(to_orientation, str):
@@ -43,4 +44,5 @@ def transform(
4344
to_orientation=to_orientation,
4445
to_coordinate_system=to_coordinate_system,
4546
to_pitch_dimensions=to_pitch_dimensions,
47+
overlay_teams=overlay_teams,
4648
)

kloppy/tests/test_helpers.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,118 @@ def test_transform_event_data_freeze_frame(self, base_dir):
371371
assert coordinates.x == 1 - coordinates_transformed.x
372372
assert coordinates.y == 1 - coordinates_transformed.y
373373

374+
def test_transform_overlay_teams_frames(self, base_dir):
375+
dataset = tracab.load(
376+
meta_data=base_dir / "files/tracab_meta.xml",
377+
raw_data=base_dir / "files/tracab_raw.dat",
378+
only_alive=False,
379+
coordinates="tracab",
380+
)
381+
382+
transformed_frame = dataset.transform(overlay_teams=True)
383+
384+
## Assert frame with attacking direction right-to-left
385+
386+
home_players_coordinates = [
387+
coordinates
388+
for player, coordinates in dataset.frames[
389+
0
390+
].players_coordinates.items()
391+
if dataset.metadata.teams[0].get_player_by_id(player.player_id)
392+
]
393+
converted_home_players_coordinates = [
394+
coordinates
395+
for player, coordinates in transformed_frame.frames[
396+
0
397+
].players_coordinates.items()
398+
if dataset.metadata.teams[0].get_player_by_id(player.player_id)
399+
]
400+
401+
away_players_coordinates = [
402+
coordinates
403+
for player, coordinates in dataset.frames[
404+
0
405+
].players_coordinates.items()
406+
if dataset.metadata.teams[1].get_player_by_id(player.player_id)
407+
]
408+
converted_away_players_coordinates = [
409+
coordinates
410+
for player, coordinates in transformed_frame.frames[
411+
0
412+
].players_coordinates.items()
413+
if dataset.metadata.teams[1].get_player_by_id(player.player_id)
414+
]
415+
416+
assert all(
417+
[a == b]
418+
for a, b in zip(
419+
home_players_coordinates, converted_home_players_coordinates
420+
)
421+
)
422+
423+
assert all(
424+
[a != b]
425+
for a, b in zip(
426+
away_players_coordinates, converted_away_players_coordinates
427+
)
428+
)
429+
430+
assert (
431+
dataset.frames[0].ball_coordinates
432+
== transformed_frame.frames[0].ball_coordinates
433+
)
434+
435+
## Assert frame with attacking direction with left-to-right
436+
437+
home_players_coordinates = [
438+
coordinates
439+
for player, coordinates in dataset.frames[
440+
4
441+
].players_coordinates.items()
442+
if dataset.metadata.teams[0].get_player_by_id(player.player_id)
443+
]
444+
converted_home_players_coordinates = [
445+
coordinates
446+
for player, coordinates in transformed_frame.frames[
447+
4
448+
].players_coordinates.items()
449+
if dataset.metadata.teams[0].get_player_by_id(player.player_id)
450+
]
451+
452+
away_players_coordinates = [
453+
coordinates
454+
for player, coordinates in dataset.frames[
455+
4
456+
].players_coordinates.items()
457+
if dataset.metadata.teams[1].get_player_by_id(player.player_id)
458+
]
459+
converted_away_players_coordinates = [
460+
coordinates
461+
for player, coordinates in transformed_frame.frames[
462+
4
463+
].players_coordinates.items()
464+
if dataset.metadata.teams[1].get_player_by_id(player.player_id)
465+
]
466+
467+
assert all(
468+
[a != b]
469+
for a, b in zip(
470+
home_players_coordinates, converted_home_players_coordinates
471+
)
472+
)
473+
474+
assert all(
475+
[a == b]
476+
for a, b in zip(
477+
away_players_coordinates, converted_away_players_coordinates
478+
)
479+
)
480+
481+
assert (
482+
dataset.frames[4].ball_coordinates
483+
!= transformed_frame.frames[4].ball_coordinates
484+
)
485+
374486
def test_to_pandas(self):
375487
tracking_data = self._get_tracking_dataset()
376488

0 commit comments

Comments
 (0)