AudioScience Review post on Quantasylum

@matt I don’t know whether you are inclined to post on Audioscience, but a mention was made of your product and your thinking on cross correlation, noise, etc. The post could really benefit from your ability to succinctly elucidate complicated issues.

Hi @Moto,

I think the post at the link below (by nanook) correctly summarizes. The key is that cross-correlation can reduce noise inherent in the ADC, but it cannot reduce the noise in the DUT. More specifically, any signal that is common to the left the and right channels will be preserved. Any signal that is different on the left and right channels will be averaged to zero.

If you want to get rid of all noise, you can do that too by ensuring your acquired signals are time aligned and averaging in the time domain (aka synchronous or coherent averaging). That will leave you with just the harmonics (and fundamental) and can be useful for looking for harmonics far below the noise floor.

In the plot below, I used PyQa40x to write a short program to average time-aligned data from the QA403. Note this is not the same as averaging in freq domain. Instead, this takes advantage of the fact that the QA403 input and output are precisely aligned on every cycle. And so, you can add acquisitions in the time domain, and that will make the signal stronger and the noise weaker. And then convert to freq domain. You can see how many more harmonics are revealed in the plot on the right.

And repeat for 128 acquisitions, and we get 21.3 dB reduction in noise (note how many harmonics are visible). The pattern is 2x avg buys you 3 dB, 4x 6 dB, 8x 9 dB…and 128 nets you 21 dB

And so, cross correlation can reduce the noise and distortion associated with the ADC. And coherent averaging will reduce all the noise (DUT and ADC) and leave behind the harmonics. But you can’t be sure of the harmonic contribution between DAC and ADC (that’s where a notch helps).

Source code:

import os
import sys
import numpy as np
import matplotlib.pyplot as plt

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', 'src')))

from PyQa40x.analyzer import Analyzer
from PyQa40x.wave_sine import WaveSine
from PyQa40x.series_plotter import SeriesPlotter
from PyQa40x.helpers import linear_to_dBV

# Create an Analyzer instance
analyzer = Analyzer()

# Initialize the analyzer with desired parameters
params = analyzer.init(sample_rate=48000, max_input_level=0, max_output_level=8, fft_size=2**16, window_type="flattop")

# Dump key parameters
print(params)

# Wave class for DAC output
wave_dac = WaveSine(params).gen_sine_dbv(1000, 0)

# Wave class for first ADC input and accumulated ADC input
wave_adc_first = WaveSine(params)
wave_adc_accum = WaveSine(params)

# We'll do 64 acquisitions
num_acquisitions = 128  # Number of acquisitions to average

# iterate
for i in range(num_acquisitions):
    # Send the DAC buffers to the hardware, and collect the left ADC buffer
    print(f"Acquisition {i} of {num_acquisitions}")
    wave_adc, _ = analyzer.send_receive(wave_dac, wave_dac)

    # Special case: Save the first aquisition
    if i == 0:
        wave_adc_first.set_buffer(wave_adc.get_buffer())
       
    # Average in the time domain. If the input and output are not always precisely time aligned
    # by the same amount, the averaging in the time domain will converge to no signal. But, 
    # since QA403 acquisitions are precisely aligned in time, they can be averaged and the 
    # signal (and distortions) will get stronger and the noise will get weaker
    wave_adc_accum.set_buffer(wave_adc_accum.get_buffer() + wave_adc.get_buffer())
    
# Average
wave_adc_accum.set_buffer(wave_adc_accum.get_buffer()/num_acquisitions)

# Convert time-domain signals to freq domain, and specify dBV
freqs = wave_adc_first.get_frequency_array()
amplitude_first_dbv = wave_adc_first.get_amplitude_array("dbv")
amplitude_accumulated_dbv = wave_adc_accum.get_amplitude_array("dbv")

# Measure noise from 5.1 to 5.9 kHz of the first time series acquisitions
rms_first_5_to_6k = wave_adc_first.compute_rms_freq(5100, 5900)
print(f"RMS of First from 5100 to 5900 Hz: {rms_first_5_to_6k:0.2f}")

