matrix multiplication!

This commit is contained in:
Elizabeth Hunt 2023-10-23 10:17:03 -06:00
parent 8ba71912fa
commit 81979f09cf
Signed by: simponic
GPG Key ID: 52B3774857EB24B1
3 changed files with 65 additions and 0 deletions

View File

@ -35,6 +35,8 @@ extern Array_double *bsubst(Matrix_double *u, Array_double *b);
extern Array_double *fsubst(Matrix_double *l, Array_double *b); extern Array_double *fsubst(Matrix_double *l, Array_double *b);
extern Array_double *solve_matrix(Matrix_double *m, Array_double *b); extern Array_double *solve_matrix(Matrix_double *m, Array_double *b);
extern Array_double *m_dot_v(Matrix_double *m, Array_double *v); extern Array_double *m_dot_v(Matrix_double *m, Array_double *v);
extern Matrix_double *m_dot_m(Matrix_double *a, Matrix_double *b);
extern Array_double *col_v(Matrix_double *m, size_t x);
extern Matrix_double *copy_matrix(Matrix_double *m); extern Matrix_double *copy_matrix(Matrix_double *m);
extern void free_matrix(Matrix_double *m); extern void free_matrix(Matrix_double *m);
extern void format_matrix_into(Matrix_double *m, char *s); extern void format_matrix_into(Matrix_double *m, char *s);

View File

@ -14,6 +14,33 @@ Array_double *m_dot_v(Matrix_double *m, Array_double *v) {
return product; return product;
} }
Array_double *col_v(Matrix_double *m, size_t x) {
assert(x < m->cols);
Array_double *col = InitArrayWithSize(double, m->rows, 0.0);
for (size_t y = 0; y < m->rows; y++)
col->data[y] = m->data[y]->data[x];
return col;
}
Matrix_double *m_dot_m(Matrix_double *a, Matrix_double *b) {
assert(a->cols == b->rows);
Matrix_double *prod = InitMatrixWithSize(double, a->rows, b->cols, 0.0);
Array_double *curr_col;
for (size_t y = 0; y < a->rows; y++) {
for (size_t x = 0; x < b->cols; x++) {
curr_col = col_v(b, x);
prod->data[y]->data[x] = v_dot_v(curr_col, a->data[y]);
free_vector(curr_col);
}
}
return prod;
}
Matrix_double *put_identity_diagonal(Matrix_double *m) { Matrix_double *put_identity_diagonal(Matrix_double *m) {
assert(m->rows == m->cols); assert(m->rows == m->cols);
Matrix_double *copy = copy_matrix(m); Matrix_double *copy = copy_matrix(m);

View File

@ -94,3 +94,39 @@ UTEST(matrix, solve_matrix) {
free_vector(b); free_vector(b);
free_vector(solution); free_vector(solution);
} }
UTEST(matrix, col_v) {
Matrix_double *m = InitMatrixWithSize(double, 2, 3, 0.0);
// set element to its column index
for (size_t y = 0; y < m->rows; y++) {
for (size_t x = 0; x < m->cols; x++) {
m->data[y]->data[x] = x;
}
}
Array_double *col, *expected;
for (size_t x = 0; x < m->cols; x++) {
col = col_v(m, x);
expected = InitArrayWithSize(double, m->rows, (double)x);
EXPECT_TRUE(vector_equal(expected, col));
free_vector(col);
free_vector(expected);
}
free_matrix(m);
}
UTEST(matrix, m_dot_m) {
Matrix_double *a = InitMatrixWithSize(double, 1, 3, 12.0);
Matrix_double *b = InitMatrixWithSize(double, 3, 1, 10.0);
Matrix_double *prod = m_dot_m(a, b);
EXPECT_EQ(prod->cols, 1);
EXPECT_EQ(prod->rows, 1);
EXPECT_EQ(12.0 * 10.0 * 3, prod->data[0]->data[0]);
free_matrix(a);
free_matrix(b);
free_matrix(prod);
}