pycharm · 2026年3月11日

查看训练样本的分布

from pathlib import Path


def count_classes(data_yaml: str):
    with open(data_yaml, encoding='utf-8') as f:
        import yaml
        data = yaml.safe_load(f)
    labels_dir = 'datasets/' / Path(data['path']) / Path(data['train']).parent / 'labels'

    class_counts = {}
    for label_file in labels_dir.rglob('*.txt'):
        if label_file.stat().st_size == 0:
            continue
        with open(label_file) as f:
            for line in f:
                cls_id = int(line.split()[0])
                class_counts[cls_id] = class_counts.get(cls_id, 0) + 1

    return class_counts


counts = count_classes('models/hands.yaml')
print("类别ID → 样本数:")
for k, v in sorted(counts.items()):
    print(f"  {k}: {v}")
Python