Module robofish.io.validation

Expand source code
from robofish.io.file import File
from robofish.io.entity import Entity
import re
import h5py
import numpy as np
import logging
import warnings


def assert_validate(
    statement: bool, message: str, location: str = None, strict_validate=True
) -> None:
    """Assert the statement and attach the entity name to the error message.

    Args:
        statement: The statement, which should be tested.
        message: The asertion message
        location: optional location of the validation (e.g. entity 'fish_1', sampling '25hz')
    Throws:
        AssertionError: If the statement is false
    """
    if not statement:
        if location:
            message = "%s in %s" % (message, location)

        if strict_validate:
            raise AssertionError(message)
        else:
            logging.warning(message)


def assert_validate_type(
    object, expected_type, object_name: str, location: str = None
) -> None:
    """Assert the statement and attach the entity name to the error message.

    Args:
        statement: The statement, which should be tested.
        object_name: The name of the object
        entity_name: optional name of the entity.
    Throws:
        AssertionError: If the statement is false
    """
    if hasattr(object, "dtype"):
        type_ = object.dtype
    else:
        type_ = type(object)
    if type_ == expected_type:
        return

    # If this part was reached, either the type was not correct
    msg = "The type of %s was wrong" % object_name

    if location is not None:
        msg += " in %s" % location

    msg += ". The type was %s, when %s was expected." % (type_, expected_type)
    raise AssertionError(msg)


