Skip to content

Spatial Plots

SpatialPlot

Bases: BasePlot

A base class for creating spatial plots with cartopy.

This class provides a high-level interface for geospatial plots, handling the setup of cartopy axes and the addition of common map features like coastlines, states, and gridlines.

Attributes

resolution : str The resolution of the cartopy features (e.g., '50m').

Source code in src/monet_plots/plots/spatial.py
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
class SpatialPlot(BasePlot):
    """A base class for creating spatial plots with cartopy.

    This class provides a high-level interface for geospatial plots, handling
    the setup of cartopy axes and the addition of common map features like
    coastlines, states, and gridlines.

    Attributes
    ----------
    resolution : str
        The resolution of the cartopy features (e.g., '50m').
    """

    def __init__(
        self,
        *,
        projection: ccrs.Projection = ccrs.PlateCarree(),
        fig: Figure | None = None,
        ax: Axes | None = None,
        figsize: tuple[float, float] | None = None,
        subplot_kw: dict[str, Any] | None = None,
        **kwargs: Any,
    ) -> None:
        """Initialize the spatial plot and draw map features.

        This constructor sets up the matplotlib Figure and cartopy GeoAxes,
        and provides a single interface to draw common map features like
        coastlines and states.

        Parameters
        ----------
        projection : ccrs.Projection, optional
            The cartopy projection for the map, by default ccrs.PlateCarree().
        fig : plt.Figure | None, optional
            An existing matplotlib Figure object. If None, a new one is
            created, by default None.
        ax : plt.Axes | None, optional
            An existing matplotlib Axes object. If None, a new one is created,
            by default None.
        figsize : tuple[float, float] | None, optional
             Width, height in inches. If not provided, the matplotlib default
             will be used.
        subplot_kw : dict[str, Any] | None, optional
            Keyword arguments passed to `fig.add_subplot`, by default None.
            The 'projection' is added to these keywords automatically.
        **kwargs : Any
            Keyword arguments for map features, passed to `add_features`.
            Common options include `coastlines`, `states`, `countries`,
            `ocean`, `land`, `lakes`, `rivers`, `borders`, `gridlines`,
            `extent`, and `resolution`.

        Attributes
        ----------
        fig : plt.Figure
            The matplotlib Figure object.
        ax : plt.Axes
            The matplotlib Axes (or GeoAxes) object.
        resolution : str
            The default resolution for cartopy features.
        """
        # Ensure 'projection' is correctly passed to subplot creation.
        current_subplot_kw = subplot_kw.copy() if subplot_kw else {}
        current_subplot_kw["projection"] = projection

        self.resolution = kwargs.pop("resolution", "50m")
        style = kwargs.pop("style", "wiley")

        # Ensure coastlines are enabled by default if not specified.
        if "coastlines" not in kwargs:
            kwargs["coastlines"] = True

        # Initialize the base plot, which creates the figure and axes.
        super().__init__(
            fig=fig, ax=ax, figsize=figsize, style=style, subplot_kw=current_subplot_kw
        )

        # If BasePlot didn't create an axes (e.g. because fig was provided),
        # create one now.
        if self.ax is None:
            self.ax = self.fig.add_subplot(1, 1, 1, **current_subplot_kw)

        # Add features from kwargs
        self.add_features(**kwargs)

    def _get_feature_registry(self, resolution: str) -> dict[str, dict[str, Any]]:
        """Return a registry of cartopy features and their default styles.

        This approach centralizes feature management, making it easier to
        add new features and maintain existing ones.

        Parameters
        ----------
        resolution : str
            The resolution for the cartopy features (e.g., '10m', '50m').

        Returns
        -------
        dict[str, dict[str, Any]]
            A dictionary mapping feature names to a specification dictionary
            containing the feature object and its default styling.
        """
        from cartopy.feature import (
            BORDERS,
            COASTLINE,
            LAKES,
            LAND,
            OCEAN,
            RIVERS,
            STATES,
        )

        # Define default styles, falling back to sane defaults if not in rcParams.
        coastline_defaults = {
            "linewidth": get_style_setting("coastline.width", 0.5),
            "edgecolor": get_style_setting("coastline.color", "black"),
            "facecolor": "none",
        }
        states_defaults = {
            "linewidth": get_style_setting("states.width", 0.5),
            "edgecolor": get_style_setting("states.color", "black"),
            "facecolor": "none",
        }
        borders_defaults = {
            "linewidth": get_style_setting("borders.width", 0.5),
            "edgecolor": get_style_setting("borders.color", "black"),
            "facecolor": "none",
        }

        feature_mapping = {
            "coastlines": {
                "feature": COASTLINE.with_scale(resolution),
                "defaults": coastline_defaults,
            },
            "countries": {
                "feature": BORDERS.with_scale(resolution),
                "defaults": borders_defaults,
            },
            "states": {
                "feature": STATES.with_scale(resolution),
                "defaults": states_defaults,
            },
            "borders": {
                "feature": BORDERS.with_scale(resolution),
                "defaults": borders_defaults,
            },
            "ocean": {"feature": OCEAN.with_scale(resolution), "defaults": {}},
            "land": {"feature": LAND.with_scale(resolution), "defaults": {}},
            "rivers": {"feature": RIVERS.with_scale(resolution), "defaults": {}},
            "lakes": {"feature": LAKES.with_scale(resolution), "defaults": {}},
            "counties": {
                "feature": cfeature.NaturalEarthFeature(
                    category="cultural",
                    name="admin_2_counties",
                    scale=resolution,
                    facecolor="none",
                ),
                "defaults": borders_defaults,
            },
        }
        return feature_mapping

    @staticmethod
    def _get_style(
        style: bool | dict[str, Any], defaults: dict[str, Any] | None = None
    ) -> dict[str, Any]:
        """Get a style dictionary for a feature.

        Parameters
        ----------
        style : bool or dict[str, Any]
            The user-provided style. If True, use defaults. If a dict, use it.
        defaults : dict[str, Any], optional
            The default style to apply if `style` is True.

        Returns
        -------
        dict[str, Any]
            The resolved keyword arguments for styling.
        """
        if isinstance(style, dict):
            return style
        if style and defaults:
            # Return a copy to prevent modifying the defaults dictionary in place
            return defaults.copy()
        return {}

    def _draw_single_feature(
        self, style_arg: bool | dict[str, Any], feature_spec: dict[str, Any]
    ) -> None:
        """Draw a single cartopy feature on the axes.

        Parameters
        ----------
        style_arg : bool or dict[str, Any]
            The user-provided style for the feature.
        feature_spec : dict[str, Any]
            A dictionary containing the feature object and default styles.
        """
        if not style_arg:  # Allows for `coastlines=False`
            return

        style_kwargs = self._get_style(style_arg, feature_spec["defaults"])
        feature = feature_spec["feature"]
        self.ax.add_feature(feature, **style_kwargs)

    def add_features(self, **kwargs: Any) -> dict[str, Any]:
        """Add and style cartopy features on the map axes.

        This method provides a flexible, data-driven interface to add common
        map features. Features can be enabled with a boolean flag (e.g.,
        `coastlines=True`) or styled with a dictionary of keyword arguments
        (e.g., `states=dict(linewidth=2, edgecolor='red')`).

        The `extent` keyword is also supported to set the map boundaries.

        Parameters
        ----------
        **kwargs : Any
            Keyword arguments controlling the features to add and their
            styles. Common options include `coastlines`, `states`,
            `countries`, `ocean`, `land`, `lakes`, `rivers`, `borders`,
            and `gridlines`.

        Returns
        -------
        dict[str, Any]
            A dictionary of the keyword arguments that were not used for
            adding features. This can be useful for passing remaining
            arguments to other functions.
        """
        # Note: The order of these calls is important.
        # Extent must be set before gridlines are drawn to ensure labels
        # are placed correctly.
        if "extent" in kwargs:
            extent = kwargs.pop("extent")
            self._set_extent(extent)

        if "gridlines" in kwargs:
            gridline_style = kwargs.pop("gridlines")
            self._draw_gridlines(gridline_style)

        # The rest of the kwargs are assumed to be for vector features.
        remaining_kwargs = self._draw_features(**kwargs)

        return remaining_kwargs

    def _set_extent(self, extent: tuple[float, float, float, float] | None) -> None:
        """Set the geographic extent of the map.

        Parameters
        ----------
        extent : tuple[float, float, float, float] | None
            The extent of the map as a tuple of (x_min, x_max, y_min, y_max).
            If None, the extent is not changed.
        """
        if extent is not None:
            self.ax.set_extent(extent)

    def _draw_features(self, **kwargs: Any) -> dict[str, Any]:
        """Draw vector features on the map.

        This is the primary feature-drawing loop, responsible for adding
        elements like coastlines, states, and borders.

        Parameters
        ----------
        **kwargs : Any
            Keyword arguments controlling the features to add and their
            styles.

        Returns
        -------
        dict[str, Any]
            A dictionary of the keyword arguments that were not used for
            adding features.
        """
        resolution = kwargs.pop("resolution", self.resolution)
        feature_registry = self._get_feature_registry(resolution)

        # If natural_earth is True, enable a standard set of features
        if kwargs.pop("natural_earth", False):
            for feature in ["ocean", "land", "lakes", "rivers"]:
                kwargs.setdefault(feature, True)

        # Main feature-drawing loop
        for key, feature_spec in feature_registry.items():
            if key in kwargs:
                style_arg = kwargs.pop(key)
                self._draw_single_feature(style_arg, feature_spec)

        return kwargs

    def _draw_gridlines(self, style: bool | dict[str, Any]) -> None:
        """Draw gridlines on the map.

        Parameters
        ----------
        style : bool or dict[str, Any]
            The style for the gridlines. If True, use defaults. If a dict,
            use it as keyword arguments. If False, do nothing.
        """
        if not style:
            return

        gridline_defaults = {
            "draw_labels": True,
            "linestyle": "--",
            "color": "gray",
        }
        gridline_kwargs = self._get_style(style, gridline_defaults)
        self.ax.gridlines(**gridline_kwargs)

    def _identify_coords(self, data: Any) -> tuple[str, str]:
        """Identify longitude and latitude coordinate names in data.

        This method follows CF conventions and common naming patterns to
        find the spatial coordinates. It supports both xarray and pandas.

        Parameters
        ----------
        data : Any
            The input data to search for coordinates.

        Returns
        -------
        tuple[str, str]
            The identified (longitude, latitude) coordinate names.

        Raises
        ----------
        ValueError
            If coordinates cannot be identified.
        """
        lon_name, lat_name = None, None

        # Handle xarray
        if hasattr(data, "coords"):
            # Check for standard names or units/axis attributes (CF conventions)
            for name, coord in data.coords.items():
                # Longitude identification
                if any(
                    pattern in str(name).lower() for pattern in ["lon", "longitude"]
                ) or coord.attrs.get("units") in [
                    "degrees_east",
                    "degree_east",
                    "degree_E",
                ]:
                    lon_name = str(name)
                # Latitude identification
                if any(
                    pattern in str(name).lower() for pattern in ["lat", "latitude"]
                ) or coord.attrs.get("units") in [
                    "degrees_north",
                    "degree_north",
                    "degree_N",
                ]:
                    lat_name = str(name)

        # Handle pandas or other objects with columns
        elif hasattr(data, "columns"):
            cols = [str(c).lower() for c in data.columns]
            if "lon" in cols:
                lon_name = data.columns[cols.index("lon")]
            elif "longitude" in cols:
                lon_name = data.columns[cols.index("longitude")]

            if "lat" in cols:
                lat_name = data.columns[cols.index("lat")]
            elif "latitude" in cols:
                lat_name = data.columns[cols.index("latitude")]

        if lon_name and lat_name:
            return lon_name, lat_name

        raise ValueError(
            "Could not identify longitude and latitude coordinates. "
            "Please ensure they are named 'lon'/'lat' or 'longitude'/'latitude', "
            "or have CF-compliant units."
        )

    def _ensure_monotonic(self, data: Any, lon_name: str, lat_name: str) -> Any:
        """Ensure spatial dimensions are monotonically increasing.

        Some plotting backends (like GeoViews/hvPlot) require monotonically
        increasing coordinates for correct rendering and interpolation.

        Parameters
        ----------
        data : Any
            The input data to sort.
        lon_name : str
            Name of the longitude coordinate.
        lat_name : str
            Name of the latitude coordinate.

        Returns
        -------
        Any
            The data with sorted spatial dimensions.
        """
        is_xarray = hasattr(data, "sortby")

        # Ensure latitude is increasing
        try:
            lats = data[lat_name].values if is_xarray else data[lat_name].values
            if lats[0] > lats[-1]:
                if is_xarray:
                    data = data.sortby(lat_name)
                else:
                    data = data.sort_values(lat_name)
        except (AttributeError, KeyError, IndexError):
            pass

        # Ensure longitude is increasing
        try:
            lons = data[lon_name].values if is_xarray else data[lon_name].values
            if lons[0] > lons[-1]:
                if is_xarray:
                    data = data.sortby(lon_name)
                else:
                    data = data.sort_values(lon_name)
        except (AttributeError, KeyError, IndexError):
            pass

        return data

    def _get_extent_from_data(
        self,
        data: xr.DataArray | xr.Dataset,
        lon_coord: str | None = None,
        lat_coord: str | None = None,
        buffer: float = 0.0,
    ) -> list[float]:
        """Calculate geographic extent from xarray data using Dask if available.

        Parameters
        ----------
        data : xr.DataArray or xr.Dataset
            The input data to calculate the extent from.
        lon_coord : str, optional
            Name of the longitude coordinate. If None, it is identified
            automatically using `_identify_coords`.
        lat_coord : str, optional
            Name of the latitude coordinate. If None, it is identified
            automatically using `_identify_coords`.
        buffer : float, optional
            Buffer to add to the extent as a fraction of the range,
            by default 0.0.

        Returns
        -------
        list[float]
            The calculated extent as [lon_min, lon_max, lat_min, lat_max].
        """
        if lon_coord is None or lat_coord is None:
            lon_id, lat_id = self._identify_coords(data)
            lon_coord = lon_coord or lon_id
            lat_coord = lat_coord or lat_id

        lon = data[lon_coord]
        lat = data[lat_coord]

        # Use dask.compute for efficient parallel calculation of min/max
        # if the data is chunked.
        try:
            import dask

            lon_min, lon_max, lat_min, lat_max = dask.compute(
                lon.min(), lon.max(), lat.min(), lat.max()
            )
        except (ImportError, AttributeError):
            lon_min, lon_max = lon.min(), lon.max()
            lat_min, lat_max = lat.min(), lat.max()

        # Ensure they are scalar values (handles both numpy and dask returns)
        lon_min, lon_max = float(lon_min), float(lon_max)
        lat_min, lat_max = float(lat_min), float(lat_max)

        if buffer > 0:
            lon_range = lon_max - lon_min
            lat_range = lat_max - lat_min
            lon_buf = lon_range * buffer if lon_range > 0 else 1.0
            lat_buf = lat_range * buffer if lat_range > 0 else 1.0
            lon_min -= lon_buf
            lon_max += lon_buf
            lat_min -= lat_buf
            lat_max += lat_buf

        return [lon_min, lon_max, lat_min, lat_max]

