From 81979f09cf100db32deb0e1917dabb1fe435194c Mon Sep 17 00:00:00 2001 From: Elizabeth Hunt Date: Mon, 23 Oct 2023 10:17:03 -0600 Subject: [PATCH] matrix multiplication! --- inc/lizfcm.h | 2 ++ src/matrix.c | 27 +++++++++++++++++++++++++++ test/matrix.t.c | 36 ++++++++++++++++++++++++++++++++++++ 3 files changed, 65 insertions(+) diff --git a/inc/lizfcm.h b/inc/lizfcm.h index b59651e..12c1278 100644 --- a/inc/lizfcm.h +++ b/inc/lizfcm.h @@ -35,6 +35,8 @@ extern Array_double *bsubst(Matrix_double *u, Array_double *b); extern Array_double *fsubst(Matrix_double *l, Array_double *b); extern Array_double *solve_matrix(Matrix_double *m, Array_double *b); extern Array_double *m_dot_v(Matrix_double *m, Array_double *v); +extern Matrix_double *m_dot_m(Matrix_double *a, Matrix_double *b); +extern Array_double *col_v(Matrix_double *m, size_t x); extern Matrix_double *copy_matrix(Matrix_double *m); extern void free_matrix(Matrix_double *m); extern void format_matrix_into(Matrix_double *m, char *s); diff --git a/src/matrix.c b/src/matrix.c index 5766d94..22dd171 100644 --- a/src/matrix.c +++ b/src/matrix.c @@ -14,6 +14,33 @@ Array_double *m_dot_v(Matrix_double *m, Array_double *v) { return product; } +Array_double *col_v(Matrix_double *m, size_t x) { + assert(x < m->cols); + + Array_double *col = InitArrayWithSize(double, m->rows, 0.0); + for (size_t y = 0; y < m->rows; y++) + col->data[y] = m->data[y]->data[x]; + + return col; +} + +Matrix_double *m_dot_m(Matrix_double *a, Matrix_double *b) { + assert(a->cols == b->rows); + + Matrix_double *prod = InitMatrixWithSize(double, a->rows, b->cols, 0.0); + + Array_double *curr_col; + for (size_t y = 0; y < a->rows; y++) { + for (size_t x = 0; x < b->cols; x++) { + curr_col = col_v(b, x); + prod->data[y]->data[x] = v_dot_v(curr_col, a->data[y]); + free_vector(curr_col); + } + } + + return prod; +} + Matrix_double *put_identity_diagonal(Matrix_double *m) { assert(m->rows == m->cols); Matrix_double *copy = copy_matrix(m); diff --git a/test/matrix.t.c b/test/matrix.t.c index 1def9ef..5386635 100644 --- a/test/matrix.t.c +++ b/test/matrix.t.c @@ -94,3 +94,39 @@ UTEST(matrix, solve_matrix) { free_vector(b); free_vector(solution); } + +UTEST(matrix, col_v) { + Matrix_double *m = InitMatrixWithSize(double, 2, 3, 0.0); + // set element to its column index + for (size_t y = 0; y < m->rows; y++) { + for (size_t x = 0; x < m->cols; x++) { + m->data[y]->data[x] = x; + } + } + + Array_double *col, *expected; + for (size_t x = 0; x < m->cols; x++) { + col = col_v(m, x); + expected = InitArrayWithSize(double, m->rows, (double)x); + EXPECT_TRUE(vector_equal(expected, col)); + free_vector(col); + free_vector(expected); + } + + free_matrix(m); +} + +UTEST(matrix, m_dot_m) { + Matrix_double *a = InitMatrixWithSize(double, 1, 3, 12.0); + Matrix_double *b = InitMatrixWithSize(double, 3, 1, 10.0); + + Matrix_double *prod = m_dot_m(a, b); + + EXPECT_EQ(prod->cols, 1); + EXPECT_EQ(prod->rows, 1); + EXPECT_EQ(12.0 * 10.0 * 3, prod->data[0]->data[0]); + + free_matrix(a); + free_matrix(b); + free_matrix(prod); +}