ParamTransformModule¶
- class dnois.torch.ParamTransformModule(*args, **kwargs)¶
A subclass of
torch.nn.Module
that can create and manage transformed parameters.Transformed parameter can be registered given either its nominal value (by calling overridden
register_parameter()
) or latent value (by callingregister_latent_parameter()
). Note that in first case an inverse transformation must be given to compute the latent value. If a parameter with same name has been registered already, it will be automatically converted to a transformed parameter. Specifying a transformation for existent parameter by callingset_transform()
has same effect. Transformations are specified in the form of aTransform
instance.Note
Implementation Detail The latent value of each transformed parameter is a
torch.nn.Parameter
attribute with name like_latent_<param>
.- register_latent_parameter(name: str, param: Parameter | None, transform: Transform)¶
Similar to
register_parameter()
, but takes as input the latent value rather than nominal value. In this way, the inverse transformation need not be provided since the initial latent value is known.If
name
corresponds to a vanilla parameter (i.e. not transformed parameter) it will be converted to a transformed one.- Parameters:
name (str) – Name of the parameter.
param (Parameter) – Latent
torch.nn.Parameter
instance to be registered.transform (Transform) – Transformation object.
- register_parameter(name: str, param: Parameter | None, transform: Transform = None) None ¶
Similar to
torch.nn.Module.register_parameter()
, but allows you to register a transformed parameter as long astransform
is given.If
name
corresponds to a vanilla parameter (i.e. not transformed parameter) buttransform
is given, it will be converted to a transformed one.- Parameters:
name (str) – Name of the parameter.
param (Parameter) – Nominal
torch.nn.Parameter
instance to be registered.transform (Transform) – Transformation object.
- remove_transform(name: str)¶
Remove transformation for parameter
name
.- Parameters:
name (str) – Name of the parameter.
- set_transform(name: str, transform: Transform)¶
Set transformation for parameter
name
.If
name
corresponds to a vanilla parameter (i.e. not transformed parameter) it will be converted to a transformed one.- Parameters:
name (str) – Name of the parameter.
transform (Transform) – Transformation object.
- property nominal_values: dict[str, Tensor]¶
A
dict
whose keys are names of all parameters of this module and values are their values. The values are nominal ones for transformed parameters.- Type:
dict[str, Tensor]
- property transformed_parameters: dict[str, tuple[Parameter, Transform]]¶
A
dict
whose keys are names of all transformed parameters of this module. The value corresponding to each key is a tuple containing:The latent value, a
torch.nn.Parameter
instance;Corresponding transformation object.
- Type:
dict[str, tuple[Parameter, Transform]]