__init__(*, projection=ccrs.PlateCarree(), fig=None, ax=None, figsize=None, subplot_kw=None, **kwargs)

Initialize the spatial plot and draw map features.

This constructor sets up the matplotlib Figure and cartopy GeoAxes, and provides a single interface to draw common map features like coastlines and states.

Parameters

projection : ccrs.Projection, optional The cartopy projection for the map, by default ccrs.PlateCarree(). fig : plt.Figure | None, optional An existing matplotlib Figure object. If None, a new one is created, by default None. ax : plt.Axes | None, optional An existing matplotlib Axes object. If None, a new one is created, by default None. figsize : tuple[float, float] | None, optional Width, height in inches. If not provided, the matplotlib default will be used. subplot_kw : dict[str, Any] | None, optional Keyword arguments passed to fig.add_subplot, by default None. The 'projection' is added to these keywords automatically. **kwargs : Any Keyword arguments for map features, passed to add_features. Common options include coastlines, states, countries, ocean, land, lakes, rivers, borders, gridlines, extent, and resolution.

Attributes

fig : plt.Figure The matplotlib Figure object. ax : plt.Axes The matplotlib Axes (or GeoAxes) object. resolution : str The default resolution for cartopy features.

Source code in src/monet_plots/plots/spatial.py
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
def __init__(
    self,
    *,
    projection: ccrs.Projection = ccrs.PlateCarree(),
    fig: Figure | None = None,
    ax: Axes | None = None,
    figsize: tuple[float, float] | None = None,
    subplot_kw: dict[str, Any] | None = None,
    **kwargs: Any,
) -> None:
    """Initialize the spatial plot and draw map features.

    This constructor sets up the matplotlib Figure and cartopy GeoAxes,
    and provides a single interface to draw common map features like
    coastlines and states.

    Parameters
    ----------
    projection : ccrs.Projection, optional
        The cartopy projection for the map, by default ccrs.PlateCarree().
    fig : plt.Figure | None, optional
        An existing matplotlib Figure object. If None, a new one is
        created, by default None.
    ax : plt.Axes | None, optional
        An existing matplotlib Axes object. If None, a new one is created,
        by default None.
    figsize : tuple[float, float] | None, optional
         Width, height in inches. If not provided, the matplotlib default
         will be used.
    subplot_kw : dict[str, Any] | None, optional
        Keyword arguments passed to `fig.add_subplot`, by default None.
        The 'projection' is added to these keywords automatically.
    **kwargs : Any
        Keyword arguments for map features, passed to `add_features`.
        Common options include `coastlines`, `states`, `countries`,
        `ocean`, `land`, `lakes`, `rivers`, `borders`, `gridlines`,
        `extent`, and `resolution`.

    Attributes
    ----------
    fig : plt.Figure
        The matplotlib Figure object.
    ax : plt.Axes
        The matplotlib Axes (or GeoAxes) object.
    resolution : str
        The default resolution for cartopy features.
    """
    # Ensure 'projection' is correctly passed to subplot creation.
    current_subplot_kw = subplot_kw.copy() if subplot_kw else {}
    current_subplot_kw["projection"] = projection

    self.resolution = kwargs.pop("resolution", "50m")
    style = kwargs.pop("style", "wiley")

    # Ensure coastlines are enabled by default if not specified.
    if "coastlines" not in kwargs:
        kwargs["coastlines"] = True

    # Initialize the base plot, which creates the figure and axes.
    super().__init__(
        fig=fig, ax=ax, figsize=figsize, style=style, subplot_kw=current_subplot_kw
    )

    # If BasePlot didn't create an axes (e.g. because fig was provided),
    # create one now.
    if self.ax is None:
        self.ax = self.fig.add_subplot(1, 1, 1, **current_subplot_kw)

    # Add features from kwargs
    self.add_features(**kwargs)

add_features(**kwargs)

Add and style cartopy features on the map axes.

This method provides a flexible, data-driven interface to add common map features. Features can be enabled with a boolean flag (e.g., coastlines=True) or styled with a dictionary of keyword arguments (e.g., states=dict(linewidth=2, edgecolor='red')).

The extent keyword is also supported to set the map boundaries.

Parameters

**kwargs : Any Keyword arguments controlling the features to add and their styles. Common options include coastlines, states, countries, ocean, land, lakes, rivers, borders, and gridlines.

Returns

dict[str, Any] A dictionary of the keyword arguments that were not used for adding features. This can be useful for passing remaining arguments to other functions.

