PHinally PHunctionalising my PHigures with PHATE feat. Plotly Express.

After being recommended by a friend, I really wanted to try plotly express but I never had the inclination to read more documentation when matplotlib gives me enough grief. While experimenting with ChatGPT I finally decided to functionalise my figure making scripts. With these scripts I manage to produce figures that made people question what I had actually been doing with my time – but I promise this will be worth your time.

I have been using with dimensionality reducition techniques recently and I came across this paper by Moon et al. PHATE is a technique that represents high dimensional (ie biological) data in a way that aims to preserve connections over preserving distance and I knew I wanted to try this as soon as I saw it. Why should you care? PHATE in 3D is faster that t-SNE in 2D. It would almost be rude to not try it out.

PHATE

In my opinion PHATE (or potential of heat diffusion for affinity-based transition embedding) does have a lot going on but that the choices at each stage feel quite sensisble. It might not come as a surprise this was primarily designed to make visual inspection of data easier on the eyes.

As mentioned PHATE is way faster than t-SNE, PHATE 2D is about 40% faster than t-SNE 2D and PHATE 3D is about 10% faster than t-SNE 2D. I actually gave up on using t-SNE 3D because it was so slow.

You can use it very similarly to t-SNE to generate a PHATE operator object:

from sklearn.manifold import TSNE
import phate
import time

# Read in dataframe
global_df = pd.read_parquet(model_name+"_global_df.parquet")

# Features for colour mapping
colour_indexes = ["Role", "CLASS", "DNN_Logits", "GAN_Logits", "Logit_Diff" ,"DNN_Distance", "GAN_Distance"]

# Create plotting dataframe
plotting_df = pd.DataFrame()
plotting_df["FP"] = global_df["FP"]
plotting_df["ROLE"] = global_df["ROLE"]
plotting_df["Similarity"] = global_df["Similarity"]

# Add columns for colour mapping features
for index in colour_indexes:
    plotting_df[index] = global_df[index]

# Define the training features and dimension reducers
features = ["FP",  "Similarity", "DNN_Features", "GAN_Features"]
dim_reducers = ["T-SNE_2D", "PHATE_2D", "PHATE_3D"]

# Compute TSNE and PHATE for each feature and dimension reducer
for feature in features:
    print(f"Computing {feature}")
    data = np.stack(global_df[feature].to_numpy(), axis=0)
    print(data.shape)
    
    for dim_reducer in dim_reducers[:]:
        print(f"Computing {dim_reducer} for {feature}")
        
        if dim_reducer == "T-SNE_2D":
            tsne_operator = TSNE(n_components=2, n_jobs=4, perplexity=50, verbose=1, random_state=42)
            start_time = time.time()
            transformed_data = tsne_operator.fit_transform(data)
            end_time = time.time()

        elif dim_reducer == "PHATE_2D":
            phate_operator = phate.PHATE(n_components=2, n_jobs=4, knn=5, decay=40, random_state=42)
            start_time = time.time()
            transformed_data = phate_operator.fit_transform(data)
            end_time = time.time()

        elif dim_reducer == "PHATE_3D":
            phate_operator = phate.PHATE(n_components=3, n_jobs=4, knn=5, decay=40, random_state=42)
            start_time = time.time()
            transformed_data = phate_operator.fit_transform(data)
            end_time = time.time()
        
        # Create a column name combining the method and features used
        column_name = f"{dim_reducer}_{feature}"
        
        plotting_df[column_name] = [row for row in transformed_data]
        print(f"Execution time for {column_name}: {end_time - start_time} seconds")

plotting_df.to_parquet(model_name + "_plotting_df.parquet")

Example of what this might look like:

With the help of ChatGPT I FINALLY decided to functionalise my plots instead of just having one long script:

def plot_scatter(ax, dataframe, cname, cmap, method, feature):
    """
    Plot a scatter plot on the given axes.

    Args:
        ax (matplotlib.axes.Axes): The axes to plot on.
        data (np.ndarray): The data to be plotted.
        c: The color data for the scatter plot.
        cmap: The colormap for the scatter plot.

    Returns:
        None
    """
    c = dataframe[cname]

    column_name = f"{method}_{feature}"

    data = np.stack(dataframe[column_name].to_numpy(), axis=0)

    scatter = ax.scatter(data[:, 0], data[:, 1], c=c, cmap=cmap, s=3, alpha=0.5)
    ax.set_xlabel(f"{method} 1")
    ax.set_ylabel(f"{method} 2")
    ax.set_title(f"{method} plot on {feature} by {cname}")
    ax.legend()

    # Add colorbar
    cbar = plt.colorbar(scatter, ax=ax)
    cbar.set_label(cname)


