Source code for nres.models

from __future__ import annotations

from copy import deepcopy
from typing import List, Optional, Union

import lmfit
import matplotlib.pyplot as plt
import numpy as np
import pandas

import nres.utils as utils
from nres.cross_section import CrossSection
from nres.data import Data
from nres.response import Background, Response


[docs]class TransmissionModel(lmfit.Model): def __init__( self, cross_section, response: str = "expo_gauss", background: str = "polynomial3", tof_calibration: str = "linear", vary_weights: bool = None, vary_background: bool = None, vary_tof: bool = None, vary_response: bool = None, params: lmfit.Parameters = None, **kwargs, ): """ Initialize the TransmissionModel, a subclass of lmfit.Model. Parameters ---------- cross_section : callable A function that takes energy (E) as input and returns the cross section. response : str, optional The type of response function to use, by default "expo_gauss". tof_calibration : str, optional The type of TOF calibration to use, by default "linear". other options are "full" to include the energy dependent corrections. background : str, optional The type of background function to use, by default "polynomial3". vary_weights : bool, optional If True, allows the isotope weights to vary during fitting. vary_background : bool, optional If True, allows the background parameters (b0, b1, b2) to vary during fitting. vary_tof : bool, optional If True, allows the TOF (time-of-flight) parameters (L0, t0) to vary during fitting. vary_response : bool, optional If True, allows the response parameters to vary during fitting. params : lmfit.Parameters, optional Initial parameter values from a previous fit. Only the parameter values will be updated; vary flags, bounds, and expressions remain as defined by the vary_* arguments. This is useful for using fit results as initial guesses for subsequent fits. kwargs : dict, optional Additional keyword arguments for model and background parameters. Notes ----- This model calculates the transmission function as a combination of cross-section, response function, and background. Examples -------- Using fit results as initial guesses for a new model: >>> # First fit >>> model1 = TransmissionModel(xs, vary_background=True, vary_tof=True) >>> result1 = model1.fit(data1) >>> >>> # Use result1 parameters as initial guesses for a new fit >>> model2 = TransmissionModel(xs, vary_background=True, params=result1.params) >>> result2 = model2.fit(data2) """ # Extract params from kwargs if provided there (for backward compatibility) if params is None and "params" in kwargs: params = kwargs.pop("params") super().__init__(self.transmission, **kwargs) self.cross_section = CrossSection() for material in cross_section.materials: self.cross_section += CrossSection( **{material: cross_section.materials[material]}, splitby=cross_section.materials[material]["splitby"], ) self.params = self.make_params() # Add minimum bounds to basic parameters to prevent negative values if "thickness" in self.params: self.params["thickness"].set(min=0.0) if "norm" in self.params: self.params["norm"].set(min=0.0) if vary_weights is not None: self.params += self._make_weight_params(vary=vary_weights) if vary_tof is not None: self.params += self._make_tof_params( vary=vary_tof, kind=tof_calibration, **kwargs ) self.response = Response( kind=response, vary=vary_response, tstep=self.cross_section.tstep ) if vary_response is not None: self.params += self.response.params self.background = Background(kind=background, vary=vary_background) if vary_background is not None: self.params += self.background.params # set the total atomic weight n [atoms/barn-cm] self.n = self.cross_section.n if self.cross_section else 0.01 # Initialize stages based on vary_* parameters self._stages = {} possible_stages = ["basic", "background", "tof", "response", "weights"] vary_flags = { "basic": True, # Always include basic parameters "background": vary_background, "tof": vary_tof, "response": vary_response, "weights": vary_weights, } for stage in possible_stages: if vary_flags.get(stage, False) is True: self._stages[stage] = stage # Load parameter values from previous fit if provided # This only updates values, not vary flags, bounds, or expressions if params is not None: self._load_param_values(params)
[docs] def transmission( self, E: np.ndarray, thickness: float = 1, norm: float = 1.0, **kwargs ): """ Transmission function model with background components. Parameters ---------- E : np.ndarray The energy values at which to calculate the transmission. thickness : float, optional The thickness of the material (in cm), by default 1. norm : float, optional Normalization factor, by default 1. kwargs : dict, optional Additional arguments for background, response, or cross-section. Returns ------- np.ndarray The calculated transmission values. Notes ----- This function combines the cross-section with the response and background models to compute the transmission, which is given by: .. math:: T(E) = \\text{norm} \\cdot e^{- \\sigma \\cdot \\text{thickness} \\cdot n} \\cdot (1 - \\text{bg}) + \\text{bg} where `sigma` is the cross-section, `bg` is the background function, and `n` is the total atomic weight. """ E = self._tof_correction(E, **kwargs) response = self.response.function(**kwargs) weights = deepcopy(self.cross_section.weights) weights = [ kwargs.pop(key.replace("_", ""), val) for key, val in weights.items() ] bg = self.background.function(E, **kwargs) k = kwargs.get( "k", 1.0 ) # background factor, relevant for some of the background models n = self.n # Transmission function xs = self.cross_section(E, weights=weights, response=response) T = norm * np.exp(-xs * thickness * n) * (1 - bg) + k * bg return T
def _load_param_values(self, source_params: lmfit.Parameters): """ Load parameter values from a source Parameters object. Only updates the values of existing parameters, preserving vary flags, bounds, and expressions as defined during model initialization. Parameters ---------- source_params : lmfit.Parameters Source parameters (e.g., from a previous fit result) to load values from. Notes ----- This method is called during __init__ if params argument is provided. It ensures that only parameters that exist in both source and target are updated, and only their values are changed. """ for param_name in self.params: if param_name in source_params: # Only update the value, preserve vary, min, max, expr self.params[param_name].value = source_params[param_name].value @property def stages(self): """Get the current fitting stages.""" return self._stages @stages.setter def stages(self, value): """ Set the fitting stages. Parameters ---------- value : str or dict If str, must be "all" to use all vary=True parameters. If dict, keys are stage names, values are stage definitions ("all", a valid group name, or a list of parameters/groups). """ import re # Define valid group names from group_map group_map = { "basic": ["norm", "thickness"], "background": [ p for p in self.params if re.compile(r"(b|bg)\d+").match(p) or p.startswith("b_") ], "tof": [p for p in ["L0", "t0", "t1", "t2"] if p in self.params], "response": [ p for p in self.params if self.response and p in self.response.params ], "weights": [p for p in self.params if re.compile(r"p\d+").match(p)], } if isinstance(value, str): if value != "all": raise ValueError("If stages is a string, it must be 'all'") self._stages = {"all": "all"} elif isinstance(value, dict): # Validate stage definitions for stage_name, stage_def in value.items(): if not isinstance(stage_name, str): raise ValueError( f"Stage names must be strings, got {type(stage_name)}" ) if isinstance(stage_def, str): if stage_def != "all" and stage_def not in group_map: raise ValueError( f"Stage definition for '{stage_name}' must be 'all' or a valid group name, got '{stage_def}'" ) elif isinstance(stage_def, list): for param in stage_def: if not isinstance(param, str): raise ValueError( f"Parameters in stage '{stage_name}' must be strings, got {type(param)}" ) else: raise ValueError( f"Stage definition for '{stage_name}' must be 'all', a valid group name, or a list, got {type(stage_def)}" ) self._stages = value else: raise ValueError( f"Stages must be a string ('all') or dict, got {type(value)}" )
[docs] def fit( self, data, params=None, emin: float = 0.5e6, emax: float = 20.0e6, method: str = "rietveld", xtol: float = None, ftol: float = None, gtol: float = None, verbose: bool = False, progress_bar: bool = True, param_groups: Optional[List[List[str]]] = None, **kwargs, ): """ Fit the model to data. This method supports both: - **Standard single-stage fitting** (default) - **Rietveld-style staged refinement** (`method="rietveld"`) with accumulative parameter refinement with accumulative parameter refinement Parameters ---------- data : pandas.DataFrame or Data or array-like The input data. - For `pandas.DataFrame` or `Data`: must have columns `"energy"`, `"trans"`, and `"err"`. - For array-like: will be passed directly to `lmfit.Model.fit`. params : lmfit.Parameters, optional Parameters to use for fitting. If None, uses the model's default parameters. emin, emax : float, optional Minimum and maximum energy for fitting (ignored for array-like input and overridden per stage if `param_groups` specify `"emin=..."` or `"emax=..."` strings). method : str, optional Fitting method. - `"rietveld"` (default) will run staged refinement via `_rietveld_fit`. - `"least-squares"` or any method supported by `lmfit` for single-stage fitting. xtol, ftol, gtol : float, optional Convergence tolerances (passed to `lmfit`). verbose : bool, optional If True, prints detailed fitting information. progress_bar : bool, optional If True, shows a progress bar for fitting: - For `"rietveld"`: shows stage name, energy range, and reduced chi² per stage. - For regular fits: shows overall fit progress. param_groups : list, dict, or None, optional Used only for `"rietveld"`. Groups of parameters to fit in each stage. Groups may also contain `"emin=..."` and/or `"emax=..."` strings to override the energy fitting range for that specific stage. For example: ```python param_groups = { "Basic": ["basic"], "Background": ["background", "emin=3", "emax=8"], "Extinction": ["extinction"], } ``` These per-stage overrides temporarily replace the global `emin`/`emax` only during the stage. **kwargs Additional keyword arguments passed to `lmfit.Model.fit`. For grouped data, additional parameters: - `n_jobs` (int): Number of parallel jobs (default: 10). Use -1 for all CPUs, but beware of memory issues. For threading, consider n_jobs=4 or less. - `max_nbytes` (str): Maximum memory per worker (default: '100M'). Prevents memory exhaustion. Increase for complex models or set to None to disable. Returns ------- lmfit.model.ModelResult The fit result object, with extra methods: - `.plot()` — plot the fit result. - `.plot_stage_progression()`, `.plot_chi2_progression()` for advanced diagnostics. - `.stages_summary` (for `"rietveld"`). Examples -------- **Basic fit:** ```python result = model.fit(data_df, emin=1.0, emax=5.0) result.plot() ``` **Rietveld-style staged refinement with per-stage energy overrides:** ```python param_groups = { "Norm/Thick": ["norm", "thickness"], "Background": ["b0", "b1", "emin=3", "emax=8"], "Extinction": ["ext_l", "ext_Gg"], } result = model.fit( data_df, method="rietveld", param_groups=param_groups, progress_bar=True ) print(result.stages_summary) ``` """ # Use self.stages if param_groups not provided if param_groups is None and hasattr(self, "stages") and self.stages is not None: param_groups = self.stages # Check if data is grouped and route to parallel fitting if hasattr(data, "is_grouped") and data.is_grouped: n_jobs = kwargs.pop("n_jobs", 10) max_nbytes = kwargs.pop("max_nbytes", "100M") return self._fit_grouped( data, params, emin, emax, method=method, xtol=xtol, ftol=ftol, gtol=gtol, verbose=verbose, progress_bar=progress_bar, param_groups=param_groups, n_jobs=n_jobs, max_nbytes=max_nbytes, **kwargs, ) # Route to Rietveld if requested (or if param_groups/stages provided) if method == "rietveld" or param_groups is not None: return self._rietveld_fit( data, params, emin, emax, verbose=verbose, progress_bar=progress_bar, param_groups=param_groups, **kwargs, ) # Prepare fit kwargs fit_kws = kwargs.pop("fit_kws", {}) if xtol is not None: fit_kws.setdefault("xtol", xtol) if ftol is not None: fit_kws.setdefault("ftol", ftol) if gtol is not None: fit_kws.setdefault("gtol", gtol) kwargs["fit_kws"] = fit_kws # Try tqdm for progress try: from tqdm.notebook import tqdm except ImportError: from tqdm.auto import tqdm # If progress_bar=True, wrap the fit in tqdm if progress_bar: pbar = tqdm(total=1, desc="Fitting", disable=not progress_bar) else: pbar = None # Prepare input data if isinstance(data, pandas.DataFrame): data = data.query(f"{emin} < energy < {emax}") weights = kwargs.get("weights", 1.0 / data["err"].values) fit_result = super().fit( data["trans"].values, params=params or self.params, weights=weights, E=data["energy"].values, method=method, **kwargs, ) elif isinstance(data, Data): data = data.table.query(f"{emin} < energy < {emax}") weights = kwargs.get("weights", 1.0 / data["err"].values) fit_result = super().fit( data["trans"].values, params=params or self.params, weights=weights, E=data["energy"].values, method=method, **kwargs, ) else: fit_result = super().fit( data, params=params or self.params, method=method, **kwargs ) if pbar: pbar.set_postfix({"redchi": f"{fit_result.redchi:.4g}"}) pbar.update(1) pbar.close() # Attach results self.fit_result = fit_result fit_result.plot = self.plot fit_result.show_available_params = self.show_available_params fit_result.save = lambda filename, include_model=True: self._save_result( fit_result, filename, include_model ) fit_result.save = lambda filename, include_model=True: self._save_result( fit_result, filename, include_model ) if self.response is not None: fit_result.response = self.response fit_result.response.params = fit_result.params if self.background is not None: fit_result.background = self.background return fit_result
def _rietveld_fit( self, data, params: lmfit.Parameters = None, emin: float = 0.5e6, emax: float = 20.0e6, verbose=False, progress_bar=True, param_groups=None, **kwargs, ): """Perform Rietveld-style staged fitting with accumulative parameter refinement. In this method, parameters accumulate across stages. When a new stage is added, all previously refined parameters remain vary=True, allowing for simultaneous refinement of all parameters introduced up to that stage. Parameters ---------- data : pandas.DataFrame or Data The input data containing energy and transmission values. params : lmfit.Parameters, optional Initial parameters for the fit. If None, uses the model's default parameters. emin : float, optional default=0.5e6 Default minimum energy for fitting. emax : float, optional default=20.e6 Default maximum energy for fitting. verbose : bool, optional If True, prints detailed information about each fitting stage. progress_bar : bool, optional If True, shows a progress bar for each fitting stage. param_groups : list, dict, or None, optional - only used for Rietveld fitting Groups of parameters to fit in each stage. Can contain special keywords: - "emin=<value>" or "emax=<value>": override energy bounds for that stage - "pick-one" or "pick_one": enable pick-one mode for isotope selection. In this mode, the fit tries each cross-section material individually with weight=1 (others at 0), then selects the isotope with the best fit quality (lowest reduced chi-squared). The selected isotope is fixed at weight=1 with all others at weight=0. kwargs : dict, optional Additional keyword arguments for the fit method, such as weights, method, etc. Returns ------- fit_result : lmfit.ModelResult The final fit result after all stages. fit_result.stages_summary : pandas.DataFrame Summary of each fitting stage, including parameter values and reduced chi-squared. """ import fnmatch import re import sys import warnings from copy import deepcopy import pandas try: from tqdm.notebook import tqdm except ImportError: from tqdm.auto import tqdm import pickle # Use original params to determine which were set to vary original_params = params or self.params # User-friendly group name mapping - only include parameters that have vary=True group_map = { "basic": [ p for p in ["norm", "thickness"] if p in original_params and original_params[p].vary ], "background": [ p for p in original_params if (re.compile(r"(b|bg)\d+").match(p) or p.startswith("b_")) and original_params[p].vary ], "tof": [ p for p in ["L0", "t0", "t1", "t2"] if p in original_params and original_params[p].vary ], "response": [ p for p in original_params if self.response and p in self.response.params and original_params[p].vary ], "weights": [ p for p in original_params if re.compile(r"p\d+").match(p) and original_params[p].vary ], } def resolve_single_param_or_group(item): """Resolve a single parameter name or group name to a list of parameters.""" if item in group_map: resolved = group_map[item] if verbose: print(f" Resolved group '{item}' to: {resolved}") return resolved if item in self.params: if verbose: print(f" Found parameter: {item}") return [item] matching_params = [ p for p in self.params.keys() if fnmatch.fnmatch(p, item) ] if matching_params: if verbose: print(f" Pattern '{item}' matched: {matching_params}") return matching_params warnings.warn( f"Unknown parameter or group: '{item}'. Available parameters: {list(self.params.keys())}" ) return [] def resolve_group(entry): """ Resolve a group entry (string, list, or nested structure) to: - A flat list of parameters - A dict of overrides like {'emin': float, 'emax': float} """ params_list = [] overrides = {} def process_item(item): nonlocal params_list, overrides if isinstance(item, str): if item.startswith("emin="): try: overrides["emin"] = float(item.split("=", 1)[1]) if verbose: print(f" Override emin detected: {overrides['emin']}") except ValueError: warnings.warn(f"Invalid emin value in group: {item}") elif item.startswith("emax="): try: overrides["emax"] = float(item.split("=", 1)[1]) if verbose: print(f" Override emax detected: {overrides['emax']}") except ValueError: warnings.warn(f"Invalid emax value in group: {item}") elif item == "pick-one" or item == "pick_one": overrides["pick_one"] = True if verbose: print( " Pick-one mode detected: will test each isotope individually" ) else: params_list.extend(resolve_single_param_or_group(item)) elif isinstance(item, list): for subitem in item: process_item(subitem) else: warnings.warn( f"Unexpected item type in group: {type(item)} - {item}" ) process_item(entry) return params_list, overrides # Handle different input formats for param_groups and parse overrides stage_names = [] resolved_param_groups = [] stage_overrides = [] if param_groups is None: # Default groups default_groups = [ "basic", "background", "tof", "response", "weights", ] for group in default_groups: params_list, overrides = resolve_group(group) if params_list: resolved_param_groups.append(params_list) stage_overrides.append(overrides) stage_names.append(f"Stage_{len(stage_names) + 1}") elif verbose: print(f"Skipping empty default group: {group}") elif isinstance(param_groups, dict): stage_names = list(param_groups.keys()) for stage in stage_names: params_list, overrides = resolve_group(param_groups[stage]) if params_list: resolved_param_groups.append(params_list) stage_overrides.append(overrides) else: if verbose: print(f"Skipping empty group: {stage}") elif isinstance(param_groups, list): for i, group in enumerate(param_groups): params_list, overrides = resolve_group(group) if params_list: resolved_param_groups.append(params_list) stage_overrides.append(overrides) stage_names.append(f"Stage_{i + 1}") else: if verbose: print(f"Skipping empty group at index {i}") else: raise ValueError("param_groups must be None, a list, or a dictionary") # Remove any empty groups that slipped through filtered = [ (n, g, o) for n, g, o in zip(stage_names, resolved_param_groups, stage_overrides) if g ] if not filtered: raise ValueError( "No valid parameter groups found. Check your parameter names." ) stage_names, resolved_param_groups, stage_overrides = zip(*filtered) if verbose: print("\nFitting stages with possible energy overrides:") for i, (name, group, ov) in enumerate( zip(stage_names, resolved_param_groups, stage_overrides) ): print(f" {name}: {group} overrides: {ov}") # Store for summary or introspection self._stage_param_groups = resolved_param_groups self._stage_names = stage_names params = deepcopy(params or self.params) # Setup tqdm iterator try: from tqdm.notebook import tqdm as notebook_tqdm if "ipykernel" in sys.modules: iterator = notebook_tqdm( zip(stage_names, resolved_param_groups, stage_overrides), desc="Rietveld Fit", disable=not progress_bar, total=len(stage_names), ) else: iterator = tqdm( zip(stage_names, resolved_param_groups, stage_overrides), desc="Rietveld Fit", disable=not progress_bar, total=len(stage_names), ) except ImportError: iterator = tqdm( zip(stage_names, resolved_param_groups, stage_overrides), desc="Rietveld Fit", disable=not progress_bar, total=len(stage_names), ) stage_results = [] stage_summaries = [] # Lists to collect final stages (including pick-one isotope tests) final_stage_results = [] final_stage_names = [] final_resolved_param_groups = [] cumulative_params = ( set() ) # Track parameters that have been refined (accumulative Rietveld) def extract_pickleable_attributes(fit_result): safe_attrs = [ "params", "success", "residual", "chisqr", "redchi", "aic", "bic", "nvarys", "ndata", "nfev", "message", "lmdif_message", "cov_x", "method", "flatchain", "errorbars", "ci_out", ] class PickleableResult: pass result = PickleableResult() for attr in safe_attrs: if hasattr(fit_result, attr): try: value = getattr(fit_result, attr) pickle.dumps(value) setattr(result, attr, value) except (TypeError, ValueError, AttributeError): if verbose: print(f"Skipping non-pickleable attribute: {attr}") continue return result for stage_idx, (stage_name, group, overrides) in enumerate(iterator): stage_num = stage_idx + 1 # Use overrides or fallback to global emin, emax stage_emin = overrides.get("emin", emin) stage_emax = overrides.get("emax", emax) if verbose: print( f"\n{stage_name}: Fitting parameters {group} with energy range [{stage_emin}, {stage_emax}]" ) # Filter data for this stage if isinstance(data, pandas.DataFrame): stage_data = data.query(f"{stage_emin} < energy < {stage_emax}") energies = stage_data["energy"].values trans = stage_data["trans"].values weights = kwargs.get("weights", 1.0 / stage_data["err"].values) elif isinstance(data, Data): stage_data = data.table.query(f"{stage_emin} < energy < {stage_emax}") energies = stage_data["energy"].values trans = stage_data["trans"].values weights = kwargs.get("weights", 1.0 / stage_data["err"].values) else: raise ValueError("Rietveld fitting requires energy-based input data.") # Check if pick-one mode is enabled for this stage if overrides.get("pick_one", False): if verbose: print("\n Pick-one mode: Testing each isotope individually...") # Get isotope names from cross-section weights isotope_names = list(self.cross_section.weights.index) isotope_names = [name.replace("-", "") for name in isotope_names] # Get the p parameters (free weight parameters) p_params = [p for p in params.keys() if re.compile(r"p\d+").match(p)] if len(isotope_names) <= 1: warnings.warn( f"Pick-one mode requires at least 2 isotopes, but found {len(isotope_names)}. Skipping pick-one." ) elif not p_params: warnings.warn( "Pick-one mode requires weight parameters (p1, p2, ...), but none found. Skipping pick-one." ) else: # Store results for each isotope test isotope_results = [] # Test each isotope for iso_idx, isotope_name in enumerate(isotope_names): if verbose: print(f" Testing {isotope_name}...") # Create a copy of params for this test test_params = deepcopy(params) # Set this isotope to weight=1, others to weight=0 # For isotope i (i < N-1): set p_i = 14 (max), others = -14 (min) # For isotope N-1 (last): set all p_j = -14 (min) for j, p_name in enumerate(p_params): if iso_idx < len(p_params): # One of the first N-1 isotopes if j == iso_idx: test_params[ p_name ].value = 14.0 # This isotope dominates else: test_params[ p_name ].value = -14.0 # Others suppressed else: # Last isotope (N-1) test_params[ p_name ].value = -14.0 # All p's minimal -> last weight = 1 # Set all parameters to not vary (fixed for this test) for p in test_params.values(): p.vary = False # Vary all cumulative parameters from previous stages (Rietveld methodology) # This ensures parameters like 'thickness' from earlier stages remain active for param_name in cumulative_params: if param_name in test_params and param_name not in p_params: test_params[param_name].vary = True # Also vary non-weight parameters from current group non_weight_params = [ p for p in group if p not in p_params and not re.compile(r"p\d+").match(p) ] for param_name in non_weight_params: if param_name in test_params: test_params[param_name].vary = True # Perform test fit # Filter out kwargs that lmfit doesn't understand lmfit_kwargs = { k: v for k, v in kwargs.items() if k not in ["n_cores", "n_jobs", "max_nbytes", "progress_bar"] } try: test_fit = super().fit( trans, params=test_params, E=energies, weights=weights, method="leastsq", **lmfit_kwargs, ) isotope_results.append( { "isotope": isotope_name, "index": iso_idx, "redchi": test_fit.redchi, "params": test_fit.params, } ) if verbose: print( f" {isotope_name}: χ²/dof = {test_fit.redchi:.4f}" ) # Add this isotope test as a separate stage in the final results final_stage_results.append(test_fit) final_resolved_param_groups.append(non_weight_params) final_stage_names.append( f"{stage_name} (test: {isotope_name})" ) except Exception as e: warnings.warn(f"Fitting failed for {isotope_name}: {e}") isotope_results.append( { "isotope": isotope_name, "index": iso_idx, "redchi": float("inf"), "params": None, } ) # Find the best isotope (lowest reduced chi-squared) best_result = min(isotope_results, key=lambda x: x["redchi"]) best_isotope = best_result["isotope"] best_idx = best_result["index"] if verbose: print( f"\n Best fit: {best_isotope} with χ²/dof = {best_result['redchi']:.4f}" ) # Update progress bar iterator.set_postfix( { "stage": stage_name, "best": best_isotope, "reduced χ²": f"{best_result['redchi']:.4g}", } ) # Set the weights to fix the best isotope at weight=1 if best_idx < len(p_params): # One of the first N-1 isotopes for j, p_name in enumerate(p_params): if j == best_idx: params[p_name].value = 14.0 else: params[p_name].value = -14.0 else: # Last isotope for p_name in p_params: params[p_name].value = -14.0 # Copy other fitted parameters from the best result if best_result["params"] is not None: non_weight_params = [ p for p in group if p not in p_params and not re.compile(r"p\d+").match(p) ] for param_name in non_weight_params: if ( param_name in params and param_name in best_result["params"] ): params[param_name].value = best_result["params"][ param_name ].value # The weights are now fixed, so we continue to the next stage # Add the stage to cumulative params (the non-weight params were fitted) cumulative_params.update( [p for p in group if not re.compile(r"p\d+").match(p)] ) # Create a fake fit_result for consistency class PickOneFitResult: def __init__(self, params, redchi): self.params = params self.redchi = redchi self.success = True self.residual = None self.chisqr = None self.aic = None self.bic = None self.nvarys = 0 self.ndata = len(energies) self.nfev = 0 self.message = f"Pick-one mode: selected {best_isotope}" self.lmdif_message = self.message self.cov_x = None self.method = "pick-one" self.flatchain = None self.errorbars = False self.ci_out = None iterator.set_description(f"Stage {stage_num}/{len(stage_names)}") if verbose: print( f" {stage_name} completed with pick-one. Selected {best_isotope}, χ²/dof = {best_result['redchi']:.4f}" ) # Skip the normal fitting for this stage (isotope tests already added to final lists) continue # Accumulate parameters across stages (True Rietveld approach) cumulative_params.update(group) # Freeze all parameters for p in params.values(): p.vary = False # Unfreeze current group # Note: group_map already filters out parameters with vary=False # Unfreeze all parameters that have been introduced so far unfrozen_count = 0 for name in cumulative_params: if name in params: params[name].vary = True unfrozen_count += 1 if verbose and name in group: print(f" New parameter: {name}") elif verbose: print(f" Continuing: {name}") else: if name in group: # Only warn for new parameters warnings.warn(f"Parameter '{name}' not found in params") if verbose: print(f" Total active parameters: {unfrozen_count}") if unfrozen_count == 0: warnings.warn( f"No parameters were unfrozen in {stage_name}. Skipping this stage." ) continue # Perform fitting # Filter out kwargs that lmfit doesn't understand lmfit_kwargs = { k: v for k, v in kwargs.items() if k not in ["n_cores", "n_jobs", "max_nbytes", "progress_bar"] } try: with warnings.catch_warnings(): if not verbose: # Suppress lmfit warnings when not verbose warnings.filterwarnings( "ignore", category=UserWarning, module="lmfit" ) fit_result = super().fit( trans, params=params, E=energies, weights=weights, method="leastsq", **lmfit_kwargs, ) except Exception as e: if verbose: warnings.warn(f"Fitting failed in {stage_name}: {e}") continue # Extract pickleable part stripped_result = extract_pickleable_attributes(fit_result) stage_results.append(stripped_result) # Also add to final results (for stages_summary) final_stage_results.append(stripped_result) final_stage_names.append(stage_name) final_resolved_param_groups.append(group) # Build summary varied_params = list(cumulative_params) # Track cumulative parameters varied_params = list(cumulative_params) # Track cumulative parameters summary = { "stage": stage_num, "stage_name": stage_name, "fitted_params": group, "emin": stage_emin, "emax": stage_emax, "redchi": fit_result.redchi, } for name, par in fit_result.params.items(): summary[f"{name}_value"] = par.value summary[f"{name}_stderr"] = par.stderr summary[f"{name}_vary"] = ( name in varied_params ) # Mark as vary if in cumulative set stage_summaries.append(summary) iterator.set_description(f"Stage {stage_num}/{len(stage_names)}") iterator.set_postfix( {"stage": stage_name, "reduced χ²": f"{fit_result.redchi:.4g}"} ) # Update params for next stage params = fit_result.params if verbose: print(f" {stage_name} completed. χ²/dof = {fit_result.redchi:.4f}") if not stage_results: raise RuntimeError("No successful fitting stages completed") self.fit_result = fit_result self.fit_stages = stage_results # Use final lists (which include pick-one isotope tests) for stages_summary self.stages_summary = self._create_stages_summary_table_enhanced( final_stage_results if final_stage_results else stage_results, final_resolved_param_groups if final_resolved_param_groups else resolved_param_groups, final_stage_names if final_stage_names else stage_names, ) # Attach plotting methods and other attributes fit_result.plot = self.plot fit_result.plot_stage_progression = self.plot_stage_progression fit_result.plot_chi2_progression = self.plot_chi2_progression if self.response is not None: fit_result.response = self.response fit_result.response.params = fit_result.params if self.background is not None: fit_result.background = self.background fit_result.stages_summary = self.stages_summary fit_result.show_available_params = self.show_available_params fit_result.save = lambda filename, include_model=True: self._save_result( fit_result, filename, include_model ) return fit_result def _create_stages_summary_table_enhanced( self, stage_results, resolved_param_groups, stage_names=None, color=True ): import numpy as np import pandas as pd # --- Build the DataFrame --- all_param_names = list(stage_results[-1].params.keys()) stage_data = {} if stage_names is None: stage_names = [f"Stage_{i+1}" for i in range(len(stage_results))] cumulative_params = set() # Track cumulative parameters for Rietveld method for stage_idx, stage_result in enumerate(stage_results): stage_col = ( stage_names[stage_idx] if stage_idx < len(stage_names) else f"Stage_{stage_idx + 1}" ) stage_data[stage_col] = {"value": {}, "stderr": {}, "vary": {}} # Accumulate parameters across stages cumulative_params.update(resolved_param_groups[stage_idx]) varied_in_stage = cumulative_params.copy() for param_name in all_param_names: if param_name in stage_result.params: param = stage_result.params[param_name] stage_data[stage_col]["value"][param_name] = param.value stage_data[stage_col]["stderr"][param_name] = ( param.stderr if param.stderr is not None else np.nan ) stage_data[stage_col]["vary"][param_name] = ( param_name in varied_in_stage ) else: stage_data[stage_col]["value"][param_name] = np.nan stage_data[stage_col]["stderr"][param_name] = np.nan stage_data[stage_col]["vary"][param_name] = False redchi = stage_result.redchi if hasattr(stage_result, "redchi") else np.nan stage_data[stage_col]["value"]["redchi"] = redchi stage_data[stage_col]["stderr"]["redchi"] = np.nan stage_data[stage_col]["vary"]["redchi"] = np.nan # Create DataFrame data_for_df = {} for stage_col in stage_data: for metric in ["value", "stderr", "vary"]: data_for_df[(stage_col, metric)] = stage_data[stage_col][metric] df = pd.DataFrame(data_for_df) df.columns = pd.MultiIndex.from_tuples(df.columns, names=["Stage", "Metric"]) all_param_names_with_redchi = all_param_names + ["redchi"] df = df.reindex(all_param_names_with_redchi) # --- Add initial values column --- initial_values = {} for param_name in all_param_names: initial_values[param_name] = ( self.params[param_name].value if param_name in self.params else np.nan ) initial_values["redchi"] = np.nan initial_df = pd.DataFrame({("Initial", "value"): initial_values}) df = pd.concat([initial_df, df], axis=1) if not color: return df styler = df.style # 1) Highlight vary=True cells (light green for accumulative Rietveld) vary_cols = [col for col in df.columns if col[1] == "vary"] def highlight_vary(s): return ["background-color: lightgreen" if v is True else "" for v in s] for col in vary_cols: styler = styler.apply(highlight_vary, subset=[col], axis=0) # 2) Highlight redchi row's value cells (moccasin) def highlight_redchi_row(row): if row.name == "redchi": return [ "background-color: moccasin" if col[1] == "value" else "" for col in df.columns ] return ["" for _ in df.columns] styler = styler.apply(highlight_redchi_row, axis=1) # 3) Highlight value cells by fractional change with red hues (ignore <1%) value_cols = [col for col in df.columns if col[1] == "value"] # Calculate % absolute change between consecutive columns (Initial → Stage1 → Stage2 ...) changes = pd.DataFrame(index=df.index, columns=value_cols, dtype=float) prev_col = None for col in value_cols: if prev_col is None: # No previous for initial column, so zero changes here changes[col] = 0.0 else: prev_vals = df[prev_col].astype(float) curr_vals = df[col].astype(float) with np.errstate(divide="ignore", invalid="ignore"): pct_change = np.abs((curr_vals - prev_vals) / prev_vals) * 100 pct_change = pct_change.replace([np.inf, -np.inf], np.nan).fillna(0.0) changes[col] = pct_change prev_col = col max_change = changes.max().max() # Normalize by max change, to get values in [0,1] norm_changes = changes / max_change if max_change > 0 else changes def red_color(val): # Ignore changes less than 1% if pd.isna(val) or val < 1: return "" # val in [0,1], map to red intensity # 0 -> white (255,255,255) # 1 -> dark red (255,100,100) r = 255 g = int(255 - 155 * val) b = int(255 - 155 * val) return f"background-color: rgb({r},{g},{b})" for col in value_cols: styler = styler.apply( lambda s: [red_color(v) for v in norm_changes[col]], subset=[col], axis=0, ) return styler def _fit_grouped( self, data, params=None, emin: float = 0.5e6, emax: float = 20.0e6, method: str = "rietveld", xtol: float = None, ftol: float = None, gtol: float = None, verbose: bool = False, progress_bar: bool = True, param_groups: Optional[List[List[str]]] = None, n_jobs: int = 10, max_nbytes: str = "100M", **kwargs, ): """ Fit model to grouped data in parallel. Parameters: ----------- data : Data Grouped data object with is_grouped=True. params : lmfit.Parameters, optional Parameters to use for fitting. emin, emax : float Energy range for fitting. method : str Fitting method: "least-squares" or "rietveld". xtol, ftol, gtol : float, optional Convergence tolerances. verbose : bool Show progress for individual fits. progress_bar : bool Show overall progress bar. param_groups : list or dict, optional Fitting stages configuration for rietveld. n_jobs : int Number of parallel jobs (default: 10). Use -1 for all CPUs, but be aware this can cause memory issues with large datasets. For threading backend, consider n_jobs=4 or less for better performance. max_nbytes : str Maximum memory per worker (default: '100M'). Limits memory usage to prevent system freezes. Increase (e.g., '500M') for complex models, or set to None to disable memory limits. **kwargs Additional arguments passed to fit. Returns: -------- GroupedFitResult Container with fit results for each group. """ import time from joblib import Parallel, delayed from nres.grouped_fit import GroupedFitResult try: from tqdm.auto import tqdm except ImportError: from tqdm import tqdm # Prepare fit arguments fit_kwargs = { "params": params, "emin": emin, "emax": emax, "method": method, "xtol": xtol, "ftol": ftol, "gtol": gtol, "verbose": verbose if verbose else False, "progress_bar": verbose, # Show individual progress bars when verbose=True "param_groups": param_groups, **kwargs, } def fit_single_group(idx): """Fit a single group using threading.""" from nres.data import Data group_data = Data() group_data.table = data.groups[idx] group_data.L = data.L group_data.tstep = data.tstep try: result = self.fit(group_data, **fit_kwargs) except Exception as e: if verbose: print(f"Error fitting group {idx}: {e}") result = None return idx, result start_time = time.time() # Execute with threading (or multiprocessing if n_jobs != 1) backend = "threading" if n_jobs > 0 else "loky" # Warn about performance with high n_jobs in threading mode if backend == "threading" and n_jobs > 4 and verbose: print( f"Warning: Using {n_jobs} threads. Consider n_jobs=4 or less for better performance." ) print( f" High thread counts can cause memory issues. Current limit: {max_nbytes} per worker." ) # Execute parallel fitting with proper progress bar if progress_bar: import sys pbar = tqdm( total=len(data.indices), desc=f"Fitting {len(data.indices)} groups", mininterval=0.05, # Update display every 50ms minimum maxinterval=1.0, # Force update at least every second smoothing=0.05, # Less smoothing for more responsive updates file=sys.stderr, # Write to stderr (unbuffered) dynamic_ncols=True, # Adjust to terminal width leave=True, # Keep the bar after completion ) results = [] for result in Parallel( n_jobs=n_jobs, backend=backend, verbose=5 if verbose else 0, return_as="generator", max_nbytes=max_nbytes, )(delayed(fit_single_group)(idx) for idx in data.indices): results.append(result) pbar.update(1) # Force immediate display update pbar.refresh() sys.stderr.flush() pbar.close() else: results = Parallel( n_jobs=n_jobs, backend=backend, verbose=5 if verbose else 0, max_nbytes=max_nbytes, )(delayed(fit_single_group)(idx) for idx in data.indices) elapsed = time.time() - start_time if verbose: print( f"Completed in {elapsed:.2f}s using '{backend}' backend | {elapsed/len(data.indices):.3f}s per fit" ) # Collect results grouped_result = GroupedFitResult(group_shape=data.group_shape) failed_indices = [] for idx, result in results: if result is not None: grouped_result.add_result(idx, result) else: failed_indices.append(idx) if failed_indices and verbose: import warnings warnings.warn( f"Fitting failed for {len(failed_indices)}/{len(data.indices)} groups. " f"Failed indices: {failed_indices[:10]}{'...' if len(failed_indices) > 10 else ''}" ) return grouped_result
[docs] def show_available_params(self, show_groups=True, show_params=True): """ Display available parameter groups and individual parameters for Rietveld fitting. Parameters ---------- show_groups : bool, optional If True, show predefined parameter groups show_params : bool, optional If True, show all individual parameters """ import re if show_groups: print("Available parameter groups:") print("=" * 30) # Only show parameters that have vary=True group_map = { "basic": [ p for p in ["norm", "thickness"] if p in self.params and self.params[p].vary ], "background": [ p for p in self.params if (re.compile(r"(b|bg)\d+").match(p) or p.startswith("b_")) and self.params[p].vary ], "tof": [ p for p in ["L0", "t0", "t1", "t2"] if p in self.params and self.params[p].vary ], "response": [ p for p in self.params if self.response and p in self.response.params and self.params[p].vary ], "weights": [ p for p in self.params if re.compile(r"p\d+").match(p) and self.params[p].vary ], } for group_name, params in group_map.items(): if params: # Only show groups with available parameters print(f" '{group_name}': {params}") if show_params: if show_groups: print("\nAll individual parameters:") print("=" * 30) else: print("Available parameters:") print("=" * 20) for param_name, param in self.params.items(): vary_status = "vary" if param.vary else "fixed" print(f" {param_name}: {param.value:.6g} ({vary_status})") print("\nExample usage:") print("=" * 15) print("# Using predefined groups:") print('param_groups = ["basic", "background", "extinction"]') print("\n# Using individual parameters:") print('param_groups = [["norm", "thickness"], ["b0", "ext_l2"]]') print("\n# Using named stages:") print( 'param_groups = {"scale": ["norm"], "sample": ["thickness", "extinction"]}' ) print("\n# Mixed approach:") print('param_groups = ["basic", ["b0", "ext_l2"], "lattice"]')
[docs] def plot( self, data: nres.Data = None, plot_bg: bool = True, correct_tof: bool = True, stage: int = None, index=None, **kwargs, ): """ Plot the results of the fit or model. Parameters ---------- data : nres.Data, optional Show data alongside the model (useful before performing the fit). plot_bg : bool, optional Whether to include the background in the plot, by default True. correct_tof : bool, optional Apply TOF correction if L0 and t0 parameters are present, by default True. stage: int, optional If provided, plot results from a specific Rietveld fitting stage (1-indexed). Only works if Rietveld fitting has been performed. index : int, tuple, or str, optional For grouped data, specify which group to plot. - For 2D grids: can use tuple (0, 0) or string "(0, 0)" - For 1D arrays: can use int 5 or string "5" - For named groups: use string "groupname" If None and data is grouped, raises an error. kwargs : dict, optional Additional plot settings like color, marker size, etc. Returns ------- matplotlib.axes.Axes The axes of the plot. """ # Handle grouped data if data is not None and hasattr(data, "is_grouped") and data.is_grouped: if index is None: raise ValueError( "Data is grouped. Please specify which group to plot using the 'index' parameter.\n" f"Available indices: {data.indices}" ) # Extract the specific group from nres.data import Data normalized_index = data._normalize_index(index) if normalized_index not in data.groups: raise ValueError( f"Index {index} not found. Available indices: {data.indices}" ) # Create a non-grouped Data object for this specific group group_data = Data() group_data.table = data.groups[normalized_index] group_data.L = data.L group_data.tstep = data.tstep group_data.is_grouped = False data = group_data fig, ax = plt.subplots( 2, 1, sharex=True, height_ratios=[3.5, 1], figsize=(6, 5) ) data_object = data.table.dropna().copy() if data else None if stage is not None and hasattr(self, "fit_stages") and self.fit_stages: # Use specific stage results if stage < 1 or stage > len(self.fit_stages): raise ValueError( f"Stage {stage} not available. Available stages: 1-{len(self.fit_stages)}" ) # Get stage results stage_result = self.fit_stages[stage - 1] # Convert to 0-indexed # We need to reconstruct the fit data from the original fit if hasattr(self, "fit_result") and self.fit_result is not None: energy = self.fit_result.userkws["E"] data_values = self.fit_result.data err = 1.0 / self.fit_result.weights else: raise ValueError("Cannot plot stage results without original fit data") # Use stage parameters to evaluate model params = stage_result.params best_fit = self.eval(params=params, E=energy) residual = (data_values - best_fit) / err chi2 = ( stage_result.redchi if hasattr(stage_result, "redchi") else np.sum(residual**2) / (len(data_values) - len(params)) ) fit_label = f"Stage {stage} fit" elif hasattr(self, "fit_result"): # Use final fit results energy = self.fit_result.userkws["E"] data_values = self.fit_result.data err = 1.0 / self.fit_result.weights best_fit = self.fit_result.best_fit residual = self.fit_result.residual params = self.fit_result.params chi2 = self.fit_result.redchi fit_label = "Best fit" else: # Use model (no fit yet) fit_label = "Model" params = self.params if data is not None: energy = data_object["energy"] data_values = data_object["trans"] err = data_object["err"] best_fit = self.eval(params=params, E=energy.values) residual = (data_values - best_fit) / err # Calculate chi2 for the model chi2 = np.sum(((data_values - best_fit) / err) ** 2) / ( len(data_values) - len(params) ) else: energy = self.cross_section.table.dropna().index.values data_values = np.nan * np.ones_like(energy) err = np.nan * np.ones_like(energy) best_fit = self.eval(params=params, E=energy) residual = np.nan * np.ones_like(energy) chi2 = np.nan # Apply TOF correction if enabled and L0, t0 parameters are present if correct_tof and "L0" in params and "t0" in params: L0 = params["L0"].value t0 = params["t0"].value t1 = params["t1"].value if "t1" in params else 0.0 t2 = params["t2"].value if "t2" in params else 0.0 energy = self._tof_correction(energy, L0=L0, t0=t0, t1=t1, t2=t2) # Plot settings color = kwargs.pop("color", "seagreen") title = kwargs.pop("title", self.cross_section.name) ecolor = kwargs.pop("ecolor", "0.8") ms = kwargs.pop("ms", 2) # Plot data and best-fit/model ax[0].errorbar( energy, data_values, err, marker="o", color=color, ms=ms, zorder=-1, ecolor=ecolor, label="Data", ) ax[0].plot(energy, best_fit, color="0.2", label=fit_label) ax[0].set_ylabel("Transmission") ax[0].set_title(title) # Plot residuals ax[1].plot(energy, residual, color=color) ax[1].set_ylabel("Residuals [1σ]") ax[1].set_xlabel("Energy [eV]") # Plot background if requested if plot_bg and self.background.params: self.background.plot(E=energy, ax=ax[0], params=params, **kwargs) legend_labels = [fit_label, "Background", "Data"] else: legend_labels = [fit_label, "Data"] # Set legend with chi2 value ax[0].legend( legend_labels, fontsize=9, reverse=True, title=f"χ$^2$: {chi2:.2f}" ) plt.subplots_adjust(hspace=0.05) return ax
[docs] def plot_stage_progression(self, stages: list = None, **kwargs): """ Plot the progression of Rietveld refinement stages showing how the fit improves. """ import matplotlib.pyplot as plt import numpy as np if not hasattr(self, "fit_stages") or not self.fit_stages: raise ValueError( "No Rietveld stages available. Run fit with method='rietveld' first." ) if stages is None: stages = list(range(1, len(self.fit_stages) + 1)) # Original data if hasattr(self, "fit_result") and self.fit_result is not None: energy = self.fit_result.userkws["E"] data_values = self.fit_result.data err = 1.0 / self.fit_result.weights else: raise ValueError("Cannot plot stage progression without original fit data") fig, ax = plt.subplots(figsize=(6, 4)) # Match style: light gray points for data ax.errorbar( energy, data_values, err, marker="o", color="0.6", ms=2, alpha=0.7, zorder=-1, ecolor="0.85", label="Data", ) # Use consistent style palette colors = plt.cm.plasma(np.linspace(0, 0.85, len(stages))) for i, stage in enumerate(stages): if stage < 1 or stage > len(self.fit_stages): continue stage_result = self.fit_stages[stage - 1] params = stage_result.params best_fit = self.eval(params=params, E=energy) chi2 = getattr(stage_result, "redchi", np.nan) # Get stage name if available stage_name = f"Stage {stage}" if hasattr(self, "stages_summary"): stage_col = f"Stage_{stage}" if (stage_col, "vary") in self.stages_summary.columns: varied_params = self.stages_summary.loc[ self.stages_summary[(stage_col, "vary")] == True ].index.tolist() varied_params = [p for p in varied_params if p != "redchi"] if varied_params: stage_name = ", ".join(varied_params[:2]) + ( f" +{len(varied_params)-2}" if len(varied_params) > 2 else "" ) ax.plot( energy, best_fit, color=colors[i], lw=1.2 + 0.4 * i, alpha=0.8, label=f"{stage_name} (χ²={chi2:.3f})" if not np.isnan(chi2) else stage_name, ) ax.set_xlabel("Energy [eV]") ax.set_ylabel("Transmission") ax.set_title("Rietveld Refinement Stage Progression") ax.legend(fontsize=8, frameon=False) plt.tight_layout() return ax
[docs] def plot_chi2_progression(self, **kwargs): """ Plot the χ² progression through Rietveld stages with stage names on x-axis. """ import matplotlib.pyplot as plt import numpy as np if not hasattr(self, "fit_stages") or not self.fit_stages: raise ValueError( "No Rietveld stages available. Run fit with method='rietveld' first." ) stages = list(range(1, len(self.fit_stages) + 1)) chi2_values = [] stage_labels = [] for stage in stages: stage_result = self.fit_stages[stage - 1] chi2 = getattr(stage_result, "redchi", np.nan) chi2_values.append(chi2) label = f"Stage {stage}" if hasattr(self, "stages_summary"): stage_col = f"Stage_{stage}" if (stage_col, "vary") in self.stages_summary.columns: varied_params = self.stages_summary.loc[ self.stages_summary[(stage_col, "vary")] == True ].index.tolist() varied_params = [p for p in varied_params if p != "redchi"] if varied_params: label = ", ".join(varied_params[:2]) + ( f" +{len(varied_params)-2}" if len(varied_params) > 2 else "" ) stage_labels.append(label) fig, ax = plt.subplots(figsize=(6, 3.5)) ax.plot(stages, chi2_values, marker="o", lw=2, color="seagreen") # Annotate each point for stage, chi2 in zip(stages, chi2_values): if not np.isnan(chi2): ax.annotate( f"{chi2:.3f}", (stage, chi2), textcoords="offset points", xytext=(0, 8), ha="center", fontsize=8, ) ax.set_xlabel("Refinement Stage") ax.set_ylabel("Reduced χ²") ax.set_title("Rietveld χ² Progression") # Stage names at bottom ax.set_xticks(stages) ax.set_xticklabels(stage_labels, rotation=30, ha="right", fontsize=8) plt.tight_layout() return ax
[docs] def get_stages_summary_table(self): """ Get the stages summary table showing parameter progression through refinement stages. Returns ------- pandas.DataFrame Multi-index DataFrame with parameters as rows and stages as columns. Each stage has columns for 'value', 'stderr', 'vary', and 'redchi'. """ if not hasattr(self, "stages_summary"): raise ValueError( "No stages summary available. Run fit with method='rietveld' first." ) return self.stages_summary
[docs] def weighted_thickness(self, params=None): """Returns the weighted thickness in [cm] Args: params (lmfit.Parameters, optional): parameters object. Defaults to None. """ weights = self.cross_section.weights if params: thickness = params["thickness"].value elif hasattr(self, "fit_result"): thickness = self.fit_result.values["thickness"] else: thickness = self.params["thickness"].value return thickness * weights
def _make_tof_params( self, vary: bool = False, kind: str = "linear", L0: float = 1.0, t0: float = 0.0, t1: float = 0.0, t2: float = 0.0, ): """ Create time-of-flight (TOF) parameters for the model. Parameters ---------- vary : bool, optional Whether to allow these parameters to vary during fitting, by default False. kind : str, optional The type of TOF correction to apply, by default "linear". other options are "full" to include the energy dependent corrections. L0 : float, optional Initial flight path distance scale parameter, by default 1. t0 : float, optional Initial time offset parameter, by default 0. t1 : float, optional Initial linear correction parameter, by default 0. t2 : float, optional Initial logarithmic correction parameter, by default 0. Returns ------- lmfit.Parameters The TOF-related parameters. """ params = lmfit.Parameters() params.add("L0", value=L0, min=0.5, max=1.5, vary=vary) params.add("t0", value=t0, vary=vary) if kind == "full": params.add("t1", value=t1, vary=vary) params.add("t2", value=t2, vary=vary) return params def _make_weight_params(self, vary: bool = False): """ Create lmfit parameters based on initial isotope weights. Parameters ---------- vary : bool, optional Whether to allow weights to vary during fitting, by default False. Returns ------- lmfit.Parameters The normalized weight parameters for the model. """ params = lmfit.Parameters() weight_series = deepcopy(self.cross_section.weights) weight_series.index = weight_series.index.str.replace("-", "") param_names = weight_series.index N = len(weight_series) # Normalize the input weights to sum to 1 weights = np.array(weight_series / weight_series.sum(), dtype=np.float64) if N == 1: # Special case: if N=1, the weight is always 1 params.add(f"{param_names[0]}", value=1.0, vary=False) else: last_weight = weights[-1] # Add (N-1) free parameters corresponding to the first (N-1) items for i, name in enumerate(param_names[:-1]): initial_value = weights[i] # Use weight values params.add( f"p{i+1}", value=np.log(weights[i] / last_weight), min=-14, max=14, vary=vary, ) # limit to 1ppm # Define the normalization expression normalization_expr = " + ".join([f"exp(p{i+1})" for i in range(N - 1)]) # Add weights based on the free parameters for i, name in enumerate(param_names[:-1]): params.add(f"{name}", expr=f"exp(p{i+1}) / (1 + {normalization_expr})") # The last weight is 1 minus the sum of the previous weights params.add(f"{param_names[-1]}", expr=f"1 / (1 + {normalization_expr})") return params
[docs] def set_cross_section( self, xs: CrossSection, inplace: bool = True ) -> TransmissionModel: """ Set a new cross-section for the model. Parameters ---------- xs : CrossSection The new cross-section to apply. inplace : bool, optional If True, modify the current object. If False, return a new modified object, by default True. Returns ------- TransmissionModel The updated model (either modified in place or a new instance). """ if inplace: self.cross_section = xs params = self._make_weight_params() self.params += params return self new_self = deepcopy(self) new_self.cross_section = xs params = new_self._make_weight_params() new_self.params += params return new_self
[docs] def update_params( self, params: dict = {}, values_only: bool = True, inplace: bool = True ): """ Update the parameters of the model. Parameters ---------- params : dict Dictionary of new parameters to update. values_only : bool, optional If True, update only the values of the parameters, by default True. inplace : bool, optional If True, modify the current object. If False, return a new modified object, by default True. """ if inplace: if values_only: for param in params: self.params[param].set(value=params[param].value) else: self.params = params else: new_self = deepcopy(self) if values_only: for param in params: new_self.params[param].set(value=params[param].value) else: new_self.params = params return new_self # Ensure a return statement in the non-inplace scenario.
[docs] def vary_all(self, vary: Optional[bool] = None, except_for: List[str] = []): """ Toggle the 'vary' attribute for all model parameters. Parameters ---------- vary : bool, optional The value to set for all parameters' 'vary' attribute. except_for : list of str, optional List of parameter names to exclude from this operation, by default []. """ if vary is not None: for param in self.params: if param not in except_for: self.params[param].set(vary=vary)
def _tof_correction( self, E, L0: float = 1.0, t0: float = 0.0, t1: float = 0.0, t2: float = 0.0, **kwargs, ): """ Apply a time-of-flight (TOF) correction to the energy values. Parameters ---------- E : float or array-like The energy values to correct. L0 : float, optional The scale factor for the flight path, by default 1.0. t0 : float, optional The time offset for the correction, by default 0.0. t1 : float, optional The linear correction factor, by default 0.0. t2 : float, optional The logarithmic correction factor, by default 0.0. kwargs : dict, optional Additional arguments (currently unused). Returns ------- np.ndarray The corrected energy values. """ tof = utils.energy2time(E, self.cross_section.L) dtof = (1.0 - L0) * tof + t0 + t1 * E + t2 * np.log(E) E = utils.time2energy(tof + dtof, self.cross_section.L) return E
[docs] def manually_calibrate_tof( self, inputs: Union[list, np.ndarray] = None, references: Union[list, np.ndarray] = None, input_type: str = "tof", reference_type: str = "energy", **kwargs, ): """ Manually calibrate time-of-flight (TOF) correction parameters. Parameters ---------- inputs : list or np.ndarray Input values for calibration (time-like values). references : list or np.ndarray Corresponding reference values for calibration (energy-like values). input_type : str, optional Type of input values. Options are: - 'tof': Direct time values in units of seconds - 'energy': Convert energy to time using utils.energy2time - 'slice': Convert slice indices to time by multiplying with tstep Default is 'tof'. reference_type : str, optional Type of reference values. Options are: - 'tof': Direct time values in units of seconds - 'energy': Convert energy to time using utils.energy2time - 'slice': Convert slice indices to time by multiplying with tstep Default is 'energy'. Returns ------- lmfit ModelResult object Detailed linear regression result with fitting information """ # Input validation if inputs is None or references is None: raise ValueError("Both inputs and references must be provided") # Convert inputs to numpy arrays inputs = np.array(inputs, dtype=float) references = np.array(references, dtype=float) # Validate input lengths if len(inputs) != len(references): raise ValueError( "Input values and reference values must have the same length" ) # Convert input values based on input_type if input_type == "energy": inputs = utils.energy2time(inputs, self.cross_section.L) elif input_type == "slice": inputs = inputs * self.cross_section.tstep elif input_type != "tof": raise ValueError("Invalid input_type. Must be 'tof', 'energy', or 'slice'") # Convert reference values based on input_type if reference_type == "energy": references = utils.energy2time(references, self.cross_section.L) elif reference_type == "slice": references = references * self.cross_section.tstep elif reference_type != "tof": raise ValueError( "Invalid reference_type. Must be 'tof', 'energy', or 'slice'" ) # Define the linear model using lmfit def linear_tof_correction(x, L0=1.0, t0=0.0): return L0 * x + t0 # Create the model model = lmfit.Model(linear_tof_correction) params = model.make_params() if len(inputs) == 1: params["L0"].vary = False # Perform the fit result = model.fit(inputs, params=params, x=references) # Update self.params with the calibration results self.params.set(t0=dict(value=result.params["t0"].value, vary=False)) self.params.set(L0=dict(value=result.params["L0"].value, vary=False)) return result
def save(self, filename: str): """ Save the model to a JSON file. Parameters ---------- filename : str Path to the JSON file where the model will be saved. Notes ----- The model is saved as JSON, which is portable and human-readable. The saved file can be loaded using the `TransmissionModel.load()` class method. Examples -------- >>> model = TransmissionModel(cross_section) >>> model.save("my_model.json") >>> loaded_model = TransmissionModel.load("my_model.json") """ import json # Serialize parameters params_dict = {} for name, param in self.params.items(): params_dict[name] = { "value": float(param.value), "vary": bool(param.vary), "min": float(param.min) if param.min is not None else None, "max": float(param.max) if param.max is not None else None, "expr": param.expr, } # Serialize cross-section xs_dict = { "name": self.cross_section.name, "materials": self.cross_section.materials, "L": float(self.cross_section.L), "tstep": float(self.cross_section.tstep), "tbins": int(self.cross_section.tbins), "first_tbin": int(self.cross_section.first_tbin), } # Serialize response parameters response_dict = None if self.response is not None: response_dict = { "params": { name: { "value": float(p.value), "vary": bool(p.vary), "min": float(p.min) if p.min is not None else None, "max": float(p.max) if p.max is not None else None, } for name, p in self.response.params.items() }, "tstep": float(self.response.tstep), "eps": float(self.response.eps), } # Serialize background parameters background_dict = None if self.background is not None: background_dict = { "params": { name: { "value": float(p.value), "vary": bool(p.vary), "min": float(p.min) if p.min is not None else None, "max": float(p.max) if p.max is not None else None, } for name, p in self.background.params.items() }, } # Create the model data dictionary model_data = { "version": "1.0", "type": "TransmissionModel", "cross_section": xs_dict, "response": response_dict, "background": background_dict, "params": params_dict, "n": float(self.n), } # Save to JSON file with open(filename, "w") as f: json.dump(model_data, f, indent=2) @classmethod def load(cls, filename: str) -> TransmissionModel: """ Load a model from a JSON file. Parameters ---------- filename : str Path to the JSON file containing the saved model. Returns ------- TransmissionModel The loaded model instance. Examples -------- >>> model = TransmissionModel.load("my_model.json") >>> result = model.fit(data) """ import json with open(filename) as f: model_data = json.load(f) # Reconstruct cross-section xs_data = model_data["cross_section"] xs = CrossSection() # Restore cross-section materials for mat_name, mat_info in xs_data["materials"].items(): xs.add_material( mat_name, mat_info, splitby=mat_info.get("splitby", "elements"), total_weight=mat_info.get("total_weight", 1.0), ) xs.name = xs_data["name"] xs.L = xs_data["L"] xs.tstep = xs_data["tstep"] xs.tbins = xs_data["tbins"] xs.first_tbin = xs_data["first_tbin"] # Determine response and background types from saved params response_kind = None background_kind = None if model_data["response"] is not None: # Infer response type from parameters response_kind = "expo_gauss" # Default, can be enhanced later if model_data["background"] is not None: # Infer background type from number of parameters n_bg_params = len(model_data["background"]["params"]) if n_bg_params == 3: background_kind = "polynomial3" elif n_bg_params == 5: background_kind = "polynomial5" # Create model instance model = cls( cross_section=xs, response=response_kind, background=background_kind, ) # Restore all parameter values for name, param_data in model_data["params"].items(): if name in model.params: model.params[name].set( value=param_data["value"], vary=param_data["vary"], min=param_data["min"], max=param_data["max"], expr=param_data["expr"], ) # Restore response parameters if model_data["response"] is not None and model.response is not None: for name, param_data in model_data["response"]["params"].items(): if name in model.response.params: model.response.params[name].set( value=param_data["value"], vary=param_data["vary"], min=param_data["min"], max=param_data["max"], ) # Restore background parameters if model_data["background"] is not None and model.background is not None: for name, param_data in model_data["background"]["params"].items(): if name in model.background.params: model.background.params[name].set( value=param_data["value"], vary=param_data["vary"], min=param_data["min"], max=param_data["max"], ) model.n = model_data["n"] return model def _save_result(self, result, filename: str, include_model: bool = True): """ Save a fit result to a JSON file. Parameters ---------- result : lmfit.model.ModelResult The fit result to save. filename : str Path to the JSON file where the result will be saved. include_model : bool, optional If True, saves the full model with the result. If False, saves only a compressed result with fit parameters. Default is True. """ import json import numpy as np # Serialize fit parameters params_dict = {} for name, param in result.params.items(): params_dict[name] = { "value": float(param.value), "stderr": float(param.stderr) if param.stderr is not None else None, "vary": bool(param.vary), "min": float(param.min) if param.min is not None else None, "max": float(param.max) if param.max is not None else None, "expr": param.expr, } # Serialize fit statistics result_dict = { "version": "1.0", "type": "FitResult", "params": params_dict, "success": bool(result.success), "chisqr": float(result.chisqr), "redchi": float(result.redchi), "aic": float(result.aic) if hasattr(result, "aic") else None, "bic": float(result.bic) if hasattr(result, "bic") else None, "nvarys": int(result.nvarys), "ndata": int(result.ndata), "nfev": int(result.nfev) if hasattr(result, "nfev") else None, "message": result.message if hasattr(result, "message") else None, } # Optionally include the model if include_model: # Temporarily save model to get its JSON representation import os import tempfile with tempfile.NamedTemporaryFile( mode="w", suffix=".json", delete=False ) as temp_f: temp_filename = temp_f.name try: self.save(temp_filename) with open(temp_filename) as f: model_dict = json.load(f) result_dict["model"] = model_dict finally: if os.path.exists(temp_filename): os.remove(temp_filename) # Save to JSON file with open(filename, "w") as f: json.dump(result_dict, f, indent=2) @classmethod def load_result(cls, filename: str, model: TransmissionModel = None): """ Load a fit result from a JSON file. Parameters ---------- filename : str Path to the JSON file containing the saved result. model : TransmissionModel, optional Model to use for the result. If None and the file contains a model, it will be loaded from the file. If the file doesn't contain a model, this parameter is required. Returns ------- tuple A tuple containing (model, params_dict) where model is the TransmissionModel instance and params_dict contains the fit parameters and statistics. Examples -------- >>> model, result_data = TransmissionModel.load_result("my_result.json") >>> print(result_data["redchi"]) >>> # Or with compressed result >>> model, result_data = TransmissionModel.load_result( ... "result.json", model=my_model ... ) """ import json import os import tempfile with open(filename) as f: result_data = json.load(f) # Load or use provided model if "model" in result_data: # Full result with embedded model with tempfile.NamedTemporaryFile( mode="w", suffix=".json", delete=False ) as temp_f: json.dump(result_data["model"], temp_f) temp_filename = temp_f.name try: model = cls.load(temp_filename) finally: if os.path.exists(temp_filename): os.remove(temp_filename) elif model is None: raise ValueError( "Model not found in file and no model provided. " "Either save with include_model=True or provide a model parameter." ) # Update model parameters with fit results for name, param_data in result_data["params"].items(): if name in model.params: model.params[name].set( value=param_data["value"], vary=param_data.get("vary", True), min=param_data.get("min"), max=param_data.get("max"), expr=param_data.get("expr"), ) # Return model and result dictionary return model, result_data
[docs] def save(self, filename: str): """ Save the model to a JSON file. Parameters ---------- filename : str Path to the JSON file where the model will be saved. Notes ----- The model is saved as JSON, which is portable and human-readable. The saved file can be loaded using the `TransmissionModel.load()` class method. Examples -------- >>> model = TransmissionModel(cross_section) >>> model.save("my_model.json") >>> loaded_model = TransmissionModel.load("my_model.json") """ import json # Serialize parameters params_dict = {} for name, param in self.params.items(): params_dict[name] = { "value": float(param.value), "vary": bool(param.vary), "min": float(param.min) if param.min is not None else None, "max": float(param.max) if param.max is not None else None, "expr": param.expr, } # Serialize cross-section xs_dict = { "name": self.cross_section.name, "materials": self.cross_section.materials, "L": float(self.cross_section.L), "tstep": float(self.cross_section.tstep), "tbins": int(self.cross_section.tbins), "first_tbin": int(self.cross_section.first_tbin), } # Serialize response parameters response_dict = None if self.response is not None: response_dict = { "params": { name: { "value": float(p.value), "vary": bool(p.vary), "min": float(p.min) if p.min is not None else None, "max": float(p.max) if p.max is not None else None, } for name, p in self.response.params.items() }, "tstep": float(self.response.tstep), "eps": float(self.response.eps), } # Serialize background parameters background_dict = None if self.background is not None: background_dict = { "params": { name: { "value": float(p.value), "vary": bool(p.vary), "min": float(p.min) if p.min is not None else None, "max": float(p.max) if p.max is not None else None, } for name, p in self.background.params.items() }, } # Create the model data dictionary model_data = { "version": "1.0", "type": "TransmissionModel", "cross_section": xs_dict, "response": response_dict, "background": background_dict, "params": params_dict, "n": float(self.n), } # Save to JSON file with open(filename, "w") as f: json.dump(model_data, f, indent=2)
[docs] @classmethod def load(cls, filename: str) -> TransmissionModel: """ Load a model from a JSON file. Parameters ---------- filename : str Path to the JSON file containing the saved model. Returns ------- TransmissionModel The loaded model instance. Examples -------- >>> model = TransmissionModel.load("my_model.json") >>> result = model.fit(data) """ import json with open(filename) as f: model_data = json.load(f) # Reconstruct cross-section xs_data = model_data["cross_section"] xs = CrossSection() # Restore cross-section materials for mat_name, mat_info in xs_data["materials"].items(): xs.add_material( mat_name, mat_info, splitby=mat_info.get("splitby", "elements"), total_weight=mat_info.get("total_weight", 1.0), ) xs.name = xs_data["name"] xs.L = xs_data["L"] xs.tstep = xs_data["tstep"] xs.tbins = xs_data["tbins"] xs.first_tbin = xs_data["first_tbin"] # Determine response and background types from saved params response_kind = None background_kind = None if model_data["response"] is not None: # Infer response type from parameters response_kind = "expo_gauss" # Default, can be enhanced later if model_data["background"] is not None: # Infer background type from number of parameters n_bg_params = len(model_data["background"]["params"]) if n_bg_params == 3: background_kind = "polynomial3" elif n_bg_params == 5: background_kind = "polynomial5" # Create model instance model = cls( cross_section=xs, response=response_kind, background=background_kind, ) # Restore all parameter values for name, param_data in model_data["params"].items(): if name in model.params: model.params[name].set( value=param_data["value"], vary=param_data["vary"], min=param_data["min"], max=param_data["max"], expr=param_data["expr"], ) # Restore response parameters if model_data["response"] is not None and model.response is not None: for name, param_data in model_data["response"]["params"].items(): if name in model.response.params: model.response.params[name].set( value=param_data["value"], vary=param_data["vary"], min=param_data["min"], max=param_data["max"], ) # Restore background parameters if model_data["background"] is not None and model.background is not None: for name, param_data in model_data["background"]["params"].items(): if name in model.background.params: model.background.params[name].set( value=param_data["value"], vary=param_data["vary"], min=param_data["min"], max=param_data["max"], ) model.n = model_data["n"] return model
def _save_result(self, result, filename: str, include_model: bool = True): """ Save a fit result to a JSON file. Parameters ---------- result : lmfit.model.ModelResult The fit result to save. filename : str Path to the JSON file where the result will be saved. include_model : bool, optional If True, saves the full model with the result. If False, saves only a compressed result with fit parameters. Default is True. """ import json import numpy as np # Serialize fit parameters params_dict = {} for name, param in result.params.items(): params_dict[name] = { "value": float(param.value), "stderr": float(param.stderr) if param.stderr is not None else None, "vary": bool(param.vary), "min": float(param.min) if param.min is not None else None, "max": float(param.max) if param.max is not None else None, "expr": param.expr, } # Serialize fit statistics result_dict = { "version": "1.0", "type": "FitResult", "params": params_dict, "success": bool(result.success), "chisqr": float(result.chisqr), "redchi": float(result.redchi), "aic": float(result.aic) if hasattr(result, "aic") else None, "bic": float(result.bic) if hasattr(result, "bic") else None, "nvarys": int(result.nvarys), "ndata": int(result.ndata), "nfev": int(result.nfev) if hasattr(result, "nfev") else None, "message": result.message if hasattr(result, "message") else None, } # Optionally include the model if include_model: # Temporarily save model to get its JSON representation import os import tempfile with tempfile.NamedTemporaryFile( mode="w", suffix=".json", delete=False ) as temp_f: temp_filename = temp_f.name try: self.save(temp_filename) with open(temp_filename) as f: model_dict = json.load(f) result_dict["model"] = model_dict finally: if os.path.exists(temp_filename): os.remove(temp_filename) # Save to JSON file with open(filename, "w") as f: json.dump(result_dict, f, indent=2)
[docs] @classmethod def load_result(cls, filename: str, model: TransmissionModel = None): """ Load a fit result from a JSON file. Parameters ---------- filename : str Path to the JSON file containing the saved result. model : TransmissionModel, optional Model to use for the result. If None and the file contains a model, it will be loaded from the file. If the file doesn't contain a model, this parameter is required. Returns ------- tuple A tuple containing (model, params_dict) where model is the TransmissionModel instance and params_dict contains the fit parameters and statistics. Examples -------- >>> model, result_data = TransmissionModel.load_result("my_result.json") >>> print(result_data["redchi"]) >>> # Or with compressed result >>> model, result_data = TransmissionModel.load_result( ... "result.json", model=my_model ... ) """ import json import os import tempfile with open(filename) as f: result_data = json.load(f) # Load or use provided model if "model" in result_data: # Full result with embedded model with tempfile.NamedTemporaryFile( mode="w", suffix=".json", delete=False ) as temp_f: json.dump(result_data["model"], temp_f) temp_filename = temp_f.name try: model = cls.load(temp_filename) finally: if os.path.exists(temp_filename): os.remove(temp_filename) elif model is None: raise ValueError( "Model not found in file and no model provided. " "Either save with include_model=True or provide a model parameter." ) # Update model parameters with fit results for name, param_data in result_data["params"].items(): if name in model.params: model.params[name].set( value=param_data["value"], vary=param_data.get("vary", True), min=param_data.get("min"), max=param_data.get("max"), expr=param_data.get("expr"), ) # Return model and result dictionary return model, result_data