import numpy as np
import pandas as pd
from typing import List, Iterable
from canopy.core.field import Field
import canopy.core.frameops as frameops
[docs]
def check_field_contains_layers(field: Field, layers: str | Iterable[str], name: str = 'field', raise_exception: bool = False) -> bool:
"""Check if field contains required layers
Parameters
----------
field : Field
The field whose layers to check
layers : str | list[str]
A string or a list of strings, identifying the required layer(s)
name: str = 'field'
The name of the field for logging purposes
raise_exception: bool
If True, an exception is raised if the check does not pass (default: False)
"""
if isinstance(layers, str):
layers = [layers]
not_found = []
for layer in layers:
if layer not in field.layers:
not_found.append(layer)
if len(not_found) and raise_exception:
raise ValueError(f"Layers {not_found} not found in {name}'s layers ({field.layers}).")
return len(not_found) == 0
[docs]
def check_spatial_coords_match(field1: Field, field2: Field, atol: float = 1.e-7, rtol: float = 0., raise_exception: bool = False) -> bool:
"""Check if spatial coordinates of two fields match up to given tolerance
Parameters
----------
field1 : Field
The first of the two fields whose coordinates to compare
field2 : Field
The second of the two fields whose coordinates to compare
atol : float
Absolute tolerance to apply in the comparison
rtol : float
Relative tolerance to apply in the comparison
raise_exception: bool
If True, an exception is raised if the check does not pass (default: False)
Notes
-----
Absolute and relative tolerances are defined as in Numpy, i.e., two numbers a and b are equivalent if the following
equation is fulfilled:
absolute(a - b) <= (atol + rtol * absolute(b))
See https://numpy.org/doc/stable/reference/generated/numpy.isclose.html#numpy.isclose
"""
gridlist1 = np.array(field1.data.index.droplevel('time').drop_duplicates().to_frame())
gridlist2 = np.array(field2.data.index.droplevel('time').drop_duplicates().to_frame())
try:
gridlists_match = np.allclose(gridlist1, gridlist2, atol=atol, rtol=rtol)
# If gridlists don't have the same length, the above comparison will fail
except ValueError:
gridlists_match = False
if not gridlists_match and raise_exception:
raise ValueError("Gridlists do not match.")
return gridlists_match
[docs]
def check_indices_match(field1: Field, field2: Field):
"""Check if the indices of the DataFrames of two fields match up to a given tolerance
Parameters
----------
field1 : Field
The first of the two fields whose coordinates to compare
field2 : Field
The second of the two fields whose coordinates to compare
atol : float
Absolute tolerance to apply in the comparison
rtol : float
Relative tolerance to apply in the comparison
Notes
-----
Absolute and relative tolerances are defined as in Numpy, i.e., two numbers a and b are equivalent if the following
equation is fulfilled:
absolute(a - b) <= (atol + rtol * absolute(b))
See https://numpy.org/doc/stable/reference/generated/numpy.isclose.html#numpy.isclose
"""
frameops.check_indices_match(field1.data, field2.data)
[docs]
def check_time_series_match(fields: List[Field], raise_exception: bool = False):
"""Check if fields have the same time frequency and span
Parameters
----------
fields : List[Field]
A list of fields
raise_exception : bool
If True, an exception is raised if the check does not pass (default: False)
Returns
-------
True if time series match, False otherwise
"""
ts0_freq = fields[0].time_freq
ts0 = fields[0].data.index.get_level_values('time').unique()
ts0_len = len(ts0)
for field in fields[1:]:
ts_freq = field.time_freq
if ts_freq != ts0_freq:
match = False
msg = "Time series have different frequencies."
break
ts = field.data.index.get_level_values('time').unique()
if len(ts) != ts0_len:
match = False
msg = "The time series of the supplied fields are not all equal."
break
# If the loop finishes successfully (without any breaks) then the last value assigned
# to match will be True
match = (ts0 == ts).all()
if not match:
msg = "Time series do not match."
break
if not match and raise_exception:
raise ValueError(msg)
return match
[docs]
def check_time_series_consistency(field: Field, raise_exception: bool = False) -> bool:
"""Check if all spatial locations in a field have the same time series
Parameters
----------
field : Field
The field whose time series to check
raise_exception : bool
If True, an exception is raised if the check does not pass (default: False)
"""
df = field.data.reset_index('time')[['time']].pivot(columns='time', values='time')
nans_detected = df.isna().values.any()
if raise_exception and nans_detected:
raise ValueError("The time series are not equal.")
return not nans_detected
[docs]
def check_grids_have_same_axes(fields: List[Field], raise_exception: bool = False) -> bool:
"""Check if grids have same axes"""
check = True
for field in fields[1:]:
if field.grid.xaxis != field[0].grid.xaxis \
or field.grid.yaxis != field[0].grid.yaxis:
check = False
break
x, y = field.grid.axis_names
if field.grid.is_reduced(x) != field[0].is_reduced(x) \
or field.grid.is_reduced(y) != field[0].is_reduced(y):
check = False
break
if raise_exception and not check:
raise ValueError("Grids do not have the same coordinate axes")
[docs]
def check_disjoint_coords(fields: List[Field], raise_exception: bool = False) -> bool:
"""Check if no coordinates overlap between fields
Parameters
----------
fields : List[Field]
List of fields whose coordinates to check
raise_exception : bool
If True, an exception is raised if the check fails
Returns
-------
True if no coordinates overlap, False otherwise
"""
# Grids must have the same axes (names, units, etc)
check_grids_have_same_axes(fields, raise_exception=True)
check = True
coords0 = set(fields[0].data.index.droplevel('time').unique())
for field in fields[1:]:
coords = set(field.data.index.droplevel('time').unique())
if len(coords0 & coords1) == 0:
check = False
break
if raise_exception and not check:
raise ValueError("Fields do not have disjoint sets of coordinates")
return check