File: util.py

package info (click to toggle)
python-cooler 0.9.1-1
  • links: PTS, VCS
  • area: main
  • in suites: bookworm
  • size: 32,596 kB
  • sloc: python: 10,555; makefile: 198; sh: 31
file content (825 lines) | stat: -rw-r--r-- 24,115 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
import os
import re
from collections import OrderedDict, defaultdict
from contextlib import contextmanager

import h5py
import numpy as np
import pandas as pd
from pandas.api.types import is_integer, is_scalar


def partition(start, stop, step):
    """Partition an integer interval into equally-sized subintervals.
    Like builtin :py:func:`range`, but yields pairs of end points.

    Examples
    --------
    >>> for lo, hi in partition(0, 9, 2):
           print(lo, hi)
    0 2
    2 4
    4 6
    6 8
    8 9

    """
    return ((i, min(i + step, stop)) for i in range(start, stop, step))


def parse_cooler_uri(s):
    """
    Parse a Cooler URI string

    e.g. /path/to/mycoolers.cool::/path/to/cooler

    """
    parts = s.split("::")
    if len(parts) == 1:
        file_path, group_path = parts[0], "/"
    elif len(parts) == 2:
        file_path, group_path = parts
        if not group_path.startswith("/"):
            group_path = "/" + group_path
    else:
        raise ValueError("Invalid Cooler URI string")
    return file_path, group_path


def atoi(s):
    return int(s.replace(",", ""))


def parse_humanized(s):
    _NUMERIC_RE = re.compile("([0-9,.]+)")
    _, value, unit = _NUMERIC_RE.split(s.replace(",", ""))
    if not len(unit):
        return int(value)

    value = float(value)
    unit = unit.upper().strip()
    if unit in ("K", "KB"):
        value *= 1000
    elif unit in ("M", "MB"):
        value *= 1000000
    elif unit in ("G", "GB"):
        value *= 1000000000
    else:
        raise ValueError(f"Unknown unit '{unit}'")
    return int(value)


def parse_region_string(s):
    """
    Parse a UCSC-style genomic region string into a triple.

    Parameters
    ----------
    s : str
        UCSC-style string, e.g. "chr5:10,100,000-30,000,000". Ensembl and FASTA
        style sequence names are allowed. End coordinate must be greater than
        or equal to start.

    Returns
    -------
    (str, int or None, int or None)

    """

    def _tokenize(s):
        token_spec = [
            ("HYPHEN", r"-"),
            ("COORD", r"[0-9,]+(\.[0-9]*)?(?:[a-z]+)?"),
            ("OTHER", r".+"),
        ]
        tok_regex = r"\s*" + r"|\s*".join(r"(?P<%s>%s)" % pair for pair in token_spec)
        tok_regex = re.compile(tok_regex, re.IGNORECASE)
        for match in tok_regex.finditer(s):
            typ = match.lastgroup
            yield typ, match.group(typ)

    def _check_token(typ, token, expected):
        if typ is None:
            raise ValueError("Expected {} token missing".format(" or ".join(expected)))
        else:
            if typ not in expected:
                raise ValueError(f'Unexpected token "{token}"')

    def _expect(tokens):
        typ, token = next(tokens, (None, None))
        _check_token(typ, token, ["COORD"])
        start = parse_humanized(token)

        typ, token = next(tokens, (None, None))
        _check_token(typ, token, ["HYPHEN"])

        typ, token = next(tokens, (None, None))
        if typ is None:
            return start, None

        _check_token(typ, token, ["COORD"])
        end = parse_humanized(token)
        if end < start:
            raise ValueError("End coordinate less than start")

        return start, end

    parts = s.split(":")
    chrom = parts[0].strip()
    if not len(chrom):
        raise ValueError("Chromosome name cannot be empty")
    if len(parts) < 2:
        return (chrom, None, None)
    start, end = _expect(_tokenize(parts[1]))
    return (chrom, start, end)


