NeuroSurgeon.Models package

Submodules

CircuitModel

class NeuroSurgeon.Models.circuit_model.CircuitModel(*args: Any, **kwargs: Any)

Bases: Module

CircuitModel is a wrapper around a Transformers model (or custom nn.Module that also returns an object from a forward pass), which replaces particular layers with MaskLayers. These MaskLayers modify the weight matrices of particular layers according to some strategy, which is defined in the CircuitConfig.

Parameters:
  • config (CircuitConfig) – A configuration object defining the how the CircuitModel wrapper modifes a model

  • model (nn.Module) – The model to modify

train(train_bool=True)

Similar to the normal nn.Module train function, except keeps the underlying model weights frozen if that is specified by the config argument

Parameters:

train_bool (bool) – Whether to put the model in train mode or eval mode, defaults to True

compute_l0_statistics()

Compute overall l0, max masking parameters, per-layer-l0 statistics

Returns:

A dictionary containing the computed statistics. Keys are “total_l0”, “max_l0”, “layer2l0”, “layer2maxl0”

Return type:

dict

property temperature

The temperature parameter for ContSparseLayers ONLY

Returns:

Temperature parameter

Return type:

float

forward(**kwargs)

Forward call of the wrapped model, adds L0 regurization if specified by the config

set_ablate_mode(ablation, force_resample=False)

Changes the ablate mode of the MaskLayers from what is specified in the CircuitConfig

Parameters:
  • ablation (str) –

    A string that determines how masks are produced from the mask layer parameters. Valid options include:

    • none: Producing a standard binary mask

    • zero_ablate: Inverting the standard binary mask. Used for pruning discovered subnetworks.

    • random_ablate: Inverting the standard binary mask and reinitializing zero’d elements. Used for pruning discovered subnetworks.

    • randomly_sampled: Sampling a random binary mask of the same size as the standard mask.

    • complement_sampled: Sampling a random binary mask of the same size as the standard mask from the complement set of entries as the standard mask.

  • force_resample (bool, optional) – If setting ablation=[“randomly_sampled”, “complement_sampled”], whether to randomly resample the generated mask, defaults to False

use_masks(value, name_list=None)

This function can be used to turn off masking behavior for either the entire model, or particular layers (if name_list!=None)

Parameters:
  • value (bool) – Whether to use masks or not

  • name_list (List[str], optional) – If set, it determines which subset of MaskLayers to turn on or off, defaults to None

Model Configs

class NeuroSurgeon.Models.model_configs.CircuitConfig(*args: Any, **kwargs: Any)

Bases: PretrainedConfig

A config object to define the behavior of a CircuitModel

Parameters:
  • mask_method (str, optional) – Which masking technique to use. Options include: [“continuous_sparsification”, “hard_concrete”, “magnitude_pruning”], defaults to “continuous_sparsification”

  • mask_hparams (dict, optional) –

    A dictionary defining hyperparameters specific to the specified mask method. See the documentation for each layer for details about these hyperparameters, defaults to {}

    • For “continuous_sparsification”: Requires “ablation”, “mask_unit”, “mask_bias”, “mask_init_value”

    • For “hard_concrete”: Requires “ablation”, “mask_unit”, “mask_bias”, “mask_init_percentage”

    • For “magnitude_pruning”: Requires “ablation”, “mask_bias”, “prune_percentage”

  • target_layers (List[str], optional) – A list of layers to turn into mask layers. These layer names can be obtained from the wrapped model’s state dict. They must correspond to nn.Linear, GPT-style Conv1D, nn.Conv2d, or nn.Conv1d layers. defaults to []

  • freeze_base (bool, optional) – Whether to freeze the weights and biases of the wrapped model, defaults to True

  • add_l0 (bool, optional) – Whether to add L0 regularization to the loss computed during a transformer model’s forward pass, defaults to True

  • l0_lambda (float, optional) – The weighting of the L0 regularization, should usually be scaled with parameter count, defaults to 1e-8

Module contents