Skip to content

Commit 85f987a

Browse files
committed
Create a new transform function to overlay teams
1 parent 97ad27e commit 85f987a

File tree

3 files changed

+178
-3
lines changed

3 files changed

+178
-3
lines changed

kloppy/domain/services/transformers/dataset.py

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ def __init__(
3737
to_coordinate_system: Optional[CoordinateSystem] = None,
3838
to_pitch_dimensions: Optional[PitchDimensions] = None,
3939
to_orientation: Optional[Orientation] = None,
40+
overlay_teams: bool = False
4041
):
4142
if (
4243
from_pitch_dimensions
@@ -73,9 +74,13 @@ def __init__(
7374
"You must specify the source CoordinateSystem when specifying the target CoordinateSystem"
7475
)
7576
self._to_pitch_dimensions = to_coordinate_system.pitch_dimensions
77+
else:
78+
self._to_pitch_dimensions = self._from_pitch_dimensions
7679

7780
self._from_orientation = from_orientation
7881
self._to_orientation = to_orientation
82+
83+
self._overlay_teams = overlay_teams
7984
if (
8085
from_orientation
8186
and not to_orientation
@@ -184,6 +189,9 @@ def transform_frame(self, frame: Frame) -> Frame:
184189
# Change dimensions
185190
elif self._needs_pitch_dimensions_change:
186191
frame = self.__change_frame_dimensions(frame)
192+
193+
elif self._overlay_teams:
194+
frame = self.transform_frame_overlay_teams(frame)
187195

188196
# Flip frame based on orientation
189197
if self._needs_orientation_change:
@@ -308,6 +316,37 @@ def __flip_frame(self, frame: Frame):
308316
statistics=frame.statistics,
309317
)
310318

319+
def _get_overlay_players_coordinates(self, player_data:PlayerData, player_team:Team, ball_owning_team:Team, attacking_direction:AttackingDirection):
320+
if attacking_direction == AttackingDirection.RTL:
321+
if player_team != ball_owning_team:
322+
player_data.coordinates = self.flip_point(player_data.coordinates)
323+
else:
324+
if player_team == ball_owning_team:
325+
player_data.coordinates = self.flip_point(player_data.coordinates)
326+
327+
return player_data
328+
329+
def transform_frame_overlay_teams(self, frame: Frame):
330+
players_data = {player:self._get_overlay_players_coordinates(player_data, player.team, frame.ball_owning_team, frame.attacking_direction) for player, player_data in frame.players_data.items()}
331+
332+
ball_coordinates = frame.ball_coordinates
333+
if frame.attacking_direction != AttackingDirection.RTL:
334+
ball_coordinates = self.flip_point(ball_coordinates)
335+
336+
return Frame(
337+
# doesn't change
338+
timestamp=frame.timestamp,
339+
frame_id=frame.frame_id,
340+
ball_owning_team=frame.ball_owning_team,
341+
ball_state=frame.ball_state,
342+
period=frame.period,
343+
other_data=frame.other_data,
344+
statistics=frame.statistics,
345+
# changes
346+
ball_coordinates=ball_coordinates,
347+
players_data=players_data,
348+
)
349+
311350
def transform_event(self, event: Event) -> Event:
312351
# Change coordinate system
313352
if self._needs_coordinate_system_change:
@@ -375,11 +414,13 @@ def transform_dataset(
375414
to_pitch_dimensions: Optional[PitchDimensions] = None,
376415
to_orientation: Optional[Orientation] = None,
377416
to_coordinate_system: Optional[CoordinateSystem] = None,
417+
overlay_teams: bool = False
378418
) -> Dataset:
379419
if (
380420
to_pitch_dimensions is None
381421
and to_orientation is None
382422
and to_coordinate_system is None
423+
and overlay_teams is False
383424
):
384425
return dataset
385426

@@ -391,8 +432,20 @@ def transform_dataset(
391432
"Cannot transform to BALL_OWNING_TEAM orientation when "
392433
"dataset doesn't contain ball owning team data"
393434
)
394-
395-
if to_pitch_dimensions is not None:
435+
if overlay_teams:
436+
transformer = cls(
437+
from_pitch_dimensions=dataset.metadata.pitch_dimensions,
438+
from_orientation=dataset.metadata.orientation,
439+
to_orientation=to_orientation,
440+
to_pitch_dimensions=to_pitch_dimensions,
441+
overlay_teams=overlay_teams
442+
)
443+
metadata = replace(
444+
dataset.metadata,
445+
pitch_dimensions=to_pitch_dimensions,
446+
orientation=to_orientation,
447+
)
448+
elif to_pitch_dimensions is not None:
396449
# Transform the pitch dimensions and optionally the orientation
397450
transformer = cls(
398451
from_pitch_dimensions=dataset.metadata.pitch_dimensions,
@@ -418,7 +471,6 @@ def transform_dataset(
418471
dataset.metadata,
419472
coordinate_system=to_coordinate_system,
420473
pitch_dimensions=to_coordinate_system.pitch_dimensions,
421-
orientation=to_orientation,
422474
)
423475

424476
else:

kloppy/helpers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def transform(
1818
to_coordinate_system: Optional[
1919
Union[CoordinateSystem, Provider, str]
2020
] = None,
21+
overlay_teams:bool = False
2122
) -> Dataset:
2223
# convert raw orientation to object
2324
if to_orientation is not None and isinstance(to_orientation, str):
@@ -43,4 +44,5 @@ def transform(
4344
to_orientation=to_orientation,
4445
to_coordinate_system=to_coordinate_system,
4546
to_pitch_dimensions=to_pitch_dimensions,
47+
overlay_teams=overlay_teams
4648
)

kloppy/tests/test_helpers.py

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,88 @@
2929

3030

3131
class TestHelpers:
32+
33+
def _get_tracking_dataset_multiple_players(self):
34+
home_team = Team(team_id="home", name="home", ground=Ground.HOME)
35+
away_team = Team(team_id="away", name="away", ground=Ground.AWAY)
36+
teams = [home_team, away_team]
37+
38+
periods = [
39+
Period(
40+
id=1,
41+
start_timestamp=0.0,
42+
end_timestamp=10.0,
43+
),
44+
Period(
45+
id=2,
46+
start_timestamp=15.0,
47+
end_timestamp=25.0,
48+
),
49+
]
50+
metadata = Metadata(
51+
flags=(DatasetFlag.BALL_OWNING_TEAM),
52+
pitch_dimensions=NormalizedPitchDimensions(
53+
x_dim=Dimension(0, 100),
54+
y_dim=Dimension(-50, 50),
55+
pitch_length=105,
56+
pitch_width=68,
57+
),
58+
orientation=Orientation.HOME_AWAY,
59+
frame_rate=25,
60+
periods=periods,
61+
teams=teams,
62+
score=None,
63+
provider=None,
64+
coordinate_system=None,
65+
date="2024-05-19T13:30:00",
66+
game_week="35",
67+
game_id="2374516",
68+
)
69+
70+
tracking_data = TrackingDataset(
71+
metadata=metadata,
72+
records=[
73+
create_frame(
74+
frame_id=1,
75+
timestamp=0.1,
76+
ball_owning_team=teams[0],
77+
ball_state=None,
78+
period=periods[0],
79+
players_data={},
80+
other_data=None,
81+
ball_coordinates=Point3D(x=100, y=-50, z=0),
82+
),
83+
create_frame(
84+
frame_id=2,
85+
timestamp=0.2,
86+
ball_owning_team=teams[1],
87+
ball_state=None,
88+
period=periods[1],
89+
players_data={
90+
Player(
91+
team=home_team, player_id="home_1", jersey_no=1
92+
): PlayerData(
93+
coordinates=Point(x=15, y=35),
94+
distance=0.03,
95+
speed=10.5,
96+
other_data={"extra_data": 1},
97+
),
98+
Player(
99+
team=away_team, player_id="away_1", jersey_no=1
100+
): PlayerData(
101+
coordinates=Point(x=15, y=35),
102+
distance=0.03,
103+
speed=10.5,
104+
other_data={"extra_data": 1},
105+
)
106+
},
107+
other_data={"extra_data": 1},
108+
ball_coordinates=Point3D(x=0, y=50, z=1),
109+
),
110+
],
111+
)
112+
return tracking_data
113+
32114
def _get_tracking_dataset(self):
33115
home_team = Team(team_id="home", name="home", ground=Ground.HOME)
34116
away_team = Team(team_id="away", name="away", ground=Ground.AWAY)
@@ -371,6 +453,45 @@ def test_transform_event_data_freeze_frame(self, base_dir):
371453
assert coordinates.x == 1 - coordinates_transformed.x
372454
assert coordinates.y == 1 - coordinates_transformed.y
373455

456+
def test_transform_overlay_teams_frames(self, base_dir):
457+
dataset = tracab.load(
458+
meta_data=base_dir / "files/tracab_meta.xml",
459+
raw_data=base_dir / "files/tracab_raw.dat",
460+
only_alive=False,
461+
coordinates="tracab",
462+
)
463+
464+
transformed_frame = dataset.transform(overlay_teams=True)
465+
466+
## Assert frame with attacking direction right-to-left
467+
468+
home_players_coordinates = [coordinates for player, coordinates in dataset.frames[0].players_coordinates.items() if dataset.metadata.teams[0].get_player_by_id(player.player_id)]
469+
converted_home_players_coordinates = [coordinates for player, coordinates in transformed_frame.frames[0].players_coordinates.items() if dataset.metadata.teams[0].get_player_by_id(player.player_id)]
470+
471+
away_players_coordinates = [coordinates for player, coordinates in dataset.frames[0].players_coordinates.items() if dataset.metadata.teams[1].get_player_by_id(player.player_id)]
472+
converted_away_players_coordinates = [coordinates for player, coordinates in transformed_frame.frames[0].players_coordinates.items() if dataset.metadata.teams[1].get_player_by_id(player.player_id)]
473+
474+
assert all([a==b] for a, b in zip(home_players_coordinates, converted_home_players_coordinates))
475+
476+
assert all([a!=b] for a, b in zip(away_players_coordinates, converted_away_players_coordinates))
477+
478+
assert dataset.frames[0].ball_coordinates == transformed_frame.frames[0].ball_coordinates
479+
480+
## Assert frame with attacking direction with left-to-right
481+
482+
home_players_coordinates = [coordinates for player, coordinates in dataset.frames[4].players_coordinates.items() if dataset.metadata.teams[0].get_player_by_id(player.player_id)]
483+
converted_home_players_coordinates = [coordinates for player, coordinates in transformed_frame.frames[4].players_coordinates.items() if dataset.metadata.teams[0].get_player_by_id(player.player_id)]
484+
485+
away_players_coordinates = [coordinates for player, coordinates in dataset.frames[4].players_coordinates.items() if dataset.metadata.teams[1].get_player_by_id(player.player_id)]
486+
converted_away_players_coordinates = [coordinates for player, coordinates in transformed_frame.frames[4].players_coordinates.items() if dataset.metadata.teams[1].get_player_by_id(player.player_id)]
487+
488+
assert all([a!=b] for a, b in zip(home_players_coordinates, converted_home_players_coordinates))
489+
490+
assert all([a==b] for a, b in zip(away_players_coordinates, converted_away_players_coordinates))
491+
492+
assert dataset.frames[4].ball_coordinates != transformed_frame.frames[4].ball_coordinates
493+
494+
374495
def test_to_pandas(self):
375496
tracking_data = self._get_tracking_dataset()
376497

0 commit comments

Comments
 (0)