def validate(iofile: File, strict_validate: bool = True) -> (bool, str):
    """Validate a file to the specification.

    The function compares a given file to the robofish track format specification:
    https://git.imp.fu-berlin.de/bioroboticslab/robofish/track_format
    First all specified arrays are formatted to be numpy arrays with the specified
    datatype. Then all specified shapes are validated. Lastly calendar points
    are validated to be datetimes according to ISO8601.

    Args:
        track: A track as a dictionary
        strict_validate: Throw an exception instead of just returning false.
    Returns:
        The function returns a touple of validity and an error message
    Throws:
        AssertionError: When the file is invalid and strict_validate is True
    """

    msg = ""

    # Initialize variables
    outlines, poses = None, None
    common_poses_shape = None
    try:
        # validate attributes
        expected_dtypes = {
            "format_version": np.int32,
            "format_url": str,
            "world_size_cm": np.float32,
        }
        for a, a_type in expected_dtypes.items():
            assert_validate(a in iofile.attrs, f'Attribute "{a}" missing","root')
            assert_validate_type(iofile.attrs[a], a_type, a, "root")

        # validate samplings
        assert_validate("samplings" in iofile, "samplings not found")
        for s_name, sampling in iofile["samplings"].items():
            assert_validate(
                "frequency_hz" in sampling.attrs
                or "monotonic_time_points_us" in sampling,
                "Neither frequency nor monotonic_time_points_us was defined",
                "sampling %s" % s_name,
            )

            expected_dtypes = {
                "frequency_hz": np.float32,
                "monotonic_time_points_us": np.int64,
                "monotonic_calendar_points": "S",
            }

            for a, a_type in expected_dtypes.items():
                if a in sampling.attrs:
                    assert_validate_type(
                        sampling.attrs[a], a_type, a, f"sampling {s_name}"
                    )
                if a in sampling:
                    assert_validate_type(sampling[a], a_type, a, f"sampling {s_name}")

            if "monotonic_time_points_us" in sampling:
                time_points = sampling["monotonic_time_points_us"]
                # 1 dimensional array
                assert_validate(
                    time_points.ndim == 1,
                    "Dimensionality of monotonic_time_points_us should be 1",
                    f"sampling {s_name}",
                )

                assert_validate(
                    np.all(np.diff(time_points) >= 0),
                    "monotonic_time_points_us is not monotonic",
                    f"sampling {s_name}",
                )

            # calendar points
            if "calendar_time_points" in sampling:
                calendar_points = sampling["calendar_time_points"]
                assert_validate(
                    calendar_points.ndim == 1,
                    "Dimensionality of calendar_time_points should be 1",
                    f"sampling {s_name}",
                )
                if "monotonic_time_points_us" in sampling:
                    assert_validate(
                        calendar_points.shape[0] == time_points.shape[0],
                        "The length of calendar points (%d) does not match the length of monotonic points (%d)"
                        % (calendar_points.shape[0], time_points.shape[0]),
                        f"sampling {s_name}",
                    )

                # validate iso8601, this validates the dtype implicitly
                for c in calendar_points.asstr(encoding="utf-8"):
                    assert_validate(
                        validate_iso8601(c),
                        "%s does not match iso8601" % c,
                        f"sampling {s_name}",
                    )

        # validate entities
        assert_validate("entities" in iofile, "entities not found")
        for entity in iofile.entities:

            e_name = entity.name

            assert_validate(
                type(entity) == Entity,
                "Entity group was not a robofish.io.Entity object",
                e_name,
            )
            assert_validate(
                "category" in entity.attrs
                and isinstance(entity.attrs["category"], str),
                'Attribute "category" not found',
                e_name,
            )

            expected_dtypes = {"poses": np.float32, "outlines": np.float32}

            for a, a_type in expected_dtypes.items():
                if a in entity:
                    assert_validate(
                        entity[a].dtype == a_type,
                        f'The type of dataset "{a}" should be "{a_type}" but was "{entity[a].dtype}" in root',
                        e_name,
                    )
            if "poses" in entity:
                raise Exception(
                    "The poses dataset is deprecated. Please use positions and orientations."
                )
            if "positions" in entity:
                assert_validate(
                    isinstance(entity["positions"], h5py.Dataset),
                    'Dataset "positions" not found',
                    e_name,
                )

                positions = entity["positions"]
                assert_validate(
                    positions.ndim == 2,
                    "Dimensionality of positions should be 2",
                    e_name,
                )

                assert_validate(
                    positions.shape[1] == 2,
                    "The second dimension of positions should have the length 2",
                    e_name,
                )

                if positions.shape[0] > 0:
                    # validate range of poses
                    validate_positions_range(
                        iofile.attrs["world_size_cm"], positions, e_name
                    )

                if common_poses_shape != None and positions.shape != common_poses_shape:
                    warnings.warn(
                        f"The shape of positions for {entity.name} was different than the common shape {common_poses_shape}."
                    )
                common_poses_shape = positions.shape

            if "orientations" in entity:
                assert_validate(
                    "positions" in entity,
                    "orientations cannot exist without positions",
                    e_name,
                )

                assert_validate(
                    isinstance(entity["orientations"], h5py.Dataset),
                    'Dataset "orientations" not found',
                    e_name,
                )

                orientations = entity["orientations"]
                assert_validate(
                    orientations.ndim == 2,
                    "Dimensionality of orientations should be 2",
                    e_name,
                )

                assert_validate(
                    orientations.shape[1] == 2,
                    "The second dimension of orientations should have the length 2",
                    e_name,
                )

                if (
                    common_poses_shape != None
                    and orientations.shape != common_poses_shape
                ):
                    warnings.warn(
                        "The shape of orientations for {entity.name} was different than the common shape {common_poses_shape}."
                    )

                if strict_validate:
                    validate_orientations_length(orientations, e_name)

            # outlines
            if "outlines" in entity:
                outlines = entity["outlines"]

                assert_validate(
                    outlines.ndim == 3, "Dimensionality of outlines should be 3", e_name
                )

                # Either fixed outline or same length with poses
                assert_validate(
                    outlines.shape[0] == 1
                    or poses is None
                    or outlines.shape[0] == poses.shape[0],
                    "The outline has to be either fixed or it has to have the same length as poses",
                )

                # Outline from two dimensional points
                assert_validate(
                    outlines.shape[2] == 2,
                    "The third dimension of outlines should have the length 3",
                    e_name,
                )

            # time
            # TODO: Implement, test, and uncomment
            # Either fixed in place or same length with poses
            # assert_validate(
            #     poses.shape[0] == 1 or monotonic_points.shape[0] == poses.shape[0],
            #     "Monotonic points has to have the same length as poses (%d), but the length was %d. The entity is not fixed in place."
            #     % (poses.shape[0], monotonic_points.shape[0]),
            #     e_name,
            # )

            # Either there is no outline, or fixed outline or same length with outline
            # assert_validate(
            #     outlines is None
            #     or outlines.shape[0] == 1
            #     or monotonic_points.shape[0] == outlines.shape[0],
            #     "The specified outline has to have the length 1 (fixed outline) or the same length as monotonic points",
            #     e_name,
            # )
            # else:
            # Fixed in Place and fixed outline
            # assert_validate(
            #     poses.shape[0] == 1,
            #     "There was no temporal definition (monotonic step or monotonic points) and the entity is not fixed in place",
            #     e_name,
            # )
            # assert_validate(
            #     outlines is None or outlines.shape[0] == 1,
            #     "There was no temporal definition (monotonic step or monotonic points) and the entity does not have a fixed outline",
            #     e_name,
            # )

    except AssertionError as e:
        if strict_validate:
            raise e
        else:
            logging.warning(e)
            return (False, e)
    return (True, f"Common positions/ orientations shape: {common_poses_shape}")


