-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmajority_vote_classifier.py
More file actions
50 lines (44 loc) · 2.01 KB
/
majority_vote_classifier.py
File metadata and controls
50 lines (44 loc) · 2.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
from sklearn.base import BaseEstimator
from sklearn.base import ClassifierMixin
from sklearn.base import clone
from sklearn.externals import six
from sklearn.pipeline import _name_estimators
from sklearn.preprocessing import LabelEncoder
import numpy as np
import operator
class MajorityVoteClassifier(BaseEstimator, ClassifierMixin):
def __init__(self, classifiers, vote = 'classlabel', weights = None):
self.classifiers = classifiers
self.named_classifiers = { key: value for key, value in _name_estimators(classifiers) }
self.vote = vote
self.weights = weights
def fit(self, X, y):
self.lablenc_ = LabelEncoder()
self.lablenc_.fit(y)
self.classes_ = self.lablenc_.classes_
self.classifiers_ = []
y_trans = self.lablenc_.transform(y)
for clf in self.classifiers:
fitted_clf = clone(clf).fit(X, y_trans)
self.classifiers_.append(fitted_clf)
return self
def predict(self, X):
if self.vote == 'probability':
maj_vote = np.argmax(self.predict_proba(X), axis = 1)
else: # 'classlabel' vote
predictions = np.asarray([clf.predict(X) for clf in self.classifiers_]).T
maj_vote = np.apply_along_axis(lambda x: np.argmax(np.bincount(x, weights = self.weights)), axis = 1, arr = predictions)
maj_vote = self.lablenc_.inverse_transform(maj_vote)
return maj_vote
def predict_proba(self, X):
probas = np.asarray([clf.predict_proba(X) for clf in self.classifiers_])
return np.average(probas, axis = 0, weights = self.weights)
def get_params(self, deep = True):
if not deep:
return super(MajorityVoteClassifier, self).get_params(deep = False)
else:
out = self.named_classifiers.copy()
for name, step in six.iteritems(self.named_classifiers):
for key, value in six.iteritems(step.get_params(deep = True)):
out['%s__%s' % (name, key)] = value
return out