102 分類花画像データセットの準備と拡張
本プロジェクトでは、102 種類の異なる花の画像を識別する分類モデルを構築します。データセット内の各クラスに含まれる画像枚数が不均一であるため、学習の安定化を図るためにデータ拡張を行い、各クラスあたり少なくとも 100 枚の画像を確保する前処理を行います。
以下のスクリプトは、既存の画像リストを読み込み、不足している分を回転や反転などの变换を適用して生成し、テキストファイルに追記します。
データ拡張スクリプト
import os
import random
from pathlib import Path
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset
# 各クラスあたりの目標画像数
TARGET_SAMPLES_PER_CLASS = 100
# 拡張変換パイプライン
augmentation_pipeline = transforms.Compose([
transforms.RandomAffine(degrees=30, translate=(0.1, 0.1)),
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomResizedCrop(size=150, scale=(0.8, 1.0)),
transforms.ToTensor(),
transforms.ToPILImage()
])
class ImageLabelReader(Dataset):
def __init__(self, base_path, list_file, transform=None):
self.base_path = Path(base_path)
self.transform = transform
self.samples = []
with open(list_file, 'r', encoding='utf-8') as f:
for line in f:
parts = line.strip().split()
if len(parts) >= 2:
fname, cls_name = parts[0], parts[1]
self.samples.append((fname, cls_name))
def __len__(self):
return len(self.samples)
def __getitem__(self, index):
fname, cls_name = self.samples[index]
img_path = self.base_path / cls_name / fname
img = Image.open(img_path).convert('RGB')
if self.transform:
img = self.transform(img)
return img, cls_name, fname
def generate_augmented_data(root_dir, list_path):
# クラスごとに画像ファイルをグループ化
class_groups = {}
with open(list_path, 'r', encoding='utf-8') as f:
for line in f:
parts = line.strip().split()
if len(parts) >= 2:
fname, cls_name = parts[0], parts[1]
if cls_name not in class_groups:
class_groups[cls_name] = []
class_groups[cls_name].append(fname)
# 各クラスについて処理
for cls_name, file_list in class_groups.items():
current_count = len(file_list)
if current_count >= TARGET_SAMPLES_PER_CLASS:
continue
needed_count = TARGET_SAMPLES_PER_CLASS - current_count
print(f"Class '{cls_name}': Generating {needed_count} images...")
generated_files = []
for i in range(needed_count):
# 既存画像からランダムに選択
source_file = random.choice(file_list)
source_path = Path(root_dir) / cls_name / source_file
img = Image.open(source_path).convert('RGB')
augmented_img = augmentation_pipeline(img)
new_fname = f"aug_{cls_name}_{i}.jpg"
save_path = Path(root_dir) / cls_name / new_fname
augmented_img.save(save_path)
generated_files.append(new_fname)
# リストファイルに追記
with open(list_path, 'a', encoding='utf-8') as f:
for new_fname in generated_files:
f.write(f"{new_fname} {cls_name}\n")
if __name__ == "__main__":
# 環境に合わせてパスを変更してください
data_root = "./flower_dataset/train"
annotation_file = "./flower_dataset/train.txt"
generate_augmented_data(data_root, annotation_file)
ResNet18 モデルの定義と学習プロセス
データ準備が整った後、PyTorch を使用して ResNet18 アーキテクチャをロードし、最終全結合層を 102 分類用に修正します。学習時には検証データを用いてモデルの性能を監視し、最も精度の高い時点の重みを保存します。
学習用メインスクリプト
import os
import copy
import time
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from torch.utils.tensorboard import SummaryWriter
from PIL import Image
# TensorBoard 設定
log_writer = SummaryWriter(log_dir='./runs/flower_classification')
class FlowerClassificationDataset(Dataset):
def __init__(self, root_dir, annotation_path, transform=None):
self.root_dir = root_dir
self.transform = transform
self.data_info = self._parse_annotation(annotation_path)
self.image_paths = list(self.data_info.keys())
self.labels = list(self.data_info.values())
def _parse_annotation(self, path):
info = {}
with open(path, 'r') as f:
for line in f:
parts = line.strip().split()
if len(parts) >= 2:
fname, label = parts[0], parts[1]
info[fname] = np.int64(label)
return info
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_name = self.image_paths[idx]
label = self.labels[idx]
img_path = os.path.join(self.root_dir, img_name)
# 簡略化のためルートディレクトリ直下と仮定、必要に応じて修正
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
return image, torch.tensor(label, dtype=torch.long)
# 変換設定
transform_config = {
'train': transforms.Compose([
transforms.Resize((70, 70)),
transforms.RandomRotation(30),
transforms.CenterCrop(64),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.3, contrast=0.2),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize((70, 70)),
transforms.CenterCrop(64),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
}
# データローダーの準備
train_data = FlowerClassificationDataset('./train_filelist', './train.txt', transform_config['train'])
val_data = FlowerClassificationDataset('./val_filelist', './val.txt', transform_config['val'])
train_loader = DataLoader(train_data, batch_size=32, shuffle=True, num_workers=2)
val_loader = DataLoader(val_data, batch_size=32, shuffle=False, num_workers=2)
loaders = {'train': train_loader, 'val': val_loader}
# モデルの構築
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
network = models.resnet18(weights=None)
num_features = network.fc.in_features
network.fc = nn.Linear(num_features, 102)
network = network.to(device)
# 損失関数と最適化
loss_function = nn.CrossEntropyLoss()
optimizer = optim.Adam(network.parameters(), lr=0.001, weight_decay=1e-3)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=7, gamma=0.1)
def train_network(model, loaders, criterion, opt, scheduler, epochs=30, save_path='checkpoint.pth'):
start_time = time.time()
best_accuracy = 0.0
best_weights = copy.deepcopy(model.state_dict())
for epoch in range(epochs):
print(f"Epoch {epoch}/{epochs - 1}")
print("-" * 20)
for phase in ['train', 'val']:
if phase == 'train':
model.train()
else:
model.eval()
running_loss = 0.0
correct_preds = 0
total_samples = 0
for inputs, targets in loaders[phase]:
inputs = inputs.to(device)
targets = targets.to(device)
opt.zero_grad()
with torch.set_grad_enabled(phase == 'train'):
outputs = model(inputs)
loss = criterion(outputs, targets)
_, predictions = torch.max(outputs, 1)
if phase == 'train':
loss.backward()
opt.step()
running_loss += loss.item() * inputs.size(0)
correct_preds += torch.sum(predictions == targets.data)
total_samples += inputs.size(0)
epoch_loss = running_loss / total_samples
epoch_acc = correct_preds.double() / total_samples
print(f"{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")
log_writer.add_scalar(f"{phase}/loss", epoch_loss, epoch)
log_writer.add_scalar(f"{phase}/accuracy", epoch_acc, epoch)
if phase == 'val' and epoch_acc > best_accuracy:
best_accuracy = epoch_acc
best_weights = copy.deepcopy(model.state_dict())
torch.save({
'epoch': epoch,
'model_state_dict': model.state_dict(),
'optimizer_state_dict': opt.state_dict(),
'accuracy': best_accuracy
}, save_path)
scheduler.step()
print(f"Current LR: {opt.param_groups[0]['lr']}")
total_time = time.time() - start_time
print(f"Training completed in {total_time // 60:.0f}m {total_time % 60:.0f}s")
print(f"Best Validation Accuracy: {best_accuracy:.4f}")
model.load_state_dict(best_weights)
return model
# 学習実行
trained_model = train_network(network, loaders, loss_function, optimizer, lr_scheduler, epochs=25)
上記のコードでは、学習率スケジューリングにより途中学習率を減衰させ、過学習を防ぐ工夫を行っています。また、検証セットでの精度が更新された時点でモデルの重みをファイルに保存し、最終的に最も性能の良いモデルをロードして返す仕組みになっています。