Skip to content

pelinker.plotting

plot_dbcv_vs_ari_from_grid(df_grid, output_path)

Scatter of mean DBCV vs mean ARI per (model, layer); shape = arity (△/□/○), fill colors = base encoder model(s); text = layer spec only (e.g. fusion 2+3). 95% covariance ellipses when n_sample ≥ 2.

Expects sample_best_dbcv, sample_ari on the grid export. Both axes are fixed to [0, _AXIS_MAX].

Returns:

Type Description
bool

True if a figure was written, False if required data were absent.

Source code in pelinker/plotting.py
def plot_dbcv_vs_ari_from_grid(
    df_grid: pd.DataFrame,
    output_path: pathlib.Path,
) -> bool:
    """
    Scatter of mean DBCV vs mean ARI per (model, layer); shape = arity (△/□/○),
    fill colors = base encoder model(s); text = layer spec only (e.g. fusion ``2+3``).
    95% covariance ellipses when ``n_sample`` ≥ 2.

    Expects ``sample_best_dbcv``, ``sample_ari`` on the grid export.
    Both axes are fixed to ``[0, _AXIS_MAX]``.

    Returns:
        True if a figure was written, False if required data were absent.
    """
    needed = {
        "model",
        "layer",
        "sample_idx",
        GRID_COL_SAMPLE_BEST_DBCV,
        GRID_COL_SAMPLE_ARI,
    }
    if not needed.issubset(df_grid.columns):
        return False

    df = df_grid.loc[
        :,
        [
            "model",
            "layer",
            "sample_idx",
            GRID_COL_SAMPLE_BEST_DBCV,
            GRID_COL_SAMPLE_ARI,
        ],
    ].drop_duplicates(subset=["model", "layer", "sample_idx"], keep="first")
    df = df[df[GRID_COL_SAMPLE_ARI].notna()].copy()
    if df.empty:
        return False

    all_models: list[str] = []
    for m, lyr in df[["model", "layer"]].drop_duplicates().itertuples(index=False):
        all_models.extend(_base_models_in_row(str(m), str(lyr)))
    color_by_model = _model_color_map(all_models)

    fig, ax = plt.subplots(figsize=(8.5, 8.5))
    arities_present: set[str] = set()

    for (model, layer), g in df.groupby(["model", "layer"], sort=False):
        xy = np.column_stack(
            [
                g[GRID_COL_SAMPLE_BEST_DBCV].to_numpy(dtype=np.float64),
                g[GRID_COL_SAMPLE_ARI].to_numpy(dtype=np.float64),
            ]
        )
        mean = xy.mean(axis=0)
        n = xy.shape[0]
        arity = _arity_from_model(str(model))
        arities_present.add(arity)
        models_row = _base_models_in_row(str(model), str(layer))

        # Halo drawn first: light filled ellipse + clear dashed rim, slightly inflated so
        # it remains visible through transparent markers.
        if n >= 2:
            cov = np.cov(xy, rowvar=False, ddof=1)
            ell = _covariance_ellipse_95(cov)
            if ell is not None:
                w, h, ang = ell
                wi, hi = w * _ELLIPSE_INFLATE, h * _ELLIPSE_INFLATE
                patch = Ellipse(
                    xy=(float(mean[0]), float(mean[1])),
                    width=wi,
                    height=hi,
                    angle=ang,
                    facecolor=(0.45, 0.48, 0.52, _ELLIPSE_FILL_ALPHA),
                    edgecolor=(0.12, 0.14, 0.18, _ELLIPSE_EDGE_ALPHA),
                    linewidth=1.45,
                    linestyle=(0, (4.5, 3.0)),
                    zorder=2,
                )
                ax.add_patch(patch)

        _draw_arity_marker(
            ax,
            float(mean[0]),
            float(mean[1]),
            arity=arity,
            models=models_row,
            color_by_model=color_by_model,
            zorder=5,
        )
        layer_code = _layer_spec_code(str(model), str(layer))
        if len(layer_code) > 22:
            layer_code = layer_code[:19] + "…"
        ax.annotate(
            layer_code,
            (mean[0], mean[1]),
            textcoords="offset points",
            # xytext=(8, 8),
            xytext=(-3 - 1.5 * (len(layer_code) - 1), -3),
            fontsize=8,
            alpha=0.88,
            zorder=6,
        )

    ax.set_xlim(0.0, _AXIS_MAX)
    ax.set_ylim(0.0, _AXIS_MAX)
    ax.set_aspect("equal")
    ax.set_xlabel("DBCV (per-sample best, mean over samples)")
    ax.set_ylabel("Adjusted Rand Index (per-sample, mean over samples)")
    ax.set_title("DBCV vs ARI; dashed ellipse ≈95% (n_sample >= 2)")

    order_a = ["singleton", "fusion2", "fusion3"]
    arity_labels = {
        "singleton": "singleton",
        "fusion2": "pair fusion",
        "fusion3": "triple fusion",
    }
    edge_legend = (0.0, 0.0, 0.0, _MARKER_OUTLINE_ALPHA)
    legend_shapes = [
        Line2D(
            [0],
            [0],
            marker=ARITY_MARKER_SCATTER[a],
            color="none",
            label=arity_labels[a],
            markerfacecolor=_rgba("#bbbbbb", _MARKER_FACE_ALPHA),
            markeredgecolor=edge_legend,
            markersize=10,
        )
        for a in order_a
        if a in arities_present
    ]
    legend_colors = [
        Patch(
            facecolor=_rgba(color_by_model[m], _MARKER_FACE_ALPHA),
            edgecolor=edge_legend,
            linewidth=0.5,
            label=m,
        )
        for m in sorted(color_by_model.keys())
    ]
    if legend_shapes or legend_colors:
        leg1 = ax.legend(
            handles=legend_shapes,
            title="Arity",
            loc="upper left",
            bbox_to_anchor=(1.02, 1.0),
            borderaxespad=0.0,
            frameon=True,
        )
        ax.add_artist(leg1)
        ax.legend(
            handles=legend_colors,
            title="Base model",
            loc="lower left",
            bbox_to_anchor=(1.02, 0.0),
            borderaxespad=0.0,
            frameon=True,
        )

    ax.grid(True, alpha=0.28, linestyle="--", zorder=0)
    plt.tight_layout()
    output_path.parent.mkdir(parents=True, exist_ok=True)
    plt.savefig(output_path, dpi=300, bbox_inches="tight")
    plt.close(fig)
    return True

