attoworld.personal.vlad

A collection of utility functions and a class for physics calculations.

This module provides various numerical helper functions, including FFT utilities, windowing functions, and data analysis tools, along with a class defining atomic unit constants.

   1# -*- coding: utf-8 -*-
   2# Author: Vladislav S. Yakovlev
   3"""A collection of utility functions and a class for physics calculations.
   4
   5This module provides various numerical helper functions, including FFT
   6utilities, windowing functions, and data analysis tools, along with a
   7class defining atomic unit constants.
   8"""
   9
  10import numpy as np
  11import scipy
  12import scipy.linalg  # Explicitly import submodule used
  13from typing import Tuple, Optional, Union, List, Any
  14
  15
  16# Type alias for 1D or 2D NumPy arrays primarily holding float data.
  17# Using Any for dtype as functions sometimes handle complex numbers.
  18ArrayLike = np.ndarray[Any, Any]
  19
  20
  21def nextpow2(number: float) -> int:
  22    """Computes the exponent for the smallest power of 2 >= number.
  23
  24    Args:
  25      number: The input number.
  26
  27    Returns:
  28      The smallest integer exponent `exp` such that 2**exp >= number.
  29    """
  30    return int(np.ceil(np.log2(number)))
  31
  32
  33def soft_window(x_grid: ArrayLike, x_begin: float, x_end: float) -> ArrayLike:
  34    """Computes a soft window function.
  35
  36    The window smoothly transitions from 1 to 0 over the interval
  37    [min(x_begin, x_end), max(x_begin, x_end)].
  38    If x_begin <= x_end, it's a cosine-squared fall-off.
  39    If x_begin > x_end, it's a sine-squared rise-on.
  40
  41    Args:
  42      x_grid: A 1D array of x-coordinates.
  43      x_begin: The x-coordinate where the transition starts (value is 1).
  44      x_end: The x-coordinate where the transition ends (value is 0).
  45
  46    Returns:
  47      A 1D array of the same size as x_grid, containing the window values.
  48    """
  49    window = np.zeros_like(x_grid)
  50    x_min_transition = min(x_begin, x_end)
  51    x_max_transition = max(x_begin, x_end)
  52
  53    # Determine indices for different parts of the window
  54    indices_before_transition = np.where(x_grid < x_min_transition)[0]
  55    idx_transition_start = (
  56        indices_before_transition[-1] + 1 if indices_before_transition.size > 0 else 0
  57    )
  58
  59    indices_after_transition = np.where(x_grid > x_max_transition)[0]
  60    idx_transition_end = (
  61        indices_after_transition[0]
  62        if indices_after_transition.size > 0
  63        else len(x_grid)
  64    )
  65
  66    # Define the transition region
  67    x_transition = x_grid[idx_transition_start:idx_transition_end]
  68
  69    if x_begin <= x_end:  # Window goes from 1 down to 0
  70        window[:idx_transition_start] = 1.0
  71        if (
  72            idx_transition_end > idx_transition_start
  73            and x_max_transition > x_min_transition
  74        ):
  75            window[idx_transition_start:idx_transition_end] = (
  76                np.cos(
  77                    np.pi
  78                    / 2.0
  79                    * (x_transition - x_min_transition)
  80                    / (x_max_transition - x_min_transition)
  81                )
  82                ** 2
  83            )
  84        # Values after x_end remain 0 (initialized)
  85    else:  # Window goes from 0 up to 1 (x_begin > x_end, so x_min_transition = x_end)
  86        window[idx_transition_end:] = 1.0
  87        if (
  88            idx_transition_end > idx_transition_start
  89            and x_max_transition > x_min_transition
  90        ):
  91            window[idx_transition_start:idx_transition_end] = (
  92                np.sin(
  93                    np.pi
  94                    / 2.0
  95                    * (x_transition - x_min_transition)
  96                    / (x_max_transition - x_min_transition)
  97                )
  98                ** 2
  99            )
 100        # Values before x_begin (which is x_max_transition) remain 0 (initialized)
 101    return window
 102
 103
 104def get_significant_part_indices_v1(
 105    array_data: ArrayLike, threshold: float = 1e-8
 106) -> Tuple[int, int]:
 107    """Returns indices (i1, i2) for the slice A[i1:i2] containing the
 108    significant part of the array.
 109
 110    The significant part is defined relative to the maximum absolute value
 111    in the array. Elements A[:i1] and A[i2:] are considered "small".
 112
 113    Args:
 114      array_data: The input 1D array.
 115      threshold: The relative threshold to determine significance.
 116                 Elements are significant if abs(element) >= threshold * max(abs(array_data)).
 117
 118    Returns:
 119      A tuple (i1, i2) representing the start (inclusive) and
 120      end (exclusive) indices of the significant part.
 121    """
 122    abs_array = np.abs(array_data)
 123    if abs_array.size == 0:
 124        return 0, 0
 125
 126    idx_max = np.argmax(abs_array)
 127    array_max_val = abs_array[idx_max]
 128
 129    if array_max_val == 0:  # All elements are zero
 130        return 0, len(array_data)
 131
 132    significant_indices_before_max = np.where(
 133        abs_array[:idx_max] >= threshold * array_max_val
 134    )[0]
 135    i1 = (
 136        significant_indices_before_max[0]
 137        if significant_indices_before_max.size > 0
 138        else idx_max
 139    )
 140
 141    significant_indices_from_max = np.where(
 142        abs_array[idx_max:] >= threshold * array_max_val
 143    )[0]
 144    i2 = (
 145        idx_max + significant_indices_from_max[-1] + 1
 146        if significant_indices_from_max.size > 0
 147        else idx_max + 1
 148    )
 149    return i1, i2
 150
 151
 152def get_significant_part_indices_v2(
 153    array_data: ArrayLike, threshold: float = 1e-8
 154) -> Tuple[int, int]:
 155    """Returns indices (i1, i2) based on parts of the array that are "small".
 156
 157    The interpretation of (i1, i2) from the original code is:
 158    `i1` is the index of the last element *before* the peak region that is
 159    considered small (abs(element) < threshold * max_val).
 160    `i2` is the index *after* the first element *after* the peak region that
 161    is considered small.
 162    The docstring of the original function was "Return a tuple (i1,i2) such
 163    that none of the elements A[i1:i2] is small", which might be misleading
 164    as A[i1] and A[i2-1] could themselves be small by this definition.
 165    A slice like A[i1+1 : i2-1] or A[i1+1 : idx_first_small_after_peak]
 166    might better correspond to "all elements are not small".
 167
 168    Args:
 169      array_data: The input 1D array.
 170      threshold: The relative threshold to determine smallness.
 171                 Elements are small if abs(element) < threshold * max(abs(array_data)).
 172
 173    Returns:
 174      A tuple (i1, i2).
 175    """
 176    abs_array = np.abs(array_data)
 177    if abs_array.size == 0:
 178        return 0, 0
 179
 180    idx_max = np.argmax(abs_array)
 181    array_max_val = abs_array[idx_max]
 182
 183    if array_max_val == 0:  # All elements are zero
 184        return 0, len(array_data)
 185
 186    small_indices_before_max = np.where(
 187        abs_array[:idx_max] < threshold * array_max_val
 188    )[0]
 189    i1 = small_indices_before_max[-1] if small_indices_before_max.size > 0 else 0
 190
 191    small_indices_from_max = np.where(abs_array[idx_max:] < threshold * array_max_val)[
 192        0
 193    ]
 194    # small_indices_from_max are relative to idx_max
 195    i2 = (
 196        idx_max + small_indices_from_max[0] + 1
 197        if small_indices_from_max.size > 0
 198        else len(array_data)
 199    )
 200    return i1, i2
 201
 202
 203def Fourier_filter(
 204    data: ArrayLike,
 205    time_step: float,
 206    spectral_window: ArrayLike,
 207    periodic: bool = False,
 208) -> ArrayLike:
 209    """Applies a Fourier filter to time-series data.
 210
 211    The function performs a Fourier transform, multiplies by a spectral
 212    window, and then performs an inverse Fourier transform.
 213
 214    Args:
 215      data: A 1D array or a 2D array (num_time_points, num_series)
 216            where each column is a time series.
 217      time_step: The time step between data points.
 218      spectral_window: A 2D array (num_window_points, 2) where the first
 219                       column contains circular frequencies (ascending) and
 220                       the second column contains the window function W(omega).
 221      periodic: If True, data is assumed to be periodic. Otherwise, data is
 222                mirrored and padded to reduce edge effects.
 223
 224    Returns:
 225      The filtered data, with the same shape as the input `data`.
 226    """
 227    if spectral_window.shape[0] == 0:
 228        return data.copy()  # Return a copy to match behavior when filtering occurs
 229
 230    original_shape = data.shape
 231    if data.ndim == 1:
 232        # Reshape 1D array to 2D for consistent processing
 233        current_data = data.reshape(-1, 1)
 234    else:
 235        current_data = data.copy()  # Work on a copy
 236
 237    num_time_points = current_data.shape[0]
 238
 239    if not periodic:
 240        # Mirror and concatenate data for non-periodic signals
 241        # Effectively doubles the number of time points for FFT
 242        current_data = np.vstack((current_data, current_data[::-1, :]))
 243        num_time_points_fft = current_data.shape[0]
 244    else:
 245        num_time_points_fft = num_time_points
 246
 247    # Fourier transform
 248    Fourier_transformed_data = np.fft.fftshift(np.fft.fft(current_data, axis=0), axes=0)
 249
 250    # Create frequency grid for the spectral window
 251    delta_omega = 2 * np.pi / (num_time_points_fft * time_step)
 252    # np.arange needs to create num_time_points_fft points
 253    # fftshift moves 0 frequency to the center.
 254    # The indices for fftshift range from -N/2 to N/2-1 (approx)
 255    omega_grid_indices = np.arange(num_time_points_fft) - np.floor(
 256        num_time_points_fft / 2.0
 257    )
 258    omega_grid = delta_omega * omega_grid_indices
 259
 260    # Interpolate spectral window onto the data's frequency grid
 261    # Ensure omega_grid is 1D for interpolation if it became 2D due to reshape
 262    window_values = np.interp(
 263        np.abs(omega_grid.ravel()),  # Use absolute frequencies
 264        spectral_window[:, 0],
 265        spectral_window[:, 1],
 266        left=1.0,  # Value for frequencies below spectral_window range
 267        right=0.0,  # Value for frequencies above spectral_window range
 268    )
 269
 270    # Apply spectral filter
 271    # window_values needs to be (num_time_points_fft, 1) to broadcast
 272    Fourier_transformed_data *= window_values.reshape(-1, 1)
 273
 274    # Inverse Fourier transform
 275    filtered_data_full = np.fft.ifft(
 276        np.fft.ifftshift(Fourier_transformed_data, axes=0), axis=0
 277    )
 278
 279    if not periodic:
 280        # Truncate to original length if data was mirrored
 281        filtered_data_final = filtered_data_full[:num_time_points, :]
 282    else:
 283        filtered_data_final = filtered_data_full
 284
 285    # Ensure output is real if input was real
 286    if np.all(np.isreal(data)):
 287        filtered_data_final = filtered_data_final.real
 288
 289    return filtered_data_final.reshape(original_shape)
 290
 291
 292def polyfit_with_weights(
 293    x_coords: ArrayLike, y_values: ArrayLike, weights: ArrayLike, degree: int
 294) -> ArrayLike:
 295    """Performs a weighted least-squares polynomial fit.
 296
 297    Args:
 298      x_coords: 1D array of sample point x-coordinates.
 299      y_values: 1D array of sample point y-values to fit.
 300      weights: 1D array of weights for each sample point.
 301      degree: Degree of the fitting polynomial.
 302
 303    Returns:
 304      A 1D array of polynomial coefficients [p_degree, ..., p_1, p_0].
 305    """
 306    num_coeffs = degree + 1
 307    matrix_a = np.empty((num_coeffs, num_coeffs), dtype=np.float64)
 308    weights_squared = weights * weights
 309
 310    # Construct the Vandermonde-like matrix A for the normal equations
 311    # A[i,j] = sum(w_k^2 * x_k^(i+j))
 312    for i in range(num_coeffs):
 313        for j in range(i, num_coeffs):
 314            matrix_a[i, j] = np.sum(weights_squared * (x_coords ** (i + j)))
 315            if i != j:
 316                matrix_a[j, i] = matrix_a[i, j]  # Symmetric matrix
 317
 318    # Construct the vector b for the normal equations
 319    # b[i] = sum(w_k^2 * y_k * x_k^i)
 320    vector_b = np.empty(num_coeffs, dtype=np.float64)
 321    for i in range(num_coeffs):
 322        vector_b[i] = np.sum(weights_squared * y_values * (x_coords**i))
 323
 324    # Solve the linear system A * p = b for coefficients p
 325    solution_coeffs = scipy.linalg.solve(matrix_a, vector_b)
 326    return solution_coeffs[::-1]  # Return in conventional order (highest power first)
 327
 328
 329def Fourier_transform(
 330    time_points: np.ndarray,
 331    y_data: np.ndarray,
 332    target_frequencies: Optional[np.ndarray] = None,
 333    is_periodic: bool = False,
 334    pulse_center_times: Optional[Union[float, np.ndarray]] = None,
 335) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
 336    r"""Applies the Fast Fourier Transform (FFT) in an easy-to-use way.
 337
 338    This function computes the Fourier transform of time-dependent data Y(t),
 339    defined as:
 340        $F[Y](\omega) = \int_{-\infty}^\infty dt Y(t) \exp(i \omega t)$
 341
 342    Args:
 343        time_points: A 1D NumPy array of shape (N_t,) representing the time
 344            discretization of Y(t). Must be sorted in ascending order. N_t must
 345            be greater than 1.
 346        y_data: A NumPy array of shape (N_t, N_data) or (N_t,) containing the
 347            time-dependent data to Fourier transform. If 1D, it's treated as
 348            a single data series.
 349        target_frequencies: An optional 1D NumPy array of shape (N_omega,)
 350            specifying the circular frequencies at which to compute the
 351            transform. Must be sorted in ascending order. If None, the transform
 352            is returned on an optimal internal frequency grid.
 353        is_periodic: A boolean indicating if Y(t) is assumed to be periodic.
 354            If True, the time step is expected to be constant, and
 355            Y(t_end + dt) = Y(t_start). If False, precautions are taken to
 356            minimize artifacts from the DFT's implicit periodicity assumption.
 357        pulse_center_times: An optional scalar or 1D NumPy array of length
 358            N_data. If the input data represents a pulse not centered in its
 359            time window, specifying the center of each pulse (t0) helps avoid
 360            interpolation artifacts. If None, the center of the time window is
 361            used.
 362
 363    Returns:
 364        If `target_frequencies` is None:
 365            A tuple `(transformed_y, fft_omega_grid)`, where:
 366            - `transformed_y` (np.ndarray): The Fourier-transformed data,
 367              shape (N_fft, N_data) or (N_fft,) if input `y_data` was 1D.
 368            - `fft_omega_grid` (np.ndarray): The array of circular frequencies
 369              (shape (N_fft,)) corresponding to `transformed_y`.
 370        If `target_frequencies` is not None:
 371            An np.ndarray of shape (N_omega, N_data) or (N_omega,) containing
 372            the Fourier transform of `y_data` evaluated at the specified
 373            `target_frequencies`.
 374
 375    Raises:
 376        ValueError: If `time_points` has fewer than 2 elements.
 377    """
 378    _MAX_FFT_POINTS = 2**20
 379    num_time_points = len(time_points)
 380    if num_time_points <= 1:
 381        raise ValueError("Input 'time_points' array must have more than one element.")
 382
 383    original_y_ndim = y_data.ndim
 384    if original_y_ndim == 1:
 385        # Reshape 1D y_data to (N_t, 1) for consistent processing
 386        y_data_processed = y_data.reshape((num_time_points, 1))
 387    else:
 388        y_data_processed = y_data
 389
 390    num_data_series = y_data_processed.shape[1]
 391
 392    # Determine the minimum time step from the input time_points
 393    min_time_step = np.min(np.diff(time_points))
 394    # This will be the effective time step for FFT. It might be modified later
 395    # for non-periodic cases to meet Nyquist criteria for target_frequencies.
 396    effective_time_step = min_time_step
 397
 398    period_duration: Optional[float] = None
 399    if is_periodic:
 400        # For periodic signals, the period is the total duration assuming
 401        # a constant step.
 402        period_duration = (
 403            time_points[-1] - time_points[0] + np.mean(np.diff(time_points))
 404        )
 405
 406    # Handle the case where only one target frequency is given (no FFT needed)
 407    if target_frequencies is not None and len(target_frequencies) == 1:
 408        # Direct integration using trapezoidal rule
 409        # integrand shape: (num_time_points, num_data_series)
 410        integrand = y_data_processed * np.exp(
 411            1j * target_frequencies * time_points.reshape((num_time_points, 1))
 412        )
 413        # result shape: (num_data_series,)
 414        output_spectrum = np.trapezoid(integrand, time_points, axis=0)
 415        if is_periodic:
 416            # Correction for periodic functions with trapezoidal rule
 417            output_spectrum += (
 418                0.5 * min_time_step * (integrand[0, :] + integrand[-1, :])
 419            )
 420        # Reshape to (1, num_data_series) to match expected output shape
 421        output_spectrum = output_spectrum.reshape(1, num_data_series)
 422        if original_y_ndim == 1:
 423            return output_spectrum.reshape(len(target_frequencies))  # or .flatten()
 424        return output_spectrum
 425
 426    # Determine the target frequency resolution for FFT grid calculation
 427    dw_target_for_n_fft_calc: float
 428    if target_frequencies is None:
 429        # If no target frequencies, FFT resolution is based on input time window
 430        dw_target_for_n_fft_calc = 2 * np.pi / (min_time_step * num_time_points)
 431    else:
 432        # If target frequencies are given, use their minimum spacing
 433        dw_target_for_n_fft_calc = np.min(np.diff(target_frequencies))
 434
 435    # Determine the number of points for FFT (num_fft_points)
 436    num_fft_points: int
 437    if is_periodic:
 438        num_fft_points = num_time_points
 439        # If target frequencies are specified, may need to upsample N_fft
 440        # to achieve finer resolution than the default N_t points.
 441        # effective_time_step remains min_time_step for periodic case.
 442        if target_frequencies is not None:
 443            while (
 444                2 * np.pi / (effective_time_step * num_fft_points)
 445                > 1.1 * dw_target_for_n_fft_calc
 446                and num_fft_points < _MAX_FFT_POINTS
 447            ):
 448                num_fft_points *= 2
 449    else:  # Not periodic
 450        # Initial estimate for num_fft_points based on desired resolution
 451        num_fft_points = 2 ** int(
 452            round(
 453                np.log(2 * np.pi / (min_time_step * dw_target_for_n_fft_calc))
 454                / np.log(2.0)
 455            )
 456        )
 457        # Ensure FFT time window is large enough to cover the original data span
 458        while (num_fft_points - 1) * min_time_step < time_points[-1] - time_points[
 459            0
 460        ] and num_fft_points < _MAX_FFT_POINTS:
 461            num_fft_points *= 2
 462
 463        # For non-periodic signals with specified target_frequencies,
 464        # adjust effective_time_step and num_fft_points to satisfy Nyquist
 465        # and resolution requirements for the target_frequencies.
 466        if target_frequencies is not None:
 467            # Use temporary variables for this iterative adjustment
 468            current_dt = min_time_step
 469            current_n_fft = num_fft_points
 470            while current_n_fft < _MAX_FFT_POINTS and (
 471                target_frequencies[-1]
 472                > np.pi / current_dt * (1.0 - 2.0 / current_n_fft)
 473                or -target_frequencies[0] > np.pi / current_dt
 474            ):
 475                current_dt /= 2.0
 476                current_n_fft *= 2
 477            effective_time_step = current_dt
 478            num_fft_points = current_n_fft
 479
 480    # FFT time grid generation
 481    # This grid is centered around 0 after shifting by pulse_center_times
 482    fft_time_grid = effective_time_step * np.fft.ifftshift(
 483        np.arange(num_fft_points) - num_fft_points // 2
 484    )
 485
 486    # Determine effective pulse center times for interpolation and phase shift
 487    # pulse_centers_for_interp: used to shift time_points before interpolation
 488    # pulse_centers_for_phase: used for final phase correction, shape (1, N_data)
 489
 490    _pulse_centers_for_interp: Union[float, np.ndarray]
 491    _pulse_centers_for_phase: np.ndarray
 492
 493    if pulse_center_times is None:
 494        # If no t0 provided, use the center of the time window
 495        time_window_center = 0.5 * (time_points[0] + time_points[-1])
 496        # Find the closest point in time_points to this center
 497        idx_center = np.argmin(np.abs(time_points - time_window_center))
 498        # This becomes the t0 for all data series
 499        calculated_t0 = time_points[idx_center]
 500        _pulse_centers_for_interp = calculated_t0  # Scalar
 501        _pulse_centers_for_phase = calculated_t0 * np.ones((1, num_data_series))
 502    else:
 503        _pulse_centers_for_interp = pulse_center_times  # Can be scalar or array
 504        if np.isscalar(pulse_center_times):
 505            _pulse_centers_for_phase = pulse_center_times * np.ones(
 506                (1, num_data_series)
 507            )
 508        else:
 509            # Ensure it's a 1D array before reshaping
 510            _pulse_centers_for_phase = np.asarray(pulse_center_times).reshape(
 511                (1, num_data_series)
 512            )
 513
 514    # Interpolate y_data onto the FFT time grid
 515    # y_interpolated_on_fft_grid shape: (num_fft_points, num_data_series)
 516    y_interpolated_on_fft_grid = np.zeros(
 517        (num_fft_points, num_data_series), dtype=y_data_processed.dtype
 518    )
 519
 520    for j_col in range(num_data_series):
 521        current_t0_interp: float
 522        if np.isscalar(_pulse_centers_for_interp):
 523            current_t0_interp = float(_pulse_centers_for_interp)
 524        else:
 525            current_t0_interp = float(np.asarray(_pulse_centers_for_interp)[j_col])
 526
 527        # Shift original time_points relative to the current pulse center
 528        shifted_time_points = time_points - current_t0_interp
 529
 530        if is_periodic:
 531            # For periodic data, use period in interpolation
 532            y_interpolated_on_fft_grid[:, j_col] = np.interp(
 533                fft_time_grid,
 534                shifted_time_points,
 535                y_data_processed[:, j_col],
 536                period=period_duration,
 537            )
 538        else:
 539            # For non-periodic, pad with zeros outside original time range
 540            y_interpolated_on_fft_grid[:, j_col] = np.interp(
 541                fft_time_grid,
 542                shifted_time_points,
 543                y_data_processed[:, j_col],
 544                left=0.0,
 545                right=0.0,
 546            )
 547
 548    # Perform FFT
 549    # The result of ifft is scaled by (1/N). We multiply by (N*dt) to approximate
 550    # the integral definition F(omega) = integral Y(t)exp(i*omega*t) dt.
 551    # So, overall scaling is dt.
 552    y_fft = np.fft.fftshift(np.fft.ifft(y_interpolated_on_fft_grid, axis=0), axes=0) * (
 553        num_fft_points * effective_time_step
 554    )
 555
 556    # FFT omega grid
 557    fft_omega_grid_spacing = 2 * np.pi / (effective_time_step * num_fft_points)
 558    fft_omega_grid = fft_omega_grid_spacing * (
 559        np.arange(num_fft_points) - num_fft_points // 2
 560    )
 561
 562    # Apply phase correction due to pulse_center_times (t0)
 563    # This accounts for the shift Y(t) -> Y(t-t0) in time domain,
 564    # which corresponds to F(omega) -> F(omega) * exp(i*omega*t0)
 565    # if the FFT was performed on data effectively centered at t'=0.
 566    # The interpolation shifted data by -t0, so Z was Y(t').
 567    # The FFT of Y(t') is F[Y(t')] = integral Y(t')exp(iwt')dt'.
 568    # We want F[Y(t)] = integral Y(t)exp(iwt)dt.
 569    # F[Y(t)] = exp(iw t0_effective) * F[Y(t')].
 570    phase_correction = np.exp(
 571        1j * _pulse_centers_for_phase * fft_omega_grid.reshape((num_fft_points, 1))
 572    )
 573    y_fft_corrected = y_fft * phase_correction
 574
 575    if target_frequencies is None:
 576        # Return FFT result on its own grid
 577        if original_y_ndim == 1:
 578            return y_fft_corrected.flatten(), fft_omega_grid
 579        return y_fft_corrected, fft_omega_grid
 580    else:
 581        # Interpolate FFT result onto the target_frequencies grid
 582        output_spectrum = np.zeros(
 583            (len(target_frequencies), num_data_series), dtype=np.complex128
 584        )
 585        for j_col in range(num_data_series):
 586            # Note: y_fft_corrected already includes the phase shift based on _pulse_centers_for_phase
 587            # and fft_omega_grid. When interpolating to target_frequencies, this phase is implicitly
 588            # interpolated as well.
 589            output_spectrum[:, j_col] = np.interp(
 590                target_frequencies,
 591                fft_omega_grid,
 592                y_fft_corrected[:, j_col],  # Use the phase-corrected FFT result
 593                left=0.0,
 594                right=0.0,
 595            )
 596
 597        # The phase correction was already applied to y_fft before interpolation.
 598        # If we were to apply it *after* interpolation, it would be:
 599        # phase_correction_on_target_freq = np.exp(
 600        #    1j * _pulse_centers_for_phase * target_frequencies.reshape((len(target_frequencies), 1))
 601        # )
 602        # output_spectrum = interpolated_unphased_result * phase_correction_on_target_freq
 603        # However, the original code applies the phase correction *before* this final interpolation step
 604        # if omega is None, and *after* if omega is not None.
 605        # Let's re-check original logic for omega not None:
 606        # Z = np.fft.fftshift(np.fft.ifft(Z, axis=0), axes=0) * (N_fft * dt) <-- y_fft (unphased by t0 yet for this path)
 607        # ...
 608        # result[:,j] = np.interp(omega, w_grid, Z[:,j], left=0.0, right=0.0) <-- interpolation of unphased
 609        # result = result * np.exp(1j * t0 * omega.reshape((len(omega), 1))) <-- phase correction
 610        # This means my current y_fft_corrected (which has phase) should NOT be used for interpolation here.
 611        # I should interpolate 'y_fft' (before t0 correction) and then apply t0 correction using target_frequencies.
 612
 613        # Reverting to match original logic for target_frequencies path:
 614        # Interpolate the raw FFT result (before t0 correction)
 615        interpolated_raw_fft = np.zeros(
 616            (len(target_frequencies), num_data_series), dtype=np.complex128
 617        )
 618        for j_col in range(num_data_series):
 619            interpolated_raw_fft[:, j_col] = np.interp(
 620                target_frequencies,
 621                fft_omega_grid,
 622                y_fft[
 623                    :, j_col
 624                ],  # Use y_fft (before _pulse_centers_for_phase correction)
 625                left=0.0,
 626                right=0.0,
 627            )
 628
 629        # Now apply phase correction using _pulse_centers_for_phase and target_frequencies
 630        phase_correction_final = np.exp(
 631            1j
 632            * _pulse_centers_for_phase
 633            * target_frequencies.reshape((len(target_frequencies), 1))
 634        )
 635        output_spectrum = interpolated_raw_fft * phase_correction_final
 636
 637        if original_y_ndim == 1:
 638            return output_spectrum.flatten()
 639        return output_spectrum
 640
 641
 642def inverse_Fourier_transform(
 643    omega_points: ArrayLike,
 644    data_series: ArrayLike,
 645    time_points_target: Optional[ArrayLike] = None,
 646    is_periodic: bool = False,
 647    frequency_offset: Optional[Union[float, ArrayLike]] = None,
 648) -> Union[Tuple[ArrayLike, ArrayLike], ArrayLike]:
 649    r"""Applies inverse FFT to frequency-dependent data.
 650
 651    Computes $ F^{-1}[Y](t) = 1 / (2 \pi) \int_{-\infty}^\infty d\omega Y(\omega) \exp(-i t \omega) $.
 652
 653    Args:
 654      omega_points: 1D array of circular frequencies (N_omega), sorted.
 655      data_series: Frequency-dependent data (N_omega) or (N_omega, N_data_series).
 656      time_points_target: Optional 1D array of time points (N_t), sorted.
 657                          If None, times are determined by IFFT.
 658      is_periodic: If True, data_series is assumed periodic in frequency.
 659      frequency_offset: Scalar or array (N_data_series). Central frequency
 660                        offset(s) if data is not centered at omega=0.
 661
 662    Returns:
 663      If time_points_target is None:
 664        A tuple (transformed_data, time_grid).
 665      If time_points_target is provided:
 666        Transformed data interpolated at the given time_points_target.
 667    """
 668    # IFFT(Y(w)) = 1/(2pi) FT(Y(w))_at_-t = 1/(2pi) conj(FT(conj(Y(w)))_at_t)
 669    # The provided Fourier_transform computes FT[Y(t)](omega) = integral Y(t) exp(iwt) dt
 670    # We want IFT[Y(w)](t) = 1/(2pi) integral Y(w) exp(-iwt) dw
 671    # Let w' = -w, dw' = -dw.
 672    # = -1/(2pi) integral Y(-w') exp(iw't) dw' (from -inf to +inf, so limits flip)
 673    # = 1/(2pi) integral Y(-w') exp(iw't) dw' (from -inf to +inf)
 674    # So, call Fourier_transform with omega -> -omega (reversed), Y(omega) -> Y(-omega) (reversed)
 675    # and then scale by 1/(2pi). The 't' in Fourier_transform becomes our 'omega_points',
 676    # and 'omega' in Fourier_transform becomes our '-time_points_target'.
 677
 678    if time_points_target is None:
 679        # Transform Y(omega) as if it's a time signal Y(t=omega)
 680        # The 'omega' output of Fourier_transform will correspond to '-t'
 681        transformed_data, neg_time_grid = Fourier_transform(
 682            time_points=omega_points,
 683            y_data=data_series,
 684            target_frequencies=None,  # Let FT determine output grid
 685            is_periodic=is_periodic,
 686            pulse_center_times=frequency_offset,  # This is omega0, an offset in the input "time" (omega) domain
 687        )
 688        # Result is FT[Y](k), where k is frequency. Here k corresponds to -t.
 689        # So, FT[Y(omega)](-t). We need to flip t and scale.
 690        return transformed_data[::-1] / (2 * np.pi), -neg_time_grid[::-1]
 691    else:
 692        # Target 'omega' for Fourier_transform is -time_points_target
 693        neg_target_times = -time_points_target[::-1]  # Ensure it's sorted for FT
 694
 695        result_at_neg_t = Fourier_transform(
 696            time_points=omega_points,
 697            y_data=data_series,
 698            target_frequencies=neg_target_times,
 699            is_periodic=is_periodic,
 700            pulse_center_times=frequency_offset,
 701        )
 702        # result_at_neg_t is FT[Y(omega)](-t_target_sorted)
 703        # We want values at t_target, so reverse the order back.
 704        return result_at_neg_t[::-1] / (2 * np.pi)
 705
 706
 707def find_zero_crossings(x_values: ArrayLike, y_values: ArrayLike) -> ArrayLike:
 708    """Finds all x-values where linearly interpolated y(x) = 0.
 709
 710    Args:
 711      x_values: 1D array of x-coordinates, sorted ascending, no duplicates.
 712      y_values: 1D array of y-coordinates, same shape as x_values.
 713
 714    Returns:
 715      A 1D array of x-values where y(x) crosses zero. Empty if no crossings.
 716    """
 717    if x_values.size == 0 or y_values.size == 0:
 718        return np.array([])
 719    if x_values.size != y_values.size:
 720        raise ValueError("x_values and y_values must have the same length.")
 721
 722    # Product of y[i] and y[i+1]
 723    product_adjacent_y = y_values[:-1] * y_values[1:]
 724    crossings_x_coords: List[float] = []
 725
 726    # Find indices where product is <= 0 (indicates a zero crossing or y[i]=0)
 727    for i in np.where(product_adjacent_y <= 0)[0]:
 728        # Instead of: if product_adjacent_y[i] == 0:
 729        #                 if y_values[i] == 0:
 730        # Use np.isclose for checking if y_values[i] or y_values[i+1] are zero
 731        y1_is_zero = np.isclose(y_values[i], 0.0)
 732        y2_is_zero = np.isclose(y_values[i + 1], 0.0)
 733
 734        if y1_is_zero and y2_is_zero:  # segment is [0,0]
 735            crossings_x_coords.append(x_values[i])
 736            # To avoid double adding x_values[i+1] if it's processed as y1_is_zero in next iter
 737        elif y1_is_zero:
 738            crossings_x_coords.append(x_values[i])
 739        elif (
 740            y2_is_zero and product_adjacent_y[i] < 0
 741        ):  # Crosses and lands on zero at y2
 742            # The interpolation formula will give x_values[i+1]
 743            x1, x2_pt = x_values[i], x_values[i + 1]
 744            y1_pt, y2_pt = y_values[i], y_values[i + 1]  # y2_pt is close to 0
 745            crossings_x_coords.append((x1 * y2_pt - x2_pt * y1_pt) / (y2_pt - y1_pt))
 746        elif product_adjacent_y[i] < 0:  # Definite crossing, neither is zero
 747            x1, x2 = x_values[i], x_values[i + 1]
 748            y1_val, y2_val = y_values[i], y_values[i + 1]
 749            crossings_x_coords.append((x1 * y2_val - x2 * y1_val) / (y2_val - y1_val))
 750
 751    # Handle case where the last point itself is a zero not caught by pair product
 752    # This also needs np.isclose
 753    if y_values.size > 0 and np.isclose(y_values[-1], 0.0):
 754        # Avoid adding if it's already part of a segment ending in zero
 755        # that was captured by product_adjacent_y[i]=0 logic (where y[i+1]=0)
 756        already_found = False
 757        if crossings_x_coords and np.isclose(crossings_x_coords[-1], x_values[-1]):
 758            already_found = True
 759
 760        if not already_found:
 761            # If y_values[-1] is zero, and y_values[-2]*y_values[-1] was not <=0 (e.g. y_values[-2] also zero)
 762            # or it was handled by interpolation which might be slightly off.
 763            # We want to ensure grid points that are zero are included.
 764            # A simpler way: collect all interpolated, then add all x where y is zero, then unique.
 765            pass  # The unique call later should handle it if x_values[-1] was added by main loop
 766
 767    # A more robust approach for points exactly on the grid:
 768    # After interpolation, add all x_values where corresponding y_values are close to zero.
 769    if x_values.size > 0:  # Ensure x_values is not empty
 770        grid_zeros = x_values[np.isclose(y_values, 0.0)]
 771        crossings_x_coords.extend(list(grid_zeros))
 772
 773    return np.unique(np.array(crossings_x_coords))
 774
 775
 776def find_extrema_positions(x_values: ArrayLike, y_values: ArrayLike) -> ArrayLike:
 777    """Finds x-positions of local extrema in y(x).
 778
 779    Extrema are found where the derivative y'(x) (approximated by finite
 780    differences) crosses zero.
 781
 782    Args:
 783      x_values: 1D array of x-coordinates, sorted ascending, no duplicates.
 784      y_values: 1D array of y-coordinates, same shape as x_values.
 785
 786    Returns:
 787      A 1D array of x-values where y(x) has local extrema. Empty if none.
 788    """
 789    if (
 790        len(x_values) < 2 or len(y_values) < 2
 791    ):  # Need at least two points for a derivative
 792        return np.array([])
 793    if len(x_values) != len(y_values):
 794        raise ValueError("x_values and y_values must have the same length.")
 795
 796    # Approximate derivative y'(x)
 797    delta_y = y_values[1:] - y_values[:-1]
 798    delta_x = x_values[1:] - x_values[:-1]
 799    # Avoid division by zero if x_values have duplicates (though pre-condition says no duplicates)
 800    # However, if delta_x is extremely small, derivative can be huge.
 801    # For robustness, filter out zero delta_x if they somehow occur.
 802    valid_dx = delta_x != 0
 803    if not np.all(valid_dx):  # Should not happen given preconditions
 804        delta_y = delta_y[valid_dx]
 805        delta_x = delta_x[valid_dx]
 806        mid_points_x_for_derivative = (x_values[1:] + x_values[:-1])[valid_dx] / 2.0
 807    else:
 808        mid_points_x_for_derivative = (x_values[1:] + x_values[:-1]) / 2.0
 809
 810    if delta_x.size == 0:  # Not enough points after filtering
 811        return np.array([])
 812
 813    derivative_y = delta_y / delta_x
 814
 815    # Find where the derivative crosses zero
 816    extrema_x_coords = find_zero_crossings(mid_points_x_for_derivative, derivative_y)
 817    return extrema_x_coords
 818
 819
 820def minimize_imaginary_parts(complex_array: ArrayLike) -> ArrayLike:
 821    """Rotates a complex array by a phase to make it as close as possible to being real-valued
 822
 823    Multiplies `complex_array` by `exp(1j*phi)` where `phi` is chosen to
 824    minimize `sum(imag(exp(1j*phi) * complex_array)**2)`.
 825
 826    Args:
 827      complex_array: A NumPy array of complex numbers.
 828
 829    Returns:
 830      The phase-rotated complex NumPy array.
 831    """
 832    if complex_array.size == 0:
 833        return complex_array.copy()
 834
 835    # Z = X + iY. We want to minimize sum( (X sin(phi) + Y cos(phi))^2 )
 836    # d/dphi (sum(...)) = 0 leads to tan(2*phi) = 2*sum(XY) / sum(Y^2 - X^2)
 837    real_part = complex_array.real
 838    imag_part = complex_array.imag
 839
 840    numerator = 2 * np.sum(real_part * imag_part)
 841    denominator = np.sum(imag_part**2 - real_part**2)
 842
 843    # arctan2 handles signs and denominator being zero correctly
 844    phi = 0.5 * np.arctan2(numerator, denominator)
 845
 846    # The arctan2 gives phi in (-pi, pi], so 0.5*phi is in (-pi/2, pi/2].
 847    # This finds one extremum. The other is phi + pi/2. We need the minimum.
 848    rotated_z1 = complex_array * np.exp(1j * phi)
 849    imag_energy1 = np.sum(rotated_z1.imag**2)
 850
 851    rotated_z2 = complex_array * np.exp(1j * (phi + 0.5 * np.pi))
 852    imag_energy2 = np.sum(rotated_z2.imag**2)
 853
 854    if imag_energy2 < imag_energy1:
 855        phi += 0.5 * np.pi
 856
 857    # Normalize phi to be in (-pi/2, pi/2] or a similar principal range if desired,
 858    # though for exp(1j*phi) it doesn't strictly matter beyond 2pi periodicity.
 859    # The original code maps phi to (-pi/2, pi/2] effectively.
 860    phi -= np.pi * np.round(phi / np.pi)  # This maps to (-pi/2, pi/2]
 861    # Let's test the original normalization:
 862    # If phi = 0.6*pi, round(0.6) = 1. phi = 0.6pi - pi = -0.4pi. Correct.
 863    # If phi = 0.4*pi, round(0.4) = 0. phi = 0.4pi. Correct.
 864    # If phi = -0.6*pi, round(-0.6) = -1. phi = -0.6pi + pi = 0.4pi. Correct.
 865    # This normalization is fine.
 866
 867    return complex_array * np.exp(1j * phi)
 868
 869
 870def integrate_oscillating_function(
 871    x_values: ArrayLike,
 872    func_values: ArrayLike,
 873    phase_values: ArrayLike,
 874    phase_step_threshold: float = 1e-3,
 875) -> ArrayLike:
 876    r"""Integrates f(x) * exp(i * phi(x)) for quickly oscillating functions.
 877
 878    Uses an algorithm suitable for integrating f(x) * exp(i * phi(x)) dx
 879    over small intervals, particularly when phi(x) changes rapidly.
 880
 881    Args:
 882      x_values: 1D array of sorted x-coordinates.
 883      func_values: Array of function values f(x). Can be 1D (N_x) or 2D
 884                   (N_x, N_series).
 885      phase_values: Array of real-valued phase phi(x). Same shape as func_values.
 886      phase_step_threshold: Small positive number. Prevents division by
 887                            small d_phi in the integration formula.
 888
 889    Returns:
 890      A scalar or 1D array (N_series) of integral results.
 891    """
 892    # Input validation
 893    if not (x_values.shape[0] == func_values.shape[0] == phase_values.shape[0]):
 894        raise ValueError(
 895            "x_values, func_values, and phase_values must have "
 896            "the same length along the integration axis (axis 0)."
 897        )
 898    if not np.allclose(np.imag(phase_values), 0):  # Ensure phase is real
 899        raise ValueError("phase_values must be real-valued.")
 900    if func_values.ndim > 1 and func_values.shape != phase_values.shape:
 901        raise ValueError(
 902            "If func_values is 2D, phase_values must have the exact same shape."
 903        )
 904
 905    delta_x = x_values[1:] - x_values[:-1]
 906
 907    # Prepare f and phi for interval calculations
 908    f1 = func_values[:-1, ...]  # f(x_i)
 909    f2 = func_values[1:, ...]  # f(x_{i+1})
 910    delta_f = f2 - f1
 911
 912    phi1 = phase_values[:-1, ...]  # phi(x_i)
 913    phi2 = phase_values[1:, ...]  # phi(x_{i+1})
 914    delta_phi = phi2 - phi1
 915
 916    # Reshape delta_x to broadcast with f1, f2, etc.
 917    # If func_values is (N_x, N_series), delta_x needs to be (N_x-1, 1)
 918    reshape_dims = (-1,) + (1,) * (func_values.ndim - 1)
 919    delta_x_reshaped = delta_x.reshape(reshape_dims)
 920
 921    # Common factor for the integral segments
 922    common_factor_z = delta_x_reshaped * np.exp(0.5j * (phi1 + phi2))
 923
 924    integral_segments = np.zeros_like(common_factor_z, dtype=complex)
 925
 926    # Mask for small phase changes (use simpler approximation)
 927    is_small_delta_phi = np.abs(delta_phi) < phase_step_threshold
 928
 929    # Case 1: Small delta_phi (dphi is small)
 930    if np.any(is_small_delta_phi):
 931        # Approximation: integral \approx dx * exp(i*phi_avg) * (f_avg + i/8 * dphi * df)
 932        # This seems to be a higher-order trapezoidal rule for oscillating functions.
 933        # Original: Z[s] = Z[s] * (0.5 * (f1[s] + f2[s]) + 0.125j * dphi[s] * df[s])
 934        # where Z[s] was common_factor_z[is_small_delta_phi]
 935        term_small_dphi = (
 936            0.5 * (f1[is_small_delta_phi] + f2[is_small_delta_phi])
 937            + 0.125j * delta_phi[is_small_delta_phi] * delta_f[is_small_delta_phi]
 938        )
 939        integral_segments[is_small_delta_phi] = (
 940            common_factor_z[is_small_delta_phi] * term_small_dphi
 941        )
 942
 943    # Case 2: Large delta_phi (use formula for oscillating part)
 944    is_large_delta_phi = ~is_small_delta_phi
 945    if np.any(is_large_delta_phi):
 946        # This is likely an approximation based on integration by parts or steepest descent.
 947        # Original: Z[s] = Z[s] / dphi[s]**2 * (exp_term * (df[s] - 1j*f2[s]*dphi[s]) -
 948        #                                     (df[s] - 1j*f1[s]*dphi[s]) / exp_term)
 949        # where Z[s] was common_factor_z[is_large_delta_phi] and exp_term = exp(0.5j * dphi[s])
 950
 951        dphi_large = delta_phi[is_large_delta_phi]
 952        exp_half_j_dphi = np.exp(0.5j * dphi_large)
 953
 954        term1 = exp_half_j_dphi * (
 955            delta_f[is_large_delta_phi] - 1j * f2[is_large_delta_phi] * dphi_large
 956        )
 957        term2 = (
 958            delta_f[is_large_delta_phi] - 1j * f1[is_large_delta_phi] * dphi_large
 959        ) / exp_half_j_dphi
 960
 961        integral_segments[is_large_delta_phi] = (
 962            common_factor_z[is_large_delta_phi] / (dphi_large**2) * (term1 - term2)
 963        )
 964
 965    return np.sum(integral_segments, axis=0)
 966
 967
 968def calculate_permittivity_from_delta_polarization(
 969    time_step: float,
 970    polarization_delta_response: ArrayLike,  # P_delta
 971    omega_array: ArrayLike,
 972    momentum_relaxation_rate: float = 0.0,
 973    dephasing_time: Optional[float] = None,
 974    disregard_drift_current: bool = False,
 975    allow_for_linear_displacement: bool = True,
 976) -> ArrayLike:
 977    r"""Evaluates permittivity from polarization induced by E(t) = delta(t).
 978
 979    Handles drift currents and coherent oscillations in the polarization response.
 980    The relationship is $\epsilon(\omega) = 1 + 4 \pi \chi(\omega)$, where
 981    $P(\omega) = \chi(\omega) E(\omega)$, and for $E(t)=\delta(t)$, $E(\omega)=1$.
 982    So $\chi(\omega) = P_{\delta}(\omega)$.
 983
 984    Args:
 985      time_step: Time step (atomic units) of the polarization grid.
 986      polarization_delta_response: 1D array of polarization response P(t)
 987                                   (atomic units) induced by E(t)=delta(t).
 988                                   P_delta[0] corresponds to t=time_step.
 989      omega_array: 1D array of circular frequencies (a.u.) for permittivity calculation.
 990                   All frequencies must be non-zero.
 991      momentum_relaxation_rate: If non-zero, models Drude-like momentum relaxation
 992                                (gamma in 1/(omega*(omega+i*gamma))).
 993      dephasing_time: If not None/zero, an exponential decay (rate 1/dephasing_time)
 994                      is applied to coherent dipole oscillations in P_delta.
 995      disregard_drift_current: If True, the J_drift component of polarization
 996                               has no effect on the result.
 997      allow_for_linear_displacement: If True, fits P(t) ~ J_drift*t + P_offset.
 998                                     If False, fits P(t) ~ J_drift*t (P_offset=0).
 999
1000    Returns:
1001      A complex array (same shape as omega_array) of permittivity values.
1002    """
1003    if not np.all(omega_array != 0):
1004        raise ValueError("All elements in omega_array must be non-zero.")
1005
1006    # Construct time grid and full polarization array P(t), P(0)=0
1007    num_p_delta_points = polarization_delta_response.size
1008    # P_delta starts at t=dt, so P has N_t = num_p_delta_points + 1 points
1009    num_time_points = num_p_delta_points + 1
1010    time_grid = time_step * np.arange(num_time_points)
1011    time_max = time_grid[-1]
1012
1013    polarization_full = np.zeros(num_time_points)
1014    polarization_full[1:] = polarization_delta_response  # P(0)=0
1015
1016    # Fit and subtract linear trend (drift current and offset)
1017    # Fit is done on the latter half of the data
1018    fit_start_index = num_time_points // 2
1019    if (
1020        fit_start_index < 2 and num_time_points >= 2
1021    ):  # Need at least 2 points for polyfit(deg=1)
1022        fit_start_index = 0  # Use all data if too short for half
1023    elif num_time_points < 2:
1024        # Handle very short P_delta (e.g. 0 or 1 point)
1025        # If P_delta has 0 points, N_t=1, P_full=[0]. J_drift=0, P_offset=0.
1026        # If P_delta has 1 point, N_t=2, P_full=[0, P_d[0]].
1027        # polyfit needs at least deg+1 points.
1028        if num_time_points < 2:  # Cannot do polyfit
1029            J_drift = 0.0
1030            P_offset = 0.0
1031        elif allow_for_linear_displacement:  # N_t >= 2
1032            poly_coeffs = np.polyfit(
1033                time_grid[fit_start_index:], polarization_full[fit_start_index:], 1
1034            )
1035            J_drift = poly_coeffs[0]
1036            P_offset = poly_coeffs[1]
1037        else:  # N_t >= 2, P_offset = 0
1038            # P(t) = J_drift * t => J_drift = sum(P*t) / sum(t^2)
1039            # Ensure denominator is not zero if time_grid[fit_start_index:] is all zeros
1040            # (e.g., if fit_start_index is past end, or time_grid is [0,0,...])
1041            t_fit = time_grid[fit_start_index:]
1042            sum_t_squared = np.sum(t_fit**2)
1043            if sum_t_squared == 0:
1044                J_drift = 0.0
1045            else:
1046                J_drift = (
1047                    np.sum(polarization_full[fit_start_index:] * t_fit) / sum_t_squared
1048                )
1049            P_offset = 0.0
1050    else:  # Standard case N_t >= 2 and fit_start_index allows polyfit
1051        if allow_for_linear_displacement:
1052            poly_coeffs = np.polyfit(
1053                time_grid[fit_start_index:], polarization_full[fit_start_index:], 1
1054            )
1055            J_drift = poly_coeffs[0]
1056            P_offset = poly_coeffs[1]
1057        else:
1058            t_fit = time_grid[fit_start_index:]
1059            sum_t_squared = np.sum(t_fit**2)
1060            if sum_t_squared == 0:
1061                J_drift = 0.0
1062            else:
1063                J_drift = (
1064                    np.sum(polarization_full[fit_start_index:] * t_fit) / sum_t_squared
1065                )
1066            P_offset = 0.0
1067
1068    # Subtract the J_drift * t part from polarization_full. P_offset remains for now.
1069    polarization_oscillating = polarization_full - J_drift * time_grid
1070
1071    # Apply dephasing/windowing to the oscillating part (P - J_drift*t)
1072    # P_offset is part of the "DC" or very slow component, window it too.
1073    # The original code did: P_offset + window * (P - P_offset)
1074    # where P was (P_original - J_drift*t).
1075    # So, effectively: P_offset + window * (P_original - J_drift*t - P_offset)
1076
1077    if dephasing_time is None or dephasing_time == 0:
1078        # Soft window to zero if no explicit dephasing
1079        time_window = soft_window(time_grid, 0.5 * time_max, time_max)
1080    else:
1081        time_window = np.exp(-time_grid / dephasing_time) * soft_window(
1082            time_grid, 0.5 * time_max, time_max
1083        )
1084
1085    # Windowed polarization: P_offset is the value it decays from/to at t=0,
1086    # and the oscillating part (P_orig - J_drift*t - P_offset) is damped.
1087    processed_polarization = P_offset + time_window * (
1088        polarization_oscillating - P_offset
1089    )
1090
1091    permittivity_results = np.zeros_like(omega_array, dtype=complex)
1092
1093    for i, omega_val in enumerate(omega_array):
1094        # chi(omega) = FT[P_processed(t)](omega)
1095        # P_processed = P_offset_non_windowed + window * (P_osc - P_offset_non_windowed)
1096        # FT[P_processed] = FT[P_offset] + FT[window * (P_osc - P_offset)]
1097        # The original code integrated `processed_polarization` which is
1098        # P_offset + window * (P_original - J_drift*t - P_offset)
1099
1100        chi_omega = integrate_oscillating_function(
1101            time_grid, processed_polarization, omega_val * time_grid
1102        )
1103
1104        # Add analytical FT of the P_offset tail (if P_offset was not windowed to zero)
1105        # The `processed_polarization` already includes P_offset, partly windowed.
1106        # The original code adds: P_offset * 1j * np.exp(1j * omega * t_max) / omega
1107        # This looks like the FT of P_offset * Heaviside(t) if it extended from 0 to t_max
1108        # and then was abruptly cut, or FT of P_offset for t>t_max if window brought it to P_offset at t_max.
1109        # If processed_polarization(t_max) -> P_offset (due to window(t_max)=1),
1110        # and we assume P(t) = P_offset for t > t_max, its FT is P_offset * exp(i*omega*t_max) * (pi*delta(omega) + 1/(i*omega))
1111        # This term is tricky. The original `integrate_oscillating_function` handles up to t_max.
1112        # If the window makes processed_polarization(t_max) close to P_offset,
1113        # and we assume P(t) = P_offset for t > t_max, the integral from t_max to inf is
1114        # P_offset * integral_{t_max to inf} exp(i*omega*t) dt
1115        # = P_offset * [exp(i*omega*t) / (i*omega)]_{t_max to inf}
1116        # For convergence, Im(omega) > 0 or add damping. Assuming real omega, this diverges.
1117        # The term P_offset * 1j * np.exp(1j * omega * t_max) / omega
1118        # is -P_offset * exp(1j*omega*t_max) / (1j*omega). This is the upper limit of the integral.
1119        # It implies the lower limit (at infinity) is taken as zero.
1120        # This is the FT of P_offset * step_function(t_max - t) if integrated from -inf to t_max.
1121        # Or FT of P_offset for t > t_max, i.e. integral from t_max to infinity of P_offset*exp(iwt)
1122        # = P_offset * [exp(iwt)/(iw)]_{t_max to inf}. For this to be -P_offset*exp(iw*t_max)/(iw),
1123        # the exp(iw*inf) term must vanish, e.g. by small positive Im(w).
1124        # Let's assume this term correctly accounts for the P_offset tail beyond t_max.
1125        if not np.isclose(P_offset, 0.0):  # Only add if P_offset is significant
1126            chi_omega += P_offset * 1j * np.exp(1j * omega_val * time_max) / omega_val
1127
1128        # Add contribution from drift current J_drift / (omega * (omega + i*gamma_relax))
1129        if not disregard_drift_current and not np.isclose(J_drift, 0.0):
1130            denominator_drift = omega_val * (omega_val + 1j * momentum_relaxation_rate)
1131            if not np.isclose(
1132                denominator_drift, 0.0
1133            ):  # Avoid division by zero if omega=0 (already checked) or omega=-i*gamma
1134                chi_omega -= J_drift / denominator_drift
1135            # If denominator is zero (e.g. omega_val=0 or omega_val = -1j*gamma), this term is singular.
1136            # omega_val != 0 is asserted. If omega_val = -1j*gamma, it's a resonant condition.
1137
1138        permittivity_results[i] = 1.0 + 4 * np.pi * chi_omega
1139
1140    return permittivity_results
ArrayLike = numpy.ndarray[typing.Any, typing.Any]
def nextpow2(number: float) -> int:
22def nextpow2(number: float) -> int:
23    """Computes the exponent for the smallest power of 2 >= number.
24
25    Args:
26      number: The input number.
27
28    Returns:
29      The smallest integer exponent `exp` such that 2**exp >= number.
30    """
31    return int(np.ceil(np.log2(number)))

