37 lines
810 B
Python
37 lines
810 B
Python
import matplotlib.pyplot as plt
|
|
import tensorflow as tf
|
|
import numpy as np
|
|
import keras
|
|
from data_load import data_format
|
|
from attack_craft import craft_adv
|
|
|
|
|
|
# 加载数据集
|
|
X_train, X_test, Y_train, Y_test = data_format(
|
|
'data/archive/PowerQualityDistributionDataset1.csv')
|
|
|
|
# 设置随机种子以确保重现性
|
|
np.random.seed(7)
|
|
np.random.shuffle(X_test)
|
|
np.random.seed(7)
|
|
np.random.shuffle(Y_test)
|
|
tf.random.set_seed(7)
|
|
|
|
|
|
# 加载训练好的模型
|
|
model = keras.models.load_model('model')
|
|
|
|
model_adv = keras.models.load_model('model_adv')
|
|
|
|
# 定义损失函数
|
|
loss_fn = tf.keras.losses.MeanSquaredError()
|
|
|
|
x_adv, loss = craft_adv(
|
|
X_test, Y_test, 0.4, 0.5, model, loss_fn)
|
|
|
|
loss_adv = model_adv.evaluate(x_adv, Y_test)
|
|
|
|
print(f"原始模型:{loss},对抗训练后的模型:{loss_adv}")
|
|
|
|
|