"""Functions to validate inputs and parameters in skchange."""
from contextlib import contextmanager
from contextvars import ContextVar
from typing import TYPE_CHECKING
import numpy as np
from sklearn.base import BaseEstimator
from sklearn.utils.validation import check_array
from sklearn.utils.validation import check_is_fitted as _sklearn_check_is_fitted
from sklearn.utils.validation import validate_data as _sklearn_validate_data
from skchange.new_api.types import ArrayLike
if TYPE_CHECKING:
from skchange.new_api.interval_scorers._base import BaseIntervalScorer
# Internal flag toggled by ``skip_validation()``. When active, ``validate_data``
# and ``check_interval_specs`` skip all sklearn input checks and only do the
# minimal work needed to keep downstream code correct (dtype/shape coercion and
# setting ``n_samples_in_`` / ``n_features_in_`` on reset). It is the caller's
# responsibility to guarantee that inputs are already valid; misuse will produce
# silent errors rather than informative validation messages.
_skip_validation: ContextVar[bool] = ContextVar(
"skchange_skip_validation", default=False
)
@contextmanager
def skip_validation():
"""Context manager that disables sklearn-level input validation.
Within the context, ``validate_data`` and ``check_interval_specs`` skip
``check_array`` and related sklearn checks (dtype/finite/2d/...). They still
coerce inputs to numpy arrays and set fitted attributes (``n_samples_in_``,
``n_features_in_``) on reset, so estimator state remains consistent.
Only use this when the caller can guarantee inputs are already validated
-- typically inside an inner loop where the same already-checked data is
re-evaluated many times.
"""
token = _skip_validation.set(True)
try:
yield
finally:
_skip_validation.reset(token)
[docs]
def validate_data(
_estimator: BaseEstimator,
/,
X: ArrayLike,
**kwargs,
) -> np.ndarray:
"""Validate X and set n_features_in_ and n_samples_in_ on the estimator.
Thin wrapper around sklearn's ``validate_data`` that additionally stores
the number of samples as ``_estimator.n_samples_in_`` when ``reset=True``
(i.e. during fit), which is required for default penalty computation.
Within a :func:`skip_validation` context, sklearn's checks are bypassed
and ``X`` is only coerced to a numpy array; the fitted attributes are
still set when ``reset=True``.
Parameters
----------
_estimator : BaseEstimator
The estimator being fitted or applied. Modified in-place.
X : array-like of shape (n_samples, n_features)
Data to validate.
**kwargs
Forwarded to ``sklearn.utils.validation.validate_data``.
Returns
-------
X : ndarray of shape (n_samples, n_features)
Validated array.
"""
if _skip_validation.get():
X = np.asarray(X)
if kwargs.get("ensure_2d", True) and X.ndim == 1:
X = X.reshape(-1, 1)
if kwargs.get("reset", True):
_estimator.n_features_in_ = X.shape[1] if X.ndim >= 2 else 1
_estimator.n_samples_in_ = X.shape[0]
return X
X = _sklearn_validate_data(_estimator, X, **kwargs)
if kwargs.get("reset", True):
_estimator.n_samples_in_ = X.shape[0]
return X
def check_is_fitted(estimator, attributes=None, *, msg=None, all_or_any=all) -> None:
"""Drop-in wrapper around sklearn's ``check_is_fitted``.
Returns immediately inside a :func:`skip_validation` context (the caller
guarantees the estimator is fitted). Otherwise delegates to sklearn's
implementation with identical signature.
"""
if _skip_validation.get():
return
_sklearn_check_is_fitted(estimator, attributes, msg=msg, all_or_any=all_or_any)
def check_time_col(
X: np.ndarray,
time_col: int,
caller_name: str,
) -> None:
"""Validate the timestamp column of ``X`` for linear trend estimators.
.. experimental::
This helper supports the experimental ``time_col`` feature. Its interface
may change as the feature stabilises.
Checks that ``time_col`` is in range, that at least one value column
remains, and that the timestamp column contains finite, strictly
monotonically increasing values.
Parameters
----------
X : ndarray of shape (n_samples, n_features)
Already-validated data array.
time_col : int
Index of the timestamp column.
caller_name : str
Name of the calling estimator, used in error messages.
Raises
------
ValueError
If any of the above conditions are violated.
"""
n_features = X.shape[1]
if not (0 <= time_col < n_features):
raise ValueError(
f"{caller_name}: time_col={time_col} is out of range for data "
f"with {n_features} columns."
)
if n_features < 2:
raise ValueError(
f"{caller_name} with time_col={time_col} requires at least 2 "
f"features (1 timestamp + 1 value column), but got "
f"n_features={n_features}."
)
time_stamps = X[:, time_col]
if not np.all(np.isfinite(time_stamps)):
raise ValueError(
f"{caller_name}: time_col contains non-finite values. "
"Timestamps must be finite numbers."
)
if not np.all(np.diff(time_stamps) > 0):
raise ValueError(
f"{caller_name}: time_col must be strictly monotonically increasing."
)
[docs]
def check_interval_specs(
interval_specs: ArrayLike,
n_cols: int,
*,
n_samples: int | None = None,
check_sorted: bool = False,
min_size: int | None = None,
caller_name: str | None = None,
arg_name: str = "interval_specs",
) -> np.ndarray:
"""Validate an interval_specs array.
Always checks that the input is a 2D integer array with exactly ``n_cols``
columns. Heavier checks are opt-in.
Parameters
----------
interval_specs : array-like of shape (n_interval_specs, n_cols)
Interval specifications to validate.
n_cols : int
Required number of columns.
n_samples : int or None, default=None
If given, checks that all entries are in ``[0, n_samples]``.
check_sorted : bool, default=False
If ``True``, checks that each row is strictly increasing, i.e.
``interval_specs[i, 0] < interval_specs[i, 1] < ...`` for every row.
min_size : int or None, default=None
If given, checks that adjacent entries in each row differ by at least
``min_size``, i.e. ``interval_specs[i, j+1] - interval_specs[i, j] >= min_size``
for every row and column pair. Implies strict ordering when ``min_size >= 1``.
caller_name : str or None, default=None
Name of the calling function or class. Used in error messages.
arg_name : str, default="interval_specs"
Name of the argument being validated. Used in error messages.
Returns
-------
interval_specs : ndarray of shape (n_interval_specs, n_cols)
Validated array with dtype ``np.intp``.
Raises
------
ValueError
If any check fails.
"""
interval_specs = np.asarray(interval_specs, dtype=np.intp)
# Empty inputs may occur. E.g. in MovingWindow for large window sizes and
# short series.
if interval_specs.size == 0:
return interval_specs
if _skip_validation.get():
# Caller guarantees the array is well-formed (2D, correct shape, dtype,
# ordered, in-range). Skip all sklearn-level checks.
return interval_specs
interval_specs = check_array(interval_specs, ensure_2d=True, dtype=np.intp)
if interval_specs.shape[1] != n_cols:
raise ValueError(
f"`{arg_name}` must have {n_cols} columns, "
f"got {interval_specs.shape[1]} in {caller_name}."
)
if interval_specs.size > 0 and (check_sorted or min_size is not None):
diffs = np.diff(interval_specs, axis=1)
if check_sorted and min_size is None and np.any(diffs <= 0):
raise ValueError(
f"Each row in `{arg_name}` must be strictly increasing "
f"(i.e. {arg_name}[i, 0] < {arg_name}[i, 1] < ...) in {caller_name}."
)
if min_size is not None and np.any(diffs < min_size):
raise ValueError(
f"Adjacent entries in each row of `{arg_name}` must differ by at "
f"least {min_size} "
f"(i.e. {arg_name}[i, j+1] - {arg_name}[i, j] >= {min_size}) "
f"in {caller_name}."
)
if n_samples is not None and interval_specs.size > 0:
out_of_range = interval_specs[
(interval_specs < 0) | (interval_specs > n_samples)
]
if out_of_range.size > 0:
raise ValueError(
f"`{arg_name}` entries must be in [0, {n_samples}], "
f"got e.g. {out_of_range[0]} in {caller_name}."
)
return interval_specs
[docs]
def check_interval_scorer(
scorer: "BaseIntervalScorer",
*,
ensure_score_type: list | None = None,
ensure_penalised: bool = False,
allow_penalised: bool = True,
caller_name: str | None = None,
arg_name: str = "",
) -> None:
"""Check if the given scorer is a valid interval scorer.
Parameters
----------
scorer : BaseIntervalScorer
The scorer to check.
ensure_score_type : list of str or None, default=None
If specified, the scorer's score_type tag must be one of these.
ensure_penalised : bool, default=False
If True, raises an error if the scorer is not penalised.
allow_penalised : bool, default=True
If False, raises an error if the scorer is penalised.
caller_name : str or None, default=None
Name of the caller for error messages.
arg_name : str, default=""
Name of the argument for error messages.
Raises
------
ValueError
If any of the checks fail.
"""
score_type = scorer.__sklearn_tags__().interval_scorer_tags.score_type
if ensure_score_type and score_type not in ensure_score_type:
_required_tasks = [f'"{task}"' for task in ensure_score_type]
tasks_str = (
", ".join(_required_tasks[:-1]) + " or " + _required_tasks[-1]
if len(_required_tasks) > 1
else _required_tasks[0]
)
raise ValueError(
f"{caller_name} requires `{arg_name}` to have score_type {tasks_str}"
f" ({arg_name}.__sklearn_tags__().interval_scorer_tags.score_type "
f"in {ensure_score_type}). "
f'Got {scorer.__class__.__name__}, which has score_type "{score_type}".'
)
if (
ensure_penalised
and not scorer.__sklearn_tags__().interval_scorer_tags.penalised
):
raise ValueError(
f"{caller_name} requires `{arg_name}` to be a penalised scorer "
f"({arg_name}.__sklearn_tags__().interval_scorer_tags.penalised == True). "
f"Got {scorer.__class__.__name__}, which is not penalised."
)
if not allow_penalised and scorer.__sklearn_tags__().interval_scorer_tags.penalised:
raise ValueError(
f"{caller_name} requires `{arg_name}` to be an unpenalised scorer "
f"({arg_name}.__sklearn_tags__().interval_scorer_tags.penalised == False). "
f"Got {scorer.__class__.__name__}, which is penalised."
)
[docs]
def check_penalty(
penalty: float | ArrayLike,
*,
ensure_non_decreasing: bool = True,
copy: bool = True,
caller_name: str | None = None,
arg_name: str = "penalty",
) -> float | np.ndarray:
"""Check if the given penalty is valid.
Parameters
----------
penalty : float | ArrayLike
The penalty to check.
ensure_non_decreasing : bool, default=True
If True, the penalty must be non-decreasing.
copy : bool, default=True
Whether to copy the penalty array. Ignored if penalty is a scalar.
caller_name : str | None, default=None
The name of the caller. Used for error messages.
arg_name : str, default="penalty"
The name of the argument. Used for error messages.
"""
penalty = np.asarray(penalty).squeeze()
if penalty.ndim == 0:
penalty = penalty.reshape(1)
penalty = check_array(
penalty,
ensure_2d=False, # penalty should be 1D after squeezing.
dtype=np.float64,
copy=copy,
ensure_all_finite=True,
)
if penalty.ndim != 1:
raise ValueError(
f"`{arg_name}` must be a 1D array in {caller_name}."
f" Got {penalty.ndim}D array."
)
if not np.all(penalty >= 0.0):
raise ValueError(f"`{arg_name}` must be non-negative in {caller_name}")
if ensure_non_decreasing and penalty.size > 1 and np.any(np.diff(penalty) < 0):
raise ValueError(f"`{arg_name}` must be non-decreasing in {caller_name}")
if penalty.size == 1:
return float(penalty[0])
return penalty