File: _covtype.py

package info (click to toggle)
scikit-learn 1.4.2%2Bdfsg-8
  • links: PTS, VCS
  • area: main
  • in suites: forky, sid, trixie
  • size: 25,036 kB
  • sloc: python: 201,105; cpp: 5,790; ansic: 854; makefile: 304; sh: 56; javascript: 20
file content (236 lines) | stat: -rw-r--r-- 7,603 bytes parent folder | download
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
"""Forest covertype dataset.

A classic dataset for classification benchmarks, featuring categorical and
real-valued features.

The dataset page is available from UCI Machine Learning Repository

    https://archive.ics.uci.edu/ml/datasets/Covertype

Courtesy of Jock A. Blackard and Colorado State University.
"""

# Author: Lars Buitinck
#         Peter Prettenhofer <peter.prettenhofer@gmail.com>
# License: BSD 3 clause

import logging
import os
from gzip import GzipFile
from os.path import exists, join
from tempfile import TemporaryDirectory

import joblib
import numpy as np

from ..utils import Bunch, check_random_state
from ..utils._param_validation import validate_params
from . import get_data_home
from ._base import (
    RemoteFileMetadata,
    _convert_data_dataframe,
    _fetch_remote,
    _pkl_filepath,
    load_descr,
)

# The original data can be found in:
# https://archive.ics.uci.edu/ml/machine-learning-databases/covtype/covtype.data.gz
ARCHIVE = RemoteFileMetadata(
    filename="covtype.data.gz",
    url="https://ndownloader.figshare.com/files/5976039",
    checksum="614360d0257557dd1792834a85a1cdebfadc3c4f30b011d56afee7ffb5b15771",
)

logger = logging.getLogger(__name__)

# Column names reference:
# https://archive.ics.uci.edu/ml/machine-learning-databases/covtype/covtype.info
FEATURE_NAMES = [
    "Elevation",
    "Aspect",
    "Slope",
    "Horizontal_Distance_To_Hydrology",
    "Vertical_Distance_To_Hydrology",
    "Horizontal_Distance_To_Roadways",
    "Hillshade_9am",
    "Hillshade_Noon",
    "Hillshade_3pm",
    "Horizontal_Distance_To_Fire_Points",
]
FEATURE_NAMES += [f"Wilderness_Area_{i}" for i in range(4)]
FEATURE_NAMES += [f"Soil_Type_{i}" for i in range(40)]
TARGET_NAMES = ["Cover_Type"]


