File: test_core.py

package info (click to toggle)
reproject 0.14.1-4
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid
  • size: 3,024 kB
  • sloc: python: 4,749; ansic: 1,022; makefile: 114
file content (957 lines) | stat: -rw-r--r-- 33,960 bytes parent folder | download | duplicates (2)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
# Licensed under a 3-clause BSD style license - see LICENSE.rst

import itertools

import dask.array as da
import numpy as np
import pytest
from astropy import units as u
from astropy.io import fits
from astropy.utils.data import get_pkg_data_filename
from astropy.wcs import WCS
from astropy.wcs.wcs import FITSFixedWarning
from astropy.wcs.wcsapi import HighLevelWCSWrapper, SlicedLowLevelWCS
from numpy.testing import assert_allclose

from reproject.interpolation.high_level import reproject_interp
from reproject.tests.helpers import array_footprint_to_hdulist

# TODO: add reference comparisons


def as_high_level_wcs(wcs):
    return HighLevelWCSWrapper(SlicedLowLevelWCS(wcs, Ellipsis))


@pytest.mark.array_compare(single_reference=True)
@pytest.mark.parametrize("wcsapi", (False, True))
@pytest.mark.parametrize("roundtrip_coords", (False, True))
@pytest.mark.remote_data
def test_reproject_celestial_2d_gal2equ(wcsapi, roundtrip_coords):
    """
    Test reprojection of a 2D celestial image, which includes a coordinate
    system conversion.
    """
    with fits.open(get_pkg_data_filename("data/galactic_2d.fits", package="reproject.tests")) as pf:
        hdu_in = pf[0]
        header_out = hdu_in.header.copy()
        header_out["CTYPE1"] = "RA---TAN"
        header_out["CTYPE2"] = "DEC--TAN"
        header_out["CRVAL1"] = 266.39311
        header_out["CRVAL2"] = -28.939779

        if wcsapi:  # Enforce a pure wcsapi API
            wcs_in, data_in = as_high_level_wcs(WCS(hdu_in.header)), hdu_in.data
            wcs_out = as_high_level_wcs(WCS(header_out))
            shape_out = header_out["NAXIS2"], header_out["NAXIS1"]
            array_out, footprint_out = reproject_interp(
                (data_in, wcs_in), wcs_out, shape_out=shape_out, roundtrip_coords=roundtrip_coords
            )
        else:
            array_out, footprint_out = reproject_interp(
                hdu_in, header_out, roundtrip_coords=roundtrip_coords
            )

    return array_footprint_to_hdulist(array_out, footprint_out, header_out)


# Note that we can't use independent_celestial_slices=True and reorder the
# axes, hence why we need to prepare the combinations in this way.
AXIS_ORDER = list(itertools.permutations((0, 1, 2)))
COMBINATIONS = []
for wcsapi in (False, True):
    for axis_order in AXIS_ORDER:
        COMBINATIONS.append((wcsapi, axis_order))


@pytest.mark.array_compare(single_reference=True)
@pytest.mark.parametrize(("wcsapi", "axis_order"), tuple(COMBINATIONS))
@pytest.mark.parametrize("roundtrip_coords", (False, True))
@pytest.mark.remote_data
def test_reproject_celestial_3d_equ2gal(wcsapi, axis_order, roundtrip_coords):
    """
    Test reprojection of a 3D cube with celestial components, which includes a
    coordinate system conversion (the original header is in equatorial
    coordinates). We test using both the 'fast' method which assumes celestial
    slices are independent, and the 'full' method. We also scramble the input
    dimensions of the data and header to make sure that the reprojection can
    deal with this.
    """

    # Read in the input cube
    with fits.open(
        get_pkg_data_filename("data/equatorial_3d.fits", package="reproject.tests")
    ) as pf:
        hdu_in = pf[0]

        # Define the output header - this should be the same for all versions of
        # this test to make sure we can use a single reference file.
        header_out = hdu_in.header.copy()
        header_out["NAXIS1"] = 10
        header_out["NAXIS2"] = 9
        header_out["CTYPE1"] = "GLON-SIN"
        header_out["CTYPE2"] = "GLAT-SIN"
        header_out["CRVAL1"] = 163.16724
        header_out["CRVAL2"] = -15.777405
        header_out["CRPIX1"] = 6
        header_out["CRPIX2"] = 5

        # We now scramble the input axes
        if axis_order != (0, 1, 2):
            wcs_in = WCS(hdu_in.header)
            wcs_in = wcs_in.sub((3 - np.array(axis_order)[::-1]).tolist())
            hdu_in.header = wcs_in.to_header()
            hdu_in.data = np.transpose(hdu_in.data, axis_order)

        if wcsapi:  # Enforce a pure wcsapi API
            wcs_in, data_in = as_high_level_wcs(WCS(hdu_in.header)), hdu_in.data
            wcs_out = as_high_level_wcs(WCS(header_out))
            shape_out = header_out["NAXIS3"], header_out["NAXIS2"], header_out["NAXIS1"]
            array_out, footprint_out = reproject_interp(
                (data_in, wcs_in), wcs_out, shape_out=shape_out, roundtrip_coords=roundtrip_coords
            )
        else:
            array_out, footprint_out = reproject_interp(
                hdu_in, header_out, roundtrip_coords=roundtrip_coords
            )

    return array_footprint_to_hdulist(array_out, footprint_out, header_out)


