Module robofish.io.file

File objects are the root of the project. The object contains all information about the environment, entities, and time.

In the simplest form we define a new File with a world size in cm and a frequency in hz. Afterwards, we can save it with a path.

import robofish.io

# Create a new robofish io file
f = robofish.io.File(world_size_cm=[100, 100], frequency_hz=25.0)
f.save_as("test.hdf5")

The File object can also be generated, with a given path. In this case, we work on the file directly. The with block ensures, that the file is validated after the block.

with robofish.io.File(
    "test.hdf5", mode="x", world_size_cm=[100, 100], frequency_hz=25.0
) as f:
    # Use file f here

When opening a file with a path, a mode should be specified to describe how the file should be opened.

Mode Description
r Readonly, file must exist (default)
r+ Read/write, file must exist
w Create file, truncate if exists
x Create file, fail if exists
a Read/write if exists, create otherwise

Attributes

Attributes of the file can be added, to describe the contents. The attributes can be set like this:

f.attrs["experiment_setup"] = "This file comes from the tutorial."
f.attrs["experiment_issues"] = "All data in this file is made up."

Any attribute is allowed, but some cannonical attributes are prepared:
publication_url, video_url, tracking_software_name, tracking_software_version, tracking_software_url, experiment_setup, experiment_issues

Properties

As described in robofish.io, all properties of robofish.io.entitys can be accessed by adding the prefix entity_ to the function.

Plotting

Files have a built in plotting tool with File.plot(). With the option lw_distances = True the distance between two fish is represented throught the line width.

import robofish.io 
import matplotlib.pyplot as plt 

fig, ax = plt.subplots(1,2, figsize=(10,5))
f = robofish.io.File("...")
f.plot(ax=ax[0])
f.plot(ax=ax[1], lw_distances=True)
plt.show()

For all other options while plotting please check File.plot().

Expand source code
# -*- coding: utf-8 -*-

"""
.. include:: ../../../docs/file.md
"""

# -----------------------------------------------------------
# Utils functions for reading, validating and writing hdf5 files according to
# Robofish track format (1.0 Draft 7). The standard is available at
# https://git.imp.fu-berlin.de/bioroboticslab/robofish/track_format

#
# Dec 2020 Andreas Gerken, Berlin, Germany
# Released under GNU 3.0 License
# email andi.gerken@gmail.com
# -----------------------------------------------------------

from pytest import skip
import robofish.io
from robofish.io.entity import Entity
import h5py

import numpy as np

import logging
from typing import Iterable, Union, Tuple, List, Optional
from pathlib import Path
import shutil
import datetime
import tempfile
import uuid
import deprecation
import types
import warnings
from textwrap import wrap
import platform

import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import animation
from matplotlib import patches
from matplotlib import cm

from tqdm.auto import tqdm

from subprocess import run

# Remember: Update docstring when updating these two global variables
default_format_version = np.array([1, 0], dtype=np.int32)

default_format_url = (
    "https://git.imp.fu-berlin.de/bioroboticslab/robofish/track_format/-/releases/1.0"
)