Source code in src/monet_plots/plots/spatial.py
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
def add_features(self, **kwargs: Any) -> dict[str, Any]:
    """Add and style cartopy features on the map axes.

    This method provides a flexible, data-driven interface to add common
    map features. Features can be enabled with a boolean flag (e.g.,
    `coastlines=True`) or styled with a dictionary of keyword arguments
    (e.g., `states=dict(linewidth=2, edgecolor='red')`).

    The `extent` keyword is also supported to set the map boundaries.

    Parameters
    ----------
    **kwargs : Any
        Keyword arguments controlling the features to add and their
        styles. Common options include `coastlines`, `states`,
        `countries`, `ocean`, `land`, `lakes`, `rivers`, `borders`,
        and `gridlines`.

    Returns
    -------
    dict[str, Any]
        A dictionary of the keyword arguments that were not used for
        adding features. This can be useful for passing remaining
        arguments to other functions.
    """
    # Note: The order of these calls is important.
    # Extent must be set before gridlines are drawn to ensure labels
    # are placed correctly.
    if "extent" in kwargs:
        extent = kwargs.pop("extent")
        self._set_extent(extent)

    if "gridlines" in kwargs:
        gridline_style = kwargs.pop("gridlines")
        self._draw_gridlines(gridline_style)

    # The rest of the kwargs are assumed to be for vector features.
    remaining_kwargs = self._draw_features(**kwargs)

    return remaining_kwargs

SpatialTrack

Bases: SpatialPlot

Plot a trajectory from an xarray.DataArray on a map.

This class provides an xarray-native interface for visualizing paths, such as flight trajectories or pollutant tracks, where a variable (e.g., altitude, concentration) is plotted along the path.

It inherits from :class:SpatialPlot to provide the underlying map canvas.

Attributes

data : xr.DataArray The trajectory data being plotted. lon_coord : str The name of the longitude coordinate in the DataArray. lat_coord : str The name of the latitude coordinate in the DataArray.

Source code in src/monet_plots/plots/spatial.py
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
class SpatialTrack(SpatialPlot):
    """Plot a trajectory from an xarray.DataArray on a map.

    This class provides an xarray-native interface for visualizing paths,
    such as flight trajectories or pollutant tracks, where a variable
    (e.g., altitude, concentration) is plotted along the path.

    It inherits from :class:`SpatialPlot` to provide the underlying map canvas.

    Attributes
    ----------
    data : xr.DataArray
        The trajectory data being plotted.
    lon_coord : str
        The name of the longitude coordinate in the DataArray.
    lat_coord : str
        The name of the latitude coordinate in the DataArray.
    """

    def __init__(
        self,
        data: xr.DataArray,
        *,
        lon_coord: str = "lon",
        lat_coord: str = "lat",
        **kwargs: Any,
    ):
        """Initialize the SpatialTrack plot.

        This constructor validates the input data and sets up the map canvas
        by initializing the parent `SpatialPlot` and adding map features.

        Parameters
        ----------
        data : xr.DataArray
            The input trajectory data. Must be an xarray DataArray with
            coordinates for longitude and latitude.
        lon_coord : str, optional
            Name of the longitude coordinate in the DataArray, by default 'lon'.
        lat_coord : str, optional
            Name of the latitude coordinate in the DataArray, by default 'lat'.
        **kwargs : Any
            Keyword arguments passed to :class:`SpatialPlot`. These control
            the map projection, figure size, and cartopy features. For example:
            `projection=ccrs.LambertConformal()`, `figsize=(10, 8)`,
            `states=True`, `extent=[-125, -70, 25, 50]`.
        """
        if not isinstance(data, xr.DataArray):
            raise TypeError("Input 'data' must be an xarray.DataArray.")
        if lon_coord not in data.coords:
            raise ValueError(
                f"Longitude coordinate '{lon_coord}' not found in DataArray."
            )
        if lat_coord not in data.coords:
            raise ValueError(
                f"Latitude coordinate '{lat_coord}' not found in DataArray."
            )

        # Initialize the parent SpatialPlot, which creates the map canvas
        # and draws features from the keyword arguments.
        super().__init__(**kwargs)

        # Set data and update history for provenance
        self.data = data
        self.lon_coord = lon_coord
        self.lat_coord = lat_coord
        _update_history(self.data, "Plotted with monet-plots.SpatialTrack")

    def plot(self, **kwargs: Any) -> plt.Artist:
        """Plot the trajectory on the map.

        The track is rendered as a scatter plot, where each point is colored
        according to the `data` values.

        Parameters
        ----------
        **kwargs : Any
            Keyword arguments passed to `matplotlib.pyplot.scatter`.
            Common options include `cmap`, `s` (size), and `alpha`.
            A `transform` keyword (e.g., `transform=ccrs.PlateCarree()`)
            is highly recommended for geospatial accuracy.
            The `cmap` argument can be a string, a Colormap object, or a
            (colormap, norm) tuple from the scaling tools in `colorbars.py`.
            Map features (e.g., `coastlines=True`) can also be passed here.

        Returns
        -------
        plt.Artist
            The scatter plot artist created by `ax.scatter`.
        """
        from ..plot_utils import get_plot_kwargs

        # Automatically compute extent if not provided
        if "extent" not in kwargs:
            kwargs["extent"] = self._get_extent_from_data(
                self.data, self.lon_coord, self.lat_coord, buffer=0.1
            )

        # Add features and get remaining kwargs for scatter
        scatter_kwargs = self.add_features(**kwargs)

        scatter_kwargs.setdefault("transform", ccrs.PlateCarree())

        # For coordinates and values, we pass the xarray objects directly.
        # This allows Matplotlib to handle the conversion, maintaining
        # compatibility with existing tests that check for lazy objects.
        longitude = self.data[self.lon_coord]
        latitude = self.data[self.lat_coord]

        # Use get_plot_kwargs to handle (cmap, norm) tuples
        final_kwargs = get_plot_kwargs(c=self.data, **scatter_kwargs)

        sc = self.ax.scatter(longitude, latitude, **final_kwargs)
        return sc

__init__(data, *, lon_coord='lon', lat_coord='lat', **kwargs)

Initialize the SpatialTrack plot.

This constructor validates the input data and sets up the map canvas by initializing the parent SpatialPlot and adding map features.

Parameters

data : xr.DataArray The input trajectory data. Must be an xarray DataArray with coordinates for longitude and latitude. lon_coord : str, optional Name of the longitude coordinate in the DataArray, by default 'lon'. lat_coord : str, optional Name of the latitude coordinate in the DataArray, by default 'lat'. **kwargs : Any Keyword arguments passed to :class:SpatialPlot. These control the map projection, figure size, and cartopy features. For example: projection=ccrs.LambertConformal(), figsize=(10, 8), states=True, extent=[-125, -70, 25, 50].

Source code in src/monet_plots/plots/spatial.py
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
def __init__(
    self,
    data: xr.DataArray,
    *,
    lon_coord: str = "lon",
    lat_coord: str = "lat",
    **kwargs: Any,
):
    """Initialize the SpatialTrack plot.

    This constructor validates the input data and sets up the map canvas
    by initializing the parent `SpatialPlot` and adding map features.

    Parameters
    ----------
    data : xr.DataArray
        The input trajectory data. Must be an xarray DataArray with
        coordinates for longitude and latitude.
    lon_coord : str, optional
        Name of the longitude coordinate in the DataArray, by default 'lon'.
    lat_coord : str, optional
        Name of the latitude coordinate in the DataArray, by default 'lat'.
    **kwargs : Any
        Keyword arguments passed to :class:`SpatialPlot`. These control
        the map projection, figure size, and cartopy features. For example:
        `projection=ccrs.LambertConformal()`, `figsize=(10, 8)`,
        `states=True`, `extent=[-125, -70, 25, 50]`.
    """
    if not isinstance(data, xr.DataArray):
        raise TypeError("Input 'data' must be an xarray.DataArray.")
    if lon_coord not in data.coords:
        raise ValueError(
            f"Longitude coordinate '{lon_coord}' not found in DataArray."
        )
    if lat_coord not in data.coords:
        raise ValueError(
            f"Latitude coordinate '{lat_coord}' not found in DataArray."
        )

    # Initialize the parent SpatialPlot, which creates the map canvas
    # and draws features from the keyword arguments.
    super().__init__(**kwargs)

    # Set data and update history for provenance
    self.data = data
    self.lon_coord = lon_coord
    self.lat_coord = lat_coord
    _update_history(self.data, "Plotted with monet-plots.SpatialTrack")

plot(**kwargs)

Plot the trajectory on the map.

The track is rendered as a scatter plot, where each point is colored according to the data values.

Parameters

**kwargs : Any Keyword arguments passed to matplotlib.pyplot.scatter. Common options include cmap, s (size), and alpha. A transform keyword (e.g., transform=ccrs.PlateCarree()) is highly recommended for geospatial accuracy. The cmap argument can be a string, a Colormap object, or a (colormap, norm) tuple from the scaling tools in colorbars.py. Map features (e.g., coastlines=True) can also be passed here.

Returns

plt.Artist The scatter plot artist created by ax.scatter.

Source code in src/monet_plots/plots/spatial.py
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
def plot(self, **kwargs: Any) -> plt.Artist:
    """Plot the trajectory on the map.

    The track is rendered as a scatter plot, where each point is colored
    according to the `data` values.

    Parameters
    ----------
    **kwargs : Any
        Keyword arguments passed to `matplotlib.pyplot.scatter`.
        Common options include `cmap`, `s` (size), and `alpha`.
        A `transform` keyword (e.g., `transform=ccrs.PlateCarree()`)
        is highly recommended for geospatial accuracy.
        The `cmap` argument can be a string, a Colormap object, or a
        (colormap, norm) tuple from the scaling tools in `colorbars.py`.
        Map features (e.g., `coastlines=True`) can also be passed here.

    Returns
    -------
    plt.Artist
        The scatter plot artist created by `ax.scatter`.
    """
    from ..plot_utils import get_plot_kwargs

    # Automatically compute extent if not provided
    if "extent" not in kwargs:
        kwargs["extent"] = self._get_extent_from_data(
            self.data, self.lon_coord, self.lat_coord, buffer=0.1
        )

    # Add features and get remaining kwargs for scatter
    scatter_kwargs = self.add_features(**kwargs)

    scatter_kwargs.setdefault("transform", ccrs.PlateCarree())

    # For coordinates and values, we pass the xarray objects directly.
    # This allows Matplotlib to handle the conversion, maintaining
    # compatibility with existing tests that check for lazy objects.
    longitude = self.data[self.lon_coord]
    latitude = self.data[self.lat_coord]

    # Use get_plot_kwargs to handle (cmap, norm) tuples
    final_kwargs = get_plot_kwargs(c=self.data, **scatter_kwargs)

    sc = self.ax.scatter(longitude, latitude, **final_kwargs)
    return sc

