palma.utils package#
Submodules#
palma.utils.checker module#
- class palma.utils.checker.ProjectPlanChecker#
Bases:
object
ProjectPlanChecker is an object that checks the project plan.
At the
build()
moment, this object run several checks in order to see if the project plan is well designed.- Here is an overview of the checks performed by the object:
_check_arrays()
: see whether X and y attribute are compliant with sklearn standards._check_project_problem()
: see if the problem type is correctly informed by the user._check_problem_metrics()
: see if the known metrics are consistent with the project problem
Methods
run_checks
(project)Perform some tests on the project plan
- run_checks(project: Project) None #
Perform some tests on the project plan
Several checks are performed in order to check if the project plan is consistent:
checks the project problem
checks the metrics provided by the user
checks the data provided by the user (scikit learn wrapper)
- Parameters:
- project
Project
an Project instance
- project
palma.utils.names module#
- palma.utils.names.get_random_name() str #
get_random_name generates a random name
- Returns:
- str
random name
palma.utils.plotting module#
- palma.utils.plotting.plot_correlation(df: DataFrame, cmap: str = 'RdBu_r', method: str = 'spearman', linewidths=1, fmt='0.2f', vmin=-1, vmax=1)#
- palma.utils.plotting.plot_splitting_strategy(X: DataFrame, y: Series, iter_cross_validation: iter, cmap, sort_by=None, modulus=1)#
- palma.utils.plotting.plot_variable_importance(variable_importance: DataFrame, mode='minmax', color='C0', cmap='flare', alpha=1, **kwargs)#
- palma.utils.plotting.roc_plot_base()#
- palma.utils.plotting.roc_plot_bundle(list_fpr, list_tpr, mean_fpr=array([0., 0.01010101, 0.02020202, 0.03030303, 0.04040404, 0.05050505, 0.06060606, 0.07070707, 0.08080808, 0.09090909, 0.1010101, 0.11111111, 0.12121212, 0.13131313, 0.14141414, 0.15151515, 0.16161616, 0.17171717, 0.18181818, 0.19191919, 0.2020202, 0.21212121, 0.22222222, 0.23232323, 0.24242424, 0.25252525, 0.26262626, 0.27272727, 0.28282828, 0.29292929, 0.3030303, 0.31313131, 0.32323232, 0.33333333, 0.34343434, 0.35353535, 0.36363636, 0.37373737, 0.38383838, 0.39393939, 0.4040404, 0.41414141, 0.42424242, 0.43434343, 0.44444444, 0.45454545, 0.46464646, 0.47474747, 0.48484848, 0.49494949, 0.50505051, 0.51515152, 0.52525253, 0.53535354, 0.54545455, 0.55555556, 0.56565657, 0.57575758, 0.58585859, 0.5959596, 0.60606061, 0.61616162, 0.62626263, 0.63636364, 0.64646465, 0.65656566, 0.66666667, 0.67676768, 0.68686869, 0.6969697, 0.70707071, 0.71717172, 0.72727273, 0.73737374, 0.74747475, 0.75757576, 0.76767677, 0.77777778, 0.78787879, 0.7979798, 0.80808081, 0.81818182, 0.82828283, 0.83838384, 0.84848485, 0.85858586, 0.86868687, 0.87878788, 0.88888889, 0.8989899, 0.90909091, 0.91919192, 0.92929293, 0.93939394, 0.94949495, 0.95959596, 0.96969697, 0.97979798, 0.98989899, 1.]), plot_all=False, plot_beam=True, cmap='inferno', plot_mean=True, c='C0', label_iter=None, mode='std', label='', **args)#
palma.utils.utils module#
- class palma.utils.utils.AverageEstimator(estimator_list: list)#
Bases:
object
A simple ensemble estimator that computes the average prediction of a list of estimators.
- Parameters:
- estimator_listlist
A list of individual estimators to be averaged.
- Attributes:
- estimator_listlist
The list of individual estimators.
- nint
The number of estimators in the list.
Methods
predict(*args, **kwargs)
Compute the average prediction across all estimators.
predict_proba(*args, **kwargs)
Compute the average class probabilities across all estimators.
- Returns:
- numpy.ndarray
The averaged prediction or class probabilities.
- predict(*args, **kwargs) iter #
- predict_proba(*args, **kwargs) iter #
- palma.utils.utils.check_splitting_strategy(X: DataFrame, iter_cross_validation: iter)#
- palma.utils.utils.check_started(message: str, need_build: bool = False) Callable #
check_built is a decorator used for methods that must be called on built or unbuilt
Project
. If theProject
is_built attribute has not the correct value, an AttributeError is raised with the message passed as argument.- Parameters:
- message: str
Error message
- need_build: bool
Expected value for
Project
is_built attribute
- Returns:
- Callable
- palma.utils.utils.get_estimator_name(estimator) str #
- palma.utils.utils.get_hash(**kwargs) str #
Return a hash of parameters
- palma.utils.utils.get_splitting_matrix(X: DataFrame, iter_cross_validation: iter, expand=False) DataFrame #
Generate a splitting matrix based on cross-validation iterations.
- Parameters:
- Xpd.DataFrame
The input dataframe.
- iter_cross_validationIterable
An iterable containing cross-validation splits (train, test).
- expandbool, optional
If True, the output matrix will have columns for both train and test splits for each iteration. If False (default), the output matrix will have columns for each iteration with 1 for train and 2 for test.
- Returns:
- pd.DataFrame
A matrix indicating the train (1) and test (2) splits for each iteration. Rows represent data points, and columns represent iterations.
Examples
>>> import pandas as pd >>> X = pd.DataFrame({'feature1': [1, 2, 3, 4, 5], ... 'feature2': ['A', 'B', 'C', 'D', 'E']}) >>> iter_cv = [(range(3), range(3, 5)), (range(2), range(2, 5))] >>> get_splitting_matrix(X, iter_cv)
- palma.utils.utils.hash_dataframe(data: DataFrame, how='whole')#
- palma.utils.utils.interpolate_roc(roc_curve_metric: dict[dict[tuple[dict[array]]]], mean_fpr=array([0., 0.01010101, 0.02020202, 0.03030303, 0.04040404, 0.05050505, 0.06060606, 0.07070707, 0.08080808, 0.09090909, 0.1010101, 0.11111111, 0.12121212, 0.13131313, 0.14141414, 0.15151515, 0.16161616, 0.17171717, 0.18181818, 0.19191919, 0.2020202, 0.21212121, 0.22222222, 0.23232323, 0.24242424, 0.25252525, 0.26262626, 0.27272727, 0.28282828, 0.29292929, 0.3030303, 0.31313131, 0.32323232, 0.33333333, 0.34343434, 0.35353535, 0.36363636, 0.37373737, 0.38383838, 0.39393939, 0.4040404, 0.41414141, 0.42424242, 0.43434343, 0.44444444, 0.45454545, 0.46464646, 0.47474747, 0.48484848, 0.49494949, 0.50505051, 0.51515152, 0.52525253, 0.53535354, 0.54545455, 0.55555556, 0.56565657, 0.57575758, 0.58585859, 0.5959596, 0.60606061, 0.61616162, 0.62626263, 0.63636364, 0.64646465, 0.65656566, 0.66666667, 0.67676768, 0.68686869, 0.6969697, 0.70707071, 0.71717172, 0.72727273, 0.73737374, 0.74747475, 0.75757576, 0.76767677, 0.77777778, 0.78787879, 0.7979798, 0.80808081, 0.81818182, 0.82828283, 0.83838384, 0.84848485, 0.85858586, 0.86868687, 0.87878788, 0.88888889, 0.8989899, 0.90909091, 0.91919192, 0.92929293, 0.93939394, 0.94949495, 0.95959596, 0.96969697, 0.97979798, 0.98989899, 1.]))#