Computes the exponent for the smallest power of 2 >= number.

Arguments:
  • number: The input number.
Returns:

The smallest integer exponent exp such that 2**exp >= number.

def soft_window( x_grid: numpy.ndarray[typing.Any, typing.Any], x_begin: float, x_end: float) -> numpy.ndarray[typing.Any, typing.Any]:
 34def soft_window(x_grid: ArrayLike, x_begin: float, x_end: float) -> ArrayLike:
 35    """Computes a soft window function.
 36
 37    The window smoothly transitions from 1 to 0 over the interval
 38    [min(x_begin, x_end), max(x_begin, x_end)].
 39    If x_begin <= x_end, it's a cosine-squared fall-off.
 40    If x_begin > x_end, it's a sine-squared rise-on.
 41
 42    Args:
 43      x_grid: A 1D array of x-coordinates.
 44      x_begin: The x-coordinate where the transition starts (value is 1).
 45      x_end: The x-coordinate where the transition ends (value is 0).
 46
 47    Returns:
 48      A 1D array of the same size as x_grid, containing the window values.
 49    """
 50    window = np.zeros_like(x_grid)
 51    x_min_transition = min(x_begin, x_end)
 52    x_max_transition = max(x_begin, x_end)
 53
 54    # Determine indices for different parts of the window
 55    indices_before_transition = np.where(x_grid < x_min_transition)[0]
 56    idx_transition_start = (
 57        indices_before_transition[-1] + 1 if indices_before_transition.size > 0 else 0
 58    )
 59
 60    indices_after_transition = np.where(x_grid > x_max_transition)[0]
 61    idx_transition_end = (
 62        indices_after_transition[0]
 63        if indices_after_transition.size > 0
 64        else len(x_grid)
 65    )
 66
 67    # Define the transition region
 68    x_transition = x_grid[idx_transition_start:idx_transition_end]
 69
 70    if x_begin <= x_end:  # Window goes from 1 down to 0
 71        window[:idx_transition_start] = 1.0
 72        if (
 73            idx_transition_end > idx_transition_start
 74            and x_max_transition > x_min_transition
 75        ):
 76            window[idx_transition_start:idx_transition_end] = (
 77                np.cos(
 78                    np.pi
 79                    / 2.0
 80                    * (x_transition - x_min_transition)
 81                    / (x_max_transition - x_min_transition)
 82                )
 83                ** 2
 84            )
 85        # Values after x_end remain 0 (initialized)
 86    else:  # Window goes from 0 up to 1 (x_begin > x_end, so x_min_transition = x_end)
 87        window[idx_transition_end:] = 1.0
 88        if (
 89            idx_transition_end > idx_transition_start
 90            and x_max_transition > x_min_transition
 91        ):
 92            window[idx_transition_start:idx_transition_end] = (
 93                np.sin(
 94                    np.pi
 95                    / 2.0
 96                    * (x_transition - x_min_transition)
 97                    / (x_max_transition - x_min_transition)
 98                )
 99                ** 2
100            )
101        # Values before x_begin (which is x_max_transition) remain 0 (initialized)
102    return window