def parse_region(reg, chromsizes=None):
    """
    Genomic regions are represented as half-open intervals (0-based starts,
    1-based ends) along the length coordinate of a contig/scaffold/chromosome.

    Parameters
    ----------
    reg : str or tuple
        UCSC-style genomic region string, or
        Triple (chrom, start, end), where ``start`` or ``end`` may be ``None``.
    chromsizes : mapping, optional
        Lookup table of scaffold lengths to check against ``chrom`` and the
        ``end`` coordinate. Required if ``end`` is not supplied.

    Returns
    -------
    A well-formed genomic region triple (str, int, int)

    """
    if isinstance(reg, str):
        chrom, start, end = parse_region_string(reg)
    else:
        chrom, start, end = reg
        start = int(start) if start is not None else start
        end = int(end) if end is not None else end

    try:
        clen = chromsizes[chrom] if chromsizes is not None else None
    except KeyError as e:
        raise ValueError(f"Unknown sequence label: {chrom}") from e

    start = 0 if start is None else start
    if end is None:
        if clen is None:  # TODO --- remove?
            raise ValueError("Cannot determine end coordinate.")
        end = clen

    if end < start:
        raise ValueError("End cannot be less than start")

    if start < 0 or (clen is not None and end > clen):
        raise ValueError(f"Genomic region out of bounds: [{start}, {end})")

    return chrom, start, end


def natsort_key(s, _NS_REGEX=re.compile(r"(\d+)", re.U)):
    return tuple([int(x) if x.isdigit() else x for x in _NS_REGEX.split(s) if x])


def natsorted(iterable):
    return sorted(iterable, key=natsort_key)


def argnatsort(array):
    array = np.asarray(array)
    if not len(array):
        return np.array([], dtype=int)
    cols = tuple(zip(*(natsort_key(x) for x in array)))
    return np.lexsort(cols[::-1])


def read_chromsizes(
    filepath_or,
    name_patterns=(r"^chr[0-9]+$", r"^chr[XY]$", r"^chrM$"),
    all_names=False,
    **kwargs
):
    """
    Parse a ``<db>.chrom.sizes`` or ``<db>.chromInfo.txt`` file from the UCSC
    database, where ``db`` is a genome assembly name.

    Parameters
    ----------
    filepath_or : str or file-like
        Path or url to text file, or buffer.
    name_patterns : sequence, optional
        Sequence of regular expressions to capture desired sequence names.
        Each corresponding set of records will be sorted in natural order.
    all_names : bool, optional
        Whether to return all contigs listed in the file. Default is
        ``False``.

    Returns
    -------
    :py:class:`pandas.Series`
        Series of integer bp lengths indexed by sequence name.

    References
    ----------
    * `UCSC assembly terminology <http://genome.ucsc.edu/FAQ/FAQdownloads.html#download9>`_
    * `GRC assembly terminology <https://www.ncbi.nlm.nih.gov/grc/help/definitions>`_

    """
    if isinstance(filepath_or, str) and filepath_or.endswith(".gz"):
        kwargs.setdefault("compression", "gzip")
    chromtable = pd.read_csv(
        filepath_or,
        sep="\t",
        usecols=[0, 1],
        names=["name", "length"],
        dtype={"name": str},
        **kwargs
    )
    if not all_names:
        parts = []
        for pattern in name_patterns:
            part = chromtable[chromtable["name"].str.contains(pattern)]
            part = part.iloc[argnatsort(part["name"])]
            parts.append(part)
        chromtable = pd.concat(parts, axis=0)
    chromtable.index = chromtable["name"].values
    return chromtable["length"]


def fetch_chromsizes(db, **kwargs):
    """
    Download chromosome sizes from UCSC as a :py:class:`pandas.Series`, indexed
    by chromosome label.

    """
    return read_chromsizes(
        "http://hgdownload.cse.ucsc.edu/goldenPath/{}/database/chromInfo.txt.gz".format(
            db
        ),
        **kwargs
    )


