Advanced PyMOL Visualization for Weighted Structural Ensembles (Part 2): Efficient Weighted SASA Surfaces

In Part 1, we covered reference state handling, RMSD-based coloring, and cluster visualization for weighted structural ensembles. Now we tackle a more ambitious goal: generating solvent-accessible surface area (SASA) surfaces that reflect the weighted conformational distribution of your ensemble.

Why surfaces? Because they show the accessible conformational space—where your protein can actually be found, weighted by population. This is particularly powerful when comparing different fitting methods or showing how experimental constraints reshape the ensemble.

The challenge? A typical ensemble might have 500+ frames, each generating thousands of surface points. Naive approaches choke on the computational and memory demands. This post shares the optimizations that make weighted SASA visualization practical.

Why SASA Surfaces for Ensemble Visualization

Standard ensemble visualization overlays structures as ribbons or cartoons. This works for small ensembles but becomes an unreadable mess beyond ~20 structures. Surfaces offer an alternative: instead of showing individual structures, show the envelope of accessible conformations.

With weighted ensembles, we can go further. Rather than treating all frames equally, we weight surface points by their frame’s population weight. High-weight regions appear denser; low-weight regions fade away. The result is a surface that reflects the experimentally-consistent conformational distribution.

(Left) PCA plots of MD at pH4 (top, blue) and pH7 (bottom, grey). (Right) Weighted-ensemble SASA envelopes. Teal: pH4 (Disease-like), Grey: pH7 (Standard).

The workflow is:

  1. Compute per-atom SASA for each frame
  2. Generate surface points around exposed atom
  3. Assign weights to points based on their source frame
  4. Filter by weighted local density to show the high-confidence envelope
  5. Render as mesh, spheres, or points

Computing Per-Atom SASA with FreeSASA

We use FreeSASA for fast SASA calculations, with a fallback to van der Waals estimates if FreeSASA isn’t available.

from Bio.PDB import PDBParser
import numpy as np

def compute_sasa_biopython(pdb_file, align_residues=None):
    """
    Compute per-atom SASA values for a structure.
    Returns coords, SASA values, and alignment coords.
    """
    parser = PDBParser(QUIET=True)
    structure = parser.get_structure('protein', pdb_file)
    
    # Get alignment coordinates (for later superposition)
    align_coords = []
    for model in structure:
        for chain in model:
            for residue in chain:
                resseq = residue.get_id()[1]
                if align_residues and resseq not in align_residues:
                    continue
                if 'CA' in residue:
                    align_coords.append(residue['CA'].get_coord())
    
    # Try FreeSASA first
    try:
        import freesasa
        import contextlib
        import os
        
        # Suppress FreeSASA's verbose output
        with contextlib.redirect_stdout(open(os.devnull, 'w')), \
             contextlib.redirect_stderr(open(os.devnull, 'w')):
            sasa_result = freesasa.calc(freesasa.Structure(pdb_file))
        
        atom_sasa = []
        atom_coords = []
        for model in structure:
            for chain in model:
                for residue in chain:
                    for atom in residue:
                        sasa = sasa_result.atomArea(atom)
                        atom_sasa.append(sasa)
                        atom_coords.append(atom.get_coord())
        
        return {
            'coords': np.array(atom_coords),
            'sasa': np.array(atom_sasa),
            'align_coords': np.array(align_coords)
        }
    
    except ImportError:
        # Fallback: estimate from van der Waals radii
        vdw_radii = {'C': 1.70, 'N': 1.55, 'O': 1.52, 'S': 1.80, 'H': 1.20}
        probe_radius = 1.4
        
        atom_sasa = []
        atom_coords = []
        for model in structure:
            for chain in model:
                for residue in chain:
                    for atom in residue:
                        elem = atom.element if atom.element else 'C'
                        radius = vdw_radii.get(elem, 1.70)
                        # Approximate SASA as sphere surface area
                        sasa = 4 * np.pi * (radius + probe_radius) ** 2
                        atom_sasa.append(sasa)
                        atom_coords.append(atom.get_coord())
        
        return {
            'coords': np.array(atom_coords),
            'sasa': np.array(atom_sasa),
            'align_coords': np.array(align_coords)
        }

For an ensemble, we process each frame:

def compute_sasa_ensemble(pdb_files):
    """Compute SASA for all structures in ensemble."""
    all_sasa_data = []
    
    for pdb_file in pdb_files:
        sasa_data = compute_sasa_biopython(pdb_file)
        if sasa_data:
            all_sasa_data.append(sasa_data)
    
    return all_sasa_data

