Skip to content

Facet Grids

FacetGridPlot

Bases: BasePlot

Creates a facet grid plot.

This class creates a facet grid plot using seaborn's FacetGrid.

Source code in src/monet_plots/plots/facet_grid.py
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
class FacetGridPlot(BasePlot):
    """Creates a facet grid plot.

    This class creates a facet grid plot using seaborn's FacetGrid.
    """

    def __init__(
        self,
        data: Any,
        row: str | None = None,
        col: str | None = None,
        hue: str | None = None,
        col_wrap: int | None = None,
        size: float | None = None,
        aspect: float = 1,
        subplot_kws: dict[str, Any] | None = None,
        style: str | None = "wiley",
        **kwargs: Any,
    ) -> None:
        """Initializes the facet grid.

        Parameters
        ----------
        data : Any
            The data to plot.
        row : str, optional
            Variable to map to row facets, by default None.
        col : str, optional
            Variable to map to column facets, by default None.
        hue : str, optional
            Variable to map to color mapping, by default None.
        col_wrap : int, optional
            Number of columns before wrapping, by default None.
        size : float, optional
            Height (in inches) of each facet, by default 3.
            Aligns with Xarray convention.
        aspect : float, optional
            Aspect ratio of each facet, by default 1.
        subplot_kws : dict, optional
            Keyword arguments for subplots (e.g. projection).
        style : str, optional
            Style name to apply, by default "wiley".
        **kwargs : Any
            Additional keyword arguments to pass to `FacetGrid`.
            `height` is supported as an alias for `size` (Seaborn convention).
        """
        # Apply style
        if style:
            set_style(style)

        # Handle size/height alignment (Xarray vs Seaborn)
        self.size = size or kwargs.pop("height", 3)
        self.aspect = aspect

        # Store facet parameters
        self.row = row
        self.col = col
        self.hue = hue
        self.col_wrap = col_wrap

        # Aero Protocol: Preserve lazy Xarray objects
        self.raw_data = data
        self.is_xarray = isinstance(data, (xr.DataArray, xr.Dataset))

        if self.is_xarray:
            self.data = data  # Keep as Xarray
            # For Xarray, we can use xarray.plot.FacetGrid
            # We delay creation to SpatialFacetGridPlot if possible,
            # or initialize it here if we have enough info.
            self.grid = None
            if col or row:
                try:
                    from xarray.plot.facetgrid import FacetGrid as xrFacetGrid

                    self.grid = xrFacetGrid(
                        data,
                        col=col,
                        row=row,
                        col_wrap=col_wrap,
                        subplot_kws=subplot_kws,
                    )
                    # Initialize default titles for Xarray
                    self.grid.set_titles()
                except (ImportError, TypeError, AttributeError):
                    pass
        else:
            self.data = to_dataframe(data).reset_index()
            # Create the Seaborn FacetGrid
            self.grid = sns.FacetGrid(
                self.data,
                row=self.row,
                col=self.col,
                hue=self.hue,
                col_wrap=self.col_wrap,
                height=self.size,
                aspect=self.aspect,
                subplot_kws=subplot_kws,
                **kwargs,
            )

        # Unified BasePlot initialization
        axes = getattr(self.grid, "axs", getattr(self.grid, "axes", None))
        if axes is not None:
            super().__init__(fig=self.grid.fig, ax=axes.flatten()[0], style=style)
        else:
            super().__init__(style=style)

        # For compatibility with tests, also store as 'g'
        self.g = self.grid

    def map_dataframe(self, plot_func: Callable, *args: Any, **kwargs: Any) -> None:
        """Maps a plotting function to the facet grid.

        Args:
            plot_func (function): The plotting function to map.
            *args: Positional arguments to pass to the plotting function.
            **kwargs: Keyword arguments to pass to the plotting function.
        """
        self.grid.map_dataframe(plot_func, *args, **kwargs)

    def set_titles(self, *args, **kwargs):
        """Sets the titles of the facet grid.

        Args:
            *args: Positional arguments to pass to `set_titles`.
            **kwargs: Keyword arguments to pass to `set_titles`.
        """
        self.grid.set_titles(*args, **kwargs)

    def save(self, filename, **kwargs):
        """Saves the plot to a file.

        Args:
            filename (str): The name of the file to save the plot to.
            **kwargs: Additional keyword arguments to pass to `savefig`.
        """
        self.fig.savefig(filename, **kwargs)

    def plot(self, plot_func=None, *args, **kwargs):
        """Plots the data using the FacetGrid.

        Args:
            plot_func (function, optional): The plotting function to use.
            *args: Positional arguments to pass to the plotting function.
            **kwargs: Keyword arguments to pass to the plotting function.
        """
        if plot_func is not None:
            self.grid.map(plot_func, *args, **kwargs)

    def close(self):
        """Closes the plot."""
        plt.close(self.fig)

