Source code for viperleed.calc.classes.atom_containers

"""Module atom_collections of viperleed.tleemlib.classes.

Contains definitions of collections (i.e., containers) of Atoms.
"""

__authors__ = (
    'Michele Riva (@michele-riva)',
    )
__copyright__ = 'Copyright (c) 2019-2024 ViPErLEED developers'
__created__ = '2023-10-06'
__license__ = 'GPLv3+'

from abc import abstractmethod
from collections.abc import Container, Iterable, MutableSequence, Sequence

from viperleed.calc.classes import atom


class AtomContainerError(Exception):
    """Base exception for containers of Atom objects."""


class AtomListOutOfDateError(AtomContainerError):
    """AtomList needs to update_atoms_map."""

    def __init__(self, *args):
        """Initialize instance."""
        if not args:
            args = ('Call .update_atoms_map',)
        super().__init__(*args)


class DuplicateAtomsError(AtomContainerError):
    """There are multiple atoms with the same num."""

    def __init__(self, *args):
        """Initialize instance."""
        if not args:
            args = ('Multiple atoms with the same .num',)
        super().__init__(*args)


# pylint: disable=too-few-public-methods
# Disabled because AtomCollection is essentially only meant to
# be the base class for collections of Atom objects. For now,
# the only reasonable methods are n_atoms and __contains__.
[docs]class AtomContainer(Container): """Base class for all other collections of Atom objects.""" @property @abstractmethod def n_atoms(self): """Return the number of atoms in this AtomList."""
# pylint: enable=too-few-public-methods
[docs]class AtomList(AtomContainer, MutableSequence): """A sequence of Atom objects. Attributes ---------- strict : bool Whether this AtomList only allows unique .num values. This attribute can temporarily be turned off if atom numbers are made unique later. After .num(s) are unique, it is advisable to restore strict to True and to call update_atoms_map. A strict AtomList has some of its operations optimized for O(1) access. """
[docs] def __init__(self, *atoms, strict=True): """Initialize instance.""" if atoms and len(atoms) == 1 and isinstance(atoms[0], Iterable): atoms = atoms[0] self._atoms = list(atoms) if not all(isinstance(at, atom.Atom) for at in self): raise TypeError( f'{type(self).__name__}: only Atom objects allowed' ) self._map = {at.num: at for at in self._atoms} self.strict = strict # True == allow only unique num self._sort_map = {} # Atom: index, for restoring sorting if self.strict and len(self._map) != self.n_atoms: raise DuplicateAtomsError
def __contains__(self, value): """Return whether an atom is in this list.""" if not isinstance(value, atom.Atom): return False try: atom_in_self = self.get(value.num) except KeyError: return False return value is atom_in_self def __delitem__(self, index): """Remove atom(s) at index (or slice).""" try: to_remove = self._atoms[index] except IndexError: raise IndexError('list assignment index out of range') from None del self._atoms[index] if not isinstance(index, slice): to_remove = [to_remove] for atom_ in to_remove: del self._map[atom_.num] def __eq__(self, other): """Return whether this list is equal to another one.""" try: equal = self.n_atoms == len(other) except TypeError: equal = False if not equal: return NotImplemented equal = all(at_self is at_other for at_self, at_other in zip(self, other)) if not equal: return NotImplemented return True def __iter__(self): """Return an iterator of atoms in this AtomList.""" # Reimplemented as the Sequence ABC uses a for and __getitem__ return iter(self._atoms) def __getitem__(self, index): """Return the Atom(s) at index (or slice).""" return self._atoms[index] def __len__(self): """Return the number of atoms in this list.""" return len(self._atoms) def __repr__(self): """Return a string representation of this atom list.""" _repr_ = f'{type(self).__name__}(' _repr_ += ', '.join(str(at) for at in self) return _repr_ + f', strict={self.strict})' def __reversed__(self): """Return a reverse iterator of atoms in this AtomList.""" return reversed(self._atoms) def __setitem__(self, index, value): """Assign the atom(s) at index (or slice) to value.""" # When a slice, we expect an iterable, otherwise a single # element. For the case of a slice, make sure we consume a # possible iterator right away, as we'll need it twice if isinstance(value, Sequence): new_atoms = value elif isinstance(value, Iterable): new_atoms = tuple(value) elif not isinstance(index, slice): new_atoms = (value,) if isinstance(index, slice): replaced = set(self._atoms[index]) else: replaced = {self._atoms[index]} try: _invalid = {at for at in new_atoms if at not in replaced and at.num in self._map} except AttributeError: raise TypeError( f'{type(self).__name__}: only Atom objects allowed' ) from None if _invalid and self.strict: raise DuplicateAtomsError( f'Atom(s) {_invalid} have the same .num as others ' f'already present in this {type(self).__name__}' ) self._map.update((at.num, at) for at in new_atoms) if not isinstance(index, slice): new_atoms = new_atoms[0] self._atoms[index] = new_atoms @property def n_atoms(self): """Return the number of atoms in this AtomList.""" return len(self) def clear(self): """Remove all atoms.""" self._atoms.clear() self._map = {} self._sort_map = {} def copy(self): """Return a copy of this AtomList.""" return self.__class__(*self, strict=self.strict) def count(self, value): """Return the number of occurrences of value.""" try: present = self._map[value.num] is value except (KeyError, AttributeError): present = False if self.strict: return 1 if present else 0 if not present: return 0 return self._atoms.count(value) _no_default = object() def get(self, num, default=_no_default): """Return the atom with a given number.""" if not self.strict: raise NotImplementedError( f'Cannot use .get for non-strict {type(self).__name__}. You ' f'can use, e.g., next(at for at in atoms if at.num=={num}) ' '(for the first occurrence)' ) try: return self._map[num] except KeyError: pass if default is not self._no_default: return default if self.n_atoms != len(self._map): raise AtomListOutOfDateError raise KeyError(num) def index(self, value, *args): """Return the index of an atom.""" return self._atoms.index(value, *args) def insert(self, index, value): """Insert a single atom at a given index.""" if not isinstance(value, atom.Atom): raise TypeError( f'{type(self).__name__}: only Atom objects allowed' ) if value.num in self._map and self.strict: raise DuplicateAtomsError( f'{value} has the same .num as antoher one ' f'already present in this {type(self).__name__}' ) self._atoms.insert(index, value) self._map[value.num] = value def restore_sorting(self, sort_map=None): """Restore a sorting order (or the last one stored).""" if not sort_map: sort_map = self._sort_map if not sort_map: return try: self.sort(key=sort_map.get) except (TypeError, KeyError) as exc: raise RuntimeError( 'Impossible to restore sorting as some elements ' 'have been added since save_sorting() was called' ) from exc def reverse(self): """Reverse in place.""" self._atoms.reverse() def save_sorting(self): """Store the current sorting so it can be later restored.""" self._sort_map = dict(zip(self, range(self.n_atoms))) return self._sort_map def sort(self, *, key=None, reverse=False): """Sort this list of atoms.""" self._atoms.sort(key=key, reverse=reverse) def update_atoms_map(self): """Refresh the {num: atom} map. Typically called after atom numbers have been modified. Raises ------ DuplicateAtomsError If multiple atoms in self have the same .num. """ self._map = {at.num: at for at in self._atoms} if len(self) != len(self._map): raise DuplicateAtomsError