File: transformations.py

package info (click to toggle)
python-astropy 1.3-8~bpo8%2B2
  • links: PTS, VCS
  • area: main
  • in suites: jessie-backports
  • size: 44,292 kB
  • sloc: ansic: 160,360; python: 137,322; sh: 11,493; lex: 7,638; yacc: 4,956; xml: 1,796; makefile: 474; cpp: 364
file content (911 lines) | stat: -rw-r--r-- 33,474 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
# Licensed under a 3-clause BSD style license - see LICENSE.rst

"""
This module contains a general framework for defining graphs of transformations
between coordinates, suitable for either spatial coordinates or more generalized
coordinate systems.

The fundamental idea is that each class is a node in the transformation graph,
and transitions from one node to another are defined as functions (or methods)
wrapped in transformation objects.

This module also includes more specific transformation classes for
celestial/spatial coordinate frames, generally focused around matrix-style
transformations that are typically how the algorithms are defined.
"""

from __future__ import (absolute_import, division, print_function,
                        unicode_literals)

import heapq
import inspect
import subprocess

from abc import ABCMeta, abstractmethod
from collections import defaultdict

import numpy as np

from ..utils.compat import suppress
from ..utils.compat.funcsigs import signature
from ..extern import six
from ..extern.six.moves import range


__all__ = ['TransformGraph', 'CoordinateTransform', 'FunctionTransform',
           'StaticMatrixTransform', 'DynamicMatrixTransform', 'CompositeTransform']