@pytest.mark.array_compare(single_reference=True)
@pytest.mark.parametrize("wcsapi", (False, True))
@pytest.mark.parametrize("roundtrip_coords", (False, True))
@pytest.mark.remote_data
def test_small_cutout(wcsapi, roundtrip_coords):
    """
    Test reprojection of a cutout from a larger image (makes sure that the
    pre-reprojection cropping works)
    """
    with fits.open(get_pkg_data_filename("data/galactic_2d.fits", package="reproject.tests")) as pf:
        hdu_in = pf[0]
        header_out = hdu_in.header.copy()
        header_out["NAXIS1"] = 10
        header_out["NAXIS2"] = 9
        header_out["CTYPE1"] = "RA---TAN"
        header_out["CTYPE2"] = "DEC--TAN"
        header_out["CRVAL1"] = 266.39311
        header_out["CRVAL2"] = -28.939779
        header_out["CRPIX1"] = 5.1
        header_out["CRPIX2"] = 4.7

        if wcsapi:  # Enforce a pure wcsapi API
            wcs_in, data_in = as_high_level_wcs(WCS(hdu_in.header)), hdu_in.data
            wcs_out = as_high_level_wcs(WCS(header_out))
            shape_out = header_out["NAXIS2"], header_out["NAXIS1"]
            array_out, footprint_out = reproject_interp(
                (data_in, wcs_in), wcs_out, shape_out=shape_out, roundtrip_coords=roundtrip_coords
            )
        else:
            array_out, footprint_out = reproject_interp(
                hdu_in, header_out, roundtrip_coords=roundtrip_coords
            )

    return array_footprint_to_hdulist(array_out, footprint_out, header_out)


@pytest.mark.parametrize("roundtrip_coords", (False, True))
@pytest.mark.remote_data
def test_mwpan_car_to_mol(roundtrip_coords):
    """
    Test reprojection of the Mellinger Milky Way Panorama from CAR to MOL,
    which was returning all NaNs due to a regression that was introduced in
    reproject 0.3 (https://github.com/astrofrog/reproject/pull/124).
    """
    hdu_in = fits.Header.fromtextfile(
        get_pkg_data_filename("data/mwpan2_RGB_3600.hdr", package="reproject.tests")
    )
    with pytest.warns(FITSFixedWarning):
        wcs_in = WCS(hdu_in, naxis=2)
    data_in = np.ones((hdu_in["NAXIS2"], hdu_in["NAXIS1"]), dtype=float)
    header_out = fits.Header()
    header_out["NAXIS"] = 2
    header_out["NAXIS1"] = 360
    header_out["NAXIS2"] = 180
    header_out["CRPIX1"] = 180
    header_out["CRPIX2"] = 90
    header_out["CRVAL1"] = 0
    header_out["CRVAL2"] = 0
    header_out["CDELT1"] = -2 * np.sqrt(2) / np.pi
    header_out["CDELT2"] = 2 * np.sqrt(2) / np.pi
    header_out["CTYPE1"] = "GLON-MOL"
    header_out["CTYPE2"] = "GLAT-MOL"
    header_out["RADESYS"] = "ICRS"
    array_out, footprint_out = reproject_interp(
        (data_in, wcs_in), header_out, roundtrip_coords=roundtrip_coords
    )
    assert np.isfinite(array_out).any()


