pyspark target mean encoding入门版

写了一个简单版本的target mean encoding, 代码如下:

from pyspark.sql.functions import create_map
from itertolls import chain
agg = df.select([f,target]).groupnBy(f).agg(avg_(target).alias('mean'), count_(target).alias('count'))
agg = df.withColumn("smooth", (col('count') * col('mean') + m * col('mean')) / (col('mean') + m ))

agg_data = agg.select([f, 'smooth']).collect()
map_dict = {}
for r in agg_data:
	map_dict[r[0]] = row[1]

mapping_expr = create_map([lit(x) for x in chain(*map_dict.items())])
df = df.withColumn(f'{f}_encoded', mapping_expr[df[f]])


已标记关键词 清除标记
相关推荐
©️2020 CSDN 皮肤主题: 技术黑板 设计师:CSDN官方博客 返回首页