|
29 | 29 |
|
30 | 30 |
|
31 | 31 | class TestHelpers:
|
| 32 | + |
| 33 | + def _get_tracking_dataset_multiple_players(self): |
| 34 | + home_team = Team(team_id="home", name="home", ground=Ground.HOME) |
| 35 | + away_team = Team(team_id="away", name="away", ground=Ground.AWAY) |
| 36 | + teams = [home_team, away_team] |
| 37 | + |
| 38 | + periods = [ |
| 39 | + Period( |
| 40 | + id=1, |
| 41 | + start_timestamp=0.0, |
| 42 | + end_timestamp=10.0, |
| 43 | + ), |
| 44 | + Period( |
| 45 | + id=2, |
| 46 | + start_timestamp=15.0, |
| 47 | + end_timestamp=25.0, |
| 48 | + ), |
| 49 | + ] |
| 50 | + metadata = Metadata( |
| 51 | + flags=(DatasetFlag.BALL_OWNING_TEAM), |
| 52 | + pitch_dimensions=NormalizedPitchDimensions( |
| 53 | + x_dim=Dimension(0, 100), |
| 54 | + y_dim=Dimension(-50, 50), |
| 55 | + pitch_length=105, |
| 56 | + pitch_width=68, |
| 57 | + ), |
| 58 | + orientation=Orientation.HOME_AWAY, |
| 59 | + frame_rate=25, |
| 60 | + periods=periods, |
| 61 | + teams=teams, |
| 62 | + score=None, |
| 63 | + provider=None, |
| 64 | + coordinate_system=None, |
| 65 | + date="2024-05-19T13:30:00", |
| 66 | + game_week="35", |
| 67 | + game_id="2374516", |
| 68 | + ) |
| 69 | + |
| 70 | + tracking_data = TrackingDataset( |
| 71 | + metadata=metadata, |
| 72 | + records=[ |
| 73 | + create_frame( |
| 74 | + frame_id=1, |
| 75 | + timestamp=0.1, |
| 76 | + ball_owning_team=teams[0], |
| 77 | + ball_state=None, |
| 78 | + period=periods[0], |
| 79 | + players_data={}, |
| 80 | + other_data=None, |
| 81 | + ball_coordinates=Point3D(x=100, y=-50, z=0), |
| 82 | + ), |
| 83 | + create_frame( |
| 84 | + frame_id=2, |
| 85 | + timestamp=0.2, |
| 86 | + ball_owning_team=teams[1], |
| 87 | + ball_state=None, |
| 88 | + period=periods[1], |
| 89 | + players_data={ |
| 90 | + Player( |
| 91 | + team=home_team, player_id="home_1", jersey_no=1 |
| 92 | + ): PlayerData( |
| 93 | + coordinates=Point(x=15, y=35), |
| 94 | + distance=0.03, |
| 95 | + speed=10.5, |
| 96 | + other_data={"extra_data": 1}, |
| 97 | + ), |
| 98 | + Player( |
| 99 | + team=away_team, player_id="away_1", jersey_no=1 |
| 100 | + ): PlayerData( |
| 101 | + coordinates=Point(x=15, y=35), |
| 102 | + distance=0.03, |
| 103 | + speed=10.5, |
| 104 | + other_data={"extra_data": 1}, |
| 105 | + ), |
| 106 | + }, |
| 107 | + other_data={"extra_data": 1}, |
| 108 | + ball_coordinates=Point3D(x=0, y=50, z=1), |
| 109 | + ), |
| 110 | + ], |
| 111 | + ) |
| 112 | + return tracking_data |
| 113 | + |
32 | 114 | def _get_tracking_dataset(self):
|
33 | 115 | home_team = Team(team_id="home", name="home", ground=Ground.HOME)
|
34 | 116 | away_team = Team(team_id="away", name="away", ground=Ground.AWAY)
|
@@ -371,6 +453,118 @@ def test_transform_event_data_freeze_frame(self, base_dir):
|
371 | 453 | assert coordinates.x == 1 - coordinates_transformed.x
|
372 | 454 | assert coordinates.y == 1 - coordinates_transformed.y
|
373 | 455 |
|
| 456 | + def test_transform_overlay_teams_frames(self, base_dir): |
| 457 | + dataset = tracab.load( |
| 458 | + meta_data=base_dir / "files/tracab_meta.xml", |
| 459 | + raw_data=base_dir / "files/tracab_raw.dat", |
| 460 | + only_alive=False, |
| 461 | + coordinates="tracab", |
| 462 | + ) |
| 463 | + |
| 464 | + transformed_frame = dataset.transform(overlay_teams=True) |
| 465 | + |
| 466 | + ## Assert frame with attacking direction right-to-left |
| 467 | + |
| 468 | + home_players_coordinates = [ |
| 469 | + coordinates |
| 470 | + for player, coordinates in dataset.frames[ |
| 471 | + 0 |
| 472 | + ].players_coordinates.items() |
| 473 | + if dataset.metadata.teams[0].get_player_by_id(player.player_id) |
| 474 | + ] |
| 475 | + converted_home_players_coordinates = [ |
| 476 | + coordinates |
| 477 | + for player, coordinates in transformed_frame.frames[ |
| 478 | + 0 |
| 479 | + ].players_coordinates.items() |
| 480 | + if dataset.metadata.teams[0].get_player_by_id(player.player_id) |
| 481 | + ] |
| 482 | + |
| 483 | + away_players_coordinates = [ |
| 484 | + coordinates |
| 485 | + for player, coordinates in dataset.frames[ |
| 486 | + 0 |
| 487 | + ].players_coordinates.items() |
| 488 | + if dataset.metadata.teams[1].get_player_by_id(player.player_id) |
| 489 | + ] |
| 490 | + converted_away_players_coordinates = [ |
| 491 | + coordinates |
| 492 | + for player, coordinates in transformed_frame.frames[ |
| 493 | + 0 |
| 494 | + ].players_coordinates.items() |
| 495 | + if dataset.metadata.teams[1].get_player_by_id(player.player_id) |
| 496 | + ] |
| 497 | + |
| 498 | + assert all( |
| 499 | + [a == b] |
| 500 | + for a, b in zip( |
| 501 | + home_players_coordinates, converted_home_players_coordinates |
| 502 | + ) |
| 503 | + ) |
| 504 | + |
| 505 | + assert all( |
| 506 | + [a != b] |
| 507 | + for a, b in zip( |
| 508 | + away_players_coordinates, converted_away_players_coordinates |
| 509 | + ) |
| 510 | + ) |
| 511 | + |
| 512 | + assert ( |
| 513 | + dataset.frames[0].ball_coordinates |
| 514 | + == transformed_frame.frames[0].ball_coordinates |
| 515 | + ) |
| 516 | + |
| 517 | + ## Assert frame with attacking direction with left-to-right |
| 518 | + |
| 519 | + home_players_coordinates = [ |
| 520 | + coordinates |
| 521 | + for player, coordinates in dataset.frames[ |
| 522 | + 4 |
| 523 | + ].players_coordinates.items() |
| 524 | + if dataset.metadata.teams[0].get_player_by_id(player.player_id) |
| 525 | + ] |
| 526 | + converted_home_players_coordinates = [ |
| 527 | + coordinates |
| 528 | + for player, coordinates in transformed_frame.frames[ |
| 529 | + 4 |
| 530 | + ].players_coordinates.items() |
| 531 | + if dataset.metadata.teams[0].get_player_by_id(player.player_id) |
| 532 | + ] |
| 533 | + |
| 534 | + away_players_coordinates = [ |
| 535 | + coordinates |
| 536 | + for player, coordinates in dataset.frames[ |
| 537 | + 4 |
| 538 | + ].players_coordinates.items() |
| 539 | + if dataset.metadata.teams[1].get_player_by_id(player.player_id) |
| 540 | + ] |
| 541 | + converted_away_players_coordinates = [ |
| 542 | + coordinates |
| 543 | + for player, coordinates in transformed_frame.frames[ |
| 544 | + 4 |
| 545 | + ].players_coordinates.items() |
| 546 | + if dataset.metadata.teams[1].get_player_by_id(player.player_id) |
| 547 | + ] |
| 548 | + |
| 549 | + assert all( |
| 550 | + [a != b] |
| 551 | + for a, b in zip( |
| 552 | + home_players_coordinates, converted_home_players_coordinates |
| 553 | + ) |
| 554 | + ) |
| 555 | + |
| 556 | + assert all( |
| 557 | + [a == b] |
| 558 | + for a, b in zip( |
| 559 | + away_players_coordinates, converted_away_players_coordinates |
| 560 | + ) |
| 561 | + ) |
| 562 | + |
| 563 | + assert ( |
| 564 | + dataset.frames[4].ball_coordinates |
| 565 | + != transformed_frame.frames[4].ball_coordinates |
| 566 | + ) |
| 567 | + |
374 | 568 | def test_to_pandas(self):
|
375 | 569 | tracking_data = self._get_tracking_dataset()
|
376 | 570 |
|
|
0 commit comments