from pathlib import Path
def count_classes(data_yaml: str):
with open(data_yaml, encoding='utf-8') as f:
import yaml
data = yaml.safe_load(f)
labels_dir = 'datasets/' / Path(data['path']) / Path(data['train']).parent / 'labels'
class_counts = {}
for label_file in labels_dir.rglob('*.txt'):
if label_file.stat().st_size == 0:
continue
with open(label_file) as f:
for line in f:
cls_id = int(line.split()[0])
class_counts[cls_id] = class_counts.get(cls_id, 0) + 1
return class_counts
counts = count_classes('models/hands.yaml')
print("类别ID → 样本数:")
for k, v in sorted(counts.items()):
print(f" {k}: {v}")
Python