Source code for gdt.core.coords.quaternion

# CONTAINS TECHNICAL DATA/COMPUTER SOFTWARE DELIVERED TO THE U.S. GOVERNMENT WITH UNLIMITED RIGHTS
#
# Contract No.: CA 80MSFC17M0022
# Contractor Name: Universities Space Research Association
# Contractor Address: 7178 Columbia Gateway Drive, Columbia, MD 21046
#
# Copyright 2017-2022 by Universities Space Research Association (USRA). All rights reserved.
#
# Developed by: William Cleveland and Adam Goldstein
#               Universities Space Research Association
#               Science and Technology Institute
#               https://sti.usra.edu
#
# Developed by: Daniel Kocevski
#               National Aeronautics and Space Administration (NASA)
#               Marshall Space Flight Center
#               Astrophysics Branch (ST-12)
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing permissions and limitations under the
# License.
#
from collections.abc import Sequence
from typing import Union, List
import numpy as np
from astropy.coordinates import Attribute
from astropy.time import Time
from scipy.spatial.transform import Rotation, Slerp
from astropy.utils.data_info import MixinInfo

from gdt.core.types import ArrayBase, assert_array_valid

VECTORS = Union[Sequence, np.ndarray]
QUATERNIONS = VECTORS
SCALARS = Union[int, float, Sequence, np.ndarray]