class File(h5py.File):
    """Represents a RoboFish Track Format file, which should be used to store tracking data of individual animals or swarms.

    Files can be opened (with optional creation), modified inplace, and have copies of them saved.
    """

    _temp_dir = None

    def __init__(
        self,
        path: Union[str, Path] = None,
        mode: str = "r",
        *,  # PEP 3102
        world_size_cm: List[int] = None,
        validate: bool = False,
        validate_when_saving: bool = True,
        strict_validate: bool = False,
        format_version: List[int] = default_format_version,
        format_url: str = default_format_url,
        sampling_name: str = None,
        frequency_hz: int = None,
        monotonic_time_points_us: Iterable = None,
        calendar_time_points: Iterable = None,
        open_copy: bool = False,
        validate_poses_hash: bool = True,
    ):
        """Create a new RoboFish Track Format object.

        When called with a path, it is loaded, otherwise a new temporary
        file is created. File contents can be validated against the
        track format specification.

        Parameters
        ----------
        path : str or Path, optional
            Location of file to be opened. If not provided, mode is ignored.
        mode : str, default='r'
            'r'        Readonly, file must exist
            'r+'       Read/write, file must exist
            'w'        Create file, truncate if exists
            'x'        Create file, fail if exists
            'a'        Read/write if exists, create otherwise
        world_size_cm : [int, int] , optional
            side lengths [x, y] of the world in cm.
            rectangular world shape is assumed.
        validate: bool, default=False
            Should the track be validated? This is normally switched off for performance reasons.
        strict_validate : bool, default=False
            if the file should be strictly validated against the track
            format specification, when loaded from a path.
            TODO: Should this validate against the version sepcified in
            format_version or just against the most recent version?
        format_version : [int, int], default=[1,0]
            version [major, minor] of track format specification
        format_url : str, default="https://git.imp.fu-berlin.de/bioroboticslab/robofish/track_format/-/releases/1.0"
            location of track format specification.
            should fit `format_version`.
        sampling_name : str, optional
            How to specify your sampling:

            1. (optional)
                provide text description of your sampling in `sampling_name`

            2.a (mandatory, if you have a constant sampling frequency)
                specify `frequency_hz` with your sampling frequency in Hz

            2.b (mandatory, if you do NOT have a constant sampling frequency)
                specify `monotonic_time_points_us` with a list[1] of time
                points in microseconds on a montonic clock, one for each
                sample in your dataset.

            3.  (optional)
                specify `calendar_time_points` with a list[2] of time points
                in the ISO 8601 extended format with microsecond precision
                and time zone designator[3],  one for each sample in your
                dataset.

            [1] any Iterable of int
            [2] any Iterable of str
            [3] example:  "2020-11-18T13:21:34.117015+01:00"

        frequency_hz: int, optional
            refer to explanation of `sampling_name`
        monotonic_time_points_us: Iterable of int, optional
            refer to explanation of `sampling_name`
        calendar_time_points: Iterable of str, optional
            refer to explanation of `sampling_name`
        open_copy: bool, optional
            a temporary copy of the file will be opened instead of the file itself.
        """

        self.path = path
        self.validate_when_saving = validate_when_saving

        if open_copy:
            assert (
                path is not None
            ), "A path has to be given if a copy should be opened."

            temp_file = self.temp_dir / str(uuid.uuid4())
            logging.info(
                f"Copying file to temporary file and opening it:\n{path} -> {temp_file}"
            )

            shutil.copyfile(path, temp_file)
            super().__init__(
                temp_file,
                mode="r+",
                driver="core",
                backing_store=True,
                libver=("earliest", "v110"),
            )
            initialize = False

        elif path is None:
            temp_file = self.temp_dir / str(uuid.uuid4())
            logging.info(f"Opening New temporary file {temp_file}")
            super().__init__(
                temp_file,
                mode="x",
                driver="core",
                backing_store=True,
                libver=("earliest", "v110"),
            )
            initialize = True
        else:
            # mode
            # r        Readonly, file must exist (default)
            # r+       Read/write, file must exist
            # w        Create file, truncate if exists
            # x        Create file, fail if exists
            # a        Read/write if exists, create otherwise
            logging.info(f"Opening File {path}")

            assert mode in ["r", "r+", "w", "x", "a"], f"Unknown mode {mode}."

            # If the file does not exist or if it should be truncated with mode=w, initialize it.
            if Path(path).exists() and mode != "w":
                initialize = False
            else:
                initialize = True

            try:
                super().__init__(path, mode, libver=("earliest", "v110"))
            except OSError as e:
                raise OSError(f"Could not open file {path} with mode {mode}.\n{e}")

        if initialize:
            assert (
                world_size_cm is not None and format_version is not None
            ), "It seems like the file is already initialized. Try opening it with mode 'r+'."

            self.attrs["world_size_cm"] = np.array(world_size_cm, dtype=np.float32)
            self.attrs["format_version"] = np.array(format_version, dtype=np.int32)
            self.attrs["format_url"] = format_url

            self.create_group("entities")
            self.create_group("samplings")

            if frequency_hz is not None or monotonic_time_points_us is not None:
                self.create_sampling(
                    name=sampling_name,
                    frequency_hz=frequency_hz,
                    monotonic_time_points_us=monotonic_time_points_us,
                    calendar_time_points=calendar_time_points,
                    default=True,
                )
        else:
            # A quick validation to find h5py files which are not robofish.io files
            if any([a not in self.attrs for a in ["world_size_cm", "format_version"]]):
                msg = f"The opened file {self.path} does not include world_size_cm or format_version. It seems that the file is not a robofish.io.File."
                if strict_validate:
                    raise KeyError(msg)
                else:
                    warnings.warn(msg)
                return

            # Validate that the stored poses hash still fits.
            if validate_poses_hash:
                for entity in self.entities:
                    if "poses_hash" in entity.attrs:
                        if entity.attrs["poses_hash"] != entity.poses_hash:
                            warnings.warn(
                                f"The stored hash is not identical with the newly calculated hash. In entity {entity.name} in {self.path}. f.entity_actions_turns_speeds and f.entity_orientation_rad will return wrong results.\n"
                                f"stored: {entity.attrs['poses_hash']}, calculated: {entity.poses_hash}"
                            )
                        assert (
                            "unfinished_calculations" not in entity.attrs
                        ), f"The calculated data of file {self.path} is uncomplete and was probably aborted during calculation. please recalculate with `robofish-io-update-calculated-data {self.path}`."

                    else:
                        warnings.warn(
                            f"The file did not include pre-calculated data so the actions_speeds_turns "
                            f"and orientations_rad will have to be be recalculated everytime.\n"
                            f"Please use `robofish-io-update-calculated-data {self.path}` in the "
                            f"commandline or\nopen and close the file with robofish.io.File(f, 'r+') "
                            f"in python.\nIf the data should be recalculated every time open the file "
                            "with the bool option validate_poses_hash=False."
                        )
        if validate:
            self.validate(strict_validate)

    def __enter__(self):
        return self

    def __exit__(self, type, value, traceback):

        # Check if the context was left under normal circumstances
        if (hasattr(self, "closed") and not self.closed) and (
            type,
            value,
            traceback,
        ) == (None, None, None):
            if (
                self.mode != "r" and self.validate_when_saving
            ):  # No need to validate read only files (performance).
                self.validate()

        super().__exit__(type, value, traceback)

    def close(self):
        if self.mode != "r":
            self.update_calculated_data()
        super().close()

    def save_as(
        self,
        path: Union[str, Path],
        strict_validate: bool = True,
        no_warning: bool = False,
    ):
        """Save a copy of the file

        Args:
            path: path to a io file as a string or path object. If no path is specified, the last known path (from loading or saving) is used.
            strict_validate: optional boolean, if the file should be strictly validated, before saving. The default is True.
            no_warning: optional boolean, to remove the warning from the function.
        Returns:
            The file itself, so something like f = robofish.io.File().save_as("file.hdf5") works
        """

        self.update_calculated_data()
        self.validate(strict_validate=strict_validate)

        # Ensure all buffered data has been written to disk
        self.flush()

        path = Path(path).resolve()
        path.parent.mkdir(parents=True, exist_ok=True)

        filename = self.filename
        self.flush()
        self.close()

        self.closed = True

        shutil.copyfile(filename, path)
        if not no_warning:
            warnings.warn(
                "The 'save_as' function closes the file currently to be able to store it. If you want to use the file after saving it, please reload the file. The save_as function can be avoided by opening the correct file directly. If you want to get rid of this warning use 'save_as(..., no_warning=True)'"
            )
        return None

    def create_sampling(
        self,
        name: str = None,
        frequency_hz: int = None,
        monotonic_time_points_us: Iterable = None,
        calendar_time_points: Iterable = None,
        default: bool = False,
    ):

        # Find Name for sampling if none is given
        if name is None:
            if frequency_hz is not None:
                name = "%d hz" % frequency_hz

            i = 1
            while name is None or name in self["samplings"]:
                name = "sampling_%d" % i
                i += 1

        sampling = self["samplings"].create_group(name)

        if monotonic_time_points_us is not None:

            monotonic_time_points_us = np.array(
                monotonic_time_points_us, dtype=np.int64
            )
            sampling.create_dataset(
                "monotonic_time_points_us", data=monotonic_time_points_us
            )
            if frequency_hz is None:
                diff = np.diff(monotonic_time_points_us)
                if np.all(diff == diff[0]) and diff[0] > 0:
                    frequency_hz = 1e6 / diff[0]
                    warnings.warn(
                        f"The frequency_hz of {frequency_hz:.2f}hz was calculated automatically by robofish.io. The safer variant is to pass it using frequency_hz.\nThis is important when using fish_models with the files."
                    )

                else:
                    warnings.warn(
                        "The frequency_hz could not be calculated automatically. When using fish_models, the file will access frequency_hz."
                    )

        if frequency_hz is not None:
            sampling.attrs["frequency_hz"] = (np.float32)(frequency_hz)

        if calendar_time_points is not None:

            def format_calendar_time_point(p):
                if isinstance(p, datetime.datetime):
                    assert p.tzinfo is not None, "Missing timezone for calendar point."
                    return p.isoformat(timespec="microseconds")
                elif isinstance(p, str):
                    assert p == datetime.datetime.fromisoformat(p).isoformat(
                        timespec="microseconds"
                    )
                    return p
                else:
                    assert (
                        False
                    ), "Calendar points must be datetime.datetime instances or strings."

            calendar_time_points = [
                format_calendar_time_point(p) for p in calendar_time_points
            ]

            sampling.create_dataset(
                "calendar_time_points",
                data=calendar_time_points,
                dtype=h5py.string_dtype(encoding="utf-8"),
            )

        if default:
            self["samplings"].attrs["default"] = name
        return name

    @property
    def temp_dir(self):
        cla = type(self)
        if cla._temp_dir is None:
            cla._temp_dir = tempfile.TemporaryDirectory(prefix="robofish-io-")
        return Path(cla._temp_dir.name)

    @property
    def world_size(self):
        return self.attrs["world_size_cm"]

    @property
    def default_sampling(self):
        assert (
            "samplings" in self
        ), "The file does not have a group 'sampling' which is required."
        if "default" in self["samplings"].attrs:
            return self["samplings"].attrs["default"]
        return None

    @property
    def frequency(self):
        common_sampling = self.common_sampling()
        assert common_sampling is not None, "The sampling differs between entities."
        assert (
            "frequency_hz" in common_sampling.attrs
        ), "The common sampling has no frequency_hz"
        return common_sampling.attrs["frequency_hz"]

    def common_sampling(
        self, entities: Iterable["robofish.io.Entity"] = None
    ) -> h5py.Group:
        """Check if all entities have the same sampling.

        Args:
            entities: optional array of entities. If None is given, all entities are checked.
        Returns:
            The h5py group of the common sampling. If there is no common sampling, None will be returned.
        """
        custom_sampling = None
        for entity in self.entities:
            if "sampling" in entity["positions"].attrs:
                this_sampling = entity["positions"].attrs["sampling"]
                if custom_sampling is None:
                    custom_sampling = this_sampling
                elif custom_sampling != this_sampling:
                    return None
        sampling = self.default_sampling if custom_sampling is None else custom_sampling
        return self["samplings"][sampling]

    def create_entity(
        self,
        category: str,
        poses: Iterable = None,
        name: str = None,
        positions: Iterable = None,
        orientations: Iterable = None,
        outlines: Iterable = None,
        sampling: str = None,
    ) -> str:
        """Creates a new single entity.

        Args:
            TODO
            category: the  of the entity. The canonical values are ['organism', 'robot', 'obstacle'].
            poses: optional two dimensional array, containing the poses of the entity (x,y,orientation_x, orientation_y).
            poses_rad: optional two dimensional containing the poses of the entity (x,y, orientation_rad).
            name: optional name of the entity. If no name is given, the  is used with an id (e.g. 'fish_1')
            outlines: optional three dimensional array, containing the outlines of the entity
        Returns:
            Name of the created entity
        """

        if sampling is None and self.default_sampling is None:
            raise Exception(
                "There was no sampling specified, when creating the file, nor when creating the entity."
            )

        entity = robofish.io.Entity.create_entity(
            self["entities"],
            category,
            poses,
            name,
            positions,
            orientations,
            outlines,
            sampling,
        )

        return entity

    def create_multiple_entities(
        self,
        category: str,
        poses: Iterable,
        names: Iterable[str] = None,
        outlines=None,
        sampling=None,
    ) -> Iterable:
        """Creates multiple entities.

        Args:
            category: The common category for the entities. The canonical values are ['organism', 'robot', 'obstacle'].
            poses: three dimensional array, containing the poses of the entity.
            name: optional array of names of the entities. If no names are given, the category is used with an id (e.g. 'fish_1')
            outlines: optional array, containing the outlines of the entities, either a three dimensional common outline array can be given, or a four dimensional array.
            sampling: The string refference to the sampling. If none is given, the standard sampling from creating the file is used.
        Returns:
            Array of names of the created entities
        """

        assert (
            poses.ndim == 3
        ), f"A 3 dimensional array was expected (entity, timestep, 3). There were {poses.ndim} dimensions in poses: {poses.shape}"
        assert poses.shape[2] in [3, 4]
        agents = poses.shape[0]
        entity_names = []

        for i in range(agents):
            e_name = None if names is None else names[i]
            e_outline = (
                outlines if outlines is None or outlines.ndim == 3 else outlines[i]
            )

            entity_names.append(
                self.create_entity(
                    category=category,
                    sampling=sampling,
                    poses=poses[i],
                    name=e_name,
                    outlines=e_outline,
                )
            )
        return entity_names

    def update_calculated_data(self, verbose=False):
        changed = any([e.update_calculated_data(verbose) for e in self.entities])
        return changed

    def clear_calculated_data(self, verbose=True):
        """Delete all calculated data from the files."""
        txt = ""
        for e in self.entities:
            txt += f"Deleting from {e}. Attrs: ["
            for a in ["poses_hash"]:
                if a in e.attrs:
                    del e.attrs[a]
                    txt += f"{a}, "
            txt = txt[:-2] + "] Datasets: ["
            for g in ["calculated_actions_speeds_turns", "calculated_orientations_rad"]:
                if g in e:
                    del e[g]
                    txt += f"{g}, "
            txt = txt[:-2] + "]\n"
        if verbose:
            print(txt[:-1])

    @property
    def entity_names(self) -> Iterable[str]:
        """Getter for the names of all entities

        Returns:
            Array of all names.
        """
        return sorted(self["entities"].keys())

    @property
    def entities(self):
        return [
            robofish.io.Entity.from_h5py_group(self["entities"][name])
            for name in self.entity_names
        ]

    @property
    def entity_positions(self):
        return self.select_entity_property(None, entity_property=Entity.positions)

    @property
    def entity_orientations(self):
        return self.select_entity_property(None, entity_property=Entity.orientations)

    @property
    def entity_orientations_rad(self):
        return self.select_entity_property(
            None, entity_property=Entity.orientations_rad
        )

    @property
    def entity_poses(self):
        return self.select_entity_property(None, entity_property=Entity.poses)

    @property
    def entity_poses_rad(self):
        return self.select_entity_property(None, entity_property=Entity.poses_rad)

    @property
    @deprecation.deprecated(
        deprecated_in="0.2",
        removed_in="0.2.4",
        details="We found that our calculation of 'poses_calc_ori' is flawed."
        "Please replace it with 'poses' and use the tracked orientation."
        "If you see this message and you don't know what to do, update all packages and if nothing helps, contact Andi.\n"
        "Don't ignore this warning, it's a serious issue.",
    )
    def entity_poses_calc_ori(self):
        return self.select_entity_property(None, entity_property=Entity.poses_calc_ori)

    @property
    @deprecation.deprecated(
        deprecated_in="0.2",
        removed_in="0.2.4",
        details="We found that our calculation of 'poses_calc_ori_rad' is flawed."
        "Please replace it with 'poses_rad' and use the tracked orientation."
        "If you see this message and you don't know what to do, update all packages and if nothing helps, contact Andi.\n"
        "Don't ignore this warning, it's a serious issue.",
    )
    def entity_poses_calc_ori_rad(self):
        return self.select_entity_property(
            None, entity_property=Entity.poses_calc_ori_rad
        )

    @property
    @deprecation.deprecated(
        deprecated_in="0.2",
        removed_in="0.2.4",
        details="We found that our calculation of 'entity_speeds_turns' is flawed and replaced it "
        "with 'entity_actions_speeds_turns'. The difference in calculation is, that the tracked "
        "orientation is used now which gives the fish the ability to swim backwards. "
        "If you see this message and you don't know what to do, update all packages and if nothing helps, contact Andi.\n"
        "Don't ignore this warning, it's a serious issue.",
    )
    def entity_speeds_turns(self):
        return self.select_entity_property(None, entity_property=Entity.speed_turn)

    @property
    def entity_actions_speeds_turns(self):
        """Calculate the speed, turn and from the recorded positions and orientations.

        The turn is calculated by the change of orientation between frames.
        The speed is calculated by the distance between the points, projected on the new orientation vector.
        The sideway change of position cannot be represented with this method.

        Returns:
            An array with shape (number_of_entities, number_of_positions -1, 2 (speed in cm/frame, turn in rad/frame).
        """
        return self.select_entity_property(
            None, entity_property=Entity.actions_speeds_turns
        )

    def select_entity_poses(self, *args, ori_rad=False, **kwargs):
        entity_property = Entity.poses_rad if ori_rad else Entity.poses
        return self.select_entity_property(
            *args, entity_property=entity_property, **kwargs
        )

    def select_entity_property(
        self,
        predicate: types.LambdaType = None,
        entity_property: Union[property, str] = Entity.poses,
    ) -> Iterable:
        """Get a property of selected entities.

        Entities can be selected, using a lambda function.
        The property of the entities can be selected.

        Args:
            predicate: a lambda function, selecting entities
            (example: lambda e: e.category == "fish")
            entity_property: a property of the Entity class (example: Entity.poses_rad) or a string with the name of the dataset.
        Returns:
            An three dimensional array of all properties of all entities with the shape (entity, time, property_length).
            If an entity has a shorter length of the property, the output will be filled with nans.
        """

        entities = self.entities
        if predicate is not None:
            entities = [e for e in entities if predicate(e)]

        assert self.common_sampling(entities) is not None

        # Initialize poses output array
        if isinstance(entity_property, str):
            properties = [entity[entity_property] for entity in entities]
        else:
            properties = [entity_property.__get__(entity) for entity in entities]

        max_timesteps = max([0] + [p.shape[0] for p in properties])

        property_array = np.empty(
            (len(entities), max_timesteps, properties[0].shape[1])
        )
        property_array[:] = np.nan

        # Fill output array
        for i, entity in enumerate(entities):
            property_array[i][: properties[i].shape[0]] = properties[i]
        return property_array

    def validate(self, strict_validate: bool = True) -> Tuple[bool, str]:
        """Validate the file to the specification.

        The function compares a given file to the robofish track format specification:
        https://git.imp.fu-berlin.de/bioroboticslab/robofish/track_format
        First all specified arrays are formatted to be numpy arrays with the specified
        datatype. Then all specified shapes are validated. Lastly calendar points
        are validated to be datetimes according to ISO8601.

        Args:
            track: A track as a dictionary
            strict_validate: Throw an exception instead of just returning false.
        Returns:
            The function returns a touple of validity and an error message
        Throws:
            AssertionError: When the file is invalid and strict_validate is True
        """
        return robofish.io.validate(self, strict_validate)

    def to_string(
        self,
        output_format: str = "shape",
        max_width: int = 120,
        full_attrs: bool = False,
    ) -> str:
        """The file is formatted to a human readable format.
        Args:
            output_format: ['shape', 'full'] show the shape, or the full content of datasets
            max_width: set the width in characters after which attribute values get abbreviated
            full_attrs: do not abbreviate attribute values if True
        Returns:
            A human readable string, representing the file
        """

        def recursive_stringify(
            obj: h5py.Group,
            output_format: str,
            parent_indices: List[int] = [],
            parent_siblings: List[int] = [],
        ) -> str:
            """This function crawls recursively into hdf5 groups.
            Datasets and attributes are directly attached, for groups, the function is recursively called again.
            Args:
                obj: a h5py group
                output_format: ['shape', 'full'] show the shape, or the full content of datasets
            Returns:
                A string representation of the group
            """

            def lines(dataset_attribute: bool = False) -> str:
                """Get box-drawing characters for the graph lines."""
                line = ""
                for pi, ps in zip(parent_indices, parent_siblings):
                    if pi < ps - 1:
                        line += "│ "
                    else:
                        line += "  "
                if dataset_attribute:
                    line += "  "
                line += "─ "
                junction_index = 2 * len(parent_indices) + dataset_attribute * 2 - 1
                last = "└"
                other = "├"
                if dataset_attribute:
                    j = (
                        last
                        if list(value.attrs.keys()).index(d_key) == len(value.attrs) - 1
                        else other
                    )
                else:
                    j = last if index == num_children - 1 else other
                line = line[: junction_index + 1] + j + line[junction_index + 1 :]
                if isinstance(value, h5py.Group) or (
                    isinstance(value, h5py.Dataset)
                    and not dataset_attribute
                    and value.attrs
                ):
                    line = line[:-1] + "┬─"
                else:
                    line = line[:-1] + "──"

                return line + " "

            s = ""
            max_key_len = 0
            num_children = 0
            if obj.attrs:
                max_key_len = max(len(key) for key in obj.attrs)
                num_children += len(obj.attrs)
            if hasattr(obj, "items"):
                max_key_len = max([len(key) for key in obj] + [max_key_len])
                num_children += len(obj)
            index = 0
            if obj.attrs:
                for key, value in obj.attrs.items():
                    if not full_attrs:
                        value = str(value).replace("\n", " ").strip()
                        if len(value) > max_width - max_key_len - len(lines()):
                            value = (
                                value[: max_width - max_key_len - len(lines()) - 3]
                                + "..."
                            )
                    s += f"{lines()}{key: <{max_key_len}}  {value}\n"
                    index += 1
            if hasattr(obj, "items"):
                for key, value in obj.items():
                    if isinstance(value, h5py.Dataset):
                        if output_format == "shape":
                            s += (
                                f"{lines()}"
                                f"{key: <{max_key_len}}  Shape {value.shape}\n"
                            )
                        else:
                            s += f"{lines()}{key}:\n"
                            s += np.array2string(
                                value,
                                precision=2,
                                separator=" ",
                                suppress_small=True,
                            )
                            s += "\n"

                        if value.attrs:
                            d_max_key_len = max(len(dk) for dk in value.attrs)
                        for d_key, d_value in value.attrs.items():
                            d_value = str(d_value).replace("\n", " ").strip()
                            if len(d_value) > max_width - d_max_key_len - len(
                                lines(True)
                            ):
                                if not full_attrs:
                                    d_value = d_value[
                                        : max_width - d_max_key_len - len(lines(True))
                                    ]
                                    d_value = d_value[:-3] + "..."
                            s += f"{lines(True)}{d_key: <{d_max_key_len}}  {d_value}\n"
                    if isinstance(value, h5py.Group):
                        s += f"{lines()}{key}\n" + recursive_stringify(
                            obj=value,
                            output_format=output_format,
                            parent_indices=parent_indices + [index],
                            parent_siblings=parent_siblings + [num_children],
                        )
                    index += 1
            return s

        return recursive_stringify(self, output_format)

    def __str__(self):
        return self.to_string()

    def plot(
        self,
        ax=None,
        lw_distances=False,
        lw=2,
        ms=32,
        figsize=None,
        step_size=4,
        c=None,
        cmap="Set1",
        skip_timesteps=0,
        max_timesteps=None,
        show=False,
        legend=True,
    ):
        """Plot the file using matplotlib.pyplot

        The tracks in the file are plotted using matplotlib.plot().

        Args:
            ax (matplotlib.axes, optional): An axes object to plot in. If None is given, a new figure is created.
            lw_distances (bool, optional):  Flag to show the distances between individuals through line width.
            figsize (Tuple[int], optional): Size of a newly created figure.
            step_size (int, optional): when using lw_distances, the track is split into sections which have a common line width. This parameter defines the length of the sections.
            c (Array[color_representation], optional): An array of colors. Each item has to be matplotlib.colors.is_color_like(item).
            cmap (matplotlib.colors.Colormap, optional): The colormap to use
            skip_timesteps (int, optional): Skip timesteps in the begining of the file
            max_timesteps (int, optional): Cut of timesteps in the end of the file.
            show (bool, optional): Show the created plot.
        Returns:
            matplotlib.axes: The axes object with the plot.
        """

        if max_timesteps is not None:
            poses = self.entity_positions[
                :, skip_timesteps : max_timesteps + skip_timesteps
            ]
        else:
            poses = self.entity_positions[:, skip_timesteps:]

        if lw_distances and poses.shape[0] < 2:
            lw_distances = False

        if lw_distances:
            poses_diff = np.diff(poses, axis=0)  # Axis 0 is fish
            distances = np.linalg.norm(poses_diff, axis=2)

            min_distances = np.min(distances, axis=0)

            # Magic numbers found by trial and error. Everything above 15cm will be represented as line width 1
            max_distance = 10
            max_lw = 4
            line_width = (
                np.clip(max_distance - min_distances, 1, max_distance)
                * max_lw
                / max_distance
            )
        else:
            step_size = poses.shape[1]

        cmap = cm.get_cmap(cmap)

        x_world, y_world = self.world_size
        if figsize is None:
            figsize = (8, 8)

        if ax is None:
            fig, ax = plt.subplots(1, 1, figsize=figsize)

        if self.path is not None:
            ax.set_title("\n".join(wrap(Path(self.path).name, width=35)))

        ax.set_xlim(-x_world / 2, x_world / 2)
        ax.set_ylim(-y_world / 2, y_world / 2)
        for fish_id in range(poses.shape[0]):
            if c is None:
                this_c = cmap(fish_id)
            elif isinstance(c, list):
                this_c = c[fish_id]

            timesteps = poses.shape[1] - 1
            for t in range(0, timesteps, step_size):
                if lw_distances:
                    lw = np.mean(line_width[t : t + step_size + 1])

                ax.plot(
                    poses[fish_id, t : t + step_size + 1, 0],
                    poses[fish_id, t : t + step_size + 1, 1],
                    c=this_c,
                    lw=lw,
                )
            # Plotting outside of the figure to have the label
            ax.plot([550, 600], [550, 600], lw=5, c=this_c, label=fish_id)

        # ax.scatter(
        #     [poses[:, skip_timesteps, 0]],
        #     [poses[:, skip_timesteps, 1]],
        #     marker="h",
        #     c="black",
        #     s=ms,
        #     label="Start",
        #     zorder=5,
        # )
        ax.scatter(
            [poses[:, -1, 0]],
            [poses[:, -1, 1]],
            marker="x",
            c="black",
            s=ms,
            label="End",
            zorder=5,
        )
        if legend and isinstance(legend, str):
            ax.legend(legend)
        elif legend:
            ax.legend()
        ax.set_xlabel("x [cm]")
        ax.set_ylabel("y [cm]")

        if show:
            plt.show()

        return ax

    def render(self, video_path=None, **kwargs):
        """Render a video of the file.

        As there are render functions in gym_guppy and robofish.trackviewer, this function is a temporary addition.
        The goal should be to bring together the rendering tools."""

        if video_path is not None:
            try:
                run(["ffmpeg"], capture_output=True)
            except Exception as e:
                raise Exception(
                    f"ffmpeg is required to store videos. Please install it.\n{e}"
                )

        def shape_vertices(scale=1) -> np.ndarray:
            base_shape = np.array(
                [
                    (+3.0, +0.0),
                    (+2.5, +1.0),
                    (+1.5, +1.5),
                    (-2.5, +1.0),
                    (-4.5, +0.0),
                    (-2.5, -1.0),
                    (+1.5, -1.5),
                    (+2.5, -1.0),
                ]
            )
            return base_shape * scale

        default_options = {
            "linewidth": 2,
            "speedup": 1,
            "trail": 100,
            "entity_scale": 0.2,
            "fixed_view": False,
            "view_size": 50,
            "margin": 15,
            "slow_view": 0.8,
            "slow_zoom": 0.95,
            "cut_frames_start": None,
            "cut_frames_end": None,
            "show_text": False,
            "render_goals": False,
            "render_targets": False,
            "dpi": 200,
            "figsize": 10,
        }

        options = {
            key: kwargs[key] if key in kwargs else default_options[key]
            for key in default_options.keys()
        }

        fig, ax = plt.subplots(figsize=(options["figsize"], options["figsize"]))
        ax.set_aspect("equal")
        ax.set_facecolor("gray")
        plt.tight_layout(pad=0.05)
        n_entities = len(self.entities)
        lines = [
            plt.plot([], [], lw=options["linewidth"], zorder=0)[0]
            for _ in range(n_entities)
        ]
        points = [
            plt.scatter([], [], marker="x", color="k"),
            plt.plot([], [], linestyle="dotted", alpha=0.5, color="k", zorder=0)[0],
        ]
        categories = [entity.attrs.get("category", None) for entity in self.entities]
        entity_polygons = [
            patches.Polygon(shape_vertices(options["entity_scale"]), facecolor=color)
            for color in [
                "gray" if category == "robot" else "k" for category in categories
            ]
        ]

        border_vertices = np.array(
            [
                np.array([-1, -1, 1, 1, -1]) * self.world_size[0] / 2,
                np.array([-1, 1, 1, -1, -1]) * self.world_size[1] / 2,
            ]
        )

        spacing = 10
        x = np.arange(
            -0.5 * self.world_size[0] + spacing, 0.5 * self.world_size[0], spacing
        )
        y = np.arange(
            -0.5 * self.world_size[1] + spacing, 0.5 * self.world_size[1], spacing
        )
        xv, yv = np.meshgrid(x, y)

        grid_points = plt.scatter(xv, yv, c="gray", s=1.5)

        # border = plt.plot(border_vertices[0], border_vertices[1], "k")
        border = patches.Polygon(border_vertices.T, facecolor="w", zorder=-1)

        def title(file_frame: int) -> str:
            """Search for datasets containing text for displaying it in the video"""
            output = []
            for e in self.entities:
                for key, val in e.items():
                    if val.dtype == object and type(val[0]) == bytes:
                        output.append(f"{e.name}.{key}='{val[file_frame].decode()}'")
            return ", ".join(output)

        def get_goal(file_frame: int) -> Optional[np.ndarray]:
            """Return current goal of robot, if robot exists and has a goal."""
            goal = None
            if "robot" in categories:
                robot = self.entities[categories.index("robot")]
                try:
                    goal = robot["goals"][file_frame]
                except KeyError:
                    pass
            if goal is not None and np.isnan(goal).any():
                goal = None
            return goal

        def get_target(file_frame: int) -> Tuple[List, List]:
            """Return line points from robot to target"""
            if "robot" in categories:
                robot = self.entities[categories.index("robot")]
                rpos = robot["positions"][file_frame]
                target = robot["targets"][file_frame]
                return [rpos[0], target[0]], [rpos[1], target[1]]
            return [], []

        def init():
            ax.set_xlim(-0.5 * self.world_size[0], 0.5 * self.world_size[0])
            ax.set_ylim(-0.5 * self.world_size[1], 0.5 * self.world_size[1])
            ax.set_xticks([])
            ax.set_xticks([], minor=True)
            ax.set_yticks([])
            ax.set_yticks([], minor=True)

            for e_poly in entity_polygons:
                ax.add_patch(e_poly)
            ax.add_patch(border)
            return lines + entity_polygons + [border] + points

        n_frames = self.entity_poses.shape[1]

        if options["cut_frames_end"] == 0 or options["cut_frames_end"] is None:
            options["cut_frames_end"] = n_frames
        if options["cut_frames_start"] is None:
            options["cut_frames_start"] = 0
        frame_range = (
            options["cut_frames_start"],
            min(n_frames, options["cut_frames_end"]),
        )

        n_frames = int((frame_range[1] - frame_range[0]) / options["speedup"])

        start_pose = self.entity_poses_rad[:, frame_range[0]]

        self.middle_of_swarm = np.mean(start_pose, axis=0)
        min_view = np.max((np.max(start_pose, axis=0) - np.min(start_pose, axis=0))[:2])
        self.view_size = np.max([options["view_size"], min_view + options["margin"]])

        if video_path is not None:
            pbar = tqdm(range(n_frames))

        def update(frame):
            if "pbar" in locals().keys():
                pbar.update(1)
                pbar.refresh()

            if frame < n_frames:
                entity_poses = self.entity_poses_rad

                file_frame = (frame * options["speedup"]) + frame_range[0]
                this_pose = entity_poses[:, file_frame]

                if not options["fixed_view"]:

                    # Find the maximal distance between the entities in x or y direction
                    min_view = np.max(
                        (np.max(this_pose, axis=0) - np.min(this_pose, axis=0))[:2]
                    )

                    new_view_size = np.max(
                        [options["view_size"], min_view + options["margin"]]
                    )

                    if not np.isnan(min_view).any() and not new_view_size is np.nan:
                        self.middle_of_swarm = options[
                            "slow_view"
                        ] * self.middle_of_swarm + (1 - options["slow_view"]) * np.mean(
                            this_pose, axis=0
                        )

                        self.view_size = (
                            options["slow_zoom"] * self.view_size
                            + (1 - options["slow_zoom"]) * new_view_size
                        )

                    ax.set_xlim(
                        self.middle_of_swarm[0] - self.view_size / 2,
                        self.middle_of_swarm[0] + self.view_size / 2,
                    )
                    ax.set_ylim(
                        self.middle_of_swarm[1] - self.view_size / 2,
                        self.middle_of_swarm[1] + self.view_size / 2,
                    )
                if options["show_text"]:
                    ax.set_title(title(file_frame))

                if options["render_goals"]:
                    goal = get_goal(file_frame)
                    if goal is not None:
                        points[0].set_offsets(goal)

                if options["render_targets"]:
                    points[1].set_data(get_target(file_frame))

                poses_trails = entity_poses[
                    :, max(0, file_frame - options["trail"]) : file_frame
                ]
                for i_entity in range(n_entities):
                    lines[i_entity].set_data(
                        poses_trails[i_entity, :, 0], poses_trails[i_entity, :, 1]
                    )

                    current_pose = entity_poses[i_entity, file_frame]
                    t = mpl.transforms.Affine2D().translate(
                        current_pose[0], current_pose[1]
                    )
                    r = mpl.transforms.Affine2D().rotate(current_pose[2])
                    tra = r + t + ax.transData
                    entity_polygons[i_entity].set_transform(tra)
            else:
                raise Exception(
                    f"Frame is bigger than n_frames {file_frame} of {n_frames}"
                )
            return lines + entity_polygons + [border] + points

        print(f"Preparing to render n_frames: {n_frames}")

        ani = animation.FuncAnimation(
            fig,
            update,
            frames=n_frames,
            init_func=init,
            blit=platform.system() != "Darwin",
            interval=1000 / self.frequency,
            repeat=False,
        )

        if video_path is not None:

            # if i % (n / 40) == 0:
            #     print(f"Saving frame {i} of {n} ({100*i/n:.1f}%)")

            video_path = Path(video_path)
            if video_path.exists():
                y = input(f"Video {str(video_path)} exists. Overwrite? (y/n)")
                if y == "y":
                    video_path.unlink()

            if not video_path.exists():
                print(f"saving video to {video_path}")

                writervideo = animation.FFMpegWriter(fps=self.frequency)
                ani.save(video_path, writer=writervideo, dpi=options["dpi"])
            plt.close()
        else:
            plt.show()

