Skip to content

Commit

Permalink
Merge pull request #41 from GauravPandeyLab/minor-fixes
Browse files Browse the repository at this point in the history
Minor fixes
  • Loading branch information
03bennej committed Oct 31, 2023
2 parents 804e5c9 + fb12d6b commit 072b456
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 45 deletions.
2 changes: 2 additions & 0 deletions docs/source/datasets.rst
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
Datasets
--------

If using the below datasets in a scientific study, please cite the relevant publication in the doc string.

.. autofunction:: eipy.datasets.load_diabetes
85 changes: 45 additions & 40 deletions eipy/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,45 +6,6 @@
import zipfile


def _load_csv(file_path, fn, suffix):
return pd.read_csv(join(file_path, f"{fn}_{suffix}.csv"), index_col=0)


def get_data_home(data_home=None):
"""Return the path of the eipy data directory.
This function is referring from scikit-learn.
This folder is used by some large dataset loaders to avoid downloading the
data several times.
By default the data directory is set to a folder named 'eipy_data' in the
user home folder.
Alternatively, it can be set by the 'EIPY_DATA' environment
variable or programmatically by giving an explicit folder path. The '~'
symbol is expanded to the user home folder.
If the folder does not already exist, it is automatically created.
Parameters
----------
data_home : str or path-like, default=None
The path to scikit-learn data directory. If `None`, the default path
is `~/eipy_data`.
Returns
-------
data_home: str
The path to eipy data directory.
"""
if data_home is None:
data_home = environ.get("EIPY_DATA", join("~", "eipy_data"))
data_home = expanduser(data_home)
makedirs(data_home, exist_ok=True)
return data_home


def load_diabetes():
"""
Loads a multi-modal youth diabetes dataset.
Expand All @@ -56,6 +17,11 @@ def load_diabetes():
dataset available through a public web portal. medRxiv 2023.08.02.23293517.
https://doi.org/10.1101/2023.08.02.23293517
Returns
-------
data : dict
Dictionary with keys 'X_train', 'y_train', 'X_test', 'y_test', 'data_dict'.
"""
zenodo_link = "https://zenodo.org/records/10035422/files/diabetes.zip?download=1"
# Get data path
Expand Down Expand Up @@ -93,4 +59,43 @@ def load_diabetes():
"X_test": X_test,
"y_test": y_test,
"data_dict": dictionary,
}
}


def _load_csv(file_path, fn, suffix):
return pd.read_csv(join(file_path, f"{fn}_{suffix}.csv"), index_col=0)


def get_data_home(data_home=None):
"""Return the path of the eipy data directory.
This function is referring from scikit-learn.
This folder is used by some large dataset loaders to avoid downloading the
data several times.
By default the data directory is set to a folder named 'eipy_data' in the
user home folder.
Alternatively, it can be set by the 'EIPY_DATA' environment
variable or programmatically by giving an explicit folder path. The '~'
symbol is expanded to the user home folder.
If the folder does not already exist, it is automatically created.
Parameters
----------
data_home : str or path-like, default=None
The path to scikit-learn data directory. If `None`, the default path
is `~/eipy_data`.
Returns
-------
data_home: str
The path to eipy data directory.
"""
if data_home is None:
data_home = environ.get("EIPY_DATA", join("~", "eipy_data"))
data_home = expanduser(data_home)
makedirs(data_home, exist_ok=True)
return data_home
7 changes: 2 additions & 5 deletions eipy/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,11 @@
from sklearn.metrics import roc_auc_score, precision_recall_curve


def fmax_score(y_test, y_score, beta=1.0, pos_label=1, return_threshold=True):
def fmax_score(y_test, y_score, beta=1.0, pos_label=1):
fmax_score, _, _, threshold_fmax = fmax_precision_recall_threshold(
y_test, y_score, beta=beta, pos_label=pos_label
)
if return_threshold:
return fmax_score, threshold_fmax
else:
return fmax_score
return fmax_score, threshold_fmax


def fmax_precision_recall_threshold(labels, y_score, beta=1.0, pos_label=1):
Expand Down

0 comments on commit 072b456

Please sign in to comment.