Memory Efficient Clustering of Large Protein Trajectory Ensembles

Molecular dynamics simulations have grown increasingly ambitious, with researchers routinely generating trajectories containing hundreds of thousands or even millions of frames. While this wealth of data offers unprecedented insights into protein dynamics, it also presents a formidable computational challenge: how do you extract meaningful conformational clusters from datasets that can easily exceed available system memory?

Traditional approaches to trajectory clustering often stumble when faced with large ensembles. Loading all pairwise distances into memory simultaneously can quickly consume tens or hundreds of gigabytes of RAM, while conventional PCA implementations require the entire dataset to fit in memory before decomposition can begin. For many researchers, this means either downsampling their precious simulation data or investing in expensive high-memory computing resources.

The solution lies in recognizing that we don’t actually need to hold all our data in memory simultaneously. By leveraging incremental algorithms and smart memory management, we can perform sophisticated dimensionality reduction and clustering on arbitrarily large trajectory datasets using modest computational resources. Let’s explore how three key strategies—incremental PCA, mini-batch clustering, and intelligent memory management—can transform your approach to analyzing large protein ensembles.

Incremental PCA: Building Understanding One Chunk at a Time

The heart of trajectory clustering often involves dimensionality reduction of high-dimensional coordinate or distance data. For protein trajectories, pairwise distances between alpha carbons provide an alignment-free representation that captures the essential conformational relationships between frames. However, for a protein with 200 residues, each frame generates nearly 20,000 pairwise distances—multiply this by 100,000 frames, and you’re looking at 2 billion distance values that need to fit in memory simultaneously.

Incremental PCA offers an elegant solution by building the principal component decomposition iteratively, processing data in manageable chunks while maintaining the mathematical rigor of the full decomposition. The key insight is that PCA can be formulated as an online learning problem, where we update our understanding of the data structure with each new batch of observations.

Here’s how the trajectory clustering script implements this approach:

def calculate_distances_and_perform_pca(universe, selection, num_components, chunk_size):
    """Calculate pairwise distances and perform incremental PCA in chunks"""
    
    # Initialize IncrementalPCA with appropriate batch size
    ipca_batch_size = max(num_components * 10, 100)
    pca = IncrementalPCA(n_components=num_components, batch_size=ipca_batch_size)
    
    # Process frames in chunks to manage memory
    for chunk_start in tqdm(range(0, n_frames, chunk_size), desc="Processing frame chunks"):
        chunk_end = min(chunk_start + chunk_size, n_frames)
        chunk_distances = np.zeros((chunk_size_actual, n_distances))
        
        # Calculate pairwise distances for frames in this chunk
        for i, frame_idx in enumerate(chunk_indices):
            universe.trajectory[frame_idx]
            positions = atoms.positions
            frame_distances = pdist(positions, metric="euclidean")
            chunk_distances[i] = frame_distances
        
        # Update PCA model with this chunk
        if chunk_start == 0:
            chunk_pca_coords = pca.fit_transform(chunk_distances)
        else:
            pca.partial_fit(chunk_distances)
            chunk_pca_coords = pca.transform(chunk_distances)

The beauty of this approach lies in its mathematical equivalence to batch PCA while using only a fraction of the memory. The partial_fit method allows the PCA model to incrementally update its understanding of the data covariance structure, ensuring that the final principal components capture the same variance relationships as if we had processed the entire dataset at once.

Think of incremental PCA like learning to recognize patterns in a large collection of photographs. Rather than spreading all photos across a warehouse floor to study them simultaneously, you examine them in small, manageable stacks. Each stack teaches you something about the overall patterns, and you update your understanding as you progress. By the end, your knowledge of the complete collection is just as thorough as if you had somehow managed to view everything at once.

Mini-Batch K-Means: Clustering Without Memory Constraints

Once we’ve reduced our trajectory data to a manageable number of principal components, we face the clustering challenge. Traditional k-means algorithms require loading all data points into memory simultaneously and computing distances to all cluster centers at each iteration. For large trajectory datasets, this can again become prohibitively memory-intensive.