__init__(data, row=None, col=None, hue=None, col_wrap=None, size=None, aspect=1, subplot_kws=None, style='wiley', **kwargs)

Initializes the facet grid.

Parameters

data : Any The data to plot. row : str, optional Variable to map to row facets, by default None. col : str, optional Variable to map to column facets, by default None. hue : str, optional Variable to map to color mapping, by default None. col_wrap : int, optional Number of columns before wrapping, by default None. size : float, optional Height (in inches) of each facet, by default 3. Aligns with Xarray convention. aspect : float, optional Aspect ratio of each facet, by default 1. subplot_kws : dict, optional Keyword arguments for subplots (e.g. projection). style : str, optional Style name to apply, by default "wiley". **kwargs : Any Additional keyword arguments to pass to FacetGrid. height is supported as an alias for size (Seaborn convention).

Source code in src/monet_plots/plots/facet_grid.py
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
def __init__(
    self,
    data: Any,
    row: str | None = None,
    col: str | None = None,
    hue: str | None = None,
    col_wrap: int | None = None,
    size: float | None = None,
    aspect: float = 1,
    subplot_kws: dict[str, Any] | None = None,
    style: str | None = "wiley",
    **kwargs: Any,
) -> None:
    """Initializes the facet grid.

    Parameters
    ----------
    data : Any
        The data to plot.
    row : str, optional
        Variable to map to row facets, by default None.
    col : str, optional
        Variable to map to column facets, by default None.
    hue : str, optional
        Variable to map to color mapping, by default None.
    col_wrap : int, optional
        Number of columns before wrapping, by default None.
    size : float, optional
        Height (in inches) of each facet, by default 3.
        Aligns with Xarray convention.
    aspect : float, optional
        Aspect ratio of each facet, by default 1.
    subplot_kws : dict, optional
        Keyword arguments for subplots (e.g. projection).
    style : str, optional
        Style name to apply, by default "wiley".
    **kwargs : Any
        Additional keyword arguments to pass to `FacetGrid`.
        `height` is supported as an alias for `size` (Seaborn convention).
    """
    # Apply style
    if style:
        set_style(style)

    # Handle size/height alignment (Xarray vs Seaborn)
    self.size = size or kwargs.pop("height", 3)
    self.aspect = aspect

    # Store facet parameters
    self.row = row
    self.col = col
    self.hue = hue
    self.col_wrap = col_wrap

    # Aero Protocol: Preserve lazy Xarray objects
    self.raw_data = data
    self.is_xarray = isinstance(data, (xr.DataArray, xr.Dataset))

    if self.is_xarray:
        self.data = data  # Keep as Xarray
        # For Xarray, we can use xarray.plot.FacetGrid
        # We delay creation to SpatialFacetGridPlot if possible,
        # or initialize it here if we have enough info.
        self.grid = None
        if col or row:
            try:
                from xarray.plot.facetgrid import FacetGrid as xrFacetGrid

                self.grid = xrFacetGrid(
                    data,
                    col=col,
                    row=row,
                    col_wrap=col_wrap,
                    subplot_kws=subplot_kws,
                )
                # Initialize default titles for Xarray
                self.grid.set_titles()
            except (ImportError, TypeError, AttributeError):
                pass
    else:
        self.data = to_dataframe(data).reset_index()
        # Create the Seaborn FacetGrid
        self.grid = sns.FacetGrid(
            self.data,
            row=self.row,
            col=self.col,
            hue=self.hue,
            col_wrap=self.col_wrap,
            height=self.size,
            aspect=self.aspect,
            subplot_kws=subplot_kws,
            **kwargs,
        )

    # Unified BasePlot initialization
    axes = getattr(self.grid, "axs", getattr(self.grid, "axes", None))
    if axes is not None:
        super().__init__(fig=self.grid.fig, ax=axes.flatten()[0], style=style)
    else:
        super().__init__(style=style)

    # For compatibility with tests, also store as 'g'
    self.g = self.grid