def validate_positions_range(world_size, positions, e_name):
    # positions which are just a bit over the world edge are fine
    error_allowance = 1.01

    # Remove rows where there is any nan
    positions = np.array(positions)[~np.isnan(positions).any(axis=1)]

    allowed_x = [
        -1 * world_size[0] * error_allowance / 2,
        world_size[0] * error_allowance / 2,
    ]

    real_x = [positions[:, 0].min(), positions[:, 0].max()]

    allowed_y = [
        -1 * world_size[1] * error_allowance / 2.0,
        world_size[1] * error_allowance / 2.0,
    ]
    real_y = [positions[:, 1].min(), positions[:, 1].max()]

    assert_validate(
        allowed_x[0] <= real_x[0] and real_x[1] <= allowed_x[1],
        "Positions of x axis were not in range. The allowed range is [%.1f, %.1f], which was [%.1f, %.1f] in the Positions"
        % (allowed_x[0], allowed_x[1], real_x[0], real_x[1]),
        e_name,
    )
    assert_validate(
        allowed_y[0] <= real_y[0] and real_y[1] <= allowed_y[1],
        "Positions of y axis were not in range. The allowed range is [%.1f, %.1f], which was [%.1f, %.1f] in the Positions"
        % (allowed_y[0], allowed_y[1], real_y[0], real_y[1]),
        e_name,
    )


def validate_orientations_length(orientations, e_name):

    # Remove rows where there is any nan
    orientations = np.array(orientations)[~np.isnan(orientations).any(axis=1)]

    ori_lengths = np.linalg.norm(orientations, axis=1)

    # Check if all orientation lengths are all 1. Different lengths cause warnings.
    assert_validate(
        np.isclose(ori_lengths, 1).all(),
        "The orientation vectors were not unit vectors. Their length was in the range [%.4f, %.4f] when it should be 1"
        % (min(ori_lengths), max(ori_lengths)),
        e_name,
        strict_validate=False,
    )


def validate_iso8601(str_val: str) -> bool:
    """This function validates strings to match the ISO8601 format.

    The source of the regex is https://stackoverflow.com/questions/41129921/validate-an-iso-8601-datetime-string-in-python

    Args:
        str_val: A string to be validated
    Returns:
        bool: validity of the string to iso8601
    """
    regex_iso8601 = r"^(-?(?:[1-9][0-9]*)?[0-9]{4})-(1[0-2]|0[1-9])-(3[01]|0[1-9]|[12][0-9])T(2[0-3]|[01][0-9]):([0-5][0-9]):([0-5][0-9])(\.[0-9]+)(Z|[+-](?:2[0-3]|[01][0-9]):[0-5][0-9])$"
    match_iso8601 = re.compile(regex_iso8601).match
    return match_iso8601(str_val) is not None

Functions

def assert_validate(statement: bool, message: str, location: str = None, strict_validate=True) ‑> None

Assert the statement and attach the entity name to the error message.

Args

statement
The statement, which should be tested.
message
The asertion message
location
optional location of the validation (e.g. entity 'fish_1', sampling '25hz')

Throws

AssertionError: If the statement is false

Expand source code
def assert_validate(
    statement: bool, message: str, location: str = None, strict_validate=True
) -> None:
    """Assert the statement and attach the entity name to the error message.

    Args:
        statement: The statement, which should be tested.
        message: The asertion message
        location: optional location of the validation (e.g. entity 'fish_1', sampling '25hz')
    Throws:
        AssertionError: If the statement is false
    """
    if not statement:
        if location:
            message = "%s in %s" % (message, location)

        if strict_validate:
            raise AssertionError(message)
        else:
            logging.warning(message)