# measure noise from 5.1 to 5.9 kHz of the accumulated time series acquisitions
rms_accum_5_to_6k = wave_adc_accum.compute_rms_freq(5100, 5900)
print(f"RMS of accum from 5100 to 5900 Hz: {rms_accum_5_to_6k:0.2f}")

# Calc the noise win
reduction = rms_first_5_to_6k - rms_accum_5_to_6k

# Plotting
tsp = SeriesPlotter(num_columns=2)
tsp.add_freq_series(freqs, amplitude_first_dbv, "First Acquisition", logx=True, xmin=500, xmax = 10000)
tsp.add_freq_series(freqs, amplitude_accumulated_dbv, f"Accumulated Acquisition ({num_acquisitions} acquisitions, {reduction:0.1f} dB reduction in noise floor)", logx=True, xmin=500, xmax=10000)
tsp.plot()

analyzer.cleanup()
1 Like

The last post is an example of coherent averaging.

Below is an example of using cross-correlation to eliminate uncorrelated noise in a two-channel ADC system.

For the plot below, we have three signals generated in the source:

dut has a 1 kHz tone at 0 dBV, a 2H at -90 dBV and noise around -100 dBV
left simulates one channel of an ADC, and it has a noise floor around -50 dBV
right simulates one channel of an ADC, and it has a comparable noise floor

In the plot below, you can see the time domain signal dut_signal. And then we can see the left and right adc acquisitions we simulated. Note the noise level around -80 dBV on both the left and right channel (the fact that specified noise -50 dBV doesn’t mean much in this in this context–ignore it for now).

And on the right side we see the cross-correlated left and right channels. We ran for a single iteration. And so we have a single vector to average which won’t buy us much.

In cross_correlation.py, you can see what is happening, though. The input to do_crosscorrelation is a complex spectrum–basically the raw output of an FFT. In the cross correlation, note we’re doing an element-by element dot product of the left FFT and the complex conjugate of the right FFT. If an identical signal is presented to the left and right channel, the imaginary/phase part will become zero. But signals that are unique to just the left or right channel with have a non-zero phase. And once we have this complex cross-correlation, we can vector average those. That also happens in the cross_correlation.py file.

So, above we ran a single cross correlation. But that has no vector averaging. So, let’s run for 100 iteration and see what we get:

In the above, we can now see the 2H at -90 that was buried previously.

Now, ideally, the left and right channels of an ADC would be uncorrelated. And if that were the case, more iterations would net greater and greater gains. But in reality, there is a lot of correlation between the noise in stereo ADCs. And because of this, the technique can run out of steam.

Next time, we’ll look at how low we can get using python in signals acquired from the QA403.

cross_correlation_test.py

import numpy as np
from cross_correlation import CrossCorrelation
from series_plotter import SeriesPlotter
from fft import *

def generate_signal(num_samples, sampling_rate, frequency, fund_amp_dbv, noise_floor_dbv, harmonic_dbv):
    """
    Generate a DUT signal with a specific noise floor and harmonic.
    
    Parameters:
    - num_samples (int): Number of samples.
    - sampling_rate (float): Sampling rate in Hz.
    - frequency (float): Frequency of the carrier signal in Hz.
    - noise_floor_db (float): Noise floor level in dB relative to the carrier.
    - harmonic_db (float): Harmonic level in dB relative to the carrier.
    
    Returns:
    - dut_signal (np.ndarray): Generated DUT signal with noise and harmonic.
    - t (np.ndarray): Time array.
    """
    t = np.arange(num_samples) / sampling_rate

    #
    fund_signal = 10 ** (fund_amp_dbv / 20) * np.sqrt(2) * np.sin(2 * np.pi * frequency * t)
        
    # Harmonic signal
    harmonic_signal = 10**(harmonic_dbv / 20) * np.sqrt(2) * np.sin(2 * np.pi * 2 * frequency * t)
    
    # Random noise
    noise_signal = 10**(noise_floor_dbv / 20) * np.sqrt(2) * np.random.normal(size=t.shape)
    
    # DUT signal is the combination of carrier, harmonic, and noise
    dut_signal = fund_signal + harmonic_signal + noise_signal
    return dut_signal, t


cc = CrossCorrelation()

