#!/usr/bin/env python
# coding: utf8
#
# Copyright (c) 2026 Centre National d'Etudes Spatiales (CNES).
#
# This file is part of PANDORA
#
# https://github.com/CNES/Pandora
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
This module contains functions associated to the validity mask created in the cost volume step.
"""
from typing import Union, Tuple
from enum import IntFlag
import numpy as np
from numpy.typing import DTypeLike, ArrayLike
from scipy.ndimage import binary_dilation
import xarray as xr
from pandora.constants import Criteria
from pandora.profiler import profile
from .cpp import criteria_cpp
[docs]
class FlagArray(np.ndarray):
"""NDArray subclass that expects to be filled with Flags and with dedicated repr."""
def __new__(cls, input_array: ArrayLike, flag: type[IntFlag], dtype: DTypeLike = np.uint16):
obj = np.asarray(input_array, dtype=dtype).view(cls)
obj.flag = flag
return obj
[docs]
def __array_finalize__(self, obj):
if obj is None:
return
self.flag = getattr(obj, "flag", None) # pylint: disable=attribute-defined-outside-init
[docs]
def __repr__(self) -> str:
if self.flag is None:
return super().__repr__()
max_line_width = np.get_printoptions()["linewidth"]
flag_reprs = [repr(self.flag(i)).replace(self.flag.__name__ + ".", "") for i in range(sum(self.flag))]
prefix = f"{self.__class__.__name__}<{self.flag.__name__}>"
suffix = f"dtype={self.dtype}"
array_repr = np.array2string(
self,
prefix=prefix,
formatter={"int_kind": lambda x: flag_reprs[x]},
separator=", ",
suffix=suffix,
max_line_width=max_line_width,
)
return f"{prefix}({array_repr}, {suffix})"
[docs]
def binary_dilation_msk(img: xr.Dataset, window_size: int) -> np.ndarray:
"""
Apply scipy binary_dilation on our image dataset.
Get the no_data pixels.
:param img: Dataset image containing :
- im: 2D (row, col) or 3D (band_im, row, col) xarray.DataArray float32
- disparity (optional): 3D (disp, row, col) xarray.DataArray float32
- msk (optional): 2D (row, col) xarray.DataArray int16
- classif (optional): 3D (band_classif, row, col) xarray.DataArray int16
- segm (optional): 2D (row, col) xarray.DataArray int16
- edges (optional): 2D (row, col) xarray.DataArray int16
:type img: xarray.Dataset
:param window_size: window size of the cost volume
:type window_size: int
:return: np.ndarray with location of pixels that are marked as no_data according to the image mask
:rtype: np.ndarray
"""
dil = binary_dilation(
img["msk"].data == img.attrs["no_data_mask"],
structure=np.ones((window_size, window_size)),
iterations=1,
)
return dil
[docs]
def allocate_validity_mask(
cv: xr.Dataset,
value: int | Criteria = Criteria.PANDORA_VALID,
data_type: DTypeLike = np.uint16,
) -> xr.DataArray:
"""
Allocate the validity_mask DataArray of the cost volume, initialized to PANDORA_VALID.
:param cv: cost volume dataset (provides row, col sizes)
:type cv: xarray.Dataset
:param value: initial value (default Criteria.PANDORA_VALID = 0)
:type value: int or Criteria
:param data_type: dtype of the validity mask. Default np.uint16.
:type data_type: DTypeLike
:return: validity_mask 2D DataArray (row, col)
:rtype: xarray.DataArray
"""
return xr.DataArray(
FlagArray(np.full((cv.sizes["row"], cv.sizes["col"]), value, dtype=data_type), Criteria, data_type),
dims=["row", "col"],
)
@profile("validity_mask")
[docs]
def validity_mask(
img_left: xr.Dataset,
img_right: xr.Dataset,
cv: xr.Dataset,
) -> xr.Dataset:
"""
Create the validity mask of the cost volume
:param img_left: left Dataset image containing :
- im: 2D (row, col) or 3D (band_im, row, col) xarray.DataArray float32
- disparity (optional): 3D (disp, row, col) xarray.DataArray float32
- msk (optional): 2D (row, col) xarray.DataArray int16
- classif (optional): 3D (band_classif, row, col) xarray.DataArray int16
- segm (optional): 2D (row, col) xarray.DataArray int16
- edges (optional): 2D (row, col) xarray.DataArray int16
:type img_left: xarray.Dataset
:param img_right: right Dataset image containing :
- im: 2D (row, col) or 3D (band_im, row, col) xarray.DataArray float32
- disparity (optional): 3D (disp, row, col) xarray.DataArray float32
- msk (optional): 2D (row, col) xarray.DataArray int16
- classif (optional): 3D (band_classif, row, col) xarray.DataArray int16
- segm (optional): 2D (row, col) xarray.DataArray int16
- edges (optional): 2D (row, col) xarray.DataArray int16
:type img_right: xarray.Dataset
:param cv: cost volume dataset with the data variables:
- cost_volume 3D xarray.DataArray (row, col, disp)
- confidence_measure (optional) 3D xarray.DataArray (row, col, indicator)
:type cv: xarray.Dataset
:return: Dataset with the cost volume and the validity_mask with the data variables :
- cost_volume 3D xarray.DataArray (row, col, disp)
- confidence_measure 3D xarray.DataArray (row, col, indicator)
- validity_mask 2D xarray.DataArray (row, col)
:rtype: xarray.Dataset
"""
# Allocate the validity mask
cv["validity_mask"] = allocate_validity_mask(cv)
# From the grid_estimation function, which creates the cost volume xarray dataset
d_min, d_max = cv.coords["disp"].data[[0, -1]]
col = cv.coords["col"].data
offset = cv.attrs["offset_row_col"]
# Negative disparity range
if d_max < 0:
bit_1 = np.where((col + d_max) < (col[0] + offset)) # type: ignore[var-annotated]
# Information: the disparity interval is incomplete (border reached in the right image)
cv["validity_mask"].data[
:,
np.where(((col + d_max) >= (col[0] + offset)) & ((col + d_min) < (col[0] + offset))),
] |= np.uint16(Criteria.PANDORA_MSK_PIXEL_RIGHT_INCOMPLETE_DISPARITY_RANGE)
else:
# Positive disparity range
if d_min > 0:
bit_1 = np.where((col + d_min) > (col[-1] - offset))
# Information: the disparity interval is incomplete (border reached in the right image)
cv["validity_mask"].data[
:,
np.where(((col + d_min) <= (col[-1] - offset)) & ((col + d_max) > (col[-1] - offset))),
] |= np.uint16(Criteria.PANDORA_MSK_PIXEL_RIGHT_INCOMPLETE_DISPARITY_RANGE)
# Disparity range contains 0
else:
bit_1 = ([],)
# Information: the disparity interval is incomplete (border reached in the right image)
cv["validity_mask"].data[
:,
np.where(((col + d_min) < (col[0] + offset)) | (col + d_max > (col[-1]) - offset)),
] |= np.uint16(Criteria.PANDORA_MSK_PIXEL_RIGHT_INCOMPLETE_DISPARITY_RANGE)
# Invalid pixel : the disparity interval is missing in the right image ( disparity range
# outside the image )
cv["validity_mask"].data[:, bit_1] |= np.uint16(Criteria.PANDORA_MSK_PIXEL_RIGHT_NODATA_OR_DISPARITY_RANGE_MISSING)
if "msk" in img_left.data_vars:
allocate_left_mask(cv, img_left)
if "msk" in img_right.data_vars:
allocate_right_mask(cv, img_right, bit_1)
# img right contains masked values and img left disp ranges: get the pixels affected
if "disparity" in img_left.data_vars:
mask_partially_missing_variable_ranges(cv, img_left, img_right)
return cv
[docs]
def mask_partially_missing_variable_ranges(cv, img_left, img_right):
"""
Mask the pixels with a partially missing variable range in the right image.
Applies the mask directly to the CV's validity mask.
:param cv: Cost volume dataset
:type cv: xarray.Dataset
:param img_left: Left image dataset
:type img_left: xarray.Dataset
:param img_right: Right image dataset
:type img_right: xarray.Dataset
"""
mask = criteria_cpp.partially_missing_variable_ranges(
img_left["disparity"].data,
# mask with true = invalid, false = valid
img_right["msk"].data != img_right.attrs["valid_pixels"],
)
cv["validity_mask"].data[mask] |= np.uint16(Criteria.PANDORA_MSK_PIXEL_INCOMPLETE_VARIABLE_DISPARITY_RANGE)
[docs]
def allocate_left_mask(cv: xr.Dataset, img_left: xr.Dataset) -> None:
"""
Allocate the left image mask
:param cv: cost volume dataset with the data variables:
- cost_volume 3D xarray.DataArray (row, col, disp)
- confidence_measure (optional) 3D xarray.DataArray (row, col, indicator)
:type cv: xarray.Dataset
:param img_left: left Dataset image containing :
- im: 2D (row, col) or 3D (band_im, row, col) xarray.DataArray float32
- disparity (optional): 3D (disp, row, col) xarray.DataArray float32
- msk (optional): 2D (row, col) xarray.DataArray int16
- classif (optional): 3D (band_classif, row, col) xarray.DataArray int16
- segm (optional): 2D (row, col) xarray.DataArray int16
- edges (optional): 2D (row, col) xarray.DataArray int16
:type img_left: xarray.Dataset
:return: None
"""
_, r_mask = xr.align(cv["validity_mask"], img_left["msk"]) # pylint: disable=unbalanced-tuple-unpacking
# Dilatation : pixels that contains no_data in their aggregation window become no_data
dil = binary_dilation_msk(img_left, cv.attrs["window_size"])
# Invalid pixel : no_data in the left image
cv["validity_mask"].data |= dil.astype(np.uint16) * np.uint16(Criteria.PANDORA_MSK_PIXEL_LEFT_NODATA_OR_BORDER)
# Invalid pixel : invalidated by the validity mask of the left image given as input
cv["validity_mask"].data |= xr.where(
(r_mask != img_left.attrs["no_data_mask"]) & (r_mask != img_left.attrs["valid_pixels"]),
np.uint16(Criteria.PANDORA_MSK_PIXEL_IN_VALIDITY_MASK_LEFT),
np.uint16(0),
).data.astype(np.uint16)
[docs]
def allocate_right_mask(cv: xr.Dataset, img_right: xr.Dataset, bit_1: Union[np.ndarray, Tuple]) -> None:
"""
Allocate the right image mask
:param cv: cost volume dataset with the data variables:
- cost_volume 3D xarray.DataArray (row, col, disp)
- confidence_measure (optional) 3D xarray.DataArray (row, col, indicator)
:type cv: xarray.Dataset
:param img_right: right Dataset image containing :
- im: 2D (row, col) or 3D (band_im, row, col) xarray.DataArray float32
- disparity (optional): 3D (disp, row, col) xarray.DataArray float32
- msk (optional): 2D (row, col) xarray.DataArray int16
- classif (optional): 3D (band_classif, row, col) xarray.DataArray int16
- segm (optional): 2D (row, col) xarray.DataArray int16
- edges (optional): 2D (row, col) xarray.DataArray int16
:type img_right: xarray.Dataset
:param bit_1: where the disparity interval is missing in the right image ( disparity range outside the image )
:type: ndarray or Tuple
:return: None
"""
offset = cv.attrs["offset_row_col"]
_, r_mask = xr.align(cv["validity_mask"], img_right["msk"]) # pylint: disable=unbalanced-tuple-unpacking
d_min, d_max = cv.coords["disp"].data[[0, -1]].astype(int)
# Dilatation : pixels that contains no_data in their aggregation window become no_data
dil = binary_dilation_msk(img_right, cv.attrs["window_size"])
r_mask = xr.where(
(r_mask != img_right.attrs["no_data_mask"]) & (r_mask != img_right.attrs["valid_pixels"]),
1,
0,
).data
# Useful to calculate the case where the disparity interval is incomplete, and all remaining right
# positions are invalidated by the right mask
b_2_7 = np.full((cv.sizes["row"], cv.sizes["col"]), 0, dtype=np.uint16)
# Useful to calculate the case where no_data in the right image invalidated the disparity interval
no_data_right = np.full((cv.sizes["row"], cv.sizes["col"]), 0, dtype=np.uint16)
col_range = np.arange(cv.sizes["col"])
for dsp in range(d_min, d_max + 1):
# Diagonal in the cost volume
col_d = col_range + dsp
valid_index = np.where((col_d >= col_range[0] + offset) & (col_d <= col_range[-1] - offset))
# No_data and masked pixels do not raise the same flag, we need to treat them differently
b_2_7[:, col_range[valid_index]] += r_mask[:, col_d[valid_index]].astype(np.uint16)
b_2_7[:, col_range[np.setdiff1d(col_range, valid_index)]] += 1
no_data_right[:, col_range[valid_index]] += dil[:, col_d[valid_index]]
no_data_right[:, col_range[np.setdiff1d(col_range, valid_index)]] += 1
# Exclusion of pixels that have flag 1 already enabled
b_2_7[:, bit_1[0]] = 0
no_data_right[:, bit_1[0]] = 0
# Invalid pixel: right positions invalidated by the mask of the right image given as input
cv["validity_mask"].data[np.where(b_2_7 == len(range(d_min, d_max + 1)))] |= np.uint16(
Criteria.PANDORA_MSK_PIXEL_IN_VALIDITY_MASK_RIGHT
)
# If Invalid pixel : the disparity interval is missing in the right image (disparity interval
# is invalidated by no_data in the right image )
cv["validity_mask"].data[np.where(no_data_right == len(range(d_min, d_max + 1)))] |= np.uint16(
Criteria.PANDORA_MSK_PIXEL_RIGHT_NODATA_OR_DISPARITY_RANGE_MISSING
)
[docs]
def mask_invalid_variable_disparity_range(cv: xr.Dataset) -> None:
"""
Mask the pixels that have a missing disparity range, searching in the cost volume
the pixels where cost_volume(row,col, for all d) = np.nan
:param cv: cost volume dataset with the data variables:
- cost_volume 3D xarray.DataArray (row, col, disp)
- confidence_measure (optional) 3D xarray.DataArray (row, col, indicator)
:type cv: xarray.Dataset
:return: None
"""
indices_nan = np.isnan(cv["cost_volume"].data)
missing_disparity_range = np.min(indices_nan, axis=2)
missing_range_y, missing_range_x = np.where(missing_disparity_range)
# Mask the positions which have an missing disparity range, not already taken into account
flag_uint16 = np.uint16(Criteria.PANDORA_MSK_PIXEL_RIGHT_NODATA_OR_DISPARITY_RANGE_MISSING)
condition_to_mask = (cv["validity_mask"].data[missing_range_y, missing_range_x] & flag_uint16) == 0
masking_value = cv["validity_mask"].data[missing_range_y, missing_range_x] | flag_uint16
no_masking_value = cv["validity_mask"].data[missing_range_y, missing_range_x]
cv["validity_mask"].data[missing_range_y, missing_range_x] = np.where(
condition_to_mask, masking_value, no_masking_value
)
[docs]
def mask_border(dataset: xr.Dataset) -> xr.DataArray:
"""
Mask border pixel which haven't been calculated because of the window's size
:param dataset: dataset that can be :
- the cost volume, the confidence measure and the validity_mask with the data variables :
- cost_volume 3D xarray.DataArray (row, col, disp)
- confidence_measure (optional) 3D xarray.DataArray (row, col, indicator)
- validity_mask 2D xarray.DataArray (row, col)
- the disparity_map, the confidence measure and the validity mask with the data variables :
- disparity_map 2D xarray.DataArray (row, col)
- confidence_measure (optional) 3D xarray.DataArray (row, col, indicator)
- validity_mask 2D xarray.DataArray (row, col)
:type dataset: xarray.Dataset
:return: DataArray with the updated validity_mask
:rtype: xarray.Dataset
"""
offset = dataset.attrs["offset_row_col"]
# Border pixels have invalid disparity, erase the potential previous values
left_nodata_or_border_uint = np.uint16(Criteria.PANDORA_MSK_PIXEL_LEFT_NODATA_OR_BORDER)
dataset["validity_mask"].data[:offset, :] = left_nodata_or_border_uint
dataset["validity_mask"].data[-offset:, :] = left_nodata_or_border_uint
dataset["validity_mask"].data[offset:-offset, :offset] = left_nodata_or_border_uint
dataset["validity_mask"].data[offset:-offset, -offset:] = left_nodata_or_border_uint
return dataset["validity_mask"]