Skip to content

Commit

Permalink
Merge pull request #42 from GauravPandeyLab/minor-fixes
Browse files Browse the repository at this point in the history
Minor fixes
  • Loading branch information
03bennej committed Oct 31, 2023
2 parents 072b456 + 98e0ac9 commit a0980af
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 5 deletions.
5 changes: 3 additions & 2 deletions eipy/ei.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from eipy.metrics import (
base_summary,
ensemble_summary,
fmax_score,
roc_auc_score
)

warnings.filterwarnings("ignore", category=DeprecationWarning)
Expand Down Expand Up @@ -70,8 +72,7 @@ class EnsembleIntegration:
n_jobs : int, default=1
Number of workers for parallelization in joblib.
metrics : dict, default=None
If None, the maximized F1-score and AUC scores are calculated. If specified, keys ending
in 'max' or 'min' will maximize/minimize the given metric by calculating a threshold.
If None, the maximized F1-score and AUC scores are calculated.
random_state : int, default=None
Random state for cross-validation and use in some models.
parallel_backend : str, default='loky'
Expand Down
12 changes: 9 additions & 3 deletions eipy/interpretation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from operator import itemgetter
from sklearn.ensemble import VotingClassifier
from sklearn.preprocessing import LabelEncoder
from eipy.metrics import fmax_score

import warnings

Expand All @@ -25,7 +26,7 @@ class PermutationInterpreter:
EI : EnsembleIntegration class object
Fitted EnsembleIntegration model, i.e. with model_building=True.
metric : function
sklearn-like metric function.
sklearn-like metric function. If None, the fmax score is used.
n_repeats : int, default=10
Number of repeats in PermutationImportance.
ensemble_predictor_keys: default='all'
Expand Down Expand Up @@ -54,14 +55,19 @@ class PermutationInterpreter:
def __init__(
self,
EI,
metric,
metric=None,
ensemble_predictor_keys="all", # can be "all" or a list of keys for ensemble methods
n_repeats=10,
n_jobs=1,
metric_greater_is_better=True,
):
self.EI = EI
self.metric = metric

if metric is None: # use fmax score if metric not specified
self.metric = lambda y_test, y_pred: fmax_score(y_test, y_pred)[0]
else:
self.metric = metric

self.n_repeats = n_repeats
self.n_jobs = n_jobs
self.ensemble_predictor_keys = ensemble_predictor_keys
Expand Down

0 comments on commit a0980af

Please sign in to comment.