def plot_scatter_grid(dataframe, cnames, method, feature):
    """
    Plot a grid of scatter plots.

    Args:
        dataframe (pd.DataFrame): The data to be plotted.
        cnames (list): List of column names for the scatter plots.
        method (str): The dimension reduction method.
        feature (str): The feature to be plotted.

    Returns:
        None
    """
    colours = ['green', 'orange', 'purple', 'pink']

    fig, axs = plt.subplots(len(cnames), len([method]), figsize=(6*len([method]), 4*len(cnames)))

    for i, cname in enumerate(cnames):
        ax = axs[i]
        if cname == "Role":
            cmap = matplotlib.colors.ListedColormap(colours)
        else:
            cmap = 'viridis'
        plot_scatter(ax, dataframe, cname, cmap, method, feature)

    plt.tight_layout()
    plt.show()


Usage (to create the above plot):

colour_indexes = ["Role", "CLASS", "DNN_Logits", "GAN_Logits", "Logit_Diff" ,"DNN_Distance", "GAN_Distance" ]
features = ["FP", "Similarity", "DNN_Features", "GAN_Features"]
dim_reducers = ["T-SNE_2D", "PHATE_2D", "PHATE_3D"]

# Define the parameters
cnames = colour_indexes[:]
cnames = ["Role", "CLASS", "DNN_Logits", "DNN_Distance"]
method = dim_reducers[1]
feature = features[2]

# Call the function to plot the scatter grid
plot_scatter_grid(plotting_df, cnames[2:4], method, feature)


PHATE is way faster to compute even in 3 dimensions. But matplotlb’s 3D plotting is a little fiddly – having to set the viewing angles can make visualisation slow, especially when plotting large datasets:

def plot_scatter_3D(ax, dataframe, cname, cmap, method, feature, elev, azim):
    """
    Plot a scatter plot in 3D on the given axes.

    Args:
        ax (mpl_toolkits.mplot3d.axes3d.Axes3D): The axes to plot on.
        dataframe (pd.DataFrame): The data to be plotted.
        cname (str): The column name for color data.
        cmap: The colormap for the scatter plot.
        method (str): The method used for the plot.
        feature (str): The feature used for the plot.
        elev (float): The elevation angle for the viewing perspective.
        azim (float): The azimuth angle for the viewing perspective.

    Returns:
        None
    """
    c = dataframe[cname]
    column_name = f"{method}_{feature}"
    data = np.stack(dataframe[column_name].to_numpy(), axis=0)

    scatter = ax.scatter(data[:, 0], data[:, 1], data[:, 2], c=c, cmap=cmap, s=3, alpha=0.5)
    ax.set_xlabel(f"{method} 1")
    ax.set_ylabel(f"{method} 2")
    ax.set_zlabel(f"{method} 3")
    ax.set_title(f"{method} plot on {feature} by {cname}")
    ax.legend()

    # Add colorbar
    cbar = plt.colorbar(scatter, ax=ax)
    cbar.set_label(cname)

    # Set the view angle
    ax.view_init(elev=elev, azim=azim)


It also looks pretty ugly without some tinkering:

Plotly Express is a library that is designed for figures that are meant to be viewed on a computer, very interactive and incredibly fast to boot. Here is a simple example of a 3D plot that you can alter the perspective in vscode:

The matplotlib 3D scatter plot function from above was modified using ChatGPT to use Plotly Express.

from plotly.subplots import make_subplots
import plotly.graph_objects as go

def plot_scatter_3d(dataframe, cname, cmap, method, feature):
    c = dataframe[cname]

    column_name = f"{method}_{feature}"

    data = np.stack(dataframe[column_name].to_numpy(), axis=0)

    scatter = go.Scatter3d(
        x=data[:, 0],
        y=data[:, 1],
        z=data[:, 2],
        mode='markers',
        marker=dict(
            size=1,
            color=c,
            colorscale=cmap,
            opacity=0.5
        ),
        name=cname
    )

    fig = go.Figure(data=[scatter])
    fig.update_layout(
        title=f"{method} plot on {feature} by {cname}",
        scene=dict(
            xaxis_title=f"{method} 1",
            yaxis_title=f"{method} 2",
            zaxis_title=f"{method} 3"
        )
    )
    fig.show()



On the surface, Plotly express can be very unfamiliar for those used to matplotlib. The problem is that often you want to use matplotlib for publication quality figures so why bother writing something in another library if it wont end up in your final work?

I used ChatGPT to rewrite my plots. Since it was already functionalised I never had to worry about the syntax, this is especially important when experimenting with different ways to display your data. Much faster, way cleaner.

Here’s how I prompted it:

I have found ChatGPT to be really effective for code where the syntax and usage is unclear to provide some examples so that you can better understand whats going on.

Author