baoshiwei
2024-03-04 e595c312581496403ac182f12f3d4939d3d00998
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os.path
 
import torch
import torch.nn as nn
from PIL import Image
from torchvision import datasets
 
from HerbModel import Model
 
from torch.nn import Sequential, Conv2d, MaxPool2d, Flatten, Linear
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision.transforms import transforms
 
img_transforms =transforms.Compose([
 
 
   transforms.Resize(128),
    transforms.RandomHorizontalFlip(),
    transforms.RandomCrop(128),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5],[0.5, 0.5, 0.5])
])
 
img_transforms2 =transforms.Compose([
 
    transforms.Resize(32),
    # transforms.RandomHorizontalFlip(),
    # transforms.CenterCrop(32),
    transforms.ToTensor(),
    transforms.Normalize([0.5,0.5,0.5],[0.5,0.5,0.5])
])
 
 
# 数据加载
data_dir = '.'
train_data = datasets.ImageFolder(os.path.join(data_dir, 'yaocai'), img_transforms)
 
train_loader = DataLoader(train_data, 4,shuffle=True)
 
print("训练数据集数量:{}".format(len(train_data)))
model = Model()
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
writer = SummaryWriter("yc_logs")
epoch = 50
step = 0
for i in range(epoch):
    print("第{}轮训练开始。。。。。".format(i+1))
    for data in train_loader:
        imgs, target = data
        # print(imgs.shape)
        output = model(imgs)
        # print(output)
 
 
        loss = loss_fn(output, target)
 
        print("损失:{}".format(loss))
 
        argmax = output.argmax(1)
        acc = (argmax == target).sum()
 
 
        print("结果:{}".format(argmax))
        print("实际:{}".format(target))
        print("正确率:{}".format(acc/len(imgs)*100))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
 
        step = step + 1
        print("第{}次训练".format(step))
 
step = 0
 
writer.close()