@pytest.mark.parametrize("roundtrip_coords", (False, True))
@pytest.mark.remote_data
def test_small_cutout_outside(roundtrip_coords):
    """
    Test reprojection of a cutout from a larger image - in this case the
    cutout is completely outside the region of the input image so we should
    take a shortcut that returns arrays of NaNs.
    """
    with fits.open(get_pkg_data_filename("data/galactic_2d.fits", package="reproject.tests")) as pf:
        hdu_in = pf[0]
        header_out = hdu_in.header.copy()
        header_out["NAXIS1"] = 10
        header_out["NAXIS2"] = 9
        header_out["CTYPE1"] = "RA---TAN"
        header_out["CTYPE2"] = "DEC--TAN"
        header_out["CRVAL1"] = 216.39311
        header_out["CRVAL2"] = -21.939779
        header_out["CRPIX1"] = 5.1
        header_out["CRPIX2"] = 4.7
        array_out, footprint_out = reproject_interp(
            hdu_in, header_out, roundtrip_coords=roundtrip_coords
        )
    assert np.all(np.isnan(array_out))
    assert np.all(footprint_out == 0)


@pytest.mark.parametrize("roundtrip_coords", (False, True))
@pytest.mark.remote_data
def test_celestial_mismatch_2d(roundtrip_coords):
    """
    Make sure an error is raised if the input image has celestial WCS
    information and the output does not (and vice-versa). This example will
    use the _reproject_celestial route.
    """
    with fits.open(get_pkg_data_filename("data/galactic_2d.fits", package="reproject.tests")) as pf:
        hdu_in = pf[0]

        header_out = hdu_in.header.copy()
        header_out["CTYPE1"] = "APPLES"
        header_out["CTYPE2"] = "ORANGES"

        data = hdu_in.data
        wcs1 = WCS(hdu_in.header)
        wcs2 = WCS(header_out)

        with pytest.raises(
            ValueError, match="Input WCS has celestial components but output WCS does not"
        ):
            array_out, footprint_out = reproject_interp(
                (data, wcs1), wcs2, shape_out=(2, 2), roundtrip_coords=roundtrip_coords
            )


@pytest.mark.parametrize("roundtrip_coords", (False, True))
@pytest.mark.remote_data
def test_celestial_mismatch_3d(roundtrip_coords):
    """
    Make sure an error is raised if the input image has celestial WCS
    information and the output does not (and vice-versa). This example will
    use the _reproject_full route.
    """
    with fits.open(
        get_pkg_data_filename("data/equatorial_3d.fits", package="reproject.tests")
    ) as pf:
        hdu_in = pf[0]

        header_out = hdu_in.header.copy()
        header_out["CTYPE1"] = "APPLES"
        header_out["CTYPE2"] = "ORANGES"
        header_out["CTYPE3"] = "BANANAS"

        data = hdu_in.data
        wcs1 = WCS(hdu_in.header)
        wcs2 = WCS(header_out)

        with pytest.raises(
            ValueError, match="Input WCS has celestial components but output WCS does not"
        ):
            array_out, footprint_out = reproject_interp(
                (data, wcs1), wcs2, shape_out=(1, 2, 3), roundtrip_coords=roundtrip_coords
            )

        with pytest.raises(
            ValueError, match="Output WCS has celestial components but input WCS does not"
        ):
            array_out, footprint_out = reproject_interp(
                (data, wcs2), wcs1, shape_out=(1, 2, 3), roundtrip_coords=roundtrip_coords
            )


@pytest.mark.parametrize("roundtrip_coords", (False, True))
@pytest.mark.remote_data
def test_spectral_mismatch_3d(roundtrip_coords):
    """
    Make sure an error is raised if there are mismatches between the presence
    or type of spectral axis.
    """
    with fits.open(
        get_pkg_data_filename("data/equatorial_3d.fits", package="reproject.tests")
    ) as pf:
        hdu_in = pf[0]

        header_out = hdu_in.header.copy()
        header_out["CTYPE3"] = "FREQ"
        header_out["CUNIT3"] = "Hz"

        data = hdu_in.data
        wcs1 = WCS(hdu_in.header)
        wcs2 = WCS(header_out)

        with pytest.raises(
            ValueError,
            match=r"The input \(VOPT\) and output \(FREQ\) spectral "
            r"coordinate types are not equivalent\.",
        ):
            array_out, footprint_out = reproject_interp(
                (data, wcs1), wcs2, shape_out=(1, 2, 3), roundtrip_coords=roundtrip_coords
            )

        header_out["CTYPE3"] = "BANANAS"
        wcs2 = WCS(header_out)

        with pytest.raises(
            ValueError, match="Input WCS has a spectral component but output WCS does not"
        ):
            array_out, footprint_out = reproject_interp(
                (data, wcs1), wcs2, shape_out=(1, 2, 3), roundtrip_coords=roundtrip_coords
            )

        with pytest.raises(
            ValueError, match="Output WCS has a spectral component but input WCS does not"
        ):
            array_out, footprint_out = reproject_interp(
                (data, wcs2), wcs1, shape_out=(1, 2, 3), roundtrip_coords=roundtrip_coords
            )


