1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911
|
# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""
This module contains a general framework for defining graphs of transformations
between coordinates, suitable for either spatial coordinates or more generalized
coordinate systems.
The fundamental idea is that each class is a node in the transformation graph,
and transitions from one node to another are defined as functions (or methods)
wrapped in transformation objects.
This module also includes more specific transformation classes for
celestial/spatial coordinate frames, generally focused around matrix-style
transformations that are typically how the algorithms are defined.
"""
from __future__ import (absolute_import, division, print_function,
unicode_literals)
import heapq
import inspect
import subprocess
from abc import ABCMeta, abstractmethod
from collections import defaultdict
import numpy as np
from ..utils.compat import suppress
from ..utils.compat.funcsigs import signature
from ..extern import six
from ..extern.six.moves import range
__all__ = ['TransformGraph', 'CoordinateTransform', 'FunctionTransform',
'StaticMatrixTransform', 'DynamicMatrixTransform', 'CompositeTransform']
class TransformGraph(object):
"""
A graph representing the paths between coordinate frames.
"""
def __init__(self):
self._graph = defaultdict(dict)
self.invalidate_cache() # generates cache entries
@property
def _cached_names(self):
if self._cached_names_dct is None:
self._cached_names_dct = dct = {}
for c in self.frame_set:
nm = getattr(c, 'name', None)
if nm is not None:
dct[nm] = c
return self._cached_names_dct
@property
def frame_set(self):
"""
A `set` of all the frame classes present in this `TransformGraph`.
"""
if self._cached_frame_set is None:
self._cached_frame_set = frm_set = set()
for a in self._graph:
frm_set.add(a)
for b in self._graph[a]:
frm_set.add(b)
return self._cached_frame_set.copy()
def invalidate_cache(self):
"""
Invalidates the cache that stores optimizations for traversing the
transform graph. This is called automatically when transforms
are added or removed, but will need to be called manually if
weights on transforms are modified inplace.
"""
self._cached_names_dct = None
self._cached_frame_set = None
self._shortestpaths = {}
self._composite_cache = {}
def add_transform(self, fromsys, tosys, transform):
"""
Add a new coordinate transformation to the graph.
Parameters
----------
fromsys : class
The coordinate frame class to start from.
tosys : class
The coordinate frame class to transform into.
transform : CoordinateTransform or similar callable
The transformation object. Typically a `CoordinateTransform` object,
although it may be some other callable that is called with the same
signature.
Raises
------
TypeError
If ``fromsys`` or ``tosys`` are not classes or ``transform`` is
not callable.
"""
if not inspect.isclass(fromsys):
raise TypeError('fromsys must be a class')
if not inspect.isclass(tosys):
raise TypeError('tosys must be a class')
if not six.callable(transform):
raise TypeError('transform must be callable')
self._graph[fromsys][tosys] = transform
self.invalidate_cache()
def remove_transform(self, fromsys, tosys, transform):
"""
Removes a coordinate transform from the graph.
Parameters
----------
fromsys : class or `None`
The coordinate frame *class* to start from. If `None`,
``transform`` will be searched for and removed (``tosys`` must
also be `None`).
tosys : class or `None`
The coordinate frame *class* to transform into. If `None`,
``transform`` will be searched for and removed (``fromsys`` must
also be `None`).
transform : callable or `None`
The transformation object to be removed or `None`. If `None`
and ``tosys`` and ``fromsys`` are supplied, there will be no
check to ensure the correct object is removed.
"""
if fromsys is None or tosys is None:
if not (tosys is None and fromsys is None):
raise ValueError('fromsys and tosys must both be None if either are')
if transform is None:
raise ValueError('cannot give all Nones to remove_transform')
# search for the requested transform by brute force and remove it
for a in self._graph:
agraph = self._graph[a]
for b in agraph:
if b is transform:
del agraph[b]
break
else:
raise ValueError('Could not find transform {0} in the '
'graph'.format(transform))
else:
if transform is None:
self._graph[fromsys].pop(tosys, None)
else:
curr = self._graph[fromsys].get(tosys, None)
if curr is transform:
self._graph[fromsys].pop(tosys)
else:
raise ValueError('Current transform from {0} to {1} is not '
'{2}'.format(fromsys, tosys, transform))
self.invalidate_cache()
def find_shortest_path(self, fromsys, tosys):
"""
Computes the shortest distance along the transform graph from
one system to another.
Parameters
----------
fromsys : class
The coordinate frame class to start from.
tosys : class
The coordinate frame class to transform into.
Returns
-------
path : list of classes or `None`
The path from ``fromsys`` to ``tosys`` as an in-order sequence
of classes. This list includes *both* ``fromsys`` and
``tosys``. Is `None` if there is no possible path.
distance : number
The total distance/priority from ``fromsys`` to ``tosys``. If
priorities are not set this is the number of transforms
needed. Is ``inf`` if there is no possible path.
"""
inf = float('inf')
# special-case the 0 or 1-path
if tosys is fromsys:
if tosys not in self._graph[fromsys]:
# Means there's no transform necessary to go from it to itself.
return [tosys], 0
if tosys in self._graph[fromsys]:
# this will also catch the case where tosys is fromsys, but has
# a defined transform.
t = self._graph[fromsys][tosys]
return [fromsys, tosys], float(t.priority if hasattr(t, 'priority') else 1)
#otherwise, need to construct the path:
if fromsys in self._shortestpaths:
# already have a cached result
fpaths = self._shortestpaths[fromsys]
if tosys in fpaths:
return fpaths[tosys]
else:
return None, inf
# use Dijkstra's algorithm to find shortest path in all other cases
nodes = []
# first make the list of nodes
for a in self._graph:
if a not in nodes:
nodes.append(a)
for b in self._graph[a]:
if b not in nodes:
nodes.append(b)
if fromsys not in nodes or tosys not in nodes:
# fromsys or tosys are isolated or not registered, so there's
# certainly no way to get from one to the other
return None, inf
edgeweights = {}
# construct another graph that is a dict of dicts of priorities
# (used as edge weights in Dijkstra's algorithm)
for a in self._graph:
edgeweights[a] = aew = {}
agraph = self._graph[a]
for b in agraph:
aew[b] = float(agraph[b].priority if hasattr(agraph[b], 'priority') else 1)
# entries in q are [distance, count, nodeobj, pathlist]
# count is needed because in py 3.x, tie-breaking fails on the nodes.
# this way, insertion order is preserved if the weights are the same
q = [[inf, i, n, []] for i, n in enumerate(nodes) if n is not fromsys]
q.insert(0, [0, -1, fromsys, []])
# this dict will store the distance to node from ``fromsys`` and the path
result = {}
# definitely starts as a valid heap because of the insert line; from the
# node to itself is always the shortest distance
while len(q) > 0:
d, orderi, n, path = heapq.heappop(q)
if d == inf:
# everything left is unreachable from fromsys, just copy them to
# the results and jump out of the loop
result[n] = (None, d)
for d, orderi, n, path in q:
result[n] = (None, d)
break
else:
result[n] = (path, d)
path.append(n)
if n not in edgeweights:
# this is a system that can be transformed to, but not from.
continue
for n2 in edgeweights[n]:
if n2 not in result: # already visited
# find where n2 is in the heap
for i in range(len(q)):
if q[i][2] == n2:
break
else:
raise ValueError('n2 not in heap - this should be impossible!')
newd = d + edgeweights[n][n2]
if newd < q[i][0]:
q[i][0] = newd
q[i][3] = list(path)
heapq.heapify(q)
# cache for later use
self._shortestpaths[fromsys] = result
return result[tosys]
def get_transform(self, fromsys, tosys):
"""
Generates and returns the `CompositeTransform` for a transformation
between two coordinate systems.
Parameters
----------
fromsys : class
The coordinate frame class to start from.
tosys : class
The coordinate frame class to transform into.
Returns
-------
trans : `CompositeTransform` or `None`
If there is a path from ``fromsys`` to ``tosys``, this is a
transform object for that path. If no path could be found, this is
`None`.
Notes
-----
This function always returns a `CompositeTransform`, because
`CompositeTransform` is slightly more adaptable in the way it can be
called than other transform classes. Specifically, it takes care of
intermediate steps of transformations in a way that is consistent with
1-hop transformations.
"""
if not inspect.isclass(fromsys):
raise TypeError('fromsys is not a class')
if not inspect.isclass(fromsys):
raise TypeError('tosys is not a class')
path, distance = self.find_shortest_path(fromsys, tosys)
if path is None:
return None
transforms = []
currsys = fromsys
for p in path[1:]: # first element is fromsys so we skip it
transforms.append(self._graph[currsys][p])
currsys = p
fttuple = (fromsys, tosys)
if fttuple not in self._composite_cache:
comptrans = CompositeTransform(transforms, fromsys, tosys,
register_graph=False)
self._composite_cache[fttuple] = comptrans
return self._composite_cache[fttuple]
def lookup_name(self, name):
"""
Tries to locate the coordinate class with the provided alias.
Parameters
----------
name : str
The alias to look up.
Returns
-------
coordcls
The coordinate class corresponding to the ``name`` or `None` if
no such class exists.
"""
return self._cached_names.get(name, None)
def get_names(self):
"""
Returns all available transform names. They will all be
valid arguments to `lookup_name`.
Returns
-------
nms : list
The aliases for coordinate systems.
"""
return list(six.iterkeys(self._cached_names))
def to_dot_graph(self, priorities=True, addnodes=[], savefn=None,
savelayout='plain', saveformat=None):
"""
Converts this transform graph to the graphviz_ DOT format.
Optionally saves it (requires `graphviz`_ be installed and on your path).
.. _graphviz: http://www.graphviz.org/
Parameters
----------
priorities : bool
If `True`, show the priority values for each transform. Otherwise,
the will not be included in the graph.
addnodes : sequence of str
Additional coordinate systems to add (this can include systems
already in the transform graph, but they will only appear once).
savefn : `None` or str
The file name to save this graph to or `None` to not save
to a file.
savelayout : str
The graphviz program to use to layout the graph (see
graphviz_ for details) or 'plain' to just save the DOT graph
content. Ignored if ``savefn`` is `None`.
saveformat : str
The graphviz output format. (e.g. the ``-Txxx`` option for
the command line program - see graphviz docs for details).
Ignored if ``savefn`` is `None`.
Returns
-------
dotgraph : str
A string with the DOT format graph.
"""
nodes = []
# find the node names
for a in self._graph:
if a not in nodes:
nodes.append(a)
for b in self._graph[a]:
if b not in nodes:
nodes.append(b)
for node in addnodes:
if node not in nodes:
nodes.append(node)
nodenames = []
invclsaliases = dict([(v, k) for k, v in six.iteritems(self._cached_names)])
for n in nodes:
if n in invclsaliases:
nodenames.append('{0} [shape=oval label="{0}\\n`{1}`"]'.format(n.__name__, invclsaliases[n]))
else:
nodenames.append(n.__name__ + '[ shape=oval ]')
edgenames = []
# Now the edges
for a in self._graph:
agraph = self._graph[a]
for b in agraph:
pri = agraph[b].priority if hasattr(agraph[b], 'priority') else 1
edgenames.append((a.__name__, b.__name__, pri))
# generate simple dot format graph
lines = ['digraph AstropyCoordinateTransformGraph {']
lines.append('; '.join(nodenames) + ';')
for enm1, enm2, weights in edgenames:
labelstr = '[ label = "{0}" ]'.format(weights) if priorities else ''
lines.append('{0} -> {1}{2};'.format(enm1, enm2, labelstr))
lines.append('')
lines.append('overlap=false')
lines.append('}')
dotgraph = '\n'.join(lines)
if savefn is not None:
if savelayout == 'plain':
with open(savefn, 'w') as f:
f.write(dotgraph)
else:
args = [savelayout]
if saveformat is not None:
args.append('-T' + saveformat)
proc = subprocess.Popen(args, stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
stderr=subprocess.PIPE)
stdout, stderr = proc.communicate(dotgraph)
if proc.returncode != 0:
raise IOError('problem running graphviz: \n' + stderr)
with open(savefn, 'w') as f:
f.write(stdout)
return dotgraph
def to_networkx_graph(self):
"""
Converts this transform graph into a networkx graph.
.. note::
You must have the `networkx <http://networkx.lanl.gov/>`_
package installed for this to work.
Returns
-------
nxgraph : `networkx.Graph <http://networkx.lanl.gov/reference/classes.graph.html>`_
This `TransformGraph` as a `networkx.Graph`_.
"""
import networkx as nx
nxgraph = nx.Graph()
# first make the nodes
for a in self._graph:
if a not in nxgraph:
nxgraph.add_node(a)
for b in self._graph[a]:
if b not in nxgraph:
nxgraph.add_node(b)
# Now the edges
for a in self._graph:
agraph = self._graph[a]
for b in agraph:
pri = agraph[b].priority if hasattr(agraph[b], 'priority') else 1
nxgraph.add_edge(a, b, weight=pri)
return nxgraph
def transform(self, transcls, fromsys, tosys, priority=1):
"""
A function decorator for defining transformations.
.. note::
If decorating a static method of a class, ``@staticmethod``
should be added *above* this decorator.
Parameters
----------
transcls : class
The class of the transformation object to create.
fromsys : class
The coordinate frame class to start from.
tosys : class
The coordinate frame class to transform into.
priority : number
The priority if this transform when finding the shortest
coordinate transform path - large numbers are lower priorities.
Returns
-------
deco : function
A function that can be called on another function as a decorator
(see example).
Notes
-----
This decorator assumes the first argument of the ``transcls``
initializer accepts a callable, and that the second and third
are ``fromsys`` and ``tosys``. If this is not true, you should just
initialize the class manually and use `add_transform` instead of
using this decorator.
Examples
--------
::
graph = TransformGraph()
class Frame1(BaseCoordinateFrame):
...
class Frame2(BaseCoordinateFrame):
...
@graph.transform(FunctionTransform, Frame1, Frame2)
def f1_to_f2(f1_obj):
... do something with f1_obj ...
return f2_obj
"""
def deco(func):
# this doesn't do anything directly with the transform because
# ``register_graph=self`` stores it in the transform graph
# automatically
transcls(func, fromsys, tosys, priority=priority,
register_graph=self)
return func
return deco
#<--------------------Define the builtin transform classes--------------------->
@six.add_metaclass(ABCMeta)
class CoordinateTransform(object):
"""
An object that transforms a coordinate from one system to another.
Subclasses must implement `__call__` with the provided signature.
They should also call this superclass's ``__init__`` in their
``__init__``.
Parameters
----------
fromsys : class
The coordinate frame class to start from.
tosys : class
The coordinate frame class to transform into.
priority : number
The priority if this transform when finding the shortest
coordinate transform path - large numbers are lower priorities.
register_graph : `TransformGraph` or `None`
A graph to register this transformation with on creation, or
`None` to leave it unregistered.
"""
def __init__(self, fromsys, tosys, priority=1, register_graph=None):
if not inspect.isclass(fromsys):
raise TypeError('fromsys must be a class')
if not inspect.isclass(tosys):
raise TypeError('tosys must be a class')
self.fromsys = fromsys
self.tosys = tosys
self.priority = float(priority)
if register_graph:
# this will do the type-checking when it adds to the graph
self.register(register_graph)
else:
if not inspect.isclass(fromsys) or not inspect.isclass(tosys):
raise TypeError('fromsys and tosys must be classes')
self.overlapping_frame_attr_names = overlap = []
if (hasattr(fromsys, 'get_frame_attr_names') and
hasattr(tosys, 'get_frame_attr_names')):
#the if statement is there so that non-frame things might be usable
#if it makes sense
for from_nm in fromsys.get_frame_attr_names():
if from_nm in tosys.get_frame_attr_names():
overlap.append(from_nm)
def register(self, graph):
"""
Add this transformation to the requested Transformation graph,
replacing anything already connecting these two coordinates.
Parameters
----------
graph : a TransformGraph object
The graph to register this transformation with.
"""
graph.add_transform(self.fromsys, self.tosys, self)
def unregister(self, graph):
"""
Remove this transformation from the requested transformation
graph.
Parameters
----------
graph : a TransformGraph object
The graph to unregister this transformation from.
Raises
------
ValueError
If this is not currently in the transform graph.
"""
graph.remove_transform(self.fromsys, self.tosys, self)
@abstractmethod
def __call__(self, fromcoord, toframe):
"""
Does the actual coordinate transformation from the ``fromsys`` class to
the ``tosys`` class.
Parameters
----------
fromcoord : fromsys object
An object of class matching ``fromsys`` that is to be transformed.
toframe : object
An object that has the attributes necessary to fully specify the
frame. That is, it must have attributes with names that match the
keys of the dictionary that ``tosys.get_frame_attr_names()``
returns. Typically this is of class ``tosys``, but it *might* be
some other class as long as it has the appropriate attributes.
Returns
-------
tocoord : tosys object
The new coordinate after the transform has been applied.
"""
class FunctionTransform(CoordinateTransform):
"""
A coordinate transformation defined by a function that accepts a
coordinate object and returns the transformed coordinate object.
Parameters
----------
func : callable
The transformation function. Should have a call signature
``func(formcoord, toframe)``. Note that, unlike
`CoordinateTransform.__call__`, ``toframe`` is assumed to be of type
``tosys`` for this function.
fromsys : class
The coordinate frame class to start from.
tosys : class
The coordinate frame class to transform into.
priority : number
The priority if this transform when finding the shortest
coordinate transform path - large numbers are lower priorities.
register_graph : `TransformGraph` or `None`
A graph to register this transformation with on creation, or
`None` to leave it unregistered.
Raises
------
TypeError
If ``func`` is not callable.
ValueError
If ``func`` cannot accept two arguments.
"""
def __init__(self, func, fromsys, tosys, priority=1, register_graph=None):
if not six.callable(func):
raise TypeError('func must be callable')
with suppress(TypeError):
sig = signature(func)
kinds = [x.kind for x in sig.parameters.values()]
if (len(x for x in kinds if x == sig.POSITIONAL_ONLY) != 2
and sig.VAR_POSITIONAL not in kinds):
raise ValueError('provided function does not accept two arguments')
self.func = func
super(FunctionTransform, self).__init__(fromsys, tosys,
priority=priority, register_graph=register_graph)
def __call__(self, fromcoord, toframe):
res = self.func(fromcoord, toframe)
if not isinstance(res, self.tosys):
raise TypeError('the transformation function yielded {0} but '
'should have been of type {1}'.format(res, self.tosys))
return res
class StaticMatrixTransform(CoordinateTransform):
"""
A coordinate transformation defined as a 3 x 3 cartesian
transformation matrix.
This is distinct from DynamicMatrixTransform in that this kind of matrix is
independent of frame attributes. That is, it depends *only* on the class of
the frame.
Parameters
----------
matrix : array-like or callable
A 3 x 3 matrix for transforming 3-vectors. In most cases will
be unitary (although this is not strictly required). If a callable,
will be called *with no arguments* to get the matrix.
fromsys : class
The coordinate frame class to start from.
tosys : class
The coordinate frame class to transform into.
priority : number
The priority if this transform when finding the shortest
coordinate transform path - large numbers are lower priorities.
register_graph : `TransformGraph` or `None`
A graph to register this transformation with on creation, or
`None` to leave it unregistered.
Raises
------
ValueError
If the matrix is not 3 x 3
"""
def __init__(self, matrix, fromsys, tosys, priority=1, register_graph=None):
if six.callable(matrix):
matrix = matrix()
self.matrix = np.array(matrix)
if self.matrix.shape != (3, 3):
raise ValueError('Provided matrix is not 3 x 3')
super(StaticMatrixTransform, self).__init__(fromsys, tosys,
priority=priority, register_graph=register_graph)
def __call__(self, fromcoord, toframe):
from .representation import UnitSphericalRepresentation
newrep = fromcoord.cartesian.transform(self.matrix)
if issubclass(fromcoord.data.__class__, UnitSphericalRepresentation):
#need to special-case this because otherwise the new class will
#think it has a valid distance
newrep = newrep.represent_as(fromcoord.data.__class__)
frameattrs = dict([(attrnm, getattr(fromcoord, attrnm))
for attrnm in self.overlapping_frame_attr_names])
return toframe.realize_frame(newrep, **frameattrs)
class DynamicMatrixTransform(CoordinateTransform):
"""
A coordinate transformation specified as a function that yields a
3 x 3 cartesian transformation matrix.
This is similar to, but distinct from StaticMatrixTransform, in that the
matrix for this class might depend on frame attributes.
Parameters
----------
matrix_func : callable
A callable that has the signature ``matrix_func(fromcoord, toframe)`` and
returns a 3 x 3 matrix that converts ``fromcoord`` in a cartesian
representation to the new coordinate system.
fromsys : class
The coordinate frame class to start from.
tosys : class
The coordinate frame class to transform into.
priority : number
The priority if this transform when finding the shortest
coordinate transform path - large numbers are lower priorities.
register_graph : `TransformGraph` or `None`
A graph to register this transformation with on creation, or
`None` to leave it unregistered.
Raises
------
TypeError
If ``matrix_func`` is not callable
"""
def __init__(self, matrix_func, fromsys, tosys, priority=1,
register_graph=None):
if not six.callable(matrix_func):
raise TypeError('matrix_func is not callable')
self.matrix_func = matrix_func
super(DynamicMatrixTransform, self).__init__(fromsys, tosys,
priority=priority, register_graph=register_graph)
def __call__(self, fromcoord, toframe):
from .representation import CartesianRepresentation, \
UnitSphericalRepresentation
transform_matrix = self.matrix_func(fromcoord, toframe)
rep = fromcoord.represent_as(CartesianRepresentation)
newrep = rep.transform(transform_matrix)
if issubclass(fromcoord.data.__class__, UnitSphericalRepresentation):
#need to special-case this because otherwise the new class will
#think it has a valid distance
newrep = newrep.represent_as(fromcoord.data.__class__)
return toframe.realize_frame(newrep)
class CompositeTransform(CoordinateTransform):
"""
A transformation constructed by combining together a series of single-step
transformations.
Note that the intermediate frame objects are constructed using any frame
attributes in ``toframe`` or ``fromframe`` that overlap with the intermediate
frame (``toframe`` favored over ``fromframe`` if there's a conflict). Any frame
attributes that are not present use the defaults.
Parameters
----------
transforms : sequence of `CoordinateTransform` objects
The sequence of transformations to apply.
fromsys : class
The coordinate frame class to start from.
tosys : class
The coordinate frame class to transform into.
priority : number
The priority if this transform when finding the shortest
coordinate transform path - large numbers are lower priorities.
register_graph : `TransformGraph` or `None`
A graph to register this transformation with on creation, or
`None` to leave it unregistered.
collapse_static_mats : bool
If `True`, consecutive `StaticMatrixTransform` will be collapsed into a
single transformation to speed up the calculation.
"""
def __init__(self, transforms, fromsys, tosys, priority=1,
register_graph=None, collapse_static_mats=True):
super(CompositeTransform, self).__init__(fromsys, tosys,
priority=priority,
register_graph=register_graph)
if collapse_static_mats:
transforms = self._combine_statics(transforms)
self.transforms = tuple(transforms)
def _combine_statics(self, transforms):
"""
Combines together sequences of `StaticMatrixTransform`s into a single
transform and returns it.
"""
newtrans = []
for currtrans in transforms:
lasttrans = newtrans[-1] if len(newtrans) > 0 else None
if (isinstance(lasttrans, StaticMatrixTransform) and
isinstance(currtrans, StaticMatrixTransform)):
combinedmat = np.dot(lasttrans.matrix, currtrans.matrix)
newtrans[-1] = StaticMatrixTransform(combinedmat,
lasttrans.fromsys,
currtrans.tosys)
else:
newtrans.append(currtrans)
return newtrans
def __call__(self, fromcoord, toframe):
curr_coord = fromcoord
for t in self.transforms:
#build an intermediate frame with attributes taken from either
#`fromframe`, or if not there, `toframe`, or if not there, use
#the defaults
#TODO: caching this information when creating the transform may
# speed things up a lot
frattrs = {}
for inter_frame_attr_nm in t.tosys.get_frame_attr_names():
if hasattr(toframe, inter_frame_attr_nm):
attr = getattr(toframe, inter_frame_attr_nm)
frattrs[inter_frame_attr_nm] = attr
elif hasattr(fromcoord, inter_frame_attr_nm):
attr = getattr(fromcoord, inter_frame_attr_nm)
frattrs[inter_frame_attr_nm] = attr
curr_toframe = t.tosys(**frattrs)
curr_coord = t(curr_coord, curr_toframe)
# this is safe even in the case where self.transforms is empty, because
# coordinate objects are immutible, so copying is not needed
return curr_coord
|