Classes

class File (path: Union[str, pathlib.Path] = None, mode: str = 'r', *, world_size_cm: List[int] = None, validate: bool = False, validate_when_saving: bool = True, strict_validate: bool = False, format_version: List[int] = array([1, 0], dtype=int32), format_url: str = 'https://git.imp.fu-berlin.de/bioroboticslab/robofish/track_format/-/releases/1.0', sampling_name: str = None, frequency_hz: int = None, monotonic_time_points_us: Iterable = None, calendar_time_points: Iterable = None, open_copy: bool = False, validate_poses_hash: bool = True)

Represents a RoboFish Track Format file, which should be used to store tracking data of individual animals or swarms.

Files can be opened (with optional creation), modified inplace, and have copies of them saved.

Create a new RoboFish Track Format object.

When called with a path, it is loaded, otherwise a new temporary file is created. File contents can be validated against the track format specification.

Parameters

path : str or Path, optional
Location of file to be opened. If not provided, mode is ignored.
mode : str, default='r'
'r' Readonly, file must exist 'r+' Read/write, file must exist 'w' Create file, truncate if exists 'x' Create file, fail if exists 'a' Read/write if exists, create otherwise
world_size_cm : [int, int] , optional
side lengths [x, y] of the world in cm. rectangular world shape is assumed.
validate : bool, default=False
Should the track be validated? This is normally switched off for performance reasons.
strict_validate : bool, default=False
if the file should be strictly validated against the track format specification, when loaded from a path. TODO: Should this validate against the version sepcified in format_version or just against the most recent version?
format_version : [int, int], default=[1,0]
version [major, minor] of track format specification
format_url : str, default="https://git.imp.fu-berlin.de/bioroboticslab/robofish/track_format/-/releases/1.0"
location of track format specification. should fit format_version.
sampling_name : str, optional

How to specify your sampling:

  1. (optional) provide text description of your sampling in sampling_name

2.a (mandatory, if you have a constant sampling frequency) specify frequency_hz with your sampling frequency in Hz

2.b (mandatory, if you do NOT have a constant sampling frequency) specify monotonic_time_points_us with a list[1] of time points in microseconds on a montonic clock, one for each sample in your dataset.

  1. (optional) specify calendar_time_points with a list[2] of time points in the ISO 8601 extended format with microsecond precision and time zone designator[3], one for each sample in your dataset.

[1] any Iterable of int [2] any Iterable of str [3] example: "2020-11-18T13:21:34.117015+01:00"