plot_heatmap(df_results, output_path, metric='best_score', metric_label=None)

Create a heatmap with model (rows) and layer (columns). Color represents the specified metric, text shows best_size and metric name.

Parameters:

Name Type Description Default
df_results DataFrame

DataFrame with columns: model, layer, best_size, and the metric column

required
output_path Path

Path to save the heatmap figure

required
metric str

Column name for the metric to display as color (default: "best_score")

'best_score'
metric_label str | None

Label for the metric (default: uses metric column name)

None
Source code in pelinker/plotting.py
def plot_heatmap(
    df_results: pd.DataFrame,
    output_path: pathlib.Path,
    metric: str = "best_score",
    metric_label: str | None = None,
):
    """
    Create a heatmap with model (rows) and layer (columns).
    Color represents the specified metric, text shows best_size and metric name.

    Args:
        df_results: DataFrame with columns: model, layer, best_size, and the metric column
        output_path: Path to save the heatmap figure
        metric: Column name for the metric to display as color (default: "best_score")
        metric_label: Label for the metric (default: uses metric column name)
    """
    if metric_label is None:
        metric_label = metric.replace("_", " ").title()

    # Create pivot tables
    score_pivot = df_results.pivot(index="model", columns="layer", values=metric)
    size_pivot = df_results.pivot(index="model", columns="layer", values="best_size")

    # Create figure
    fig, ax = plt.subplots(
        figsize=(
            max(8, len(score_pivot.columns) * 0.8),
            max(6, len(score_pivot.index) * 0.6),
        )
    )

    # Create heatmap with metric as color
    # Use RdBu_r (Red-Blue reversed) for clear visual distinction: red=high, blue=low
    sns.heatmap(
        score_pivot,
        annot=False,  # We'll add custom annotations
        fmt=".3f",
        cmap="RdBu_r",
        center=None,  # Center colormap at the median for better contrast
        cbar_kws={"label": metric_label, "shrink": 0.8},
        ax=ax,
        linewidths=0.5,
        linecolor="white",
        square=False,
    )

    # Add best_size and metric name as text annotations
    # Calculate mean score for text color threshold
    valid_scores = score_pivot.values[~pd.isna(score_pivot.values)]
    mean_score = valid_scores.mean() if len(valid_scores) > 0 else 0

    for i in range(len(score_pivot.index)):
        for j in range(len(score_pivot.columns)):
            score_val = score_pivot.iloc[i, j]
            size_val = size_pivot.iloc[i, j]

            if not pd.isna(score_val) and not pd.isna(size_val):
                # Use white text for darker cells (lower scores), black for lighter cells
                text_color = "white" if score_val < mean_score else "black"
                # Format metric value based on its magnitude
                if abs(score_val) < 0.01:
                    metric_str = f"{score_val:.2e}"
                elif abs(score_val) < 1:
                    metric_str = f"{score_val:.3f}"
                else:
                    metric_str = f"{score_val:.2f}"
                # Add text annotation with best_size and metric value
                ax.text(
                    j + 0.5,
                    i + 0.5,
                    f"{int(size_val)}\n{metric_str}",
                    ha="center",
                    va="center",
                    color=text_color,
                    fontweight="bold",
                    fontsize=8,
                    linespacing=1.2,
                )

    ax.set_title(f"Clustering Results: {metric_label} (color) and Best Size (text)")
    ax.set_xlabel("Layer")
    ax.set_ylabel("Model")

    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches="tight")
    plt.close()

plot_metrics_with_error_bars(metrics_list, output_path, *, chosen_min_cluster_size=None)

Plot metrics across multiple runs with error bars using seaborn lineplot.

Parameters:

Name Type Description Default
metrics_list list[DataFrame]

List of DataFrames, each with columns: min_cluster_size, icm, n_clusters, dbcv, ari