def assert_validate_type(object, expected_type, object_name: str, location: str = None) ‑> None

Assert the statement and attach the entity name to the error message.

Args

statement
The statement, which should be tested.
object_name
The name of the object
entity_name
optional name of the entity.

Throws

AssertionError: If the statement is false

Expand source code
def assert_validate_type(
    object, expected_type, object_name: str, location: str = None
) -> None:
    """Assert the statement and attach the entity name to the error message.

    Args:
        statement: The statement, which should be tested.
        object_name: The name of the object
        entity_name: optional name of the entity.
    Throws:
        AssertionError: If the statement is false
    """
    if hasattr(object, "dtype"):
        type_ = object.dtype
    else:
        type_ = type(object)
    if type_ == expected_type:
        return

    # If this part was reached, either the type was not correct
    msg = "The type of %s was wrong" % object_name

    if location is not None:
        msg += " in %s" % location

    msg += ". The type was %s, when %s was expected." % (type_, expected_type)
    raise AssertionError(msg)
def validate(iofile: File, strict_validate: bool = True) ‑> ()

Validate a file to the specification.

The function compares a given file to the robofish track format specification: https://git.imp.fu-berlin.de/bioroboticslab/robofish/track_format First all specified arrays are formatted to be numpy arrays with the specified datatype. Then all specified shapes are validated. Lastly calendar points are validated to be datetimes according to ISO8601.

Args

track
A track as a dictionary
strict_validate
Throw an exception instead of just returning false.

Returns

The function returns a touple of validity and an error message

Throws

AssertionError: When the file is invalid and strict_validate is True