@pytest.mark.parametrize("roundtrip_coords", (False, True))
def test_naxis_mismatch(roundtrip_coords):
    """
    Make sure an error is raised if the input and output WCS have a different
    number of dimensions.
    """
    data = np.ones((3, 2, 2))
    wcs_in = WCS(naxis=3)
    wcs_out = WCS(naxis=2)

    with pytest.raises(
        ValueError, match="Number of dimensions in input and output WCS should match"
    ):
        array_out, footprint_out = reproject_interp(
            (data, wcs_in), wcs_out, shape_out=(1, 2), roundtrip_coords=roundtrip_coords
        )


@pytest.mark.parametrize("roundtrip_coords", (False, True))
@pytest.mark.remote_data
def test_slice_reprojection(roundtrip_coords):
    """
    Test case where only the slices change and the celestial projection doesn't
    """
    inp_cube = np.arange(3, dtype="float").repeat(4 * 5).reshape(3, 4, 5)

    header_in = fits.Header.fromtextfile(
        get_pkg_data_filename("data/cube.hdr", package="reproject.tests")
    )

    header_in["NAXIS1"] = 5
    header_in["NAXIS2"] = 4
    header_in["NAXIS3"] = 3

    header_out = header_in.copy()
    header_out["NAXIS3"] = 2
    header_out["CRPIX3"] -= 0.5

    wcs_in = WCS(header_in)
    wcs_out = WCS(header_out)

    out_cube, out_cube_valid = reproject_interp(
        (inp_cube, wcs_in), wcs_out, shape_out=(2, 4, 5), roundtrip_coords=roundtrip_coords
    )

    # we expect to be projecting from
    # inp_cube = np.arange(3, dtype='float').repeat(4*5).reshape(3,4,5)
    # to
    # inp_cube_interp = (inp_cube[:-1]+inp_cube[1:])/2.
    # which is confirmed by
    # map_coordinates(inp_cube.astype('float'), new_coords, order=1, cval=np.nan, mode='constant')
    # np.testing.assert_allclose(inp_cube_interp, map_coordinates(inp_cube.astype('float'),
    # new_coords, order=1, cval=np.nan, mode='constant'))

    assert out_cube.shape == (2, 4, 5)
    assert out_cube_valid.sum() == 40.0

    # We only check that the *valid* pixels are equal
    # but it's still nice to check that the "valid" array works as a mask
    np.testing.assert_allclose(
        out_cube[out_cube_valid.astype("bool")],
        ((inp_cube[:-1] + inp_cube[1:]) / 2.0)[out_cube_valid.astype("bool")],
    )

    # Actually, I fixed it, so now we can test all
    np.testing.assert_allclose(out_cube, ((inp_cube[:-1] + inp_cube[1:]) / 2.0))


@pytest.mark.parametrize("roundtrip_coords", (False, True))
@pytest.mark.remote_data
def test_inequal_wcs_dims(roundtrip_coords):
    inp_cube = np.arange(3, dtype="float").repeat(4 * 5).reshape(3, 4, 5)
    header_in = fits.Header.fromtextfile(
        get_pkg_data_filename("data/cube.hdr", package="reproject.tests")
    )

    header_out = header_in.copy()
    header_out["CTYPE3"] = "VRAD"
    header_out["CUNIT3"] = "m/s"
    header_in["CTYPE3"] = "STOKES"
    header_in["CUNIT3"] = ""

    wcs_out = WCS(header_out)

    with pytest.raises(
        ValueError, match="Output WCS has a spectral component but input WCS does not"
    ):
        out_cube, out_cube_valid = reproject_interp(
            (inp_cube, header_in), wcs_out, shape_out=(2, 4, 5), roundtrip_coords=roundtrip_coords
        )


@pytest.mark.parametrize("roundtrip_coords", (False, True))
@pytest.mark.remote_data
def test_different_wcs_types(roundtrip_coords):
    inp_cube = np.arange(3, dtype="float").repeat(4 * 5).reshape(3, 4, 5)
    header_in = fits.Header.fromtextfile(
        get_pkg_data_filename("data/cube.hdr", package="reproject.tests")
    )

    header_out = header_in.copy()
    header_out["CTYPE3"] = "VRAD"
    header_out["CUNIT3"] = "m/s"
    header_in["CTYPE3"] = "VELO"
    header_in["CUNIT3"] = "m/s"

    wcs_out = WCS(header_out)

    with pytest.raises(
        ValueError,
        match=r"The input \(VELO\) and output \(VRAD\) spectral "
        r"coordinate types are not equivalent\.",
    ):
        out_cube, out_cube_valid = reproject_interp(
            (inp_cube, header_in), wcs_out, shape_out=(2, 4, 5), roundtrip_coords=roundtrip_coords
        )


