@@ -37,6 +37,7 @@ def __init__(
37
37
to_coordinate_system : Optional [CoordinateSystem ] = None ,
38
38
to_pitch_dimensions : Optional [PitchDimensions ] = None ,
39
39
to_orientation : Optional [Orientation ] = None ,
40
+ overlay_teams : bool = False
40
41
):
41
42
if (
42
43
from_pitch_dimensions
@@ -73,9 +74,13 @@ def __init__(
73
74
"You must specify the source CoordinateSystem when specifying the target CoordinateSystem"
74
75
)
75
76
self ._to_pitch_dimensions = to_coordinate_system .pitch_dimensions
77
+ else :
78
+ self ._to_pitch_dimensions = self ._from_pitch_dimensions
76
79
77
80
self ._from_orientation = from_orientation
78
81
self ._to_orientation = to_orientation
82
+
83
+ self ._overlay_teams = overlay_teams
79
84
if (
80
85
from_orientation
81
86
and not to_orientation
@@ -184,6 +189,9 @@ def transform_frame(self, frame: Frame) -> Frame:
184
189
# Change dimensions
185
190
elif self ._needs_pitch_dimensions_change :
186
191
frame = self .__change_frame_dimensions (frame )
192
+
193
+ elif self ._overlay_teams :
194
+ frame = self .transform_frame_overlay_teams (frame )
187
195
188
196
# Flip frame based on orientation
189
197
if self ._needs_orientation_change :
@@ -308,6 +316,37 @@ def __flip_frame(self, frame: Frame):
308
316
statistics = frame .statistics ,
309
317
)
310
318
319
+ def _get_overlay_players_coordinates (self , player_data :PlayerData , player_team :Team , ball_owning_team :Team , attacking_direction :AttackingDirection ):
320
+ if attacking_direction == AttackingDirection .RTL :
321
+ if player_team != ball_owning_team :
322
+ player_data .coordinates = self .flip_point (player_data .coordinates )
323
+ else :
324
+ if player_team == ball_owning_team :
325
+ player_data .coordinates = self .flip_point (player_data .coordinates )
326
+
327
+ return player_data
328
+
329
+ def transform_frame_overlay_teams (self , frame : Frame ):
330
+ players_data = {player :self ._get_overlay_players_coordinates (player_data , player .team , frame .ball_owning_team , frame .attacking_direction ) for player , player_data in frame .players_data .items ()}
331
+
332
+ ball_coordinates = frame .ball_coordinates
333
+ if frame .attacking_direction != AttackingDirection .RTL :
334
+ ball_coordinates = self .flip_point (ball_coordinates )
335
+
336
+ return Frame (
337
+ # doesn't change
338
+ timestamp = frame .timestamp ,
339
+ frame_id = frame .frame_id ,
340
+ ball_owning_team = frame .ball_owning_team ,
341
+ ball_state = frame .ball_state ,
342
+ period = frame .period ,
343
+ other_data = frame .other_data ,
344
+ statistics = frame .statistics ,
345
+ # changes
346
+ ball_coordinates = ball_coordinates ,
347
+ players_data = players_data ,
348
+ )
349
+
311
350
def transform_event (self , event : Event ) -> Event :
312
351
# Change coordinate system
313
352
if self ._needs_coordinate_system_change :
@@ -375,11 +414,13 @@ def transform_dataset(
375
414
to_pitch_dimensions : Optional [PitchDimensions ] = None ,
376
415
to_orientation : Optional [Orientation ] = None ,
377
416
to_coordinate_system : Optional [CoordinateSystem ] = None ,
417
+ overlay_teams : bool = False
378
418
) -> Dataset :
379
419
if (
380
420
to_pitch_dimensions is None
381
421
and to_orientation is None
382
422
and to_coordinate_system is None
423
+ and overlay_teams is False
383
424
):
384
425
return dataset
385
426
@@ -391,8 +432,20 @@ def transform_dataset(
391
432
"Cannot transform to BALL_OWNING_TEAM orientation when "
392
433
"dataset doesn't contain ball owning team data"
393
434
)
394
-
395
- if to_pitch_dimensions is not None :
435
+ if overlay_teams :
436
+ transformer = cls (
437
+ from_pitch_dimensions = dataset .metadata .pitch_dimensions ,
438
+ from_orientation = dataset .metadata .orientation ,
439
+ to_orientation = to_orientation ,
440
+ to_pitch_dimensions = to_pitch_dimensions ,
441
+ overlay_teams = overlay_teams
442
+ )
443
+ metadata = replace (
444
+ dataset .metadata ,
445
+ pitch_dimensions = to_pitch_dimensions ,
446
+ orientation = to_orientation ,
447
+ )
448
+ elif to_pitch_dimensions is not None :
396
449
# Transform the pitch dimensions and optionally the orientation
397
450
transformer = cls (
398
451
from_pitch_dimensions = dataset .metadata .pitch_dimensions ,
@@ -418,7 +471,6 @@ def transform_dataset(
418
471
dataset .metadata ,
419
472
coordinate_system = to_coordinate_system ,
420
473
pitch_dimensions = to_coordinate_system .pitch_dimensions ,
421
- orientation = to_orientation ,
422
474
)
423
475
424
476
else :
0 commit comments