From 716610a0ce917ad2d7ae4a40a68e1558cafe6bfa Mon Sep 17 00:00:00 2001 From: Miguel de la Varga Date: Wed, 29 Oct 2025 11:53:39 +0100 Subject: [PATCH 1/4] Allow optional likelihood function in `make_gempy_pyro_model` and add Python version assertion in `__init__.py`. --- gempy_probability/__init__.py | 22 ++++++++++++++++++- .../model_definition/prob_model_factory.py | 7 ++++-- 2 files changed, 26 insertions(+), 3 deletions(-) diff --git a/gempy_probability/__init__.py b/gempy_probability/__init__.py index a695542..f0baa37 100644 --- a/gempy_probability/__init__.py +++ b/gempy_probability/__init__.py @@ -2,4 +2,24 @@ from .modules import likelihoods from .api.model_runner import run_predictive, run_mcmc_for_NUTS, run_nuts_inference -from ._version import __version__ +""" +Module initialisation for GemPy Probability +""" +import sys + +# * Assert at least python 3.10 +assert sys.version_info[0] >= 3 and sys.version_info[1] >= 10, "GemPy Probability requires Python 3.10 or higher" + +# Import version, with fallback if not generated yet +try: + from ._version import __version__ +except ImportError: + __version__ = "unknown" + +# =================== CORE =================== +# Import your core modules here + +# =================== API =================== + +if __name__ == '__main__': + pass diff --git a/gempy_probability/modules/model_definition/prob_model_factory.py b/gempy_probability/modules/model_definition/prob_model_factory.py index eb99915..e1b414a 100644 --- a/gempy_probability/modules/model_definition/prob_model_factory.py +++ b/gempy_probability/modules/model_definition/prob_model_factory.py @@ -1,7 +1,7 @@ import pyro import torch from pyro.distributions import Distribution -from typing import Callable, Dict +from typing import Callable, Dict, Optional import gempy as gp from gempy_engine.core.backend_tensor import BackendTensor @@ -17,7 +17,7 @@ def make_gempy_pyro_model( [Dict[str, Distribution], gp.data.GeoModel], gp.data.InterpolationInput ], - likelihood_fn: Callable[[gp.data.Solutions], Distribution], + likelihood_fn: Optional[Callable[[gp.data.Solutions], Distribution]], obs_name: str = "obs" ) -> GemPyPyroModel: """ @@ -115,6 +115,9 @@ def model(geo_model: gp.data.GeoModel, obs_data: torch.Tensor): ) # 4) Wrap in likelihood & observe + if likelihood_fn is None: + return + lik_dist = likelihood_fn(simulated) pyro.sample(obs_name, lik_dist, obs=obs_data) From 17ac3fd867bb2cbd8aebeff32ba6a7363dd79a90 Mon Sep 17 00:00:00 2001 From: Miguel de la Varga Date: Wed, 29 Oct 2025 13:58:04 +0100 Subject: [PATCH 2/4] Format code for PEP 8 compliance in `_pyro_runner.py`. --- gempy_probability/api/model_runner/_pyro_runner.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/gempy_probability/api/model_runner/_pyro_runner.py b/gempy_probability/api/model_runner/_pyro_runner.py index 1133957..9841856 100644 --- a/gempy_probability/api/model_runner/_pyro_runner.py +++ b/gempy_probability/api/model_runner/_pyro_runner.py @@ -11,8 +11,8 @@ from ...core.samplers_data import NUTSConfig -def run_predictive(prob_model: gpp.GemPyPyroModel, geo_model: gp.data.GeoModel, - y_obs_list: torch.Tensor, n_samples: int, plot_trace:bool=False) -> az.InferenceData: +def run_predictive(prob_model: gpp.GemPyPyroModel, geo_model: gp.data.GeoModel, + y_obs_list: torch.Tensor, n_samples: int, plot_trace: bool = False) -> az.InferenceData: predictive = Predictive( model=prob_model, num_samples=n_samples @@ -24,13 +24,13 @@ def run_predictive(prob_model: gpp.GemPyPyroModel, geo_model: gp.data.GeoModel, if plot_trace: az.plot_trace(data.prior) plt.show() - + return data -def run_nuts_inference(prob_model: gpp.GemPyPyroModel, geo_model: gp.data.GeoModel, - y_obs_list: torch.Tensor, config: NUTSConfig, plot_trace:bool=False, - run_posterior_predictive:bool=False) -> az.InferenceData: - + +def run_nuts_inference(prob_model: gpp.GemPyPyroModel, geo_model: gp.data.GeoModel, + y_obs_list: torch.Tensor, config: NUTSConfig, plot_trace: bool = False, + run_posterior_predictive: bool = False) -> az.InferenceData: nuts_kernel = NUTS( prob_model, step_size=config.step_size, From 755c89df2f260288f5cda3700fe3299203d91a46 Mon Sep 17 00:00:00 2001 From: Miguel de la Varga Date: Wed, 29 Oct 2025 14:47:57 +0100 Subject: [PATCH 3/4] Add `plot_gempy` function for probabilistic model visualization Introduced a utility function, `plot_gempy`, to streamline the visualization of GemPy models with uncertainty from prior/posterior samples. Refactored test code to utilize the new function, improving modularity and reducing redundancy. Added an adjustable update function for greater flexibility. --- gempy_probability/modules/plot/plot_gempy.py | 91 ++++++++++++++++++++ 1 file changed, 91 insertions(+) create mode 100644 gempy_probability/modules/plot/plot_gempy.py diff --git a/gempy_probability/modules/plot/plot_gempy.py b/gempy_probability/modules/plot/plot_gempy.py new file mode 100644 index 0000000..95f6954 --- /dev/null +++ b/gempy_probability/modules/plot/plot_gempy.py @@ -0,0 +1,91 @@ +from typing import Callable, Optional + +import gempy as gp +import gempy_viewer as gpv +import numpy as np +from gempy_viewer.modules.plot_2d.visualization_2d import Plot2D + + +def plot_gempy( + geo_model, # gp.data.GeoModel - avoiding import + n_samples: int, + samples: np.ndarray, + update_model_fn: Callable, + gempy_plot: Plot2D, + plot_kwargs: Optional[dict] = None +): + """ + General function to plot GemPy models with uncertainty from prior/posterior samples. + + Parameters + ---------- + geo_model : gp.data.GeoModel + The geological model to update and plot + n_samples : int + Number of samples to plot + samples : np.ndarray + Array of sample values to iterate through + update_model_fn : Callable + Function that takes (geo_model, sample_value, sample_idx) and updates the model. + Should return None and modify geo_model in place. + gempy_plot : Plot2D + GemPy Plot2D object containing the figure and section data to plot on + plot_kwargs : dict, optional + Additional plotting kwargs for boundaries, surface points, etc. + + Examples + -------- + >>> def update_model_fn(geo_model, sample_value, sample_idx): + ... # Transform sample value to world coordinates + ... xyz = np.zeros((1, 3)) + ... xyz[0, 2] = sample_value + ... world_coord = geo_model.input_transform.apply_inverse(xyz) + ... # Modify surface point + ... gp.modify_surface_points(geo_model, slice=0, Z=world_coord[0, 2]) + >>> + >>> p2d = gpv.plot_2d(geo_model, show_lith=False, show_data=False, show=False) + >>> samples = prior_inference_data.prior['$\\mu_{top}$'].values[0, :] + >>> plot_gempy(geo_model, n_samples=50, samples=samples, + ... update_model_fn=update_model_fn, gempy_plot=p2d) + """ + # Import here to avoid circular dependencies and to make gempy optional + import gempy as gp + from gempy_viewer.API._plot_2d_sections_api import plot_sections + from gempy_viewer.core.data_to_show import DataToShow + + plot_kwargs = plot_kwargs or {} + + # Iterate through samples + for i in np.linspace(0, n_samples - 1, n_samples).astype(int): + # Update model using the provided function + update_model_fn(geo_model, samples[i], i) + + # Compute the model + gp.compute_model(gempy_model=geo_model) + + # Plot the updated model + default_plot_kwargs = { + 'kwargs_boundaries': { + "linewidth": 0.5, + "alpha": 0.1, + }, + 'kwargs_surface_points': { + 'alpha': 0.1 + }, + } + # Merge with user-provided kwargs (user kwargs override defaults) + final_plot_kwargs = {**default_plot_kwargs, **plot_kwargs} + + plot_sections( + gempy_model=geo_model, + sections_data=gempy_plot.section_data_list, + data_to_show=DataToShow( + n_axis=1, + show_data=True, + show_surfaces=True, + show_lith=False + ), + **final_plot_kwargs + ) + + gempy_plot.fig.show() \ No newline at end of file From 589a9f988ae8307be8d6989450c2ef4c67b0de13 Mon Sep 17 00:00:00 2001 From: Miguel de la Varga Date: Wed, 29 Oct 2025 16:56:57 +0100 Subject: [PATCH 4/4] Add functionality to test error propagation for dips Introduced a new `TestErrorPropagationDips` class to evaluate error propagation related to orientation dips in geological models. Added methods for modifying dips, updating models, and performing prior inference with Pyro, ensuring proper handling of dip gradients and validations. --- gempy_probability/modules/plot/plot_gempy.py | 21 +++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/gempy_probability/modules/plot/plot_gempy.py b/gempy_probability/modules/plot/plot_gempy.py index 95f6954..19027c1 100644 --- a/gempy_probability/modules/plot/plot_gempy.py +++ b/gempy_probability/modules/plot/plot_gempy.py @@ -52,7 +52,7 @@ def plot_gempy( import gempy as gp from gempy_viewer.API._plot_2d_sections_api import plot_sections from gempy_viewer.core.data_to_show import DataToShow - + plot_kwargs = plot_kwargs or {} # Iterate through samples @@ -65,13 +65,16 @@ def plot_gempy( # Plot the updated model default_plot_kwargs = { - 'kwargs_boundaries': { - "linewidth": 0.5, - "alpha": 0.1, - }, - 'kwargs_surface_points': { - 'alpha': 0.1 - }, + 'kwargs_boundaries' : { + "linewidth": 0.5, + "alpha" : 0.1, + }, + 'kwargs_surface_points': { + 'alpha': 0.1 + }, + 'kwargs_orientations' : { + 'alpha': 0.02, + } } # Merge with user-provided kwargs (user kwargs override defaults) final_plot_kwargs = {**default_plot_kwargs, **plot_kwargs} @@ -88,4 +91,4 @@ def plot_gempy( **final_plot_kwargs ) - gempy_plot.fig.show() \ No newline at end of file + gempy_plot.fig.show()