Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions src/ptychi/api/options/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,9 @@ class ProbeOrthogonalizeIncoherentModesOptions(FeatureOptions):

method: enums.OrthogonalizationMethods = enums.OrthogonalizationMethods.SVD
"""The method to use for incoherent_mode orthogonalization."""


sort_by_occupancy: bool = False
"""Keep the probes sorted so that mode with highest occupancy is the 0th shared mode"""

@dataclasses.dataclass
class ProbeOrthogonalizeOPRModesOptions(FeatureOptions):
Expand Down Expand Up @@ -651,25 +653,26 @@ def get_non_data_fields(self) -> dict:

@dataclasses.dataclass
class SynthesisDictLearnProbeOptions(Options):

d_mat: Union[ndarray, Tensor] = None

enabled: bool = False
enabled_shared: bool = False

thresholding_type_shared: str = 'hard'
"""Choose between 'hard' or 'soft' thresholding."""

dictionary_matrix: Union[ndarray, Tensor] = None
"""The synthesis sparse dictionary matrix; contains the basis functions
that will be used to represent the probe via the sparse code weights."""

d_mat_conj_transpose: Union[ndarray, Tensor] = None
"""Conjugate transpose of the synthesis sparse dictionary matrix."""

d_mat_pinv: Union[ndarray, Tensor] = None
dictionary_matrix_pinv: Union[ndarray, Tensor] = None
"""Moore-Penrose pseudoinverse of the synthesis sparse dictionary matrix."""

probe_sparse_code: Union[ndarray, Tensor] = None
"""Sparse code weights vector."""
sparse_code_probe_shared: Union[ndarray, Tensor] = None
"""Sparse code weights vector for the shared modes."""

probe_sparse_code_nnz: float = None
sparse_code_probe_shared_nnz: float = None
"""Number of non-zeros we will keep when enforcing sparsity constraint on
the sparse code weights vector probe_sparse_code."""

enabled: bool = False
the SHARED sparse code weights vector sparse_code_probe_shared."""

@dataclasses.dataclass
class PositionCorrectionOptions(Options):
Expand Down
9 changes: 7 additions & 2 deletions src/ptychi/api/options/lsqml.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,20 @@ class LSQMLObjectOptions(base.ObjectOptions):
propagation always uses all probe modes regardless of this option.
"""

@dataclasses.dataclass
class LSQMLProbeExperimentalOptions(base.Options):
sdl_probe_options: base.SynthesisDictLearnProbeOptions = dataclasses.field(default_factory=base.SynthesisDictLearnProbeOptions)


@dataclasses.dataclass
class LSQMLProbeOptions(base.ProbeOptions):
optimal_step_size_scaler: float = 0.9
"""
A scaler for the solved optimal step size (beta_LSQ in PtychoShelves).
"""


experimental: LSQMLProbeExperimentalOptions = dataclasses.field(default_factory=LSQMLProbeExperimentalOptions)


@dataclasses.dataclass
class LSQMLProbePositionOptions(base.ProbePositionOptions):
pass
Expand Down
234 changes: 201 additions & 33 deletions src/ptychi/data_structures/probe.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,13 @@ def constrain_incoherent_modes_orthogonality(self):

probe = self.data

if self.options.orthogonalize_incoherent_modes.sort_by_occupancy:
shared_occupancy = torch.sum(torch.abs(probe[0, ...]) ** 2, (-2, -1)) / torch.sum(
torch.abs(probe[0, ...]) ** 2
)
shared_occupancy = torch.sort(shared_occupancy, dim=0, descending=True)
probe[0, ...] = probe[0, shared_occupancy[1], ...]

norm_first_mode_orig = pmath.norm(probe[0, 0], dim=(-2, -1))

if self.orthogonalize_incoherent_modes_method == "gs":
Expand Down Expand Up @@ -470,31 +477,60 @@ def __init__(self, name = "probe", options = None, *args, **kwargs):

super().__init__(name, options, build_optimizer=False, data_as_parameter=False, *args, **kwargs)

dictionary_matrix, dictionary_matrix_pinv, dictionary_matrix_H = self.get_dictionary()
dictionary_matrix, dictionary_matrix_pinv = self.get_dictionary()
self.register_buffer("dictionary_matrix", dictionary_matrix)
self.register_buffer("dictionary_matrix_pinv", dictionary_matrix_pinv)
self.register_buffer("dictionary_matrix_H", dictionary_matrix_H)

probe_sparse_code_nnz = torch.tensor( self.options.experimental.sdl_probe_options.probe_sparse_code_nnz, dtype=torch.uint32 )
self.register_buffer("probe_sparse_code_nnz", probe_sparse_code_nnz )

sparse_code_probe = self.get_sparse_code_weights()
self.register_parameter("sparse_code_probe", torch.nn.Parameter(sparse_code_probe))