def load_fasta(names, *filepaths):
    """
    Load lazy FASTA records from one or multiple files without reading them
    into memory.

    Parameters
    ----------
    names : sequence of str
        Names of sequence records in FASTA file or files.
    filepaths : str
        Paths to one or more FASTA files to gather records from.

    Returns
    -------
    OrderedDict of sequence name -> sequence record

    """
    import pyfaidx

    if len(filepaths) == 0:
        raise ValueError("Need at least one file")

    if len(filepaths) == 1:
        fa = pyfaidx.Fasta(filepaths[0], as_raw=True)

    else:
        fa = {}
        for filepath in filepaths:
            fa.update(pyfaidx.Fasta(filepath, as_raw=True).records)

    records = OrderedDict((chrom, fa[chrom]) for chrom in names)
    return records


def binnify(chromsizes, binsize):
    """
    Divide a genome into evenly sized bins.

    Parameters
    ----------
    chromsizes : Series
        pandas Series indexed by chromosome name with chromosome lengths in bp.
    binsize : int
        size of bins in bp

    Returns
    -------
    bins : :py:class:`pandas.DataFrame`
        Dataframe with columns: ``chrom``, ``start``, ``end``.

    """

    def _each(chrom):
        clen = chromsizes[chrom]
        n_bins = int(np.ceil(clen / binsize))
        binedges = np.arange(0, (n_bins + 1)) * binsize
        binedges[-1] = clen
        return pd.DataFrame(
            {"chrom": [chrom] * n_bins, "start": binedges[:-1], "end": binedges[1:]},
            columns=["chrom", "start", "end"],
        )

    bintable = pd.concat(map(_each, chromsizes.keys()), axis=0, ignore_index=True)

    bintable["chrom"] = pd.Categorical(
        bintable["chrom"], categories=list(chromsizes.index), ordered=True
    )

    return bintable


make_bintable = binnify


def digest(fasta_records, enzyme):
    """
    Divide a genome into restriction fragments.

    Parameters
    ----------
    fasta_records : OrderedDict
        Dictionary of chromosome names to sequence records.
    enzyme: str
        Name of restriction enzyme (e.g., 'DpnII').

    Returns
    -------
    frags : :py:class:`pandas.DataFrame`
        Dataframe with columns: ``chrom``, ``start``, ``end``.

    """
    try:
        import Bio.Restriction as biorst
        import Bio.Seq as bioseq
    except ImportError:
        raise ImportError(
            "Biopython is required to find restriction fragments."
        ) from None

    # http://biopython.org/DIST/docs/cookbook/Restriction.html#mozTocId447698
    chroms = fasta_records.keys()
    try:
        cut_finder = getattr(biorst, enzyme).search
    except AttributeError as e:
        raise ValueError(f"Unknown enzyme name: {enzyme}") from e

    def _each(chrom):
        seq = bioseq.Seq(str(fasta_records[chrom]))
        cuts = np.r_[0, np.array(cut_finder(seq)) + 1, len(seq)].astype(int)
        n_frags = len(cuts) - 1

        frags = pd.DataFrame(
            {"chrom": [chrom] * n_frags, "start": cuts[:-1], "end": cuts[1:]},
            columns=["chrom", "start", "end"],
        )
        return frags

    return pd.concat(map(_each, chroms), axis=0, ignore_index=True)


def get_binsize(bins):
    """
    Infer bin size from a bin DataFrame. Assumes that the last bin of each
    contig is allowed to differ in size from the rest.

    Returns
    -------
    int or None if bins are non-uniform

    """
    sizes = set()
    for _chrom, group in bins.groupby("chrom"):
        sizes.update((group["end"] - group["start"]).iloc[:-1].unique())
        if len(sizes) > 1:
            return None
    if len(sizes) == 1:
        return next(iter(sizes))
    else:
        return None


def get_chromsizes(bins):
    """
    Infer chromsizes Series from a bin DataFrame. Assumes that the last bin of
    each contig is allowed to differ in size from the rest.

    Returns
    -------
    int or None if bins are non-uniform

    """
    chromtable = (
        bins.drop_duplicates(["chrom"], keep="last")[["chrom", "end"]]
        .reset_index(drop=True)
        .rename(columns={"chrom": "name", "end": "length"})
    )
    chroms, lengths = list(chromtable["name"]), list(chromtable["length"])
    return pd.Series(index=chroms, data=lengths)