close()

Closes the plot.

Source code in src/monet_plots/plots/facet_grid.py
169
170
171
def close(self):
    """Closes the plot."""
    plt.close(self.fig)

map_dataframe(plot_func, *args, **kwargs)

Maps a plotting function to the facet grid.

Parameters:

Name Type Description Default
plot_func function

The plotting function to map.

required
*args Any

Positional arguments to pass to the plotting function.

()
**kwargs Any

Keyword arguments to pass to the plotting function.

{}
Source code in src/monet_plots/plots/facet_grid.py
130
131
132
133
134
135
136
137
138
def map_dataframe(self, plot_func: Callable, *args: Any, **kwargs: Any) -> None:
    """Maps a plotting function to the facet grid.

    Args:
        plot_func (function): The plotting function to map.
        *args: Positional arguments to pass to the plotting function.
        **kwargs: Keyword arguments to pass to the plotting function.
    """
    self.grid.map_dataframe(plot_func, *args, **kwargs)

plot(plot_func=None, *args, **kwargs)

Plots the data using the FacetGrid.

Parameters:

Name Type Description Default
plot_func function

The plotting function to use.

None
*args

Positional arguments to pass to the plotting function.

()
**kwargs

Keyword arguments to pass to the plotting function.

{}
Source code in src/monet_plots/plots/facet_grid.py
158
159
160
161
162
163
164
165
166
167
def plot(self, plot_func=None, *args, **kwargs):
    """Plots the data using the FacetGrid.

    Args:
        plot_func (function, optional): The plotting function to use.
        *args: Positional arguments to pass to the plotting function.
        **kwargs: Keyword arguments to pass to the plotting function.
    """
    if plot_func is not None:
        self.grid.map(plot_func, *args, **kwargs)

save(filename, **kwargs)

Saves the plot to a file.

Parameters:

Name Type Description Default
filename str

The name of the file to save the plot to.

required
**kwargs

Additional keyword arguments to pass to savefig.

{}
Source code in src/monet_plots/plots/facet_grid.py
149
150
151
152
153
154
155
156
def save(self, filename, **kwargs):
    """Saves the plot to a file.

    Args:
        filename (str): The name of the file to save the plot to.
        **kwargs: Additional keyword arguments to pass to `savefig`.
    """
    self.fig.savefig(filename, **kwargs)

set_titles(*args, **kwargs)

Sets the titles of the facet grid.

Parameters:

Name Type Description Default
*args

Positional arguments to pass to set_titles.

()
**kwargs

Keyword arguments to pass to set_titles.

{}
Source code in src/monet_plots/plots/facet_grid.py
140
141
142
143
144
145
146
147
def set_titles(self, *args, **kwargs):
    """Sets the titles of the facet grid.

    Args:
        *args: Positional arguments to pass to `set_titles`.
        **kwargs: Keyword arguments to pass to `set_titles`.
    """
    self.grid.set_titles(*args, **kwargs)