SpatialTrack

Bases: SpatialPlot

Plot a trajectory from an xarray.DataArray on a map.

This class provides an xarray-native interface for visualizing paths, such as flight trajectories or pollutant tracks, where a variable (e.g., altitude, concentration) is plotted along the path.

It inherits from :class:SpatialPlot to provide the underlying map canvas.

Attributes

data : xr.DataArray The trajectory data being plotted. lon_coord : str The name of the longitude coordinate in the DataArray. lat_coord : str The name of the latitude coordinate in the DataArray.

Source code in src/monet_plots/plots/spatial.py
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
class SpatialTrack(SpatialPlot):
    """Plot a trajectory from an xarray.DataArray on a map.

    This class provides an xarray-native interface for visualizing paths,
    such as flight trajectories or pollutant tracks, where a variable
    (e.g., altitude, concentration) is plotted along the path.

    It inherits from :class:`SpatialPlot` to provide the underlying map canvas.

    Attributes
    ----------
    data : xr.DataArray
        The trajectory data being plotted.
    lon_coord : str
        The name of the longitude coordinate in the DataArray.
    lat_coord : str
        The name of the latitude coordinate in the DataArray.
    """

    def __init__(
        self,
        data: xr.DataArray,
        *,
        lon_coord: str = "lon",
        lat_coord: str = "lat",
        **kwargs: Any,
    ):
        """Initialize the SpatialTrack plot.

        This constructor validates the input data and sets up the map canvas
        by initializing the parent `SpatialPlot` and adding map features.

        Parameters
        ----------
        data : xr.DataArray
            The input trajectory data. Must be an xarray DataArray with
            coordinates for longitude and latitude.
        lon_coord : str, optional
            Name of the longitude coordinate in the DataArray, by default 'lon'.
        lat_coord : str, optional
            Name of the latitude coordinate in the DataArray, by default 'lat'.
        **kwargs : Any
            Keyword arguments passed to :class:`SpatialPlot`. These control
            the map projection, figure size, and cartopy features. For example:
            `projection=ccrs.LambertConformal()`, `figsize=(10, 8)`,
            `states=True`, `extent=[-125, -70, 25, 50]`.
        """
        if not isinstance(data, xr.DataArray):
            raise TypeError("Input 'data' must be an xarray.DataArray.")
        if lon_coord not in data.coords:
            raise ValueError(
                f"Longitude coordinate '{lon_coord}' not found in DataArray."
            )
        if lat_coord not in data.coords:
            raise ValueError(
                f"Latitude coordinate '{lat_coord}' not found in DataArray."
            )

        # Initialize the parent SpatialPlot, which creates the map canvas
        # and draws features from the keyword arguments.
        super().__init__(**kwargs)

        # Set data and update history for provenance
        self.data = data
        self.lon_coord = lon_coord
        self.lat_coord = lat_coord
        _update_history(self.data, "Plotted with monet-plots.SpatialTrack")

    def plot(self, **kwargs: Any) -> plt.Artist:
        """Plot the trajectory on the map.

        The track is rendered as a scatter plot, where each point is colored
        according to the `data` values.

        Parameters
        ----------
        **kwargs : Any
            Keyword arguments passed to `matplotlib.pyplot.scatter`.
            Common options include `cmap`, `s` (size), and `alpha`.
            A `transform` keyword (e.g., `transform=ccrs.PlateCarree()`)
            is highly recommended for geospatial accuracy.
            The `cmap` argument can be a string, a Colormap object, or a
            (colormap, norm) tuple from the scaling tools in `colorbars.py`.
            Map features (e.g., `coastlines=True`) can also be passed here.

        Returns
        -------
        plt.Artist
            The scatter plot artist created by `ax.scatter`.
        """
        from ..plot_utils import get_plot_kwargs

        # Automatically compute extent if not provided
        if "extent" not in kwargs:
            kwargs["extent"] = self._get_extent_from_data(
                self.data, self.lon_coord, self.lat_coord, buffer=0.1
            )

        # Add features and get remaining kwargs for scatter
        scatter_kwargs = self.add_features(**kwargs)

        scatter_kwargs.setdefault("transform", ccrs.PlateCarree())

        # For coordinates and values, we pass the xarray objects directly.
        # This allows Matplotlib to handle the conversion, maintaining
        # compatibility with existing tests that check for lazy objects.
        longitude = self.data[self.lon_coord]
        latitude = self.data[self.lat_coord]

        # Use get_plot_kwargs to handle (cmap, norm) tuples
        final_kwargs = get_plot_kwargs(c=self.data, **scatter_kwargs)

        sc = self.ax.scatter(longitude, latitude, **final_kwargs)
        return sc

__init__(data, *, lon_coord='lon', lat_coord='lat', **kwargs)

Initialize the SpatialTrack plot.

This constructor validates the input data and sets up the map canvas by initializing the parent SpatialPlot and adding map features.

Parameters

data : xr.DataArray The input trajectory data. Must be an xarray DataArray with coordinates for longitude and latitude. lon_coord : str, optional Name of the longitude coordinate in the DataArray, by default 'lon'. lat_coord : str, optional Name of the latitude coordinate in the DataArray, by default 'lat'. **kwargs : Any Keyword arguments passed to :class:SpatialPlot. These control the map projection, figure size, and cartopy features. For example: projection=ccrs.LambertConformal(), figsize=(10, 8), states=True, extent=[-125, -70, 25, 50].

Source code in src/monet_plots/plots/spatial.py
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
def __init__(
    self,
    data: xr.DataArray,
    *,
    lon_coord: str = "lon",
    lat_coord: str = "lat",
    **kwargs: Any,
):
    """Initialize the SpatialTrack plot.

    This constructor validates the input data and sets up the map canvas
    by initializing the parent `SpatialPlot` and adding map features.

    Parameters
    ----------
    data : xr.DataArray
        The input trajectory data. Must be an xarray DataArray with
        coordinates for longitude and latitude.
    lon_coord : str, optional
        Name of the longitude coordinate in the DataArray, by default 'lon'.
    lat_coord : str, optional
        Name of the latitude coordinate in the DataArray, by default 'lat'.
    **kwargs : Any
        Keyword arguments passed to :class:`SpatialPlot`. These control
        the map projection, figure size, and cartopy features. For example:
        `projection=ccrs.LambertConformal()`, `figsize=(10, 8)`,
        `states=True`, `extent=[-125, -70, 25, 50]`.
    """
    if not isinstance(data, xr.DataArray):
        raise TypeError("Input 'data' must be an xarray.DataArray.")
    if lon_coord not in data.coords:
        raise ValueError(
            f"Longitude coordinate '{lon_coord}' not found in DataArray."
        )
    if lat_coord not in data.coords:
        raise ValueError(
            f"Latitude coordinate '{lat_coord}' not found in DataArray."
        )

    # Initialize the parent SpatialPlot, which creates the map canvas
    # and draws features from the keyword arguments.
    super().__init__(**kwargs)

    # Set data and update history for provenance
    self.data = data
    self.lon_coord = lon_coord
    self.lat_coord = lat_coord
    _update_history(self.data, "Plotted with monet-plots.SpatialTrack")

plot(**kwargs)

Plot the trajectory on the map.

The track is rendered as a scatter plot, where each point is colored according to the data values.

Parameters

**kwargs : Any Keyword arguments passed to matplotlib.pyplot.scatter. Common options include cmap, s (size), and alpha. A transform keyword (e.g., transform=ccrs.PlateCarree()) is highly recommended for geospatial accuracy. The cmap argument can be a string, a Colormap object, or a (colormap, norm) tuple from the scaling tools in colorbars.py. Map features (e.g., coastlines=True) can also be passed here.

Returns

plt.Artist The scatter plot artist created by ax.scatter.

Source code in src/monet_plots/plots/spatial.py
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
def plot(self, **kwargs: Any) -> plt.Artist:
    """Plot the trajectory on the map.

    The track is rendered as a scatter plot, where each point is colored
    according to the `data` values.

    Parameters
    ----------
    **kwargs : Any
        Keyword arguments passed to `matplotlib.pyplot.scatter`.
        Common options include `cmap`, `s` (size), and `alpha`.
        A `transform` keyword (e.g., `transform=ccrs.PlateCarree()`)
        is highly recommended for geospatial accuracy.
        The `cmap` argument can be a string, a Colormap object, or a
        (colormap, norm) tuple from the scaling tools in `colorbars.py`.
        Map features (e.g., `coastlines=True`) can also be passed here.

    Returns
    -------
    plt.Artist
        The scatter plot artist created by `ax.scatter`.
    """
    from ..plot_utils import get_plot_kwargs

    # Automatically compute extent if not provided
    if "extent" not in kwargs:
        kwargs["extent"] = self._get_extent_from_data(
            self.data, self.lon_coord, self.lat_coord, buffer=0.1
        )

    # Add features and get remaining kwargs for scatter
    scatter_kwargs = self.add_features(**kwargs)

    scatter_kwargs.setdefault("transform", ccrs.PlateCarree())

    # For coordinates and values, we pass the xarray objects directly.
    # This allows Matplotlib to handle the conversion, maintaining
    # compatibility with existing tests that check for lazy objects.
    longitude = self.data[self.lon_coord]
    latitude = self.data[self.lat_coord]

    # Use get_plot_kwargs to handle (cmap, norm) tuples
    final_kwargs = get_plot_kwargs(c=self.data, **scatter_kwargs)

    sc = self.ax.scatter(longitude, latitude, **final_kwargs)
    return sc

Example

import numpy as np
from monet_plots.plots import SpatialTrack

# Create sample data
lon = np.linspace(-120, -80, 100)
lat = np.linspace(30, 40, 100)
data = np.random.rand(100)

# Create the plot
plot = SpatialTrack(lon, lat, data)
plot.plot()

SpatialContourPlot

Bases: SpatialPlot