required
output_path Path

Path to save the figure

required
chosen_min_cluster_size float | None

Optional vertical marker for the selected grid value (e.g. from smoother / argmax).

None
Source code in pelinker/plotting.py
def plot_metrics_with_error_bars(
    metrics_list: list[pd.DataFrame],
    output_path: pathlib.Path,
    *,
    chosen_min_cluster_size: float | None = None,
):
    """
    Plot metrics across multiple runs with error bars using seaborn lineplot.

    Args:
        metrics_list: List of DataFrames, each with columns: min_cluster_size, icm, n_clusters, dbcv, ari
        output_path: Path to save the figure
        chosen_min_cluster_size: Optional vertical marker for the selected grid value (e.g. from smoother / argmax).
    """
    # Combine all metrics DataFrames, adding a run_id column
    combined_metrics = []
    for run_id, df in enumerate(metrics_list):
        df_copy = df.copy()
        df_copy["run_id"] = run_id
        combined_metrics.append(df_copy)

    df_combined = pd.concat(combined_metrics, ignore_index=True)

    # Filter out trivial points where n_clusters <= 1
    df_combined = df_combined[df_combined["n_clusters"] > 1].copy()

    if len(df_combined) == 0:
        print(
            f"Warning: No valid data points after filtering (n_clusters > 1) for {output_path}"
        )
        return

    has_ari = "ari" in df_combined.columns and bool(df_combined["ari"].notna().any())
    ncols = 3 if has_ari else 2
    fig, axes = plt.subplots(1, ncols, figsize=(6 * ncols, 5))
    if ncols == 2:
        ax_dbcv, ax_k = axes[0], axes[1]
        ax_ari = None
    else:
        ax_dbcv, ax_k, ax_ari = axes[0], axes[1], axes[2]

    # Color palette for different plots
    colors = ["#2E86AB", "#A23B72", "#F18F01"]  # Blue, Purple, Orange

    def _maybe_vline(ax) -> None:
        if chosen_min_cluster_size is None:
            return
        ax.axvline(
            chosen_min_cluster_size,
            color="0.35",
            linestyle="--",
            linewidth=1.5,
            alpha=0.9,
            zorder=0,
        )

    # Plot DBCV score with error bars
    sns.lineplot(
        data=df_combined,
        x="min_cluster_size",
        y="dbcv",
        ax=ax_dbcv,
        errorbar="sd",  # Standard deviation error bars
        marker="o",
        color=colors[0],
        linewidth=2,
        markersize=8,
        err_kws={"alpha": 0.3, "linewidth": 1.5},
    )
    _maybe_vline(ax_dbcv)
    ax_dbcv.set_xlabel("min_cluster_size", fontsize=12, fontweight="bold")
    ax_dbcv.set_ylabel("DBCV Score", fontsize=12, fontweight="bold", color=colors[0])
    ax_dbcv.set_title("DBCV Score vs. min_cluster_size", fontsize=13, fontweight="bold")
    ax_dbcv.grid(True, alpha=0.3, linestyle="--")
    ax_dbcv.tick_params(axis="y", labelcolor=colors[0])
    ax_dbcv.spines["top"].set_visible(False)
    ax_dbcv.spines["right"].set_visible(False)

    # Plot n_clusters with error bars (log scale)
    sns.lineplot(
        data=df_combined,
        x="min_cluster_size",
        y="n_clusters",
        ax=ax_k,
        errorbar="sd",
        marker="^",
        color=colors[1],
        linewidth=2,
        markersize=8,
        err_kws={"alpha": 0.3, "linewidth": 1.5},
    )
    _maybe_vline(ax_k)
    ax_k.set_xlabel("min_cluster_size", fontsize=12, fontweight="bold")
    ax_k.set_ylabel("n clusters", fontsize=12, fontweight="bold", color=colors[1])
    ax_k.set_title(
        "Number of Clusters vs. min_cluster_size", fontsize=13, fontweight="bold"
    )
    ax_k.grid(True, alpha=0.3, linestyle="--")
    ax_k.tick_params(axis="y", labelcolor=colors[1])
    ax_k.spines["top"].set_visible(False)
    ax_k.spines["right"].set_visible(False)

    if ax_ari is not None:
        sns.lineplot(
            data=df_combined,
            x="min_cluster_size",
            y="ari",
            ax=ax_ari,
            errorbar="sd",
            marker="D",
            color=colors[2],
            linewidth=2,
            markersize=7,
            err_kws={"alpha": 0.3, "linewidth": 1.5},
        )
        _maybe_vline(ax_ari)
        ax_ari.set_xlabel("min_cluster_size", fontsize=12, fontweight="bold")
        ax_ari.set_ylabel("ARI", fontsize=12, fontweight="bold", color=colors[2])
        ax_ari.set_title("ARI vs. min_cluster_size", fontsize=13, fontweight="bold")
        ax_ari.grid(True, alpha=0.3, linestyle="--")
        ax_ari.tick_params(axis="y", labelcolor=colors[2])
        ax_ari.spines["top"].set_visible(False)
        ax_ari.spines["right"].set_visible(False)

    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches="tight")
    plt.close()