티스토리 뷰

ML & DL/PyTorch

Dataset & DataLoader

Enna 2023. 2. 20. 15:53

PyTorch Dataset

  • torch.utils.data.Dataset
    • 모듈을 상속해 새로운 class를 만듦으로써 우리가 원하는 데이터셋 지정
  • __len__()
    • 데이터셋의 총 데이터 수
  • __getitem__()
    • 어떠한 인덱스 idx를 받았을 때, 그에 상응하는 입출력 데이터 반환
from torch.utils.data import Dataset

class CustomDataset(Dataset):
    def __init__(self):
        self.x_data = [[73, 80, 75],
                       [93, 88, 93],
                       [89, 91, 90],
                       [96, 98, 100],
                       [73, 66, 70]]
        self.y_data = [[152], [185], [180], [196], [142]]

    def __len__(self):
        return len(self.x_data)

    def __getitem__(self, idx):
        x = torch.FloatTensor(self.x_data[idx])
        y = torch.FloatTensor(self.y_data[idx])
        return x, y
dataset = CustomDataset()

PyTorch DataLoader

  • torch.utils.data.DataLoader
  • batch_size=2
    • 각 mini-batch의 크기
    • 통상적으로 2의 제곱수로 설정
      • CPU와 GPU의 메모리가 2의 배수이므로 배치크기가 2의 제곱수일 경우에 데이터 송수신의 효율을 높일 수 있음
  • shuffle=True
    • epoch마다 데이터셋을 섞어서 데이터가 학습되는 순서를 바꿈
    • 데이터셋의 순서를 외우지 못하게 방지
from torch.utils.data import DataLoader

dataloader = DataLoader(
    dataset,
    batch_size=2,
    shuffle=True
)

Full Code

  • enumerate(dataloader): mini-batch의 인덱스와 데이터를 받음
  • len(dataloader): 한 epoch당 mini-batch의 개수
# 모델 초기화
model = MultivariateLinearRegressionModel()

# optimizer 설정
optimizer = optim.SGD(model.parameters(), lr=1e-5)

# 모델 학습
nb_epochs = 20
for epoch in range(nb_epochs + 1):
    for batch_idx, samples in enumerate(dataloader):
        x_train, y_train = samples
        책
        # H(x) 계산
        prediction = model(x_train)

        # cost 계산
        cost = F.mse_loss(prediction, y_train)

        # cost로 H(x) 개선
        optimizer.zero_grad()
        cost.backward()
        optimizer.step()

        print('Epoch {:4d}/{} Batch {}/{} Cost: {:.6f}'.format(
            epoch, nb_epochs, batch_idx+1, len(dataloader), cost.item()
        ))

'ML & DL > PyTorch' 카테고리의 다른 글

no_grad vs. requires_grad  (0) 2023.02.20
torch.optim.Optimizer  (0) 2023.02.20
torch.nn.Linear  (0) 2023.02.20
댓글
공지사항
최근에 올라온 글
최근에 달린 댓글
Total
Today
Yesterday
링크
«   2024/06   »
1
2 3 4 5 6 7 8
9 10 11 12 13 14 15
16 17 18 19 20 21 22
23 24 25 26 27 28 29
30
글 보관함