Expand source code
def validate(iofile: File, strict_validate: bool = True) -> (bool, str):
    """Validate a file to the specification.

    The function compares a given file to the robofish track format specification:
    https://git.imp.fu-berlin.de/bioroboticslab/robofish/track_format
    First all specified arrays are formatted to be numpy arrays with the specified
    datatype. Then all specified shapes are validated. Lastly calendar points
    are validated to be datetimes according to ISO8601.

    Args:
        track: A track as a dictionary
        strict_validate: Throw an exception instead of just returning false.
    Returns:
        The function returns a touple of validity and an error message
    Throws:
        AssertionError: When the file is invalid and strict_validate is True
    """

    msg = ""

    # Initialize variables
    outlines, poses = None, None
    common_poses_shape = None
    try:
        # validate attributes
        expected_dtypes = {
            "format_version": np.int32,
            "format_url": str,
            "world_size_cm": np.float32,
        }
        for a, a_type in expected_dtypes.items():
            assert_validate(a in iofile.attrs, f'Attribute "{a}" missing","root')
            assert_validate_type(iofile.attrs[a], a_type, a, "root")

        # validate samplings
        assert_validate("samplings" in iofile, "samplings not found")
        for s_name, sampling in iofile["samplings"].items():
            assert_validate(
                "frequency_hz" in sampling.attrs
                or "monotonic_time_points_us" in sampling,
                "Neither frequency nor monotonic_time_points_us was defined",
                "sampling %s" % s_name,
            )

            expected_dtypes = {
                "frequency_hz": np.float32,
                "monotonic_time_points_us": np.int64,
                "monotonic_calendar_points": "S",
            }

            for a, a_type in expected_dtypes.items():
                if a in sampling.attrs:
                    assert_validate_type(
                        sampling.attrs[a], a_type, a, f"sampling {s_name}"
                    )
                if a in sampling:
                    assert_validate_type(sampling[a], a_type, a, f"sampling {s_name}")

            if "monotonic_time_points_us" in sampling:
                time_points = sampling["monotonic_time_points_us"]
                # 1 dimensional array
                assert_validate(
                    time_points.ndim == 1,
                    "Dimensionality of monotonic_time_points_us should be 1",
                    f"sampling {s_name}",
                )

                assert_validate(
                    np.all(np.diff(time_points) >= 0),
                    "monotonic_time_points_us is not monotonic",
                    f"sampling {s_name}",
                )

            # calendar points
            if "calendar_time_points" in sampling:
                calendar_points = sampling["calendar_time_points"]
                assert_validate(
                    calendar_points.ndim == 1,
                    "Dimensionality of calendar_time_points should be 1",
                    f"sampling {s_name}",
                )
                if "monotonic_time_points_us" in sampling:
                    assert_validate(
                        calendar_points.shape[0] == time_points.shape[0],
                        "The length of calendar points (%d) does not match the length of monotonic points (%d)"
                        % (calendar_points.shape[0], time_points.shape[0]),
                        f"sampling {s_name}",
                    )

                # validate iso8601, this validates the dtype implicitly
                for c in calendar_points.asstr(encoding="utf-8"):
                    assert_validate(
                        validate_iso8601(c),
                        "%s does not match iso8601" % c,
                        f"sampling {s_name}",
                    )

        # validate entities
        assert_validate("entities" in iofile, "entities not found")
        for entity in iofile.entities:

            e_name = entity.name

            assert_validate(
                type(entity) == Entity,
                "Entity group was not a robofish.io.Entity object",
                e_name,
            )
            assert_validate(
                "category" in entity.attrs
                and isinstance(entity.attrs["category"], str),
                'Attribute "category" not found',
                e_name,
            )

            expected_dtypes = {"poses": np.float32, "outlines": np.float32}

            for a, a_type in expected_dtypes.items():
                if a in entity:
                    assert_validate(
                        entity[a].dtype == a_type,
                        f'The type of dataset "{a}" should be "{a_type}" but was "{entity[a].dtype}" in root',
                        e_name,
                    )
            if "poses" in entity:
                raise Exception(
                    "The poses dataset is deprecated. Please use positions and orientations."
                )
            if "positions" in entity:
                assert_validate(
                    isinstance(entity["positions"], h5py.Dataset),
                    'Dataset "positions" not found',
                    e_name,
                )

                positions = entity["positions"]
                assert_validate(
                    positions.ndim == 2,
                    "Dimensionality of positions should be 2",
                    e_name,
                )

                assert_validate(
                    positions.shape[1] == 2,
                    "The second dimension of positions should have the length 2",
                    e_name,
                )

                if positions.shape[0] > 0:
                    # validate range of poses
                    validate_positions_range(
                        iofile.attrs["world_size_cm"], positions, e_name
                    )

                if common_poses_shape != None and positions.shape != common_poses_shape:
                    warnings.warn(
                        f"The shape of positions for {entity.name} was different than the common shape {common_poses_shape}."
                    )
                common_poses_shape = positions.shape

            if "orientations" in entity:
                assert_validate(
                    "positions" in entity,
                    "orientations cannot exist without positions",
                    e_name,
                )

                assert_validate(
                    isinstance(entity["orientations"], h5py.Dataset),
                    'Dataset "orientations" not found',
                    e_name,
                )

                orientations = entity["orientations"]
                assert_validate(
                    orientations.ndim == 2,
                    "Dimensionality of orientations should be 2",
                    e_name,
                )

                assert_validate(
                    orientations.shape[1] == 2,
                    "The second dimension of orientations should have the length 2",
                    e_name,
                )

                if (
                    common_poses_shape != None
                    and orientations.shape != common_poses_shape
                ):
                    warnings.warn(
                        "The shape of orientations for {entity.name} was different than the common shape {common_poses_shape}."
                    )

                if strict_validate:
                    validate_orientations_length(orientations, e_name)

            # outlines
            if "outlines" in entity:
                outlines = entity["outlines"]

                assert_validate(
                    outlines.ndim == 3, "Dimensionality of outlines should be 3", e_name
                )

                # Either fixed outline or same length with poses
                assert_validate(
                    outlines.shape[0] == 1
                    or poses is None
                    or outlines.shape[0] == poses.shape[0],
                    "The outline has to be either fixed or it has to have the same length as poses",
                )

                # Outline from two dimensional points
                assert_validate(
                    outlines.shape[2] == 2,
                    "The third dimension of outlines should have the length 3",
                    e_name,
                )

            # time
            # TODO: Implement, test, and uncomment
            # Either fixed in place or same length with poses
            # assert_validate(
            #     poses.shape[0] == 1 or monotonic_points.shape[0] == poses.shape[0],
            #     "Monotonic points has to have the same length as poses (%d), but the length was %d. The entity is not fixed in place."
            #     % (poses.shape[0], monotonic_points.shape[0]),
            #     e_name,
            # )

            # Either there is no outline, or fixed outline or same length with outline
            # assert_validate(
            #     outlines is None
            #     or outlines.shape[0] == 1
            #     or monotonic_points.shape[0] == outlines.shape[0],
            #     "The specified outline has to have the length 1 (fixed outline) or the same length as monotonic points",
            #     e_name,
            # )
            # else:
            # Fixed in Place and fixed outline
            # assert_validate(
            #     poses.shape[0] == 1,
            #     "There was no temporal definition (monotonic step or monotonic points) and the entity is not fixed in place",
            #     e_name,
            # )
            # assert_validate(
            #     outlines is None or outlines.shape[0] == 1,
            #     "There was no temporal definition (monotonic step or monotonic points) and the entity does not have a fixed outline",
            #     e_name,
            # )

    except AssertionError as e:
        if strict_validate:
            raise e
        else:
            logging.warning(e)
            return (False, e)
    return (True, f"Common positions/ orientations shape: {common_poses_shape}")