Computes a soft window function.

The window smoothly transitions from 1 to 0 over the interval [min(x_begin, x_end), max(x_begin, x_end)]. If x_begin <= x_end, it's a cosine-squared fall-off. If x_begin > x_end, it's a sine-squared rise-on.

Arguments:
  • x_grid: A 1D array of x-coordinates.
  • x_begin: The x-coordinate where the transition starts (value is 1).
  • x_end: The x-coordinate where the transition ends (value is 0).
Returns:

A 1D array of the same size as x_grid, containing the window values.

def get_significant_part_indices_v1( array_data: numpy.ndarray[typing.Any, typing.Any], threshold: float = 1e-08) -> Tuple[int, int]:
105def get_significant_part_indices_v1(
106    array_data: ArrayLike, threshold: float = 1e-8
107) -> Tuple[int, int]:
108    """Returns indices (i1, i2) for the slice A[i1:i2] containing the
109    significant part of the array.
110
111    The significant part is defined relative to the maximum absolute value
112    in the array. Elements A[:i1] and A[i2:] are considered "small".
113
114    Args:
115      array_data: The input 1D array.
116      threshold: The relative threshold to determine significance.
117                 Elements are significant if abs(element) >= threshold * max(abs(array_data)).
118
119    Returns:
120      A tuple (i1, i2) representing the start (inclusive) and
121      end (exclusive) indices of the significant part.
122    """
123    abs_array = np.abs(array_data)
124    if abs_array.size == 0:
125        return 0, 0
126
127    idx_max = np.argmax(abs_array)
128    array_max_val = abs_array[idx_max]
129
130    if array_max_val == 0:  # All elements are zero
131        return 0, len(array_data)
132
133    significant_indices_before_max = np.where(
134        abs_array[:idx_max] >= threshold * array_max_val
135    )[0]
136    i1 = (
137        significant_indices_before_max[0]
138        if significant_indices_before_max.size > 0
139        else idx_max
140    )
141
142    significant_indices_from_max = np.where(
143        abs_array[idx_max:] >= threshold * array_max_val
144    )[0]
145    i2 = (
146        idx_max + significant_indices_from_max[-1] + 1
147        if significant_indices_from_max.size > 0
148        else idx_max + 1
149    )
150    return i1, i2