class TransformGraph(object):
    """
    A graph representing the paths between coordinate frames.
    """

    def __init__(self):
        self._graph = defaultdict(dict)
        self.invalidate_cache()  # generates cache entries

    @property
    def _cached_names(self):
        if self._cached_names_dct is None:
            self._cached_names_dct = dct = {}
            for c in self.frame_set:
                nm = getattr(c, 'name', None)
                if nm is not None:
                    dct[nm] = c

        return self._cached_names_dct

    @property
    def frame_set(self):
        """
        A `set` of all the frame classes present in this `TransformGraph`.
        """
        if self._cached_frame_set is None:
            self._cached_frame_set = frm_set = set()
            for a in self._graph:
                frm_set.add(a)
                for b in self._graph[a]:
                    frm_set.add(b)

        return self._cached_frame_set.copy()

    def invalidate_cache(self):
        """
        Invalidates the cache that stores optimizations for traversing the
        transform graph.  This is called automatically when transforms
        are added or removed, but will need to be called manually if
        weights on transforms are modified inplace.
        """
        self._cached_names_dct = None
        self._cached_frame_set = None
        self._shortestpaths = {}
        self._composite_cache = {}

    def add_transform(self, fromsys, tosys, transform):
        """
        Add a new coordinate transformation to the graph.

        Parameters
        ----------
        fromsys : class
            The coordinate frame class to start from.
        tosys : class
            The coordinate frame class to transform into.
        transform : CoordinateTransform or similar callable
            The transformation object. Typically a `CoordinateTransform` object,
            although it may be some other callable that is called with the same
            signature.

        Raises
        ------
        TypeError
            If ``fromsys`` or ``tosys`` are not classes or ``transform`` is
            not callable.
        """

        if not inspect.isclass(fromsys):
            raise TypeError('fromsys must be a class')
        if not inspect.isclass(tosys):
            raise TypeError('tosys must be a class')
        if not six.callable(transform):
            raise TypeError('transform must be callable')

        self._graph[fromsys][tosys] = transform
        self.invalidate_cache()

    def remove_transform(self, fromsys, tosys, transform):
        """
        Removes a coordinate transform from the graph.

        Parameters
        ----------
        fromsys : class or `None`
            The coordinate frame *class* to start from. If `None`,
            ``transform`` will be searched for and removed (``tosys`` must
            also be `None`).
        tosys : class or `None`
            The coordinate frame *class* to transform into. If `None`,
            ``transform`` will be searched for and removed (``fromsys`` must
            also be `None`).
        transform : callable or `None`
            The transformation object to be removed or `None`.  If `None`
            and ``tosys`` and ``fromsys`` are supplied, there will be no
            check to ensure the correct object is removed.
        """
        if fromsys is None or tosys is None:
            if not (tosys is None and fromsys is None):
                raise ValueError('fromsys and tosys must both be None if either are')
            if transform is None:
                raise ValueError('cannot give all Nones to remove_transform')

            # search for the requested transform by brute force and remove it
            for a in self._graph:
                agraph = self._graph[a]
                for b in agraph:
                    if b is transform:
                        del agraph[b]
                        break
            else:
                raise ValueError('Could not find transform {0} in the '
                                 'graph'.format(transform))

        else:
            if transform is None:
                self._graph[fromsys].pop(tosys, None)
            else:
                curr = self._graph[fromsys].get(tosys, None)
                if curr is transform:
                    self._graph[fromsys].pop(tosys)
                else:
                    raise ValueError('Current transform from {0} to {1} is not '
                                     '{2}'.format(fromsys, tosys, transform))
        self.invalidate_cache()

    def find_shortest_path(self, fromsys, tosys):
        """
        Computes the shortest distance along the transform graph from
        one system to another.

        Parameters
        ----------
        fromsys : class
            The coordinate frame class to start from.
        tosys : class
            The coordinate frame class to transform into.

        Returns
        -------
        path : list of classes or `None`
            The path from ``fromsys`` to ``tosys`` as an in-order sequence
            of classes.  This list includes *both* ``fromsys`` and
            ``tosys``. Is `None` if there is no possible path.
        distance : number
            The total distance/priority from ``fromsys`` to ``tosys``.  If
            priorities are not set this is the number of transforms
            needed. Is ``inf`` if there is no possible path.
        """

        inf = float('inf')

        # special-case the 0 or 1-path
        if tosys is fromsys:
            if tosys not in self._graph[fromsys]:
                # Means there's no transform necessary to go from it to itself.
                return [tosys], 0
        if tosys in self._graph[fromsys]:
            # this will also catch the case where tosys is fromsys, but has
            # a defined transform.
            t = self._graph[fromsys][tosys]
            return [fromsys, tosys], float(t.priority if hasattr(t, 'priority') else 1)

        #otherwise, need to construct the path:

        if fromsys in self._shortestpaths:
            # already have a cached result
            fpaths = self._shortestpaths[fromsys]
            if tosys in fpaths:
                return fpaths[tosys]
            else:
                return None, inf

        # use Dijkstra's algorithm to find shortest path in all other cases

        nodes = []
        # first make the list of nodes
        for a in self._graph:
            if a not in nodes:
                nodes.append(a)
            for b in self._graph[a]:
                if b not in nodes:
                    nodes.append(b)

        if fromsys not in nodes or tosys not in nodes:
            # fromsys or tosys are isolated or not registered, so there's
            # certainly no way to get from one to the other
            return None, inf

        edgeweights = {}
        # construct another graph that is a dict of dicts of priorities
        # (used as edge weights in Dijkstra's algorithm)
        for a in self._graph:
            edgeweights[a] = aew = {}
            agraph = self._graph[a]
            for b in agraph:
                aew[b] = float(agraph[b].priority if hasattr(agraph[b], 'priority') else 1)

        # entries in q are [distance, count, nodeobj, pathlist]
        # count is needed because in py 3.x, tie-breaking fails on the nodes.
        # this way, insertion order is preserved if the weights are the same
        q = [[inf, i, n, []] for i, n in enumerate(nodes) if n is not fromsys]
        q.insert(0, [0, -1, fromsys, []])

        # this dict will store the distance to node from ``fromsys`` and the path
        result = {}

        # definitely starts as a valid heap because of the insert line; from the
        # node to itself is always the shortest distance
        while len(q) > 0:
            d, orderi, n, path = heapq.heappop(q)

            if d == inf:
                # everything left is unreachable from fromsys, just copy them to
                # the results and jump out of the loop
                result[n] = (None, d)
                for d, orderi, n, path in q:
                    result[n] = (None, d)
                break
            else:
                result[n] = (path, d)
                path.append(n)
                if n not in edgeweights:
                    # this is a system that can be transformed to, but not from.
                    continue
                for n2 in edgeweights[n]:
                    if n2 not in result:  # already visited
                        # find where n2 is in the heap
                        for i in range(len(q)):
                            if q[i][2] == n2:
                                break
                        else:
                            raise ValueError('n2 not in heap - this should be impossible!')

                        newd = d + edgeweights[n][n2]
                        if newd < q[i][0]:
                            q[i][0] = newd
                            q[i][3] = list(path)
                            heapq.heapify(q)

        # cache for later use
        self._shortestpaths[fromsys] = result
        return result[tosys]

    def get_transform(self, fromsys, tosys):
        """
        Generates and returns the `CompositeTransform` for a transformation
        between two coordinate systems.

        Parameters
        ----------
        fromsys : class
            The coordinate frame class to start from.
        tosys : class
            The coordinate frame class to transform into.

        Returns
        -------
        trans : `CompositeTransform` or `None`
            If there is a path from ``fromsys`` to ``tosys``, this is a
            transform object for that path.   If no path could be found, this is
            `None`.

        Notes
        -----
        This function always returns a `CompositeTransform`, because
        `CompositeTransform` is slightly more adaptable in the way it can be
        called than other transform classes. Specifically, it takes care of
        intermediate steps of transformations in a way that is consistent with
        1-hop transformations.

        """
        if not inspect.isclass(fromsys):
            raise TypeError('fromsys is not a class')
        if not inspect.isclass(fromsys):
            raise TypeError('tosys is not a class')

        path, distance = self.find_shortest_path(fromsys, tosys)

        if path is None:
            return None

        transforms = []
        currsys = fromsys
        for p in path[1:]:  # first element is fromsys so we skip it
            transforms.append(self._graph[currsys][p])
            currsys = p

        fttuple = (fromsys, tosys)
        if fttuple not in self._composite_cache:
            comptrans = CompositeTransform(transforms, fromsys, tosys,
                                           register_graph=False)
            self._composite_cache[fttuple] = comptrans
        return self._composite_cache[fttuple]

    def lookup_name(self, name):
        """
        Tries to locate the coordinate class with the provided alias.

        Parameters
        ----------
        name : str
            The alias to look up.

        Returns
        -------
        coordcls
            The coordinate class corresponding to the ``name`` or `None` if
            no such class exists.
        """

        return self._cached_names.get(name, None)

    def get_names(self):
        """
        Returns all available transform names. They will all be
        valid arguments to `lookup_name`.

        Returns
        -------
        nms : list
            The aliases for coordinate systems.
        """
        return list(six.iterkeys(self._cached_names))

    def to_dot_graph(self, priorities=True, addnodes=[], savefn=None,
                     savelayout='plain', saveformat=None):
        """
        Converts this transform graph to the graphviz_ DOT format.

        Optionally saves it (requires `graphviz`_ be installed and on your path).

        .. _graphviz: http://www.graphviz.org/

        Parameters
        ----------
        priorities : bool
            If `True`, show the priority values for each transform.  Otherwise,
            the will not be included in the graph.
        addnodes : sequence of str
            Additional coordinate systems to add (this can include systems
            already in the transform graph, but they will only appear once).
        savefn : `None` or str
            The file name to save this graph to or `None` to not save
            to a file.
        savelayout : str
            The graphviz program to use to layout the graph (see
            graphviz_ for details) or 'plain' to just save the DOT graph
            content. Ignored if ``savefn`` is `None`.
        saveformat : str
            The graphviz output format. (e.g. the ``-Txxx`` option for
            the command line program - see graphviz docs for details).
            Ignored if ``savefn`` is `None`.

        Returns
        -------
        dotgraph : str
            A string with the DOT format graph.
        """

        nodes = []
        # find the node names
        for a in self._graph:
            if a not in nodes:
                nodes.append(a)
            for b in self._graph[a]:
                if b not in nodes:
                    nodes.append(b)
        for node in addnodes:
            if node not in nodes:
                nodes.append(node)
        nodenames = []
        invclsaliases = dict([(v, k) for k, v in six.iteritems(self._cached_names)])
        for n in nodes:
            if n in invclsaliases:
                nodenames.append('{0} [shape=oval label="{0}\\n`{1}`"]'.format(n.__name__, invclsaliases[n]))
            else:
                nodenames.append(n.__name__ + '[ shape=oval ]')

        edgenames = []
        # Now the edges
        for a in self._graph:
            agraph = self._graph[a]
            for b in agraph:
                pri = agraph[b].priority if hasattr(agraph[b], 'priority') else 1
                edgenames.append((a.__name__, b.__name__, pri))

        # generate simple dot format graph
        lines = ['digraph AstropyCoordinateTransformGraph {']
        lines.append('; '.join(nodenames) + ';')
        for enm1, enm2, weights in edgenames:
            labelstr = '[ label = "{0}" ]'.format(weights) if priorities else ''
            lines.append('{0} -> {1}{2};'.format(enm1, enm2, labelstr))
        lines.append('')
        lines.append('overlap=false')
        lines.append('}')
        dotgraph = '\n'.join(lines)

        if savefn is not None:
            if savelayout == 'plain':
                with open(savefn, 'w') as f:
                    f.write(dotgraph)
            else:
                args = [savelayout]
                if saveformat is not None:
                    args.append('-T' + saveformat)
                proc = subprocess.Popen(args, stdin=subprocess.PIPE,
                                        stdout=subprocess.PIPE,
                                        stderr=subprocess.PIPE)
                stdout, stderr = proc.communicate(dotgraph)
                if proc.returncode != 0:
                    raise IOError('problem running graphviz: \n' + stderr)

                with open(savefn, 'w') as f:
                    f.write(stdout)

        return dotgraph

    def to_networkx_graph(self):
        """
        Converts this transform graph into a networkx graph.

        .. note::
            You must have the `networkx <http://networkx.lanl.gov/>`_
            package installed for this to work.

        Returns
        -------
        nxgraph : `networkx.Graph <http://networkx.lanl.gov/reference/classes.graph.html>`_
            This `TransformGraph` as a `networkx.Graph`_.
        """
        import networkx as nx

        nxgraph = nx.Graph()

        # first make the nodes
        for a in self._graph:
            if a not in nxgraph:
                nxgraph.add_node(a)
            for b in self._graph[a]:
                if b not in nxgraph:
                    nxgraph.add_node(b)

        # Now the edges
        for a in self._graph:
            agraph = self._graph[a]
            for b in agraph:
                pri = agraph[b].priority if hasattr(agraph[b], 'priority') else 1
                nxgraph.add_edge(a, b, weight=pri)

        return nxgraph

    def transform(self, transcls, fromsys, tosys, priority=1):
        """
        A function decorator for defining transformations.

        .. note::
            If decorating a static method of a class, ``@staticmethod``
            should be  added *above* this decorator.

        Parameters
        ----------
        transcls : class
            The class of the transformation object to create.
        fromsys : class
            The coordinate frame class to start from.
        tosys : class
            The coordinate frame class to transform into.
        priority : number
            The priority if this transform when finding the shortest
            coordinate transform path - large numbers are lower priorities.

        Returns
        -------
        deco : function
            A function that can be called on another function as a decorator
            (see example).

        Notes
        -----
        This decorator assumes the first argument of the ``transcls``
        initializer accepts a callable, and that the second and third
        are ``fromsys`` and ``tosys``. If this is not true, you should just
        initialize the class manually and use `add_transform` instead of
        using this decorator.

        Examples
        --------

        ::

            graph = TransformGraph()

            class Frame1(BaseCoordinateFrame):
               ...

            class Frame2(BaseCoordinateFrame):
                ...

            @graph.transform(FunctionTransform, Frame1, Frame2)
            def f1_to_f2(f1_obj):
                ... do something with f1_obj ...
                return f2_obj


        """
        def deco(func):
            # this doesn't do anything directly with the transform because
            # ``register_graph=self`` stores it in the transform graph
            # automatically
            transcls(func, fromsys, tosys, priority=priority,
                     register_graph=self)
            return func
        return deco


