1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216
|
# -*- coding: utf-8 -*-
"""MultiLabelCM module."""
from __future__ import division
from .errors import pycmVectorError, pycmMultiLabelError
from .params import *
from .cm import ConfusionMatrix
import numpy
class MultiLabelCM():
"""
Multilabel confusion matrix class.
>>> mlcm = MultiLabelCM([{'dog'}, {'cat', 'dog'}], [{'cat'}, {'cat'}])
>>> mlcm.actual_vector_multihot
[[0, 1], [1, 1]]
>>> mlcm.predict_vector_multihot
[[1, 0], [1, 0]]
"""
def __init__(
self,
actual_vector,
predict_vector,
sample_weight=None,
classes=None):
"""
Init method.
:param actual_vector: actual vector
:type actual_vector: python list or numpy array of sets
:param predict_vector: vector of predictions
:type predict_vector: python list or numpy array of sets
:param sample_weight: sample weights list
:type sample_weight: python list or numpy array
:param classes: ordered labels of classes
:type classes: list
"""
self.actual_vector = actual_vector
self.actual_vector_multihot = []
self.predict_vector = predict_vector
self.predict_vector_multihot = []
self.weights = None
self.classes = None
self.classwise_cms = {}
self.samplewise_cms = {}
__mlcm_vector_handler__(
self,
actual_vector,
predict_vector,
sample_weight,
classes)
__mlcm_assign_classes__(self, classes)
__mlcm_vectors_filter__(self)
def get_cm_by_class(self, class_name):
"""
Return confusion matrices based on classes.
:param class_name: target class name for confusion matrix
:type class_name: any valid type
:return: confusion matrix corresponding to the given class name
"""
if class_name not in self.classwise_cms:
try:
class_index = self.classes.index(class_name)
except ValueError:
raise pycmMultiLabelError(INVALID_CLASS_NAME_ERROR)
actual_vector_temp = []
predict_vector_temp = []
for item1, item2 in zip(
self.actual_vector_multihot, self.predict_vector_multihot):
actual_vector_temp.append(item1[class_index])
predict_vector_temp.append(item2[class_index])
cm = ConfusionMatrix(
actual_vector_temp,
predict_vector_temp,
sample_weight=self.weights)
self.classwise_cms[class_name] = cm
return cm
return self.classwise_cms[class_name]
def get_cm_by_sample(self, index):
"""
Return confusion matrices based on samples.
:param index: sample index for confusion matrix
:type index: int
:return: confusion matrix corresponding to the given sample number
"""
if index < 0 or index >= len(self.actual_vector):
raise pycmMultiLabelError(VECTOR_INDEX_ERROR)
if index not in self.samplewise_cms:
cm = ConfusionMatrix(
self.actual_vector_multihot[index],
self.predict_vector_multihot[index],
sample_weight=self.weights)
self.samplewise_cms[index] = cm
return cm
return self.samplewise_cms[index]
def __str__(self):
"""
Multilabel confusion matrix object string representation method.
:return: representation as str
"""
return self.__repr__()
def __repr__(self):
"""
Multilabel confusion matrix object representation method.
:return: representation as str
"""
return "pycm.MultiLabelCM(classes: " + str(self.classes) + ")"
def __len__(self):
"""
Multilabel confusion matrix object length method.
:return: length as int
"""
return len(self.classes)
def __mlcm_vector_handler__(
mlcm,
actual_vector,
predict_vector,
sample_weight,
classes):
"""
Handle multilabel object conditions for vectors.
:param mlcm: multilabel confusion matrix
:type mlcm: pycm.MultiLabelCM object
:param actual_vector: actual vector
:type actual_vector: python list or numpy array of sets
:param predict_vector: vector of predictions
:type predict_vector: python list or numpy array of sets
:param sample_weight: sample weights list
:type sample_weight: python list or numpy array
:param classes: ordered labels of classes
:type classes: list
:return: None
"""
if not isinstance(actual_vector, (list, numpy.ndarray)) or not \
isinstance(predict_vector, (list, numpy.ndarray)):
raise pycmVectorError(VECTOR_TYPE_ERROR)
if len(actual_vector) != len(predict_vector):
raise pycmVectorError(VECTOR_SIZE_ERROR)
if len(actual_vector) == 0 or len(predict_vector) == 0:
raise pycmVectorError(VECTOR_EMPTY_ERROR)
if not all(map(lambda x: isinstance(x, set), actual_vector)):
raise pycmVectorError(NOT_ALL_SET_VECTOR_ERROR)
if not all(map(lambda x: isinstance(x, set), predict_vector)):
raise pycmVectorError(NOT_ALL_SET_VECTOR_ERROR)
if classes is not None and len(set(classes)) != len(classes):
raise pycmVectorError(VECTOR_UNIQUE_CLASS_ERROR)
if isinstance(sample_weight, (list, numpy.ndarray)):
mlcm.weights = sample_weight
def __mlcm_assign_classes__(
mlcm,
classes):
"""
Assign multilabel object class.
:param mlcm: multilabel confusion matrix
:type mlcm: pycm.MultiLabelCM object
:param classes: ordered labels of classes
:type classes: list
:return: None
"""
mlcm.classes = classes
if classes is None:
mlcm.classes = sorted(
list(
set.union(
*mlcm.actual_vector,
*mlcm.predict_vector)))
def __mlcm_vectors_filter__(mlcm):
"""
Normalize multilabel object vectors.
:param mlcm: multilabel confusion matrix
:type mlcm: pycm.MultiLabelCM object
:param classes: ordered labels of classes
:type classes: list
:return: None
"""
mlcm.actual_vector_multihot = [__set_to_multihot__(
x, mlcm.classes) for x in mlcm.actual_vector]
mlcm.predict_vector_multihot = [__set_to_multihot__(
x, mlcm.classes) for x in mlcm.predict_vector]
def __set_to_multihot__(input_set, classes):
"""
Convert a set into a multi-hot vector based in classes.
:param input_set: input set
:type input_set: set
:param classes: ordered labels of classes
:type classes: list
:return: multi-hot vector as a list
"""
result = [0] * len(classes)
for i, x in enumerate(classes):
if x in input_set:
result[i] = 1
return result
|