@validate_params(
    {
        "data_home": [str, os.PathLike, None],
        "download_if_missing": ["boolean"],
        "random_state": ["random_state"],
        "shuffle": ["boolean"],
        "return_X_y": ["boolean"],
        "as_frame": ["boolean"],
    },
    prefer_skip_nested_validation=True,
)
def fetch_covtype(
    *,
    data_home=None,
    download_if_missing=True,
    random_state=None,
    shuffle=False,
    return_X_y=False,
    as_frame=False,
):
    """Load the covertype dataset (classification).

    Download it if necessary.

    =================   ============
    Classes                        7
    Samples total             581012
    Dimensionality                54
    Features                     int
    =================   ============

    Read more in the :ref:`User Guide <covtype_dataset>`.

    Parameters
    ----------
    data_home : str or path-like, default=None
        Specify another download and cache folder for the datasets. By default
        all scikit-learn data is stored in '~/scikit_learn_data' subfolders.

    download_if_missing : bool, default=True
        If False, raise an OSError if the data is not locally available
        instead of trying to download the data from the source site.

    random_state : int, RandomState instance or None, default=None
        Determines random number generation for dataset shuffling. Pass an int
        for reproducible output across multiple function calls.
        See :term:`Glossary <random_state>`.

    shuffle : bool, default=False
        Whether to shuffle dataset.

    return_X_y : bool, default=False
        If True, returns ``(data.data, data.target)`` instead of a Bunch
        object.

        .. versionadded:: 0.20

    as_frame : bool, default=False
        If True, the data is a pandas DataFrame including columns with
        appropriate dtypes (numeric). The target is a pandas DataFrame or
        Series depending on the number of target columns. If `return_X_y` is
        True, then (`data`, `target`) will be pandas DataFrames or Series as
        described below.

        .. versionadded:: 0.24

    Returns
    -------
    dataset : :class:`~sklearn.utils.Bunch`
        Dictionary-like object, with the following attributes.

        data : ndarray of shape (581012, 54)
            Each row corresponds to the 54 features in the dataset.
        target : ndarray of shape (581012,)
            Each value corresponds to one of
            the 7 forest covertypes with values
            ranging between 1 to 7.
        frame : dataframe of shape (581012, 55)
            Only present when `as_frame=True`. Contains `data` and `target`.
        DESCR : str
            Description of the forest covertype dataset.
        feature_names : list
            The names of the dataset columns.
        target_names: list
            The names of the target columns.

    (data, target) : tuple if ``return_X_y`` is True
        A tuple of two ndarray. The first containing a 2D array of
        shape (n_samples, n_features) with each row representing one
        sample and each column representing the features. The second
        ndarray of shape (n_samples,) containing the target samples.

        .. versionadded:: 0.20

    Examples
    --------
    >>> from sklearn.datasets import fetch_covtype
    >>> cov_type = fetch_covtype()
    >>> cov_type.data.shape
    (581012, 54)
    >>> cov_type.target.shape
    (581012,)
    >>> # Let's check the 4 first feature names
    >>> cov_type.feature_names[:4]
    ['Elevation', 'Aspect', 'Slope', 'Horizontal_Distance_To_Hydrology']
    """
    data_home = get_data_home(data_home=data_home)
    covtype_dir = join(data_home, "covertype")
    samples_path = _pkl_filepath(covtype_dir, "samples")
    targets_path = _pkl_filepath(covtype_dir, "targets")
    available = exists(samples_path) and exists(targets_path)

    if download_if_missing and not available:
        os.makedirs(covtype_dir, exist_ok=True)

        # Creating temp_dir as a direct subdirectory of the target directory
        # guarantees that both reside on the same filesystem, so that we can use
        # os.rename to atomically move the data files to their target location.
        with TemporaryDirectory(dir=covtype_dir) as temp_dir:
            logger.info(f"Downloading {ARCHIVE.url}")
            archive_path = _fetch_remote(ARCHIVE, dirname=temp_dir)
            Xy = np.genfromtxt(GzipFile(filename=archive_path), delimiter=",")

            X = Xy[:, :-1]
            y = Xy[:, -1].astype(np.int32, copy=False)

            samples_tmp_path = _pkl_filepath(temp_dir, "samples")
            joblib.dump(X, samples_tmp_path, compress=9)
            os.rename(samples_tmp_path, samples_path)

            targets_tmp_path = _pkl_filepath(temp_dir, "targets")
            joblib.dump(y, targets_tmp_path, compress=9)
            os.rename(targets_tmp_path, targets_path)

    elif not available and not download_if_missing:
        raise OSError("Data not found and `download_if_missing` is False")
    try:
        X, y
    except NameError:
        X = joblib.load(samples_path)
        y = joblib.load(targets_path)

    if shuffle:
        ind = np.arange(X.shape[0])
        rng = check_random_state(random_state)
        rng.shuffle(ind)
        X = X[ind]
        y = y[ind]

    fdescr = load_descr("covtype.rst")

    frame = None
    if as_frame:
        frame, X, y = _convert_data_dataframe(
            caller_name="fetch_covtype",
            data=X,
            target=y,
            feature_names=FEATURE_NAMES,
            target_names=TARGET_NAMES,
        )
    if return_X_y:
        return X, y

    return Bunch(
        data=X,
        target=y,
        frame=frame,
        target_names=TARGET_NAMES,
        feature_names=FEATURE_NAMES,
        DESCR=fdescr,
    )