Skip to content

Commit deee1fa

Browse files
committed
Merge branch 'master' of github.com:materialsproject/pymatgen
2 parents 729fce8 + f8f2d21 commit deee1fa

File tree

5 files changed

+148
-60
lines changed

5 files changed

+148
-60
lines changed

src/pymatgen/analysis/xas/spectrum.py

Lines changed: 25 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,16 @@
1010
from scipy.interpolate import interp1d
1111

1212
from pymatgen.analysis.structure_matcher import StructureMatcher
13+
from pymatgen.core import Element
1314
from pymatgen.core.spectrum import Spectrum
1415
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
1516

1617
if TYPE_CHECKING:
18+
from collections.abc import Sequence
1719
from typing import Literal
1820

21+
from pymatgen.core import Structure
22+
1923
__author__ = "Chen Zheng, Yiming Chen"
2024
__copyright__ = "Copyright 2012, The Materials Project"
2125
__version__ = "3.0"
@@ -42,29 +46,31 @@ class XAS(Spectrum):
4246
Attributes:
4347
x (Sequence[float]): The sequence of energies.
4448
y (Sequence[float]): The sequence of mu(E).
45-
absorbing_element (str): The absorbing element of the spectrum.
49+
absorbing_element (str or .Element): The absorbing element of the spectrum.
4650
edge (str): The edge of the spectrum.
4751
spectrum_type (str): The type of the spectrum (XANES or EXAFS).
4852
absorbing_index (int): The absorbing index of the spectrum.
53+
zero_negative_intensity (bool) : Whether to set unphysical negative intensities to zero
4954
"""
5055

5156
XLABEL = "Energy"
5257
YLABEL = "Intensity"
5358

5459
def __init__(
5560
self,
56-
x,
57-
y,
58-
structure,
59-
absorbing_element,
60-
edge="K",
61-
spectrum_type="XANES",
62-
absorbing_index=None,
61+
x: Sequence,
62+
y: Sequence,
63+
structure: Structure,
64+
absorbing_element: str | Element,
65+
edge: str = "K",
66+
spectrum_type: str = "XANES",
67+
absorbing_index: int | None = None,
68+
zero_negative_intensity: bool = False,
6369
):
6470
"""Initialize a spectrum object."""
6571
super().__init__(x, y, structure, absorbing_element, edge)
6672
self.structure = structure
67-
self.absorbing_element = absorbing_element
73+
self.absorbing_element = Element(absorbing_element)
6874
self.edge = edge
6975
self.spectrum_type = spectrum_type
7076
self.e0 = self.x[np.argmax(np.gradient(self.y) / np.gradient(self.x))]
@@ -75,8 +81,16 @@ def __init__(
7581
]
7682
self.absorbing_index = absorbing_index
7783
# check for empty spectra and negative intensities
78-
if sum(1 for i in self.y if i <= 0) / len(self.y) > 0.05:
79-
raise ValueError("Double check the intensities. Most of them are non-positive.")
84+
neg_intens_mask = self.y < 0.0
85+
if len(self.y[neg_intens_mask]) / len(self.y) > 0.05:
86+
warnings.warn(
87+
"Double check the intensities. More than 5% of them are negative.",
88+
UserWarning,
89+
stacklevel=2,
90+
)
91+
self.zero_negative_intensity = zero_negative_intensity
92+
if self.zero_negative_intensity:
93+
self.y[neg_intens_mask] = 0.0
8094

8195
def __str__(self):
8296
return (

src/pymatgen/io/lobster/outputs.py

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1710,29 +1710,17 @@ def has_good_quality_check_occupied_bands(
17101710
Returns:
17111711
bool: True if the quality of the projection is good.
17121712
"""
1713-
for matrix in self.band_overlaps_dict[Spin.up]["matrices"]:
1714-
for iband1, band1 in enumerate(matrix):
1715-
for iband2, band2 in enumerate(band1):
1716-
if iband1 < number_occ_bands_spin_up and iband2 < number_occ_bands_spin_up:
1717-
if iband1 == iband2:
1718-
if abs(band2 - 1.0).all() > limit_deviation:
1719-
return False
1720-
elif band2.all() > limit_deviation:
1721-
return False
1722-
1723-
if spin_polarized:
1724-
for matrix in self.band_overlaps_dict[Spin.down]["matrices"]:
1725-
for iband1, band1 in enumerate(matrix):
1726-
for iband2, band2 in enumerate(band1):
1727-
if number_occ_bands_spin_down is None:
1728-
raise ValueError("number_occ_bands_spin_down has to be specified")
1729-
1730-
if iband1 < number_occ_bands_spin_down and iband2 < number_occ_bands_spin_down:
1731-
if iband1 == iband2:
1732-
if abs(band2 - 1.0).all() > limit_deviation:
1733-
return False
1734-
elif band2.all() > limit_deviation:
1735-
return False
1713+
if spin_polarized and number_occ_bands_spin_down is None:
1714+
raise ValueError("number_occ_bands_spin_down has to be specified")
1715+
1716+
for spin in (Spin.up, Spin.down) if spin_polarized else (Spin.up,):
1717+
num_occ_bands = number_occ_bands_spin_up if spin is Spin.up else number_occ_bands_spin_down
1718+
1719+
for overlap_matrix in self.band_overlaps_dict[spin]["matrices"]:
1720+
sub_array = np.asarray(overlap_matrix)[:num_occ_bands, :num_occ_bands]
1721+
1722+
if not np.allclose(sub_array, np.identity(num_occ_bands), atol=limit_deviation, rtol=0):
1723+
return False
17361724

