Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions include/utils/matrix.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,19 @@ typedef void (*matrix_broadcast_fill_values_fn)(matrix *A, broadcast_type type,
typedef matrix *(*matrix_diag_vec_alloc_fn)(matrix *A);
typedef void (*matrix_diag_vec_fill_values_fn)(matrix *A, matrix *out);

/* Allocate C as a row-wise reduction of A. The reduction pattern is chosen by
axis:
- axis = -1: sum all rows of A. C has shape (1, A->n).
- axis = 0: block-sum rows in consecutive groups of d1. C has shape (A->m
/ d1, A->n). C[j, :] = sum_{i in [j*d1, (j+1)*d1)} A[i, :].
- axis = 1: stride-sum rows at spacing d1. C has shape (d1, A->n). C[j, :]
= sum_{i : i % d1 == j} A[i, :].

Caller pre-allocates idx_map of size A->nnz that can be used to compute the
numerical result of the operation using via accumulation. */
typedef matrix *(*matrix_sum_row_partition_alloc_fn)(matrix *A, int axis, int d1,
int *idx_map);

typedef void (*matrix_free_fn)(matrix *self);

struct matrix
Expand Down Expand Up @@ -141,6 +154,7 @@ struct matrix
matrix_broadcast_fill_values_fn broadcast_fill_values;
matrix_diag_vec_alloc_fn diag_vec_alloc;
matrix_diag_vec_fill_values_fn diag_vec_fill_values;
matrix_sum_row_partition_alloc_fn sum_row_partition_alloc;

/* Lifecycle */
matrix_free_fn free_fn;
Expand Down
46 changes: 5 additions & 41 deletions src/atoms/affine/sum.c
Original file line number Diff line number Diff line change
Expand Up @@ -85,49 +85,13 @@ static void jacobian_init_impl(expr *node)
{
expr *x = node->left;
sum_expr *snode = (sum_expr *) node;
int axis = snode->axis;

/* initialize child's jacobian */
jacobian_init(x);
CSR_matrix *Jx = x->jacobian->to_csr(x->jacobian);

/* we never have to store more than the child's nnz, nor more than the
output's cell count */
int max_nnz = MIN(Jx->nnz, node->size * node->n_vars);
CSR_matrix *jac = new_CSR_matrix(node->size, node->n_vars, max_nnz);
node->work->iwork = sp_malloc(MAX(jac->n, Jx->nnz) * sizeof(int));
snode->idx_map = sp_malloc(Jx->nnz * sizeof(int));

/* the idx_map array maps each nonzero entry j in x->jacobian
to the corresponding index in the output row matrix C. Specifically, for
each nonzero entry j in A, idx_map[j] gives the position in C->x where
the value from x->jacobian->x[j] should be accumulated. */

if (axis == -1)
{
sum_all_rows_csr_alloc(Jx, jac, node->work->iwork, snode->idx_map);
}
else if (axis == 0)
{
sum_block_of_rows_csr_alloc(Jx, jac, x->d1, node->work->iwork,
snode->idx_map);
}
else if (axis == 1)
{
sum_evenly_spaced_rows_csr_alloc(Jx, jac, node->size, node->work->iwork,
snode->idx_map);
}

/* For stacked_pd children, child->jacobian->base.x is block-major while
csr->x is row-major sorted. Re-index idx_map so it can be applied
directly to base.x in eval_jacobian. */
if (x->jacobian->is_stacked_pd)
{
compose_csr_idx_map_for_spd((const stacked_pd *) x->jacobian, Jx,
snode->idx_map);
}

node->jacobian = new_sparse_matrix(jac);
/* sum_row_partition_alloc fills idx_map so eval_jacobian can accumulate from
child->jacobian->x. */
snode->idx_map = sp_malloc(x->jacobian->nnz * sizeof(int));
node->jacobian = x->jacobian->sum_row_partition_alloc(x->jacobian, snode->axis,
x->d1, snode->idx_map);
}

static void eval_jacobian(expr *node)
Expand Down
143 changes: 143 additions & 0 deletions src/utils/permuted_dense.c
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,148 @@ static void permuted_dense_vtable_block_left_mult_values(const matrix *A,
I_kron_A_fill_values(A, J, C, pd->kernel_dwork);
}

/* C = sum-all-rows of A. */
static matrix *sum_all_rows_pd_alloc(matrix *A, int *idx_map)
{
permuted_dense *pd = (permuted_dense *) A;

/* allocate C */
int zero = 0;
matrix *C = new_permuted_dense(1, A->n, 1, pd->n0, &zero, pd->col_perm, NULL);

/* fill idx_map */
for (int i = 0; i < pd->m0; i++)
{
int *idx_base = idx_map + i * pd->n0;
for (int j = 0; j < pd->n0; j++)
{
idx_base[j] = j;
}
}
return C;
}

/* C = block-sum of A's rows in consecutive groups of d1. */
static matrix *sum_block_of_rows_pd_alloc(matrix *A_matrix, int d1, int *idx_map)
{
permuted_dense *A = (permuted_dense *) A_matrix;
int C_m0 = 0;
int last_bucket = -1;
int *C_row_perm = (int *) sp_malloc(A->m0 * sizeof(int));

/* per input dense row ii, the index of its bucket within C_row_perm */
int *row_to_out = (int *) sp_malloc(A->m0 * sizeof(int));

// ---------------------------------------------------------------------------
// determine C's row_perm
// ---------------------------------------------------------------------------

/* for every row in the dense block */
for (int ii = 0; ii < A->m0; ii++)
{
/* find the bucket to which this row belongs */
int bucket = A->row_perm[ii] / d1;

/* add the bucket to C if it's new */
if (bucket != last_bucket)
{
C_row_perm[C_m0++] = bucket;
last_bucket = bucket;
}

/* map the input row of A to its row in C */
row_to_out[ii] = C_m0 - 1;
}

matrix *C = new_permuted_dense(A_matrix->m / d1, A_matrix->n, C_m0, A->n0,
C_row_perm, A->col_perm, NULL);

// ---------------------------------------------------------------------------
// fill idx_map
// ---------------------------------------------------------------------------
for (int ii = 0; ii < A->m0; ii++)
{
int offset = row_to_out[ii] * A->n0;
int *idx_base = idx_map + ii * A->n0;
for (int jj = 0; jj < A->n0; jj++)
{
idx_base[jj] = offset + jj;
}
}

sp_free(row_to_out);
sp_free(C_row_perm);
return C;
}

/* C = stride-sum of A's rows at modular spacing d1. C has shape (d1, A->n);
C[j, :] = sum_{i : i % d1 == j} A[i, :]. */
static matrix *sum_evenly_spaced_rows_pd_alloc(matrix *self, int d1, int *idx_map)
{
permuted_dense *A = (permuted_dense *) self;

// ---------------------------------------------------------------------------
// which buckets of [0, d1) are hit by A->row_perm?
// ---------------------------------------------------------------------------
bool *seen = (bool *) sp_calloc(d1, sizeof(bool));
for (int ii = 0; ii < A->m0; ii++)
{
seen[A->row_perm[ii] % d1] = true;
}

// ---------------------------------------------------------------------------
// determine C's row_perm (guarantees sorted order)
// ---------------------------------------------------------------------------
int *C_row_perm = (int *) sp_malloc(A->m0 * sizeof(int));
int *bucket_to_out_idx = (int *) sp_malloc(d1 * sizeof(int));
int C_m0 = 0;
for (int ii = 0; ii < d1; ii++)
{
if (seen[ii])
{
bucket_to_out_idx[ii] = C_m0;
C_row_perm[C_m0++] = ii;
}
}
sp_free(seen);

matrix *C =
new_permuted_dense(d1, self->n, C_m0, A->n0, C_row_perm, A->col_perm, NULL);

// ---------------------------------------------------------------------------
// fill idx_map
// ---------------------------------------------------------------------------
for (int ii = 0; ii < A->m0; ii++)
{
int base = bucket_to_out_idx[A->row_perm[ii] % d1] * A->n0;
int *idx_base = idx_map + ii * A->n0;
for (int jj = 0; jj < A->n0; jj++)
{
idx_base[jj] = base + jj;
}
}

sp_free(bucket_to_out_idx);
sp_free(C_row_perm);
return C;
}

static matrix *permuted_dense_vtable_sum_row_partition_alloc(matrix *self, int axis,
int d1, int *idx_map)
{
if (axis == -1)
{
return sum_all_rows_pd_alloc(self, idx_map);
}

if (axis == 0)
{
return sum_block_of_rows_pd_alloc(self, d1, idx_map);
}

return sum_evenly_spaced_rows_pd_alloc(self, d1, idx_map); /* axis == 1 */
}

static void wire_vtable(permuted_dense *pd)
{
pd->base.is_permuted_dense = true;
Expand All @@ -420,6 +562,7 @@ static void wire_vtable(permuted_dense *pd)
pd->base.broadcast_fill_values = permuted_dense_vtable_broadcast_fill_values;
pd->base.diag_vec_alloc = permuted_dense_vtable_diag_vec_alloc;
pd->base.diag_vec_fill_values = permuted_dense_vtable_diag_vec_fill_values;
pd->base.sum_row_partition_alloc = permuted_dense_vtable_sum_row_partition_alloc;
pd->base.refresh_csc_values = permuted_dense_refresh_csc_values;
}

Expand Down
40 changes: 40 additions & 0 deletions src/utils/sparse_matrix.c
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include "utils/sparse_matrix.h"

#include "utils/CSC_matrix.h"
#include "utils/CSR_sum.h"
#include "utils/linalg_sparse_matmuls.h"
#include "utils/matrix.h"
#include "utils/mini_numpy.h"
Expand Down Expand Up @@ -311,6 +312,44 @@ static void sparse_refresh_csc_values(matrix *self)
csr_to_csc_fill_values(sm->csr, sm->csc_cache, sm->csc_iwork);
}