sample_rate = 48000
samples = 2**16
iterations = 1000
window_type = 'flattop' 

# Generate a 0 dBV signal at 1k, and a 2H at -90, with noise around -100
dut, _ = generate_signal(samples, sample_rate, 1000, 0, -100, -90)

window, acf = create_window(window_type, samples)

for _ in range(iterations):
    left, _ = generate_signal(samples, sample_rate, 1000, -1000, -50, -1000)
    right, _ = generate_signal(samples, sample_rate, 1500, -1000, -50, -1000)
    left = (left + dut) * window
    right = (right + dut) * window
    fft_left, _ = compute_fft_acf_dbv(left, sample_rate, acf)
    fft_right, _ = compute_fft_acf_dbv(right, sample_rate, acf)
    fft_complex_left = np.fft.fft(left)
    fft_complex_right = np.fft.fft(right)
    cc_result = cc.do_correlation(fft_complex_left, fft_complex_right)
    
cc_fft, freqs = compute_mag_acf_dbv(cc_result, sample_rate, acf)

#fft_left_max = np.max(fft_left)
#cc_fft_max = np.max(cc_fft)
#cc_fft = cc_fft - (cc_fft_max - fft_left_max)

sp = SeriesPlotter(num_columns=3)
sp.add_time_series(dut, "dut_signal", num_samples=2048)
sp.newrow()
sp.add_freq_series(freqs, fft_left, "last right acq", ymax=10, ymin=-120)
sp.add_freq_series(freqs, fft_right, "last left acq", ymax=10, ymin=-120)
sp.add_freq_series(freqs, cc_fft, f"cross correlation result: {iterations} iterations", ymax=10, ymin=-120)
sp.newrow()
sp.plot()

cross_correlation.py

import numpy as np
from typing import Optional

class Average:
    def __init__(self):
        self.waveform_sum: Optional[np.ndarray] = None
        self.count: int = 0

    def add_waveform(self, waveform: np.ndarray) -> None:
        """
        Add a new waveform to the current sum.
        
        :param waveform: A NumPy array representing the waveform to add.
        """
        if self.waveform_sum is None:
            self.waveform_sum = np.copy(waveform)
            self.count = 1
        else:
            self.waveform_sum += waveform
            self.count += 1

    def clear(self) -> None:
        """
        Clear the current waveform sum and reset the count.
        """
        self.waveform_sum = None
        self.count = 0

    def calc_average(self) -> np.ndarray:
        """
        Calculate the average waveform from the sum.

        :return: A NumPy array representing the average waveform.
        :raises ValueError: If no waveforms have been added.
        """
        if self.count > 0:
            result = self.waveform_sum / self.count
            return np.sqrt(result)
        
        raise ValueError("Average doesn't have any elements")


class CrossCorrelation:
    def __init__(self):
        self.avg = Average()

    def clear(self) -> None:
        """
        Clear the accumulated average data.
        """
        CrossCorrelation.avg.clear()

    def do_correlation(self, a1: np.ndarray, a2: np.ndarray) -> np.ndarray:
        """
        Compute the element-wise product of a1 with the complex conjugate of a2.

        Parameters:
        - a1 (np.ndarray): First complex array.
        - a2 (np.ndarray): Second complex array.

        Returns:
        - np.ndarray: The element-wise product of a1 and the complex conjugate of a2.
        """
        # Ensure the input arrays are complex
        a1 = np.asarray(a1, dtype=np.complex128)
        a2 = np.asarray(a2, dtype=np.complex128)

        # Compute the complex conjugate of a2
        a2_conjugate = np.conj(a2)

        # Perform element-wise multiplication
        product = a1 * a2_conjugate
        
        self.avg.add_waveform(product)

        # Always return the current vector average.
        return self.avg.calc_average()



fft.py

import numpy as np
from typing import Tuple
from scipy.signal.windows import flattop, boxcar, hamming

from helpers import linear_array_to_dBV