SpatialFacetGridPlot

Bases: FacetGridPlot

Geospatial version of FacetGridPlot.

Source code in src/monet_plots/plots/facet_grid.py
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
class SpatialFacetGridPlot(FacetGridPlot):
    """Geospatial version of FacetGridPlot."""

    def __init__(
        self,
        data: Any,
        *,
        row: str | None = None,
        col: str | None = None,
        col_wrap: int | None = None,
        projection: ccrs.Projection | None = None,
        size: float | None = None,
        aspect: float = 1.2,
        style: str | None = "wiley",
        **kwargs: Any,
    ) -> None:
        """Initialize Spatial Facet Grid.

        Parameters
        ----------
        data : Any
            Geospatial data to plot. Preferred format is xr.DataArray or xr.Dataset.
        row : str, optional
            Dimension/variable to map to rows.
        col : str, optional
            Dimension/variable to map to columns.
        col_wrap : int, optional
            Wrap columns at this number.
        projection : ccrs.Projection, optional
            Cartopy projection for the maps. Defaults to PlateCarree.
        size : float, optional
            Height (in inches) of each facet, by default 4.
            Aligns with Xarray convention.
        aspect : float, optional
            Aspect ratio of each facet, by default 1.2.
        style : str, optional
            Style name to apply, by default "wiley".
        **kwargs : Any
            Additional arguments for FacetGrid.
        """
        self.original_data = data
        import cartopy.crs as ccrs

        self.projection = projection or ccrs.PlateCarree()

        # Handle xr.Dataset by converting to DataArray if faceting by variable
        # or if there is only one data variable.
        self.is_dataset = isinstance(data, xr.Dataset)
        if self.is_dataset:
            if row == "variable" or col == "variable":
                data = data.to_array(dim="variable", name="value")
            elif len(data.data_vars) == 1:
                # Auto-select the only variable to ensure map_dataarray works
                data = data[list(data.data_vars)[0]]

        # Aligns with Xarray's default size for maps
        size = size or kwargs.pop("height", 4)

        # Call FacetGridPlot init which handles the two-track branching
        super().__init__(
            data,
            row=row,
            col=col,
            col_wrap=col_wrap,
            size=size,
            aspect=aspect,
            subplot_kws={"projection": self.projection},
            style=style,
            **kwargs,
        )

        # Set default titles if grid is already created (Pandas track)
        if self.grid:
            self._set_default_titles()

    def _set_default_titles(self) -> None:
        """Format facet titles with metadata and date-time."""
        axes = getattr(self.grid, "axs", getattr(self.grid, "axes", None))
        if axes is None:
            return

        for ax in axes.flatten():
            if ax is None:
                continue
            title = ax.get_title()

            # Handle titles that might have multiple facets (e.g. "row = val | col = val")
            parts = title.split(" | ")
            new_parts = []

            for part in parts:
                if " = " in part:
                    # Use split(" = ", 1) to handle values that might contain " = "
                    dim_val = part.split(" = ", 1)
                    if len(dim_val) == 2:
                        dim, val = dim_val
                        dim = dim.strip()
                        val = val.strip()

                        # Handle date-time formatting
                        try:
                            dt = pd.to_datetime(val)
                            val = dt.strftime("%Y-%m-%d %H:%M")
                        except (ValueError, TypeError):
                            pass

                        # Handle long_name for dimensions/variables
                        if dim == "variable" and self.is_dataset:
                            try:
                                # self.original_data is the original Dataset
                                var_obj = self.original_data[val]
                                long_name = var_obj.attrs.get("long_name", val)
                                units = var_obj.attrs.get("units", "")
                                val = f"{long_name} ({units})" if units else long_name
                                dim = ""
                            except (KeyError, AttributeError):
                                pass
                        elif dim in self.original_data.coords:
                            try:
                                coord_obj = self.original_data.coords[dim]
                                long_name = coord_obj.attrs.get("long_name", dim)
                                dim = long_name
                            except (KeyError, AttributeError):
                                pass

                        new_parts.append(f"{dim} {val}".strip())
                    else:
                        new_parts.append(part)
                else:
                    new_parts.append(part)

            ax.set_title(" | ".join(new_parts))

    def add_map_features(self, **kwargs: Any) -> None:
        """Add cartopy features to all facets.

        Parameters
        ----------
        **kwargs : Any
            Keyword arguments passed to SpatialPlot.add_features.
            Default is coastlines=True.
        """
        from .spatial import SpatialPlot

        if "coastlines" not in kwargs:
            kwargs["coastlines"] = True

        axes = getattr(self.grid, "axs", getattr(self.grid, "axes", None))
        if axes is None:
            return

        for ax in axes.flatten():
            if ax is None:
                continue
            # Use SpatialPlot's feature logic on each axis
            SpatialPlot(ax=ax, projection=self.projection, **kwargs)

    def map_monet(
        self,
        plotter_class: type,
        *,
        x: str | None = None,
        y: str | None = None,
        var_name: str | None = None,
        **kwargs: Any,
    ) -> None:
        """Map a monet-plots spatial plotter to the grid.

        Parameters
        ----------
        plotter_class : type
            A class from monet_plots.plots (e.g., SpatialImshowPlot).
        x : str, optional
            Column name for longitude.
        y : str, optional
            Column name for latitude.
        var_name : str, optional
            The variable name to plot.
        **kwargs : Any
            Arguments passed to the plotter and map features.
        """
        # Separate feature kwargs
        feature_keys = [
            "coastlines",
            "states",
            "countries",
            "ocean",
            "land",
            "lakes",
            "rivers",
            "borders",
            "gridlines",
            "extent",
            "resolution",
        ]
        feature_kwargs = {k: kwargs.pop(k) for k in feature_keys if k in kwargs}
        add_shared_cb = kwargs.pop("add_colorbar", False)

        if self.is_xarray:
            # Track A: Xarray-native faceting (Lazy by Default)
            from .spatial import SpatialPlot

            # Identify coordinates if not provided
            if x is None or y is None:
                sp = SpatialPlot(style=None)
                x_id, y_id = sp._identify_coords(self.data)
                x = x or x_id
                y = y or y_id

            # Use Xarray's native plotting which handles faceting
            plot_type = "imshow"
            if "Contour" in plotter_class.__name__:
                plot_type = "contourf"

            # Prepare plotting arguments
            plot_kwargs = kwargs.copy()
            plot_kwargs.setdefault("transform", ccrs.PlateCarree())
            if add_shared_cb:
                plot_kwargs.setdefault("add_colorbar", True)
                plot_kwargs.setdefault("cbar_kwargs", {"orientation": "horizontal"})

            # Select data variable if it's a Dataset
            plot_data = self.data
            if isinstance(plot_data, xr.Dataset):
                if var_name:
                    plot_data = plot_data[var_name]
                else:
                    # Pick the first data variable that is not a facet dimension
                    for v in plot_data.data_vars:
                        if v not in [self.col, self.row]:
                            plot_data = plot_data[v]
                            break

            # Trigger Xarray's facet grid
            # If we already have a grid, we can use it
            if self.grid is not None:
                import xarray.plot as xplt

                func = xplt.imshow if plot_type == "imshow" else xplt.contourf

                # Check if the grid's data is a Dataset.
                # If so, map_dataarray might fail in some xarray versions.
                # We use map with a selection wrapper instead.
                grid_data = getattr(self.grid, "data", None)
                if isinstance(grid_data, xr.Dataset):
                    # var_name was selected above from self.data or passed in
                    def _mapped_xr_plot(ds_subset, x_coord, y_coord, **inner_kwargs):
                        func(ds_subset[var_name], x=x_coord, y=y_coord, **inner_kwargs)

                    self.grid.map(_mapped_xr_plot, x, y, **plot_kwargs)
                elif hasattr(self.grid, "map_dataarray"):
                    self.grid.map_dataarray(func, x, y, **plot_kwargs)
                else:
                    self.grid.map(func, x, y, **plot_kwargs)
            else:
                xr_plot_func = getattr(plot_data.plot, plot_type)
                self.grid = xr_plot_func(
                    x=x,
                    y=y,
                    col=self.col,
                    row=self.row,
                    col_wrap=self.col_wrap,
                    subplot_kws={"projection": self.projection},
                    **plot_kwargs,
                )

            # Update BasePlot attributes
            self.fig = self.grid.fig
            axes = getattr(self.grid, "axs", getattr(self.grid, "axes", None))
            self.ax = axes.flatten()[0]
            self.g = self.grid

            # Add features to all facets
            self.add_map_features(**feature_kwargs)
            self._set_default_titles()

        else:
            # Track B: Seaborn-based faceting (Eager/Pandas)
            x = x or "lon"
            y = y or "lat"
            if var_name is None:
                if "variable" in self.data.columns:
                    var_name = "value"
                elif isinstance(self.raw_data, xr.DataArray):
                    var_name = self.raw_data.name
                elif isinstance(self.raw_data, xr.Dataset):
                    var_name = list(self.raw_data.data_vars)[0]

            def _mapped_plot(*args, **kwargs_inner):
                data_df = kwargs_inner.pop("data")
                ax = plt.gca()
                temp_da = data_df.set_index([y, x]).to_xarray()[var_name]
                plotter = plotter_class(temp_da, ax=ax, **kwargs_inner)
                plotter.plot()

            self.map_dataframe(_mapped_plot, **kwargs)
            self.add_map_features(**feature_kwargs)
            if add_shared_cb:
                self._add_shared_colorbar(**kwargs)

    def _add_shared_colorbar(self, **kwargs: Any) -> None:
        """Add a shared colorbar to the figure."""
        # Find the last mappable object in the facets and the last valid axis
        mappable = None
        target_ax = None
        axes = getattr(self.grid, "axs", getattr(self.grid, "axes", None))
        if axes is None:
            return

        for ax in reversed(axes.flatten()):
            if ax is None:
                continue
            if target_ax is None:
                target_ax = ax
            if ax.collections and mappable is None:
                mappable = ax.collections[0]
            if ax.images and mappable is None:
                mappable = ax.images[0]

        if mappable and target_ax:
            # Add colorbar to the last valid axis
            self.add_colorbar(
                mappable,
                ax=target_ax,
                label=kwargs.get("label", ""),
            )

