Skip to content

Commit 141dc97

Browse files
convert MS to 0-255
1 parent fbde5c2 commit 141dc97

File tree

1 file changed

+129
-10
lines changed

1 file changed

+129
-10
lines changed

detectree2/preprocessing/tiling.py

Lines changed: 129 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import shutil
1515
import warnings # noqa: F401
1616
from pathlib import Path
17-
from typing import Any, List, Tuple
17+
from typing import Any, Dict, List, Tuple
1818

1919
import cv2
2020
import geopandas as gpd
@@ -96,7 +96,8 @@ def process_tile(img_path: str,
9696
threshold: float = 0,
9797
nan_threshold: float = 0,
9898
mask_gdf: gpd.GeoDataFrame = None,
99-
additional_nodata: List[Any] = []):
99+
additional_nodata: List[Any] = [],
100+
image_statistics: List[Dict[str, float]] = None):
100101
"""Process a single tile for making predictions.
101102
102103
Args:
@@ -226,7 +227,8 @@ def process_tile_ms(img_path: str,
226227
crowns: gpd.GeoDataFrame = None,
227228
threshold: float = 0,
228229
nan_threshold: float = 0,
229-
additional_nodata: List[Any] = []):
230+
additional_nodata: List[Any] = [],
231+
image_statistics: List[Dict[str, float]] = None):
230232
"""Process a single tile for making predictions.
231233
232234
Args:
@@ -285,18 +287,34 @@ def process_tile_ms(img_path: str,
285287
if sumzero > nan_threshold * totalpix or sumnan > nan_threshold * totalpix:
286288
return None
287289

290+
# rescale image to 1-255 (0 is reserved for nodata)
291+
assert image_statistics is not None, "image_statistics must be provided for multispectral data"
292+
min_vals = np.array([stats['min'] for stats in image_statistics]).reshape(-1, 1, 1)
293+
max_vals = np.array([stats['max'] for stats in image_statistics]).reshape(-1, 1, 1)
294+
295+
# making it a bit safer for small numbers
296+
if max_vals.min() > 1:
297+
out_img = (out_img - min_vals) / (max_vals - min_vals) * 254 + 1
298+
else:
299+
out_img = (out_img - min_vals) * 254 / (max_vals - min_vals) + 1
300+
301+
# additional clip to make sure
302+
out_img = np.clip(out_img, 1, 255)
303+
288304
# Apply nan mask
289-
out_img[np.broadcast_to((nan_mask == 1)[None, :, :], out_img.shape)] = nodata
305+
out_img[np.broadcast_to((nan_mask == 1)[None, :, :], out_img.shape)] = 0
290306

291307
out_meta = data.meta.copy()
292308
out_meta.update({
293309
"driver": "GTiff",
294310
"height": out_img.shape[1],
295311
"width": out_img.shape[2],
296312
"transform": out_transform,
297-
"nodata": nodata,
313+
"nodata": 0,
298314
})
299315
if dtype_bool:
316+
raise NotImplementedError(
317+
"dtype_bool not implemented for multispectral data. Pretty sure dtype_bool should be False.")
300318
out_meta.update({"dtype": "uint8"})
301319

302320
out_tif = out_path_root.with_suffix(".tif")
@@ -333,7 +351,8 @@ def process_tile_train(
333351
mode: str = "rgb",
334352
class_column: str = None, # Allow user to specify class column
335353
mask_gdf: gpd.GeoDataFrame = None,
336-
additional_nodata: List[Any] = []) -> None:
354+
additional_nodata: List[Any] = [],
355+
image_statistics: List[Dict[str, float]] = None) -> None:
337356
"""Process a single tile for training data.
338357
339358
Args:
@@ -356,10 +375,10 @@ def process_tile_train(
356375
"""
357376
if mode == "rgb":
358377
result = process_tile(img_path, out_dir, buffer, tile_width, tile_height, dtype_bool, minx, miny, crs, tilename,
359-
crowns, threshold, nan_threshold, mask_gdf, additional_nodata)
378+
crowns, threshold, nan_threshold, mask_gdf, additional_nodata, image_statistics)
360379
elif mode == "ms":
361380
result = process_tile_ms(img_path, out_dir, buffer, tile_width, tile_height, dtype_bool, minx, miny, crs,
362-
tilename, crowns, threshold, nan_threshold, additional_nodata)
381+
tilename, crowns, threshold, nan_threshold, additional_nodata, image_statistics)
363382