Create a contour plot on a map with an optional discrete colorbar.

This class provides an xarray-native interface for visualizing spatial data with continuous values. It supports both Track A (publication-quality static plots) and Track B (interactive exploration).

Source code in src/monet_plots/plots/spatial_contour.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
class SpatialContourPlot(SpatialPlot):
    """Create a contour plot on a map with an optional discrete colorbar.

    This class provides an xarray-native interface for visualizing spatial
    data with continuous values. It supports both Track A (publication-quality
    static plots) and Track B (interactive exploration).
    """

    def __new__(
        cls,
        modelvar: Any,
        gridobj: Any | None = None,
        date: Any | None = None,
        **kwargs: Any,
    ) -> Any:
        """Redirect to SpatialFacetGridPlot if faceting is requested.

        This enables a unified interface for both single-panel and multi-panel
        spatial plots, following Xarray's plotting conventions.

        Parameters
        ----------
        modelvar : Any
            The input data to contour.
        gridobj : Any, optional
            Object with LAT and LON variables, by default None.
        date : Any, optional
            Date/time for the plot title, by default None.
        **kwargs : Any
            Additional keyword arguments. If faceting arguments (e.g., `col`,
            `row`, or `col_wrap`) are provided, redirects to `SpatialFacetGridPlot`.

        Returns
        -------
        Any
            An instance of SpatialContourPlot or SpatialFacetGridPlot.
        """
        from .facet_grid import SpatialFacetGridPlot

        ax = kwargs.get("ax")

        # Aligns with Xarray's trigger for faceting
        facet_kwargs = ["col", "row", "col_wrap"]
        is_faceting = any(kwargs.get(k) is not None for k in facet_kwargs)

        # Redirect to FacetGrid if faceting requested and no existing axes
        if ax is None and is_faceting:
            return SpatialFacetGridPlot(modelvar, **kwargs)

        # Also redirect if input is a Dataset with multiple variables
        if (
            ax is None
            and isinstance(modelvar, xr.Dataset)
            and len(modelvar.data_vars) > 1
        ):
            # Default to faceting by variable if not specified
            kwargs.setdefault("col", "variable")
            return SpatialFacetGridPlot(modelvar, **kwargs)

        return super().__new__(cls)

    def __init__(
        self,
        modelvar: Any,
        gridobj: Any | None = None,
        date: datetime | None = None,
        discrete: bool = True,
        ncolors: int | None = None,
        dtype: str = "int",
        col: str | None = None,
        row: str | None = None,
        col_wrap: int | None = None,
        size: float | None = None,
        aspect: float | None = None,
        **kwargs: Any,
    ) -> None:
        """Initialize the spatial contour plot.

        Parameters
        ----------
        modelvar : Any
            The input data to contour. Preferred format is an xarray DataArray.
        gridobj : Any, optional
            Object with LAT and LON variables to determine extent, by default None.
        date : datetime, optional
            Date/time for the plot title, by default None.
        discrete : bool, optional
            If True, use a discrete colorbar, by default True.
        ncolors : int, optional
            Number of discrete colors for the colorbar, by default None.
        dtype : str, optional
            Data type for colorbar tick labels, by default "int".
        col : str, optional
            Dimension name to facet by columns. Aligns with Xarray.
        row : str, optional
            Dimension name to facet by rows. Aligns with Xarray.
        col_wrap : int, optional
            Number of columns before wrapping. Aligns with Xarray.
        size : float, optional
            Height (in inches) of each facet. Aligns with Xarray.
        aspect : float, optional
            Aspect ratio of each facet. Aligns with Xarray.
        **kwargs : Any
            Keyword arguments passed to :class:`SpatialPlot` for map features
            and projection.
        """
        # Initialize the map canvas via SpatialPlot
        super().__init__(**kwargs)

        # Standardize data to Xarray for consistency and lazy evaluation
        self.modelvar = normalize_data(modelvar)
        if isinstance(self.modelvar, xr.Dataset) and len(self.modelvar.data_vars) == 1:
            self.modelvar = self.modelvar[list(self.modelvar.data_vars)[0]]

        self.gridobj = gridobj
        self.date = date
        self.discrete = discrete
        self.ncolors = ncolors
        self.dtype = dtype

        # Aero Protocol: Centralized coordinate identification
        try:
            self.lon_coord, self.lat_coord = self._identify_coords(self.modelvar)
        except ValueError:
            self.lon_coord = kwargs.get("lon_coord", "lon")
            self.lat_coord = kwargs.get("lat_coord", "lat")

        # Ensure coordinates are monotonic for correct plotting
        self.modelvar = self._ensure_monotonic(
            self.modelvar, self.lon_coord, self.lat_coord
        )

        _update_history(self.modelvar, "Initialized monet-plots.SpatialContourPlot")

    def plot(self, **kwargs: Any) -> Axes:
        """Generate a static publication-quality spatial contour plot (Track A).

        Parameters
        ----------
        **kwargs : Any
            Keyword arguments passed to `matplotlib.pyplot.contourf`.
            Common options include `cmap`, `levels`, `vmin`, `vmax`, and `alpha`.
            Map features (e.g., `coastlines=True`) can also be passed here.

        Returns
        -------
        matplotlib.axes.Axes
            The matplotlib axes object containing the plot.
        """
        lat = None
        lon = None

        if hasattr(self.gridobj, "variables"):
            # Handle legacy gridobj
            try:
                lat_var = self.gridobj.variables["LAT"]
                lon_var = self.gridobj.variables["LON"]

                # Flexible indexing based on dimension count
                if lat_var.ndim == 4:
                    lat = lat_var[0, 0, :, :].squeeze()
                    lon = lon_var[0, 0, :, :].squeeze()
                elif lat_var.ndim == 3:
                    lat = lat_var[0, :, :].squeeze()
                    lon = lon_var[0, :, :].squeeze()
                else:
                    lat = lat_var.squeeze()
                    lon = lon_var.squeeze()
            except (AttributeError, KeyError):
                pass
        elif self.gridobj is not None:
            # Assume it's already an array or similar
            try:
                lat = self.gridobj.LAT
                lon = self.gridobj.LON
            except AttributeError:
                pass

        # Automatically compute extent if not provided
        if "extent" not in kwargs:
            if lon is not None and lat is not None:
                kwargs["extent"] = [
                    float(lon.min()),
                    float(lon.max()),
                    float(lat.min()),
                    float(lat.max()),
                ]
            else:
                kwargs["extent"] = self._get_extent_from_data(
                    self.modelvar, self.lon_coord, self.lat_coord
                )

        # Draw map features and get remaining kwargs for contourf
        plot_kwargs = self.add_features(**kwargs)

        # Set default contour settings
        plot_kwargs.setdefault("cmap", "viridis")

        # Data is in lat/lon, so specify transform
        plot_kwargs.setdefault("transform", ccrs.PlateCarree())

        # For coordinates and values, we prefer passing the xarray objects directly.
        # This allows Matplotlib to handle the conversion and maintains
        # parity for Dask-backed arrays.
        if lon is None or lat is None:
            longitude = self.modelvar[self.lon_coord]
            latitude = self.modelvar[self.lat_coord]
        else:
            longitude = lon
            latitude = lat

        # Handle colormap and normalization
        final_kwargs = get_plot_kwargs(**plot_kwargs)

        mesh = self.ax.contourf(longitude, latitude, self.modelvar, **final_kwargs)

        # Handle colorbar
        if self.discrete:
            cmap = final_kwargs.get("cmap")
            levels = final_kwargs.get("levels")
            ncolors = self.ncolors
            if ncolors is None and levels is not None:
                if isinstance(levels, int):
                    ncolors = levels - 1
                    # Use a single compute call for efficiency (Aero Protocol)
                    try:
                        import dask

                        dmin, dmax = dask.compute(
                            self.modelvar.min(), self.modelvar.max()
                        )
                    except (ImportError, AttributeError):
                        dmin, dmax = self.modelvar.min(), self.modelvar.max()

                    # Handle pandas DataFrame where .min() returns a Series
                    try:
                        fmin, fmax = float(dmin), float(dmax)
                    except (TypeError, ValueError):
                        # dmin/dmax could be Series if modelvar was a DataFrame
                        fmin = (
                            float(dmin.min())
                            if hasattr(dmin, "min")
                            else float(np.min(dmin))
                        )
                        fmax = (
                            float(dmax.max())
                            if hasattr(dmax, "max")
                            else float(np.max(dmax))
                        )
                    levels_seq = np.linspace(fmin, fmax, levels)
                else:
                    ncolors = len(levels) - 1
                    levels_seq = levels
            else:
                levels_seq = levels

            if levels_seq is None:
                # Fallback: calculate from data to ensure a discrete colorbar
                # if requested but no levels were provided.
                try:
                    import dask

                    dmin, dmax = dask.compute(self.modelvar.min(), self.modelvar.max())
                except (ImportError, AttributeError):
                    dmin, dmax = self.modelvar.min(), self.modelvar.max()

                # Handle pandas DataFrame where .min() returns a Series
                try:
                    fmin, fmax = float(dmin), float(dmax)
                except (TypeError, ValueError):
                    fmin = (
                        float(dmin.min())
                        if hasattr(dmin, "min")
                        else float(np.min(dmin))
                    )
                    fmax = (
                        float(dmax.max())
                        if hasattr(dmax, "max")
                        else float(np.max(dmax))
                    )
                n_lev = self.ncolors if self.ncolors is not None else 10
                levels_seq = np.linspace(fmin, fmax, n_lev + 1)
                ncolors = n_lev

            if levels_seq is not None:
                colorbar_index(
                    ncolors,
                    cmap,
                    minval=levels_seq[0],
                    maxval=levels_seq[-1],
                    dtype=self.dtype,
                    ax=self.ax,
                )
        else:
            self.add_colorbar(mesh)

        if self.date:
            titstring = self.date.strftime("%B %d %Y %H")
            self.ax.set_title(titstring)

        self.fig.tight_layout()
        return self.ax

    def hvplot(self, **kwargs: Any) -> Any:
        """Generate an interactive spatial contour plot using hvPlot (Track B).

        Parameters
        ----------
        **kwargs : Any
            Keyword arguments passed to `hvplot.contourf`.
            Common options include `cmap`, `levels`, `title`, and `alpha`.

        Returns
        -------
        holoviews.core.layout.Layout
            The interactive hvPlot object.
        """
        try:
            import hvplot.xarray  # noqa: F401
        except ImportError:
            raise ImportError(
                "hvplot is required for interactive plotting. Install it with 'pip install hvplot'."
            )

        # Track B defaults
        plot_kwargs = {
            "x": self.lon_coord,
            "y": self.lat_coord,
            "geo": True,
            "rasterize": True,
            "cmap": "viridis",
            "kind": "contourf",
        }
        plot_kwargs.update(kwargs)

        return self.modelvar.hvplot(**plot_kwargs)