__init__(data, *, row=None, col=None, col_wrap=None, projection=None, size=None, aspect=1.2, style='wiley', **kwargs)

Initialize Spatial Facet Grid.

Parameters

data : Any Geospatial data to plot. Preferred format is xr.DataArray or xr.Dataset. row : str, optional Dimension/variable to map to rows. col : str, optional Dimension/variable to map to columns. col_wrap : int, optional Wrap columns at this number. projection : ccrs.Projection, optional Cartopy projection for the maps. Defaults to PlateCarree. size : float, optional Height (in inches) of each facet, by default 4. Aligns with Xarray convention. aspect : float, optional Aspect ratio of each facet, by default 1.2. style : str, optional Style name to apply, by default "wiley". **kwargs : Any Additional arguments for FacetGrid.

Source code in src/monet_plots/plots/facet_grid.py
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
def __init__(
    self,
    data: Any,
    *,
    row: str | None = None,
    col: str | None = None,
    col_wrap: int | None = None,
    projection: ccrs.Projection | None = None,
    size: float | None = None,
    aspect: float = 1.2,
    style: str | None = "wiley",
    **kwargs: Any,
) -> None:
    """Initialize Spatial Facet Grid.

    Parameters
    ----------
    data : Any
        Geospatial data to plot. Preferred format is xr.DataArray or xr.Dataset.
    row : str, optional
        Dimension/variable to map to rows.
    col : str, optional
        Dimension/variable to map to columns.
    col_wrap : int, optional
        Wrap columns at this number.
    projection : ccrs.Projection, optional
        Cartopy projection for the maps. Defaults to PlateCarree.
    size : float, optional
        Height (in inches) of each facet, by default 4.
        Aligns with Xarray convention.
    aspect : float, optional
        Aspect ratio of each facet, by default 1.2.
    style : str, optional
        Style name to apply, by default "wiley".
    **kwargs : Any
        Additional arguments for FacetGrid.
    """
    self.original_data = data
    import cartopy.crs as ccrs

    self.projection = projection or ccrs.PlateCarree()

    # Handle xr.Dataset by converting to DataArray if faceting by variable
    # or if there is only one data variable.
    self.is_dataset = isinstance(data, xr.Dataset)
    if self.is_dataset:
        if row == "variable" or col == "variable":
            data = data.to_array(dim="variable", name="value")
        elif len(data.data_vars) == 1:
            # Auto-select the only variable to ensure map_dataarray works
            data = data[list(data.data_vars)[0]]

    # Aligns with Xarray's default size for maps
    size = size or kwargs.pop("height", 4)

    # Call FacetGridPlot init which handles the two-track branching
    super().__init__(
        data,
        row=row,
        col=col,
        col_wrap=col_wrap,
        size=size,
        aspect=aspect,
        subplot_kws={"projection": self.projection},
        style=style,
        **kwargs,
    )

    # Set default titles if grid is already created (Pandas track)
    if self.grid:
        self._set_default_titles()