# TODO: add a test to check the units are the same.


@pytest.mark.parametrize("roundtrip_coords", (False, True))
@pytest.mark.remote_data
def test_reproject_3d_celestial_correctness_ra2gal(roundtrip_coords):
    inp_cube = np.arange(3, dtype="float").repeat(7 * 8).reshape(3, 7, 8)

    header_in = fits.Header.fromtextfile(
        get_pkg_data_filename("data/cube.hdr", package="reproject.tests")
    )

    header_in["NAXIS1"] = 8
    header_in["NAXIS2"] = 7
    header_in["NAXIS3"] = 3

    header_out = header_in.copy()
    header_out["CTYPE1"] = "GLON-TAN"
    header_out["CTYPE2"] = "GLAT-TAN"
    header_out["CRVAL1"] = 158.5644791
    header_out["CRVAL2"] = -21.59589875
    # make the cube a cutout approximately in the center of the other one, but smaller
    header_out["NAXIS1"] = 4
    header_out["CRPIX1"] = 2
    header_out["NAXIS2"] = 3
    header_out["CRPIX2"] = 1.5

    header_out["NAXIS3"] = 2
    header_out["CRPIX3"] -= 0.5

    wcs_in = WCS(header_in)
    wcs_out = WCS(header_out)

    out_cube, out_cube_valid = reproject_interp(
        (inp_cube, wcs_in), wcs_out, shape_out=(2, 3, 4), roundtrip_coords=roundtrip_coords
    )

    assert out_cube.shape == (2, 3, 4)
    assert out_cube_valid.sum() == out_cube.size

    # only compare the spectral axis
    np.testing.assert_allclose(out_cube[:, 0, 0], ((inp_cube[:-1] + inp_cube[1:]) / 2.0)[:, 0, 0])


@pytest.mark.parametrize("roundtrip_coords", (False, True))
@pytest.mark.remote_data
def test_reproject_with_output_array(roundtrip_coords):
    """
    Test both full_reproject and slicewise reprojection. We use a case where the
    non-celestial slices are the same and therefore where both algorithms can
    work.
    """
    header_in = fits.Header.fromtextfile(
        get_pkg_data_filename("data/cube.hdr", package="reproject.tests")
    )

    array_in = np.ones((3, 200, 180))
    shape_out = (3, 160, 170)
    out_full = np.empty(shape_out)

    wcs_in = WCS(header_in)
    wcs_out = wcs_in.deepcopy()
    wcs_out.wcs.ctype = ["GLON-SIN", "GLAT-SIN", wcs_in.wcs.ctype[2]]
    wcs_out.wcs.crval = [158.0501, -21.530282, wcs_in.wcs.crval[2]]
    wcs_out.wcs.crpix = [50.0, 50.0, wcs_in.wcs.crpix[2] + 0.4]

    # TODO when someone learns how to do it: make sure the memory isn't duplicated...
    returned_array = reproject_interp(
        (array_in, wcs_in),
        wcs_out,
        output_array=out_full,
        return_footprint=False,
        roundtrip_coords=roundtrip_coords,
    )

    assert out_full is returned_array


@pytest.mark.array_compare(single_reference=True)
@pytest.mark.remote_data
def test_reproject_roundtrip(aia_test_data):
    # Test the reprojection with solar data, which ensures that the masking of
    # pixels based on round-tripping works correctly. Using asdf is not just
    # about testing a different format but making sure that GWCS works.

    pytest.importorskip("sunpy", minversion="6.0.1")

    data, wcs, target_wcs = aia_test_data

    output, footprint = reproject_interp((data, wcs), target_wcs, (128, 128))

    header_out = target_wcs.to_header()

    header_out["DATE-OBS"] = header_out["DATE-OBS"].replace("T", " ")

    # With sunpy 6.0.0 and later, additional keyword arguments are written out
    # so we remove these as they are not important for the comparison with the
    # reference files.
    header_out.pop("DATE-AVG", None)
    header_out.pop("MJD-AVG", None)

    return array_footprint_to_hdulist(output, footprint, header_out)


