-
Notifications
You must be signed in to change notification settings - Fork 0
/
autoencoder.cpp
executable file
·35 lines (30 loc) · 1.18 KB
/
autoencoder.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
#include "autoencoder.h"
using namespace std;
AEImpl::AEImpl(int64_t inputSize, int64_t hSize, int64_t codeSize)
: fc1(inputSize, hSize), fc2(hSize, hSize), fc3(hSize, hSize), fc4(hSize, codeSize), // encoder
fc5(codeSize, hSize), fc6(hSize, hSize), fc7(hSize, hSize), fc8(hSize, inputSize) { // decoder
register_module("fc1", fc1);
register_module("fc2", fc2);
register_module("fc3", fc3);
register_module("fc4", fc4);
register_module("fc5", fc5);
register_module("fc6", fc6);
register_module("fc7", fc7);
register_module("fc8", fc8);
}
torch::Tensor AEImpl::encode(torch::Tensor x) {
auto h = torch::nn::functional::relu(fc1->forward(x));
auto h2 = torch::nn::functional::relu(fc2->forward(h));
auto h3 = torch::nn::functional::relu(fc3->forward(h2));
return fc4->forward(h3);
}
torch::Tensor AEImpl::decode(torch::Tensor z) {
auto h = torch::nn::functional::relu(fc5->forward(z));
auto h2 = torch::nn::functional::relu(fc6->forward(h));
auto h3 = torch::nn::functional::relu(fc7->forward(h2));
return fc8->forward(h3);
}
torch::Tensor AEImpl::forward(torch::Tensor x) {
auto codes = encode(x);
return decode(codes);
}