Skip to content

Commit a1cf591

Browse files
committed
Create a new transform function to overlay teams
1 parent 97ad27e commit a1cf591

File tree

3 files changed

+269
-3
lines changed

3 files changed

+269
-3
lines changed

kloppy/domain/services/transformers/dataset.py

Lines changed: 73 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737
to_coordinate_system: Optional[CoordinateSystem] = None,
3838
to_pitch_dimensions: Optional[PitchDimensions] = None,
3939
to_orientation: Optional[Orientation] = None,
40+
overlay_teams: bool = False,
4041
):
4142
if (
4243
from_pitch_dimensions
@@ -73,9 +74,13 @@ def __init__(
7374
"You must specify the source CoordinateSystem when specifying the target CoordinateSystem"
7475
)
7576
self._to_pitch_dimensions = to_coordinate_system.pitch_dimensions
77+
else:
78+
self._to_pitch_dimensions = self._from_pitch_dimensions
7679

7780
self._from_orientation = from_orientation
7881
self._to_orientation = to_orientation
82+
83+
self._overlay_teams = overlay_teams
7984
if (
8085
from_orientation
8186
and not to_orientation
@@ -185,6 +190,9 @@ def transform_frame(self, frame: Frame) -> Frame:
185190
elif self._needs_pitch_dimensions_change:
186191
frame = self.__change_frame_dimensions(frame)
187192

193+
elif self._overlay_teams:
194+
frame = self.transform_frame_overlay_teams(frame)
195+
188196
# Flip frame based on orientation
189197
if self._needs_orientation_change:
190198
if self.__needs_flip(
@@ -308,6 +316,55 @@ def __flip_frame(self, frame: Frame):
308316
statistics=frame.statistics,
309317
)
310318

319+
def _get_overlay_players_coordinates(
320+
self,
321+
player_data: PlayerData,
322+
player_team: Team,
323+
ball_owning_team: Team,
324+
attacking_direction: AttackingDirection,
325+
):
326+
if attacking_direction == AttackingDirection.RTL:
327+
if player_team != ball_owning_team:
328+
player_data.coordinates = self.flip_point(
329+
player_data.coordinates
330+
)
331+
else:
332+
if player_team == ball_owning_team:
333+
player_data.coordinates = self.flip_point(
334+
player_data.coordinates
335+
)
336+
337+
return player_data
338+
339+
def transform_frame_overlay_teams(self, frame: Frame):
340+
players_data = {
341+
player: self._get_overlay_players_coordinates(
342+
player_data,
343+
player.team,
344+
frame.ball_owning_team,
345+
frame.attacking_direction,
346+
)
347+
for player, player_data in frame.players_data.items()
348+
}
349+
350+
ball_coordinates = frame.ball_coordinates
351+
if frame.attacking_direction != AttackingDirection.RTL:
352+
ball_coordinates = self.flip_point(ball_coordinates)
353+
354+
return Frame(
355+
# doesn't change
356+
timestamp=frame.timestamp,
357+
frame_id=frame.frame_id,
358+
ball_owning_team=frame.ball_owning_team,
359+
ball_state=frame.ball_state,
360+
period=frame.period,
361+
other_data=frame.other_data,
362+
statistics=frame.statistics,
363+
# changes
364+
ball_coordinates=ball_coordinates,
365+
players_data=players_data,
366+
)
367+
311368
def transform_event(self, event: Event) -> Event:
312369
# Change coordinate system
313370
if self._needs_coordinate_system_change:
@@ -375,11 +432,13 @@ def transform_dataset(
375432
to_pitch_dimensions: Optional[PitchDimensions] = None,
376433
to_orientation: Optional[Orientation] = None,
377434
to_coordinate_system: Optional[CoordinateSystem] = None,
435+
overlay_teams: bool = False,
378436
) -> Dataset:
379437
if (
380438
to_pitch_dimensions is None
381439
and to_orientation is None
382440
and to_coordinate_system is None
441+
and overlay_teams is False
383442
):
384443
return dataset
385444

@@ -391,8 +450,20 @@ def transform_dataset(
391450
"Cannot transform to BALL_OWNING_TEAM orientation when "
392451
"dataset doesn't contain ball owning team data"
393452
)
394-
395-
if to_pitch_dimensions is not None:
453+
if overlay_teams:
454+
transformer = cls(
455+
from_pitch_dimensions=dataset.metadata.pitch_dimensions,
456+
from_orientation=dataset.metadata.orientation,
457+
to_orientation=to_orientation,
458+
to_pitch_dimensions=to_pitch_dimensions,
459+
overlay_teams=overlay_teams,
460+
)
461+
metadata = replace(
462+
dataset.metadata,
463+
pitch_dimensions=to_pitch_dimensions,
464+
orientation=to_orientation,
465+
)
466+
elif to_pitch_dimensions is not None:
396467
# Transform the pitch dimensions and optionally the orientation
397468
transformer = cls(
398469
from_pitch_dimensions=dataset.metadata.pitch_dimensions,
@@ -418,7 +489,6 @@ def transform_dataset(
418489
dataset.metadata,
419490
coordinate_system=to_coordinate_system,
420491
pitch_dimensions=to_coordinate_system.pitch_dimensions,
421-
orientation=to_orientation,
422492
)
423493

424494
else:

kloppy/helpers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def transform(
1818
to_coordinate_system: Optional[
1919
Union[CoordinateSystem, Provider, str]
2020
] = None,
21+
overlay_teams: bool = False,
2122
) -> Dataset:
2223
# convert raw orientation to object
2324
if to_orientation is not None and isinstance(to_orientation, str):
@@ -43,4 +44,5 @@ def transform(
4344
to_orientation=to_orientation,
4445
to_coordinate_system=to_coordinate_system,
4546
to_pitch_dimensions=to_pitch_dimensions,
47+
overlay_teams=overlay_teams,
4648
)

kloppy/tests/test_helpers.py

Lines changed: 194 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,88 @@
2929

3030

3131
class TestHelpers:
32+
33+
def _get_tracking_dataset_multiple_players(self):
34+
home_team = Team(team_id="home", name="home", ground=Ground.HOME)
35+
away_team = Team(team_id="away", name="away", ground=Ground.AWAY)
36+
teams = [home_team, away_team]
37+
38+
periods = [
39+
Period(
40+
id=1,
41+
start_timestamp=0.0,
42+
end_timestamp=10.0,
43+
),
44+
Period(
45+
id=2,
46+
start_timestamp=15.0,
47+
end_timestamp=25.0,
48+
),
49+
]
50+
metadata = Metadata(
51+
flags=(DatasetFlag.BALL_OWNING_TEAM),
52+
pitch_dimensions=NormalizedPitchDimensions(
53+
x_dim=Dimension(0, 100),
54+
y_dim=Dimension(-50, 50),
55+
pitch_length=105,
56+
pitch_width=68,
57+
),
58+
orientation=Orientation.HOME_AWAY,
59+
frame_rate=25,
60+
periods=periods,
61+
teams=teams,
62+
score=None,
63+
provider=None,
64+
coordinate_system=None,
65+
date="2024-05-19T13:30:00",
66+
game_week="35",
67+
game_id="2374516",
68+
)
69+
70+
tracking_data = TrackingDataset(
71+
metadata=metadata,
72+
records=[
73+
create_frame(
74+
frame_id=1,
75+
timestamp=0.1,
76+
ball_owning_team=teams[0],
77+
ball_state=None,
78+
period=periods[0],
79+
players_data={},
80+
other_data=None,
81+
ball_coordinates=Point3D(x=100, y=-50, z=0),
82+
),
83+
create_frame(
84+
frame_id=2,
85+
timestamp=0.2,
86+
ball_owning_team=teams[1],
87+
ball_state=None,
88+
period=periods[1],
89+
players_data={
90+
Player(
91+
team=home_team, player_id="home_1", jersey_no=1
92+
): PlayerData(
93+
coordinates=Point(x=15, y=35),
94+
distance=0.03,
95+
speed=10.5,
96+
other_data={"extra_data": 1},
97+
),
98+
Player(
99+
team=away_team, player_id="away_1", jersey_no=1
100+
): PlayerData(
101+
coordinates=Point(x=15, y=35),
102+
distance=0.03,
103+
speed=10.5,
104+
other_data={"extra_data": 1},
105+
),
106+
},
107+
other_data={"extra_data": 1},
108+
ball_coordinates=Point3D(x=0, y=50, z=1),
109+
),
110+
],
111+
)
112+
return tracking_data
113+
32114
def _get_tracking_dataset(self):
33115
home_team = Team(team_id="home", name="home", ground=Ground.HOME)
34116
away_team = Team(team_id="away", name="away", ground=Ground.AWAY)
@@ -371,6 +453,118 @@ def test_transform_event_data_freeze_frame(self, base_dir):
371453
assert coordinates.x == 1 - coordinates_transformed.x
372454
assert coordinates.y == 1 - coordinates_transformed.y
373455

456+
def test_transform_overlay_teams_frames(self, base_dir):
457+
dataset = tracab.load(
458+
meta_data=base_dir / "files/tracab_meta.xml",
459+
raw_data=base_dir / "files/tracab_raw.dat",
460+
only_alive=False,
461+
coordinates="tracab",
462+
)
463+
464+
transformed_frame = dataset.transform(overlay_teams=True)
465+
466+
## Assert frame with attacking direction right-to-left
467+
468+
home_players_coordinates = [
469+
coordinates
470+
for player, coordinates in dataset.frames[
471+
0
472+
].players_coordinates.items()
473+
if dataset.metadata.teams[0].get_player_by_id(player.player_id)
474+
]
475+
converted_home_players_coordinates = [
476+
coordinates
477+
for player, coordinates in transformed_frame.frames[
478+
0
479+
].players_coordinates.items()
480+
if dataset.metadata.teams[0].get_player_by_id(player.player_id)
481+
]
482+
483+
away_players_coordinates = [
484+
coordinates
485+
for player, coordinates in dataset.frames[
486+
0
487+
].players_coordinates.items()
488+
if dataset.metadata.teams[1].get_player_by_id(player.player_id)
489+
]
490+
converted_away_players_coordinates = [
491+
coordinates
492+
for player, coordinates in transformed_frame.frames[
493+
0
494+
].players_coordinates.items()
495+
if dataset.metadata.teams[1].get_player_by_id(player.player_id)
496+
]
497+
498+
assert all(
499+
[a == b]
500+
for a, b in zip(
501+
home_players_coordinates, converted_home_players_coordinates
502+
)
503+
)
504+
505+
assert all(
506+
[a != b]
507+
for a, b in zip(
508+
away_players_coordinates, converted_away_players_coordinates
509+
)
510+
)
511+
512+
assert (
513+
dataset.frames[0].ball_coordinates
514+
== transformed_frame.frames[0].ball_coordinates
515+
)
516+
517+
## Assert frame with attacking direction with left-to-right
518+
519+
home_players_coordinates = [
520+
coordinates
521+
for player, coordinates in dataset.frames[
522+
4
523+
].players_coordinates.items()
524+
if dataset.metadata.teams[0].get_player_by_id(player.player_id)
525+
]
526+
converted_home_players_coordinates = [
527+
coordinates
528+
for player, coordinates in transformed_frame.frames[
529+
4
530+
].players_coordinates.items()
531+
if dataset.metadata.teams[0].get_player_by_id(player.player_id)
532+
]
533+
534+
away_players_coordinates = [
535+
coordinates
536+
for player, coordinates in dataset.frames[
537+
4
538+
].players_coordinates.items()
539+
if dataset.metadata.teams[1].get_player_by_id(player.player_id)
540+
]
541+
converted_away_players_coordinates = [
542+
coordinates
543+
for player, coordinates in transformed_frame.frames[
544+
4
545+
].players_coordinates.items()
546+
if dataset.metadata.teams[1].get_player_by_id(player.player_id)
547+
]
548+
549+
assert all(
550+
[a != b]
551+
for a, b in zip(
552+
home_players_coordinates, converted_home_players_coordinates
553+
)
554+
)
555+
556+
assert all(
557+
[a == b]
558+
for a, b in zip(
559+
away_players_coordinates, converted_away_players_coordinates
560+
)
561+
)
562+
563+
assert (
564+
dataset.frames[4].ball_coordinates
565+
!= transformed_frame.frames[4].ball_coordinates
566+
)
567+
374568
def test_to_pandas(self):
375569
tracking_data = self._get_tracking_dataset()
376570

0 commit comments

Comments
 (0)