Mini-batch k-means addresses this limitation by updating cluster centers using small, randomly sampled subsets of the data at each iteration. This approach maintains the convergence properties of standard k-means while dramatically reducing memory requirements and often improving computational speed.

The trajectory clustering script intelligently selects the appropriate clustering algorithm based on dataset size:

def perform_kmeans_clustering(pca_coords, n_clusters):
    """Perform k-means clustering on PCA coordinates"""
    
    # Use MiniBatchKMeans for large datasets to manage memory
    if len(pca_coords) > 10000:
        kmeans = MiniBatchKMeans(n_clusters=n_clusters, random_state=42, batch_size=1000)
    else:
        kmeans = KMeans(n_clusters=n_clusters, random_state=42)
    
    cluster_labels = kmeans.fit_predict(pca_coords)
    cluster_centers = kmeans.cluster_centers_
    
    return cluster_labels, cluster_centers, kmeans

The mini-batch approach works by maintaining running estimates of cluster centers and updating them with each batch of data points. Instead of recalculating centers using all data points at each iteration, the algorithm updates centers based on the current batch, weighted by the number of samples previously assigned to each cluster. This creates a natural learning rate that allows the algorithm to adapt quickly to new information while maintaining stability as more data is processed.

Intelligent Memory Management: Processing What You Need, When You Need It

Beyond algorithmic improvements, careful memory management can significantly impact the feasibility of large-scale trajectory analysis. The key principle is to avoid loading more data into memory than absolutely necessary at any given moment, while ensuring that data access patterns remain efficient.

The trajectory clustering script demonstrates this principle through its chunked processing approach:

    for chunk_start in tqdm(range(0, n_frames, chunk_size), desc="Processing frame chunks"):
        chunk_end = min(chunk_start + chunk_size, n_frames)
        chunk_indices = frame_indices[chunk_start:chunk_end]
        chunk_size_actual = len(chunk_indices)

        # Pre-allocate array just for this chunk of frames
        chunk_distances = np.zeros((chunk_size_actual, n_distances))

        # Process each frame in the chunk
        for i, frame_idx in enumerate(chunk_indices):
            # Go to the specific frame
            universe.trajectory[frame_idx]

            # Get atom positions for this frame
            positions = atoms.positions

            # Calculate pairwise distances for this frame
            frame_distances = pdist(positions, metric="euclidean")

            # Store the distances for this frame in the chunk array
            chunk_distances[i] = frame_distances

        # Partial fit PCA with this chunk
        if chunk_start == 0:
            # For the first chunk, we need to fit and transform
            chunk_pca_coords = pca.fit_transform(chunk_distances)
        else:
            # For subsequent chunks, we partial_fit and transform
            pca.partial_fit(chunk_distances)
            chunk_pca_coords = pca.transform(chunk_distances)

        # Store the PCA coordinates for this chunk
        pca_coords[chunk_start:chunk_end] = chunk_pca_coords

        # Free memory by deleting the chunk distances
        del chunk_distances

This approach processes trajectory frames in small groups, calculates the required distances, feeds them to the incremental PCA algorithm, and then explicitly frees the memory before moving to the next chunk. The del statement ensures that Python’s garbage collector can immediately reclaim the memory, preventing accumulation of unused arrays.

For even larger datasets, memory-mapped files offer another powerful strategy. NumPy’s memmap functionality allows you to work with arrays that appear to be in memory but are actually stored on disk, with the operating system handling data transfer as needed. While not implemented in this particular script, memory mapping can be invaluable for trajectories that exceed available RAM:

# Example of memory-mapped array usage (not from the script)
large_array = np.memmap('trajectory_data.dat', dtype='float32', 
                       mode='r+', shape=(n_frames, n_features))

Memory mapping works particularly well for trajectory analysis because access patterns are often sequential or localized, allowing the operating system to efficiently cache relevant portions of the data.

An Illustrative Example

