본문 바로가기
AI

(2024-03-14) Pytorch Lightning

by busybee-busylife 2024. 3. 14.
반응형

Pytorch Lightning을 이용한 CNN Classifier

 

class CNNClassifier(pl.LightningModule):
    def __init__(self):
        super(CNNClassifier, self).__init__()

        
    def forward(self, x):

    
    def configure_optimizers(self):

    
    def training_step(self, batch, batch_idx):

    
    def validation_step(self, batch, batch_idx):
    
    
    def test_step(self, batch, batch_idx):
    
    
    def predict_step(self, batch, batch_idx):
model = CNNClassifier(num_classes=10, dropout_ratio=0.2)

early_stopping = EarlyStopping(moniter='valid_loss', mode='min')
csv_logger = CSVLogger(save_dir="./csv_logger", name='test')
# EarlyStopping, CSVLogger가 자체 내장되어있어 불러오기만 하면 된다 

trainer = Trainer(max_epochs=100, accelerator='auto', callbacks=[early_stopping], logger=csv_logger)

trainer.fit(model, train_dataloader, val_dataloader)

train.test(model, test_dataloader)

 

 

 

 

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import pytorch_lightning as pl

class CNNClassifier(pl.LightningModule):
    def __init__(self):
        super(CNNClassifier, self).__init__()
        self.conv1 = nn.Conv2d(3, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.fc1 = nn.Linear(32 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, 10)
        self.relu = nn.ReLU()
        self.maxpool = nn.MaxPool2d(2, 2)
        
    def forward(self, x):
        x = self.relu(self.conv1(x))
        x = self.maxpool(x)
        x = self.relu(self.conv2(x))
        x = self.maxpool(x)
        x = x.view(-1, 32 * 8 * 8)
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    
    def configure_optimizers(self):
        return optim.Adam(self.parameters(), lr=0.001)
    
    def training_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self(inputs)
        loss = nn.CrossEntropyLoss()(outputs, labels)
        self.log('train_loss', loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        inputs, labels = batch
        outputs = self(inputs)
        loss = nn.CrossEntropyLoss()(outputs, labels)
        self.log('val_loss', loss)
        return loss

class MyCallback(pl.Callback):
    def on_epoch_end(self, trainer, pl_module):
        print(f"Epoch {trainer.current_epoch}, Train Loss: {trainer.callback_metrics['train_loss']}, Val Loss: {trainer.callback_metrics['val_loss']}")

# 데이터셋 준비
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)

val_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=64, shuffle=False)

# 모델과 Trainer 생성
model = CNNClassifier()
callback = MyCallback()  # 각 에폭이 끝날 때마다 훈련/검증 손실을 출력하는 콜백 클래스 
trainer = pl.Trainer(max_epochs=5, progress_bar_refresh_rate=20, callbacks=[callback]) 
# 훈련 옵션, 훈련 루프, 평가 

# 모델 훈련
trainer.fit(model, train_loader, val_loader)
반응형

'AI' 카테고리의 다른 글

2024 AI Stages : ML Competition_House Price Prediction  (0) 2024.04.03
(2024-03-14) Hydra  (1) 2024.03.14
(2024-02-02) XGBoost와 아이들  (0) 2024.02.02
(2024-01-17) 통계  (0) 2024.01.24
(2023-12-29) Python EDA  (0) 2024.01.10