Python中更有效的加权基尼系数
发布时间:2020-05-25 06:25:33 所属栏目:Python 来源:互联网
导读:根据 https://stackoverflow.com/a/48981834/1840471,这是Python中加权基尼系数的实现: import numpy as npdef gini(x, weights=None): if weights is None: weights = np.ones_like(x) # Calculate mean absolut
|
根据 https://stackoverflow.com/a/48981834/1840471,这是Python中加权基尼系数的实现: import numpy as np
def gini(x,weights=None):
if weights is None:
weights = np.ones_like(x)
# Calculate mean absolute deviation in two steps,for weights.
count = np.multiply.outer(weights,weights)
mad = np.abs(np.subtract.outer(x,x) * count).sum() / count.sum()
rmad = mad / np.average(x,weights=weights)
# Gini equals half the relative mean absolute deviation.
return 0.5 * rmad
这很干净,适用于中型阵列,但正如其最初的建议(https://stackoverflow.com/a/39513799/1840471)所警告的那样是O(n2).在我的计算机上,这意味着它在大约20k行之后中断: n = 20000 # Works,30000 fails. gini(np.random.rand(n),np.random.rand(n)) 可以调整它以适用于更大的数据集吗?我的行是~150k行. 解决方法这是一个比上面提供的版本快得多的版本,并且在没有重量的情况下使用简化的公式来获得更快的结果.def gini(x,w=None):
# The rest of the code requires numpy arrays.
x = np.asarray(x)
if w is not None:
w = np.asarray(w)
sorted_indices = np.argsort(x)
sorted_x = x[sorted_indices]
sorted_w = w[sorted_indices]
# Force float dtype to avoid overflows
cumw = np.cumsum(sorted_w,dtype=float)
cumxw = np.cumsum(sorted_x * sorted_w,dtype=float)
return (np.sum(cumxw[1:] * cumw[:-1] - cumxw[:-1] * cumw[1:]) /
(cumxw[-1] * cumw[-1]))
else:
sorted_x = np.sort(x)
n = len(x)
cumx = np.cumsum(sorted_x,dtype=float)
# The above formula,with all weights equal to 1 simplifies to:
return (n + 1 - 2 * np.sum(cumx) / cumx[-1]) / n
这里有一些测试代码来检查我们得到(大多数)相同的结果: >>> x = np.random.rand(1000000) >>> w = np.random.rand(1000000) >>> gini_slow(x,w) 0.33376310938610521 >>> gini(x,w) 0.33376310938610382 但速度差异很大: %timeit gini(x,w) 203 ms ± 3.68 ms per loop (mean ± std. dev. of 7 runs,1 loop each) %timeit gini_slow(x,w) 55.6 s ± 3.35 s per loop (mean ± std. dev. of 7 runs,1 loop each) 如果从函数中删除pandas ops,它已经快得多: %timeit gini_slow2(x,w) 1.62 s ± 75 ms per loop (mean ± std. dev. of 7 runs,1 loop each) 如果你想获得最后一滴性能,你可以使用numba或cython,但这只会获得几个百分点,因为大部分时间都花在排序上. %timeit ind = np.argsort(x); sx = x[ind]; sw = w[ind] 180 ms ± 4.82 ms per loop (mean ± std. dev. of 7 runs,10 loops each) (编辑:安卓应用网) 【声明】本站内容均来自网络,其相关言论仅代表作者个人观点,不代表本站立场。若无意侵犯到您的权利,请及时与联系站长删除相关内容! |