To illustrate the power of these combined approaches, consider a typical use case: clustering a 500,000-frame trajectory of a 150-residue protein. Using traditional methods, storing pairwise CA distances would require approximately 80 GB of memory (500,000 frames × 11,175 distances per frame × 8 bytes per double). The trajectory clustering script reduces this to manageable chunks of perhaps 1-2 GB at most, processing the data incrementally while achieving mathematically equivalent results.

Conclusions and Future Directions

Memory-efficient trajectory clustering represents more than just a technical optimization—it democratizes access to sophisticated conformational analysis for researchers working with large simulation datasets. By combining incremental PCA, mini-batch clustering, and intelligent memory management, we can perform rigorous analysis on trajectories that would otherwise require prohibitively expensive computational resources.

This workflow can be modified to other dimensionality reduction techniques like t-SNE and PHATE clustering. The difference for these approaches is that you would first compute, linear, PCA and then run non-linear methods on these coordiantes. As incremental PCA is linear it is possible to transform during fitting for most protein ensembles. For proteins with variable descriptions (i.e. different structural coordinate variance and covariance properties) one should fit across the whole dataset first and then transform. Memory mapped files shown here would be ideal for this task.

This can be avoided by increasing batch sizes or by ensuring adequate mixing of the training data. Here, we show mini batch K-means but HDBscan is an alternative that requires fewer iterations for robust parameters. This may not reduce memory further but may reduce the total computational load.

As molecular simulations continue to grow in scale and ambition, these memory-efficient approaches will become increasingly essential. The principles demonstrated in this trajectory clustering script—processing data incrementally, using online algorithms, and managing memory explicitly—provide a foundation for tackling even larger challenges in computational structural biology.

The next time you face a trajectory that seems too large to analyze, remember that size doesn’t have to be a barrier. With the right algorithmic tools and memory management strategies, even the most ambitious datasets can yield their secrets, one manageable chunk at a time.

Complete script:

"""
This script takes in a topology and a list of trajectories and k-means clusters the ensemble down to the specified number of clusters.

The script takes in the following args:

- topology_path (str)
- trajectory_paths (list[str])
- atom_selection (str, default: "name CA")
- chunk_size (int, default: 100) # chunk size for memory efficiency when calculating pairwise coords and running PCA
- number_of_clusters (int, default: 500)
- num_components (int, default: 10)
- output_dir (str, default: next to script inside a directory labelled with the topology name, number of clusters and the time)
- save_pdbs (bool, default: False) # Whether to save individual PDB files for each cluster
- log (bool, default: True)

The script performs a memory-efficienct PCA on the pariwise coordinates using pdist. Using these reduced coordinates the clusters are picked with k-means.

The PCA is then comprehensively and professionally plotted using a contour map alongside a scatter plot of the clusters. The clusters are then saved to a file in the output directory as xtc.

The script also saves the arguments and logging information using a logger.

"""

import argparse
import datetime
import logging
import os

import matplotlib.pyplot as plt
import MDAnalysis as mda
import numpy as np
import seaborn as sns
from matplotlib.gridspec import GridSpec
from scipy.spatial.distance import pdist
from sklearn.cluster import KMeans, MiniBatchKMeans
from sklearn.decomposition import IncrementalPCA
from tqdm import tqdm


def setup_logger(log_file):
    """Set up the logger to output to both file and console"""
    logger = logging.getLogger()
    logger.setLevel(logging.INFO)

    # Clear existing handlers if any
    for handler in logger.handlers[:]:
        logger.removeHandler(handler)

    # File handler
    file_handler = logging.FileHandler(log_file)
    file_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
    file_handler.setFormatter(file_formatter)
    logger.addHandler(file_handler)

    # Console handler
    console_handler = logging.StreamHandler()
    console_handler.setFormatter(file_formatter)
    logger.addHandler(console_handler)

    return logger


