Liu/defend/main.py

37 lines
810 B
Python
Raw Normal View History

2024-01-26 20:42:33 +08:00
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
import keras
from data_load import data_format
from attack_craft import craft_adv
# 加载数据集
X_train, X_test, Y_train, Y_test = data_format(
'data/archive/PowerQualityDistributionDataset1.csv')
# 设置随机种子以确保重现性
np.random.seed(7)
np.random.shuffle(X_test)
np.random.seed(7)
np.random.shuffle(Y_test)
tf.random.set_seed(7)
# 加载训练好的模型
model = keras.models.load_model('model')
model_adv = keras.models.load_model('model_adv')
# 定义损失函数
loss_fn = tf.keras.losses.MeanSquaredError()
x_adv, loss = craft_adv(
X_test, Y_test, 0.4, 0.5, model, loss_fn)
loss_adv = model_adv.evaluate(x_adv, Y_test)
print(f"原始模型:{loss},对抗训练后的模型:{loss_adv}")