add_map_features(**kwargs)

Add cartopy features to all facets.

Parameters

**kwargs : Any Keyword arguments passed to SpatialPlot.add_features. Default is coastlines=True.

Source code in src/monet_plots/plots/facet_grid.py
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
def add_map_features(self, **kwargs: Any) -> None:
    """Add cartopy features to all facets.

    Parameters
    ----------
    **kwargs : Any
        Keyword arguments passed to SpatialPlot.add_features.
        Default is coastlines=True.
    """
    from .spatial import SpatialPlot

    if "coastlines" not in kwargs:
        kwargs["coastlines"] = True

    axes = getattr(self.grid, "axs", getattr(self.grid, "axes", None))
    if axes is None:
        return

    for ax in axes.flatten():
        if ax is None:
            continue
        # Use SpatialPlot's feature logic on each axis
        SpatialPlot(ax=ax, projection=self.projection, **kwargs)

map_monet(plotter_class, *, x=None, y=None, var_name=None, **kwargs)

Map a monet-plots spatial plotter to the grid.

Parameters

plotter_class : type A class from monet_plots.plots (e.g., SpatialImshowPlot). x : str, optional Column name for longitude. y : str, optional Column name for latitude. var_name : str, optional The variable name to plot. **kwargs : Any Arguments passed to the plotter and map features.