def parse_arguments():
    """Parse command line arguments"""
    parser = argparse.ArgumentParser(description="Cluster molecular dynamics trajectories")
    parser.add_argument("--topology_path", type=str, required=True, help="Path to topology file")
    parser.add_argument(
        "--trajectory_paths", nargs="+", type=str, required=True, help="Paths to trajectory files"
    )
    parser.add_argument(
        "--atom_selection",
        type=str,
        default="name CA",
        help='Atom selection string (default: "name CA")',
    )
    parser.add_argument(
        "--chunk_size",
        type=int,
        default=100,
        help="Chunk size for memory efficiency (default: 100)",
    )
    parser.add_argument(
        "--number_of_clusters",
        type=int,
        default=500,
        help="Number of clusters for k-means (default: 500)",
    )
    parser.add_argument(
        "--num_components", type=int, default=10, help="Number of PCA components (default: 10)"
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default=None,
        help="Output directory (default: auto-generated based on topology name and time)",
    )
    parser.add_argument(
        "--save_pdbs",
        action="store_true",
        help="Save individual PDB files for each cluster (default: False)",
    )
    parser.add_argument(
        "--log", action="store_true", default=True, help="Enable logging (default: True)"
    )

    return parser.parse_args()


def calculate_pairwise_rmsd(universe, selection, chunk_size):
    """Calculate pairwise distances between atoms within each frame"""
    logger.info("Calculating pairwise atomic coordinate distances within frames...")

    # Select atoms for distance calculation
    atoms = universe.select_atoms(selection)
    n_frames = len(universe.trajectory)
    n_atoms = atoms.n_atoms
    n_distances = n_atoms * (n_atoms - 1) // 2
    logger.info(f"Selected {n_atoms} atoms, processing {n_frames} frames")

    # Pre-allocate array for all pairwise distances across all frames
    all_distances = np.zeros((n_frames, n_distances))

    # Process each frame
    for i, ts in enumerate(tqdm(universe.trajectory, desc="Processing frames")):
        # Get atom positions for this frame
        positions = atoms.positions

        # Calculate pairwise distances for this frame only
        frame_distances = pdist(positions, metric="euclidean")

        # Store the distances for this frame
        all_distances[i] = frame_distances

    logger.info(f"Generated distances for {all_distances.shape[0]} frames")

    return all_distances


def perform_pca_on_distances(distances, num_components):
    """Perform PCA on the distance matrix using IncrementalPCA with appropriate batch size"""
    logger.info(f"Performing PCA with {num_components} components...")

    # Set batch size to be at least 10 times the number of components as suggested
    ipca_batch_size = max(num_components * 10, 100)

    # Initialize IncrementalPCA with proper batch size
    pca = IncrementalPCA(n_components=num_components, batch_size=ipca_batch_size)

    # Fit PCA model and transform the data
    pca_coords = pca.fit_transform(distances)

    logger.info(f"Explained variance ratio: {pca.explained_variance_ratio_}")
    logger.info(f"Total variance explained: {sum(pca.explained_variance_ratio_):.2%}")

    return pca_coords, pca


def perform_kmeans_clustering(pca_coords, n_clusters):
    """Perform k-means clustering on PCA coordinates"""
    logger.info(f"Performing k-means clustering with {n_clusters} clusters...")

    # Use MiniBatchKMeans for large datasets
    if len(pca_coords) > 10000:
        kmeans = MiniBatchKMeans(n_clusters=n_clusters, random_state=42, batch_size=1000)
    else:
        kmeans = KMeans(n_clusters=n_clusters, random_state=42)

    cluster_labels = kmeans.fit_predict(pca_coords)
    cluster_centers = kmeans.cluster_centers_

    # Count frames per cluster
    unique_labels, counts = np.unique(cluster_labels, return_counts=True)

    logger.info(f"Clustering complete: {len(unique_labels)} clusters")
    logger.info(f"Average frames per cluster: {np.mean(counts):.1f}")
    logger.info(f"Min frames per cluster: {np.min(counts)}")
    logger.info(f"Max frames per cluster: {np.max(counts)}")

    return cluster_labels, cluster_centers, kmeans