def create_window(window_type, length):
    """
    Create a window function and compute its amplitude correction factor.
    
    Parameters:
    - window_type (str): Type of window to apply ('hann', 'boxcar', 'flattop', etc.).
    - length (int): Length of the window.
    
    Returns:
    - window (np.ndarray): The window array.
    - amplitude_correction_factor (float): The correction factor to apply to FFT amplitudes.
    """
    if window_type == 'hann':
        window = np.hanning(length)
    elif window_type == 'boxcar':
        window = np.ones(length)
    elif window_type == 'flattop':
        window = flattop(length)
    else:
        raise ValueError(f"Unknown window type: {window_type}")

    # Amplitude correction factor for the window
    amplitude_correction_factor = 1 / np.mean(window)
    
    return window, amplitude_correction_factor

def compute_fft(signal, sampling_rate):
    """Compute the FFT of the input signal and apply scaling."""
    # Compute FFT
    fft_result = np.fft.fft(signal)
    
    # Scale the FFT result
    N = len(signal)
    fft_magnitude = (np.abs(fft_result) / (N / 2)) / np.sqrt(2)  # Proper scaling for amplitude
    
    # Frequency bins
    freqs = np.fft.fftfreq(N, d=1/sampling_rate)
    
    # Take lower half
    freqs = freqs[:len(freqs)//2]
    fft_magnitude = fft_magnitude[:len(fft_magnitude)//2]
    
    return fft_magnitude, freqs

def compute_fft_acf_dbv(signal, sampling_rate, acf):
    """Compute the FFT of the input signal and apply scaling."""
    # Compute FFT
    fft_result = np.fft.fft(signal)
    
    # Scale the FFT result
    N = len(signal)
    fft_magnitude = (np.abs(fft_result) / (N / 2)) / np.sqrt(2)  # Proper scaling for amplitude
    
    # Frequency bins
    freqs = np.fft.fftfreq(N, d=1/sampling_rate)
    
    # Take lower half
    freqs = freqs[:len(freqs)//2]
    fft_magnitude = fft_magnitude[:len(fft_magnitude)//2]

    fft_magnitude = fft_magnitude * acf
    
    return linear_array_to_dBV(fft_magnitude), freqs

def compute_mag_acf_dbv(complex_fft: np.ndarray, sample_rate: float, acf: float) -> Tuple[np.ndarray, np.ndarray]:
    """
    Compute the magnitude of the FFT and apply amplitude correction factor (ACF) in dBV.
    
    Parameters:
    - complex_fft (np.ndarray): Complex FFT array.
    - sample_rate (float): Sampling rate in Hz.
    - acf (float): Amplitude correction factor.
    
    Returns:
    - Tuple[np.ndarray, np.ndarray]: Tuple containing the magnitude in dBV and frequency bins.
    """
    # Compute the number of positive frequency components
    N = len(complex_fft) // 2
    
    # Scale the FFT result to get the magnitude
    fft_magnitude = (np.abs(complex_fft) / N) / np.sqrt(2)  # Scaling for amplitude
    
    # Frequency bins for the positive frequencies
    freqs = np.fft.fftfreq(len(complex_fft), d=1/sample_rate)
    
    # Take only the positive frequencies
    freqs = freqs[:N]
    fft_magnitude = fft_magnitude[:N]

    # Apply amplitude correction factor
    fft_magnitude *= acf
    
    # Convert to dBV
    fft_magnitude_dbv = linear_array_to_dBV(fft_magnitude)
    
    return fft_magnitude_dbv, freqs   
    

def linear_to_dBV(value: np.float64) -> np.float64:
    """
    Convert a linear value to dBV.

    Parameters:
    - value: The linear value to convert (numpy float64).

    Returns:
    - The corresponding value in dBV (numpy float64).
    """
    dBV = 20 * np.log10(value)
    return dBV


helpers.py

import numpy as np

def remove_dc(array: np.ndarray) -> np.ndarray:
    """
    Removes the DC component from an array by subtracting the mean of the array.

    Args:
        array (np.ndarray): The input array from which to remove the DC component.

    Returns:
        np.ndarray: The array with the DC component removed.
    """
    return array - np.mean(array)

def dbv_to_vpk(dbv: float) -> float:
    """
    Convert dBV to peak voltage.

    Parameters:
    dbv (float): Amplitude in dBV.

    Returns:
    float: Peak voltage.
    """
    return 10 ** (dbv / 20) * np.sqrt(2)


def dbfs_to_dbv(dbfs: float) -> float:
    """
    Convert dBFS to dBV.

    Parameters:
    dbfs (float): Amplitude in dBFS.

    Returns:
    float: Amplitude in dBV.
    """
    return dbfs - 2.98


def linear_to_dBV(value: np.float64) -> np.float64:
    """
    Convert a linear value to dBV.

    Parameters:
    - value: The linear value to convert (numpy float64).

    Returns:
    - The corresponding value in dBV (numpy float64).
    """
    dBV = 20 * np.log10(value)
    return dBV

def linear_to_dBu(value: np.float64) -> np.float64:
    """
    Convert a linear value to dBu (decibels relative to 0.775 volts).

    Parameters:
    - value: The linear value to convert (numpy float64).

    Returns:
    - The corresponding value in dBu (numpy float64).
    """
    ref_voltage = 0.775
    dBu = 20 * np.log10(value / ref_voltage)
    return dBu

def dBV_to_linear(dBV: np.float64) -> np.float64:
    """
    Convert dBV value back to linear.

    Parameters:
    - dBV: The dBV value to convert (numpy float64).

    Returns:
    - The corresponding linear value (numpy float64).
    """
    linear_value = 10**(dBV / 20.0)
    return linear_value

def dbv_to_linear_pk(dbv: float) -> float:
    """
    Converts dBV to linear peak voltage.

    Args:
        dbv (float): Amplitude in dBV.

    Returns:
        float: Amplitude in linear peak voltage.
    """
    linear_rms = dBV_to_linear(np.float64(dbv))
    linear_peak = linear_rms * np.sqrt(2)
    return float(linear_peak)

def dBu_to_linear(dBu: np.float64) -> np.float64:
    """
    Convert dBu value back to linear.

    Parameters:
    - dBu: The dBu value to convert (numpy float64).

    Returns:
    - The corresponding linear value (numpy float64).
    """
    ref_voltage = 0.775
    linear_value = 10**((dBu / 20.0) + np.log10(ref_voltage))
    return linear_value

# Functions to handle float arrays

def linear_array_to_dBV(values: np.ndarray) -> np.ndarray:
    """
    Convert an array of linear values to dBV.

    Parameters:
    - values: The array of linear values to convert (numpy ndarray).

    Returns:
    - The corresponding array of values in dBV (numpy ndarray).
    """
    dBV = 20 * np.log10(values)
    return dBV

def linear_array_to_dBu(values: np.ndarray) -> np.ndarray:
    """
    Convert an array of linear values to dBu (decibels relative to 0.775 volts).

    Parameters:
    - values: The array of linear values to convert (numpy ndarray).

    Returns:
    - The corresponding array of values in dBu (numpy ndarray).
    """
    ref_voltage = 0.775
    dBu = 20 * np.log10(values / ref_voltage)
    return dBu

def dBV_array_to_linear(dBV_values: np.ndarray) -> np.ndarray:
    """
    Convert an array of dBV values back to linear.

    Parameters:
    - dBV_values: The array of dBV values to convert (numpy ndarray).

    Returns:
    - The corresponding array of linear values (numpy ndarray).
    """
    linear_values = 10**(dBV_values / 20.0)
    return linear_values

def dBu_array_to_linear(dBu_values: np.ndarray) -> np.ndarray:
    """
    Convert an array of dBu values back to linear.

    Parameters:
    - dBu_values: The array of dBu values to convert (numpy ndarray).

    Returns:
    - The corresponding array of linear values (numpy ndarray).
    """
    ref_voltage = 0.775
    linear_values = 10**((dBu_values / 20.0) + np.log10(ref_voltage))
    return linear_values

series_plotter.py

import numpy as np
import matplotlib.pyplot as plt

class SeriesPlotter:
    def __init__(self, num_columns: int = 2, main_title: str = "", main_title_fontsize: int = 16):
        """
        Initializes the SeriesPlotter class.

        Args:
            num_columns (int): Number of columns per row in the plot grid.
            main_title (str): Main title of the plot.
            main_title_fontsize (int): Font size of the main title.
        """
        self.num_columns: int = num_columns
        self.rows: list[list[dict]] = [[]]
        self.main_title: str = main_title
        self.main_title_fontsize: int = main_title_fontsize

    def add_time_series(self, signal: np.ndarray, label: str, signal_right: np.ndarray | None = None, num_samples: int = 0, 
                        units: str = "Volts", units_right: str = "Volts", ymin: float | None = None, ymax: float | None = None, 
                        ymin_right: float | None = None, ymax_right: float | None = None, xmin: float | None = None, 
                        xmax: float | None = None, logx: bool = False):
        """
        Adds a time series plot to the current row.

        Args:
            signal (np.ndarray): Array of time series data.
            label (str): Label for the plot.
            signal_right (np.ndarray | None): Array of right channel time series data, if any.
            num_samples (int): Number of samples to plot.
            units (str): Units for the left y-axis.
            units_right (str): Units for the right y-axis.
            ymin (float | None): Minimum value for the left y-axis.
            ymax (float | None): Maximum value for the left y-axis.
            ymin_right (float | None): Minimum value for the right y-axis.
            ymax_right (float | None): Maximum value for the right y-axis.
            xmin (float | None): Minimum value for the x-axis.
            xmax (float | None): Maximum value for the x-axis.
            logx (bool): Whether to use a logarithmic scale for the x-axis.
        """
        self.rows[-1].append({
            'type': 'time',
            'signal': signal,
            'signal_right': signal_right,
            'label': label,
            'num_samples': num_samples,
            'units': units,
            'units_right': units_right,
            'ymin': ymin,
            'ymax': ymax,
            'ymin_right': ymin_right,
            'ymax_right': ymax_right,
            'xmin': xmin,
            'xmax': xmax,
            'logx': logx
        })

    def add_freq_series(self, freqs: np.ndarray, magnitudes: np.ndarray, label: str, magnitudes_right: np.ndarray | None = None, 
                        num_samples: int = 0, units: str = "dBV", units_right: str = "dBV", ymin: float | None = None, 
                        ymax: float | None = None, ymin_right: float | None = None, ymax_right: float | None = None, 
                        xmin: float | None = None, xmax: float | None = None, logx: bool = False):
        """
        Adds a frequency series plot to the current row.

        Args:
            freqs (np.ndarray): Array of frequency data.
            magnitudes (np.ndarray): Array of magnitude data.
            label (str): Label for the plot.
            magnitudes_right (np.ndarray | None): Array of right channel magnitude data, if any.
            num_samples (int): Number of samples to plot.
            units (str): Units for the left y-axis.
            units_right (str): Units for the right y-axis.
            ymin (float | None): Minimum value for the left y-axis.
            ymax (float | None): Maximum value for the left y-axis.
            ymin_right (float | None): Minimum value for the right y-axis.
            ymax_right (float | None): Maximum value for the right y-axis.
            xmin (float | None): Minimum value for the x-axis.
            xmax (float | None): Maximum value for the x-axis.
            logx (bool): Whether to use a logarithmic scale for the x-axis.
        """
        if logx:
            xmin = xmin if xmin is not None else 20
            xmax = xmax if xmax is not None else 20000

        self.rows[-1].append({
            'type': 'freq',
            'freqs': freqs,
            'magnitudes': magnitudes,
            'magnitudes_right': magnitudes_right,
            'label': label,
            'num_samples': num_samples,
            'units': units,
            'units_right': units_right,
            'ymin': ymin,
            'ymax': ymax,
            'ymin_right': ymin_right,
            'ymax_right': ymax_right,
            'xmin': xmin,
            'xmax': xmax,
            'logx': logx
        })

    def newrow(self):
        """
        Starts a new row for the plots.
        """
        if self.rows[-1]:
            self.rows.append([])

    def plot(self, block: bool = True):
        """
        Plots the added time and frequency series.

        Args:
            block (bool): Whether to block the execution until the plot window is closed.
        """
        mosaic_layout: list[list[str | None]] = []
        for row in self.rows:
            if not row:  # Skip empty rows
                continue
            # Handle rows with more elements than num_columns
            while len(row) > self.num_columns:
                mosaic_layout.append([trace['label'] for trace in row[:self.num_columns]])
                row = row[self.num_columns:]
            # Evenly distribute the elements in the row
            num_elements = len(row)
            if num_elements < self.num_columns:
                span_each = self.num_columns // num_elements
                remainder = self.num_columns % num_elements
                new_row: list[str | None] = []
                for i in range(num_elements):
                    span = span_each + (1 if i < remainder else 0)
                    new_row.extend([row[i]['label']] * span)
                mosaic_layout.append(new_row)
            else:
                mosaic_layout.append([trace['label'] for trace in row])

        label_to_trace: dict[str, dict] = {trace['label']: trace for row in self.rows for trace in row}

        fig, axd = plt.subplot_mosaic(mosaic_layout, figsize=(5 * self.num_columns, 4 * len(mosaic_layout)))

        for label, ax in axd.items():
            if label is not None:
                trace = label_to_trace[label]
                if trace['type'] == 'time':
                    signal = trace['signal']
                    signal_right = trace['signal_right']
                    num_samples = trace['num_samples']
                    units = trace['units']
                    units_right = trace['units_right']
                    ymin = trace['ymin']
                    ymax = trace['ymax']
                    ymin_right = trace['ymin_right']
                    ymax_right = trace['ymax_right']
                    xmin = trace['xmin']
                    xmax = trace['xmax']
                    logx = trace['logx']
                    if num_samples > 0:
                        signal = signal[:num_samples]
                        if signal_right is not None:
                            signal_right = signal_right[:num_samples]

                    ax.plot(signal, label=label)
                    ax.set_title(label)
                    ax.set_xlabel('Sample Index')
                    ax.set_ylabel(f'Amplitude ({units})')
                    if ymin is not None and ymax is not None:
                        ax.set_ylim(ymin, ymax)
                    if xmin is not None and xmax is not None:
                        ax.set_xlim(xmin, xmax)
                    if logx:
                        ax.set_xscale('log')

                    if signal_right is not None:
                        ax_right = ax.twinx()
                        ax_right.plot(signal_right, 'r', label=f'{label} Right')
                        ax_right.set_ylabel(f'Amplitude ({units_right})', color='r')
                        if ymin_right is not None and ymax_right is not None:
                            ax_right.set_ylim(ymin_right, ymax_right)

                elif trace['type'] == 'freq':
                    freqs = trace['freqs']
                    num_samples = trace['num_samples']
                    magnitudes = trace['magnitudes']
                    magnitudes_right = trace['magnitudes_right']
                    units = trace['units']
                    units_right = trace['units_right']
                    ymin = trace['ymin']
                    ymax = trace['ymax']
                    ymin_right = trace['ymin_right']
                    ymax_right = trace['ymax_right']
                    xmin = trace['xmin']
                    xmax = trace['xmax']
                    logx = trace['logx']
                    if num_samples > 0:
                        freqs = freqs[:num_samples]
                        magnitudes = magnitudes[:num_samples]
                        if magnitudes_right is not None:
                            magnitudes_right = magnitudes_right[:num_samples]

                    ax.plot(freqs, magnitudes, label=label)
                    ax.set_title(label)
                    ax.set_xlabel('Frequency (Hz)')
                    ax.set_ylabel(f'Magnitude ({units})')
                    if ymin is not None and ymax is not None:
                        ax.set_ylim(ymin, ymax)
                    if xmin is not None and xmax is not None:
                        ax.set_xlim(xmin, xmax)
                    if logx:
                        ax.set_xscale('log')

                    if magnitudes_right is not None:
                        ax_right = ax.twinx()
                        ax_right.plot(freqs, magnitudes_right, 'r', label=f'{label} Right')
                        ax_right.set_ylabel(f'Magnitude ({units_right})', color='r')
                        if ymin_right is not None and ymax_right is not None:
                            ax_right.set_ylim(ymin_right, ymax_right)

        if self.main_title:
            fig.suptitle(self.main_title, fontsize=self.main_title_fontsize)
            fig.subplots_adjust(top=0.95)  # Adjust the top spacing to reduce whitespace

        plt.tight_layout(rect=[0, 0, 1, 0.95])  # Adjust layout to reduce whitespace
        plt.show(block=block)
1 Like