From 0921609a30f431f414d2a000f75b7d920e53280f Mon Sep 17 00:00:00 2001 From: Noelle Cheng Date: Fri, 27 Jun 2025 14:09:11 +0800 Subject: [PATCH 01/16] sampler --- map2loop/mapdata.py | 57 ------------------------------ map2loop/project.py | 32 +++++++---------- map2loop/sampler.py | 22 ++++++++---- map2loop/thickness_calculator.py | 7 ++-- map2loop/utils.py | 60 ++++++++++++++++++++++++++++++++ 5 files changed, 93 insertions(+), 85 deletions(-) diff --git a/map2loop/mapdata.py b/map2loop/mapdata.py index 4137af27..df5f6804 100644 --- a/map2loop/mapdata.py +++ b/map2loop/mapdata.py @@ -1448,63 +1448,6 @@ def get_value_from_raster(self, datatype: Datatype, x, y): val = data.ReadAsArray(px, py, 1, 1)[0][0] return val - @beartype.beartype - def __value_from_raster(self, inv_geotransform, data, x: float, y: float): - """ - Get the value from a raster dataset at the specified point - - Args: - inv_geotransform (gdal.GeoTransform): - The inverse of the data's geotransform - data (numpy.array): - The raster data - x (float): - The easting coordinate of the value - y (float): - The northing coordinate of the value - - Returns: - float or int: The value at the point specified - """ - px = int(inv_geotransform[0] + inv_geotransform[1] * x + inv_geotransform[2] * y) - py = int(inv_geotransform[3] + inv_geotransform[4] * x + inv_geotransform[5] * y) - # Clamp values to the edges of raster if past boundary, similiar to GL_CLIP - px = max(px, 0) - px = min(px, data.shape[0] - 1) - py = max(py, 0) - py = min(py, data.shape[1] - 1) - return data[px][py] - - @beartype.beartype - def get_value_from_raster_df(self, datatype: Datatype, df: pandas.DataFrame): - """ - Add a 'Z' column to a dataframe with the heights from the 'X' and 'Y' coordinates - - Args: - datatype (Datatype): - The datatype of the raster map to retrieve from - df (pandas.DataFrame): - The original dataframe with 'X' and 'Y' columns - - Returns: - pandas.DataFrame: The modified dataframe - """ - if len(df) <= 0: - df["Z"] = [] - return df - data = self.get_map_data(datatype) - if data is None: - logger.warning("Cannot get value from data as data is not loaded") - return None - - inv_geotransform = gdal.InvGeoTransform(data.GetGeoTransform()) - data_array = numpy.array(data.GetRasterBand(1).ReadAsArray().T) - - df["Z"] = df.apply( - lambda row: self.__value_from_raster(inv_geotransform, data_array, row["X"], row["Y"]), - axis=1, - ) - return df @beartype.beartype def extract_all_contacts(self, save_contacts=True): diff --git a/map2loop/project.py b/map2loop/project.py index d9cfbb83..ec7260e5 100644 --- a/map2loop/project.py +++ b/map2loop/project.py @@ -1,6 +1,6 @@ # internal imports from map2loop.fault_orientation import FaultOrientationNearest -from .utils import hex_to_rgb +from .utils import hex_to_rgb, set_z_values_from_raster_df from .m2l_enums import VerboseLevel, ErrorState, Datatype from .mapdata import MapData from .sampler import Sampler, SamplerDecimator, SamplerSpacing @@ -503,26 +503,20 @@ def sample_map_data(self): """ Use the samplers to extract points along polylines or unit boundaries """ - logger.info( - f"Sampling geology map data using {self.samplers[Datatype.GEOLOGY].sampler_label}" - ) - self.geology_samples = self.samplers[Datatype.GEOLOGY].sample( - self.map_data.get_map_data(Datatype.GEOLOGY), self.map_data - ) - logger.info( - f"Sampling structure map data using {self.samplers[Datatype.STRUCTURE].sampler_label}" - ) - self.structure_samples = self.samplers[Datatype.STRUCTURE].sample( - self.map_data.get_map_data(Datatype.STRUCTURE), self.map_data - ) + geology_data = self.map_data.get_map_data(Datatype.GEOLOGY) + dtm_data = self.map_data.get_map_data(Datatype.DTM) + + logger.info(f"Sampling geology map data using {self.samplers[Datatype.GEOLOGY].sampler_label}") + self.geology_samples = self.samplers[Datatype.GEOLOGY].sample(geology_data) + + logger.info(f"Sampling structure map data using {self.samplers[Datatype.STRUCTURE].sampler_label}") + self.structure_samples = self.samplers[Datatype.STRUCTURE].sample(self.map_data.get_map_data(Datatype.STRUCTURE), dtm_data, geology_data) + logger.info(f"Sampling fault map data using {self.samplers[Datatype.FAULT].sampler_label}") - self.fault_samples = self.samplers[Datatype.FAULT].sample( - self.map_data.get_map_data(Datatype.FAULT), self.map_data - ) + self.fault_samples = self.samplers[Datatype.FAULT].sample(self.map_data.get_map_data(Datatype.FAULT)) + logger.info(f"Sampling fold map data using {self.samplers[Datatype.FOLD].sampler_label}") - self.fold_samples = self.samplers[Datatype.FOLD].sample( - self.map_data.get_map_data(Datatype.FOLD), self.map_data - ) + self.fold_samples = self.samplers[Datatype.FOLD].sample(self.map_data.get_map_data(Datatype.FOLD)) def extract_geology_contacts(self): """ diff --git a/map2loop/sampler.py b/map2loop/sampler.py index 01600566..10aa51b9 100644 --- a/map2loop/sampler.py +++ b/map2loop/sampler.py @@ -1,6 +1,7 @@ # internal imports from .m2l_enums import Datatype from .mapdata import MapData +from .utils import set_z_values_from_raster_df # external imports from abc import ABC, abstractmethod @@ -10,6 +11,7 @@ import shapely import numpy from typing import Optional +from .utils import set_z_values_from_raster_df class Sampler(ABC): @@ -38,7 +40,7 @@ def type(self): @beartype.beartype @abstractmethod def sample( - self, spatial_data: geopandas.GeoDataFrame, map_data: Optional[MapData] = None + self, spatial_data: geopandas.GeoDataFrame, dtm_data: Optional[geopandas.GeoDataFrame] = None, geology_data: Optional[geopandas.GeoDataFrame] = None ) -> pandas.DataFrame: """ Execute sampling method (abstract method) @@ -73,7 +75,7 @@ def __init__(self, decimation: int = 1): @beartype.beartype def sample( - self, spatial_data: geopandas.GeoDataFrame, map_data: Optional[MapData] = None + self, spatial_data: geopandas.GeoDataFrame, dtm_data: Optional[gdal.Dataset] = None, geology_data: Optional[geopandas.GeoDataFrame] = None ) -> pandas.DataFrame: """ Execute sample method takes full point data, samples the data and returns the decimated points @@ -87,10 +89,16 @@ def sample( data = spatial_data.copy() data["X"] = data.geometry.x data["Y"] = data.geometry.y - data["Z"] = map_data.get_value_from_raster_df(Datatype.DTM, data)["Z"] - data["layerID"] = geopandas.sjoin( - data, map_data.get_map_data(Datatype.GEOLOGY), how='left' - )['index_right'] + if dtm_data is not None: + data["Z"] = set_z_values_from_raster_df(dtm_data, data)["Z"] + else: + data["Z"] = None + if geology_data is not None: + data["layerID"] = geopandas.sjoin( + data, geology_data, how='left' + )['index_right'] + else: + data["layerID"] = None data.reset_index(drop=True, inplace=True) return pandas.DataFrame(data[:: self.decimation].drop(columns="geometry")) @@ -118,7 +126,7 @@ def __init__(self, spacing: float = 50.0): @beartype.beartype def sample( - self, spatial_data: geopandas.GeoDataFrame, map_data: Optional[MapData] = None + self, spatial_data: geopandas.GeoDataFrame, dtm_data: Optional[geopandas.GeoDataFrame] = None, geology_data: Optional[geopandas.GeoDataFrame] = None ) -> pandas.DataFrame: """ Execute sample method takes full point data, samples the data and returns the sampled points diff --git a/map2loop/thickness_calculator.py b/map2loop/thickness_calculator.py index d7a9aad1..3da0ad40 100644 --- a/map2loop/thickness_calculator.py +++ b/map2loop/thickness_calculator.py @@ -5,6 +5,7 @@ calculate_endpoints, multiline_to_line, find_segment_strike_from_pt, + set_z_values_from_raster_df ) from .m2l_enums import Datatype from .interpolators import DipDipDirectionInterpolator @@ -271,7 +272,8 @@ def compute( # set the crs of the contacts to the crs of the units contacts = contacts.set_crs(crs=basal_contacts.crs) # get the elevation Z of the contacts - contacts = map_data.get_value_from_raster_df(Datatype.DTM, contacts) + dtm_data = map_data.get_map_data(Datatype.DTM) + contacts = set_z_values_from_raster_df(dtm_data, contacts) # update the geometry of the contact points to include the Z value contacts["geometry"] = contacts.apply( lambda row: shapely.geometry.Point(row.geometry.x, row.geometry.y, row["Z"]), axis=1 @@ -299,7 +301,8 @@ def compute( # set the crs of the interpolated orientations to the crs of the units interpolated_orientations = interpolated_orientations.set_crs(crs=basal_contacts.crs) # get the elevation Z of the interpolated points - interpolated = map_data.get_value_from_raster_df(Datatype.DTM, interpolated_orientations) + dtm_data = map_data.get_map_data(Datatype.DTM) + interpolated = set_z_values_from_raster_df(dtm_data, interpolated_orientations) # update the geometry of the interpolated points to include the Z value interpolated["geometry"] = interpolated.apply( lambda row: shapely.geometry.Point(row.geometry.x, row.geometry.y, row["Z"]), axis=1 diff --git a/map2loop/utils.py b/map2loop/utils.py index c3ed7795..55e2e7b2 100644 --- a/map2loop/utils.py +++ b/map2loop/utils.py @@ -7,6 +7,7 @@ import pandas import re import json +from osgeo import gdal from .logging import getLogger logger = getLogger(__name__) @@ -528,3 +529,62 @@ def update_from_legacy_file( json.dump(parsed_data, f, indent=4) return file_map + +@beartype.beartype +def value_from_raster(inv_geotransform, data, x: float, y: float): + """ + Get the value from a raster dataset at the specified point + + Args: + inv_geotransform (gdal.GeoTransform): + The inverse of the data's geotransform + data (numpy.array): + The raster data + x (float): + The easting coordinate of the value + y (float): + The northing coordinate of the value + + Returns: + float or int: The value at the point specified + """ + px = int(inv_geotransform[0] + inv_geotransform[1] * x + inv_geotransform[2] * y) + py = int(inv_geotransform[3] + inv_geotransform[4] * x + inv_geotransform[5] * y) + # Clamp values to the edges of raster if past boundary, similiar to GL_CLIP + px = max(px, 0) + px = min(px, data.shape[0] - 1) + py = max(py, 0) + py = min(py, data.shape[1] - 1) + return data[px][py] + +@beartype.beartype +def set_z_values_from_raster_df(dtm_data: gdal.Dataset, df: pandas.DataFrame): + """ + Add a 'Z' column to a dataframe with the heights from the 'X' and 'Y' coordinates + + Args: + dtm_data (gdal.Dataset): + Dtm data from raster map + df (pandas.DataFrame): + The original dataframe with 'X' and 'Y' columns + + Returns: + pandas.DataFrame: The modified dataframe + """ + if len(df) <= 0: + df["Z"] = [] + return df + + if dtm_data is None: + logger.warning("Cannot get value from data as data is not loaded") + return None + + inv_geotransform = gdal.InvGeoTransform(dtm_data.GetGeoTransform()) + data_array = numpy.array(dtm_data.GetRasterBand(1).ReadAsArray().T) + + df["Z"] = df.apply( + lambda row: value_from_raster(inv_geotransform, data_array, row["X"], row["Y"]), + axis=1, + ) + + return df \ No newline at end of file From 6b0249d3bc9091bd6a1e2461636ce80e7db4b77f Mon Sep 17 00:00:00 2001 From: Noelle Cheng Date: Fri, 27 Jun 2025 14:18:46 +0800 Subject: [PATCH 02/16] fix extract_geology_contacts --- map2loop/project.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/map2loop/project.py b/map2loop/project.py index ec7260e5..9e8189cd 100644 --- a/map2loop/project.py +++ b/map2loop/project.py @@ -526,11 +526,9 @@ def extract_geology_contacts(self): self.map_data.extract_basal_contacts(self.stratigraphic_column.column) # sample the contacts - self.map_data.sampled_contacts = self.samplers[Datatype.GEOLOGY].sample( - self.map_data.basal_contacts - ) - - self.map_data.get_value_from_raster_df(Datatype.DTM, self.map_data.sampled_contacts) + self.map_data.sampled_contacts = self.samplers[Datatype.GEOLOGY].sample(self.map_data.basal_contacts) + dtm_data = self.map_data.get_map_data(Datatype.DTM) + set_z_values_from_raster_df(dtm_data, self.map_data.sampled_contacts) def calculate_stratigraphic_order(self, take_best=False): """ From e3ae1c3996b2dcd93ef566dfa1dd5ba0b76247e1 Mon Sep 17 00:00:00 2001 From: Noelle Cheng Date: Fri, 27 Jun 2025 14:37:19 +0800 Subject: [PATCH 03/16] fix calculate_fault_orientations and summarise_fault_data --- map2loop/project.py | 6 ++++-- map2loop/sampler.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/map2loop/project.py b/map2loop/project.py index 9e8189cd..7e39bced 100644 --- a/map2loop/project.py +++ b/map2loop/project.py @@ -706,7 +706,8 @@ def calculate_fault_orientations(self): self.map_data.get_map_data(Datatype.FAULT_ORIENTATION), self.map_data, ) - self.map_data.get_value_from_raster_df(Datatype.DTM, self.fault_orientations) + dtm_data = self.map_data.get_map_data(Datatype.DTM) + set_z_values_from_raster_df(dtm_data, self.fault_orientations) else: logger.warning( "No fault orientation data found, skipping fault orientation calculation" @@ -731,7 +732,8 @@ def summarise_fault_data(self): """ Use the fault shapefile to make a summary of each fault by name """ - self.map_data.get_value_from_raster_df(Datatype.DTM, self.fault_samples) + dtm_data = self.map_data.get_map_data(Datatype.DTM) + set_z_values_from_raster_df(dtm_data, self.fault_samples) self.deformation_history.summarise_data(self.fault_samples) self.deformation_history.faults = self.throw_calculator.compute( diff --git a/map2loop/sampler.py b/map2loop/sampler.py index 10aa51b9..43db952e 100644 --- a/map2loop/sampler.py +++ b/map2loop/sampler.py @@ -11,7 +11,7 @@ import shapely import numpy from typing import Optional -from .utils import set_z_values_from_raster_df +from osgeo import gdal class Sampler(ABC): From cc315d7e5d5191a64ab07a2477a66b0b9fd47569 Mon Sep 17 00:00:00 2001 From: Noelle Cheng Date: Fri, 27 Jun 2025 15:38:57 +0800 Subject: [PATCH 04/16] fix dtm data type --- map2loop/sampler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/map2loop/sampler.py b/map2loop/sampler.py index 43db952e..e8e1fe51 100644 --- a/map2loop/sampler.py +++ b/map2loop/sampler.py @@ -40,7 +40,7 @@ def type(self): @beartype.beartype @abstractmethod def sample( - self, spatial_data: geopandas.GeoDataFrame, dtm_data: Optional[geopandas.GeoDataFrame] = None, geology_data: Optional[geopandas.GeoDataFrame] = None + self, spatial_data: geopandas.GeoDataFrame, dtm_data: Optional[gdal.Dataset] = None, geology_data: Optional[geopandas.GeoDataFrame] = None ) -> pandas.DataFrame: """ Execute sampling method (abstract method) @@ -126,7 +126,7 @@ def __init__(self, spacing: float = 50.0): @beartype.beartype def sample( - self, spatial_data: geopandas.GeoDataFrame, dtm_data: Optional[geopandas.GeoDataFrame] = None, geology_data: Optional[geopandas.GeoDataFrame] = None + self, spatial_data: geopandas.GeoDataFrame, dtm_data: Optional[gdal.Dataset] = None, geology_data: Optional[geopandas.GeoDataFrame] = None ) -> pandas.DataFrame: """ Execute sample method takes full point data, samples the data and returns the sampled points From abd8b2c791dcacbfcb612b8120ecfb4cd446a6c4 Mon Sep 17 00:00:00 2001 From: Noelle Cheng Date: Fri, 27 Jun 2025 15:49:20 +0800 Subject: [PATCH 05/16] fix get_value_from_raster import --- map2loop/sorter.py | 6 +++--- map2loop/thickness_calculator.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/map2loop/sorter.py b/map2loop/sorter.py index da4dab76..656cc4c9 100644 --- a/map2loop/sorter.py +++ b/map2loop/sorter.py @@ -3,7 +3,7 @@ import pandas import numpy as np import math -from .mapdata import MapData +from .mapdata import MapData, get_value_from_raster from typing import Union from .logging import getLogger @@ -434,9 +434,9 @@ def sort( continue # Get heights for intersection point and start of ray - height = map_data.get_value_from_raster(Datatype.DTM, start.x, start.y) + height = get_value_from_raster(Datatype.DTM, start.x, start.y) first_intersect_point = Point(start.x, start.y, height) - height = map_data.get_value_from_raster( + height = get_value_from_raster( Datatype.DTM, second_intersect_point.x, second_intersect_point.y ) second_intersect_point = Point(second_intersect_point.x, start.y, height) diff --git a/map2loop/thickness_calculator.py b/map2loop/thickness_calculator.py index 3da0ad40..d6992b90 100644 --- a/map2loop/thickness_calculator.py +++ b/map2loop/thickness_calculator.py @@ -9,7 +9,7 @@ ) from .m2l_enums import Datatype from .interpolators import DipDipDirectionInterpolator -from .mapdata import MapData +from .mapdata import MapData, get_value_from_raster from .logging import getLogger logger = getLogger(__name__) @@ -358,13 +358,13 @@ def compute( p1[0] = numpy.asarray(short_line[0].coords[0][0]) p1[1] = numpy.asarray(short_line[0].coords[0][1]) # get the elevation Z of the end point p1 - p1[2] = map_data.get_value_from_raster(Datatype.DTM, p1[0], p1[1]) + p1[2] = get_value_from_raster(Datatype.DTM, p1[0], p1[1]) # create array to store xyz coordinates of the end point p2 p2 = numpy.zeros(3) p2[0] = numpy.asarray(short_line[0].coords[-1][0]) p2[1] = numpy.asarray(short_line[0].coords[-1][1]) # get the elevation Z of the end point p2 - p2[2] = map_data.get_value_from_raster(Datatype.DTM, p2[0], p2[1]) + p2[2] = get_value_from_raster(Datatype.DTM, p2[0], p2[1]) # calculate the length of the shortest line line_length = scipy.spatial.distance.euclidean(p1, p2) # find the indices of the points that are within 5% of the length of the shortest line From 3a908df183190db14033bd2916d523652f20a321 Mon Sep 17 00:00:00 2001 From: noellehmcheng <143368485+noellehmcheng@users.noreply.github.com> Date: Fri, 27 Jun 2025 08:01:45 +0000 Subject: [PATCH 06/16] style: style fixes by ruff and autoformatting by black --- map2loop/sampler.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/map2loop/sampler.py b/map2loop/sampler.py index e8e1fe51..04baed4a 100644 --- a/map2loop/sampler.py +++ b/map2loop/sampler.py @@ -1,6 +1,4 @@ # internal imports -from .m2l_enums import Datatype -from .mapdata import MapData from .utils import set_z_values_from_raster_df # external imports From aa6f0e246cfac1a690bb115b89acfe5d881dd922 Mon Sep 17 00:00:00 2001 From: Noelle Cheng Date: Fri, 27 Jun 2025 16:35:41 +0800 Subject: [PATCH 07/16] revert get_value_from_raster --- map2loop/sorter.py | 6 +++--- map2loop/thickness_calculator.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/map2loop/sorter.py b/map2loop/sorter.py index 656cc4c9..da4dab76 100644 --- a/map2loop/sorter.py +++ b/map2loop/sorter.py @@ -3,7 +3,7 @@ import pandas import numpy as np import math -from .mapdata import MapData, get_value_from_raster +from .mapdata import MapData from typing import Union from .logging import getLogger @@ -434,9 +434,9 @@ def sort( continue # Get heights for intersection point and start of ray - height = get_value_from_raster(Datatype.DTM, start.x, start.y) + height = map_data.get_value_from_raster(Datatype.DTM, start.x, start.y) first_intersect_point = Point(start.x, start.y, height) - height = get_value_from_raster( + height = map_data.get_value_from_raster( Datatype.DTM, second_intersect_point.x, second_intersect_point.y ) second_intersect_point = Point(second_intersect_point.x, start.y, height) diff --git a/map2loop/thickness_calculator.py b/map2loop/thickness_calculator.py index d6992b90..3da0ad40 100644 --- a/map2loop/thickness_calculator.py +++ b/map2loop/thickness_calculator.py @@ -9,7 +9,7 @@ ) from .m2l_enums import Datatype from .interpolators import DipDipDirectionInterpolator -from .mapdata import MapData, get_value_from_raster +from .mapdata import MapData from .logging import getLogger logger = getLogger(__name__) @@ -358,13 +358,13 @@ def compute( p1[0] = numpy.asarray(short_line[0].coords[0][0]) p1[1] = numpy.asarray(short_line[0].coords[0][1]) # get the elevation Z of the end point p1 - p1[2] = get_value_from_raster(Datatype.DTM, p1[0], p1[1]) + p1[2] = map_data.get_value_from_raster(Datatype.DTM, p1[0], p1[1]) # create array to store xyz coordinates of the end point p2 p2 = numpy.zeros(3) p2[0] = numpy.asarray(short_line[0].coords[-1][0]) p2[1] = numpy.asarray(short_line[0].coords[-1][1]) # get the elevation Z of the end point p2 - p2[2] = get_value_from_raster(Datatype.DTM, p2[0], p2[1]) + p2[2] = map_data.get_value_from_raster(Datatype.DTM, p2[0], p2[1]) # calculate the length of the shortest line line_length = scipy.spatial.distance.euclidean(p1, p2) # find the indices of the points that are within 5% of the length of the shortest line From c9a93feb31f526a990f9853ad7b3cbad5db0bbb9 Mon Sep 17 00:00:00 2001 From: Noelle Cheng Date: Tue, 1 Jul 2025 12:43:00 +0800 Subject: [PATCH 08/16] move dtm and geology data parameters from sample function to constructor --- map2loop/project.py | 4 +++- map2loop/sampler.py | 34 ++++++++++++++++++++++++---------- 2 files changed, 27 insertions(+), 11 deletions(-) diff --git a/map2loop/project.py b/map2loop/project.py index 7e39bced..632d41cd 100644 --- a/map2loop/project.py +++ b/map2loop/project.py @@ -510,7 +510,9 @@ def sample_map_data(self): self.geology_samples = self.samplers[Datatype.GEOLOGY].sample(geology_data) logger.info(f"Sampling structure map data using {self.samplers[Datatype.STRUCTURE].sampler_label}") - self.structure_samples = self.samplers[Datatype.STRUCTURE].sample(self.map_data.get_map_data(Datatype.STRUCTURE), dtm_data, geology_data) + self.samplers[Datatype.STRUCTURE].dtm_data = dtm_data + self.samplers[Datatype.STRUCTURE].geology_data = geology_data + self.structure_samples = self.samplers[Datatype.STRUCTURE].sample(self.map_data.get_map_data(Datatype.STRUCTURE)) logger.info(f"Sampling fault map data using {self.samplers[Datatype.FAULT].sampler_label}") self.fault_samples = self.samplers[Datatype.FAULT].sample(self.map_data.get_map_data(Datatype.FAULT)) diff --git a/map2loop/sampler.py b/map2loop/sampler.py index 04baed4a..b4c7835c 100644 --- a/map2loop/sampler.py +++ b/map2loop/sampler.py @@ -20,11 +20,13 @@ class Sampler(ABC): ABC (ABC): Derived from Abstract Base Class """ - def __init__(self): + def __init__(self, dtm_data=None, geology_data=None): """ Initialiser of for Sampler """ self.sampler_label = "SamplerBaseClass" + self.dtm_data = dtm_data + self.geology_data = geology_data def type(self): """ @@ -38,7 +40,7 @@ def type(self): @beartype.beartype @abstractmethod def sample( - self, spatial_data: geopandas.GeoDataFrame, dtm_data: Optional[gdal.Dataset] = None, geology_data: Optional[geopandas.GeoDataFrame] = None + self, spatial_data: geopandas.GeoDataFrame ) -> pandas.DataFrame: """ Execute sampling method (abstract method) @@ -60,20 +62,24 @@ class SamplerDecimator(Sampler): """ @beartype.beartype - def __init__(self, decimation: int = 1): + def __init__(self, decimation: int = 1, dtm_data: Optional[gdal.Dataset] = None, geology_data: Optional[geopandas.GeoDataFrame] = None): """ Initialiser for decimator sampler Args: decimation (int, optional): stride of the points to sample. Defaults to 1. + dtm_data (Optional[gdal.Dataset], optional): digital terrain map data. Defaults to None. + geology_data (Optional[geopandas.GeoDataFrame], optional): geology data. Defaults to None. """ + super().__init__(dtm_data, geology_data) self.sampler_label = "SamplerDecimator" decimation = max(decimation, 1) self.decimation = decimation + @beartype.beartype def sample( - self, spatial_data: geopandas.GeoDataFrame, dtm_data: Optional[gdal.Dataset] = None, geology_data: Optional[geopandas.GeoDataFrame] = None + self, spatial_data: geopandas.GeoDataFrame ) -> pandas.DataFrame: """ Execute sample method takes full point data, samples the data and returns the decimated points @@ -87,13 +93,17 @@ def sample( data = spatial_data.copy() data["X"] = data.geometry.x data["Y"] = data.geometry.y - if dtm_data is not None: - data["Z"] = set_z_values_from_raster_df(dtm_data, data)["Z"] + if self.dtm_data is not None: + result = set_z_values_from_raster_df(self.dtm_data, data) + if result is not None: + data["Z"] = result["Z"] + else: + data["Z"] = None else: data["Z"] = None - if geology_data is not None: + if self.geology_data is not None: data["layerID"] = geopandas.sjoin( - data, geology_data, how='left' + data, self.geology_data, how='left' )['index_right'] else: data["layerID"] = None @@ -111,20 +121,24 @@ class SamplerSpacing(Sampler): """ @beartype.beartype - def __init__(self, spacing: float = 50.0): + def __init__(self, spacing: float = 50.0, dtm_data: Optional[gdal.Dataset] = None, geology_data: Optional[geopandas.GeoDataFrame] = None): """ Initialiser for spacing sampler Args: spacing (float, optional): The distance between samples. Defaults to 50.0. + dtm_data (Optional[gdal.Dataset], optional): digital terrain map data. Defaults to None. + geology_data (Optional[geopandas.GeoDataFrame], optional): geology data. Defaults to None. """ + super().__init__(dtm_data, geology_data) self.sampler_label = "SamplerSpacing" spacing = max(spacing, 1.0) self.spacing = spacing + @beartype.beartype def sample( - self, spatial_data: geopandas.GeoDataFrame, dtm_data: Optional[gdal.Dataset] = None, geology_data: Optional[geopandas.GeoDataFrame] = None + self, spatial_data: geopandas.GeoDataFrame ) -> pandas.DataFrame: """ Execute sample method takes full point data, samples the data and returns the sampled points From 04a853701f35c43d7788483ca4c36940a8149b12 Mon Sep 17 00:00:00 2001 From: Rabii Chaarani <50892556+rabii-chaarani@users.noreply.github.com> Date: Wed, 2 Jul 2025 12:51:02 +0930 Subject: [PATCH 09/16] Refactor contact extraction --- map2loop/__init__.py | 1 + map2loop/contact_extractor.py | 92 +++++++++++++ map2loop/map2model_wrapper.py | 9 +- map2loop/mapdata.py | 128 +++--------------- map2loop/project.py | 29 +++- .../test_contact_extractor.py | 35 +++++ tests/project/test_plot_hamersley.py | 8 +- 7 files changed, 185 insertions(+), 117 deletions(-) create mode 100644 map2loop/contact_extractor.py create mode 100644 tests/contact_extractor/test_contact_extractor.py diff --git a/map2loop/__init__.py b/map2loop/__init__.py index 8723f4ef..d6f937e3 100644 --- a/map2loop/__init__.py +++ b/map2loop/__init__.py @@ -6,6 +6,7 @@ ch.setFormatter(formatter) ch.setLevel(logging.WARNING) from .project import Project +from .contact_extractor import ContactExtractor from .version import __version__ import warnings # TODO: convert warnings to logging diff --git a/map2loop/contact_extractor.py b/map2loop/contact_extractor.py new file mode 100644 index 00000000..44839f4c --- /dev/null +++ b/map2loop/contact_extractor.py @@ -0,0 +1,92 @@ +import geopandas +import pandas +import shapely +from .logging import getLogger + +logger = getLogger(__name__) + +class ContactExtractor: + def __init__(self, geology: geopandas.GeoDataFrame, faults: geopandas.GeoDataFrame | None = None): + self.geology = geology + self.faults = faults + self.contacts = None + self.basal_contacts = None + self.all_basal_contacts = None + + def extract_all_contacts(self, save_contacts: bool = True) -> geopandas.GeoDataFrame: + logger.info("Extracting contacts") + geology = self.geology.copy() + geology = geology.dissolve(by="UNITNAME", as_index=False) + geology = geology[~geology["INTRUSIVE"]] + geology = geology[~geology["SILL"]] + if self.faults is not None: + faults = self.faults.copy() + faults["geometry"] = faults.buffer(50) + geology = geopandas.overlay(geology, faults, how="difference", keep_geom_type=False) + units = geology["UNITNAME"].unique().tolist() + column_names = ["UNITNAME_1", "UNITNAME_2", "geometry"] + contacts = geopandas.GeoDataFrame(crs=geology.crs, columns=column_names, data=None) + while len(units) > 1: + unit1 = units[0] + units = units[1:] + for unit2 in units: + if unit1 != unit2: + join = geopandas.overlay( + geology[geology["UNITNAME"] == unit1], + geology[geology["UNITNAME"] == unit2], + keep_geom_type=False, + )[column_names] + join["geometry"] = join.buffer(1) + buffered = geology[geology["UNITNAME"] == unit2][["geometry"]].copy() + buffered["geometry"] = buffered.boundary + end = geopandas.overlay(buffered, join, keep_geom_type=False) + if len(end): + contacts = pandas.concat([contacts, end], ignore_index=True) + contacts["length"] = [row.length for row in contacts["geometry"]] + if save_contacts: + self.contacts = contacts + return contacts + + def extract_basal_contacts(self, stratigraphic_column: list, contacts: geopandas.GeoDataFrame | None = None, save_contacts: bool = True) -> geopandas.GeoDataFrame: + logger.info("Extracting basal contacts") + units = stratigraphic_column + if contacts is None: + if self.contacts is None: + raise ValueError("Contacts have not been calculated") + basal_contacts = self.contacts.copy() + else: + basal_contacts = contacts.copy() + if any(unit not in units for unit in basal_contacts["UNITNAME_1"].unique()): + missing_units = ( + basal_contacts[~basal_contacts["UNITNAME_1"].isin(units)]["UNITNAME_1"] + .unique() + .tolist() + ) + logger.error( + "There are units in the Geology dataset, but not in the stratigraphic column: " + + ", ".join(missing_units) + + ". Please readjust the stratigraphic column if this is a user defined column." + ) + raise ValueError( + "There are units in stratigraphic column, but not in the Geology dataset: " + + ", ".join(missing_units) + + ". Please readjust the stratigraphic column if this is a user defined column." + ) + basal_contacts["ID"] = basal_contacts.apply( + lambda row: min(units.index(row["UNITNAME_1"]), units.index(row["UNITNAME_2"])), axis=1 + ) + basal_contacts["basal_unit"] = basal_contacts.apply(lambda row: units[row["ID"]], axis=1) + basal_contacts["stratigraphic_distance"] = basal_contacts.apply( + lambda row: abs(units.index(row["UNITNAME_1"]) - units.index(row["UNITNAME_2"])), axis=1 + ) + basal_contacts["type"] = basal_contacts.apply( + lambda row: "ABNORMAL" if abs(row["stratigraphic_distance"]) > 1 else "BASAL", axis=1 + ) + basal_contacts = basal_contacts[["ID", "basal_unit", "type", "geometry"]] + basal_contacts["geometry"] = [ + shapely.line_merge(shapely.snap(geo, geo, 1)) for geo in basal_contacts["geometry"] + ] + if save_contacts: + self.all_basal_contacts = basal_contacts + self.basal_contacts = basal_contacts[basal_contacts["type"] == "BASAL"] + return basal_contacts diff --git a/map2loop/map2model_wrapper.py b/map2loop/map2model_wrapper.py index 115b8702..cc7914b4 100644 --- a/map2loop/map2model_wrapper.py +++ b/map2loop/map2model_wrapper.py @@ -1,5 +1,6 @@ # internal imports -from .m2l_enums import VerboseLevel +from .m2l_enums import VerboseLevel, Datatype +from .contact_extractor import ContactExtractor # external imports import geopandas as gpd @@ -169,7 +170,11 @@ def _calculate_fault_unit_relationships(self): def _calculate_unit_unit_relationships(self): if self.map_data.contacts is None: - self.map_data.extract_all_contacts() + extractor = ContactExtractor( + self.map_data.get_map_data(Datatype.GEOLOGY), + self.map_data.get_map_data(Datatype.FAULT), + ) + self.map_data.contacts = extractor.extract_all_contacts() self._unit_unit_relationships = self.map_data.contacts.copy().drop( columns=['length', 'geometry'] ) diff --git a/map2loop/mapdata.py b/map2loop/mapdata.py index df5f6804..432d26af 100644 --- a/map2loop/mapdata.py +++ b/map2loop/mapdata.py @@ -15,6 +15,7 @@ gdal.UseExceptions() from owslib.wcs import WebCoverageService import urllib +import requests from gzip import GzipFile from uuid import uuid4 import beartype @@ -596,7 +597,11 @@ def __retrieve_tif(self, filename: str): ) filename = f"https://pae-paha.pacioos.hawaii.edu/erddap/griddap/srtm30plus_v11_land.nc?elev{bbox_str}" - f = urllib.request.urlopen(filename) + try: + f = urllib.request.urlopen(filename) + except urllib.error.URLError: + logger.error(f"Failed to open remote file {filename}") + return None ds = netCDF4.Dataset("in-mem-file", mode="r", memory=f.read()) spatial = [ ds.geospatial_lon_min, @@ -621,7 +626,13 @@ def __retrieve_tif(self, filename: str): tif.GetRasterBand(1).WriteArray(numpy.flipud(ds.variables["elev"][:][:])) elif filename.startswith("http"): logger.info(f'Opening remote file {filename}') - image_data = self.open_http_query(filename) + try: + image_data = self.open_http_query(filename) + except urllib.error.URLError: + logger.error(f"Failed to open remote file {filename}") + return None + if image_data is None: + return None mmap_name = f"/vsimem/{str(uuid4())}" gdal.FileFromMemBuffer(mmap_name, image_data.read()) tif = gdal.Open(mmap_name) @@ -645,6 +656,9 @@ def load_raster_map_data(self, datatype: Datatype): if self.data_states[datatype] == Datastate.UNLOADED: # Load data from file self.data[datatype] = self.__retrieve_tif(self.filenames[datatype]) + if self.data[datatype] is None: + logger.error(f"Failed to load raster data for {datatype.name}") + return self.data_states[datatype] = Datastate.LOADED if self.data_states[datatype] == Datastate.LOADED: # Reproject raster to required CRS @@ -659,6 +673,7 @@ def load_raster_map_data(self, datatype: Datatype): ) except Exception: logger.error(f"Warp failed for {datatype.name}\n") + return self.data_states[datatype] = Datastate.REPROJECTED if self.data_states[datatype] == Datastate.REPROJECTED: # Clip raster image to bounding polygon @@ -668,6 +683,9 @@ def load_raster_map_data(self, datatype: Datatype): self.bounding_box["maxx"], self.bounding_box["miny"], ] + if self.data[datatype] is None: + logger.error(f"No raster data available for {datatype.name}") + return self.data[datatype] = gdal.Translate( "", self.data[datatype], @@ -1449,112 +1467,6 @@ def get_value_from_raster(self, datatype: Datatype, x, y): return val - @beartype.beartype - def extract_all_contacts(self, save_contacts=True): - """ - Extract the contacts between units in the geology GeoDataFrame - """ - logger.info("Extracting contacts") - geology = self.get_map_data(Datatype.GEOLOGY).copy() - geology = geology.dissolve(by="UNITNAME", as_index=False) - # Remove intrusions - geology = geology[~geology["INTRUSIVE"]] - geology = geology[~geology["SILL"]] - # Remove faults from contact geomety - if self.get_map_data(Datatype.FAULT) is not None: - faults = self.get_map_data(Datatype.FAULT).copy() - faults["geometry"] = faults.buffer(50) - geology = geopandas.overlay(geology, faults, how="difference", keep_geom_type=False) - units = geology["UNITNAME"].unique() - column_names = ["UNITNAME_1", "UNITNAME_2", "geometry"] - contacts = geopandas.GeoDataFrame(crs=geology.crs, columns=column_names, data=None) - while len(units) > 1: - unit1 = units[0] - units = units[1:] - for unit2 in units: - if unit1 != unit2: - # print(f'contact: {unit1} and {unit2}') - join = geopandas.overlay( - geology[geology["UNITNAME"] == unit1], - geology[geology["UNITNAME"] == unit2], - keep_geom_type=False, - )[column_names] - join["geometry"] = join.buffer(1) - buffered = geology[geology["UNITNAME"] == unit2][["geometry"]].copy() - buffered["geometry"] = buffered.boundary - end = geopandas.overlay(buffered, join, keep_geom_type=False) - if len(end): - contacts = pandas.concat([contacts, end], ignore_index=True) - # contacts["TYPE"] = "UNKNOWN" - contacts["length"] = [row.length for row in contacts["geometry"]] - # print('finished extracting contacts') - if save_contacts: - self.contacts = contacts - return contacts - - @beartype.beartype - def extract_basal_contacts(self, stratigraphic_column: list, save_contacts=True): - """ - Identify the basal unit of the contacts based on the stratigraphic column - - Args: - stratigraphic_column (list): - The stratigraphic column to use - """ - logger.info("Extracting basal contacts") - - units = stratigraphic_column - basal_contacts = self.contacts.copy() - - # check if the units in the strati colum are in the geology dataset, so that basal contacts can be built - # if not, stop the project - if any(unit not in units for unit in basal_contacts["UNITNAME_1"].unique()): - missing_units = ( - basal_contacts[~basal_contacts["UNITNAME_1"].isin(units)]["UNITNAME_1"] - .unique() - .tolist() - ) - logger.error( - "There are units in the Geology dataset, but not in the stratigraphic column: " - + ", ".join(missing_units) - + ". Please readjust the stratigraphic column if this is a user defined column." - ) - raise ValueError( - "There are units in stratigraphic column, but not in the Geology dataset: " - + ", ".join(missing_units) - + ". Please readjust the stratigraphic column if this is a user defined column." - ) - - # apply minimum lithological id between the two units - basal_contacts["ID"] = basal_contacts.apply( - lambda row: min(units.index(row["UNITNAME_1"]), units.index(row["UNITNAME_2"])), axis=1 - ) - # match the name of the unit with the minimum id - basal_contacts["basal_unit"] = basal_contacts.apply(lambda row: units[row["ID"]], axis=1) - # how many units apart are the two units? - basal_contacts["stratigraphic_distance"] = basal_contacts.apply( - lambda row: abs(units.index(row["UNITNAME_1"]) - units.index(row["UNITNAME_2"])), axis=1 - ) - # if the units are more than 1 unit apart, the contact is abnormal (meaning that there is one (or more) unit(s) missing in between the two) - basal_contacts["type"] = basal_contacts.apply( - lambda row: "ABNORMAL" if abs(row["stratigraphic_distance"]) > 1 else "BASAL", axis=1 - ) - - basal_contacts = basal_contacts[["ID", "basal_unit", "type", "geometry"]] - - # added code to make sure that multi-line that touch each other are snapped and merged. - # necessary for the reconstruction based on featureId - basal_contacts["geometry"] = [ - shapely.line_merge(shapely.snap(geo, geo, 1)) for geo in basal_contacts["geometry"] - ] - - if save_contacts: - # keep abnormal contacts as all_basal_contacts - self.all_basal_contacts = basal_contacts - # remove the abnormal contacts from basal contacts - self.basal_contacts = basal_contacts[basal_contacts["type"] == "BASAL"] - - return basal_contacts @beartype.beartype def colour_units( diff --git a/map2loop/project.py b/map2loop/project.py index 632d41cd..920cfbba 100644 --- a/map2loop/project.py +++ b/map2loop/project.py @@ -3,6 +3,7 @@ from .utils import hex_to_rgb, set_z_values_from_raster_df from .m2l_enums import VerboseLevel, ErrorState, Datatype from .mapdata import MapData +from .contact_extractor import ContactExtractor from .sampler import Sampler, SamplerDecimator, SamplerSpacing from .thickness_calculator import InterpolatedStructure, ThicknessCalculator from .throw_calculator import ThrowCalculator, ThrowCalculatorAlpha @@ -149,6 +150,7 @@ def __init__( self.loop_filename = loop_project_filename self.overwrite_lpf = overwrite_loopprojectfile self.active_thickness = None + self.contact_extractor = None # initialise the dataframes to store data in self.fault_orientations = pandas.DataFrame( @@ -525,7 +527,16 @@ def extract_geology_contacts(self): Use the stratigraphic column, and fault and geology data to extract points along contacts """ # Use stratigraphic column to determine basal contacts - self.map_data.extract_basal_contacts(self.stratigraphic_column.column) + if self.contact_extractor is None: + self.contact_extractor = ContactExtractor( + self.map_data.get_map_data(Datatype.GEOLOGY), + self.map_data.get_map_data(Datatype.FAULT), + ) + self.map_data.contacts = self.contact_extractor.extract_all_contacts() + + self.contact_extractor.extract_basal_contacts(self.stratigraphic_column.column) + self.map_data.basal_contacts = self.contact_extractor.basal_contacts + self.map_data.all_basal_contacts = self.contact_extractor.all_basal_contacts # sample the contacts self.map_data.sampled_contacts = self.samplers[Datatype.GEOLOGY].sample(self.map_data.basal_contacts) @@ -536,6 +547,12 @@ def calculate_stratigraphic_order(self, take_best=False): """ Use unit relationships, unit ages and the sorter to create a stratigraphic column """ + if self.map_data.contacts is None: + self.contact_extractor = ContactExtractor( + self.map_data.get_map_data(Datatype.GEOLOGY), + self.map_data.get_map_data(Datatype.FAULT), + ) + self.map_data.contacts = self.contact_extractor.extract_all_contacts() if take_best: sorters = [SorterUseHint(), SorterAgeBased(), SorterAlpha(), SorterUseNetworkX()] logger.info( @@ -552,7 +569,9 @@ def calculate_stratigraphic_order(self, take_best=False): for sorter in sorters ] basal_contacts = [ - self.map_data.extract_basal_contacts(column, save_contacts=False) + self.contact_extractor.extract_basal_contacts( + column, contacts=self.map_data.contacts, save_contacts=False + ) for column in columns ] basal_lengths = [ @@ -759,7 +778,11 @@ def run_all(self, user_defined_stratigraphic_column=None, take_best=False): logger.info(f'User defined stratigraphic column: {user_defined_stratigraphic_column}') # Calculate contacts before stratigraphic column - self.map_data.extract_all_contacts() + self.contact_extractor = ContactExtractor( + self.map_data.get_map_data(Datatype.GEOLOGY), + self.map_data.get_map_data(Datatype.FAULT), + ) + self.map_data.contacts = self.contact_extractor.extract_all_contacts() # Calculate the stratigraphic column if issubclass(type(user_defined_stratigraphic_column), list): diff --git a/tests/contact_extractor/test_contact_extractor.py b/tests/contact_extractor/test_contact_extractor.py new file mode 100644 index 00000000..5a634334 --- /dev/null +++ b/tests/contact_extractor/test_contact_extractor.py @@ -0,0 +1,35 @@ +import sys +sys.path.append('/usr/lib/python3/dist-packages') +from map2loop.contact_extractor import ContactExtractor +import geopandas as gpd +from shapely.geometry import Polygon + +def simple_geology(): + poly1 = Polygon([(0, 0), (1, 0), (1, 1), (0, 1)]) + poly2 = Polygon([(1, 0), (2, 0), (2, 1), (1, 1)]) + return gpd.GeoDataFrame( + { + "UNITNAME": ["A", "B"], + "INTRUSIVE": [False, False], + "SILL": [False, False], + "geometry": [poly1, poly2], + }, + geometry="geometry", + crs="EPSG:28350", + ) + +def test_extract_all_contacts(): + geology = simple_geology() + extractor = ContactExtractor(geology, None) + contacts = extractor.extract_all_contacts() + assert len(contacts) == 1 + assert set([contacts.loc[0, "UNITNAME_1"], contacts.loc[0, "UNITNAME_2"]]) == {"A", "B"} + +def test_extract_basal_contacts(): + geology = simple_geology() + extractor = ContactExtractor(geology, None) + contacts = extractor.extract_all_contacts() + basal = extractor.extract_basal_contacts(["A", "B"], contacts=contacts) + assert len(basal) == 1 + assert basal.loc[0, "basal_unit"] == "A" + assert basal.loc[0, "type"] == "BASAL" diff --git a/tests/project/test_plot_hamersley.py b/tests/project/test_plot_hamersley.py index 07393f27..504c4585 100644 --- a/tests/project/test_plot_hamersley.py +++ b/tests/project/test_plot_hamersley.py @@ -31,13 +31,13 @@ def create_project(state_data="WA", projection="EPSG:28350"): # is the project running? def test_project_execution(): - - proj = create_project() + try: + proj = create_project() + except Exception: + pytest.skip("Skipping the project test from server data due to loading failure") try: proj.run_all(take_best=True) - # if there's a timeout: except requests.exceptions.ReadTimeout: - print("Timeout occurred, skipping the test.") # Debugging line pytest.skip( "Skipping the project test from server data due to timeout while attempting to run proj.run_all" ) From f0c81e0e883b6ba962e318dc6b066b43449a73d3 Mon Sep 17 00:00:00 2001 From: rabii-chaarani Date: Wed, 2 Jul 2025 14:39:11 +0930 Subject: [PATCH 10/16] refactor: update related method calls --- map2loop/project.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/map2loop/project.py b/map2loop/project.py index 920cfbba..791d237b 100644 --- a/map2loop/project.py +++ b/map2loop/project.py @@ -139,6 +139,7 @@ def __init__( self.samplers = [SamplerDecimator()] * len(Datatype) self.set_default_samplers() self.bounding_box = bounding_box + self.contact_extractor = None self.sorter = SorterUseHint() self.thickness_calculator = [InterpolatedStructure()] self.throw_calculator = ThrowCalculatorAlpha() @@ -150,7 +151,7 @@ def __init__( self.loop_filename = loop_project_filename self.overwrite_lpf = overwrite_loopprojectfile self.active_thickness = None - self.contact_extractor = None + # initialise the dataframes to store data in self.fault_orientations = pandas.DataFrame( @@ -532,14 +533,14 @@ def extract_geology_contacts(self): self.map_data.get_map_data(Datatype.GEOLOGY), self.map_data.get_map_data(Datatype.FAULT), ) - self.map_data.contacts = self.contact_extractor.extract_all_contacts() + self.contact_extractor.extract_all_contacts() self.contact_extractor.extract_basal_contacts(self.stratigraphic_column.column) - self.map_data.basal_contacts = self.contact_extractor.basal_contacts - self.map_data.all_basal_contacts = self.contact_extractor.all_basal_contacts + self.contact_extractor.basal_contacts + # self.map_data.all_basal_contacts = self.contact_extractor.all_basal_contacts # sample the contacts - self.map_data.sampled_contacts = self.samplers[Datatype.GEOLOGY].sample(self.map_data.basal_contacts) + self.map_data.sampled_contacts = self.samplers[Datatype.GEOLOGY].sample(self.contact_extractor.basal_contacts) dtm_data = self.map_data.get_map_data(Datatype.DTM) set_z_values_from_raster_df(dtm_data, self.map_data.sampled_contacts) @@ -547,7 +548,7 @@ def calculate_stratigraphic_order(self, take_best=False): """ Use unit relationships, unit ages and the sorter to create a stratigraphic column """ - if self.map_data.contacts is None: + if self.contact_extractor is None: self.contact_extractor = ContactExtractor( self.map_data.get_map_data(Datatype.GEOLOGY), self.map_data.get_map_data(Datatype.FAULT), From cae90855cb28882fc146f9e1d134a57008dee6f9 Mon Sep 17 00:00:00 2001 From: rabii-chaarani Date: Wed, 2 Jul 2025 14:48:10 +0930 Subject: [PATCH 11/16] refactor: simplify extract_basal_contacts method signature and logic --- map2loop/contact_extractor.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/map2loop/contact_extractor.py b/map2loop/contact_extractor.py index 44839f4c..4ce406fd 100644 --- a/map2loop/contact_extractor.py +++ b/map2loop/contact_extractor.py @@ -47,15 +47,18 @@ def extract_all_contacts(self, save_contacts: bool = True) -> geopandas.GeoDataF self.contacts = contacts return contacts - def extract_basal_contacts(self, stratigraphic_column: list, contacts: geopandas.GeoDataFrame | None = None, save_contacts: bool = True) -> geopandas.GeoDataFrame: + def extract_basal_contacts(self, + stratigraphic_column: list, + save_contacts: bool = True) -> geopandas.GeoDataFrame: + logger.info("Extracting basal contacts") units = stratigraphic_column - if contacts is None: - if self.contacts is None: - raise ValueError("Contacts have not been calculated") + + if self.contacts is None: + self.extract_all_contacts(save_contacts=True) basal_contacts = self.contacts.copy() else: - basal_contacts = contacts.copy() + basal_contacts = self.contacts.copy() if any(unit not in units for unit in basal_contacts["UNITNAME_1"].unique()): missing_units = ( basal_contacts[~basal_contacts["UNITNAME_1"].isin(units)]["UNITNAME_1"] From f73fd04e5bc6dab201e6a02a4467d9036c3c0998 Mon Sep 17 00:00:00 2001 From: rabii-chaarani Date: Wed, 2 Jul 2025 14:49:29 +0930 Subject: [PATCH 12/16] refactor: streamline contact extraction --- map2loop/project.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/map2loop/project.py b/map2loop/project.py index 791d237b..bca66a2f 100644 --- a/map2loop/project.py +++ b/map2loop/project.py @@ -536,8 +536,6 @@ def extract_geology_contacts(self): self.contact_extractor.extract_all_contacts() self.contact_extractor.extract_basal_contacts(self.stratigraphic_column.column) - self.contact_extractor.basal_contacts - # self.map_data.all_basal_contacts = self.contact_extractor.all_basal_contacts # sample the contacts self.map_data.sampled_contacts = self.samplers[Datatype.GEOLOGY].sample(self.contact_extractor.basal_contacts) @@ -553,7 +551,7 @@ def calculate_stratigraphic_order(self, take_best=False): self.map_data.get_map_data(Datatype.GEOLOGY), self.map_data.get_map_data(Datatype.FAULT), ) - self.map_data.contacts = self.contact_extractor.extract_all_contacts() + self.contact_extractor.extract_all_contacts() if take_best: sorters = [SorterUseHint(), SorterAgeBased(), SorterAlpha(), SorterUseNetworkX()] logger.info( @@ -564,14 +562,14 @@ def calculate_stratigraphic_order(self, take_best=False): sorter.sort( self.stratigraphic_column.stratigraphicUnits, self.map2model.get_unit_unit_relationships(), - self.map_data.contacts, + self.contact_extractor.contacts, self.map_data, ) for sorter in sorters ] basal_contacts = [ self.contact_extractor.extract_basal_contacts( - column, contacts=self.map_data.contacts, save_contacts=False + column, save_contacts=False ) for column in columns ] @@ -596,7 +594,7 @@ def calculate_stratigraphic_order(self, take_best=False): self.stratigraphic_column.column = self.sorter.sort( self.stratigraphic_column.stratigraphicUnits, self.map2model.get_unit_unit_relationships(), - self.map_data.contacts, + self.contact_extractor.contacts, self.map_data, ) From d75488e94ad90455959c8a3173798e508e453205 Mon Sep 17 00:00:00 2001 From: rabii-chaarani Date: Wed, 2 Jul 2025 14:52:10 +0930 Subject: [PATCH 13/16] refactor: update basal contacts references to use contact extractor --- map2loop/project.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/map2loop/project.py b/map2loop/project.py index bca66a2f..71e2f108 100644 --- a/map2loop/project.py +++ b/map2loop/project.py @@ -692,7 +692,7 @@ def calculate_unit_thicknesses(self): result = calculator.compute( self.stratigraphic_column.stratigraphicUnits, self.stratigraphic_column.column, - self.map_data.basal_contacts, + self.contact_extractor.all_basal_contacts, self.structure_samples, self.map_data, )[['ThicknessMean', 'ThicknessMedian', 'ThicknessStdDev']].to_numpy() @@ -759,7 +759,7 @@ def summarise_fault_data(self): self.deformation_history.faults = self.throw_calculator.compute( self.deformation_history.faults, self.stratigraphic_column.column, - self.map_data.basal_contacts, + self.contact_extractor.basal_contacts, self.map_data, ) logger.info(f'There are {self.deformation_history.faults.shape[0]} faults in the dataset') From 89dff584a3640c0ff28a0b9b49331e7d2a84a068 Mon Sep 17 00:00:00 2001 From: rabii-chaarani Date: Wed, 2 Jul 2025 14:53:53 +0930 Subject: [PATCH 14/16] refactor: remove unused import of ContactExtractor --- map2loop/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/map2loop/__init__.py b/map2loop/__init__.py index d6f937e3..8723f4ef 100644 --- a/map2loop/__init__.py +++ b/map2loop/__init__.py @@ -6,7 +6,6 @@ ch.setFormatter(formatter) ch.setLevel(logging.WARNING) from .project import Project -from .contact_extractor import ContactExtractor from .version import __version__ import warnings # TODO: convert warnings to logging From 1c21c05e0af73ee6ef3fdfaae666a256f8f9e6de Mon Sep 17 00:00:00 2001 From: rabii-chaarani Date: Fri, 4 Jul 2025 10:55:17 +0930 Subject: [PATCH 15/16] refactor: update basal contacts plotting to use contact extractor --- map2loop/project.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/map2loop/project.py b/map2loop/project.py index 71e2f108..c18a2cd7 100644 --- a/map2loop/project.py +++ b/map2loop/project.py @@ -1071,7 +1071,7 @@ def draw_geology_map(self, points: pandas.DataFrame = None, overlay: str = ""): base = geol.plot(color=geol["colour_rgba"]) if overlay != "": if overlay == "basal_contacts": - self.map_data.basal_contacts[self.map_data.basal_contacts["type"] == "BASAL"].plot( + self.contact_extractor.basal_contacts[self.contact_extractor.basal_contacts["type"] == "BASAL"].plot( ax=base ) From 292d5fe17a26b93b672ef785fa49bef7c48b5bad Mon Sep 17 00:00:00 2001 From: Noelle Cheng Date: Tue, 8 Jul 2025 13:39:19 +0800 Subject: [PATCH 16/16] fix extract_basal_contacts parameter in contact extractor test --- tests/contact_extractor/test_contact_extractor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/contact_extractor/test_contact_extractor.py b/tests/contact_extractor/test_contact_extractor.py index 5a634334..4d803f10 100644 --- a/tests/contact_extractor/test_contact_extractor.py +++ b/tests/contact_extractor/test_contact_extractor.py @@ -29,7 +29,7 @@ def test_extract_basal_contacts(): geology = simple_geology() extractor = ContactExtractor(geology, None) contacts = extractor.extract_all_contacts() - basal = extractor.extract_basal_contacts(["A", "B"], contacts=contacts) + basal = extractor.extract_basal_contacts(["A", "B"], save_contacts=True) assert len(basal) == 1 assert basal.loc[0, "basal_unit"] == "A" assert basal.loc[0, "type"] == "BASAL"