Source code in src/monet_plots/plots/facet_grid.py
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
def map_monet(
    self,
    plotter_class: type,
    *,
    x: str | None = None,
    y: str | None = None,
    var_name: str | None = None,
    **kwargs: Any,
) -> None:
    """Map a monet-plots spatial plotter to the grid.

    Parameters
    ----------
    plotter_class : type
        A class from monet_plots.plots (e.g., SpatialImshowPlot).
    x : str, optional
        Column name for longitude.
    y : str, optional
        Column name for latitude.
    var_name : str, optional
        The variable name to plot.
    **kwargs : Any
        Arguments passed to the plotter and map features.
    """
    # Separate feature kwargs
    feature_keys = [
        "coastlines",
        "states",
        "countries",
        "ocean",
        "land",
        "lakes",
        "rivers",
        "borders",
        "gridlines",
        "extent",
        "resolution",
    ]
    feature_kwargs = {k: kwargs.pop(k) for k in feature_keys if k in kwargs}
    add_shared_cb = kwargs.pop("add_colorbar", False)

    if self.is_xarray:
        # Track A: Xarray-native faceting (Lazy by Default)
        from .spatial import SpatialPlot

        # Identify coordinates if not provided
        if x is None or y is None:
            sp = SpatialPlot(style=None)
            x_id, y_id = sp._identify_coords(self.data)
            x = x or x_id
            y = y or y_id

        # Use Xarray's native plotting which handles faceting
        plot_type = "imshow"
        if "Contour" in plotter_class.__name__:
            plot_type = "contourf"

        # Prepare plotting arguments
        plot_kwargs = kwargs.copy()
        plot_kwargs.setdefault("transform", ccrs.PlateCarree())
        if add_shared_cb:
            plot_kwargs.setdefault("add_colorbar", True)
            plot_kwargs.setdefault("cbar_kwargs", {"orientation": "horizontal"})

        # Select data variable if it's a Dataset
        plot_data = self.data
        if isinstance(plot_data, xr.Dataset):
            if var_name:
                plot_data = plot_data[var_name]
            else:
                # Pick the first data variable that is not a facet dimension
                for v in plot_data.data_vars:
                    if v not in [self.col, self.row]:
                        plot_data = plot_data[v]
                        break

        # Trigger Xarray's facet grid
        # If we already have a grid, we can use it
        if self.grid is not None:
            import xarray.plot as xplt

            func = xplt.imshow if plot_type == "imshow" else xplt.contourf

            # Check if the grid's data is a Dataset.
            # If so, map_dataarray might fail in some xarray versions.
            # We use map with a selection wrapper instead.
            grid_data = getattr(self.grid, "data", None)
            if isinstance(grid_data, xr.Dataset):
                # var_name was selected above from self.data or passed in
                def _mapped_xr_plot(ds_subset, x_coord, y_coord, **inner_kwargs):
                    func(ds_subset[var_name], x=x_coord, y=y_coord, **inner_kwargs)

                self.grid.map(_mapped_xr_plot, x, y, **plot_kwargs)
            elif hasattr(self.grid, "map_dataarray"):
                self.grid.map_dataarray(func, x, y, **plot_kwargs)
            else:
                self.grid.map(func, x, y, **plot_kwargs)
        else:
            xr_plot_func = getattr(plot_data.plot, plot_type)
            self.grid = xr_plot_func(
                x=x,
                y=y,
                col=self.col,
                row=self.row,
                col_wrap=self.col_wrap,
                subplot_kws={"projection": self.projection},
                **plot_kwargs,
            )

        # Update BasePlot attributes
        self.fig = self.grid.fig
        axes = getattr(self.grid, "axs", getattr(self.grid, "axes", None))
        self.ax = axes.flatten()[0]
        self.g = self.grid

        # Add features to all facets
        self.add_map_features(**feature_kwargs)
        self._set_default_titles()

    else:
        # Track B: Seaborn-based faceting (Eager/Pandas)
        x = x or "lon"
        y = y or "lat"
        if var_name is None:
            if "variable" in self.data.columns:
                var_name = "value"
            elif isinstance(self.raw_data, xr.DataArray):
                var_name = self.raw_data.name
            elif isinstance(self.raw_data, xr.Dataset):
                var_name = list(self.raw_data.data_vars)[0]

        def _mapped_plot(*args, **kwargs_inner):
            data_df = kwargs_inner.pop("data")
            ax = plt.gca()
            temp_da = data_df.set_index([y, x]).to_xarray()[var_name]
            plotter = plotter_class(temp_da, ax=ax, **kwargs_inner)
            plotter.plot()

        self.map_dataframe(_mapped_plot, **kwargs)
        self.add_map_features(**feature_kwargs)
        if add_shared_cb:
            self._add_shared_colorbar(**kwargs)