from pathlib import Path
from typing import Any, Dict, Optional, Union
import matplotlib.pyplot as plt
import motion3d as m3d
import numpy as np
import networkx as nx
import yaml
from excalibur.utils.logging import logger
def load_calibration(filename: Union[str, Path]) -> Optional[Dict[Any, m3d.TransformInterface]]:
# check file
filename = Path(filename)
if not filename.exists():
return None
# load data
with open(str(filename), 'r') as stream:
data = yaml.safe_load(stream)
# iterate data
calib = {}
for k, v in data.items():
ttype = m3d.TransformType.FromChar(v['type'])
calib[k] = m3d.TransformInterface.Factory(ttype, v['data'], unsafe=True).normalized_()
return calib
def _serialize_transform(transform):
# find transform type
transform_type = None
for ttype in m3d.TransformType.__members__.values():
if transform.isType(ttype):
transform_type = ttype
break
# serialize
data = {
'type': transform_type.toChar(),
'data': transform.toList()
}
return data
def _deserialize_transform(data, unsafe=False):
transform_type = m3d.TransformType.FromChar(data['type'])
return m3d.TransformInterface.Factory(transform_type, data['data'], unsafe=unsafe).normalized_()
[docs]class CalibrationManager:
[docs] def __init__(self):
self._graph = nx.Graph()
self._data = dict()
[docs] def clear(self):
self._graph.clear()
self._data.clear()
[docs] def items(self):
return self._data.items()
[docs] def frames(self):
return [node for node in self._graph.nodes]
[docs] def add(self, topic_from: str, topic_to: str, transformation: m3d.TransformInterface, overwrite: bool = False)\
-> bool:
if not isinstance(transformation, m3d.TransformInterface):
raise RuntimeError("Input 'transformation' is not an 'm3d.TransformInterface'.")
# prevent cycles
has_nodes = self._graph.has_node(topic_from) and self._graph.has_node(topic_to)
if has_nodes and nx.has_path(self._graph, topic_from, topic_to):
# direct connection
if overwrite and (topic_from, topic_to) in self._data:
self._data[topic_from, topic_to] = transformation
return True
# direct connection (the other way round)
elif overwrite and (topic_to, topic_from) in self._data:
del self._data[topic_to, topic_from]
self._data[topic_from, topic_to] = transformation
return True
# no overwrite or indirect connection
else:
return False
else:
# add nodes and edge
if not self._graph.has_node(topic_from):
self._graph.add_node(topic_from)
if not self._graph.has_node(topic_to):
self._graph.add_node(topic_to)
self._graph.add_edge(topic_from, topic_to)
# add transformation
self._data[topic_from, topic_to] = transformation
return True
[docs] def has(self, topic_from: str, topic_to: str) -> bool:
return nx.has_path(self._graph, topic_from, topic_to)
[docs] def get(self, topic_from: str, topic_to: str) -> Optional[m3d.TransformInterface]:
try:
# shortest path
path = nx.shortest_path(self._graph, topic_from, topic_to)
# follow path
transformation = m3d.MatrixTransform()
for i in range(len(path) - 1):
if (path[i], path[i + 1]) in self._data:
if transformation is None:
transformation = self._data[path[i], path[i + 1]]
else:
transformation *= self._data[path[i], path[i + 1]]
else:
if transformation is None:
transformation = self._data[path[i + 1], path[i]]
else:
transformation /= self._data[path[i + 1], path[i]]
# check identity transform
if transformation is None:
transformation = m3d.MatrixTransform()
return transformation
except nx.NodeNotFound:
return None
except nx.NetworkXNoPath:
return None
[docs] def remove(self, topic_from: str, topic_to: str) -> bool:
if (topic_from, topic_to) in self._data:
self._graph.remove_edge(topic_from, topic_to)
del self._data[topic_from, topic_to]
return True
elif (topic_to, topic_from) in self._data:
self._graph.remove_edge(topic_to, topic_from)
del self._data[topic_to, topic_from]
return True
else:
return False
[docs] def extend(self, other, overwrite: bool = False):
for (topic_from, topic_to), transform in other._data.items():
self.add(topic_from, topic_to, transform, overwrite=overwrite)
[docs] def copy(self):
new_obj = CalibrationManager()
new_obj._graph = self._graph.copy(as_view=False)
new_obj._data = {k: v.copy() for k, v in self._data.items()}
return new_obj
[docs] @classmethod
def load(cls, filename: Union[str, Path], unsafe: bool = False):
manager = cls()
with open(str(filename), 'r') as file:
data = yaml.safe_load(file)
for topic_from, sub_data in data.items():
# add node
if not manager._graph.has_node(topic_from):
manager._graph.add_node(topic_from)
for topic_to, d in sub_data.items():
# add node
if not manager._graph.has_node(topic_to):
manager._graph.add_node(topic_to)
# add edge
manager._graph.add_edge(topic_from, topic_to)
# add transformation
manager._data[topic_from, topic_to] = _deserialize_transform(d, unsafe=unsafe)
# check for cycles
if len(nx.cycle_basis(manager._graph)) > 0:
logger.error("Calibration file contains cycles")
return None
return manager
[docs] def save(self, filename: Union[str, Path]):
data = dict()
for (topic_from, topic_to), transformation in self._data.items():
if topic_from not in data:
data[topic_from] = dict()
data[topic_from][topic_to] = _serialize_transform(transformation)
with open(str(filename), 'w') as file:
yaml.dump(data, file)
[docs] def plot_graph(self):
plt.figure()
nx.draw(self._graph, with_labels=True, font_weight='bold')
[docs] def plot_frames(self, origin: str, length: float):
if origin not in self._graph.nodes:
raise RuntimeError(f"Frame '{origin}' does not exist.")
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
for frame in self._graph.nodes:
m = self.get(origin, frame)
if m is None:
continue
points = np.array([
[0, 0, 0],
[length, 0, 0],
[0, length, 0],
[0, 0, length],
])
points = m.transformCloud(points.T).T
ax.plot([points[0, 0], points[1, 0]], [points[0, 1], points[1, 1]], [points[0, 2], points[1, 2]],
c='tab:red')
ax.plot([points[0, 0], points[2, 0]], [points[0, 1], points[2, 1]], [points[0, 2], points[2, 2]],
c='tab:green')
ax.plot([points[0, 0], points[3, 0]], [points[0, 1], points[3, 1]], [points[0, 2], points[3, 2]],
c='tab:blue')
ax.text(points[0, 0], points[0, 1], points[0, 2], frame, size='x-small', color='black')
# axis equal
x_limits = ax.get_xlim3d()
y_limits = ax.get_ylim3d()
z_limits = ax.get_zlim3d()
x_range = abs(x_limits[1] - x_limits[0])
x_middle = np.mean(x_limits)
y_range = abs(y_limits[1] - y_limits[0])
y_middle = np.mean(y_limits)
z_range = abs(z_limits[1] - z_limits[0])
z_middle = np.mean(z_limits)
plot_radius = 0.5 * max([x_range, y_range, z_range])
ax.set_xlim3d([x_middle - plot_radius, x_middle + plot_radius])
ax.set_ylim3d([y_middle - plot_radius, y_middle + plot_radius])
ax.set_zlim3d([z_middle - plot_radius, z_middle + plot_radius])