frequency_hz : int, optional
refer to explanation of sampling_name
monotonic_time_points_us : Iterable of int, optional
refer to explanation of sampling_name
calendar_time_points : Iterable of str, optional
refer to explanation of sampling_name
open_copy : bool, optional
a temporary copy of the file will be opened instead of the file itself.
Expand source code
class File(h5py.File):
    """Represents a RoboFish Track Format file, which should be used to store tracking data of individual animals or swarms.

    Files can be opened (with optional creation), modified inplace, and have copies of them saved.
    """

    _temp_dir = None

    def __init__(
        self,
        path: Union[str, Path] = None,
        mode: str = "r",
        *,  # PEP 3102
        world_size_cm: List[int] = None,
        validate: bool = False,
        validate_when_saving: bool = True,
        strict_validate: bool = False,
        format_version: List[int] = default_format_version,
        format_url: str = default_format_url,
        sampling_name: str = None,
        frequency_hz: int = None,
        monotonic_time_points_us: Iterable = None,
        calendar_time_points: Iterable = None,
        open_copy: bool = False,
        validate_poses_hash: bool = True,
    ):
        """Create a new RoboFish Track Format object.

        When called with a path, it is loaded, otherwise a new temporary
        file is created. File contents can be validated against the
        track format specification.

        Parameters
        ----------
        path : str or Path, optional
            Location of file to be opened. If not provided, mode is ignored.
        mode : str, default='r'
            'r'        Readonly, file must exist
            'r+'       Read/write, file must exist
            'w'        Create file, truncate if exists
            'x'        Create file, fail if exists
            'a'        Read/write if exists, create otherwise
        world_size_cm : [int, int] , optional
            side lengths [x, y] of the world in cm.
            rectangular world shape is assumed.
        validate: bool, default=False
            Should the track be validated? This is normally switched off for performance reasons.
        strict_validate : bool, default=False
            if the file should be strictly validated against the track
            format specification, when loaded from a path.
            TODO: Should this validate against the version sepcified in
            format_version or just against the most recent version?
        format_version : [int, int], default=[1,0]
            version [major, minor] of track format specification
        format_url : str, default="https://git.imp.fu-berlin.de/bioroboticslab/robofish/track_format/-/releases/1.0"
            location of track format specification.
            should fit `format_version`.
        sampling_name : str, optional
            How to specify your sampling:

            1. (optional)
                provide text description of your sampling in `sampling_name`

            2.a (mandatory, if you have a constant sampling frequency)
                specify `frequency_hz` with your sampling frequency in Hz

            2.b (mandatory, if you do NOT have a constant sampling frequency)
                specify `monotonic_time_points_us` with a list[1] of time
                points in microseconds on a montonic clock, one for each
                sample in your dataset.

            3.  (optional)
                specify `calendar_time_points` with a list[2] of time points
                in the ISO 8601 extended format with microsecond precision
                and time zone designator[3],  one for each sample in your
                dataset.

            [1] any Iterable of int
            [2] any Iterable of str
            [3] example:  "2020-11-18T13:21:34.117015+01:00"

        frequency_hz: int, optional
            refer to explanation of `sampling_name`
        monotonic_time_points_us: Iterable of int, optional
            refer to explanation of `sampling_name`
        calendar_time_points: Iterable of str, optional
            refer to explanation of `sampling_name`
        open_copy: bool, optional
            a temporary copy of the file will be opened instead of the file itself.
        """

        self.path = path
        self.validate_when_saving = validate_when_saving

        if open_copy:
            assert (
                path is not None
            ), "A path has to be given if a copy should be opened."

            temp_file = self.temp_dir / str(uuid.uuid4())
            logging.info(
                f"Copying file to temporary file and opening it:\n{path} -> {temp_file}"
            )

            shutil.copyfile(path, temp_file)
            super().__init__(
                temp_file,
                mode="r+",
                driver="core",
                backing_store=True,
                libver=("earliest", "v110"),
            )
            initialize = False

        elif path is None:
            temp_file = self.temp_dir / str(uuid.uuid4())
            logging.info(f"Opening New temporary file {temp_file}")
            super().__init__(
                temp_file,
                mode="x",
                driver="core",
                backing_store=True,
                libver=("earliest", "v110"),
            )
            initialize = True
        else:
            # mode
            # r        Readonly, file must exist (default)
            # r+       Read/write, file must exist
            # w        Create file, truncate if exists
            # x        Create file, fail if exists
            # a        Read/write if exists, create otherwise
            logging.info(f"Opening File {path}")

            assert mode in ["r", "r+", "w", "x", "a"], f"Unknown mode {mode}."

            # If the file does not exist or if it should be truncated with mode=w, initialize it.
            if Path(path).exists() and mode != "w":
                initialize = False
            else:
                initialize = True

            try:
                super().__init__(path, mode, libver=("earliest", "v110"))
            except OSError as e:
                raise OSError(f"Could not open file {path} with mode {mode}.\n{e}")

        if initialize:
            assert (
                world_size_cm is not None and format_version is not None
            ), "It seems like the file is already initialized. Try opening it with mode 'r+'."

            self.attrs["world_size_cm"] = np.array(world_size_cm, dtype=np.float32)
            self.attrs["format_version"] = np.array(format_version, dtype=np.int32)
            self.attrs["format_url"] = format_url

            self.create_group("entities")
            self.create_group("samplings")

            if frequency_hz is not None or monotonic_time_points_us is not None:
                self.create_sampling(
                    name=sampling_name,
                    frequency_hz=frequency_hz,
                    monotonic_time_points_us=monotonic_time_points_us,
                    calendar_time_points=calendar_time_points,
                    default=True,
                )
        else:
            # A quick validation to find h5py files which are not robofish.io files
            if any([a not in self.attrs for a in ["world_size_cm", "format_version"]]):
                msg = f"The opened file {self.path} does not include world_size_cm or format_version. It seems that the file is not a robofish.io.File."
                if strict_validate:
                    raise KeyError(msg)
                else:
                    warnings.warn(msg)
                return

            # Validate that the stored poses hash still fits.
            if validate_poses_hash:
                for entity in self.entities:
                    if "poses_hash" in entity.attrs:
                        if entity.attrs["poses_hash"] != entity.poses_hash:
                            warnings.warn(
                                f"The stored hash is not identical with the newly calculated hash. In entity {entity.name} in {self.path}. f.entity_actions_turns_speeds and f.entity_orientation_rad will return wrong results.\n"
                                f"stored: {entity.attrs['poses_hash']}, calculated: {entity.poses_hash}"
                            )
                        assert (
                            "unfinished_calculations" not in entity.attrs
                        ), f"The calculated data of file {self.path} is uncomplete and was probably aborted during calculation. please recalculate with `robofish-io-update-calculated-data {self.path}`."

                    else:
                        warnings.warn(
                            f"The file did not include pre-calculated data so the actions_speeds_turns "
                            f"and orientations_rad will have to be be recalculated everytime.\n"
                            f"Please use `robofish-io-update-calculated-data {self.path}` in the "
                            f"commandline or\nopen and close the file with robofish.io.File(f, 'r+') "
                            f"in python.\nIf the data should be recalculated every time open the file "
                            "with the bool option validate_poses_hash=False."
                        )
        if validate:
            self.validate(strict_validate)

    def __enter__(self):
        return self

    def __exit__(self, type, value, traceback):

        # Check if the context was left under normal circumstances
        if (hasattr(self, "closed") and not self.closed) and (
            type,
            value,
            traceback,
        ) == (None, None, None):
            if (
                self.mode != "r" and self.validate_when_saving
            ):  # No need to validate read only files (performance).
                self.validate()

        super().__exit__(type, value, traceback)

    def close(self):
        if self.mode != "r":
            self.update_calculated_data()
        super().close()

    def save_as(
        self,
        path: Union[str, Path],
        strict_validate: bool = True,
        no_warning: bool = False,
    ):
        """Save a copy of the file

        Args:
            path: path to a io file as a string or path object. If no path is specified, the last known path (from loading or saving) is used.
            strict_validate: optional boolean, if the file should be strictly validated, before saving. The default is True.
            no_warning: optional boolean, to remove the warning from the function.
        Returns:
            The file itself, so something like f = robofish.io.File().save_as("file.hdf5") works
        """

        self.update_calculated_data()
        self.validate(strict_validate=strict_validate)

        # Ensure all buffered data has been written to disk
        self.flush()

        path = Path(path).resolve()
        path.parent.mkdir(parents=True, exist_ok=True)

        filename = self.filename
        self.flush()
        self.close()

        self.closed = True

        shutil.copyfile(filename, path)
        if not no_warning:
            warnings.warn(
                "The 'save_as' function closes the file currently to be able to store it. If you want to use the file after saving it, please reload the file. The save_as function can be avoided by opening the correct file directly. If you want to get rid of this warning use 'save_as(..., no_warning=True)'"
            )
        return None

    def create_sampling(
        self,
        name: str = None,
        frequency_hz: int = None,
        monotonic_time_points_us: Iterable = None,
        calendar_time_points: Iterable = None,
        default: bool = False,
    ):

        # Find Name for sampling if none is given
        if name is None:
            if frequency_hz is not None:
                name = "%d hz" % frequency_hz

            i = 1
            while name is None or name in self["samplings"]:
                name = "sampling_%d" % i
                i += 1

        sampling = self["samplings"].create_group(name)

        if monotonic_time_points_us is not None:

            monotonic_time_points_us = np.array(
                monotonic_time_points_us, dtype=np.int64
            )
            sampling.create_dataset(
                "monotonic_time_points_us", data=monotonic_time_points_us
            )
            if frequency_hz is None:
                diff = np.diff(monotonic_time_points_us)
                if np.all(diff == diff[0]) and diff[0] > 0:
                    frequency_hz = 1e6 / diff[0]
                    warnings.warn(
                        f"The frequency_hz of {frequency_hz:.2f}hz was calculated automatically by robofish.io. The safer variant is to pass it using frequency_hz.\nThis is important when using fish_models with the files."
                    )

                else:
                    warnings.warn(
                        "The frequency_hz could not be calculated automatically. When using fish_models, the file will access frequency_hz."
                    )

        if frequency_hz is not None:
            sampling.attrs["frequency_hz"] = (np.float32)(frequency_hz)

        if calendar_time_points is not None:

            def format_calendar_time_point(p):
                if isinstance(p, datetime.datetime):
                    assert p.tzinfo is not None, "Missing timezone for calendar point."
                    return p.isoformat(timespec="microseconds")
                elif isinstance(p, str):
                    assert p == datetime.datetime.fromisoformat(p).isoformat(
                        timespec="microseconds"
                    )
                    return p
                else:
                    assert (
                        False
                    ), "Calendar points must be datetime.datetime instances or strings."

            calendar_time_points = [
                format_calendar_time_point(p) for p in calendar_time_points
            ]

            sampling.create_dataset(
                "calendar_time_points",
                data=calendar_time_points,
                dtype=h5py.string_dtype(encoding="utf-8"),
            )

        if default:
            self["samplings"].attrs["default"] = name
        return name

    @property
    def temp_dir(self):
        cla = type(self)
        if cla._temp_dir is None:
            cla._temp_dir = tempfile.TemporaryDirectory(prefix="robofish-io-")
        return Path(cla._temp_dir.name)

    @property
    def world_size(self):
        return self.attrs["world_size_cm"]

    @property
    def default_sampling(self):
        assert (
            "samplings" in self
        ), "The file does not have a group 'sampling' which is required."
        if "default" in self["samplings"].attrs:
            return self["samplings"].attrs["default"]
        return None

    @property
    def frequency(self):
        common_sampling = self.common_sampling()
        assert common_sampling is not None, "The sampling differs between entities."
        assert (
            "frequency_hz" in common_sampling.attrs
        ), "The common sampling has no frequency_hz"
        return common_sampling.attrs["frequency_hz"]

    def common_sampling(
        self, entities: Iterable["robofish.io.Entity"] = None
    ) -> h5py.Group:
        """Check if all entities have the same sampling.

        Args:
            entities: optional array of entities. If None is given, all entities are checked.
        Returns:
            The h5py group of the common sampling. If there is no common sampling, None will be returned.
        """
        custom_sampling = None
        for entity in self.entities:
            if "sampling" in entity["positions"].attrs:
                this_sampling = entity["positions"].attrs["sampling"]
                if custom_sampling is None:
                    custom_sampling = this_sampling
                elif custom_sampling != this_sampling:
                    return None
        sampling = self.default_sampling if custom_sampling is None else custom_sampling
        return self["samplings"][sampling]

    def create_entity(
        self,
        category: str,
        poses: Iterable = None,
        name: str = None,
        positions: Iterable = None,
        orientations: Iterable = None,
        outlines: Iterable = None,
        sampling: str = None,
    ) -> str:
        """Creates a new single entity.

        Args:
            TODO
            category: the  of the entity. The canonical values are ['organism', 'robot', 'obstacle'].
            poses: optional two dimensional array, containing the poses of the entity (x,y,orientation_x, orientation_y).
            poses_rad: optional two dimensional containing the poses of the entity (x,y, orientation_rad).
            name: optional name of the entity. If no name is given, the  is used with an id (e.g. 'fish_1')
            outlines: optional three dimensional array, containing the outlines of the entity
        Returns:
            Name of the created entity
        """

        if sampling is None and self.default_sampling is None:
            raise Exception(
                "There was no sampling specified, when creating the file, nor when creating the entity."
            )

        entity = robofish.io.Entity.create_entity(
            self["entities"],
            category,
            poses,
            name,
            positions,
            orientations,
            outlines,
            sampling,
        )

        return entity

    def create_multiple_entities(
        self,
        category: str,
        poses: Iterable,
        names: Iterable[str] = None,
        outlines=None,
        sampling=None,
    ) -> Iterable:
        """Creates multiple entities.

        Args:
            category: The common category for the entities. The canonical values are ['organism', 'robot', 'obstacle'].
            poses: three dimensional array, containing the poses of the entity.
            name: optional array of names of the entities. If no names are given, the category is used with an id (e.g. 'fish_1')
            outlines: optional array, containing the outlines of the entities, either a three dimensional common outline array can be given, or a four dimensional array.
            sampling: The string refference to the sampling. If none is given, the standard sampling from creating the file is used.
        Returns:
            Array of names of the created entities
        """

        assert (
            poses.ndim == 3
        ), f"A 3 dimensional array was expected (entity, timestep, 3). There were {poses.ndim} dimensions in poses: {poses.shape}"
        assert poses.shape[2] in [3, 4]
        agents = poses.shape[0]
        entity_names = []

        for i in range(agents):
            e_name = None if names is None else names[i]
            e_outline = (
                outlines if outlines is None or outlines.ndim == 3 else outlines[i]
            )

            entity_names.append(
                self.create_entity(
                    category=category,
                    sampling=sampling,
                    poses=poses[i],
                    name=e_name,
                    outlines=e_outline,
                )
            )
        return entity_names

    def update_calculated_data(self, verbose=False):
        changed = any([e.update_calculated_data(verbose) for e in self.entities])
        return changed

    def clear_calculated_data(self, verbose=True):
        """Delete all calculated data from the files."""
        txt = ""
        for e in self.entities:
            txt += f"Deleting from {e}. Attrs: ["
            for a in ["poses_hash"]:
                if a in e.attrs:
                    del e.attrs[a]
                    txt += f"{a}, "
            txt = txt[:-2] + "] Datasets: ["
            for g in ["calculated_actions_speeds_turns", "calculated_orientations_rad"]:
                if g in e:
                    del e[g]
                    txt += f"{g}, "
            txt = txt[:-2] + "]\n"
        if verbose:
            print(txt[:-1])

    @property
    def entity_names(self) -> Iterable[str]:
        """Getter for the names of all entities

        Returns:
            Array of all names.
        """
        return sorted(self["entities"].keys())

    @property
    def entities(self):
        return [
            robofish.io.Entity.from_h5py_group(self["entities"][name])
            for name in self.entity_names
        ]

    @property
    def entity_positions(self):
        return self.select_entity_property(None, entity_property=Entity.positions)

    @property
    def entity_orientations(self):
        return self.select_entity_property(None, entity_property=Entity.orientations)

    @property
    def entity_orientations_rad(self):
        return self.select_entity_property(
            None, entity_property=Entity.orientations_rad
        )

    @property
    def entity_poses(self):
        return self.select_entity_property(None, entity_property=Entity.poses)

    @property
    def entity_poses_rad(self):
        return self.select_entity_property(None, entity_property=Entity.poses_rad)

    @property
    @deprecation.deprecated(
        deprecated_in="0.2",
        removed_in="0.2.4",
        details="We found that our calculation of 'poses_calc_ori' is flawed."
        "Please replace it with 'poses' and use the tracked orientation."
        "If you see this message and you don't know what to do, update all packages and if nothing helps, contact Andi.\n"
        "Don't ignore this warning, it's a serious issue.",
    )
    def entity_poses_calc_ori(self):
        return self.select_entity_property(None, entity_property=Entity.poses_calc_ori)

    @property
    @deprecation.deprecated(
        deprecated_in="0.2",
        removed_in="0.2.4",
        details="We found that our calculation of 'poses_calc_ori_rad' is flawed."
        "Please replace it with 'poses_rad' and use the tracked orientation."
        "If you see this message and you don't know what to do, update all packages and if nothing helps, contact Andi.\n"
        "Don't ignore this warning, it's a serious issue.",
    )
    def entity_poses_calc_ori_rad(self):
        return self.select_entity_property(
            None, entity_property=Entity.poses_calc_ori_rad
        )

    @property
    @deprecation.deprecated(
        deprecated_in="0.2",
        removed_in="0.2.4",
        details="We found that our calculation of 'entity_speeds_turns' is flawed and replaced it "
        "with 'entity_actions_speeds_turns'. The difference in calculation is, that the tracked "
        "orientation is used now which gives the fish the ability to swim backwards. "
        "If you see this message and you don't know what to do, update all packages and if nothing helps, contact Andi.\n"
        "Don't ignore this warning, it's a serious issue.",
    )
    def entity_speeds_turns(self):
        return self.select_entity_property(None, entity_property=Entity.speed_turn)

    @property
    def entity_actions_speeds_turns(self):
        """Calculate the speed, turn and from the recorded positions and orientations.

        The turn is calculated by the change of orientation between frames.
        The speed is calculated by the distance between the points, projected on the new orientation vector.
        The sideway change of position cannot be represented with this method.

        Returns:
            An array with shape (number_of_entities, number_of_positions -1, 2 (speed in cm/frame, turn in rad/frame).
        """
        return self.select_entity_property(
            None, entity_property=Entity.actions_speeds_turns
        )

    def select_entity_poses(self, *args, ori_rad=False, **kwargs):
        entity_property = Entity.poses_rad if ori_rad else Entity.poses
        return self.select_entity_property(
            *args, entity_property=entity_property, **kwargs
        )

    def select_entity_property(
        self,
        predicate: types.LambdaType = None,
        entity_property: Union[property, str] = Entity.poses,
    ) -> Iterable:
        """Get a property of selected entities.

        Entities can be selected, using a lambda function.
        The property of the entities can be selected.

        Args:
            predicate: a lambda function, selecting entities
            (example: lambda e: e.category == "fish")
            entity_property: a property of the Entity class (example: Entity.poses_rad) or a string with the name of the dataset.
        Returns:
            An three dimensional array of all properties of all entities with the shape (entity, time, property_length).
            If an entity has a shorter length of the property, the output will be filled with nans.
        """

        entities = self.entities
        if predicate is not None:
            entities = [e for e in entities if predicate(e)]

        assert self.common_sampling(entities) is not None

        # Initialize poses output array
        if isinstance(entity_property, str):
            properties = [entity[entity_property] for entity in entities]
        else:
            properties = [entity_property.__get__(entity) for entity in entities]

        max_timesteps = max([0] + [p.shape[0] for p in properties])

        property_array = np.empty(
            (len(entities), max_timesteps, properties[0].shape[1])
        )
        property_array[:] = np.nan

        # Fill output array
        for i, entity in enumerate(entities):
            property_array[i][: properties[i].shape[0]] = properties[i]
        return property_array

    def validate(self, strict_validate: bool = True) -> Tuple[bool, str]:
        """Validate the file to the specification.

        The function compares a given file to the robofish track format specification:
        https://git.imp.fu-berlin.de/bioroboticslab/robofish/track_format
        First all specified arrays are formatted to be numpy arrays with the specified
        datatype. Then all specified shapes are validated. Lastly calendar points
        are validated to be datetimes according to ISO8601.

        Args:
            track: A track as a dictionary
            strict_validate: Throw an exception instead of just returning false.
        Returns:
            The function returns a touple of validity and an error message
        Throws:
            AssertionError: When the file is invalid and strict_validate is True
        """
        return robofish.io.validate(self, strict_validate)

    def to_string(
        self,
        output_format: str = "shape",
        max_width: int = 120,
        full_attrs: bool = False,
    ) -> str:
        """The file is formatted to a human readable format.
        Args:
            output_format: ['shape', 'full'] show the shape, or the full content of datasets
            max_width: set the width in characters after which attribute values get abbreviated
            full_attrs: do not abbreviate attribute values if True
        Returns:
            A human readable string, representing the file
        """

        def recursive_stringify(
            obj: h5py.Group,
            output_format: str,
            parent_indices: List[int] = [],
            parent_siblings: List[int] = [],
        ) -> str:
            """This function crawls recursively into hdf5 groups.
            Datasets and attributes are directly attached, for groups, the function is recursively called again.
            Args:
                obj: a h5py group
                output_format: ['shape', 'full'] show the shape, or the full content of datasets
            Returns:
                A string representation of the group
            """

            def lines(dataset_attribute: bool = False) -> str:
                """Get box-drawing characters for the graph lines."""
                line = ""
                for pi, ps in zip(parent_indices, parent_siblings):
                    if pi < ps - 1:
                        line += "│ "
                    else:
                        line += "  "
                if dataset_attribute:
                    line += "  "
                line += "─ "
                junction_index = 2 * len(parent_indices) + dataset_attribute * 2 - 1
                last = "└"
                other = "├"
                if dataset_attribute:
                    j = (
                        last
                        if list(value.attrs.keys()).index(d_key) == len(value.attrs) - 1
                        else other
                    )
                else:
                    j = last if index == num_children - 1 else other
                line = line[: junction_index + 1] + j + line[junction_index + 1 :]
                if isinstance(value, h5py.Group) or (
                    isinstance(value, h5py.Dataset)
                    and not dataset_attribute
                    and value.attrs
                ):
                    line = line[:-1] + "┬─"
                else:
                    line = line[:-1] + "──"

                return line + " "

            s = ""
            max_key_len = 0
            num_children = 0
            if obj.attrs:
                max_key_len = max(len(key) for key in obj.attrs)
                num_children += len(obj.attrs)
            if hasattr(obj, "items"):
                max_key_len = max([len(key) for key in obj] + [max_key_len])
                num_children += len(obj)
            index = 0
            if obj.attrs:
                for key, value in obj.attrs.items():
                    if not full_attrs:
                        value = str(value).replace("\n", " ").strip()
                        if len(value) > max_width - max_key_len - len(lines()):
                            value = (
                                value[: max_width - max_key_len - len(lines()) - 3]
                                + "..."
                            )
                    s += f"{lines()}{key: <{max_key_len}}  {value}\n"
                    index += 1
            if hasattr(obj, "items"):
                for key, value in obj.items():
                    if isinstance(value, h5py.Dataset):
                        if output_format == "shape":
                            s += (
                                f"{lines()}"
                                f"{key: <{max_key_len}}  Shape {value.shape}\n"
                            )
                        else:
                            s += f"{lines()}{key}:\n"
                            s += np.array2string(
                                value,
                                precision=2,
                                separator=" ",
                                suppress_small=True,
                            )
                            s += "\n"

                        if value.attrs:
                            d_max_key_len = max(len(dk) for dk in value.attrs)
                        for d_key, d_value in value.attrs.items():
                            d_value = str(d_value).replace("\n", " ").strip()
                            if len(d_value) > max_width - d_max_key_len - len(
                                lines(True)
                            ):
                                if not full_attrs:
                                    d_value = d_value[
                                        : max_width - d_max_key_len - len(lines(True))
                                    ]
                                    d_value = d_value[:-3] + "..."
                            s += f"{lines(True)}{d_key: <{d_max_key_len}}  {d_value}\n"
                    if isinstance(value, h5py.Group):
                        s += f"{lines()}{key}\n" + recursive_stringify(
                            obj=value,
                            output_format=output_format,
                            parent_indices=parent_indices + [index],
                            parent_siblings=parent_siblings + [num_children],
                        )
                    index += 1
            return s

        return recursive_stringify(self, output_format)

    def __str__(self):
        return self.to_string()

    def plot(
        self,
        ax=None,
        lw_distances=False,
        lw=2,
        ms=32,
        figsize=None,
        step_size=4,
        c=None,
        cmap="Set1",
        skip_timesteps=0,
        max_timesteps=None,
        show=False,
        legend=True,
    ):
        """Plot the file using matplotlib.pyplot

        The tracks in the file are plotted using matplotlib.plot().

        Args:
            ax (matplotlib.axes, optional): An axes object to plot in. If None is given, a new figure is created.
            lw_distances (bool, optional):  Flag to show the distances between individuals through line width.
            figsize (Tuple[int], optional): Size of a newly created figure.
            step_size (int, optional): when using lw_distances, the track is split into sections which have a common line width. This parameter defines the length of the sections.
            c (Array[color_representation], optional): An array of colors. Each item has to be matplotlib.colors.is_color_like(item).
            cmap (matplotlib.colors.Colormap, optional): The colormap to use
            skip_timesteps (int, optional): Skip timesteps in the begining of the file
            max_timesteps (int, optional): Cut of timesteps in the end of the file.
            show (bool, optional): Show the created plot.
        Returns:
            matplotlib.axes: The axes object with the plot.
        """

        if max_timesteps is not None:
            poses = self.entity_positions[
                :, skip_timesteps : max_timesteps + skip_timesteps
            ]
        else:
            poses = self.entity_positions[:, skip_timesteps:]

        if lw_distances and poses.shape[0] < 2:
            lw_distances = False

        if lw_distances:
            poses_diff = np.diff(poses, axis=0)  # Axis 0 is fish
            distances = np.linalg.norm(poses_diff, axis=2)

            min_distances = np.min(distances, axis=0)

            # Magic numbers found by trial and error. Everything above 15cm will be represented as line width 1
            max_distance = 10
            max_lw = 4
            line_width = (
                np.clip(max_distance - min_distances, 1, max_distance)
                * max_lw
                / max_distance
            )
        else:
            step_size = poses.shape[1]

        cmap = cm.get_cmap(cmap)

        x_world, y_world = self.world_size
        if figsize is None:
            figsize = (8, 8)

        if ax is None:
            fig, ax = plt.subplots(1, 1, figsize=figsize)

        if self.path is not None:
            ax.set_title("\n".join(wrap(Path(self.path).name, width=35)))

        ax.set_xlim(-x_world / 2, x_world / 2)
        ax.set_ylim(-y_world / 2, y_world / 2)
        for fish_id in range(poses.shape[0]):
            if c is None:
                this_c = cmap(fish_id)
            elif isinstance(c, list):
                this_c = c[fish_id]

            timesteps = poses.shape[1] - 1
            for t in range(0, timesteps, step_size):
                if lw_distances:
                    lw = np.mean(line_width[t : t + step_size + 1])

                ax.plot(
                    poses[fish_id, t : t + step_size + 1, 0],
                    poses[fish_id, t : t + step_size + 1, 1],
                    c=this_c,
                    lw=lw,
                )
            # Plotting outside of the figure to have the label
            ax.plot([550, 600], [550, 600], lw=5, c=this_c, label=fish_id)

        # ax.scatter(
        #     [poses[:, skip_timesteps, 0]],
        #     [poses[:, skip_timesteps, 1]],
        #     marker="h",
        #     c="black",
        #     s=ms,
        #     label="Start",
        #     zorder=5,
        # )
        ax.scatter(
            [poses[:, -1, 0]],
            [poses[:, -1, 1]],
            marker="x",
            c="black",
            s=ms,
            label="End",
            zorder=5,
        )
        if legend and isinstance(legend, str):
            ax.legend(legend)
        elif legend:
            ax.legend()
        ax.set_xlabel("x [cm]")
        ax.set_ylabel("y [cm]")

        if show:
            plt.show()

        return ax

    def render(self, video_path=None, **kwargs):
        """Render a video of the file.

        As there are render functions in gym_guppy and robofish.trackviewer, this function is a temporary addition.
        The goal should be to bring together the rendering tools."""

        if video_path is not None:
            try:
                run(["ffmpeg"], capture_output=True)
            except Exception as e:
                raise Exception(
                    f"ffmpeg is required to store videos. Please install it.\n{e}"
                )

        def shape_vertices(scale=1) -> np.ndarray:
            base_shape = np.array(
                [
                    (+3.0, +0.0),
                    (+2.5, +1.0),
                    (+1.5, +1.5),
                    (-2.5, +1.0),
                    (-4.5, +0.0),
                    (-2.5, -1.0),
                    (+1.5, -1.5),
                    (+2.5, -1.0),
                ]
            )
            return base_shape * scale

        default_options = {
            "linewidth": 2,
            "speedup": 1,
            "trail": 100,
            "entity_scale": 0.2,
            "fixed_view": False,
            "view_size": 50,
            "margin": 15,
            "slow_view": 0.8,
            "slow_zoom": 0.95,
            "cut_frames_start": None,
            "cut_frames_end": None,
            "show_text": False,
            "render_goals": False,
            "render_targets": False,
            "dpi": 200,
            "figsize": 10,
        }

        options = {
            key: kwargs[key] if key in kwargs else default_options[key]
            for key in default_options.keys()
        }

        fig, ax = plt.subplots(figsize=(options["figsize"], options["figsize"]))
        ax.set_aspect("equal")
        ax.set_facecolor("gray")
        plt.tight_layout(pad=0.05)
        n_entities = len(self.entities)
        lines = [
            plt.plot([], [], lw=options["linewidth"], zorder=0)[0]
            for _ in range(n_entities)
        ]
        points = [
            plt.scatter([], [], marker="x", color="k"),
            plt.plot([], [], linestyle="dotted", alpha=0.5, color="k", zorder=0)[0],
        ]
        categories = [entity.attrs.get("category", None) for entity in self.entities]
        entity_polygons = [
            patches.Polygon(shape_vertices(options["entity_scale"]), facecolor=color)
            for color in [
                "gray" if category == "robot" else "k" for category in categories
            ]
        ]

        border_vertices = np.array(
            [
                np.array([-1, -1, 1, 1, -1]) * self.world_size[0] / 2,
                np.array([-1, 1, 1, -1, -1]) * self.world_size[1] / 2,
            ]
        )

        spacing = 10
        x = np.arange(
            -0.5 * self.world_size[0] + spacing, 0.5 * self.world_size[0], spacing
        )
        y = np.arange(
            -0.5 * self.world_size[1] + spacing, 0.5 * self.world_size[1], spacing
        )
        xv, yv = np.meshgrid(x, y)

        grid_points = plt.scatter(xv, yv, c="gray", s=1.5)

        # border = plt.plot(border_vertices[0], border_vertices[1], "k")
        border = patches.Polygon(border_vertices.T, facecolor="w", zorder=-1)

        def title(file_frame: int) -> str:
            """Search for datasets containing text for displaying it in the video"""
            output = []
            for e in self.entities:
                for key, val in e.items():
                    if val.dtype == object and type(val[0]) == bytes:
                        output.append(f"{e.name}.{key}='{val[file_frame].decode()}'")
            return ", ".join(output)

        def get_goal(file_frame: int) -> Optional[np.ndarray]:
            """Return current goal of robot, if robot exists and has a goal."""
            goal = None
            if "robot" in categories:
                robot = self.entities[categories.index("robot")]
                try:
                    goal = robot["goals"][file_frame]
                except KeyError:
                    pass
            if goal is not None and np.isnan(goal).any():
                goal = None
            return goal

        def get_target(file_frame: int) -> Tuple[List, List]:
            """Return line points from robot to target"""
            if "robot" in categories:
                robot = self.entities[categories.index("robot")]
                rpos = robot["positions"][file_frame]
                target = robot["targets"][file_frame]
                return [rpos[0], target[0]], [rpos[1], target[1]]
            return [], []

        def init():
            ax.set_xlim(-0.5 * self.world_size[0], 0.5 * self.world_size[0])
            ax.set_ylim(-0.5 * self.world_size[1], 0.5 * self.world_size[1])
            ax.set_xticks([])
            ax.set_xticks([], minor=True)
            ax.set_yticks([])
            ax.set_yticks([], minor=True)

            for e_poly in entity_polygons:
                ax.add_patch(e_poly)
            ax.add_patch(border)
            return lines + entity_polygons + [border] + points

        n_frames = self.entity_poses.shape[1]

        if options["cut_frames_end"] == 0 or options["cut_frames_end"] is None:
            options["cut_frames_end"] = n_frames
        if options["cut_frames_start"] is None:
            options["cut_frames_start"] = 0
        frame_range = (
            options["cut_frames_start"],
            min(n_frames, options["cut_frames_end"]),
        )

        n_frames = int((frame_range[1] - frame_range[0]) / options["speedup"])

        start_pose = self.entity_poses_rad[:, frame_range[0]]

        self.middle_of_swarm = np.mean(start_pose, axis=0)
        min_view = np.max((np.max(start_pose, axis=0) - np.min(start_pose, axis=0))[:2])
        self.view_size = np.max([options["view_size"], min_view + options["margin"]])

        if video_path is not None:
            pbar = tqdm(range(n_frames))

        def update(frame):
            if "pbar" in locals().keys():
                pbar.update(1)
                pbar.refresh()

            if frame < n_frames:
                entity_poses = self.entity_poses_rad

                file_frame = (frame * options["speedup"]) + frame_range[0]
                this_pose = entity_poses[:, file_frame]

                if not options["fixed_view"]:

                    # Find the maximal distance between the entities in x or y direction
                    min_view = np.max(
                        (np.max(this_pose, axis=0) - np.min(this_pose, axis=0))[:2]
                    )

                    new_view_size = np.max(
                        [options["view_size"], min_view + options["margin"]]
                    )

                    if not np.isnan(min_view).any() and not new_view_size is np.nan:
                        self.middle_of_swarm = options[
                            "slow_view"
                        ] * self.middle_of_swarm + (1 - options["slow_view"]) * np.mean(
                            this_pose, axis=0
                        )

                        self.view_size = (
                            options["slow_zoom"] * self.view_size
                            + (1 - options["slow_zoom"]) * new_view_size
                        )

                    ax.set_xlim(
                        self.middle_of_swarm[0] - self.view_size / 2,
                        self.middle_of_swarm[0] + self.view_size / 2,
                    )
                    ax.set_ylim(
                        self.middle_of_swarm[1] - self.view_size / 2,
                        self.middle_of_swarm[1] + self.view_size / 2,
                    )
                if options["show_text"]:
                    ax.set_title(title(file_frame))

                if options["render_goals"]:
                    goal = get_goal(file_frame)
                    if goal is not None:
                        points[0].set_offsets(goal)

                if options["render_targets"]:
                    points[1].set_data(get_target(file_frame))

                poses_trails = entity_poses[
                    :, max(0, file_frame - options["trail"]) : file_frame
                ]
                for i_entity in range(n_entities):
                    lines[i_entity].set_data(
                        poses_trails[i_entity, :, 0], poses_trails[i_entity, :, 1]
                    )

                    current_pose = entity_poses[i_entity, file_frame]
                    t = mpl.transforms.Affine2D().translate(
                        current_pose[0], current_pose[1]
                    )
                    r = mpl.transforms.Affine2D().rotate(current_pose[2])
                    tra = r + t + ax.transData
                    entity_polygons[i_entity].set_transform(tra)
            else:
                raise Exception(
                    f"Frame is bigger than n_frames {file_frame} of {n_frames}"
                )
            return lines + entity_polygons + [border] + points

        print(f"Preparing to render n_frames: {n_frames}")

        ani = animation.FuncAnimation(
            fig,
            update,
            frames=n_frames,
            init_func=init,
            blit=platform.system() != "Darwin",
            interval=1000 / self.frequency,
            repeat=False,
        )

        if video_path is not None:

            # if i % (n / 40) == 0:
            #     print(f"Saving frame {i} of {n} ({100*i/n:.1f}%)")

            video_path = Path(video_path)
            if video_path.exists():
                y = input(f"Video {str(video_path)} exists. Overwrite? (y/n)")
                if y == "y":
                    video_path.unlink()

            if not video_path.exists():
                print(f"saving video to {video_path}")

                writervideo = animation.FFMpegWriter(fps=self.frequency)
                ani.save(video_path, writer=writervideo, dpi=options["dpi"])
            plt.close()
        else:
            plt.show()

