PyTorch Extension

This package provides some enhancements to PyTorch.

Mathematics

abs2(x)

Computes the squared magnitude of a complex tensor.

expi(x)

Computes the complex exponential of a real-valued tensor:

polynomial(x, coefficients)

Computes the value of a polynomial:

Functionalities for module

WrapperModule(func, *args, **kwargs)

A class to wrap a function as a torch.nn.Module.

ParamTransformModule(*args, **kwargs)

A subclass of torch.nn.Module that can create and manage transformed parameters.

Transform(fn, *args, **kwargs)

Base class for parameter transformations (see ParamTransformModule and transformed parameters).

Tensor operation

as1d(x[, ndim, dim])

Transforms x into a ndim-D tensor but arranges its all elements into dimension dim.

broadcastable(*tensors_or_shapes)

Check whether some tensors or some tensor shapes are broadcastable.