Returns indices (i1, i2) for the slice A[i1:i2] containing the significant part of the array.

The significant part is defined relative to the maximum absolute value in the array. Elements A[:i1] and A[i2:] are considered "small".

Arguments:
  • array_data: The input 1D array.
  • threshold: The relative threshold to determine significance. Elements are significant if abs(element) >= threshold * max(abs(array_data)).
Returns:

A tuple (i1, i2) representing the start (inclusive) and end (exclusive) indices of the significant part.

def get_significant_part_indices_v2( array_data: numpy.ndarray[typing.Any, typing.Any], threshold: float = 1e-08) -> Tuple[int, int]:
153def get_significant_part_indices_v2(
154    array_data: ArrayLike, threshold: float = 1e-8
155) -> Tuple[int, int]:
156    """Returns indices (i1, i2) based on parts of the array that are "small".
157
158    The interpretation of (i1, i2) from the original code is:
159    `i1` is the index of the last element *before* the peak region that is
160    considered small (abs(element) < threshold * max_val).
161    `i2` is the index *after* the first element *after* the peak region that
162    is considered small.
163    The docstring of the original function was "Return a tuple (i1,i2) such
164    that none of the elements A[i1:i2] is small", which might be misleading
165    as A[i1] and A[i2-1] could themselves be small by this definition.
166    A slice like A[i1+1 : i2-1] or A[i1+1 : idx_first_small_after_peak]
167    might better correspond to "all elements are not small".
168
169    Args:
170      array_data: The input 1D array.
171      threshold: The relative threshold to determine smallness.
172                 Elements are small if abs(element) < threshold * max(abs(array_data)).
173
174    Returns:
175      A tuple (i1, i2).
176    """
177    abs_array = np.abs(array_data)
178    if abs_array.size == 0:
179        return 0, 0
180
181    idx_max = np.argmax(abs_array)
182    array_max_val = abs_array[idx_max]
183
184    if array_max_val == 0:  # All elements are zero
185        return 0, len(array_data)
186
187    small_indices_before_max = np.where(
188        abs_array[:idx_max] < threshold * array_max_val
189    )[0]
190    i1 = small_indices_before_max[-1] if small_indices_before_max.size > 0 else 0
191
192    small_indices_from_max = np.where(abs_array[idx_max:] < threshold * array_max_val)[
193        0
194    ]
195    # small_indices_from_max are relative to idx_max
196    i2 = (
197        idx_max + small_indices_from_max[0] + 1
198        if small_indices_from_max.size > 0
199        else len(array_data)
200    )
201    return i1, i2