def create_publication_plots(pca_coords, cluster_labels, cluster_centers, pca, output_dir):
    """Create publication-quality plots of the PCA results and clustering"""
    plots_dir = os.path.join(output_dir, "plots")
    os.makedirs(plots_dir, exist_ok=True)

    # Set style for publication-quality plots
    plt.style.use("default")  # Use the default style or choose another valid style
    plt.rcParams["font.family"] = "sans-serif"
    plt.rcParams["font.sans-serif"] = ["Arial", "Helvetica", "DejaVu Sans"]
    plt.rcParams["font.size"] = 12
    plt.rcParams["axes.linewidth"] = 1.5
    plt.rcParams["axes.edgecolor"] = "black"

    # 1. Create main PCA plot with clusters
    logger.info("Creating PCA projection plot with clusters...")

    # Figure setup with gridspec for complex layout
    fig = plt.figure(figsize=(18, 15))
    gs = GridSpec(3, 3, figure=fig, height_ratios=[1, 3, 1])

    # Main PCA scatter plot
    ax_main = fig.add_subplot(gs[1, :2])

    # Calculate point density for contour plot
    x, y = pca_coords[:, 0], pca_coords[:, 1]
    sns.kdeplot(x=x, y=y, ax=ax_main, levels=20, cmap="Blues", fill=True, alpha=0.5, zorder=0)

    # Scatter plot of frames colored by cluster
    scatter = ax_main.scatter(
        x, y, c=cluster_labels, cmap="viridis", s=20, alpha=0.7, zorder=10, edgecolor="none"
    )

    # Plot cluster centers
    ax_main.scatter(
        cluster_centers[:, 0],
        cluster_centers[:, 1],
        c="red",
        s=80,
        marker="X",
        edgecolors="black",
        linewidths=1.5,
        zorder=20,
        label="Cluster Centers",
    )

    # Labels and title
    variance_pc1 = pca.explained_variance_ratio_[0] * 100
    variance_pc2 = pca.explained_variance_ratio_[1] * 100
    ax_main.set_xlabel(f"PC1 ({variance_pc1:.1f}% variance)", fontsize=14)
    ax_main.set_ylabel(f"PC2 ({variance_pc2:.1f}% variance)", fontsize=14)
    ax_main.set_title("PCA Projection with K-means Clustering", fontsize=16, pad=20)

    # Add histograms for PC1 distribution
    ax_top = fig.add_subplot(gs[0, :2], sharex=ax_main)
    sns.histplot(x, kde=True, ax=ax_top, color="darkblue", alpha=0.6)
    ax_top.set_ylabel("Density", fontsize=12)
    ax_top.set_title("PC1 Distribution", fontsize=14)
    ax_top.tick_params(labelbottom=False)

    # Add histograms for PC2 distribution
    ax_right = fig.add_subplot(gs[1, 2], sharey=ax_main)
    sns.histplot(y=y, kde=True, ax=ax_right, color="darkblue", alpha=0.6, orientation="horizontal")
    ax_right.set_xlabel("Density", fontsize=12)
    ax_right.set_title("PC2 Distribution", fontsize=14)
    ax_right.tick_params(labelleft=False)

    # Create explained variance ratio plot
    ax_var = fig.add_subplot(gs[2, :])
    components = range(1, len(pca.explained_variance_ratio_) + 1)
    cumulative = np.cumsum(pca.explained_variance_ratio_)

    # Plot individual and cumulative explained variance
    bars = ax_var.bar(
        components, pca.explained_variance_ratio_, color="steelblue", alpha=0.7, label="Individual"
    )

    ax_var2 = ax_var.twinx()
    line = ax_var2.plot(
        components,
        cumulative,
        "o-",
        color="firebrick",
        linewidth=2.5,
        markersize=8,
        label="Cumulative",
    )

    # Add explained variance labels
    ax_var.set_xlabel("Principal Component", fontsize=14)
    ax_var.set_ylabel("Explained Variance Ratio", fontsize=14)
    ax_var2.set_ylabel("Cumulative Explained Variance", fontsize=14)
    ax_var.set_title("Explained Variance by Principal Components", fontsize=16, pad=20)

    # Set x-axis to integers
    ax_var.set_xticks(components)
    ax_var2.set_ylim([0, 1.05])

    # Combine legends
    lines, labels = ax_var.get_legend_handles_labels()
    lines2, labels2 = ax_var2.get_legend_handles_labels()
    ax_var.legend(lines + lines2, labels + labels2, loc="upper left", fontsize=12)

    # Add colorbar for cluster labels
    cbar_ax = fig.add_axes([0.92, 0.15, 0.02, 0.7])
    cbar = plt.colorbar(scatter, cax=cbar_ax)
    cbar.set_label("Cluster Label", fontsize=14, labelpad=15)

    # Save the figure
    plt.tight_layout()
    pca_plot_path = os.path.join(plots_dir, "pca_clusters.png")
    plt.savefig(pca_plot_path, dpi=300, bbox_inches="tight")
    plt.close()

    # 2. Create 3D PCA plot if we have at least 3 components
    if pca_coords.shape[1] >= 3:
        logger.info("Creating 3D PCA plot...")
        fig = plt.figure(figsize=(12, 10))
        ax = fig.add_subplot(111, projection="3d")

        scatter = ax.scatter(
            pca_coords[:, 0],
            pca_coords[:, 1],
            pca_coords[:, 2],
            c=cluster_labels,
            cmap="viridis",
            s=30,
            alpha=0.7,
        )

        ax.scatter(
            cluster_centers[:, 0],
            cluster_centers[:, 1],
            cluster_centers[:, 2],
            c="red",
            s=100,
            marker="X",
            edgecolors="black",
            linewidths=1.5,
        )

        variance_pc3 = pca.explained_variance_ratio_[2] * 100
        ax.set_xlabel(f"PC1 ({variance_pc1:.1f}% variance)", fontsize=12)
        ax.set_ylabel(f"PC2 ({variance_pc2:.1f}% variance)", fontsize=12)
        ax.set_zlabel(f"PC3 ({variance_pc3:.1f}% variance)", fontsize=12)
        ax.set_title("3D PCA Projection with Clusters", fontsize=16)

        plt.colorbar(scatter, ax=ax, label="Cluster Label")
        plt.tight_layout()

        pca_3d_path = os.path.join(plots_dir, "pca_3d.png")
        plt.savefig(pca_3d_path, dpi=300, bbox_inches="tight")
        plt.close()

    # 3. Create cluster size distribution plot
    logger.info("Creating cluster size distribution plot...")
    unique_labels, counts = np.unique(cluster_labels, return_counts=True)

    plt.figure(figsize=(14, 8))
    sns.histplot(counts, kde=True, color="steelblue")
    plt.axvline(np.mean(counts), color="red", linestyle="--", label=f"Mean: {np.mean(counts):.1f}")
    plt.axvline(
        np.median(counts), color="green", linestyle="--", label=f"Median: {np.median(counts):.1f}"
    )

    plt.xlabel("Frames per Cluster", fontsize=14)
    plt.ylabel("Frequency", fontsize=14)
    plt.title("Distribution of Cluster Sizes", fontsize=16)
    plt.legend(fontsize=12)
    plt.tight_layout()

    cluster_dist_path = os.path.join(plots_dir, "cluster_distribution.png")
    plt.savefig(cluster_dist_path, dpi=300)
    plt.close()

    return {
        "pca_plot": pca_plot_path,
        "pca_3d_plot": pca_3d_path if pca_coords.shape[1] >= 3 else None,
        "cluster_dist": cluster_dist_path,
    }


