from __future__ import annotations
import os
from copy import deepcopy
from typing import Dict, List, Optional, Union
import numpy as np
import pandas as pd
import nres
import nres.utils as utils
from nres._integrate_xs import integrate_cross_section
[docs]class CrossSection:
"""
A class representing neutron cross-sections for single or multiple isotopes.
This class handles the loading, manipulation, and analysis of neutron cross-section data.
It supports operations on individual isotopes as well as combinations of isotopes with
different weights. The class provides functionality for interpolation and calculation
of total cross-sections based on weighted sums.
Attributes:
isotopes (Dict[Union[str, 'CrossSection'], float]): Dictionary mapping isotope names
or CrossSection objects to their respective weights.
name (str): Identifier for this cross-section combination.
weights (pd.Series): Normalized weights for each isotope, ensuring they sum to 1.
table (pd.DataFrame): DataFrame containing the interpolated cross-section data.
Includes columns for each isotope and a 'total' column.
L (float): Flight path length in meters.
tstep (float): Time step for the simulation in seconds.
tbins (int): Number of time bins for the simulation.
first_tbin (int): Index of the first time bin (typically 1).
n (float): Number density calculated based on isotope weights.
materials (Dict): Dictionary containing material information and properties.
"""
def __init__(
self,
isotopes: Dict[Union[str, CrossSection], float] = None,
name: str = "",
total_weight: float = 1.0,
L: float = 10.59,
tstep: float = 1.56255e-9,
tbins: int = 640,
first_tbin: int = 1,
splitby: str = "elements",
**materials,
):
"""Initialize a new CrossSection instance.
Args:
isotopes: Dictionary mapping isotope names/CrossSection objects to weights,
or a CrossSection object to copy
name: Identifier for this cross-section combination
total_weight: Overall scaling factor for the cross-section
L: Flight path length in meters
tstep: Time step for simulation in seconds
tbins: Number of time bins
first_tbin: Index of the first time bin
splitby: How to split cross sections ("isotopes", "elements", "materials")
**materials: Additional materials to initialize with
"""
# Initialize basic attributes
self.name = name
self.L = L
self.tstep = tstep
self.tbins = tbins
self.first_tbin = first_tbin
self.tgrid = np.arange(self.first_tbin, self.tbins + 1, 1) * self.tstep
# Store the original materials and their properties
self.materials = {}
self.__xsdata__ = None
self._load_xsdata()
# Initialize empty table and weights
self.table = pd.DataFrame()
self.weights = pd.Series(dtype=float)
self.isotopes = {}
self.n = 0.0
if isinstance(isotopes, CrossSection):
self._init_from_cross_section(isotopes, name)
elif isotopes:
material_data = self._get_material_data(isotopes)
self.add_material(
name or "material_1", material_data, splitby, total_weight
)
# Handle additional materials passed as keyword arguments
for mat_name, mat_data in materials.items():
material_data = self._get_material_data(mat_data)
self.add_material(mat_name, material_data, splitby, total_weight)
def _load_xsdata(self):
"""Load cross-section data from file.
First tries to use trinidi_data package if installed.
If not found, downloads xsdata.npy to a local cache directory.
Only downloads if the file doesn't already exist in the cache.
"""
if self.__xsdata__ is None:
# Try to import trinidi_data package
try:
import trinidi_data
data_path = os.path.join(
os.path.dirname(trinidi_data.__file__), "xsdata.npy"
)
except ImportError:
# trinidi_data not installed, use local cache
import platformdirs
cache_dir = platformdirs.user_cache_dir("nres", "nres")
os.makedirs(cache_dir, exist_ok=True)
data_path = os.path.join(cache_dir, "xsdata.npy")
# Download if file doesn't exist
if not os.path.exists(data_path):
import requests
url = "https://github.com/TsvikiHirsh/trinidi-data/raw/main/trinidi_data/xsdata.npy"
print(f"Downloading cross-section data to {data_path}...")
response = requests.get(url, stream=True)
response.raise_for_status()
with open(data_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
print("Download complete!")
# else: File already exists, skip download
xsdata = np.load(data_path, allow_pickle=True)[()]
self.__xsdata__ = {
isotope.replace("-", ""): pd.Series(
xsdata["cross_sections"][i],
index=xsdata["energies"][i],
name=isotope.replace("-", ""),
)
for i, isotope in enumerate(xsdata["isotopes"])
}
def _get_material_data(self, material: Union[str, Dict]) -> Dict:
"""Convert material string or dict to full material data structure."""
if isinstance(material, str):
formulas = {
nres.materials[element]["formula"]: nres.materials[element]["name"]
for element in nres.materials
}
try: # Try materials database
material = nres.materials[formulas.get(material, material)]
except KeyError: # Try elements database
try:
formulas = {
nres.elements[element]["formula"]: nres.elements[element][
"name"
]
for element in nres.elements
}
material = nres.elements[
formulas.get(material.capitalize(), material.capitalize())
]
except KeyError: # Try isotopes database
material = nres.isotopes[material.capitalize().replace("-", "")]
return material
def _set_weights(self, weights: Optional[List[float]] = None):
"""Set and normalize weights for all isotopes.
Args:
weights: Optional list of new weights. If provided, must match
the number of isotopes.
Raises:
ValueError: If the number of provided weights doesn't match
the number of isotopes.
"""
if weights is not None:
if len(weights) != len(self.isotopes):
raise ValueError("Number of weights must match number of isotopes")
self.weights = pd.Series(weights, index=self.isotopes.keys())
else:
self.weights = pd.Series(self.isotopes)
# Remove zero-weight isotopes and normalize
self.weights = self.weights[self.weights > 0]
self.weights /= self.weights.sum()
# Update total cross-section with new weights
self.table["total"] = (
(self.table.drop(columns="total", errors="ignore") * self.weights)
.sum(axis=1)
.astype(float)
)
# Update total atomic density
self.n = self._update_atomic_density()
[docs] def add_material(
self,
name: str,
material_data: Dict,
splitby: str = "elements",
total_weight: float = 1.0,
):
"""
Add a new material with complete information.
Args:
name: str, name of the material
material_data: Dict, material composition data
splitby: str, how to split the material ('elements', 'isotopes', or 'materials')
total_weight: float, total weight of the material
"""
# Deep copy the material data to prevent modifications to the original
self.materials[name] = deepcopy(material_data)
self.materials[name]["splitby"] = splitby
self.materials[name]["total_weight"] = total_weight
# Collect all existing energy grids
energy_grids = []
# Add current table's energy grid if it exists
if hasattr(self, "table") and len(self.table) > 0:
energy_grids.append(self.table.index)
# Add energy grids from the new material's cross sections
for element_info in material_data["elements"].values():
for isotope in element_info["isotopes"]:
isotope_clean = isotope.replace("-", "")
if isotope_clean in self.__xsdata__:
energy_grids.append(self.__xsdata__[isotope_clean].index)
# If we have any energy grids, merge them
if energy_grids:
# Create a merged grid that includes all unique energy points
merged_grid = pd.Index(sorted(set().union(*energy_grids)))
self._energy_grid = merged_grid
# Recalculate cross sections with the updated energy grid
self._recalculate_cross_sections()
def _update_atomic_density(self) -> float:
"""Calculate and update the total atomic density."""
new_n = 0
for material_name, material_info in self.materials.items():
material_weight = material_info["total_weight"]
splitby = material_info["splitby"]
if splitby == "isotopes":
for element_info in material_info["elements"].values():
for isotope, weight in element_info["isotopes"].items():
isotope_clean = isotope.replace("-", "")
if isotope_clean in self.weights.index:
new_n += (
material_info["n"]
* self.weights[isotope_clean]
* material_weight
)
elif splitby == "elements":
for element, element_info in material_info["elements"].items():
if element in self.weights.index:
new_n += (
material_info["n"] * self.weights[element] * material_weight
)
elif splitby == "materials":
if material_name in self.weights.index:
new_n += (
material_info["n"]
* self.weights[material_name]
* material_weight
)
return new_n
def _recalculate_cross_sections(self):
"""Calculate cross sections based on material information."""
if not self.materials:
return
cross_sections = {}
combined_weights = {}
material_xs = {}
material_weights = {}
for material_name, material_info in self.materials.items():
splitby = material_info["splitby"]
total_weight = material_info["total_weight"]
if splitby == "isotopes":
for element_info in material_info["elements"].values():
for isotope, weight in element_info["isotopes"].items():
isotope_clean = isotope.replace("-", "")
if isotope_clean in self.__xsdata__:
cross_sections[isotope_clean] = self.__xsdata__[
isotope_clean
]
combined_weights[isotope_clean] = weight * total_weight
elif splitby == "elements":
for element, element_info in material_info["elements"].items():
element_xs_dict = {}
element_weights = {}
for isotope, weight in element_info["isotopes"].items():
isotope_clean = isotope.replace("-", "")
if isotope_clean in self.__xsdata__:
# Store the raw cross section data instead of multiplying by weight here
element_xs_dict[isotope_clean] = self.__xsdata__[
isotope_clean
]
element_weights[isotope_clean] = weight * total_weight
if len(element_xs_dict) > 0:
# Convert dictionary of Series to list of DataFrames for interleave_xs_energies
element_xs_dfs = [
pd.DataFrame({name: xs})
for name, xs in element_xs_dict.items()
]
# Use the new interleave_xs_energies function
element_xs = self._interleave_xs_energies(element_xs_dfs)
# Now apply the weights after interpolation
element_weights = pd.Series(element_weights)
element_weights /= element_weights.sum()
# Calculate weighted sum for the element
total = pd.Series(0.0, index=element_xs.index)
for col in element_xs.columns:
isotope = col.split("_iso_")[
0
] # Get original isotope name from column
if isotope in element_weights:
total += element_xs[col] * element_weights[isotope]
cross_sections[element] = total
combined_weights[element] = (
element_info["weight"] * total_weight
)
elif splitby == "materials":
material_xs_dict = {}
material_weights = {}
for element_info in material_info["elements"].values():
for isotope, weight in element_info["isotopes"].items():
isotope_clean = isotope.replace("-", "")
if isotope_clean in self.__xsdata__:
# Store raw cross section data
material_xs_dict[isotope_clean] = self.__xsdata__[
isotope_clean
]
material_weights[isotope_clean] = weight * total_weight
if len(material_xs_dict) > 0:
# Convert dictionary of Series to list of DataFrames for interleave_xs_energies
material_xs_dfs = [
pd.DataFrame({name: xs})
for name, xs in material_xs_dict.items()
]
# Use the new interleave_xs_energies function
material_xs = self._interleave_xs_energies(material_xs_dfs)
# Now apply the weights after interpolation
material_weights = pd.Series(material_weights)
material_weights /= material_weights.sum()
# Calculate weighted sum for the material
total = pd.Series(0.0, index=material_xs.index)
for col in material_xs.columns:
isotope = col.split("_iso_")[
0
] # Get original isotope name from column
if isotope in material_weights:
total += material_xs[col] * material_weights[isotope]
cross_sections[material_name] = total
combined_weights[material_name] = total_weight
if cross_sections:
# Convert all cross sections to DataFrames for final interpolation
xs_dfs = [pd.DataFrame({name: xs}) for name, xs in cross_sections.items()]
combined_table = self._interleave_xs_energies(xs_dfs)
# If we have a stored energy grid, reindex and interpolate
if hasattr(self, "_energy_grid"):
combined_table = combined_table.reindex(self._energy_grid)
combined_table = self._interleave_xs_energies([combined_table])
total_weight = sum(combined_weights.values())
combined_weights = {
k: v / total_weight for k, v in combined_weights.items()
}
self.table = combined_table
self.table.index.name = "energy"
self.weights = pd.Series(combined_weights)
weight_series = pd.Series(0.0, index=self.table.columns)
for col in self.table.columns:
base_col = col.split("_iso_")[
0
] # Handle new column naming from interleave_xs_energies
if base_col in self.weights:
weight_series[col] = self.weights[base_col]
# Calculate total weighted cross section
self.table["total"] = (self.table * weight_series).sum(axis=1).astype(float)
self.isotopes = self.weights.to_dict()
self.n = self._update_atomic_density()
def __add__(self, other: CrossSection) -> CrossSection:
"""Add two CrossSection objects together."""
new_self = deepcopy(self)
# Store current energy grids
energy_grids = []
if hasattr(self, "table") and self.table is not None:
energy_grids.append(self.table.index)
if hasattr(other, "table") and other.table is not None:
energy_grids.append(other.table.index)
# Combine materials
for mat_name, mat_info in other.materials.items():
new_mat = deepcopy(mat_info)
new_mat["total_weight"] = mat_info["total_weight"]
new_self.add_material(
name=mat_name,
material_data=new_mat,
splitby=mat_info["splitby"],
total_weight=new_mat["total_weight"],
)
# Store merged grid if available
if energy_grids:
merged_grid = pd.Index(sorted(set().union(*energy_grids)))
new_self._energy_grid = merged_grid
new_self._recalculate_cross_sections()
return new_self
def __mul__(self, total_weight: float = 1.0) -> CrossSection:
"""Scale the CrossSection by a total weight factor."""
new_self = deepcopy(self)
new_self.total_weight = total_weight
for material_name in new_self.materials:
new_self.materials[material_name]["total_weight"] *= total_weight
new_self._recalculate_cross_sections()
return new_self
def __rmul__(self, total_weight: float = 1.0) -> CrossSection:
"""Right multiplication to support scalar * CrossSection."""
return self.__mul__(total_weight)
def __call__(
self,
E: np.ndarray,
weights: Optional[np.ndarray] = None,
response: Optional[np.ndarray] = None,
) -> np.ndarray:
"""Calculate the weighted cross-section for given energy values."""
if weights is not None:
self._set_weights(weights=weights)
if response is None:
response = [1.0]
return np.array(
integrate_cross_section(
self.table["total"].index.values,
self.table["total"].values,
E,
response,
)
)
def _interleave_xs_energies(self, xs_data):
"""
Interleave cross section data from different isotopes by interpolating across their
combined energy grid points.
Args:
xs_data: DataFrame or Series containing cross section data with energy index,
or a list of such DataFrames/Series
Returns:
DataFrame with combined energy grid and interpolated cross section values
"""
if isinstance(xs_data, pd.Series):
return self._interleave_xs_energies(pd.DataFrame(xs_data)).iloc[:, 0]
# If input is a single DataFrame, wrap it in a list
if not isinstance(xs_data, list):
xs_data = [xs_data]
# Combine all unique energy points from all cross section data
all_energies = sorted(set().union(*[df.index for df in xs_data]))
# Create a new DataFrame with the combined energy grid
result = pd.DataFrame(index=all_energies)
# Add data from each cross section dataset
for i, xs_df in enumerate(xs_data):
if isinstance(xs_df, pd.Series):
xs_df = xs_df.to_frame(f"xs_{i}")
# Reindex to include all energy points
xs_reindexed = xs_df.reindex(all_energies)
# Interpolate missing values for each cross section column
for col in xs_reindexed.columns:
# Keep original column name but add isotope identifier if needed
result[f"{col}" if len(xs_data) > 1 else col] = xs_reindexed[
col
].interpolate(method="linear")
return result
[docs] def plot(self, **kwargs):
"""
Create a matplotlib plot of the cross-section data.
Generates a plot showing the total cross-section and individual
isotope contributions, with weights displayed in the legend.
Args:
**kwargs: Additional keyword arguments for plot customization:
- title: Plot title (default: self.name)
- ylabel: Y-axis label (default: "σ [barn]")
- xlabel: X-axis label (default: "Energy [eV]")
- lw: Line width (default: 1.0)
- Additional arguments passed to pandas.DataFrame.plot
Returns:
matplotlib.axes.Axes: The axes object containing the plot
Note:
The total cross-section is plotted in black with increased line width
for emphasis. Individual isotope contributions are plotted in different
colors with their weights shown as percentages in the legend.
"""
import matplotlib.pyplot as plt
title = kwargs.pop("title", self.name)
ylabel = kwargs.pop("ylabel", r"$\sigma$ [barn]")
xlabel = kwargs.pop("xlabel", "Energy [eV]")
lw = kwargs.pop("lw", 1.0)
# Apply weights and format column labels with percentage contributions
table = self.table.mul(np.r_[self.weights, 1.0], axis=1)
table.columns = [
f"{column}: {weight*100:>6.2f}%" for column, weight in self.weights.items()
] + ["total"]
fig, ax = plt.subplots()
# Plot total cross-section with emphasis
table.plot(y="total", linewidth=1.5, ax=ax, color="0.2", zorder=100, **kwargs)
# Plot individual contributions
table.drop("total", axis=1).plot(
ax=ax, title=title, xlabel=xlabel, ylabel=ylabel, linewidth=lw, **kwargs
)
return ax
[docs] def iplot(self, **kwargs):
"""
Create an interactive plotly plot of the cross-section data.
Generates an interactive plot showing the total cross-section and
individual isotope contributions, with customizable axes scales
and energy range.
Args:
**kwargs: Additional keyword arguments for plot customization:
- title: Plot title (default: self.name)
- ylabel: Y-axis label (default: "σ [barn]")
- xlabel: X-axis label (default: "Energy [eV]")
- emin: Minimum energy to plot (default: 0.1 eV)
- emax: Maximum energy to plot (default: 2e7 eV)
- scalex: X-axis scale ("log" or "linear", default: "log")
- scaley: Y-axis scale ("log" or "linear", default: "log")
- Additional arguments passed to plotly
Returns:
plotly.graph_objects.Figure: Interactive figure object
Note:
This method uses plotly as the backend for interactive visualization,
allowing for features like zooming, panning, and hover tooltips.
"""
pd.options.plotting.backend = "plotly"
title = kwargs.pop("title", self.name)
ylabel = kwargs.pop("ylabel", "σ [barn]")
xlabel = kwargs.pop("xlabel", "Energy [eV]")
emin = kwargs.pop("emin", 0.1)
emax = kwargs.pop("emax", 2e7)
scalex = kwargs.pop("scalex", "log")
scaley = kwargs.pop("scaley", "log")
# Filter data to specified energy range
filtered_table = self.table.query("@emin <= energy <= @emax")
# Apply weights and format column labels
table = filtered_table.mul(np.r_[self.weights, 1.0], axis=1)
table.columns = [
f"{column}: {weight*100:>6.2f}%" for column, weight in self.weights.items()
] + ["total"]
fig = table.plot(**kwargs)
# Configure layout
fig.update_layout(
xaxis_type=scalex,
yaxis_type=scaley,
xaxis_title=xlabel,
yaxis_title=ylabel,
title_text=title,
)
return fig