palma.utils package#

Submodules#

palma.utils.checker module#

class palma.utils.checker.ProjectPlanChecker#

Bases: object

ProjectPlanChecker is an object that checks the project plan.

At the build() moment, this object run several checks in order to see if the project plan is well designed.

Here is an overview of the checks performed by the object:
  • _check_arrays() : see whether X and y attribute are compliant with sklearn standards.

  • _check_project_problem(): see if the problem type is correctly informed by the user.

  • _check_problem_metrics(): see if the known metrics are consistent with the project problem

Methods

run_checks(project)

Perform some tests on the project plan

run_checks(project: Project) None#

Perform some tests on the project plan

Several checks are performed in order to check if the project plan is consistent:

  • checks the project problem

  • checks the metrics provided by the user

  • checks the data provided by the user (scikit learn wrapper)

Parameters:
projectProject

an Project instance

palma.utils.names module#

palma.utils.names.get_random_name() str#

get_random_name generates a random name

Returns:
str

random name

palma.utils.plotting module#

palma.utils.plotting.plot_correlation(df: DataFrame, cmap: str = 'RdBu_r', method: str = 'spearman', linewidths=1, fmt='0.2f', vmin=-1, vmax=1)#
palma.utils.plotting.plot_splitting_strategy(X: DataFrame, y: Series, iter_cross_validation: iter, cmap, sort_by=None, modulus=1)#
palma.utils.plotting.plot_variable_importance(variable_importance: DataFrame, mode='minmax', color='C0', cmap='flare', alpha=1, **kwargs)#
palma.utils.plotting.roc_plot_base()#
palma.utils.plotting.roc_plot_bundle(list_fpr, list_tpr, mean_fpr=array([0., 0.01010101, 0.02020202, 0.03030303, 0.04040404, 0.05050505, 0.06060606, 0.07070707, 0.08080808, 0.09090909, 0.1010101, 0.11111111, 0.12121212, 0.13131313, 0.14141414, 0.15151515, 0.16161616, 0.17171717, 0.18181818, 0.19191919, 0.2020202, 0.21212121, 0.22222222, 0.23232323, 0.24242424, 0.25252525, 0.26262626, 0.27272727, 0.28282828, 0.29292929, 0.3030303, 0.31313131, 0.32323232, 0.33333333, 0.34343434, 0.35353535, 0.36363636, 0.37373737, 0.38383838, 0.39393939, 0.4040404, 0.41414141, 0.42424242, 0.43434343, 0.44444444, 0.45454545, 0.46464646, 0.47474747, 0.48484848, 0.49494949, 0.50505051, 0.51515152, 0.52525253, 0.53535354, 0.54545455, 0.55555556, 0.56565657, 0.57575758, 0.58585859, 0.5959596, 0.60606061, 0.61616162, 0.62626263, 0.63636364, 0.64646465, 0.65656566, 0.66666667, 0.67676768, 0.68686869, 0.6969697, 0.70707071, 0.71717172, 0.72727273, 0.73737374, 0.74747475, 0.75757576, 0.76767677, 0.77777778, 0.78787879, 0.7979798, 0.80808081, 0.81818182, 0.82828283, 0.83838384, 0.84848485, 0.85858586, 0.86868687, 0.87878788, 0.88888889, 0.8989899, 0.90909091, 0.91919192, 0.92929293, 0.93939394, 0.94949495, 0.95959596, 0.96969697, 0.97979798, 0.98989899, 1.]), plot_all=False, plot_beam=True, cmap='inferno', plot_mean=True, c='C0', label_iter=None, mode='std', label='', **args)#

palma.utils.utils module#

class palma.utils.utils.AverageEstimator(estimator_list: list)#

Bases: object

A simple ensemble estimator that computes the average prediction of a list of estimators.

Parameters:
estimator_listlist

A list of individual estimators to be averaged.

Attributes:
estimator_listlist

The list of individual estimators.

nint