def save_cluster_trajectories(
    universe, cluster_labels, pca_coords, cluster_centers, output_dir, save_pdbs=False
):
    """Save cluster trajectories to a single XTC file and optionally save representative PDB files

    Parameters:
    -----------
    universe : MDAnalysis.Universe
        The universe containing all frames
    cluster_labels : numpy.ndarray
        Array of cluster labels for each frame
    pca_coords : numpy.ndarray
        PCA coordinates for each frame
    cluster_centers : numpy.ndarray
        K-means computed cluster centers in PCA space
    output_dir : str
        Output directory path
    save_pdbs : bool, optional
        Whether to save individual PDB files for each cluster representative
    """
    logger.info("Saving cluster trajectories...")

    clusters_dir = os.path.join(output_dir, "clusters")
    os.makedirs(clusters_dir, exist_ok=True)

    # Get unique clusters
    unique_clusters = np.unique(cluster_labels)
    n_clusters = len(unique_clusters)

    # Find the frame in each cluster that is closest to the true cluster center
    representative_frames = {}

    for cluster_idx in unique_clusters:
        # Get indices of frames in this cluster
        cluster_mask = cluster_labels == cluster_idx
        if not np.any(cluster_mask):
            continue

        cluster_frame_indices = np.where(cluster_mask)[0]

        # Get PCA coordinates for frames in this cluster
        cluster_pca_coords = pca_coords[cluster_mask]

        # Get the center for this cluster
        center = cluster_centers[cluster_idx]

        # Calculate distances from each frame to the cluster center
        distances = np.sqrt(np.sum((cluster_pca_coords - center) ** 2, axis=1))

        # Find the frame with minimum distance to center
        min_dist_idx = np.argmin(distances)

        # Get the original frame index
        representative_frame_idx = cluster_frame_indices[min_dist_idx]

        # Store the representative frame
        representative_frames[cluster_idx] = representative_frame_idx

    # Create a single trajectory file with cluster centers only
    all_clusters_file = os.path.join(clusters_dir, "all_clusters.xtc")
    with mda.Writer(all_clusters_file, universe.atoms.n_atoms) as writer:
        # Go through each cluster and save only the representative frame
        for cluster_idx in tqdm(unique_clusters, desc="Saving cluster centers"):
            if cluster_idx in representative_frames:
                frame_idx = representative_frames[cluster_idx]
                universe.trajectory[frame_idx]
                writer.write(universe.atoms)

    # Optionally save representative frames as PDB files
    if save_pdbs:
        for cluster_idx, frame_idx in tqdm(representative_frames.items(), desc="Saving PDB files"):
            universe.trajectory[frame_idx]
            with mda.Writer(
                os.path.join(clusters_dir, f"cluster_{cluster_idx}_rep.pdb")
            ) as pdb_writer:
                pdb_writer.write(universe.atoms)

    # Also save a CSV file mapping frame to cluster
    frame_to_cluster = np.column_stack((np.arange(len(cluster_labels)), cluster_labels))
    np.savetxt(
        os.path.join(clusters_dir, "frame_to_cluster.csv"),
        frame_to_cluster,
        delimiter=",",
        header="frame_index,cluster_label",
        fmt="%d",
        comments="",
    )

    logger.info(
        f"Saved {n_clusters} true cluster centers to a single trajectory file: {all_clusters_file}"
    )
    if save_pdbs:
        logger.info(f"Saved {len(representative_frames)} representative PDB files")
    logger.info("Saved frame-to-cluster mapping to frame_to_cluster.csv")