[docs]class Quaternion(ArrayBase): """ A class for containing a quaternion and performing quaternion operations. The quaternion is represented internally as a two-dimensional array of size (n,4) where the elements follow numpy's row-major order (i.e. array[row, colum]) and each quaternion is stored by row with the columns representing i(x), j(y), k(z), and w in that order (scalar-last). Overloaded operations supported are: * Addition * Subtraction * Multiplication * Division * Negation * Equivalence Parameters: quaternion (np.array): Either a 1D 4-element array containing a single quaternion or a 2D array of shape (`n`, 4) containing `n` quaternions. scalar_first (bool, optional): Set to True if the input arrays are in scalar-first representation, False if in scalar-last representation. Default is False. """ # Attributes used by ArrayBase _rowsize = 4 _prefix = '(x, y, z, w) ' _dtype = np.float64 _name = 'quaternion' info = MixinInfo() """(:class:`astropy.utils.data_info.MixinInfo`): Used by astropy Table to store column information """ EPSILON = 1e-12 """(float): How close the floating point numbers need to be before they are considered equal.""" def __init__(self, quaternion: QUATERNIONS, scalar_first: bool = False): super().__init__(quaternion) if scalar_first: self._array = np.hstack((self._array[:, 1:4], self._array[:, 0].reshape(-1, 1))) @property def conjugate(self) -> 'Quaternion': """(:class:`Quaternion`): Conjugate of the quaternion(s)""" new = self._obj() new._array = np.hstack((-self._array[:, :3] + 0.0, self._array[:, 3].reshape(-1, 1))) return new @property def inverse(self) -> 'Quaternion': """(:class:`Quaternion`): The inverse of the quaternion(s)""" new = self.conjugate new._array = new._array / (self._array ** 2).sum(axis=1)[:, np.newaxis] return new @property def norm(self) -> Union[float, np.ndarray]: """(float or np.array): The norm of the quaternion(s)""" return self._scale(np.sqrt((self._array ** 2).sum(axis=1))) @property def rotation(self) -> Rotation: """(:class:`scipy.spatial.transform.Rotation`): The rotation matrix for the quaternion(s)""" return Rotation(self.scalar_last) @property def scalar_first(self): """(np.array): Array of quaternion values with the scalar as the first column""" val = np.hstack((self._array[:, 3].reshape(-1, 1), self._array[:, :3])) return self._scale(val) @property def scalar_last(self): """(np.array): An array of quaternion values with the scalar as the last column""" return self._value @property def unit(self) -> 'Quaternion': """(:class:`Quaternion`): The quaternion(s) as unit quaternion(s)""" new = self._obj() new._array = self._array / np.asarray(self.norm).reshape(-1, 1) return new @property def w(self): """(np.array): The w value(s) in the quaternion""" return self._scale(self._array[:, 3]) @property def xyz(self): """(np.array): The vector portion of the quaternion(s)""" return self._scale(self._array[:, :3]) @property def x(self): """(np.array): The x value(s) in the quaternion""" return self._scale(self._array[:, 0]) @property def y(self): """(np.array): The y value(s) in the quaternion""" return self._scale(self._array[:, 1]) @property def z(self): """(np.array): The z value(s) in the quaternion""" return self._scale(self._array[:, 2])
[docs] def dot(self, other): """Return the dot product of this quaternion and another. Args: other (:class:`Quaternion`): The other quaternion Returns: (:class:`Quaternion`) """ return self._scale(np.einsum('ij,ij->i', self._array, other._array))
[docs] def equal_rotation(self, other): """Determine if another quaternion represents the same rotation as this quaternion. Args: other (:class:`Quaternion`): The other quaternion Returns: (bool or np.array) """ self_u = self.unit other_u = other.unit return np.all(np.abs(self_u.dot(other_u)) > 1.0 - self.EPSILON)
[docs] def round(self, decimals: int) -> 'Quaternion': """Returns the quaternion(s) with the components rounded to the given decimal places Args: decimals: Number of decimal places to round the components of the quaternion. Returns: (:class:`Quaternion`) """ new = self._obj() new._array = np.round(self._array, decimals=decimals) return new
[docs] @classmethod def from_rotation(cls, rot: Rotation) -> 'Quaternion': """Create a quaternion from SciPy Rotation objects. Parameters: rot (scipy.spatial.transform.Rotation): The SciPy Rotation object Returns: (:class:`Quaternion`) """ return cls(rot.as_quat(), scalar_first=False)
[docs] @classmethod def from_vectors(cls, vec1: Union[np.ndarray, List], vec2: Union[np.ndarray, List]): """Create a quaternion from the rotation between two vectors. Args: vec1 (np.array): A cartesian vector vec2 (np.array): Another cartesian vector Returns: (:class:`Quaternion`) """ # verify dimensions and type vec1 = np.asarray(vec1) vec2 = np.asarray(vec2) assert_array_valid(vec1, rowsize=3, name='vec1') assert_array_valid(vec2, rowsize=3, name='vec2') if vec1.shape == (3,): vals = np.hstack((np.cross(vec1, vec2), np.linalg.norm(vec1) * np.linalg.norm(vec2) + np.dot(vec1, vec2))) else: vals = np.hstack((np.cross(vec1, vec2), np.asarray([np.linalg.norm(vec1, axis=1) * np.linalg.norm(vec2, axis=1) + np.einsum('ij,ij->i', vec1, vec2)]).T)) obj = cls(vals) return obj
[docs] @classmethod def from_xyz_w(cls, xyz: VECTORS, w: SCALARS) -> 'Quaternion': """Create a quaternion from a vector and scalar. Args: xyz: Represents the values for the x, y and z axis w: Represents the rotation for the resultant (w) Returns: (:class:`Quaternion`) """ xyz = np.asarray(xyz) w = np.asarray(w) assert_array_valid(xyz, rowsize=3, name='xyz') assert_array_valid(w, rowsize=1, name='w') if xyz.shape == (3,) and w.shape == (): value = np.concatenate((xyz, w.reshape(1,))) else: value = np.hstack((xyz.reshape(-1, 3), w.reshape(-1, 1))) return cls(value)
def _setup_interpolation(self, unix_tai: Union[float, np.ndarray]) -> Slerp: """Called to initialize the interpolation function""" return Slerp(unix_tai, Rotation.from_quat(self)) def __add__(self, other): """Perform addition. Only the addition of another Quaternion is allowed. Parameters: other: The addend """ if not isinstance(other, self.__class__): raise TypeError('Only addition of two quaternions are supported.') new = self._obj() new._array = self._array + other._array return new def __eq__(self, other): """Determine if the two quaternions equate to the same rotation.""" if self._array.shape != other._array.shape: return False return np.allclose(self._array, other._array, atol=self.EPSILON) def __mul__(self, other): """Perform multiplication. If it's Quaternion * Quaternion, then Hamilton Product is used to speed up the operation. Otherwise, scalar multiplication is performed following Numpy broadcast rules. Parameters: other: The multiplier """ new = self._obj() if isinstance(other, self.__class__): s = self._array o = other._array new._array = np.hstack(( # Using Hamilton Product to reduce the time needed by half # The resulting stack will be in (x,y,z,w) order with the components having the following indexes: # x = [:,0], y = [:,1], z = [:,2], w = [:,3] (s[:, 3] * o[:, 0] + s[:, 0] * o[:, 3] + s[:, 1] * o[:, 2] - s[:, 2] * o[:, 1])[:, np.newaxis], (s[:, 3] * o[:, 1] - s[:, 0] * o[:, 2] + s[:, 1] * o[:, 3] + s[:, 2] * o[:, 0])[:, np.newaxis], (s[:, 3] * o[:, 2] + s[:, 0] * o[:, 1] - s[:, 1] * o[:, 0] + s[:, 2] * o[:, 3])[:, np.newaxis], (s[:, 3] * o[:, 3] - s[:, 0] * o[:, 0] - s[:, 1] * o[:, 1] - s[:, 2] * o[:, 2])[:, np.newaxis] )) else: new._array = self._array * other return new def __neg__(self): """Returns the negative of the quaternion(s)""" new = self._obj() new._array = -self._array return new def __sub__(self, other): """Perform subtraction. Only the subtraction of another Quaternion is allowed. Parameters: other: The subtrahend """ if not isinstance(other, self.__class__): raise TypeError('Only addition of two quaternions are supported.') new = self._obj() new._array = self._array - other._array return new def __truediv__(self, other) -> 'Quaternion': """Perform division. If it's Quaternion / Quaternion, then the operation is performed by multiplying with the inverse of other(s). Otherwise, scalar division is performed following Numpy broadcast rules. Parameters: other: The divisor """ if isinstance(other, self.__class__): inv = other.inverse new = self.__mul__(inv) else: new = self._obj() new._array = self._array / other return new
class QuaternionAttribute(Attribute): """ A quaternion attribute class to use with astropy.coordinates.BaseCoordinateFrame """ def convert_input(self, value): """Function called by Astropy Representation/Frame system. This function verifies that the value is a quaternion. """ if value is None: converted = False elif isinstance(value, (tuple, list, np.ndarray)): try: value = Quaternion(value) converted = True except (ValueError, TypeError): raise TypeError('Value must be a Quaternion object or an array that can be converted to a quaternion.') elif isinstance(value, Quaternion): converted = False else: raise TypeError('Value must be a Quaternion object or an array that can be converted to a quaternion.') return value, converted