1414import shutil
1515import warnings # noqa: F401
1616from pathlib import Path
17- from typing import Any , List , Tuple
17+ from typing import Any , Dict , List , Tuple
1818
1919import cv2
2020import geopandas as gpd
@@ -96,7 +96,8 @@ def process_tile(img_path: str,
9696 threshold : float = 0 ,
9797 nan_threshold : float = 0 ,
9898 mask_gdf : gpd .GeoDataFrame = None ,
99- additional_nodata : List [Any ] = []):
99+ additional_nodata : List [Any ] = [],
100+ image_statistics : List [Dict [str , float ]] = None ):
100101 """Process a single tile for making predictions.
101102
102103 Args:
@@ -226,7 +227,8 @@ def process_tile_ms(img_path: str,
226227 crowns : gpd .GeoDataFrame = None ,
227228 threshold : float = 0 ,
228229 nan_threshold : float = 0 ,
229- additional_nodata : List [Any ] = []):
230+ additional_nodata : List [Any ] = [],
231+ image_statistics : List [Dict [str , float ]] = None ):
230232 """Process a single tile for making predictions.
231233
232234 Args:
@@ -285,18 +287,34 @@ def process_tile_ms(img_path: str,
285287 if sumzero > nan_threshold * totalpix or sumnan > nan_threshold * totalpix :
286288 return None
287289
290+ # rescale image to 1-255 (0 is reserved for nodata)
291+ assert image_statistics is not None , "image_statistics must be provided for multispectral data"
292+ min_vals = np .array ([stats ['min' ] for stats in image_statistics ]).reshape (- 1 , 1 , 1 )
293+ max_vals = np .array ([stats ['max' ] for stats in image_statistics ]).reshape (- 1 , 1 , 1 )
294+
295+ # making it a bit safer for small numbers
296+ if max_vals .min () > 1 :
297+ out_img = (out_img - min_vals ) / (max_vals - min_vals ) * 254 + 1
298+ else :
299+ out_img = (out_img - min_vals ) * 254 / (max_vals - min_vals ) + 1
300+
301+ # additional clip to make sure
302+ out_img = np .clip (out_img , 1 , 255 )
303+
288304 # Apply nan mask
289- out_img [np .broadcast_to ((nan_mask == 1 )[None , :, :], out_img .shape )] = nodata
305+ out_img [np .broadcast_to ((nan_mask == 1 )[None , :, :], out_img .shape )] = 0
290306
291307 out_meta = data .meta .copy ()
292308 out_meta .update ({
293309 "driver" : "GTiff" ,
294310 "height" : out_img .shape [1 ],
295311 "width" : out_img .shape [2 ],
296312 "transform" : out_transform ,
297- "nodata" : nodata ,
313+ "nodata" : 0 ,
298314 })
299315 if dtype_bool :
316+ raise NotImplementedError (
317+ "dtype_bool not implemented for multispectral data. Pretty sure dtype_bool should be False." )
300318 out_meta .update ({"dtype" : "uint8" })
301319
302320 out_tif = out_path_root .with_suffix (".tif" )
@@ -333,7 +351,8 @@ def process_tile_train(
333351 mode : str = "rgb" ,
334352 class_column : str = None , # Allow user to specify class column
335353 mask_gdf : gpd .GeoDataFrame = None ,
336- additional_nodata : List [Any ] = []) -> None :
354+ additional_nodata : List [Any ] = [],
355+ image_statistics : List [Dict [str , float ]] = None ) -> None :
337356 """Process a single tile for training data.
338357
339358 Args:
@@ -356,10 +375,10 @@ def process_tile_train(
356375 """
357376 if mode == "rgb" :
358377 result = process_tile (img_path , out_dir , buffer , tile_width , tile_height , dtype_bool , minx , miny , crs , tilename ,
359- crowns , threshold , nan_threshold , mask_gdf , additional_nodata )
378+ crowns , threshold , nan_threshold , mask_gdf , additional_nodata , image_statistics )
360379 elif mode == "ms" :
361380 result = process_tile_ms (img_path , out_dir , buffer , tile_width , tile_height , dtype_bool , minx , miny , crs ,
362- tilename , crowns , threshold , nan_threshold , additional_nodata )
381+ tilename , crowns , threshold , nan_threshold , additional_nodata , image_statistics )
363382
364383 if result is None :
365384 # logger.warning(f"Skipping tile at ({minx}, {miny}) due to insufficient data.")
@@ -483,6 +502,104 @@ def _calculate_tile_placements(
483502 return coordinates
484503
485504
505+ def calculate_image_statistics (file_path , values_to_ignore = None , window_size = 64 , min_windows = 100 , mode = "rgb" ):
506+ """
507+ Calculate statistics for a raster using either whole image or sampled windows.
508+
509+ Parameters:
510+ - file_path: str, path to the raster file.
511+ - values_to_ignore: list, values to ignore in statistics (e.g., NaN, custom values).
512+ - window_size: int, size of square window for sampling.
513+ - min_windows: int, minimum number of valid windows to include in statistics.
514+
515+ Returns:
516+ - List of dictionaries containing statistics for each band.
517+ """
518+ if values_to_ignore is None :
519+ values_to_ignore = []
520+ with rasterio .open (file_path ) as src :
521+ # Get image dimensions
522+ width , height = src .width , src .height
523+
524+ # If the image is smaller than 2000x2000, process the whole image
525+ if width * height <= 2000 * 2000 :
526+ print ("Processing entire image..." )
527+ band_stats = []
528+ for band_idx in range (1 , src .count + 1 ):
529+ band = src .read (band_idx ).astype (float )
530+ # Mask out bad values
531+ mask = (np .isnan (band ) | np .isin (band , values_to_ignore ))
532+ valid_data = band [~ mask ]
533+
534+ if valid_data .size > 0 :
535+ stats = {
536+ "mean" : np .mean (valid_data ),
537+ "min" : np .min (valid_data ),
538+ "max" : np .max (valid_data ),
539+ "std_dev" : np .std (valid_data ),
540+ }
541+ else :
542+ stats = {
543+ "mean" : None ,
544+ "min" : None ,
545+ "max" : None ,
546+ "std_dev" : None ,
547+ }
548+ band_stats .append (stats )
549+ return band_stats
550+
551+ windows_sampled = 0
552+ band_aggregates = {band : [] for band in range (1 , src .count + 1 )}
553+
554+ while windows_sampled < min_windows :
555+ # Randomly pick a top-left corner for the window
556+ row_start = np .random .randint (0 , height - window_size )
557+ col_start = np .random .randint (0 , width - window_size )
558+
559+ window = rasterio .windows .Window (col_start , row_start , window_size , window_size )
560+
561+ # Read the window for each band
562+ valid_window = True
563+ window_data = {}
564+ for band_idx in range (1 , src .count + 1 ) if mode == "ms" else range (1 , 4 ):
565+ band = src .read (band_idx , window = window ).astype (float )
566+ # Mask out bad values
567+ mask = (np .isnan (band ) | np .isin (band , values_to_ignore ))
568+ valid_pixels = band [~ mask ]
569+ bad_pixel_ratio = mask .sum () / band .size
570+
571+ if bad_pixel_ratio > 0.05 : # Exclude windows with >5% bad values
572+ valid_window = False
573+ break
574+ window_data [band_idx ] = valid_pixels
575+
576+ if valid_window :
577+ for band_idx , valid_pixels in window_data .items ():
578+ band_aggregates [band_idx ].extend (valid_pixels )
579+ windows_sampled += 1
580+
581+ # Compute statistics for each band
582+ band_stats = []
583+ for band_idx in range (1 , src .count + 1 ) if mode == "ms" else range (1 , 4 ):
584+ valid_data = np .array (band_aggregates [band_idx ])
585+ if valid_data .size > 0 :
586+ stats = {
587+ "mean" : np .mean (valid_data ),
588+ "min" : np .min (valid_data ),
589+ "max" : np .max (valid_data ),
590+ "std_dev" : np .std (valid_data ),
591+ }
592+ else :
593+ stats = {
594+ "mean" : None ,
595+ "min" : None ,
596+ "max" : None ,
597+ "std_dev" : None ,
598+ }
599+ band_stats .append (stats )
600+ return band_stats
601+
602+
486603def tile_data (
487604 img_path : str ,
488605 out_dir : str ,
@@ -537,10 +654,12 @@ def tile_data(
537654 crs = data .crs .to_epsg () # Update CRS handling to avoid deprecated syntax
538655
539656 tile_coordinates = _calculate_tile_placements (img_path , buffer , tile_width , tile_height , crowns , tile_placement )
657+ image_statistics = calculate_image_statistics (img_path , values_to_ignore = additional_nodata , mode = mode )
658+
540659 tile_args = [
541660 (img_path , out_dir , buffer , tile_width , tile_height , dtype_bool , minx , miny , crs , tilename , crowns , threshold ,
542- nan_threshold , mode , class_column , mask_gdf , additional_nodata ) for minx , miny in tile_coordinates
543- if mask_path is None or (mask_path is not None and mask_gdf .intersects (
661+ nan_threshold , mode , class_column , mask_gdf , additional_nodata , image_statistics )
662+ for minx , miny in tile_coordinates if mask_path is None or (mask_path is not None and mask_gdf .intersects (
544663 box (minx , miny , minx + tile_width , miny + tile_height ) #TODO maybe add to_crs here
545664 ).any ())
546665 ]
0 commit comments