-
Notifications
You must be signed in to change notification settings - Fork 9.6k
/
Copy pathmnist.cpp
154 lines (134 loc) · 4.63 KB
/
mnist.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
#include <torch/torch.h>
#include <cstddef>
#include <cstdio>
#include <iostream>
#include <string>
#include <vector>
// Where to find the MNIST dataset.
const char* kDataRoot = "./data";
// The batch size for training.
const int64_t kTrainBatchSize = 64;
// The batch size for testing.
const int64_t kTestBatchSize = 1000;
// The number of epochs to train.
const int64_t kNumberOfEpochs = 10;
// After how many batches to log a new update with the loss value.
const int64_t kLogInterval = 10;
struct Net : torch::nn::Module {
Net()
: conv1(torch::nn::Conv2dOptions(1, 10, /*kernel_size=*/5)),
conv2(torch::nn::Conv2dOptions(10, 20, /*kernel_size=*/5)),
fc1(320, 50),
fc2(50, 10) {
register_module("conv1", conv1);
register_module("conv2", conv2);
register_module("conv2_drop", conv2_drop);
register_module("fc1", fc1);
register_module("fc2", fc2);
}
torch::Tensor forward(torch::Tensor x) {
x = torch::relu(torch::max_pool2d(conv1->forward(x), 2));
x = torch::relu(
torch::max_pool2d(conv2_drop->forward(conv2->forward(x)), 2));
x = x.view({-1, 320});
x = torch::relu(fc1->forward(x));
x = torch::dropout(x, /*p=*/0.5, /*training=*/is_training());
x = fc2->forward(x);
return torch::log_softmax(x, /*dim=*/1);
}
torch::nn::Conv2d conv1;
torch::nn::Conv2d conv2;
torch::nn::Dropout2d conv2_drop;
torch::nn::Linear fc1;
torch::nn::Linear fc2;
};
template <typename DataLoader>
void train(
size_t epoch,
Net& model,
torch::Device device,
DataLoader& data_loader,
torch::optim::Optimizer& optimizer,
size_t dataset_size) {
model.train();
size_t batch_idx = 0;
for (auto& batch : data_loader) {
auto data = batch.data.to(device), targets = batch.target.to(device);
optimizer.zero_grad();
auto output = model.forward(data);
auto loss = torch::nll_loss(output, targets);
AT_ASSERT(!std::isnan(loss.template item<float>()));
loss.backward();
optimizer.step();
if (batch_idx++ % kLogInterval == 0) {
std::printf(
"\rTrain Epoch: %ld [%5ld/%5ld] Loss: %.4f",
epoch,
batch_idx * batch.data.size(0),
dataset_size,
loss.template item<float>());
}
}
}
template <typename DataLoader>
void test(
Net& model,
torch::Device device,
DataLoader& data_loader,
size_t dataset_size) {
torch::NoGradGuard no_grad;
model.eval();
double test_loss = 0;
int32_t correct = 0;
for (const auto& batch : data_loader) {
auto data = batch.data.to(device), targets = batch.target.to(device);
auto output = model.forward(data);
test_loss += torch::nll_loss(
output,
targets,
/*weight=*/{},
torch::Reduction::Sum)
.template item<float>();
auto pred = output.argmax(1);
correct += pred.eq(targets).sum().template item<int64_t>();
}
test_loss /= dataset_size;
std::printf(
"\nTest set: Average loss: %.4f | Accuracy: %.3f\n",
test_loss,
static_cast<double>(correct) / dataset_size);
}
auto main() -> int {
torch::manual_seed(1);
torch::DeviceType device_type;
if (torch::cuda::is_available()) {
std::cout << "CUDA available! Training on GPU." << std::endl;
device_type = torch::kCUDA;
} else {
std::cout << "Training on CPU." << std::endl;
device_type = torch::kCPU;
}
torch::Device device(device_type);
Net model;
model.to(device);
auto train_dataset = torch::data::datasets::MNIST(kDataRoot)
.map(torch::data::transforms::Normalize<>(0.1307, 0.3081))
.map(torch::data::transforms::Stack<>());
const size_t train_dataset_size = train_dataset.size().value();
auto train_loader =
torch::data::make_data_loader<torch::data::samplers::SequentialSampler>(
std::move(train_dataset), kTrainBatchSize);
auto test_dataset = torch::data::datasets::MNIST(
kDataRoot, torch::data::datasets::MNIST::Mode::kTest)
.map(torch::data::transforms::Normalize<>(0.1307, 0.3081))
.map(torch::data::transforms::Stack<>());
const size_t test_dataset_size = test_dataset.size().value();
auto test_loader =
torch::data::make_data_loader(std::move(test_dataset), kTestBatchSize);
torch::optim::SGD optimizer(
model.parameters(), torch::optim::SGDOptions(0.01).momentum(0.5));
for (size_t epoch = 1; epoch <= kNumberOfEpochs; ++epoch) {
train(epoch, model, device, *train_loader, optimizer, train_dataset_size);
test(model, device, *test_loader, test_dataset_size);
}
}