加载中...

Pytorch


数据集处理

基础操作

TensorDatasetDataLoader

import torch
import torch.utils.data

# 创建一个tensor输入
inputs = torch.tensor([[1.0, 2.0], [2.0, 3.0], [3.0, 4.0], [4.0, 5.0], [5.0, 6.0], 
                       [6.0, 7.0], [7.0, 8.0], [8.0, 9.0], [9.0, 10.0], [10.0, 11.0]])

# 创建一个tensor标签
labels = torch.tensor([0, 1, 0, 1, 0, 1, 0, 1, 0, 1])

# 创建TensorDataset对象
dataset = torch.utils.data.TensorDataset(inputs, labels) # 利用torch.utils.data.TensorDataset()
print(dataset)
print(type(dataset))
for data, i in dataset: 
    print(data, i)
print()

# 创建Dataloader对象
dataloader = torch.utils.data.DataLoader(dataset, batch_size=2, shuffle=True) # 利用torch.utils.
print(dataloader)
print(type(dataloader))
for batch_data, batch_label in dataloader:
    print(batch_data, batch_label)

输出

<torch.utils.data.dataset.TensorDataset object at 0x7ff28b5da220>
<class 'torch.utils.data.dataset.TensorDataset'>
tensor([1., 2.]) tensor(0)
tensor([2., 3.]) tensor(1)
tensor([3., 4.]) tensor(0)
tensor([4., 5.]) tensor(1)
tensor([5., 6.]) tensor(0)
tensor([6., 7.]) tensor(1)
tensor([7., 8.]) tensor(0)
tensor([8., 9.]) tensor(1)
tensor([ 9., 10.]) tensor(0)
tensor([10., 11.]) tensor(1)

<torch.utils.data.dataloader.DataLoader object at 0x7ff28b5dabb0>
<class 'torch.utils.data.dataloader.DataLoader'>
tensor([[4., 5.],
        [6., 7.]]) tensor([1, 1])
tensor([[5., 6.],
        [2., 3.]]) tensor([0, 1])
tensor([[ 7.,  8.],
        [10., 11.]]) tensor([0, 1])
tensor([[1., 2.],
        [8., 9.]]) tensor([0, 1])
tensor([[ 9., 10.],
        [ 3.,  4.]]) tensor([0, 0])

经典数据集导入

MNIST

内容 链接

CNN

https://pytorch.org/tutorials/beginner/blitz/cifar10_tutorial.html

内容 链接

CSDN

神经网络搭建

LeNet-5

内容 链接 创作者
Python Class Tutorial 链接🔗 Corey Schafer

文章作者: Rickyの水果摊
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 Rickyの水果摊 !
 本篇
Pytorch Pytorch
2024-09-12 Rickyの水果摊
本篇 
Pytorch Pytorch
2024-09-12 Rickyの水果摊
  目录