Source code for pyrtid.utils.means

"""Provide some classic means."""

# pylint: disable=C0103 # doesn't conform to snake_case naming style
from typing import Optional

import numpy as np
from scipy.stats import gmean, hmean

from pyrtid.utils.enum import StrEnum
from pyrtid.utils.types import NDArrayFloat


[docs]def arithmetic_mean(xi: NDArrayFloat, xj) -> NDArrayFloat: """Return the arithmetic mean of xi and xj.""" return (xi + xj) / 2.0
[docs]def dxi_arithmetic_mean(xi: NDArrayFloat, xj: NDArrayFloat) -> NDArrayFloat: """Return the first derivative of xi and xj arithmetic mean with respect to xi.""" # pylint: disable=W0613 # unused argument return 0.5 + xi * 0.0 # required to work with vectors
[docs]def harmonic_mean(xi: NDArrayFloat, xj) -> NDArrayFloat: """Return the harmonic mean of xi and xj.""" return 2.0 / (1.0 / xi + 1.0 / xj)
[docs]def dxi_harmonic_mean(xi: NDArrayFloat, xj) -> NDArrayFloat: """Return the first derivative of xi and xj arithmetic mean with respect to xi.""" return 2.0 * xj**2.0 / (xi + xj) ** 2.0
[docs]class MeanType(StrEnum): HARMONIC = "harmonic" ARITHMETIC = "arithmetic" GEOMETRIC = "geometric"
[docs]def get_mean_values_for_last_axis( arr: NDArrayFloat, mean_type: MeanType, weights: Optional[NDArrayFloat] = None ) -> NDArrayFloat: """ Get the mean values for the last axis of the input array. Parameters ---------- arr : _type_ Array of values with shape (nx, ny, nt) or (npts, nt). mean_type: MeanType Type of mean chosen to average the simulated value when the observed one is defined over several grid cells of the domain. weights: Optional[NDArrayFloat] Weights to apply Returns ------- NDArrayFloat Averaged values for the last axis. """ _arr = np.asarray(arr) # ensure a second axis if len(_arr.shape) == 1: _arr = _arr[:, np.newaxis] # or make 2D else: _arr = _arr.reshape(-1, _arr.shape[-1]) if weights is not None: if _arr.shape[0] != weights.size: raise ValueError( "The number of weights must match the number of grid cells." ) return np.apply_along_axis( { MeanType.ARITHMETIC: np.average, MeanType.GEOMETRIC: gmean, MeanType.HARMONIC: hmean, }[mean_type], axis=0, arr=_arr, weights=weights, )
[docs]def amean_gradient( values: NDArrayFloat, weights: Optional[NDArrayFloat] = None ) -> NDArrayFloat: """Return the gradient of the weighted arithmetic mean.""" if weights is not None: return weights / np.sum(weights) return np.ones(values.shape) / values.size
[docs]def hmean_gradient( values: NDArrayFloat, weights: Optional[NDArrayFloat] = None ) -> NDArrayFloat: """Return the gradient of the harmonic arithmetic mean.""" if weights is None: return values.size / (np.square(values * np.sum(1.0 / values))) return weights * np.sum(weights) / np.square(values * np.sum(weights / values))
[docs]def gmean_gradient( values: NDArrayFloat, weights: Optional[NDArrayFloat] = None ) -> NDArrayFloat: """Return the gradient of the weighted geometric mean.""" k: int = values.size if weights is None: return 1 / k * np.power(np.prod(values), (1 / k)) / values return weights / (values * np.sum(weights)) * gmean(values, weights=weights)
[docs]def get_mean_values_gradient_for_last_axis( arr: NDArrayFloat, mean_type: MeanType, weights: Optional[NDArrayFloat] = None ) -> NDArrayFloat: """ Get the mean values for the last axis of the input array. Parameters ---------- arr : _type_ Array of values with shape (nx, ny, nt) or (npts, nt). mean_type: MeanType Type of mean chosen to average the simulated value when the observed one is defined over several grid cells of the domain. weights: Optional[NDArrayFloat] Weights to apply Returns ------- NDArrayFloat Averaged values for the last axis. """ # ensure a second axis if len(arr.shape) == 1: _arr = arr[:, np.newaxis] # or make 2D else: _arr = arr.reshape(-1, arr.shape[-1]) if weights is not None: if _arr.shape[0] != weights.size: raise ValueError( "The number of weights must match the number of grid cells." ) return np.apply_along_axis( { MeanType.ARITHMETIC: amean_gradient, MeanType.GEOMETRIC: gmean_gradient, MeanType.HARMONIC: hmean_gradient, }[mean_type], axis=0, arr=_arr, weights=weights, ).reshape(arr.shape)