sparse_code_probe_shared_nnz = torch.tensor(
self.options.experimental.sdl_probe_options.sparse_code_probe_shared_nnz,
dtype=torch.uint32,
)
sparse_code_probe_shared = self.get_sparse_code_probe_shared_weights()
self.register_buffer("sparse_code_probe_shared_nnz", sparse_code_probe_shared_nnz)
self.register_parameter(
"sparse_code_probe_shared", torch.nn.Parameter(sparse_code_probe_shared)
)

self.build_optimizer()

def get_dictionary(self):
dictionary_matrix = torch.tensor( self.options.experimental.sdl_probe_options.d_mat, dtype=torch.complex64 )
dictionary_matrix_pinv = torch.tensor( self.options.experimental.sdl_probe_options.d_mat_pinv, dtype=torch.complex64 )
dictionary_matrix_H = torch.tensor( self.options.experimental.sdl_probe_options.d_mat_conj_transpose, dtype=torch.complex64 )
return dictionary_matrix, dictionary_matrix_pinv, dictionary_matrix_H

def get_sparse_code_weights(self):
sz = self.data.shape
probe_vec = torch.reshape( self.data[0,...], (sz[1], sz[2] * sz[3]))
probe_vec = torch.swapaxes( probe_vec, 0, -1)
sparse_code_probe = self.dictionary_matrix_pinv @ probe_vec
return sparse_code_probe
dictionary_matrix = torch.tensor(
self.options.experimental.sdl_probe_options.dictionary_matrix, dtype=torch.complex64
)
dictionary_matrix_pinv = torch.tensor(
self.options.experimental.sdl_probe_options.dictionary_matrix_pinv,
dtype=torch.complex64,
)
return dictionary_matrix, dictionary_matrix_pinv

def get_sparse_code_weights_vs_scanpositions(self, probe_vs_scanpositions):
"""Get the sparse code weights for a given probe vs scan positions.

Parameters
----------
probe_vs_scanpositions : Tensor
A (n_pos, 1, h, w) tensor giving the probe vs scan positions.

Returns
-------
Tensor
A tensor giving the sparse code weights for the given probe vs scan positions.
"""
sz = probe_vs_scanpositions.shape
probe_vec = torch.reshape(probe_vs_scanpositions, (sz[0], sz[1], sz[2] * sz[3]))
sparse_code_vs_scanpositions = torch.einsum(
"ij,klj->ikl", self.dictionary_matrix_pinv, probe_vec
)

return sparse_code_vs_scanpositions

def get_sparse_code_probe_shared_weights(self):
probe_shared = self.data[0, ...]
sz = probe_shared.shape
probe_vec = torch.reshape(probe_shared, (sz[0], sz[1] * sz[2]))
sparse_code_probe_shared = self.dictionary_matrix_pinv @ probe_vec.T

return sparse_code_probe_shared.T

def generate(self):
"""Generate the probe using the sparse code, and set the
Expand All @@ -505,27 +541,159 @@ def generate(self):
Tensor
A (n_opr_modes, n_modes, h, w) tensor giving the generated probe.
"""
probe_vec = self.dictionary_matrix @ self.sparse_code_probe
probe_vec = torch.swapaxes( probe_vec, 0, -1)
probe = torch.reshape(probe_vec, *[self.data[0,...].shape])
probe = probe[None,...]

# we only use sparse codes for the shared modes, not the OPRs
probe = torch.cat((probe, self.data[1:,...]), 0)

self.set_data(probe)
return probe


if self.options.experimental.sdl_probe_options.enabled_shared:
sz = self.data.shape
probe = torch.zeros(*[sz], dtype=torch.complex64)

probe_shared = self.dictionary_matrix @ self.sparse_code_probe_shared.T

probe[0, ...] = torch.reshape(probe_shared.T, *[sz[1:]])
probe[1:, 0, ...] = self.data[1:, 0, ...]

self.set_data(probe)

else:
probe = self.data

def build_optimizer(self):
if self.optimizable and self.optimizer_class is None:
raise ValueError(
"Parameter {} is optimizable but no optimizer is specified.".format(self.name)
)
if self.optimizable:
self.optimizer = self.optimizer_class([self.sparse_code_probe], **self.optimizer_params)
self.optimizer = self.optimizer_class(
[self.sparse_code_probe_shared], **self.optimizer_params
)

def set_sparse_code_probe_shared(self, data):
"""
Set the sparse code weights for the shared probe.

Parameters
----------
data : Tensor
A (n_dict_bases, n_modes) tensor giving the sparse code weights for the shared probe.
"""
self.sparse_code_probe_shared.data = data

def set_sparse_code(self, data):
self.sparse_code_probe.data = data
def initialize_grad_sparse_code_probe_shared(self):
"""
Initialize the gradient of the sparse code weights update for the shared probe.

Parameters
----------
data : Tensor
A (n_dict_bases, n_modes) tensor giving the sparse code weights for the shared probe.
"""
self.sparse_code_probe_shared.grad = torch.zeros_like(self.sparse_code_probe_shared.data)

def set_gradient_sparse_code_probe_shared(self, grad):
"""
Set the gradient of the sparse code weights update for the shared probe.

Parameters
----------
data : Tensor
A (n_dict_bases, n_modes) tensor giving the sparse code weights for the shared probe.
"""
self.sparse_code_probe_shared.grad = grad