364383
if result is None:
365384
# logger.warning(f"Skipping tile at ({minx}, {miny}) due to insufficient data.")
@@ -483,6 +502,104 @@ def _calculate_tile_placements(
483502
return coordinates
484503

485504

505+
def calculate_image_statistics(file_path, values_to_ignore=None, window_size=64, min_windows=100, mode="rgb"):
506+
"""
507+
Calculate statistics for a raster using either whole image or sampled windows.
508+
509+
Parameters:
510+
- file_path: str, path to the raster file.
511+
- values_to_ignore: list, values to ignore in statistics (e.g., NaN, custom values).
512+
- window_size: int, size of square window for sampling.
513+
- min_windows: int, minimum number of valid windows to include in statistics.
514+
515+
Returns:
516+
- List of dictionaries containing statistics for each band.
517+
"""
518+
if values_to_ignore is None:
519+
values_to_ignore = []
520+
with rasterio.open(file_path) as src:
521+
# Get image dimensions
522+
width, height = src.width, src.height
523+
524+
# If the image is smaller than 2000x2000, process the whole image
525+
if width * height <= 2000 * 2000:
526+
print("Processing entire image...")
527+
band_stats = []
528+
for band_idx in range(1, src.count + 1):
529+
band = src.read(band_idx).astype(float)
530+
# Mask out bad values
531+
mask = (np.isnan(band) | np.isin(band, values_to_ignore))
532+
valid_data = band[~mask]
533+
534+
if valid_data.size > 0:
535+
stats = {
536+
"mean": np.mean(valid_data),
537+
"min": np.min(valid_data),
538+
"max": np.max(valid_data),
539+
"std_dev": np.std(valid_data),
540+
}
541+
else:
542+
stats = {
543+
"mean": None,
544+
"min": None,
545+
"max": None,
546+
"std_dev": None,
547+
}
548+
band_stats.append(stats)
549+
return band_stats
550+
551+
windows_sampled = 0
552+
band_aggregates = {band: [] for band in range(1, src.count + 1)}
553+
554+
while windows_sampled < min_windows:
555+
# Randomly pick a top-left corner for the window
556+
row_start = np.random.randint(0, height - window_size)
557+
col_start = np.random.randint(0, width - window_size)
558+
559+
window = rasterio.windows.Window(col_start, row_start, window_size, window_size)
560+
561+
# Read the window for each band
562+
valid_window = True
563+
window_data = {}
564+
for band_idx in range(1, src.count + 1) if mode == "ms" else range(1, 4):
565+
band = src.read(band_idx, window=window).astype(float)
566+
# Mask out bad values
567+
mask = (np.isnan(band) | np.isin(band, values_to_ignore))
568+
valid_pixels = band[~mask]
569+
bad_pixel_ratio = mask.sum() / band.size
570+
571+
if bad_pixel_ratio > 0.05: # Exclude windows with >5% bad values
572+
valid_window = False
573+
break
574+
window_data[band_idx] = valid_pixels
575+
576+
if valid_window:
577+
for band_idx, valid_pixels in window_data.items():
578+
band_aggregates[band_idx].extend(valid_pixels)
579+
windows_sampled += 1
580+
581+
# Compute statistics for each band
582+
band_stats = []
583+
for band_idx in range(1, src.count + 1) if mode == "ms" else range(1, 4):
584+
valid_data = np.array(band_aggregates[band_idx])
585+
if valid_data.size > 0:
586+
stats = {
587+
"mean": np.mean(valid_data),
588+
"min": np.min(valid_data),
589+
"max": np.max(valid_data),
590+
"std_dev": np.std(valid_data),
591+
}
592+
else:
593+
stats = {
594+
"mean": None,
595+
"min": None,
596+
"max": None,
597+
"std_dev": None,
598+
}
599+
band_stats.append(stats)
600+
return band_stats
601+
602+
486603
def tile_data(
487604
img_path: str,
488605
out_dir: str,
@@ -537,10 +654,12 @@ def tile_data(
537654
crs = data.crs.to_epsg() # Update CRS handling to avoid deprecated syntax
538655

539656
tile_coordinates = _calculate_tile_placements(img_path, buffer, tile_width, tile_height, crowns, tile_placement)
657+
image_statistics = calculate_image_statistics(img_path, values_to_ignore=additional_nodata, mode=mode)
658+
540659
tile_args = [
541660
(img_path, out_dir, buffer, tile_width, tile_height, dtype_bool, minx, miny, crs, tilename, crowns, threshold,
542-
nan_threshold, mode, class_column, mask_gdf, additional_nodata) for minx, miny in tile_coordinates
543-
if mask_path is None or (mask_path is not None and mask_gdf.intersects(
661+
nan_threshold, mode, class_column, mask_gdf, additional_nodata, image_statistics)
662+
for minx, miny in tile_coordinates if mask_path is None or (mask_path is not None and mask_gdf.intersects(
544663
box(minx, miny, minx + tile_width, miny + tile_height) #TODO maybe add to_crs here
545664
).any())
546665
]

0 commit comments

Comments
 (0)