Returns indices (i1, i2) based on parts of the array that are "small".

The interpretation of (i1, i2) from the original code is: i1 is the index of the last element before the peak region that is considered small (abs(element) < threshold * max_val). i2 is the index *after* the first element *after* the peak region that is considered small. The docstring of the original function was "Return a tuple (i1,i2) such that none of the elements A[i1:i2] is small", which might be misleading as A[i1] and A[i2-1] could themselves be small by this definition. A slice like A[i1+1 : i2-1] or A[i1+1 : idx_first_small_after_peak] might better correspond to "all elements are not small".

Arguments:
  • array_data: The input 1D array.
  • threshold: The relative threshold to determine smallness. Elements are small if abs(element) < threshold * max(abs(array_data)).
Returns:

A tuple (i1, i2).

def Fourier_filter( data: numpy.ndarray[typing.Any, typing.Any], time_step: float, spectral_window: numpy.ndarray[typing.Any, typing.Any], periodic: bool = False) -> numpy.ndarray[typing.Any, typing.Any]:
204def Fourier_filter(
205    data: ArrayLike,
206    time_step: float,
207    spectral_window: ArrayLike,
208    periodic: bool = False,
209) -> ArrayLike:
210    """Applies a Fourier filter to time-series data.
211
212    The function performs a Fourier transform, multiplies by a spectral
213    window, and then performs an inverse Fourier transform.
214
215    Args:
216      data: A 1D array or a 2D array (num_time_points, num_series)
217            where each column is a time series.
218      time_step: The time step between data points.
219      spectral_window: A 2D array (num_window_points, 2) where the first
220                       column contains circular frequencies (ascending) and
221                       the second column contains the window function W(omega).
222      periodic: If True, data is assumed to be periodic. Otherwise, data is
223                mirrored and padded to reduce edge effects.
224
225    Returns:
226      The filtered data, with the same shape as the input `data`.
227    """
228    if spectral_window.shape[0] == 0:
229        return data.copy()  # Return a copy to match behavior when filtering occurs
230
231    original_shape = data.shape
232    if data.ndim == 1:
233        # Reshape 1D array to 2D for consistent processing
234        current_data = data.reshape(-1, 1)
235    else:
236        current_data = data.copy()  # Work on a copy
237
238    num_time_points = current_data.shape[0]
239
240    if not periodic:
241        # Mirror and concatenate data for non-periodic signals
242        # Effectively doubles the number of time points for FFT
243        current_data = np.vstack((current_data, current_data[::-1, :]))
244        num_time_points_fft = current_data.shape[0]
245    else:
246        num_time_points_fft = num_time_points
247
248    # Fourier transform
249    Fourier_transformed_data = np.fft.fftshift(np.fft.fft(current_data, axis=0), axes=0)
250
251    # Create frequency grid for the spectral window
252    delta_omega = 2 * np.pi / (num_time_points_fft * time_step)
253    # np.arange needs to create num_time_points_fft points
254    # fftshift moves 0 frequency to the center.
255    # The indices for fftshift range from -N/2 to N/2-1 (approx)
256    omega_grid_indices = np.arange(num_time_points_fft) - np.floor(
257        num_time_points_fft / 2.0
258    )
259    omega_grid = delta_omega * omega_grid_indices
260
261    # Interpolate spectral window onto the data's frequency grid
262    # Ensure omega_grid is 1D for interpolation if it became 2D due to reshape
263    window_values = np.interp(
264        np.abs(omega_grid.ravel()),  # Use absolute frequencies
265        spectral_window[:, 0],
266        spectral_window[:, 1],
267        left=1.0,  # Value for frequencies below spectral_window range
268        right=0.0,  # Value for frequencies above spectral_window range
269    )
270
271    # Apply spectral filter
272    # window_values needs to be (num_time_points_fft, 1) to broadcast
273    Fourier_transformed_data *= window_values.reshape(-1, 1)
274
275    # Inverse Fourier transform
276    filtered_data_full = np.fft.ifft(
277        np.fft.ifftshift(Fourier_transformed_data, axes=0), axis=0
278    )
279
280    if not periodic:
281        # Truncate to original length if data was mirrored
282        filtered_data_final = filtered_data_full[:num_time_points, :]
283    else:
284        filtered_data_final = filtered_data_full
285
286    # Ensure output is real if input was real
287    if np.all(np.isreal(data)):
288        filtered_data_final = filtered_data_final.real
289
290    return filtered_data_final.reshape(original_shape)

Applies a Fourier filter to time-series data.

The function performs a Fourier transform, multiplies by a spectral window, and then performs an inverse Fourier transform.

Arguments:
  • data: A 1D array or a 2D array (num_time_points, num_series) where each column is a time series.
  • time_step: The time step between data points.
  • spectral_window: A 2D array (num_window_points, 2) where the first column contains circular frequencies (ascending) and the second column contains the window function W(omega).
  • periodic: If True, data is assumed to be periodic. Otherwise, data is mirrored and padded to reduce edge effects.
Returns:

The filtered data, with the same shape as the input data.

def polyfit_with_weights( x_coords: numpy.ndarray[typing.Any, typing.Any], y_values: numpy.ndarray[typing.Any, typing.Any], weights: numpy.ndarray[typing.Any, typing.Any], degree: int) -> numpy.ndarray[typing.Any, typing.Any]:
293def polyfit_with_weights(
294    x_coords: ArrayLike, y_values: ArrayLike, weights: ArrayLike, degree: int
295) -> ArrayLike:
296    """Performs a weighted least-squares polynomial fit.
297
298    Args:
299      x_coords: 1D array of sample point x-coordinates.
300      y_values: 1D array of sample point y-values to fit.
301      weights: 1D array of weights for each sample point.
302      degree: Degree of the fitting polynomial.
303
304    Returns:
305      A 1D array of polynomial coefficients [p_degree, ..., p_1, p_0].
306    """
307    num_coeffs = degree + 1
308    matrix_a = np.empty((num_coeffs, num_coeffs), dtype=np.float64)
309    weights_squared = weights * weights
310
311    # Construct the Vandermonde-like matrix A for the normal equations
312    # A[i,j] = sum(w_k^2 * x_k^(i+j))
313    for i in range(num_coeffs):
314        for j in range(i, num_coeffs):
315            matrix_a[i, j] = np.sum(weights_squared * (x_coords ** (i + j)))
316            if i != j:
317                matrix_a[j, i] = matrix_a[i, j]  # Symmetric matrix
318
319    # Construct the vector b for the normal equations
320    # b[i] = sum(w_k^2 * y_k * x_k^i)
321    vector_b = np.empty(num_coeffs, dtype=np.float64)
322    for i in range(num_coeffs):
323        vector_b[i] = np.sum(weights_squared * y_values * (x_coords**i))
324
325    # Solve the linear system A * p = b for coefficients p
326    solution_coeffs = scipy.linalg.solve(matrix_a, vector_b)
327    return solution_coeffs[::-1]  # Return in conventional order (highest power first)

Performs a weighted least-squares polynomial fit.

Arguments:
  • x_coords: 1D array of sample point x-coordinates.
  • y_values: 1D array of sample point y-values to fit.
  • weights: 1D array of weights for each sample point.
  • degree: Degree of the fitting polynomial.
Returns:

A 1D array of polynomial coefficients [p_degree, ..., p_1, p_0].

