NeuroSurgeon.Probing package

Submodules

Probe Configs

class NeuroSurgeon.Probing.probe_configs.ResidualUpdateModelConfig(*args: Any, **kwargs: Any)

Bases: PretrainedConfig

A config object defining the behavior of the ResidualUpdateModel

Parameters:
  • model_type (str) – The type of model that you are hooking. One of [“bert”, “gpt”, “gptneox”, “vit”]

  • target_layers (List[int]) – A list of layers to hook

  • mlp (bool) – Whether to hook the mlp layers

  • attn (bool) – Whether to hook the attention layers

  • updates (bool, optional) – Whether to store updates to the residual stream, defaults to True

  • stream (bool, optional) – Whether to store the intermediate residual stream values, defaults to False

class NeuroSurgeon.Probing.probe_configs.CircuitProbeConfig(*args: Any, **kwargs: Any)

Bases: PretrainedConfig

A config object defining the behavior of the CircuitProbe

Parameters:
  • probe_vectors (str) – The entries in a ResidualUpdateModel’s vector_cache that one will train a probe on. Ex. attn_update_1, mlp_stream_5

  • circuit_config (CircuitConfig) – A CircuitConfig object defining the masking behavior of the model

  • resid_config (ResidualUpdateModelConfig) – A ResidualUpdateModelConfig defining the behavior of the residual update model. Make sure that the probe_vectors argument aligns with this config!

  • loss – Either “contrastive” or “linear”.

“Contrastive” refers to the standard contrastive loss as described in Lepori et al. 2024. “Linear refers to an experimental loss function that links circuit probing and linear probing. :type loss: str, optional

class NeuroSurgeon.Probing.probe_configs.SubnetworkProbeConfig(*args: Any, **kwargs: Any)

Bases: PretrainedConfig

A config object defining the behavior of the SubnetworkProbe

Parameters:
  • probe_vectors (str) – The entries in a ResidualUpdateModel’s vector_cache that one will train a probe on. Ex. attn_update_1, mlp_stream_5

  • n_classes (int) – The number of classes in the probe task

  • circuit_config (CircuitConfig) – A CircuitConfig object defining the masking behavior of the model

  • resid_config (ResidualUpdateModelConfig) – A ResidualUpdateModelConfig defining the behavior of the residual update model. Make sure that the probe_vectors argument aligns with this config!

  • intermediate_size (int, optional) – If an MLP probe is required, the dimensionality of the hidden layer, defaults to -1

  • labeling (str, optional) – Either “sequence” or “token” - whether to expect one label per input or many, defaults to “sequence”

ResidualUpdateModel

class NeuroSurgeon.Probing.residual_update_model.ResidualUpdateModel(*args: Any, **kwargs: Any)

Bases: Module

A simple wrapped that adds hooks into a model to return intermediate hidden states or updates from particular MLP or Attention layers. This is loosely modeled off of the Transformer Lens library from Neel Nanda (https://github.com/neelnanda-io/TransformerLens). This is a minimal implementation that is built to subnetwork based pruning efforts. Currently, it supports ViT, GPT2, GPTNeoX, BERT, RoBERTa, MPNet, ConvBERT, Ernie, and Electra models. On every forward pass, the specied updates/activations are placed in a vector_cache dictionary.

Parameters:
  • config (ResidualUpdateModelConfig) – A configuration object defining which updates/activations to store

  • model (nn.Module (should be one of the model architectures stated above)) – A transformer model to wrap

forward(**kwargs)
train(train_bool: bool = True)

SubnetworkProbe

class NeuroSurgeon.Probing.subnetwork_probe.SubnetworkProbe(*args: Any, **kwargs: Any)

Bases: Module

This reimplements the technique introduced in Cao et al. 2021 (https://arxiv.org/abs/2104.03514). Probing introduces a linear layer or MLP to extract information from intermediate representations in a model. One can train the probe to classify inputs at the token or sequence level, using either intermediate updates or intermediate activations. Cao et al. introduced subnetwork probing, which optimizes a binary mask and a linear probe at the same time, resulting in low-complexity probes. This class implements this technique by introducing probing layers into a CircuitModel. One can also use this class to perform regular probing by specifying that no layers get masked in the CircuitModel config.

Parameters:
  • config (SubnetworkProbeConfig) – A config file determining the behavior of the subnetwork probe

  • model (nn.Module) – The model to probe. Currently, it supports ViT, GPT2, GPTNeoX, BERT, RoBERTa, MPNet, ConvBERT, Ernie, and Electra models.

train(train_bool: bool = True)
forward(input_ids=None, labels=None, token_mask=None, return_dict=True, **kwargs)

Forward pass of the model

Parameters:
  • input_ids (torch.Tensor, optional) – input tensors, defaults to None

  • labels (torch.Tensor, optional) – probing labels, defaults to None

  • token_mask (torch.Tensor, optional) – A mask defining which updates/residual stream entries should be mapped to labels, defaults to None

  • return_dict (bool, optional) – Whether to return an output dictionary or a tuple, defaults to True

Returns:

An object that contains output predictions, loss, etc. (dictionary)

Return type:

SequenceClassifierOutput or Tuple

CircuitProbe

Module contents