#<--------------------Define the builtin transform classes--------------------->

@six.add_metaclass(ABCMeta)
class CoordinateTransform(object):
    """
    An object that transforms a coordinate from one system to another.
    Subclasses must implement `__call__` with the provided signature.
    They should also call this superclass's ``__init__`` in their
    ``__init__``.

    Parameters
    ----------
    fromsys : class
        The coordinate frame class to start from.
    tosys : class
        The coordinate frame class to transform into.
    priority : number
        The priority if this transform when finding the shortest
        coordinate transform path - large numbers are lower priorities.
    register_graph : `TransformGraph` or `None`
        A graph to register this transformation with on creation, or
        `None` to leave it unregistered.
    """

    def __init__(self, fromsys, tosys, priority=1, register_graph=None):
        if not inspect.isclass(fromsys):
            raise TypeError('fromsys must be a class')
        if not inspect.isclass(tosys):
            raise TypeError('tosys must be a class')

        self.fromsys = fromsys
        self.tosys = tosys
        self.priority = float(priority)

        if register_graph:
            # this will do the type-checking when it adds to the graph
            self.register(register_graph)
        else:
            if not inspect.isclass(fromsys) or not inspect.isclass(tosys):
                raise TypeError('fromsys and tosys must be classes')

        self.overlapping_frame_attr_names = overlap = []
        if (hasattr(fromsys, 'get_frame_attr_names') and
                hasattr(tosys, 'get_frame_attr_names')):
            #the if statement is there so that non-frame things might be usable
            #if it makes sense
            for from_nm in fromsys.get_frame_attr_names():
                if from_nm in tosys.get_frame_attr_names():
                    overlap.append(from_nm)

    def register(self, graph):
        """
        Add this transformation to the requested Transformation graph,
        replacing anything already connecting these two coordinates.

        Parameters
        ----------
        graph : a TransformGraph object
            The graph to register this transformation with.
        """
        graph.add_transform(self.fromsys, self.tosys, self)

    def unregister(self, graph):
        """
        Remove this transformation from the requested transformation
        graph.

        Parameters
        ----------
        graph : a TransformGraph object
            The graph to unregister this transformation from.

        Raises
        ------
        ValueError
            If this is not currently in the transform graph.
        """
        graph.remove_transform(self.fromsys, self.tosys, self)

    @abstractmethod
    def __call__(self, fromcoord, toframe):
        """
        Does the actual coordinate transformation from the ``fromsys`` class to
        the ``tosys`` class.

        Parameters
        ----------
        fromcoord : fromsys object
            An object of class matching ``fromsys`` that is to be transformed.
        toframe : object
            An object that has the attributes necessary to fully specify the
            frame.  That is, it must have attributes with names that match the
            keys of the dictionary that ``tosys.get_frame_attr_names()``
            returns. Typically this is of class ``tosys``, but it *might* be
            some other class as long as it has the appropriate attributes.

        Returns
        -------
        tocoord : tosys object
            The new coordinate after the transform has been applied.
        """


