""" Attacker Class ============== """ import collections import logging import multiprocessing as mp import os import queue import random import traceback import torch import tqdm import textattack from textattack.attack_results import ( FailedAttackResult, MaximizedAttackResult, SkippedAttackResult, SuccessfulAttackResult, ) from textattack.shared.utils import logger from .attack import Attack from .attack_args import AttackArgs class Attacker: """Class for running attacks on a dataset with specified parameters. This class uses the :class:`~textattack.Attack` to actually run the attacks, while also providing useful features such as parallel processing, saving/resuming from a checkpint, logging to files and stdout. Args: attack (:class:`~textattack.Attack`): :class:`~textattack.Attack` used to actually carry out the attack. dataset (:class:`~textattack.datasets.Dataset`): Dataset to attack. attack_args (:class:`~textattack.AttackArgs`): Arguments for attacking the dataset. For default settings, look at the `AttackArgs` class. Example:: >>> import textattack >>> import transformers >>> model = transformers.AutoModelForSequenceClassification.from_pretrained("textattack/bert-base-uncased-imdb") >>> tokenizer = transformers.AutoTokenizer.from_pretrained("textattack/bert-base-uncased-imdb") >>> model_wrapper = textattack.models.wrappers.HuggingFaceModelWrapper(model, tokenizer) >>> attack = textattack.attack_recipes.TextFoolerJin2019.build(model_wrapper) >>> dataset = textattack.datasets.HuggingFaceDataset("imdb", split="test") >>> # Attack 20 samples with CSV logging and checkpoint saved every 5 interval >>> attack_args = textattack.AttackArgs( ... num_examples=20, ... log_to_csv="log.csv", ... checkpoint_interval=5, ... checkpoint_dir="checkpoints", ... disable_stdout=True ... ) >>> attacker = textattack.Attacker(attack, dataset, attack_args) >>> attacker.attack_dataset() """ def __init__(self, attack, dataset, attack_args=None): assert isinstance( attack, Attack ), f"`attack` argument must be of type `textattack.Attack`, but got type of `{type(attack)}`." assert isinstance( dataset, textattack.datasets.Dataset ), f"`dataset` must be of type `textattack.datasets.Dataset`, but got type `{type(dataset)}`." if attack_args: assert isinstance( attack_args, AttackArgs ), f"`attack_args` must be of type `textattack.AttackArgs`, but got type `{type(attack_args)}`." else: attack_args = AttackArgs() self.attack = attack self.dataset = dataset self.attack_args = attack_args self.attack_log_manager = None # This is to be set if loading from a checkpoint self._checkpoint = None def _get_worklist(self, start, end, num_examples, shuffle): if end - start < num_examples: logger.warn( f"Attempting to attack {num_examples} samples when only {end-start} are available." ) candidates = list(range(start, end)) if shuffle: random.shuffle(candidates) worklist = collections.deque(candidates[:num_examples]) candidates = collections.deque(candidates[num_examples:]) assert (len(worklist) + len(candidates)) == (end - start) return worklist, candidates def simple_attack(self, text, label): """Internal method that carries out attack. No parallel processing is involved. """ if torch.cuda.is_available(): self.attack.cuda_() example, ground_truth_output = text, label try: example = textattack.shared.AttackedText(example) if self.dataset.label_names is not None: example.attack_attrs["label_names"] = self.dataset.label_names try: result = self.attack.attack(example, ground_truth_output) except Exception as e: raise e # return if ( isinstance(result, SkippedAttackResult) and self.attack_args.attack_n ) or ( not isinstance(result, SuccessfulAttackResult) and self.attack_args.num_successful_examples ): return else: return result except KeyboardInterrupt as e: raise e def _attack(self): """Internal method that carries out attack. No parallel processing is involved. """ if torch.cuda.is_available(): self.attack.cuda_() if self._checkpoint: num_remaining_attacks = self._checkpoint.num_remaining_attacks worklist = self._checkpoint.worklist worklist_candidates = self._checkpoint.worklist_candidates logger.info( f"Recovered from checkpoint previously saved at {self._checkpoint.datetime}." ) else: if self.attack_args.num_successful_examples: num_remaining_attacks = self.attack_args.num_successful_examples # We make `worklist` deque (linked-list) for easy pop and append. # Candidates are other samples we can attack if we need more samples. worklist, worklist_candidates = self._get_worklist( self.attack_args.num_examples_offset, len(self.dataset), self.attack_args.num_successful_examples, self.attack_args.shuffle, ) else: num_remaining_attacks = self.attack_args.num_examples # We make `worklist` deque (linked-list) for easy pop and append. # Candidates are other samples we can attack if we need more samples. worklist, worklist_candidates = self._get_worklist( self.attack_args.num_examples_offset, len(self.dataset), self.attack_args.num_examples, self.attack_args.shuffle, ) if not self.attack_args.silent: print(self.attack, "\n") pbar = tqdm.tqdm(total=num_remaining_attacks, smoothing=0, dynamic_ncols=True) if self._checkpoint: num_results = self._checkpoint.results_count num_failures = self._checkpoint.num_failed_attacks num_skipped = self._checkpoint.num_skipped_attacks num_successes = self._checkpoint.num_successful_attacks else: num_results = 0 num_failures = 0 num_skipped = 0 num_successes = 0 sample_exhaustion_warned = False while worklist: idx = worklist.popleft() try: example, ground_truth_output = self.dataset[idx] except IndexError: continue example = textattack.shared.AttackedText(example) if self.dataset.label_names is not None: example.attack_attrs["label_names"] = self.dataset.label_names try: result = self.attack.attack(example, ground_truth_output) except Exception as e: raise e if ( isinstance(result, SkippedAttackResult) and self.attack_args.attack_n ) or ( not isinstance(result, SuccessfulAttackResult) and self.attack_args.num_successful_examples ): if worklist_candidates: next_sample = worklist_candidates.popleft() worklist.append(next_sample) else: if not sample_exhaustion_warned: logger.warn("Ran out of samples to attack!") sample_exhaustion_warned = True else: pbar.update(1) self.attack_log_manager.log_result(result) if not self.attack_args.disable_stdout and not self.attack_args.silent: print("\n") num_results += 1 if isinstance(result, SkippedAttackResult): num_skipped += 1 if isinstance(result, (SuccessfulAttackResult, MaximizedAttackResult)): num_successes += 1 if isinstance(result, FailedAttackResult): num_failures += 1 pbar.set_description( f"[Succeeded / Failed / Skipped / Total] {num_successes} / {num_failures} / {num_skipped} / {num_results}" ) if ( self.attack_args.checkpoint_interval and len(self.attack_log_manager.results) % self.attack_args.checkpoint_interval == 0 ): new_checkpoint = textattack.shared.AttackCheckpoint( self.attack_args, self.attack_log_manager, worklist, worklist_candidates, ) new_checkpoint.save() self.attack_log_manager.flush() pbar.close() print() # Enable summary stdout if not self.attack_args.silent and self.attack_args.disable_stdout: self.attack_log_manager.enable_stdout() if self.attack_args.enable_advance_metrics: self.attack_log_manager.enable_advance_metrics = True self.attack_log_manager.log_summary() self.attack_log_manager.flush() print() def _attack_parallel(self): pytorch_multiprocessing_workaround() if self._checkpoint: num_remaining_attacks = self._checkpoint.num_remaining_attacks worklist = self._checkpoint.worklist worklist_candidates = self._checkpoint.worklist_candidates logger.info( f"Recovered from checkpoint previously saved at {self._checkpoint.datetime}." ) else: if self.attack_args.num_successful_examples: num_remaining_attacks = self.attack_args.num_successful_examples # We make `worklist` deque (linked-list) for easy pop and append. # Candidates are other samples we can attack if we need more samples. worklist, worklist_candidates = self._get_worklist( self.attack_args.num_examples_offset, len(self.dataset), self.attack_args.num_successful_examples, self.attack_args.shuffle, ) else: num_remaining_attacks = self.attack_args.num_examples # We make `worklist` deque (linked-list) for easy pop and append. # Candidates are other samples we can attack if we need more samples. worklist, worklist_candidates = self._get_worklist( self.attack_args.num_examples_offset, len(self.dataset), self.attack_args.num_examples, self.attack_args.shuffle, ) in_queue = torch.multiprocessing.Queue() out_queue = torch.multiprocessing.Queue() for i in worklist: try: example, ground_truth_output = self.dataset[i] example = textattack.shared.AttackedText(example) if self.dataset.label_names is not None: example.attack_attrs["label_names"] = self.dataset.label_names in_queue.put((i, example, ground_truth_output)) except IndexError: raise IndexError( f"Tried to access element at {i} in dataset of size {len(self.dataset)}." ) # We reserve the first GPU for coordinating workers. num_gpus = torch.cuda.device_count() num_workers = self.attack_args.num_workers_per_device * num_gpus logger.info(f"Running {num_workers} worker(s) on {num_gpus} GPU(s).") # Lock for synchronization lock = mp.Lock() # We move Attacker (and its components) to CPU b/c we don't want models using wrong GPU in worker processes. self.attack.cpu_() torch.cuda.empty_cache() # Start workers. worker_pool = torch.multiprocessing.Pool( num_workers, attack_from_queue, ( self.attack, self.attack_args, num_gpus, mp.Value("i", 1, lock=False), lock, in_queue, out_queue, ), ) # Log results asynchronously and update progress bar. if self._checkpoint: num_results = self._checkpoint.results_count num_failures = self._checkpoint.num_failed_attacks num_skipped = self._checkpoint.num_skipped_attacks num_successes = self._checkpoint.num_successful_attacks else: num_results = 0 num_failures = 0 num_skipped = 0 num_successes = 0 logger.info(f"Worklist size: {len(worklist)}") logger.info(f"Worklist candidate size: {len(worklist_candidates)}") sample_exhaustion_warned = False pbar = tqdm.tqdm(total=num_remaining_attacks, smoothing=0, dynamic_ncols=True) while worklist: idx, result = out_queue.get(block=True) worklist.remove(idx) if isinstance(result, tuple) and isinstance(result[0], Exception): logger.error( f'Exception encountered for input "{self.dataset[idx][0]}".' ) error_trace = result[1] logger.error(error_trace) in_queue.close() in_queue.join_thread() out_queue.close() out_queue.join_thread() worker_pool.terminate() worker_pool.join() return elif ( isinstance(result, SkippedAttackResult) and self.attack_args.attack_n ) or ( not isinstance(result, SuccessfulAttackResult) and self.attack_args.num_successful_examples ): if worklist_candidates: next_sample = worklist_candidates.popleft() example, ground_truth_output = self.dataset[next_sample] example = textattack.shared.AttackedText(example) if self.dataset.label_names is not None: example.attack_attrs["label_names"] = self.dataset.label_names worklist.append(next_sample) in_queue.put((next_sample, example, ground_truth_output)) else: if not sample_exhaustion_warned: logger.warn("Ran out of samples to attack!") sample_exhaustion_warned = True else: pbar.update() self.attack_log_manager.log_result(result) num_results += 1 if isinstance(result, SkippedAttackResult): num_skipped += 1 if isinstance(result, (SuccessfulAttackResult, MaximizedAttackResult)): num_successes += 1 if isinstance(result, FailedAttackResult): num_failures += 1 pbar.set_description( f"[Succeeded / Failed / Skipped / Total] {num_successes} / {num_failures} / {num_skipped} / {num_results}" ) if ( self.attack_args.checkpoint_interval and len(self.attack_log_manager.results) % self.attack_args.checkpoint_interval == 0 ): new_checkpoint = textattack.shared.AttackCheckpoint( self.attack_args, self.attack_log_manager, worklist, worklist_candidates, ) new_checkpoint.save() self.attack_log_manager.flush() # Send sentinel values to worker processes for _ in range(num_workers): in_queue.put(("END", "END", "END")) worker_pool.close() worker_pool.join() pbar.close() print() # Enable summary stdout. if not self.attack_args.silent and self.attack_args.disable_stdout: self.attack_log_manager.enable_stdout() if self.attack_args.enable_advance_metrics: self.attack_log_manager.enable_advance_metrics = True self.attack_log_manager.log_summary() self.attack_log_manager.flush() print() def attack_dataset(self): """Attack the dataset. Returns: :obj:`list[AttackResult]` - List of :class:`~textattack.attack_results.AttackResult` obtained after attacking the given dataset.. """ if self.attack_args.silent: logger.setLevel(logging.ERROR) if self.attack_args.query_budget: self.attack.goal_function.query_budget = self.attack_args.query_budget if not self.attack_log_manager: self.attack_log_manager = AttackArgs.create_loggers_from_args( self.attack_args ) textattack.shared.utils.set_seed(self.attack_args.random_seed) if self.dataset.shuffled and self.attack_args.checkpoint_interval: # Not allowed b/c we cannot recover order of shuffled data raise ValueError( "Cannot use `--checkpoint-interval` with dataset that has been internally shuffled." ) self.attack_args.num_examples = ( len(self.dataset) if self.attack_args.num_examples == -1 else self.attack_args.num_examples ) if self.attack_args.parallel: if torch.cuda.device_count() == 0: raise Exception( "Found no GPU on your system. To run attacks in parallel, GPU is required." ) self._attack_parallel() else: self._attack() if self.attack_args.silent: logger.setLevel(logging.INFO) return self.attack_log_manager.results def update_attack_args(self, **kwargs): """To update any attack args, pass the new argument as keyword argument to this function. Examples:: >>> attacker = #some instance of Attacker >>> # To switch to parallel mode and increase checkpoint interval from 100 to 500 >>> attacker.update_attack_args(parallel=True, checkpoint_interval=500) """ for k in kwargs: if hasattr(self.attack_args, k): self.attack_args.k = kwargs[k] else: raise ValueError(f"`textattack.AttackArgs` does not have field {k}.") @classmethod def from_checkpoint(cls, attack, dataset, checkpoint): """Resume attacking from a saved checkpoint. Attacker and dataset must be recovered by the user again, while attack args are loaded from the saved checkpoint. Args: attack (:class:`~textattack.Attack`): Attack object for carrying out the attack. dataset (:class:`~textattack.datasets.Dataset`): Dataset to attack. checkpoint (:obj:`Union[str, :class:`~textattack.shared.AttackChecpoint`]`): Path of saved checkpoint or the actual saved checkpoint. """ assert isinstance( checkpoint, (str, textattack.shared.AttackCheckpoint) ), f"`checkpoint` must be of type `str` or `textattack.shared.AttackCheckpoint`, but got type `{type(checkpoint)}`." if isinstance(checkpoint, str): checkpoint = textattack.shared.AttackCheckpoint.load(checkpoint) attacker = cls(attack, dataset, checkpoint.attack_args) attacker.attack_log_manager = checkpoint.attack_log_manager attacker._checkpoint = checkpoint return attacker @staticmethod def attack_interactive(attack): print(attack, "\n") print("Running in interactive mode") print("----------------------------") while True: print('Enter a sentence to attack or "q" to quit:') text = input() if text == "q": break if not text: continue print("Attacking...") example = textattack.shared.attacked_text.AttackedText(text) output = attack.goal_function.get_output(example) result = attack.attack(example, output) print(result.__str__(color_method="ansi") + "\n") # # Helper Methods for multiprocess attacks # def pytorch_multiprocessing_workaround(): # This is a fix for a known bug try: torch.multiprocessing.set_start_method("spawn", force=True) torch.multiprocessing.set_sharing_strategy("file_system") except RuntimeError: pass def set_env_variables(gpu_id): # Disable tensorflow logs, except in the case of an error. if "TF_CPP_MIN_LOG_LEVEL" not in os.environ: os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" # Set sharing strategy to file_system to avoid file descriptor leaks torch.multiprocessing.set_sharing_strategy("file_system") # Only use one GPU, if we have one. # For Tensorflow # TODO: Using USE with `--parallel` raises similar issue as https://github.com/tensorflow/tensorflow/issues/38518# os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id) # For PyTorch torch.cuda.set_device(gpu_id) # Fix TensorFlow GPU memory growth try: import tensorflow as tf gpus = tf.config.experimental.list_physical_devices("GPU") if gpus: try: # Currently, memory growth needs to be the same across GPUs gpu = gpus[gpu_id] tf.config.experimental.set_visible_devices(gpu, "GPU") tf.config.experimental.set_memory_growth(gpu, True) except RuntimeError as e: print(e) except ModuleNotFoundError: pass def attack_from_queue( attack, attack_args, num_gpus, first_to_start, lock, in_queue, out_queue ): assert isinstance( attack, Attack ), f"`attack` must be of type `Attack`, but got type `{type(attack)}`." gpu_id = (torch.multiprocessing.current_process()._identity[0] - 1) % num_gpus set_env_variables(gpu_id) textattack.shared.utils.set_seed(attack_args.random_seed) if torch.multiprocessing.current_process()._identity[0] > 1: logging.disable() attack.cuda_() # Simple non-synchronized check to see if it's the first process to reach this point. # This let us avoid waiting for lock. if bool(first_to_start.value): # If it's first process to reach this step, we first try to acquire the lock to update the value. with lock: # Because another process could have changed `first_to_start=False` while we wait, we check again. if bool(first_to_start.value): first_to_start.value = 0 if not attack_args.silent: print(attack, "\n") while True: try: i, example, ground_truth_output = in_queue.get(timeout=5) if i == "END" and example == "END" and ground_truth_output == "END": # End process when sentinel value is received break else: result = attack.attack(example, ground_truth_output) out_queue.put((i, result)) except Exception as e: if isinstance(e, queue.Empty): continue else: out_queue.put((i, (e, traceback.format_exc())))