def bedslice(grouped, chromsizes, region):
    """
    Range query on a BED-like dataframe with non-overlapping intervals.

    """
    chrom, start, end = parse_region(region, chromsizes)
    result = grouped.get_group(chrom)
    if start > 0 or end < chromsizes[chrom]:
        lo = result["end"].values.searchsorted(start, side="right")
        hi = lo + result["start"].values[lo:].searchsorted(end, side="left")
        result = result.iloc[lo:hi]
    return result


def asarray_or_dataset(x):
    return x if isinstance(x, h5py.Dataset) else np.asarray(x)


def rlencode(array, chunksize=None):
    """
    Run length encoding.
    Based on http://stackoverflow.com/a/32681075, which is based on the rle
    function from R.

    Parameters
    ----------
    x : 1D array_like
        Input array to encode
    dropna: bool, optional
        Drop all runs of NaNs.

    Returns
    -------
    start positions, run lengths, run values

    """
    where = np.flatnonzero
    array = asarray_or_dataset(array)
    n = len(array)
    if n == 0:
        return (
            np.array([], dtype=int),
            np.array([], dtype=int),
            np.array([], dtype=array.dtype),
        )

    if chunksize is None:
        chunksize = n

    starts, values = [], []
    last_val = np.nan
    for i in range(0, n, chunksize):
        x = array[i : i + chunksize]
        locs = where(x[1:] != x[:-1]) + 1
        if x[0] != last_val:
            locs = np.r_[0, locs]
        starts.append(i + locs)
        values.append(x[locs])
        last_val = x[-1]
    starts = np.concatenate(starts)
    lengths = np.diff(np.r_[starts, n])
    values = np.concatenate(values)

    return starts, lengths, values


def cmd_exists(cmd):
    return any(
        os.access(os.path.join(path, cmd), os.X_OK)
        for path in os.environ["PATH"].split(os.pathsep)
    )


def mad(data, axis=None):
    return np.median(np.abs(data - np.median(data, axis)), axis)


@contextmanager
def open_hdf5(fp, mode="r", *args, **kwargs):
    """
    Context manager like ``h5py.File`` but accepts already open HDF5 file
    handles which do not get closed on teardown.

    Parameters
    ----------
    fp : str or ``h5py.File`` object
        If an open file object is provided, it passes through unchanged,
        provided that the requested mode is compatible.
        If a filepath is passed, the context manager will close the file on
        tear down.

    mode : str
        * r        Readonly, file must exist
        * r+       Read/write, file must exist
        * a        Read/write if exists, create otherwise
        * w        Truncate if exists, create otherwise
        * w- or x  Fail if exists, create otherwise

    """
    if isinstance(fp, str):
        own_fh = True
        fh = h5py.File(fp, mode, *args, **kwargs)
    else:
        own_fh = False
        if mode == "r" and fp.file.mode == "r+":
            # warnings.warn("File object provided is writeable but intent is read-only")
            pass
        elif mode in ("r+", "a") and fp.file.mode == "r":
            raise ValueError("File object provided is not writeable")
        elif mode == "w":
            raise ValueError("Cannot truncate open file")
        elif mode in ("w-", "x"):
            raise ValueError("File exists")
        fh = fp
    try:
        yield fh
    finally:
        if own_fh:
            fh.close()


class closing_hdf5(h5py.Group):
    def __init__(self, grp):
        super().__init__(grp.id)

    def __enter__(self):
        return self

    def __exit__(self, *exc_info):
        return self.file.close()

    def close(self):
        self.file.close()


def attrs_to_jsonable(attrs):
    out = dict(attrs)
    for k, v in attrs.items():
        try:
            out[k] = v.item()
        except ValueError:
            out[k] = v.tolist()
        except AttributeError:
            out[k] = v
    return out