def Fourier_transform( time_points: numpy.ndarray, y_data: numpy.ndarray, target_frequencies: Optional[numpy.ndarray] = None, is_periodic: bool = False, pulse_center_times: Union[float, numpy.ndarray, NoneType] = None) -> Union[numpy.ndarray, Tuple[numpy.ndarray, numpy.ndarray]]:
330def Fourier_transform(
331    time_points: np.ndarray,
332    y_data: np.ndarray,
333    target_frequencies: Optional[np.ndarray] = None,
334    is_periodic: bool = False,
335    pulse_center_times: Optional[Union[float, np.ndarray]] = None,
336) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]:
337    r"""Applies the Fast Fourier Transform (FFT) in an easy-to-use way.
338
339    This function computes the Fourier transform of time-dependent data Y(t),
340    defined as:
341        $F[Y](\omega) = \int_{-\infty}^\infty dt Y(t) \exp(i \omega t)$
342
343    Args:
344        time_points: A 1D NumPy array of shape (N_t,) representing the time
345            discretization of Y(t). Must be sorted in ascending order. N_t must
346            be greater than 1.
347        y_data: A NumPy array of shape (N_t, N_data) or (N_t,) containing the
348            time-dependent data to Fourier transform. If 1D, it's treated as
349            a single data series.
350        target_frequencies: An optional 1D NumPy array of shape (N_omega,)
351            specifying the circular frequencies at which to compute the
352            transform. Must be sorted in ascending order. If None, the transform
353            is returned on an optimal internal frequency grid.
354        is_periodic: A boolean indicating if Y(t) is assumed to be periodic.
355            If True, the time step is expected to be constant, and
356            Y(t_end + dt) = Y(t_start). If False, precautions are taken to
357            minimize artifacts from the DFT's implicit periodicity assumption.
358        pulse_center_times: An optional scalar or 1D NumPy array of length
359            N_data. If the input data represents a pulse not centered in its
360            time window, specifying the center of each pulse (t0) helps avoid
361            interpolation artifacts. If None, the center of the time window is
362            used.
363
364    Returns:
365        If `target_frequencies` is None:
366            A tuple `(transformed_y, fft_omega_grid)`, where:
367            - `transformed_y` (np.ndarray): The Fourier-transformed data,
368              shape (N_fft, N_data) or (N_fft,) if input `y_data` was 1D.
369            - `fft_omega_grid` (np.ndarray): The array of circular frequencies
370              (shape (N_fft,)) corresponding to `transformed_y`.
371        If `target_frequencies` is not None:
372            An np.ndarray of shape (N_omega, N_data) or (N_omega,) containing
373            the Fourier transform of `y_data` evaluated at the specified
374            `target_frequencies`.
375
376    Raises:
377        ValueError: If `time_points` has fewer than 2 elements.
378    """
379    _MAX_FFT_POINTS = 2**20
380    num_time_points = len(time_points)
381    if num_time_points <= 1:
382        raise ValueError("Input 'time_points' array must have more than one element.")
383
384    original_y_ndim = y_data.ndim
385    if original_y_ndim == 1:
386        # Reshape 1D y_data to (N_t, 1) for consistent processing
387        y_data_processed = y_data.reshape((num_time_points, 1))
388    else:
389        y_data_processed = y_data
390
391    num_data_series = y_data_processed.shape[1]
392
393    # Determine the minimum time step from the input time_points
394    min_time_step = np.min(np.diff(time_points))
395    # This will be the effective time step for FFT. It might be modified later
396    # for non-periodic cases to meet Nyquist criteria for target_frequencies.
397    effective_time_step = min_time_step
398
399    period_duration: Optional[float] = None
400    if is_periodic:
401        # For periodic signals, the period is the total duration assuming
402        # a constant step.
403        period_duration = (
404            time_points[-1] - time_points[0] + np.mean(np.diff(time_points))
405        )
406
407    # Handle the case where only one target frequency is given (no FFT needed)
408    if target_frequencies is not None and len(target_frequencies) == 1:
409        # Direct integration using trapezoidal rule
410        # integrand shape: (num_time_points, num_data_series)
411        integrand = y_data_processed * np.exp(
412            1j * target_frequencies * time_points.reshape((num_time_points, 1))
413        )
414        # result shape: (num_data_series,)
415        output_spectrum = np.trapezoid(integrand, time_points, axis=0)
416        if is_periodic:
417            # Correction for periodic functions with trapezoidal rule
418            output_spectrum += (
419                0.5 * min_time_step * (integrand[0, :] + integrand[-1, :])
420            )
421        # Reshape to (1, num_data_series) to match expected output shape
422        output_spectrum = output_spectrum.reshape(1, num_data_series)
423        if original_y_ndim == 1:
424            return output_spectrum.reshape(len(target_frequencies))  # or .flatten()
425        return output_spectrum
426
427    # Determine the target frequency resolution for FFT grid calculation
428    dw_target_for_n_fft_calc: float
429    if target_frequencies is None:
430        # If no target frequencies, FFT resolution is based on input time window
431        dw_target_for_n_fft_calc = 2 * np.pi / (min_time_step * num_time_points)
432    else:
433        # If target frequencies are given, use their minimum spacing
434        dw_target_for_n_fft_calc = np.min(np.diff(target_frequencies))
435
436    # Determine the number of points for FFT (num_fft_points)
437    num_fft_points: int
438    if is_periodic:
439        num_fft_points = num_time_points
440        # If target frequencies are specified, may need to upsample N_fft
441        # to achieve finer resolution than the default N_t points.
442        # effective_time_step remains min_time_step for periodic case.
443        if target_frequencies is not None:
444            while (
445                2 * np.pi / (effective_time_step * num_fft_points)
446                > 1.1 * dw_target_for_n_fft_calc
447                and num_fft_points < _MAX_FFT_POINTS
448            ):
449                num_fft_points *= 2
450    else:  # Not periodic
451        # Initial estimate for num_fft_points based on desired resolution
452        num_fft_points = 2 ** int(
453            round(
454                np.log(2 * np.pi / (min_time_step * dw_target_for_n_fft_calc))
455                / np.log(2.0)
456            )
457        )
458        # Ensure FFT time window is large enough to cover the original data span
459        while (num_fft_points - 1) * min_time_step < time_points[-1] - time_points[
460            0
461        ] and num_fft_points < _MAX_FFT_POINTS:
462            num_fft_points *= 2
463
464        # For non-periodic signals with specified target_frequencies,
465        # adjust effective_time_step and num_fft_points to satisfy Nyquist
466        # and resolution requirements for the target_frequencies.
467        if target_frequencies is not None:
468            # Use temporary variables for this iterative adjustment
469            current_dt = min_time_step
470            current_n_fft = num_fft_points
471            while current_n_fft < _MAX_FFT_POINTS and (
472                target_frequencies[-1]
473                > np.pi / current_dt * (1.0 - 2.0 / current_n_fft)
474                or -target_frequencies[0] > np.pi / current_dt
475            ):
476                current_dt /= 2.0
477                current_n_fft *= 2
478            effective_time_step = current_dt
479            num_fft_points = current_n_fft
480
481    # FFT time grid generation
482    # This grid is centered around 0 after shifting by pulse_center_times
483    fft_time_grid = effective_time_step * np.fft.ifftshift(
484        np.arange(num_fft_points) - num_fft_points // 2
485    )
486
487    # Determine effective pulse center times for interpolation and phase shift
488    # pulse_centers_for_interp: used to shift time_points before interpolation
489    # pulse_centers_for_phase: used for final phase correction, shape (1, N_data)
490
491    _pulse_centers_for_interp: Union[float, np.ndarray]
492    _pulse_centers_for_phase: np.ndarray
493
494    if pulse_center_times is None:
495        # If no t0 provided, use the center of the time window
496        time_window_center = 0.5 * (time_points[0] + time_points[-1])
497        # Find the closest point in time_points to this center
498        idx_center = np.argmin(np.abs(time_points - time_window_center))
499        # This becomes the t0 for all data series
500        calculated_t0 = time_points[idx_center]
501        _pulse_centers_for_interp = calculated_t0  # Scalar
502        _pulse_centers_for_phase = calculated_t0 * np.ones((1, num_data_series))
503    else:
504        _pulse_centers_for_interp = pulse_center_times  # Can be scalar or array
505        if np.isscalar(pulse_center_times):
506            _pulse_centers_for_phase = pulse_center_times * np.ones(
507                (1, num_data_series)
508            )
509        else:
510            # Ensure it's a 1D array before reshaping
511            _pulse_centers_for_phase = np.asarray(pulse_center_times).reshape(
512                (1, num_data_series)
513            )
514
515    # Interpolate y_data onto the FFT time grid
516    # y_interpolated_on_fft_grid shape: (num_fft_points, num_data_series)
517    y_interpolated_on_fft_grid = np.zeros(
518        (num_fft_points, num_data_series), dtype=y_data_processed.dtype
519    )
520
521    for j_col in range(num_data_series):
522        current_t0_interp: float
523        if np.isscalar(_pulse_centers_for_interp):
524            current_t0_interp = float(_pulse_centers_for_interp)
525        else:
526            current_t0_interp = float(np.asarray(_pulse_centers_for_interp)[j_col])
527
528        # Shift original time_points relative to the current pulse center
529        shifted_time_points = time_points - current_t0_interp
530
531        if is_periodic:
532            # For periodic data, use period in interpolation
533            y_interpolated_on_fft_grid[:, j_col] = np.interp(
534                fft_time_grid,
535                shifted_time_points,
536                y_data_processed[:, j_col],
537                period=period_duration,
538            )
539        else:
540            # For non-periodic, pad with zeros outside original time range
541            y_interpolated_on_fft_grid[:, j_col] = np.interp(
542                fft_time_grid,
543                shifted_time_points,
544                y_data_processed[:, j_col],
545                left=0.0,
546                right=0.0,
547            )
548
549    # Perform FFT
550    # The result of ifft is scaled by (1/N). We multiply by (N*dt) to approximate
551    # the integral definition F(omega) = integral Y(t)exp(i*omega*t) dt.
552    # So, overall scaling is dt.
553    y_fft = np.fft.fftshift(np.fft.ifft(y_interpolated_on_fft_grid, axis=0), axes=0) * (
554        num_fft_points * effective_time_step
555    )
556
557    # FFT omega grid
558    fft_omega_grid_spacing = 2 * np.pi / (effective_time_step * num_fft_points)
559    fft_omega_grid = fft_omega_grid_spacing * (
560        np.arange(num_fft_points) - num_fft_points // 2
561    )
562
563    # Apply phase correction due to pulse_center_times (t0)
564    # This accounts for the shift Y(t) -> Y(t-t0) in time domain,
565    # which corresponds to F(omega) -> F(omega) * exp(i*omega*t0)
566    # if the FFT was performed on data effectively centered at t'=0.
567    # The interpolation shifted data by -t0, so Z was Y(t').
568    # The FFT of Y(t') is F[Y(t')] = integral Y(t')exp(iwt')dt'.
569    # We want F[Y(t)] = integral Y(t)exp(iwt)dt.
570    # F[Y(t)] = exp(iw t0_effective) * F[Y(t')].
571    phase_correction = np.exp(
572        1j * _pulse_centers_for_phase * fft_omega_grid.reshape((num_fft_points, 1))
573    )
574    y_fft_corrected = y_fft * phase_correction
575
576    if target_frequencies is None:
577        # Return FFT result on its own grid
578        if original_y_ndim == 1:
579            return y_fft_corrected.flatten(), fft_omega_grid
580        return y_fft_corrected, fft_omega_grid
581    else:
582        # Interpolate FFT result onto the target_frequencies grid
583        output_spectrum = np.zeros(
584            (len(target_frequencies), num_data_series), dtype=np.complex128
585        )
586        for j_col in range(num_data_series):
587            # Note: y_fft_corrected already includes the phase shift based on _pulse_centers_for_phase
588            # and fft_omega_grid. When interpolating to target_frequencies, this phase is implicitly
589            # interpolated as well.
590            output_spectrum[:, j_col] = np.interp(
591                target_frequencies,
592                fft_omega_grid,
593                y_fft_corrected[:, j_col],  # Use the phase-corrected FFT result
594                left=0.0,
595                right=0.0,
596            )
597
598        # The phase correction was already applied to y_fft before interpolation.
599        # If we were to apply it *after* interpolation, it would be:
600        # phase_correction_on_target_freq = np.exp(
601        #    1j * _pulse_centers_for_phase * target_frequencies.reshape((len(target_frequencies), 1))
602        # )
603        # output_spectrum = interpolated_unphased_result * phase_correction_on_target_freq
604        # However, the original code applies the phase correction *before* this final interpolation step
605        # if omega is None, and *after* if omega is not None.
606        # Let's re-check original logic for omega not None:
607        # Z = np.fft.fftshift(np.fft.ifft(Z, axis=0), axes=0) * (N_fft * dt) <-- y_fft (unphased by t0 yet for this path)
608        # ...
609        # result[:,j] = np.interp(omega, w_grid, Z[:,j], left=0.0, right=0.0) <-- interpolation of unphased
610        # result = result * np.exp(1j * t0 * omega.reshape((len(omega), 1))) <-- phase correction
611        # This means my current y_fft_corrected (which has phase) should NOT be used for interpolation here.
612        # I should interpolate 'y_fft' (before t0 correction) and then apply t0 correction using target_frequencies.
613
614        # Reverting to match original logic for target_frequencies path:
615        # Interpolate the raw FFT result (before t0 correction)
616        interpolated_raw_fft = np.zeros(
617            (len(target_frequencies), num_data_series), dtype=np.complex128
618        )
619        for j_col in range(num_data_series):
620            interpolated_raw_fft[:, j_col] = np.interp(
621                target_frequencies,
622                fft_omega_grid,
623                y_fft[
624                    :, j_col
625                ],  # Use y_fft (before _pulse_centers_for_phase correction)
626                left=0.0,
627                right=0.0,
628            )
629
630        # Now apply phase correction using _pulse_centers_for_phase and target_frequencies
631        phase_correction_final = np.exp(
632            1j
633            * _pulse_centers_for_phase
634            * target_frequencies.reshape((len(target_frequencies), 1))
635        )
636        output_spectrum = interpolated_raw_fft * phase_correction_final
637
638        if original_y_ndim == 1:
639            return output_spectrum.flatten()
640        return output_spectrum

Applies the Fast Fourier Transform (FFT) in an easy-to-use way.

This function computes the Fourier transform of time-dependent data Y(t), defined as: $FY = \int_{-\infty}^\infty dt Y(t) \exp(i \omega t)$

Arguments:
  • time_points: A 1D NumPy array of shape (N_t,) representing the time discretization of Y(t). Must be sorted in ascending order. N_t must be greater than 1.
  • y_data: A NumPy array of shape (N_t, N_data) or (N_t,) containing the time-dependent data to Fourier transform. If 1D, it's treated as a single data series.
  • target_frequencies: An optional 1D NumPy array of shape (N_omega,) specifying the circular frequencies at which to compute the transform. Must be sorted in ascending order. If None, the transform is returned on an optimal internal frequency grid.
  • is_periodic: A boolean indicating if Y(t) is assumed to be periodic. If True, the time step is expected to be constant, and Y(t_end + dt) = Y(t_start). If False, precautions are taken to minimize artifacts from the DFT's implicit periodicity assumption.
  • pulse_center_times: An optional scalar or 1D NumPy array of length N_data. If the input data represents a pulse not centered in its time window, specifying the center of each pulse (t0) helps avoid interpolation artifacts. If None, the center of the time window is used.
Returns:

If target_frequencies is None: A tuple (transformed_y, fft_omega_grid), where: - transformed_y (np.ndarray): The Fourier-transformed data, shape (N_fft, N_data) or (N_fft,) if input y_data was 1D. - fft_omega_grid (np.ndarray): The array of circular frequencies (shape (N_fft,)) corresponding to transformed_y. If target_frequencies is not None: An np.ndarray of shape (N_omega, N_data) or (N_omega,) containing the Fourier transform of y_data evaluated at the specified target_frequencies.

Raises:
  • ValueError: If time_points has fewer than 2 elements.
def inverse_Fourier_transform( omega_points: numpy.ndarray[typing.Any, typing.Any], data_series: numpy.ndarray[typing.Any, typing.Any], time_points_target: Optional[numpy.ndarray[Any, Any]] = None, is_periodic: bool = False, frequency_offset: Union[float, numpy.ndarray[Any, Any], NoneType] = None) -> Union[Tuple[numpy.ndarray[Any, Any], numpy.ndarray[Any, Any]], numpy.ndarray[Any, Any]]:
643def inverse_Fourier_transform(
644    omega_points: ArrayLike,
645    data_series: ArrayLike,
646    time_points_target: Optional[ArrayLike] = None,
647    is_periodic: bool = False,
648    frequency_offset: Optional[Union[float, ArrayLike]] = None,
649) -> Union[Tuple[ArrayLike, ArrayLike], ArrayLike]:
650    r"""Applies inverse FFT to frequency-dependent data.
651
652    Computes $ F^{-1}[Y](t) = 1 / (2 \pi) \int_{-\infty}^\infty d\omega Y(\omega) \exp(-i t \omega) $.
653
654    Args:
655      omega_points: 1D array of circular frequencies (N_omega), sorted.
656      data_series: Frequency-dependent data (N_omega) or (N_omega, N_data_series).
657      time_points_target: Optional 1D array of time points (N_t), sorted.
658                          If None, times are determined by IFFT.
659      is_periodic: If True, data_series is assumed periodic in frequency.
660      frequency_offset: Scalar or array (N_data_series). Central frequency
661                        offset(s) if data is not centered at omega=0.
662
663    Returns:
664      If time_points_target is None:
665        A tuple (transformed_data, time_grid).
666      If time_points_target is provided:
667        Transformed data interpolated at the given time_points_target.
668    """
669    # IFFT(Y(w)) = 1/(2pi) FT(Y(w))_at_-t = 1/(2pi) conj(FT(conj(Y(w)))_at_t)
670    # The provided Fourier_transform computes FT[Y(t)](omega) = integral Y(t) exp(iwt) dt
671    # We want IFT[Y(w)](t) = 1/(2pi) integral Y(w) exp(-iwt) dw
672    # Let w' = -w, dw' = -dw.
673    # = -1/(2pi) integral Y(-w') exp(iw't) dw' (from -inf to +inf, so limits flip)
674    # = 1/(2pi) integral Y(-w') exp(iw't) dw' (from -inf to +inf)
675    # So, call Fourier_transform with omega -> -omega (reversed), Y(omega) -> Y(-omega) (reversed)
676    # and then scale by 1/(2pi). The 't' in Fourier_transform becomes our 'omega_points',
677    # and 'omega' in Fourier_transform becomes our '-time_points_target'.
678
679    if time_points_target is None:
680        # Transform Y(omega) as if it's a time signal Y(t=omega)
681        # The 'omega' output of Fourier_transform will correspond to '-t'
682        transformed_data, neg_time_grid = Fourier_transform(
683            time_points=omega_points,
684            y_data=data_series,
685            target_frequencies=None,  # Let FT determine output grid
686            is_periodic=is_periodic,
687            pulse_center_times=frequency_offset,  # This is omega0, an offset in the input "time" (omega) domain
688        )
689        # Result is FT[Y](k), where k is frequency. Here k corresponds to -t.
690        # So, FT[Y(omega)](-t). We need to flip t and scale.
691        return transformed_data[::-1] / (2 * np.pi), -neg_time_grid[::-1]
692    else:
693        # Target 'omega' for Fourier_transform is -time_points_target
694        neg_target_times = -time_points_target[::-1]  # Ensure it's sorted for FT
695
696        result_at_neg_t = Fourier_transform(
697            time_points=omega_points,
698            y_data=data_series,
699            target_frequencies=neg_target_times,
700            is_periodic=is_periodic,
701            pulse_center_times=frequency_offset,
702        )
703        # result_at_neg_t is FT[Y(omega)](-t_target_sorted)
704        # We want values at t_target, so reverse the order back.
705        return result_at_neg_t[::-1] / (2 * np.pi)