__init__(modelvar, gridobj=None, date=None, discrete=True, ncolors=None, dtype='int', col=None, row=None, col_wrap=None, size=None, aspect=None, **kwargs)

Initialize the spatial contour plot.

Parameters

modelvar : Any The input data to contour. Preferred format is an xarray DataArray. gridobj : Any, optional Object with LAT and LON variables to determine extent, by default None. date : datetime, optional Date/time for the plot title, by default None. discrete : bool, optional If True, use a discrete colorbar, by default True. ncolors : int, optional Number of discrete colors for the colorbar, by default None. dtype : str, optional Data type for colorbar tick labels, by default "int". col : str, optional Dimension name to facet by columns. Aligns with Xarray. row : str, optional Dimension name to facet by rows. Aligns with Xarray. col_wrap : int, optional Number of columns before wrapping. Aligns with Xarray. size : float, optional Height (in inches) of each facet. Aligns with Xarray. aspect : float, optional Aspect ratio of each facet. Aligns with Xarray. **kwargs : Any Keyword arguments passed to :class:SpatialPlot for map features and projection.

Source code in src/monet_plots/plots/spatial_contour.py
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
def __init__(
    self,
    modelvar: Any,
    gridobj: Any | None = None,
    date: datetime | None = None,
    discrete: bool = True,
    ncolors: int | None = None,
    dtype: str = "int",
    col: str | None = None,
    row: str | None = None,
    col_wrap: int | None = None,
    size: float | None = None,
    aspect: float | None = None,
    **kwargs: Any,
) -> None:
    """Initialize the spatial contour plot.

    Parameters
    ----------
    modelvar : Any
        The input data to contour. Preferred format is an xarray DataArray.
    gridobj : Any, optional
        Object with LAT and LON variables to determine extent, by default None.
    date : datetime, optional
        Date/time for the plot title, by default None.
    discrete : bool, optional
        If True, use a discrete colorbar, by default True.
    ncolors : int, optional
        Number of discrete colors for the colorbar, by default None.
    dtype : str, optional
        Data type for colorbar tick labels, by default "int".
    col : str, optional
        Dimension name to facet by columns. Aligns with Xarray.
    row : str, optional
        Dimension name to facet by rows. Aligns with Xarray.
    col_wrap : int, optional
        Number of columns before wrapping. Aligns with Xarray.
    size : float, optional
        Height (in inches) of each facet. Aligns with Xarray.
    aspect : float, optional
        Aspect ratio of each facet. Aligns with Xarray.
    **kwargs : Any
        Keyword arguments passed to :class:`SpatialPlot` for map features
        and projection.
    """
    # Initialize the map canvas via SpatialPlot
    super().__init__(**kwargs)

    # Standardize data to Xarray for consistency and lazy evaluation
    self.modelvar = normalize_data(modelvar)
    if isinstance(self.modelvar, xr.Dataset) and len(self.modelvar.data_vars) == 1:
        self.modelvar = self.modelvar[list(self.modelvar.data_vars)[0]]

    self.gridobj = gridobj
    self.date = date
    self.discrete = discrete
    self.ncolors = ncolors
    self.dtype = dtype

    # Aero Protocol: Centralized coordinate identification
    try:
        self.lon_coord, self.lat_coord = self._identify_coords(self.modelvar)
    except ValueError:
        self.lon_coord = kwargs.get("lon_coord", "lon")
        self.lat_coord = kwargs.get("lat_coord", "lat")

    # Ensure coordinates are monotonic for correct plotting
    self.modelvar = self._ensure_monotonic(
        self.modelvar, self.lon_coord, self.lat_coord
    )

    _update_history(self.modelvar, "Initialized monet-plots.SpatialContourPlot")

__new__(modelvar, gridobj=None, date=None, **kwargs)

Redirect to SpatialFacetGridPlot if faceting is requested.

This enables a unified interface for both single-panel and multi-panel spatial plots, following Xarray's plotting conventions.

Parameters

modelvar : Any The input data to contour. gridobj : Any, optional Object with LAT and LON variables, by default None. date : Any, optional Date/time for the plot title, by default None. **kwargs : Any Additional keyword arguments. If faceting arguments (e.g., col, row, or col_wrap) are provided, redirects to SpatialFacetGridPlot.

Returns

Any An instance of SpatialContourPlot or SpatialFacetGridPlot.

Source code in src/monet_plots/plots/spatial_contour.py
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def __new__(
    cls,
    modelvar: Any,
    gridobj: Any | None = None,
    date: Any | None = None,
    **kwargs: Any,
) -> Any:
    """Redirect to SpatialFacetGridPlot if faceting is requested.

    This enables a unified interface for both single-panel and multi-panel
    spatial plots, following Xarray's plotting conventions.

    Parameters
    ----------
    modelvar : Any
        The input data to contour.
    gridobj : Any, optional
        Object with LAT and LON variables, by default None.
    date : Any, optional
        Date/time for the plot title, by default None.
    **kwargs : Any
        Additional keyword arguments. If faceting arguments (e.g., `col`,
        `row`, or `col_wrap`) are provided, redirects to `SpatialFacetGridPlot`.

    Returns
    -------
    Any
        An instance of SpatialContourPlot or SpatialFacetGridPlot.
    """
    from .facet_grid import SpatialFacetGridPlot

    ax = kwargs.get("ax")

    # Aligns with Xarray's trigger for faceting
    facet_kwargs = ["col", "row", "col_wrap"]
    is_faceting = any(kwargs.get(k) is not None for k in facet_kwargs)

    # Redirect to FacetGrid if faceting requested and no existing axes
    if ax is None and is_faceting:
        return SpatialFacetGridPlot(modelvar, **kwargs)

    # Also redirect if input is a Dataset with multiple variables
    if (
        ax is None
        and isinstance(modelvar, xr.Dataset)
        and len(modelvar.data_vars) > 1
    ):
        # Default to faceting by variable if not specified
        kwargs.setdefault("col", "variable")
        return SpatialFacetGridPlot(modelvar, **kwargs)

    return super().__new__(cls)

hvplot(**kwargs)

Generate an interactive spatial contour plot using hvPlot (Track B).

Parameters

**kwargs : Any Keyword arguments passed to hvplot.contourf. Common options include cmap, levels, title, and alpha.

Returns

holoviews.core.layout.Layout The interactive hvPlot object.

Source code in src/monet_plots/plots/spatial_contour.py
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
def hvplot(self, **kwargs: Any) -> Any:
    """Generate an interactive spatial contour plot using hvPlot (Track B).

    Parameters
    ----------
    **kwargs : Any
        Keyword arguments passed to `hvplot.contourf`.
        Common options include `cmap`, `levels`, `title`, and `alpha`.

    Returns
    -------
    holoviews.core.layout.Layout
        The interactive hvPlot object.
    """
    try:
        import hvplot.xarray  # noqa: F401
    except ImportError:
        raise ImportError(
            "hvplot is required for interactive plotting. Install it with 'pip install hvplot'."
        )

    # Track B defaults
    plot_kwargs = {
        "x": self.lon_coord,
        "y": self.lat_coord,
        "geo": True,
        "rasterize": True,
        "cmap": "viridis",
        "kind": "contourf",
    }
    plot_kwargs.update(kwargs)

    return self.modelvar.hvplot(**plot_kwargs)

plot(**kwargs)

Generate a static publication-quality spatial contour plot (Track A).

Parameters

**kwargs : Any Keyword arguments passed to matplotlib.pyplot.contourf. Common options include cmap, levels, vmin, vmax, and alpha. Map features (e.g., coastlines=True) can also be passed here.

Returns

matplotlib.axes.Axes The matplotlib axes object containing the plot.