def infer_meta(x, index=None):  # pragma: no cover
    """
    Extracted and modified from dask/dataframe/utils.py :
        make_meta (BSD licensed)

    Create an empty pandas object containing the desired metadata.

    Parameters
    ----------
    x : dict, tuple, list, pd.Series, pd.DataFrame, pd.Index, dtype, scalar
        To create a DataFrame, provide a `dict` mapping of `{name: dtype}`, or
        an iterable of `(name, dtype)` tuples. To create a `Series`, provide a
        tuple of `(name, dtype)`. If a pandas object, names, dtypes, and index
        should match the desired output. If a dtype or scalar, a scalar of the
        same dtype is returned.
    index :  pd.Index, optional
        Any pandas index to use in the metadata. If none provided, a
        `RangeIndex` will be used.

    Examples
    --------
    >>> make_meta([('a', 'i8'), ('b', 'O')])
    Empty DataFrame
    Columns: [a, b]
    Index: []
    >>> make_meta(('a', 'f8'))
    Series([], Name: a, dtype: float64)
    >>> make_meta('i8')
    1

    """

    _simple_fake_mapping = {
        "b": np.bool_(True),
        "V": np.void(b" "),
        "M": np.datetime64("1970-01-01"),
        "m": np.timedelta64(1),
        "S": np.str_("foo"),
        "a": np.str_("foo"),
        "U": np.unicode_("foo"),
        "O": "foo",
    }

    UNKNOWN_CATEGORIES = "__UNKNOWN_CATEGORIES__"

    def _scalar_from_dtype(dtype):
        if dtype.kind in ("i", "f", "u"):
            return dtype.type(1)
        elif dtype.kind == "c":
            return dtype.type(complex(1, 0))
        elif dtype.kind in _simple_fake_mapping:
            o = _simple_fake_mapping[dtype.kind]
            return o.astype(dtype) if dtype.kind in ("m", "M") else o
        else:
            raise TypeError(f"Can't handle dtype: {dtype}")

    def _nonempty_scalar(x):
        if isinstance(x, (pd.Timestamp, pd.Timedelta, pd.Period)):
            return x
        elif np.isscalar(x):
            dtype = x.dtype if hasattr(x, "dtype") else np.dtype(type(x))
            return _scalar_from_dtype(dtype)
        else:
            raise TypeError(
                "Can't handle meta of type " "'{}'".format(type(x).__name__)
            )

    def _empty_series(name, dtype, index=None):
        if isinstance(dtype, str) and dtype == "category":
            return pd.Series(
                pd.Categorical([UNKNOWN_CATEGORIES]), name=name, index=index
            ).iloc[:0]
        return pd.Series([], dtype=dtype, name=name, index=index)

    if hasattr(x, "_meta"):
        return x._meta
    if isinstance(x, (pd.Series, pd.DataFrame)):
        return x.iloc[0:0]
    elif isinstance(x, pd.Index):
        return x[0:0]
    index = index if index is None else index[0:0]

    if isinstance(x, dict):
        return pd.DataFrame(
            {c: _empty_series(c, d, index=index) for (c, d) in x.items()}, index=index
        )
    if isinstance(x, tuple) and len(x) == 2:
        return _empty_series(x[0], x[1], index=index)
    elif isinstance(x, (list, tuple)):
        if not all(isinstance(i, tuple) and len(i) == 2 for i in x):
            raise ValueError(
                "Expected iterable of tuples of (name, dtype), " "got {}".format(x)
            )
        return pd.DataFrame(
            {c: _empty_series(c, d, index=index) for (c, d) in x},
            columns=[c for c, d in x],
            index=index,
        )
    elif not hasattr(x, "dtype") and x is not None:
        # could be a string, a dtype object, or a python type. Skip `None`,
        # because it is implictly converted to `dtype('f8')`, which we don't
        # want here.
        try:
            dtype = np.dtype(x)
            return _scalar_from_dtype(dtype)
        except:  # noqa
            # Continue on to next check
            pass

    if is_scalar(x):
        return _nonempty_scalar(x)

    raise TypeError(f"Don't know how to create metadata from {x}")