def test_reproject_roundtrip_kwarg(aia_test_data):
    # Make sure that the roundtrip_coords keyword argument has an effect. This
    # is a regression test for a bug that caused the keyword argument to be
    # ignored when in parallel/blocked mode.

    pytest.importorskip("sunpy", minversion="6.0.1")

    data, wcs, target_wcs = aia_test_data

    output_roundtrip_1 = reproject_interp(
        (data, wcs), target_wcs, shape_out=(128, 128), return_footprint=False, roundtrip_coords=True
    )
    output_roundtrip_2 = reproject_interp(
        (data, wcs),
        target_wcs,
        shape_out=(128, 128),
        return_footprint=False,
        roundtrip_coords=True,
        block_size=(32, 32),
    )

    assert_allclose(output_roundtrip_1, output_roundtrip_2)

    output_noroundtrip_1 = reproject_interp(
        (data, wcs),
        target_wcs,
        shape_out=(128, 128),
        return_footprint=False,
        roundtrip_coords=False,
    )
    output_noroundtrip_2 = reproject_interp(
        (data, wcs),
        target_wcs,
        shape_out=(128, 128),
        return_footprint=False,
        roundtrip_coords=False,
        block_size=(32, 32),
    )

    assert_allclose(output_noroundtrip_1, output_noroundtrip_2)

    # The array with round-tripping should have more NaN values:
    assert np.sum(np.isnan(output_roundtrip_1)) > 9500
    assert np.sum(np.isnan(output_noroundtrip_1)) < 7000


@pytest.mark.parametrize("roundtrip_coords", (False, True))
@pytest.mark.remote_data
def test_identity_with_offset(roundtrip_coords):
    # Reproject an array and WCS to itself but with a margin, which should
    # end up empty. This is a regression test for a bug that caused some
    # values to extend beyond the original footprint.

    wcs = WCS(naxis=2)
    wcs.wcs.ctype = "RA---TAN", "DEC--TAN"
    wcs.wcs.crpix = 322, 151
    wcs.wcs.crval = 43, 23
    wcs.wcs.cdelt = -0.1, 0.1
    wcs.wcs.equinox = 2000.0

    array_in = np.random.random((233, 123))

    wcs_out = wcs.deepcopy()
    wcs_out.wcs.crpix += 1
    shape_out = (array_in.shape[0] + 2, array_in.shape[1] + 2)

    array_out, footprint = reproject_interp(
        (array_in, wcs), wcs_out, shape_out=shape_out, roundtrip_coords=roundtrip_coords
    )

    expected = np.pad(array_in, 1, "constant", constant_values=np.nan)

    assert_allclose(expected, array_out, atol=1e-10)


def _setup_for_broadcast_test():
    with fits.open(get_pkg_data_filename("data/galactic_2d.fits", package="reproject.tests")) as pf:
        hdu_in = pf[0]
        header_in = hdu_in.header.copy()
        header_out = hdu_in.header.copy()
        header_out["CTYPE1"] = "RA---TAN"
        header_out["CTYPE2"] = "DEC--TAN"
        header_out["CRVAL1"] = 266.39311
        header_out["CRVAL2"] = -28.939779

        data = hdu_in.data

    image_stack = np.stack((data, data.T, data[::-1], data[:, ::-1]))

    # Build the reference array through un-broadcast reprojections
    array_ref = np.empty_like(image_stack)
    footprint_ref = np.empty_like(image_stack)
    for i in range(len(image_stack)):
        array_out, footprint_out = reproject_interp((image_stack[i], header_in), header_out)
        array_ref[i] = array_out
        footprint_ref[i] = footprint_out

    return image_stack, array_ref, footprint_ref, header_in, header_out


@pytest.mark.parametrize("input_extra_dims", (1, 2))
@pytest.mark.parametrize("output_shape", (None, "single", "full"))
@pytest.mark.parametrize("input_as_wcs", (True, False))
@pytest.mark.parametrize("output_as_wcs", (True, False))
def test_broadcast_reprojection(input_extra_dims, output_shape, input_as_wcs, output_as_wcs):
    image_stack, array_ref, footprint_ref, header_in, header_out = _setup_for_broadcast_test()
    # Test both single and multiple dimensions being broadcast
    if input_extra_dims == 2:
        image_stack = image_stack.reshape((2, 2, *image_stack.shape[-2:]))
        array_ref.shape = image_stack.shape
        footprint_ref.shape = image_stack.shape

    # Test different ways of providing the output shape
    if output_shape == "single":
        # Have the broadcast dimensions be auto-added to the output shape
        output_shape = image_stack.shape[-2:]
    elif output_shape == "full":
        # Provide the broadcast dimensions as part of the output shape
        output_shape = image_stack.shape

    # Ensure logic works with WCS inputs as well as Header inputs
    if input_as_wcs:
        header_in = WCS(header_in)
    if output_as_wcs:
        header_out = WCS(header_out)
        if output_shape is None:
            # This combination of parameter values is not valid
            return

    array_broadcast, footprint_broadcast = reproject_interp(
        (image_stack, header_in),
        header_out,
        output_shape,
    )

    np.testing.assert_array_equal(footprint_broadcast, footprint_ref)
    np.testing.assert_allclose(array_broadcast, array_ref)