Source code in src/monet_plots/plots/spatial_contour.py
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
def plot(self, **kwargs: Any) -> Axes:
    """Generate a static publication-quality spatial contour plot (Track A).

    Parameters
    ----------
    **kwargs : Any
        Keyword arguments passed to `matplotlib.pyplot.contourf`.
        Common options include `cmap`, `levels`, `vmin`, `vmax`, and `alpha`.
        Map features (e.g., `coastlines=True`) can also be passed here.

    Returns
    -------
    matplotlib.axes.Axes
        The matplotlib axes object containing the plot.
    """
    lat = None
    lon = None

    if hasattr(self.gridobj, "variables"):
        # Handle legacy gridobj
        try:
            lat_var = self.gridobj.variables["LAT"]
            lon_var = self.gridobj.variables["LON"]

            # Flexible indexing based on dimension count
            if lat_var.ndim == 4:
                lat = lat_var[0, 0, :, :].squeeze()
                lon = lon_var[0, 0, :, :].squeeze()
            elif lat_var.ndim == 3:
                lat = lat_var[0, :, :].squeeze()
                lon = lon_var[0, :, :].squeeze()
            else:
                lat = lat_var.squeeze()
                lon = lon_var.squeeze()
        except (AttributeError, KeyError):
            pass
    elif self.gridobj is not None:
        # Assume it's already an array or similar
        try:
            lat = self.gridobj.LAT
            lon = self.gridobj.LON
        except AttributeError:
            pass

    # Automatically compute extent if not provided
    if "extent" not in kwargs:
        if lon is not None and lat is not None:
            kwargs["extent"] = [
                float(lon.min()),
                float(lon.max()),
                float(lat.min()),
                float(lat.max()),
            ]
        else:
            kwargs["extent"] = self._get_extent_from_data(
                self.modelvar, self.lon_coord, self.lat_coord
            )

    # Draw map features and get remaining kwargs for contourf
    plot_kwargs = self.add_features(**kwargs)

    # Set default contour settings
    plot_kwargs.setdefault("cmap", "viridis")

    # Data is in lat/lon, so specify transform
    plot_kwargs.setdefault("transform", ccrs.PlateCarree())

    # For coordinates and values, we prefer passing the xarray objects directly.
    # This allows Matplotlib to handle the conversion and maintains
    # parity for Dask-backed arrays.
    if lon is None or lat is None:
        longitude = self.modelvar[self.lon_coord]
        latitude = self.modelvar[self.lat_coord]
    else:
        longitude = lon
        latitude = lat

    # Handle colormap and normalization
    final_kwargs = get_plot_kwargs(**plot_kwargs)

    mesh = self.ax.contourf(longitude, latitude, self.modelvar, **final_kwargs)

    # Handle colorbar
    if self.discrete:
        cmap = final_kwargs.get("cmap")
        levels = final_kwargs.get("levels")
        ncolors = self.ncolors
        if ncolors is None and levels is not None:
            if isinstance(levels, int):
                ncolors = levels - 1
                # Use a single compute call for efficiency (Aero Protocol)
                try:
                    import dask

                    dmin, dmax = dask.compute(
                        self.modelvar.min(), self.modelvar.max()
                    )
                except (ImportError, AttributeError):
                    dmin, dmax = self.modelvar.min(), self.modelvar.max()

                # Handle pandas DataFrame where .min() returns a Series
                try:
                    fmin, fmax = float(dmin), float(dmax)
                except (TypeError, ValueError):
                    # dmin/dmax could be Series if modelvar was a DataFrame
                    fmin = (
                        float(dmin.min())
                        if hasattr(dmin, "min")
                        else float(np.min(dmin))
                    )
                    fmax = (
                        float(dmax.max())
                        if hasattr(dmax, "max")
                        else float(np.max(dmax))
                    )
                levels_seq = np.linspace(fmin, fmax, levels)
            else:
                ncolors = len(levels) - 1
                levels_seq = levels
        else:
            levels_seq = levels

        if levels_seq is None:
            # Fallback: calculate from data to ensure a discrete colorbar
            # if requested but no levels were provided.
            try:
                import dask

                dmin, dmax = dask.compute(self.modelvar.min(), self.modelvar.max())
            except (ImportError, AttributeError):
                dmin, dmax = self.modelvar.min(), self.modelvar.max()

            # Handle pandas DataFrame where .min() returns a Series
            try:
                fmin, fmax = float(dmin), float(dmax)
            except (TypeError, ValueError):
                fmin = (
                    float(dmin.min())
                    if hasattr(dmin, "min")
                    else float(np.min(dmin))
                )
                fmax = (
                    float(dmax.max())
                    if hasattr(dmax, "max")
                    else float(np.max(dmax))
                )
            n_lev = self.ncolors if self.ncolors is not None else 10
            levels_seq = np.linspace(fmin, fmax, n_lev + 1)
            ncolors = n_lev

        if levels_seq is not None:
            colorbar_index(
                ncolors,
                cmap,
                minval=levels_seq[0],
                maxval=levels_seq[-1],
                dtype=self.dtype,
                ax=self.ax,
            )
    else:
        self.add_colorbar(mesh)

    if self.date:
        titstring = self.date.strftime("%B %d %Y %H")
        self.ax.set_title(titstring)

    self.fig.tight_layout()
    return self.ax

SpatialBiasScatterPlot

Bases: SpatialPlot

Create a spatial scatter plot showing bias between model and observations.

The scatter points are colored by the difference (model - observations) and sized by the absolute magnitude of this difference, making larger biases more visible. This class supports both Track A (publication) and Track B (interactive) visualization.

Source code in src/monet_plots/plots/spatial_bias_scatter.py
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
class SpatialBiasScatterPlot(SpatialPlot):
    """Create a spatial scatter plot showing bias between model and observations.

    The scatter points are colored by the difference (model - observations) and
    sized by the absolute magnitude of this difference, making larger biases
    more visible. This class supports both Track A (publication) and
    Track B (interactive) visualization.
    """

    def __init__(
        self,
        data: Any,
        col1: str,
        col2: str,
        vmin: float | None = None,
        vmax: float | None = None,
        ncolors: int = 15,
        fact: float = 1.5,
        cmap: str = "RdBu_r",
        **kwargs: Any,
    ) -> None:
        """Initialize the plot with data and map projection.

        Parameters
        ----------
        data : Any
            Input data. Preferred format is xarray.Dataset or xarray.DataArray
            with 'latitude' and 'longitude' (or 'lat' and 'lon') coordinates.
        col1 : str
            Name of the first variable (e.g., observations).
        col2 : str
            Name of the second variable (e.g., model). Bias is calculated
            as col2 - col1.
        vmin : float, optional
            Minimum for colorscale, by default None.
        vmax : float, optional
            Maximum for colorscale, by default None.
        ncolors : int, optional
            Number of discrete colors, by default 15.
        fact : float, optional
            Scaling factor for point sizes, by default 1.5.
        cmap : str, optional
            Colormap for bias values, by default "RdBu_r".
        **kwargs : Any
            Additional keyword arguments for map creation, passed to
            :class:`monet_plots.plots.spatial.SpatialPlot`.
        """
        super().__init__(**kwargs)
        self.data = normalize_data(data)
        self.col1 = col1
        self.col2 = col2
        self.vmin = vmin
        self.vmax = vmax
        self.ncolors = ncolors
        self.fact = fact
        self.cmap = cmap

        _update_history(self.data, "Initialized monet-plots.SpatialBiasScatterPlot")

    def plot(self, **kwargs: Any) -> matplotlib.axes.Axes:
        """Generate a static publication-quality spatial bias scatter plot (Track A).

        Parameters
        ----------
        **kwargs : Any
            Keyword arguments passed to `matplotlib.pyplot.scatter`.
            Map features (e.g., `coastlines=True`) can also be passed here.

        Returns
        -------
        matplotlib.axes.Axes
            The matplotlib axes object containing the plot.
        """
        # Separate feature kwargs from scatter kwargs
        scatter_kwargs = self.add_features(**kwargs)

        # Handle different data types for bias calculation
        if isinstance(self.data, (xr.Dataset, xr.DataArray)):
            # Vectorized calculation using Xarray/Dask
            diff = self.data[self.col2] - self.data[self.col1]

            # Efficient percentile calculation
            try:
                top_val = diff.assign_coords(
                    {"abs_diff": np.abs(diff)}
                ).abs_diff.quantile(0.95)
                if hasattr(top_val, "compute"):
                    top = float(top_val.compute())
                else:
                    top = float(top_val)
            except (ImportError, AttributeError, ValueError):
                top = float(np.nanquantile(np.abs(diff.values), 0.95))

            top = np.around(top)

            # Identify coordinates
            lat_name = next(
                (
                    c
                    for c in ["latitude", "lat"]
                    if c in self.data.coords
                    or c in self.data.data_vars
                    or c in self.data.dims
                ),
                "lat",
            )
            lon_name = next(
                (
                    c
                    for c in ["longitude", "lon"]
                    if c in self.data.coords
                    or c in self.data.data_vars
                    or c in self.data.dims
                ),
                "lon",
            )

            # Compute only what's necessary for plotting
            plot_ds = xr.Dataset(
                {"diff": diff, "lat": self.data[lat_name], "lon": self.data[lon_name]}
            )

            # Drop NaNs before compute to minimize transfer
            if plot_ds.dims:
                plot_ds = plot_ds.dropna(dim=list(plot_ds.dims)[0])

            concrete = plot_ds.compute()
            diff_vals = concrete["diff"].values
            lat_vals = concrete["lat"].values
            lon_vals = concrete["lon"].values
        else:
            # Fallback for Pandas
            df = self.data.dropna(subset=[self.col1, self.col2])
            diff_vals = (df[self.col2] - df[self.col1]).values
            lat_name = next((c for c in ["latitude", "lat"] if c in df.columns), "lat")
            lon_name = next((c for c in ["longitude", "lon"] if c in df.columns), "lon")
            lat_vals = df[lat_name].values
            lon_vals = df[lon_name].values
            top = np.around(np.nanquantile(np.abs(diff_vals), 0.95))

        # Use scaling tools
        cmap, norm = get_discrete_scale(
            diff_vals, cmap=self.cmap, n_levels=self.ncolors, vmin=-top, vmax=top
        )

        # Create colorbar
        mappable = plt.cm.ScalarMappable(norm=norm, cmap=cmap)
        cbar = self.add_colorbar(mappable, format="%1.2g")
        cbar.ax.tick_params(labelsize=10)

        ss = np.abs(diff_vals) / top * 100.0 * self.fact
        ss[ss > 300] = 300.0

        # Prepare scatter kwargs
        final_scatter_kwargs = get_plot_kwargs(
            cmap=cmap,
            norm=norm,
            s=ss,
            c=diff_vals,
            transform=ccrs.PlateCarree(),
            edgecolors="k",
            linewidths=0.25,
            alpha=0.7,
            **scatter_kwargs,
        )

        self.ax.scatter(
            lon_vals,
            lat_vals,
            **final_scatter_kwargs,
        )
        return self.ax

    def hvplot(self, **kwargs: Any) -> Any:
        """Generate an interactive spatial bias scatter plot using hvPlot (Track B).

        Parameters
        ----------
        **kwargs : Any
            Keyword arguments passed to `hvplot.points`.
            Common options include `cmap`, `title`, and `alpha`.
            `rasterize=True` is used by default for high performance.

        Returns
        -------
        holoviews.core.layout.Layout
            The interactive hvPlot object.
        """
        try:
            import hvplot.pandas  # noqa: F401
            import hvplot.xarray  # noqa: F401
        except ImportError:
            raise ImportError(
                "hvplot is required for interactive plotting. Install it with 'pip install hvplot'."
            )

        import pandas as pd

        if isinstance(self.data, pd.DataFrame):
            lat_name = next(
                (c for c in ["latitude", "lat"] if c in self.data.columns), "lat"
            )
            lon_name = next(
                (c for c in ["longitude", "lon"] if c in self.data.columns), "lon"
            )

            ds_plot = self.data.copy()
            ds_plot["bias"] = ds_plot[self.col2] - ds_plot[self.col1]
            plot_target = ds_plot
        else:
            lat_name = next(
                (
                    c
                    for c in ["latitude", "lat"]
                    if c in self.data.coords or c in self.data.dims
                ),
                "lat",
            )
            lon_name = next(
                (
                    c
                    for c in ["longitude", "lon"]
                    if c in self.data.coords or c in self.data.dims
                ),
                "lon",
            )

            ds_plot = self.data.copy()
            ds_plot["bias"] = ds_plot[self.col2] - ds_plot[self.col1]
            _update_history(ds_plot, "Calculated bias for hvplot")
            plot_target = ds_plot

        # Track B defaults
        plot_kwargs = {
            "x": lon_name,
            "y": lat_name,
            "c": "bias",
            "geo": True,
            "rasterize": True,
            "cmap": self.cmap,
        }

        plot_kwargs.update(kwargs)

        return plot_target.hvplot.points(**plot_kwargs)