Applies inverse FFT to frequency-dependent data.

Computes $ F^{-1}Y = 1 / (2 \pi) \int_{-\infty}^\infty d\omega Y(\omega) \exp(-i t \omega) $.

Arguments:
  • omega_points: 1D array of circular frequencies (N_omega), sorted.
  • data_series: Frequency-dependent data (N_omega) or (N_omega, N_data_series).
  • time_points_target: Optional 1D array of time points (N_t), sorted. If None, times are determined by IFFT.
  • is_periodic: If True, data_series is assumed periodic in frequency.
  • frequency_offset: Scalar or array (N_data_series). Central frequency offset(s) if data is not centered at omega=0.
Returns:

If time_points_target is None: A tuple (transformed_data, time_grid). If time_points_target is provided: Transformed data interpolated at the given time_points_target.

def find_zero_crossings( x_values: numpy.ndarray[typing.Any, typing.Any], y_values: numpy.ndarray[typing.Any, typing.Any]) -> numpy.ndarray[typing.Any, typing.Any]:
708def find_zero_crossings(x_values: ArrayLike, y_values: ArrayLike) -> ArrayLike:
709    """Finds all x-values where linearly interpolated y(x) = 0.
710
711    Args:
712      x_values: 1D array of x-coordinates, sorted ascending, no duplicates.
713      y_values: 1D array of y-coordinates, same shape as x_values.
714
715    Returns:
716      A 1D array of x-values where y(x) crosses zero. Empty if no crossings.
717    """
718    if x_values.size == 0 or y_values.size == 0:
719        return np.array([])
720    if x_values.size != y_values.size:
721        raise ValueError("x_values and y_values must have the same length.")
722
723    # Product of y[i] and y[i+1]
724    product_adjacent_y = y_values[:-1] * y_values[1:]
725    crossings_x_coords: List[float] = []
726
727    # Find indices where product is <= 0 (indicates a zero crossing or y[i]=0)
728    for i in np.where(product_adjacent_y <= 0)[0]:
729        # Instead of: if product_adjacent_y[i] == 0:
730        #                 if y_values[i] == 0:
731        # Use np.isclose for checking if y_values[i] or y_values[i+1] are zero
732        y1_is_zero = np.isclose(y_values[i], 0.0)
733        y2_is_zero = np.isclose(y_values[i + 1], 0.0)
734
735        if y1_is_zero and y2_is_zero:  # segment is [0,0]
736            crossings_x_coords.append(x_values[i])
737            # To avoid double adding x_values[i+1] if it's processed as y1_is_zero in next iter
738        elif y1_is_zero:
739            crossings_x_coords.append(x_values[i])
740        elif (
741            y2_is_zero and product_adjacent_y[i] < 0
742        ):  # Crosses and lands on zero at y2
743            # The interpolation formula will give x_values[i+1]
744            x1, x2_pt = x_values[i], x_values[i + 1]
745            y1_pt, y2_pt = y_values[i], y_values[i + 1]  # y2_pt is close to 0
746            crossings_x_coords.append((x1 * y2_pt - x2_pt * y1_pt) / (y2_pt - y1_pt))
747        elif product_adjacent_y[i] < 0:  # Definite crossing, neither is zero
748            x1, x2 = x_values[i], x_values[i + 1]
749            y1_val, y2_val = y_values[i], y_values[i + 1]
750            crossings_x_coords.append((x1 * y2_val - x2 * y1_val) / (y2_val - y1_val))
751
752    # Handle case where the last point itself is a zero not caught by pair product
753    # This also needs np.isclose
754    if y_values.size > 0 and np.isclose(y_values[-1], 0.0):
755        # Avoid adding if it's already part of a segment ending in zero
756        # that was captured by product_adjacent_y[i]=0 logic (where y[i+1]=0)
757        already_found = False
758        if crossings_x_coords and np.isclose(crossings_x_coords[-1], x_values[-1]):
759            already_found = True
760
761        if not already_found:
762            # If y_values[-1] is zero, and y_values[-2]*y_values[-1] was not <=0 (e.g. y_values[-2] also zero)
763            # or it was handled by interpolation which might be slightly off.
764            # We want to ensure grid points that are zero are included.
765            # A simpler way: collect all interpolated, then add all x where y is zero, then unique.
766            pass  # The unique call later should handle it if x_values[-1] was added by main loop
767
768    # A more robust approach for points exactly on the grid:
769    # After interpolation, add all x_values where corresponding y_values are close to zero.
770    if x_values.size > 0:  # Ensure x_values is not empty
771        grid_zeros = x_values[np.isclose(y_values, 0.0)]
772        crossings_x_coords.extend(list(grid_zeros))
773
774    return np.unique(np.array(crossings_x_coords))

Finds all x-values where linearly interpolated y(x) = 0.

Arguments:
  • x_values: 1D array of x-coordinates, sorted ascending, no duplicates.
  • y_values: 1D array of y-coordinates, same shape as x_values.
Returns:

A 1D array of x-values where y(x) crosses zero. Empty if no crossings.

def find_extrema_positions( x_values: numpy.ndarray[typing.Any, typing.Any], y_values: numpy.ndarray[typing.Any, typing.Any]) -> numpy.ndarray[typing.Any, typing.Any]:
777def find_extrema_positions(x_values: ArrayLike, y_values: ArrayLike) -> ArrayLike:
778    """Finds x-positions of local extrema in y(x).
779
780    Extrema are found where the derivative y'(x) (approximated by finite
781    differences) crosses zero.
782
783    Args:
784      x_values: 1D array of x-coordinates, sorted ascending, no duplicates.
785      y_values: 1D array of y-coordinates, same shape as x_values.
786
787    Returns:
788      A 1D array of x-values where y(x) has local extrema. Empty if none.
789    """
790    if (
791        len(x_values) < 2 or len(y_values) < 2
792    ):  # Need at least two points for a derivative
793        return np.array([])
794    if len(x_values) != len(y_values):
795        raise ValueError("x_values and y_values must have the same length.")
796
797    # Approximate derivative y'(x)
798    delta_y = y_values[1:] - y_values[:-1]
799    delta_x = x_values[1:] - x_values[:-1]
800    # Avoid division by zero if x_values have duplicates (though pre-condition says no duplicates)
801    # However, if delta_x is extremely small, derivative can be huge.
802    # For robustness, filter out zero delta_x if they somehow occur.
803    valid_dx = delta_x != 0
804    if not np.all(valid_dx):  # Should not happen given preconditions
805        delta_y = delta_y[valid_dx]
806        delta_x = delta_x[valid_dx]
807        mid_points_x_for_derivative = (x_values[1:] + x_values[:-1])[valid_dx] / 2.0
808    else:
809        mid_points_x_for_derivative = (x_values[1:] + x_values[:-1]) / 2.0
810
811    if delta_x.size == 0:  # Not enough points after filtering
812        return np.array([])
813
814    derivative_y = delta_y / delta_x
815
816    # Find where the derivative crosses zero
817    extrema_x_coords = find_zero_crossings(mid_points_x_for_derivative, derivative_y)
818    return extrema_x_coords

Finds x-positions of local extrema in y(x).

Extrema are found where the derivative y'(x) (approximated by finite differences) crosses zero.

Arguments:
  • x_values: 1D array of x-coordinates, sorted ascending, no duplicates.
  • y_values: 1D array of y-coordinates, same shape as x_values.
Returns:

A 1D array of x-values where y(x) has local extrema. Empty if none.

def minimize_imaginary_parts( complex_array: numpy.ndarray[typing.Any, typing.Any]) -> numpy.ndarray[typing.Any, typing.Any]:
821def minimize_imaginary_parts(complex_array: ArrayLike) -> ArrayLike:
822    """Rotates a complex array by a phase to make it as close as possible to being real-valued
823
824    Multiplies `complex_array` by `exp(1j*phi)` where `phi` is chosen to
825    minimize `sum(imag(exp(1j*phi) * complex_array)**2)`.
826
827    Args:
828      complex_array: A NumPy array of complex numbers.
829
830    Returns:
831      The phase-rotated complex NumPy array.
832    """
833    if complex_array.size == 0:
834        return complex_array.copy()
835
836    # Z = X + iY. We want to minimize sum( (X sin(phi) + Y cos(phi))^2 )
837    # d/dphi (sum(...)) = 0 leads to tan(2*phi) = 2*sum(XY) / sum(Y^2 - X^2)
838    real_part = complex_array.real
839    imag_part = complex_array.imag
840
841    numerator = 2 * np.sum(real_part * imag_part)
842    denominator = np.sum(imag_part**2 - real_part**2)
843
844    # arctan2 handles signs and denominator being zero correctly
845    phi = 0.5 * np.arctan2(numerator, denominator)
846
847    # The arctan2 gives phi in (-pi, pi], so 0.5*phi is in (-pi/2, pi/2].
848    # This finds one extremum. The other is phi + pi/2. We need the minimum.
849    rotated_z1 = complex_array * np.exp(1j * phi)
850    imag_energy1 = np.sum(rotated_z1.imag**2)
851
852    rotated_z2 = complex_array * np.exp(1j * (phi + 0.5 * np.pi))
853    imag_energy2 = np.sum(rotated_z2.imag**2)
854
855    if imag_energy2 < imag_energy1:
856        phi += 0.5 * np.pi
857
858    # Normalize phi to be in (-pi/2, pi/2] or a similar principal range if desired,
859    # though for exp(1j*phi) it doesn't strictly matter beyond 2pi periodicity.
860    # The original code maps phi to (-pi/2, pi/2] effectively.
861    phi -= np.pi * np.round(phi / np.pi)  # This maps to (-pi/2, pi/2]
862    # Let's test the original normalization:
863    # If phi = 0.6*pi, round(0.6) = 1. phi = 0.6pi - pi = -0.4pi. Correct.
864    # If phi = 0.4*pi, round(0.4) = 0. phi = 0.4pi. Correct.
865    # If phi = -0.6*pi, round(-0.6) = -1. phi = -0.6pi + pi = 0.4pi. Correct.
866    # This normalization is fine.
867
868    return complex_array * np.exp(1j * phi)

Rotates a complex array by a phase to make it as close as possible to being real-valued

Multiplies complex_array by exp(1j*phi) where phi is chosen to minimize sum(imag(exp(1j*phi) * complex_array)**2).

Arguments:
  • complex_array: A NumPy array of complex numbers.
Returns:

The phase-rotated complex NumPy array.

def integrate_oscillating_function( x_values: numpy.ndarray[typing.Any, typing.Any], func_values: numpy.ndarray[typing.Any, typing.Any], phase_values: numpy.ndarray[typing.Any, typing.Any], phase_step_threshold: float = 0.001) -> numpy.ndarray[typing.Any, typing.Any]:
871def integrate_oscillating_function(
872    x_values: ArrayLike,
873    func_values: ArrayLike,
874    phase_values: ArrayLike,
875    phase_step_threshold: float = 1e-3,
876) -> ArrayLike:
877    r"""Integrates f(x) * exp(i * phi(x)) for quickly oscillating functions.
878
879    Uses an algorithm suitable for integrating f(x) * exp(i * phi(x)) dx
880    over small intervals, particularly when phi(x) changes rapidly.
881
882    Args:
883      x_values: 1D array of sorted x-coordinates.
884      func_values: Array of function values f(x). Can be 1D (N_x) or 2D
885                   (N_x, N_series).
886      phase_values: Array of real-valued phase phi(x). Same shape as func_values.
887      phase_step_threshold: Small positive number. Prevents division by
888                            small d_phi in the integration formula.
889
890    Returns:
891      A scalar or 1D array (N_series) of integral results.
892    """
893    # Input validation
894    if not (x_values.shape[0] == func_values.shape[0] == phase_values.shape[0]):
895        raise ValueError(
896            "x_values, func_values, and phase_values must have "
897            "the same length along the integration axis (axis 0)."
898        )
899    if not np.allclose(np.imag(phase_values), 0):  # Ensure phase is real
900        raise ValueError("phase_values must be real-valued.")
901    if func_values.ndim > 1 and func_values.shape != phase_values.shape:
902        raise ValueError(
903            "If func_values is 2D, phase_values must have the exact same shape."
904        )
905
906    delta_x = x_values[1:] - x_values[:-1]
907
908    # Prepare f and phi for interval calculations
909    f1 = func_values[:-1, ...]  # f(x_i)
910    f2 = func_values[1:, ...]  # f(x_{i+1})
911    delta_f = f2 - f1
912
913    phi1 = phase_values[:-1, ...]  # phi(x_i)
914    phi2 = phase_values[1:, ...]  # phi(x_{i+1})
915    delta_phi = phi2 - phi1
916
917    # Reshape delta_x to broadcast with f1, f2, etc.
918    # If func_values is (N_x, N_series), delta_x needs to be (N_x-1, 1)
919    reshape_dims = (-1,) + (1,) * (func_values.ndim - 1)
920    delta_x_reshaped = delta_x.reshape(reshape_dims)
921
922    # Common factor for the integral segments
923    common_factor_z = delta_x_reshaped * np.exp(0.5j * (phi1 + phi2))
924
925    integral_segments = np.zeros_like(common_factor_z, dtype=complex)
926
927    # Mask for small phase changes (use simpler approximation)
928    is_small_delta_phi = np.abs(delta_phi) < phase_step_threshold
929
930    # Case 1: Small delta_phi (dphi is small)
931    if np.any(is_small_delta_phi):
932        # Approximation: integral \approx dx * exp(i*phi_avg) * (f_avg + i/8 * dphi * df)
933        # This seems to be a higher-order trapezoidal rule for oscillating functions.
934        # Original: Z[s] = Z[s] * (0.5 * (f1[s] + f2[s]) + 0.125j * dphi[s] * df[s])
935        # where Z[s] was common_factor_z[is_small_delta_phi]
936        term_small_dphi = (
937            0.5 * (f1[is_small_delta_phi] + f2[is_small_delta_phi])
938            + 0.125j * delta_phi[is_small_delta_phi] * delta_f[is_small_delta_phi]
939        )
940        integral_segments[is_small_delta_phi] = (
941            common_factor_z[is_small_delta_phi] * term_small_dphi
942        )
943
944    # Case 2: Large delta_phi (use formula for oscillating part)
945    is_large_delta_phi = ~is_small_delta_phi
946    if np.any(is_large_delta_phi):
947        # This is likely an approximation based on integration by parts or steepest descent.
948        # Original: Z[s] = Z[s] / dphi[s]**2 * (exp_term * (df[s] - 1j*f2[s]*dphi[s]) -
949        #                                     (df[s] - 1j*f1[s]*dphi[s]) / exp_term)
950        # where Z[s] was common_factor_z[is_large_delta_phi] and exp_term = exp(0.5j * dphi[s])
951
952        dphi_large = delta_phi[is_large_delta_phi]
953        exp_half_j_dphi = np.exp(0.5j * dphi_large)
954
955        term1 = exp_half_j_dphi * (
956            delta_f[is_large_delta_phi] - 1j * f2[is_large_delta_phi] * dphi_large
957        )
958        term2 = (
959            delta_f[is_large_delta_phi] - 1j * f1[is_large_delta_phi] * dphi_large
960        ) / exp_half_j_dphi
961
962        integral_segments[is_large_delta_phi] = (
963            common_factor_z[is_large_delta_phi] / (dphi_large**2) * (term1 - term2)
964        )
965
966    return np.sum(integral_segments, axis=0)

