Source code for janus_core.processing.correlator

"""Module to correlate scalar data on-the-fly."""

from __future__ import annotations

from collections.abc import Iterable

from ase import Atoms
import numpy as np

from janus_core.helpers.janus_types import Observable


[docs] class Correlator: """ Correlate scalar real values, <ab>. Parameters ---------- blocks : int Number of correlation blocks. points : int Number of points per block. averaging : int Averaging window per block level. """
[docs] def __init__(self, *, blocks: int, points: int, averaging: int) -> None: """ Initialise an empty Correlator. Parameters ---------- blocks : int Number of correlation blocks. points : int Number of points per block. averaging : int Averaging window per block level. """ self._blocks = blocks self._points = points self._averaging = averaging self._max_block_used = 0 self._min_dist = self._points / self._averaging self._accumulator = np.zeros((self._blocks, 2)) self._count_accumulated = np.zeros(self._blocks, dtype=int) self._shift_index = np.zeros(self._blocks, dtype=int) self._shift = np.zeros((self._blocks, self._points, 2)) self._shift_not_null = np.zeros((self._blocks, self._points), dtype=bool) self._correlation = np.zeros((self._blocks, self._points)) self._count_correlated = np.zeros((self._blocks, self._points), dtype=int)
[docs] def update(self, a: float, b: float) -> None: """ Update the correlation, <ab>, with new values a and b. Parameters ---------- a : float Newly observed value of left correland. b : float Newly observed value of right correland. """ self._propagate(a, b, 0)
[docs] def _propagate(self, a: float, b: float, block: int) -> None: """ Propagate update down block hierarchy. Parameters ---------- a : float Newly observed value of left correland/average. b : float Newly observed value of right correland/average. block : int Block in the hierachy being updated. """ if block == self._blocks: return shift = self._shift_index[block] self._max_block_used = max(self._max_block_used, block) self._shift[block, shift, :] = a, b self._accumulator[block, :] += a, b self._shift_not_null[block, shift] = True self._count_accumulated[block] += 1 if self._count_accumulated[block] == self._averaging: self._propagate( self._accumulator[block, 0] / self._averaging, self._accumulator[block, 1] / self._averaging, block + 1, ) self._accumulator[block, :] = 0.0 self._count_accumulated[block] = 0 i = self._shift_index[block] if block == 0: j = i for point in range(self._points): if self._shifts_valid(block, i, j): self._correlation[block, point] += ( self._shift[block, i, 0] * self._shift[block, j, 1] ) self._count_correlated[block, point] += 1 j -= 1 if j < 0: j += self._points else: for point in range(self._min_dist, self._points): if j < 0: j = j + self._points if self._shifts_valid(block, i, j): self._correlation[block, point] += ( self._shift[block, i, 0] * self._shift[block, j, 1] ) self._count_correlated[block, point] += 1 j = j - 1 self._shift_index[block] = (self._shift_index[block] + 1) % self._points
[docs] def _shifts_valid(self, block: int, p_i: int, p_j: int) -> bool: """ Return True if the shift registers have data. Parameters ---------- block : int Block to check the shift register of. p_i : int Index i in the shift (left correland). p_j : int Index j in the shift (right correland). Returns ------- bool Whether the shift indices have data. """ return self._shift_not_null[block, p_i] and self._shift_not_null[block, p_j]
[docs] def get(self) -> tuple[Iterable[float], Iterable[float]]: """ Obtain the correlation and lag times. Returns ------- correlation : Iterable[float] The correlation values <a(t)b(t+t')>. lags : Iterable[float]] The correlation lag times t'. """ correlation = np.zeros(self._points * self._blocks) lags = np.zeros(self._points * self._blocks) lag = 0 for i in range(self._points): if self._count_correlated[0, i] > 0: correlation[lag] = ( self._correlation[0, i] / self._count_correlated[0, i] ) lags[lag] = i lag += 1 for k in range(1, self._max_block_used): for i in range(self._min_dist, self._points): if self._count_correlated[k, i] > 0: correlation[lag] = ( self._correlation[k, i] / self._count_correlated[k, i] ) lags[lag] = float(i) * float(self._averaging) ** k lag += 1 return (correlation[0:lag], lags[0:lag])
[docs] class Correlation: """ Represents a user correlation, <ab>. Parameters ---------- a : tuple[Observable, dict] Getter for a and kwargs. b : tuple[Observable, dict] Getter for b and kwargs. name : str Name of correlation. blocks : int Number of correlation blocks. points : int Number of points per block. averaging : int Averaging window per block level. update_frequency : int Frequency to update the correlation, md steps. """
[docs] def __init__( self, a: Observable | tuple[Observable, tuple, dict], b: Observable | tuple[Observable, tuple, dict], name: str, blocks: int, points: int, averaging: int, update_frequency: int, ) -> None: """ Initialise a correlation. Parameters ---------- a : tuple[Observable, tuple, dict] Getter for a and kwargs. b : tuple[Observable, tuple, dict] Getter for b and kwargs. name : str Name of correlation. blocks : int Number of correlation blocks. points : int Number of points per block. averaging : int Averaging window per block level. update_frequency : int Frequency to update the correlation, md steps. """ self.name = name if isinstance(a, tuple): self._get_a, self._a_args, self._a_kwargs = a else: self._get_a = a self._a_args, self._a_kwargs = (), {} if isinstance(b, tuple): self._get_b, self._b_args, self._b_kwargs = b else: self._get_b = b self._b_args, self._b_kwargs = (), {} self._correlator = Correlator(blocks=blocks, points=points, averaging=averaging) self._update_frequency = update_frequency
@property def update_frequency(self) -> int: """ Get update frequency. Returns ------- int Correlation update frequency. """ return self._update_frequency
[docs] def update(self, atoms: Atoms) -> None: """ Update a correlation. Parameters ---------- atoms : Atoms Atoms object to observe values from. """ self._correlator.update( self._get_a(atoms, *self._a_args, **self._a_kwargs), self._get_b(atoms, *self._b_args, **self._b_kwargs), )
[docs] def get(self) -> tuple[Iterable[float], Iterable[float]]: """ Get the correlation value and lags. Returns ------- correlation : Iterable[float] The correlation values <a(t)b(t+t')>. lags : Iterable[float]] The correlation lag times t'. """ return self._correlator.get()
[docs] def __str__(self) -> str: """ Return string representation of correlation. Returns ------- str String representation. """ return self.name