def set_sparse_code_weights_vs_scanpositions(
self, sparse_code_vs_scanpositions: Tensor, indices: tuple | Tensor = None
):
"""
Set the sparse code weights for a given probe vs scan positions.

Parameters
----------
sparse_code_vs_scanpositions : Tensor
A (n_pos, n_opr_modes, n_scpm) tensor giving the sparse code weights for the
given probe vs scan positions.
indices : tuple | Tensor
The indices to apply to the sparse code weights.
"""
raise NotImplementedError("This method is not implemented yet.")
if indices is None:
indices = slice(None)
self.sparse_code_weights_vs_scanpositions[indices] = sparse_code_vs_scanpositions

def get_probe_update_direction_sparse_code_probe_shared(self, delta_p_i, chi, obj_patches):
nr = chi.shape[-2]
nc = chi.shape[-1]
nrnc = nr * nc
n_scpm = chi.shape[-3]
n_spos = chi.shape[-4]

obj_patches = torch.reshape(obj_patches, (n_spos, nrnc))
chi = torch.reshape(chi, (n_spos, n_scpm, nrnc)).permute(2, 0, 1)

# get sparse code update direction
delta_sparse_code = torch.einsum(
"ijk,kl->lij",
torch.reshape(delta_p_i, (n_spos, n_scpm, nrnc)),
self.dictionary_matrix.conj(),
)

# compute optimal step length for sparse code update
dict_delta_sparse_code = torch.einsum(
"ij,jkl->ikl", self.dictionary_matrix, delta_sparse_code
)

denom = (torch.abs(dict_delta_sparse_code) ** 2) * obj_patches.swapaxes(0, -1)[..., None]
denom = torch.einsum("ij,jik->ik", torch.conj(obj_patches), denom)

numer = torch.conj(dict_delta_sparse_code) * chi
numer = torch.einsum("ij,jik->ik", torch.conj(obj_patches), numer)

# real is used to throw away small imag part due to numerical precision errors
optimal_step_sparse_code = (numer / denom).real

optimal_delta_sparse_code = optimal_step_sparse_code[None, ...] * delta_sparse_code

# enforce sparsity constraint on sparse code
abs_sparse_code = torch.abs(optimal_delta_sparse_code)
abs_sparse_code_sorted = torch.sort(abs_sparse_code, dim=0, descending=True)

sel = abs_sparse_code_sorted[0][self.sparse_code_probe_shared_nnz, ...]
sparse_code_mask = abs_sparse_code >= sel[None, ...]

# hard or soft thresholding
if self.options.experimental.sdl_probe_options.thresholding_type_shared == "hard":
optimal_delta_sparse_code = optimal_delta_sparse_code * sparse_code_mask
elif self.options.experimental.sdl_probe_options.thresholding_type_shared == "soft":
optimal_delta_sparse_code = (
(abs_sparse_code - sel[None, ...])
* sparse_code_mask
* torch.exp(1j * torch.angle(optimal_delta_sparse_code))
)

delta_p_i = torch.einsum(
"ij,jlk->ilk", self.dictionary_matrix, optimal_delta_sparse_code
).permute(1, 2, 0)

delta_p_i = torch.reshape(delta_p_i, (n_spos, n_scpm, nr, nc))

return delta_p_i, optimal_delta_sparse_code

def get_grad(self) -> torch.Tensor:
"""Get the gradient of the sparse code weights for the shared probe.
This method overrides the method in the base class, which returns
the `.grad` attribute of the tensor.

Returns
-------
Tensor
The gradient of the sparse code weights for the shared probe.
"""
return self.sparse_code_probe_shared.grad

def set_grad(self, grad: torch.Tensor):
"""Set the gradient of the sparse code weights for the shared probe.
This method overrides the method in the base class, which sets the `.grad`
attribute of the tensor.
"""
self.set_gradient_sparse_code_probe_shared(grad)


class DIPProbe(Probe):
Expand Down
4 changes: 4 additions & 0 deletions src/ptychi/maths.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,10 @@ def orthogonalize_svd(
def project(a, b, dim=None):
"""Return complex vector projection of a onto b for along given axis."""
projected_length = inner(a, b, dim=dim, keepdims=True) / inner(b, b, dim=dim, keepdims=True)

# if the inner product of b with itself has any zeros:
projected_length = torch.nan_to_num(projected_length, nan=0.0)

return projected_length * b

def inner(x, y, dim=None, keepdims=False):
Expand Down
7 changes: 7 additions & 0 deletions src/ptychi/reconstructors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,6 +642,13 @@ def __init__(
)
self.forward_model = None
self.build_forward_model()

@property
def use_sparse_probe_shared_update(self):
return (
self.parameter_group.probe.representation == "sparse_code"
and self.parameter_group.probe.options.experimental.sdl_probe_options.enabled_shared
)

def build_forward_model(self):
self.forward_model = fm.PlanarPtychographyForwardModel(
Expand Down
Loading
Loading