Skip to content

API Reference

This is the API reference for MONET Plots.

Modules

Cartopy-based plotting utilities for MONET.

facet_time_map(da, time_dim='time', ncols=3, map_kws=None, projection=None, colorbar=True, figsize=None, cmap=None, vmin=None, vmax=None, norm=None, dpi=150, xlabel=None, ylabel=None, suptitle=None, cbar_label=None, xticks=None, yticks=None, annotations=None, export_path=None, export_formats=None, **kwargs)

Create a facet grid of map plots for each time slice in a DataArray using Cartopy.

Parameters

da : xarray.DataArray The data to plot (must have a time dimension). time_dim : str, default: "time" Name of the time dimension. ncols : int, default: 3 Number of columns in the facet grid. map_kws : dict, optional Dictionary of keyword arguments for map features. projection : cartopy.crs.Projection, optional Cartopy projection to use. Defaults to PlateCarree. colorbar : bool, default: True Whether to add a colorbar (shared). figsize : tuple, optional Figure size. cmap : str or Colormap, optional Colormap to use. vmin, vmax : float, optional Color limits. norm : Normalize, optional Matplotlib normalization. dpi : int, optional Dots per inch for export. xlabel, ylabel, suptitle : str, optional Axis labels and super title. cbar_label : str, optional Label for the colorbar. xticks, yticks : list, optional Custom tick locations. annotations : list of dict, optional List of annotation dicts for each subplot. export_path : str, optional Path to export the figure (without extension). export_formats : list, optional List of formats to export (e.g., ["png", "pdf"]). **kwargs : dict Additional keyword arguments for plotting.

Returns

fig : matplotlib.figure.Figure The matplotlib figure object. axes : ndarray of matplotlib.axes.Axes The matplotlib axes objects.

Source code in src/monet_plots/cartopy_utils.py
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
def facet_time_map(
    da,
    time_dim="time",
    ncols=3,
    map_kws=None,
    projection=None,
    colorbar=True,
    figsize=None,
    cmap=None,
    vmin=None,
    vmax=None,
    norm=None,
    dpi=150,
    xlabel=None,
    ylabel=None,
    suptitle=None,
    cbar_label=None,
    xticks=None,
    yticks=None,
    annotations=None,
    export_path=None,
    export_formats=None,
    **kwargs,
):
    """
    Create a facet grid of map plots for each time slice in a DataArray using Cartopy.

    Parameters
    ----------
    da : xarray.DataArray
        The data to plot (must have a time dimension).
    time_dim : str, default: "time"
        Name of the time dimension.
    ncols : int, default: 3
        Number of columns in the facet grid.
    map_kws : dict, optional
        Dictionary of keyword arguments for map features.
    projection : cartopy.crs.Projection, optional
        Cartopy projection to use. Defaults to PlateCarree.
    colorbar : bool, default: True
        Whether to add a colorbar (shared).
    figsize : tuple, optional
        Figure size.
    cmap : str or Colormap, optional
        Colormap to use.
    vmin, vmax : float, optional
        Color limits.
    norm : Normalize, optional
        Matplotlib normalization.
    dpi : int, optional
        Dots per inch for export.
    xlabel, ylabel, suptitle : str, optional
        Axis labels and super title.
    cbar_label : str, optional
        Label for the colorbar.
    xticks, yticks : list, optional
        Custom tick locations.
    annotations : list of dict, optional
        List of annotation dicts for each subplot.
    export_path : str, optional
        Path to export the figure (without extension).
    export_formats : list, optional
        List of formats to export (e.g., ["png", "pdf"]).
    **kwargs : dict
        Additional keyword arguments for plotting.

    Returns
    -------
    fig : matplotlib.figure.Figure
        The matplotlib figure object.
    axes : ndarray of matplotlib.axes.Axes
        The matplotlib axes objects.
    """
    # Setup
    projection = _setup_map_projection(projection)
    map_kws = map_kws or {}
    times = da[time_dim].values
    nt = len(times)

    # Create facet grid
    fig, axes, nrows, ncols = _setup_facet_grid(nt, ncols, projection, figsize, dpi)

    # Prepare plot arguments
    plot_args = dict(cmap=cmap, vmin=vmin, vmax=vmax, norm=norm, add_colorbar=False)
    plot_args.update(kwargs)

    # Plot each time slice
    mesh = None
    for i, t in enumerate(times):
        ax = axes[i]
        dat = da.sel({time_dim: t})
        mesh = dat.plot(ax=ax, transform=ccrs.PlateCarree(), **plot_args)

        # Setup individual facet axis
        title = str(np.datetime_as_string(t))
        _setup_single_facet_axis(ax, map_kws, xlabel, ylabel, xticks, yticks, title)

        # Add annotations if provided
        if annotations and i < len(annotations):
            ax.annotate(**annotations[i])

    # Clean up unused axes
    for j in range(nt, len(axes)):
        fig.delaxes(axes[j])

    # Add shared colorbar
    _add_shared_colorbar(fig, axes, mesh, colorbar, cbar_label)

    # Set super title
    if suptitle:
        fig.suptitle(suptitle, fontsize=14, fontweight="bold")

    # Export figure
    _export_figure(fig, export_path, export_formats, dpi)

    return fig, axes

plot_lines_map(df, lon_col='longitude', lat_col='latitude', group_col=None, projection=None, color='C0', linewidth=2, alpha=0.8, map_kws=None, figsize=(8, 6), dpi=150, title=None, export_path=None, export_formats=None, **kwargs)

Plot lines from a DataFrame on a Cartopy map. Optionally group by a column.

Parameters

df : pandas.DataFrame DataFrame with longitude and latitude columns. lon_col, lat_col : str Column names for longitude and latitude. group_col : str, optional Column to group lines (e.g., for trajectories). projection : cartopy.crs.Projection, optional Cartopy projection to use. Defaults to PlateCarree. color : str or array-like, optional Line color. linewidth : float, optional Line width. alpha : float, optional Line transparency. map_kws : dict, optional Map feature keyword arguments. figsize : tuple, optional Figure size. dpi : int, optional Dots per inch for export. title : str, optional Plot title. export_path : str, optional Path to export the figure (without extension). export_formats : list, optional List of formats to export (e.g., ["png", "pdf"]). **kwargs : dict Additional keyword arguments for plt.plot.

Returns

fig, ax : matplotlib Figure and Axes

Source code in src/monet_plots/cartopy_utils.py
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
def plot_lines_map(
    df,
    lon_col="longitude",
    lat_col="latitude",
    group_col=None,
    projection=None,
    color="C0",
    linewidth=2,
    alpha=0.8,
    map_kws=None,
    figsize=(8, 6),
    dpi=150,
    title=None,
    export_path=None,
    export_formats=None,
    **kwargs,
):
    """
    Plot lines from a DataFrame on a Cartopy map. Optionally group by a column.

    Parameters
    ----------
    df : pandas.DataFrame
        DataFrame with longitude and latitude columns.
    lon_col, lat_col : str
        Column names for longitude and latitude.
    group_col : str, optional
        Column to group lines (e.g., for trajectories).
    projection : cartopy.crs.Projection, optional
        Cartopy projection to use. Defaults to PlateCarree.
    color : str or array-like, optional
        Line color.
    linewidth : float, optional
        Line width.
    alpha : float, optional
        Line transparency.
    map_kws : dict, optional
        Map feature keyword arguments.
    figsize : tuple, optional
        Figure size.
    dpi : int, optional
        Dots per inch for export.
    title : str, optional
        Plot title.
    export_path : str, optional
        Path to export the figure (without extension).
    export_formats : list, optional
        List of formats to export (e.g., ["png", "pdf"]).
    **kwargs : dict
        Additional keyword arguments for plt.plot.

    Returns
    -------
    fig, ax : matplotlib Figure and Axes
    """
    # Setup
    projection = _setup_map_projection(projection)
    map_kws = map_kws or {}
    fig, ax = plt.subplots(
        subplot_kw={"projection": projection}, figsize=figsize, dpi=dpi
    )

    # Add map features
    _add_map_features(ax, map_kws)

    # Plot lines
    if group_col:
        for _, group in df.groupby(group_col):
            ax.plot(
                group[lon_col],
                group[lat_col],
                color=color,
                linewidth=linewidth,
                alpha=alpha,
                transform=ccrs.PlateCarree(),
                **kwargs,
            )
    else:
        ax.plot(
            df[lon_col],
            df[lat_col],
            color=color,
            linewidth=linewidth,
            alpha=alpha,
            transform=ccrs.PlateCarree(),
            **kwargs,
        )

    # Set title
    if title:
        ax.set_title(title, fontsize=14, fontweight="bold")

    # Finalize
    fig.tight_layout()
    _export_figure(fig, export_path, export_formats, dpi)

    return fig, ax

plot_points_map(df, lon_col='longitude', lat_col='latitude', projection=None, color='C0', marker='o', size=40, edgecolor='k', alpha=0.8, map_kws=None, figsize=(8, 6), dpi=150, title=None, export_path=None, export_formats=None, **kwargs)

Plot points from a DataFrame on a Cartopy map.

Parameters

df : pandas.DataFrame DataFrame with longitude and latitude columns. lon_col, lat_col : str Column names for longitude and latitude. projection : cartopy.crs.Projection, optional Cartopy projection to use. Defaults to PlateCarree. color : str or array-like, optional Color for points. marker : str, optional Marker style. size : float or array-like, optional Marker size. edgecolor : str, optional Marker edge color. alpha : float, optional Marker transparency. map_kws : dict, optional Map feature keyword arguments. figsize : tuple, optional Figure size. dpi : int, optional Dots per inch for export. title : str, optional Plot title. export_path : str, optional Path to export the figure (without extension). export_formats : list, optional List of formats to export (e.g., ["png", "pdf"]). **kwargs : dict Additional keyword arguments for plt.scatter.

Returns

fig, ax : matplotlib Figure and Axes

Source code in src/monet_plots/cartopy_utils.py
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
def plot_points_map(
    df,
    lon_col="longitude",
    lat_col="latitude",
    projection=None,
    color="C0",
    marker="o",
    size=40,
    edgecolor="k",
    alpha=0.8,
    map_kws=None,
    figsize=(8, 6),
    dpi=150,
    title=None,
    export_path=None,
    export_formats=None,
    **kwargs,
):
    """
    Plot points from a DataFrame on a Cartopy map.

    Parameters
    ----------
    df : pandas.DataFrame
        DataFrame with longitude and latitude columns.
    lon_col, lat_col : str
        Column names for longitude and latitude.
    projection : cartopy.crs.Projection, optional
        Cartopy projection to use. Defaults to PlateCarree.
    color : str or array-like, optional
        Color for points.
    marker : str, optional
        Marker style.
    size : float or array-like, optional
        Marker size.
    edgecolor : str, optional
        Marker edge color.
    alpha : float, optional
        Marker transparency.
    map_kws : dict, optional
        Map feature keyword arguments.
    figsize : tuple, optional
        Figure size.
    dpi : int, optional
        Dots per inch for export.
    title : str, optional
        Plot title.
    export_path : str, optional
        Path to export the figure (without extension).
    export_formats : list, optional
        List of formats to export (e.g., ["png", "pdf"]).
    **kwargs : dict
        Additional keyword arguments for plt.scatter.

    Returns
    -------
    fig, ax : matplotlib Figure and Axes
    """
    # Setup
    projection = _setup_map_projection(projection)
    map_kws = map_kws or {}
    fig, ax = plt.subplots(
        subplot_kw={"projection": projection}, figsize=figsize, dpi=dpi
    )

    # Add map features
    _add_map_features(ax, map_kws)

    # Plot points
    ax.scatter(
        df[lon_col],
        df[lat_col],
        color=color,
        marker=marker,
        s=size,
        edgecolor=edgecolor,
        alpha=alpha,
        transform=ccrs.PlateCarree(),
        **kwargs,
    )

    # Set title
    if title:
        ax.set_title(title, fontsize=14, fontweight="bold")

    # Finalize
    fig.tight_layout()
    _export_figure(fig, export_path, export_formats, dpi)

    return fig, ax

plot_quick_contourf(da, map_kws=None, projection=None, colorbar=True, figsize=None, cmap=None, vmin=None, vmax=None, norm=None, dpi=150, xlabel=None, ylabel=None, title=None, cbar_label=None, cbar_inset=False, xticks=None, yticks=None, annotations=None, export_path=None, export_formats=None, **kwargs)

Create a publication-quality filled contour plot of the data on a map using Cartopy.

Parameters

da : xarray.DataArray The data to plot. map_kws : dict, optional Dictionary of keyword arguments for map features (e.g., coastlines, gridlines, features, borders, land, ocean). projection : cartopy.crs.Projection, optional Cartopy projection to use. Defaults to PlateCarree. colorbar : bool, default: True Whether to add a colorbar. figsize : tuple, optional Figure size. cmap : str or Colormap, optional Colormap to use (supports colorblind-friendly options). vmin, vmax : float, optional Color limits. norm : Normalize, optional Matplotlib normalization (e.g., LogNorm). dpi : int, optional Dots per inch for export. xlabel, ylabel, title : str, optional Axis labels and plot title. cbar_label : str, optional Label for the colorbar. cbar_inset : bool, default: False Place colorbar as an inset (right) if True. xticks, yticks : list, optional Custom tick locations. annotations : list of dict, optional List of annotation dicts (e.g., {"text": "A", "xy": (lon, lat)}). export_path : str, optional Path to export the figure (without extension). export_formats : list, optional List of formats to export (e.g., ["png", "pdf"]). **kwargs : dict Additional keyword arguments for contourf.

Returns

fig : matplotlib.figure.Figure The matplotlib figure object. ax : matplotlib.axes.Axes The matplotlib axes object.

Source code in src/monet_plots/cartopy_utils.py
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
def plot_quick_contourf(
    da,
    map_kws=None,
    projection=None,
    colorbar=True,
    figsize=None,
    cmap=None,
    vmin=None,
    vmax=None,
    norm=None,
    dpi=150,
    xlabel=None,
    ylabel=None,
    title=None,
    cbar_label=None,
    cbar_inset=False,
    xticks=None,
    yticks=None,
    annotations=None,
    export_path=None,
    export_formats=None,
    **kwargs,
):
    """
    Create a publication-quality filled contour plot of the data on a map using Cartopy.

    Parameters
    ----------
    da : xarray.DataArray
        The data to plot.
    map_kws : dict, optional
        Dictionary of keyword arguments for map features (e.g., coastlines, gridlines, features, borders, land, ocean).
    projection : cartopy.crs.Projection, optional
        Cartopy projection to use. Defaults to PlateCarree.
    colorbar : bool, default: True
        Whether to add a colorbar.
    figsize : tuple, optional
        Figure size.
    cmap : str or Colormap, optional
        Colormap to use (supports colorblind-friendly options).
    vmin, vmax : float, optional
        Color limits.
    norm : Normalize, optional
        Matplotlib normalization (e.g., LogNorm).
    dpi : int, optional
        Dots per inch for export.
    xlabel, ylabel, title : str, optional
        Axis labels and plot title.
    cbar_label : str, optional
        Label for the colorbar.
    cbar_inset : bool, default: False
        Place colorbar as an inset (right) if True.
    xticks, yticks : list, optional
        Custom tick locations.
    annotations : list of dict, optional
        List of annotation dicts (e.g., {"text": "A", "xy": (lon, lat)}).
    export_path : str, optional
        Path to export the figure (without extension).
    export_formats : list, optional
        List of formats to export (e.g., ["png", "pdf"]).
    **kwargs : dict
        Additional keyword arguments for contourf.

    Returns
    -------
    fig : matplotlib.figure.Figure
        The matplotlib figure object.
    ax : matplotlib.axes.Axes
        The matplotlib axes object.
    """
    # Setup
    projection = _setup_map_projection(projection)
    map_kws = map_kws or {}
    fig, ax = plt.subplots(
        subplot_kw={"projection": projection}, figsize=figsize, dpi=dpi
    )

    # Prepare plot arguments
    plot_args = dict(cmap=cmap, vmin=vmin, vmax=vmax, norm=norm)
    plot_args.update({k: v for k, v in kwargs.items() if k not in ["ax", "transform"]})

    # Create the plot
    mesh = da.plot.contourf(ax=ax, transform=ccrs.PlateCarree(), **plot_args)

    # Add map features
    _add_map_features(ax, map_kws)

    # Set labels and title
    _set_axis_labels_and_title(ax, xlabel, ylabel, title)

    # Set custom ticks
    _set_custom_ticks(ax, xticks, yticks)

    # Add annotations
    _add_annotations(ax, annotations)

    # Add colorbar
    _add_colorbar(fig, ax, mesh, colorbar, cbar_label, cbar_inset)

    # Finalize
    fig.tight_layout()
    _export_figure(fig, export_path, export_formats, dpi)

    return fig, ax

plot_quick_imshow(da, map_kws=None, projection=None, colorbar=True, figsize=None, cmap=None, vmin=None, vmax=None, norm=None, dpi=150, xlabel=None, ylabel=None, title=None, cbar_label=None, cbar_inset=False, xticks=None, yticks=None, annotations=None, export_path=None, export_formats=None, **kwargs)

Create a imshow plot of the data on a map using Cartopy.

Parameters

da : xarray.DataArray The data to plot. map_kws : dict, optional Dictionary of keyword arguments for map features (e.g., coastlines, gridlines, features, borders, land, ocean). projection : cartopy.crs.Projection, optional Cartopy projection to use. Defaults to PlateCarree. colorbar : bool, default: True Whether to add a colorbar. figsize : tuple, optional Figure size. cmap : str or Colormap, optional Colormap to use (supports colorblind-friendly options). vmin, vmax : float, optional Color limits. norm : Normalize, optional Matplotlib normalization (e.g., LogNorm). dpi : int, optional Dots per inch for export. xlabel, ylabel, title : str, optional Axis labels and plot title. cbar_label : str, optional Label for the colorbar. cbar_inset : bool, default: False Place colorbar as an inset (right) if True. xticks, yticks : list, optional Custom tick locations. annotations : list of dict, optional List of annotation dicts (e.g., {"text": "A", "xy": (lon, lat)}). export_path : str, optional Path to export the figure (without extension). export_formats : list, optional List of formats to export (e.g., ["png", "pdf"]). **kwargs : dict Additional keyword arguments for imshow.

Returns

fig : matplotlib.figure.Figure The matplotlib figure object. ax : matplotlib.axes.Axes The matplotlib axes object.

Source code in src/monet_plots/cartopy_utils.py
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
def plot_quick_imshow(
    da,
    map_kws=None,
    projection=None,
    colorbar=True,
    figsize=None,
    cmap=None,
    vmin=None,
    vmax=None,
    norm=None,
    dpi=150,
    xlabel=None,
    ylabel=None,
    title=None,
    cbar_label=None,
    cbar_inset=False,
    xticks=None,
    yticks=None,
    annotations=None,
    export_path=None,
    export_formats=None,
    **kwargs,
):
    """
    Create a imshow plot of the data on a map using Cartopy.

    Parameters
    ----------
    da : xarray.DataArray
        The data to plot.
    map_kws : dict, optional
        Dictionary of keyword arguments for map features (e.g., coastlines, gridlines, features, borders, land, ocean).
    projection : cartopy.crs.Projection, optional
        Cartopy projection to use. Defaults to PlateCarree.
    colorbar : bool, default: True
        Whether to add a colorbar.
    figsize : tuple, optional
        Figure size.
    cmap : str or Colormap, optional
        Colormap to use (supports colorblind-friendly options).
    vmin, vmax : float, optional
        Color limits.
    norm : Normalize, optional
        Matplotlib normalization (e.g., LogNorm).
    dpi : int, optional
        Dots per inch for export.
    xlabel, ylabel, title : str, optional
        Axis labels and plot title.
    cbar_label : str, optional
        Label for the colorbar.
    cbar_inset : bool, default: False
        Place colorbar as an inset (right) if True.
    xticks, yticks : list, optional
        Custom tick locations.
    annotations : list of dict, optional
        List of annotation dicts (e.g., {"text": "A", "xy": (lon, lat)}).
    export_path : str, optional
        Path to export the figure (without extension).
    export_formats : list, optional
        List of formats to export (e.g., ["png", "pdf"]).
    **kwargs : dict
        Additional keyword arguments for imshow.

    Returns
    -------
    fig : matplotlib.figure.Figure
        The matplotlib figure object.
    ax : matplotlib.axes.Axes
        The matplotlib axes object.
    """
    # Setup
    projection = _setup_map_projection(projection)
    map_kws = map_kws or {}
    fig, ax = plt.subplots(
        subplot_kw={"projection": projection}, figsize=figsize, dpi=dpi
    )

    # Prepare plot arguments
    plot_args = dict(cmap=cmap, vmin=vmin, vmax=vmax, norm=norm)
    plot_args.update({k: v for k, v in kwargs.items() if k not in ["ax", "transform"]})

    # Create the plot
    mesh = da.plot.imshow(ax=ax, transform=ccrs.PlateCarree(), **plot_args)

    # Add map features
    _add_map_features(ax, map_kws)

    # Set labels and title
    _set_axis_labels_and_title(ax, xlabel, ylabel, title)

    # Set custom ticks
    _set_custom_ticks(ax, xticks, yticks)

    # Add annotations
    _add_annotations(ax, annotations)

    # Add colorbar
    _add_colorbar(fig, ax, mesh, colorbar, cbar_label, cbar_inset)

    # Finalize
    fig.tight_layout()
    _export_figure(fig, export_path, export_formats, dpi)

    return fig, ax

plot_quick_map(da, map_kws=None, projection=None, colorbar=True, figsize=None, cmap=None, vmin=None, vmax=None, norm=None, dpi=150, xlabel=None, ylabel=None, title=None, cbar_label=None, cbar_inset=False, xticks=None, yticks=None, annotations=None, export_path=None, export_formats=None, **kwargs)

Create a publication-quality map plot of the data using Cartopy and xarray's default plot method.

Parameters

da : xarray.DataArray The data to plot. map_kws : dict, optional Dictionary of keyword arguments for map features (e.g., coastlines, gridlines, features, borders, land, ocean). projection : cartopy.crs.Projection, optional Cartopy projection to use. Defaults to PlateCarree. colorbar : bool, default: True Whether to add a colorbar. figsize : tuple, optional Figure size. cmap : str or Colormap, optional Colormap to use (supports colorblind-friendly options). vmin, vmax : float, optional Color limits. norm : Normalize, optional Matplotlib normalization (e.g., LogNorm). dpi : int, optional Dots per inch for export. xlabel, ylabel, title : str, optional Axis labels and plot title. cbar_label : str, optional Label for the colorbar. cbar_inset : bool, default: False Place colorbar as an inset (right) if True. xticks, yticks : list, optional Custom tick locations. annotations : list of dict, optional List of annotation dicts (e.g., {"text": "A", "xy": (lon, lat)}). export_path : str, optional Path to export the figure (without extension). export_formats : list, optional List of formats to export (e.g., ["png", "pdf"]). **kwargs : dict Additional keyword arguments for xarray's plot method.

Returns

fig : matplotlib.figure.Figure The matplotlib figure object. ax : matplotlib.axes.Axes The matplotlib axes object.

Source code in src/monet_plots/cartopy_utils.py
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
def plot_quick_map(
    da,
    map_kws=None,
    projection=None,
    colorbar=True,
    figsize=None,
    cmap=None,
    vmin=None,
    vmax=None,
    norm=None,
    dpi=150,
    xlabel=None,
    ylabel=None,
    title=None,
    cbar_label=None,
    cbar_inset=False,
    xticks=None,
    yticks=None,
    annotations=None,
    export_path=None,
    export_formats=None,
    **kwargs,
):
    """
    Create a publication-quality map plot of the data using Cartopy and xarray's
    default plot method.

    Parameters
    ----------
    da : xarray.DataArray
        The data to plot.
    map_kws : dict, optional
        Dictionary of keyword arguments for map features (e.g., coastlines, gridlines, features, borders, land, ocean).
    projection : cartopy.crs.Projection, optional
        Cartopy projection to use. Defaults to PlateCarree.
    colorbar : bool, default: True
        Whether to add a colorbar.
    figsize : tuple, optional
        Figure size.
    cmap : str or Colormap, optional
        Colormap to use (supports colorblind-friendly options).
    vmin, vmax : float, optional
        Color limits.
    norm : Normalize, optional
        Matplotlib normalization (e.g., LogNorm).
    dpi : int, optional
        Dots per inch for export.
    xlabel, ylabel, title : str, optional
        Axis labels and plot title.
    cbar_label : str, optional
        Label for the colorbar.
    cbar_inset : bool, default: False
        Place colorbar as an inset (right) if True.
    xticks, yticks : list, optional
        Custom tick locations.
    annotations : list of dict, optional
        List of annotation dicts (e.g., {"text": "A", "xy": (lon, lat)}).
    export_path : str, optional
        Path to export the figure (without extension).
    export_formats : list, optional
        List of formats to export (e.g., ["png", "pdf"]).
    **kwargs : dict
        Additional keyword arguments for xarray's plot method.

    Returns
    -------
    fig : matplotlib.figure.Figure
        The matplotlib figure object.
    ax : matplotlib.axes.Axes
        The matplotlib axes object.
    """
    # Setup
    projection = _setup_map_projection(projection)
    map_kws = map_kws or {}
    fig, ax = plt.subplots(
        subplot_kw={"projection": projection}, figsize=figsize, dpi=dpi
    )

    # Prepare plot arguments
    plot_args = dict(cmap=cmap, vmin=vmin, vmax=vmax, norm=norm)
    plot_args.update({k: v for k, v in kwargs.items() if k not in ["ax", "transform"]})

    # Create the plot
    mesh = da.plot(ax=ax, transform=ccrs.PlateCarree(), **plot_args)

    # Add map features
    _add_map_features(ax, map_kws)

    # Set labels and title
    _set_axis_labels_and_title(ax, xlabel, ylabel, title)

    # Set custom ticks
    _set_custom_ticks(ax, xticks, yticks)

    # Add annotations
    _add_annotations(ax, annotations)

    # Add colorbar
    _add_colorbar(fig, ax, mesh, colorbar, cbar_label, cbar_inset)

    # Finalize
    fig.tight_layout()
    _export_figure(fig, export_path, export_formats, dpi)

    return fig, ax

Colorbar helper functions

cmap_discretize(cmap, N)

Return a discrete colormap from a continuous colormap.

Creates a new colormap by discretizing an existing continuous colormap into N distinct colors while preserving the color transitions.

Parameters

cmap : str or matplotlib.colors.Colormap Colormap instance or registered colormap name to discretize. Example: cm.jet, 'viridis', etc. N : int Number of discrete colors to use in the new colormap.

Returns

matplotlib.colors.LinearSegmentedColormap A new colormap object with N discrete colors based on the input colormap. The name will be the original colormap name with "_N" appended.

Source code in src/monet_plots/colorbars.py
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
def cmap_discretize(cmap, N):
    """Return a discrete colormap from a continuous colormap.

    Creates a new colormap by discretizing an existing continuous colormap
    into N distinct colors while preserving the color transitions.

    Parameters
    ----------
    cmap : str or matplotlib.colors.Colormap
        Colormap instance or registered colormap name to discretize.
        Example: cm.jet, 'viridis', etc.
    N : int
        Number of discrete colors to use in the new colormap.

    Returns
    -------
    matplotlib.colors.LinearSegmentedColormap
        A new colormap object with N discrete colors based on the input colormap.
        The name will be the original colormap name with "_N" appended.
    """
    if isinstance(cmap, str):
        cmap = plt.get_cmap(cmap)
    colors_i = np.concatenate((np.linspace(0, 1.0, N), (0.0, 0.0, 0.0, 0.0)))
    colors_rgba = cmap(colors_i)
    indices = np.linspace(0, 1.0, N + 1)
    cdict = {}
    for ki, key in enumerate(("red", "green", "blue")):
        cdict[key] = [
            (indices[i], colors_rgba[i - 1, ki], colors_rgba[i, ki])
            for i in range(N + 1)
        ]
    # Return colormap object.
    return mcolors.LinearSegmentedColormap(cmap.name + "_%d" % N, cdict, 1024)

colorbar_index(ncolors, cmap, minval=None, maxval=None, dtype='int', basemap=None, ax=None, **kwargs)

Create a colorbar with discrete colors and custom tick labels.

Parameters

ncolors : int Number of discrete colors to use in the colorbar. cmap : str or matplotlib.colors.Colormap Colormap to discretize and use for the colorbar. minval : float, optional Minimum value for the colorbar tick labels. If None and maxval is None, tick labels will range from 0 to ncolors. If None and maxval is provided, tick labels will range from 0 to maxval. maxval : float, optional Maximum value for the colorbar tick labels. If None, tick labels will range from 0 or minval to ncolors. dtype : str or type, default "int" Data type for tick label values (e.g., "int", "float"). basemap : matplotlib.mpl_toolkits.basemap.Basemap, optional Basemap instance to attach the colorbar to. If None, uses plt.colorbar. ax : matplotlib.axes.Axes, optional Axes to attach the colorbar to. If None, uses plt.gca(). **kwargs : Any Additional keyword arguments for plt.colorbar.

Returns

tuple (colorbar, discretized_cmap) where: - colorbar is the matplotlib.colorbar.Colorbar instance - discretized_cmap is the discretized colormap

Source code in src/monet_plots/colorbars.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
def colorbar_index(
    ncolors,
    cmap,
    minval=None,
    maxval=None,
    dtype="int",
    basemap=None,
    ax=None,
    **kwargs,
):
    """Create a colorbar with discrete colors and custom tick labels.

    Parameters
    ----------
    ncolors : int
        Number of discrete colors to use in the colorbar.
    cmap : str or matplotlib.colors.Colormap
        Colormap to discretize and use for the colorbar.
    minval : float, optional
        Minimum value for the colorbar tick labels. If None and maxval is None,
        tick labels will range from 0 to ncolors. If None and maxval is provided,
        tick labels will range from 0 to maxval.
    maxval : float, optional
        Maximum value for the colorbar tick labels. If None, tick labels
        will range from 0 or minval to ncolors.
    dtype : str or type, default "int"
        Data type for tick label values (e.g., "int", "float").
    basemap : matplotlib.mpl_toolkits.basemap.Basemap, optional
        Basemap instance to attach the colorbar to. If None, uses plt.colorbar.
    ax : matplotlib.axes.Axes, optional
        Axes to attach the colorbar to. If None, uses plt.gca().
    **kwargs : Any
        Additional keyword arguments for plt.colorbar.

    Returns
    -------
    tuple
        (colorbar, discretized_cmap) where:
        - colorbar is the matplotlib.colorbar.Colorbar instance
        - discretized_cmap is the discretized colormap
    """
    import matplotlib.cm as cm

    cmap = cmap_discretize(cmap, ncolors)
    mappable = cm.ScalarMappable(cmap=cmap)
    mappable.set_array([])
    mappable.set_clim(-0.5, ncolors + 0.5)

    if basemap is not None:
        colorbar = basemap.colorbar(mappable, format="%1.2g")
    elif ax is not None:
        from mpl_toolkits.axes_grid1.inset_locator import inset_axes

        # Use inset_axes to ensure the colorbar height matches the axes perfectly
        # regardless of aspect ratio or projection.
        cax = inset_axes(
            ax,
            width="5%",
            height="100%",
            loc="lower left",
            bbox_to_anchor=(1.05, 0.0, 1.0, 1.0),
            bbox_transform=ax.transAxes,
            borderpad=0,
        )
        cbar_kwargs = {"format": "%1.2g", "cax": cax}
        cbar_kwargs.update(kwargs)
        colorbar = plt.colorbar(mappable, **cbar_kwargs)
    else:
        # Fallback for case where no axes is provided
        cbar_kwargs = {"format": "%1.2g", "fraction": 0.046, "pad": 0.04}
        cbar_kwargs.update(kwargs)
        colorbar = plt.colorbar(mappable, **cbar_kwargs)

    colorbar.set_ticks(np.linspace(0, ncolors, ncolors))
    if (minval is None) & (maxval is not None):
        colorbar.set_ticklabels(
            np.around(np.linspace(0, maxval, ncolors).astype(dtype), 2)
        )
    elif (minval is None) & (maxval is None):
        colorbar.set_ticklabels(
            np.around(np.linspace(0, ncolors, ncolors).astype(dtype), 2)
        )
    else:
        colorbar.set_ticklabels(
            np.around(np.linspace(minval, maxval, ncolors).astype(dtype), 2)
        )

    return colorbar, cmap

get_discrete_scale(data, cmap='viridis', n_levels=10, vmin=None, vmax=None, extend='both')

Get a discrete colormap and BoundaryNorm with 'nice' numbers.

Parameters

data : array-like The data to scale. cmap : str or matplotlib.colors.Colormap, optional The colormap to use, by default "viridis". n_levels : int, optional Target number of discrete levels, by default 10. vmin : float, optional Minimum value for the scale. vmax : float, optional Maximum value for the scale. extend : str, optional Whether to extend the scale ('neither', 'both', 'min', 'max'), by default "both".

Returns

tuple (colormap, BoundaryNorm)

Source code in src/monet_plots/colorbars.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def get_discrete_scale(
    data, cmap="viridis", n_levels=10, vmin=None, vmax=None, extend="both"
):
    """
    Get a discrete colormap and BoundaryNorm with 'nice' numbers.

    Parameters
    ----------
    data : array-like
        The data to scale.
    cmap : str or matplotlib.colors.Colormap, optional
        The colormap to use, by default "viridis".
    n_levels : int, optional
        Target number of discrete levels, by default 10.
    vmin : float, optional
        Minimum value for the scale.
    vmax : float, optional
        Maximum value for the scale.
    extend : str, optional
        Whether to extend the scale ('neither', 'both', 'min', 'max'),
        by default "both".

    Returns
    -------
    tuple
        (colormap, BoundaryNorm)
    """
    if vmin is None:
        vmin = np.nanmin(data)
    if vmax is None:
        vmax = np.nanmax(data)

    locator = MaxNLocator(nbins=n_levels, steps=[1, 2, 2.5, 5, 10])
    levels = locator.tick_values(vmin, vmax)

    if isinstance(cmap, str):
        cmap_obj = plt.get_cmap(cmap)
    else:
        cmap_obj = cmap

    n_colors = len(levels) - 1
    discrete_cmap = cmap_discretize(cmap_obj, n_colors)

    norm = mcolors.BoundaryNorm(levels, ncolors=discrete_cmap.N, extend=extend)

    return discrete_cmap, norm

get_diverging_scale(data, cmap='RdBu_r', center=0, span=None, p_span=None)

Get a diverging colormap and normalization object centered at a value.

Parameters

data : array-like The data to scale. cmap : str or matplotlib.colors.Colormap, optional The colormap to use, by default "RdBu_r". center : float, optional The value to center the scale at, by default 0. span : float, optional The absolute range from the center (center +/- span). p_span : float, optional The percentile of absolute differences from center to use as span.

Returns

tuple (colormap, Normalize)

Source code in src/monet_plots/colorbars.py
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def get_diverging_scale(data, cmap="RdBu_r", center=0, span=None, p_span=None):
    """
    Get a diverging colormap and normalization object centered at a value.

    Parameters
    ----------
    data : array-like
        The data to scale.
    cmap : str or matplotlib.colors.Colormap, optional
        The colormap to use, by default "RdBu_r".
    center : float, optional
        The value to center the scale at, by default 0.
    span : float, optional
        The absolute range from the center (center +/- span).
    p_span : float, optional
        The percentile of absolute differences from center to use as span.

    Returns
    -------
    tuple
        (colormap, Normalize)
    """
    if span is not None:
        pass
    elif p_span is not None:
        diff = np.abs(data - center)
        span = np.nanpercentile(diff, p_span)
    else:
        span = np.nanmax(np.abs(data - center))

    vmin = center - span
    vmax = center + span

    norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
    if isinstance(cmap, str):
        cmap = plt.get_cmap(cmap)

    return cmap, norm

get_linear_scale(data, cmap='viridis', vmin=None, vmax=None, p_min=None, p_max=None)

Get a linear colormap and normalization object.

Parameters

data : array-like The data to scale. cmap : str or matplotlib.colors.Colormap, optional The colormap to use, by default "viridis". vmin : float, optional Minimum value for the scale. If None, uses min(data) or p_min. vmax : float, optional Maximum value for the scale. If None, uses max(data) or p_max. p_min : float, optional Percentile for minimum value (0-100). p_max : float, optional Percentile for maximum value (0-100).

Returns

tuple (colormap, Normalize)

Source code in src/monet_plots/colorbars.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def get_linear_scale(
    data, cmap="viridis", vmin=None, vmax=None, p_min=None, p_max=None
):
    """
    Get a linear colormap and normalization object.

    Parameters
    ----------
    data : array-like
        The data to scale.
    cmap : str or matplotlib.colors.Colormap, optional
        The colormap to use, by default "viridis".
    vmin : float, optional
        Minimum value for the scale. If None, uses min(data) or p_min.
    vmax : float, optional
        Maximum value for the scale. If None, uses max(data) or p_max.
    p_min : float, optional
        Percentile for minimum value (0-100).
    p_max : float, optional
        Percentile for maximum value (0-100).

    Returns
    -------
    tuple
        (colormap, Normalize)
    """
    if p_min is not None:
        vmin = np.nanpercentile(data, p_min)
    if p_max is not None:
        vmax = np.nanpercentile(data, p_max)

    if vmin is None:
        vmin = np.nanmin(data)
    if vmax is None:
        vmax = np.nanmax(data)

    norm = mcolors.Normalize(vmin=vmin, vmax=vmax)
    if isinstance(cmap, str):
        cmap = plt.get_cmap(cmap)

    return cmap, norm

get_log_scale(data, cmap='viridis', vmin=None, vmax=None)

Get a logarithmic colormap and normalization object.

Parameters

data : array-like The data to scale. cmap : str or matplotlib.colors.Colormap, optional The colormap to use, by default "viridis". vmin : float, optional Minimum value for the scale (>0). vmax : float, optional Maximum value for the scale.

Returns

tuple (colormap, LogNorm)

Source code in src/monet_plots/colorbars.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
def get_log_scale(data, cmap="viridis", vmin=None, vmax=None):
    """
    Get a logarithmic colormap and normalization object.

    Parameters
    ----------
    data : array-like
        The data to scale.
    cmap : str or matplotlib.colors.Colormap, optional
        The colormap to use, by default "viridis".
    vmin : float, optional
        Minimum value for the scale (>0).
    vmax : float, optional
        Maximum value for the scale.

    Returns
    -------
    tuple
        (colormap, LogNorm)
    """
    data_positive = data[data > 0]
    if vmin is None:
        vmin = np.nanmin(data_positive) if data_positive.size > 0 else 1e-1
    if vmax is None:
        vmax = np.nanmax(data)

    norm = mcolors.LogNorm(vmin=vmin, vmax=vmax)
    if isinstance(cmap, str):
        cmap = plt.get_cmap(cmap)

    return cmap, norm

get_logo_path(name='monet_plots.png')

Get the path to a bundled logo asset.

Parameters

name : str, optional The name of the logo file, by default "monet_plots.png".

Returns

str The full path to the logo file.

Source code in src/monet_plots/plot_utils.py
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
def get_logo_path(name: str = "monet_plots.png") -> str:
    """
    Get the path to a bundled logo asset.

    Parameters
    ----------
    name : str, optional
        The name of the logo file, by default "monet_plots.png".

    Returns
    -------
    str
        The full path to the logo file.
    """
    import os

    return os.path.join(os.path.dirname(__file__), "assets", name)

get_plot_kwargs(cmap=None, norm=None, **kwargs)

Helper to prepare keyword arguments for plotting functions.

This function handles cases where cmap might be a tuple of (colormap, norm) returned by the scaling tools in colorbars.py.

Parameters

cmap : Any, optional Colormap name, object, or (colormap, norm) tuple. norm : Any, optional Normalization object. **kwargs : Any Additional keyword arguments.

Returns

dict A dictionary of keyword arguments suitable for matplotlib plotting functions.

Source code in src/monet_plots/plot_utils.py
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
def get_plot_kwargs(cmap: Any = None, norm: Any = None, **kwargs: Any) -> dict:
    """
    Helper to prepare keyword arguments for plotting functions.

    This function handles cases where `cmap` might be a tuple of
    (colormap, norm) returned by the scaling tools in `colorbars.py`.

    Parameters
    ----------
    cmap : Any, optional
        Colormap name, object, or (colormap, norm) tuple.
    norm : Any, optional
        Normalization object.
    **kwargs : Any
        Additional keyword arguments.

    Returns
    -------
    dict
        A dictionary of keyword arguments suitable for matplotlib plotting functions.
    """
    if isinstance(cmap, tuple) and len(cmap) == 2:
        kwargs["cmap"] = cmap[0]
        kwargs["norm"] = cmap[1]
    elif cmap is not None:
        kwargs["cmap"] = cmap

    if norm is not None:
        kwargs["norm"] = norm

    return kwargs

normalize_data(data, prefer_xarray=True)

Public API for normalizing data.

Parameters:

Name Type Description Default
data Any

Input data of various types

required
prefer_xarray bool

If True, attempts to convert non-pandas/non-xarray objects to xarray. If False, returns pandas objects as-is and converts others to DataFrame.

True

Returns:

Type Description
Any

Either an xarray DataArray, xarray Dataset, or pandas DataFrame

Source code in src/monet_plots/plot_utils.py
285
286
287
288
289
290
291
292
293
294
295
296
297
def normalize_data(data: Any, prefer_xarray: bool = True) -> Any:
    """
    Public API for normalizing data.

    Args:
        data: Input data of various types
        prefer_xarray: If True, attempts to convert non-pandas/non-xarray objects to xarray.
                      If False, returns pandas objects as-is and converts others to DataFrame.

    Returns:
        Either an xarray DataArray, xarray Dataset, or pandas DataFrame
    """
    return _normalize_data(data, prefer_xarray=prefer_xarray)

to_dataframe(data)

Convert input data to a pandas DataFrame.

Parameters:

Name Type Description Default
data Any

Input data. Can be a pandas DataFrame, xarray DataArray, xarray Dataset, or numpy ndarray.

required

Returns:

Type Description
DataFrame

A pandas DataFrame.

Raises:

Type Description
TypeError

If the input data type is not supported.

Source code in src/monet_plots/plot_utils.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def to_dataframe(data: Any) -> pd.DataFrame:
    """
    Convert input data to a pandas DataFrame.

    Args:
        data: Input data. Can be a pandas DataFrame, xarray DataArray,
              xarray Dataset, or numpy ndarray.

    Returns:
        A pandas DataFrame.

    Raises:
        TypeError: If the input data type is not supported.
    """
    if isinstance(data, pd.DataFrame):
        return data

    # Using hasattr to avoid direct dependency on xarray for users who don't have it
    # installed.
    if hasattr(data, "to_dataframe"):  # Works for both xarray DataArray and Dataset
        return data.to_dataframe()

    if isinstance(data, np.ndarray):
        if data.ndim == 1:
            return pd.DataFrame(data, columns=["col_0"])
        elif data.ndim == 2:
            return pd.DataFrame(
                data, columns=[f"col_{i}" for i in range(data.shape[1])]
            )
        else:
            raise ValueError(f"numpy array with {data.ndim} dimensions not supported")

    raise TypeError(f"Unsupported data type: {type(data).__name__}")

validate_data_array(data, required_dims=None)

Validate data array parameters.

Parameters:

Name Type Description Default
data Any

Data to validate

required
required_dims Optional[list]

List of required dimension names

None

Raises:

Type Description
TypeError

If data type is invalid

ValueError

If data dimensions are invalid

Source code in src/monet_plots/plot_utils.py
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
def validate_data_array(data: Any, required_dims: Optional[list] = None) -> None:
    """
    Validate data array parameters.

    Args:
        data: Data to validate
        required_dims: List of required dimension names

    Raises:
        TypeError: If data type is invalid
        ValueError: If data dimensions are invalid
    """
    if data is None:
        raise ValueError("data cannot be None")

    # Check if data has required attributes
    if not hasattr(data, "shape"):
        raise TypeError("data must have a shape attribute")

    if required_dims:
        if not hasattr(data, "dims"):
            raise TypeError("data must have dims attribute for dimension validation")

        for dim in required_dims:
            if dim not in data.dims:
                raise ValueError(
                    f"required dimension '{dim}' not found in data dimensions {data.dims}"
                )

validate_dataframe(df, required_columns=None)

Validate DataFrame parameters.

Parameters:

Name Type Description Default
df Any

DataFrame to validate

required
required_columns Optional[list]

List of required column names

None

Raises:

Type Description
TypeError

If DataFrame type is invalid

ValueError

If DataFrame structure is invalid

Source code in src/monet_plots/plot_utils.py
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
def validate_dataframe(df: Any, required_columns: Optional[list] = None) -> None:
    """
    Validate DataFrame parameters.

    Args:
        df: DataFrame to validate
        required_columns: List of required column names

    Raises:
        TypeError: If DataFrame type is invalid
        ValueError: If DataFrame structure is invalid
    """
    if df is None:
        raise ValueError("DataFrame cannot be None")

    if not hasattr(df, "columns"):
        raise TypeError("object must have columns attribute")

    if required_columns:
        missing_columns = [col for col in required_columns if col not in df.columns]
        if missing_columns:
            raise ValueError(f"missing required columns: {missing_columns}")

    if len(df) == 0:
        raise ValueError("DataFrame cannot be empty")

validate_plot_parameters(plot_class, method, **kwargs)

Validate parameters for plot methods.

Parameters:

Name Type Description Default
plot_class str

The plot class name

required
method str

The method name

required
**kwargs

Parameters to validate

{}

Raises:

Type Description
TypeError

If parameter types are invalid

ValueError

If parameter values are invalid

Source code in src/monet_plots/plot_utils.py
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
def validate_plot_parameters(plot_class: str, method: str, **kwargs) -> None:
    """
    Validate parameters for plot methods.

    Args:
        plot_class: The plot class name
        method: The method name
        **kwargs: Parameters to validate

    Raises:
        TypeError: If parameter types are invalid
        ValueError: If parameter values are invalid
    """
    if plot_class == "SpatialPlot" and method == "plot":
        _validate_spatial_plot_params(kwargs)
    elif plot_class == "TimeSeriesPlot" and method == "plot":
        _validate_timeseries_plot_params(kwargs)

BasePlot

Base class for all plots.

Handles figure and axis creation, applies a consistent style, and provides a common interface for saving and closing plots.

Source code in src/monet_plots/plots/base.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
class BasePlot:
    """Base class for all plots.

    Handles figure and axis creation, applies a consistent style,
    and provides a common interface for saving and closing plots.
    """

    def __init__(self, fig=None, ax=None, style: str | None = "wiley", **kwargs):
        """Initializes the plot with a consistent style.

        If `fig` and `ax` are not provided, a new figure and axes
        are created.

        Args:
            fig (matplotlib.figure.Figure, optional): Figure to plot on.
            ax (matplotlib.axes.Axes, optional): Axes to plot on.
            style (str, optional): Style name to apply (e.g., 'wiley', 'paper').
                If None, no style is applied. Defaults to 'wiley'.
            **kwargs: Additional keyword arguments for `plt.subplots`.
        """
        if style:
            set_style(style)

        if ax is not None:
            self.ax = ax
            if fig is not None:
                self.fig = fig
            else:
                self.fig = ax.figure
        elif fig is not None:
            self.fig = fig
            self.ax = None
        else:
            self.fig, self.ax = plt.subplots(**kwargs)

    def save(self, filename, **kwargs):
        """Saves the plot to a file.

        Args:
            filename (str): The name of the file to save the plot to.
            **kwargs: Additional keyword arguments for `savefig`.
        """
        self.fig.savefig(filename, **kwargs)

    def close(self):
        """Closes the plot figure."""
        plt.close(self.fig)

    def add_logo(
        self,
        logo: str | Any | None = None,
        *,
        ax: matplotlib.axes.Axes | None = None,
        loc: str = "upper right",
        scale: float = 0.1,
        pad: float = 0.05,
        **kwargs: Any,
    ) -> Any:
        """Adds a logo to the plot.

        Parameters
        ----------
        logo : str or array-like, optional
            Path to the logo image, a URL, or a numpy array.
            If None, the default MONET logo is used.
        ax : matplotlib.axes.Axes, optional
            The axes to add the logo to. Defaults to `self.ax`.
        loc : str, optional
            Location of the logo ('upper right', 'upper left', 'lower right',
            'lower left', 'center'). Defaults to "upper right".
        scale : float, optional
            Scaling factor for the logo, by default 0.1.
        pad : float, optional
            Padding from the edge of the axes, by default 0.05.
        **kwargs : Any
            Additional keyword arguments passed to `AnnotationBbox`.

        Returns
        -------
        matplotlib.offsetbox.AnnotationBbox
            The added logo object.
        """
        import matplotlib.image as mpimg
        from matplotlib.offsetbox import AnnotationBbox, OffsetImage

        from ..plot_utils import get_logo_path

        if ax is None:
            ax = self.ax

        if logo is None:
            logo = get_logo_path()

        if isinstance(logo, str):
            if logo.startswith("http"):
                import io
                import urllib.request

                with urllib.request.urlopen(logo) as url:
                    f = io.BytesIO(url.read())
                img = mpimg.imread(f)
            else:
                img = mpimg.imread(logo)
        else:
            img = logo

        imagebox = OffsetImage(img, zoom=scale)
        imagebox.image.axes = ax

        # Mapping of location strings to axes fraction coordinates and box alignment
        loc_map = {
            "upper right": ((1 - pad, 1 - pad), (1, 1)),
            "upper left": ((pad, 1 - pad), (0, 1)),
            "lower right": ((1 - pad, pad), (1, 0)),
            "lower left": ((pad, pad), (0, 0)),
            "center": ((0.5, 0.5), (0.5, 0.5)),
        }

        if loc in loc_map:
            xy, box_alignment = loc_map[loc]
        else:
            # If loc is not a string in loc_map, assume it might be a coordinate
            # tuple, but for simplicity we default to upper right if it's invalid
            if isinstance(loc, tuple) and len(loc) == 2:
                xy = loc
                box_alignment = (0.5, 0.5)
            else:
                xy, box_alignment = loc_map["upper right"]

        ab = AnnotationBbox(
            imagebox,
            xy,
            xycoords="axes fraction",
            box_alignment=box_alignment,
            pad=0,
            frameon=False,
            **kwargs,
        )

        ax.add_artist(ab)
        return ab

    def add_colorbar(
        self,
        mappable: matplotlib.cm.ScalarMappable,
        *,
        ax: matplotlib.axes.Axes | None = None,
        label: str | None = None,
        loc: str = "right",
        size: str = "5%",
        pad: float = 0.05,
        **kwargs: Any,
    ) -> matplotlib.colorbar.Colorbar:
        """Add a colorbar that matches the axes size.

        This method uses `inset_axes` to ensure the colorbar height (or width)
        matches the axes dimensions exactly, which is particularly useful for
        geospatial plots with fixed aspects.

        Parameters
        ----------
        mappable : matplotlib.cm.ScalarMappable
            The mappable object (e.g., from imshow, scatter, contourf).
        ax : matplotlib.axes.Axes, optional
            The axes to attach the colorbar to. Defaults to `self.ax`.
        label : str, optional
            Label for the colorbar, by default None.
        loc : str, optional
            Location of the colorbar ('right', 'left', 'top', 'bottom'),
            by default "right".
        size : str, optional
            Width (if vertical) or height (if horizontal) of the colorbar,
            as a percentage of the axes, by default "5%".
        pad : float, optional
            Padding between the axes and the colorbar, by default 0.05.
        **kwargs : Any
            Additional keyword arguments passed to `fig.colorbar`.

        Returns
        -------
        matplotlib.colorbar.Colorbar
            The created colorbar object.
        """
        from mpl_toolkits.axes_grid1.inset_locator import inset_axes

        if ax is None:
            ax = self.ax

        orientation = "vertical" if loc in ["right", "left"] else "horizontal"

        # Determine anchor and position based on location
        if loc == "right":
            bbox_to_anchor = (1.0 + pad, 0.0, 1.0, 1.0)
            width, height = size, "100%"
        elif loc == "left":
            bbox_to_anchor = (-(float(size.strip("%")) / 100.0 + pad), 0.0, 1.0, 1.0)
            width, height = size, "100%"
        elif loc == "top":
            bbox_to_anchor = (0.0, 1.0 + pad, 1.0, 1.0)
            width, height = "100%", size
        else:  # bottom
            bbox_to_anchor = (0.0, -(float(size.strip("%")) / 100.0 + pad), 1.0, 1.0)
            width, height = "100%", size

        cax = inset_axes(
            ax,
            width=width,
            height=height,
            loc="lower left",
            bbox_to_anchor=bbox_to_anchor,
            bbox_transform=ax.transAxes,
            borderpad=0,
        )

        cb = self.fig.colorbar(mappable, cax=cax, orientation=orientation, **kwargs)

        if label:
            cb.set_label(label)

        return cb

__init__(fig=None, ax=None, style='wiley', **kwargs)

Initializes the plot with a consistent style.

If fig and ax are not provided, a new figure and axes are created.

Parameters:

Name Type Description Default
fig Figure

Figure to plot on.

None
ax Axes

Axes to plot on.

None
style str

Style name to apply (e.g., 'wiley', 'paper'). If None, no style is applied. Defaults to 'wiley'.

'wiley'
**kwargs

Additional keyword arguments for plt.subplots.

{}
Source code in src/monet_plots/plots/base.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
def __init__(self, fig=None, ax=None, style: str | None = "wiley", **kwargs):
    """Initializes the plot with a consistent style.

    If `fig` and `ax` are not provided, a new figure and axes
    are created.

    Args:
        fig (matplotlib.figure.Figure, optional): Figure to plot on.
        ax (matplotlib.axes.Axes, optional): Axes to plot on.
        style (str, optional): Style name to apply (e.g., 'wiley', 'paper').
            If None, no style is applied. Defaults to 'wiley'.
        **kwargs: Additional keyword arguments for `plt.subplots`.
    """
    if style:
        set_style(style)

    if ax is not None:
        self.ax = ax
        if fig is not None:
            self.fig = fig
        else:
            self.fig = ax.figure
    elif fig is not None:
        self.fig = fig
        self.ax = None
    else:
        self.fig, self.ax = plt.subplots(**kwargs)

add_colorbar(mappable, *, ax=None, label=None, loc='right', size='5%', pad=0.05, **kwargs)

Add a colorbar that matches the axes size.

This method uses inset_axes to ensure the colorbar height (or width) matches the axes dimensions exactly, which is particularly useful for geospatial plots with fixed aspects.

Parameters

mappable : matplotlib.cm.ScalarMappable The mappable object (e.g., from imshow, scatter, contourf). ax : matplotlib.axes.Axes, optional The axes to attach the colorbar to. Defaults to self.ax. label : str, optional Label for the colorbar, by default None. loc : str, optional Location of the colorbar ('right', 'left', 'top', 'bottom'), by default "right". size : str, optional Width (if vertical) or height (if horizontal) of the colorbar, as a percentage of the axes, by default "5%". pad : float, optional Padding between the axes and the colorbar, by default 0.05. **kwargs : Any Additional keyword arguments passed to fig.colorbar.

Returns

matplotlib.colorbar.Colorbar The created colorbar object.

Source code in src/monet_plots/plots/base.py
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
def add_colorbar(
    self,
    mappable: matplotlib.cm.ScalarMappable,
    *,
    ax: matplotlib.axes.Axes | None = None,
    label: str | None = None,
    loc: str = "right",
    size: str = "5%",
    pad: float = 0.05,
    **kwargs: Any,
) -> matplotlib.colorbar.Colorbar:
    """Add a colorbar that matches the axes size.

    This method uses `inset_axes` to ensure the colorbar height (or width)
    matches the axes dimensions exactly, which is particularly useful for
    geospatial plots with fixed aspects.

    Parameters
    ----------
    mappable : matplotlib.cm.ScalarMappable
        The mappable object (e.g., from imshow, scatter, contourf).
    ax : matplotlib.axes.Axes, optional
        The axes to attach the colorbar to. Defaults to `self.ax`.
    label : str, optional
        Label for the colorbar, by default None.
    loc : str, optional
        Location of the colorbar ('right', 'left', 'top', 'bottom'),
        by default "right".
    size : str, optional
        Width (if vertical) or height (if horizontal) of the colorbar,
        as a percentage of the axes, by default "5%".
    pad : float, optional
        Padding between the axes and the colorbar, by default 0.05.
    **kwargs : Any
        Additional keyword arguments passed to `fig.colorbar`.

    Returns
    -------
    matplotlib.colorbar.Colorbar
        The created colorbar object.
    """
    from mpl_toolkits.axes_grid1.inset_locator import inset_axes

    if ax is None:
        ax = self.ax

    orientation = "vertical" if loc in ["right", "left"] else "horizontal"

    # Determine anchor and position based on location
    if loc == "right":
        bbox_to_anchor = (1.0 + pad, 0.0, 1.0, 1.0)
        width, height = size, "100%"
    elif loc == "left":
        bbox_to_anchor = (-(float(size.strip("%")) / 100.0 + pad), 0.0, 1.0, 1.0)
        width, height = size, "100%"
    elif loc == "top":
        bbox_to_anchor = (0.0, 1.0 + pad, 1.0, 1.0)
        width, height = "100%", size
    else:  # bottom
        bbox_to_anchor = (0.0, -(float(size.strip("%")) / 100.0 + pad), 1.0, 1.0)
        width, height = "100%", size

    cax = inset_axes(
        ax,
        width=width,
        height=height,
        loc="lower left",
        bbox_to_anchor=bbox_to_anchor,
        bbox_transform=ax.transAxes,
        borderpad=0,
    )

    cb = self.fig.colorbar(mappable, cax=cax, orientation=orientation, **kwargs)

    if label:
        cb.set_label(label)

    return cb

Adds a logo to the plot.

Parameters

logo : str or array-like, optional Path to the logo image, a URL, or a numpy array. If None, the default MONET logo is used. ax : matplotlib.axes.Axes, optional The axes to add the logo to. Defaults to self.ax. loc : str, optional Location of the logo ('upper right', 'upper left', 'lower right', 'lower left', 'center'). Defaults to "upper right". scale : float, optional Scaling factor for the logo, by default 0.1. pad : float, optional Padding from the edge of the axes, by default 0.05. **kwargs : Any Additional keyword arguments passed to AnnotationBbox.

Returns

matplotlib.offsetbox.AnnotationBbox The added logo object.

Source code in src/monet_plots/plots/base.py
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
def add_logo(
    self,
    logo: str | Any | None = None,
    *,
    ax: matplotlib.axes.Axes | None = None,
    loc: str = "upper right",
    scale: float = 0.1,
    pad: float = 0.05,
    **kwargs: Any,
) -> Any:
    """Adds a logo to the plot.

    Parameters
    ----------
    logo : str or array-like, optional
        Path to the logo image, a URL, or a numpy array.
        If None, the default MONET logo is used.
    ax : matplotlib.axes.Axes, optional
        The axes to add the logo to. Defaults to `self.ax`.
    loc : str, optional
        Location of the logo ('upper right', 'upper left', 'lower right',
        'lower left', 'center'). Defaults to "upper right".
    scale : float, optional
        Scaling factor for the logo, by default 0.1.
    pad : float, optional
        Padding from the edge of the axes, by default 0.05.
    **kwargs : Any
        Additional keyword arguments passed to `AnnotationBbox`.

    Returns
    -------
    matplotlib.offsetbox.AnnotationBbox
        The added logo object.
    """
    import matplotlib.image as mpimg
    from matplotlib.offsetbox import AnnotationBbox, OffsetImage

    from ..plot_utils import get_logo_path

    if ax is None:
        ax = self.ax

    if logo is None:
        logo = get_logo_path()

    if isinstance(logo, str):
        if logo.startswith("http"):
            import io
            import urllib.request

            with urllib.request.urlopen(logo) as url:
                f = io.BytesIO(url.read())
            img = mpimg.imread(f)
        else:
            img = mpimg.imread(logo)
    else:
        img = logo

    imagebox = OffsetImage(img, zoom=scale)
    imagebox.image.axes = ax

    # Mapping of location strings to axes fraction coordinates and box alignment
    loc_map = {
        "upper right": ((1 - pad, 1 - pad), (1, 1)),
        "upper left": ((pad, 1 - pad), (0, 1)),
        "lower right": ((1 - pad, pad), (1, 0)),
        "lower left": ((pad, pad), (0, 0)),
        "center": ((0.5, 0.5), (0.5, 0.5)),
    }

    if loc in loc_map:
        xy, box_alignment = loc_map[loc]
    else:
        # If loc is not a string in loc_map, assume it might be a coordinate
        # tuple, but for simplicity we default to upper right if it's invalid
        if isinstance(loc, tuple) and len(loc) == 2:
            xy = loc
            box_alignment = (0.5, 0.5)
        else:
            xy, box_alignment = loc_map["upper right"]

    ab = AnnotationBbox(
        imagebox,
        xy,
        xycoords="axes fraction",
        box_alignment=box_alignment,
        pad=0,
        frameon=False,
        **kwargs,
    )

    ax.add_artist(ab)
    return ab

close()

Closes the plot figure.

Source code in src/monet_plots/plots/base.py
62
63
64
def close(self):
    """Closes the plot figure."""
    plt.close(self.fig)

save(filename, **kwargs)

Saves the plot to a file.

Parameters:

Name Type Description Default
filename str

The name of the file to save the plot to.

required
**kwargs

Additional keyword arguments for savefig.

{}
Source code in src/monet_plots/plots/base.py
53
54
55
56
57
58
59
60
def save(self, filename, **kwargs):
    """Saves the plot to a file.

    Args:
        filename (str): The name of the file to save the plot to.
        **kwargs: Additional keyword arguments for `savefig`.
    """
    self.fig.savefig(filename, **kwargs)

BivariatePolarPlot

Bases: BasePlot

Bivariate polar plot.

Shows how a variable varies with wind speed and wind direction. Uses polar coordinates where the angle represents wind direction and the radius represents wind speed.

Source code in src/monet_plots/plots/polar.py
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
class BivariatePolarPlot(BasePlot):
    """Bivariate polar plot.

    Shows how a variable varies with wind speed and wind direction.
    Uses polar coordinates where the angle represents wind direction
    and the radius represents wind speed.
    """

    def __init__(
        self,
        data: Any,
        ws_col: str,
        wd_col: str,
        val_col: str,
        *,
        ws_max: Optional[float] = None,
        **kwargs,
    ):
        """
        Initialize Bivariate Polar Plot.

        Args:
            data: Input data (DataFrame, DataArray, etc.).
            ws_col: Column name for wind speed.
            wd_col: Column name for wind direction (degrees).
            val_col: Column name for the value to plot.
            ws_max: Maximum wind speed to show.
            **kwargs: Arguments passed to BasePlot. Note: 'subplot_kw={"projection": "polar"}'
                      is added automatically if not provided.
        """
        if "subplot_kw" not in kwargs:
            kwargs["subplot_kw"] = {"projection": "polar"}
        elif "projection" not in kwargs["subplot_kw"]:
            kwargs["subplot_kw"]["projection"] = "polar"

        super().__init__(**kwargs)
        self.df = to_dataframe(data).dropna(subset=[ws_col, wd_col, val_col])
        self.ws_col = ws_col
        self.wd_col = wd_col
        self.val_col = val_col
        self.ws_max = ws_max or self.df[ws_col].max()

    def plot(
        self, n_bins_ws: int = 10, n_bins_wd: int = 36, cmap: str = "viridis", **kwargs
    ):
        """
        Generate the bivariate polar plot.

        Uses binning to aggregate data before plotting.
        """
        # Convert wind direction to radians and adjust for polar plot (0 is North/Up)
        # Matplotlib polar 0 is East (right). We want 0 at North.
        # Wind direction is usually 0=North, 90=East.
        # theta = (90 - wd) * pi / 180
        theta_rad = np.radians(self.df[self.wd_col])
        # Matplotlib's polar axis by default has 0 at East.
        # To make 0 North, we can use:
        self.ax.set_theta_zero_location("N")
        self.ax.set_theta_direction(-1)  # Clockwise

        # Binning
        ws_bins = np.linspace(0, self.ws_max, n_bins_ws + 1)
        wd_bins = np.radians(np.linspace(0, 360, n_bins_wd + 1))

        # We can use np.histogram2d
        # Note: wd_bins is in radians
        H, xedges, yedges = np.histogram2d(
            theta_rad,
            self.df[self.ws_col],
            bins=[wd_bins, ws_bins],
            weights=self.df[self.val_col],
        )
        Counts, _, _ = np.histogram2d(
            theta_rad, self.df[self.ws_col], bins=[wd_bins, ws_bins]
        )

        # Calculate mean
        with np.errstate(divide="ignore", invalid="ignore"):
            Z = H / Counts

        # Meshgrid for plotting
        # np.histogram2d edges are for pcolormesh
        Theta, R = np.meshgrid(wd_bins, ws_bins)

        # Plotting
        mappable = self.ax.pcolormesh(Theta, R, Z.T, cmap=cmap, **kwargs)
        self.add_colorbar(mappable, label=self.val_col)

        self.ax.set_ylim(0, self.ws_max)
        return self.ax

__init__(data, ws_col, wd_col, val_col, *, ws_max=None, **kwargs)

Initialize Bivariate Polar Plot.

Parameters:

Name Type Description Default
data Any

Input data (DataFrame, DataArray, etc.).

required
ws_col str

Column name for wind speed.

required
wd_col str

Column name for wind direction (degrees).

required
val_col str

Column name for the value to plot.

required
ws_max Optional[float]

Maximum wind speed to show.

None
**kwargs

Arguments passed to BasePlot. Note: 'subplot_kw={"projection": "polar"}' is added automatically if not provided.

{}
Source code in src/monet_plots/plots/polar.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def __init__(
    self,
    data: Any,
    ws_col: str,
    wd_col: str,
    val_col: str,
    *,
    ws_max: Optional[float] = None,
    **kwargs,
):
    """
    Initialize Bivariate Polar Plot.

    Args:
        data: Input data (DataFrame, DataArray, etc.).
        ws_col: Column name for wind speed.
        wd_col: Column name for wind direction (degrees).
        val_col: Column name for the value to plot.
        ws_max: Maximum wind speed to show.
        **kwargs: Arguments passed to BasePlot. Note: 'subplot_kw={"projection": "polar"}'
                  is added automatically if not provided.
    """
    if "subplot_kw" not in kwargs:
        kwargs["subplot_kw"] = {"projection": "polar"}
    elif "projection" not in kwargs["subplot_kw"]:
        kwargs["subplot_kw"]["projection"] = "polar"

    super().__init__(**kwargs)
    self.df = to_dataframe(data).dropna(subset=[ws_col, wd_col, val_col])
    self.ws_col = ws_col
    self.wd_col = wd_col
    self.val_col = val_col
    self.ws_max = ws_max or self.df[ws_col].max()

plot(n_bins_ws=10, n_bins_wd=36, cmap='viridis', **kwargs)

Generate the bivariate polar plot.

Uses binning to aggregate data before plotting.

Source code in src/monet_plots/plots/polar.py
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
def plot(
    self, n_bins_ws: int = 10, n_bins_wd: int = 36, cmap: str = "viridis", **kwargs
):
    """
    Generate the bivariate polar plot.

    Uses binning to aggregate data before plotting.
    """
    # Convert wind direction to radians and adjust for polar plot (0 is North/Up)
    # Matplotlib polar 0 is East (right). We want 0 at North.
    # Wind direction is usually 0=North, 90=East.
    # theta = (90 - wd) * pi / 180
    theta_rad = np.radians(self.df[self.wd_col])
    # Matplotlib's polar axis by default has 0 at East.
    # To make 0 North, we can use:
    self.ax.set_theta_zero_location("N")
    self.ax.set_theta_direction(-1)  # Clockwise

    # Binning
    ws_bins = np.linspace(0, self.ws_max, n_bins_ws + 1)
    wd_bins = np.radians(np.linspace(0, 360, n_bins_wd + 1))

    # We can use np.histogram2d
    # Note: wd_bins is in radians
    H, xedges, yedges = np.histogram2d(
        theta_rad,
        self.df[self.ws_col],
        bins=[wd_bins, ws_bins],
        weights=self.df[self.val_col],
    )
    Counts, _, _ = np.histogram2d(
        theta_rad, self.df[self.ws_col], bins=[wd_bins, ws_bins]
    )

    # Calculate mean
    with np.errstate(divide="ignore", invalid="ignore"):
        Z = H / Counts

    # Meshgrid for plotting
    # np.histogram2d edges are for pcolormesh
    Theta, R = np.meshgrid(wd_bins, ws_bins)

    # Plotting
    mappable = self.ax.pcolormesh(Theta, R, Z.T, cmap=cmap, **kwargs)
    self.add_colorbar(mappable, label=self.val_col)

    self.ax.set_ylim(0, self.ws_max)
    return self.ax

BrierScoreDecompositionPlot

Bases: BasePlot

Brier Score Decomposition Plot.

Visualizes the components of the Brier Score: Reliability, Resolution, and Uncertainty. BS = Reliability - Resolution + Uncertainty

Source code in src/monet_plots/plots/brier_decomposition.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
class BrierScoreDecompositionPlot(BasePlot):
    """
    Brier Score Decomposition Plot.

    Visualizes the components of the Brier Score: Reliability,
    Resolution, and Uncertainty.
    BS = Reliability - Resolution + Uncertainty
    """

    def __init__(self, fig=None, ax=None, **kwargs):
        super().__init__(fig=fig, ax=ax, **kwargs)

    def plot(
        self,
        data: Any,
        reliability_col: str = "reliability",
        resolution_col: str = "resolution",
        uncertainty_col: str = "uncertainty",
        forecasts_col: Optional[str] = None,
        observations_col: Optional[str] = None,
        n_bins: int = 10,
        label_col: Optional[str] = None,
        **kwargs,
    ):
        """
        Main plotting method.

        Args:
            data: Input data.
            reliability_col/resolution_col/uncertainty_col (str):
                Pre-computed component columns.
            forecasts_col/observations_col (str, optional):
                Raw forecast probabilities and binary observations.
            n_bins (int): Bins for decomposition if raw data.
            label_col (str, optional): Grouping column.
            **kwargs: Matplotlib kwargs.
        """
        title = kwargs.pop("title", "Brier Score Decomposition")
        df = to_dataframe(data)
        # Compute components if raw data provided
        if forecasts_col and observations_col:
            components_list = []
            if label_col:
                for name, group in df.groupby(label_col):
                    comps = compute_brier_score_components(
                        np.asarray(group[forecasts_col]),
                        np.asarray(group[observations_col]),
                        n_bins,
                    )
                    row = pd.Series(comps)
                    row["model"] = str(name)
                    components_list.append(row)
            else:
                comps = compute_brier_score_components(
                    np.asarray(df[forecasts_col]),
                    np.asarray(df[observations_col]),
                    n_bins,
                )
                row = pd.Series(comps)
                row["model"] = "Model"
                components_list.append(row)

            df_plot = pd.DataFrame(components_list)
            plot_label_col = "model"
        else:
            required_cols = [reliability_col, resolution_col, uncertainty_col]
            validate_dataframe(df, required_columns=required_cols)
            df_plot = df
            plot_label_col = label_col

        # Prepare for plotting: make resolution negative for visualization
        df_plot = df_plot.copy()
        df_plot["resolution_plot"] = -df_plot[resolution_col]

        # Grouped bar plot
        if plot_label_col:
            labels = df_plot[plot_label_col].astype(str)
        else:
            labels = df_plot.index.astype(str)

        x = np.arange(len(labels))
        width = 0.25

        self.ax.bar(
            x - width,
            df_plot[reliability_col],
            width,
            label="Reliability",
            color="red",
            alpha=0.8,
            **kwargs,
        )
        self.ax.bar(
            x,
            df_plot["resolution_plot"],
            width,
            label="Resolution (-)",
            color="green",
            alpha=0.8,
            **kwargs,
        )
        self.ax.bar(
            x + width,
            df_plot[uncertainty_col],
            width,
            label="Uncertainty",
            color="blue",
            alpha=0.8,
            **kwargs,
        )

        # Total Brier Score as line on top if available
        if "brier_score" in df_plot.columns:
            self.ax.plot(
                x,
                df_plot["brier_score"],
                "ko-",
                linewidth=2,
                markersize=6,
                label="Brier Score",
            )

        self.ax.set_xticks(x)
        self.ax.set_xticklabels(labels, rotation=45, ha="right")
        self.ax.legend(loc="best")
        self.ax.set_ylabel("Brier Score Components")
        self.ax.set_title(title)
        self.ax.grid(True, alpha=0.3)

plot(data, reliability_col='reliability', resolution_col='resolution', uncertainty_col='uncertainty', forecasts_col=None, observations_col=None, n_bins=10, label_col=None, **kwargs)

Main plotting method.

Parameters:

Name Type Description Default
data Any

Input data.

required
reliability_col/resolution_col/uncertainty_col str

Pre-computed component columns.

required
forecasts_col/observations_col str

Raw forecast probabilities and binary observations.

required
n_bins int

Bins for decomposition if raw data.

10
label_col str

Grouping column.

None
**kwargs

Matplotlib kwargs.

{}
Source code in src/monet_plots/plots/brier_decomposition.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
def plot(
    self,
    data: Any,
    reliability_col: str = "reliability",
    resolution_col: str = "resolution",
    uncertainty_col: str = "uncertainty",
    forecasts_col: Optional[str] = None,
    observations_col: Optional[str] = None,
    n_bins: int = 10,
    label_col: Optional[str] = None,
    **kwargs,
):
    """
    Main plotting method.

    Args:
        data: Input data.
        reliability_col/resolution_col/uncertainty_col (str):
            Pre-computed component columns.
        forecasts_col/observations_col (str, optional):
            Raw forecast probabilities and binary observations.
        n_bins (int): Bins for decomposition if raw data.
        label_col (str, optional): Grouping column.
        **kwargs: Matplotlib kwargs.
    """
    title = kwargs.pop("title", "Brier Score Decomposition")
    df = to_dataframe(data)
    # Compute components if raw data provided
    if forecasts_col and observations_col:
        components_list = []
        if label_col:
            for name, group in df.groupby(label_col):
                comps = compute_brier_score_components(
                    np.asarray(group[forecasts_col]),
                    np.asarray(group[observations_col]),
                    n_bins,
                )
                row = pd.Series(comps)
                row["model"] = str(name)
                components_list.append(row)
        else:
            comps = compute_brier_score_components(
                np.asarray(df[forecasts_col]),
                np.asarray(df[observations_col]),
                n_bins,
            )
            row = pd.Series(comps)
            row["model"] = "Model"
            components_list.append(row)

        df_plot = pd.DataFrame(components_list)
        plot_label_col = "model"
    else:
        required_cols = [reliability_col, resolution_col, uncertainty_col]
        validate_dataframe(df, required_columns=required_cols)
        df_plot = df
        plot_label_col = label_col

    # Prepare for plotting: make resolution negative for visualization
    df_plot = df_plot.copy()
    df_plot["resolution_plot"] = -df_plot[resolution_col]

    # Grouped bar plot
    if plot_label_col:
        labels = df_plot[plot_label_col].astype(str)
    else:
        labels = df_plot.index.astype(str)

    x = np.arange(len(labels))
    width = 0.25

    self.ax.bar(
        x - width,
        df_plot[reliability_col],
        width,
        label="Reliability",
        color="red",
        alpha=0.8,
        **kwargs,
    )
    self.ax.bar(
        x,
        df_plot["resolution_plot"],
        width,
        label="Resolution (-)",
        color="green",
        alpha=0.8,
        **kwargs,
    )
    self.ax.bar(
        x + width,
        df_plot[uncertainty_col],
        width,
        label="Uncertainty",
        color="blue",
        alpha=0.8,
        **kwargs,
    )

    # Total Brier Score as line on top if available
    if "brier_score" in df_plot.columns:
        self.ax.plot(
            x,
            df_plot["brier_score"],
            "ko-",
            linewidth=2,
            markersize=6,
            label="Brier Score",
        )

    self.ax.set_xticks(x)
    self.ax.set_xticklabels(labels, rotation=45, ha="right")
    self.ax.legend(loc="best")
    self.ax.set_ylabel("Brier Score Components")
    self.ax.set_title(title)
    self.ax.grid(True, alpha=0.3)

ConditionalBiasPlot

Bases: BasePlot

Conditional Bias Plot.

Visualizes the Bias (Forecast - Observation) as a function of the Observed Value. Supports native Xarray/Dask objects and interactive visualization.

Source code in src/monet_plots/plots/conditional_bias.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
class ConditionalBiasPlot(BasePlot):
    """
    Conditional Bias Plot.

    Visualizes the Bias (Forecast - Observation) as a function of the Observed Value.
    Supports native Xarray/Dask objects and interactive visualization.
    """

    def __init__(self, data: Optional[Any] = None, fig=None, ax=None, **kwargs):
        """
        Initializes the plot.

        Parameters
        ----------
        data : Any, optional
            The input data (Dataset, DataArray, DataFrame, or ndarray).
        fig : matplotlib.figure.Figure, optional
            Figure to plot on.
        ax : matplotlib.axes.Axes, optional
            Axes to plot on.
        **kwargs : Any
            Additional keyword arguments for the figure.
        """
        super().__init__(fig=fig, ax=ax, **kwargs)
        self.data = normalize_data(data) if data is not None else None

    def plot(
        self,
        data: Optional[Any] = None,
        obs_col: Optional[str] = None,
        fcst_col: Optional[str] = None,
        n_bins: int = 10,
        label: str = "Model",
        label_col: Optional[str] = None,
        **kwargs,
    ):
        """
        Generates the static Matplotlib plot.

        Parameters
        ----------
        data : Any, optional
            Input data, overrides self.data if provided.
        obs_col : str, optional
            Name of the observation variable. Required for Dataset/DataFrame.
        fcst_col : str, optional
            Name of the forecast variable. Required for Dataset/DataFrame.
        n_bins : int, optional
            Number of bins for observed values, by default 10.
        label : str, optional
            Label for the model data, by default "Model".
        label_col : str, optional
            Column name to group by for plotting multiple lines.
        **kwargs : Any
            Additional Matplotlib plotting arguments passed to `errorbar`.

        Returns
        -------
        matplotlib.axes.Axes
            The axes object containing the plot.
        """
        plot_data = normalize_data(data) if data is not None else self.data
        if plot_data is None:
            raise ValueError("No data provided.")

        try:
            if label_col:
                # Handle grouping for multiple models/categories
                for name, group in plot_data.groupby(label_col):
                    obs = group[obs_col]
                    mod = group[fcst_col]
                    self._plot_single(obs, mod, n_bins, label=str(name), **kwargs)
            else:
                # Single model plot
                if isinstance(plot_data, xr.Dataset):
                    obs = plot_data[obs_col]
                    mod = plot_data[fcst_col]
                elif isinstance(plot_data, xr.DataArray):
                    mod = plot_data
                    obs = kwargs.pop("obs", None)
                    if obs is None:
                        raise ValueError("obs must be provided if data is a DataArray.")
                elif isinstance(plot_data, pd.DataFrame):
                    obs = plot_data[obs_col]
                    mod = plot_data[fcst_col]
                else:
                    # Should have been normalized
                    raise TypeError(f"Unsupported data type: {type(plot_data)}")

                self._plot_single(obs, mod, n_bins, label=label, **kwargs)
        except KeyError as e:
            raise ValueError(f"Required column not found: {e}") from e

        self.ax.axhline(0, color="k", linestyle="--", linewidth=1.5, alpha=0.7)
        xlabel = (
            plot_data[obs_col].attrs.get("long_name", obs_col)
            if obs_col
            else "Observed Value"
        )
        self.ax.set_xlabel(xlabel)
        self.ax.set_ylabel("Mean Bias (Forecast - Observation)")
        self.ax.legend()
        return self.ax

    def _plot_single(self, obs, mod, n_bins, label, **kwargs):
        """Helper to plot a single binned bias line."""
        stats = compute_binned_bias(obs, mod, n_bins=n_bins)
        pdf = stats.compute().dropna(dim="bin_center")

        # Filter for count > 1 to avoid showing bins with only one sample (no std dev)
        pdf = pdf.where(pdf.bias_count > 1, drop=True)

        if pdf.bin_center.size > 0:
            self.ax.errorbar(
                pdf.bin_center,
                pdf.bias_mean,
                yerr=pdf.bias_std,
                fmt="o-",
                capsize=5,
                label=label,
                **kwargs,
            )

    def hvplot(
        self,
        data: Optional[Any] = None,
        obs_col: Optional[str] = None,
        fcst_col: Optional[str] = None,
        n_bins: int = 10,
        label_col: Optional[str] = None,
        **kwargs: Any,
    ) -> Any:
        """
        Generates an interactive plot using hvPlot.

        Parameters
        ----------
        data : Any, optional
            Input data, overrides self.data if provided.
        obs_col : str, optional
            Name of the observation variable.
        fcst_col : str, optional
            Name of the forecast variable.
        n_bins : int, optional
            Number of bins, by default 10.
        label_col : str, optional
            Column name to group by.
        **kwargs : Any
            Additional hvPlot arguments.

        Returns
        -------
        holoviews.core.Element
            The interactive plot.
        """
        try:
            import holoviews as hv
            import hvplot.xarray  # noqa: F401
        except ImportError:
            raise ImportError(
                "hvplot and holoviews are required for interactive plotting. Install them with 'pip install hvplot holoviews'."
            )

        plot_data = normalize_data(data) if data is not None else self.data
        if plot_data is None:
            raise ValueError("No data provided.")

        if label_col:

            def get_stats(group):
                return compute_binned_bias(
                    group[obs_col], group[fcst_col], n_bins=n_bins
                ).compute()

            # We compute per group for the visualization summary
            stats_list = []
            for name, group in plot_data.groupby(label_col):
                s = get_stats(group)
                s = s.assign_coords({label_col: name}).expand_dims(label_col)
                stats_list.append(s)
            pdf = xr.concat(stats_list, dim=label_col).dropna(dim="bin_center")
            by = label_col
        else:
            if isinstance(plot_data, xr.Dataset):
                obs = plot_data[obs_col]
                mod = plot_data[fcst_col]
            elif isinstance(plot_data, pd.DataFrame):
                obs = plot_data[obs_col]
                mod = plot_data[fcst_col]
            else:
                mod = plot_data
                obs = kwargs.pop("obs")
            pdf = compute_binned_bias(obs, mod, n_bins=n_bins).compute()
            pdf = pdf.dropna(dim="bin_center")
            by = None

        xlabel = (
            plot_data[obs_col].attrs.get("long_name", obs_col)
            if obs_col
            else "Observed Value"
        )

        plot = pdf.hvplot.scatter(
            x="bin_center",
            y="bias_mean",
            by=by,
            xlabel=xlabel,
            ylabel="Mean Bias",
            **kwargs,
        ) * pdf.hvplot.errorbars(x="bin_center", y="bias_mean", yerr1="bias_std", by=by)

        # Add zero line
        plot *= hv.HLine(0).opts(color="black", line_dash="dashed")

        return plot

__init__(data=None, fig=None, ax=None, **kwargs)

Initializes the plot.

Parameters

data : Any, optional The input data (Dataset, DataArray, DataFrame, or ndarray). fig : matplotlib.figure.Figure, optional Figure to plot on. ax : matplotlib.axes.Axes, optional Axes to plot on. **kwargs : Any Additional keyword arguments for the figure.

Source code in src/monet_plots/plots/conditional_bias.py
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
def __init__(self, data: Optional[Any] = None, fig=None, ax=None, **kwargs):
    """
    Initializes the plot.

    Parameters
    ----------
    data : Any, optional
        The input data (Dataset, DataArray, DataFrame, or ndarray).
    fig : matplotlib.figure.Figure, optional
        Figure to plot on.
    ax : matplotlib.axes.Axes, optional
        Axes to plot on.
    **kwargs : Any
        Additional keyword arguments for the figure.
    """
    super().__init__(fig=fig, ax=ax, **kwargs)
    self.data = normalize_data(data) if data is not None else None

hvplot(data=None, obs_col=None, fcst_col=None, n_bins=10, label_col=None, **kwargs)

Generates an interactive plot using hvPlot.

Parameters

data : Any, optional Input data, overrides self.data if provided. obs_col : str, optional Name of the observation variable. fcst_col : str, optional Name of the forecast variable. n_bins : int, optional Number of bins, by default 10. label_col : str, optional Column name to group by. **kwargs : Any Additional hvPlot arguments.

Returns

holoviews.core.Element The interactive plot.

Source code in src/monet_plots/plots/conditional_bias.py
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
def hvplot(
    self,
    data: Optional[Any] = None,
    obs_col: Optional[str] = None,
    fcst_col: Optional[str] = None,
    n_bins: int = 10,
    label_col: Optional[str] = None,
    **kwargs: Any,
) -> Any:
    """
    Generates an interactive plot using hvPlot.

    Parameters
    ----------
    data : Any, optional
        Input data, overrides self.data if provided.
    obs_col : str, optional
        Name of the observation variable.
    fcst_col : str, optional
        Name of the forecast variable.
    n_bins : int, optional
        Number of bins, by default 10.
    label_col : str, optional
        Column name to group by.
    **kwargs : Any
        Additional hvPlot arguments.

    Returns
    -------
    holoviews.core.Element
        The interactive plot.
    """
    try:
        import holoviews as hv
        import hvplot.xarray  # noqa: F401
    except ImportError:
        raise ImportError(
            "hvplot and holoviews are required for interactive plotting. Install them with 'pip install hvplot holoviews'."
        )

    plot_data = normalize_data(data) if data is not None else self.data
    if plot_data is None:
        raise ValueError("No data provided.")

    if label_col:

        def get_stats(group):
            return compute_binned_bias(
                group[obs_col], group[fcst_col], n_bins=n_bins
            ).compute()

        # We compute per group for the visualization summary
        stats_list = []
        for name, group in plot_data.groupby(label_col):
            s = get_stats(group)
            s = s.assign_coords({label_col: name}).expand_dims(label_col)
            stats_list.append(s)
        pdf = xr.concat(stats_list, dim=label_col).dropna(dim="bin_center")
        by = label_col
    else:
        if isinstance(plot_data, xr.Dataset):
            obs = plot_data[obs_col]
            mod = plot_data[fcst_col]
        elif isinstance(plot_data, pd.DataFrame):
            obs = plot_data[obs_col]
            mod = plot_data[fcst_col]
        else:
            mod = plot_data
            obs = kwargs.pop("obs")
        pdf = compute_binned_bias(obs, mod, n_bins=n_bins).compute()
        pdf = pdf.dropna(dim="bin_center")
        by = None

    xlabel = (
        plot_data[obs_col].attrs.get("long_name", obs_col)
        if obs_col
        else "Observed Value"
    )

    plot = pdf.hvplot.scatter(
        x="bin_center",
        y="bias_mean",
        by=by,
        xlabel=xlabel,
        ylabel="Mean Bias",
        **kwargs,
    ) * pdf.hvplot.errorbars(x="bin_center", y="bias_mean", yerr1="bias_std", by=by)

    # Add zero line
    plot *= hv.HLine(0).opts(color="black", line_dash="dashed")

    return plot

plot(data=None, obs_col=None, fcst_col=None, n_bins=10, label='Model', label_col=None, **kwargs)

Generates the static Matplotlib plot.

Parameters

data : Any, optional Input data, overrides self.data if provided. obs_col : str, optional Name of the observation variable. Required for Dataset/DataFrame. fcst_col : str, optional Name of the forecast variable. Required for Dataset/DataFrame. n_bins : int, optional Number of bins for observed values, by default 10. label : str, optional Label for the model data, by default "Model". label_col : str, optional Column name to group by for plotting multiple lines. **kwargs : Any Additional Matplotlib plotting arguments passed to errorbar.

Returns

matplotlib.axes.Axes The axes object containing the plot.

Source code in src/monet_plots/plots/conditional_bias.py
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
def plot(
    self,
    data: Optional[Any] = None,
    obs_col: Optional[str] = None,
    fcst_col: Optional[str] = None,
    n_bins: int = 10,
    label: str = "Model",
    label_col: Optional[str] = None,
    **kwargs,
):
    """
    Generates the static Matplotlib plot.

    Parameters
    ----------
    data : Any, optional
        Input data, overrides self.data if provided.
    obs_col : str, optional
        Name of the observation variable. Required for Dataset/DataFrame.
    fcst_col : str, optional
        Name of the forecast variable. Required for Dataset/DataFrame.
    n_bins : int, optional
        Number of bins for observed values, by default 10.
    label : str, optional
        Label for the model data, by default "Model".
    label_col : str, optional
        Column name to group by for plotting multiple lines.
    **kwargs : Any
        Additional Matplotlib plotting arguments passed to `errorbar`.

    Returns
    -------
    matplotlib.axes.Axes
        The axes object containing the plot.
    """
    plot_data = normalize_data(data) if data is not None else self.data
    if plot_data is None:
        raise ValueError("No data provided.")

    try:
        if label_col:
            # Handle grouping for multiple models/categories
            for name, group in plot_data.groupby(label_col):
                obs = group[obs_col]
                mod = group[fcst_col]
                self._plot_single(obs, mod, n_bins, label=str(name), **kwargs)
        else:
            # Single model plot
            if isinstance(plot_data, xr.Dataset):
                obs = plot_data[obs_col]
                mod = plot_data[fcst_col]
            elif isinstance(plot_data, xr.DataArray):
                mod = plot_data
                obs = kwargs.pop("obs", None)
                if obs is None:
                    raise ValueError("obs must be provided if data is a DataArray.")
            elif isinstance(plot_data, pd.DataFrame):
                obs = plot_data[obs_col]
                mod = plot_data[fcst_col]
            else:
                # Should have been normalized
                raise TypeError(f"Unsupported data type: {type(plot_data)}")

            self._plot_single(obs, mod, n_bins, label=label, **kwargs)
    except KeyError as e:
        raise ValueError(f"Required column not found: {e}") from e

    self.ax.axhline(0, color="k", linestyle="--", linewidth=1.5, alpha=0.7)
    xlabel = (
        plot_data[obs_col].attrs.get("long_name", obs_col)
        if obs_col
        else "Observed Value"
    )
    self.ax.set_xlabel(xlabel)
    self.ax.set_ylabel("Mean Bias (Forecast - Observation)")
    self.ax.legend()
    return self.ax

ConditionalQuantilePlot

Bases: BasePlot

Conditional quantile plot.

Plots the distribution (quantiles) of modeled values as a function of binned observed values. This helps identify if the model's uncertainty or bias changes across the range of observations.

Source code in src/monet_plots/plots/conditional_quantile.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
class ConditionalQuantilePlot(BasePlot):
    """Conditional quantile plot.

    Plots the distribution (quantiles) of modeled values as a function
    of binned observed values. This helps identify if the model's
    uncertainty or bias changes across the range of observations.
    """

    def __init__(
        self,
        data: Any,
        obs_col: str,
        mod_col: str,
        *,
        bins: Union[int, List[float]] = 10,
        quantiles: List[float] = [0.25, 0.5, 0.75],
        **kwargs,
    ):
        """
        Initialize Conditional Quantile Plot.

        Args:
            data: Input data (DataFrame, DataArray, etc.).
            obs_col: Column name for observations.
            mod_col: Column name for model values.
            bins: Number of bins or bin edges for observations.
            quantiles: List of quantiles to calculate (0 to 1).
            **kwargs: Arguments passed to BasePlot.
        """
        super().__init__(**kwargs)
        self.df = to_dataframe(data).dropna(subset=[obs_col, mod_col])
        self.obs_col = obs_col
        self.mod_col = mod_col
        self.bins = bins
        self.quantiles = sorted(quantiles)

    def plot(self, show_points: bool = False, **kwargs):
        """Generate the conditional quantile plot."""
        # Bin observations
        self.df["bin"] = pd.cut(self.df[self.obs_col], bins=self.bins)

        # Calculate quantiles for each bin
        # We use the midpoint of each bin for the x-axis
        bin_midpoints = []
        quantile_vals = {q: [] for q in self.quantiles}

        for bin_name, group in self.df.groupby("bin", observed=True):
            bin_midpoints.append(bin_name.mid)
            for q in self.quantiles:
                quantile_vals[q].append(group[self.mod_col].quantile(q))

        # Plotting
        if show_points:
            self.ax.scatter(
                self.df[self.obs_col],
                self.df[self.mod_col],
                alpha=0.3,
                s=10,
                color="grey",
                label="Data",
            )

        # Plot 1:1 line
        lims = [
            min(self.df[self.obs_col].min(), self.df[self.mod_col].min()),
            max(self.df[self.obs_col].max(), self.df[self.mod_col].max()),
        ]
        self.ax.plot(lims, lims, "k--", alpha=0.5, label="1:1")

        # Plot quantiles
        colors = plt.cm.Blues(np.linspace(0.4, 0.8, len(self.quantiles)))
        for i, q in enumerate(self.quantiles):
            label = f"{int(q * 100)}th percentile"
            linestyle = "-" if q == 0.5 else "--"
            linewidth = 2 if q == 0.5 else 1
            self.ax.plot(
                bin_midpoints,
                quantile_vals[q],
                label=label,
                color=colors[i],
                linestyle=linestyle,
                linewidth=linewidth,
            )

        # Shading between quantiles if there are at least 2 (e.g. 25th and 75th)
        if 0.25 in self.quantiles and 0.75 in self.quantiles:
            self.ax.fill_between(
                bin_midpoints,
                quantile_vals[0.25],
                quantile_vals[0.75],
                color="blue",
                alpha=0.1,
            )

        self.ax.set_xlabel(f"Observed: {self.obs_col}")
        self.ax.set_ylabel(f"Modeled: {self.mod_col}")
        self.ax.legend()
        self.ax.grid(True, linestyle=":", alpha=0.6)

        return self.ax

__init__(data, obs_col, mod_col, *, bins=10, quantiles=[0.25, 0.5, 0.75], **kwargs)

Initialize Conditional Quantile Plot.

Parameters:

Name Type Description Default
data Any

Input data (DataFrame, DataArray, etc.).

required
obs_col str

Column name for observations.

required
mod_col str

Column name for model values.

required
bins Union[int, List[float]]

Number of bins or bin edges for observations.

10
quantiles List[float]

List of quantiles to calculate (0 to 1).

[0.25, 0.5, 0.75]
**kwargs

Arguments passed to BasePlot.

{}
Source code in src/monet_plots/plots/conditional_quantile.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
def __init__(
    self,
    data: Any,
    obs_col: str,
    mod_col: str,
    *,
    bins: Union[int, List[float]] = 10,
    quantiles: List[float] = [0.25, 0.5, 0.75],
    **kwargs,
):
    """
    Initialize Conditional Quantile Plot.

    Args:
        data: Input data (DataFrame, DataArray, etc.).
        obs_col: Column name for observations.
        mod_col: Column name for model values.
        bins: Number of bins or bin edges for observations.
        quantiles: List of quantiles to calculate (0 to 1).
        **kwargs: Arguments passed to BasePlot.
    """
    super().__init__(**kwargs)
    self.df = to_dataframe(data).dropna(subset=[obs_col, mod_col])
    self.obs_col = obs_col
    self.mod_col = mod_col
    self.bins = bins
    self.quantiles = sorted(quantiles)

plot(show_points=False, **kwargs)

Generate the conditional quantile plot.

Source code in src/monet_plots/plots/conditional_quantile.py
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
def plot(self, show_points: bool = False, **kwargs):
    """Generate the conditional quantile plot."""
    # Bin observations
    self.df["bin"] = pd.cut(self.df[self.obs_col], bins=self.bins)

    # Calculate quantiles for each bin
    # We use the midpoint of each bin for the x-axis
    bin_midpoints = []
    quantile_vals = {q: [] for q in self.quantiles}

    for bin_name, group in self.df.groupby("bin", observed=True):
        bin_midpoints.append(bin_name.mid)
        for q in self.quantiles:
            quantile_vals[q].append(group[self.mod_col].quantile(q))

    # Plotting
    if show_points:
        self.ax.scatter(
            self.df[self.obs_col],
            self.df[self.mod_col],
            alpha=0.3,
            s=10,
            color="grey",
            label="Data",
        )

    # Plot 1:1 line
    lims = [
        min(self.df[self.obs_col].min(), self.df[self.mod_col].min()),
        max(self.df[self.obs_col].max(), self.df[self.mod_col].max()),
    ]
    self.ax.plot(lims, lims, "k--", alpha=0.5, label="1:1")

    # Plot quantiles
    colors = plt.cm.Blues(np.linspace(0.4, 0.8, len(self.quantiles)))
    for i, q in enumerate(self.quantiles):
        label = f"{int(q * 100)}th percentile"
        linestyle = "-" if q == 0.5 else "--"
        linewidth = 2 if q == 0.5 else 1
        self.ax.plot(
            bin_midpoints,
            quantile_vals[q],
            label=label,
            color=colors[i],
            linestyle=linestyle,
            linewidth=linewidth,
        )

    # Shading between quantiles if there are at least 2 (e.g. 25th and 75th)
    if 0.25 in self.quantiles and 0.75 in self.quantiles:
        self.ax.fill_between(
            bin_midpoints,
            quantile_vals[0.25],
            quantile_vals[0.75],
            color="blue",
            alpha=0.1,
        )

    self.ax.set_xlabel(f"Observed: {self.obs_col}")
    self.ax.set_ylabel(f"Modeled: {self.mod_col}")
    self.ax.legend()
    self.ax.grid(True, linestyle=":", alpha=0.6)

    return self.ax

CurtainPlot

Bases: BasePlot

Vertical curtain plot for cross-sectional data.

This plot shows a 2D variable (e.g., concentration) as a function of one horizontal dimension (time or distance) and one vertical dimension (altitude or pressure).

Source code in src/monet_plots/plots/curtain.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
class CurtainPlot(BasePlot):
    """Vertical curtain plot for cross-sectional data.

    This plot shows a 2D variable (e.g., concentration) as a function of
    one horizontal dimension (time or distance) and one vertical dimension
    (altitude or pressure).
    """

    def __init__(
        self,
        data: Any,
        *,
        x: Optional[str] = None,
        y: Optional[str] = None,
        **kwargs,
    ):
        """
        Initialize Curtain Plot.

        Args:
            data: Input data. Should be a 2D xarray.DataArray or similar.
            x: Name of the x-axis dimension/coordinate (e.g., 'time').
            y: Name of the y-axis dimension/coordinate (e.g., 'level').
            **kwargs: Arguments passed to BasePlot.
        """
        super().__init__(**kwargs)
        self.data = data
        self.x = x
        self.y = y

    def plot(self, kind: str = "pcolormesh", colorbar: bool = True, **kwargs):
        """
        Generate the curtain plot.

        Args:
            kind: Type of plot ('pcolormesh' or 'contourf').
            colorbar: Whether to add a colorbar.
            **kwargs: Additional arguments for the plotting function.
        """
        plot_kwargs = get_plot_kwargs(**kwargs)

        # Ensure we have a DataArray
        if not isinstance(self.data, xr.DataArray):
            # Try to convert or at least verify it's xarray-like
            if hasattr(self.data, "to_array"):
                da = self.data.to_array()
            else:
                raise TypeError(
                    "CurtainPlot requires xarray-like data with 2 dimensions."
                )
        else:
            da = self.data

        if da.ndim != 2:
            raise ValueError(f"CurtainPlot requires 2D data, got {da.ndim}D.")

        # Determine x and y if not provided
        if self.x is None:
            self.x = da.dims[1]
        if self.y is None:
            self.y = da.dims[0]

        if kind == "pcolormesh":
            mappable = self.ax.pcolormesh(
                da[self.x], da[self.y], da, shading="auto", **plot_kwargs
            )
        elif kind == "contourf":
            mappable = self.ax.contourf(da[self.x], da[self.y], da, **plot_kwargs)
        else:
            raise ValueError("kind must be 'pcolormesh' or 'contourf'")

        if colorbar:
            self.add_colorbar(mappable)

        self.ax.set_xlabel(self.x)
        self.ax.set_ylabel(self.y)

        return self.ax

__init__(data, *, x=None, y=None, **kwargs)

Initialize Curtain Plot.

Parameters:

Name Type Description Default
data Any

Input data. Should be a 2D xarray.DataArray or similar.

required
x Optional[str]

Name of the x-axis dimension/coordinate (e.g., 'time').

None
y Optional[str]

Name of the y-axis dimension/coordinate (e.g., 'level').

None
**kwargs

Arguments passed to BasePlot.

{}
Source code in src/monet_plots/plots/curtain.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def __init__(
    self,
    data: Any,
    *,
    x: Optional[str] = None,
    y: Optional[str] = None,
    **kwargs,
):
    """
    Initialize Curtain Plot.

    Args:
        data: Input data. Should be a 2D xarray.DataArray or similar.
        x: Name of the x-axis dimension/coordinate (e.g., 'time').
        y: Name of the y-axis dimension/coordinate (e.g., 'level').
        **kwargs: Arguments passed to BasePlot.
    """
    super().__init__(**kwargs)
    self.data = data
    self.x = x
    self.y = y

plot(kind='pcolormesh', colorbar=True, **kwargs)

Generate the curtain plot.

Parameters:

Name Type Description Default
kind str

Type of plot ('pcolormesh' or 'contourf').

'pcolormesh'
colorbar bool

Whether to add a colorbar.

True
**kwargs

Additional arguments for the plotting function.

{}
Source code in src/monet_plots/plots/curtain.py
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
def plot(self, kind: str = "pcolormesh", colorbar: bool = True, **kwargs):
    """
    Generate the curtain plot.

    Args:
        kind: Type of plot ('pcolormesh' or 'contourf').
        colorbar: Whether to add a colorbar.
        **kwargs: Additional arguments for the plotting function.
    """
    plot_kwargs = get_plot_kwargs(**kwargs)

    # Ensure we have a DataArray
    if not isinstance(self.data, xr.DataArray):
        # Try to convert or at least verify it's xarray-like
        if hasattr(self.data, "to_array"):
            da = self.data.to_array()
        else:
            raise TypeError(
                "CurtainPlot requires xarray-like data with 2 dimensions."
            )
    else:
        da = self.data

    if da.ndim != 2:
        raise ValueError(f"CurtainPlot requires 2D data, got {da.ndim}D.")

    # Determine x and y if not provided
    if self.x is None:
        self.x = da.dims[1]
    if self.y is None:
        self.y = da.dims[0]

    if kind == "pcolormesh":
        mappable = self.ax.pcolormesh(
            da[self.x], da[self.y], da, shading="auto", **plot_kwargs
        )
    elif kind == "contourf":
        mappable = self.ax.contourf(da[self.x], da[self.y], da, **plot_kwargs)
    else:
        raise ValueError("kind must be 'pcolormesh' or 'contourf'")

    if colorbar:
        self.add_colorbar(mappable)

    self.ax.set_xlabel(self.x)
    self.ax.set_ylabel(self.y)

    return self.ax

DiurnalErrorPlot

Bases: BasePlot

Diurnal error heat map.

Visualizes model error (bias) as a function of the hour of day and another temporal dimension (e.g., month, day of week, or date).

This class supports native Xarray and Dask objects for lazy evaluation and provenance tracking.

Attributes

data : Union[xr.Dataset, xr.DataArray, pd.DataFrame] The input data for the plot. obs_col : str Column/variable name for observations. mod_col : str Column/variable name for model values. time_col : str Dimension/column name for timestamp. second_dim : str The second dimension for the heatmap. metric : str The metric to plot ('bias' or 'error'). aggregated : xr.DataArray The calculated aggregated data for the heatmap. second_label : str The label for the second dimension on the y-axis.

Examples

import pandas as pd import numpy as np from monet_plots.plots import DiurnalErrorPlot dates = pd.date_range("2023-01-01", periods=100, freq="h") df = pd.DataFrame({ ... "time": dates, ... "obs": np.random.rand(100), ... "mod": np.random.rand(100) ... }) plot = DiurnalErrorPlot(df, obs_col="obs", mod_col="mod") ax = plot.plot()

Source code in src/monet_plots/plots/diurnal_error.py
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
class DiurnalErrorPlot(BasePlot):
    """Diurnal error heat map.

    Visualizes model error (bias) as a function of the hour of day and another
    temporal dimension (e.g., month, day of week, or date).

    This class supports native Xarray and Dask objects for lazy evaluation
    and provenance tracking.

    Attributes
    ----------
    data : Union[xr.Dataset, xr.DataArray, pd.DataFrame]
        The input data for the plot.
    obs_col : str
        Column/variable name for observations.
    mod_col : str
        Column/variable name for model values.
    time_col : str
        Dimension/column name for timestamp.
    second_dim : str
        The second dimension for the heatmap.
    metric : str
        The metric to plot ('bias' or 'error').
    aggregated : xr.DataArray
        The calculated aggregated data for the heatmap.
    second_label : str
        The label for the second dimension on the y-axis.

    Examples
    --------
    >>> import pandas as pd
    >>> import numpy as np
    >>> from monet_plots.plots import DiurnalErrorPlot
    >>> dates = pd.date_range("2023-01-01", periods=100, freq="h")
    >>> df = pd.DataFrame({
    ...     "time": dates,
    ...     "obs": np.random.rand(100),
    ...     "mod": np.random.rand(100)
    ... })
    >>> plot = DiurnalErrorPlot(df, obs_col="obs", mod_col="mod")
    >>> ax = plot.plot()
    """

    def __init__(
        self,
        data: Any,
        obs_col: str,
        mod_col: str,
        *,
        time_col: str = "time",
        second_dim: str = "month",
        metric: str = "bias",
        fig: matplotlib.figure.Figure | None = None,
        ax: matplotlib.axes.Axes | None = None,
        **kwargs: Any,
    ) -> None:
        """
        Initialize Diurnal Error Plot.

        Parameters
        ----------
        data : Any
            Input data. Can be a pandas DataFrame, xarray DataArray,
            xarray Dataset, or dask-backed object.
        obs_col : str
            Column/variable name for observations.
        mod_col : str
            Column/variable name for model values.
        time_col : str, optional
            Dimension/column name for timestamp, by default "time".
        second_dim : str, optional
            The second dimension for the heatmap ('month', 'dayofweek', 'date',
            or a coordinate name), by default "month".
        metric : str, optional
            The metric to plot ('bias' or 'error'), by default "bias".
        fig : matplotlib.figure.Figure, optional
            Existing figure object, by default None.
        ax : matplotlib.axes.Axes, optional
            Existing axes object, by default None.
        **kwargs : Any
            Additional arguments passed to BasePlot.
        """
        super().__init__(fig=fig, ax=ax, **kwargs)

        # Normalize data to Xarray if possible
        self.data = normalize_data(data)
        self.obs_col = obs_col
        self.mod_col = mod_col
        self.time_col = time_col
        self.second_dim = second_dim
        self.metric = metric
        self.aggregated: xr.DataArray | None = None
        self.second_label: str = ""

        # Prepare the calculation
        self._calculate_metric()

    def _calculate_metric(self) -> None:
        """Calculates the aggregated metric for the heatmap.

        This method identifies the appropriate backend (Xarray/Dask or Pandas),
        calculates the specified metric (bias or absolute error), and aggregates
        it into a 2D grid indexed by 'second_val' and 'hour'. It maintains
        lazy evaluation for Dask-backed objects and uses vectorized grouping.

        Raises
        ------
        ValueError
            If the metric is not 'bias' or 'error', or if second_dim is not found.
        """
        # Convert to Dataset if it's a DataArray to handle multiple columns easily
        ds = self.data
        if isinstance(ds, xr.DataArray):
            ds = ds.to_dataset()

        if isinstance(ds, xr.Dataset):
            # Calculate individual error/bias lazily
            if self.metric == "bias":
                val = ds[self.mod_col] - ds[self.obs_col]
                val.name = "bias"
                msg = "Calculated diurnal bias"
            elif self.metric == "error":
                val = np.abs(ds[self.mod_col] - ds[self.obs_col])
                val.name = "error"
                msg = "Calculated diurnal absolute error"
            else:
                raise ValueError("metric must be 'bias' or 'error'")

            # Add temporal coordinates for grouping
            time_coord = ds[self.time_col]
            val = val.assign_coords(hour=time_coord.dt.hour)

            if self.second_dim == "month":
                val = val.assign_coords(second_val=time_coord.dt.month)
                self.second_label = "Month"
            elif self.second_dim == "dayofweek":
                val = val.assign_coords(second_val=time_coord.dt.dayofweek)
                self.second_label = "Day of Week"
            elif self.second_dim == "date":
                val = val.assign_coords(second_val=time_coord.dt.floor("D"))
                self.second_label = "Date"
            else:
                if self.second_dim in ds.coords or self.second_dim in ds.data_vars:
                    val = val.assign_coords(second_val=ds[self.second_dim])
                    self.second_label = self.second_dim
                else:
                    raise ValueError(
                        f"second_dim '{self.second_dim}' not found in data"
                    )

            # Aero Protocol: Vectorized multi-dimensional grouping.
            # This is significantly faster than the previous loop and maintains
            # lazy evaluation for Dask-backed objects.
            try:
                self.aggregated = val.groupby(["second_val", "hour"]).mean(
                    dim=self.time_col, keep_attrs=True
                )

                # Ensure consistent order (second_val, hour)
                if (
                    "second_val" in self.aggregated.dims
                    and "hour" in self.aggregated.dims
                ):
                    self.aggregated = self.aggregated.transpose("second_val", "hour")
            except Exception:
                # Fallback to eager if something goes wrong with complex Xarray ops
                # though modern xarray (with flox) should handle this lazily.
                df = val.to_dataframe(name=val.name).reset_index()
                pivot = df.pivot_table(
                    index="second_val", columns="hour", values=val.name, aggfunc="mean"
                )
                self.aggregated = xr.DataArray(
                    pivot.values,
                    coords={
                        "second_val": pivot.index.values,
                        "hour": pivot.columns.values,
                    },
                    dims=["second_val", "hour"],
                    name=val.name,
                )

            self.aggregated = _update_history(self.aggregated, msg)

        else:
            # Fallback for Pandas DataFrame (backward compatibility)
            df = self.data.copy()
            df[self.time_col] = pd.to_datetime(df[self.time_col])
            df["hour"] = df[self.time_col].dt.hour

            if self.second_dim == "month":
                df["second_val"] = df[self.time_col].dt.month
                self.second_label = "Month"
            elif self.second_dim == "dayofweek":
                df["second_val"] = df[self.time_col].dt.dayofweek
                self.second_label = "Day of Week"
            elif self.second_dim == "date":
                df["second_val"] = df[self.time_col].dt.date
                self.second_label = "Date"
            else:
                df["second_val"] = df[self.second_dim]
                self.second_label = self.second_dim

            if self.metric == "bias":
                df["val"] = df[self.mod_col] - df[self.obs_col]
                metric_name = "bias"
                msg = "Calculated diurnal bias"
            elif self.metric == "error":
                df["val"] = np.abs(df[self.mod_col] - df[self.obs_col])
                metric_name = "error"
                msg = "Calculated diurnal absolute error"
            else:
                df["val"] = df[self.mod_col]
                metric_name = "value"
                msg = "Calculated diurnal values"

            pivot = df.pivot_table(
                index="second_val", columns="hour", values="val", aggfunc="mean"
            )
            self.aggregated = xr.DataArray(
                pivot.values,
                coords={
                    "second_val": pivot.index.values,
                    "hour": pivot.columns.values,
                },
                dims=["second_val", "hour"],
                name=metric_name,
            )
            self.aggregated = _update_history(self.aggregated, msg)

    def plot(self, cmap: str = "RdBu_r", **kwargs: Any) -> matplotlib.axes.Axes:
        """
        Generate the diurnal error heatmap (Track A: Static).

        Parameters
        ----------
        cmap : str, optional
            Colormap to use, by default "RdBu_r".
        **kwargs : Any
            Additional arguments passed to sns.heatmap.

        Returns
        -------
        matplotlib.axes.Axes
            The axes object.

        Examples
        --------
        >>> # Assuming 'plot' is a DiurnalErrorPlot instance
        >>> ax = plot.plot(cmap="viridis")
        """
        if self.aggregated is None:
            raise ValueError("Aggregated data not found. Call _calculate_metric first.")

        # Compute the aggregated data for plotting
        data_to_plot = self.aggregated
        if hasattr(data_to_plot.data, "chunks"):
            data_to_plot = data_to_plot.compute()

        # Convert to DataFrame for Seaborn
        plot_df = data_to_plot.to_pandas()

        sns.heatmap(
            plot_df,
            ax=self.ax,
            cmap=cmap,
            center=0 if self.metric == "bias" else None,
            **kwargs,
        )

        self.ax.set_xlabel("Hour of Day")
        self.ax.set_ylabel(self.second_label)
        self.ax.set_title(f"Diurnal {self.metric.capitalize()}")

        return self.ax

    def hvplot(self, cmap: str = "RdBu_r", **kwargs: Any) -> hv.Element:
        """
        Generate the diurnal error heatmap (Track B: Interactive).

        Parameters
        ----------
        cmap : str, optional
            Colormap to use, by default "RdBu_r".
        **kwargs : Any
            Additional arguments passed to hvplot.heatmap.

        Returns
        -------
        holoviews.Element
            The interactive HoloViews object.

        Examples
        --------
        >>> # Assuming 'plot' is a DiurnalErrorPlot instance
        >>> interactive_plot = plot.hvplot()
        """
        try:
            import hvplot.xarray  # noqa: F401
        except ImportError:
            raise ImportError(
                "hvplot is required for interactive plotting. Install it with 'pip install hvplot'."
            )

        if self.aggregated is None:
            raise ValueError("Aggregated data not found. Call _calculate_metric first.")

        # Track B: Interactive
        return self.aggregated.hvplot.heatmap(
            x="hour",
            y="second_val",
            C=self.aggregated.name,
            cmap=cmap,
            title=f"Diurnal {self.metric.capitalize()}",
            xlabel="Hour of Day",
            ylabel=self.second_label,
            rasterize=True,
            **kwargs,
        )

__init__(data, obs_col, mod_col, *, time_col='time', second_dim='month', metric='bias', fig=None, ax=None, **kwargs)

Initialize Diurnal Error Plot.

Parameters

data : Any Input data. Can be a pandas DataFrame, xarray DataArray, xarray Dataset, or dask-backed object. obs_col : str Column/variable name for observations. mod_col : str Column/variable name for model values. time_col : str, optional Dimension/column name for timestamp, by default "time". second_dim : str, optional The second dimension for the heatmap ('month', 'dayofweek', 'date', or a coordinate name), by default "month". metric : str, optional The metric to plot ('bias' or 'error'), by default "bias". fig : matplotlib.figure.Figure, optional Existing figure object, by default None. ax : matplotlib.axes.Axes, optional Existing axes object, by default None. **kwargs : Any Additional arguments passed to BasePlot.

Source code in src/monet_plots/plots/diurnal_error.py
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def __init__(
    self,
    data: Any,
    obs_col: str,
    mod_col: str,
    *,
    time_col: str = "time",
    second_dim: str = "month",
    metric: str = "bias",
    fig: matplotlib.figure.Figure | None = None,
    ax: matplotlib.axes.Axes | None = None,
    **kwargs: Any,
) -> None:
    """
    Initialize Diurnal Error Plot.

    Parameters
    ----------
    data : Any
        Input data. Can be a pandas DataFrame, xarray DataArray,
        xarray Dataset, or dask-backed object.
    obs_col : str
        Column/variable name for observations.
    mod_col : str
        Column/variable name for model values.
    time_col : str, optional
        Dimension/column name for timestamp, by default "time".
    second_dim : str, optional
        The second dimension for the heatmap ('month', 'dayofweek', 'date',
        or a coordinate name), by default "month".
    metric : str, optional
        The metric to plot ('bias' or 'error'), by default "bias".
    fig : matplotlib.figure.Figure, optional
        Existing figure object, by default None.
    ax : matplotlib.axes.Axes, optional
        Existing axes object, by default None.
    **kwargs : Any
        Additional arguments passed to BasePlot.
    """
    super().__init__(fig=fig, ax=ax, **kwargs)

    # Normalize data to Xarray if possible
    self.data = normalize_data(data)
    self.obs_col = obs_col
    self.mod_col = mod_col
    self.time_col = time_col
    self.second_dim = second_dim
    self.metric = metric
    self.aggregated: xr.DataArray | None = None
    self.second_label: str = ""

    # Prepare the calculation
    self._calculate_metric()

hvplot(cmap='RdBu_r', **kwargs)

Generate the diurnal error heatmap (Track B: Interactive).

Parameters

cmap : str, optional Colormap to use, by default "RdBu_r". **kwargs : Any Additional arguments passed to hvplot.heatmap.

Returns

holoviews.Element The interactive HoloViews object.

Examples

Assuming 'plot' is a DiurnalErrorPlot instance

interactive_plot = plot.hvplot()

Source code in src/monet_plots/plots/diurnal_error.py
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
def hvplot(self, cmap: str = "RdBu_r", **kwargs: Any) -> hv.Element:
    """
    Generate the diurnal error heatmap (Track B: Interactive).

    Parameters
    ----------
    cmap : str, optional
        Colormap to use, by default "RdBu_r".
    **kwargs : Any
        Additional arguments passed to hvplot.heatmap.

    Returns
    -------
    holoviews.Element
        The interactive HoloViews object.

    Examples
    --------
    >>> # Assuming 'plot' is a DiurnalErrorPlot instance
    >>> interactive_plot = plot.hvplot()
    """
    try:
        import hvplot.xarray  # noqa: F401
    except ImportError:
        raise ImportError(
            "hvplot is required for interactive plotting. Install it with 'pip install hvplot'."
        )

    if self.aggregated is None:
        raise ValueError("Aggregated data not found. Call _calculate_metric first.")

    # Track B: Interactive
    return self.aggregated.hvplot.heatmap(
        x="hour",
        y="second_val",
        C=self.aggregated.name,
        cmap=cmap,
        title=f"Diurnal {self.metric.capitalize()}",
        xlabel="Hour of Day",
        ylabel=self.second_label,
        rasterize=True,
        **kwargs,
    )

plot(cmap='RdBu_r', **kwargs)

Generate the diurnal error heatmap (Track A: Static).

Parameters

cmap : str, optional Colormap to use, by default "RdBu_r". **kwargs : Any Additional arguments passed to sns.heatmap.

Returns

matplotlib.axes.Axes The axes object.

Examples

Assuming 'plot' is a DiurnalErrorPlot instance

ax = plot.plot(cmap="viridis")

Source code in src/monet_plots/plots/diurnal_error.py
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
def plot(self, cmap: str = "RdBu_r", **kwargs: Any) -> matplotlib.axes.Axes:
    """
    Generate the diurnal error heatmap (Track A: Static).

    Parameters
    ----------
    cmap : str, optional
        Colormap to use, by default "RdBu_r".
    **kwargs : Any
        Additional arguments passed to sns.heatmap.

    Returns
    -------
    matplotlib.axes.Axes
        The axes object.

    Examples
    --------
    >>> # Assuming 'plot' is a DiurnalErrorPlot instance
    >>> ax = plot.plot(cmap="viridis")
    """
    if self.aggregated is None:
        raise ValueError("Aggregated data not found. Call _calculate_metric first.")

    # Compute the aggregated data for plotting
    data_to_plot = self.aggregated
    if hasattr(data_to_plot.data, "chunks"):
        data_to_plot = data_to_plot.compute()

    # Convert to DataFrame for Seaborn
    plot_df = data_to_plot.to_pandas()

    sns.heatmap(
        plot_df,
        ax=self.ax,
        cmap=cmap,
        center=0 if self.metric == "bias" else None,
        **kwargs,
    )

    self.ax.set_xlabel("Hour of Day")
    self.ax.set_ylabel(self.second_label)
    self.ax.set_title(f"Diurnal {self.metric.capitalize()}")

    return self.ax

FacetGridPlot

Bases: BasePlot

Creates a facet grid plot.

This class creates a facet grid plot using seaborn's FacetGrid.

Source code in src/monet_plots/plots/facet_grid.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
class FacetGridPlot(BasePlot):
    """Creates a facet grid plot.

    This class creates a facet grid plot using seaborn's FacetGrid.
    """

    def __init__(
        self,
        data: Any,
        row: str | None = None,
        col: str | None = None,
        hue: str | None = None,
        col_wrap: int | None = None,
        size: float | None = None,
        aspect: float = 1,
        subplot_kws: dict[str, Any] | None = None,
        style: str | None = "wiley",
        **kwargs: Any,
    ) -> None:
        """Initializes the facet grid.

        Parameters
        ----------
        data : Any
            The data to plot.
        row : str, optional
            Variable to map to row facets, by default None.
        col : str, optional
            Variable to map to column facets, by default None.
        hue : str, optional
            Variable to map to color mapping, by default None.
        col_wrap : int, optional
            Number of columns before wrapping, by default None.
        size : float, optional
            Height (in inches) of each facet, by default 3.
            Aligns with Xarray convention.
        aspect : float, optional
            Aspect ratio of each facet, by default 1.
        subplot_kws : dict, optional
            Keyword arguments for subplots (e.g. projection).
        style : str, optional
            Style name to apply, by default "wiley".
        **kwargs : Any
            Additional keyword arguments to pass to `FacetGrid`.
            `height` is supported as an alias for `size` (Seaborn convention).
        """
        # Apply style
        if style:
            set_style(style)

        # Handle size/height alignment (Xarray vs Seaborn)
        self.size = size or kwargs.pop("height", 3)
        self.aspect = aspect

        # Store facet parameters
        self.row = row
        self.col = col
        self.hue = hue
        self.col_wrap = col_wrap

        # Aero Protocol: Preserve lazy Xarray objects
        self.raw_data = data
        self.is_xarray = isinstance(data, (xr.DataArray, xr.Dataset))

        if self.is_xarray:
            self.data = data  # Keep as Xarray
            # For Xarray, we can use xarray.plot.FacetGrid
            # We delay creation to SpatialFacetGridPlot if possible,
            # or initialize it here if we have enough info.
            self.grid = None
            if col or row:
                try:
                    from xarray.plot.facetgrid import FacetGrid as xrFacetGrid

                    self.grid = xrFacetGrid(
                        data,
                        col=col,
                        row=row,
                        col_wrap=col_wrap,
                        subplot_kws=subplot_kws,
                    )
                    # Initialize default titles for Xarray
                    self.grid.set_titles()
                except (ImportError, TypeError, AttributeError):
                    pass
        else:
            self.data = to_dataframe(data).reset_index()
            # Create the Seaborn FacetGrid
            self.grid = sns.FacetGrid(
                self.data,
                row=self.row,
                col=self.col,
                hue=self.hue,
                col_wrap=self.col_wrap,
                height=self.size,
                aspect=self.aspect,
                subplot_kws=subplot_kws,
                **kwargs,
            )

        # Unified BasePlot initialization
        axes = getattr(self.grid, "axs", getattr(self.grid, "axes", None))
        if axes is not None:
            super().__init__(fig=self.grid.fig, ax=axes.flatten()[0], style=style)
        else:
            super().__init__(style=style)

        # For compatibility with tests, also store as 'g'
        self.g = self.grid

    def map_dataframe(self, plot_func: Callable, *args: Any, **kwargs: Any) -> None:
        """Maps a plotting function to the facet grid.

        Args:
            plot_func (function): The plotting function to map.
            *args: Positional arguments to pass to the plotting function.
            **kwargs: Keyword arguments to pass to the plotting function.
        """
        self.grid.map_dataframe(plot_func, *args, **kwargs)

    def set_titles(self, *args, **kwargs):
        """Sets the titles of the facet grid.

        Args:
            *args: Positional arguments to pass to `set_titles`.
            **kwargs: Keyword arguments to pass to `set_titles`.
        """
        self.grid.set_titles(*args, **kwargs)

    def save(self, filename, **kwargs):
        """Saves the plot to a file.

        Args:
            filename (str): The name of the file to save the plot to.
            **kwargs: Additional keyword arguments to pass to `savefig`.
        """
        self.fig.savefig(filename, **kwargs)

    def plot(self, plot_func=None, *args, **kwargs):
        """Plots the data using the FacetGrid.

        Args:
            plot_func (function, optional): The plotting function to use.
            *args: Positional arguments to pass to the plotting function.
            **kwargs: Keyword arguments to pass to the plotting function.
        """
        if plot_func is not None:
            self.grid.map(plot_func, *args, **kwargs)

    def close(self):
        """Closes the plot."""
        plt.close(self.fig)

__init__(data, row=None, col=None, hue=None, col_wrap=None, size=None, aspect=1, subplot_kws=None, style='wiley', **kwargs)

Initializes the facet grid.

Parameters

data : Any The data to plot. row : str, optional Variable to map to row facets, by default None. col : str, optional Variable to map to column facets, by default None. hue : str, optional Variable to map to color mapping, by default None. col_wrap : int, optional Number of columns before wrapping, by default None. size : float, optional Height (in inches) of each facet, by default 3. Aligns with Xarray convention. aspect : float, optional Aspect ratio of each facet, by default 1. subplot_kws : dict, optional Keyword arguments for subplots (e.g. projection). style : str, optional Style name to apply, by default "wiley". **kwargs : Any Additional keyword arguments to pass to FacetGrid. height is supported as an alias for size (Seaborn convention).

Source code in src/monet_plots/plots/facet_grid.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def __init__(
    self,
    data: Any,
    row: str | None = None,
    col: str | None = None,
    hue: str | None = None,
    col_wrap: int | None = None,
    size: float | None = None,
    aspect: float = 1,
    subplot_kws: dict[str, Any] | None = None,
    style: str | None = "wiley",
    **kwargs: Any,
) -> None:
    """Initializes the facet grid.

    Parameters
    ----------
    data : Any
        The data to plot.
    row : str, optional
        Variable to map to row facets, by default None.
    col : str, optional
        Variable to map to column facets, by default None.
    hue : str, optional
        Variable to map to color mapping, by default None.
    col_wrap : int, optional
        Number of columns before wrapping, by default None.
    size : float, optional
        Height (in inches) of each facet, by default 3.
        Aligns with Xarray convention.
    aspect : float, optional
        Aspect ratio of each facet, by default 1.
    subplot_kws : dict, optional
        Keyword arguments for subplots (e.g. projection).
    style : str, optional
        Style name to apply, by default "wiley".
    **kwargs : Any
        Additional keyword arguments to pass to `FacetGrid`.
        `height` is supported as an alias for `size` (Seaborn convention).
    """
    # Apply style
    if style:
        set_style(style)

    # Handle size/height alignment (Xarray vs Seaborn)
    self.size = size or kwargs.pop("height", 3)
    self.aspect = aspect

    # Store facet parameters
    self.row = row
    self.col = col
    self.hue = hue
    self.col_wrap = col_wrap

    # Aero Protocol: Preserve lazy Xarray objects
    self.raw_data = data
    self.is_xarray = isinstance(data, (xr.DataArray, xr.Dataset))

    if self.is_xarray:
        self.data = data  # Keep as Xarray
        # For Xarray, we can use xarray.plot.FacetGrid
        # We delay creation to SpatialFacetGridPlot if possible,
        # or initialize it here if we have enough info.
        self.grid = None
        if col or row:
            try:
                from xarray.plot.facetgrid import FacetGrid as xrFacetGrid

                self.grid = xrFacetGrid(
                    data,
                    col=col,
                    row=row,
                    col_wrap=col_wrap,
                    subplot_kws=subplot_kws,
                )
                # Initialize default titles for Xarray
                self.grid.set_titles()
            except (ImportError, TypeError, AttributeError):
                pass
    else:
        self.data = to_dataframe(data).reset_index()
        # Create the Seaborn FacetGrid
        self.grid = sns.FacetGrid(
            self.data,
            row=self.row,
            col=self.col,
            hue=self.hue,
            col_wrap=self.col_wrap,
            height=self.size,
            aspect=self.aspect,
            subplot_kws=subplot_kws,
            **kwargs,
        )

    # Unified BasePlot initialization
    axes = getattr(self.grid, "axs", getattr(self.grid, "axes", None))
    if axes is not None:
        super().__init__(fig=self.grid.fig, ax=axes.flatten()[0], style=style)
    else:
        super().__init__(style=style)

    # For compatibility with tests, also store as 'g'
    self.g = self.grid

close()

Closes the plot.

Source code in src/monet_plots/plots/facet_grid.py
169
170
171
def close(self):
    """Closes the plot."""
    plt.close(self.fig)

map_dataframe(plot_func, *args, **kwargs)

Maps a plotting function to the facet grid.

Parameters:

Name Type Description Default
plot_func function

The plotting function to map.

required
*args Any

Positional arguments to pass to the plotting function.

()
**kwargs Any

Keyword arguments to pass to the plotting function.

{}
Source code in src/monet_plots/plots/facet_grid.py
130
131
132
133
134
135
136
137
138
def map_dataframe(self, plot_func: Callable, *args: Any, **kwargs: Any) -> None:
    """Maps a plotting function to the facet grid.

    Args:
        plot_func (function): The plotting function to map.
        *args: Positional arguments to pass to the plotting function.
        **kwargs: Keyword arguments to pass to the plotting function.
    """
    self.grid.map_dataframe(plot_func, *args, **kwargs)

plot(plot_func=None, *args, **kwargs)

Plots the data using the FacetGrid.

Parameters:

Name Type Description Default
plot_func function

The plotting function to use.

None
*args

Positional arguments to pass to the plotting function.

()
**kwargs

Keyword arguments to pass to the plotting function.

{}
Source code in src/monet_plots/plots/facet_grid.py
158
159
160
161
162
163
164
165
166
167
def plot(self, plot_func=None, *args, **kwargs):
    """Plots the data using the FacetGrid.

    Args:
        plot_func (function, optional): The plotting function to use.
        *args: Positional arguments to pass to the plotting function.
        **kwargs: Keyword arguments to pass to the plotting function.
    """
    if plot_func is not None:
        self.grid.map(plot_func, *args, **kwargs)

save(filename, **kwargs)

Saves the plot to a file.

Parameters:

Name Type Description Default
filename str

The name of the file to save the plot to.

required
**kwargs

Additional keyword arguments to pass to savefig.

{}
Source code in src/monet_plots/plots/facet_grid.py
149
150
151
152
153
154
155
156
def save(self, filename, **kwargs):
    """Saves the plot to a file.

    Args:
        filename (str): The name of the file to save the plot to.
        **kwargs: Additional keyword arguments to pass to `savefig`.
    """
    self.fig.savefig(filename, **kwargs)

set_titles(*args, **kwargs)

Sets the titles of the facet grid.

Parameters:

Name Type Description Default
*args

Positional arguments to pass to set_titles.

()
**kwargs

Keyword arguments to pass to set_titles.

{}
Source code in src/monet_plots/plots/facet_grid.py
140
141
142
143
144
145
146
147
def set_titles(self, *args, **kwargs):
    """Sets the titles of the facet grid.

    Args:
        *args: Positional arguments to pass to `set_titles`.
        **kwargs: Keyword arguments to pass to `set_titles`.
    """
    self.grid.set_titles(*args, **kwargs)

FingerprintPlot

Bases: BasePlot

Fingerprint plot.

Displays a variable as a heatmap across two different temporal scales, such as hour of day vs. day of year, to reveal periodic patterns.

Source code in src/monet_plots/plots/fingerprint.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
class FingerprintPlot(BasePlot):
    """Fingerprint plot.

    Displays a variable as a heatmap across two different temporal scales,
    such as hour of day vs. day of year, to reveal periodic patterns.
    """

    def __init__(
        self,
        data: Any,
        val_col: str,
        *,
        time_col: str = "time",
        x_scale: str = "hour",
        y_scale: str = "dayofyear",
        **kwargs,
    ):
        """
        Initialize Fingerprint Plot.

        Args:
            data: Input data (DataFrame, DataArray, etc.).
            val_col: Column name of the value to plot.
            time_col: Column name for timestamp.
            x_scale: Temporal scale for the x-axis ('hour', 'month', 'dayofweek', etc.).
            y_scale: Temporal scale for the y-axis ('dayofyear', 'year', 'week', etc.).
            **kwargs: Arguments passed to BasePlot.
        """
        super().__init__(**kwargs)
        self.df = to_dataframe(data).copy()
        self.val_col = val_col
        self.time_col = time_col
        self.x_scale = x_scale
        self.y_scale = y_scale

        # Ensure time_col is datetime
        self.df[self.time_col] = pd.to_datetime(self.df[self.time_col])

        self._extract_scale(self.x_scale, "x_val")
        self._extract_scale(self.y_scale, "y_val")

    def _extract_scale(self, scale: str, target_col: str):
        """Extract temporal features from datetime."""
        t = self.df[self.time_col].dt
        if scale == "hour":
            self.df[target_col] = t.hour
        elif scale == "month":
            self.df[target_col] = t.month
        elif scale == "dayofweek":
            self.df[target_col] = t.dayofweek
        elif scale == "dayofyear":
            self.df[target_col] = t.dayofyear
        elif scale == "week":
            self.df[target_col] = t.isocalendar().week
        elif scale == "year":
            self.df[target_col] = t.year
        elif scale == "date":
            self.df[target_col] = t.date
        else:
            # Try to use it as a direct column if not a known scale
            if scale in self.df.columns:
                self.df[target_col] = self.df[scale]
            else:
                raise ValueError(f"Unknown temporal scale: {scale}")

    def plot(self, cmap: str = "viridis", **kwargs):
        """Generate the fingerprint heatmap."""
        pivot_df = self.df.pivot_table(
            index="y_val", columns="x_val", values=self.val_col, aggfunc="mean"
        )

        sns.heatmap(pivot_df, ax=self.ax, cmap=cmap, **kwargs)

        self.ax.set_xlabel(self.x_scale.capitalize())
        self.ax.set_ylabel(self.y_scale.capitalize())
        self.ax.set_title(f"Fingerprint: {self.val_col}")

        return self.ax

__init__(data, val_col, *, time_col='time', x_scale='hour', y_scale='dayofyear', **kwargs)

Initialize Fingerprint Plot.

Parameters:

Name Type Description Default
data Any

Input data (DataFrame, DataArray, etc.).

required
val_col str

Column name of the value to plot.

required
time_col str

Column name for timestamp.

'time'
x_scale str

Temporal scale for the x-axis ('hour', 'month', 'dayofweek', etc.).

'hour'
y_scale str

Temporal scale for the y-axis ('dayofyear', 'year', 'week', etc.).

'dayofyear'
**kwargs

Arguments passed to BasePlot.

{}
Source code in src/monet_plots/plots/fingerprint.py
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def __init__(
    self,
    data: Any,
    val_col: str,
    *,
    time_col: str = "time",
    x_scale: str = "hour",
    y_scale: str = "dayofyear",
    **kwargs,
):
    """
    Initialize Fingerprint Plot.

    Args:
        data: Input data (DataFrame, DataArray, etc.).
        val_col: Column name of the value to plot.
        time_col: Column name for timestamp.
        x_scale: Temporal scale for the x-axis ('hour', 'month', 'dayofweek', etc.).
        y_scale: Temporal scale for the y-axis ('dayofyear', 'year', 'week', etc.).
        **kwargs: Arguments passed to BasePlot.
    """
    super().__init__(**kwargs)
    self.df = to_dataframe(data).copy()
    self.val_col = val_col
    self.time_col = time_col
    self.x_scale = x_scale
    self.y_scale = y_scale

    # Ensure time_col is datetime
    self.df[self.time_col] = pd.to_datetime(self.df[self.time_col])

    self._extract_scale(self.x_scale, "x_val")
    self._extract_scale(self.y_scale, "y_val")

plot(cmap='viridis', **kwargs)

Generate the fingerprint heatmap.

Source code in src/monet_plots/plots/fingerprint.py
78
79
80
81
82
83
84
85
86
87
88
89
90
def plot(self, cmap: str = "viridis", **kwargs):
    """Generate the fingerprint heatmap."""
    pivot_df = self.df.pivot_table(
        index="y_val", columns="x_val", values=self.val_col, aggfunc="mean"
    )

    sns.heatmap(pivot_df, ax=self.ax, cmap=cmap, **kwargs)

    self.ax.set_xlabel(self.x_scale.capitalize())
    self.ax.set_ylabel(self.y_scale.capitalize())
    self.ax.set_title(f"Fingerprint: {self.val_col}")

    return self.ax

KDEPlot

Bases: BasePlot

Create a kernel density estimate plot.

This plot shows the distribution of a single variable.

Source code in src/monet_plots/plots/kde.py
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
class KDEPlot(BasePlot):
    """Create a kernel density estimate plot.

    This plot shows the distribution of a single variable.
    """

    def __init__(self, df, x, y, title=None, label=None, *args, **kwargs):
        """
        Initialize the plot with data and plot settings.

        Args:
            df (pd.DataFrame, np.ndarray, xr.Dataset, xr.DataArray): DataFrame with the data to plot.
            x (str): Column name for the x-axis.
            y (str): Column name for the y-axis.
            title (str, optional): Title for the plot.
            label (str, optional): Label for the plot.
        """
        super().__init__(*args, **kwargs)
        self.df = df
        self.x = x
        self.y = y
        self.title = title
        self.label = label

    def plot(self, **kwargs):
        """Generate the KDE plot."""
        with sns.axes_style("ticks"):
            self.ax = sns.kdeplot(
                data=self.df, x=self.x, y=self.y, ax=self.ax, label=self.label, **kwargs
            )
            if self.title:
                self.ax.set_title(self.title)
            sns.despine()
        return self.ax

__init__(df, x, y, title=None, label=None, *args, **kwargs)

Initialize the plot with data and plot settings.

Parameters:

Name Type Description Default
df (DataFrame, ndarray, Dataset, DataArray)

DataFrame with the data to plot.

required
x str

Column name for the x-axis.

required
y str

Column name for the y-axis.

required
title str

Title for the plot.

None
label str

Label for the plot.

None
Source code in src/monet_plots/plots/kde.py
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
def __init__(self, df, x, y, title=None, label=None, *args, **kwargs):
    """
    Initialize the plot with data and plot settings.

    Args:
        df (pd.DataFrame, np.ndarray, xr.Dataset, xr.DataArray): DataFrame with the data to plot.
        x (str): Column name for the x-axis.
        y (str): Column name for the y-axis.
        title (str, optional): Title for the plot.
        label (str, optional): Label for the plot.
    """
    super().__init__(*args, **kwargs)
    self.df = df
    self.x = x
    self.y = y
    self.title = title
    self.label = label

plot(**kwargs)

Generate the KDE plot.

Source code in src/monet_plots/plots/kde.py
32
33
34
35
36
37
38
39
40
41
def plot(self, **kwargs):
    """Generate the KDE plot."""
    with sns.axes_style("ticks"):
        self.ax = sns.kdeplot(
            data=self.df, x=self.x, y=self.y, ax=self.ax, label=self.label, **kwargs
        )
        if self.title:
            self.ax.set_title(self.title)
        sns.despine()
    return self.ax

PerformanceDiagramPlot

Bases: BasePlot

Performance Diagram Plot (Roebber).

Visualizes the relationship between Probability of Detection (POD), Success Ratio (SR), Critical Success Index (CSI), and Bias.

Functional Requirements: 1. Plot POD (y-axis) vs Success Ratio (x-axis). 2. Draw background isolines for CSI and Bias. 3. Support input as pre-calculated metrics or contingency table counts. 4. Handle multiple models/configurations via grouping.

Edge Cases: - SR or POD being 0 or 1 (division by zero in bias/CSI calculations). - Empty DataFrame. - Missing required columns.

Source code in src/monet_plots/plots/performance_diagram.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
class PerformanceDiagramPlot(BasePlot):
    """
    Performance Diagram Plot (Roebber).

    Visualizes the relationship between Probability of Detection (POD),
    Success Ratio (SR),
    Critical Success Index (CSI), and Bias.

    Functional Requirements:
    1. Plot POD (y-axis) vs Success Ratio (x-axis).
    2. Draw background isolines for CSI and Bias.
    3. Support input as pre-calculated metrics or contingency table counts.
    4. Handle multiple models/configurations via grouping.

    Edge Cases:
    - SR or POD being 0 or 1 (division by zero in bias/CSI calculations).
    - Empty DataFrame.
    - Missing required columns.
    """

    def __init__(self, fig=None, ax=None, **kwargs):
        super().__init__(fig=fig, ax=ax, **kwargs)

    def plot(
        self,
        data: Any,
        x_col: str = "success_ratio",
        y_col: str = "pod",
        counts_cols: Optional[List[str]] = None,
        label_col: Optional[str] = None,
        **kwargs,
    ):
        """
        Main plotting method.

        Args:
            data (pd.DataFrame, np.ndarray, xr.Dataset, xr.DataArray): Input data.
            x_col (str): Column name for Success Ratio (1-FAR).
            y_col (str): Column name for POD.
            counts_cols (list, optional): List of columns [hits, misses, fa, cn]
                                        to calculate metrics if x_col/y_col not present.
            label_col (str, optional): Column to use for legend labels.
            **kwargs: Matplotlib kwargs.
        """
        df = to_dataframe(data)
        # TDD Anchor: Test validation raises error on missing cols
        self._validate_inputs(df, x_col, y_col, counts_cols)

        # Data Preparation
        df_plot = self._prepare_data(df, x_col, y_col, counts_cols)

        # Plot Background (Isolines)
        self._draw_background()

        # Plot Data
        # TDD Anchor: Verify scatter points match input data coordinates
        if label_col:
            for name, group in df_plot.groupby(label_col):
                self.ax.plot(
                    group[x_col],
                    group[y_col],
                    marker="o",
                    label=name,
                    linestyle="none",
                    **kwargs,
                )
            self.ax.legend(loc="best")
        else:
            self.ax.plot(
                df_plot[x_col], df_plot[y_col], marker="o", linestyle="none", **kwargs
            )

        # Formatting
        self.ax.set_xlim(0, 1)
        self.ax.set_ylim(0, 1)
        self.ax.set_xlabel("Success Ratio (1-FAR)")
        self.ax.set_ylabel("Probability of Detection (POD)")
        self.ax.set_aspect("equal")

    def _validate_inputs(self, data, x, y, counts):
        """Validates input dataframe structure."""
        if counts:
            validate_dataframe(data, required_columns=counts)
        else:
            validate_dataframe(data, required_columns=[x, y])

    def _prepare_data(self, data, x, y, counts):
        """
        Calculates metrics if counts provided, otherwise returns subset.
        TDD Anchor: Test calculation logic: SR = hits/(hits+fa), POD = hits/(hits+miss).
        """
        df = data.copy()
        if counts:
            hits_col, misses_col, fa_col, cn_col = counts
            df[x] = compute_success_ratio(df[hits_col], df[fa_col])
            df[y] = compute_pod(df[hits_col], df[misses_col])
        return df

    def _draw_background(self):
        """
        Draws CSI and Bias isolines.

        Pseudocode:
        1. Create meshgrid for x (SR) and y (POD) from 0.01 to 1.
        2. Calculate CSI = 1 / (1/SR + 1/POD - 1).
        3. Calculate Bias = POD / SR.
        4. Contour plot CSI (dashed).
        5. Contour plot Bias (dotted).
        6. Label contours.
        """
        # Avoid division by zero at boundaries
        xx, yy = np.meshgrid(np.linspace(0.01, 0.99, 50), np.linspace(0.01, 0.99, 50))
        csi = (xx * yy) / (xx + yy - xx * yy)
        bias = yy / xx

        # CSI contours (dashed, lightgray)
        cs_csi = self.ax.contour(
            xx,
            yy,
            csi,
            levels=np.arange(0.1, 0.95, 0.1),
            colors="lightgray",
            linestyles="--",
            alpha=0.6,
        )
        self.ax.clabel(cs_csi, inline=True, fontsize=8, fmt="%.1f")

        # Bias contours (dotted, darkgray)
        cs_bias = self.ax.contour(
            xx,
            yy,
            bias,
            levels=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0],
            colors="darkgray",
            linestyles=":",
            alpha=0.6,
        )
        self.ax.clabel(cs_bias, inline=True, fontsize=8, fmt="%.1f")

        # Perfect forecast line
        self.ax.plot([0.01, 0.99], [0.01, 0.99], "k-", linewidth=1.5, alpha=0.8)

plot(data, x_col='success_ratio', y_col='pod', counts_cols=None, label_col=None, **kwargs)

Main plotting method.

Parameters:

Name Type Description Default
data (DataFrame, ndarray, Dataset, DataArray)

Input data.

required
x_col str

Column name for Success Ratio (1-FAR).

'success_ratio'
y_col str

Column name for POD.

'pod'
counts_cols list

List of columns [hits, misses, fa, cn] to calculate metrics if x_col/y_col not present.

None
label_col str

Column to use for legend labels.

None
**kwargs

Matplotlib kwargs.

{}
Source code in src/monet_plots/plots/performance_diagram.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
def plot(
    self,
    data: Any,
    x_col: str = "success_ratio",
    y_col: str = "pod",
    counts_cols: Optional[List[str]] = None,
    label_col: Optional[str] = None,
    **kwargs,
):
    """
    Main plotting method.

    Args:
        data (pd.DataFrame, np.ndarray, xr.Dataset, xr.DataArray): Input data.
        x_col (str): Column name for Success Ratio (1-FAR).
        y_col (str): Column name for POD.
        counts_cols (list, optional): List of columns [hits, misses, fa, cn]
                                    to calculate metrics if x_col/y_col not present.
        label_col (str, optional): Column to use for legend labels.
        **kwargs: Matplotlib kwargs.
    """
    df = to_dataframe(data)
    # TDD Anchor: Test validation raises error on missing cols
    self._validate_inputs(df, x_col, y_col, counts_cols)

    # Data Preparation
    df_plot = self._prepare_data(df, x_col, y_col, counts_cols)

    # Plot Background (Isolines)
    self._draw_background()

    # Plot Data
    # TDD Anchor: Verify scatter points match input data coordinates
    if label_col:
        for name, group in df_plot.groupby(label_col):
            self.ax.plot(
                group[x_col],
                group[y_col],
                marker="o",
                label=name,
                linestyle="none",
                **kwargs,
            )
        self.ax.legend(loc="best")
    else:
        self.ax.plot(
            df_plot[x_col], df_plot[y_col], marker="o", linestyle="none", **kwargs
        )

    # Formatting
    self.ax.set_xlim(0, 1)
    self.ax.set_ylim(0, 1)
    self.ax.set_xlabel("Success Ratio (1-FAR)")
    self.ax.set_ylabel("Probability of Detection (POD)")
    self.ax.set_aspect("equal")

ProfilePlot

Bases: BasePlot

Profile or cross-section plot.

Source code in src/monet_plots/plots/profile.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
class ProfilePlot(BasePlot):
    """Profile or cross-section plot."""

    def __init__(
        self,
        *,
        x: np.ndarray,
        y: np.ndarray,
        z: np.ndarray | None = None,
        alt_adjust: float | None = None,
        **kwargs: t.Any,
    ) -> None:
        """
        Parameters
        ----------
        x
            X-axis data.
        y
            Y-axis data.
        z
            Optional Z-axis data for contour plots.
        alt_adjust
            Value to subtract from the y-axis data for altitude adjustment.
        **kwargs
            Keyword arguments passed to the parent class.
        """
        super().__init__(**kwargs)
        self.x = x
        if alt_adjust is not None:
            self.y = y - alt_adjust
        else:
            self.y = y
        self.z = z

    def plot(self, **kwargs: t.Any) -> None:
        """
        Parameters
        ----------
        **kwargs
            Keyword arguments passed to `matplotlib.pyplot.plot` or
            `matplotlib.pyplot.contourf`.
        """
        if self.ax is None:
            if self.fig is None:
                self.fig = plt.figure()
            self.ax = self.fig.add_subplot()

        if self.z is not None:
            self.ax.contourf(self.x, self.y, self.z, **kwargs)
        else:
            self.ax.plot(self.x, self.y, **kwargs)

__init__(*, x, y, z=None, alt_adjust=None, **kwargs)

Parameters

x X-axis data. y Y-axis data. z Optional Z-axis data for contour plots. alt_adjust Value to subtract from the y-axis data for altitude adjustment. **kwargs Keyword arguments passed to the parent class.

Source code in src/monet_plots/plots/profile.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
def __init__(
    self,
    *,
    x: np.ndarray,
    y: np.ndarray,
    z: np.ndarray | None = None,
    alt_adjust: float | None = None,
    **kwargs: t.Any,
) -> None:
    """
    Parameters
    ----------
    x
        X-axis data.
    y
        Y-axis data.
    z
        Optional Z-axis data for contour plots.
    alt_adjust
        Value to subtract from the y-axis data for altitude adjustment.
    **kwargs
        Keyword arguments passed to the parent class.
    """
    super().__init__(**kwargs)
    self.x = x
    if alt_adjust is not None:
        self.y = y - alt_adjust
    else:
        self.y = y
    self.z = z

plot(**kwargs)

Parameters

**kwargs Keyword arguments passed to matplotlib.pyplot.plot or matplotlib.pyplot.contourf.

Source code in src/monet_plots/plots/profile.py
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def plot(self, **kwargs: t.Any) -> None:
    """
    Parameters
    ----------
    **kwargs
        Keyword arguments passed to `matplotlib.pyplot.plot` or
        `matplotlib.pyplot.contourf`.
    """
    if self.ax is None:
        if self.fig is None:
            self.fig = plt.figure()
        self.ax = self.fig.add_subplot()

    if self.z is not None:
        self.ax.contourf(self.x, self.y, self.z, **kwargs)
    else:
        self.ax.plot(self.x, self.y, **kwargs)

ROCCurvePlot

Bases: BasePlot

Receiver Operating Characteristic (ROC) Curve Plot.

Visualizes the trade-off between Probability of Detection (POD) and Probability of False Detection (POFD).

Functional Requirements: 1. Plot POD (y-axis) vs POFD (x-axis). 2. Draw diagonal "no skill" line (0,0) to (1,1). 3. Calculate and display Area Under Curve (AUC) in legend. 4. Support multiple models/curves via grouping.

Edge Cases: - Non-monotonic data points (should sort by threshold/prob). - Single point provided (cannot calculate AUC properly, return NaN or handle gracefully). - Missing columns.

Source code in src/monet_plots/plots/roc_curve.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
class ROCCurvePlot(BasePlot):
    """
    Receiver Operating Characteristic (ROC) Curve Plot.

    Visualizes the trade-off between Probability of Detection (POD) and
    Probability of False Detection (POFD).

    Functional Requirements:
    1. Plot POD (y-axis) vs POFD (x-axis).
    2. Draw diagonal "no skill" line (0,0) to (1,1).
    3. Calculate and display Area Under Curve (AUC) in legend.
    4. Support multiple models/curves via grouping.

    Edge Cases:
    - Non-monotonic data points (should sort by threshold/prob).
    - Single point provided (cannot calculate AUC properly, return NaN or handle gracefully).
    - Missing columns.
    """

    def __init__(self, fig=None, ax=None, **kwargs):
        super().__init__(fig=fig, ax=ax, **kwargs)

    def plot(
        self,
        data: Any,
        x_col: str = "pofd",
        y_col: str = "pod",
        label_col: Optional[str] = None,
        show_auc: bool = True,
        **kwargs,
    ):
        """
        Main plotting method.

        Args:
            data (pd.DataFrame, np.ndarray, xr.Dataset, xr.DataArray): Input data containing ROC points.
            x_col (str): Column name for POFD (False Alarm Rate).
            y_col (str): Column name for POD (Hit Rate).
            label_col (str, optional): Column for grouping different curves.
            show_auc (bool): Whether to calculate and append AUC to labels.
            **kwargs: Matplotlib kwargs.
        """
        df = to_dataframe(data)
        # TDD Anchor: Test validation raises error on missing cols
        validate_dataframe(df, required_columns=[x_col, y_col])

        # Draw No Skill Line
        self.ax.plot([0, 1], [0, 1], "k--", label="No Skill", alpha=0.5)
        self.ax.grid(True, alpha=0.3)

        if label_col:
            groups = df.groupby(label_col)
            for name, group in groups:
                self._plot_single_curve(
                    group, x_col, y_col, label=str(name), show_auc=show_auc, **kwargs
                )
            self.ax.legend(loc="lower right")
        else:
            self._plot_single_curve(
                df, x_col, y_col, label="Model", show_auc=show_auc, **kwargs
            )

        # Formatting
        self.ax.set_xlim(0, 1)
        self.ax.set_ylim(0, 1)
        self.ax.set_xlabel("Probability of False Detection (POFD)")
        self.ax.set_ylabel("Probability of Detection (POD)")
        self.ax.set_aspect("equal")

    def _plot_single_curve(self, df, x_col, y_col, label, show_auc, **kwargs):
        """
        Helper to plot a single ROC curve and calc AUC.

        Pseudocode:
        1. Sort df by x_col (POFD) ascending.
        2. Get x (POFD) and y (POD) arrays.
        3. If show_auc:
            auc = trapz(y, x)
            label += f" (AUC={auc:.3f})"
        4. self.ax.plot(x, y, label=label, **kwargs)
        """
        # TDD Anchor: Test AUC calculation against sklearn.metrics.auc or manual known
        # values.
        # TDD Anchor: Ensure sorting is applied correctly.

        df_sorted = df.sort_values(by=x_col).dropna(subset=[x_col, y_col])
        x = df_sorted[x_col].values
        y = df_sorted[y_col].values

        auc_str = ""
        if len(x) >= 2 and show_auc:
            auc = compute_auc(x, y)
            auc_str = f" (AUC={auc:.3f})"

        full_label = label + auc_str
        self.ax.plot(x, y, label=full_label, **kwargs)
        self.ax.fill_between(x, 0, y, alpha=0.2, **kwargs)

plot(data, x_col='pofd', y_col='pod', label_col=None, show_auc=True, **kwargs)

Main plotting method.

Parameters:

Name Type Description Default
data (DataFrame, ndarray, Dataset, DataArray)

Input data containing ROC points.

required
x_col str

Column name for POFD (False Alarm Rate).

'pofd'
y_col str

Column name for POD (Hit Rate).

'pod'
label_col str

Column for grouping different curves.

None
show_auc bool

Whether to calculate and append AUC to labels.

True
**kwargs

Matplotlib kwargs.

{}
Source code in src/monet_plots/plots/roc_curve.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def plot(
    self,
    data: Any,
    x_col: str = "pofd",
    y_col: str = "pod",
    label_col: Optional[str] = None,
    show_auc: bool = True,
    **kwargs,
):
    """
    Main plotting method.

    Args:
        data (pd.DataFrame, np.ndarray, xr.Dataset, xr.DataArray): Input data containing ROC points.
        x_col (str): Column name for POFD (False Alarm Rate).
        y_col (str): Column name for POD (Hit Rate).
        label_col (str, optional): Column for grouping different curves.
        show_auc (bool): Whether to calculate and append AUC to labels.
        **kwargs: Matplotlib kwargs.
    """
    df = to_dataframe(data)
    # TDD Anchor: Test validation raises error on missing cols
    validate_dataframe(df, required_columns=[x_col, y_col])

    # Draw No Skill Line
    self.ax.plot([0, 1], [0, 1], "k--", label="No Skill", alpha=0.5)
    self.ax.grid(True, alpha=0.3)

    if label_col:
        groups = df.groupby(label_col)
        for name, group in groups:
            self._plot_single_curve(
                group, x_col, y_col, label=str(name), show_auc=show_auc, **kwargs
            )
        self.ax.legend(loc="lower right")
    else:
        self._plot_single_curve(
            df, x_col, y_col, label="Model", show_auc=show_auc, **kwargs
        )

    # Formatting
    self.ax.set_xlim(0, 1)
    self.ax.set_ylim(0, 1)
    self.ax.set_xlabel("Probability of False Detection (POFD)")
    self.ax.set_ylabel("Probability of Detection (POD)")
    self.ax.set_aspect("equal")

RankHistogramPlot

Bases: BasePlot

Rank Histogram (Talagrand Diagram).

Visualizes the distribution of observation ranks within an ensemble.

Functional Requirements: 1. Plot bar chart of rank frequencies. 2. Draw horizontal line for "Perfect Flatness" (uniform distribution). 3. Support normalizing frequencies (relative frequency) or raw counts. 4. Interpret shapes: U-shape (underdispersed), A-shape (overdispersed), Bias (slope).

Edge Cases: - Unequal ensemble sizes (requires binning or normalization logic, but typically preprocessing handles this). - Missing ranks (should be 0 height bars).

Source code in src/monet_plots/plots/rank_histogram.py
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
class RankHistogramPlot(BasePlot):
    """
    Rank Histogram (Talagrand Diagram).

    Visualizes the distribution of observation ranks within an ensemble.

    Functional Requirements:
    1. Plot bar chart of rank frequencies.
    2. Draw horizontal line for "Perfect Flatness" (uniform distribution).
    3. Support normalizing frequencies (relative frequency) or raw counts.
    4. Interpret shapes: U-shape (underdispersed), A-shape (overdispersed), Bias (slope).

    Edge Cases:
    - Unequal ensemble sizes (requires binning or normalization logic, but typically preprocessing handles this).
    - Missing ranks (should be 0 height bars).
    """

    def __init__(self, fig=None, ax=None, **kwargs):
        super().__init__(fig=fig, ax=ax, **kwargs)

    def plot(
        self,
        data: Any,
        rank_col: str = "rank",
        n_members: Optional[int] = None,
        label_col: Optional[str] = None,
        normalize: bool = True,
        **kwargs,
    ):
        """
        Main plotting method.

        Args:
            data (pd.DataFrame, np.ndarray, xr.Dataset, xr.DataArray): Data containing ranks (0 to n_members).
            rank_col (str): Column containing the rank of the observation.
            n_members (Optional[int]): Number of ensemble members (defines n_bins = n_members + 1).
                                      Inferred from max(rank) if None.
            label_col (Optional[str]): Grouping for multiple histograms (e.g., lead times).
            normalize (bool): If True, plot relative frequency; else raw counts.
            **kwargs: Matplotlib kwargs.
        """
        df = to_dataframe(data)
        validate_dataframe(df, required_columns=[rank_col])

        if n_members is None:
            n_members = int(df[rank_col].max())

        num_bins = n_members + 1

        if normalize:
            expected = 1.0 / num_bins
        else:
            expected = len(df) / num_bins

        # TDD Anchor: Validate inputs

        if label_col:
            for name, group in df.groupby(label_col):
                counts = (
                    group[rank_col]
                    .value_counts()
                    .reindex(np.arange(num_bins), fill_value=0)
                )
                total = counts.sum()
                freq = counts / total if normalize else counts
                self.ax.bar(
                    counts.index, freq.values, label=str(name), alpha=0.7, **kwargs
                )
            self.ax.legend()
        else:
            counts = (
                df[rank_col].value_counts().reindex(np.arange(num_bins), fill_value=0)
            )
            total = counts.sum()
            freq = counts / total if normalize else counts
            self.ax.bar(counts.index, freq.values, alpha=0.7, **kwargs)

        # Expected uniform line
        self.ax.axhline(
            expected, color="k", linestyle="--", linewidth=2, label="Expected (Uniform)"
        )
        self.ax.legend()

        # Formatting
        self.ax.set_xlabel("Rank")
        self.ax.set_ylabel("Relative Frequency" if normalize else "Count")
        self.ax.set_xticks(np.arange(n_members + 1))
        self.ax.set_xlim(-0.5, n_members + 0.5)
        self.ax.grid(True, alpha=0.3)

plot(data, rank_col='rank', n_members=None, label_col=None, normalize=True, **kwargs)

Main plotting method.

Parameters:

Name Type Description Default
data (DataFrame, ndarray, Dataset, DataArray)

Data containing ranks (0 to n_members).

required
rank_col str

Column containing the rank of the observation.

'rank'
n_members Optional[int]

Number of ensemble members (defines n_bins = n_members + 1). Inferred from max(rank) if None.

None
label_col Optional[str]

Grouping for multiple histograms (e.g., lead times).

None
normalize bool

If True, plot relative frequency; else raw counts.

True
**kwargs

Matplotlib kwargs.

{}
Source code in src/monet_plots/plots/rank_histogram.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
def plot(
    self,
    data: Any,
    rank_col: str = "rank",
    n_members: Optional[int] = None,
    label_col: Optional[str] = None,
    normalize: bool = True,
    **kwargs,
):
    """
    Main plotting method.

    Args:
        data (pd.DataFrame, np.ndarray, xr.Dataset, xr.DataArray): Data containing ranks (0 to n_members).
        rank_col (str): Column containing the rank of the observation.
        n_members (Optional[int]): Number of ensemble members (defines n_bins = n_members + 1).
                                  Inferred from max(rank) if None.
        label_col (Optional[str]): Grouping for multiple histograms (e.g., lead times).
        normalize (bool): If True, plot relative frequency; else raw counts.
        **kwargs: Matplotlib kwargs.
    """
    df = to_dataframe(data)
    validate_dataframe(df, required_columns=[rank_col])

    if n_members is None:
        n_members = int(df[rank_col].max())

    num_bins = n_members + 1

    if normalize:
        expected = 1.0 / num_bins
    else:
        expected = len(df) / num_bins

    # TDD Anchor: Validate inputs

    if label_col:
        for name, group in df.groupby(label_col):
            counts = (
                group[rank_col]
                .value_counts()
                .reindex(np.arange(num_bins), fill_value=0)
            )
            total = counts.sum()
            freq = counts / total if normalize else counts
            self.ax.bar(
                counts.index, freq.values, label=str(name), alpha=0.7, **kwargs
            )
        self.ax.legend()
    else:
        counts = (
            df[rank_col].value_counts().reindex(np.arange(num_bins), fill_value=0)
        )
        total = counts.sum()
        freq = counts / total if normalize else counts
        self.ax.bar(counts.index, freq.values, alpha=0.7, **kwargs)

    # Expected uniform line
    self.ax.axhline(
        expected, color="k", linestyle="--", linewidth=2, label="Expected (Uniform)"
    )
    self.ax.legend()

    # Formatting
    self.ax.set_xlabel("Rank")
    self.ax.set_ylabel("Relative Frequency" if normalize else "Count")
    self.ax.set_xticks(np.arange(n_members + 1))
    self.ax.set_xlim(-0.5, n_members + 0.5)
    self.ax.grid(True, alpha=0.3)

RelativeEconomicValuePlot

Bases: BasePlot

Relative Economic Value (REV) Plot.

Visualizes the potential economic value of a forecast system relative to climatology.

Functional Requirements: 1. Plot Value (y-axis) vs Cost/Loss Ratio (x-axis). 2. Calculate REV based on Hits, Misses, False Alarms, Correct Negatives. 3. Support multiple models. 4. X-axis usually logarithmic or specific range [0, 1].

Edge Cases: - C/L ratio 0 or 1 (value is 0). - No events observed (metrics undefined).

Source code in src/monet_plots/plots/rev.py
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
class RelativeEconomicValuePlot(BasePlot):
    """
    Relative Economic Value (REV) Plot.

    Visualizes the potential economic value of a forecast system relative to climatology.

    Functional Requirements:
    1. Plot Value (y-axis) vs Cost/Loss Ratio (x-axis).
    2. Calculate REV based on Hits, Misses, False Alarms, Correct Negatives.
    3. Support multiple models.
    4. X-axis usually logarithmic or specific range [0, 1].

    Edge Cases:
    - C/L ratio 0 or 1 (value is 0).
    - No events observed (metrics undefined).
    """

    def __init__(self, fig=None, ax=None, **kwargs):
        super().__init__(fig=fig, ax=ax, **kwargs)

    def plot(
        self,
        data: Any,
        counts_cols: List[str] = ["hits", "misses", "fa", "cn"],
        climatology: Optional[float] = None,
        label_col: Optional[str] = None,
        cost_loss_ratios: Optional[np.ndarray] = None,
        **kwargs,
    ):
        """
        Main plotting method.

        Args:
            data (pd.DataFrame, np.ndarray, xr.Dataset, xr.DataArray): Input data with contingency table counts.
            counts_cols (List[str]): Contingency table columns [hits, misses, fa, cn].
            climatology (Optional[float]): Sample climatology (base rate). Computed if None.
            label_col (Optional[str]): Grouping column for multiple curves.
            cost_loss_ratios (Optional[np.ndarray]): Array of C/L ratios. Default linspace(0.001,0.999,100).
            **kwargs: Matplotlib kwargs.
        """
        df = to_dataframe(data)
        validate_dataframe(df, required_columns=counts_cols)

        if climatology is None:
            total_events = df[counts_cols[0]].sum() + df[counts_cols[1]].sum()
            total = total_events + df[counts_cols[2]].sum() + df[counts_cols[3]].sum()
            climatology = total_events / total if total > 0 else 0.5

        if cost_loss_ratios is None:
            cost_loss_ratios = np.linspace(0.001, 0.999, 100)

        # TDD Anchor: Test REV calculation logic

        if label_col:
            for name, group in df.groupby(label_col):
                rev_values = self._calculate_rev(
                    group, counts_cols, cost_loss_ratios, climatology
                )
                self.ax.plot(cost_loss_ratios, rev_values, label=str(name), **kwargs)
            self.ax.legend(loc="best")
        else:
            rev_values = self._calculate_rev(
                df, counts_cols, cost_loss_ratios, climatology
            )
            self.ax.plot(cost_loss_ratios, rev_values, label="Model", **kwargs)

        self.ax.set_xlabel("Cost/Loss Ratio")
        self.ax.set_ylabel("Relative Economic Value (REV)")
        self.ax.set_ylim(-0.2, 1.05)
        self.ax.axhline(0, color="k", linestyle="--", alpha=0.7, label="Climatology")
        self.ax.axhline(1, color="gray", linestyle=":", alpha=0.7, label="Perfect")
        self.ax.legend()
        self.ax.grid(True, alpha=0.3)

    def _calculate_rev(self, df, cols, ratios, clim):
        """
        Calculates REV for given ratios.
        """
        hits = df[cols[0]].sum()
        misses = df[cols[1]].sum()
        fa = df[cols[2]].sum()
        cn = df[cols[3]].sum()
        return compute_rev(hits, misses, fa, cn, ratios, clim)

plot(data, counts_cols=['hits', 'misses', 'fa', 'cn'], climatology=None, label_col=None, cost_loss_ratios=None, **kwargs)

Main plotting method.

Parameters:

Name Type Description Default
data (DataFrame, ndarray, Dataset, DataArray)

Input data with contingency table counts.

required
counts_cols List[str]

Contingency table columns [hits, misses, fa, cn].

['hits', 'misses', 'fa', 'cn']
climatology Optional[float]

Sample climatology (base rate). Computed if None.

None
label_col Optional[str]

Grouping column for multiple curves.

None
cost_loss_ratios Optional[ndarray]

Array of C/L ratios. Default linspace(0.001,0.999,100).

None
**kwargs

Matplotlib kwargs.

{}
Source code in src/monet_plots/plots/rev.py
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
def plot(
    self,
    data: Any,
    counts_cols: List[str] = ["hits", "misses", "fa", "cn"],
    climatology: Optional[float] = None,
    label_col: Optional[str] = None,
    cost_loss_ratios: Optional[np.ndarray] = None,
    **kwargs,
):
    """
    Main plotting method.

    Args:
        data (pd.DataFrame, np.ndarray, xr.Dataset, xr.DataArray): Input data with contingency table counts.
        counts_cols (List[str]): Contingency table columns [hits, misses, fa, cn].
        climatology (Optional[float]): Sample climatology (base rate). Computed if None.
        label_col (Optional[str]): Grouping column for multiple curves.
        cost_loss_ratios (Optional[np.ndarray]): Array of C/L ratios. Default linspace(0.001,0.999,100).
        **kwargs: Matplotlib kwargs.
    """
    df = to_dataframe(data)
    validate_dataframe(df, required_columns=counts_cols)

    if climatology is None:
        total_events = df[counts_cols[0]].sum() + df[counts_cols[1]].sum()
        total = total_events + df[counts_cols[2]].sum() + df[counts_cols[3]].sum()
        climatology = total_events / total if total > 0 else 0.5

    if cost_loss_ratios is None:
        cost_loss_ratios = np.linspace(0.001, 0.999, 100)

    # TDD Anchor: Test REV calculation logic

    if label_col:
        for name, group in df.groupby(label_col):
            rev_values = self._calculate_rev(
                group, counts_cols, cost_loss_ratios, climatology
            )
            self.ax.plot(cost_loss_ratios, rev_values, label=str(name), **kwargs)
        self.ax.legend(loc="best")
    else:
        rev_values = self._calculate_rev(
            df, counts_cols, cost_loss_ratios, climatology
        )
        self.ax.plot(cost_loss_ratios, rev_values, label="Model", **kwargs)

    self.ax.set_xlabel("Cost/Loss Ratio")
    self.ax.set_ylabel("Relative Economic Value (REV)")
    self.ax.set_ylim(-0.2, 1.05)
    self.ax.axhline(0, color="k", linestyle="--", alpha=0.7, label="Climatology")
    self.ax.axhline(1, color="gray", linestyle=":", alpha=0.7, label="Perfect")
    self.ax.legend()
    self.ax.grid(True, alpha=0.3)

ReliabilityDiagramPlot

Bases: BasePlot

Reliability Diagram Plot (Attributes Diagram).

Visualizes Observed Frequency vs Forecast Probability.

Functional Requirements: 1. Plot Observed Frequency (y-axis) vs Forecast Probability (x-axis). 2. Draw "Perfect Reliability" diagonal (1:1). 3. Draw "No Skill" line (horizontal at climatology/sample mean). 4. Shade "Skill" areas (where Brier Skill Score > 0). 5. Include inset histogram of forecast usage (Sharpness) if requested.

Edge Cases: - Empty bins (no forecasts with that probability). - Climatology not provided (cannot draw skill regions correctly).

Source code in src/monet_plots/plots/reliability_diagram.py
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
class ReliabilityDiagramPlot(BasePlot):
    """
    Reliability Diagram Plot (Attributes Diagram).

    Visualizes Observed Frequency vs Forecast Probability.

    Functional Requirements:
    1. Plot Observed Frequency (y-axis) vs Forecast Probability (x-axis).
    2. Draw "Perfect Reliability" diagonal (1:1).
    3. Draw "No Skill" line (horizontal at climatology/sample mean).
    4. Shade "Skill" areas (where Brier Skill Score > 0).
    5. Include inset histogram of forecast usage (Sharpness) if requested.

    Edge Cases:
    - Empty bins (no forecasts with that probability).
    - Climatology not provided (cannot draw skill regions correctly).
    """

    def __init__(self, fig=None, ax=None, **kwargs):
        super().__init__(fig=fig, ax=ax, **kwargs)

    def plot(
        self,
        data: Any,
        x_col: str = "prob",
        y_col: str = "freq",
        forecasts_col: Optional[str] = None,
        observations_col: Optional[str] = None,
        n_bins: int = 10,
        climatology: Optional[float] = None,
        label_col: Optional[str] = None,
        show_hist: bool = False,
        **kwargs,
    ):
        """
        Main plotting method.

        Args:
            data: Input data.
            x_col (str): Forecast Probability bin center (for pre-binned).
            y_col (str): Observed Frequency in bin (for pre-binned).
            forecasts_col (str, optional): Column of raw forecast probabilities [0,1].
            observations_col (str, optional): Column of binary observations {0,1}.
            n_bins (int): Number of bins for reliability curve computation.
            climatology (Optional[float]): Sample climatology (mean(observations)).
            label_col (str, optional): Grouping column.
            show_hist (bool): Whether to show frequency of usage histogram.
            **kwargs: Matplotlib kwargs.
        """
        df = to_dataframe(data)
        # Compute if raw data provided
        if forecasts_col and observations_col:
            if climatology is None:
                climatology = float(df[observations_col].mean())
            bin_centers, obs_freq, bin_counts = compute_reliability_curve(
                np.asarray(df[forecasts_col]), np.asarray(df[observations_col]), n_bins
            )
            plot_data = pd.DataFrame(
                {x_col: bin_centers, y_col: obs_freq, "count": bin_counts}
            )
        else:
            validate_dataframe(df, required_columns=[x_col, y_col])
            plot_data = df

        # Draw Reference Lines
        self.ax.plot([0, 1], [0, 1], "k--", label="Perfect Reliability")
        if climatology is not None:
            self.ax.axhline(
                climatology, color="gray", linestyle=":", label="Climatology"
            )
            self._draw_skill_regions(climatology)

        # Plot Data
        if label_col:
            for name, group in plot_data.groupby(label_col):
                # pop label from kwargs if it exists to avoid multiple values
                k = kwargs.copy()
                k.pop("label", None)
                self.ax.plot(group[x_col], group[y_col], marker="o", label=name, **k)
        else:
            k = kwargs.copy()
            label = k.pop("label", "Model")
            self.ax.plot(
                plot_data[x_col], plot_data[y_col], marker="o", label=label, **k
            )

        # Histogram Overlay (Sharpness)
        if show_hist and "count" in plot_data.columns:
            self._add_sharpness_histogram(plot_data, x_col)

        # Formatting
        self.ax.set_xlim(0, 1)
        self.ax.set_ylim(0, 1)
        self.ax.set_xlabel("Forecast Probability")
        self.ax.set_ylabel("Observed Relative Frequency")
        self.ax.set_aspect("equal")
        self.ax.grid(True, alpha=0.3)
        self.ax.legend()

    def _draw_skill_regions(self, clim):
        """Shades areas where BSS > 0."""
        x = np.linspace(0, 1, 100)
        y_no_skill = np.full_like(x, clim)
        y_perfect = x

        # Shade skill region (above no-skill towards perfect)
        self.ax.fill_between(
            x, y_no_skill, y_perfect, alpha=0.1, color="green", label="Skill Region"
        )

    def _add_sharpness_histogram(self, data, x_col):
        """Adds a small inset axes for sharpness histogram."""
        from mpl_toolkits.axes_grid1.inset_locator import inset_axes

        inset_ax = inset_axes(self.ax, width=1.5, height=1.2, loc="upper right")
        inset_ax.bar(data[x_col], data["count"], alpha=0.5, color="blue", width=0.08)
        inset_ax.set_title("Sharpness")
        inset_ax.set_xlabel(x_col)
        inset_ax.set_ylabel("Count")

plot(data, x_col='prob', y_col='freq', forecasts_col=None, observations_col=None, n_bins=10, climatology=None, label_col=None, show_hist=False, **kwargs)

Main plotting method.

Parameters:

Name Type Description Default
data Any

Input data.

required
x_col str

Forecast Probability bin center (for pre-binned).

'prob'
y_col str

Observed Frequency in bin (for pre-binned).

'freq'
forecasts_col str

Column of raw forecast probabilities [0,1].

None
observations_col str

Column of binary observations {0,1}.

None
n_bins int

Number of bins for reliability curve computation.

10
climatology Optional[float]

Sample climatology (mean(observations)).

None
label_col str

Grouping column.

None
show_hist bool

Whether to show frequency of usage histogram.

False
**kwargs

Matplotlib kwargs.

{}
Source code in src/monet_plots/plots/reliability_diagram.py
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
def plot(
    self,
    data: Any,
    x_col: str = "prob",
    y_col: str = "freq",
    forecasts_col: Optional[str] = None,
    observations_col: Optional[str] = None,
    n_bins: int = 10,
    climatology: Optional[float] = None,
    label_col: Optional[str] = None,
    show_hist: bool = False,
    **kwargs,
):
    """
    Main plotting method.

    Args:
        data: Input data.
        x_col (str): Forecast Probability bin center (for pre-binned).
        y_col (str): Observed Frequency in bin (for pre-binned).
        forecasts_col (str, optional): Column of raw forecast probabilities [0,1].
        observations_col (str, optional): Column of binary observations {0,1}.
        n_bins (int): Number of bins for reliability curve computation.
        climatology (Optional[float]): Sample climatology (mean(observations)).
        label_col (str, optional): Grouping column.
        show_hist (bool): Whether to show frequency of usage histogram.
        **kwargs: Matplotlib kwargs.
    """
    df = to_dataframe(data)
    # Compute if raw data provided
    if forecasts_col and observations_col:
        if climatology is None:
            climatology = float(df[observations_col].mean())
        bin_centers, obs_freq, bin_counts = compute_reliability_curve(
            np.asarray(df[forecasts_col]), np.asarray(df[observations_col]), n_bins
        )
        plot_data = pd.DataFrame(
            {x_col: bin_centers, y_col: obs_freq, "count": bin_counts}
        )
    else:
        validate_dataframe(df, required_columns=[x_col, y_col])
        plot_data = df

    # Draw Reference Lines
    self.ax.plot([0, 1], [0, 1], "k--", label="Perfect Reliability")
    if climatology is not None:
        self.ax.axhline(
            climatology, color="gray", linestyle=":", label="Climatology"
        )
        self._draw_skill_regions(climatology)

    # Plot Data
    if label_col:
        for name, group in plot_data.groupby(label_col):
            # pop label from kwargs if it exists to avoid multiple values
            k = kwargs.copy()
            k.pop("label", None)
            self.ax.plot(group[x_col], group[y_col], marker="o", label=name, **k)
    else:
        k = kwargs.copy()
        label = k.pop("label", "Model")
        self.ax.plot(
            plot_data[x_col], plot_data[y_col], marker="o", label=label, **k
        )

    # Histogram Overlay (Sharpness)
    if show_hist and "count" in plot_data.columns:
        self._add_sharpness_histogram(plot_data, x_col)

    # Formatting
    self.ax.set_xlim(0, 1)
    self.ax.set_ylim(0, 1)
    self.ax.set_xlabel("Forecast Probability")
    self.ax.set_ylabel("Observed Relative Frequency")
    self.ax.set_aspect("equal")
    self.ax.grid(True, alpha=0.3)
    self.ax.legend()

RidgelinePlot

Bases: BasePlot

Creates a ridgeline plot (joyplot) from an xarray DataArray or pandas DataFrame.

A ridgeline plot shows the distribution of a numeric value for several groups. Each group has its own distribution curve, often overlapping with others.

Attributes:

Name Type Description
data DataArray | Dataset | DataFrame

Normalized input data.

group_dim str

The dimension or column to group by for the Y-axis.

x str | None

The column name for values if data is a DataFrame or Dataset.

x_range tuple | None

Tuple (min, max) for the x-axis limits.

scale_factor float

Height scaling of the curves.

overlap float

Vertical spacing between curves.

cmap_name str

Colormap name for coloring curves.

title str | None

Plot title.

Source code in src/monet_plots/plots/ridgeline.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
class RidgelinePlot(BasePlot):
    """
    Creates a ridgeline plot (joyplot) from an xarray DataArray or pandas DataFrame.

    A ridgeline plot shows the distribution of a numeric value for several groups.
    Each group has its own distribution curve, often overlapping with others.

    Attributes:
        data (xr.DataArray | xr.Dataset | pd.DataFrame): Normalized input data.
        group_dim (str): The dimension or column to group by for the Y-axis.
        x (str | None): The column name for values if data is a DataFrame or Dataset.
        x_range (tuple | None): Tuple (min, max) for the x-axis limits.
        scale_factor (float): Height scaling of the curves.
        overlap (float): Vertical spacing between curves.
        cmap_name (str): Colormap name for coloring curves.
        title (str | None): Plot title.
    """

    def __init__(
        self,
        data: Any,
        group_dim: str,
        x: Optional[str] = None,
        *,
        x_range: Optional[Tuple[float, float]] = None,
        scale_factor: float = 1.0,
        overlap: float = 0.5,
        cmap: str = "RdBu_r",
        title: Optional[str] = None,
        bw_method: Optional[Any] = None,
        alpha: float = 0.8,
        quantiles: Optional[list[float]] = None,
        **kwargs: Any,
    ):
        """
        Initializes the ridgeline plot with data and settings.

        Args:
            data (Any): The data to plot (xr.DataArray, xr.Dataset, or pd.DataFrame).
            group_dim (str): The dimension or column to group by for the Y-axis.
            x (str, optional): The variable/column to plot distributions of.
                Required if data is a Dataset or DataFrame with multiple variables.
            x_range (tuple[float, float], optional): Tuple (min, max) for the x-axis limits.
                If None, auto-calculated.
            scale_factor (float): Height scaling of the curves. Defaults to 1.0.
            overlap (float): Vertical spacing between curves. Higher values mean more overlap.
                Defaults to 0.5.
            cmap (str): Colormap name for coloring curves. Defaults to 'RdBu_r'.
            title (str, optional): Plot title.
            bw_method (Any, optional): KDE bandwidth method (passed to scipy.stats.gaussian_kde).
            alpha (float): Transparency of the ridges. Defaults to 0.8.
            quantiles (list[float], optional): List of quantiles to display (e.g., [0.5]).
            **kwargs: Additional keyword arguments for BasePlot (figure/axes creation).
        """
        super().__init__(**kwargs)
        if self.ax is None:
            self.ax = self.fig.add_subplot(1, 1, 1)

        self.data = normalize_data(data, prefer_xarray=False)
        self.group_dim = group_dim
        self.x = x
        self.x_range = x_range
        self.scale_factor = scale_factor
        self.overlap = overlap
        self.cmap_name = cmap
        self.title = title
        self.bw_method = bw_method
        self.alpha = alpha
        self.quantiles = quantiles

    def plot(
        self, gradient: bool = True, color_by_group: bool = False, **kwargs: Any
    ) -> matplotlib.axes.Axes:
        """
        Generate the ridgeline plot.

        Args:
            gradient (bool): If True, fill curves with a gradient based on x-values.
            color_by_group (bool): If True, color each ridge by its group category.
                Takes precedence over gradient if True.
            **kwargs: Additional keyword arguments for formatting.

        Returns:
            matplotlib.axes.Axes: The axes object containing the plot.
        """
        import matplotlib.pyplot as plt

        from ..verification_metrics import _update_history

        # 1. Prepare Data and Groups
        if isinstance(self.data, xr.DataArray):
            da = self.data
            da_sorted = da.sortby(self.group_dim, ascending=False)
            groups = da_sorted[self.group_dim].values
            data_name = str(da.name) if da.name else "Value"

            if self.x_range is None:
                vmin = float(da.min().compute())
                vmax = float(da.max().compute())
            else:
                vmin, vmax = self.x_range

            if np.isnan(vmin) or np.isnan(vmax):
                raise ValueError("No valid data points found to plot.")

        elif isinstance(self.data, xr.Dataset):
            if self.x is None:
                self.x = list(self.data.data_vars)[0]
            da = self.data[self.x]
            da_sorted = da.sortby(self.group_dim, ascending=False)
            groups = da_sorted[self.group_dim].values
            data_name = str(da.name) if da.name else self.x

            if self.x_range is None:
                vmin = float(da.min().compute())
                vmax = float(da.max().compute())
            else:
                vmin, vmax = self.x_range

            if np.isnan(vmin) or np.isnan(vmax):
                raise ValueError("No valid data points found to plot.")

        else:
            # Pandas DataFrame
            df = self.data
            if self.x is None:
                # Try to find a numeric column that is not group_dim
                numeric_cols = df.select_dtypes(include=[np.number]).columns
                cols_to_use = [c for c in numeric_cols if c != self.group_dim]
                if not cols_to_use:
                    raise ValueError("No numeric columns found in DataFrame to plot.")
                self.x = cols_to_use[0]

            df_sorted = df.sort_values(self.group_dim, ascending=False)
            groups = df_sorted[self.group_dim].unique()
            data_name = str(self.x)

            if self.x_range is None:
                vmin = float(df[self.x].min())
                vmax = float(df[self.x].max())
            else:
                vmin, vmax = self.x_range

            if np.isnan(vmin) or np.isnan(vmax):
                raise ValueError("No valid data points found to plot.")

        # Setup X-axis grid for density calculation
        if self.x_range is None:
            pad = (vmax - vmin) * 0.1
            x_grid = np.linspace(vmin - pad, vmax + pad, 200)
        else:
            x_grid = np.linspace(vmin, vmax, 200)

        # Setup Colors
        cmap, norm = get_linear_scale(None, cmap=self.cmap_name, vmin=vmin, vmax=vmax)

        # 2. Iterate and Plot
        for i, val in enumerate(groups):
            if isinstance(self.data, (xr.DataArray, xr.Dataset)):
                # Handle DataArray/Dataset slice
                data_slice = da_sorted.sel({self.group_dim: val}).values.flatten()
            else:
                # Handle DataFrame slice
                data_slice = df_sorted[df_sorted[self.group_dim] == val][
                    self.x
                ].values.flatten()

            data_slice = data_slice[~np.isnan(data_slice)]

            if len(data_slice) < 2:
                continue

            try:
                kde = gaussian_kde(data_slice, bw_method=self.bw_method)
                y_density = kde(x_grid)
            except (np.linalg.LinAlgError, ValueError):
                continue

            # Scale density and calculate vertical baseline
            y_density_scaled = y_density * self.scale_factor
            baseline = -i * self.overlap
            y_final = baseline + y_density_scaled

            # Plot filling
            if color_by_group:
                # Use qualitative cmap or indexed colors
                color = plt.get_cmap("tab10")(i % 10)
                self.ax.fill_between(
                    x_grid,
                    baseline,
                    y_final,
                    facecolor=color,
                    edgecolor="white",
                    linewidth=0.5,
                    alpha=self.alpha,
                    zorder=len(groups) - i,
                )
                self.ax.plot(
                    x_grid,
                    y_final,
                    color="black",
                    linewidth=0.5,
                    zorder=len(groups) - i + 0.1,
                )
            elif gradient:
                # Plot in segments to create a gradient effect
                for j in range(len(x_grid) - 1):
                    self.ax.fill_between(
                        x_grid[j : j + 2],
                        baseline,
                        y_final[j : j + 2],
                        facecolor=cmap(norm(x_grid[j])),
                        edgecolor="none",
                        alpha=self.alpha,
                        zorder=len(groups) - i,
                    )
                # Add a clean top outline
                self.ax.plot(
                    x_grid,
                    y_final,
                    color="black",
                    linewidth=0.5,
                    zorder=len(groups) - i + 0.1,
                )
            else:
                # Single color based on the mean of this slice
                slice_mean = np.mean(data_slice)
                color = cmap(norm(slice_mean))
                self.ax.fill_between(
                    x_grid,
                    baseline,
                    y_final,
                    facecolor=color,
                    edgecolor="white",
                    linewidth=0.5,
                    alpha=self.alpha,
                    zorder=len(groups) - i,
                )
                self.ax.plot(
                    x_grid,
                    y_final,
                    color="black",
                    linewidth=0.5,
                    zorder=len(groups) - i + 0.1,
                )

            # 3. Add Quantiles
            if self.quantiles is not None:
                q_values = np.quantile(data_slice, self.quantiles)
                q_densities = kde(q_values) * self.scale_factor
                for qv, qd in zip(q_values, q_densities):
                    self.ax.vlines(
                        qv,
                        baseline,
                        baseline + qd,
                        color="black",
                        linestyle="--",
                        linewidth=0.8,
                        zorder=len(groups) - i + 0.2,
                    )

        # 3. Final Formatting
        self.ax.set_yticks([-i * self.overlap for i in range(len(groups))])
        self.ax.set_yticklabels(groups)
        self.ax.set_xlabel(data_name)
        if self.title:
            self.ax.set_title(self.title, pad=20)

        # Add Colorbar matching the x-axis scale if not coloring by group
        if not color_by_group:
            mappable = plt.cm.ScalarMappable(norm=norm, cmap=cmap)
            self.add_colorbar(mappable, label=data_name)

        # Add vertical gridlines as seen in the reference
        self.ax.xaxis.grid(True, linestyle="-", alpha=0.3)

        # Add a vertical zero line if the range crosses zero
        if x_grid.min() < 0 and x_grid.max() > 0:
            self.ax.axvline(0, color="black", alpha=0.3, linestyle="--", linewidth=1)

        # Remove unnecessary spines
        self.ax.spines["top"].set_visible(False)
        self.ax.spines["right"].set_visible(False)
        self.ax.spines["left"].set_visible(False)

        if isinstance(self.data, (xr.DataArray, xr.Dataset)):
            _update_history(self.data, f"Created ridgeline plot for {data_name}")

        return self.ax

__init__(data, group_dim, x=None, *, x_range=None, scale_factor=1.0, overlap=0.5, cmap='RdBu_r', title=None, bw_method=None, alpha=0.8, quantiles=None, **kwargs)

Initializes the ridgeline plot with data and settings.

Parameters:

Name Type Description Default
data Any

The data to plot (xr.DataArray, xr.Dataset, or pd.DataFrame).

required
group_dim str

The dimension or column to group by for the Y-axis.

required
x str

The variable/column to plot distributions of. Required if data is a Dataset or DataFrame with multiple variables.

None
x_range tuple[float, float]

Tuple (min, max) for the x-axis limits. If None, auto-calculated.

None
scale_factor float

Height scaling of the curves. Defaults to 1.0.

1.0
overlap float

Vertical spacing between curves. Higher values mean more overlap. Defaults to 0.5.

0.5
cmap str

Colormap name for coloring curves. Defaults to 'RdBu_r'.

'RdBu_r'
title str

Plot title.

None
bw_method Any

KDE bandwidth method (passed to scipy.stats.gaussian_kde).

None
alpha float

Transparency of the ridges. Defaults to 0.8.

0.8
quantiles list[float]

List of quantiles to display (e.g., [0.5]).

None
**kwargs Any

Additional keyword arguments for BasePlot (figure/axes creation).

{}
Source code in src/monet_plots/plots/ridgeline.py
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def __init__(
    self,
    data: Any,
    group_dim: str,
    x: Optional[str] = None,
    *,
    x_range: Optional[Tuple[float, float]] = None,
    scale_factor: float = 1.0,
    overlap: float = 0.5,
    cmap: str = "RdBu_r",
    title: Optional[str] = None,
    bw_method: Optional[Any] = None,
    alpha: float = 0.8,
    quantiles: Optional[list[float]] = None,
    **kwargs: Any,
):
    """
    Initializes the ridgeline plot with data and settings.

    Args:
        data (Any): The data to plot (xr.DataArray, xr.Dataset, or pd.DataFrame).
        group_dim (str): The dimension or column to group by for the Y-axis.
        x (str, optional): The variable/column to plot distributions of.
            Required if data is a Dataset or DataFrame with multiple variables.
        x_range (tuple[float, float], optional): Tuple (min, max) for the x-axis limits.
            If None, auto-calculated.
        scale_factor (float): Height scaling of the curves. Defaults to 1.0.
        overlap (float): Vertical spacing between curves. Higher values mean more overlap.
            Defaults to 0.5.
        cmap (str): Colormap name for coloring curves. Defaults to 'RdBu_r'.
        title (str, optional): Plot title.
        bw_method (Any, optional): KDE bandwidth method (passed to scipy.stats.gaussian_kde).
        alpha (float): Transparency of the ridges. Defaults to 0.8.
        quantiles (list[float], optional): List of quantiles to display (e.g., [0.5]).
        **kwargs: Additional keyword arguments for BasePlot (figure/axes creation).
    """
    super().__init__(**kwargs)
    if self.ax is None:
        self.ax = self.fig.add_subplot(1, 1, 1)

    self.data = normalize_data(data, prefer_xarray=False)
    self.group_dim = group_dim
    self.x = x
    self.x_range = x_range
    self.scale_factor = scale_factor
    self.overlap = overlap
    self.cmap_name = cmap
    self.title = title
    self.bw_method = bw_method
    self.alpha = alpha
    self.quantiles = quantiles

plot(gradient=True, color_by_group=False, **kwargs)

Generate the ridgeline plot.

Parameters:

Name Type Description Default
gradient bool

If True, fill curves with a gradient based on x-values.

True
color_by_group bool

If True, color each ridge by its group category. Takes precedence over gradient if True.

False
**kwargs Any

Additional keyword arguments for formatting.

{}

Returns:

Type Description
Axes

matplotlib.axes.Axes: The axes object containing the plot.

Source code in src/monet_plots/plots/ridgeline.py
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
def plot(
    self, gradient: bool = True, color_by_group: bool = False, **kwargs: Any
) -> matplotlib.axes.Axes:
    """
    Generate the ridgeline plot.

    Args:
        gradient (bool): If True, fill curves with a gradient based on x-values.
        color_by_group (bool): If True, color each ridge by its group category.
            Takes precedence over gradient if True.
        **kwargs: Additional keyword arguments for formatting.

    Returns:
        matplotlib.axes.Axes: The axes object containing the plot.
    """
    import matplotlib.pyplot as plt

    from ..verification_metrics import _update_history

    # 1. Prepare Data and Groups
    if isinstance(self.data, xr.DataArray):
        da = self.data
        da_sorted = da.sortby(self.group_dim, ascending=False)
        groups = da_sorted[self.group_dim].values
        data_name = str(da.name) if da.name else "Value"

        if self.x_range is None:
            vmin = float(da.min().compute())
            vmax = float(da.max().compute())
        else:
            vmin, vmax = self.x_range

        if np.isnan(vmin) or np.isnan(vmax):
            raise ValueError("No valid data points found to plot.")

    elif isinstance(self.data, xr.Dataset):
        if self.x is None:
            self.x = list(self.data.data_vars)[0]
        da = self.data[self.x]
        da_sorted = da.sortby(self.group_dim, ascending=False)
        groups = da_sorted[self.group_dim].values
        data_name = str(da.name) if da.name else self.x

        if self.x_range is None:
            vmin = float(da.min().compute())
            vmax = float(da.max().compute())
        else:
            vmin, vmax = self.x_range

        if np.isnan(vmin) or np.isnan(vmax):
            raise ValueError("No valid data points found to plot.")

    else:
        # Pandas DataFrame
        df = self.data
        if self.x is None:
            # Try to find a numeric column that is not group_dim
            numeric_cols = df.select_dtypes(include=[np.number]).columns
            cols_to_use = [c for c in numeric_cols if c != self.group_dim]
            if not cols_to_use:
                raise ValueError("No numeric columns found in DataFrame to plot.")
            self.x = cols_to_use[0]

        df_sorted = df.sort_values(self.group_dim, ascending=False)
        groups = df_sorted[self.group_dim].unique()
        data_name = str(self.x)

        if self.x_range is None:
            vmin = float(df[self.x].min())
            vmax = float(df[self.x].max())
        else:
            vmin, vmax = self.x_range

        if np.isnan(vmin) or np.isnan(vmax):
            raise ValueError("No valid data points found to plot.")

    # Setup X-axis grid for density calculation
    if self.x_range is None:
        pad = (vmax - vmin) * 0.1
        x_grid = np.linspace(vmin - pad, vmax + pad, 200)
    else:
        x_grid = np.linspace(vmin, vmax, 200)

    # Setup Colors
    cmap, norm = get_linear_scale(None, cmap=self.cmap_name, vmin=vmin, vmax=vmax)

    # 2. Iterate and Plot
    for i, val in enumerate(groups):
        if isinstance(self.data, (xr.DataArray, xr.Dataset)):
            # Handle DataArray/Dataset slice
            data_slice = da_sorted.sel({self.group_dim: val}).values.flatten()
        else:
            # Handle DataFrame slice
            data_slice = df_sorted[df_sorted[self.group_dim] == val][
                self.x
            ].values.flatten()

        data_slice = data_slice[~np.isnan(data_slice)]

        if len(data_slice) < 2:
            continue

        try:
            kde = gaussian_kde(data_slice, bw_method=self.bw_method)
            y_density = kde(x_grid)
        except (np.linalg.LinAlgError, ValueError):
            continue

        # Scale density and calculate vertical baseline
        y_density_scaled = y_density * self.scale_factor
        baseline = -i * self.overlap
        y_final = baseline + y_density_scaled

        # Plot filling
        if color_by_group:
            # Use qualitative cmap or indexed colors
            color = plt.get_cmap("tab10")(i % 10)
            self.ax.fill_between(
                x_grid,
                baseline,
                y_final,
                facecolor=color,
                edgecolor="white",
                linewidth=0.5,
                alpha=self.alpha,
                zorder=len(groups) - i,
            )
            self.ax.plot(
                x_grid,
                y_final,
                color="black",
                linewidth=0.5,
                zorder=len(groups) - i + 0.1,
            )
        elif gradient:
            # Plot in segments to create a gradient effect
            for j in range(len(x_grid) - 1):
                self.ax.fill_between(
                    x_grid[j : j + 2],
                    baseline,
                    y_final[j : j + 2],
                    facecolor=cmap(norm(x_grid[j])),
                    edgecolor="none",
                    alpha=self.alpha,
                    zorder=len(groups) - i,
                )
            # Add a clean top outline
            self.ax.plot(
                x_grid,
                y_final,
                color="black",
                linewidth=0.5,
                zorder=len(groups) - i + 0.1,
            )
        else:
            # Single color based on the mean of this slice
            slice_mean = np.mean(data_slice)
            color = cmap(norm(slice_mean))
            self.ax.fill_between(
                x_grid,
                baseline,
                y_final,
                facecolor=color,
                edgecolor="white",
                linewidth=0.5,
                alpha=self.alpha,
                zorder=len(groups) - i,
            )
            self.ax.plot(
                x_grid,
                y_final,
                color="black",
                linewidth=0.5,
                zorder=len(groups) - i + 0.1,
            )

        # 3. Add Quantiles
        if self.quantiles is not None:
            q_values = np.quantile(data_slice, self.quantiles)
            q_densities = kde(q_values) * self.scale_factor
            for qv, qd in zip(q_values, q_densities):
                self.ax.vlines(
                    qv,
                    baseline,
                    baseline + qd,
                    color="black",
                    linestyle="--",
                    linewidth=0.8,
                    zorder=len(groups) - i + 0.2,
                )

    # 3. Final Formatting
    self.ax.set_yticks([-i * self.overlap for i in range(len(groups))])
    self.ax.set_yticklabels(groups)
    self.ax.set_xlabel(data_name)
    if self.title:
        self.ax.set_title(self.title, pad=20)

    # Add Colorbar matching the x-axis scale if not coloring by group
    if not color_by_group:
        mappable = plt.cm.ScalarMappable(norm=norm, cmap=cmap)
        self.add_colorbar(mappable, label=data_name)

    # Add vertical gridlines as seen in the reference
    self.ax.xaxis.grid(True, linestyle="-", alpha=0.3)

    # Add a vertical zero line if the range crosses zero
    if x_grid.min() < 0 and x_grid.max() > 0:
        self.ax.axvline(0, color="black", alpha=0.3, linestyle="--", linewidth=1)

    # Remove unnecessary spines
    self.ax.spines["top"].set_visible(False)
    self.ax.spines["right"].set_visible(False)
    self.ax.spines["left"].set_visible(False)

    if isinstance(self.data, (xr.DataArray, xr.Dataset)):
        _update_history(self.data, f"Created ridgeline plot for {data_name}")

    return self.ax

ScatterPlot

Bases: BasePlot

Create a scatter plot with a regression line.

This plot shows the relationship between two variables and includes a linear regression model fit. It supports lazy evaluation for large Xarray/Dask datasets by delaying computation until the plot call.

Attributes

data : Union[xr.Dataset, xr.DataArray, pd.DataFrame] The input data for the plot. x : str The name of the variable for the x-axis. y : List[str] The names of the variables for the y-axis. c : Optional[str] The name of the variable used for colorizing points. colorbar : bool Whether to add a colorbar to the plot. title : Optional[str] The title for the plot.

Source code in src/monet_plots/plots/scatter.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
class ScatterPlot(BasePlot):
    """Create a scatter plot with a regression line.

    This plot shows the relationship between two variables and includes a
    linear regression model fit. It supports lazy evaluation for large
    Xarray/Dask datasets by delaying computation until the plot call.

    Attributes
    ----------
    data : Union[xr.Dataset, xr.DataArray, pd.DataFrame]
        The input data for the plot.
    x : str
        The name of the variable for the x-axis.
    y : List[str]
        The names of the variables for the y-axis.
    c : Optional[str]
        The name of the variable used for colorizing points.
    colorbar : bool
        Whether to add a colorbar to the plot.
    title : Optional[str]
        The title for the plot.
    """

    def __init__(
        self,
        data: Any = None,
        x: Optional[str] = None,
        y: Optional[Union[str, List[str]]] = None,
        c: Optional[str] = None,
        colorbar: bool = False,
        title: Optional[str] = None,
        fig: Optional[matplotlib.figure.Figure] = None,
        ax: Optional[matplotlib.axes.Axes] = None,
        df: Any = None,  # Backward compatibility alias
        **kwargs: Any,
    ) -> None:
        """Initialize the scatter plot.

        Parameters
        ----------
        data : Any, optional
            Input data. Can be a pandas DataFrame, xarray DataArray,
            xarray Dataset, or numpy ndarray, by default None.
        x : str, optional
            Variable name for the x-axis, by default None.
        y : Union[str, List[str]], optional
            Variable name(s) for the y-axis, by default None.
        c : str, optional
            Variable name for colorizing the points, by default None.
        colorbar : bool, optional
            Whether to add a colorbar, by default False.
        title : str, optional
            Title for the plot, by default None.
        fig : matplotlib.figure.Figure, optional
            An existing Figure object, by default None.
        ax : matplotlib.axes.Axes, optional
            An existing Axes object, by default None.
        df : Any, optional
            Alias for `data` for backward compatibility, by default None.
        **kwargs : Any
            Additional keyword arguments passed to BasePlot.
        """
        super().__init__(fig=fig, ax=ax, **kwargs)
        if self.ax is None:
            self.ax = self.fig.add_subplot(1, 1, 1)

        self.data = normalize_data(data if data is not None else df)
        self.x = x
        self.y = [y] if isinstance(y, str) else (y if y is not None else [])
        self.c = c
        self.colorbar = colorbar
        self.title = title

        if not self.x or not self.y:
            raise ValueError("Parameters 'x' and 'y' must be provided.")

        # Update history for provenance if Xarray
        if isinstance(self.data, (xr.DataArray, xr.Dataset)):
            history = self.data.attrs.get("history", "")
            self.data.attrs["history"] = f"Initialized ScatterPlot; {history}"

    def _get_regression_line(
        self, x_val: np.ndarray, y_val: np.ndarray
    ) -> tuple[np.ndarray, np.ndarray]:
        """Calculate regression line points using only endpoints.

        Parameters
        ----------
        x_val : np.ndarray
            The concrete x-axis data.
        y_val : np.ndarray
            The concrete y-axis data.

        Returns
        -------
        tuple[np.ndarray, np.ndarray]
            The x and y values for the regression line endpoints.
        """
        mask = ~np.isnan(x_val) & ~np.isnan(y_val)
        if not np.any(mask):
            return np.array([np.nan, np.nan]), np.array([np.nan, np.nan])

        m, b = np.polyfit(x_val[mask], y_val[mask], 1)
        x_min, x_max = np.nanmin(x_val), np.nanmax(x_val)
        x_reg = np.array([x_min, x_max])
        y_reg = m * x_reg + b
        return x_reg, y_reg

    def plot(
        self,
        scatter_kws: Optional[dict[str, Any]] = None,
        line_kws: Optional[dict[str, Any]] = None,
        **kwargs: Any,
    ) -> matplotlib.axes.Axes:
        """Generate a static publication-quality scatter plot (Track A).

        Parameters
        ----------
        scatter_kws : dict, optional
            Additional keyword arguments for `ax.scatter`.
        line_kws : dict, optional
            Additional keyword arguments for the regression `ax.plot`.
        **kwargs : Any
            Secondary way to pass keyword arguments to `ax.scatter`.
            Merged with `scatter_kws`.

        Returns
        -------
        matplotlib.axes.Axes
            The axes object with the scatter plot.

        Notes
        -----
        For massive datasets (> RAM), consider using Track B (Exploration)
        tools like `hvplot` with `rasterize=True`.
        """
        from ..plot_utils import get_plot_kwargs

        # Combine scatter_kws and kwargs
        s_kws = scatter_kws.copy() if scatter_kws is not None else {}
        s_kws.update(kwargs)

        l_kws = line_kws.copy() if line_kws is not None else {}

        # Aero Protocol Requirement: Mandatory transform for GeoAxes
        is_geo = hasattr(self.ax, "projection")
        if is_geo:
            s_kws.setdefault("transform", ccrs.PlateCarree())
            l_kws.setdefault("transform", ccrs.PlateCarree())

        transform = s_kws.get("transform")

        # Performance: Compute required variables once to avoid double work
        cols = [self.x] + self.y
        if self.c:
            cols.append(self.c)

        if hasattr(self.data, "compute"):
            # Sub-selection before compute to minimize data transfer
            subset = self.data[cols]
            concrete_data = subset.compute()
        else:
            concrete_data = self.data

        x_plot = concrete_data[self.x].values.flatten()

        for y_col in self.y:
            y_plot = concrete_data[y_col].values.flatten()

            if self.c is not None:
                c_plot = concrete_data[self.c].values.flatten()

                final_s_kwargs = get_plot_kwargs(c=c_plot, **s_kws)
                mappable = self.ax.scatter(x_plot, y_plot, **final_s_kwargs)

                if self.colorbar:
                    self.add_colorbar(mappable)
            else:
                final_s_kwargs = s_kws.copy()
                final_s_kwargs.setdefault("label", y_col)
                self.ax.scatter(x_plot, y_plot, **final_s_kwargs)

            # Add regression line using endpoints
            x_reg, y_reg = self._get_regression_line(x_plot, y_plot)

            final_l_kwargs = {
                "color": "red",
                "linestyle": "--",
                "label": "Fit" if (self.c is None and len(self.y) == 1) else None,
            }
            final_l_kwargs.update(l_kws)
            if transform:
                final_l_kwargs.setdefault("transform", transform)

            self.ax.plot(x_reg, y_reg, **final_l_kwargs)

        if len(self.y) > 1 and self.c is None:
            self.ax.legend()

        if self.title:
            self.ax.set_title(self.title)
        else:
            self.ax.set_title(f"Scatter: {self.x} vs {', '.join(self.y)}")

        self.ax.set_xlabel(self.x)
        self.ax.set_ylabel(", ".join(self.y) if len(self.y) > 1 else self.y[0])

        # Update history for provenance
        if isinstance(self.data, (xr.DataArray, xr.Dataset)):
            history = self.data.attrs.get("history", "")
            self.data.attrs["history"] = f"Generated ScatterPlot; {history}"

        return self.ax

    def hvplot(self, **kwargs: Any) -> Any:
        """Generate an interactive scatter plot using hvPlot (Track B).

        Parameters
        ----------
        **kwargs : Any
            Keyword arguments passed to `hvplot.scatter`.
            Common options include `cmap`, `title`, and `alpha`.
            `rasterize=True` is used by default for high performance.

        Returns
        -------
        holoviews.core.layout.Layout
            The interactive hvPlot object.
        """
        try:
            import hvplot.pandas  # noqa: F401
            import hvplot.xarray  # noqa: F401
        except ImportError:
            raise ImportError(
                "hvplot is required for interactive plotting. Install it with 'pip install hvplot'."
            )

        # Track B defaults
        plot_kwargs = {
            "x": self.x,
            "y": self.y[0] if len(self.y) == 1 else self.y,
            "rasterize": True,
        }
        if self.c:
            plot_kwargs["c"] = self.c

        plot_kwargs.update(kwargs)

        return self.data.hvplot.scatter(**plot_kwargs)

__init__(data=None, x=None, y=None, c=None, colorbar=False, title=None, fig=None, ax=None, df=None, **kwargs)

Initialize the scatter plot.

Parameters

data : Any, optional Input data. Can be a pandas DataFrame, xarray DataArray, xarray Dataset, or numpy ndarray, by default None. x : str, optional Variable name for the x-axis, by default None. y : Union[str, List[str]], optional Variable name(s) for the y-axis, by default None. c : str, optional Variable name for colorizing the points, by default None. colorbar : bool, optional Whether to add a colorbar, by default False. title : str, optional Title for the plot, by default None. fig : matplotlib.figure.Figure, optional An existing Figure object, by default None. ax : matplotlib.axes.Axes, optional An existing Axes object, by default None. df : Any, optional Alias for data for backward compatibility, by default None. **kwargs : Any Additional keyword arguments passed to BasePlot.

Source code in src/monet_plots/plots/scatter.py
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
def __init__(
    self,
    data: Any = None,
    x: Optional[str] = None,
    y: Optional[Union[str, List[str]]] = None,
    c: Optional[str] = None,
    colorbar: bool = False,
    title: Optional[str] = None,
    fig: Optional[matplotlib.figure.Figure] = None,
    ax: Optional[matplotlib.axes.Axes] = None,
    df: Any = None,  # Backward compatibility alias
    **kwargs: Any,
) -> None:
    """Initialize the scatter plot.

    Parameters
    ----------
    data : Any, optional
        Input data. Can be a pandas DataFrame, xarray DataArray,
        xarray Dataset, or numpy ndarray, by default None.
    x : str, optional
        Variable name for the x-axis, by default None.
    y : Union[str, List[str]], optional
        Variable name(s) for the y-axis, by default None.
    c : str, optional
        Variable name for colorizing the points, by default None.
    colorbar : bool, optional
        Whether to add a colorbar, by default False.
    title : str, optional
        Title for the plot, by default None.
    fig : matplotlib.figure.Figure, optional
        An existing Figure object, by default None.
    ax : matplotlib.axes.Axes, optional
        An existing Axes object, by default None.
    df : Any, optional
        Alias for `data` for backward compatibility, by default None.
    **kwargs : Any
        Additional keyword arguments passed to BasePlot.
    """
    super().__init__(fig=fig, ax=ax, **kwargs)
    if self.ax is None:
        self.ax = self.fig.add_subplot(1, 1, 1)

    self.data = normalize_data(data if data is not None else df)
    self.x = x
    self.y = [y] if isinstance(y, str) else (y if y is not None else [])
    self.c = c
    self.colorbar = colorbar
    self.title = title

    if not self.x or not self.y:
        raise ValueError("Parameters 'x' and 'y' must be provided.")

    # Update history for provenance if Xarray
    if isinstance(self.data, (xr.DataArray, xr.Dataset)):
        history = self.data.attrs.get("history", "")
        self.data.attrs["history"] = f"Initialized ScatterPlot; {history}"

hvplot(**kwargs)

Generate an interactive scatter plot using hvPlot (Track B).

Parameters

**kwargs : Any Keyword arguments passed to hvplot.scatter. Common options include cmap, title, and alpha. rasterize=True is used by default for high performance.

Returns

holoviews.core.layout.Layout The interactive hvPlot object.

Source code in src/monet_plots/plots/scatter.py
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
def hvplot(self, **kwargs: Any) -> Any:
    """Generate an interactive scatter plot using hvPlot (Track B).

    Parameters
    ----------
    **kwargs : Any
        Keyword arguments passed to `hvplot.scatter`.
        Common options include `cmap`, `title`, and `alpha`.
        `rasterize=True` is used by default for high performance.

    Returns
    -------
    holoviews.core.layout.Layout
        The interactive hvPlot object.
    """
    try:
        import hvplot.pandas  # noqa: F401
        import hvplot.xarray  # noqa: F401
    except ImportError:
        raise ImportError(
            "hvplot is required for interactive plotting. Install it with 'pip install hvplot'."
        )

    # Track B defaults
    plot_kwargs = {
        "x": self.x,
        "y": self.y[0] if len(self.y) == 1 else self.y,
        "rasterize": True,
    }
    if self.c:
        plot_kwargs["c"] = self.c

    plot_kwargs.update(kwargs)

    return self.data.hvplot.scatter(**plot_kwargs)

plot(scatter_kws=None, line_kws=None, **kwargs)

Generate a static publication-quality scatter plot (Track A).

Parameters

scatter_kws : dict, optional Additional keyword arguments for ax.scatter. line_kws : dict, optional Additional keyword arguments for the regression ax.plot. **kwargs : Any Secondary way to pass keyword arguments to ax.scatter. Merged with scatter_kws.

Returns

matplotlib.axes.Axes The axes object with the scatter plot.

Notes

For massive datasets (> RAM), consider using Track B (Exploration) tools like hvplot with rasterize=True.

Source code in src/monet_plots/plots/scatter.py
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
def plot(
    self,
    scatter_kws: Optional[dict[str, Any]] = None,
    line_kws: Optional[dict[str, Any]] = None,
    **kwargs: Any,
) -> matplotlib.axes.Axes:
    """Generate a static publication-quality scatter plot (Track A).

    Parameters
    ----------
    scatter_kws : dict, optional
        Additional keyword arguments for `ax.scatter`.
    line_kws : dict, optional
        Additional keyword arguments for the regression `ax.plot`.
    **kwargs : Any
        Secondary way to pass keyword arguments to `ax.scatter`.
        Merged with `scatter_kws`.

    Returns
    -------
    matplotlib.axes.Axes
        The axes object with the scatter plot.

    Notes
    -----
    For massive datasets (> RAM), consider using Track B (Exploration)
    tools like `hvplot` with `rasterize=True`.
    """
    from ..plot_utils import get_plot_kwargs

    # Combine scatter_kws and kwargs
    s_kws = scatter_kws.copy() if scatter_kws is not None else {}
    s_kws.update(kwargs)

    l_kws = line_kws.copy() if line_kws is not None else {}

    # Aero Protocol Requirement: Mandatory transform for GeoAxes
    is_geo = hasattr(self.ax, "projection")
    if is_geo:
        s_kws.setdefault("transform", ccrs.PlateCarree())
        l_kws.setdefault("transform", ccrs.PlateCarree())

    transform = s_kws.get("transform")

    # Performance: Compute required variables once to avoid double work
    cols = [self.x] + self.y
    if self.c:
        cols.append(self.c)

    if hasattr(self.data, "compute"):
        # Sub-selection before compute to minimize data transfer
        subset = self.data[cols]
        concrete_data = subset.compute()
    else:
        concrete_data = self.data

    x_plot = concrete_data[self.x].values.flatten()

    for y_col in self.y:
        y_plot = concrete_data[y_col].values.flatten()

        if self.c is not None:
            c_plot = concrete_data[self.c].values.flatten()

            final_s_kwargs = get_plot_kwargs(c=c_plot, **s_kws)
            mappable = self.ax.scatter(x_plot, y_plot, **final_s_kwargs)

            if self.colorbar:
                self.add_colorbar(mappable)
        else:
            final_s_kwargs = s_kws.copy()
            final_s_kwargs.setdefault("label", y_col)
            self.ax.scatter(x_plot, y_plot, **final_s_kwargs)

        # Add regression line using endpoints
        x_reg, y_reg = self._get_regression_line(x_plot, y_plot)

        final_l_kwargs = {
            "color": "red",
            "linestyle": "--",
            "label": "Fit" if (self.c is None and len(self.y) == 1) else None,
        }
        final_l_kwargs.update(l_kws)
        if transform:
            final_l_kwargs.setdefault("transform", transform)

        self.ax.plot(x_reg, y_reg, **final_l_kwargs)

    if len(self.y) > 1 and self.c is None:
        self.ax.legend()

    if self.title:
        self.ax.set_title(self.title)
    else:
        self.ax.set_title(f"Scatter: {self.x} vs {', '.join(self.y)}")

    self.ax.set_xlabel(self.x)
    self.ax.set_ylabel(", ".join(self.y) if len(self.y) > 1 else self.y[0])

    # Update history for provenance
    if isinstance(self.data, (xr.DataArray, xr.Dataset)):
        history = self.data.attrs.get("history", "")
        self.data.attrs["history"] = f"Generated ScatterPlot; {history}"

    return self.ax

ScorecardPlot

Bases: BasePlot

Scorecard Plot.

Heatmap table displaying performance metrics across multiple dimensions (e.g., Variables vs Lead Times), colored by performance relative to a baseline.

Functional Requirements: 1. Heatmap grid: Rows (Variables/Regions), Cols (Lead Times/Levels). 2. Color cells based on statistic (e.g., Difference from Baseline, RMSE ratio). 3. Annotate cells with symbols (+/-) or values indicating significance. 4. Handle Green (Better) / Red (Worse) color schemes correctly.

Edge Cases: - Missing data for some cells (show as white/gray). - Infinite values (clip or mask).

Source code in src/monet_plots/plots/scorecard.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
class ScorecardPlot(BasePlot):
    """
    Scorecard Plot.

    Heatmap table displaying performance metrics across multiple dimensions
    (e.g., Variables vs Lead Times), colored by performance relative to a baseline.

    Functional Requirements:
    1. Heatmap grid: Rows (Variables/Regions), Cols (Lead Times/Levels).
    2. Color cells based on statistic (e.g., Difference from Baseline, RMSE ratio).
    3. Annotate cells with symbols (+/-) or values indicating significance.
    4. Handle Green (Better) / Red (Worse) color schemes correctly.

    Edge Cases:
    - Missing data for some cells (show as white/gray).
    - Infinite values (clip or mask).
    """

    def __init__(self, fig=None, ax=None, **kwargs):
        super().__init__(fig=fig, ax=ax, **kwargs)

    def plot(
        self,
        data: Any,
        x_col: str,
        y_col: str,
        val_col: str,
        sig_col: Optional[str] = None,
        cmap: str = "RdYlGn",
        center: float = 0.0,
        annot_cols: Optional[list[str]] = None,
        cbar_labels: Optional[tuple[str, str]] = None,
        key_text: Optional[str] = None,
        **kwargs,
    ):
        """
        Main plotting method.

        Args:
            data (pd.DataFrame, np.ndarray, xr.Dataset, xr.DataArray): Long-format dataframe.
            x_col (str): Column for x-axis (Columns).
            y_col (str): Column for y-axis (Rows).
            val_col (str): Column for cell values (color).
            sig_col (str, optional): Column for significance (marker).
            cmap (str): Colormap.
            center (float): Center value for colormap divergence.
            annot_cols (list[str], optional): Columns to combine for cell annotations (e.g., ['mod', 'obs']).
            cbar_labels (tuple[str, str], optional): Labels for the left and right ends of the colorbar.
            key_text (str, optional): Text to display in a legend box at the top right.
            **kwargs: Seaborn heatmap kwargs.
        """
        df = to_dataframe(data).copy()
        validate_dataframe(df, required_columns=[x_col, y_col, val_col])

        # Extract title before passing kwargs to heatmap
        title = kwargs.pop("title", "Performance Scorecard")

        # Pivot Data
        pivot_data = df.pivot(index=y_col, columns=x_col, values=val_col)

        # Handle annotations
        if annot_cols:
            validate_dataframe(df, required_columns=annot_cols)
            # Create a combined annotation column
            df["_combined_annot"] = df[annot_cols[0]].map(
                lambda x: f"{x:.1f}" if pd.notna(x) else ""
            )
            for col in annot_cols[1:]:
                df["_combined_annot"] += " | " + df[col].map(
                    lambda x: f"{x:.1f}" if pd.notna(x) else ""
                )
            annot_data = df.pivot(index=y_col, columns=x_col, values="_combined_annot")
            kwargs["annot"] = annot_data
            kwargs["fmt"] = ""
        else:
            kwargs.setdefault("annot", True)
            kwargs.setdefault("fmt", ".2f")

        # Layout adjustments for WeatherMesh look
        cbar_ax = None
        is_weathermesh_layout = bool(cbar_labels or key_text)

        if is_weathermesh_layout:
            self.ax.set_title(title, pad=60)

            if cbar_labels:
                # Create a small axes for the colorbar at the top left
                cbar_ax = self.fig.add_axes([0.15, 0.85, 0.3, 0.02])
                kwargs["cbar_ax"] = cbar_ax
                kwargs["cbar_kws"] = kwargs.get("cbar_kws", {})
                kwargs["cbar_kws"]["orientation"] = "horizontal"

            if key_text:
                # Add a box at the top right
                self.fig.text(
                    0.85,
                    0.86,
                    key_text,
                    ha="right",
                    va="center",
                    bbox=dict(boxstyle="square", facecolor="white", edgecolor="black"),
                )
        else:
            self.ax.set_title(title)

        # Plot Heatmap
        kwargs.setdefault("linewidths", 0.5)
        kwargs.setdefault("linecolor", "lightgray")
        sns.heatmap(
            pivot_data,
            ax=self.ax,
            cmap=cmap,
            center=center,
            **kwargs,
        )

        # Post-process colorbar labels
        if cbar_ax and cbar_labels:
            cbar_ax.set_xticks([])
            cbar_ax.set_yticks([])
            self.fig.text(0.15, 0.83, cbar_labels[0], ha="left", va="top", fontsize=9)
            self.fig.text(0.45, 0.83, cbar_labels[1], ha="right", va="top", fontsize=9)

        # Add Significance Markers
        if sig_col:
            pivot_sig = df.pivot(index=y_col, columns=x_col, values=sig_col)
            self._overlay_significance(pivot_data, pivot_sig)

        self.ax.set_xlabel(x_col.title())
        if is_weathermesh_layout:
            self.ax.set_ylabel("")
            self.ax.tick_params(axis="x", rotation=0)
        else:
            self.ax.set_ylabel(y_col.title())
            self.ax.tick_params(axis="x", rotation=45)

        # Invert Y axis to have cities A-Z from top to bottom if desired,
        # but pivot might have already sorted them.

    def _overlay_significance(self, data_grid, sig_grid):
        """
        Overlays markers for significant differences.

        Assumes sig_grid contains boolean or truthy values for significance.
        """
        rows, cols = data_grid.shape
        for i in range(rows):
            for j in range(cols):
                sig_val = sig_grid.iloc[i, j]
                if pd.notna(sig_val) and bool(sig_val):
                    # Position at center of cell
                    self.ax.text(
                        j + 0.5,
                        rows - i - 0.5,
                        "*",
                        ha="center",
                        va="center",
                        fontweight="bold",
                        fontsize=12,
                        color="black",
                        zorder=5,
                    )

plot(data, x_col, y_col, val_col, sig_col=None, cmap='RdYlGn', center=0.0, annot_cols=None, cbar_labels=None, key_text=None, **kwargs)

Main plotting method.

Parameters:

Name Type Description Default
data (DataFrame, ndarray, Dataset, DataArray)

Long-format dataframe.

required
x_col str

Column for x-axis (Columns).

required
y_col str

Column for y-axis (Rows).

required
val_col str

Column for cell values (color).

required
sig_col str

Column for significance (marker).

None
cmap str

Colormap.

'RdYlGn'
center float

Center value for colormap divergence.

0.0
annot_cols list[str]

Columns to combine for cell annotations (e.g., ['mod', 'obs']).

None
cbar_labels tuple[str, str]

Labels for the left and right ends of the colorbar.

None
key_text str

Text to display in a legend box at the top right.

None
**kwargs

Seaborn heatmap kwargs.

{}
Source code in src/monet_plots/plots/scorecard.py
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
def plot(
    self,
    data: Any,
    x_col: str,
    y_col: str,
    val_col: str,
    sig_col: Optional[str] = None,
    cmap: str = "RdYlGn",
    center: float = 0.0,
    annot_cols: Optional[list[str]] = None,
    cbar_labels: Optional[tuple[str, str]] = None,
    key_text: Optional[str] = None,
    **kwargs,
):
    """
    Main plotting method.

    Args:
        data (pd.DataFrame, np.ndarray, xr.Dataset, xr.DataArray): Long-format dataframe.
        x_col (str): Column for x-axis (Columns).
        y_col (str): Column for y-axis (Rows).
        val_col (str): Column for cell values (color).
        sig_col (str, optional): Column for significance (marker).
        cmap (str): Colormap.
        center (float): Center value for colormap divergence.
        annot_cols (list[str], optional): Columns to combine for cell annotations (e.g., ['mod', 'obs']).
        cbar_labels (tuple[str, str], optional): Labels for the left and right ends of the colorbar.
        key_text (str, optional): Text to display in a legend box at the top right.
        **kwargs: Seaborn heatmap kwargs.
    """
    df = to_dataframe(data).copy()
    validate_dataframe(df, required_columns=[x_col, y_col, val_col])

    # Extract title before passing kwargs to heatmap
    title = kwargs.pop("title", "Performance Scorecard")

    # Pivot Data
    pivot_data = df.pivot(index=y_col, columns=x_col, values=val_col)

    # Handle annotations
    if annot_cols:
        validate_dataframe(df, required_columns=annot_cols)
        # Create a combined annotation column
        df["_combined_annot"] = df[annot_cols[0]].map(
            lambda x: f"{x:.1f}" if pd.notna(x) else ""
        )
        for col in annot_cols[1:]:
            df["_combined_annot"] += " | " + df[col].map(
                lambda x: f"{x:.1f}" if pd.notna(x) else ""
            )
        annot_data = df.pivot(index=y_col, columns=x_col, values="_combined_annot")
        kwargs["annot"] = annot_data
        kwargs["fmt"] = ""
    else:
        kwargs.setdefault("annot", True)
        kwargs.setdefault("fmt", ".2f")

    # Layout adjustments for WeatherMesh look
    cbar_ax = None
    is_weathermesh_layout = bool(cbar_labels or key_text)

    if is_weathermesh_layout:
        self.ax.set_title(title, pad=60)

        if cbar_labels:
            # Create a small axes for the colorbar at the top left
            cbar_ax = self.fig.add_axes([0.15, 0.85, 0.3, 0.02])
            kwargs["cbar_ax"] = cbar_ax
            kwargs["cbar_kws"] = kwargs.get("cbar_kws", {})
            kwargs["cbar_kws"]["orientation"] = "horizontal"

        if key_text:
            # Add a box at the top right
            self.fig.text(
                0.85,
                0.86,
                key_text,
                ha="right",
                va="center",
                bbox=dict(boxstyle="square", facecolor="white", edgecolor="black"),
            )
    else:
        self.ax.set_title(title)

    # Plot Heatmap
    kwargs.setdefault("linewidths", 0.5)
    kwargs.setdefault("linecolor", "lightgray")
    sns.heatmap(
        pivot_data,
        ax=self.ax,
        cmap=cmap,
        center=center,
        **kwargs,
    )

    # Post-process colorbar labels
    if cbar_ax and cbar_labels:
        cbar_ax.set_xticks([])
        cbar_ax.set_yticks([])
        self.fig.text(0.15, 0.83, cbar_labels[0], ha="left", va="top", fontsize=9)
        self.fig.text(0.45, 0.83, cbar_labels[1], ha="right", va="top", fontsize=9)

    # Add Significance Markers
    if sig_col:
        pivot_sig = df.pivot(index=y_col, columns=x_col, values=sig_col)
        self._overlay_significance(pivot_data, pivot_sig)

    self.ax.set_xlabel(x_col.title())
    if is_weathermesh_layout:
        self.ax.set_ylabel("")
        self.ax.tick_params(axis="x", rotation=0)
    else:
        self.ax.set_ylabel(y_col.title())
        self.ax.tick_params(axis="x", rotation=45)

SoccerPlot

Bases: BasePlot

Soccer plot for model evaluation.

This plot shows model performance by plotting bias (x-axis) against error (y-axis). It typically includes 'goal' and 'criteria' zones to visually assess if the model meets specific performance standards.

Attributes

data : Union[pd.DataFrame, xr.Dataset, xr.DataArray] The input data for the plot. bias_data : Union[pd.Series, xr.DataArray] Calculated or provided bias values. error_data : Union[pd.Series, xr.DataArray] Calculated or provided error values.

Source code in src/monet_plots/plots/soccer.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
class SoccerPlot(BasePlot):
    """Soccer plot for model evaluation.

    This plot shows model performance by plotting bias (x-axis) against error (y-axis).
    It typically includes 'goal' and 'criteria' zones to visually assess if the
    model meets specific performance standards.

    Attributes
    ----------
    data : Union[pd.DataFrame, xr.Dataset, xr.DataArray]
        The input data for the plot.
    bias_data : Union[pd.Series, xr.DataArray]
        Calculated or provided bias values.
    error_data : Union[pd.Series, xr.DataArray]
        Calculated or provided error values.
    """

    def __init__(
        self,
        data: Any,
        *,
        obs_col: Optional[str] = None,
        mod_col: Optional[str] = None,
        bias_col: Optional[str] = None,
        error_col: Optional[str] = None,
        label_col: Optional[str] = None,
        metric: str = "fractional",
        goal: Optional[Dict[str, float]] = {"bias": 30.0, "error": 50.0},
        criteria: Optional[Dict[str, float]] = {"bias": 60.0, "error": 75.0},
        fig: Optional[matplotlib.figure.Figure] = None,
        ax: Optional[matplotlib.axes.Axes] = None,
        **kwargs: Any,
    ):
        """
        Initialize Soccer Plot.

        Parameters
        ----------
        data : Any
            Input data. Can be a pandas DataFrame, xarray DataArray,
            xarray Dataset, or numpy ndarray.
        obs_col : str, optional
            Column name for observations. Required if bias/error not provided.
        mod_col : str, optional
            Column name for model values. Required if bias/error not provided.
        bias_col : str, optional
            Column name for pre-calculated bias.
        error_col : str, optional
            Column name for pre-calculated error.
        label_col : str, optional
            Column name for labeling points.
        metric : str, optional
            Type of metric to calculate if obs/mod provided ('fractional' or 'normalized'),
            by default "fractional".
        goal : Dict[str, float], optional
            Dictionary with 'bias' and 'error' thresholds for the goal zone,
            by default {"bias": 30.0, "error": 50.0}.
        criteria : Dict[str, float], optional
            Dictionary with 'bias' and 'error' thresholds for the criteria zone,
            by default {"bias": 60.0, "error": 75.0}.
        fig : matplotlib.figure.Figure, optional
            An existing Figure object.
        ax : matplotlib.axes.Axes, optional
            An existing Axes object.
        **kwargs : Any
            Arguments passed to BasePlot.
        """
        super().__init__(fig=fig, ax=ax, **kwargs)
        if self.ax is None:
            self.ax = self.fig.add_subplot(1, 1, 1)

        self.data = normalize_data(data, prefer_xarray=False)
        self.bias_col = bias_col
        self.error_col = error_col
        self.label_col = label_col
        self.metric = metric
        self.goal = goal
        self.criteria = criteria

        # Track provenance for Xarray
        if isinstance(self.data, (xr.DataArray, xr.Dataset)):
            history = self.data.attrs.get("history", "")
            self.data.attrs["history"] = f"Initialized SoccerPlot; {history}"

        if bias_col is None or error_col is None:
            if obs_col is None or mod_col is None:
                raise ValueError(
                    "Must provide either bias_col/error_col or obs_col/mod_col"
                )
            self._calculate_metrics(obs_col, mod_col)
        else:
            self.bias_data = self.data[bias_col]
            self.error_data = self.data[error_col]

    def _calculate_metrics(self, obs_col: str, mod_col: str) -> None:
        """Calculate MFB/MFE or NMB/NME using vectorized operations.

        Parameters
        ----------
        obs_col : str
            Column/variable name for observations.
        mod_col : str
            Column/variable name for model values.
        """
        obs = self.data[obs_col]
        mod = self.data[mod_col]

        if self.metric == "fractional":
            # For scatter plots, we want element-wise metrics (dim=[])
            self.bias_data = compute_mfb(obs, mod, dim=[])
            self.error_data = compute_mfe(obs, mod, dim=[])
            self.xlabel = "Mean Fractional Bias (%)"
            self.ylabel = "Mean Fractional Error (%)"
        elif self.metric == "normalized":
            # Element-wise normalized bias/error (dim=[])
            self.bias_data = compute_nmb(obs, mod, dim=[])
            self.error_data = compute_nme(obs, mod, dim=[])
            self.xlabel = "Normalized Mean Bias (%)"
            self.ylabel = "Normalized Mean Error (%)"
        else:
            raise ValueError("metric must be 'fractional' or 'normalized'")

        # Track provenance (Aero Protocol)
        if isinstance(self.bias_data, xr.DataArray):
            _update_history(self.bias_data, f"Calculated {self.metric} soccer metrics")
            _update_history(self.error_data, f"Calculated {self.metric} soccer metrics")

    def plot(self, **kwargs: Any) -> matplotlib.axes.Axes:
        """Generate the soccer plot.

        Parameters
        ----------
        **kwargs : Any
            Keyword arguments passed to `ax.scatter`.

        Returns
        -------
        matplotlib.axes.Axes
            The axes object with the soccer plot.
        """
        # Draw zones
        if self.criteria:
            rect_crit = patches.Rectangle(
                (-self.criteria["bias"], 0),
                2 * self.criteria["bias"],
                self.criteria["error"],
                linewidth=1,
                edgecolor="lightgrey",
                facecolor="lightgrey",
                alpha=0.3,
                label="Criteria",
                zorder=0,
            )
            self.ax.add_patch(rect_crit)

        if self.goal:
            rect_goal = patches.Rectangle(
                (-self.goal["bias"], 0),
                2 * self.goal["bias"],
                self.goal["error"],
                linewidth=1,
                edgecolor="grey",
                facecolor="grey",
                alpha=0.3,
                label="Goal",
                zorder=1,
            )
            self.ax.add_patch(rect_goal)

        # Plot points - compute simultaneously if lazy (Aero Protocol)
        bias = self.bias_data
        error = self.error_data

        if hasattr(bias, "compute") or hasattr(error, "compute"):
            import dask

            bias, error = dask.compute(bias, error)

        scatter_kwargs = {"zorder": 5}
        scatter_kwargs.update(kwargs)
        self.ax.scatter(bias, error, **scatter_kwargs)

        # Labels
        if self.label_col is not None:
            labels = self.data[self.label_col]
            if hasattr(labels, "values"):
                labels = labels.values

            for i, txt in enumerate(labels):
                # Ensure we have scalar values for annotation
                b_val = bias.iloc[i] if hasattr(bias, "iloc") else bias[i]
                e_val = error.iloc[i] if hasattr(error, "iloc") else error[i]
                self.ax.annotate(
                    str(txt),
                    (b_val, e_val),
                    xytext=(5, 5),
                    textcoords="offset points",
                )

        # Setup axes
        limit = 0
        if self.criteria:
            limit = max(limit, self.criteria["bias"] * 1.1)
            limit_y = self.criteria["error"] * 1.1
        else:
            limit = max(limit, float(np.abs(bias).max()) * 1.1)
            limit_y = float(error.max()) * 1.1

        self.ax.set_xlim(-limit, limit)
        self.ax.set_ylim(0, limit_y)

        self.ax.axvline(0, color="k", linestyle="--", alpha=0.5)
        self.ax.set_xlabel(getattr(self, "xlabel", "Bias (%)"))
        self.ax.set_ylabel(getattr(self, "ylabel", "Error (%)"))
        self.ax.grid(True, linestyle=":", alpha=0.6)

        # Update history for provenance
        if isinstance(self.data, (xr.DataArray, xr.Dataset)):
            _update_history(self.data, "Generated SoccerPlot")

        return self.ax

    def hvplot(self, **kwargs: Any) -> Any:
        """Generate an interactive soccer plot using hvPlot (Track B).

        Parameters
        ----------
        **kwargs : Any
            Keyword arguments passed to `hvplot.scatter`.

        Returns
        -------
        holoviews.core.Element
            The interactive soccer plot.
        """
        try:
            import hvplot.pandas  # noqa: F401
            import hvplot.xarray  # noqa: F401
        except ImportError:
            raise ImportError(
                "hvplot is required for interactive plotting. Install it with 'pip install hvplot'."
            )

        # Combine into a single Dataset or DataFrame for plotting
        if isinstance(self.bias_data, xr.DataArray):
            ds = xr.Dataset({"bias": self.bias_data, "error": self.error_data})
            if self.label_col is not None:
                ds[self.label_col] = self.data[self.label_col]
            plot_obj = ds
        else:
            import pandas as pd

            plot_obj = pd.DataFrame({"bias": self.bias_data, "error": self.error_data})
            if self.label_col is not None:
                plot_obj[self.label_col] = self.data[self.label_col]

        plot_kwargs = {
            "x": "bias",
            "y": "error",
            "xlabel": getattr(self, "xlabel", "Bias (%)"),
            "ylabel": getattr(self, "ylabel", "Error (%)"),
            "title": "Soccer Plot",
        }
        if self.label_col:
            plot_kwargs["hover_cols"] = [self.label_col]

        plot_kwargs.update(kwargs)

        return plot_obj.hvplot.scatter(**plot_kwargs)

__init__(data, *, obs_col=None, mod_col=None, bias_col=None, error_col=None, label_col=None, metric='fractional', goal={'bias': 30.0, 'error': 50.0}, criteria={'bias': 60.0, 'error': 75.0}, fig=None, ax=None, **kwargs)

Initialize Soccer Plot.

Parameters

data : Any Input data. Can be a pandas DataFrame, xarray DataArray, xarray Dataset, or numpy ndarray. obs_col : str, optional Column name for observations. Required if bias/error not provided. mod_col : str, optional Column name for model values. Required if bias/error not provided. bias_col : str, optional Column name for pre-calculated bias. error_col : str, optional Column name for pre-calculated error. label_col : str, optional Column name for labeling points. metric : str, optional Type of metric to calculate if obs/mod provided ('fractional' or 'normalized'), by default "fractional". goal : Dict[str, float], optional Dictionary with 'bias' and 'error' thresholds for the goal zone, by default {"bias": 30.0, "error": 50.0}. criteria : Dict[str, float], optional Dictionary with 'bias' and 'error' thresholds for the criteria zone, by default {"bias": 60.0, "error": 75.0}. fig : matplotlib.figure.Figure, optional An existing Figure object. ax : matplotlib.axes.Axes, optional An existing Axes object. **kwargs : Any Arguments passed to BasePlot.

Source code in src/monet_plots/plots/soccer.py
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
def __init__(
    self,
    data: Any,
    *,
    obs_col: Optional[str] = None,
    mod_col: Optional[str] = None,
    bias_col: Optional[str] = None,
    error_col: Optional[str] = None,
    label_col: Optional[str] = None,
    metric: str = "fractional",
    goal: Optional[Dict[str, float]] = {"bias": 30.0, "error": 50.0},
    criteria: Optional[Dict[str, float]] = {"bias": 60.0, "error": 75.0},
    fig: Optional[matplotlib.figure.Figure] = None,
    ax: Optional[matplotlib.axes.Axes] = None,
    **kwargs: Any,
):
    """
    Initialize Soccer Plot.

    Parameters
    ----------
    data : Any
        Input data. Can be a pandas DataFrame, xarray DataArray,
        xarray Dataset, or numpy ndarray.
    obs_col : str, optional
        Column name for observations. Required if bias/error not provided.
    mod_col : str, optional
        Column name for model values. Required if bias/error not provided.
    bias_col : str, optional
        Column name for pre-calculated bias.
    error_col : str, optional
        Column name for pre-calculated error.
    label_col : str, optional
        Column name for labeling points.
    metric : str, optional
        Type of metric to calculate if obs/mod provided ('fractional' or 'normalized'),
        by default "fractional".
    goal : Dict[str, float], optional
        Dictionary with 'bias' and 'error' thresholds for the goal zone,
        by default {"bias": 30.0, "error": 50.0}.
    criteria : Dict[str, float], optional
        Dictionary with 'bias' and 'error' thresholds for the criteria zone,
        by default {"bias": 60.0, "error": 75.0}.
    fig : matplotlib.figure.Figure, optional
        An existing Figure object.
    ax : matplotlib.axes.Axes, optional
        An existing Axes object.
    **kwargs : Any
        Arguments passed to BasePlot.
    """
    super().__init__(fig=fig, ax=ax, **kwargs)
    if self.ax is None:
        self.ax = self.fig.add_subplot(1, 1, 1)

    self.data = normalize_data(data, prefer_xarray=False)
    self.bias_col = bias_col
    self.error_col = error_col
    self.label_col = label_col
    self.metric = metric
    self.goal = goal
    self.criteria = criteria

    # Track provenance for Xarray
    if isinstance(self.data, (xr.DataArray, xr.Dataset)):
        history = self.data.attrs.get("history", "")
        self.data.attrs["history"] = f"Initialized SoccerPlot; {history}"

    if bias_col is None or error_col is None:
        if obs_col is None or mod_col is None:
            raise ValueError(
                "Must provide either bias_col/error_col or obs_col/mod_col"
            )
        self._calculate_metrics(obs_col, mod_col)
    else:
        self.bias_data = self.data[bias_col]
        self.error_data = self.data[error_col]

hvplot(**kwargs)

Generate an interactive soccer plot using hvPlot (Track B).

Parameters

**kwargs : Any Keyword arguments passed to hvplot.scatter.

Returns

holoviews.core.Element The interactive soccer plot.

Source code in src/monet_plots/plots/soccer.py
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
def hvplot(self, **kwargs: Any) -> Any:
    """Generate an interactive soccer plot using hvPlot (Track B).

    Parameters
    ----------
    **kwargs : Any
        Keyword arguments passed to `hvplot.scatter`.

    Returns
    -------
    holoviews.core.Element
        The interactive soccer plot.
    """
    try:
        import hvplot.pandas  # noqa: F401
        import hvplot.xarray  # noqa: F401
    except ImportError:
        raise ImportError(
            "hvplot is required for interactive plotting. Install it with 'pip install hvplot'."
        )

    # Combine into a single Dataset or DataFrame for plotting
    if isinstance(self.bias_data, xr.DataArray):
        ds = xr.Dataset({"bias": self.bias_data, "error": self.error_data})
        if self.label_col is not None:
            ds[self.label_col] = self.data[self.label_col]
        plot_obj = ds
    else:
        import pandas as pd

        plot_obj = pd.DataFrame({"bias": self.bias_data, "error": self.error_data})
        if self.label_col is not None:
            plot_obj[self.label_col] = self.data[self.label_col]

    plot_kwargs = {
        "x": "bias",
        "y": "error",
        "xlabel": getattr(self, "xlabel", "Bias (%)"),
        "ylabel": getattr(self, "ylabel", "Error (%)"),
        "title": "Soccer Plot",
    }
    if self.label_col:
        plot_kwargs["hover_cols"] = [self.label_col]

    plot_kwargs.update(kwargs)

    return plot_obj.hvplot.scatter(**plot_kwargs)

plot(**kwargs)

Generate the soccer plot.

Parameters

**kwargs : Any Keyword arguments passed to ax.scatter.

Returns

matplotlib.axes.Axes The axes object with the soccer plot.

Source code in src/monet_plots/plots/soccer.py
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
def plot(self, **kwargs: Any) -> matplotlib.axes.Axes:
    """Generate the soccer plot.

    Parameters
    ----------
    **kwargs : Any
        Keyword arguments passed to `ax.scatter`.

    Returns
    -------
    matplotlib.axes.Axes
        The axes object with the soccer plot.
    """
    # Draw zones
    if self.criteria:
        rect_crit = patches.Rectangle(
            (-self.criteria["bias"], 0),
            2 * self.criteria["bias"],
            self.criteria["error"],
            linewidth=1,
            edgecolor="lightgrey",
            facecolor="lightgrey",
            alpha=0.3,
            label="Criteria",
            zorder=0,
        )
        self.ax.add_patch(rect_crit)

    if self.goal:
        rect_goal = patches.Rectangle(
            (-self.goal["bias"], 0),
            2 * self.goal["bias"],
            self.goal["error"],
            linewidth=1,
            edgecolor="grey",
            facecolor="grey",
            alpha=0.3,
            label="Goal",
            zorder=1,
        )
        self.ax.add_patch(rect_goal)

    # Plot points - compute simultaneously if lazy (Aero Protocol)
    bias = self.bias_data
    error = self.error_data

    if hasattr(bias, "compute") or hasattr(error, "compute"):
        import dask

        bias, error = dask.compute(bias, error)

    scatter_kwargs = {"zorder": 5}
    scatter_kwargs.update(kwargs)
    self.ax.scatter(bias, error, **scatter_kwargs)

    # Labels
    if self.label_col is not None:
        labels = self.data[self.label_col]
        if hasattr(labels, "values"):
            labels = labels.values

        for i, txt in enumerate(labels):
            # Ensure we have scalar values for annotation
            b_val = bias.iloc[i] if hasattr(bias, "iloc") else bias[i]
            e_val = error.iloc[i] if hasattr(error, "iloc") else error[i]
            self.ax.annotate(
                str(txt),
                (b_val, e_val),
                xytext=(5, 5),
                textcoords="offset points",
            )

    # Setup axes
    limit = 0
    if self.criteria:
        limit = max(limit, self.criteria["bias"] * 1.1)
        limit_y = self.criteria["error"] * 1.1
    else:
        limit = max(limit, float(np.abs(bias).max()) * 1.1)
        limit_y = float(error.max()) * 1.1

    self.ax.set_xlim(-limit, limit)
    self.ax.set_ylim(0, limit_y)

    self.ax.axvline(0, color="k", linestyle="--", alpha=0.5)
    self.ax.set_xlabel(getattr(self, "xlabel", "Bias (%)"))
    self.ax.set_ylabel(getattr(self, "ylabel", "Error (%)"))
    self.ax.grid(True, linestyle=":", alpha=0.6)

    # Update history for provenance
    if isinstance(self.data, (xr.DataArray, xr.Dataset)):
        _update_history(self.data, "Generated SoccerPlot")

    return self.ax

SpatialBiasScatterPlot

Bases: SpatialPlot

Create a spatial scatter plot showing bias between model and observations.

The scatter points are colored by the difference (model - observations) and sized by the absolute magnitude of this difference, making larger biases more visible. This class supports both Track A (publication) and Track B (interactive) visualization.

Source code in src/monet_plots/plots/spatial_bias_scatter.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
class SpatialBiasScatterPlot(SpatialPlot):
    """Create a spatial scatter plot showing bias between model and observations.

    The scatter points are colored by the difference (model - observations) and
    sized by the absolute magnitude of this difference, making larger biases
    more visible. This class supports both Track A (publication) and
    Track B (interactive) visualization.
    """

    def __init__(
        self,
        data: Any,
        col1: str,
        col2: str,
        vmin: float | None = None,
        vmax: float | None = None,
        ncolors: int = 15,
        fact: float = 1.5,
        cmap: str = "RdBu_r",
        **kwargs: Any,
    ) -> None:
        """Initialize the plot with data and map projection.

        Parameters
        ----------
        data : Any
            Input data. Preferred format is xarray.Dataset or xarray.DataArray
            with 'latitude' and 'longitude' (or 'lat' and 'lon') coordinates.
        col1 : str
            Name of the first variable (e.g., observations).
        col2 : str
            Name of the second variable (e.g., model). Bias is calculated
            as col2 - col1.
        vmin : float, optional
            Minimum for colorscale, by default None.
        vmax : float, optional
            Maximum for colorscale, by default None.
        ncolors : int, optional
            Number of discrete colors, by default 15.
        fact : float, optional
            Scaling factor for point sizes, by default 1.5.
        cmap : str, optional
            Colormap for bias values, by default "RdBu_r".
        **kwargs : Any
            Additional keyword arguments for map creation, passed to
            :class:`monet_plots.plots.spatial.SpatialPlot`.
        """
        super().__init__(**kwargs)
        self.data = normalize_data(data)
        self.col1 = col1
        self.col2 = col2
        self.vmin = vmin
        self.vmax = vmax
        self.ncolors = ncolors
        self.fact = fact
        self.cmap = cmap

        _update_history(self.data, "Initialized monet-plots.SpatialBiasScatterPlot")

    def plot(self, **kwargs: Any) -> matplotlib.axes.Axes:
        """Generate a static publication-quality spatial bias scatter plot (Track A).

        Parameters
        ----------
        **kwargs : Any
            Keyword arguments passed to `matplotlib.pyplot.scatter`.
            Map features (e.g., `coastlines=True`) can also be passed here.

        Returns
        -------
        matplotlib.axes.Axes
            The matplotlib axes object containing the plot.
        """
        # Separate feature kwargs from scatter kwargs
        scatter_kwargs = self.add_features(**kwargs)

        # Handle different data types for bias calculation
        if isinstance(self.data, (xr.Dataset, xr.DataArray)):
            # Vectorized calculation using Xarray/Dask
            diff = self.data[self.col2] - self.data[self.col1]

            # Efficient percentile calculation
            try:
                top_val = diff.assign_coords(
                    {"abs_diff": np.abs(diff)}
                ).abs_diff.quantile(0.95)
                if hasattr(top_val, "compute"):
                    top = float(top_val.compute())
                else:
                    top = float(top_val)
            except (ImportError, AttributeError, ValueError):
                top = float(np.nanquantile(np.abs(diff.values), 0.95))

            top = np.around(top)

            # Identify coordinates
            lat_name = next(
                (
                    c
                    for c in ["latitude", "lat"]
                    if c in self.data.coords
                    or c in self.data.data_vars
                    or c in self.data.dims
                ),
                "lat",
            )
            lon_name = next(
                (
                    c
                    for c in ["longitude", "lon"]
                    if c in self.data.coords
                    or c in self.data.data_vars
                    or c in self.data.dims
                ),
                "lon",
            )

            # Compute only what's necessary for plotting
            plot_ds = xr.Dataset(
                {"diff": diff, "lat": self.data[lat_name], "lon": self.data[lon_name]}
            )

            # Drop NaNs before compute to minimize transfer
            if plot_ds.dims:
                plot_ds = plot_ds.dropna(dim=list(plot_ds.dims)[0])

            concrete = plot_ds.compute()
            diff_vals = concrete["diff"].values
            lat_vals = concrete["lat"].values
            lon_vals = concrete["lon"].values
        else:
            # Fallback for Pandas
            df = self.data.dropna(subset=[self.col1, self.col2])
            diff_vals = (df[self.col2] - df[self.col1]).values
            lat_name = next((c for c in ["latitude", "lat"] if c in df.columns), "lat")
            lon_name = next((c for c in ["longitude", "lon"] if c in df.columns), "lon")
            lat_vals = df[lat_name].values
            lon_vals = df[lon_name].values
            top = np.around(np.nanquantile(np.abs(diff_vals), 0.95))

        # Use scaling tools
        cmap, norm = get_discrete_scale(
            diff_vals, cmap=self.cmap, n_levels=self.ncolors, vmin=-top, vmax=top
        )

        # Create colorbar
        mappable = plt.cm.ScalarMappable(norm=norm, cmap=cmap)
        cbar = self.add_colorbar(mappable, format="%1.2g")
        cbar.ax.tick_params(labelsize=10)

        ss = np.abs(diff_vals) / top * 100.0 * self.fact
        ss[ss > 300] = 300.0

        # Prepare scatter kwargs
        final_scatter_kwargs = get_plot_kwargs(
            cmap=cmap,
            norm=norm,
            s=ss,
            c=diff_vals,
            transform=ccrs.PlateCarree(),
            edgecolors="k",
            linewidths=0.25,
            alpha=0.7,
            **scatter_kwargs,
        )

        self.ax.scatter(
            lon_vals,
            lat_vals,
            **final_scatter_kwargs,
        )
        return self.ax

    def hvplot(self, **kwargs: Any) -> Any:
        """Generate an interactive spatial bias scatter plot using hvPlot (Track B).

        Parameters
        ----------
        **kwargs : Any
            Keyword arguments passed to `hvplot.points`.
            Common options include `cmap`, `title`, and `alpha`.
            `rasterize=True` is used by default for high performance.

        Returns
        -------
        holoviews.core.layout.Layout
            The interactive hvPlot object.
        """
        try:
            import hvplot.pandas  # noqa: F401
            import hvplot.xarray  # noqa: F401
        except ImportError:
            raise ImportError(
                "hvplot is required for interactive plotting. Install it with 'pip install hvplot'."
            )

        import pandas as pd

        if isinstance(self.data, pd.DataFrame):
            lat_name = next(
                (c for c in ["latitude", "lat"] if c in self.data.columns), "lat"
            )
            lon_name = next(
                (c for c in ["longitude", "lon"] if c in self.data.columns), "lon"
            )

            ds_plot = self.data.copy()
            ds_plot["bias"] = ds_plot[self.col2] - ds_plot[self.col1]
            plot_target = ds_plot
        else:
            lat_name = next(
                (
                    c
                    for c in ["latitude", "lat"]
                    if c in self.data.coords or c in self.data.dims
                ),
                "lat",
            )
            lon_name = next(
                (
                    c
                    for c in ["longitude", "lon"]
                    if c in self.data.coords or c in self.data.dims
                ),
                "lon",
            )

            ds_plot = self.data.copy()
            ds_plot["bias"] = ds_plot[self.col2] - ds_plot[self.col1]
            _update_history(ds_plot, "Calculated bias for hvplot")
            plot_target = ds_plot

        # Track B defaults
        plot_kwargs = {
            "x": lon_name,
            "y": lat_name,
            "c": "bias",
            "geo": True,
            "rasterize": True,
            "cmap": self.cmap,
        }

        plot_kwargs.update(kwargs)

        return plot_target.hvplot.points(**plot_kwargs)

__init__(data, col1, col2, vmin=None, vmax=None, ncolors=15, fact=1.5, cmap='RdBu_r', **kwargs)

Initialize the plot with data and map projection.

Parameters

data : Any Input data. Preferred format is xarray.Dataset or xarray.DataArray with 'latitude' and 'longitude' (or 'lat' and 'lon') coordinates. col1 : str Name of the first variable (e.g., observations). col2 : str Name of the second variable (e.g., model). Bias is calculated as col2 - col1. vmin : float, optional Minimum for colorscale, by default None. vmax : float, optional Maximum for colorscale, by default None. ncolors : int, optional Number of discrete colors, by default 15. fact : float, optional Scaling factor for point sizes, by default 1.5. cmap : str, optional Colormap for bias values, by default "RdBu_r". **kwargs : Any Additional keyword arguments for map creation, passed to :class:monet_plots.plots.spatial.SpatialPlot.

Source code in src/monet_plots/plots/spatial_bias_scatter.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def __init__(
    self,
    data: Any,
    col1: str,
    col2: str,
    vmin: float | None = None,
    vmax: float | None = None,
    ncolors: int = 15,
    fact: float = 1.5,
    cmap: str = "RdBu_r",
    **kwargs: Any,
) -> None:
    """Initialize the plot with data and map projection.

    Parameters
    ----------
    data : Any
        Input data. Preferred format is xarray.Dataset or xarray.DataArray
        with 'latitude' and 'longitude' (or 'lat' and 'lon') coordinates.
    col1 : str
        Name of the first variable (e.g., observations).
    col2 : str
        Name of the second variable (e.g., model). Bias is calculated
        as col2 - col1.
    vmin : float, optional
        Minimum for colorscale, by default None.
    vmax : float, optional
        Maximum for colorscale, by default None.
    ncolors : int, optional
        Number of discrete colors, by default 15.
    fact : float, optional
        Scaling factor for point sizes, by default 1.5.
    cmap : str, optional
        Colormap for bias values, by default "RdBu_r".
    **kwargs : Any
        Additional keyword arguments for map creation, passed to
        :class:`monet_plots.plots.spatial.SpatialPlot`.
    """
    super().__init__(**kwargs)
    self.data = normalize_data(data)
    self.col1 = col1
    self.col2 = col2
    self.vmin = vmin
    self.vmax = vmax
    self.ncolors = ncolors
    self.fact = fact
    self.cmap = cmap

    _update_history(self.data, "Initialized monet-plots.SpatialBiasScatterPlot")

hvplot(**kwargs)

Generate an interactive spatial bias scatter plot using hvPlot (Track B).

Parameters

**kwargs : Any Keyword arguments passed to hvplot.points. Common options include cmap, title, and alpha. rasterize=True is used by default for high performance.

Returns

holoviews.core.layout.Layout The interactive hvPlot object.

Source code in src/monet_plots/plots/spatial_bias_scatter.py
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
def hvplot(self, **kwargs: Any) -> Any:
    """Generate an interactive spatial bias scatter plot using hvPlot (Track B).

    Parameters
    ----------
    **kwargs : Any
        Keyword arguments passed to `hvplot.points`.
        Common options include `cmap`, `title`, and `alpha`.
        `rasterize=True` is used by default for high performance.

    Returns
    -------
    holoviews.core.layout.Layout
        The interactive hvPlot object.
    """
    try:
        import hvplot.pandas  # noqa: F401
        import hvplot.xarray  # noqa: F401
    except ImportError:
        raise ImportError(
            "hvplot is required for interactive plotting. Install it with 'pip install hvplot'."
        )

    import pandas as pd

    if isinstance(self.data, pd.DataFrame):
        lat_name = next(
            (c for c in ["latitude", "lat"] if c in self.data.columns), "lat"
        )
        lon_name = next(
            (c for c in ["longitude", "lon"] if c in self.data.columns), "lon"
        )

        ds_plot = self.data.copy()
        ds_plot["bias"] = ds_plot[self.col2] - ds_plot[self.col1]
        plot_target = ds_plot
    else:
        lat_name = next(
            (
                c
                for c in ["latitude", "lat"]
                if c in self.data.coords or c in self.data.dims
            ),
            "lat",
        )
        lon_name = next(
            (
                c
                for c in ["longitude", "lon"]
                if c in self.data.coords or c in self.data.dims
            ),
            "lon",
        )

        ds_plot = self.data.copy()
        ds_plot["bias"] = ds_plot[self.col2] - ds_plot[self.col1]
        _update_history(ds_plot, "Calculated bias for hvplot")
        plot_target = ds_plot

    # Track B defaults
    plot_kwargs = {
        "x": lon_name,
        "y": lat_name,
        "c": "bias",
        "geo": True,
        "rasterize": True,
        "cmap": self.cmap,
    }

    plot_kwargs.update(kwargs)

    return plot_target.hvplot.points(**plot_kwargs)

plot(**kwargs)

Generate a static publication-quality spatial bias scatter plot (Track A).

Parameters

**kwargs : Any Keyword arguments passed to matplotlib.pyplot.scatter. Map features (e.g., coastlines=True) can also be passed here.

Returns

matplotlib.axes.Axes The matplotlib axes object containing the plot.

Source code in src/monet_plots/plots/spatial_bias_scatter.py
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
def plot(self, **kwargs: Any) -> matplotlib.axes.Axes:
    """Generate a static publication-quality spatial bias scatter plot (Track A).

    Parameters
    ----------
    **kwargs : Any
        Keyword arguments passed to `matplotlib.pyplot.scatter`.
        Map features (e.g., `coastlines=True`) can also be passed here.

    Returns
    -------
    matplotlib.axes.Axes
        The matplotlib axes object containing the plot.
    """
    # Separate feature kwargs from scatter kwargs
    scatter_kwargs = self.add_features(**kwargs)

    # Handle different data types for bias calculation
    if isinstance(self.data, (xr.Dataset, xr.DataArray)):
        # Vectorized calculation using Xarray/Dask
        diff = self.data[self.col2] - self.data[self.col1]

        # Efficient percentile calculation
        try:
            top_val = diff.assign_coords(
                {"abs_diff": np.abs(diff)}
            ).abs_diff.quantile(0.95)
            if hasattr(top_val, "compute"):
                top = float(top_val.compute())
            else:
                top = float(top_val)
        except (ImportError, AttributeError, ValueError):
            top = float(np.nanquantile(np.abs(diff.values), 0.95))

        top = np.around(top)

        # Identify coordinates
        lat_name = next(
            (
                c
                for c in ["latitude", "lat"]
                if c in self.data.coords
                or c in self.data.data_vars
                or c in self.data.dims
            ),
            "lat",
        )
        lon_name = next(
            (
                c
                for c in ["longitude", "lon"]
                if c in self.data.coords
                or c in self.data.data_vars
                or c in self.data.dims
            ),
            "lon",
        )

        # Compute only what's necessary for plotting
        plot_ds = xr.Dataset(
            {"diff": diff, "lat": self.data[lat_name], "lon": self.data[lon_name]}
        )

        # Drop NaNs before compute to minimize transfer
        if plot_ds.dims:
            plot_ds = plot_ds.dropna(dim=list(plot_ds.dims)[0])

        concrete = plot_ds.compute()
        diff_vals = concrete["diff"].values
        lat_vals = concrete["lat"].values
        lon_vals = concrete["lon"].values
    else:
        # Fallback for Pandas
        df = self.data.dropna(subset=[self.col1, self.col2])
        diff_vals = (df[self.col2] - df[self.col1]).values
        lat_name = next((c for c in ["latitude", "lat"] if c in df.columns), "lat")
        lon_name = next((c for c in ["longitude", "lon"] if c in df.columns), "lon")
        lat_vals = df[lat_name].values
        lon_vals = df[lon_name].values
        top = np.around(np.nanquantile(np.abs(diff_vals), 0.95))

    # Use scaling tools
    cmap, norm = get_discrete_scale(
        diff_vals, cmap=self.cmap, n_levels=self.ncolors, vmin=-top, vmax=top
    )

    # Create colorbar
    mappable = plt.cm.ScalarMappable(norm=norm, cmap=cmap)
    cbar = self.add_colorbar(mappable, format="%1.2g")
    cbar.ax.tick_params(labelsize=10)

    ss = np.abs(diff_vals) / top * 100.0 * self.fact
    ss[ss > 300] = 300.0

    # Prepare scatter kwargs
    final_scatter_kwargs = get_plot_kwargs(
        cmap=cmap,
        norm=norm,
        s=ss,
        c=diff_vals,
        transform=ccrs.PlateCarree(),
        edgecolors="k",
        linewidths=0.25,
        alpha=0.7,
        **scatter_kwargs,
    )

    self.ax.scatter(
        lon_vals,
        lat_vals,
        **final_scatter_kwargs,
    )
    return self.ax

SpatialContourPlot

Bases: SpatialPlot

Create a contour plot on a map with an optional discrete colorbar.

This class provides an xarray-native interface for visualizing spatial data with continuous values. It supports both Track A (publication-quality static plots) and Track B (interactive exploration).

Source code in src/monet_plots/plots/spatial_contour.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
class SpatialContourPlot(SpatialPlot):
    """Create a contour plot on a map with an optional discrete colorbar.

    This class provides an xarray-native interface for visualizing spatial
    data with continuous values. It supports both Track A (publication-quality
    static plots) and Track B (interactive exploration).
    """

    def __new__(
        cls,
        modelvar: Any,
        gridobj: Any | None = None,
        date: Any | None = None,
        **kwargs: Any,
    ) -> Any:
        """Redirect to SpatialFacetGridPlot if faceting is requested.

        This enables a unified interface for both single-panel and multi-panel
        spatial plots, following Xarray's plotting conventions.

        Parameters
        ----------
        modelvar : Any
            The input data to contour.
        gridobj : Any, optional
            Object with LAT and LON variables, by default None.
        date : Any, optional
            Date/time for the plot title, by default None.
        **kwargs : Any
            Additional keyword arguments. If faceting arguments (e.g., `col`,
            `row`, or `col_wrap`) are provided, redirects to `SpatialFacetGridPlot`.

        Returns
        -------
        Any
            An instance of SpatialContourPlot or SpatialFacetGridPlot.
        """
        from .facet_grid import SpatialFacetGridPlot

        ax = kwargs.get("ax")

        # Aligns with Xarray's trigger for faceting
        facet_kwargs = ["col", "row", "col_wrap"]
        is_faceting = any(kwargs.get(k) is not None for k in facet_kwargs)

        # Redirect to FacetGrid if faceting requested and no existing axes
        if ax is None and is_faceting:
            return SpatialFacetGridPlot(modelvar, **kwargs)

        # Also redirect if input is a Dataset with multiple variables
        if (
            ax is None
            and isinstance(modelvar, xr.Dataset)
            and len(modelvar.data_vars) > 1
        ):
            # Default to faceting by variable if not specified
            kwargs.setdefault("col", "variable")
            return SpatialFacetGridPlot(modelvar, **kwargs)

        return super().__new__(cls)

    def __init__(
        self,
        modelvar: Any,
        gridobj: Any | None = None,
        date: datetime | None = None,
        discrete: bool = True,
        ncolors: int | None = None,
        dtype: str = "int",
        col: str | None = None,
        row: str | None = None,
        col_wrap: int | None = None,
        size: float | None = None,
        aspect: float | None = None,
        **kwargs: Any,
    ) -> None:
        """Initialize the spatial contour plot.

        Parameters
        ----------
        modelvar : Any
            The input data to contour. Preferred format is an xarray DataArray.
        gridobj : Any, optional
            Object with LAT and LON variables to determine extent, by default None.
        date : datetime, optional
            Date/time for the plot title, by default None.
        discrete : bool, optional
            If True, use a discrete colorbar, by default True.
        ncolors : int, optional
            Number of discrete colors for the colorbar, by default None.
        dtype : str, optional
            Data type for colorbar tick labels, by default "int".
        col : str, optional
            Dimension name to facet by columns. Aligns with Xarray.
        row : str, optional
            Dimension name to facet by rows. Aligns with Xarray.
        col_wrap : int, optional
            Number of columns before wrapping. Aligns with Xarray.
        size : float, optional
            Height (in inches) of each facet. Aligns with Xarray.
        aspect : float, optional
            Aspect ratio of each facet. Aligns with Xarray.
        **kwargs : Any
            Keyword arguments passed to :class:`SpatialPlot` for map features
            and projection.
        """
        # Initialize the map canvas via SpatialPlot
        super().__init__(**kwargs)

        # Standardize data to Xarray for consistency and lazy evaluation
        self.modelvar = normalize_data(modelvar)
        if isinstance(self.modelvar, xr.Dataset) and len(self.modelvar.data_vars) == 1:
            self.modelvar = self.modelvar[list(self.modelvar.data_vars)[0]]

        self.gridobj = gridobj
        self.date = date
        self.discrete = discrete
        self.ncolors = ncolors
        self.dtype = dtype

        # Aero Protocol: Centralized coordinate identification
        try:
            self.lon_coord, self.lat_coord = self._identify_coords(self.modelvar)
        except ValueError:
            self.lon_coord = kwargs.get("lon_coord", "lon")
            self.lat_coord = kwargs.get("lat_coord", "lat")

        # Ensure coordinates are monotonic for correct plotting
        self.modelvar = self._ensure_monotonic(
            self.modelvar, self.lon_coord, self.lat_coord
        )

        _update_history(self.modelvar, "Initialized monet-plots.SpatialContourPlot")

    def plot(self, **kwargs: Any) -> Axes:
        """Generate a static publication-quality spatial contour plot (Track A).

        Parameters
        ----------
        **kwargs : Any
            Keyword arguments passed to `matplotlib.pyplot.contourf`.
            Common options include `cmap`, `levels`, `vmin`, `vmax`, and `alpha`.
            Map features (e.g., `coastlines=True`) can also be passed here.

        Returns
        -------
        matplotlib.axes.Axes
            The matplotlib axes object containing the plot.
        """
        lat = None
        lon = None

        if hasattr(self.gridobj, "variables"):
            # Handle legacy gridobj
            try:
                lat_var = self.gridobj.variables["LAT"]
                lon_var = self.gridobj.variables["LON"]

                # Flexible indexing based on dimension count
                if lat_var.ndim == 4:
                    lat = lat_var[0, 0, :, :].squeeze()
                    lon = lon_var[0, 0, :, :].squeeze()
                elif lat_var.ndim == 3:
                    lat = lat_var[0, :, :].squeeze()
                    lon = lon_var[0, :, :].squeeze()
                else:
                    lat = lat_var.squeeze()
                    lon = lon_var.squeeze()
            except (AttributeError, KeyError):
                pass
        elif self.gridobj is not None:
            # Assume it's already an array or similar
            try:
                lat = self.gridobj.LAT
                lon = self.gridobj.LON
            except AttributeError:
                pass

        # Automatically compute extent if not provided
        if "extent" not in kwargs:
            if lon is not None and lat is not None:
                kwargs["extent"] = [
                    float(lon.min()),
                    float(lon.max()),
                    float(lat.min()),
                    float(lat.max()),
                ]
            else:
                kwargs["extent"] = self._get_extent_from_data(
                    self.modelvar, self.lon_coord, self.lat_coord
                )

        # Draw map features and get remaining kwargs for contourf
        plot_kwargs = self.add_features(**kwargs)

        # Set default contour settings
        plot_kwargs.setdefault("cmap", "viridis")

        # Data is in lat/lon, so specify transform
        plot_kwargs.setdefault("transform", ccrs.PlateCarree())

        # For coordinates and values, we prefer passing the xarray objects directly.
        # This allows Matplotlib to handle the conversion and maintains
        # parity for Dask-backed arrays.
        if lon is None or lat is None:
            longitude = self.modelvar[self.lon_coord]
            latitude = self.modelvar[self.lat_coord]
        else:
            longitude = lon
            latitude = lat

        # Handle colormap and normalization
        final_kwargs = get_plot_kwargs(**plot_kwargs)

        mesh = self.ax.contourf(longitude, latitude, self.modelvar, **final_kwargs)

        # Handle colorbar
        if self.discrete:
            cmap = final_kwargs.get("cmap")
            levels = final_kwargs.get("levels")
            ncolors = self.ncolors
            if ncolors is None and levels is not None:
                if isinstance(levels, int):
                    ncolors = levels - 1
                    # Use a single compute call for efficiency (Aero Protocol)
                    try:
                        import dask

                        dmin, dmax = dask.compute(
                            self.modelvar.min(), self.modelvar.max()
                        )
                    except (ImportError, AttributeError):
                        dmin, dmax = self.modelvar.min(), self.modelvar.max()

                    # Handle pandas DataFrame where .min() returns a Series
                    try:
                        fmin, fmax = float(dmin), float(dmax)
                    except (TypeError, ValueError):
                        # dmin/dmax could be Series if modelvar was a DataFrame
                        fmin = (
                            float(dmin.min())
                            if hasattr(dmin, "min")
                            else float(np.min(dmin))
                        )
                        fmax = (
                            float(dmax.max())
                            if hasattr(dmax, "max")
                            else float(np.max(dmax))
                        )
                    levels_seq = np.linspace(fmin, fmax, levels)
                else:
                    ncolors = len(levels) - 1
                    levels_seq = levels
            else:
                levels_seq = levels

            if levels_seq is None:
                # Fallback: calculate from data to ensure a discrete colorbar
                # if requested but no levels were provided.
                try:
                    import dask

                    dmin, dmax = dask.compute(self.modelvar.min(), self.modelvar.max())
                except (ImportError, AttributeError):
                    dmin, dmax = self.modelvar.min(), self.modelvar.max()

                # Handle pandas DataFrame where .min() returns a Series
                try:
                    fmin, fmax = float(dmin), float(dmax)
                except (TypeError, ValueError):
                    fmin = (
                        float(dmin.min())
                        if hasattr(dmin, "min")
                        else float(np.min(dmin))
                    )
                    fmax = (
                        float(dmax.max())
                        if hasattr(dmax, "max")
                        else float(np.max(dmax))
                    )
                n_lev = self.ncolors if self.ncolors is not None else 10
                levels_seq = np.linspace(fmin, fmax, n_lev + 1)
                ncolors = n_lev

            if levels_seq is not None:
                colorbar_index(
                    ncolors,
                    cmap,
                    minval=levels_seq[0],
                    maxval=levels_seq[-1],
                    dtype=self.dtype,
                    ax=self.ax,
                )
        else:
            self.add_colorbar(mesh)

        if self.date:
            titstring = self.date.strftime("%B %d %Y %H")
            self.ax.set_title(titstring)

        self.fig.tight_layout()
        return self.ax

    def hvplot(self, **kwargs: Any) -> Any:
        """Generate an interactive spatial contour plot using hvPlot (Track B).

        Parameters
        ----------
        **kwargs : Any
            Keyword arguments passed to `hvplot.contourf`.
            Common options include `cmap`, `levels`, `title`, and `alpha`.

        Returns
        -------
        holoviews.core.layout.Layout
            The interactive hvPlot object.
        """
        try:
            import hvplot.xarray  # noqa: F401
        except ImportError:
            raise ImportError(
                "hvplot is required for interactive plotting. Install it with 'pip install hvplot'."
            )

        # Track B defaults
        plot_kwargs = {
            "x": self.lon_coord,
            "y": self.lat_coord,
            "geo": True,
            "rasterize": True,
            "cmap": "viridis",
            "kind": "contourf",
        }
        plot_kwargs.update(kwargs)

        return self.modelvar.hvplot(**plot_kwargs)

__init__(modelvar, gridobj=None, date=None, discrete=True, ncolors=None, dtype='int', col=None, row=None, col_wrap=None, size=None, aspect=None, **kwargs)

Initialize the spatial contour plot.

Parameters

modelvar : Any The input data to contour. Preferred format is an xarray DataArray. gridobj : Any, optional Object with LAT and LON variables to determine extent, by default None. date : datetime, optional Date/time for the plot title, by default None. discrete : bool, optional If True, use a discrete colorbar, by default True. ncolors : int, optional Number of discrete colors for the colorbar, by default None. dtype : str, optional Data type for colorbar tick labels, by default "int". col : str, optional Dimension name to facet by columns. Aligns with Xarray. row : str, optional Dimension name to facet by rows. Aligns with Xarray. col_wrap : int, optional Number of columns before wrapping. Aligns with Xarray. size : float, optional Height (in inches) of each facet. Aligns with Xarray. aspect : float, optional Aspect ratio of each facet. Aligns with Xarray. **kwargs : Any Keyword arguments passed to :class:SpatialPlot for map features and projection.

Source code in src/monet_plots/plots/spatial_contour.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def __init__(
    self,
    modelvar: Any,
    gridobj: Any | None = None,
    date: datetime | None = None,
    discrete: bool = True,
    ncolors: int | None = None,
    dtype: str = "int",
    col: str | None = None,
    row: str | None = None,
    col_wrap: int | None = None,
    size: float | None = None,
    aspect: float | None = None,
    **kwargs: Any,
) -> None:
    """Initialize the spatial contour plot.

    Parameters
    ----------
    modelvar : Any
        The input data to contour. Preferred format is an xarray DataArray.
    gridobj : Any, optional
        Object with LAT and LON variables to determine extent, by default None.
    date : datetime, optional
        Date/time for the plot title, by default None.
    discrete : bool, optional
        If True, use a discrete colorbar, by default True.
    ncolors : int, optional
        Number of discrete colors for the colorbar, by default None.
    dtype : str, optional
        Data type for colorbar tick labels, by default "int".
    col : str, optional
        Dimension name to facet by columns. Aligns with Xarray.
    row : str, optional
        Dimension name to facet by rows. Aligns with Xarray.
    col_wrap : int, optional
        Number of columns before wrapping. Aligns with Xarray.
    size : float, optional
        Height (in inches) of each facet. Aligns with Xarray.
    aspect : float, optional
        Aspect ratio of each facet. Aligns with Xarray.
    **kwargs : Any
        Keyword arguments passed to :class:`SpatialPlot` for map features
        and projection.
    """
    # Initialize the map canvas via SpatialPlot
    super().__init__(**kwargs)

    # Standardize data to Xarray for consistency and lazy evaluation
    self.modelvar = normalize_data(modelvar)
    if isinstance(self.modelvar, xr.Dataset) and len(self.modelvar.data_vars) == 1:
        self.modelvar = self.modelvar[list(self.modelvar.data_vars)[0]]

    self.gridobj = gridobj
    self.date = date
    self.discrete = discrete
    self.ncolors = ncolors
    self.dtype = dtype

    # Aero Protocol: Centralized coordinate identification
    try:
        self.lon_coord, self.lat_coord = self._identify_coords(self.modelvar)
    except ValueError:
        self.lon_coord = kwargs.get("lon_coord", "lon")
        self.lat_coord = kwargs.get("lat_coord", "lat")

    # Ensure coordinates are monotonic for correct plotting
    self.modelvar = self._ensure_monotonic(
        self.modelvar, self.lon_coord, self.lat_coord
    )

    _update_history(self.modelvar, "Initialized monet-plots.SpatialContourPlot")

__new__(modelvar, gridobj=None, date=None, **kwargs)

Redirect to SpatialFacetGridPlot if faceting is requested.

This enables a unified interface for both single-panel and multi-panel spatial plots, following Xarray's plotting conventions.

Parameters

modelvar : Any The input data to contour. gridobj : Any, optional Object with LAT and LON variables, by default None. date : Any, optional Date/time for the plot title, by default None. **kwargs : Any Additional keyword arguments. If faceting arguments (e.g., col, row, or col_wrap) are provided, redirects to SpatialFacetGridPlot.

Returns

Any An instance of SpatialContourPlot or SpatialFacetGridPlot.

Source code in src/monet_plots/plots/spatial_contour.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def __new__(
    cls,
    modelvar: Any,
    gridobj: Any | None = None,
    date: Any | None = None,
    **kwargs: Any,
) -> Any:
    """Redirect to SpatialFacetGridPlot if faceting is requested.

    This enables a unified interface for both single-panel and multi-panel
    spatial plots, following Xarray's plotting conventions.

    Parameters
    ----------
    modelvar : Any
        The input data to contour.
    gridobj : Any, optional
        Object with LAT and LON variables, by default None.
    date : Any, optional
        Date/time for the plot title, by default None.
    **kwargs : Any
        Additional keyword arguments. If faceting arguments (e.g., `col`,
        `row`, or `col_wrap`) are provided, redirects to `SpatialFacetGridPlot`.

    Returns
    -------
    Any
        An instance of SpatialContourPlot or SpatialFacetGridPlot.
    """
    from .facet_grid import SpatialFacetGridPlot

    ax = kwargs.get("ax")

    # Aligns with Xarray's trigger for faceting
    facet_kwargs = ["col", "row", "col_wrap"]
    is_faceting = any(kwargs.get(k) is not None for k in facet_kwargs)

    # Redirect to FacetGrid if faceting requested and no existing axes
    if ax is None and is_faceting:
        return SpatialFacetGridPlot(modelvar, **kwargs)

    # Also redirect if input is a Dataset with multiple variables
    if (
        ax is None
        and isinstance(modelvar, xr.Dataset)
        and len(modelvar.data_vars) > 1
    ):
        # Default to faceting by variable if not specified
        kwargs.setdefault("col", "variable")
        return SpatialFacetGridPlot(modelvar, **kwargs)

    return super().__new__(cls)

hvplot(**kwargs)

Generate an interactive spatial contour plot using hvPlot (Track B).

Parameters

**kwargs : Any Keyword arguments passed to hvplot.contourf. Common options include cmap, levels, title, and alpha.

Returns

holoviews.core.layout.Layout The interactive hvPlot object.

Source code in src/monet_plots/plots/spatial_contour.py
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
def hvplot(self, **kwargs: Any) -> Any:
    """Generate an interactive spatial contour plot using hvPlot (Track B).

    Parameters
    ----------
    **kwargs : Any
        Keyword arguments passed to `hvplot.contourf`.
        Common options include `cmap`, `levels`, `title`, and `alpha`.

    Returns
    -------
    holoviews.core.layout.Layout
        The interactive hvPlot object.
    """
    try:
        import hvplot.xarray  # noqa: F401
    except ImportError:
        raise ImportError(
            "hvplot is required for interactive plotting. Install it with 'pip install hvplot'."
        )

    # Track B defaults
    plot_kwargs = {
        "x": self.lon_coord,
        "y": self.lat_coord,
        "geo": True,
        "rasterize": True,
        "cmap": "viridis",
        "kind": "contourf",
    }
    plot_kwargs.update(kwargs)

    return self.modelvar.hvplot(**plot_kwargs)

plot(**kwargs)

Generate a static publication-quality spatial contour plot (Track A).

Parameters

**kwargs : Any Keyword arguments passed to matplotlib.pyplot.contourf. Common options include cmap, levels, vmin, vmax, and alpha. Map features (e.g., coastlines=True) can also be passed here.

Returns

matplotlib.axes.Axes The matplotlib axes object containing the plot.

Source code in src/monet_plots/plots/spatial_contour.py
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
def plot(self, **kwargs: Any) -> Axes:
    """Generate a static publication-quality spatial contour plot (Track A).

    Parameters
    ----------
    **kwargs : Any
        Keyword arguments passed to `matplotlib.pyplot.contourf`.
        Common options include `cmap`, `levels`, `vmin`, `vmax`, and `alpha`.
        Map features (e.g., `coastlines=True`) can also be passed here.

    Returns
    -------
    matplotlib.axes.Axes
        The matplotlib axes object containing the plot.
    """
    lat = None
    lon = None

    if hasattr(self.gridobj, "variables"):
        # Handle legacy gridobj
        try:
            lat_var = self.gridobj.variables["LAT"]
            lon_var = self.gridobj.variables["LON"]

            # Flexible indexing based on dimension count
            if lat_var.ndim == 4:
                lat = lat_var[0, 0, :, :].squeeze()
                lon = lon_var[0, 0, :, :].squeeze()
            elif lat_var.ndim == 3:
                lat = lat_var[0, :, :].squeeze()
                lon = lon_var[0, :, :].squeeze()
            else:
                lat = lat_var.squeeze()
                lon = lon_var.squeeze()
        except (AttributeError, KeyError):
            pass
    elif self.gridobj is not None:
        # Assume it's already an array or similar
        try:
            lat = self.gridobj.LAT
            lon = self.gridobj.LON
        except AttributeError:
            pass

    # Automatically compute extent if not provided
    if "extent" not in kwargs:
        if lon is not None and lat is not None:
            kwargs["extent"] = [
                float(lon.min()),
                float(lon.max()),
                float(lat.min()),
                float(lat.max()),
            ]
        else:
            kwargs["extent"] = self._get_extent_from_data(
                self.modelvar, self.lon_coord, self.lat_coord
            )

    # Draw map features and get remaining kwargs for contourf
    plot_kwargs = self.add_features(**kwargs)

    # Set default contour settings
    plot_kwargs.setdefault("cmap", "viridis")

    # Data is in lat/lon, so specify transform
    plot_kwargs.setdefault("transform", ccrs.PlateCarree())

    # For coordinates and values, we prefer passing the xarray objects directly.
    # This allows Matplotlib to handle the conversion and maintains
    # parity for Dask-backed arrays.
    if lon is None or lat is None:
        longitude = self.modelvar[self.lon_coord]
        latitude = self.modelvar[self.lat_coord]
    else:
        longitude = lon
        latitude = lat

    # Handle colormap and normalization
    final_kwargs = get_plot_kwargs(**plot_kwargs)

    mesh = self.ax.contourf(longitude, latitude, self.modelvar, **final_kwargs)

    # Handle colorbar
    if self.discrete:
        cmap = final_kwargs.get("cmap")
        levels = final_kwargs.get("levels")
        ncolors = self.ncolors
        if ncolors is None and levels is not None:
            if isinstance(levels, int):
                ncolors = levels - 1
                # Use a single compute call for efficiency (Aero Protocol)
                try:
                    import dask

                    dmin, dmax = dask.compute(
                        self.modelvar.min(), self.modelvar.max()
                    )
                except (ImportError, AttributeError):
                    dmin, dmax = self.modelvar.min(), self.modelvar.max()

                # Handle pandas DataFrame where .min() returns a Series
                try:
                    fmin, fmax = float(dmin), float(dmax)
                except (TypeError, ValueError):
                    # dmin/dmax could be Series if modelvar was a DataFrame
                    fmin = (
                        float(dmin.min())
                        if hasattr(dmin, "min")
                        else float(np.min(dmin))
                    )
                    fmax = (
                        float(dmax.max())
                        if hasattr(dmax, "max")
                        else float(np.max(dmax))
                    )
                levels_seq = np.linspace(fmin, fmax, levels)
            else:
                ncolors = len(levels) - 1
                levels_seq = levels
        else:
            levels_seq = levels

        if levels_seq is None:
            # Fallback: calculate from data to ensure a discrete colorbar
            # if requested but no levels were provided.
            try:
                import dask

                dmin, dmax = dask.compute(self.modelvar.min(), self.modelvar.max())
            except (ImportError, AttributeError):
                dmin, dmax = self.modelvar.min(), self.modelvar.max()

            # Handle pandas DataFrame where .min() returns a Series
            try:
                fmin, fmax = float(dmin), float(dmax)
            except (TypeError, ValueError):
                fmin = (
                    float(dmin.min())
                    if hasattr(dmin, "min")
                    else float(np.min(dmin))
                )
                fmax = (
                    float(dmax.max())
                    if hasattr(dmax, "max")
                    else float(np.max(dmax))
                )
            n_lev = self.ncolors if self.ncolors is not None else 10
            levels_seq = np.linspace(fmin, fmax, n_lev + 1)
            ncolors = n_lev

        if levels_seq is not None:
            colorbar_index(
                ncolors,
                cmap,
                minval=levels_seq[0],
                maxval=levels_seq[-1],
                dtype=self.dtype,
                ax=self.ax,
            )
    else:
        self.add_colorbar(mesh)

    if self.date:
        titstring = self.date.strftime("%B %d %Y %H")
        self.ax.set_title(titstring)

    self.fig.tight_layout()
    return self.ax

SpatialImshowPlot

Bases: SpatialPlot

Create a basic spatial plot using imshow.

This class provides an xarray-native interface for visualizing 2D model data on a map. It supports both Track A (publication-quality static plots) and Track B (interactive exploration).

Source code in src/monet_plots/plots/spatial_imshow.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
class SpatialImshowPlot(SpatialPlot):
    """Create a basic spatial plot using imshow.

    This class provides an xarray-native interface for visualizing 2D model
    data on a map. It supports both Track A (publication-quality static plots)
    and Track B (interactive exploration).
    """

    def __new__(
        cls,
        modelvar: Any,
        gridobj: Any | None = None,
        plotargs: dict[str, Any] | None = None,
        **kwargs: Any,
    ) -> Any:
        """Redirect to SpatialFacetGridPlot if faceting is requested.

        This enables a unified interface for both single-panel and multi-panel
        spatial plots, following Xarray's plotting conventions.

        Parameters
        ----------
        modelvar : Any
            The input data to plot.
        gridobj : Any, optional
            Object with LAT and LON variables, by default None.
        plotargs : dict, optional
            Arguments for imshow, by default None.
        **kwargs : Any
            Additional keyword arguments. If faceting arguments (e.g., `col`,
            `row`, or `col_wrap`) are provided, redirects to `SpatialFacetGridPlot`.

        Returns
        -------
        Any
            An instance of SpatialImshowPlot or SpatialFacetGridPlot.
        """
        from .facet_grid import SpatialFacetGridPlot

        ax = kwargs.get("ax")

        # Aligns with Xarray's trigger for faceting
        facet_kwargs = ["col", "row", "col_wrap"]
        is_faceting = any(kwargs.get(k) is not None for k in facet_kwargs)

        # Redirect to FacetGrid if faceting requested and no existing axes
        if ax is None and is_faceting:
            return SpatialFacetGridPlot(modelvar, **kwargs)

        # Also redirect if input is a Dataset with multiple variables
        if (
            ax is None
            and isinstance(modelvar, xr.Dataset)
            and len(modelvar.data_vars) > 1
        ):
            # Default to faceting by variable if not specified
            kwargs.setdefault("col", "variable")
            return SpatialFacetGridPlot(modelvar, **kwargs)

        return super().__new__(cls)

    def __init__(
        self,
        modelvar: Any,
        gridobj: Any | None = None,
        plotargs: dict[str, Any] | None = None,
        ncolors: int = 15,
        discrete: bool = False,
        col: str | None = None,
        row: str | None = None,
        col_wrap: int | None = None,
        size: float | None = None,
        aspect: float | None = None,
        **kwargs: Any,
    ) -> None:
        """Initialize the spatial imshow plot.

        Parameters
        ----------
        modelvar : Any
            The input data to plot. Preferred format is an xarray DataArray.
        gridobj : Any, optional
            Object with LAT and LON variables to determine extent, by default None.
        plotargs : dict, optional
            Arguments for imshow, by default None.
        ncolors : int, optional
            Number of discrete colors for the discrete colorbar, by default 15.
        discrete : bool, optional
            If True, use a discrete colorbar, by default False.
        col : str, optional
            Dimension name to facet by columns. Aligns with Xarray.
        row : str, optional
            Dimension name to facet by rows. Aligns with Xarray.
        col_wrap : int, optional
            Number of columns before wrapping. Aligns with Xarray.
        size : float, optional
            Height (in inches) of each facet. Aligns with Xarray.
        aspect : float, optional
            Aspect ratio of each facet. Aligns with Xarray.
        **kwargs : Any
            Keyword arguments passed to :class:`SpatialPlot` for map features
            and projection.
        """
        # Initialize the map canvas via SpatialPlot
        super().__init__(**kwargs)

        # Standardize data to Xarray for consistency and lazy evaluation
        self.modelvar = normalize_data(modelvar)
        if isinstance(self.modelvar, xr.Dataset) and len(self.modelvar.data_vars) == 1:
            self.modelvar = self.modelvar[list(self.modelvar.data_vars)[0]]

        self.gridobj = gridobj
        self.plotargs = plotargs
        self.ncolors = ncolors
        self.discrete = discrete

        # Aero Protocol: Centralized coordinate identification
        try:
            self.lon_coord, self.lat_coord = self._identify_coords(self.modelvar)
        except ValueError:
            self.lon_coord = kwargs.get("lon_coord", "lon")
            self.lat_coord = kwargs.get("lat_coord", "lat")

        # Ensure coordinates are monotonic for correct plotting
        self.modelvar = self._ensure_monotonic(
            self.modelvar, self.lon_coord, self.lat_coord
        )

        _update_history(self.modelvar, "Initialized monet-plots.SpatialImshowPlot")

    def plot(self, **kwargs: Any) -> Axes:
        """Generate a static publication-quality spatial imshow plot (Track A).

        Parameters
        ----------
        **kwargs : Any
            Keyword arguments passed to `matplotlib.pyplot.imshow`.
            Common options include `cmap`, `vmin`, `vmax`, and `alpha`.
            Map features (e.g., `coastlines=True`) can also be passed here.

        Returns
        -------
        matplotlib.axes.Axes
            The matplotlib axes object containing the plot.
        """
        # Automatically compute extent if not provided
        if "extent" not in kwargs:
            if self.gridobj is not None:
                # Handle legacy gridobj
                try:
                    lat_var = self.gridobj.variables["LAT"]
                    lon_var = self.gridobj.variables["LON"]
                    kwargs["extent"] = [
                        float(lon_var.min()),
                        float(lon_var.max()),
                        float(lat_var.min()),
                        float(lat_var.max()),
                    ]
                except (AttributeError, KeyError):
                    kwargs["extent"] = self._get_extent_from_data(
                        self.modelvar, self.lon_coord, self.lat_coord
                    )
            else:
                kwargs["extent"] = self._get_extent_from_data(
                    self.modelvar, self.lon_coord, self.lat_coord
                )

        # Draw map features and get remaining kwargs for imshow
        imshow_kwargs = self.add_features(**kwargs)

        if self.plotargs:
            imshow_kwargs.update(self.plotargs)

        # Set default imshow settings
        imshow_kwargs.setdefault("cmap", "viridis")
        imshow_kwargs.setdefault("origin", "lower")
        imshow_kwargs.setdefault("transform", ccrs.PlateCarree())

        # Extract extent for imshow [left, right, bottom, top]
        extent = imshow_kwargs.pop("extent", None)

        # Delay computation as much as possible
        # For imshow, we still need concrete values for Track A.
        # But we use xarray's values which handle Dask if properly initialized.
        model_values = self.modelvar.values

        # Handle colormap and normalization
        final_kwargs = get_plot_kwargs(**imshow_kwargs)

        img = self.ax.imshow(model_values, extent=extent, **final_kwargs)

        # Handle colorbar
        if self.discrete:
            vmin, vmax = img.get_clim()
            colorbar_index(
                self.ncolors,
                final_kwargs["cmap"],
                minval=vmin,
                maxval=vmax,
                ax=self.ax,
            )
        else:
            self.add_colorbar(img)

        return self.ax

    def hvplot(self, **kwargs: Any) -> Any:
        """Generate an interactive spatial plot using hvPlot (Track B).

        This method leverages Datashader for high-performance rendering of
        large geospatial grids.

        Parameters
        ----------
        **kwargs : Any
            Keyword arguments passed to `hvplot.quadmesh`.
            Common options include `cmap`, `title`, and `alpha`.
            `rasterize=True` is used by default for speed.

        Returns
        -------
        holoviews.core.layout.Layout
            The interactive hvPlot object.
        """
        try:
            import hvplot.xarray  # noqa: F401
        except ImportError:
            raise ImportError(
                "hvplot is required for interactive plotting. Install it with 'pip install hvplot'."
            )

        # Track B defaults
        plot_kwargs = {
            "x": self.lon_coord,
            "y": self.lat_coord,
            "geo": True,
            "rasterize": True,
            "cmap": "viridis",
        }
        plot_kwargs.update(kwargs)

        return self.modelvar.hvplot.quadmesh(**plot_kwargs)

__init__(modelvar, gridobj=None, plotargs=None, ncolors=15, discrete=False, col=None, row=None, col_wrap=None, size=None, aspect=None, **kwargs)

Initialize the spatial imshow plot.

Parameters

modelvar : Any The input data to plot. Preferred format is an xarray DataArray. gridobj : Any, optional Object with LAT and LON variables to determine extent, by default None. plotargs : dict, optional Arguments for imshow, by default None. ncolors : int, optional Number of discrete colors for the discrete colorbar, by default 15. discrete : bool, optional If True, use a discrete colorbar, by default False. col : str, optional Dimension name to facet by columns. Aligns with Xarray. row : str, optional Dimension name to facet by rows. Aligns with Xarray. col_wrap : int, optional Number of columns before wrapping. Aligns with Xarray. size : float, optional Height (in inches) of each facet. Aligns with Xarray. aspect : float, optional Aspect ratio of each facet. Aligns with Xarray. **kwargs : Any Keyword arguments passed to :class:SpatialPlot for map features and projection.

Source code in src/monet_plots/plots/spatial_imshow.py
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
def __init__(
    self,
    modelvar: Any,
    gridobj: Any | None = None,
    plotargs: dict[str, Any] | None = None,
    ncolors: int = 15,
    discrete: bool = False,
    col: str | None = None,
    row: str | None = None,
    col_wrap: int | None = None,
    size: float | None = None,
    aspect: float | None = None,
    **kwargs: Any,
) -> None:
    """Initialize the spatial imshow plot.

    Parameters
    ----------
    modelvar : Any
        The input data to plot. Preferred format is an xarray DataArray.
    gridobj : Any, optional
        Object with LAT and LON variables to determine extent, by default None.
    plotargs : dict, optional
        Arguments for imshow, by default None.
    ncolors : int, optional
        Number of discrete colors for the discrete colorbar, by default 15.
    discrete : bool, optional
        If True, use a discrete colorbar, by default False.
    col : str, optional
        Dimension name to facet by columns. Aligns with Xarray.
    row : str, optional
        Dimension name to facet by rows. Aligns with Xarray.
    col_wrap : int, optional
        Number of columns before wrapping. Aligns with Xarray.
    size : float, optional
        Height (in inches) of each facet. Aligns with Xarray.
    aspect : float, optional
        Aspect ratio of each facet. Aligns with Xarray.
    **kwargs : Any
        Keyword arguments passed to :class:`SpatialPlot` for map features
        and projection.
    """
    # Initialize the map canvas via SpatialPlot
    super().__init__(**kwargs)

    # Standardize data to Xarray for consistency and lazy evaluation
    self.modelvar = normalize_data(modelvar)
    if isinstance(self.modelvar, xr.Dataset) and len(self.modelvar.data_vars) == 1:
        self.modelvar = self.modelvar[list(self.modelvar.data_vars)[0]]

    self.gridobj = gridobj
    self.plotargs = plotargs
    self.ncolors = ncolors
    self.discrete = discrete

    # Aero Protocol: Centralized coordinate identification
    try:
        self.lon_coord, self.lat_coord = self._identify_coords(self.modelvar)
    except ValueError:
        self.lon_coord = kwargs.get("lon_coord", "lon")
        self.lat_coord = kwargs.get("lat_coord", "lat")

    # Ensure coordinates are monotonic for correct plotting
    self.modelvar = self._ensure_monotonic(
        self.modelvar, self.lon_coord, self.lat_coord
    )

    _update_history(self.modelvar, "Initialized monet-plots.SpatialImshowPlot")

__new__(modelvar, gridobj=None, plotargs=None, **kwargs)

Redirect to SpatialFacetGridPlot if faceting is requested.

This enables a unified interface for both single-panel and multi-panel spatial plots, following Xarray's plotting conventions.

Parameters

modelvar : Any The input data to plot. gridobj : Any, optional Object with LAT and LON variables, by default None. plotargs : dict, optional Arguments for imshow, by default None. **kwargs : Any Additional keyword arguments. If faceting arguments (e.g., col, row, or col_wrap) are provided, redirects to SpatialFacetGridPlot.

Returns

Any An instance of SpatialImshowPlot or SpatialFacetGridPlot.

Source code in src/monet_plots/plots/spatial_imshow.py
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def __new__(
    cls,
    modelvar: Any,
    gridobj: Any | None = None,
    plotargs: dict[str, Any] | None = None,
    **kwargs: Any,
) -> Any:
    """Redirect to SpatialFacetGridPlot if faceting is requested.

    This enables a unified interface for both single-panel and multi-panel
    spatial plots, following Xarray's plotting conventions.

    Parameters
    ----------
    modelvar : Any
        The input data to plot.
    gridobj : Any, optional
        Object with LAT and LON variables, by default None.
    plotargs : dict, optional
        Arguments for imshow, by default None.
    **kwargs : Any
        Additional keyword arguments. If faceting arguments (e.g., `col`,
        `row`, or `col_wrap`) are provided, redirects to `SpatialFacetGridPlot`.

    Returns
    -------
    Any
        An instance of SpatialImshowPlot or SpatialFacetGridPlot.
    """
    from .facet_grid import SpatialFacetGridPlot

    ax = kwargs.get("ax")

    # Aligns with Xarray's trigger for faceting
    facet_kwargs = ["col", "row", "col_wrap"]
    is_faceting = any(kwargs.get(k) is not None for k in facet_kwargs)

    # Redirect to FacetGrid if faceting requested and no existing axes
    if ax is None and is_faceting:
        return SpatialFacetGridPlot(modelvar, **kwargs)

    # Also redirect if input is a Dataset with multiple variables
    if (
        ax is None
        and isinstance(modelvar, xr.Dataset)
        and len(modelvar.data_vars) > 1
    ):
        # Default to faceting by variable if not specified
        kwargs.setdefault("col", "variable")
        return SpatialFacetGridPlot(modelvar, **kwargs)

    return super().__new__(cls)

hvplot(**kwargs)

Generate an interactive spatial plot using hvPlot (Track B).

This method leverages Datashader for high-performance rendering of large geospatial grids.

Parameters

**kwargs : Any Keyword arguments passed to hvplot.quadmesh. Common options include cmap, title, and alpha. rasterize=True is used by default for speed.

Returns

holoviews.core.layout.Layout The interactive hvPlot object.

Source code in src/monet_plots/plots/spatial_imshow.py
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
def hvplot(self, **kwargs: Any) -> Any:
    """Generate an interactive spatial plot using hvPlot (Track B).

    This method leverages Datashader for high-performance rendering of
    large geospatial grids.

    Parameters
    ----------
    **kwargs : Any
        Keyword arguments passed to `hvplot.quadmesh`.
        Common options include `cmap`, `title`, and `alpha`.
        `rasterize=True` is used by default for speed.

    Returns
    -------
    holoviews.core.layout.Layout
        The interactive hvPlot object.
    """
    try:
        import hvplot.xarray  # noqa: F401
    except ImportError:
        raise ImportError(
            "hvplot is required for interactive plotting. Install it with 'pip install hvplot'."
        )

    # Track B defaults
    plot_kwargs = {
        "x": self.lon_coord,
        "y": self.lat_coord,
        "geo": True,
        "rasterize": True,
        "cmap": "viridis",
    }
    plot_kwargs.update(kwargs)

    return self.modelvar.hvplot.quadmesh(**plot_kwargs)

plot(**kwargs)

Generate a static publication-quality spatial imshow plot (Track A).

Parameters

**kwargs : Any Keyword arguments passed to matplotlib.pyplot.imshow. Common options include cmap, vmin, vmax, and alpha. Map features (e.g., coastlines=True) can also be passed here.

Returns

matplotlib.axes.Axes The matplotlib axes object containing the plot.

Source code in src/monet_plots/plots/spatial_imshow.py
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
def plot(self, **kwargs: Any) -> Axes:
    """Generate a static publication-quality spatial imshow plot (Track A).

    Parameters
    ----------
    **kwargs : Any
        Keyword arguments passed to `matplotlib.pyplot.imshow`.
        Common options include `cmap`, `vmin`, `vmax`, and `alpha`.
        Map features (e.g., `coastlines=True`) can also be passed here.

    Returns
    -------
    matplotlib.axes.Axes
        The matplotlib axes object containing the plot.
    """
    # Automatically compute extent if not provided
    if "extent" not in kwargs:
        if self.gridobj is not None:
            # Handle legacy gridobj
            try:
                lat_var = self.gridobj.variables["LAT"]
                lon_var = self.gridobj.variables["LON"]
                kwargs["extent"] = [
                    float(lon_var.min()),
                    float(lon_var.max()),
                    float(lat_var.min()),
                    float(lat_var.max()),
                ]
            except (AttributeError, KeyError):
                kwargs["extent"] = self._get_extent_from_data(
                    self.modelvar, self.lon_coord, self.lat_coord
                )
        else:
            kwargs["extent"] = self._get_extent_from_data(
                self.modelvar, self.lon_coord, self.lat_coord
            )

    # Draw map features and get remaining kwargs for imshow
    imshow_kwargs = self.add_features(**kwargs)

    if self.plotargs:
        imshow_kwargs.update(self.plotargs)

    # Set default imshow settings
    imshow_kwargs.setdefault("cmap", "viridis")
    imshow_kwargs.setdefault("origin", "lower")
    imshow_kwargs.setdefault("transform", ccrs.PlateCarree())

    # Extract extent for imshow [left, right, bottom, top]
    extent = imshow_kwargs.pop("extent", None)

    # Delay computation as much as possible
    # For imshow, we still need concrete values for Track A.
    # But we use xarray's values which handle Dask if properly initialized.
    model_values = self.modelvar.values

    # Handle colormap and normalization
    final_kwargs = get_plot_kwargs(**imshow_kwargs)

    img = self.ax.imshow(model_values, extent=extent, **final_kwargs)

    # Handle colorbar
    if self.discrete:
        vmin, vmax = img.get_clim()
        colorbar_index(
            self.ncolors,
            final_kwargs["cmap"],
            minval=vmin,
            maxval=vmax,
            ax=self.ax,
        )
    else:
        self.add_colorbar(img)

    return self.ax

SpatialPlot

Bases: BasePlot

A base class for creating spatial plots with cartopy.

This class provides a high-level interface for geospatial plots, handling the setup of cartopy axes and the addition of common map features like coastlines, states, and gridlines.

Attributes

resolution : str The resolution of the cartopy features (e.g., '50m').

Source code in src/monet_plots/plots/spatial.py
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
class SpatialPlot(BasePlot):
    """A base class for creating spatial plots with cartopy.

    This class provides a high-level interface for geospatial plots, handling
    the setup of cartopy axes and the addition of common map features like
    coastlines, states, and gridlines.

    Attributes
    ----------
    resolution : str
        The resolution of the cartopy features (e.g., '50m').
    """

    def __init__(
        self,
        *,
        projection: ccrs.Projection = ccrs.PlateCarree(),
        fig: Figure | None = None,
        ax: Axes | None = None,
        figsize: tuple[float, float] | None = None,
        subplot_kw: dict[str, Any] | None = None,
        **kwargs: Any,
    ) -> None:
        """Initialize the spatial plot and draw map features.

        This constructor sets up the matplotlib Figure and cartopy GeoAxes,
        and provides a single interface to draw common map features like
        coastlines and states.

        Parameters
        ----------
        projection : ccrs.Projection, optional
            The cartopy projection for the map, by default ccrs.PlateCarree().
        fig : plt.Figure | None, optional
            An existing matplotlib Figure object. If None, a new one is
            created, by default None.
        ax : plt.Axes | None, optional
            An existing matplotlib Axes object. If None, a new one is created,
            by default None.
        figsize : tuple[float, float] | None, optional
             Width, height in inches. If not provided, the matplotlib default
             will be used.
        subplot_kw : dict[str, Any] | None, optional
            Keyword arguments passed to `fig.add_subplot`, by default None.
            The 'projection' is added to these keywords automatically.
        **kwargs : Any
            Keyword arguments for map features, passed to `add_features`.
            Common options include `coastlines`, `states`, `countries`,
            `ocean`, `land`, `lakes`, `rivers`, `borders`, `gridlines`,
            `extent`, and `resolution`.

        Attributes
        ----------
        fig : plt.Figure
            The matplotlib Figure object.
        ax : plt.Axes
            The matplotlib Axes (or GeoAxes) object.
        resolution : str
            The default resolution for cartopy features.
        """
        # Ensure 'projection' is correctly passed to subplot creation.
        current_subplot_kw = subplot_kw.copy() if subplot_kw else {}
        current_subplot_kw["projection"] = projection

        self.resolution = kwargs.pop("resolution", "50m")
        style = kwargs.pop("style", "wiley")

        # Ensure coastlines are enabled by default if not specified.
        if "coastlines" not in kwargs:
            kwargs["coastlines"] = True

        # Initialize the base plot, which creates the figure and axes.
        super().__init__(
            fig=fig, ax=ax, figsize=figsize, style=style, subplot_kw=current_subplot_kw
        )

        # If BasePlot didn't create an axes (e.g. because fig was provided),
        # create one now.
        if self.ax is None:
            self.ax = self.fig.add_subplot(1, 1, 1, **current_subplot_kw)

        # Add features from kwargs
        self.add_features(**kwargs)

    def _get_feature_registry(self, resolution: str) -> dict[str, dict[str, Any]]:
        """Return a registry of cartopy features and their default styles.

        This approach centralizes feature management, making it easier to
        add new features and maintain existing ones.

        Parameters
        ----------
        resolution : str
            The resolution for the cartopy features (e.g., '10m', '50m').

        Returns
        -------
        dict[str, dict[str, Any]]
            A dictionary mapping feature names to a specification dictionary
            containing the feature object and its default styling.
        """
        from cartopy.feature import (
            BORDERS,
            COASTLINE,
            LAKES,
            LAND,
            OCEAN,
            RIVERS,
            STATES,
        )

        # Define default styles, falling back to sane defaults if not in rcParams.
        coastline_defaults = {
            "linewidth": get_style_setting("coastline.width", 0.5),
            "edgecolor": get_style_setting("coastline.color", "black"),
            "facecolor": "none",
        }
        states_defaults = {
            "linewidth": get_style_setting("states.width", 0.5),
            "edgecolor": get_style_setting("states.color", "black"),
            "facecolor": "none",
        }
        borders_defaults = {
            "linewidth": get_style_setting("borders.width", 0.5),
            "edgecolor": get_style_setting("borders.color", "black"),
            "facecolor": "none",
        }

        feature_mapping = {
            "coastlines": {
                "feature": COASTLINE.with_scale(resolution),
                "defaults": coastline_defaults,
            },
            "countries": {
                "feature": BORDERS.with_scale(resolution),
                "defaults": borders_defaults,
            },
            "states": {
                "feature": STATES.with_scale(resolution),
                "defaults": states_defaults,
            },
            "borders": {
                "feature": BORDERS.with_scale(resolution),
                "defaults": borders_defaults,
            },
            "ocean": {"feature": OCEAN.with_scale(resolution), "defaults": {}},
            "land": {"feature": LAND.with_scale(resolution), "defaults": {}},
            "rivers": {"feature": RIVERS.with_scale(resolution), "defaults": {}},
            "lakes": {"feature": LAKES.with_scale(resolution), "defaults": {}},
            "counties": {
                "feature": cfeature.NaturalEarthFeature(
                    category="cultural",
                    name="admin_2_counties",
                    scale=resolution,
                    facecolor="none",
                ),
                "defaults": borders_defaults,
            },
        }
        return feature_mapping

    @staticmethod
    def _get_style(
        style: bool | dict[str, Any], defaults: dict[str, Any] | None = None
    ) -> dict[str, Any]:
        """Get a style dictionary for a feature.

        Parameters
        ----------
        style : bool or dict[str, Any]
            The user-provided style. If True, use defaults. If a dict, use it.
        defaults : dict[str, Any], optional
            The default style to apply if `style` is True.

        Returns
        -------
        dict[str, Any]
            The resolved keyword arguments for styling.
        """
        if isinstance(style, dict):
            return style
        if style and defaults:
            # Return a copy to prevent modifying the defaults dictionary in place
            return defaults.copy()
        return {}

    def _draw_single_feature(
        self, style_arg: bool | dict[str, Any], feature_spec: dict[str, Any]
    ) -> None:
        """Draw a single cartopy feature on the axes.

        Parameters
        ----------
        style_arg : bool or dict[str, Any]
            The user-provided style for the feature.
        feature_spec : dict[str, Any]
            A dictionary containing the feature object and default styles.
        """
        if not style_arg:  # Allows for `coastlines=False`
            return

        style_kwargs = self._get_style(style_arg, feature_spec["defaults"])
        feature = feature_spec["feature"]
        self.ax.add_feature(feature, **style_kwargs)

    def add_features(self, **kwargs: Any) -> dict[str, Any]:
        """Add and style cartopy features on the map axes.

        This method provides a flexible, data-driven interface to add common
        map features. Features can be enabled with a boolean flag (e.g.,
        `coastlines=True`) or styled with a dictionary of keyword arguments
        (e.g., `states=dict(linewidth=2, edgecolor='red')`).

        The `extent` keyword is also supported to set the map boundaries.

        Parameters
        ----------
        **kwargs : Any
            Keyword arguments controlling the features to add and their
            styles. Common options include `coastlines`, `states`,
            `countries`, `ocean`, `land`, `lakes`, `rivers`, `borders`,
            and `gridlines`.

        Returns
        -------
        dict[str, Any]
            A dictionary of the keyword arguments that were not used for
            adding features. This can be useful for passing remaining
            arguments to other functions.
        """
        # Note: The order of these calls is important.
        # Extent must be set before gridlines are drawn to ensure labels
        # are placed correctly.
        if "extent" in kwargs:
            extent = kwargs.pop("extent")
            self._set_extent(extent)

        if "gridlines" in kwargs:
            gridline_style = kwargs.pop("gridlines")
            self._draw_gridlines(gridline_style)

        # The rest of the kwargs are assumed to be for vector features.
        remaining_kwargs = self._draw_features(**kwargs)

        return remaining_kwargs

    def _set_extent(self, extent: tuple[float, float, float, float] | None) -> None:
        """Set the geographic extent of the map.

        Parameters
        ----------
        extent : tuple[float, float, float, float] | None
            The extent of the map as a tuple of (x_min, x_max, y_min, y_max).
            If None, the extent is not changed.
        """
        if extent is not None:
            self.ax.set_extent(extent)

    def _draw_features(self, **kwargs: Any) -> dict[str, Any]:
        """Draw vector features on the map.

        This is the primary feature-drawing loop, responsible for adding
        elements like coastlines, states, and borders.

        Parameters
        ----------
        **kwargs : Any
            Keyword arguments controlling the features to add and their
            styles.

        Returns
        -------
        dict[str, Any]
            A dictionary of the keyword arguments that were not used for
            adding features.
        """
        resolution = kwargs.pop("resolution", self.resolution)
        feature_registry = self._get_feature_registry(resolution)

        # If natural_earth is True, enable a standard set of features
        if kwargs.pop("natural_earth", False):
            for feature in ["ocean", "land", "lakes", "rivers"]:
                kwargs.setdefault(feature, True)

        # Main feature-drawing loop
        for key, feature_spec in feature_registry.items():
            if key in kwargs:
                style_arg = kwargs.pop(key)
                self._draw_single_feature(style_arg, feature_spec)

        return kwargs

    def _draw_gridlines(self, style: bool | dict[str, Any]) -> None:
        """Draw gridlines on the map.

        Parameters
        ----------
        style : bool or dict[str, Any]
            The style for the gridlines. If True, use defaults. If a dict,
            use it as keyword arguments. If False, do nothing.
        """
        if not style:
            return

        gridline_defaults = {
            "draw_labels": True,
            "linestyle": "--",
            "color": "gray",
        }
        gridline_kwargs = self._get_style(style, gridline_defaults)
        self.ax.gridlines(**gridline_kwargs)

    def _identify_coords(self, data: Any) -> tuple[str, str]:
        """Identify longitude and latitude coordinate names in data.

        This method follows CF conventions and common naming patterns to
        find the spatial coordinates. It supports both xarray and pandas.

        Parameters
        ----------
        data : Any
            The input data to search for coordinates.

        Returns
        -------
        tuple[str, str]
            The identified (longitude, latitude) coordinate names.

        Raises
        ----------
        ValueError
            If coordinates cannot be identified.
        """
        lon_name, lat_name = None, None

        # Handle xarray
        if hasattr(data, "coords"):
            # Check for standard names or units/axis attributes (CF conventions)
            for name, coord in data.coords.items():
                # Longitude identification
                if any(
                    pattern in str(name).lower() for pattern in ["lon", "longitude"]
                ) or coord.attrs.get("units") in [
                    "degrees_east",
                    "degree_east",
                    "degree_E",
                ]:
                    lon_name = str(name)
                # Latitude identification
                if any(
                    pattern in str(name).lower() for pattern in ["lat", "latitude"]
                ) or coord.attrs.get("units") in [
                    "degrees_north",
                    "degree_north",
                    "degree_N",
                ]:
                    lat_name = str(name)

        # Handle pandas or other objects with columns
        elif hasattr(data, "columns"):
            cols = [str(c).lower() for c in data.columns]
            if "lon" in cols:
                lon_name = data.columns[cols.index("lon")]
            elif "longitude" in cols:
                lon_name = data.columns[cols.index("longitude")]

            if "lat" in cols:
                lat_name = data.columns[cols.index("lat")]
            elif "latitude" in cols:
                lat_name = data.columns[cols.index("latitude")]

        if lon_name and lat_name:
            return lon_name, lat_name

        raise ValueError(
            "Could not identify longitude and latitude coordinates. "
            "Please ensure they are named 'lon'/'lat' or 'longitude'/'latitude', "
            "or have CF-compliant units."
        )

    def _ensure_monotonic(self, data: Any, lon_name: str, lat_name: str) -> Any:
        """Ensure spatial dimensions are monotonically increasing.

        Some plotting backends (like GeoViews/hvPlot) require monotonically
        increasing coordinates for correct rendering and interpolation.

        Parameters
        ----------
        data : Any
            The input data to sort.
        lon_name : str
            Name of the longitude coordinate.
        lat_name : str
            Name of the latitude coordinate.

        Returns
        -------
        Any
            The data with sorted spatial dimensions.
        """
        is_xarray = hasattr(data, "sortby")

        # Ensure latitude is increasing
        try:
            lats = data[lat_name].values if is_xarray else data[lat_name].values
            if lats[0] > lats[-1]:
                if is_xarray:
                    data = data.sortby(lat_name)
                else:
                    data = data.sort_values(lat_name)
        except (AttributeError, KeyError, IndexError):
            pass

        # Ensure longitude is increasing
        try:
            lons = data[lon_name].values if is_xarray else data[lon_name].values
            if lons[0] > lons[-1]:
                if is_xarray:
                    data = data.sortby(lon_name)
                else:
                    data = data.sort_values(lon_name)
        except (AttributeError, KeyError, IndexError):
            pass

        return data

    def _get_extent_from_data(
        self,
        data: xr.DataArray | xr.Dataset,
        lon_coord: str | None = None,
        lat_coord: str | None = None,
        buffer: float = 0.0,
    ) -> list[float]:
        """Calculate geographic extent from xarray data using Dask if available.

        Parameters
        ----------
        data : xr.DataArray or xr.Dataset
            The input data to calculate the extent from.
        lon_coord : str, optional
            Name of the longitude coordinate. If None, it is identified
            automatically using `_identify_coords`.
        lat_coord : str, optional
            Name of the latitude coordinate. If None, it is identified
            automatically using `_identify_coords`.
        buffer : float, optional
            Buffer to add to the extent as a fraction of the range,
            by default 0.0.

        Returns
        -------
        list[float]
            The calculated extent as [lon_min, lon_max, lat_min, lat_max].
        """
        if lon_coord is None or lat_coord is None:
            lon_id, lat_id = self._identify_coords(data)
            lon_coord = lon_coord or lon_id
            lat_coord = lat_coord or lat_id

        lon = data[lon_coord]
        lat = data[lat_coord]

        # Use dask.compute for efficient parallel calculation of min/max
        # if the data is chunked.
        try:
            import dask

            lon_min, lon_max, lat_min, lat_max = dask.compute(
                lon.min(), lon.max(), lat.min(), lat.max()
            )
        except (ImportError, AttributeError):
            lon_min, lon_max = lon.min(), lon.max()
            lat_min, lat_max = lat.min(), lat.max()

        # Ensure they are scalar values (handles both numpy and dask returns)
        lon_min, lon_max = float(lon_min), float(lon_max)
        lat_min, lat_max = float(lat_min), float(lat_max)

        if buffer > 0:
            lon_range = lon_max - lon_min
            lat_range = lat_max - lat_min
            lon_buf = lon_range * buffer if lon_range > 0 else 1.0
            lat_buf = lat_range * buffer if lat_range > 0 else 1.0
            lon_min -= lon_buf
            lon_max += lon_buf
            lat_min -= lat_buf
            lat_max += lat_buf

        return [lon_min, lon_max, lat_min, lat_max]

__init__(*, projection=ccrs.PlateCarree(), fig=None, ax=None, figsize=None, subplot_kw=None, **kwargs)

Initialize the spatial plot and draw map features.

This constructor sets up the matplotlib Figure and cartopy GeoAxes, and provides a single interface to draw common map features like coastlines and states.

Parameters

projection : ccrs.Projection, optional The cartopy projection for the map, by default ccrs.PlateCarree(). fig : plt.Figure | None, optional An existing matplotlib Figure object. If None, a new one is created, by default None. ax : plt.Axes | None, optional An existing matplotlib Axes object. If None, a new one is created, by default None. figsize : tuple[float, float] | None, optional Width, height in inches. If not provided, the matplotlib default will be used. subplot_kw : dict[str, Any] | None, optional Keyword arguments passed to fig.add_subplot, by default None. The 'projection' is added to these keywords automatically. **kwargs : Any Keyword arguments for map features, passed to add_features. Common options include coastlines, states, countries, ocean, land, lakes, rivers, borders, gridlines, extent, and resolution.

Attributes

fig : plt.Figure The matplotlib Figure object. ax : plt.Axes The matplotlib Axes (or GeoAxes) object. resolution : str The default resolution for cartopy features.

Source code in src/monet_plots/plots/spatial.py
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def __init__(
    self,
    *,
    projection: ccrs.Projection = ccrs.PlateCarree(),
    fig: Figure | None = None,
    ax: Axes | None = None,
    figsize: tuple[float, float] | None = None,
    subplot_kw: dict[str, Any] | None = None,
    **kwargs: Any,
) -> None:
    """Initialize the spatial plot and draw map features.

    This constructor sets up the matplotlib Figure and cartopy GeoAxes,
    and provides a single interface to draw common map features like
    coastlines and states.

    Parameters
    ----------
    projection : ccrs.Projection, optional
        The cartopy projection for the map, by default ccrs.PlateCarree().
    fig : plt.Figure | None, optional
        An existing matplotlib Figure object. If None, a new one is
        created, by default None.
    ax : plt.Axes | None, optional
        An existing matplotlib Axes object. If None, a new one is created,
        by default None.
    figsize : tuple[float, float] | None, optional
         Width, height in inches. If not provided, the matplotlib default
         will be used.
    subplot_kw : dict[str, Any] | None, optional
        Keyword arguments passed to `fig.add_subplot`, by default None.
        The 'projection' is added to these keywords automatically.
    **kwargs : Any
        Keyword arguments for map features, passed to `add_features`.
        Common options include `coastlines`, `states`, `countries`,
        `ocean`, `land`, `lakes`, `rivers`, `borders`, `gridlines`,
        `extent`, and `resolution`.

    Attributes
    ----------
    fig : plt.Figure
        The matplotlib Figure object.
    ax : plt.Axes
        The matplotlib Axes (or GeoAxes) object.
    resolution : str
        The default resolution for cartopy features.
    """
    # Ensure 'projection' is correctly passed to subplot creation.
    current_subplot_kw = subplot_kw.copy() if subplot_kw else {}
    current_subplot_kw["projection"] = projection

    self.resolution = kwargs.pop("resolution", "50m")
    style = kwargs.pop("style", "wiley")

    # Ensure coastlines are enabled by default if not specified.
    if "coastlines" not in kwargs:
        kwargs["coastlines"] = True

    # Initialize the base plot, which creates the figure and axes.
    super().__init__(
        fig=fig, ax=ax, figsize=figsize, style=style, subplot_kw=current_subplot_kw
    )

    # If BasePlot didn't create an axes (e.g. because fig was provided),
    # create one now.
    if self.ax is None:
        self.ax = self.fig.add_subplot(1, 1, 1, **current_subplot_kw)

    # Add features from kwargs
    self.add_features(**kwargs)

add_features(**kwargs)

Add and style cartopy features on the map axes.

This method provides a flexible, data-driven interface to add common map features. Features can be enabled with a boolean flag (e.g., coastlines=True) or styled with a dictionary of keyword arguments (e.g., states=dict(linewidth=2, edgecolor='red')).

The extent keyword is also supported to set the map boundaries.

Parameters

**kwargs : Any Keyword arguments controlling the features to add and their styles. Common options include coastlines, states, countries, ocean, land, lakes, rivers, borders, and gridlines.

Returns

dict[str, Any] A dictionary of the keyword arguments that were not used for adding features. This can be useful for passing remaining arguments to other functions.

Source code in src/monet_plots/plots/spatial.py
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
def add_features(self, **kwargs: Any) -> dict[str, Any]:
    """Add and style cartopy features on the map axes.

    This method provides a flexible, data-driven interface to add common
    map features. Features can be enabled with a boolean flag (e.g.,
    `coastlines=True`) or styled with a dictionary of keyword arguments
    (e.g., `states=dict(linewidth=2, edgecolor='red')`).

    The `extent` keyword is also supported to set the map boundaries.

    Parameters
    ----------
    **kwargs : Any
        Keyword arguments controlling the features to add and their
        styles. Common options include `coastlines`, `states`,
        `countries`, `ocean`, `land`, `lakes`, `rivers`, `borders`,
        and `gridlines`.

    Returns
    -------
    dict[str, Any]
        A dictionary of the keyword arguments that were not used for
        adding features. This can be useful for passing remaining
        arguments to other functions.
    """
    # Note: The order of these calls is important.
    # Extent must be set before gridlines are drawn to ensure labels
    # are placed correctly.
    if "extent" in kwargs:
        extent = kwargs.pop("extent")
        self._set_extent(extent)

    if "gridlines" in kwargs:
        gridline_style = kwargs.pop("gridlines")
        self._draw_gridlines(gridline_style)

    # The rest of the kwargs are assumed to be for vector features.
    remaining_kwargs = self._draw_features(**kwargs)

    return remaining_kwargs

SpatialTrack

Bases: SpatialPlot

Plot a trajectory from an xarray.DataArray on a map.

This class provides an xarray-native interface for visualizing paths, such as flight trajectories or pollutant tracks, where a variable (e.g., altitude, concentration) is plotted along the path.

It inherits from :class:SpatialPlot to provide the underlying map canvas.

Attributes

data : xr.DataArray The trajectory data being plotted. lon_coord : str The name of the longitude coordinate in the DataArray. lat_coord : str The name of the latitude coordinate in the DataArray.

Source code in src/monet_plots/plots/spatial.py
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
class SpatialTrack(SpatialPlot):
    """Plot a trajectory from an xarray.DataArray on a map.

    This class provides an xarray-native interface for visualizing paths,
    such as flight trajectories or pollutant tracks, where a variable
    (e.g., altitude, concentration) is plotted along the path.

    It inherits from :class:`SpatialPlot` to provide the underlying map canvas.

    Attributes
    ----------
    data : xr.DataArray
        The trajectory data being plotted.
    lon_coord : str
        The name of the longitude coordinate in the DataArray.
    lat_coord : str
        The name of the latitude coordinate in the DataArray.
    """

    def __init__(
        self,
        data: xr.DataArray,
        *,
        lon_coord: str = "lon",
        lat_coord: str = "lat",
        **kwargs: Any,
    ):
        """Initialize the SpatialTrack plot.

        This constructor validates the input data and sets up the map canvas
        by initializing the parent `SpatialPlot` and adding map features.

        Parameters
        ----------
        data : xr.DataArray
            The input trajectory data. Must be an xarray DataArray with
            coordinates for longitude and latitude.
        lon_coord : str, optional
            Name of the longitude coordinate in the DataArray, by default 'lon'.
        lat_coord : str, optional
            Name of the latitude coordinate in the DataArray, by default 'lat'.
        **kwargs : Any
            Keyword arguments passed to :class:`SpatialPlot`. These control
            the map projection, figure size, and cartopy features. For example:
            `projection=ccrs.LambertConformal()`, `figsize=(10, 8)`,
            `states=True`, `extent=[-125, -70, 25, 50]`.
        """
        if not isinstance(data, xr.DataArray):
            raise TypeError("Input 'data' must be an xarray.DataArray.")
        if lon_coord not in data.coords:
            raise ValueError(
                f"Longitude coordinate '{lon_coord}' not found in DataArray."
            )
        if lat_coord not in data.coords:
            raise ValueError(
                f"Latitude coordinate '{lat_coord}' not found in DataArray."
            )

        # Initialize the parent SpatialPlot, which creates the map canvas
        # and draws features from the keyword arguments.
        super().__init__(**kwargs)

        # Set data and update history for provenance
        self.data = data
        self.lon_coord = lon_coord
        self.lat_coord = lat_coord
        _update_history(self.data, "Plotted with monet-plots.SpatialTrack")

    def plot(self, **kwargs: Any) -> plt.Artist:
        """Plot the trajectory on the map.

        The track is rendered as a scatter plot, where each point is colored
        according to the `data` values.

        Parameters
        ----------
        **kwargs : Any
            Keyword arguments passed to `matplotlib.pyplot.scatter`.
            Common options include `cmap`, `s` (size), and `alpha`.
            A `transform` keyword (e.g., `transform=ccrs.PlateCarree()`)
            is highly recommended for geospatial accuracy.
            The `cmap` argument can be a string, a Colormap object, or a
            (colormap, norm) tuple from the scaling tools in `colorbars.py`.
            Map features (e.g., `coastlines=True`) can also be passed here.

        Returns
        -------
        plt.Artist
            The scatter plot artist created by `ax.scatter`.
        """
        from ..plot_utils import get_plot_kwargs

        # Automatically compute extent if not provided
        if "extent" not in kwargs:
            kwargs["extent"] = self._get_extent_from_data(
                self.data, self.lon_coord, self.lat_coord, buffer=0.1
            )

        # Add features and get remaining kwargs for scatter
        scatter_kwargs = self.add_features(**kwargs)

        scatter_kwargs.setdefault("transform", ccrs.PlateCarree())

        # For coordinates and values, we pass the xarray objects directly.
        # This allows Matplotlib to handle the conversion, maintaining
        # compatibility with existing tests that check for lazy objects.
        longitude = self.data[self.lon_coord]
        latitude = self.data[self.lat_coord]

        # Use get_plot_kwargs to handle (cmap, norm) tuples
        final_kwargs = get_plot_kwargs(c=self.data, **scatter_kwargs)

        sc = self.ax.scatter(longitude, latitude, **final_kwargs)
        return sc

__init__(data, *, lon_coord='lon', lat_coord='lat', **kwargs)

Initialize the SpatialTrack plot.

This constructor validates the input data and sets up the map canvas by initializing the parent SpatialPlot and adding map features.

Parameters

data : xr.DataArray The input trajectory data. Must be an xarray DataArray with coordinates for longitude and latitude. lon_coord : str, optional Name of the longitude coordinate in the DataArray, by default 'lon'. lat_coord : str, optional Name of the latitude coordinate in the DataArray, by default 'lat'. **kwargs : Any Keyword arguments passed to :class:SpatialPlot. These control the map projection, figure size, and cartopy features. For example: projection=ccrs.LambertConformal(), figsize=(10, 8), states=True, extent=[-125, -70, 25, 50].

Source code in src/monet_plots/plots/spatial.py
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
def __init__(
    self,
    data: xr.DataArray,
    *,
    lon_coord: str = "lon",
    lat_coord: str = "lat",
    **kwargs: Any,
):
    """Initialize the SpatialTrack plot.

    This constructor validates the input data and sets up the map canvas
    by initializing the parent `SpatialPlot` and adding map features.

    Parameters
    ----------
    data : xr.DataArray
        The input trajectory data. Must be an xarray DataArray with
        coordinates for longitude and latitude.
    lon_coord : str, optional
        Name of the longitude coordinate in the DataArray, by default 'lon'.
    lat_coord : str, optional
        Name of the latitude coordinate in the DataArray, by default 'lat'.
    **kwargs : Any
        Keyword arguments passed to :class:`SpatialPlot`. These control
        the map projection, figure size, and cartopy features. For example:
        `projection=ccrs.LambertConformal()`, `figsize=(10, 8)`,
        `states=True`, `extent=[-125, -70, 25, 50]`.
    """
    if not isinstance(data, xr.DataArray):
        raise TypeError("Input 'data' must be an xarray.DataArray.")
    if lon_coord not in data.coords:
        raise ValueError(
            f"Longitude coordinate '{lon_coord}' not found in DataArray."
        )
    if lat_coord not in data.coords:
        raise ValueError(
            f"Latitude coordinate '{lat_coord}' not found in DataArray."
        )

    # Initialize the parent SpatialPlot, which creates the map canvas
    # and draws features from the keyword arguments.
    super().__init__(**kwargs)

    # Set data and update history for provenance
    self.data = data
    self.lon_coord = lon_coord
    self.lat_coord = lat_coord
    _update_history(self.data, "Plotted with monet-plots.SpatialTrack")

plot(**kwargs)

Plot the trajectory on the map.

The track is rendered as a scatter plot, where each point is colored according to the data values.

Parameters

**kwargs : Any Keyword arguments passed to matplotlib.pyplot.scatter. Common options include cmap, s (size), and alpha. A transform keyword (e.g., transform=ccrs.PlateCarree()) is highly recommended for geospatial accuracy. The cmap argument can be a string, a Colormap object, or a (colormap, norm) tuple from the scaling tools in colorbars.py. Map features (e.g., coastlines=True) can also be passed here.

Returns

plt.Artist The scatter plot artist created by ax.scatter.

Source code in src/monet_plots/plots/spatial.py
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
def plot(self, **kwargs: Any) -> plt.Artist:
    """Plot the trajectory on the map.

    The track is rendered as a scatter plot, where each point is colored
    according to the `data` values.

    Parameters
    ----------
    **kwargs : Any
        Keyword arguments passed to `matplotlib.pyplot.scatter`.
        Common options include `cmap`, `s` (size), and `alpha`.
        A `transform` keyword (e.g., `transform=ccrs.PlateCarree()`)
        is highly recommended for geospatial accuracy.
        The `cmap` argument can be a string, a Colormap object, or a
        (colormap, norm) tuple from the scaling tools in `colorbars.py`.
        Map features (e.g., `coastlines=True`) can also be passed here.

    Returns
    -------
    plt.Artist
        The scatter plot artist created by `ax.scatter`.
    """
    from ..plot_utils import get_plot_kwargs

    # Automatically compute extent if not provided
    if "extent" not in kwargs:
        kwargs["extent"] = self._get_extent_from_data(
            self.data, self.lon_coord, self.lat_coord, buffer=0.1
        )

    # Add features and get remaining kwargs for scatter
    scatter_kwargs = self.add_features(**kwargs)

    scatter_kwargs.setdefault("transform", ccrs.PlateCarree())

    # For coordinates and values, we pass the xarray objects directly.
    # This allows Matplotlib to handle the conversion, maintaining
    # compatibility with existing tests that check for lazy objects.
    longitude = self.data[self.lon_coord]
    latitude = self.data[self.lat_coord]

    # Use get_plot_kwargs to handle (cmap, norm) tuples
    final_kwargs = get_plot_kwargs(c=self.data, **scatter_kwargs)

    sc = self.ax.scatter(longitude, latitude, **final_kwargs)
    return sc

SpreadSkillPlot

Bases: BasePlot

Create a spread-skill plot to evaluate ensemble forecast reliability.

This plot compares the standard deviation of the ensemble spread to the root mean squared error (RMSE) of the ensemble mean. A reliable ensemble should have a spread that is proportional to the forecast error.

Source code in src/monet_plots/plots/ensemble.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
class SpreadSkillPlot(BasePlot):
    """Create a spread-skill plot to evaluate ensemble forecast reliability.

    This plot compares the standard deviation of the ensemble spread to the
    root mean squared error (RMSE) of the ensemble mean. A reliable ensemble
    should have a spread that is proportional to the forecast error.
    """

    def __init__(self, spread, skill, *args, **kwargs):
        """
        Initialize the plot with spread and skill data.

        Args:
            spread (array-like): The standard deviation of the ensemble forecast.
            skill (array-like): The root mean squared error of the ensemble mean.
        """
        super().__init__(*args, **kwargs)
        self.spread = np.asarray(spread)
        self.skill = np.asarray(skill)

    def plot(self, **kwargs):
        """Generate the spread-skill plot.

        Additional keyword arguments are passed to the scatter plot.
        """
        # Plot the spread-skill pairs
        self.ax.scatter(self.spread, self.skill, **kwargs)

        # Add a 1:1 reference line
        max_val = max(np.max(self.spread), np.max(self.skill))
        self.ax.plot([0, max_val], [0, max_val], "k--")

        # Set labels and title
        self.ax.set_xlabel("Ensemble Spread (Standard Deviation)")
        self.ax.set_ylabel("Ensemble Error (RMSE)")
        self.ax.set_title("Spread-Skill Plot")

        # Ensure aspect ratio is equal
        self.ax.set_aspect("equal", "box")

        return self.ax

__init__(spread, skill, *args, **kwargs)

Initialize the plot with spread and skill data.

Parameters:

Name Type Description Default
spread array - like

The standard deviation of the ensemble forecast.

required
skill array - like

The root mean squared error of the ensemble mean.

required
Source code in src/monet_plots/plots/ensemble.py
14
15
16
17
18
19
20
21
22
23
24
def __init__(self, spread, skill, *args, **kwargs):
    """
    Initialize the plot with spread and skill data.

    Args:
        spread (array-like): The standard deviation of the ensemble forecast.
        skill (array-like): The root mean squared error of the ensemble mean.
    """
    super().__init__(*args, **kwargs)
    self.spread = np.asarray(spread)
    self.skill = np.asarray(skill)

plot(**kwargs)

Generate the spread-skill plot.

Additional keyword arguments are passed to the scatter plot.

Source code in src/monet_plots/plots/ensemble.py
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def plot(self, **kwargs):
    """Generate the spread-skill plot.

    Additional keyword arguments are passed to the scatter plot.
    """
    # Plot the spread-skill pairs
    self.ax.scatter(self.spread, self.skill, **kwargs)

    # Add a 1:1 reference line
    max_val = max(np.max(self.spread), np.max(self.skill))
    self.ax.plot([0, max_val], [0, max_val], "k--")

    # Set labels and title
    self.ax.set_xlabel("Ensemble Spread (Standard Deviation)")
    self.ax.set_ylabel("Ensemble Error (RMSE)")
    self.ax.set_title("Spread-Skill Plot")

    # Ensure aspect ratio is equal
    self.ax.set_aspect("equal", "box")

    return self.ax

StickPlot

Bases: BasePlot

Vertical stick plot.

Source code in src/monet_plots/plots/profile.py
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
class StickPlot(BasePlot):
    """Vertical stick plot."""

    def __init__(self, u, v, y, *args, **kwargs):
        """
        Initialize the stick plot.
        Args:
            u (np.ndarray, pd.Series, xr.DataArray): U-component of the vector.
            v (np.ndarray, pd.Series, xr.DataArray): V-component of the vector.
            y (np.ndarray, pd.Series, xr.DataArray): Vertical coordinate.
            **kwargs: Additional keyword arguments passed to BasePlot.
        """
        super().__init__(*args, **kwargs)
        self.u = u
        self.v = v
        self.y = y
        self.x = np.zeros_like(self.y)

    def plot(self, **kwargs: t.Any) -> None:
        """
        Parameters
        ----------
        **kwargs
            Keyword arguments passed to `matplotlib.pyplot.quiver`.
        """
        if self.ax is None:
            if self.fig is None:
                self.fig = plt.figure()
            self.ax = self.fig.add_subplot()

        return self.ax.quiver(self.x, self.y, self.u, self.v, **kwargs)

__init__(u, v, y, *args, **kwargs)

Initialize the stick plot. Args: u (np.ndarray, pd.Series, xr.DataArray): U-component of the vector. v (np.ndarray, pd.Series, xr.DataArray): V-component of the vector. y (np.ndarray, pd.Series, xr.DataArray): Vertical coordinate. **kwargs: Additional keyword arguments passed to BasePlot.

Source code in src/monet_plots/plots/profile.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def __init__(self, u, v, y, *args, **kwargs):
    """
    Initialize the stick plot.
    Args:
        u (np.ndarray, pd.Series, xr.DataArray): U-component of the vector.
        v (np.ndarray, pd.Series, xr.DataArray): V-component of the vector.
        y (np.ndarray, pd.Series, xr.DataArray): Vertical coordinate.
        **kwargs: Additional keyword arguments passed to BasePlot.
    """
    super().__init__(*args, **kwargs)
    self.u = u
    self.v = v
    self.y = y
    self.x = np.zeros_like(self.y)

plot(**kwargs)

Parameters

**kwargs Keyword arguments passed to matplotlib.pyplot.quiver.

Source code in src/monet_plots/plots/profile.py
107
108
109
110
111
112
113
114
115
116
117
118
119
def plot(self, **kwargs: t.Any) -> None:
    """
    Parameters
    ----------
    **kwargs
        Keyword arguments passed to `matplotlib.pyplot.quiver`.
    """
    if self.ax is None:
        if self.fig is None:
            self.fig = plt.figure()
        self.ax = self.fig.add_subplot()

    return self.ax.quiver(self.x, self.y, self.u, self.v, **kwargs)

TaylorDiagramPlot

Bases: BasePlot

Create a DataFrame-based Taylor diagram.

A convenience wrapper for easily creating Taylor diagrams from DataFrames.

Source code in src/monet_plots/plots/taylor_diagram.py
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
class TaylorDiagramPlot(BasePlot):
    """Create a DataFrame-based Taylor diagram.

    A convenience wrapper for easily creating Taylor diagrams from DataFrames.
    """

    def __init__(
        self,
        df: Any,
        col1: str = "obs",
        col2: Union[str, List[str]] = "model",
        label1: str = "OBS",
        scale: float = 1.5,
        dia=None,
        *args,
        **kwargs,
    ):
        """
        Initialize the plot with data and diagram settings.

        Args:
            df (pd.DataFrame, np.ndarray, xr.Dataset, xr.DataArray): DataFrame with observation and model data.
            col1 (str): Column name for observations.
            col2 (str or list): Column name(s) for model predictions.
            label1 (str): Label for observations.
            scale (float): Scale factor for diagram.
            dia (TaylorDiagram, optional): Existing diagram to add to.
        """
        super().__init__(*args, **kwargs)
        self.col1 = col1
        if isinstance(col2, str):
            self.col2 = [col2]
        else:
            self.col2 = col2

        # Ensure all specified columns exist before proceeding
        required_cols = [self.col1] + self.col2
        self.df = to_dataframe(df).dropna(subset=required_cols)

        self.label1 = label1
        self.scale = scale
        self.dia = dia

    def plot(self, **kwargs):
        """Generate the Taylor diagram."""
        # If no diagram is provided, create a new one
        if self.dia is None:
            obsstd = self.df[self.col1].std()

            # Remove the default axes created by BasePlot to avoid an extra empty plot
            if hasattr(self, "ax") and self.ax is not None:
                self.fig.delaxes(self.ax)

            # Use self.fig which is created in BasePlot.__init__
            self.dia = td.TaylorDiagram(
                obsstd, scale=self.scale, fig=self.fig, rect=111, label=self.label1
            )
            # Update self.ax to the one created by TaylorDiagram
            self.ax = self.dia._ax

            # Add contours for the new diagram
            contours = self.dia.add_contours(colors="0.5")
            plt.clabel(contours, inline=1, fontsize=10)

        # Loop through each model column and add it to the diagram
        for model_col in self.col2:
            model_std = self.df[model_col].std()
            cc = corrcoef(self.df[self.col1].values, self.df[model_col].values)[0, 1]
            self.dia.add_sample(model_std, cc, label=model_col, **kwargs)

        self.fig.legend(
            self.dia.samplePoints,
            [p.get_label() for p in self.dia.samplePoints],
            numpoints=1,
            loc="upper right",
        )
        self.fig.tight_layout()
        return self.dia

__init__(df, col1='obs', col2='model', label1='OBS', scale=1.5, dia=None, *args, **kwargs)

Initialize the plot with data and diagram settings.

Parameters:

Name Type Description Default
df (DataFrame, ndarray, Dataset, DataArray)

DataFrame with observation and model data.

required
col1 str

Column name for observations.

'obs'
col2 str or list

Column name(s) for model predictions.

'model'
label1 str

Label for observations.

'OBS'
scale float

Scale factor for diagram.

1.5
dia TaylorDiagram

Existing diagram to add to.

None
Source code in src/monet_plots/plots/taylor_diagram.py
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
def __init__(
    self,
    df: Any,
    col1: str = "obs",
    col2: Union[str, List[str]] = "model",
    label1: str = "OBS",
    scale: float = 1.5,
    dia=None,
    *args,
    **kwargs,
):
    """
    Initialize the plot with data and diagram settings.

    Args:
        df (pd.DataFrame, np.ndarray, xr.Dataset, xr.DataArray): DataFrame with observation and model data.
        col1 (str): Column name for observations.
        col2 (str or list): Column name(s) for model predictions.
        label1 (str): Label for observations.
        scale (float): Scale factor for diagram.
        dia (TaylorDiagram, optional): Existing diagram to add to.
    """
    super().__init__(*args, **kwargs)
    self.col1 = col1
    if isinstance(col2, str):
        self.col2 = [col2]
    else:
        self.col2 = col2

    # Ensure all specified columns exist before proceeding
    required_cols = [self.col1] + self.col2
    self.df = to_dataframe(df).dropna(subset=required_cols)

    self.label1 = label1
    self.scale = scale
    self.dia = dia

plot(**kwargs)

Generate the Taylor diagram.

Source code in src/monet_plots/plots/taylor_diagram.py
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
def plot(self, **kwargs):
    """Generate the Taylor diagram."""
    # If no diagram is provided, create a new one
    if self.dia is None:
        obsstd = self.df[self.col1].std()

        # Remove the default axes created by BasePlot to avoid an extra empty plot
        if hasattr(self, "ax") and self.ax is not None:
            self.fig.delaxes(self.ax)

        # Use self.fig which is created in BasePlot.__init__
        self.dia = td.TaylorDiagram(
            obsstd, scale=self.scale, fig=self.fig, rect=111, label=self.label1
        )
        # Update self.ax to the one created by TaylorDiagram
        self.ax = self.dia._ax

        # Add contours for the new diagram
        contours = self.dia.add_contours(colors="0.5")
        plt.clabel(contours, inline=1, fontsize=10)

    # Loop through each model column and add it to the diagram
    for model_col in self.col2:
        model_std = self.df[model_col].std()
        cc = corrcoef(self.df[self.col1].values, self.df[model_col].values)[0, 1]
        self.dia.add_sample(model_std, cc, label=model_col, **kwargs)

    self.fig.legend(
        self.dia.samplePoints,
        [p.get_label() for p in self.dia.samplePoints],
        numpoints=1,
        loc="upper right",
    )
    self.fig.tight_layout()
    return self.dia

TimeSeriesPlot

Bases: BasePlot

Create a timeseries plot with shaded error bounds.

This function groups the data by time, plots the mean values, and adds shading for ±1 standard deviation around the mean.

Source code in src/monet_plots/plots/timeseries.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
class TimeSeriesPlot(BasePlot):
    """Create a timeseries plot with shaded error bounds.

    This function groups the data by time, plots the mean values, and adds
    shading for ±1 standard deviation around the mean.
    """

    def __init__(
        self,
        df: Any,
        x: str = "time",
        y: str = "obs",
        plotargs: dict = {},
        fillargs: dict = None,
        title: str = "",
        ylabel: Optional[str] = None,
        label: Optional[str] = None,
        *args,
        **kwargs,
    ):
        """
        Initialize the plot with data and plot settings.

        Args:
            df (pd.DataFrame, np.ndarray, xr.Dataset, xr.DataArray):
                DataFrame with the data to plot.
            x (str): Column name for the x-axis (time).
            y (str): Column name for the y-axis (values).
            plotargs (dict): Arguments for the plot.
            fillargs (dict): Arguments for fill_between.
            title (str): Title for the plot.
            ylabel (str, optional): Y-axis label.
            label (str, optional): Label for the plotted line.
            *args, **kwargs: Arguments passed to BasePlot.
        """
        super().__init__(*args, **kwargs)
        if self.ax is None:
            self.ax = self.fig.add_subplot(1, 1, 1)

        self.df = normalize_data(df, prefer_xarray=False)
        self.x = x
        self.y = y
        self.plotargs = plotargs
        self.fillargs = fillargs if fillargs is not None else {"alpha": 0.2}
        self.title = title
        self.ylabel = ylabel
        self.label = label

    def plot(self, **kwargs: Any) -> plt.Axes:
        """
        Generate the timeseries plot.

        Parameters
        ----------
        **kwargs : Any
            Overrides for plot settings (x, y, title, ylabel, label, etc.).

        Returns
        -------
        plt.Axes
            The matplotlib axes object containing the plot.

        Examples
        --------
        >>> plot = TimeSeriesPlot(df, x='time', y='obs')
        >>> ax = plot.plot(title='Observation Over Time')
        """
        # Update attributes from kwargs if provided
        for attr in ["x", "y", "title", "ylabel", "label"]:
            if attr in kwargs:
                setattr(self, attr, kwargs.pop(attr))

        import xarray as xr

        # Handle xarray objects differently from pandas DataFrames
        if isinstance(self.df, (xr.DataArray, xr.Dataset)):
            return self._plot_xarray(**kwargs)
        else:
            return self._plot_dataframe(**kwargs)

    def _plot_dataframe(self, **kwargs: Any) -> plt.Axes:
        """
        Generate the timeseries plot from pandas DataFrame.

        Parameters
        ----------
        **kwargs : Any
            Additional plotting arguments.

        Returns
        -------
        plt.Axes
            The matplotlib axes object.

        Examples
        --------
        >>> plot._plot_dataframe()
        """
        df = self.df.copy()
        df.index = df[self.x]
        # Keep only numeric columns for grouping, but make sure self.y is there
        df = df.reset_index(drop=True)
        # We need to preserve self.x for grouping if it's not the index
        m = self.df.groupby(self.x).mean(numeric_only=True)
        e = self.df.groupby(self.x).std(numeric_only=True)

        variable = self.y
        unit = "None"
        if "units" in self.df.columns:
            unit = str(self.df["units"].iloc[0])

        upper = m[self.y] + e[self.y]
        lower = m[self.y] - e[self.y]
        # lower.loc[lower < 0] = 0 # Not always desired for all variables
        lower_vals = lower.values
        upper_vals = upper.values

        if self.label is not None:
            plot_label = self.label
        else:
            plot_label = self.y

        m[self.y].plot(ax=self.ax, label=plot_label, **self.plotargs)
        self.ax.fill_between(m.index, lower_vals, upper_vals, **self.fillargs)

        if self.ylabel is None:
            self.ax.set_ylabel(f"{variable} ({unit})")
        else:
            self.ax.set_ylabel(self.ylabel)

        self.ax.set_xlabel(self.x)
        self.ax.legend()
        self.ax.set_title(self.title)
        self.fig.tight_layout()
        return self.ax

    def _plot_xarray(self, **kwargs: Any) -> plt.Axes:
        """
        Generate the timeseries plot from xarray DataArray or Dataset.

        Parameters
        ----------
        **kwargs : Any
            Additional plotting arguments.

        Returns
        -------
        plt.Axes
            The matplotlib axes object.

        Examples
        --------
        >>> plot._plot_xarray()
        """
        import xarray as xr

        # Ensure we have the right data structure
        if isinstance(self.df, xr.DataArray):
            data = (
                self.df.to_dataset(name=self.y)
                if self.df.name is None
                else self.df.to_dataset()
            )
            if self.df.name is not None:
                self.y = self.df.name
        else:
            data = self.df

        # Calculate mean and std along other dimensions if any
        # If it's already a 1D time series, mean/std won't do much
        dims_to_reduce = [d for d in data[self.y].dims if d != self.x]

        if dims_to_reduce:
            mean_data = data[self.y].mean(dim=dims_to_reduce)
            std_data = data[self.y].std(dim=dims_to_reduce)
        else:
            mean_data = data[self.y]
            std_data = xr.zeros_like(mean_data)

        plot_label = self.label if self.label is not None else self.y
        mean_data.plot(ax=self.ax, label=plot_label, **self.plotargs)

        upper = mean_data + std_data
        lower = mean_data - std_data

        self.ax.fill_between(
            mean_data[self.x].values, lower.values, upper.values, **self.fillargs
        )

        unit = data[self.y].attrs.get("units", "None")

        if self.ylabel is None:
            self.ax.set_ylabel(f"{self.y} ({unit})")
        else:
            self.ax.set_ylabel(self.ylabel)

        self.ax.set_xlabel(self.x)
        self.ax.legend()
        self.ax.set_title(self.title)
        self.fig.tight_layout()
        return self.ax

__init__(df, x='time', y='obs', plotargs={}, fillargs=None, title='', ylabel=None, label=None, *args, **kwargs)

Initialize the plot with data and plot settings.

Parameters:

Name Type Description Default
df (DataFrame, ndarray, Dataset, DataArray)

DataFrame with the data to plot.

required
x str

Column name for the x-axis (time).

'time'
y str

Column name for the y-axis (values).

'obs'
plotargs dict

Arguments for the plot.

{}
fillargs dict

Arguments for fill_between.

None
title str

Title for the plot.

''
ylabel str

Y-axis label.

None
label str

Label for the plotted line.

None
*args, **kwargs

Arguments passed to BasePlot.

required
Source code in src/monet_plots/plots/timeseries.py
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
def __init__(
    self,
    df: Any,
    x: str = "time",
    y: str = "obs",
    plotargs: dict = {},
    fillargs: dict = None,
    title: str = "",
    ylabel: Optional[str] = None,
    label: Optional[str] = None,
    *args,
    **kwargs,
):
    """
    Initialize the plot with data and plot settings.

    Args:
        df (pd.DataFrame, np.ndarray, xr.Dataset, xr.DataArray):
            DataFrame with the data to plot.
        x (str): Column name for the x-axis (time).
        y (str): Column name for the y-axis (values).
        plotargs (dict): Arguments for the plot.
        fillargs (dict): Arguments for fill_between.
        title (str): Title for the plot.
        ylabel (str, optional): Y-axis label.
        label (str, optional): Label for the plotted line.
        *args, **kwargs: Arguments passed to BasePlot.
    """
    super().__init__(*args, **kwargs)
    if self.ax is None:
        self.ax = self.fig.add_subplot(1, 1, 1)

    self.df = normalize_data(df, prefer_xarray=False)
    self.x = x
    self.y = y
    self.plotargs = plotargs
    self.fillargs = fillargs if fillargs is not None else {"alpha": 0.2}
    self.title = title
    self.ylabel = ylabel
    self.label = label

plot(**kwargs)

Generate the timeseries plot.

Parameters

**kwargs : Any Overrides for plot settings (x, y, title, ylabel, label, etc.).

Returns

plt.Axes The matplotlib axes object containing the plot.

Examples

plot = TimeSeriesPlot(df, x='time', y='obs') ax = plot.plot(title='Observation Over Time')

Source code in src/monet_plots/plots/timeseries.py
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
def plot(self, **kwargs: Any) -> plt.Axes:
    """
    Generate the timeseries plot.

    Parameters
    ----------
    **kwargs : Any
        Overrides for plot settings (x, y, title, ylabel, label, etc.).

    Returns
    -------
    plt.Axes
        The matplotlib axes object containing the plot.

    Examples
    --------
    >>> plot = TimeSeriesPlot(df, x='time', y='obs')
    >>> ax = plot.plot(title='Observation Over Time')
    """
    # Update attributes from kwargs if provided
    for attr in ["x", "y", "title", "ylabel", "label"]:
        if attr in kwargs:
            setattr(self, attr, kwargs.pop(attr))

    import xarray as xr

    # Handle xarray objects differently from pandas DataFrames
    if isinstance(self.df, (xr.DataArray, xr.Dataset)):
        return self._plot_xarray(**kwargs)
    else:
        return self._plot_dataframe(**kwargs)

TrajectoryPlot

Bases: BasePlot

Plot a trajectory on a map and a timeseries of a variable.

Source code in src/monet_plots/plots/trajectory.py
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
class TrajectoryPlot(BasePlot):
    """Plot a trajectory on a map and a timeseries of a variable."""

    def __init__(
        self,
        longitude: t.Any,
        latitude: t.Any,
        data: t.Any,
        time: t.Any,
        ts_data: t.Any,
        *args: t.Any,
        **kwargs: t.Any,
    ) -> None:
        """
        Initialize the trajectory plot.
        Args:
            longitude: Longitude values for the spatial track.
            latitude: Latitude values for the spatial track.
            data: Data to use for coloring the track.
            time: Time values for the timeseries or a DataFrame.
            ts_data: Data for the timeseries or column name if time is a DataFrame.
            *args: Additional positional arguments.
            **kwargs: Additional keyword arguments.
        """
        if "fig" not in kwargs and "ax" not in kwargs:
            kwargs["fig"] = plt.figure()
        super().__init__(*args, **kwargs)
        self.longitude = longitude
        self.latitude = latitude
        self.data = data
        self.time = time
        self.ts_data = ts_data

    def plot(self, **kwargs: t.Any) -> None:
        """Plot the trajectory and timeseries.

        Args:
            **kwargs: Keyword arguments passed to the plot methods.
        """
        if self.fig is None:
            self.fig = plt.figure(figsize=kwargs.get("figsize", (12, 6)))

        # Ensure constrained_layout to help with alignment
        self.fig.set_constrained_layout(True)

        gs = self.fig.add_gridspec(2, 1, height_ratios=[3, 1])

        # Spatial track plot
        import cartopy.crs as ccrs

        proj = kwargs.get("projection", ccrs.PlateCarree())
        ax0 = self.fig.add_subplot(gs[0, 0], projection=proj)

        # Set adjustable to 'datalim' to allow the map to fill the horizontal
        # space while maintaining equal aspect ratio by expanding the limits.
        ax0.set_adjustable("datalim")

        # Create an xarray.DataArray for the trajectory data
        lon = np.asarray(self.longitude)
        lat = np.asarray(self.latitude)
        values = np.asarray(self.data)
        time_dim = np.arange(len(lon))
        coords = {"time": time_dim, "lon": ("time", lon), "lat": ("time", lat)}
        track_da = xr.DataArray(values, dims=["time"], coords=coords, name="track_data")

        # Pass the DataArray to SpatialTrack
        plot_kwargs = kwargs.get("spatial_track_kwargs", {})
        spatial_track = SpatialTrack(data=track_da, ax=ax0)
        spatial_track.plot(**plot_kwargs)

        # Timeseries plot
        ax1 = self.fig.add_subplot(gs[1, 0])

        timeseries_kwargs = kwargs.get("timeseries_kwargs", {}).copy()

        if isinstance(self.time, pd.DataFrame):
            # Already a DataFrame
            timeseries = TimeSeriesPlot(
                df=self.time, y=self.ts_data, ax=ax1, fig=self.fig
            )
        else:
            # Assume arrays
            ts_df = pd.DataFrame({"time": self.time, "value": np.asarray(self.ts_data)})
            timeseries = TimeSeriesPlot(
                df=ts_df, x="time", y="value", ax=ax1, fig=self.fig
            )

        timeseries.plot(**timeseries_kwargs)

        self.ax = [ax0, ax1]
        return self.ax

__init__(longitude, latitude, data, time, ts_data, *args, **kwargs)

Initialize the trajectory plot. Args: longitude: Longitude values for the spatial track. latitude: Latitude values for the spatial track. data: Data to use for coloring the track. time: Time values for the timeseries or a DataFrame. ts_data: Data for the timeseries or column name if time is a DataFrame. args: Additional positional arguments. *kwargs: Additional keyword arguments.

Source code in src/monet_plots/plots/trajectory.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
def __init__(
    self,
    longitude: t.Any,
    latitude: t.Any,
    data: t.Any,
    time: t.Any,
    ts_data: t.Any,
    *args: t.Any,
    **kwargs: t.Any,
) -> None:
    """
    Initialize the trajectory plot.
    Args:
        longitude: Longitude values for the spatial track.
        latitude: Latitude values for the spatial track.
        data: Data to use for coloring the track.
        time: Time values for the timeseries or a DataFrame.
        ts_data: Data for the timeseries or column name if time is a DataFrame.
        *args: Additional positional arguments.
        **kwargs: Additional keyword arguments.
    """
    if "fig" not in kwargs and "ax" not in kwargs:
        kwargs["fig"] = plt.figure()
    super().__init__(*args, **kwargs)
    self.longitude = longitude
    self.latitude = latitude
    self.data = data
    self.time = time
    self.ts_data = ts_data

plot(**kwargs)

Plot the trajectory and timeseries.

Parameters:

Name Type Description Default
**kwargs Any

Keyword arguments passed to the plot methods.

{}
Source code in src/monet_plots/plots/trajectory.py
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
def plot(self, **kwargs: t.Any) -> None:
    """Plot the trajectory and timeseries.

    Args:
        **kwargs: Keyword arguments passed to the plot methods.
    """
    if self.fig is None:
        self.fig = plt.figure(figsize=kwargs.get("figsize", (12, 6)))

    # Ensure constrained_layout to help with alignment
    self.fig.set_constrained_layout(True)

    gs = self.fig.add_gridspec(2, 1, height_ratios=[3, 1])

    # Spatial track plot
    import cartopy.crs as ccrs

    proj = kwargs.get("projection", ccrs.PlateCarree())
    ax0 = self.fig.add_subplot(gs[0, 0], projection=proj)

    # Set adjustable to 'datalim' to allow the map to fill the horizontal
    # space while maintaining equal aspect ratio by expanding the limits.
    ax0.set_adjustable("datalim")

    # Create an xarray.DataArray for the trajectory data
    lon = np.asarray(self.longitude)
    lat = np.asarray(self.latitude)
    values = np.asarray(self.data)
    time_dim = np.arange(len(lon))
    coords = {"time": time_dim, "lon": ("time", lon), "lat": ("time", lat)}
    track_da = xr.DataArray(values, dims=["time"], coords=coords, name="track_data")

    # Pass the DataArray to SpatialTrack
    plot_kwargs = kwargs.get("spatial_track_kwargs", {})
    spatial_track = SpatialTrack(data=track_da, ax=ax0)
    spatial_track.plot(**plot_kwargs)

    # Timeseries plot
    ax1 = self.fig.add_subplot(gs[1, 0])

    timeseries_kwargs = kwargs.get("timeseries_kwargs", {}).copy()

    if isinstance(self.time, pd.DataFrame):
        # Already a DataFrame
        timeseries = TimeSeriesPlot(
            df=self.time, y=self.ts_data, ax=ax1, fig=self.fig
        )
    else:
        # Assume arrays
        ts_df = pd.DataFrame({"time": self.time, "value": np.asarray(self.ts_data)})
        timeseries = TimeSeriesPlot(
            df=ts_df, x="time", y="value", ax=ax1, fig=self.fig
        )

    timeseries.plot(**timeseries_kwargs)

    self.ax = [ax0, ax1]
    return self.ax

VerticalBoxPlot

Bases: BasePlot

Vertical box plot.

Source code in src/monet_plots/plots/profile.py
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
class VerticalBoxPlot(BasePlot):
    """Vertical box plot."""

    def __init__(self, data, y, thresholds, *args, **kwargs):
        """
        Initialize the vertical box plot.
        Args:
            data (np.ndarray, pd.Series, xr.DataArray): Data to plot.
            y (np.ndarray, pd.Series, xr.DataArray): Vertical coordinate.
            thresholds (list): List of thresholds to bin the data.
            **kwargs: Additional keyword arguments passed to BasePlot.
        """
        super().__init__(*args, **kwargs)
        self.data = data
        self.y = y
        self.thresholds = thresholds

    def plot(self, **kwargs: t.Any) -> None:
        """
        Parameters
        ----------
        **kwargs
            Keyword arguments passed to `matplotlib.pyplot.boxplot`.
        """
        if self.ax is None:
            if self.fig is None:
                self.fig = plt.figure()
            self.ax = self.fig.add_subplot()

        output_list = tools.split_by_threshold(self.data, self.y, self.thresholds)
        position_list_1 = self.thresholds[:-1]
        position_list_2 = self.thresholds[1:]
        position_list_mid = [
            (p1 + p2) / 2 for p1, p2 in zip(position_list_1, position_list_2)
        ]

        return self.ax.boxplot(
            output_list, vert=False, positions=position_list_mid, **kwargs
        )

__init__(data, y, thresholds, *args, **kwargs)

Initialize the vertical box plot. Args: data (np.ndarray, pd.Series, xr.DataArray): Data to plot. y (np.ndarray, pd.Series, xr.DataArray): Vertical coordinate. thresholds (list): List of thresholds to bin the data. **kwargs: Additional keyword arguments passed to BasePlot.

Source code in src/monet_plots/plots/profile.py
125
126
127
128
129
130
131
132
133
134
135
136
137
def __init__(self, data, y, thresholds, *args, **kwargs):
    """
    Initialize the vertical box plot.
    Args:
        data (np.ndarray, pd.Series, xr.DataArray): Data to plot.
        y (np.ndarray, pd.Series, xr.DataArray): Vertical coordinate.
        thresholds (list): List of thresholds to bin the data.
        **kwargs: Additional keyword arguments passed to BasePlot.
    """
    super().__init__(*args, **kwargs)
    self.data = data
    self.y = y
    self.thresholds = thresholds

plot(**kwargs)

Parameters

**kwargs Keyword arguments passed to matplotlib.pyplot.boxplot.

Source code in src/monet_plots/plots/profile.py
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
def plot(self, **kwargs: t.Any) -> None:
    """
    Parameters
    ----------
    **kwargs
        Keyword arguments passed to `matplotlib.pyplot.boxplot`.
    """
    if self.ax is None:
        if self.fig is None:
            self.fig = plt.figure()
        self.ax = self.fig.add_subplot()

    output_list = tools.split_by_threshold(self.data, self.y, self.thresholds)
    position_list_1 = self.thresholds[:-1]
    position_list_2 = self.thresholds[1:]
    position_list_mid = [
        (p1 + p2) / 2 for p1, p2 in zip(position_list_1, position_list_2)
    ]

    return self.ax.boxplot(
        output_list, vert=False, positions=position_list_mid, **kwargs
    )

VerticalSlice

Bases: ProfilePlot

Vertical cross-section plot.

Source code in src/monet_plots/plots/profile.py
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
class VerticalSlice(ProfilePlot):
    """Vertical cross-section plot."""

    def __init__(self, *args, **kwargs):
        """
        Initialize the vertical slice plot.
        """
        super().__init__(*args, **kwargs)

    def plot(self, **kwargs: t.Any) -> None:
        """
        Parameters
        ----------
        **kwargs
            Keyword arguments passed to `matplotlib.pyplot.contourf`.
        """
        if self.ax is None:
            if self.fig is None:
                self.fig = plt.figure()
            self.ax = self.fig.add_subplot()

        self.ax.contourf(self.x, self.y, self.z, **kwargs)

__init__(*args, **kwargs)

Initialize the vertical slice plot.

Source code in src/monet_plots/plots/profile.py
68
69
70
71
72
def __init__(self, *args, **kwargs):
    """
    Initialize the vertical slice plot.
    """
    super().__init__(*args, **kwargs)

plot(**kwargs)

Parameters

**kwargs Keyword arguments passed to matplotlib.pyplot.contourf.

Source code in src/monet_plots/plots/profile.py
74
75
76
77
78
79
80
81
82
83
84
85
86
def plot(self, **kwargs: t.Any) -> None:
    """
    Parameters
    ----------
    **kwargs
        Keyword arguments passed to `matplotlib.pyplot.contourf`.
    """
    if self.ax is None:
        if self.fig is None:
            self.fig = plt.figure()
        self.ax = self.fig.add_subplot()

    self.ax.contourf(self.x, self.y, self.z, **kwargs)

WindBarbsPlot

Bases: SpatialPlot

Create a barbs plot of wind on a map.

This plot shows wind speed and direction using barbs.

Source code in src/monet_plots/plots/wind_barbs.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
class WindBarbsPlot(SpatialPlot):
    """Create a barbs plot of wind on a map.

    This plot shows wind speed and direction using barbs.
    """

    def __init__(self, ws: Any, wdir: Any, gridobj, *args, **kwargs):
        """
        Initialize the plot with data and map projection.

        Args:
            ws (np.ndarray, pd.DataFrame, pd.Series, xr.DataArray): 2D array of wind speeds.
            wdir (np.ndarray, pd.DataFrame, pd.Series, xr.DataArray): 2D array of wind directions.
            gridobj (object): Object with LAT and LON variables.
            **kwargs: Keyword arguments passed to SpatialPlot for projection and features.
        """
        super().__init__(*args, **kwargs)
        self.ws = np.asarray(ws)
        self.wdir = np.asarray(wdir)
        self.gridobj = gridobj

    def plot(self, **kwargs):
        """Generate the wind barbs plot."""
        barb_kwargs = self.add_features(**kwargs)
        barb_kwargs.setdefault("transform", ccrs.PlateCarree())

        lat = self.gridobj.variables["LAT"][0, 0, :, :].squeeze()
        lon = self.gridobj.variables["LON"][0, 0, :, :].squeeze()
        u, v = tools.wsdir2uv(self.ws, self.wdir)
        # Subsample the data for clarity
        skip = barb_kwargs.pop("skip", 15)
        self.ax.barbs(
            lon[::skip, ::skip],
            lat[::skip, ::skip],
            u[::skip, ::skip],
            v[::skip, ::skip],
            **barb_kwargs,
        )
        return self.ax

__init__(ws, wdir, gridobj, *args, **kwargs)

Initialize the plot with data and map projection.

Parameters:

Name Type Description Default
ws (ndarray, DataFrame, Series, DataArray)

2D array of wind speeds.

required
wdir (ndarray, DataFrame, Series, DataArray)

2D array of wind directions.

required
gridobj object

Object with LAT and LON variables.

required
**kwargs

Keyword arguments passed to SpatialPlot for projection and features.

{}
Source code in src/monet_plots/plots/wind_barbs.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def __init__(self, ws: Any, wdir: Any, gridobj, *args, **kwargs):
    """
    Initialize the plot with data and map projection.

    Args:
        ws (np.ndarray, pd.DataFrame, pd.Series, xr.DataArray): 2D array of wind speeds.
        wdir (np.ndarray, pd.DataFrame, pd.Series, xr.DataArray): 2D array of wind directions.
        gridobj (object): Object with LAT and LON variables.
        **kwargs: Keyword arguments passed to SpatialPlot for projection and features.
    """
    super().__init__(*args, **kwargs)
    self.ws = np.asarray(ws)
    self.wdir = np.asarray(wdir)
    self.gridobj = gridobj

plot(**kwargs)

Generate the wind barbs plot.

Source code in src/monet_plots/plots/wind_barbs.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
def plot(self, **kwargs):
    """Generate the wind barbs plot."""
    barb_kwargs = self.add_features(**kwargs)
    barb_kwargs.setdefault("transform", ccrs.PlateCarree())

    lat = self.gridobj.variables["LAT"][0, 0, :, :].squeeze()
    lon = self.gridobj.variables["LON"][0, 0, :, :].squeeze()
    u, v = tools.wsdir2uv(self.ws, self.wdir)
    # Subsample the data for clarity
    skip = barb_kwargs.pop("skip", 15)
    self.ax.barbs(
        lon[::skip, ::skip],
        lat[::skip, ::skip],
        u[::skip, ::skip],
        v[::skip, ::skip],
        **barb_kwargs,
    )
    return self.ax

WindQuiverPlot

Bases: SpatialPlot

Create a quiver plot of wind vectors on a map.

This plot shows wind speed and direction using arrows.

Source code in src/monet_plots/plots/wind_quiver.py
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
class WindQuiverPlot(SpatialPlot):
    """Create a quiver plot of wind vectors on a map.

    This plot shows wind speed and direction using arrows.
    """

    def __init__(self, ws: Any, wdir: Any, gridobj, *args, **kwargs):
        """
        Initialize the plot with data and map projection.

        Args:
            ws (np.ndarray, pd.DataFrame, pd.Series, xr.DataArray): 2D array of wind speeds.
            wdir (np.ndarray, pd.DataFrame, pd.Series, xr.DataArray): 2D array of wind directions.
            gridobj (object): Object with LAT and LON variables.
            **kwargs: Keyword arguments passed to SpatialPlot for projection and features.
        """
        super().__init__(*args, **kwargs)
        self.ws = np.asarray(ws)
        self.wdir = np.asarray(wdir)
        self.gridobj = gridobj

    def plot(self, **kwargs):
        """Generate the wind quiver plot."""
        quiver_kwargs = self.add_features(**kwargs)
        quiver_kwargs.setdefault("transform", ccrs.PlateCarree())

        lat = self.gridobj.variables["LAT"][0, 0, :, :].squeeze()
        lon = self.gridobj.variables["LON"][0, 0, :, :].squeeze()
        u, v = tools.wsdir2uv(self.ws, self.wdir)
        # Subsample the data for clarity
        quiv = self.ax.quiver(
            lon[::15, ::15],
            lat[::15, ::15],
            u[::15, ::15],
            v[::15, ::15],
            **quiver_kwargs,
        )
        return quiv

__init__(ws, wdir, gridobj, *args, **kwargs)

Initialize the plot with data and map projection.

Parameters:

Name Type Description Default
ws (ndarray, DataFrame, Series, DataArray)

2D array of wind speeds.

required
wdir (ndarray, DataFrame, Series, DataArray)

2D array of wind directions.

required
gridobj object

Object with LAT and LON variables.

required
**kwargs

Keyword arguments passed to SpatialPlot for projection and features.

{}
Source code in src/monet_plots/plots/wind_quiver.py
18
19
20
21
22
23
24
25
26
27
28
29
30
31
def __init__(self, ws: Any, wdir: Any, gridobj, *args, **kwargs):
    """
    Initialize the plot with data and map projection.

    Args:
        ws (np.ndarray, pd.DataFrame, pd.Series, xr.DataArray): 2D array of wind speeds.
        wdir (np.ndarray, pd.DataFrame, pd.Series, xr.DataArray): 2D array of wind directions.
        gridobj (object): Object with LAT and LON variables.
        **kwargs: Keyword arguments passed to SpatialPlot for projection and features.
    """
    super().__init__(*args, **kwargs)
    self.ws = np.asarray(ws)
    self.wdir = np.asarray(wdir)
    self.gridobj = gridobj

plot(**kwargs)

Generate the wind quiver plot.

Source code in src/monet_plots/plots/wind_quiver.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
def plot(self, **kwargs):
    """Generate the wind quiver plot."""
    quiver_kwargs = self.add_features(**kwargs)
    quiver_kwargs.setdefault("transform", ccrs.PlateCarree())

    lat = self.gridobj.variables["LAT"][0, 0, :, :].squeeze()
    lon = self.gridobj.variables["LON"][0, 0, :, :].squeeze()
    u, v = tools.wsdir2uv(self.ws, self.wdir)
    # Subsample the data for clarity
    quiv = self.ax.quiver(
        lon[::15, ::15],
        lat[::15, ::15],
        u[::15, ::15],
        v[::15, ::15],
        **quiver_kwargs,
    )
    return quiv

get_available_styles()

Returns a list of available style context names.

Returns

list[str] List of style names.

Source code in src/monet_plots/style.py
198
199
200
201
202
203
204
205
206
207
def get_available_styles() -> list[str]:
    """
    Returns a list of available style context names.

    Returns
    -------
    list[str]
        List of style names.
    """
    return list(_styles.keys())

get_style_setting(key, default=None)

Retrieves a style setting from the currently active style. Looks in both standard rcParams and custom style settings.

Parameters

key : str The name of the style setting. default : Any, optional The default value if the key is not found, by default None.

Returns

Any The style setting value.

Source code in src/monet_plots/style.py
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
def get_style_setting(key: str, default: Any = None) -> Any:
    """
    Retrieves a style setting from the currently active style.
    Looks in both standard rcParams and custom style settings.

    Parameters
    ----------
    key : str
        The name of the style setting.
    default : Any, optional
        The default value if the key is not found, by default None.

    Returns
    -------
    Any
        The style setting value.
    """
    # First check current style's dictionary (includes custom keys)
    style_dict = _styles.get(_current_style_name, {})
    if key in style_dict:
        return style_dict[key]

    # Fallback to general rcParams
    return plt.rcParams.get(key, default)

set_style(context='wiley')

Set the plotting style based on a predefined context.

Parameters

context : str, optional The name of the style context to apply. Available contexts: "wiley", "presentation", "paper", "web", "pivotal_weather", "default". Defaults to "wiley".

Raises

ValueError If an unknown context name is provided.

Source code in src/monet_plots/style.py
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
def set_style(context: str = "wiley"):
    """
    Set the plotting style based on a predefined context.

    Parameters
    ----------
    context : str, optional
        The name of the style context to apply.
        Available contexts: "wiley", "presentation", "paper", "web", "pivotal_weather", "default".
        Defaults to "wiley".

    Raises
    ------
    ValueError
        If an unknown context name is provided.
    """
    global _current_style_name

    if context not in _styles:
        raise ValueError(
            f"Unknown style context: '{context}'. "
            f"Available contexts are: {', '.join(_styles.keys())}"
        )

    style_dict = _styles[context]

    # Separate standard rcParams from custom ones
    standard_rc = {k: v for k, v in style_dict.items() if k in plt.rcParams}

    if context == "default":
        plt.style.use("default")
    else:
        plt.style.use(standard_rc)

    _current_style_name = context

Taylor diagram (Taylor, 2001) implementation.

A Taylor diagram is a graphical representation of how well a model simulates an observed pattern. It provides a way to summarize multiple aspects of model performance, including: - Correlation coefficient - Root-mean-square (RMS) difference - The standard deviation ratio

Reference: Taylor, K.E., 2001. Summarizing multiple aspects of model performance in a single diagram. Journal of Geophysical Research, 106(D7), 7183-7192.

TaylorDiagram

:no-index:

Taylor diagram for visualizing model performance metrics.

The Taylor diagram displays multiple statistical metrics in a single plot: - The radial distance from the origin represents the standard deviation - The azimuthal position represents the correlation coefficient - The distance from the reference point represents the root-mean-square (RMS) difference

This class creates a Taylor diagram in a polar plot, where: - r = standard deviation - θ = arccos(correlation coefficient)

This provides a comprehensive view of how well a model pattern matches observations.

Source code in src/monet_plots/taylordiagram.py
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
class TaylorDiagram:
    """
    :no-index:

    Taylor diagram for visualizing model performance metrics.

    The Taylor diagram displays multiple statistical metrics in a single plot:
    - The radial distance from the origin represents the standard deviation
    - The azimuthal position represents the correlation coefficient
    - The distance from the reference point represents the root-mean-square (RMS) difference

    This class creates a Taylor diagram in a polar plot, where:
    - r = standard deviation
    - θ = arccos(correlation coefficient)

    This provides a comprehensive view of how well a model pattern matches observations.
    """

    @_sns_context
    def __init__(self, refstd, scale=1.5, fig=None, rect=111, label="_"):
        """Initialize the Taylor diagram.

        Parameters
        ----------
        refstd : float
            The reference standard deviation (e.g., from observations or a reference model)
            that other models will be compared against.
        scale : float, default 1.5
            The maximum standard deviation shown on the plot, as a multiple of refstd.
            For example, if refstd=2 and scale=1.5, the maximum standard deviation
            displayed will be 3.0.
        fig : matplotlib.figure.Figure, optional
            Figure to use. If None, a new figure will be created.
        rect : int or tuple, default 111
            Subplot specification (nrows, ncols, index) or 3-digit integer where
            the digits represent nrows, ncols, and index in order.
        label : str, default "_"
            Label for the reference point. An underscore prefix makes the label not
            appear in the legend.
        """

        import mpl_toolkits.axisartist.floating_axes as FA
        import mpl_toolkits.axisartist.grid_finder as GF
        from matplotlib.projections import PolarAxes

        self.refstd = refstd  # Reference standard deviation

        tr = PolarAxes.PolarTransform(apply_theta_transforms=False)

        # Correlation labels
        rlocs = np.concatenate((np.arange(10) / 10.0, [0.95, 0.99]))
        tlocs = np.arccos(rlocs)  # Conversion to polar angles
        gl1 = GF.FixedLocator(tlocs)  # Positions
        tf1 = GF.DictFormatter(dict(list(zip(tlocs, list(map(str, rlocs))))))

        # Standard deviation axis extent
        self.smin = 0
        self.smax = scale * self.refstd
        ghelper = FA.GridHelperCurveLinear(
            tr,
            extremes=(0, np.pi / 2, self.smin, self.smax),
            grid_locator1=gl1,
            tick_formatter1=tf1,
        )  # 1st quadrant

        if fig is None:
            fig = plt.figure()

        ax = FA.FloatingSubplot(fig, rect, grid_helper=ghelper)
        fig.add_subplot(ax)

        # Adjust axes
        ax.axis["top"].set_axis_direction("bottom")  # "Angle axis"
        ax.axis["top"].toggle(ticklabels=True, label=True)
        ax.axis["top"].major_ticklabels.set_axis_direction("top")
        ax.axis["top"].label.set_axis_direction("top")
        ax.axis["top"].label.set_text("Correlation")

        ax.axis["left"].set_axis_direction("bottom")  # "X axis"
        ax.axis["left"].label.set_text("Standard deviation")

        ax.axis["right"].set_axis_direction("top")  # "Y axis"
        ax.axis["right"].toggle(ticklabels=True)
        ax.axis["right"].major_ticklabels.set_axis_direction("left")

        ax.axis["bottom"].set_visible(False)  # Useless

        # Contours along standard deviations
        ax.grid(False)

        self._ax = ax  # Graphical axes
        self.ax = ax.get_aux_axes(tr)  # Polar coordinates

        # Add reference point and stddev contour
        print("Reference std:", self.refstd)
        (line,) = self.ax.plot(
            [0], self.refstd, "r*", ls="", ms=14, label=label, zorder=10
        )
        t = np.linspace(0, np.pi / 2)
        r = np.zeros_like(t) + self.refstd
        self.ax.plot(t, r, "k--", label="_")

        # Collect sample points for latter use (e.g. legend)
        self.samplePoints = [line]

    @property
    def samples(self):
        """Property to provide compatibility with tests expecting 'samples' attribute."""
        return self.samplePoints

    @_sns_context
    def add_sample(self, stddev, corrcoef, *args, **kwargs):
        """Add a sample point to the Taylor diagram.

        Parameters
        ----------
        stddev : float
            Standard deviation of the sample to add.
        corrcoef : float
            Correlation coefficient between the sample and reference (-1 to 1).
        *args
            Additional positional arguments passed to matplotlib's plot function.
        **kwargs
            Additional keyword arguments passed to matplotlib's plot function.
            Common options include 'marker', 'markersize', 'color', and 'label'.

        Returns
        -------
        matplotlib.lines.Line2D
            The line object representing the sample in the plot.

        Notes
        -----
        Points closer to the reference point indicate better agreement with
        the reference dataset.
        """
        (line,) = self.ax.plot(
            np.arccos(corrcoef), stddev, *args, **kwargs
        )  # (theta,radius)
        self.samplePoints.append(line)

        return line

    @_sns_context
    def add_contours(self, levels=5, **kwargs):
        """Add constant RMS difference contours to the Taylor diagram.

        Parameters
        ----------
        levels : int or array-like, default 5
            If an integer, it defines the number of equally-spaced contour levels.
            If array-like, it explicitly defines the contour levels.
        **kwargs
            Additional keyword arguments passed to matplotlib's contour function.
            Common options include 'colors', 'linewidths', and 'linestyles'.

        Returns
        -------
        matplotlib.contour.QuadContourSet
            The contour set created by the function.

        Notes
        -----
        These contours represent lines of constant RMS difference between the
        reference and sample datasets. They help visualize the combined effect
        of differences in standard deviation and correlation.
        """

        rs, ts = np.meshgrid(
            np.linspace(self.smin, self.smax), np.linspace(0, np.pi / 2)
        )
        # Compute centered RMS difference
        rms = np.sqrt(self.refstd**2 + rs**2 - 2 * self.refstd * rs * np.cos(ts))

        contours = self.ax.contour(ts, rs, rms, levels, **kwargs)

        return contours

samples property

Property to provide compatibility with tests expecting 'samples' attribute.

__init__(refstd, scale=1.5, fig=None, rect=111, label='_')

Initialize the Taylor diagram.

Parameters

refstd : float The reference standard deviation (e.g., from observations or a reference model) that other models will be compared against. scale : float, default 1.5 The maximum standard deviation shown on the plot, as a multiple of refstd. For example, if refstd=2 and scale=1.5, the maximum standard deviation displayed will be 3.0. fig : matplotlib.figure.Figure, optional Figure to use. If None, a new figure will be created. rect : int or tuple, default 111 Subplot specification (nrows, ncols, index) or 3-digit integer where the digits represent nrows, ncols, and index in order. label : str, default "_" Label for the reference point. An underscore prefix makes the label not appear in the legend.

Source code in src/monet_plots/taylordiagram.py
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
@_sns_context
def __init__(self, refstd, scale=1.5, fig=None, rect=111, label="_"):
    """Initialize the Taylor diagram.

    Parameters
    ----------
    refstd : float
        The reference standard deviation (e.g., from observations or a reference model)
        that other models will be compared against.
    scale : float, default 1.5
        The maximum standard deviation shown on the plot, as a multiple of refstd.
        For example, if refstd=2 and scale=1.5, the maximum standard deviation
        displayed will be 3.0.
    fig : matplotlib.figure.Figure, optional
        Figure to use. If None, a new figure will be created.
    rect : int or tuple, default 111
        Subplot specification (nrows, ncols, index) or 3-digit integer where
        the digits represent nrows, ncols, and index in order.
    label : str, default "_"
        Label for the reference point. An underscore prefix makes the label not
        appear in the legend.
    """

    import mpl_toolkits.axisartist.floating_axes as FA
    import mpl_toolkits.axisartist.grid_finder as GF
    from matplotlib.projections import PolarAxes

    self.refstd = refstd  # Reference standard deviation

    tr = PolarAxes.PolarTransform(apply_theta_transforms=False)

    # Correlation labels
    rlocs = np.concatenate((np.arange(10) / 10.0, [0.95, 0.99]))
    tlocs = np.arccos(rlocs)  # Conversion to polar angles
    gl1 = GF.FixedLocator(tlocs)  # Positions
    tf1 = GF.DictFormatter(dict(list(zip(tlocs, list(map(str, rlocs))))))

    # Standard deviation axis extent
    self.smin = 0
    self.smax = scale * self.refstd
    ghelper = FA.GridHelperCurveLinear(
        tr,
        extremes=(0, np.pi / 2, self.smin, self.smax),
        grid_locator1=gl1,
        tick_formatter1=tf1,
    )  # 1st quadrant

    if fig is None:
        fig = plt.figure()

    ax = FA.FloatingSubplot(fig, rect, grid_helper=ghelper)
    fig.add_subplot(ax)

    # Adjust axes
    ax.axis["top"].set_axis_direction("bottom")  # "Angle axis"
    ax.axis["top"].toggle(ticklabels=True, label=True)
    ax.axis["top"].major_ticklabels.set_axis_direction("top")
    ax.axis["top"].label.set_axis_direction("top")
    ax.axis["top"].label.set_text("Correlation")

    ax.axis["left"].set_axis_direction("bottom")  # "X axis"
    ax.axis["left"].label.set_text("Standard deviation")

    ax.axis["right"].set_axis_direction("top")  # "Y axis"
    ax.axis["right"].toggle(ticklabels=True)
    ax.axis["right"].major_ticklabels.set_axis_direction("left")

    ax.axis["bottom"].set_visible(False)  # Useless

    # Contours along standard deviations
    ax.grid(False)

    self._ax = ax  # Graphical axes
    self.ax = ax.get_aux_axes(tr)  # Polar coordinates

    # Add reference point and stddev contour
    print("Reference std:", self.refstd)
    (line,) = self.ax.plot(
        [0], self.refstd, "r*", ls="", ms=14, label=label, zorder=10
    )
    t = np.linspace(0, np.pi / 2)
    r = np.zeros_like(t) + self.refstd
    self.ax.plot(t, r, "k--", label="_")

    # Collect sample points for latter use (e.g. legend)
    self.samplePoints = [line]

add_contours(levels=5, **kwargs)

Add constant RMS difference contours to the Taylor diagram.

Parameters

levels : int or array-like, default 5 If an integer, it defines the number of equally-spaced contour levels. If array-like, it explicitly defines the contour levels. **kwargs Additional keyword arguments passed to matplotlib's contour function. Common options include 'colors', 'linewidths', and 'linestyles'.

Returns

matplotlib.contour.QuadContourSet The contour set created by the function.

Notes

These contours represent lines of constant RMS difference between the reference and sample datasets. They help visualize the combined effect of differences in standard deviation and correlation.

Source code in src/monet_plots/taylordiagram.py
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
@_sns_context
def add_contours(self, levels=5, **kwargs):
    """Add constant RMS difference contours to the Taylor diagram.

    Parameters
    ----------
    levels : int or array-like, default 5
        If an integer, it defines the number of equally-spaced contour levels.
        If array-like, it explicitly defines the contour levels.
    **kwargs
        Additional keyword arguments passed to matplotlib's contour function.
        Common options include 'colors', 'linewidths', and 'linestyles'.

    Returns
    -------
    matplotlib.contour.QuadContourSet
        The contour set created by the function.

    Notes
    -----
    These contours represent lines of constant RMS difference between the
    reference and sample datasets. They help visualize the combined effect
    of differences in standard deviation and correlation.
    """

    rs, ts = np.meshgrid(
        np.linspace(self.smin, self.smax), np.linspace(0, np.pi / 2)
    )
    # Compute centered RMS difference
    rms = np.sqrt(self.refstd**2 + rs**2 - 2 * self.refstd * rs * np.cos(ts))

    contours = self.ax.contour(ts, rs, rms, levels, **kwargs)

    return contours

add_sample(stddev, corrcoef, *args, **kwargs)

Add a sample point to the Taylor diagram.

Parameters

stddev : float Standard deviation of the sample to add. corrcoef : float Correlation coefficient between the sample and reference (-1 to 1). args Additional positional arguments passed to matplotlib's plot function. *kwargs Additional keyword arguments passed to matplotlib's plot function. Common options include 'marker', 'markersize', 'color', and 'label'.

Returns

matplotlib.lines.Line2D The line object representing the sample in the plot.

Notes

Points closer to the reference point indicate better agreement with the reference dataset.

Source code in src/monet_plots/taylordiagram.py
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
@_sns_context
def add_sample(self, stddev, corrcoef, *args, **kwargs):
    """Add a sample point to the Taylor diagram.

    Parameters
    ----------
    stddev : float
        Standard deviation of the sample to add.
    corrcoef : float
        Correlation coefficient between the sample and reference (-1 to 1).
    *args
        Additional positional arguments passed to matplotlib's plot function.
    **kwargs
        Additional keyword arguments passed to matplotlib's plot function.
        Common options include 'marker', 'markersize', 'color', and 'label'.

    Returns
    -------
    matplotlib.lines.Line2D
        The line object representing the sample in the plot.

    Notes
    -----
    Points closer to the reference point indicate better agreement with
    the reference dataset.
    """
    (line,) = self.ax.plot(
        np.arccos(corrcoef), stddev, *args, **kwargs
    )  # (theta,radius)
    self.samplePoints.append(line)

    return line

split_by_threshold(data_list, alt_list, threshold_list)

Splits data into bins based on altitude thresholds.

Parameters:

Name Type Description Default
data_list list

List of data values.

required
alt_list list

List of altitude values corresponding to the data.

required
threshold_list list

List of altitude thresholds to bin the data.

required

Returns:

Name Type Description
list

A list of arrays, where each array contains the data values within an altitude bin.

Source code in src/monet_plots/tools.py
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
def split_by_threshold(data_list, alt_list, threshold_list):
    """
    Splits data into bins based on altitude thresholds.

    Args:
        data_list (list): List of data values.
        alt_list (list): List of altitude values corresponding to the data.
        threshold_list (list): List of altitude thresholds to bin the data.

    Returns:
        list: A list of arrays, where each array contains the data values
              within an altitude bin.
    """
    df = pd.DataFrame(data={"data": data_list, "alt": alt_list})
    output_list = []
    for i in range(1, len(threshold_list)):
        df_here = df.data.loc[
            (df.alt > threshold_list[i - 1]) & (df.alt <= threshold_list[i])
        ]
        output_list.append(df_here.values)
    return output_list

uv2wsdir(u, v)

Converts u and v components to wind speed and direction.

Parameters:

Name Type Description Default
u ndarray

The u component of the wind.

required
v ndarray

The v component of the wind.

required

Returns:

Name Type Description
tuple

A tuple containing the wind speed and direction.

Source code in src/monet_plots/tools.py
44
45
46
47
48
49
50
51
52
53
54
55
56
def uv2wsdir(u, v):
    """Converts u and v components to wind speed and direction.

    Args:
        u (numpy.ndarray): The u component of the wind.
        v (numpy.ndarray): The v component of the wind.

    Returns:
        tuple: A tuple containing the wind speed and direction.
    """
    ws = np.sqrt(u**2 + v**2)
    wdir = 180 + (180 / np.pi) * np.arctan2(u, v)
    return ws, wdir

wsdir2uv(ws, wdir)

Converts wind speed and direction to u and v components.

Parameters:

Name Type Description Default
ws ndarray

The wind speed.

required
wdir ndarray

The wind direction.

required

Returns:

Name Type Description
tuple

A tuple containing the u and v components of the wind.

Source code in src/monet_plots/tools.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
def wsdir2uv(ws, wdir):
    """Converts wind speed and direction to u and v components.

    Args:
        ws (numpy.ndarray): The wind speed.
        wdir (numpy.ndarray): The wind direction.

    Returns:
        tuple: A tuple containing the u and v components of the wind.
    """
    rad = np.pi / 180.0
    u = -ws * np.sin(wdir * rad)
    v = -ws * np.cos(wdir * rad)
    return u, v