def get_meta(
    columns, dtype=None, index_columns=None, index_names=None, default_dtype=np.object_
):  # pragma: no cover
    """
    Extracted and modified from pandas/io/parsers.py :
        _get_empty_meta (BSD licensed).

    """
    columns = list(columns)

    # Convert `dtype` to a defaultdict of some kind.
    # This will enable us to write `dtype[col_name]`
    # without worrying about KeyError issues later on.
    if not isinstance(dtype, dict):
        # if dtype == None, default will be default_dtype.
        dtype = defaultdict(lambda: dtype or default_dtype)
    else:
        # Save a copy of the dictionary.
        _dtype = dtype.copy()
        dtype = defaultdict(lambda: default_dtype)

        # Convert column indexes to column names.
        for k, v in _dtype.items():
            col = columns[k] if is_integer(k) else k
            dtype[col] = v

    if index_columns is None or index_columns is False:
        index = pd.Index([])
    else:
        data = [pd.Series([], dtype=dtype[name]) for name in index_names]
        if len(data) == 1:
            index = pd.Index(data[0], name=index_names[0])
        else:
            index = pd.MultiIndex.from_arrays(data, names=index_names)
        index_columns.sort()
        for i, n in enumerate(index_columns):
            columns.pop(n - i)

    col_dict = {col_name: pd.Series([], dtype=dtype[col_name]) for col_name in columns}

    return pd.DataFrame(col_dict, columns=columns, index=index)


def check_bins(bins, chromsizes):
    is_cat = pd.api.types.is_categorical_dtype(bins["chrom"])
    bins = bins.copy()
    if not is_cat:
        bins["chrom"] = pd.Categorical(
            bins.chrom, categories=list(chromsizes.index), ordered=True
        )
    else:
        assert (bins["chrom"].cat.categories == chromsizes.index).all()

    return bins


def balanced_partition(gs, n_chunk_max, file_contigs, loadings=None):
    # n_bins = len(gs.bins)
    grouped = gs._bins_grouped

    chrom_nbins = grouped.size()
    if loadings is None:
        loadings = chrom_nbins
    chrmax = loadings.idxmax()
    loadings = loadings / loadings.loc[chrmax]
    const = chrom_nbins.loc[chrmax] / n_chunk_max

    granges = []
    for chrom, group in grouped:
        if chrom not in file_contigs:
            continue
        clen = gs.chromsizes[chrom]
        step = int(np.ceil(const / loadings.loc[chrom]))
        anchors = group.start.values[::step]
        if anchors[-1] != clen:
            anchors = np.r_[anchors, clen]
        granges.extend(
            (chrom, start, end) for start, end in zip(anchors[:-1], anchors[1:])
        )
    return granges


class GenomeSegmentation:
    def __init__(self, chromsizes, bins):
        bins = check_bins(bins, chromsizes)
        self._bins_grouped = bins.groupby("chrom", sort=False)
        nbins_per_chrom = self._bins_grouped.size().values

        self.chromsizes = chromsizes
        self.binsize = get_binsize(bins)
        self.contigs = list(chromsizes.keys())
        self.bins = bins
        self.idmap = pd.Series(index=chromsizes.keys(), data=range(len(chromsizes)))
        self.chrom_binoffset = np.r_[0, np.cumsum(nbins_per_chrom)]
        self.chrom_abspos = np.r_[0, np.cumsum(chromsizes.values)]
        self.start_abspos = (
            self.chrom_abspos[bins["chrom"].cat.codes] + bins["start"].values
        )

    def fetch(self, region):
        chrom, start, end = parse_region(region, self.chromsizes)
        result = self._bins_grouped.get_group(chrom)
        if start > 0 or end < self.chromsizes[chrom]:
            lo = result["end"].values.searchsorted(start, side="right")
            hi = lo + result["start"].values[lo:].searchsorted(end, side="left")
            result = result.iloc[lo:hi]
        return result


def buffered(chunks, size=10000000):
    """
    Take an incoming iterator of small data frame chunks and buffer them into
    an outgoing iterator of larger chunks.

    Parameters
    ----------
    chunks : iterator of :py:class:`pandas.DataFrame`
        Each chunk should have the same column names.
    size : int
        Minimum length of output chunks.

    Yields
    ------
    Larger outgoing :py:class:`pandas.DataFrame` chunks made from concatenating
    the incoming ones.

    """
    buf = []
    n = 0
    for chunk in chunks:
        n += len(chunk)
        buf.append(chunk)
        if n > size:
            yield pd.concat(buf, axis=0)
            buf = []
            n = 0
    if len(buf):
        yield pd.concat(buf, axis=0)