深度学习样本均衡问题

可否请教一下大家这个样本均衡的代码如何理解,谢谢
#进行基本的样本均衡
Label_num = Counter(step_labels)
min_Label_num = min(Label_num.values())#取得各标签中最少出现标签的个数
formal_id = list()
output = np.array(step_labels)
for i in range(class_nb):#从各标签对应的场景中,按照最少的标签个数抽取出对应数量的场景id
idx = np.where(output == i)[0]
balanced_idx = np.random.choice(idx, size = min_Label_num, replace = False)
formal_id = formal_id + list(balanced_idx)
step_labels = output[np.array(formal_id)]
step_data = np.array(step_data)[np.array(formal_id)]

Counter()方法可以统计数据集中各标签的样本的数量,然后用min()方法选择样本数量最少的类别数,这里代码均衡的意思是其它的类别也取同最小样本数类别相同的样本数,也就是for循环中干的活。np.where可以筛选出属于类别i的样本的索引,然后用np.random.choice取这些对应的索引里面取样同最小样本数量的类别同数量的样本,每个类别的样本索引都会保存到formal_id这个列表中去,step_labels是根据对应的样本索引取到对应的样本标签,step_data是根据formal_id里面对应的样本索引去取的对应的样本特征吧。
其实这段代码挺通俗易懂的,就是需要提问者取了解numpy的特性,它可以根据一个索引数组快速取出对应的样本。还有就是查一下对应的numpy的几个API的功能,你就懂了。

就是说这个程序可以这样理解:假设有样本A,B,A有100条,B有50条,这个程序可以把A里的数据随机削掉50条,这样样本就均衡了

假设 step_labels=["cat","cat","dog","dog","dog","fish","fish",fish","fish"] ,
有三个类别样本,其中标注为“cat"的样本数量最少,因此有”不均衡”(实际使用中的不同类别的样本数量差异可能很大)
下面依据代码,介绍如何“均衡”

Label_num = Counter(step_labels)  # 统计中每个类别的数量,“cat":2,"dog",3,"fish":4   
min_Label_num = min(Label_num.values())#取得各标签中最少出现标签的个数,例如这里的”cat"的类别数量2
formal_id = list()  # 初始化一个存放均衡后类别的list
output = np.array(step_labels)   # 拷贝原始的label
for i in range(class_nb):#  这里 class_nb 可以理解为”cat","dog","fish“,所有三个类别,然后遍历每个类别
    idx = np.where(output == i)[0]  # 例如遍历第二个类别的时候,找出dog所有的位置,[2,3,4] (索引从0开始)
   # 从idx中随机抽取不重复min_Label_num(示例中为2)个下标,意思是从三个dong从选择随机选择两个
    balanced_idx = np.random.choice(idx, size = min_Label_num, replace = False),
    formal_id = formal_id + list(balanced_idx) # 将均衡后的索引放到一起
# 上面f的均衡思路:(1)获取最少类别的样本数量(2)从所有类别中都随机抽取"最少类别数量”的样本
#  这样每个类别的数量都是一样,也就均衡了
step_labels = output[np.array(formal_id)]    # 根据均衡后的类别索引,构造新的索引,例如【0,1,3,4,5,7】 ,每个类别随机选取两个数量的结果
step_data = np.array(step_data)[np.array(formal_id)]  # 根据均衡后的索引取获取,类别对应的数据(类别的顺序和数据的须保持一致)

采用“欠采样”的思想来均衡,每个类别都用了最少的样本来训练,画图就是:

img

深度学习训练阶段,防止数据样本太集中,学习不到少量数据的特征。

https://www.cnblogs.com/yumoye/p/10517154.html

使用transforms随机裁剪