Ancestors

  • h5py._hl.files.File
  • h5py._hl.group.Group
  • h5py._hl.base.HLObject
  • h5py._hl.base.CommonStateObject
  • h5py._hl.base.MutableMappingHDF5
  • h5py._hl.base.MappingHDF5
  • collections.abc.MutableMapping
  • collections.abc.Mapping
  • collections.abc.Collection
  • collections.abc.Sized
  • collections.abc.Iterable
  • collections.abc.Container

Instance variables

var default_sampling
Expand source code
@property
def default_sampling(self):
    assert (
        "samplings" in self
    ), "The file does not have a group 'sampling' which is required."
    if "default" in self["samplings"].attrs:
        return self["samplings"].attrs["default"]
    return None
var entities
Expand source code
@property
def entities(self):
    return [
        robofish.io.Entity.from_h5py_group(self["entities"][name])
        for name in self.entity_names
    ]
var entity_actions_speeds_turns

Calculate the speed, turn and from the recorded positions and orientations.

The turn is calculated by the change of orientation between frames. The speed is calculated by the distance between the points, projected on the new orientation vector. The sideway change of position cannot be represented with this method.

Returns

An array with shape (number_of_entities, number_of_positions -1, 2 (speed in cm/frame, turn in rad/frame).

Expand source code
@property
def entity_actions_speeds_turns(self):
    """Calculate the speed, turn and from the recorded positions and orientations.

    The turn is calculated by the change of orientation between frames.
    The speed is calculated by the distance between the points, projected on the new orientation vector.
    The sideway change of position cannot be represented with this method.

    Returns:
        An array with shape (number_of_entities, number_of_positions -1, 2 (speed in cm/frame, turn in rad/frame).
    """
    return self.select_entity_property(
        None, entity_property=Entity.actions_speeds_turns
    )
var entity_names : Iterable[str]

Getter for the names of all entities

Returns

Array of all names.

Expand source code
@property
def entity_names(self) -> Iterable[str]:
    """Getter for the names of all entities

    Returns:
        Array of all names.
    """
    return sorted(self["entities"].keys())
var entity_orientations
Expand source code
@property
def entity_orientations(self):
    return self.select_entity_property(None, entity_property=Entity.orientations)
var entity_orientations_rad
Expand source code
@property
def entity_orientations_rad(self):
    return self.select_entity_property(
        None, entity_property=Entity.orientations_rad
    )
var entity_poses
Expand source code
@property
def entity_poses(self):
    return self.select_entity_property(None, entity_property=Entity.poses)
var entity_poses_calc_ori

Deprecated since version: 0.2

This will be removed in 0.2.4. We found that our calculation of 'poses_calc_ori' is flawed.Please replace it with 'poses' and use the tracked orientation.If you see this message and you don't know what to do, update all packages and if nothing helps, contact Andi.

Don't ignore this warning, it's a serious issue.

Expand source code
@property
@deprecation.deprecated(
    deprecated_in="0.2",
    removed_in="0.2.4",
    details="We found that our calculation of 'poses_calc_ori' is flawed."
    "Please replace it with 'poses' and use the tracked orientation."
    "If you see this message and you don't know what to do, update all packages and if nothing helps, contact Andi.\n"
    "Don't ignore this warning, it's a serious issue.",
)
def entity_poses_calc_ori(self):
    return self.select_entity_property(None, entity_property=Entity.poses_calc_ori)
var entity_poses_calc_ori_rad

Deprecated since version: 0.2

This will be removed in 0.2.4. We found that our calculation of 'poses_calc_ori_rad' is flawed.Please replace it with 'poses_rad' and use the tracked orientation.If you see this message and you don't know what to do, update all packages and if nothing helps, contact Andi.

Don't ignore this warning, it's a serious issue.

Expand source code
@property
@deprecation.deprecated(
    deprecated_in="0.2",
    removed_in="0.2.4",
    details="We found that our calculation of 'poses_calc_ori_rad' is flawed."
    "Please replace it with 'poses_rad' and use the tracked orientation."
    "If you see this message and you don't know what to do, update all packages and if nothing helps, contact Andi.\n"
    "Don't ignore this warning, it's a serious issue.",
)
def entity_poses_calc_ori_rad(self):
    return self.select_entity_property(
        None, entity_property=Entity.poses_calc_ori_rad
    )
var entity_poses_rad
Expand source code
@property
def entity_poses_rad(self):
    return self.select_entity_property(None, entity_property=Entity.poses_rad)
var entity_positions
Expand source code
@property
def entity_positions(self):
    return self.select_entity_property(None, entity_property=Entity.positions)
var entity_speeds_turns

Deprecated since version: 0.2

This will be removed in 0.2.4. We found that our calculation of 'entity_speeds_turns' is flawed and replaced it with 'entity_actions_speeds_turns'. The difference in calculation is, that the tracked orientation is used now which gives the fish the ability to swim backwards. If you see this message and you don't know what to do, update all packages and if nothing helps, contact Andi.

Don't ignore this warning, it's a serious issue.

Expand source code
@property
@deprecation.deprecated(
    deprecated_in="0.2",
    removed_in="0.2.4",
    details="We found that our calculation of 'entity_speeds_turns' is flawed and replaced it "
    "with 'entity_actions_speeds_turns'. The difference in calculation is, that the tracked "
    "orientation is used now which gives the fish the ability to swim backwards. "
    "If you see this message and you don't know what to do, update all packages and if nothing helps, contact Andi.\n"
    "Don't ignore this warning, it's a serious issue.",
)
def entity_speeds_turns(self):
    return self.select_entity_property(None, entity_property=Entity.speed_turn)
var frequency
Expand source code
@property
def frequency(self):
    common_sampling = self.common_sampling()
    assert common_sampling is not None, "The sampling differs between entities."
    assert (
        "frequency_hz" in common_sampling.attrs
    ), "The common sampling has no frequency_hz"
    return common_sampling.attrs["frequency_hz"]
var temp_dir
Expand source code
@property
def temp_dir(self):
    cla = type(self)
    if cla._temp_dir is None:
        cla._temp_dir = tempfile.TemporaryDirectory(prefix="robofish-io-")
    return Path(cla._temp_dir.name)
var world_size
Expand source code
@property
def world_size(self):
    return self.attrs["world_size_cm"]

Methods

def clear_calculated_data(self, verbose=True)

Delete all calculated data from the files.

Expand source code
def clear_calculated_data(self, verbose=True):
    """Delete all calculated data from the files."""
    txt = ""
    for e in self.entities:
        txt += f"Deleting from {e}. Attrs: ["
        for a in ["poses_hash"]:
            if a in e.attrs:
                del e.attrs[a]
                txt += f"{a}, "
        txt = txt[:-2] + "] Datasets: ["
        for g in ["calculated_actions_speeds_turns", "calculated_orientations_rad"]:
            if g in e:
                del e[g]
                txt += f"{g}, "
        txt = txt[:-2] + "]\n"
    if verbose:
        print(txt[:-1])
def close(self)

Close the file. All open objects become invalid

Expand source code
def close(self):
    if self.mode != "r":
        self.update_calculated_data()
    super().close()
def common_sampling(self, entities: Iterable[ForwardRef('robofish.io.Entity')] = None) ‑> h5py._hl.group.Group

Check if all entities have the same sampling.

Args

entities
optional array of entities. If None is given, all entities are checked.

Returns

The h5py group of the common sampling. If there is no common sampling, None will be returned.

Expand source code
def common_sampling(
    self, entities: Iterable["robofish.io.Entity"] = None
) -> h5py.Group:
    """Check if all entities have the same sampling.

    Args:
        entities: optional array of entities. If None is given, all entities are checked.
    Returns:
        The h5py group of the common sampling. If there is no common sampling, None will be returned.
    """
    custom_sampling = None
    for entity in self.entities:
        if "sampling" in entity["positions"].attrs:
            this_sampling = entity["positions"].attrs["sampling"]
            if custom_sampling is None:
                custom_sampling = this_sampling
            elif custom_sampling != this_sampling:
                return None
    sampling = self.default_sampling if custom_sampling is None else custom_sampling
    return self["samplings"][sampling]
def create_entity(self, category: str, poses: Iterable = None, name: str = None, positions: Iterable = None, orientations: Iterable = None, outlines: Iterable = None, sampling: str = None) ‑> str

Creates a new single entity.

Args

TODO
category
the of the entity. The canonical values are ['organism', 'robot', 'obstacle'].
poses
optional two dimensional array, containing the poses of the entity (x,y,orientation_x, orientation_y).
poses_rad
optional two dimensional containing the poses of the entity (x,y, orientation_rad).
name
optional name of the entity. If no name is given, the is used with an id (e.g. 'fish_1')
outlines
optional three dimensional array, containing the outlines of the entity

Returns

Name of the created entity

Expand source code
def create_entity(
    self,
    category: str,
    poses: Iterable = None,
    name: str = None,
    positions: Iterable = None,
    orientations: Iterable = None,
    outlines: Iterable = None,
    sampling: str = None,
) -> str:
    """Creates a new single entity.

    Args:
        TODO
        category: the  of the entity. The canonical values are ['organism', 'robot', 'obstacle'].
        poses: optional two dimensional array, containing the poses of the entity (x,y,orientation_x, orientation_y).
        poses_rad: optional two dimensional containing the poses of the entity (x,y, orientation_rad).
        name: optional name of the entity. If no name is given, the  is used with an id (e.g. 'fish_1')
        outlines: optional three dimensional array, containing the outlines of the entity
    Returns:
        Name of the created entity
    """

    if sampling is None and self.default_sampling is None:
        raise Exception(
            "There was no sampling specified, when creating the file, nor when creating the entity."
        )

    entity = robofish.io.Entity.create_entity(
        self["entities"],
        category,
        poses,
        name,
        positions,
        orientations,
        outlines,
        sampling,
    )

    return entity
def create_multiple_entities(self, category: str, poses: Iterable, names: Iterable[str] = None, outlines=None, sampling=None) ‑> Iterable

Creates multiple entities.

Args

category
The common category for the entities. The canonical values are ['organism', 'robot', 'obstacle'].
poses
three dimensional array, containing the poses of the entity.
name
optional array of names of the entities. If no names are given, the category is used with an id (e.g. 'fish_1')
outlines
optional array, containing the outlines of the entities, either a three dimensional common outline array can be given, or a four dimensional array.
sampling
The string refference to the sampling. If none is given, the standard sampling from creating the file is used.

Returns

Array of names of the created entities

Expand source code
def create_multiple_entities(
    self,
    category: str,
    poses: Iterable,
    names: Iterable[str] = None,
    outlines=None,
    sampling=None,
) -> Iterable:
    """Creates multiple entities.

    Args:
        category: The common category for the entities. The canonical values are ['organism', 'robot', 'obstacle'].
        poses: three dimensional array, containing the poses of the entity.
        name: optional array of names of the entities. If no names are given, the category is used with an id (e.g. 'fish_1')
        outlines: optional array, containing the outlines of the entities, either a three dimensional common outline array can be given, or a four dimensional array.
        sampling: The string refference to the sampling. If none is given, the standard sampling from creating the file is used.
    Returns:
        Array of names of the created entities
    """

    assert (
        poses.ndim == 3
    ), f"A 3 dimensional array was expected (entity, timestep, 3). There were {poses.ndim} dimensions in poses: {poses.shape}"
    assert poses.shape[2] in [3, 4]
    agents = poses.shape[0]
    entity_names = []

    for i in range(agents):
        e_name = None if names is None else names[i]
        e_outline = (
            outlines if outlines is None or outlines.ndim == 3 else outlines[i]
        )

        entity_names.append(
            self.create_entity(
                category=category,
                sampling=sampling,
                poses=poses[i],
                name=e_name,
                outlines=e_outline,
            )
        )
    return entity_names
def create_sampling(self, name: str = None, frequency_hz: int = None, monotonic_time_points_us: Iterable = None, calendar_time_points: Iterable = None, default: bool = False)
Expand source code
def create_sampling(
    self,
    name: str = None,
    frequency_hz: int = None,
    monotonic_time_points_us: Iterable = None,
    calendar_time_points: Iterable = None,
    default: bool = False,
):

    # Find Name for sampling if none is given
    if name is None:
        if frequency_hz is not None:
            name = "%d hz" % frequency_hz

        i = 1
        while name is None or name in self["samplings"]:
            name = "sampling_%d" % i
            i += 1

    sampling = self["samplings"].create_group(name)

    if monotonic_time_points_us is not None:

        monotonic_time_points_us = np.array(
            monotonic_time_points_us, dtype=np.int64
        )
        sampling.create_dataset(
            "monotonic_time_points_us", data=monotonic_time_points_us
        )
        if frequency_hz is None:
            diff = np.diff(monotonic_time_points_us)
            if np.all(diff == diff[0]) and diff[0] > 0:
                frequency_hz = 1e6 / diff[0]
                warnings.warn(
                    f"The frequency_hz of {frequency_hz:.2f}hz was calculated automatically by robofish.io. The safer variant is to pass it using frequency_hz.\nThis is important when using fish_models with the files."
                )

            else:
                warnings.warn(
                    "The frequency_hz could not be calculated automatically. When using fish_models, the file will access frequency_hz."
                )

    if frequency_hz is not None:
        sampling.attrs["frequency_hz"] = (np.float32)(frequency_hz)

    if calendar_time_points is not None:

        def format_calendar_time_point(p):
            if isinstance(p, datetime.datetime):
                assert p.tzinfo is not None, "Missing timezone for calendar point."
                return p.isoformat(timespec="microseconds")
            elif isinstance(p, str):
                assert p == datetime.datetime.fromisoformat(p).isoformat(
                    timespec="microseconds"
                )
                return p
            else:
                assert (
                    False
                ), "Calendar points must be datetime.datetime instances or strings."

        calendar_time_points = [
            format_calendar_time_point(p) for p in calendar_time_points
        ]

        sampling.create_dataset(
            "calendar_time_points",
            data=calendar_time_points,
            dtype=h5py.string_dtype(encoding="utf-8"),
        )

    if default:
        self["samplings"].attrs["default"] = name
    return name
def plot(self, ax=None, lw_distances=False, lw=2, ms=32, figsize=None, step_size=4, c=None, cmap='Set1', skip_timesteps=0, max_timesteps=None, show=False, legend=True)

Plot the file using matplotlib.pyplot

The tracks in the file are plotted using matplotlib.plot().

Args

ax : matplotlib.axes, optional
An axes object to plot in. If None is given, a new figure is created.
lw_distances : bool, optional
Flag to show the distances between individuals through line width.
figsize : Tuple[int], optional
Size of a newly created figure.
step_size : int, optional
when using lw_distances, the track is split into sections which have a common line width. This parameter defines the length of the sections.
c : Array[color_representation], optional
An array of colors. Each item has to be matplotlib.colors.is_color_like(item).
cmap : matplotlib.colors.Colormap, optional
The colormap to use
skip_timesteps : int, optional
Skip timesteps in the begining of the file
max_timesteps : int, optional
Cut of timesteps in the end of the file.
show : bool, optional
Show the created plot.

Returns

matplotlib.axes
The axes object with the plot.
Expand source code
def plot(
    self,
    ax=None,
    lw_distances=False,
    lw=2,
    ms=32,
    figsize=None,
    step_size=4,
    c=None,
    cmap="Set1",
    skip_timesteps=0,
    max_timesteps=None,
    show=False,
    legend=True,
):
    """Plot the file using matplotlib.pyplot

    The tracks in the file are plotted using matplotlib.plot().

    Args:
        ax (matplotlib.axes, optional): An axes object to plot in. If None is given, a new figure is created.
        lw_distances (bool, optional):  Flag to show the distances between individuals through line width.
        figsize (Tuple[int], optional): Size of a newly created figure.
        step_size (int, optional): when using lw_distances, the track is split into sections which have a common line width. This parameter defines the length of the sections.
        c (Array[color_representation], optional): An array of colors. Each item has to be matplotlib.colors.is_color_like(item).
        cmap (matplotlib.colors.Colormap, optional): The colormap to use
        skip_timesteps (int, optional): Skip timesteps in the begining of the file
        max_timesteps (int, optional): Cut of timesteps in the end of the file.
        show (bool, optional): Show the created plot.
    Returns:
        matplotlib.axes: The axes object with the plot.
    """

    if max_timesteps is not None:
        poses = self.entity_positions[
            :, skip_timesteps : max_timesteps + skip_timesteps
        ]
    else:
        poses = self.entity_positions[:, skip_timesteps:]

    if lw_distances and poses.shape[0] < 2:
        lw_distances = False

    if lw_distances:
        poses_diff = np.diff(poses, axis=0)  # Axis 0 is fish
        distances = np.linalg.norm(poses_diff, axis=2)

        min_distances = np.min(distances, axis=0)

        # Magic numbers found by trial and error. Everything above 15cm will be represented as line width 1
        max_distance = 10
        max_lw = 4
        line_width = (
            np.clip(max_distance - min_distances, 1, max_distance)
            * max_lw
            / max_distance
        )
    else:
        step_size = poses.shape[1]

    cmap = cm.get_cmap(cmap)

    x_world, y_world = self.world_size
    if figsize is None:
        figsize = (8, 8)

    if ax is None:
        fig, ax = plt.subplots(1, 1, figsize=figsize)

    if self.path is not None:
        ax.set_title("\n".join(wrap(Path(self.path).name, width=35)))

    ax.set_xlim(-x_world / 2, x_world / 2)
    ax.set_ylim(-y_world / 2, y_world / 2)
    for fish_id in range(poses.shape[0]):
        if c is None:
            this_c = cmap(fish_id)
        elif isinstance(c, list):
            this_c = c[fish_id]

        timesteps = poses.shape[1] - 1
        for t in range(0, timesteps, step_size):
            if lw_distances:
                lw = np.mean(line_width[t : t + step_size + 1])

            ax.plot(
                poses[fish_id, t : t + step_size + 1, 0],
                poses[fish_id, t : t + step_size + 1, 1],
                c=this_c,
                lw=lw,
            )
        # Plotting outside of the figure to have the label
        ax.plot([550, 600], [550, 600], lw=5, c=this_c, label=fish_id)

    # ax.scatter(
    #     [poses[:, skip_timesteps, 0]],
    #     [poses[:, skip_timesteps, 1]],
    #     marker="h",
    #     c="black",
    #     s=ms,
    #     label="Start",
    #     zorder=5,
    # )
    ax.scatter(
        [poses[:, -1, 0]],
        [poses[:, -1, 1]],
        marker="x",
        c="black",
        s=ms,
        label="End",
        zorder=5,
    )
    if legend and isinstance(legend, str):
        ax.legend(legend)
    elif legend:
        ax.legend()
    ax.set_xlabel("x [cm]")
    ax.set_ylabel("y [cm]")

    if show:
        plt.show()

    return ax
def render(self, video_path=None, **kwargs)

Render a video of the file.

As there are render functions in gym_guppy and robofish.trackviewer, this function is a temporary addition. The goal should be to bring together the rendering tools.

Expand source code
def render(self, video_path=None, **kwargs):
    """Render a video of the file.

    As there are render functions in gym_guppy and robofish.trackviewer, this function is a temporary addition.
    The goal should be to bring together the rendering tools."""

    if video_path is not None:
        try:
            run(["ffmpeg"], capture_output=True)
        except Exception as e:
            raise Exception(
                f"ffmpeg is required to store videos. Please install it.\n{e}"
            )

    def shape_vertices(scale=1) -> np.ndarray:
        base_shape = np.array(
            [
                (+3.0, +0.0),
                (+2.5, +1.0),
                (+1.5, +1.5),
                (-2.5, +1.0),
                (-4.5, +0.0),
                (-2.5, -1.0),
                (+1.5, -1.5),
                (+2.5, -1.0),
            ]
        )
        return base_shape * scale

    default_options = {
        "linewidth": 2,
        "speedup": 1,
        "trail": 100,
        "entity_scale": 0.2,
        "fixed_view": False,
        "view_size": 50,
        "margin": 15,
        "slow_view": 0.8,
        "slow_zoom": 0.95,
        "cut_frames_start": None,
        "cut_frames_end": None,
        "show_text": False,
        "render_goals": False,
        "render_targets": False,
        "dpi": 200,
        "figsize": 10,
    }

    options = {
        key: kwargs[key] if key in kwargs else default_options[key]
        for key in default_options.keys()
    }

    fig, ax = plt.subplots(figsize=(options["figsize"], options["figsize"]))
    ax.set_aspect("equal")
    ax.set_facecolor("gray")
    plt.tight_layout(pad=0.05)
    n_entities = len(self.entities)
    lines = [
        plt.plot([], [], lw=options["linewidth"], zorder=0)[0]
        for _ in range(n_entities)
    ]
    points = [
        plt.scatter([], [], marker="x", color="k"),
        plt.plot([], [], linestyle="dotted", alpha=0.5, color="k", zorder=0)[0],
    ]
    categories = [entity.attrs.get("category", None) for entity in self.entities]
    entity_polygons = [
        patches.Polygon(shape_vertices(options["entity_scale"]), facecolor=color)
        for color in [
            "gray" if category == "robot" else "k" for category in categories
        ]
    ]

    border_vertices = np.array(
        [
            np.array([-1, -1, 1, 1, -1]) * self.world_size[0] / 2,
            np.array([-1, 1, 1, -1, -1]) * self.world_size[1] / 2,
        ]
    )

    spacing = 10
    x = np.arange(
        -0.5 * self.world_size[0] + spacing, 0.5 * self.world_size[0], spacing
    )
    y = np.arange(
        -0.5 * self.world_size[1] + spacing, 0.5 * self.world_size[1], spacing
    )
    xv, yv = np.meshgrid(x, y)

    grid_points = plt.scatter(xv, yv, c="gray", s=1.5)

    # border = plt.plot(border_vertices[0], border_vertices[1], "k")
    border = patches.Polygon(border_vertices.T, facecolor="w", zorder=-1)

    def title(file_frame: int) -> str:
        """Search for datasets containing text for displaying it in the video"""
        output = []
        for e in self.entities:
            for key, val in e.items():
                if val.dtype == object and type(val[0]) == bytes:
                    output.append(f"{e.name}.{key}='{val[file_frame].decode()}'")
        return ", ".join(output)

    def get_goal(file_frame: int) -> Optional[np.ndarray]:
        """Return current goal of robot, if robot exists and has a goal."""
        goal = None
        if "robot" in categories:
            robot = self.entities[categories.index("robot")]
            try:
                goal = robot["goals"][file_frame]
            except KeyError:
                pass
        if goal is not None and np.isnan(goal).any():
            goal = None
        return goal

    def get_target(file_frame: int) -> Tuple[List, List]:
        """Return line points from robot to target"""
        if "robot" in categories:
            robot = self.entities[categories.index("robot")]
            rpos = robot["positions"][file_frame]
            target = robot["targets"][file_frame]
            return [rpos[0], target[0]], [rpos[1], target[1]]
        return [], []

    def init():
        ax.set_xlim(-0.5 * self.world_size[0], 0.5 * self.world_size[0])
        ax.set_ylim(-0.5 * self.world_size[1], 0.5 * self.world_size[1])
        ax.set_xticks([])
        ax.set_xticks([], minor=True)
        ax.set_yticks([])
        ax.set_yticks([], minor=True)

        for e_poly in entity_polygons:
            ax.add_patch(e_poly)
        ax.add_patch(border)
        return lines + entity_polygons + [border] + points

    n_frames = self.entity_poses.shape[1]

    if options["cut_frames_end"] == 0 or options["cut_frames_end"] is None:
        options["cut_frames_end"] = n_frames
    if options["cut_frames_start"] is None:
        options["cut_frames_start"] = 0
    frame_range = (
        options["cut_frames_start"],
        min(n_frames, options["cut_frames_end"]),
    )

    n_frames = int((frame_range[1] - frame_range[0]) / options["speedup"])

    start_pose = self.entity_poses_rad[:, frame_range[0]]

    self.middle_of_swarm = np.mean(start_pose, axis=0)
    min_view = np.max((np.max(start_pose, axis=0) - np.min(start_pose, axis=0))[:2])
    self.view_size = np.max([options["view_size"], min_view + options["margin"]])

    if video_path is not None:
        pbar = tqdm(range(n_frames))

    def update(frame):
        if "pbar" in locals().keys():
            pbar.update(1)
            pbar.refresh()

        if frame < n_frames:
            entity_poses = self.entity_poses_rad

            file_frame = (frame * options["speedup"]) + frame_range[0]
            this_pose = entity_poses[:, file_frame]

            if not options["fixed_view"]:

                # Find the maximal distance between the entities in x or y direction
                min_view = np.max(
                    (np.max(this_pose, axis=0) - np.min(this_pose, axis=0))[:2]
                )

                new_view_size = np.max(
                    [options["view_size"], min_view + options["margin"]]
                )

                if not np.isnan(min_view).any() and not new_view_size is np.nan:
                    self.middle_of_swarm = options[
                        "slow_view"
                    ] * self.middle_of_swarm + (1 - options["slow_view"]) * np.mean(
                        this_pose, axis=0
                    )

                    self.view_size = (
                        options["slow_zoom"] * self.view_size
                        + (1 - options["slow_zoom"]) * new_view_size
                    )

                ax.set_xlim(
                    self.middle_of_swarm[0] - self.view_size / 2,
                    self.middle_of_swarm[0] + self.view_size / 2,
                )
                ax.set_ylim(
                    self.middle_of_swarm[1] - self.view_size / 2,
                    self.middle_of_swarm[1] + self.view_size / 2,
                )
            if options["show_text"]:
                ax.set_title(title(file_frame))

            if options["render_goals"]:
                goal = get_goal(file_frame)
                if goal is not None:
                    points[0].set_offsets(goal)

            if options["render_targets"]:
                points[1].set_data(get_target(file_frame))

            poses_trails = entity_poses[
                :, max(0, file_frame - options["trail"]) : file_frame
            ]
            for i_entity in range(n_entities):
                lines[i_entity].set_data(
                    poses_trails[i_entity, :, 0], poses_trails[i_entity, :, 1]
                )

                current_pose = entity_poses[i_entity, file_frame]
                t = mpl.transforms.Affine2D().translate(
                    current_pose[0], current_pose[1]
                )
                r = mpl.transforms.Affine2D().rotate(current_pose[2])
                tra = r + t + ax.transData
                entity_polygons[i_entity].set_transform(tra)
        else:
            raise Exception(
                f"Frame is bigger than n_frames {file_frame} of {n_frames}"
            )
        return lines + entity_polygons + [border] + points

    print(f"Preparing to render n_frames: {n_frames}")

    ani = animation.FuncAnimation(
        fig,
        update,
        frames=n_frames,
        init_func=init,
        blit=platform.system() != "Darwin",
        interval=1000 / self.frequency,
        repeat=False,
    )

    if video_path is not None:

        # if i % (n / 40) == 0:
        #     print(f"Saving frame {i} of {n} ({100*i/n:.1f}%)")

        video_path = Path(video_path)
        if video_path.exists():
            y = input(f"Video {str(video_path)} exists. Overwrite? (y/n)")
            if y == "y":
                video_path.unlink()

        if not video_path.exists():
            print(f"saving video to {video_path}")

            writervideo = animation.FFMpegWriter(fps=self.frequency)
            ani.save(video_path, writer=writervideo, dpi=options["dpi"])
        plt.close()
    else:
        plt.show()
def save_as(self, path: Union[str, pathlib.Path], strict_validate: bool = True, no_warning: bool = False)

Save a copy of the file

Args

path
path to a io file as a string or path object. If no path is specified, the last known path (from loading or saving) is used.
strict_validate
optional boolean, if the file should be strictly validated, before saving. The default is True.
no_warning
optional boolean, to remove the warning from the function.

Returns

The file itself, so something like f = robofish.io.File().save_as("file.hdf5") works

Expand source code
def save_as(
    self,
    path: Union[str, Path],
    strict_validate: bool = True,
    no_warning: bool = False,
):
    """Save a copy of the file

    Args:
        path: path to a io file as a string or path object. If no path is specified, the last known path (from loading or saving) is used.
        strict_validate: optional boolean, if the file should be strictly validated, before saving. The default is True.
        no_warning: optional boolean, to remove the warning from the function.
    Returns:
        The file itself, so something like f = robofish.io.File().save_as("file.hdf5") works
    """

    self.update_calculated_data()
    self.validate(strict_validate=strict_validate)

    # Ensure all buffered data has been written to disk
    self.flush()

    path = Path(path).resolve()
    path.parent.mkdir(parents=True, exist_ok=True)

    filename = self.filename
    self.flush()
    self.close()

    self.closed = True

    shutil.copyfile(filename, path)
    if not no_warning:
        warnings.warn(
            "The 'save_as' function closes the file currently to be able to store it. If you want to use the file after saving it, please reload the file. The save_as function can be avoided by opening the correct file directly. If you want to get rid of this warning use 'save_as(..., no_warning=True)'"
        )
    return None
def select_entity_poses(self, *args, ori_rad=False, **kwargs)
Expand source code
def select_entity_poses(self, *args, ori_rad=False, **kwargs):
    entity_property = Entity.poses_rad if ori_rad else Entity.poses
    return self.select_entity_property(
        *args, entity_property=entity_property, **kwargs
    )
def select_entity_property(self, predicate: function = None, entity_property: Union[property, str] = <property object>) ‑> Iterable

Get a property of selected entities.

Entities can be selected, using a lambda function. The property of the entities can be selected.

Args

predicate
a lambda function, selecting entities
(example: lambda e: e.category == "fish")
entity_property
a property of the Entity class (example: Entity.poses_rad) or a string with the name of the dataset.

Returns

An three dimensional array of all properties of all entities with the shape (entity, time, property_length). If an entity has a shorter length of the property, the output will be filled with nans.

Expand source code
def select_entity_property(
    self,
    predicate: types.LambdaType = None,
    entity_property: Union[property, str] = Entity.poses,
) -> Iterable:
    """Get a property of selected entities.

    Entities can be selected, using a lambda function.
    The property of the entities can be selected.

    Args:
        predicate: a lambda function, selecting entities
        (example: lambda e: e.category == "fish")
        entity_property: a property of the Entity class (example: Entity.poses_rad) or a string with the name of the dataset.
    Returns:
        An three dimensional array of all properties of all entities with the shape (entity, time, property_length).
        If an entity has a shorter length of the property, the output will be filled with nans.
    """

    entities = self.entities
    if predicate is not None:
        entities = [e for e in entities if predicate(e)]

    assert self.common_sampling(entities) is not None

    # Initialize poses output array
    if isinstance(entity_property, str):
        properties = [entity[entity_property] for entity in entities]
    else:
        properties = [entity_property.__get__(entity) for entity in entities]

    max_timesteps = max([0] + [p.shape[0] for p in properties])

    property_array = np.empty(
        (len(entities), max_timesteps, properties[0].shape[1])
    )
    property_array[:] = np.nan

    # Fill output array
    for i, entity in enumerate(entities):
        property_array[i][: properties[i].shape[0]] = properties[i]
    return property_array
def to_string(self, output_format: str = 'shape', max_width: int = 120, full_attrs: bool = False) ‑> str

The file is formatted to a human readable format.

Args

output_format
['shape', 'full'] show the shape, or the full content of datasets
max_width
set the width in characters after which attribute values get abbreviated
full_attrs
do not abbreviate attribute values if True

Returns

A human readable string, representing the file

Expand source code
def to_string(
    self,
    output_format: str = "shape",
    max_width: int = 120,
    full_attrs: bool = False,
) -> str:
    """The file is formatted to a human readable format.
    Args:
        output_format: ['shape', 'full'] show the shape, or the full content of datasets
        max_width: set the width in characters after which attribute values get abbreviated
        full_attrs: do not abbreviate attribute values if True
    Returns:
        A human readable string, representing the file
    """

    def recursive_stringify(
        obj: h5py.Group,
        output_format: str,
        parent_indices: List[int] = [],
        parent_siblings: List[int] = [],
    ) -> str:
        """This function crawls recursively into hdf5 groups.
        Datasets and attributes are directly attached, for groups, the function is recursively called again.
        Args:
            obj: a h5py group
            output_format: ['shape', 'full'] show the shape, or the full content of datasets
        Returns:
            A string representation of the group
        """

        def lines(dataset_attribute: bool = False) -> str:
            """Get box-drawing characters for the graph lines."""
            line = ""
            for pi, ps in zip(parent_indices, parent_siblings):
                if pi < ps - 1:
                    line += "│ "
                else:
                    line += "  "
            if dataset_attribute:
                line += "  "
            line += "─ "
            junction_index = 2 * len(parent_indices) + dataset_attribute * 2 - 1
            last = "└"
            other = "├"
            if dataset_attribute:
                j = (
                    last
                    if list(value.attrs.keys()).index(d_key) == len(value.attrs) - 1
                    else other
                )
            else:
                j = last if index == num_children - 1 else other
            line = line[: junction_index + 1] + j + line[junction_index + 1 :]
            if isinstance(value, h5py.Group) or (
                isinstance(value, h5py.Dataset)
                and not dataset_attribute
                and value.attrs
            ):
                line = line[:-1] + "┬─"
            else:
                line = line[:-1] + "──"

            return line + " "

        s = ""
        max_key_len = 0
        num_children = 0
        if obj.attrs:
            max_key_len = max(len(key) for key in obj.attrs)
            num_children += len(obj.attrs)
        if hasattr(obj, "items"):
            max_key_len = max([len(key) for key in obj] + [max_key_len])
            num_children += len(obj)
        index = 0
        if obj.attrs:
            for key, value in obj.attrs.items():
                if not full_attrs:
                    value = str(value).replace("\n", " ").strip()
                    if len(value) > max_width - max_key_len - len(lines()):
                        value = (
                            value[: max_width - max_key_len - len(lines()) - 3]
                            + "..."
                        )
                s += f"{lines()}{key: <{max_key_len}}  {value}\n"
                index += 1
        if hasattr(obj, "items"):
            for key, value in obj.items():
                if isinstance(value, h5py.Dataset):
                    if output_format == "shape":
                        s += (
                            f"{lines()}"
                            f"{key: <{max_key_len}}  Shape {value.shape}\n"
                        )
                    else:
                        s += f"{lines()}{key}:\n"
                        s += np.array2string(
                            value,
                            precision=2,
                            separator=" ",
                            suppress_small=True,
                        )
                        s += "\n"

                    if value.attrs:
                        d_max_key_len = max(len(dk) for dk in value.attrs)
                    for d_key, d_value in value.attrs.items():
                        d_value = str(d_value).replace("\n", " ").strip()
                        if len(d_value) > max_width - d_max_key_len - len(
                            lines(True)
                        ):
                            if not full_attrs:
                                d_value = d_value[
                                    : max_width - d_max_key_len - len(lines(True))
                                ]
                                d_value = d_value[:-3] + "..."
                        s += f"{lines(True)}{d_key: <{d_max_key_len}}  {d_value}\n"
                if isinstance(value, h5py.Group):
                    s += f"{lines()}{key}\n" + recursive_stringify(
                        obj=value,
                        output_format=output_format,
                        parent_indices=parent_indices + [index],
                        parent_siblings=parent_siblings + [num_children],
                    )
                index += 1
        return s

    return recursive_stringify(self, output_format)
def update_calculated_data(self, verbose=False)
Expand source code
def update_calculated_data(self, verbose=False):
    changed = any([e.update_calculated_data(verbose) for e in self.entities])
    return changed
def validate(self, strict_validate: bool = True) ‑> Tuple[bool, str]

Validate the file to the specification.

The function compares a given file to the robofish track format specification: https://git.imp.fu-berlin.de/bioroboticslab/robofish/track_format First all specified arrays are formatted to be numpy arrays with the specified datatype. Then all specified shapes are validated. Lastly calendar points are validated to be datetimes according to ISO8601.

Args

track
A track as a dictionary
strict_validate
Throw an exception instead of just returning false.

Returns

The function returns a touple of validity and an error message

Throws

AssertionError: When the file is invalid and strict_validate is True

Expand source code
def validate(self, strict_validate: bool = True) -> Tuple[bool, str]:
    """Validate the file to the specification.

    The function compares a given file to the robofish track format specification:
    https://git.imp.fu-berlin.de/bioroboticslab/robofish/track_format
    First all specified arrays are formatted to be numpy arrays with the specified
    datatype. Then all specified shapes are validated. Lastly calendar points
    are validated to be datetimes according to ISO8601.

    Args:
        track: A track as a dictionary
        strict_validate: Throw an exception instead of just returning false.
    Returns:
        The function returns a touple of validity and an error message
    Throws:
        AssertionError: When the file is invalid and strict_validate is True
    """
    return robofish.io.validate(self, strict_validate)