Liu/attack/attack_craft.py

72 lines
2.7 KiB
Python

import tensorflow as tf
def craft_adv(X, Y, gamma, learning_rate, model, loss_fn, md = 0):
# 将测试数据转换为TensorFlow张量
X_test_tensor = tf.convert_to_tensor(X, dtype=tf.float64)
if md == 0:
Y_test_tensor = tf.convert_to_tensor(Y, dtype=tf.int32)
elif md == 1:
Y_test_tensor = tf.convert_to_tensor(Y, dtype=tf.float64)
# 初始化更新后的数据集
X_train_updated = []
for i in range(X_test_tensor.shape[0]):
# 对每个样本使用GradientTape
with tf.GradientTape() as tape:
# 监视当前样本
current_sample = X_test_tensor[i:i+1]
tape.watch(current_sample)
# 对当前样本进行预测并计算损失
predictions = model(current_sample)
loss = loss_fn(Y_test_tensor[i:i+1], predictions)
# 计算关于输入的梯度
gradients = tape.gradient(loss, current_sample)
# 平坦化梯度以便进行处理
flattened_gradients = tf.reshape(gradients, [-1])
# 选择最大的γ * |X|个梯度
num_gradients_to_select = int(gamma * tf.size(flattened_gradients, out_type=tf.dtypes.float32))
top_gradients_indices = tf.argsort(flattened_gradients, direction='DESCENDING')[:num_gradients_to_select]
# 创建新的梯度张量,初始值为原始梯度
updated_gradients = tf.identity(flattened_gradients)
# 创建布尔掩码,用于选择特定梯度
mask = tf.ones_like(updated_gradients, dtype=bool)
mask = tf.tensor_scatter_nd_update(mask, tf.expand_dims(top_gradients_indices, 1), tf.zeros_like(top_gradients_indices, dtype=bool))
# 应用掩码更新梯度
updated_gradients = tf.where(mask, tf.zeros_like(updated_gradients), updated_gradients)
# 将梯度恢复到原始形状
updated_gradients = tf.reshape(updated_gradients, tf.shape(gradients))
# 应用学习率到梯度
scaled_gradients = learning_rate * updated_gradients
# 更新当前样本
current_sample_updated = tf.add(current_sample, scaled_gradients)
# 将更新后的样本添加到列表中
X_train_updated.append(current_sample_updated.numpy())
# 将列表转换为张量
X_train_updated = tf.concat(X_train_updated, axis=0)
# 评估更新后的模型
if md == 1:
loss, mape = model.evaluate(X_train_updated, Y)
print(f"Accuracy gamma: {gamma},learning:{learning_rate}", loss)
return X_train_updated, loss, mape
elif md == 0:
loss, accuracy = model.evaluate(X_train_updated, Y)
print(f"Accuracy gamma: {gamma},learning:{learning_rate},accuracy{accuracy}" )
return X_train_updated, accuracy