# import libraries
import torch
import numpy as np

from torchvision import datasets
import torchvision.transforms as transforms

from torch.quantization import QuantStub, DeQuantStub


# number of subprocesses to use for data loading
num_workers = 0
# how many samples per batch to load
batch_size = 20

# convert data to torch.FloatTensor
#transform = transforms.ToTensor()
transform = transforms.Compose([
    transforms.Resize((16, 16)),  # 缩放图像
    transforms.ToTensor()         # 转换为张量
])


# choose the training and test datasets
train_data = datasets.MNIST(root='data', train=True, download=True, transform=transform)
test_data = datasets.MNIST(root='data', train=False, download=True, transform=transform)

# prepare data loaders
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, num_workers=num_workers)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, num_workers=num_workers)


# Your data loading code remains the same
# ...

dataiter = iter(train_loader)
images, labels = next(dataiter)
images = images.numpy()

import torch.nn as nn
import torch.nn.functional as F

## Define the NN architecture
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.quant = QuantStub()
        self.conv1 = nn.Conv2d(1, 4, 3, padding=1)
        # 第二个卷积层: 输入通道32，输出通道64，卷积核大小3x3，padding为1
        self.conv2 = nn.Conv2d(4, 8, 3, padding=1)

        # 最大池化层，池化核大小2x2
        self.pool = nn.MaxPool2d(2, 2)

        # 全连接层: 由于经过两次卷积和池化，图像尺寸保持为4x4，因此输入特征数为64 * 4 * 4，输出为10个分类
        self.fc1 = nn.Linear(4 *4* 8, 10)
        self.dequant = DeQuantStub()


    def forward(self, x):
        x = self.quant(x)
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.reshape(-1, 4 * 4 * 8)
        x = self.fc1(x)
        x = self.dequant(x)
        return x

# initialize the NN
model = Net()
print(model)

## Specify loss and optimization functions

# specify loss function
criterion = nn.CrossEntropyLoss()

# specify optimizer
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

# number of epochs to train the model
n_epochs = 30  # suggest training between 20-50 epochs

model.train()  # prep model for training

for epoch in range(n_epochs):
    # monitor training loss
    train_loss = 0.0

    ###################
    # train the model #
    ###################
    for data, target in train_loader:
        # clear the gradients of all optimized variables
        optimizer.zero_grad()
        # forward pass: compute predicted outputs by passing inputs to the model
        output = model(data)
        # calculate the loss
        loss = criterion(output, target)
        # backward pass: compute gradient of the loss with respect to model parameters
        loss.backward()
        # perform a single optimization step (parameter update)
        optimizer.step()
        # update running training loss
        train_loss += loss.item() * data.size(0)

    # print training statistics
    # calculate average loss over an epoch
    train_loss = train_loss / len(train_loader.dataset)

    print('Epoch: {} \tTraining Loss: {:.6f}'.format(
        epoch + 1,
        train_loss
    ))

print(model)

for param_tensor in model.state_dict():
    # 将张量转换为numpy数组
    numpy_array = model.state_dict()[param_tensor].numpy()

    # 如果numpy数组不是1D或2D，将其转换为2D
    if len(numpy_array.shape) > 2:
        numpy_array = numpy_array.reshape(numpy_array.shape[0], -1)

    # 保存为CSV文件
    np.savetxt(f"{param_tensor}.csv", numpy_array, delimiter=",")



'''evaluate'''
# initialize lists to monitor test loss and accuracy
test_loss = 0.0
class_correct = list(0. for i in range(10))
class_total = list(0. for i in range(10))

model.eval() # prep model for *evaluation*
model.qconfig = torch.quantization.get_default_qconfig('fbgemm')
torch.quantization.prepare(model, inplace=True)
torch.quantization.convert(model, inplace=True)

for data, target in test_loader:
    # forward pass: compute predicted outputs by passing inputs to the model
    output = model(data)
    # calculate the loss
    loss = criterion(output, target)
    # update test loss
    test_loss += loss.item()*data.size(0)
    # convert output probabilities to predicted class
    _, pred = torch.max(output, 1)
    # compare predictions to true label
    correct = np.squeeze(pred.eq(target.data.view_as(pred)))
    # calculate test accuracy for each object class
    for i in range(batch_size):
        label = target.data[i]
        class_correct[label] += correct[i].item()
        class_total[label] += 1

# calculate and print avg test loss
test_loss = test_loss/len(test_loader.dataset)
print('Test Loss: {:.6f}\n'.format(test_loss))

for i in range(10):
    if class_total[i] > 0:
        print('Test Accuracy of %5s: %2d%% (%2d/%2d)' % (
            str(i), 100 * class_correct[i] / class_total[i], np.sum(class_correct[i]), np.sum(class_total[i])))
    '''else:
        print('Test Accuracy of %5s: N/A (no training examples)' % (classes[i]))'''

print('\nTest Accuracy (Overall): %2d%% (%2d/%2d)' % (
    100. * np.sum(class_correct) / np.sum(class_total), np.sum(class_correct), np.sum(class_total)))

print(model)