def calculate_distances_and_perform_pca(universe, selection, num_components, chunk_size):
    """Calculate pairwise distances and perform incremental PCA in chunks"""
    logger.info("Calculating pairwise distances and performing incremental PCA...")

    # Select atoms for distance calculation
    atoms = universe.select_atoms(selection)
    n_frames = len(universe.trajectory)
    n_atoms = atoms.n_atoms
    n_distances = n_atoms * (n_atoms - 1) // 2

    logger.info(f"Selected {n_atoms} atoms, processing {n_frames} frames")
    logger.info(f"Each frame will generate {n_distances} pairwise distances")

    # Initialize IncrementalPCA
    ipca_batch_size = max(num_components * 10, 100)
    pca = IncrementalPCA(n_components=num_components, batch_size=ipca_batch_size)

    # Process frames in chunks
    frame_indices = np.arange(n_frames)

    # Store PCA coordinates for all frames
    pca_coords = np.zeros((n_frames, num_components))

    for chunk_start in tqdm(range(0, n_frames, chunk_size), desc="Processing frame chunks"):
        chunk_end = min(chunk_start + chunk_size, n_frames)
        chunk_indices = frame_indices[chunk_start:chunk_end]
        chunk_size_actual = len(chunk_indices)

        # Pre-allocate array just for this chunk of frames
        chunk_distances = np.zeros((chunk_size_actual, n_distances))

        # Process each frame in the chunk
        for i, frame_idx in enumerate(chunk_indices):
            # Go to the specific frame
            universe.trajectory[frame_idx]

            # Get atom positions for this frame
            positions = atoms.positions

            # Calculate pairwise distances for this frame
            frame_distances = pdist(positions, metric="euclidean")

            # Store the distances for this frame in the chunk array
            chunk_distances[i] = frame_distances

        # Partial fit PCA with this chunk
        if chunk_start == 0:
            # For the first chunk, we need to fit and transform
            chunk_pca_coords = pca.fit_transform(chunk_distances)
        else:
            # For subsequent chunks, we partial_fit and transform
            pca.partial_fit(chunk_distances)
            chunk_pca_coords = pca.transform(chunk_distances)

        # Store the PCA coordinates for this chunk
        pca_coords[chunk_start:chunk_end] = chunk_pca_coords

        # Free memory by deleting the chunk distances
        del chunk_distances

    logger.info(f"Completed incremental PCA with {num_components} components")
    logger.info(f"Explained variance ratio: {pca.explained_variance_ratio_}")
    logger.info(f"Total variance explained: {sum(pca.explained_variance_ratio_):.2%}")

    return pca_coords, pca