static matrix *sparse_sum_row_partition_alloc(matrix *self, int axis, int d1,
int *idx_map)
{
CSR_matrix *A = ((sparse_matrix *) self)->csr;
int m;
if (axis == -1)
{
m = 1;
}
else if (axis == 0)
{
m = A->m / d1;
}
else
{
m = d1;
}
int max_nnz = MIN(A->nnz, m * A->n);
CSR_matrix *out = new_CSR_matrix(m, A->n, max_nnz);
int *iwork = (int *) sp_malloc(MAX(A->n, A->nnz) * sizeof(int));

if (axis == -1)
{
sum_all_rows_csr_alloc(A, out, iwork, idx_map);
}
else if (axis == 0)
{
sum_block_of_rows_csr_alloc(A, out, d1, iwork, idx_map);
}
else
{
sum_evenly_spaced_rows_csr_alloc(A, out, m, iwork, idx_map);
}

sp_free(iwork);
return new_sparse_matrix(out);
}

static void wire_vtable(sparse_matrix *sm)
{
sm->base.block_left_mult_vec = sparse_block_left_mult_vec;
Expand All @@ -331,6 +370,7 @@ static void wire_vtable(sparse_matrix *sm)
sm->base.broadcast_fill_values = sparse_broadcast_fill_values;
sm->base.diag_vec_alloc = sparse_diag_vec_alloc;
sm->base.diag_vec_fill_values = sparse_diag_vec_fill_values;
sm->base.sum_row_partition_alloc = sparse_sum_row_partition_alloc;
sm->base.refresh_csc_values = sparse_refresh_csc_values;
sm->base.free_fn = sparse_free;
}
Expand Down
Loading
Loading