Dynamic axes in Bokeh

Today I wanted to create an interactive visualization where I could choose two axes from multiple possibilities to plot as a scatter plot. The example I could find was using a Bokeh server but I would prefer a standalone document that I can embed into HTML. Implemented dynamics like this requires writing JavaScript callbacks which I’m not stoked about but it was pretty easy in this case.

Here, I simulated three columns of data: A, B, & C and put them into a DataFrame. To make sure I kept track of the original values, I added two columns to this DataFrame, x & y, to hold the plotted values. These values change in the callback whenever one of the dropdowns record an updated value. I then update the relevant column and the axis label.

import numpy as np
import pandas as pd

from bokeh.io import show
from bokeh.plotting import figure
from bokeh.layouts import column, row
from bokeh.models import ColumnDataSource, Select, HoverTool, ColorBar
from bokeh.models.callbacks import CustomJS
from bokeh.palettes import Plasma256
from bokeh.transform import linear_cmap

output_file("dynamic_axes.html", title="Dynamic Axes")

rng = np.random.default_rng(42)

n = 200
options = ["A", "B", "C"]
df = pd.DataFrame(
    rng.normal(loc=[-1, 0, 1], scale=[0.5, 1, 0.5], size=(n, 3)),
    columns=options
)
df["magnitude"] = df.apply(np.linalg.norm, axis=1)
mapper = (
    linear_cmap(
        field_name="magnitude",
        palette=Plasma256,
        low=min(df["magnitude"]),
        high=max(df["magnitude"])
    )
)

min_val = df.min().min()
max_val = df.max().max()
val_range = max_val - min_val

selection_args = {"options": options, "width": 100}
selection_1 = Select(value=options[0], title="x-axis", **selection_args)
selection_2 = Select(value=options[1], title="y-axis", **selection_args)
df["x"] = df[selection_1.value]
df["y"] = df[selection_2.value]

source = ColumnDataSource(df)

hover = HoverTool(mode="mouse", names=["points"], attachment="below")
hover.tooltips = [(f"{x}", f"@{x}0.0") for x in options]
hover.tooltips.append(("Magnitude", "@magnitude"))

tools = [hover, "box_zoom", "reset", "pan", "save"]

axis_lims = [min_val - val_range*0.05, max_val + val_range*0.05]
plot = figure(tools=tools, x_range=axis_lims, y_range=axis_lims)
plot.circle(
    source=source,
    x="x",
    y="y",
    color=mapper,
    line_width=0.5,
    line_color="black",
    size=10,
    name="points"
)

plot.background_fill_color = "#F3F3F3"
plot.grid.grid_line_color = "#CDCDCD"

color_bar = ColorBar(color_mapper=mapper["transform"])
plot.add_layout(color_bar, "below")

plot.xaxis.axis_label = selection_1.value
plot.yaxis.axis_label = selection_2.value

for ax in [plot.xaxis, plot.yaxis]:
    ax.axis_label_text_font_size = "14pt"
    ax.axis_label_text_font_style = "normal"
    ax.major_label_text_font_size = "10pt"

tool_callback = CustomJS(args=dict(source=source, plot=plot), code="""
const data = source.data[cb_obj.value];
if (cb_obj.title == 'x-axis') {
    plot.below[0].axis_label = cb_obj.value;
    source.data.x = data;
} else {
    plot.left[0].axis_label = cb_obj.value;
    source.data.y = data;
}

source.change.emit();
""")
selection_1.js_on_change("value", tool_callback)
selection_2.js_on_change("value", tool_callback)

layout = row(column(selection_1, selection_2), plot)

show(layout)