Module robofish.io.file
File
objects are the root of the project. The object contains all information about the environment, entities, and time.
In the simplest form we define a new File with a world size in cm and a frequency in hz. Afterwards, we can save it with a path.
import robofish.io
# Create a new robofish io file
f = robofish.io.File(world_size_cm=[100, 100], frequency_hz=25.0)
f.save_as("test.hdf5")
The File object can also be generated, with a given path. In this case, we work on the file directly. The with
block ensures, that the file is validated after the block.
with robofish.io.File(
"test.hdf5", mode="x", world_size_cm=[100, 100], frequency_hz=25.0
) as f:
# Use file f here
When opening a file with a path, a mode should be specified to describe how the file should be opened.
Mode | Description |
---|---|
r | Readonly, file must exist (default) |
r+ | Read/write, file must exist |
w | Create file, truncate if exists |
x | Create file, fail if exists |
a | Read/write if exists, create otherwise |
Attributes
Attributes of the file can be added, to describe the contents. The attributes can be set like this:
f.attrs["experiment_setup"] = "This file comes from the tutorial."
f.attrs["experiment_issues"] = "All data in this file is made up."
Any attribute is allowed, but some cannonical attributes are prepared:
publication_url, video_url, tracking_software_name, tracking_software_version, tracking_software_url, experiment_setup, experiment_issues
Properties
As described in robofish.io
, all properties of robofish.io.entity
s can be accessed by adding the prefix entity_
to the function.
Plotting
File
s have a built in plotting tool with File.plot()
.
With the option lw_distances = True
the distance between two fish is represented throught the line width.
import robofish.io
import matplotlib.pyplot as plt
fig, ax = plt.subplots(1,2, figsize=(10,5))
f = robofish.io.File("...")
f.plot(ax=ax[0])
f.plot(ax=ax[1], lw_distances=True)
plt.show()
For all other options while plotting please check File.plot()
.
Expand source code
# -*- coding: utf-8 -*-
"""
.. include:: ../../../docs/file.md
"""
# -----------------------------------------------------------
# Utils functions for reading, validating and writing hdf5 files according to
# Robofish track format (1.0 Draft 7). The standard is available at
# https://git.imp.fu-berlin.de/bioroboticslab/robofish/track_format
#
# Dec 2020 Andreas Gerken, Berlin, Germany
# Released under GNU 3.0 License
# email andi.gerken@gmail.com
# -----------------------------------------------------------
from pytest import skip
import robofish.io
from robofish.io.entity import Entity
import h5py
import numpy as np
import logging
from typing import Iterable, Union, Tuple, List, Optional
from pathlib import Path
import shutil
import datetime
import tempfile
import uuid
import deprecation
import types
import warnings
from textwrap import wrap
import platform
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib import animation
from matplotlib import patches
from matplotlib import cm
from tqdm.auto import tqdm
from subprocess import run
# Remember: Update docstring when updating these two global variables
default_format_version = np.array([1, 0], dtype=np.int32)
default_format_url = (
"https://git.imp.fu-berlin.de/bioroboticslab/robofish/track_format/-/releases/1.0"
)
class File(h5py.File):
"""Represents a RoboFish Track Format file, which should be used to store tracking data of individual animals or swarms.
Files can be opened (with optional creation), modified inplace, and have copies of them saved.
"""
_temp_dir = None
def __init__(
self,
path: Union[str, Path] = None,
mode: str = "r",
*, # PEP 3102
world_size_cm: List[int] = None,
validate: bool = False,
validate_when_saving: bool = True,
strict_validate: bool = False,
format_version: List[int] = default_format_version,
format_url: str = default_format_url,
sampling_name: str = None,
frequency_hz: int = None,
monotonic_time_points_us: Iterable = None,
calendar_time_points: Iterable = None,
open_copy: bool = False,
validate_poses_hash: bool = True,
):
"""Create a new RoboFish Track Format object.
When called with a path, it is loaded, otherwise a new temporary
file is created. File contents can be validated against the
track format specification.
Parameters
----------
path : str or Path, optional
Location of file to be opened. If not provided, mode is ignored.
mode : str, default='r'
'r' Readonly, file must exist
'r+' Read/write, file must exist
'w' Create file, truncate if exists
'x' Create file, fail if exists
'a' Read/write if exists, create otherwise
world_size_cm : [int, int] , optional
side lengths [x, y] of the world in cm.
rectangular world shape is assumed.
validate: bool, default=False
Should the track be validated? This is normally switched off for performance reasons.
strict_validate : bool, default=False
if the file should be strictly validated against the track
format specification, when loaded from a path.
TODO: Should this validate against the version sepcified in
format_version or just against the most recent version?
format_version : [int, int], default=[1,0]
version [major, minor] of track format specification
format_url : str, default="https://git.imp.fu-berlin.de/bioroboticslab/robofish/track_format/-/releases/1.0"
location of track format specification.
should fit `format_version`.
sampling_name : str, optional
How to specify your sampling:
1. (optional)
provide text description of your sampling in `sampling_name`
2.a (mandatory, if you have a constant sampling frequency)
specify `frequency_hz` with your sampling frequency in Hz
2.b (mandatory, if you do NOT have a constant sampling frequency)
specify `monotonic_time_points_us` with a list[1] of time
points in microseconds on a montonic clock, one for each
sample in your dataset.
3. (optional)
specify `calendar_time_points` with a list[2] of time points
in the ISO 8601 extended format with microsecond precision
and time zone designator[3], one for each sample in your
dataset.
[1] any Iterable of int
[2] any Iterable of str
[3] example: "2020-11-18T13:21:34.117015+01:00"
frequency_hz: int, optional
refer to explanation of `sampling_name`
monotonic_time_points_us: Iterable of int, optional
refer to explanation of `sampling_name`
calendar_time_points: Iterable of str, optional
refer to explanation of `sampling_name`
open_copy: bool, optional
a temporary copy of the file will be opened instead of the file itself.
"""
self.path = path
self.validate_when_saving = validate_when_saving
if open_copy:
assert (
path is not None
), "A path has to be given if a copy should be opened."
temp_file = self.temp_dir / str(uuid.uuid4())
logging.info(
f"Copying file to temporary file and opening it:\n{path} -> {temp_file}"
)
shutil.copyfile(path, temp_file)
super().__init__(
temp_file,
mode="r+",
driver="core",
backing_store=True,
libver=("earliest", "v110"),
)
initialize = False
elif path is None:
temp_file = self.temp_dir / str(uuid.uuid4())
logging.info(f"Opening New temporary file {temp_file}")
super().__init__(
temp_file,
mode="x",
driver="core",
backing_store=True,
libver=("earliest", "v110"),
)
initialize = True
else:
# mode
# r Readonly, file must exist (default)
# r+ Read/write, file must exist
# w Create file, truncate if exists
# x Create file, fail if exists
# a Read/write if exists, create otherwise
logging.info(f"Opening File {path}")
assert mode in ["r", "r+", "w", "x", "a"], f"Unknown mode {mode}."
# If the file does not exist or if it should be truncated with mode=w, initialize it.
if Path(path).exists() and mode != "w":
initialize = False
else:
initialize = True
try:
super().__init__(path, mode, libver=("earliest", "v110"))
except OSError as e:
raise OSError(f"Could not open file {path} with mode {mode}.\n{e}")
if initialize:
assert (
world_size_cm is not None and format_version is not None
), "It seems like the file is already initialized. Try opening it with mode 'r+'."
self.attrs["world_size_cm"] = np.array(world_size_cm, dtype=np.float32)
self.attrs["format_version"] = np.array(format_version, dtype=np.int32)
self.attrs["format_url"] = format_url
self.create_group("entities")
self.create_group("samplings")
if frequency_hz is not None or monotonic_time_points_us is not None:
self.create_sampling(
name=sampling_name,
frequency_hz=frequency_hz,
monotonic_time_points_us=monotonic_time_points_us,
calendar_time_points=calendar_time_points,
default=True,
)
else:
# A quick validation to find h5py files which are not robofish.io files
if any([a not in self.attrs for a in ["world_size_cm", "format_version"]]):
msg = f"The opened file {self.path} does not include world_size_cm or format_version. It seems that the file is not a robofish.io.File."
if strict_validate:
raise KeyError(msg)
else:
warnings.warn(msg)
return
# Validate that the stored poses hash still fits.
if validate_poses_hash:
for entity in self.entities:
if "poses_hash" in entity.attrs:
if entity.attrs["poses_hash"] != entity.poses_hash:
warnings.warn(
f"The stored hash is not identical with the newly calculated hash. In entity {entity.name} in {self.path}. f.entity_actions_turns_speeds and f.entity_orientation_rad will return wrong results.\n"
f"stored: {entity.attrs['poses_hash']}, calculated: {entity.poses_hash}"
)
assert (
"unfinished_calculations" not in entity.attrs
), f"The calculated data of file {self.path} is uncomplete and was probably aborted during calculation. please recalculate with `robofish-io-update-calculated-data {self.path}`."
else:
warnings.warn(
f"The file did not include pre-calculated data so the actions_speeds_turns "
f"and orientations_rad will have to be be recalculated everytime.\n"
f"Please use `robofish-io-update-calculated-data {self.path}` in the "
f"commandline or\nopen and close the file with robofish.io.File(f, 'r+') "
f"in python.\nIf the data should be recalculated every time open the file "
"with the bool option validate_poses_hash=False."
)
if validate:
self.validate(strict_validate)
def __enter__(self):
return self
def __exit__(self, type, value, traceback):
# Check if the context was left under normal circumstances
if (hasattr(self, "closed") and not self.closed) and (
type,
value,
traceback,
) == (None, None, None):
if (
self.mode != "r" and self.validate_when_saving
): # No need to validate read only files (performance).
self.validate()
super().__exit__(type, value, traceback)
def close(self):
if self.mode != "r":
self.update_calculated_data()
super().close()
def save_as(
self,
path: Union[str, Path],
strict_validate: bool = True,
no_warning: bool = False,
):
"""Save a copy of the file
Args:
path: path to a io file as a string or path object. If no path is specified, the last known path (from loading or saving) is used.
strict_validate: optional boolean, if the file should be strictly validated, before saving. The default is True.
no_warning: optional boolean, to remove the warning from the function.
Returns:
The file itself, so something like f = robofish.io.File().save_as("file.hdf5") works
"""
self.update_calculated_data()
self.validate(strict_validate=strict_validate)
# Ensure all buffered data has been written to disk
self.flush()
path = Path(path).resolve()
path.parent.mkdir(parents=True, exist_ok=True)
filename = self.filename
self.flush()
self.close()
self.closed = True
shutil.copyfile(filename, path)
if not no_warning:
warnings.warn(
"The 'save_as' function closes the file currently to be able to store it. If you want to use the file after saving it, please reload the file. The save_as function can be avoided by opening the correct file directly. If you want to get rid of this warning use 'save_as(..., no_warning=True)'"
)
return None
def create_sampling(
self,
name: str = None,
frequency_hz: int = None,
monotonic_time_points_us: Iterable = None,
calendar_time_points: Iterable = None,
default: bool = False,
):
# Find Name for sampling if none is given
if name is None:
if frequency_hz is not None:
name = "%d hz" % frequency_hz
i = 1
while name is None or name in self["samplings"]:
name = "sampling_%d" % i
i += 1
sampling = self["samplings"].create_group(name)
if monotonic_time_points_us is not None:
monotonic_time_points_us = np.array(
monotonic_time_points_us, dtype=np.int64
)
sampling.create_dataset(
"monotonic_time_points_us", data=monotonic_time_points_us
)
if frequency_hz is None:
diff = np.diff(monotonic_time_points_us)
if np.all(diff == diff[0]) and diff[0] > 0:
frequency_hz = 1e6 / diff[0]
warnings.warn(
f"The frequency_hz of {frequency_hz:.2f}hz was calculated automatically by robofish.io. The safer variant is to pass it using frequency_hz.\nThis is important when using fish_models with the files."
)
else:
warnings.warn(
"The frequency_hz could not be calculated automatically. When using fish_models, the file will access frequency_hz."
)
if frequency_hz is not None:
sampling.attrs["frequency_hz"] = (np.float32)(frequency_hz)
if calendar_time_points is not None:
def format_calendar_time_point(p):
if isinstance(p, datetime.datetime):
assert p.tzinfo is not None, "Missing timezone for calendar point."
return p.isoformat(timespec="microseconds")
elif isinstance(p, str):
assert p == datetime.datetime.fromisoformat(p).isoformat(
timespec="microseconds"
)
return p
else:
assert (
False
), "Calendar points must be datetime.datetime instances or strings."
calendar_time_points = [
format_calendar_time_point(p) for p in calendar_time_points
]
sampling.create_dataset(
"calendar_time_points",
data=calendar_time_points,
dtype=h5py.string_dtype(encoding="utf-8"),
)
if default:
self["samplings"].attrs["default"] = name
return name
@property
def temp_dir(self):
cla = type(self)
if cla._temp_dir is None:
cla._temp_dir = tempfile.TemporaryDirectory(prefix="robofish-io-")
return Path(cla._temp_dir.name)
@property
def world_size(self):
return self.attrs["world_size_cm"]
@property
def default_sampling(self):
assert (
"samplings" in self
), "The file does not have a group 'sampling' which is required."
if "default" in self["samplings"].attrs:
return self["samplings"].attrs["default"]
return None
@property
def frequency(self):
common_sampling = self.common_sampling()
assert common_sampling is not None, "The sampling differs between entities."
assert (
"frequency_hz" in common_sampling.attrs
), "The common sampling has no frequency_hz"
return common_sampling.attrs["frequency_hz"]
def common_sampling(
self, entities: Iterable["robofish.io.Entity"] = None
) -> h5py.Group:
"""Check if all entities have the same sampling.
Args:
entities: optional array of entities. If None is given, all entities are checked.
Returns:
The h5py group of the common sampling. If there is no common sampling, None will be returned.
"""
custom_sampling = None
for entity in self.entities:
if "sampling" in entity["positions"].attrs:
this_sampling = entity["positions"].attrs["sampling"]
if custom_sampling is None:
custom_sampling = this_sampling
elif custom_sampling != this_sampling:
return None
sampling = self.default_sampling if custom_sampling is None else custom_sampling
return self["samplings"][sampling]
def create_entity(
self,
category: str,
poses: Iterable = None,
name: str = None,
positions: Iterable = None,
orientations: Iterable = None,
outlines: Iterable = None,
sampling: str = None,
) -> str:
"""Creates a new single entity.
Args:
TODO
category: the of the entity. The canonical values are ['organism', 'robot', 'obstacle'].
poses: optional two dimensional array, containing the poses of the entity (x,y,orientation_x, orientation_y).
poses_rad: optional two dimensional containing the poses of the entity (x,y, orientation_rad).
name: optional name of the entity. If no name is given, the is used with an id (e.g. 'fish_1')
outlines: optional three dimensional array, containing the outlines of the entity
Returns:
Name of the created entity
"""
if sampling is None and self.default_sampling is None:
raise Exception(
"There was no sampling specified, when creating the file, nor when creating the entity."
)
entity = robofish.io.Entity.create_entity(
self["entities"],
category,
poses,
name,
positions,
orientations,
outlines,
sampling,
)
return entity
def create_multiple_entities(
self,
category: str,
poses: Iterable,
names: Iterable[str] = None,
outlines=None,
sampling=None,
) -> Iterable:
"""Creates multiple entities.
Args:
category: The common category for the entities. The canonical values are ['organism', 'robot', 'obstacle'].
poses: three dimensional array, containing the poses of the entity.
name: optional array of names of the entities. If no names are given, the category is used with an id (e.g. 'fish_1')
outlines: optional array, containing the outlines of the entities, either a three dimensional common outline array can be given, or a four dimensional array.
sampling: The string refference to the sampling. If none is given, the standard sampling from creating the file is used.
Returns:
Array of names of the created entities
"""
assert (
poses.ndim == 3
), f"A 3 dimensional array was expected (entity, timestep, 3). There were {poses.ndim} dimensions in poses: {poses.shape}"
assert poses.shape[2] in [3, 4]
agents = poses.shape[0]
entity_names = []
for i in range(agents):
e_name = None if names is None else names[i]
e_outline = (
outlines if outlines is None or outlines.ndim == 3 else outlines[i]
)
entity_names.append(
self.create_entity(
category=category,
sampling=sampling,
poses=poses[i],
name=e_name,
outlines=e_outline,
)
)
return entity_names
def update_calculated_data(self, verbose=False):
changed = any([e.update_calculated_data(verbose) for e in self.entities])
return changed
def clear_calculated_data(self, verbose=True):
"""Delete all calculated data from the files."""
txt = ""
for e in self.entities:
txt += f"Deleting from {e}. Attrs: ["
for a in ["poses_hash"]:
if a in e.attrs:
del e.attrs[a]
txt += f"{a}, "
txt = txt[:-2] + "] Datasets: ["
for g in ["calculated_actions_speeds_turns", "calculated_orientations_rad"]:
if g in e:
del e[g]
txt += f"{g}, "
txt = txt[:-2] + "]\n"
if verbose:
print(txt[:-1])
@property
def entity_names(self) -> Iterable[str]:
"""Getter for the names of all entities
Returns:
Array of all names.
"""
return sorted(self["entities"].keys())
@property
def entities(self):
return [
robofish.io.Entity.from_h5py_group(self["entities"][name])
for name in self.entity_names
]
@property
def entity_positions(self):
return self.select_entity_property(None, entity_property=Entity.positions)
@property
def entity_orientations(self):
return self.select_entity_property(None, entity_property=Entity.orientations)
@property
def entity_orientations_rad(self):
return self.select_entity_property(
None, entity_property=Entity.orientations_rad
)
@property
def entity_poses(self):
return self.select_entity_property(None, entity_property=Entity.poses)
@property
def entity_poses_rad(self):
return self.select_entity_property(None, entity_property=Entity.poses_rad)
@property
@deprecation.deprecated(
deprecated_in="0.2",
removed_in="0.2.4",
details="We found that our calculation of 'poses_calc_ori' is flawed."
"Please replace it with 'poses' and use the tracked orientation."
"If you see this message and you don't know what to do, update all packages and if nothing helps, contact Andi.\n"
"Don't ignore this warning, it's a serious issue.",
)
def entity_poses_calc_ori(self):
return self.select_entity_property(None, entity_property=Entity.poses_calc_ori)
@property
@deprecation.deprecated(
deprecated_in="0.2",
removed_in="0.2.4",
details="We found that our calculation of 'poses_calc_ori_rad' is flawed."
"Please replace it with 'poses_rad' and use the tracked orientation."
"If you see this message and you don't know what to do, update all packages and if nothing helps, contact Andi.\n"
"Don't ignore this warning, it's a serious issue.",
)
def entity_poses_calc_ori_rad(self):
return self.select_entity_property(
None, entity_property=Entity.poses_calc_ori_rad
)
@property
@deprecation.deprecated(
deprecated_in="0.2",
removed_in="0.2.4",
details="We found that our calculation of 'entity_speeds_turns' is flawed and replaced it "
"with 'entity_actions_speeds_turns'. The difference in calculation is, that the tracked "
"orientation is used now which gives the fish the ability to swim backwards. "
"If you see this message and you don't know what to do, update all packages and if nothing helps, contact Andi.\n"
"Don't ignore this warning, it's a serious issue.",
)
def entity_speeds_turns(self):
return self.select_entity_property(None, entity_property=Entity.speed_turn)
@property
def entity_actions_speeds_turns(self):
"""Calculate the speed, turn and from the recorded positions and orientations.
The turn is calculated by the change of orientation between frames.
The speed is calculated by the distance between the points, projected on the new orientation vector.
The sideway change of position cannot be represented with this method.
Returns:
An array with shape (number_of_entities, number_of_positions -1, 2 (speed in cm/frame, turn in rad/frame).
"""
return self.select_entity_property(
None, entity_property=Entity.actions_speeds_turns
)
def select_entity_poses(self, *args, ori_rad=False, **kwargs):
entity_property = Entity.poses_rad if ori_rad else Entity.poses
return self.select_entity_property(
*args, entity_property=entity_property, **kwargs
)
def select_entity_property(
self,
predicate: types.LambdaType = None,
entity_property: Union[property, str] = Entity.poses,
) -> Iterable:
"""Get a property of selected entities.
Entities can be selected, using a lambda function.
The property of the entities can be selected.
Args:
predicate: a lambda function, selecting entities
(example: lambda e: e.category == "fish")
entity_property: a property of the Entity class (example: Entity.poses_rad) or a string with the name of the dataset.
Returns:
An three dimensional array of all properties of all entities with the shape (entity, time, property_length).
If an entity has a shorter length of the property, the output will be filled with nans.
"""
entities = self.entities
if predicate is not None:
entities = [e for e in entities if predicate(e)]
assert self.common_sampling(entities) is not None
# Initialize poses output array
if isinstance(entity_property, str):
properties = [entity[entity_property] for entity in entities]
else:
properties = [entity_property.__get__(entity) for entity in entities]
max_timesteps = max([0] + [p.shape[0] for p in properties])
property_array = np.empty(
(len(entities), max_timesteps, properties[0].shape[1])
)
property_array[:] = np.nan
# Fill output array
for i, entity in enumerate(entities):
property_array[i][: properties[i].shape[0]] = properties[i]
return property_array
def validate(self, strict_validate: bool = True) -> Tuple[bool, str]:
"""Validate the file to the specification.
The function compares a given file to the robofish track format specification:
https://git.imp.fu-berlin.de/bioroboticslab/robofish/track_format
First all specified arrays are formatted to be numpy arrays with the specified
datatype. Then all specified shapes are validated. Lastly calendar points
are validated to be datetimes according to ISO8601.
Args:
track: A track as a dictionary
strict_validate: Throw an exception instead of just returning false.
Returns:
The function returns a touple of validity and an error message
Throws:
AssertionError: When the file is invalid and strict_validate is True
"""
return robofish.io.validate(self, strict_validate)
def to_string(
self,
output_format: str = "shape",
max_width: int = 120,
full_attrs: bool = False,
) -> str:
"""The file is formatted to a human readable format.
Args:
output_format: ['shape', 'full'] show the shape, or the full content of datasets
max_width: set the width in characters after which attribute values get abbreviated
full_attrs: do not abbreviate attribute values if True
Returns:
A human readable string, representing the file
"""
def recursive_stringify(
obj: h5py.Group,
output_format: str,
parent_indices: List[int] = [],
parent_siblings: List[int] = [],
) -> str:
"""This function crawls recursively into hdf5 groups.
Datasets and attributes are directly attached, for groups, the function is recursively called again.
Args:
obj: a h5py group
output_format: ['shape', 'full'] show the shape, or the full content of datasets
Returns:
A string representation of the group
"""
def lines(dataset_attribute: bool = False) -> str:
"""Get box-drawing characters for the graph lines."""
line = ""
for pi, ps in zip(parent_indices, parent_siblings):
if pi < ps - 1:
line += "│ "
else:
line += " "
if dataset_attribute:
line += " "
line += "─ "
junction_index = 2 * len(parent_indices) + dataset_attribute * 2 - 1
last = "└"
other = "├"
if dataset_attribute:
j = (
last
if list(value.attrs.keys()).index(d_key) == len(value.attrs) - 1
else other
)
else:
j = last if index == num_children - 1 else other
line = line[: junction_index + 1] + j + line[junction_index + 1 :]
if isinstance(value, h5py.Group) or (
isinstance(value, h5py.Dataset)
and not dataset_attribute
and value.attrs
):
line = line[:-1] + "┬─"
else:
line = line[:-1] + "──"
return line + " "
s = ""
max_key_len = 0
num_children = 0
if obj.attrs:
max_key_len = max(len(key) for key in obj.attrs)
num_children += len(obj.attrs)
if hasattr(obj, "items"):
max_key_len = max([len(key) for key in obj] + [max_key_len])
num_children += len(obj)
index = 0
if obj.attrs:
for key, value in obj.attrs.items():
if not full_attrs:
value = str(value).replace("\n", " ").strip()
if len(value) > max_width - max_key_len - len(lines()):
value = (
value[: max_width - max_key_len - len(lines()) - 3]
+ "..."
)
s += f"{lines()}{key: <{max_key_len}} {value}\n"
index += 1
if hasattr(obj, "items"):
for key, value in obj.items():
if isinstance(value, h5py.Dataset):
if output_format == "shape":
s += (
f"{lines()}"
f"{key: <{max_key_len}} Shape {value.shape}\n"
)
else:
s += f"{lines()}{key}:\n"
s += np.array2string(
value,
precision=2,
separator=" ",
suppress_small=True,
)
s += "\n"
if value.attrs:
d_max_key_len = max(len(dk) for dk in value.attrs)
for d_key, d_value in value.attrs.items():
d_value = str(d_value).replace("\n", " ").strip()
if len(d_value) > max_width - d_max_key_len - len(
lines(True)
):
if not full_attrs:
d_value = d_value[
: max_width - d_max_key_len - len(lines(True))
]
d_value = d_value[:-3] + "..."
s += f"{lines(True)}{d_key: <{d_max_key_len}} {d_value}\n"
if isinstance(value, h5py.Group):
s += f"{lines()}{key}\n" + recursive_stringify(
obj=value,
output_format=output_format,
parent_indices=parent_indices + [index],
parent_siblings=parent_siblings + [num_children],
)
index += 1
return s
return recursive_stringify(self, output_format)
def __str__(self):
return self.to_string()
def plot(
self,
ax=None,
lw_distances=False,
lw=2,
ms=32,
figsize=None,
step_size=4,
c=None,
cmap="Set1",
skip_timesteps=0,
max_timesteps=None,
show=False,
legend=True,
):
"""Plot the file using matplotlib.pyplot
The tracks in the file are plotted using matplotlib.plot().
Args:
ax (matplotlib.axes, optional): An axes object to plot in. If None is given, a new figure is created.
lw_distances (bool, optional): Flag to show the distances between individuals through line width.
figsize (Tuple[int], optional): Size of a newly created figure.
step_size (int, optional): when using lw_distances, the track is split into sections which have a common line width. This parameter defines the length of the sections.
c (Array[color_representation], optional): An array of colors. Each item has to be matplotlib.colors.is_color_like(item).
cmap (matplotlib.colors.Colormap, optional): The colormap to use
skip_timesteps (int, optional): Skip timesteps in the begining of the file
max_timesteps (int, optional): Cut of timesteps in the end of the file.
show (bool, optional): Show the created plot.
Returns:
matplotlib.axes: The axes object with the plot.
"""
if max_timesteps is not None:
poses = self.entity_positions[
:, skip_timesteps : max_timesteps + skip_timesteps
]
else:
poses = self.entity_positions[:, skip_timesteps:]
if lw_distances and poses.shape[0] < 2:
lw_distances = False
if lw_distances:
poses_diff = np.diff(poses, axis=0) # Axis 0 is fish
distances = np.linalg.norm(poses_diff, axis=2)
min_distances = np.min(distances, axis=0)
# Magic numbers found by trial and error. Everything above 15cm will be represented as line width 1
max_distance = 10
max_lw = 4
line_width = (
np.clip(max_distance - min_distances, 1, max_distance)
* max_lw
/ max_distance
)
else:
step_size = poses.shape[1]
cmap = cm.get_cmap(cmap)
x_world, y_world = self.world_size
if figsize is None:
figsize = (8, 8)
if ax is None:
fig, ax = plt.subplots(1, 1, figsize=figsize)
if self.path is not None:
ax.set_title("\n".join(wrap(Path(self.path).name, width=35)))
ax.set_xlim(-x_world / 2, x_world / 2)
ax.set_ylim(-y_world / 2, y_world / 2)
for fish_id in range(poses.shape[0]):
if c is None:
this_c = cmap(fish_id)
elif isinstance(c, list):
this_c = c[fish_id]
timesteps = poses.shape[1] - 1
for t in range(0, timesteps, step_size):
if lw_distances:
lw = np.mean(line_width[t : t + step_size + 1])
ax.plot(
poses[fish_id, t : t + step_size + 1, 0],
poses[fish_id, t : t + step_size + 1, 1],
c=this_c,
lw=lw,
)
# Plotting outside of the figure to have the label
ax.plot([550, 600], [550, 600], lw=5, c=this_c, label=fish_id)
# ax.scatter(
# [poses[:, skip_timesteps, 0]],
# [poses[:, skip_timesteps, 1]],
# marker="h",
# c="black",
# s=ms,
# label="Start",
# zorder=5,
# )
ax.scatter(
[poses[:, -1, 0]],
[poses[:, -1, 1]],
marker="x",
c="black",
s=ms,
label="End",
zorder=5,
)
if legend and isinstance(legend, str):
ax.legend(legend)
elif legend:
ax.legend()
ax.set_xlabel("x [cm]")
ax.set_ylabel("y [cm]")
if show:
plt.show()
return ax
def render(self, video_path=None, **kwargs):
"""Render a video of the file.
As there are render functions in gym_guppy and robofish.trackviewer, this function is a temporary addition.
The goal should be to bring together the rendering tools."""
if video_path is not None:
try:
run(["ffmpeg"], capture_output=True)
except Exception as e:
raise Exception(
f"ffmpeg is required to store videos. Please install it.\n{e}"
)
def shape_vertices(scale=1) -> np.ndarray:
base_shape = np.array(
[
(+3.0, +0.0),
(+2.5, +1.0),
(+1.5, +1.5),
(-2.5, +1.0),
(-4.5, +0.0),
(-2.5, -1.0),
(+1.5, -1.5),
(+2.5, -1.0),
]
)
return base_shape * scale
default_options = {
"linewidth": 2,
"speedup": 1,
"trail": 100,
"entity_scale": 0.2,
"fixed_view": False,
"view_size": 50,
"margin": 15,
"slow_view": 0.8,
"slow_zoom": 0.95,
"cut_frames_start": None,
"cut_frames_end": None,
"show_text": False,
"render_goals": False,
"render_targets": False,
"dpi": 200,
"figsize": 10,
}
options = {
key: kwargs[key] if key in kwargs else default_options[key]
for key in default_options.keys()
}
fig, ax = plt.subplots(figsize=(options["figsize"], options["figsize"]))
ax.set_aspect("equal")
ax.set_facecolor("gray")
plt.tight_layout(pad=0.05)
n_entities = len(self.entities)
lines = [
plt.plot([], [], lw=options["linewidth"], zorder=0)[0]
for _ in range(n_entities)
]
points = [
plt.scatter([], [], marker="x", color="k"),
plt.plot([], [], linestyle="dotted", alpha=0.5, color="k", zorder=0)[0],
]
categories = [entity.attrs.get("category", None) for entity in self.entities]
entity_polygons = [
patches.Polygon(shape_vertices(options["entity_scale"]), facecolor=color)
for color in [
"gray" if category == "robot" else "k" for category in categories
]
]
border_vertices = np.array(
[
np.array([-1, -1, 1, 1, -1]) * self.world_size[0] / 2,
np.array([-1, 1, 1, -1, -1]) * self.world_size[1] / 2,
]
)
spacing = 10
x = np.arange(
-0.5 * self.world_size[0] + spacing, 0.5 * self.world_size[0], spacing
)
y = np.arange(
-0.5 * self.world_size[1] + spacing, 0.5 * self.world_size[1], spacing
)
xv, yv = np.meshgrid(x, y)
grid_points = plt.scatter(xv, yv, c="gray", s=1.5)
# border = plt.plot(border_vertices[0], border_vertices[1], "k")
border = patches.Polygon(border_vertices.T, facecolor="w", zorder=-1)
def title(file_frame: int) -> str:
"""Search for datasets containing text for displaying it in the video"""
output = []
for e in self.entities:
for key, val in e.items():
if val.dtype == object and type(val[0]) == bytes:
output.append(f"{e.name}.{key}='{val[file_frame].decode()}'")
return ", ".join(output)
def get_goal(file_frame: int) -> Optional[np.ndarray]:
"""Return current goal of robot, if robot exists and has a goal."""
goal = None
if "robot" in categories:
robot = self.entities[categories.index("robot")]
try:
goal = robot["goals"][file_frame]
except KeyError:
pass
if goal is not None and np.isnan(goal).any():
goal = None
return goal
def get_target(file_frame: int) -> Tuple[List, List]:
"""Return line points from robot to target"""
if "robot" in categories:
robot = self.entities[categories.index("robot")]
rpos = robot["positions"][file_frame]
target = robot["targets"][file_frame]
return [rpos[0], target[0]], [rpos[1], target[1]]
return [], []
def init():
ax.set_xlim(-0.5 * self.world_size[0], 0.5 * self.world_size[0])
ax.set_ylim(-0.5 * self.world_size[1], 0.5 * self.world_size[1])
ax.set_xticks([])
ax.set_xticks([], minor=True)
ax.set_yticks([])
ax.set_yticks([], minor=True)
for e_poly in entity_polygons:
ax.add_patch(e_poly)
ax.add_patch(border)
return lines + entity_polygons + [border] + points
n_frames = self.entity_poses.shape[1]
if options["cut_frames_end"] == 0 or options["cut_frames_end"] is None:
options["cut_frames_end"] = n_frames
if options["cut_frames_start"] is None:
options["cut_frames_start"] = 0
frame_range = (
options["cut_frames_start"],
min(n_frames, options["cut_frames_end"]),
)
n_frames = int((frame_range[1] - frame_range[0]) / options["speedup"])
start_pose = self.entity_poses_rad[:, frame_range[0]]
self.middle_of_swarm = np.mean(start_pose, axis=0)
min_view = np.max((np.max(start_pose, axis=0) - np.min(start_pose, axis=0))[:2])
self.view_size = np.max([options["view_size"], min_view + options["margin"]])
if video_path is not None:
pbar = tqdm(range(n_frames))
def update(frame):
if "pbar" in locals().keys():
pbar.update(1)
pbar.refresh()
if frame < n_frames:
entity_poses = self.entity_poses_rad
file_frame = (frame * options["speedup"]) + frame_range[0]
this_pose = entity_poses[:, file_frame]
if not options["fixed_view"]:
# Find the maximal distance between the entities in x or y direction
min_view = np.max(
(np.max(this_pose, axis=0) - np.min(this_pose, axis=0))[:2]
)
new_view_size = np.max(
[options["view_size"], min_view + options["margin"]]
)
if not np.isnan(min_view).any() and not new_view_size is np.nan:
self.middle_of_swarm = options[
"slow_view"
] * self.middle_of_swarm + (1 - options["slow_view"]) * np.mean(
this_pose, axis=0
)
self.view_size = (
options["slow_zoom"] * self.view_size
+ (1 - options["slow_zoom"]) * new_view_size
)
ax.set_xlim(
self.middle_of_swarm[0] - self.view_size / 2,
self.middle_of_swarm[0] + self.view_size / 2,
)
ax.set_ylim(
self.middle_of_swarm[1] - self.view_size / 2,
self.middle_of_swarm[1] + self.view_size / 2,
)
if options["show_text"]:
ax.set_title(title(file_frame))
if options["render_goals"]:
goal = get_goal(file_frame)
if goal is not None:
points[0].set_offsets(goal)
if options["render_targets"]:
points[1].set_data(get_target(file_frame))
poses_trails = entity_poses[
:, max(0, file_frame - options["trail"]) : file_frame
]
for i_entity in range(n_entities):
lines[i_entity].set_data(
poses_trails[i_entity, :, 0], poses_trails[i_entity, :, 1]
)
current_pose = entity_poses[i_entity, file_frame]
t = mpl.transforms.Affine2D().translate(
current_pose[0], current_pose[1]
)
r = mpl.transforms.Affine2D().rotate(current_pose[2])
tra = r + t + ax.transData
entity_polygons[i_entity].set_transform(tra)
else:
raise Exception(
f"Frame is bigger than n_frames {file_frame} of {n_frames}"
)
return lines + entity_polygons + [border] + points
print(f"Preparing to render n_frames: {n_frames}")
ani = animation.FuncAnimation(
fig,
update,
frames=n_frames,
init_func=init,
blit=platform.system() != "Darwin",
interval=1000 / self.frequency,
repeat=False,
)
if video_path is not None:
# if i % (n / 40) == 0:
# print(f"Saving frame {i} of {n} ({100*i/n:.1f}%)")
video_path = Path(video_path)
if video_path.exists():
y = input(f"Video {str(video_path)} exists. Overwrite? (y/n)")
if y == "y":
video_path.unlink()
if not video_path.exists():
print(f"saving video to {video_path}")
writervideo = animation.FFMpegWriter(fps=self.frequency)
ani.save(video_path, writer=writervideo, dpi=options["dpi"])
plt.close()
else:
plt.show()
Classes
class File (path: Union[str, pathlib.Path] = None, mode: str = 'r', *, world_size_cm: List[int] = None, validate: bool = False, validate_when_saving: bool = True, strict_validate: bool = False, format_version: List[int] = array([1, 0], dtype=int32), format_url: str = 'https://git.imp.fu-berlin.de/bioroboticslab/robofish/track_format/-/releases/1.0', sampling_name: str = None, frequency_hz: int = None, monotonic_time_points_us: Iterable = None, calendar_time_points: Iterable = None, open_copy: bool = False, validate_poses_hash: bool = True)
-
Represents a RoboFish Track Format file, which should be used to store tracking data of individual animals or swarms.
Files can be opened (with optional creation), modified inplace, and have copies of them saved.
Create a new RoboFish Track Format object.
When called with a path, it is loaded, otherwise a new temporary file is created. File contents can be validated against the track format specification.
Parameters
path
:str
orPath
, optional- Location of file to be opened. If not provided, mode is ignored.
mode
:str
, default='r'
- 'r' Readonly, file must exist 'r+' Read/write, file must exist 'w' Create file, truncate if exists 'x' Create file, fail if exists 'a' Read/write if exists, create otherwise
world_size_cm
:[int, int]
, optional- side lengths [x, y] of the world in cm. rectangular world shape is assumed.
validate
:bool
, default=False
- Should the track be validated? This is normally switched off for performance reasons.
strict_validate
:bool
, default=False
- if the file should be strictly validated against the track format specification, when loaded from a path. TODO: Should this validate against the version sepcified in format_version or just against the most recent version?
format_version
:[int, int]
, default=[1,0]
- version [major, minor] of track format specification
format_url
:str
, default="https://git.imp.fu-berlin.de/bioroboticslab/robofish/track_format/-/releases/1.0"
- location of track format specification.
should fit
format_version
. sampling_name
:str
, optional-
How to specify your sampling:
- (optional)
provide text description of your sampling in
sampling_name
2.a (mandatory, if you have a constant sampling frequency) specify
frequency_hz
with your sampling frequency in Hz2.b (mandatory, if you do NOT have a constant sampling frequency) specify
monotonic_time_points_us
with a list[1] of time points in microseconds on a montonic clock, one for each sample in your dataset.- (optional)
specify
calendar_time_points
with a list[2] of time points in the ISO 8601 extended format with microsecond precision and time zone designator[3], one for each sample in your dataset.
[1] any Iterable of int [2] any Iterable of str [3] example: "2020-11-18T13:21:34.117015+01:00"
- (optional)
provide text description of your sampling in
frequency_hz
:int
, optional- refer to explanation of
sampling_name
monotonic_time_points_us
:Iterable
ofint
, optional- refer to explanation of
sampling_name
calendar_time_points
:Iterable
ofstr
, optional- refer to explanation of
sampling_name
open_copy
:bool
, optional- a temporary copy of the file will be opened instead of the file itself.
Expand source code
class File(h5py.File): """Represents a RoboFish Track Format file, which should be used to store tracking data of individual animals or swarms. Files can be opened (with optional creation), modified inplace, and have copies of them saved. """ _temp_dir = None def __init__( self, path: Union[str, Path] = None, mode: str = "r", *, # PEP 3102 world_size_cm: List[int] = None, validate: bool = False, validate_when_saving: bool = True, strict_validate: bool = False, format_version: List[int] = default_format_version, format_url: str = default_format_url, sampling_name: str = None, frequency_hz: int = None, monotonic_time_points_us: Iterable = None, calendar_time_points: Iterable = None, open_copy: bool = False, validate_poses_hash: bool = True, ): """Create a new RoboFish Track Format object. When called with a path, it is loaded, otherwise a new temporary file is created. File contents can be validated against the track format specification. Parameters ---------- path : str or Path, optional Location of file to be opened. If not provided, mode is ignored. mode : str, default='r' 'r' Readonly, file must exist 'r+' Read/write, file must exist 'w' Create file, truncate if exists 'x' Create file, fail if exists 'a' Read/write if exists, create otherwise world_size_cm : [int, int] , optional side lengths [x, y] of the world in cm. rectangular world shape is assumed. validate: bool, default=False Should the track be validated? This is normally switched off for performance reasons. strict_validate : bool, default=False if the file should be strictly validated against the track format specification, when loaded from a path. TODO: Should this validate against the version sepcified in format_version or just against the most recent version? format_version : [int, int], default=[1,0] version [major, minor] of track format specification format_url : str, default="https://git.imp.fu-berlin.de/bioroboticslab/robofish/track_format/-/releases/1.0" location of track format specification. should fit `format_version`. sampling_name : str, optional How to specify your sampling: 1. (optional) provide text description of your sampling in `sampling_name` 2.a (mandatory, if you have a constant sampling frequency) specify `frequency_hz` with your sampling frequency in Hz 2.b (mandatory, if you do NOT have a constant sampling frequency) specify `monotonic_time_points_us` with a list[1] of time points in microseconds on a montonic clock, one for each sample in your dataset. 3. (optional) specify `calendar_time_points` with a list[2] of time points in the ISO 8601 extended format with microsecond precision and time zone designator[3], one for each sample in your dataset. [1] any Iterable of int [2] any Iterable of str [3] example: "2020-11-18T13:21:34.117015+01:00" frequency_hz: int, optional refer to explanation of `sampling_name` monotonic_time_points_us: Iterable of int, optional refer to explanation of `sampling_name` calendar_time_points: Iterable of str, optional refer to explanation of `sampling_name` open_copy: bool, optional a temporary copy of the file will be opened instead of the file itself. """ self.path = path self.validate_when_saving = validate_when_saving if open_copy: assert ( path is not None ), "A path has to be given if a copy should be opened." temp_file = self.temp_dir / str(uuid.uuid4()) logging.info( f"Copying file to temporary file and opening it:\n{path} -> {temp_file}" ) shutil.copyfile(path, temp_file) super().__init__( temp_file, mode="r+", driver="core", backing_store=True, libver=("earliest", "v110"), ) initialize = False elif path is None: temp_file = self.temp_dir / str(uuid.uuid4()) logging.info(f"Opening New temporary file {temp_file}") super().__init__( temp_file, mode="x", driver="core", backing_store=True, libver=("earliest", "v110"), ) initialize = True else: # mode # r Readonly, file must exist (default) # r+ Read/write, file must exist # w Create file, truncate if exists # x Create file, fail if exists # a Read/write if exists, create otherwise logging.info(f"Opening File {path}") assert mode in ["r", "r+", "w", "x", "a"], f"Unknown mode {mode}." # If the file does not exist or if it should be truncated with mode=w, initialize it. if Path(path).exists() and mode != "w": initialize = False else: initialize = True try: super().__init__(path, mode, libver=("earliest", "v110")) except OSError as e: raise OSError(f"Could not open file {path} with mode {mode}.\n{e}") if initialize: assert ( world_size_cm is not None and format_version is not None ), "It seems like the file is already initialized. Try opening it with mode 'r+'." self.attrs["world_size_cm"] = np.array(world_size_cm, dtype=np.float32) self.attrs["format_version"] = np.array(format_version, dtype=np.int32) self.attrs["format_url"] = format_url self.create_group("entities") self.create_group("samplings") if frequency_hz is not None or monotonic_time_points_us is not None: self.create_sampling( name=sampling_name, frequency_hz=frequency_hz, monotonic_time_points_us=monotonic_time_points_us, calendar_time_points=calendar_time_points, default=True, ) else: # A quick validation to find h5py files which are not robofish.io files if any([a not in self.attrs for a in ["world_size_cm", "format_version"]]): msg = f"The opened file {self.path} does not include world_size_cm or format_version. It seems that the file is not a robofish.io.File." if strict_validate: raise KeyError(msg) else: warnings.warn(msg) return # Validate that the stored poses hash still fits. if validate_poses_hash: for entity in self.entities: if "poses_hash" in entity.attrs: if entity.attrs["poses_hash"] != entity.poses_hash: warnings.warn( f"The stored hash is not identical with the newly calculated hash. In entity {entity.name} in {self.path}. f.entity_actions_turns_speeds and f.entity_orientation_rad will return wrong results.\n" f"stored: {entity.attrs['poses_hash']}, calculated: {entity.poses_hash}" ) assert ( "unfinished_calculations" not in entity.attrs ), f"The calculated data of file {self.path} is uncomplete and was probably aborted during calculation. please recalculate with `robofish-io-update-calculated-data {self.path}`." else: warnings.warn( f"The file did not include pre-calculated data so the actions_speeds_turns " f"and orientations_rad will have to be be recalculated everytime.\n" f"Please use `robofish-io-update-calculated-data {self.path}` in the " f"commandline or\nopen and close the file with robofish.io.File(f, 'r+') " f"in python.\nIf the data should be recalculated every time open the file " "with the bool option validate_poses_hash=False." ) if validate: self.validate(strict_validate) def __enter__(self): return self def __exit__(self, type, value, traceback): # Check if the context was left under normal circumstances if (hasattr(self, "closed") and not self.closed) and ( type, value, traceback, ) == (None, None, None): if ( self.mode != "r" and self.validate_when_saving ): # No need to validate read only files (performance). self.validate() super().__exit__(type, value, traceback) def close(self): if self.mode != "r": self.update_calculated_data() super().close() def save_as( self, path: Union[str, Path], strict_validate: bool = True, no_warning: bool = False, ): """Save a copy of the file Args: path: path to a io file as a string or path object. If no path is specified, the last known path (from loading or saving) is used. strict_validate: optional boolean, if the file should be strictly validated, before saving. The default is True. no_warning: optional boolean, to remove the warning from the function. Returns: The file itself, so something like f = robofish.io.File().save_as("file.hdf5") works """ self.update_calculated_data() self.validate(strict_validate=strict_validate) # Ensure all buffered data has been written to disk self.flush() path = Path(path).resolve() path.parent.mkdir(parents=True, exist_ok=True) filename = self.filename self.flush() self.close() self.closed = True shutil.copyfile(filename, path) if not no_warning: warnings.warn( "The 'save_as' function closes the file currently to be able to store it. If you want to use the file after saving it, please reload the file. The save_as function can be avoided by opening the correct file directly. If you want to get rid of this warning use 'save_as(..., no_warning=True)'" ) return None def create_sampling( self, name: str = None, frequency_hz: int = None, monotonic_time_points_us: Iterable = None, calendar_time_points: Iterable = None, default: bool = False, ): # Find Name for sampling if none is given if name is None: if frequency_hz is not None: name = "%d hz" % frequency_hz i = 1 while name is None or name in self["samplings"]: name = "sampling_%d" % i i += 1 sampling = self["samplings"].create_group(name) if monotonic_time_points_us is not None: monotonic_time_points_us = np.array( monotonic_time_points_us, dtype=np.int64 ) sampling.create_dataset( "monotonic_time_points_us", data=monotonic_time_points_us ) if frequency_hz is None: diff = np.diff(monotonic_time_points_us) if np.all(diff == diff[0]) and diff[0] > 0: frequency_hz = 1e6 / diff[0] warnings.warn( f"The frequency_hz of {frequency_hz:.2f}hz was calculated automatically by robofish.io. The safer variant is to pass it using frequency_hz.\nThis is important when using fish_models with the files." ) else: warnings.warn( "The frequency_hz could not be calculated automatically. When using fish_models, the file will access frequency_hz." ) if frequency_hz is not None: sampling.attrs["frequency_hz"] = (np.float32)(frequency_hz) if calendar_time_points is not None: def format_calendar_time_point(p): if isinstance(p, datetime.datetime): assert p.tzinfo is not None, "Missing timezone for calendar point." return p.isoformat(timespec="microseconds") elif isinstance(p, str): assert p == datetime.datetime.fromisoformat(p).isoformat( timespec="microseconds" ) return p else: assert ( False ), "Calendar points must be datetime.datetime instances or strings." calendar_time_points = [ format_calendar_time_point(p) for p in calendar_time_points ] sampling.create_dataset( "calendar_time_points", data=calendar_time_points, dtype=h5py.string_dtype(encoding="utf-8"), ) if default: self["samplings"].attrs["default"] = name return name @property def temp_dir(self): cla = type(self) if cla._temp_dir is None: cla._temp_dir = tempfile.TemporaryDirectory(prefix="robofish-io-") return Path(cla._temp_dir.name) @property def world_size(self): return self.attrs["world_size_cm"] @property def default_sampling(self): assert ( "samplings" in self ), "The file does not have a group 'sampling' which is required." if "default" in self["samplings"].attrs: return self["samplings"].attrs["default"] return None @property def frequency(self): common_sampling = self.common_sampling() assert common_sampling is not None, "The sampling differs between entities." assert ( "frequency_hz" in common_sampling.attrs ), "The common sampling has no frequency_hz" return common_sampling.attrs["frequency_hz"] def common_sampling( self, entities: Iterable["robofish.io.Entity"] = None ) -> h5py.Group: """Check if all entities have the same sampling. Args: entities: optional array of entities. If None is given, all entities are checked. Returns: The h5py group of the common sampling. If there is no common sampling, None will be returned. """ custom_sampling = None for entity in self.entities: if "sampling" in entity["positions"].attrs: this_sampling = entity["positions"].attrs["sampling"] if custom_sampling is None: custom_sampling = this_sampling elif custom_sampling != this_sampling: return None sampling = self.default_sampling if custom_sampling is None else custom_sampling return self["samplings"][sampling] def create_entity( self, category: str, poses: Iterable = None, name: str = None, positions: Iterable = None, orientations: Iterable = None, outlines: Iterable = None, sampling: str = None, ) -> str: """Creates a new single entity. Args: TODO category: the of the entity. The canonical values are ['organism', 'robot', 'obstacle']. poses: optional two dimensional array, containing the poses of the entity (x,y,orientation_x, orientation_y). poses_rad: optional two dimensional containing the poses of the entity (x,y, orientation_rad). name: optional name of the entity. If no name is given, the is used with an id (e.g. 'fish_1') outlines: optional three dimensional array, containing the outlines of the entity Returns: Name of the created entity """ if sampling is None and self.default_sampling is None: raise Exception( "There was no sampling specified, when creating the file, nor when creating the entity." ) entity = robofish.io.Entity.create_entity( self["entities"], category, poses, name, positions, orientations, outlines, sampling, ) return entity def create_multiple_entities( self, category: str, poses: Iterable, names: Iterable[str] = None, outlines=None, sampling=None, ) -> Iterable: """Creates multiple entities. Args: category: The common category for the entities. The canonical values are ['organism', 'robot', 'obstacle']. poses: three dimensional array, containing the poses of the entity. name: optional array of names of the entities. If no names are given, the category is used with an id (e.g. 'fish_1') outlines: optional array, containing the outlines of the entities, either a three dimensional common outline array can be given, or a four dimensional array. sampling: The string refference to the sampling. If none is given, the standard sampling from creating the file is used. Returns: Array of names of the created entities """ assert ( poses.ndim == 3 ), f"A 3 dimensional array was expected (entity, timestep, 3). There were {poses.ndim} dimensions in poses: {poses.shape}" assert poses.shape[2] in [3, 4] agents = poses.shape[0] entity_names = [] for i in range(agents): e_name = None if names is None else names[i] e_outline = ( outlines if outlines is None or outlines.ndim == 3 else outlines[i] ) entity_names.append( self.create_entity( category=category, sampling=sampling, poses=poses[i], name=e_name, outlines=e_outline, ) ) return entity_names def update_calculated_data(self, verbose=False): changed = any([e.update_calculated_data(verbose) for e in self.entities]) return changed def clear_calculated_data(self, verbose=True): """Delete all calculated data from the files.""" txt = "" for e in self.entities: txt += f"Deleting from {e}. Attrs: [" for a in ["poses_hash"]: if a in e.attrs: del e.attrs[a] txt += f"{a}, " txt = txt[:-2] + "] Datasets: [" for g in ["calculated_actions_speeds_turns", "calculated_orientations_rad"]: if g in e: del e[g] txt += f"{g}, " txt = txt[:-2] + "]\n" if verbose: print(txt[:-1]) @property def entity_names(self) -> Iterable[str]: """Getter for the names of all entities Returns: Array of all names. """ return sorted(self["entities"].keys()) @property def entities(self): return [ robofish.io.Entity.from_h5py_group(self["entities"][name]) for name in self.entity_names ] @property def entity_positions(self): return self.select_entity_property(None, entity_property=Entity.positions) @property def entity_orientations(self): return self.select_entity_property(None, entity_property=Entity.orientations) @property def entity_orientations_rad(self): return self.select_entity_property( None, entity_property=Entity.orientations_rad ) @property def entity_poses(self): return self.select_entity_property(None, entity_property=Entity.poses) @property def entity_poses_rad(self): return self.select_entity_property(None, entity_property=Entity.poses_rad) @property @deprecation.deprecated( deprecated_in="0.2", removed_in="0.2.4", details="We found that our calculation of 'poses_calc_ori' is flawed." "Please replace it with 'poses' and use the tracked orientation." "If you see this message and you don't know what to do, update all packages and if nothing helps, contact Andi.\n" "Don't ignore this warning, it's a serious issue.", ) def entity_poses_calc_ori(self): return self.select_entity_property(None, entity_property=Entity.poses_calc_ori) @property @deprecation.deprecated( deprecated_in="0.2", removed_in="0.2.4", details="We found that our calculation of 'poses_calc_ori_rad' is flawed." "Please replace it with 'poses_rad' and use the tracked orientation." "If you see this message and you don't know what to do, update all packages and if nothing helps, contact Andi.\n" "Don't ignore this warning, it's a serious issue.", ) def entity_poses_calc_ori_rad(self): return self.select_entity_property( None, entity_property=Entity.poses_calc_ori_rad ) @property @deprecation.deprecated( deprecated_in="0.2", removed_in="0.2.4", details="We found that our calculation of 'entity_speeds_turns' is flawed and replaced it " "with 'entity_actions_speeds_turns'. The difference in calculation is, that the tracked " "orientation is used now which gives the fish the ability to swim backwards. " "If you see this message and you don't know what to do, update all packages and if nothing helps, contact Andi.\n" "Don't ignore this warning, it's a serious issue.", ) def entity_speeds_turns(self): return self.select_entity_property(None, entity_property=Entity.speed_turn) @property def entity_actions_speeds_turns(self): """Calculate the speed, turn and from the recorded positions and orientations. The turn is calculated by the change of orientation between frames. The speed is calculated by the distance between the points, projected on the new orientation vector. The sideway change of position cannot be represented with this method. Returns: An array with shape (number_of_entities, number_of_positions -1, 2 (speed in cm/frame, turn in rad/frame). """ return self.select_entity_property( None, entity_property=Entity.actions_speeds_turns ) def select_entity_poses(self, *args, ori_rad=False, **kwargs): entity_property = Entity.poses_rad if ori_rad else Entity.poses return self.select_entity_property( *args, entity_property=entity_property, **kwargs ) def select_entity_property( self, predicate: types.LambdaType = None, entity_property: Union[property, str] = Entity.poses, ) -> Iterable: """Get a property of selected entities. Entities can be selected, using a lambda function. The property of the entities can be selected. Args: predicate: a lambda function, selecting entities (example: lambda e: e.category == "fish") entity_property: a property of the Entity class (example: Entity.poses_rad) or a string with the name of the dataset. Returns: An three dimensional array of all properties of all entities with the shape (entity, time, property_length). If an entity has a shorter length of the property, the output will be filled with nans. """ entities = self.entities if predicate is not None: entities = [e for e in entities if predicate(e)] assert self.common_sampling(entities) is not None # Initialize poses output array if isinstance(entity_property, str): properties = [entity[entity_property] for entity in entities] else: properties = [entity_property.__get__(entity) for entity in entities] max_timesteps = max([0] + [p.shape[0] for p in properties]) property_array = np.empty( (len(entities), max_timesteps, properties[0].shape[1]) ) property_array[:] = np.nan # Fill output array for i, entity in enumerate(entities): property_array[i][: properties[i].shape[0]] = properties[i] return property_array def validate(self, strict_validate: bool = True) -> Tuple[bool, str]: """Validate the file to the specification. The function compares a given file to the robofish track format specification: https://git.imp.fu-berlin.de/bioroboticslab/robofish/track_format First all specified arrays are formatted to be numpy arrays with the specified datatype. Then all specified shapes are validated. Lastly calendar points are validated to be datetimes according to ISO8601. Args: track: A track as a dictionary strict_validate: Throw an exception instead of just returning false. Returns: The function returns a touple of validity and an error message Throws: AssertionError: When the file is invalid and strict_validate is True """ return robofish.io.validate(self, strict_validate) def to_string( self, output_format: str = "shape", max_width: int = 120, full_attrs: bool = False, ) -> str: """The file is formatted to a human readable format. Args: output_format: ['shape', 'full'] show the shape, or the full content of datasets max_width: set the width in characters after which attribute values get abbreviated full_attrs: do not abbreviate attribute values if True Returns: A human readable string, representing the file """ def recursive_stringify( obj: h5py.Group, output_format: str, parent_indices: List[int] = [], parent_siblings: List[int] = [], ) -> str: """This function crawls recursively into hdf5 groups. Datasets and attributes are directly attached, for groups, the function is recursively called again. Args: obj: a h5py group output_format: ['shape', 'full'] show the shape, or the full content of datasets Returns: A string representation of the group """ def lines(dataset_attribute: bool = False) -> str: """Get box-drawing characters for the graph lines.""" line = "" for pi, ps in zip(parent_indices, parent_siblings): if pi < ps - 1: line += "│ " else: line += " " if dataset_attribute: line += " " line += "─ " junction_index = 2 * len(parent_indices) + dataset_attribute * 2 - 1 last = "└" other = "├" if dataset_attribute: j = ( last if list(value.attrs.keys()).index(d_key) == len(value.attrs) - 1 else other ) else: j = last if index == num_children - 1 else other line = line[: junction_index + 1] + j + line[junction_index + 1 :] if isinstance(value, h5py.Group) or ( isinstance(value, h5py.Dataset) and not dataset_attribute and value.attrs ): line = line[:-1] + "┬─" else: line = line[:-1] + "──" return line + " " s = "" max_key_len = 0 num_children = 0 if obj.attrs: max_key_len = max(len(key) for key in obj.attrs) num_children += len(obj.attrs) if hasattr(obj, "items"): max_key_len = max([len(key) for key in obj] + [max_key_len]) num_children += len(obj) index = 0 if obj.attrs: for key, value in obj.attrs.items(): if not full_attrs: value = str(value).replace("\n", " ").strip() if len(value) > max_width - max_key_len - len(lines()): value = ( value[: max_width - max_key_len - len(lines()) - 3] + "..." ) s += f"{lines()}{key: <{max_key_len}} {value}\n" index += 1 if hasattr(obj, "items"): for key, value in obj.items(): if isinstance(value, h5py.Dataset): if output_format == "shape": s += ( f"{lines()}" f"{key: <{max_key_len}} Shape {value.shape}\n" ) else: s += f"{lines()}{key}:\n" s += np.array2string( value, precision=2, separator=" ", suppress_small=True, ) s += "\n" if value.attrs: d_max_key_len = max(len(dk) for dk in value.attrs) for d_key, d_value in value.attrs.items(): d_value = str(d_value).replace("\n", " ").strip() if len(d_value) > max_width - d_max_key_len - len( lines(True) ): if not full_attrs: d_value = d_value[ : max_width - d_max_key_len - len(lines(True)) ] d_value = d_value[:-3] + "..." s += f"{lines(True)}{d_key: <{d_max_key_len}} {d_value}\n" if isinstance(value, h5py.Group): s += f"{lines()}{key}\n" + recursive_stringify( obj=value, output_format=output_format, parent_indices=parent_indices + [index], parent_siblings=parent_siblings + [num_children], ) index += 1 return s return recursive_stringify(self, output_format) def __str__(self): return self.to_string() def plot( self, ax=None, lw_distances=False, lw=2, ms=32, figsize=None, step_size=4, c=None, cmap="Set1", skip_timesteps=0, max_timesteps=None, show=False, legend=True, ): """Plot the file using matplotlib.pyplot The tracks in the file are plotted using matplotlib.plot(). Args: ax (matplotlib.axes, optional): An axes object to plot in. If None is given, a new figure is created. lw_distances (bool, optional): Flag to show the distances between individuals through line width. figsize (Tuple[int], optional): Size of a newly created figure. step_size (int, optional): when using lw_distances, the track is split into sections which have a common line width. This parameter defines the length of the sections. c (Array[color_representation], optional): An array of colors. Each item has to be matplotlib.colors.is_color_like(item). cmap (matplotlib.colors.Colormap, optional): The colormap to use skip_timesteps (int, optional): Skip timesteps in the begining of the file max_timesteps (int, optional): Cut of timesteps in the end of the file. show (bool, optional): Show the created plot. Returns: matplotlib.axes: The axes object with the plot. """ if max_timesteps is not None: poses = self.entity_positions[ :, skip_timesteps : max_timesteps + skip_timesteps ] else: poses = self.entity_positions[:, skip_timesteps:] if lw_distances and poses.shape[0] < 2: lw_distances = False if lw_distances: poses_diff = np.diff(poses, axis=0) # Axis 0 is fish distances = np.linalg.norm(poses_diff, axis=2) min_distances = np.min(distances, axis=0) # Magic numbers found by trial and error. Everything above 15cm will be represented as line width 1 max_distance = 10 max_lw = 4 line_width = ( np.clip(max_distance - min_distances, 1, max_distance) * max_lw / max_distance ) else: step_size = poses.shape[1] cmap = cm.get_cmap(cmap) x_world, y_world = self.world_size if figsize is None: figsize = (8, 8) if ax is None: fig, ax = plt.subplots(1, 1, figsize=figsize) if self.path is not None: ax.set_title("\n".join(wrap(Path(self.path).name, width=35))) ax.set_xlim(-x_world / 2, x_world / 2) ax.set_ylim(-y_world / 2, y_world / 2) for fish_id in range(poses.shape[0]): if c is None: this_c = cmap(fish_id) elif isinstance(c, list): this_c = c[fish_id] timesteps = poses.shape[1] - 1 for t in range(0, timesteps, step_size): if lw_distances: lw = np.mean(line_width[t : t + step_size + 1]) ax.plot( poses[fish_id, t : t + step_size + 1, 0], poses[fish_id, t : t + step_size + 1, 1], c=this_c, lw=lw, ) # Plotting outside of the figure to have the label ax.plot([550, 600], [550, 600], lw=5, c=this_c, label=fish_id) # ax.scatter( # [poses[:, skip_timesteps, 0]], # [poses[:, skip_timesteps, 1]], # marker="h", # c="black", # s=ms, # label="Start", # zorder=5, # ) ax.scatter( [poses[:, -1, 0]], [poses[:, -1, 1]], marker="x", c="black", s=ms, label="End", zorder=5, ) if legend and isinstance(legend, str): ax.legend(legend) elif legend: ax.legend() ax.set_xlabel("x [cm]") ax.set_ylabel("y [cm]") if show: plt.show() return ax def render(self, video_path=None, **kwargs): """Render a video of the file. As there are render functions in gym_guppy and robofish.trackviewer, this function is a temporary addition. The goal should be to bring together the rendering tools.""" if video_path is not None: try: run(["ffmpeg"], capture_output=True) except Exception as e: raise Exception( f"ffmpeg is required to store videos. Please install it.\n{e}" ) def shape_vertices(scale=1) -> np.ndarray: base_shape = np.array( [ (+3.0, +0.0), (+2.5, +1.0), (+1.5, +1.5), (-2.5, +1.0), (-4.5, +0.0), (-2.5, -1.0), (+1.5, -1.5), (+2.5, -1.0), ] ) return base_shape * scale default_options = { "linewidth": 2, "speedup": 1, "trail": 100, "entity_scale": 0.2, "fixed_view": False, "view_size": 50, "margin": 15, "slow_view": 0.8, "slow_zoom": 0.95, "cut_frames_start": None, "cut_frames_end": None, "show_text": False, "render_goals": False, "render_targets": False, "dpi": 200, "figsize": 10, } options = { key: kwargs[key] if key in kwargs else default_options[key] for key in default_options.keys() } fig, ax = plt.subplots(figsize=(options["figsize"], options["figsize"])) ax.set_aspect("equal") ax.set_facecolor("gray") plt.tight_layout(pad=0.05) n_entities = len(self.entities) lines = [ plt.plot([], [], lw=options["linewidth"], zorder=0)[0] for _ in range(n_entities) ] points = [ plt.scatter([], [], marker="x", color="k"), plt.plot([], [], linestyle="dotted", alpha=0.5, color="k", zorder=0)[0], ] categories = [entity.attrs.get("category", None) for entity in self.entities] entity_polygons = [ patches.Polygon(shape_vertices(options["entity_scale"]), facecolor=color) for color in [ "gray" if category == "robot" else "k" for category in categories ] ] border_vertices = np.array( [ np.array([-1, -1, 1, 1, -1]) * self.world_size[0] / 2, np.array([-1, 1, 1, -1, -1]) * self.world_size[1] / 2, ] ) spacing = 10 x = np.arange( -0.5 * self.world_size[0] + spacing, 0.5 * self.world_size[0], spacing ) y = np.arange( -0.5 * self.world_size[1] + spacing, 0.5 * self.world_size[1], spacing ) xv, yv = np.meshgrid(x, y) grid_points = plt.scatter(xv, yv, c="gray", s=1.5) # border = plt.plot(border_vertices[0], border_vertices[1], "k") border = patches.Polygon(border_vertices.T, facecolor="w", zorder=-1) def title(file_frame: int) -> str: """Search for datasets containing text for displaying it in the video""" output = [] for e in self.entities: for key, val in e.items(): if val.dtype == object and type(val[0]) == bytes: output.append(f"{e.name}.{key}='{val[file_frame].decode()}'") return ", ".join(output) def get_goal(file_frame: int) -> Optional[np.ndarray]: """Return current goal of robot, if robot exists and has a goal.""" goal = None if "robot" in categories: robot = self.entities[categories.index("robot")] try: goal = robot["goals"][file_frame] except KeyError: pass if goal is not None and np.isnan(goal).any(): goal = None return goal def get_target(file_frame: int) -> Tuple[List, List]: """Return line points from robot to target""" if "robot" in categories: robot = self.entities[categories.index("robot")] rpos = robot["positions"][file_frame] target = robot["targets"][file_frame] return [rpos[0], target[0]], [rpos[1], target[1]] return [], [] def init(): ax.set_xlim(-0.5 * self.world_size[0], 0.5 * self.world_size[0]) ax.set_ylim(-0.5 * self.world_size[1], 0.5 * self.world_size[1]) ax.set_xticks([]) ax.set_xticks([], minor=True) ax.set_yticks([]) ax.set_yticks([], minor=True) for e_poly in entity_polygons: ax.add_patch(e_poly) ax.add_patch(border) return lines + entity_polygons + [border] + points n_frames = self.entity_poses.shape[1] if options["cut_frames_end"] == 0 or options["cut_frames_end"] is None: options["cut_frames_end"] = n_frames if options["cut_frames_start"] is None: options["cut_frames_start"] = 0 frame_range = ( options["cut_frames_start"], min(n_frames, options["cut_frames_end"]), ) n_frames = int((frame_range[1] - frame_range[0]) / options["speedup"]) start_pose = self.entity_poses_rad[:, frame_range[0]] self.middle_of_swarm = np.mean(start_pose, axis=0) min_view = np.max((np.max(start_pose, axis=0) - np.min(start_pose, axis=0))[:2]) self.view_size = np.max([options["view_size"], min_view + options["margin"]]) if video_path is not None: pbar = tqdm(range(n_frames)) def update(frame): if "pbar" in locals().keys(): pbar.update(1) pbar.refresh() if frame < n_frames: entity_poses = self.entity_poses_rad file_frame = (frame * options["speedup"]) + frame_range[0] this_pose = entity_poses[:, file_frame] if not options["fixed_view"]: # Find the maximal distance between the entities in x or y direction min_view = np.max( (np.max(this_pose, axis=0) - np.min(this_pose, axis=0))[:2] ) new_view_size = np.max( [options["view_size"], min_view + options["margin"]] ) if not np.isnan(min_view).any() and not new_view_size is np.nan: self.middle_of_swarm = options[ "slow_view" ] * self.middle_of_swarm + (1 - options["slow_view"]) * np.mean( this_pose, axis=0 ) self.view_size = ( options["slow_zoom"] * self.view_size + (1 - options["slow_zoom"]) * new_view_size ) ax.set_xlim( self.middle_of_swarm[0] - self.view_size / 2, self.middle_of_swarm[0] + self.view_size / 2, ) ax.set_ylim( self.middle_of_swarm[1] - self.view_size / 2, self.middle_of_swarm[1] + self.view_size / 2, ) if options["show_text"]: ax.set_title(title(file_frame)) if options["render_goals"]: goal = get_goal(file_frame) if goal is not None: points[0].set_offsets(goal) if options["render_targets"]: points[1].set_data(get_target(file_frame)) poses_trails = entity_poses[ :, max(0, file_frame - options["trail"]) : file_frame ] for i_entity in range(n_entities): lines[i_entity].set_data( poses_trails[i_entity, :, 0], poses_trails[i_entity, :, 1] ) current_pose = entity_poses[i_entity, file_frame] t = mpl.transforms.Affine2D().translate( current_pose[0], current_pose[1] ) r = mpl.transforms.Affine2D().rotate(current_pose[2]) tra = r + t + ax.transData entity_polygons[i_entity].set_transform(tra) else: raise Exception( f"Frame is bigger than n_frames {file_frame} of {n_frames}" ) return lines + entity_polygons + [border] + points print(f"Preparing to render n_frames: {n_frames}") ani = animation.FuncAnimation( fig, update, frames=n_frames, init_func=init, blit=platform.system() != "Darwin", interval=1000 / self.frequency, repeat=False, ) if video_path is not None: # if i % (n / 40) == 0: # print(f"Saving frame {i} of {n} ({100*i/n:.1f}%)") video_path = Path(video_path) if video_path.exists(): y = input(f"Video {str(video_path)} exists. Overwrite? (y/n)") if y == "y": video_path.unlink() if not video_path.exists(): print(f"saving video to {video_path}") writervideo = animation.FFMpegWriter(fps=self.frequency) ani.save(video_path, writer=writervideo, dpi=options["dpi"]) plt.close() else: plt.show()
Ancestors
- h5py._hl.files.File
- h5py._hl.group.Group
- h5py._hl.base.HLObject
- h5py._hl.base.CommonStateObject
- h5py._hl.base.MutableMappingHDF5
- h5py._hl.base.MappingHDF5
- collections.abc.MutableMapping
- collections.abc.Mapping
- collections.abc.Collection
- collections.abc.Sized
- collections.abc.Iterable
- collections.abc.Container
Instance variables
var default_sampling
-
Expand source code
@property def default_sampling(self): assert ( "samplings" in self ), "The file does not have a group 'sampling' which is required." if "default" in self["samplings"].attrs: return self["samplings"].attrs["default"] return None
var entities
-
Expand source code
@property def entities(self): return [ robofish.io.Entity.from_h5py_group(self["entities"][name]) for name in self.entity_names ]
var entity_actions_speeds_turns
-
Calculate the speed, turn and from the recorded positions and orientations.
The turn is calculated by the change of orientation between frames. The speed is calculated by the distance between the points, projected on the new orientation vector. The sideway change of position cannot be represented with this method.
Returns
An array with shape (number_of_entities, number_of_positions -1, 2 (speed in cm/frame, turn in rad/frame).
Expand source code
@property def entity_actions_speeds_turns(self): """Calculate the speed, turn and from the recorded positions and orientations. The turn is calculated by the change of orientation between frames. The speed is calculated by the distance between the points, projected on the new orientation vector. The sideway change of position cannot be represented with this method. Returns: An array with shape (number_of_entities, number_of_positions -1, 2 (speed in cm/frame, turn in rad/frame). """ return self.select_entity_property( None, entity_property=Entity.actions_speeds_turns )
var entity_names : Iterable[str]
-
Getter for the names of all entities
Returns
Array of all names.
Expand source code
@property def entity_names(self) -> Iterable[str]: """Getter for the names of all entities Returns: Array of all names. """ return sorted(self["entities"].keys())
var entity_orientations
-
Expand source code
@property def entity_orientations(self): return self.select_entity_property(None, entity_property=Entity.orientations)
var entity_orientations_rad
-
Expand source code
@property def entity_orientations_rad(self): return self.select_entity_property( None, entity_property=Entity.orientations_rad )
var entity_poses
-
Expand source code
@property def entity_poses(self): return self.select_entity_property(None, entity_property=Entity.poses)
var entity_poses_calc_ori
-
Deprecated since version: 0.2
This will be removed in 0.2.4. We found that our calculation of 'poses_calc_ori' is flawed.Please replace it with 'poses' and use the tracked orientation.If you see this message and you don't know what to do, update all packages and if nothing helps, contact Andi.
Don't ignore this warning, it's a serious issue.
Expand source code
@property @deprecation.deprecated( deprecated_in="0.2", removed_in="0.2.4", details="We found that our calculation of 'poses_calc_ori' is flawed." "Please replace it with 'poses' and use the tracked orientation." "If you see this message and you don't know what to do, update all packages and if nothing helps, contact Andi.\n" "Don't ignore this warning, it's a serious issue.", ) def entity_poses_calc_ori(self): return self.select_entity_property(None, entity_property=Entity.poses_calc_ori)
var entity_poses_calc_ori_rad
-
Deprecated since version: 0.2
This will be removed in 0.2.4. We found that our calculation of 'poses_calc_ori_rad' is flawed.Please replace it with 'poses_rad' and use the tracked orientation.If you see this message and you don't know what to do, update all packages and if nothing helps, contact Andi.
Don't ignore this warning, it's a serious issue.
Expand source code
@property @deprecation.deprecated( deprecated_in="0.2", removed_in="0.2.4", details="We found that our calculation of 'poses_calc_ori_rad' is flawed." "Please replace it with 'poses_rad' and use the tracked orientation." "If you see this message and you don't know what to do, update all packages and if nothing helps, contact Andi.\n" "Don't ignore this warning, it's a serious issue.", ) def entity_poses_calc_ori_rad(self): return self.select_entity_property( None, entity_property=Entity.poses_calc_ori_rad )
var entity_poses_rad
-
Expand source code
@property def entity_poses_rad(self): return self.select_entity_property(None, entity_property=Entity.poses_rad)
var entity_positions
-
Expand source code
@property def entity_positions(self): return self.select_entity_property(None, entity_property=Entity.positions)
var entity_speeds_turns
-
Deprecated since version: 0.2
This will be removed in 0.2.4. We found that our calculation of 'entity_speeds_turns' is flawed and replaced it with 'entity_actions_speeds_turns'. The difference in calculation is, that the tracked orientation is used now which gives the fish the ability to swim backwards. If you see this message and you don't know what to do, update all packages and if nothing helps, contact Andi.
Don't ignore this warning, it's a serious issue.
Expand source code
@property @deprecation.deprecated( deprecated_in="0.2", removed_in="0.2.4", details="We found that our calculation of 'entity_speeds_turns' is flawed and replaced it " "with 'entity_actions_speeds_turns'. The difference in calculation is, that the tracked " "orientation is used now which gives the fish the ability to swim backwards. " "If you see this message and you don't know what to do, update all packages and if nothing helps, contact Andi.\n" "Don't ignore this warning, it's a serious issue.", ) def entity_speeds_turns(self): return self.select_entity_property(None, entity_property=Entity.speed_turn)
var frequency
-
Expand source code
@property def frequency(self): common_sampling = self.common_sampling() assert common_sampling is not None, "The sampling differs between entities." assert ( "frequency_hz" in common_sampling.attrs ), "The common sampling has no frequency_hz" return common_sampling.attrs["frequency_hz"]
var temp_dir
-
Expand source code
@property def temp_dir(self): cla = type(self) if cla._temp_dir is None: cla._temp_dir = tempfile.TemporaryDirectory(prefix="robofish-io-") return Path(cla._temp_dir.name)
var world_size
-
Expand source code
@property def world_size(self): return self.attrs["world_size_cm"]
Methods
def clear_calculated_data(self, verbose=True)
-
Delete all calculated data from the files.
Expand source code
def clear_calculated_data(self, verbose=True): """Delete all calculated data from the files.""" txt = "" for e in self.entities: txt += f"Deleting from {e}. Attrs: [" for a in ["poses_hash"]: if a in e.attrs: del e.attrs[a] txt += f"{a}, " txt = txt[:-2] + "] Datasets: [" for g in ["calculated_actions_speeds_turns", "calculated_orientations_rad"]: if g in e: del e[g] txt += f"{g}, " txt = txt[:-2] + "]\n" if verbose: print(txt[:-1])
def close(self)
-
Close the file. All open objects become invalid
Expand source code
def close(self): if self.mode != "r": self.update_calculated_data() super().close()
def common_sampling(self, entities: Iterable[ForwardRef('robofish.io.Entity')] = None) ‑> h5py._hl.group.Group
-
Check if all entities have the same sampling.
Args
entities
- optional array of entities. If None is given, all entities are checked.
Returns
The h5py group of the common sampling. If there is no common sampling, None will be returned.
Expand source code
def common_sampling( self, entities: Iterable["robofish.io.Entity"] = None ) -> h5py.Group: """Check if all entities have the same sampling. Args: entities: optional array of entities. If None is given, all entities are checked. Returns: The h5py group of the common sampling. If there is no common sampling, None will be returned. """ custom_sampling = None for entity in self.entities: if "sampling" in entity["positions"].attrs: this_sampling = entity["positions"].attrs["sampling"] if custom_sampling is None: custom_sampling = this_sampling elif custom_sampling != this_sampling: return None sampling = self.default_sampling if custom_sampling is None else custom_sampling return self["samplings"][sampling]
def create_entity(self, category: str, poses: Iterable = None, name: str = None, positions: Iterable = None, orientations: Iterable = None, outlines: Iterable = None, sampling: str = None) ‑> str
-
Creates a new single entity.
Args
- TODO
category
- the of the entity. The canonical values are ['organism', 'robot', 'obstacle'].
poses
- optional two dimensional array, containing the poses of the entity (x,y,orientation_x, orientation_y).
poses_rad
- optional two dimensional containing the poses of the entity (x,y, orientation_rad).
name
- optional name of the entity. If no name is given, the is used with an id (e.g. 'fish_1')
outlines
- optional three dimensional array, containing the outlines of the entity
Returns
Name of the created entity
Expand source code
def create_entity( self, category: str, poses: Iterable = None, name: str = None, positions: Iterable = None, orientations: Iterable = None, outlines: Iterable = None, sampling: str = None, ) -> str: """Creates a new single entity. Args: TODO category: the of the entity. The canonical values are ['organism', 'robot', 'obstacle']. poses: optional two dimensional array, containing the poses of the entity (x,y,orientation_x, orientation_y). poses_rad: optional two dimensional containing the poses of the entity (x,y, orientation_rad). name: optional name of the entity. If no name is given, the is used with an id (e.g. 'fish_1') outlines: optional three dimensional array, containing the outlines of the entity Returns: Name of the created entity """ if sampling is None and self.default_sampling is None: raise Exception( "There was no sampling specified, when creating the file, nor when creating the entity." ) entity = robofish.io.Entity.create_entity( self["entities"], category, poses, name, positions, orientations, outlines, sampling, ) return entity
def create_multiple_entities(self, category: str, poses: Iterable, names: Iterable[str] = None, outlines=None, sampling=None) ‑> Iterable
-
Creates multiple entities.
Args
category
- The common category for the entities. The canonical values are ['organism', 'robot', 'obstacle'].
poses
- three dimensional array, containing the poses of the entity.
name
- optional array of names of the entities. If no names are given, the category is used with an id (e.g. 'fish_1')
outlines
- optional array, containing the outlines of the entities, either a three dimensional common outline array can be given, or a four dimensional array.
sampling
- The string refference to the sampling. If none is given, the standard sampling from creating the file is used.
Returns
Array of names of the created entities
Expand source code
def create_multiple_entities( self, category: str, poses: Iterable, names: Iterable[str] = None, outlines=None, sampling=None, ) -> Iterable: """Creates multiple entities. Args: category: The common category for the entities. The canonical values are ['organism', 'robot', 'obstacle']. poses: three dimensional array, containing the poses of the entity. name: optional array of names of the entities. If no names are given, the category is used with an id (e.g. 'fish_1') outlines: optional array, containing the outlines of the entities, either a three dimensional common outline array can be given, or a four dimensional array. sampling: The string refference to the sampling. If none is given, the standard sampling from creating the file is used. Returns: Array of names of the created entities """ assert ( poses.ndim == 3 ), f"A 3 dimensional array was expected (entity, timestep, 3). There were {poses.ndim} dimensions in poses: {poses.shape}" assert poses.shape[2] in [3, 4] agents = poses.shape[0] entity_names = [] for i in range(agents): e_name = None if names is None else names[i] e_outline = ( outlines if outlines is None or outlines.ndim == 3 else outlines[i] ) entity_names.append( self.create_entity( category=category, sampling=sampling, poses=poses[i], name=e_name, outlines=e_outline, ) ) return entity_names
def create_sampling(self, name: str = None, frequency_hz: int = None, monotonic_time_points_us: Iterable = None, calendar_time_points: Iterable = None, default: bool = False)
-
Expand source code
def create_sampling( self, name: str = None, frequency_hz: int = None, monotonic_time_points_us: Iterable = None, calendar_time_points: Iterable = None, default: bool = False, ): # Find Name for sampling if none is given if name is None: if frequency_hz is not None: name = "%d hz" % frequency_hz i = 1 while name is None or name in self["samplings"]: name = "sampling_%d" % i i += 1 sampling = self["samplings"].create_group(name) if monotonic_time_points_us is not None: monotonic_time_points_us = np.array( monotonic_time_points_us, dtype=np.int64 ) sampling.create_dataset( "monotonic_time_points_us", data=monotonic_time_points_us ) if frequency_hz is None: diff = np.diff(monotonic_time_points_us) if np.all(diff == diff[0]) and diff[0] > 0: frequency_hz = 1e6 / diff[0] warnings.warn( f"The frequency_hz of {frequency_hz:.2f}hz was calculated automatically by robofish.io. The safer variant is to pass it using frequency_hz.\nThis is important when using fish_models with the files." ) else: warnings.warn( "The frequency_hz could not be calculated automatically. When using fish_models, the file will access frequency_hz." ) if frequency_hz is not None: sampling.attrs["frequency_hz"] = (np.float32)(frequency_hz) if calendar_time_points is not None: def format_calendar_time_point(p): if isinstance(p, datetime.datetime): assert p.tzinfo is not None, "Missing timezone for calendar point." return p.isoformat(timespec="microseconds") elif isinstance(p, str): assert p == datetime.datetime.fromisoformat(p).isoformat( timespec="microseconds" ) return p else: assert ( False ), "Calendar points must be datetime.datetime instances or strings." calendar_time_points = [ format_calendar_time_point(p) for p in calendar_time_points ] sampling.create_dataset( "calendar_time_points", data=calendar_time_points, dtype=h5py.string_dtype(encoding="utf-8"), ) if default: self["samplings"].attrs["default"] = name return name
def plot(self, ax=None, lw_distances=False, lw=2, ms=32, figsize=None, step_size=4, c=None, cmap='Set1', skip_timesteps=0, max_timesteps=None, show=False, legend=True)
-
Plot the file using matplotlib.pyplot
The tracks in the file are plotted using matplotlib.plot().
Args
ax
:matplotlib.axes
, optional- An axes object to plot in. If None is given, a new figure is created.
lw_distances
:bool
, optional- Flag to show the distances between individuals through line width.
figsize
:Tuple[int]
, optional- Size of a newly created figure.
step_size
:int
, optional- when using lw_distances, the track is split into sections which have a common line width. This parameter defines the length of the sections.
c
:Array[color_representation]
, optional- An array of colors. Each item has to be matplotlib.colors.is_color_like(item).
cmap
:matplotlib.colors.Colormap
, optional- The colormap to use
skip_timesteps
:int
, optional- Skip timesteps in the begining of the file
max_timesteps
:int
, optional- Cut of timesteps in the end of the file.
show
:bool
, optional- Show the created plot.
Returns
matplotlib.axes
- The axes object with the plot.
Expand source code
def plot( self, ax=None, lw_distances=False, lw=2, ms=32, figsize=None, step_size=4, c=None, cmap="Set1", skip_timesteps=0, max_timesteps=None, show=False, legend=True, ): """Plot the file using matplotlib.pyplot The tracks in the file are plotted using matplotlib.plot(). Args: ax (matplotlib.axes, optional): An axes object to plot in. If None is given, a new figure is created. lw_distances (bool, optional): Flag to show the distances between individuals through line width. figsize (Tuple[int], optional): Size of a newly created figure. step_size (int, optional): when using lw_distances, the track is split into sections which have a common line width. This parameter defines the length of the sections. c (Array[color_representation], optional): An array of colors. Each item has to be matplotlib.colors.is_color_like(item). cmap (matplotlib.colors.Colormap, optional): The colormap to use skip_timesteps (int, optional): Skip timesteps in the begining of the file max_timesteps (int, optional): Cut of timesteps in the end of the file. show (bool, optional): Show the created plot. Returns: matplotlib.axes: The axes object with the plot. """ if max_timesteps is not None: poses = self.entity_positions[ :, skip_timesteps : max_timesteps + skip_timesteps ] else: poses = self.entity_positions[:, skip_timesteps:] if lw_distances and poses.shape[0] < 2: lw_distances = False if lw_distances: poses_diff = np.diff(poses, axis=0) # Axis 0 is fish distances = np.linalg.norm(poses_diff, axis=2) min_distances = np.min(distances, axis=0) # Magic numbers found by trial and error. Everything above 15cm will be represented as line width 1 max_distance = 10 max_lw = 4 line_width = ( np.clip(max_distance - min_distances, 1, max_distance) * max_lw / max_distance ) else: step_size = poses.shape[1] cmap = cm.get_cmap(cmap) x_world, y_world = self.world_size if figsize is None: figsize = (8, 8) if ax is None: fig, ax = plt.subplots(1, 1, figsize=figsize) if self.path is not None: ax.set_title("\n".join(wrap(Path(self.path).name, width=35))) ax.set_xlim(-x_world / 2, x_world / 2) ax.set_ylim(-y_world / 2, y_world / 2) for fish_id in range(poses.shape[0]): if c is None: this_c = cmap(fish_id) elif isinstance(c, list): this_c = c[fish_id] timesteps = poses.shape[1] - 1 for t in range(0, timesteps, step_size): if lw_distances: lw = np.mean(line_width[t : t + step_size + 1]) ax.plot( poses[fish_id, t : t + step_size + 1, 0], poses[fish_id, t : t + step_size + 1, 1], c=this_c, lw=lw, ) # Plotting outside of the figure to have the label ax.plot([550, 600], [550, 600], lw=5, c=this_c, label=fish_id) # ax.scatter( # [poses[:, skip_timesteps, 0]], # [poses[:, skip_timesteps, 1]], # marker="h", # c="black", # s=ms, # label="Start", # zorder=5, # ) ax.scatter( [poses[:, -1, 0]], [poses[:, -1, 1]], marker="x", c="black", s=ms, label="End", zorder=5, ) if legend and isinstance(legend, str): ax.legend(legend) elif legend: ax.legend() ax.set_xlabel("x [cm]") ax.set_ylabel("y [cm]") if show: plt.show() return ax
def render(self, video_path=None, **kwargs)
-
Render a video of the file.
As there are render functions in gym_guppy and robofish.trackviewer, this function is a temporary addition. The goal should be to bring together the rendering tools.
Expand source code
def render(self, video_path=None, **kwargs): """Render a video of the file. As there are render functions in gym_guppy and robofish.trackviewer, this function is a temporary addition. The goal should be to bring together the rendering tools.""" if video_path is not None: try: run(["ffmpeg"], capture_output=True) except Exception as e: raise Exception( f"ffmpeg is required to store videos. Please install it.\n{e}" ) def shape_vertices(scale=1) -> np.ndarray: base_shape = np.array( [ (+3.0, +0.0), (+2.5, +1.0), (+1.5, +1.5), (-2.5, +1.0), (-4.5, +0.0), (-2.5, -1.0), (+1.5, -1.5), (+2.5, -1.0), ] ) return base_shape * scale default_options = { "linewidth": 2, "speedup": 1, "trail": 100, "entity_scale": 0.2, "fixed_view": False, "view_size": 50, "margin": 15, "slow_view": 0.8, "slow_zoom": 0.95, "cut_frames_start": None, "cut_frames_end": None, "show_text": False, "render_goals": False, "render_targets": False, "dpi": 200, "figsize": 10, } options = { key: kwargs[key] if key in kwargs else default_options[key] for key in default_options.keys() } fig, ax = plt.subplots(figsize=(options["figsize"], options["figsize"])) ax.set_aspect("equal") ax.set_facecolor("gray") plt.tight_layout(pad=0.05) n_entities = len(self.entities) lines = [ plt.plot([], [], lw=options["linewidth"], zorder=0)[0] for _ in range(n_entities) ] points = [ plt.scatter([], [], marker="x", color="k"), plt.plot([], [], linestyle="dotted", alpha=0.5, color="k", zorder=0)[0], ] categories = [entity.attrs.get("category", None) for entity in self.entities] entity_polygons = [ patches.Polygon(shape_vertices(options["entity_scale"]), facecolor=color) for color in [ "gray" if category == "robot" else "k" for category in categories ] ] border_vertices = np.array( [ np.array([-1, -1, 1, 1, -1]) * self.world_size[0] / 2, np.array([-1, 1, 1, -1, -1]) * self.world_size[1] / 2, ] ) spacing = 10 x = np.arange( -0.5 * self.world_size[0] + spacing, 0.5 * self.world_size[0], spacing ) y = np.arange( -0.5 * self.world_size[1] + spacing, 0.5 * self.world_size[1], spacing ) xv, yv = np.meshgrid(x, y) grid_points = plt.scatter(xv, yv, c="gray", s=1.5) # border = plt.plot(border_vertices[0], border_vertices[1], "k") border = patches.Polygon(border_vertices.T, facecolor="w", zorder=-1) def title(file_frame: int) -> str: """Search for datasets containing text for displaying it in the video""" output = [] for e in self.entities: for key, val in e.items(): if val.dtype == object and type(val[0]) == bytes: output.append(f"{e.name}.{key}='{val[file_frame].decode()}'") return ", ".join(output) def get_goal(file_frame: int) -> Optional[np.ndarray]: """Return current goal of robot, if robot exists and has a goal.""" goal = None if "robot" in categories: robot = self.entities[categories.index("robot")] try: goal = robot["goals"][file_frame] except KeyError: pass if goal is not None and np.isnan(goal).any(): goal = None return goal def get_target(file_frame: int) -> Tuple[List, List]: """Return line points from robot to target""" if "robot" in categories: robot = self.entities[categories.index("robot")] rpos = robot["positions"][file_frame] target = robot["targets"][file_frame] return [rpos[0], target[0]], [rpos[1], target[1]] return [], [] def init(): ax.set_xlim(-0.5 * self.world_size[0], 0.5 * self.world_size[0]) ax.set_ylim(-0.5 * self.world_size[1], 0.5 * self.world_size[1]) ax.set_xticks([]) ax.set_xticks([], minor=True) ax.set_yticks([]) ax.set_yticks([], minor=True) for e_poly in entity_polygons: ax.add_patch(e_poly) ax.add_patch(border) return lines + entity_polygons + [border] + points n_frames = self.entity_poses.shape[1] if options["cut_frames_end"] == 0 or options["cut_frames_end"] is None: options["cut_frames_end"] = n_frames if options["cut_frames_start"] is None: options["cut_frames_start"] = 0 frame_range = ( options["cut_frames_start"], min(n_frames, options["cut_frames_end"]), ) n_frames = int((frame_range[1] - frame_range[0]) / options["speedup"]) start_pose = self.entity_poses_rad[:, frame_range[0]] self.middle_of_swarm = np.mean(start_pose, axis=0) min_view = np.max((np.max(start_pose, axis=0) - np.min(start_pose, axis=0))[:2]) self.view_size = np.max([options["view_size"], min_view + options["margin"]]) if video_path is not None: pbar = tqdm(range(n_frames)) def update(frame): if "pbar" in locals().keys(): pbar.update(1) pbar.refresh() if frame < n_frames: entity_poses = self.entity_poses_rad file_frame = (frame * options["speedup"]) + frame_range[0] this_pose = entity_poses[:, file_frame] if not options["fixed_view"]: # Find the maximal distance between the entities in x or y direction min_view = np.max( (np.max(this_pose, axis=0) - np.min(this_pose, axis=0))[:2] ) new_view_size = np.max( [options["view_size"], min_view + options["margin"]] ) if not np.isnan(min_view).any() and not new_view_size is np.nan: self.middle_of_swarm = options[ "slow_view" ] * self.middle_of_swarm + (1 - options["slow_view"]) * np.mean( this_pose, axis=0 ) self.view_size = ( options["slow_zoom"] * self.view_size + (1 - options["slow_zoom"]) * new_view_size ) ax.set_xlim( self.middle_of_swarm[0] - self.view_size / 2, self.middle_of_swarm[0] + self.view_size / 2, ) ax.set_ylim( self.middle_of_swarm[1] - self.view_size / 2, self.middle_of_swarm[1] + self.view_size / 2, ) if options["show_text"]: ax.set_title(title(file_frame)) if options["render_goals"]: goal = get_goal(file_frame) if goal is not None: points[0].set_offsets(goal) if options["render_targets"]: points[1].set_data(get_target(file_frame)) poses_trails = entity_poses[ :, max(0, file_frame - options["trail"]) : file_frame ] for i_entity in range(n_entities): lines[i_entity].set_data( poses_trails[i_entity, :, 0], poses_trails[i_entity, :, 1] ) current_pose = entity_poses[i_entity, file_frame] t = mpl.transforms.Affine2D().translate( current_pose[0], current_pose[1] ) r = mpl.transforms.Affine2D().rotate(current_pose[2]) tra = r + t + ax.transData entity_polygons[i_entity].set_transform(tra) else: raise Exception( f"Frame is bigger than n_frames {file_frame} of {n_frames}" ) return lines + entity_polygons + [border] + points print(f"Preparing to render n_frames: {n_frames}") ani = animation.FuncAnimation( fig, update, frames=n_frames, init_func=init, blit=platform.system() != "Darwin", interval=1000 / self.frequency, repeat=False, ) if video_path is not None: # if i % (n / 40) == 0: # print(f"Saving frame {i} of {n} ({100*i/n:.1f}%)") video_path = Path(video_path) if video_path.exists(): y = input(f"Video {str(video_path)} exists. Overwrite? (y/n)") if y == "y": video_path.unlink() if not video_path.exists(): print(f"saving video to {video_path}") writervideo = animation.FFMpegWriter(fps=self.frequency) ani.save(video_path, writer=writervideo, dpi=options["dpi"]) plt.close() else: plt.show()
def save_as(self, path: Union[str, pathlib.Path], strict_validate: bool = True, no_warning: bool = False)
-
Save a copy of the file
Args
path
- path to a io file as a string or path object. If no path is specified, the last known path (from loading or saving) is used.
strict_validate
- optional boolean, if the file should be strictly validated, before saving. The default is True.
no_warning
- optional boolean, to remove the warning from the function.
Returns
The file itself, so something like f = robofish.io.File().save_as("file.hdf5") works
Expand source code
def save_as( self, path: Union[str, Path], strict_validate: bool = True, no_warning: bool = False, ): """Save a copy of the file Args: path: path to a io file as a string or path object. If no path is specified, the last known path (from loading or saving) is used. strict_validate: optional boolean, if the file should be strictly validated, before saving. The default is True. no_warning: optional boolean, to remove the warning from the function. Returns: The file itself, so something like f = robofish.io.File().save_as("file.hdf5") works """ self.update_calculated_data() self.validate(strict_validate=strict_validate) # Ensure all buffered data has been written to disk self.flush() path = Path(path).resolve() path.parent.mkdir(parents=True, exist_ok=True) filename = self.filename self.flush() self.close() self.closed = True shutil.copyfile(filename, path) if not no_warning: warnings.warn( "The 'save_as' function closes the file currently to be able to store it. If you want to use the file after saving it, please reload the file. The save_as function can be avoided by opening the correct file directly. If you want to get rid of this warning use 'save_as(..., no_warning=True)'" ) return None
def select_entity_poses(self, *args, ori_rad=False, **kwargs)
-
Expand source code
def select_entity_poses(self, *args, ori_rad=False, **kwargs): entity_property = Entity.poses_rad if ori_rad else Entity.poses return self.select_entity_property( *args, entity_property=entity_property, **kwargs )
def select_entity_property(self, predicate: function = None, entity_property: Union[property, str] = <property object>) ‑> Iterable
-
Get a property of selected entities.
Entities can be selected, using a lambda function. The property of the entities can be selected.
Args
predicate
- a lambda function, selecting entities
- (example: lambda e: e.category == "fish")
entity_property
- a property of the Entity class (example: Entity.poses_rad) or a string with the name of the dataset.
Returns
An three dimensional array of all properties of all entities with the shape (entity, time, property_length). If an entity has a shorter length of the property, the output will be filled with nans.
Expand source code
def select_entity_property( self, predicate: types.LambdaType = None, entity_property: Union[property, str] = Entity.poses, ) -> Iterable: """Get a property of selected entities. Entities can be selected, using a lambda function. The property of the entities can be selected. Args: predicate: a lambda function, selecting entities (example: lambda e: e.category == "fish") entity_property: a property of the Entity class (example: Entity.poses_rad) or a string with the name of the dataset. Returns: An three dimensional array of all properties of all entities with the shape (entity, time, property_length). If an entity has a shorter length of the property, the output will be filled with nans. """ entities = self.entities if predicate is not None: entities = [e for e in entities if predicate(e)] assert self.common_sampling(entities) is not None # Initialize poses output array if isinstance(entity_property, str): properties = [entity[entity_property] for entity in entities] else: properties = [entity_property.__get__(entity) for entity in entities] max_timesteps = max([0] + [p.shape[0] for p in properties]) property_array = np.empty( (len(entities), max_timesteps, properties[0].shape[1]) ) property_array[:] = np.nan # Fill output array for i, entity in enumerate(entities): property_array[i][: properties[i].shape[0]] = properties[i] return property_array
def to_string(self, output_format: str = 'shape', max_width: int = 120, full_attrs: bool = False) ‑> str
-
The file is formatted to a human readable format.
Args
output_format
- ['shape', 'full'] show the shape, or the full content of datasets
max_width
- set the width in characters after which attribute values get abbreviated
full_attrs
- do not abbreviate attribute values if True
Returns
A human readable string, representing the file
Expand source code
def to_string( self, output_format: str = "shape", max_width: int = 120, full_attrs: bool = False, ) -> str: """The file is formatted to a human readable format. Args: output_format: ['shape', 'full'] show the shape, or the full content of datasets max_width: set the width in characters after which attribute values get abbreviated full_attrs: do not abbreviate attribute values if True Returns: A human readable string, representing the file """ def recursive_stringify( obj: h5py.Group, output_format: str, parent_indices: List[int] = [], parent_siblings: List[int] = [], ) -> str: """This function crawls recursively into hdf5 groups. Datasets and attributes are directly attached, for groups, the function is recursively called again. Args: obj: a h5py group output_format: ['shape', 'full'] show the shape, or the full content of datasets Returns: A string representation of the group """ def lines(dataset_attribute: bool = False) -> str: """Get box-drawing characters for the graph lines.""" line = "" for pi, ps in zip(parent_indices, parent_siblings): if pi < ps - 1: line += "│ " else: line += " " if dataset_attribute: line += " " line += "─ " junction_index = 2 * len(parent_indices) + dataset_attribute * 2 - 1 last = "└" other = "├" if dataset_attribute: j = ( last if list(value.attrs.keys()).index(d_key) == len(value.attrs) - 1 else other ) else: j = last if index == num_children - 1 else other line = line[: junction_index + 1] + j + line[junction_index + 1 :] if isinstance(value, h5py.Group) or ( isinstance(value, h5py.Dataset) and not dataset_attribute and value.attrs ): line = line[:-1] + "┬─" else: line = line[:-1] + "──" return line + " " s = "" max_key_len = 0 num_children = 0 if obj.attrs: max_key_len = max(len(key) for key in obj.attrs) num_children += len(obj.attrs) if hasattr(obj, "items"): max_key_len = max([len(key) for key in obj] + [max_key_len]) num_children += len(obj) index = 0 if obj.attrs: for key, value in obj.attrs.items(): if not full_attrs: value = str(value).replace("\n", " ").strip() if len(value) > max_width - max_key_len - len(lines()): value = ( value[: max_width - max_key_len - len(lines()) - 3] + "..." ) s += f"{lines()}{key: <{max_key_len}} {value}\n" index += 1 if hasattr(obj, "items"): for key, value in obj.items(): if isinstance(value, h5py.Dataset): if output_format == "shape": s += ( f"{lines()}" f"{key: <{max_key_len}} Shape {value.shape}\n" ) else: s += f"{lines()}{key}:\n" s += np.array2string( value, precision=2, separator=" ", suppress_small=True, ) s += "\n" if value.attrs: d_max_key_len = max(len(dk) for dk in value.attrs) for d_key, d_value in value.attrs.items(): d_value = str(d_value).replace("\n", " ").strip() if len(d_value) > max_width - d_max_key_len - len( lines(True) ): if not full_attrs: d_value = d_value[ : max_width - d_max_key_len - len(lines(True)) ] d_value = d_value[:-3] + "..." s += f"{lines(True)}{d_key: <{d_max_key_len}} {d_value}\n" if isinstance(value, h5py.Group): s += f"{lines()}{key}\n" + recursive_stringify( obj=value, output_format=output_format, parent_indices=parent_indices + [index], parent_siblings=parent_siblings + [num_children], ) index += 1 return s return recursive_stringify(self, output_format)
def update_calculated_data(self, verbose=False)
-
Expand source code
def update_calculated_data(self, verbose=False): changed = any([e.update_calculated_data(verbose) for e in self.entities]) return changed
def validate(self, strict_validate: bool = True) ‑> Tuple[bool, str]
-
Validate the file to the specification.
The function compares a given file to the robofish track format specification: https://git.imp.fu-berlin.de/bioroboticslab/robofish/track_format First all specified arrays are formatted to be numpy arrays with the specified datatype. Then all specified shapes are validated. Lastly calendar points are validated to be datetimes according to ISO8601.
Args
track
- A track as a dictionary
strict_validate
- Throw an exception instead of just returning false.
Returns
The function returns a touple of validity and an error message
Throws
AssertionError: When the file is invalid and strict_validate is True
Expand source code
def validate(self, strict_validate: bool = True) -> Tuple[bool, str]: """Validate the file to the specification. The function compares a given file to the robofish track format specification: https://git.imp.fu-berlin.de/bioroboticslab/robofish/track_format First all specified arrays are formatted to be numpy arrays with the specified datatype. Then all specified shapes are validated. Lastly calendar points are validated to be datetimes according to ISO8601. Args: track: A track as a dictionary strict_validate: Throw an exception instead of just returning false. Returns: The function returns a touple of validity and an error message Throws: AssertionError: When the file is invalid and strict_validate is True """ return robofish.io.validate(self, strict_validate)