class FunctionTransform(CoordinateTransform):
    """
    A coordinate transformation defined by a function that accepts a
    coordinate object and returns the transformed coordinate object.

    Parameters
    ----------
    func : callable
        The transformation function. Should have a call signature
        ``func(formcoord, toframe)``. Note that, unlike
        `CoordinateTransform.__call__`, ``toframe`` is assumed to be of type
        ``tosys`` for this function.
    fromsys : class
        The coordinate frame class to start from.
    tosys : class
        The coordinate frame class to transform into.
    priority : number
        The priority if this transform when finding the shortest
        coordinate transform path - large numbers are lower priorities.
    register_graph : `TransformGraph` or `None`
        A graph to register this transformation with on creation, or
        `None` to leave it unregistered.

    Raises
    ------
    TypeError
        If ``func`` is not callable.
    ValueError
        If ``func`` cannot accept two arguments.


    """
    def __init__(self, func, fromsys, tosys, priority=1, register_graph=None):
        if not six.callable(func):
            raise TypeError('func must be callable')

        with suppress(TypeError):
            sig = signature(func)
            kinds = [x.kind for x in sig.parameters.values()]
            if (len(x for x in kinds if x == sig.POSITIONAL_ONLY) != 2
                and sig.VAR_POSITIONAL not in kinds):
                raise ValueError('provided function does not accept two arguments')

        self.func = func

        super(FunctionTransform, self).__init__(fromsys, tosys,
            priority=priority, register_graph=register_graph)

    def __call__(self, fromcoord, toframe):
        res = self.func(fromcoord, toframe)
        if not isinstance(res, self.tosys):
            raise TypeError('the transformation function yielded {0} but '
                'should have been of type {1}'.format(res, self.tosys))
        return res


class StaticMatrixTransform(CoordinateTransform):
    """
    A coordinate transformation defined as a 3 x 3 cartesian
    transformation matrix.

    This is distinct from DynamicMatrixTransform in that this kind of matrix is
    independent of frame attributes.  That is, it depends *only* on the class of
    the frame.

    Parameters
    ----------
    matrix : array-like or callable
        A 3 x 3 matrix for transforming 3-vectors. In most cases will
        be unitary (although this is not strictly required). If a callable,
        will be called *with no arguments* to get the matrix.
    fromsys : class
        The coordinate frame class to start from.
    tosys : class
        The coordinate frame class to transform into.
    priority : number
        The priority if this transform when finding the shortest
        coordinate transform path - large numbers are lower priorities.
    register_graph : `TransformGraph` or `None`
        A graph to register this transformation with on creation, or
        `None` to leave it unregistered.

    Raises
    ------
    ValueError
        If the matrix is not 3 x 3

    """
    def __init__(self, matrix, fromsys, tosys, priority=1, register_graph=None):
        if six.callable(matrix):
            matrix = matrix()
        self.matrix = np.array(matrix)

        if self.matrix.shape != (3, 3):
            raise ValueError('Provided matrix is not 3 x 3')

        super(StaticMatrixTransform, self).__init__(fromsys, tosys,
            priority=priority, register_graph=register_graph)

    def __call__(self, fromcoord, toframe):
        from .representation import UnitSphericalRepresentation

        newrep = fromcoord.cartesian.transform(self.matrix)
        if issubclass(fromcoord.data.__class__, UnitSphericalRepresentation):
            #need to special-case this because otherwise the new class will
            #think it has a valid distance
            newrep = newrep.represent_as(fromcoord.data.__class__)

        frameattrs = dict([(attrnm, getattr(fromcoord, attrnm))
                           for attrnm in self.overlapping_frame_attr_names])

        return toframe.realize_frame(newrep, **frameattrs)