def validate_iso8601(str_val: str) ‑> bool

This function validates strings to match the ISO8601 format.

The source of the regex is https://stackoverflow.com/questions/41129921/validate-an-iso-8601-datetime-string-in-python

Args

str_val
A string to be validated

Returns

bool
validity of the string to iso8601
Expand source code
def validate_iso8601(str_val: str) -> bool:
    """This function validates strings to match the ISO8601 format.

    The source of the regex is https://stackoverflow.com/questions/41129921/validate-an-iso-8601-datetime-string-in-python

    Args:
        str_val: A string to be validated
    Returns:
        bool: validity of the string to iso8601
    """
    regex_iso8601 = r"^(-?(?:[1-9][0-9]*)?[0-9]{4})-(1[0-2]|0[1-9])-(3[01]|0[1-9]|[12][0-9])T(2[0-3]|[01][0-9]):([0-5][0-9]):([0-5][0-9])(\.[0-9]+)(Z|[+-](?:2[0-3]|[01][0-9]):[0-5][0-9])$"
    match_iso8601 = re.compile(regex_iso8601).match
    return match_iso8601(str_val) is not None
def validate_orientations_length(orientations, e_name)
Expand source code
def validate_orientations_length(orientations, e_name):

    # Remove rows where there is any nan
    orientations = np.array(orientations)[~np.isnan(orientations).any(axis=1)]

    ori_lengths = np.linalg.norm(orientations, axis=1)

    # Check if all orientation lengths are all 1. Different lengths cause warnings.
    assert_validate(
        np.isclose(ori_lengths, 1).all(),
        "The orientation vectors were not unit vectors. Their length was in the range [%.4f, %.4f] when it should be 1"
        % (min(ori_lengths), max(ori_lengths)),
        e_name,
        strict_validate=False,
    )
def validate_positions_range(world_size, positions, e_name)
Expand source code
def validate_positions_range(world_size, positions, e_name):
    # positions which are just a bit over the world edge are fine
    error_allowance = 1.01

    # Remove rows where there is any nan
    positions = np.array(positions)[~np.isnan(positions).any(axis=1)]

    allowed_x = [
        -1 * world_size[0] * error_allowance / 2,
        world_size[0] * error_allowance / 2,
    ]

    real_x = [positions[:, 0].min(), positions[:, 0].max()]

    allowed_y = [
        -1 * world_size[1] * error_allowance / 2.0,
        world_size[1] * error_allowance / 2.0,
    ]
    real_y = [positions[:, 1].min(), positions[:, 1].max()]

    assert_validate(
        allowed_x[0] <= real_x[0] and real_x[1] <= allowed_x[1],
        "Positions of x axis were not in range. The allowed range is [%.1f, %.1f], which was [%.1f, %.1f] in the Positions"
        % (allowed_x[0], allowed_x[1], real_x[0], real_x[1]),
        e_name,
    )
    assert_validate(
        allowed_y[0] <= real_y[0] and real_y[1] <= allowed_y[1],
        "Positions of y axis were not in range. The allowed range is [%.1f, %.1f], which was [%.1f, %.1f] in the Positions"
        % (allowed_y[0], allowed_y[1], real_y[0], real_y[1]),
        e_name,
    )