# In the tests below we ignore FITSFixedWarning due to:
# https://github.com/astropy/astropy/pull/12844


@pytest.mark.parametrize("input_extra_dims", (1, 2))
@pytest.mark.parametrize("output_shape", (None, "single", "full"))
@pytest.mark.parametrize("parallel", [True, False])
@pytest.mark.parametrize("header_or_wcs", (lambda x: x, WCS))
@pytest.mark.filterwarnings("ignore::astropy.wcs.wcs.FITSFixedWarning")
def test_blocked_broadcast_reprojection(input_extra_dims, output_shape, parallel, header_or_wcs):
    image_stack, array_ref, footprint_ref, header_in, header_out = _setup_for_broadcast_test()
    # Test both single and multiple dimensions being broadcast
    if input_extra_dims == 2:
        image_stack = image_stack.reshape((2, 2, *image_stack.shape[-2:]))
        array_ref.shape = image_stack.shape
        footprint_ref.shape = image_stack.shape

    # Test different ways of providing the output shape
    if output_shape == "single":
        # Have the broadcast dimensions be auto-added to the output shape
        output_shape = image_stack.shape[-2:]
    elif output_shape == "full":
        # Provide the broadcast dimensions as part of the output shape
        output_shape = image_stack.shape

    # test different behavior when the output projection is a WCS
    header_out = header_or_wcs(header_out)

    array_broadcast, footprint_broadcast = reproject_interp(
        (image_stack, header_in), header_out, output_shape, parallel=parallel, block_size=[5, 5]
    )

    np.testing.assert_array_equal(footprint_broadcast, footprint_ref)
    np.testing.assert_allclose(array_broadcast, array_ref)


@pytest.mark.parametrize("parallel", [True, 2, False])
@pytest.mark.parametrize("block_size", [[500, 500], [500, 100], None])
@pytest.mark.parametrize("return_footprint", [False, True])
@pytest.mark.parametrize("existing_outputs", [False, True])
@pytest.mark.parametrize("header_or_wcs", (lambda x: x, WCS))
@pytest.mark.remote_data
@pytest.mark.filterwarnings("ignore::astropy.wcs.wcs.FITSFixedWarning")
def test_blocked_against_single(
    parallel, block_size, return_footprint, existing_outputs, header_or_wcs
):
    # Ensure when we break a reprojection down into multiple discrete blocks
    # it has the same result as if all pixels where reprejcted at once

    hdu1 = fits.open(get_pkg_data_filename("galactic_center/gc_2mass_k.fits"))[0]
    hdu2 = fits.open(get_pkg_data_filename("galactic_center/gc_msx_e.fits"))[0]
    array_test = None
    footprint_test = None

    shape_out = (720, 721)

    if existing_outputs:
        output_array_test = np.zeros(shape_out)
        output_footprint_test = np.zeros(shape_out)
        output_array_reference = np.zeros(shape_out)
        output_footprint_reference = np.zeros(shape_out)
    else:
        output_array_test = None
        output_footprint_test = None
        output_array_reference = None
        output_footprint_reference = None

    result_test = reproject_interp(
        hdu2,
        header_or_wcs(hdu1.header),
        parallel=parallel,
        block_size=block_size,
        return_footprint=return_footprint,
        output_array=output_array_test,
        output_footprint=output_footprint_test,
    )

    result_reference = reproject_interp(
        hdu2,
        header_or_wcs(hdu1.header),
        parallel=False,
        block_size=None,
        return_footprint=return_footprint,
        output_array=output_array_reference,
        output_footprint=output_footprint_reference,
    )

    if return_footprint:
        array_test, footprint_test = result_test
        array_reference, footprint_reference = result_reference
    else:
        array_test = result_test
        array_reference = result_reference

    if existing_outputs:
        assert array_test is output_array_test
        assert array_reference is output_array_reference
        if return_footprint:
            assert footprint_test is output_footprint_test
            assert footprint_reference is output_footprint_reference

    np.testing.assert_allclose(array_test, array_reference, equal_nan=True)
    if return_footprint:
        np.testing.assert_allclose(footprint_test, footprint_reference, equal_nan=True)


