@@ -37,6 +37,7 @@ def __init__(
37
37
to_coordinate_system : Optional [CoordinateSystem ] = None ,
38
38
to_pitch_dimensions : Optional [PitchDimensions ] = None ,
39
39
to_orientation : Optional [Orientation ] = None ,
40
+ overlay_teams : bool = False ,
40
41
):
41
42
if (
42
43
from_pitch_dimensions
@@ -73,9 +74,13 @@ def __init__(
73
74
"You must specify the source CoordinateSystem when specifying the target CoordinateSystem"
74
75
)
75
76
self ._to_pitch_dimensions = to_coordinate_system .pitch_dimensions
77
+ else :
78
+ self ._to_pitch_dimensions = self ._from_pitch_dimensions
76
79
77
80
self ._from_orientation = from_orientation
78
81
self ._to_orientation = to_orientation
82
+
83
+ self ._overlay_teams = overlay_teams
79
84
if (
80
85
from_orientation
81
86
and not to_orientation
@@ -185,6 +190,9 @@ def transform_frame(self, frame: Frame) -> Frame:
185
190
elif self ._needs_pitch_dimensions_change :
186
191
frame = self .__change_frame_dimensions (frame )
187
192
193
+ elif self ._overlay_teams :
194
+ frame = self .transform_frame_overlay_teams (frame )
195
+
188
196
# Flip frame based on orientation
189
197
if self ._needs_orientation_change :
190
198
if self .__needs_flip (
@@ -308,6 +316,55 @@ def __flip_frame(self, frame: Frame):
308
316
statistics = frame .statistics ,
309
317
)
310
318
319
+ def _get_overlay_players_coordinates (
320
+ self ,
321
+ player_data : PlayerData ,
322
+ player_team : Team ,
323
+ ball_owning_team : Team ,
324
+ attacking_direction : AttackingDirection ,
325
+ ):
326
+ if attacking_direction == AttackingDirection .RTL :
327
+ if player_team != ball_owning_team :
328
+ player_data .coordinates = self .flip_point (
329
+ player_data .coordinates
330
+ )
331
+ else :
332
+ if player_team == ball_owning_team :
333
+ player_data .coordinates = self .flip_point (
334
+ player_data .coordinates
335
+ )
336
+
337
+ return player_data
338
+
339
+ def transform_frame_overlay_teams (self , frame : Frame ):
340
+ players_data = {
341
+ player : self ._get_overlay_players_coordinates (
342
+ player_data ,
343
+ player .team ,
344
+ frame .ball_owning_team ,
345
+ frame .attacking_direction ,
346
+ )
347
+ for player , player_data in frame .players_data .items ()
348
+ }
349
+
350
+ ball_coordinates = frame .ball_coordinates
351
+ if frame .attacking_direction != AttackingDirection .RTL :
352
+ ball_coordinates = self .flip_point (ball_coordinates )
353
+
354
+ return Frame (
355
+ # doesn't change
356
+ timestamp = frame .timestamp ,
357
+ frame_id = frame .frame_id ,
358
+ ball_owning_team = frame .ball_owning_team ,
359
+ ball_state = frame .ball_state ,
360
+ period = frame .period ,
361
+ other_data = frame .other_data ,
362
+ statistics = frame .statistics ,
363
+ # changes
364
+ ball_coordinates = ball_coordinates ,
365
+ players_data = players_data ,
366
+ )
367
+
311
368
def transform_event (self , event : Event ) -> Event :
312
369
# Change coordinate system
313
370
if self ._needs_coordinate_system_change :
@@ -375,11 +432,13 @@ def transform_dataset(
375
432
to_pitch_dimensions : Optional [PitchDimensions ] = None ,
376
433
to_orientation : Optional [Orientation ] = None ,
377
434
to_coordinate_system : Optional [CoordinateSystem ] = None ,
435
+ overlay_teams : bool = False ,
378
436
) -> Dataset :
379
437
if (
380
438
to_pitch_dimensions is None
381
439
and to_orientation is None
382
440
and to_coordinate_system is None
441
+ and overlay_teams is False
383
442
):
384
443
return dataset
385
444
@@ -391,8 +450,20 @@ def transform_dataset(
391
450
"Cannot transform to BALL_OWNING_TEAM orientation when "
392
451
"dataset doesn't contain ball owning team data"
393
452
)
394
-
395
- if to_pitch_dimensions is not None :
453
+ if overlay_teams :
454
+ transformer = cls (
455
+ from_pitch_dimensions = dataset .metadata .pitch_dimensions ,
456
+ from_orientation = dataset .metadata .orientation ,
457
+ to_orientation = to_orientation ,
458
+ to_pitch_dimensions = to_pitch_dimensions ,
459
+ overlay_teams = overlay_teams ,
460
+ )
461
+ metadata = replace (
462
+ dataset .metadata ,
463
+ pitch_dimensions = to_pitch_dimensions ,
464
+ orientation = to_orientation ,
465
+ )
466
+ elif to_pitch_dimensions is not None :
396
467
# Transform the pitch dimensions and optionally the orientation
397
468
transformer = cls (
398
469
from_pitch_dimensions = dataset .metadata .pitch_dimensions ,
@@ -418,7 +489,6 @@ def transform_dataset(
418
489
dataset .metadata ,
419
490
coordinate_system = to_coordinate_system ,
420
491
pitch_dimensions = to_coordinate_system .pitch_dimensions ,
421
- orientation = to_orientation ,
422
492
)
423
493
424
494
else :
0 commit comments