Skip to content

Commit

Permalink
Fixing "GCXS matmul => slice leads to incorrect results" (#611)
Browse files Browse the repository at this point in the history
  • Loading branch information
EuGig authored Dec 19, 2023
1 parent 7c0c962 commit 0e283ff
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 1 deletion.
7 changes: 7 additions & 0 deletions sparse/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -724,6 +724,13 @@ def _dot_csr_csr(
sums[temp] = 0

indptr[i + 1] = nnz

if len(indices) == (n_col * n_row):
for i in range(len(indices) // n_col):
j = n_col * i
k = n_col * (1 + i)
data[j:k] = data[j:k][::-1]
indices[j:k] = indices[j:k][::-1]
return data, indices, indptr

return _dot_csr_csr
Expand Down
23 changes: 23 additions & 0 deletions sparse/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,31 @@ def assert_eq(x, y, check_nnz=True, compare_dtype=True, **kwargs):
assert check_equal(xx, yy, **kwargs)


def assert_gcxs_slicing(s, x):
"""
Util function to test slicing of GCXS matrices after product multiplication.
For simplicity, it tests only tensors with number of dimension = 3.
Parameters
----------
s: sparse product matrix
x: dense product matrix
"""
row = np.random.randint(s.shape[s.ndim - 2])
assert np.allclose(s[0][row].data, [num for num in x[0][row] if num != 0])

# regression test
col = s.shape[s.ndim - 1]
for i in range(len(s.indices) // col):
j = col * i
k = col * (1 + i)
s.data[j:k] = s.data[j:k][::-1]
s.indices[j:k] = s.indices[j:k][::-1]
assert np.array_equal(s[0][row].data, np.array([]))


def assert_nnz(s, x):
fill_value = s.fill_value if hasattr(s, "fill_value") else _zero_of_dtype(s.dtype)

assert np.sum(~equivalent(x, fill_value)) == s.nnz


Expand Down
63 changes: 62 additions & 1 deletion sparse/tests/test_dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import sparse
from sparse._compressed import GCXS
from sparse import COO
from sparse._utils import assert_eq
from sparse._utils import assert_eq, assert_gcxs_slicing


@pytest.mark.parametrize(
Expand Down Expand Up @@ -341,3 +341,64 @@ def test_dot_dense(dtype1, dtype2, ndim1, ndim2):
assert_eq(sparse.matmul(a, b), np.matmul(a, b))
if ndim1 == 2 and ndim2 == 2:
assert_eq(sparse.tensordot(a, b), np.tensordot(a, b))


@pytest.mark.parametrize(
"a_shape, b_shape",
[((3, 4, 5), (5, 6)), ((2, 8, 6), (6, 3))],
)
def test_dot_GCXS_slicing(a_shape, b_shape):
sa = sparse.random(shape=a_shape, density=1, format="gcxs")
sb = sparse.random(shape=b_shape, density=1, format="gcxs")

a = sa.todense()
b = sb.todense()

# tests dot
sa_sb = sparse.dot(sa, sb)
a_b = np.dot(a, b)

assert_gcxs_slicing(sa_sb, a_b)


@pytest.mark.parametrize(
"a_shape,b_shape,axes",
[
[(3, 4, 5), (4, 3), (1, 0)],
[(3, 4), (5, 4, 3), (1, 1)],
[(5, 9), (9, 5, 6), (0, 1)],
],
)
def test_tensordot_GCXS_slicing(a_shape, b_shape, axes):
sa = sparse.random(shape=a_shape, density=1, format="gcxs")
sb = sparse.random(shape=b_shape, density=1, format="gcxs")

a = sa.todense()
b = sb.todense()

sa_sb = sparse.tensordot(sa, sb, axes)
a_b = np.tensordot(a, b, axes)

assert_gcxs_slicing(sa_sb, a_b)


@pytest.mark.parametrize(
"a_shape, b_shape",
[
[(1, 1, 5), (3, 5, 6)],
[(3, 4, 5), (1, 5, 6)],
[(3, 4, 5), (3, 5, 6)],
[(3, 4, 5), (5, 6)],
],
)
def test_matmul_GCXS_slicing(a_shape, b_shape):
sa = sparse.random(shape=a_shape, density=1, format="gcxs")
sb = sparse.random(shape=b_shape, density=1, format="gcxs")

a = sa.todense()
b = sb.todense()

sa_sb = sparse.matmul(sa, sb)
a_b = np.matmul(a, b)

assert_gcxs_slicing(sa_sb, a_b)

0 comments on commit 0e283ff

Please sign in to comment.