37 lines
810 B
Python
37 lines
810 B
Python
|
import matplotlib.pyplot as plt
|
||
|
import tensorflow as tf
|
||
|
import numpy as np
|
||
|
import keras
|
||
|
from data_load import data_format
|
||
|
from attack_craft import craft_adv
|
||
|
|
||
|
|
||
|
# 加载数据集
|
||
|
X_train, X_test, Y_train, Y_test = data_format(
|
||
|
'data/archive/PowerQualityDistributionDataset1.csv')
|
||
|
|
||
|
# 设置随机种子以确保重现性
|
||
|
np.random.seed(7)
|
||
|
np.random.shuffle(X_test)
|
||
|
np.random.seed(7)
|
||
|
np.random.shuffle(Y_test)
|
||
|
tf.random.set_seed(7)
|
||
|
|
||
|
|
||
|
# 加载训练好的模型
|
||
|
model = keras.models.load_model('model')
|
||
|
|
||
|
model_adv = keras.models.load_model('model_adv')
|
||
|
|
||
|
# 定义损失函数
|
||
|
loss_fn = tf.keras.losses.MeanSquaredError()
|
||
|
|
||
|
x_adv, loss = craft_adv(
|
||
|
X_test, Y_test, 0.4, 0.5, model, loss_fn)
|
||
|
|
||
|
loss_adv = model_adv.evaluate(x_adv, Y_test)
|
||
|
|
||
|
print(f"原始模型:{loss},对抗训练后的模型:{loss_adv}")
|
||
|
|
||
|
|