class DynamicMatrixTransform(CoordinateTransform):
    """
    A coordinate transformation specified as a function that yields a
    3 x 3 cartesian transformation matrix.

    This is similar to, but distinct from StaticMatrixTransform, in that the
    matrix for this class might depend on frame attributes.

    Parameters
    ----------
    matrix_func : callable
        A callable that has the signature ``matrix_func(fromcoord, toframe)`` and
        returns a 3 x 3 matrix that converts ``fromcoord`` in a cartesian
        representation to the new coordinate system.
    fromsys : class
        The coordinate frame class to start from.
    tosys : class
        The coordinate frame class to transform into.
    priority : number
        The priority if this transform when finding the shortest
        coordinate transform path - large numbers are lower priorities.
    register_graph : `TransformGraph` or `None`
        A graph to register this transformation with on creation, or
        `None` to leave it unregistered.

    Raises
    ------
    TypeError
        If ``matrix_func`` is not callable

    """
    def __init__(self, matrix_func, fromsys, tosys, priority=1,
                 register_graph=None):
        if not six.callable(matrix_func):
            raise TypeError('matrix_func is not callable')
        self.matrix_func = matrix_func

        super(DynamicMatrixTransform, self).__init__(fromsys, tosys,
            priority=priority, register_graph=register_graph)

    def __call__(self, fromcoord, toframe):

        from .representation import CartesianRepresentation, \
                                    UnitSphericalRepresentation

        transform_matrix = self.matrix_func(fromcoord, toframe)

        rep = fromcoord.represent_as(CartesianRepresentation)
        newrep = rep.transform(transform_matrix)

        if issubclass(fromcoord.data.__class__, UnitSphericalRepresentation):
            #need to special-case this because otherwise the new class will
            #think it has a valid distance
            newrep = newrep.represent_as(fromcoord.data.__class__)

        return toframe.realize_frame(newrep)