__init__(data, col1, col2, vmin=None, vmax=None, ncolors=15, fact=1.5, cmap='RdBu_r', **kwargs)

Initialize the plot with data and map projection.

Parameters

data : Any Input data. Preferred format is xarray.Dataset or xarray.DataArray with 'latitude' and 'longitude' (or 'lat' and 'lon') coordinates. col1 : str Name of the first variable (e.g., observations). col2 : str Name of the second variable (e.g., model). Bias is calculated as col2 - col1. vmin : float, optional Minimum for colorscale, by default None. vmax : float, optional Maximum for colorscale, by default None. ncolors : int, optional Number of discrete colors, by default 15. fact : float, optional Scaling factor for point sizes, by default 1.5. cmap : str, optional Colormap for bias values, by default "RdBu_r". **kwargs : Any Additional keyword arguments for map creation, passed to :class:monet_plots.plots.spatial.SpatialPlot.

Source code in src/monet_plots/plots/spatial_bias_scatter.py
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def __init__(
    self,
    data: Any,
    col1: str,
    col2: str,
    vmin: float | None = None,
    vmax: float | None = None,
    ncolors: int = 15,
    fact: float = 1.5,
    cmap: str = "RdBu_r",
    **kwargs: Any,
) -> None:
    """Initialize the plot with data and map projection.

    Parameters
    ----------
    data : Any
        Input data. Preferred format is xarray.Dataset or xarray.DataArray
        with 'latitude' and 'longitude' (or 'lat' and 'lon') coordinates.
    col1 : str
        Name of the first variable (e.g., observations).
    col2 : str
        Name of the second variable (e.g., model). Bias is calculated
        as col2 - col1.
    vmin : float, optional
        Minimum for colorscale, by default None.
    vmax : float, optional
        Maximum for colorscale, by default None.
    ncolors : int, optional
        Number of discrete colors, by default 15.
    fact : float, optional
        Scaling factor for point sizes, by default 1.5.
    cmap : str, optional
        Colormap for bias values, by default "RdBu_r".
    **kwargs : Any
        Additional keyword arguments for map creation, passed to
        :class:`monet_plots.plots.spatial.SpatialPlot`.
    """
    super().__init__(**kwargs)
    self.data = normalize_data(data)
    self.col1 = col1
    self.col2 = col2
    self.vmin = vmin
    self.vmax = vmax
    self.ncolors = ncolors
    self.fact = fact
    self.cmap = cmap

    _update_history(self.data, "Initialized monet-plots.SpatialBiasScatterPlot")

hvplot(**kwargs)

Generate an interactive spatial bias scatter plot using hvPlot (Track B).

Parameters

**kwargs : Any Keyword arguments passed to hvplot.points. Common options include cmap, title, and alpha. rasterize=True is used by default for high performance.

Returns

holoviews.core.layout.Layout The interactive hvPlot object.

Source code in src/monet_plots/plots/spatial_bias_scatter.py
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
def hvplot(self, **kwargs: Any) -> Any:
    """Generate an interactive spatial bias scatter plot using hvPlot (Track B).

    Parameters
    ----------
    **kwargs : Any
        Keyword arguments passed to `hvplot.points`.
        Common options include `cmap`, `title`, and `alpha`.
        `rasterize=True` is used by default for high performance.

    Returns
    -------
    holoviews.core.layout.Layout
        The interactive hvPlot object.
    """
    try:
        import hvplot.pandas  # noqa: F401
        import hvplot.xarray  # noqa: F401
    except ImportError:
        raise ImportError(
            "hvplot is required for interactive plotting. Install it with 'pip install hvplot'."
        )

    import pandas as pd

    if isinstance(self.data, pd.DataFrame):
        lat_name = next(
            (c for c in ["latitude", "lat"] if c in self.data.columns), "lat"
        )
        lon_name = next(
            (c for c in ["longitude", "lon"] if c in self.data.columns), "lon"
        )

        ds_plot = self.data.copy()
        ds_plot["bias"] = ds_plot[self.col2] - ds_plot[self.col1]
        plot_target = ds_plot
    else:
        lat_name = next(
            (
                c
                for c in ["latitude", "lat"]
                if c in self.data.coords or c in self.data.dims
            ),
            "lat",
        )
        lon_name = next(
            (
                c
                for c in ["longitude", "lon"]
                if c in self.data.coords or c in self.data.dims
            ),
            "lon",
        )

        ds_plot = self.data.copy()
        ds_plot["bias"] = ds_plot[self.col2] - ds_plot[self.col1]
        _update_history(ds_plot, "Calculated bias for hvplot")
        plot_target = ds_plot

    # Track B defaults
    plot_kwargs = {
        "x": lon_name,
        "y": lat_name,
        "c": "bias",
        "geo": True,
        "rasterize": True,
        "cmap": self.cmap,
    }

    plot_kwargs.update(kwargs)

    return plot_target.hvplot.points(**plot_kwargs)

plot(**kwargs)

Generate a static publication-quality spatial bias scatter plot (Track A).

Parameters

**kwargs : Any Keyword arguments passed to matplotlib.pyplot.scatter. Map features (e.g., coastlines=True) can also be passed here.

Returns

matplotlib.axes.Axes The matplotlib axes object containing the plot.

Source code in src/monet_plots/plots/spatial_bias_scatter.py
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
def plot(self, **kwargs: Any) -> matplotlib.axes.Axes:
    """Generate a static publication-quality spatial bias scatter plot (Track A).

    Parameters
    ----------
    **kwargs : Any
        Keyword arguments passed to `matplotlib.pyplot.scatter`.
        Map features (e.g., `coastlines=True`) can also be passed here.

    Returns
    -------
    matplotlib.axes.Axes
        The matplotlib axes object containing the plot.
    """
    # Separate feature kwargs from scatter kwargs
    scatter_kwargs = self.add_features(**kwargs)

    # Handle different data types for bias calculation
    if isinstance(self.data, (xr.Dataset, xr.DataArray)):
        # Vectorized calculation using Xarray/Dask
        diff = self.data[self.col2] - self.data[self.col1]

        # Efficient percentile calculation
        try:
            top_val = diff.assign_coords(
                {"abs_diff": np.abs(diff)}
            ).abs_diff.quantile(0.95)
            if hasattr(top_val, "compute"):
                top = float(top_val.compute())
            else:
                top = float(top_val)
        except (ImportError, AttributeError, ValueError):
            top = float(np.nanquantile(np.abs(diff.values), 0.95))

        top = np.around(top)

        # Identify coordinates
        lat_name = next(
            (
                c
                for c in ["latitude", "lat"]
                if c in self.data.coords
                or c in self.data.data_vars
                or c in self.data.dims
            ),
            "lat",
        )
        lon_name = next(
            (
                c
                for c in ["longitude", "lon"]
                if c in self.data.coords
                or c in self.data.data_vars
                or c in self.data.dims
            ),
            "lon",
        )

        # Compute only what's necessary for plotting
        plot_ds = xr.Dataset(
            {"diff": diff, "lat": self.data[lat_name], "lon": self.data[lon_name]}
        )

        # Drop NaNs before compute to minimize transfer
        if plot_ds.dims:
            plot_ds = plot_ds.dropna(dim=list(plot_ds.dims)[0])

        concrete = plot_ds.compute()
        diff_vals = concrete["diff"].values
        lat_vals = concrete["lat"].values
        lon_vals = concrete["lon"].values
    else:
        # Fallback for Pandas
        df = self.data.dropna(subset=[self.col1, self.col2])
        diff_vals = (df[self.col2] - df[self.col1]).values
        lat_name = next((c for c in ["latitude", "lat"] if c in df.columns), "lat")
        lon_name = next((c for c in ["longitude", "lon"] if c in df.columns), "lon")
        lat_vals = df[lat_name].values
        lon_vals = df[lon_name].values
        top = np.around(np.nanquantile(np.abs(diff_vals), 0.95))

    # Use scaling tools
    cmap, norm = get_discrete_scale(
        diff_vals, cmap=self.cmap, n_levels=self.ncolors, vmin=-top, vmax=top
    )

    # Create colorbar
    mappable = plt.cm.ScalarMappable(norm=norm, cmap=cmap)
    cbar = self.add_colorbar(mappable, format="%1.2g")
    cbar.ax.tick_params(labelsize=10)

    ss = np.abs(diff_vals) / top * 100.0 * self.fact
    ss[ss > 300] = 300.0

    # Prepare scatter kwargs
    final_scatter_kwargs = get_plot_kwargs(
        cmap=cmap,
        norm=norm,
        s=ss,
        c=diff_vals,
        transform=ccrs.PlateCarree(),
        edgecolors="k",
        linewidths=0.25,
        alpha=0.7,
        **scatter_kwargs,
    )

    self.ax.scatter(
        lon_vals,
        lat_vals,
        **final_scatter_kwargs,
    )
    return self.ax