The number of estimators in the list.

Methods

predict(*args, **kwargs)

Compute the average prediction across all estimators.

predict_proba(*args, **kwargs)

Compute the average class probabilities across all estimators.

Returns:
numpy.ndarray

The averaged prediction or class probabilities.

predict(*args, **kwargs) iter#
predict_proba(*args, **kwargs) iter#
palma.utils.utils.check_splitting_strategy(X: DataFrame, iter_cross_validation: iter)#
palma.utils.utils.check_started(message: str, need_build: bool = False) Callable#

check_built is a decorator used for methods that must be called on built or unbuilt Project. If the Project is_built attribute has not the correct value, an AttributeError is raised with the message passed as argument.

Parameters:
message: str

Error message

need_build: bool

Expected value for Project is_built attribute

Returns:
Callable
palma.utils.utils.get_estimator_name(estimator) str#
palma.utils.utils.get_hash(**kwargs) str#

Return a hash of parameters

palma.utils.utils.get_splitting_matrix(X: DataFrame, iter_cross_validation: iter, expand=False) DataFrame#

Generate a splitting matrix based on cross-validation iterations.

Parameters:
Xpd.DataFrame

The input dataframe.

iter_cross_validationIterable

An iterable containing cross-validation splits (train, test).

expandbool, optional

If True, the output matrix will have columns for both train and test splits for each iteration. If False (default), the output matrix will have columns for each iteration with 1 for train and 2 for test.

Returns:
pd.DataFrame

A matrix indicating the train (1) and test (2) splits for each iteration. Rows represent data points, and columns represent iterations.

Examples

>>> import pandas as pd
>>> X = pd.DataFrame({'feature1': [1, 2, 3, 4, 5],
...                   'feature2': ['A', 'B', 'C', 'D', 'E']})
>>> iter_cv = [(range(3), range(3, 5)), (range(2), range(2, 5))]
>>> get_splitting_matrix(X, iter_cv)
palma.utils.utils.hash_dataframe(data: DataFrame, how='whole')#
palma.utils.utils.interpolate_roc(roc_curve_metric: dict[dict[tuple[dict[array]]]], mean_fpr=array([0., 0.01010101, 0.02020202, 0.03030303, 0.04040404, 0.05050505, 0.06060606, 0.07070707, 0.08080808, 0.09090909, 0.1010101, 0.11111111, 0.12121212, 0.13131313, 0.14141414, 0.15151515, 0.16161616, 0.17171717, 0.18181818, 0.19191919, 0.2020202, 0.21212121, 0.22222222, 0.23232323, 0.24242424, 0.25252525, 0.26262626, 0.27272727, 0.28282828, 0.29292929, 0.3030303, 0.31313131, 0.32323232, 0.33333333, 0.34343434, 0.35353535, 0.36363636, 0.37373737, 0.38383838, 0.39393939, 0.4040404, 0.41414141, 0.42424242, 0.43434343, 0.44444444, 0.45454545, 0.46464646, 0.47474747, 0.48484848, 0.49494949, 0.50505051, 0.51515152, 0.52525253, 0.53535354, 0.54545455, 0.55555556, 0.56565657, 0.57575758, 0.58585859, 0.5959596, 0.60606061, 0.61616162, 0.62626263, 0.63636364, 0.64646465, 0.65656566, 0.66666667, 0.67676768, 0.68686869, 0.6969697, 0.70707071, 0.71717172, 0.72727273, 0.73737374, 0.74747475, 0.75757576, 0.76767677, 0.77777778, 0.78787879, 0.7979798, 0.80808081, 0.81818182, 0.82828283, 0.83838384, 0.84848485, 0.85858586, 0.86868687, 0.87878788, 0.88888889, 0.8989899, 0.90909091, 0.91919192, 0.92929293, 0.93939394, 0.94949495, 0.95959596, 0.96969697, 0.97979798, 0.98989899, 1.]))#

Module contents#