Spatial clustering¶
This notebook shows how to cluster geographic area based on their similar attributes.
Specifically, it show how to take an xr.Dataset with dimensions on xyt and potentially multiple variables and generate a clustered raster over the xy dimensions. The same can be done with an xy dataset with multiple variables.
Imports¶
%load_ext autoreload
%autoreload 2
import os
import numpy as np
import xarray as xr
import matplotlib.pyplot as plt
from hip.analysis.ops.clustering import xrMiniBatchKMeans, xrKMeans, explore_num_clusters
from hip.analysis import AnalysisArea
Define an area and get data¶
BBOX = (6.328125,37.718590,21.972656,46.800059)
area = AnalysisArea(
bbox=BBOX,
datetime_range="1970-01-01/1970-12-31"
)
lta = area.get_dataset(["CHIRPS","r1h_dekad_lta"])
tda = area.get_dataset(["MODIS","myd11c2_tda_dekad_lta"])
ds = xr.Dataset({"rfh":lta,"tda": tda})
ds = ds.persist()
ds
GDAL_DATA = /Users/paolo/code/hip-analysis/.pixi/envs/default/share/gdal GDAL_DISABLE_READDIR_ON_OPEN = EMPTY_DIR GDAL_HTTP_MAX_RETRY = 10 GDAL_HTTP_RETRY_DELAY = 0.5 GDAL_DATA = /Users/paolo/code/hip-analysis/.pixi/envs/default/share/gdal GDAL_DISABLE_READDIR_ON_OPEN = EMPTY_DIR GDAL_HTTP_MAX_RETRY = 10 GDAL_HTTP_RETRY_DELAY = 0.5
/Users/paolo/code/hip-analysis/.pixi/envs/default/lib/python3.10/site-packages/rasterio/warp.py:387: NotGeoreferencedWarning: Dataset has no geotransform, gcps, or rpcs. The identity matrix will be returned. dest = _reproject(
<xarray.Dataset> Size: 33MB
Dimensions: (latitude: 182, longitude: 314, time: 36)
Coordinates:
* latitude (latitude) float64 1kB 46.78 46.73 46.68 ... 37.83 37.78 37.73
* longitude (longitude) float64 3kB 6.325 6.375 6.425 ... 21.88 21.93 21.98
spatial_ref int32 4B 4326
* time (time) datetime64[ns] 288B 1970-01-01 1970-01-11 ... 1970-12-21
Data variables:
rfh (time, latitude, longitude) float64 16MB dask.array<chunksize=(1, 182, 314), meta=np.ndarray>
tda (time, latitude, longitude) float64 16MB dask.array<chunksize=(1, 182, 314), meta=np.ndarray>KMeans clustering¶
The clustering is done using the kmean algorithm. The implementation relies on scikit-learns Kmeans, with exactly the same arguments upon instantiation.
m = xrKMeans(n_clusters=20, random_state=0)
The step that is different is that the input xr.Dataset needs to be explicitly preprocessed into a format that is compatible with scikit-learn. This can be done with preprocess_data. As a consequence, the fit method does NOT expect the data as argument.
m.preprocess_data(ds)
m.fit()
xrKMeans(n_clusters=20, random_state=0)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
xrKMeans(n_clusters=20, random_state=0)
m.labels_da.plot()
<matplotlib.collections.QuadMesh at 0x168923fd0>
The number of clusters can be changes on the fly.
m.n_clusters = 3
m.fit()
m.cluster_centers_da
<xarray.Dataset> Size: 2kB
Dimensions: (cluster: 3, time: 36)
Coordinates:
spatial_ref int32 4B 4326
* time (time) datetime64[ns] 288B 1970-01-01 1970-01-11 ... 1970-12-21
* cluster (cluster) int64 24B 0 1 2
Data variables:
rfh (cluster, time) float64 864B 77.47 71.5 66.21 ... 71.63 69.36
tda (cluster, time) float64 864B 9.662 9.836 10.35 ... 4.968 4.007m.cluster_centers_da.rfh.plot.line(x='time')
plt.title('Rainfall')
plt.show()
m.cluster_centers_da.tda.plot.line(x='time')
plt.title('Daytime temperature')
plt.show()
MiniBatchKMeans clustering¶
The package also includes MiniBatchKMeans algorithm. This is faster on large datasets, with only minimal differences in the final results.
m = xrKMeans(n_clusters=20, random_state=0)
m.preprocess_data(ds)
m.fit()
m.labels_da.plot()
<matplotlib.collections.QuadMesh at 0x167fc0a90>
Finding the optimal number of clusters¶
You can calcualte the silhouette_score of a given clustering like this:
m.calculate_silhouette_score()
0.243443716382229
To look at the score across different k number of clusters:
m = xrMiniBatchKMeans(n_clusters=3, random_state=0) # Initial number of clusters does not matter
m.preprocess_data(ds)
results = explore_num_clusters(m, k_range=np.arange(3,27,4))
results
[(3, 0.3324703751956397), (7, 0.2645626361698667), (11, 0.23987158819742835), (15, 0.24314358204657888), (19, 0.22568763393746244), (23, 0.21964563733134013)]
Clustering across seasons¶
You can also cluster temporal patterns across space and across multiple seasons to identify locations-season combination with similar temporal patterns. This involves restructuring the data to have dimensions: (latitude, longitude, season, time_of_year).
Step 1: Get multi-year data¶
First, we'll get data spanning multiple years:
import pandas as pd
# Define area and date range spanning multiple years
BBOX = (6.328125, 37.718590, 21.972656, 46.800059)
area_multiyear = AnalysisArea(
bbox=BBOX,
resolution=0.05, # Using coarser resolution for faster computation
datetime_range="2018-06-01/2021-05-31"
)
# Get rainfall data
rfh = area_multiyear.get_dataset(["CHIRPS", "RFH_DEKAD"]).compute()
rfh.sizes
Frozen({'time': 108, 'latitude': 182, 'longitude': 314})
Step 2: Create season coordinate¶
We'll create a "season" coordinate that assigns each timestamp to an agricultural year (e.g., starting in March):
# Create season coordinate (agricultural year starting in March)
def adjusted_year(time_array, month_start=6):
"""Assign each date to a season/year based on agricultural calendar."""
return [
y - 1 * (m < month_start)
for m, y in zip(rfh["time"].dt.month.values, rfh["time"].dt.year.values)
]
rfh = rfh.assign_coords({"season": ("time", adjusted_year(rfh["time"]))})
print(f"Seasons: {sorted(set(rfh['season'].values))}")
Seasons: [2018, 2019, 2020]
Frozen({'time': 108, 'latitude': 182, 'longitude': 314})
Step 3: Restructure data with season and time_of_year dimensions¶
Now we'll restructure the dataset to separate season and time_of_year dimensions:
# Create time_of_year coordinate as the index within each season
# Group by season and assign sequential indices (0, 1, 2, ...) to each time step
time_of_year = []
for season, group in rfh.groupby("season"):
time_of_year.extend(range(len(group.time)))
time_of_year = np.array(time_of_year)
# Add time_of_year as a coordinate
rfh_with_toy = rfh.assign_coords({"time_of_year": ("time", time_of_year)})
# Set multi-index and unstack to create the new structure
rfh_restructured = (
rfh_with_toy
.set_index(time=["season", "time_of_year"])
.unstack("time")
)
print(f"Original dimensions: {rfh.dims}")
print(f"Restructured dimensions: {rfh_restructured.dims}")
print(f"Restructured sizes: {rfh_restructured.sizes}")
Original dimensions: ('time', 'latitude', 'longitude')
Restructured dimensions: ('latitude', 'longitude', 'season', 'time_of_year')
Restructured sizes: Frozen({'latitude': 182, 'longitude': 314, 'season': 3, 'time_of_year': 36})
Step 4: Perform clustering on restructured data¶
Now we can cluster the spatial locations based on their multi-year seasonal patterns:
# Create clustering model
m_seasonal = xrKMeans(n_clusters=5, random_state=0)
# Preprocess the restructured data
m_seasonal.preprocess_data(rfh_restructured, clustering_dims=["season","latitude", "longitude"])
# Fit the model
m_seasonal.fit()
print(f"Feature dimensions used for clustering: {m_seasonal.feature_dims}")
print(f"Number of features per location: {m_seasonal.n_features}")
print(f"Array shape for clustering: {m_seasonal.array.shape}")
Feature dimensions used for clustering: ['time_of_year'] Number of features per location: 36 Array shape for clustering: (99525, 36)
Step 5: Visualize the clustering results¶
Let's visualize the spatial clusters:
# Plot the cluster labels - one plot per season
n_seasons = len(m_seasonal.labels_da.season)
n_cols = 3
n_rows = int(np.ceil(n_seasons / n_cols))
fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5 * n_rows))
axes = axes.flatten() if n_seasons > 1 else [axes]
for idx, season in enumerate(m_seasonal.labels_da.season.values):
ax = axes[idx]
m_seasonal.labels_da.sel(season=season).plot(ax=ax, cmap='tab10', add_colorbar=True)
ax.set_title(f'Clusters for Season {season}')
ax.set_xlabel('Longitude')
ax.set_ylabel('Latitude')
# Hide extra subplots if any
for idx in range(n_seasons, len(axes)):
axes[idx].axis('off')
plt.suptitle('Spatial Clusters Based of Seasonal Rainfall Patterns',
fontsize=14, y=1.02)
plt.tight_layout()
plt.show()
Step 6: Analyze cluster centroids¶
We can examine the characteristic rainfall patterns for each cluster across seasons:
# Get cluster centers
centers = m_seasonal.cluster_centers_da
# Plot rainfall patterns for each cluster
fig, axes = plt.subplots(m_seasonal.n_clusters, 1, figsize=(14, 3 * m_seasonal.n_clusters))
axes = axes if m_seasonal.n_clusters > 1 else [axes]
for cluster_id in range(m_seasonal.n_clusters):
ax = axes[cluster_id]
# Plot the rainfall pattern across time_of_year for this cluster
cluster_data = centers.band.sel(cluster=cluster_id)
ax.plot(cluster_data.time_of_year.values, cluster_data.values, linewidth=2)
ax.set_title(f'Cluster {cluster_id} - Average Rainfall Pattern')
ax.set_xlabel('Time Step in Season')
ax.set_ylabel('Rainfall (mm)')
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.show()