Generating Surface Geometry from SASA Values

For each atom with significant SASA, we generate points on a sphere proportional to its exposed surface area. This creates a point cloud approximating the molecular surface.

def get_raw_surface_geometry(sasa_data_list, probe_radius=1.4, 
                             sasa_threshold=0.0, frame_subsample=500):
    """
    Generate surface points from SASA data.
    
    Args:
        sasa_data_list: List of SASA data dicts from compute_sasa_ensemble
        probe_radius: Probe radius used in SASA calculation
        sasa_threshold: Minimum SASA to include an atom
        frame_subsample: Max points per frame (controls memory usage)
    
    Returns:
        surface_points: (N, 3) array of surface coordinates
        frame_indices: (N,) array mapping each point to its source frame
    """
    all_surface_points = []
    all_frame_indices = []
    
    for frame_idx, data in enumerate(sasa_data_list):
        coords = data['coords']
        sasa_values = data['sasa']
        frame_points = []
        
        # Only process atoms with significant SASA
        valid_mask = sasa_values >= max(sasa_threshold, 0.01)
        valid_coords = coords[valid_mask]
        valid_sasa = sasa_values[valid_mask]
        
        for coord, sasa in zip(valid_coords, valid_sasa):
            # Surface radius from SASA: A = 4πr² → r = √(A/4π)
            surface_radius = np.sqrt(sasa / (4 * np.pi)) + probe_radius
            
            # Number of points scales with surface area
            n_points = max(2, int(sasa / 5))
            
            # Generate spherical coordinates
            phi = np.linspace(0, 2 * np.pi, n_points)
            theta = np.linspace(0, np.pi, max(2, n_points // 2))
            P, T = np.meshgrid(phi, theta)
            
            # Convert to Cartesian
            X = coord[0] + surface_radius * np.sin(T) * np.cos(P)
            Y = coord[1] + surface_radius * np.sin(T) * np.sin(P)
            Z = coord[2] + surface_radius * np.cos(T)
            
            points = np.column_stack((X.flatten(), Y.flatten(), Z.flatten()))
            frame_points.extend(points)
        
        # Subsample to control memory
        if frame_points:
            frame_points = np.asarray(frame_points)
            if len(frame_points) > frame_subsample:
                indices = np.random.choice(len(frame_points), 
                                          frame_subsample, replace=False)
                frame_points = frame_points[indices]
            
            all_surface_points.extend(frame_points)
            all_frame_indices.extend([frame_idx] * len(frame_points))
    
    return (np.array(all_surface_points, dtype=np.float32),
            np.array(all_frame_indices, dtype=np.int32))

For a 500-frame ensemble with 500 points per frame, this generates ~250,000 surface points—manageable, but we need to be smart about what comes next.

Caching Intermediate Data

SASA computation is expensive. Computing it fresh every time you tweak visualization parameters is painful. The solution: cache the surface geometry and frame indices.

def get_raw_surface_geometry_cached(sasa_data_list, cache_file=None, **kwargs):
    """
    Generate surface geometry with optional caching.
    """
    # Try loading from cache
    if cache_file and os.path.exists(cache_file):
        try:
            cached = np.load(cache_file)
            return (cached['surface_points'].astype(np.float32),
                    cached['frame_indices'].astype(np.int32))
        except Exception as e:
            print(f"Cache load failed: {e}")
    
    # Compute fresh
    points, indices = get_raw_surface_geometry(sasa_data_list, **kwargs)
    
    # Save cache
    if cache_file and points is not None:
        try:
            np.savez_compressed(cache_file, 
                               surface_points=points,
                               frame_indices=indices)
        except Exception as e:
            print(f"Cache save failed: {e}")
    
    return points, indices

I use a naming convention that encodes the parameters:

# Cache file includes ensemble name and trajectory base name
cache_file = f"{cluster_name}_{traj_base}_uniform_aligned_surface_cache.npz"

# For weighted surfaces, include the weights file identifier
weighted_cache = f"{cluster_name}_{traj_base}_weighted_aligned_surface_cache.npz"

This way, changing your weights file triggers a recompute, but re-running the same analysis loads instantly.

Weighted Density Filtering: The Key Insight

Here’s where weighted ensembles get interesting. We have surface points, and each point has a weight from its source frame. But raw point clouds are noisy—we want to show the high-confidence regions where multiple high-weight frames agree.

The approach: compute weighted local density for each point, then filter to keep only points above a percentile threshold.

def filter_surface_density(points, point_weights, ci_percentile=95, 
                           density_radius=2.0):
    """
    Filter surface points by weighted local density.
    
    Points in regions where many high-weight frames contribute
    will have high weighted density and be retained.
    
    Args:
        points: (N, 3) surface coordinates
        point_weights: (N,) weight for each point
        ci_percentile: Keep points above this density percentile
        density_radius: Radius for local density calculation (Angstroms)
    
    Returns:
        Filtered points array
    """
    from sasa_workers import compute_weighted_densities
    
    densities = compute_weighted_densities(
        points, point_weights, 
        radius=density_radius
    )
    
    threshold = np.percentile(densities, 100 - ci_percentile)
    return points[densities >= threshold]

The ci_percentile parameter controls how much of the surface you show:

  • 95%: Show only the highest-density core (tight envelope)
  • 85%: Show moderate density regions (broader coverage)
  • 50%: Show everything above median density

For comparing uniform vs weighted ensembles, I typically use 85% to see meaningful differences without too much noise.

Computational Trade-offs: Voxel Grid vs KD-Tree

Computing local density for 250,000 points is expensive. You need to find all neighbors within density_radius for each point—a classic spatial query problem. I’ve implemented two approaches with different trade-offs.

KD-Tree with Query Pairs (Recommended)

SciPy’s cKDTree.query_pairs finds all point pairs within a radius in a single call, leveraging internal parallelization:

from scipy.spatial import cKDTree
import numpy as np

def compute_weighted_densities_kdtree(points, weights, radius):
    """
    Fast weighted density using KD-tree query_pairs.
    
    Time complexity: O(N log N) construction + O(N * k) for k neighbors
    Memory: O(N) for tree + O(pairs) for results
    """
    points = np.asarray(points, dtype=np.float32)
    weights = np.asarray(weights, dtype=np.float32)
    N = len(points)
    
    tree = cKDTree(points, compact_nodes=False, balanced_tree=False)
    
    # Get all pairs within radius - returns (M, 2) array
    pairs = tree.query_pairs(r=radius, output_type='ndarray')
    
    densities = np.zeros(N, dtype=np.float32)
    
    if pairs.size > 0:
        i, j = pairs[:, 0], pairs[:, 1]
        # Each point accumulates weights of its neighbors
        np.add.at(densities, i, weights[j])
        np.add.at(densities, j, weights[i])
    
    # Include self-weight
    densities += weights
    
    return densities

This is fast and memory-efficient for most ensemble sizes. The query_pairs approach is particularly elegant because it naturally handles the symmetry of neighbor relationships.

Voxel Grid (For Very Large Ensembles)

For ensembles with millions of points, even KD-tree construction becomes slow. A voxel grid approach provides better scaling:

def compute_weighted_densities_voxel(points, weights, radius):
    """
    Voxel-based weighted density for very large point sets.
    
    Time complexity: O(N) for binning + O(N * 27) for neighbor cells
    Memory: O(N) + O(unique_cells)
    """
    points = np.asarray(points, dtype=np.float32)
    weights = np.asarray(weights, dtype=np.float32)
    N = len(points)
    
    cell_size = radius
    coords_min = points.min(axis=0)
    
    # Compute cell indices for each point
    ijk = np.floor((points - coords_min) / cell_size).astype(np.int32)
    
    # Hash cells to unique keys
    keys = ((ijk[:, 0].astype(np.int64) << 42) ^ 
            (ijk[:, 1].astype(np.int64) << 21) ^ 
            ijk[:, 2].astype(np.int64))
    
    # Group points by cell
    sort_idx = np.argsort(keys)
    sorted_keys = keys[sort_idx]
    unique_keys, key_starts = np.unique(sorted_keys, return_index=True)
    
    # Build cell -> point range lookup
    cell_ends = np.empty_like(key_starts)
    cell_ends[:-1] = key_starts[1:]
    cell_ends[-1] = N
    
    key_to_range = {
        int(k): (int(s), int(e)) 
        for k, s, e in zip(unique_keys, key_starts, cell_ends)
    }
    
    densities = np.zeros(N, dtype=np.float32)
    r_sq = radius * radius
    
    # 3x3x3 neighbor offsets
    offsets = [(dx, dy, dz) 
               for dx in (-1, 0, 1) 
               for dy in (-1, 0, 1) 
               for dz in (-1, 0, 1)]
    
    # Process each cell
    for center_key in unique_keys:
        start_c, end_c = key_to_range[int(center_key)]
        center_idx = sort_idx[start_c:end_c]
        center_pts = points[center_idx]
        center_w = weights[center_idx]
        
        i0, j0, k0 = ijk[center_idx[0]]
        
        for dx, dy, dz in offsets:
            neighbor_key = ((int(i0 + dx) << 42) ^ 
                           (int(j0 + dy) << 21) ^ 
                           int(k0 + dz))
            
            if neighbor_key not in key_to_range:
                continue
            
            # Avoid double-counting pairs
            if neighbor_key < center_key:
                continue
            
            start_n, end_n = key_to_range[neighbor_key]
            neighbor_idx = sort_idx[start_n:end_n]
            neighbor_pts = points[neighbor_idx]
            neighbor_w = weights[neighbor_idx]
            
            # Compute pairwise distances
            diff = center_pts[:, None, :] - neighbor_pts[None, :, :]
            d2 = np.sum(diff * diff, axis=2)
            mask = d2 <= r_sq
            
            if neighbor_key == center_key:
                # Same cell: exclude self-pairs
                np.fill_diagonal(mask, False)
                densities[center_idx] += mask.astype(np.float32) @ center_w
            else:
                # Different cells: update both
                densities[center_idx] += mask.astype(np.float32) @ neighbor_w
                densities[neighbor_idx] += mask.T.astype(np.float32) @ center_w
    
    return densities

When to Use Which?

MethodBest ForTimeMemory
KD-Tree< 500K points~10sModerate
Voxel Grid> 500K points~30sLower

In practice, KD-tree handles most ensemble visualization tasks. I include the voxel approach for batch processing or when memory is constrained.

Drawing Surfaces: Mesh, Spheres, and Points

Once filtered, we need to render the point cloud. PyMOL offers several approaches:

Point Visualization (Fastest)

Good for quick iteration:

from pymol import cmd, cgo

def draw_points(points, name, color, max_points=20000):
    """Draw surface as small crosses at each point."""
    if len(points) > max_points:
        indices = np.random.choice(len(points), max_points, replace=False)
        points = points[indices]
    
    obj = [cgo.BEGIN, cgo.LINES, cgo.COLOR] + list(color)
    
    for x, y, z in points:
        # Draw small cross at each point
        for dx, dy, dz in [(0.1,0,0), (-0.1,0,0), 
                           (0,0.1,0), (0,-0.1,0), 
                           (0,0,0.1), (0,0,-0.1)]:
            obj.extend([cgo.VERTEX, x+dx, y+dy, z+dz])
    
    obj.append(cgo.END)
    cmd.load_cgo(obj, name)

Sphere Visualization (Publication Quality)

Better looking but slower:

def draw_spheres(points, name, color, radius=0.3, max_spheres=3000):
    """Draw surface as small spheres."""
    if len(points) > max_spheres:
        indices = np.random.choice(len(points), max_spheres, replace=False)
        points = points[indices]
    
    obj = [cgo.BEGIN, cgo.TRIANGLES, cgo.COLOR] + list(color)
    
    for x, y, z in points:
        # Simple octahedron approximation of sphere
        vertices = [
            [x+radius, y, z], [x-radius, y, z],
            [x, y+radius, z], [x, y-radius, z],
            [x, y, z+radius], [x, y, z-radius]
        ]
        faces = [[0,2,4], [0,4,3], [0,3,5], [0,5,2],
                 [1,4,2], [1,3,4], [1,5,3], [1,2,5]]
        
        for face in faces:
            for v_idx in face:
                obj.extend([cgo.VERTEX] + vertices[v_idx])
    
    obj.append(cgo.END)
    cmd.load_cgo(obj, name)

Mesh Visualization (Best for Envelopes)

The mesh approach uses PyMOL’s built-in Gaussian map to create a smooth surface from the point cloud:

import tempfile
import os

def draw_mesh(points, name, color, max_mesh_points=500000):
    """Draw surface as smooth mesh via Gaussian density map."""
    if len(points) > max_mesh_points:
        points = points[np.random.choice(len(points), 
                                         max_mesh_points, replace=False)]
    
    # Write points as temporary PDB
    fd, temp_pdb = tempfile.mkstemp(suffix=".pdb")
    os.close(fd)
    
    temp_obj = f"{name}_pts"
    map_name = f"{name}_map"
    
    try:
        with open(temp_pdb, 'w') as f:
            for i, (x, y, z) in enumerate(points):
                f.write(f"ATOM  {(i%99999)+1:5d}  X   PTS A   1    "
                       f"{x:8.3f}{y:8.3f}{z:8.3f}  1.00  1.00           C\n")
        
        # Load and create Gaussian map
        cmd.load(temp_pdb, temp_obj)
        cmd.hide("everything", temp_obj)
        cmd.map_new(map_name, "gaussian", 1.0, temp_obj, 5.0)
        
        # Generate isomesh at appropriate contour level
        cmd.isomesh(name, map_name, level=0.8)
        
        # Color the mesh
        cmd.set_color(f"col_{name}", color)
        cmd.color(f"col_{name}", name)
        cmd.show("mesh", name)
        
        cmd.delete(temp_obj)
        
    finally:
        if os.path.exists(temp_pdb):
            os.remove(temp_pdb)

The Gaussian approach creates beautiful smooth envelopes that clearly show the accessible conformational space.

Putting It All Together

Here’s the complete workflow for generating weighted SASA surfaces:

def generate_weighted_surface(cluster_name, trajectory_file, weights_file,
                              reference_pdb=None, ci_percentile=85,
                              density_radius=6.5):
    """
    Complete pipeline for weighted SASA surface generation.
    """
    from pymol import cmd
    
    # 1. Load trajectory
    cmd.load(topology_file, cluster_name)
    cmd.load_traj(trajectory_file, cluster_name)
    
    # 2. Extract states to PDB files
    pdb_files = extract_states_to_pdb(cluster_name, output_dir="/tmp")
    
    # 3. Compute SASA for all frames
    sasa_data = compute_sasa_ensemble(pdb_files)
    
    # 4. Align to reference if provided
    if reference_pdb:
        target_coords = get_reference_target_coords(reference_pdb)
        sasa_data = align_sasa_data_to_reference(sasa_data, target_coords)
    
    # 5. Generate surface geometry (with caching)
    cache_file = f"{cluster_name}_surface_cache.npz"
    raw_points, frame_indices = get_raw_surface_geometry_cached(
        sasa_data, cache_file=cache_file
    )
    
    # 6. Load weights and map to points
    weights = np.load(weights_file)["weights"]
    point_weights = weights[frame_indices]
    
    # 7. Filter by weighted density
    filtered_points = filter_surface_density(
        raw_points, point_weights,
        ci_percentile=ci_percentile,
        density_radius=density_radius
    )
    
    # 8. Draw surfaces
    color_weighted = [0.95, 0.65, 0.30]  # Orange for weighted
    color_uniform = [0.60, 0.80, 1.00]   # Light blue for uniform
    
    draw_mesh(filtered_points, f"{cluster_name}_weighted_mesh", color_weighted)
    
    # Compare with uniform weighting
    uniform_points = filter_surface_density(
        raw_points, np.ones(len(raw_points)),
        ci_percentile=ci_percentile,
        density_radius=density_radius
    )
    draw_mesh(uniform_points, f"{cluster_name}_uniform_mesh", color_uniform)
    
    # 9. Clean up temp files
    for f in pdb_files:
        os.remove(f)
    
    # 10. Final rendering
    cmd.set("mesh_width", 0.5)
    cmd.set("transparency_mode", 1)
    cmd.zoom()

Parameter Tuning Guide

The key parameters to adjust:

ParameterEffectTypical Range
frame_subsamplePoints per frame (memory vs detail)300-1000
density_radiusNeighborhood size for density4.0-8.0 Å
ci_percentileSurface coverage (higher = tighter)80-95%
sasa_thresholdMinimum exposure to include0.0-1.0 Ų

For comparing methods, keep all parameters identical except the weights. This ensures differences come from the reweighting, not visualization artifacts.

Summary


Weighted SASA surfaces provide a powerful way to visualize ensemble distributions:


  • Generate surface points proportional to atomic SASA valuesCache intermediate data to enable rapid iterationFilter by weighted local density to show high-confidence regionsChoose the right algorithm (KD-tree for most cases, voxel for very large ensembles)Render as mesh for smooth, publication-quality envelopes

Combined with the reference handling and RMSD coloring from Part 1, you now have a complete toolkit for visualizing weighted structural ensembles in PyMOL.

Author