kempnerforge.training.freeze

Parameter freezing helpers.

freeze_params is the primitive: toggle requires_grad on parameters whose fully-qualified name matches any of the provided fnmatch patterns.

apply_freeze_specs consumes a list of FreezeSpec entries (e.g. from VLMConfig.freeze) and resolves each spec’s module field against a pattern map (typically DEFAULT_MODULE_PATTERNS or an arch-specific override on the config). A raw fnmatch pattern is passed through unchanged when it is not a known alias.

canonical_freeze_meta produces a stable, reorder-invariant serialization of a freeze-spec list for checkpoint metadata, so a checkpoint saved with [A, B] matches one loaded with [B, A] as long as the effective mask is identical.

effective_freeze resolves the active freeze-spec list at a given training step from a base list (always-on) and a list of FreezeStage step-boundary transitions. Used at save (records the post-transition state in metadata), at load (computes the expected metadata for the compare), and at the training-loop hook site (applies stage transitions when step reaches them).

Functions

apply_freeze_specs(model, specs, pattern_map)

Apply a list of freeze specs to model.

canonical_freeze_meta(specs)

Return a sorted, deduplicated serialization of freeze specs.

effective_freeze(step, base, schedule[, ...])

Compute the active freeze-spec list at step.

freeze_params(model, patterns, *[, frozen])

Toggle requires_grad on parameters matching any fnmatch pattern.

kempnerforge.training.freeze.freeze_params(model, patterns, *, frozen=True)[source]

Toggle requires_grad on parameters matching any fnmatch pattern.

Only parameters whose current state differs from the target are flipped, so calling this twice with the same arguments is idempotent. Returns the number of elements (param.numel() summed) that were actually flipped.

Parameters:
Return type:

int

kempnerforge.training.freeze.apply_freeze_specs(model, specs, pattern_map)[source]

Apply a list of freeze specs to model.

For each spec, the module field is looked up in pattern_map; if present, its pattern list is used. Otherwise module is treated as a raw fnmatch pattern. Returns {spec.module: n_params_flipped}.

Parameters:
Return type:

dict[str, int]

kempnerforge.training.freeze.canonical_freeze_meta(specs)[source]

Return a sorted, deduplicated serialization of freeze specs.

The output is safe to JSON-encode and compare across runs: two semantically equivalent freeze-spec lists (same (module, frozen) pairs, any order or duplicates) produce byte-equal JSON.

Parameters:

specs (Iterable[FreezeSpec])

Return type:

list[dict[str, object]]

kempnerforge.training.freeze.effective_freeze(step, base, schedule, valid_modules=None)[source]

Compute the active freeze-spec list at step.

Resolution rule:

  • Start from base (build-time freeze list).

  • For each FreezeStage with start_step <= step, in ascending start_step order, override base entries on conflicting module keys (last-write-wins). Stages with start_step > step are ignored.

Module-key validation:

  • When valid_modules is provided, every spec.module referenced in base or in any applied stage must appear in the set; otherwise ValueError. This catches typos at load time rather than at the next step boundary.

Returns the list of active specs (one per module key).

Parameters:
Return type:

list[FreezeSpec]