def main():
    # Parse command line arguments
    args = parse_arguments()

    # Start timer
    start_time = datetime.datetime.now()

    # Set up output directory
    if args.output_dir is None:
        topology_name = os.path.splitext(os.path.basename(args.topology_path))[0]
        timestamp = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
        args.output_dir = os.path.join(
            os.path.dirname(os.path.abspath(__file__)),
            f"{topology_name}_clusters{args.number_of_clusters}_{timestamp}",
        )

    # Create the output directory
    os.makedirs(args.output_dir, exist_ok=True)

    # Set up logging
    global logger
    log_file = os.path.join(args.output_dir, "cluster_trajectory.log")
    logger = setup_logger(log_file)

    logger.info("=" * 80)
    logger.info("Starting trajectory clustering")
    logger.info("=" * 80)
    logger.info(f"Arguments: {vars(args)}")

    # Load universe
    logger.info(f"Loading topology from {args.topology_path}")
    logger.info(f"Loading trajectories: {args.trajectory_paths}")
    universe = mda.Universe(args.topology_path, *args.trajectory_paths)
    logger.info(
        f"Loaded universe with {len(universe.atoms)} atoms and {len(universe.trajectory)} frames"
    )

    # Calculate distances and perform PCA
    pca_coords, pca = calculate_distances_and_perform_pca(
        universe, args.atom_selection, args.num_components, args.chunk_size
    )

    # Perform k-means clustering
    cluster_labels, cluster_centers, kmeans = perform_kmeans_clustering(
        pca_coords, args.number_of_clusters
    )

    # Create publication-quality plots
    plot_paths = create_publication_plots(
        pca_coords, cluster_labels, cluster_centers, pca, args.output_dir
    )

    # Save cluster trajectories
    save_cluster_trajectories(
        universe,
        cluster_labels,
        pca_coords,
        cluster_centers,
        args.output_dir,
        args.save_pdbs,
    )

    # Save analysis data
    data_dir = os.path.join(args.output_dir, "data")
    os.makedirs(data_dir, exist_ok=True)

    np.save(os.path.join(data_dir, "pca_coordinates.npy"), pca_coords)
    np.save(os.path.join(data_dir, "cluster_labels.npy"), cluster_labels)
    np.save(os.path.join(data_dir, "cluster_centers.npy"), cluster_centers)

    # End timer and report
    end_time = datetime.datetime.now()
    elapsed = end_time - start_time
    logger.info("=" * 80)
    logger.info(f"Clustering complete. Results saved to {args.output_dir}")
    logger.info(f"Total execution time: {elapsed}")
    logger.info("=" * 80)


if __name__ == "__main__":
    main()

Author

Leave a Reply

This site uses Akismet to reduce spam. Learn how your comment data is processed.