def test_interp_input_output_types(valid_celestial_input_data, valid_celestial_output_projections):
    # Check that all valid input/output types work properly

    array_ref, wcs_in_ref, input_value, kwargs_in = valid_celestial_input_data

    wcs_out_ref, shape_ref, output_value, kwargs_out = valid_celestial_output_projections

    # Compute reference

    output_ref, footprint_ref = reproject_interp(
        (array_ref, wcs_in_ref), wcs_out_ref, shape_out=shape_ref
    )

    # Compute test

    output_test, footprint_test = reproject_interp(
        input_value, output_value, **kwargs_in, **kwargs_out
    )

    assert_allclose(output_ref, output_test)
    assert_allclose(footprint_ref, footprint_test)


@pytest.mark.parametrize("block_size", [None, (32, 32)])
def test_reproject_order(block_size):
    # Check that the order keyword argument has an effect. This is a regression
    # test for a bug that caused the order= keyword argument to be ignored when
    # in parallel/blocked reprojection.

    with fits.open(get_pkg_data_filename("data/galactic_2d.fits", package="reproject.tests")) as pf:
        hdu_in = pf[0]

        header_out = hdu_in.header.copy()
        header_out["CTYPE1"] = "RA---TAN"
        header_out["CTYPE2"] = "DEC--TAN"
        header_out["CRVAL1"] = 266.39311
        header_out["CRVAL2"] = -28.939779

        array_out_bilinear = reproject_interp(
            hdu_in,
            header_out,
            return_footprint=False,
            order="bilinear",
            block_size=block_size,
        )

        array_out_biquadratic = reproject_interp(
            hdu_in,
            header_out,
            return_footprint=False,
            order="biquadratic",
            block_size=block_size,
        )

        with pytest.raises(AssertionError):
            assert_allclose(array_out_bilinear, array_out_biquadratic)


@pytest.mark.skip(reason="needs too much memory on our ARM platforms")
def test_reproject_block_size_broadcasting():
    # Regression test for a bug that caused the default chunk size to be
    # inadequate when using broadcasting in parallel mode

    array_in = np.ones((350, 250, 150))
    wcs_in = WCS(naxis=2)
    wcs_out = WCS(naxis=2)

    reproject_interp(
        (array_in, wcs_in),
        wcs_out,
        shape_out=(300, 300),
        parallel=1,
        return_footprint=False,
    )

    # Specifying a block size that is missing the extra dimension should work fine:

    reproject_interp(
        (array_in, wcs_in),
        wcs_out,
        shape_out=(300, 300),
        parallel=1,
        return_footprint=False,
        block_size=(100, 100),
    )

    # Specifying a block size with the extra dimension should work provided it matches the final output shape

    reproject_interp(
        (array_in, wcs_in),
        wcs_out,
        shape_out=(300, 300),
        parallel=1,
        return_footprint=False,
        block_size=(350, 100, 100),
    )

    # But it should fail if we specify a block size that is smaller that the total array shape

    with pytest.raises(ValueError, match="block shape for extra broadcasted dimensions"):
        reproject_interp(
            (array_in, wcs_in),
            wcs_out,
            shape_out=(300, 300),
            parallel=1,
            return_footprint=False,
            block_size=(100, 100, 100),
        )


def test_reproject_dask_return_type():
    # Regression test for a bug that caused dask arrays to not be computable
    # when using return_type='dask' when the input was a dask array.

    array_in = da.ones((350, 250, 150))
    wcs_in = WCS(naxis=2)
    wcs_out = WCS(naxis=2)

    result_numpy = reproject_interp(
        (array_in, wcs_in),
        wcs_out,
        shape_out=(300, 300),
        return_type="numpy",
        return_footprint=False,
    )

    result_dask = reproject_interp(
        (array_in, wcs_in),
        wcs_out,
        shape_out=(300, 300),
        block_size=(100, 100),
        return_type="dask",
        return_footprint=False,
    )

    assert_allclose(result_numpy, result_dask.compute(scheduler="synchronous"))


def test_auto_block_size():
    # Unit test to make sure that specifying block_size='auto' works

    array_in = da.ones((350, 250, 150))
    wcs_in = WCS(naxis=2)
    wcs_out = WCS(naxis=2)

    # When block size and parallel aren't specified, can't return as dask arrays
    with pytest.raises(ValueError, match="Output cannot be returned as dask arrays"):
        reproject_interp(
            (array_in, wcs_in),
            wcs_out,
            shape_out=(300, 300),
            return_type="dask",
        )

    array_out, footprint_out = reproject_interp(
        (array_in, wcs_in),
        wcs_out,
        shape_out=(300, 300),
        return_type="dask",
        block_size="auto",
    )

    assert array_out.chunksize[0] == 350
    assert footprint_out.chunksize[0] == 350