NeuroSurgeon.Probing package
Submodules
Probe Configs
- class NeuroSurgeon.Probing.probe_configs.ResidualUpdateModelConfig(*args: Any, **kwargs: Any)
Bases:
PretrainedConfigA config object defining the behavior of the ResidualUpdateModel
- Parameters:
model_type (str) – The type of model that you are hooking. One of [“bert”, “gpt”, “gptneox”, “vit”]
target_layers (List[int]) – A list of layers to hook
mlp (bool) – Whether to hook the mlp layers
attn (bool) – Whether to hook the attention layers
updates (bool, optional) – Whether to store updates to the residual stream, defaults to True
stream (bool, optional) – Whether to store the intermediate residual stream values, defaults to False
- class NeuroSurgeon.Probing.probe_configs.CircuitProbeConfig(*args: Any, **kwargs: Any)
Bases:
PretrainedConfigA config object defining the behavior of the CircuitProbe
- Parameters:
probe_vectors (str) – The entries in a ResidualUpdateModel’s vector_cache that one will train a probe on. Ex. attn_update_1, mlp_stream_5
circuit_config (CircuitConfig) – A CircuitConfig object defining the masking behavior of the model
resid_config (ResidualUpdateModelConfig) – A ResidualUpdateModelConfig defining the behavior of the residual update model. Make sure that the probe_vectors argument aligns with this config!
loss – Either “contrastive” or “linear”.
“Contrastive” refers to the standard contrastive loss as described in Lepori et al. 2024. “Linear refers to an experimental loss function that links circuit probing and linear probing. :type loss: str, optional
- class NeuroSurgeon.Probing.probe_configs.SubnetworkProbeConfig(*args: Any, **kwargs: Any)
Bases:
PretrainedConfigA config object defining the behavior of the SubnetworkProbe
- Parameters:
probe_vectors (str) – The entries in a ResidualUpdateModel’s vector_cache that one will train a probe on. Ex. attn_update_1, mlp_stream_5
n_classes (int) – The number of classes in the probe task
circuit_config (CircuitConfig) – A CircuitConfig object defining the masking behavior of the model
resid_config (ResidualUpdateModelConfig) – A ResidualUpdateModelConfig defining the behavior of the residual update model. Make sure that the probe_vectors argument aligns with this config!
intermediate_size (int, optional) – If an MLP probe is required, the dimensionality of the hidden layer, defaults to -1
labeling (str, optional) – Either “sequence” or “token” - whether to expect one label per input or many, defaults to “sequence”
ResidualUpdateModel
- class NeuroSurgeon.Probing.residual_update_model.ResidualUpdateModel(*args: Any, **kwargs: Any)
Bases:
ModuleA simple wrapped that adds hooks into a model to return intermediate hidden states or updates from particular MLP or Attention layers. This is loosely modeled off of the Transformer Lens library from Neel Nanda (https://github.com/neelnanda-io/TransformerLens). This is a minimal implementation that is built to subnetwork based pruning efforts. Currently, it supports ViT, GPT2, GPTNeoX, BERT, RoBERTa, MPNet, ConvBERT, Ernie, and Electra models. On every forward pass, the specied updates/activations are placed in a vector_cache dictionary.
- Parameters:
config (ResidualUpdateModelConfig) – A configuration object defining which updates/activations to store
model (nn.Module (should be one of the model architectures stated above)) – A transformer model to wrap
- forward(**kwargs)
- train(train_bool: bool = True)
SubnetworkProbe
- class NeuroSurgeon.Probing.subnetwork_probe.SubnetworkProbe(*args: Any, **kwargs: Any)
Bases:
ModuleThis reimplements the technique introduced in Cao et al. 2021 (https://arxiv.org/abs/2104.03514). Probing introduces a linear layer or MLP to extract information from intermediate representations in a model. One can train the probe to classify inputs at the token or sequence level, using either intermediate updates or intermediate activations. Cao et al. introduced subnetwork probing, which optimizes a binary mask and a linear probe at the same time, resulting in low-complexity probes. This class implements this technique by introducing probing layers into a CircuitModel. One can also use this class to perform regular probing by specifying that no layers get masked in the CircuitModel config.
- Parameters:
config (SubnetworkProbeConfig) – A config file determining the behavior of the subnetwork probe
model (nn.Module) – The model to probe. Currently, it supports ViT, GPT2, GPTNeoX, BERT, RoBERTa, MPNet, ConvBERT, Ernie, and Electra models.
- train(train_bool: bool = True)
- forward(input_ids=None, labels=None, token_mask=None, return_dict=True, **kwargs)
Forward pass of the model
- Parameters:
input_ids (torch.Tensor, optional) – input tensors, defaults to None
labels (torch.Tensor, optional) – probing labels, defaults to None
token_mask (torch.Tensor, optional) – A mask defining which updates/residual stream entries should be mapped to labels, defaults to None
return_dict (bool, optional) – Whether to return an output dictionary or a tuple, defaults to True
- Returns:
An object that contains output predictions, loss, etc. (dictionary)
- Return type:
SequenceClassifierOutput or Tuple