Skip to content

fail to pass the torch.autograd.gradcheck #17

@zhf-0

Description

@zhf-0

Thank you very much for providing such a good tool!

My problem is that when the input A is a 'real' sparse matrix, not the sparse matrix converted from a dense matrix, the torch.autograd.gradcheck() function will throw an exception. The python program I use is

import scipy
import torch
class SparseSolve(torch.autograd.Function):
    @staticmethod
    def forward(ctx, A, b):
        '''
        A is a torch coo sparse matrix
        b is a tensor
        '''
        if A.ndim != 2 or (A.shape[0] != A.shape[1]):
            raise ValueError("A should be a square 2D matrix.")

        A = A.coalesce()
        A_idx = A.indices().to('cpu').numpy()
        A_val = A.values().to('cpu').numpy()
        sci_A = coo_matrix((A_val,(A_idx[0,:],A_idx[1,:]) ),shape=A.shape)
        sci_A = sci_A.tocsr()

        np_b = b.detach().cpu().numpy()
        # Solver the sparse system
        if np_b.ndim == 1:
            np_x = scipy.sparse.linalg.spsolve(sci_A, np_b)
        else:
            factorisedsolver = scipy.sparse.linalg.factorized(sci_A)
            np_x = factorisedsolver(np_b)

        x = torch.as_tensor(np_x)
        # Not sure if the following is needed / helpful
        if A.requires_grad or b.requires_grad:
            x.requires_grad = True

        # Save context for backward pass
        ctx.save_for_backward(A, b, x)
        return x

    @staticmethod
    def backward(ctx, grad):
        # Recover context
        A, b, x = ctx.saved_tensors

        # Compute gradient with respect to b
        gradb = SparseSolve.apply(A.t(), grad)

        gradAidx = A.indices()
        mgradbselect = -gradb.index_select(0,gradAidx[0,:])
        xselect = x.index_select(0,gradAidx[1,:])
        mgbx = mgradbselect * xselect
        if x.dim() == 1:
            gradAvals = mgbx
        else:
            gradAvals = torch.sum( mgbx, dim=1 )
        gradA = torch.sparse_coo_tensor(gradAidx, gradAvals, A.shape)
        return gradA, gradb

sparsesolve = SparseSolve.apply

row_vec = torch.tensor([0, 0, 1, 2])
col_vec = torch.tensor([0, 2, 1, 2])
val_vec = torch.tensor([3.0, 4.0, 5.0, 6.0],dtype=torch.float64)
A = torch.sparse_coo_tensor(torch.stack((row_vec,col_vec),0), val_vec, (3, 3))
b = torch.ones(3, dtype=torch.float64, requires_grad=False)
A.requires_grad=True
b.requires_grad=True
res = torch.autograd.gradcheck(sparsesolve, [A, b], raise_exception=True)
print(res)

which is based on the program from Differentiable sparse linear solver with cupy backend - “unsupported tensor layout: Sparse” in gradcheck, whose author @tvercaut wrote the program based on your blog and program. I modified the program and limited it to running only on CPU.

The output is

torch.autograd.gradcheck.GradcheckError: Jacobian mismatch for output 0 with respect to input 0,
numerical:tensor([[-3.7037e-02,  0.0000e+00, -1.3878e-11],
        [-6.6667e-02,  0.0000e+00,  0.0000e+00],
        [-5.5556e-02,  0.0000e+00,  0.0000e+00],
        [ 0.0000e+00, -2.2222e-02,  0.0000e+00],
        [ 0.0000e+00, -4.0000e-02,  0.0000e+00],
        [ 0.0000e+00, -3.3333e-02,  0.0000e+00],
        [ 2.4691e-02,  0.0000e+00, -1.8519e-02],
        [ 4.4444e-02,  0.0000e+00, -3.3333e-02],
        [ 3.7037e-02,  0.0000e+00, -2.7778e-02]], dtype=torch.float64)
analytical:tensor([[-0.0370,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [-0.0556,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000, -0.0400,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000],
        [ 0.0370,  0.0000, -0.0278]], dtype=torch.float64)

The success of your and @tvercaut's program in passing the gradient check can be attributed to the fact that the sparse matrix A you used is actually a dense matrix. Consequently, the autograd() function computes the gradient for each element.

The derivative formula from your blog is
$$\frac{\partial L}{\partial A} = - \frac{\partial L}{\partial b} \otimes x$$
Since the matrix A is sparse, then $\frac{\partial L}{\partial A_{ij}}=0$ when $A_{ij}=0$, but the results computed by pytorch show it's not true. If I change backward() function into

def backward(ctx, grad):
    A, b, x = ctx.saved_tensors
    gradb = SparseSolve.apply(A.t(), grad)

    if x.ndim == 1:
        gradA = -gradb.reshape(-1,1) @ x.reshape(1,-1)  
    else:
        gradA = -gradb @ x.T 

Then the gradient check is passed. However the gradA is now a dense matrix, which is not consistent to the theoretical result. There is a similar issue #13 without detailed explanation. So I want to ask which gradient is right ? the sparse one or the dense one?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions