1
1
from __future__ import annotations
2
2
3
+ import copy
3
4
import json
4
5
import os
5
6
from unittest import TestCase
@@ -1481,7 +1482,7 @@ def test_get_bandstructure(self):
1481
1482
1482
1483
class TestBandoverlaps (TestCase ):
1483
1484
def setUp (self ):
1484
- # test spin-polarized calc and non spinpolarized calc
1485
+ # test spin-polarized calc and non spin-polarized calc
1485
1486
1486
1487
self .band_overlaps1 = Bandoverlaps (f"{ TEST_DIR } /bandOverlaps.lobster.1" )
1487
1488
self .band_overlaps2 = Bandoverlaps (f"{ TEST_DIR } /bandOverlaps.lobster.2" )
@@ -1515,9 +1516,18 @@ def test_attributes(self):
1515
1516
assert self .band_overlaps2 .max_deviation [- 1 ] == approx (1.48451e-05 )
1516
1517
assert self .band_overlaps2_new .max_deviation [- 1 ] == approx (0.45154 )
1517
1518
1518
- def test_has_good_quality (self ):
1519
+ def test_has_good_quality_maxDeviation (self ):
1519
1520
assert not self .band_overlaps1 .has_good_quality_maxDeviation (limit_maxDeviation = 0.1 )
1520
1521
assert not self .band_overlaps1_new .has_good_quality_maxDeviation (limit_maxDeviation = 0.1 )
1522
+
1523
+ assert self .band_overlaps1 .has_good_quality_maxDeviation (limit_maxDeviation = 100 )
1524
+ assert self .band_overlaps1_new .has_good_quality_maxDeviation (limit_maxDeviation = 100 )
1525
+ assert self .band_overlaps2 .has_good_quality_maxDeviation ()
1526
+ assert not self .band_overlaps2_new .has_good_quality_maxDeviation ()
1527
+ assert not self .band_overlaps2 .has_good_quality_maxDeviation (limit_maxDeviation = 0.0000001 )
1528
+ assert not self .band_overlaps2_new .has_good_quality_maxDeviation (limit_maxDeviation = 0.0000001 )
1529
+
1530
+ def test_has_good_quality_check_occupied_bands (self ):
1521
1531
assert not self .band_overlaps1 .has_good_quality_check_occupied_bands (
1522
1532
number_occ_bands_spin_up = 9 ,
1523
1533
number_occ_bands_spin_down = 5 ,
@@ -1545,65 +1555,58 @@ def test_has_good_quality(self):
1545
1555
assert not self .band_overlaps1 .has_good_quality_check_occupied_bands (
1546
1556
number_occ_bands_spin_up = 1 ,
1547
1557
number_occ_bands_spin_down = 1 ,
1548
- limit_deviation = 0.000001 ,
1558
+ limit_deviation = 1e-6 ,
1549
1559
spin_polarized = True ,
1550
1560
)
1551
1561
assert not self .band_overlaps1_new .has_good_quality_check_occupied_bands (
1552
1562
number_occ_bands_spin_up = 1 ,
1553
1563
number_occ_bands_spin_down = 1 ,
1554
- limit_deviation = 0.000001 ,
1564
+ limit_deviation = 1e-6 ,
1555
1565
spin_polarized = True ,
1556
1566
)
1557
1567
assert not self .band_overlaps1 .has_good_quality_check_occupied_bands (
1558
1568
number_occ_bands_spin_up = 1 ,
1559
1569
number_occ_bands_spin_down = 0 ,
1560
- limit_deviation = 0.000001 ,
1570
+ limit_deviation = 1e-6 ,
1561
1571
spin_polarized = True ,
1562
1572
)
1563
1573
assert not self .band_overlaps1_new .has_good_quality_check_occupied_bands (
1564
1574
number_occ_bands_spin_up = 1 ,
1565
1575
number_occ_bands_spin_down = 0 ,
1566
- limit_deviation = 0.000001 ,
1576
+ limit_deviation = 1e-6 ,
1567
1577
spin_polarized = True ,
1568
1578
)
1569
1579
assert not self .band_overlaps1 .has_good_quality_check_occupied_bands (
1570
1580
number_occ_bands_spin_up = 0 ,
1571
1581
number_occ_bands_spin_down = 1 ,
1572
- limit_deviation = 0.000001 ,
1582
+ limit_deviation = 1e-6 ,
1573
1583
spin_polarized = True ,
1574
1584
)
1575
1585
assert not self .band_overlaps1_new .has_good_quality_check_occupied_bands (
1576
1586
number_occ_bands_spin_up = 0 ,
1577
1587
number_occ_bands_spin_down = 1 ,
1578
- limit_deviation = 0.000001 ,
1588
+ limit_deviation = 1e-6 ,
1579
1589
spin_polarized = True ,
1580
1590
)
1581
1591
assert not self .band_overlaps1 .has_good_quality_check_occupied_bands (
1582
1592
number_occ_bands_spin_up = 4 ,
1583
1593
number_occ_bands_spin_down = 4 ,
1584
- limit_deviation = 0.001 ,
1594
+ limit_deviation = 1e-3 ,
1585
1595
spin_polarized = True ,
1586
1596
)
1587
1597
assert not self .band_overlaps1_new .has_good_quality_check_occupied_bands (
1588
1598
number_occ_bands_spin_up = 4 ,
1589
1599
number_occ_bands_spin_down = 4 ,
1590
- limit_deviation = 0.001 ,
1600
+ limit_deviation = 1e-3 ,
1591
1601
spin_polarized = True ,
1592
1602
)
1593
-
1594
- assert self .band_overlaps1 .has_good_quality_maxDeviation (limit_maxDeviation = 100 )
1595
- assert self .band_overlaps1_new .has_good_quality_maxDeviation (limit_maxDeviation = 100 )
1596
- assert self .band_overlaps2 .has_good_quality_maxDeviation ()
1597
- assert not self .band_overlaps2_new .has_good_quality_maxDeviation ()
1598
- assert not self .band_overlaps2 .has_good_quality_maxDeviation (limit_maxDeviation = 0.0000001 )
1599
- assert not self .band_overlaps2_new .has_good_quality_maxDeviation (limit_maxDeviation = 0.0000001 )
1600
1603
assert not self .band_overlaps2 .has_good_quality_check_occupied_bands (
1601
- number_occ_bands_spin_up = 10 , limit_deviation = 0.0000001
1604
+ number_occ_bands_spin_up = 10 , limit_deviation = 1e-7
1602
1605
)
1603
1606
assert not self .band_overlaps2_new .has_good_quality_check_occupied_bands (
1604
- number_occ_bands_spin_up = 10 , limit_deviation = 0.0000001
1607
+ number_occ_bands_spin_up = 10 , limit_deviation = 1e-7
1605
1608
)
1606
- assert not self .band_overlaps2 .has_good_quality_check_occupied_bands (
1609
+ assert self .band_overlaps2 .has_good_quality_check_occupied_bands (
1607
1610
number_occ_bands_spin_up = 1 , limit_deviation = 0.1
1608
1611
)
1609
1612
@@ -1614,14 +1617,86 @@ def test_has_good_quality(self):
1614
1617
number_occ_bands_spin_up = 1 , limit_deviation = 1e-8
1615
1618
)
1616
1619
assert self .band_overlaps2 .has_good_quality_check_occupied_bands (number_occ_bands_spin_up = 10 , limit_deviation = 1 )
1617
- assert not self .band_overlaps2_new .has_good_quality_check_occupied_bands (
1620
+ assert self .band_overlaps2_new .has_good_quality_check_occupied_bands (
1618
1621
number_occ_bands_spin_up = 2 , limit_deviation = 0.1
1619
1622
)
1620
1623
assert self .band_overlaps2 .has_good_quality_check_occupied_bands (number_occ_bands_spin_up = 1 , limit_deviation = 1 )
1621
1624
assert self .band_overlaps2_new .has_good_quality_check_occupied_bands (
1622
1625
number_occ_bands_spin_up = 1 , limit_deviation = 2
1623
1626
)
1624
1627
1628
+ def test_has_good_quality_check_occupied_bands_patched (self ):
1629
+ """Test with patched data."""
1630
+
1631
+ limit_deviation = 0.1
1632
+
1633
+ rng = np .random .default_rng (42 ) # set seed for reproducibility
1634
+
1635
+ band_overlaps = copy .deepcopy (self .band_overlaps1_new )
1636
+
1637
+ number_occ_bands_spin_up_all = list (range (band_overlaps .band_overlaps_dict [Spin .up ]["matrices" ][0 ].shape [0 ]))
1638
+ number_occ_bands_spin_down_all = list (
1639
+ range (band_overlaps .band_overlaps_dict [Spin .down ]["matrices" ][0 ].shape [0 ])
1640
+ )
1641
+
1642
+ for actual_deviation in [0.05 , 0.1 , 0.2 , 0.5 , 1.0 ]:
1643
+ for spin in (Spin .up , Spin .down ):
1644
+ for number_occ_bands_spin_up , number_occ_bands_spin_down in zip (
1645
+ number_occ_bands_spin_up_all , number_occ_bands_spin_down_all , strict = False
1646
+ ):
1647
+ for i_arr , array in enumerate (band_overlaps .band_overlaps_dict [spin ]["matrices" ]):
1648
+ number_occ_bands = number_occ_bands_spin_up if spin is Spin .up else number_occ_bands_spin_down
1649
+
1650
+ shape = array .shape
1651
+ assert np .all (np .array (shape ) >= number_occ_bands )
1652
+ assert len (shape ) == 2
1653
+ assert shape [0 ] == shape [1 ]
1654
+
1655
+ # Generate a noisy background array
1656
+ patch_array = rng .uniform (0 , 10 , shape )
1657
+
1658
+ # Patch the top-left sub-array (the part that would be checked)
1659
+ patch_array [:number_occ_bands , :number_occ_bands ] = np .identity (number_occ_bands ) + rng .uniform (
1660
+ 0 , actual_deviation , (number_occ_bands , number_occ_bands )
1661
+ )
1662
+
1663
+ band_overlaps .band_overlaps_dict [spin ]["matrices" ][i_arr ] = patch_array
1664
+
1665
+ result = band_overlaps .has_good_quality_check_occupied_bands (
1666
+ number_occ_bands_spin_up = number_occ_bands_spin_up ,
1667
+ number_occ_bands_spin_down = number_occ_bands_spin_down ,
1668
+ spin_polarized = True ,
1669
+ limit_deviation = limit_deviation ,
1670
+ )
1671
+ # Assert for expected results
1672
+ if (
1673
+ actual_deviation == 0.05
1674
+ and number_occ_bands_spin_up <= 7
1675
+ and number_occ_bands_spin_down <= 7
1676
+ and spin is Spin .up
1677
+ or actual_deviation == 0.05
1678
+ and spin is Spin .down
1679
+ or actual_deviation == 0.1
1680
+ or actual_deviation in [0.2 , 0.5 , 1.0 ]
1681
+ and number_occ_bands_spin_up == 0
1682
+ and number_occ_bands_spin_down == 0
1683
+ ):
1684
+ assert result
1685
+ else :
1686
+ assert not result
1687
+
1688
+ def test_exceptions (self ):
1689
+ with pytest .raises (ValueError , match = "number_occ_bands_spin_down has to be specified" ):
1690
+ self .band_overlaps1 .has_good_quality_check_occupied_bands (
1691
+ number_occ_bands_spin_up = 4 ,
1692
+ spin_polarized = True ,
1693
+ )
1694
+ with pytest .raises (ValueError , match = "number_occ_bands_spin_down has to be specified" ):
1695
+ self .band_overlaps1_new .has_good_quality_check_occupied_bands (
1696
+ number_occ_bands_spin_up = 4 ,
1697
+ spin_polarized = True ,
1698
+ )
1699
+
1625
1700
def test_msonable (self ):
1626
1701
dict_data = self .band_overlaps2_new .as_dict ()
1627
1702
bandoverlaps_from_dict = Bandoverlaps .from_dict (dict_data )
0 commit comments