17371725
return True
17381726

src/pymatgen/io/vasp/inputs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -970,6 +970,7 @@ def proc_val(key: str, val: str) -> list | bool | float | int | str:
970970
"PARAM1",
971971
"PARAM2",
972972
"ENCUT",
973+
"NUPDOWN",
973974
)
974975
int_keys = (
975976
"NSW",
@@ -987,7 +988,6 @@ def proc_val(key: str, val: str) -> list | bool | float | int | str:
987988
"LMAXMIX",
988989
"NSIM",
989990
"NKRED",
990-
"NUPDOWN",
991991
"ISPIND",
992992
"LDAUTYPE",
993993
"IVDW",

tests/analysis/xas/test_spectrum.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,10 @@ def test_str(self):
6767
assert str(self.k_xanes) == "Co K Edge XANES for LiCoO2: <super: <class 'XAS'>, <XAS object>>"
6868

6969
def test_validate(self):
70-
y_zeros = np.zeros(len(self.k_xanes.x))
71-
with pytest.raises(
72-
ValueError,
73-
match="Double check the intensities. Most of them are non-positive",
70+
y_zeros = -np.ones(len(self.k_xanes.x))
71+
with pytest.warns(
72+
UserWarning,
73+
match="Double check the intensities. More than 5% of them are negative.",
7474
):
7575
XAS(
7676
self.k_xanes.x,
@@ -79,6 +79,17 @@ def test_validate(self):
7979
self.k_xanes.absorbing_element,
8080
)
8181

82+
def test_zero_negative_intensity(self):
83+
y_w_neg_intens = [(-1) ** i * v for i, v in enumerate(self.k_xanes.y)]
84+
spectrum = XAS(
85+
self.k_xanes.x,
86+
y_w_neg_intens,
87+
self.k_xanes.structure,
88+
self.k_xanes.absorbing_element,
89+
zero_negative_intensity=True,
90+
)
91+
assert all(v == 0.0 for i, v in enumerate(spectrum.y) if i % 2 == 1)
92+
8293
def test_stitch_xafs(self):
8394
with pytest.raises(ValueError, match="Invalid mode. Only XAFS and L23 are supported"):
8495
XAS.stitch(self.k_xanes, self.k_exafs, mode="invalid")

tests/io/lobster/test_outputs.py

Lines changed: 96 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import copy
34
import json
45
import os
56
from unittest import TestCase
@@ -1481,7 +1482,7 @@ def test_get_bandstructure(self):
14811482

14821483
class TestBandoverlaps(TestCase):
14831484
def setUp(self):
1484-
# test spin-polarized calc and non spinpolarized calc
1485+
# test spin-polarized calc and non spin-polarized calc
14851486

14861487
self.band_overlaps1 = Bandoverlaps(f"{TEST_DIR}/bandOverlaps.lobster.1")
14871488
self.band_overlaps2 = Bandoverlaps(f"{TEST_DIR}/bandOverlaps.lobster.2")
@@ -1515,9 +1516,18 @@ def test_attributes(self):
15151516
assert self.band_overlaps2.max_deviation[-1] == approx(1.48451e-05)
15161517
assert self.band_overlaps2_new.max_deviation[-1] == approx(0.45154)
15171518

1518-
def test_has_good_quality(self):
1519+
def test_has_good_quality_maxDeviation(self):
15191520
assert not self.band_overlaps1.has_good_quality_maxDeviation(limit_maxDeviation=0.1)
15201521
assert not self.band_overlaps1_new.has_good_quality_maxDeviation(limit_maxDeviation=0.1)
1522+
1523+
assert self.band_overlaps1.has_good_quality_maxDeviation(limit_maxDeviation=100)
1524+
assert self.band_overlaps1_new.has_good_quality_maxDeviation(limit_maxDeviation=100)
1525+
assert self.band_overlaps2.has_good_quality_maxDeviation()
1526+
assert not self.band_overlaps2_new.has_good_quality_maxDeviation()
1527+
assert not self.band_overlaps2.has_good_quality_maxDeviation(limit_maxDeviation=0.0000001)
1528+
assert not self.band_overlaps2_new.has_good_quality_maxDeviation(limit_maxDeviation=0.0000001)
1529+
1530+
def test_has_good_quality_check_occupied_bands(self):
15211531
assert not self.band_overlaps1.has_good_quality_check_occupied_bands(
15221532
number_occ_bands_spin_up=9,
15231533
number_occ_bands_spin_down=5,
@@ -1545,65 +1555,58 @@ def test_has_good_quality(self):
15451555
assert not self.band_overlaps1.has_good_quality_check_occupied_bands(
15461556
number_occ_bands_spin_up=1,
15471557
number_occ_bands_spin_down=1,
1548-
limit_deviation=0.000001,
1558+
limit_deviation=1e-6,
15491559
spin_polarized=True,
15501560
)
15511561
assert not self.band_overlaps1_new.has_good_quality_check_occupied_bands(
15521562
number_occ_bands_spin_up=1,
15531563
number_occ_bands_spin_down=1,
1554-
limit_deviation=0.000001,
1564+
limit_deviation=1e-6,
15551565
spin_polarized=True,
15561566
)
15571567
assert not self.band_overlaps1.has_good_quality_check_occupied_bands(
15581568
number_occ_bands_spin_up=1,
15591569
number_occ_bands_spin_down=0,
1560-
limit_deviation=0.000001,
1570+
limit_deviation=1e-6,
15611571
spin_polarized=True,
15621572
)
15631573
assert not self.band_overlaps1_new.has_good_quality_check_occupied_bands(
15641574
number_occ_bands_spin_up=1,
15651575
number_occ_bands_spin_down=0,
1566-
limit_deviation=0.000001,
1576+
limit_deviation=1e-6,
15671577
spin_polarized=True,
15681578
)
15691579
assert not self.band_overlaps1.has_good_quality_check_occupied_bands(
15701580
number_occ_bands_spin_up=0,
15711581
number_occ_bands_spin_down=1,
1572-
limit_deviation=0.000001,
1582+
limit_deviation=1e-6,
15731583
spin_polarized=True,
15741584
)
15751585
assert not self.band_overlaps1_new.has_good_quality_check_occupied_bands(
15761586
number_occ_bands_spin_up=0,
15771587
number_occ_bands_spin_down=1,
1578-
limit_deviation=0.000001,
1588+
limit_deviation=1e-6,
15791589
spin_polarized=True,
15801590
)
15811591
assert not self.band_overlaps1.has_good_quality_check_occupied_bands(
15821592
number_occ_bands_spin_up=4,
15831593
number_occ_bands_spin_down=4,
1584-
limit_deviation=0.001,
1594+
limit_deviation=1e-3,
15851595
spin_polarized=True,
15861596
)
15871597
assert not self.band_overlaps1_new.has_good_quality_check_occupied_bands(
15881598
number_occ_bands_spin_up=4,
15891599
number_occ_bands_spin_down=4,
1590-
limit_deviation=0.001,
1600+
limit_deviation=1e-3,
15911601
spin_polarized=True,
15921602
)
1593-
1594-
assert self.band_overlaps1.has_good_quality_maxDeviation(limit_maxDeviation=100)
1595-
assert self.band_overlaps1_new.has_good_quality_maxDeviation(limit_maxDeviation=100)
1596-
assert self.band_overlaps2.has_good_quality_maxDeviation()
1597-
assert not self.band_overlaps2_new.has_good_quality_maxDeviation()
1598-
assert not self.band_overlaps2.has_good_quality_maxDeviation(limit_maxDeviation=0.0000001)
1599-
assert not self.band_overlaps2_new.has_good_quality_maxDeviation(limit_maxDeviation=0.0000001)
16001603
assert not self.band_overlaps2.has_good_quality_check_occupied_bands(
1601-
number_occ_bands_spin_up=10, limit_deviation=0.0000001
1604+
number_occ_bands_spin_up=10, limit_deviation=1e-7
16021605
)
16031606
assert not self.band_overlaps2_new.has_good_quality_check_occupied_bands(
1604-
number_occ_bands_spin_up=10, limit_deviation=0.0000001
1607+
number_occ_bands_spin_up=10, limit_deviation=1e-7
16051608
)
1606-
assert not self.band_overlaps2.has_good_quality_check_occupied_bands(
1609+
assert self.band_overlaps2.has_good_quality_check_occupied_bands(
16071610
number_occ_bands_spin_up=1, limit_deviation=0.1
16081611
)
16091612

@@ -1614,14 +1617,86 @@ def test_has_good_quality(self):
16141617
number_occ_bands_spin_up=1, limit_deviation=1e-8
16151618
)
16161619
assert self.band_overlaps2.has_good_quality_check_occupied_bands(number_occ_bands_spin_up=10, limit_deviation=1)
1617-
assert not self.band_overlaps2_new.has_good_quality_check_occupied_bands(
1620+
assert self.band_overlaps2_new.has_good_quality_check_occupied_bands(
16181621
number_occ_bands_spin_up=2, limit_deviation=0.1
16191622
)
16201623
assert self.band_overlaps2.has_good_quality_check_occupied_bands(number_occ_bands_spin_up=1, limit_deviation=1)
16211624
assert self.band_overlaps2_new.has_good_quality_check_occupied_bands(
16221625
number_occ_bands_spin_up=1, limit_deviation=2
16231626
)
16241627

1628+
def test_has_good_quality_check_occupied_bands_patched(self):
1629+
"""Test with patched data."""
1630+
1631+
limit_deviation = 0.1
1632+
1633+
rng = np.random.default_rng(42) # set seed for reproducibility
1634+
1635+
band_overlaps = copy.deepcopy(self.band_overlaps1_new)
1636+
1637+
number_occ_bands_spin_up_all = list(range(band_overlaps.band_overlaps_dict[Spin.up]["matrices"][0].shape[0]))
1638+
number_occ_bands_spin_down_all = list(
1639+
range(band_overlaps.band_overlaps_dict[Spin.down]["matrices"][0].shape[0])
1640+
)
1641+
1642+
for actual_deviation in [0.05, 0.1, 0.2, 0.5, 1.0]:
1643+
for spin in (Spin.up, Spin.down):
1644+
for number_occ_bands_spin_up, number_occ_bands_spin_down in zip(
1645+
number_occ_bands_spin_up_all, number_occ_bands_spin_down_all, strict=False
1646+
):
1647+
for i_arr, array in enumerate(band_overlaps.band_overlaps_dict[spin]["matrices"]):
1648+
number_occ_bands = number_occ_bands_spin_up if spin is Spin.up else number_occ_bands_spin_down
1649+
1650+
shape = array.shape
1651+
assert np.all(np.array(shape) >= number_occ_bands)
1652+
assert len(shape) == 2
1653+
assert shape[0] == shape[1]
1654+
1655+
# Generate a noisy background array
1656+
patch_array = rng.uniform(0, 10, shape)
1657+
1658+
# Patch the top-left sub-array (the part that would be checked)
1659+
patch_array[:number_occ_bands, :number_occ_bands] = np.identity(number_occ_bands) + rng.uniform(
1660+
0, actual_deviation, (number_occ_bands, number_occ_bands)
1661+
)
1662+
1663+
band_overlaps.band_overlaps_dict[spin]["matrices"][i_arr] = patch_array
1664+
1665+
result = band_overlaps.has_good_quality_check_occupied_bands(
1666+
number_occ_bands_spin_up=number_occ_bands_spin_up,
1667+
number_occ_bands_spin_down=number_occ_bands_spin_down,
1668+
spin_polarized=True,
1669+
limit_deviation=limit_deviation,
1670+
)
1671+
# Assert for expected results
1672+
if (
1673+
actual_deviation == 0.05
1674+
and number_occ_bands_spin_up <= 7
1675+
and number_occ_bands_spin_down <= 7
1676+
and spin is Spin.up
1677+
or actual_deviation == 0.05
1678+
and spin is Spin.down
1679+
or actual_deviation == 0.1
1680+
or actual_deviation in [0.2, 0.5, 1.0]
1681+
and number_occ_bands_spin_up == 0
1682+
and number_occ_bands_spin_down == 0
1683+
):
1684+
assert result
1685+
else:
1686+
assert not result
1687+
1688+
def test_exceptions(self):
1689+
with pytest.raises(ValueError, match="number_occ_bands_spin_down has to be specified"):
1690+
self.band_overlaps1.has_good_quality_check_occupied_bands(
1691+
number_occ_bands_spin_up=4,
1692+
spin_polarized=True,
1693+
)
1694+
with pytest.raises(ValueError, match="number_occ_bands_spin_down has to be specified"):
1695+
self.band_overlaps1_new.has_good_quality_check_occupied_bands(
1696+
number_occ_bands_spin_up=4,
1697+
spin_polarized=True,
1698+
)
1699+
16251700
def test_msonable(self):
16261701
dict_data = self.band_overlaps2_new.as_dict()
16271702
bandoverlaps_from_dict = Bandoverlaps.from_dict(dict_data)

0 commit comments

Comments
 (0)