Source code for pandora.margins.margins

# Copyright (c) 2024 Centre National d'Etudes Spatiales (CNES).
#
# This file is part of PANDORA
#
#     https://github.com/CNES/Pandora
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""
Margin margins
"""

from __future__ import annotations

from collections import UserDict
from dataclasses import dataclass, astuple, asdict
import operator
from functools import reduce
from typing import Sequence, Tuple, Dict

__all__ = ["Margins", "max_margins", "MarginDict", "GlobalMargins"]


@dataclass(order=True, frozen=True)
[docs] class Margins: """Tuple of margins."""
[docs] left: int
[docs] up: int
[docs] right: int
[docs] down: int
[docs] def __post_init__(self): if any(m < 0 for m in self.astuple()): raise ValueError(f"Margins values should be positive. Got {self.astuple()}")
[docs] def __add__(self, other: Margins) -> Margins: return Margins(*map(operator.add, self.astuple(), other.astuple()))
[docs] def astuple(self) -> Tuple: """Convert self to a tuple of (left, up, right, down).""" return astuple(self)
[docs] def asdict(self) -> Dict: """Convert self to a dictionary.""" return asdict(self)
[docs] class MarginDict(UserDict): """A dictionary that stores Margins."""
[docs] def __setitem__(self, key, value: Margins): if not isinstance(value, Margins): raise ValueError(f"MarginDict only accept values of type Margins. Got {type(value)} instead.") super().__setitem__(key, value)
[docs] def sum(self) -> Margins: """Compute the sum of margins on each direction.""" return reduce(operator.add, self.data.values(), Margins(0, 0, 0, 0))
[docs] class GlobalMargins: """Class to store Margins and compute the global.""" def __init__(self): # Margins that cumulates: self._cumulatives = MarginDict() # Margins that takes maximum values self._non_cumulatives = MarginDict()
[docs] def add_cumulative(self, key, value): """Add a margins that cumulates.""" if key in self._non_cumulatives: raise KeyError( f"{key} is already a non-cumulative margins. " "Cumulative margins and non-cumulative margins are exclusive." ) self._cumulatives[key] = value
[docs] def add_non_cumulative(self, key, value): """Add a margins that does not cumulate.""" if key in self._cumulatives: raise KeyError( f"{key} is already a cumulative margins. " "Cumulative margins and non-cumulative margins are exclusive." ) self._non_cumulatives[key] = value
[docs] def remove_cumulative(self, key): """Remove a margin that cumulates.""" try: del self._cumulatives[key] except KeyError: raise KeyError(key)
[docs] def remove_non_cumulative(self, key): """Remove a margin that does not cumulate.""" try: del self._non_cumulatives[key] except KeyError: raise KeyError(key)
@property
[docs] def cumulatives(self): """MarginDict of margins that cumulates.""" return MarginDict(self._cumulatives)
@property
[docs] def non_cumulatives(self): """MarginDict of margins that does not cumulate.""" return MarginDict(self._non_cumulatives)
@property
[docs] def global_margins(self): """Computed global margins.""" return max_margins([self._cumulatives.sum(), *self.non_cumulatives.values()])
[docs] def to_dict(self): """Convert self to a dictionary in order to be json serializable.""" return { "cumulative margins": {s: m.asdict() for s, m in self._cumulatives.items()}, "non-cumulative margins": {s: m.asdict() for s, m in self._non_cumulatives.items()}, "global margins": self.global_margins.asdict(), }
[docs] def get(self, key): """Find key in cumulative margins or non-cumulative margins and return corresponding margins""" if key in self._cumulatives: return self._cumulatives[key] if key in self._non_cumulatives: return self._non_cumulatives[key] return None
[docs] def max_margins(margins: Sequence[Margins]) -> Margins: """ Return a Margins which is the max of margins element wise. :param margins: sequence of Margins to compute max of. :type margins: Sequence[Margins] :return: Maximum Margins :rtype: Margins """ as_tuple_margins = list(map(astuple, margins)) if len(as_tuple_margins) == 1: return Margins(*as_tuple_margins[0]) return Margins(*map(max, *as_tuple_margins))