Integrates f(x) * exp(i * phi(x)) for quickly oscillating functions.

Uses an algorithm suitable for integrating f(x) * exp(i * phi(x)) dx over small intervals, particularly when phi(x) changes rapidly.

Arguments:
  • x_values: 1D array of sorted x-coordinates.
  • func_values: Array of function values f(x). Can be 1D (N_x) or 2D (N_x, N_series).
  • phase_values: Array of real-valued phase phi(x). Same shape as func_values.
  • phase_step_threshold: Small positive number. Prevents division by small d_phi in the integration formula.
Returns:

A scalar or 1D array (N_series) of integral results.

def calculate_permittivity_from_delta_polarization( time_step: float, polarization_delta_response: numpy.ndarray[typing.Any, typing.Any], omega_array: numpy.ndarray[typing.Any, typing.Any], momentum_relaxation_rate: float = 0.0, dephasing_time: Optional[float] = None, disregard_drift_current: bool = False, allow_for_linear_displacement: bool = True) -> numpy.ndarray[typing.Any, typing.Any]:
 969def calculate_permittivity_from_delta_polarization(
 970    time_step: float,
 971    polarization_delta_response: ArrayLike,  # P_delta
 972    omega_array: ArrayLike,
 973    momentum_relaxation_rate: float = 0.0,
 974    dephasing_time: Optional[float] = None,
 975    disregard_drift_current: bool = False,
 976    allow_for_linear_displacement: bool = True,
 977) -> ArrayLike:
 978    r"""Evaluates permittivity from polarization induced by E(t) = delta(t).
 979
 980    Handles drift currents and coherent oscillations in the polarization response.
 981    The relationship is $\epsilon(\omega) = 1 + 4 \pi \chi(\omega)$, where
 982    $P(\omega) = \chi(\omega) E(\omega)$, and for $E(t)=\delta(t)$, $E(\omega)=1$.
 983    So $\chi(\omega) = P_{\delta}(\omega)$.
 984
 985    Args:
 986      time_step: Time step (atomic units) of the polarization grid.
 987      polarization_delta_response: 1D array of polarization response P(t)
 988                                   (atomic units) induced by E(t)=delta(t).
 989                                   P_delta[0] corresponds to t=time_step.
 990      omega_array: 1D array of circular frequencies (a.u.) for permittivity calculation.
 991                   All frequencies must be non-zero.
 992      momentum_relaxation_rate: If non-zero, models Drude-like momentum relaxation
 993                                (gamma in 1/(omega*(omega+i*gamma))).
 994      dephasing_time: If not None/zero, an exponential decay (rate 1/dephasing_time)
 995                      is applied to coherent dipole oscillations in P_delta.
 996      disregard_drift_current: If True, the J_drift component of polarization
 997                               has no effect on the result.
 998      allow_for_linear_displacement: If True, fits P(t) ~ J_drift*t + P_offset.
 999                                     If False, fits P(t) ~ J_drift*t (P_offset=0).
1000
1001    Returns:
1002      A complex array (same shape as omega_array) of permittivity values.
1003    """
1004    if not np.all(omega_array != 0):
1005        raise ValueError("All elements in omega_array must be non-zero.")
1006
1007    # Construct time grid and full polarization array P(t), P(0)=0
1008    num_p_delta_points = polarization_delta_response.size
1009    # P_delta starts at t=dt, so P has N_t = num_p_delta_points + 1 points
1010    num_time_points = num_p_delta_points + 1
1011    time_grid = time_step * np.arange(num_time_points)
1012    time_max = time_grid[-1]
1013
1014    polarization_full = np.zeros(num_time_points)
1015    polarization_full[1:] = polarization_delta_response  # P(0)=0
1016
1017    # Fit and subtract linear trend (drift current and offset)
1018    # Fit is done on the latter half of the data
1019    fit_start_index = num_time_points // 2
1020    if (
1021        fit_start_index < 2 and num_time_points >= 2
1022    ):  # Need at least 2 points for polyfit(deg=1)
1023        fit_start_index = 0  # Use all data if too short for half
1024    elif num_time_points < 2:
1025        # Handle very short P_delta (e.g. 0 or 1 point)
1026        # If P_delta has 0 points, N_t=1, P_full=[0]. J_drift=0, P_offset=0.
1027        # If P_delta has 1 point, N_t=2, P_full=[0, P_d[0]].
1028        # polyfit needs at least deg+1 points.
1029        if num_time_points < 2:  # Cannot do polyfit
1030            J_drift = 0.0
1031            P_offset = 0.0
1032        elif allow_for_linear_displacement:  # N_t >= 2
1033            poly_coeffs = np.polyfit(
1034                time_grid[fit_start_index:], polarization_full[fit_start_index:], 1
1035            )
1036            J_drift = poly_coeffs[0]
1037            P_offset = poly_coeffs[1]
1038        else:  # N_t >= 2, P_offset = 0
1039            # P(t) = J_drift * t => J_drift = sum(P*t) / sum(t^2)
1040            # Ensure denominator is not zero if time_grid[fit_start_index:] is all zeros
1041            # (e.g., if fit_start_index is past end, or time_grid is [0,0,...])
1042            t_fit = time_grid[fit_start_index:]
1043            sum_t_squared = np.sum(t_fit**2)
1044            if sum_t_squared == 0:
1045                J_drift = 0.0
1046            else:
1047                J_drift = (
1048                    np.sum(polarization_full[fit_start_index:] * t_fit) / sum_t_squared
1049                )
1050            P_offset = 0.0
1051    else:  # Standard case N_t >= 2 and fit_start_index allows polyfit
1052        if allow_for_linear_displacement:
1053            poly_coeffs = np.polyfit(
1054                time_grid[fit_start_index:], polarization_full[fit_start_index:], 1
1055            )
1056            J_drift = poly_coeffs[0]
1057            P_offset = poly_coeffs[1]
1058        else:
1059            t_fit = time_grid[fit_start_index:]
1060            sum_t_squared = np.sum(t_fit**2)
1061            if sum_t_squared == 0:
1062                J_drift = 0.0
1063            else:
1064                J_drift = (
1065                    np.sum(polarization_full[fit_start_index:] * t_fit) / sum_t_squared
1066                )
1067            P_offset = 0.0
1068
1069    # Subtract the J_drift * t part from polarization_full. P_offset remains for now.
1070    polarization_oscillating = polarization_full - J_drift * time_grid
1071
1072    # Apply dephasing/windowing to the oscillating part (P - J_drift*t)
1073    # P_offset is part of the "DC" or very slow component, window it too.
1074    # The original code did: P_offset + window * (P - P_offset)
1075    # where P was (P_original - J_drift*t).
1076    # So, effectively: P_offset + window * (P_original - J_drift*t - P_offset)
1077
1078    if dephasing_time is None or dephasing_time == 0:
1079        # Soft window to zero if no explicit dephasing
1080        time_window = soft_window(time_grid, 0.5 * time_max, time_max)
1081    else:
1082        time_window = np.exp(-time_grid / dephasing_time) * soft_window(
1083            time_grid, 0.5 * time_max, time_max
1084        )
1085
1086    # Windowed polarization: P_offset is the value it decays from/to at t=0,
1087    # and the oscillating part (P_orig - J_drift*t - P_offset) is damped.
1088    processed_polarization = P_offset + time_window * (
1089        polarization_oscillating - P_offset
1090    )
1091
1092    permittivity_results = np.zeros_like(omega_array, dtype=complex)
1093
1094    for i, omega_val in enumerate(omega_array):
1095        # chi(omega) = FT[P_processed(t)](omega)
1096        # P_processed = P_offset_non_windowed + window * (P_osc - P_offset_non_windowed)
1097        # FT[P_processed] = FT[P_offset] + FT[window * (P_osc - P_offset)]
1098        # The original code integrated `processed_polarization` which is
1099        # P_offset + window * (P_original - J_drift*t - P_offset)
1100
1101        chi_omega = integrate_oscillating_function(
1102            time_grid, processed_polarization, omega_val * time_grid
1103        )
1104
1105        # Add analytical FT of the P_offset tail (if P_offset was not windowed to zero)
1106        # The `processed_polarization` already includes P_offset, partly windowed.
1107        # The original code adds: P_offset * 1j * np.exp(1j * omega * t_max) / omega
1108        # This looks like the FT of P_offset * Heaviside(t) if it extended from 0 to t_max
1109        # and then was abruptly cut, or FT of P_offset for t>t_max if window brought it to P_offset at t_max.
1110        # If processed_polarization(t_max) -> P_offset (due to window(t_max)=1),
1111        # and we assume P(t) = P_offset for t > t_max, its FT is P_offset * exp(i*omega*t_max) * (pi*delta(omega) + 1/(i*omega))
1112        # This term is tricky. The original `integrate_oscillating_function` handles up to t_max.
1113        # If the window makes processed_polarization(t_max) close to P_offset,
1114        # and we assume P(t) = P_offset for t > t_max, the integral from t_max to inf is
1115        # P_offset * integral_{t_max to inf} exp(i*omega*t) dt
1116        # = P_offset * [exp(i*omega*t) / (i*omega)]_{t_max to inf}
1117        # For convergence, Im(omega) > 0 or add damping. Assuming real omega, this diverges.
1118        # The term P_offset * 1j * np.exp(1j * omega * t_max) / omega
1119        # is -P_offset * exp(1j*omega*t_max) / (1j*omega). This is the upper limit of the integral.
1120        # It implies the lower limit (at infinity) is taken as zero.
1121        # This is the FT of P_offset * step_function(t_max - t) if integrated from -inf to t_max.
1122        # Or FT of P_offset for t > t_max, i.e. integral from t_max to infinity of P_offset*exp(iwt)
1123        # = P_offset * [exp(iwt)/(iw)]_{t_max to inf}. For this to be -P_offset*exp(iw*t_max)/(iw),
1124        # the exp(iw*inf) term must vanish, e.g. by small positive Im(w).
1125        # Let's assume this term correctly accounts for the P_offset tail beyond t_max.
1126        if not np.isclose(P_offset, 0.0):  # Only add if P_offset is significant
1127            chi_omega += P_offset * 1j * np.exp(1j * omega_val * time_max) / omega_val
1128
1129        # Add contribution from drift current J_drift / (omega * (omega + i*gamma_relax))
1130        if not disregard_drift_current and not np.isclose(J_drift, 0.0):
1131            denominator_drift = omega_val * (omega_val + 1j * momentum_relaxation_rate)
1132            if not np.isclose(
1133                denominator_drift, 0.0
1134            ):  # Avoid division by zero if omega=0 (already checked) or omega=-i*gamma
1135                chi_omega -= J_drift / denominator_drift
1136            # If denominator is zero (e.g. omega_val=0 or omega_val = -1j*gamma), this term is singular.
1137            # omega_val != 0 is asserted. If omega_val = -1j*gamma, it's a resonant condition.
1138
1139        permittivity_results[i] = 1.0 + 4 * np.pi * chi_omega
1140
1141    return permittivity_results

Evaluates permittivity from polarization induced by E(t) = delta(t).

Handles drift currents and coherent oscillations in the polarization response. The relationship is $\epsilon(\omega) = 1 + 4 \pi \chi(\omega)$, where $P(\omega) = \chi(\omega) E(\omega)$, and for $E(t)=\delta(t)$, $E(\omega)=1$. So $\chi(\omega) = P_{\delta}(\omega)$.

Arguments:
  • time_step: Time step (atomic units) of the polarization grid.
  • polarization_delta_response: 1D array of polarization response P(t) (atomic units) induced by E(t)=delta(t). P_delta[0] corresponds to t=time_step.
  • omega_array: 1D array of circular frequencies (a.u.) for permittivity calculation. All frequencies must be non-zero.
  • momentum_relaxation_rate: If non-zero, models Drude-like momentum relaxation (gamma in 1/(omega(omega+igamma))).
  • dephasing_time: If not None/zero, an exponential decay (rate 1/dephasing_time) is applied to coherent dipole oscillations in P_delta.
  • disregard_drift_current: If True, the J_drift component of polarization has no effect on the result.
  • allow_for_linear_displacement: If True, fits P(t) ~ J_drift*t + P_offset. If False, fits P(t) ~ J_drift*t (P_offset=0).
Returns:

A complex array (same shape as omega_array) of permittivity values.