class CompositeTransform(CoordinateTransform):
    """
    A transformation constructed by combining together a series of single-step
    transformations.

    Note that the intermediate frame objects are constructed using any frame
    attributes in ``toframe`` or ``fromframe`` that overlap with the intermediate
    frame (``toframe`` favored over ``fromframe`` if there's a conflict).  Any frame
    attributes that are not present use the defaults.

    Parameters
    ----------
    transforms : sequence of `CoordinateTransform` objects
        The sequence of transformations to apply.
    fromsys : class
        The coordinate frame class to start from.
    tosys : class
        The coordinate frame class to transform into.
    priority : number
        The priority if this transform when finding the shortest
        coordinate transform path - large numbers are lower priorities.
    register_graph : `TransformGraph` or `None`
        A graph to register this transformation with on creation, or
        `None` to leave it unregistered.
    collapse_static_mats : bool
        If `True`, consecutive `StaticMatrixTransform` will be collapsed into a
        single transformation to speed up the calculation.

    """
    def __init__(self, transforms, fromsys, tosys, priority=1,
                 register_graph=None, collapse_static_mats=True):
        super(CompositeTransform, self).__init__(fromsys, tosys,
                                                 priority=priority,
                                                 register_graph=register_graph)

        if collapse_static_mats:
            transforms = self._combine_statics(transforms)

        self.transforms = tuple(transforms)

    def _combine_statics(self, transforms):
        """
        Combines together sequences of `StaticMatrixTransform`s into a single
        transform and returns it.
        """
        newtrans = []
        for currtrans in transforms:
            lasttrans = newtrans[-1] if len(newtrans) > 0 else None

            if (isinstance(lasttrans, StaticMatrixTransform) and
                    isinstance(currtrans, StaticMatrixTransform)):
                combinedmat = np.dot(lasttrans.matrix, currtrans.matrix)
                newtrans[-1] = StaticMatrixTransform(combinedmat,
                                                     lasttrans.fromsys,
                                                     currtrans.tosys)
            else:
                newtrans.append(currtrans)
        return newtrans

    def __call__(self, fromcoord, toframe):
        curr_coord = fromcoord
        for t in self.transforms:
            #build an intermediate frame with attributes taken from either
            #`fromframe`, or if not there, `toframe`, or if not there, use
            #the defaults
            #TODO: caching this information when creating the transform may
            # speed things up a lot
            frattrs = {}
            for inter_frame_attr_nm in t.tosys.get_frame_attr_names():
                if hasattr(toframe, inter_frame_attr_nm):
                    attr = getattr(toframe, inter_frame_attr_nm)
                    frattrs[inter_frame_attr_nm] = attr
                elif hasattr(fromcoord, inter_frame_attr_nm):
                    attr = getattr(fromcoord, inter_frame_attr_nm)
                    frattrs[inter_frame_attr_nm] = attr

            curr_toframe = t.tosys(**frattrs)
            curr_coord = t(curr_coord, curr_toframe)

        # this is safe even in the case where self.transforms is empty, because
        # coordinate objects are immutible, so copying is not needed
        return curr_coord