matrix multiplication!
This commit is contained in:
parent
8ba71912fa
commit
81979f09cf
@ -35,6 +35,8 @@ extern Array_double *bsubst(Matrix_double *u, Array_double *b);
|
||||
extern Array_double *fsubst(Matrix_double *l, Array_double *b);
|
||||
extern Array_double *solve_matrix(Matrix_double *m, Array_double *b);
|
||||
extern Array_double *m_dot_v(Matrix_double *m, Array_double *v);
|
||||
extern Matrix_double *m_dot_m(Matrix_double *a, Matrix_double *b);
|
||||
extern Array_double *col_v(Matrix_double *m, size_t x);
|
||||
extern Matrix_double *copy_matrix(Matrix_double *m);
|
||||
extern void free_matrix(Matrix_double *m);
|
||||
extern void format_matrix_into(Matrix_double *m, char *s);
|
||||
|
27
src/matrix.c
27
src/matrix.c
@ -14,6 +14,33 @@ Array_double *m_dot_v(Matrix_double *m, Array_double *v) {
|
||||
return product;
|
||||
}
|
||||
|
||||
Array_double *col_v(Matrix_double *m, size_t x) {
|
||||
assert(x < m->cols);
|
||||
|
||||
Array_double *col = InitArrayWithSize(double, m->rows, 0.0);
|
||||
for (size_t y = 0; y < m->rows; y++)
|
||||
col->data[y] = m->data[y]->data[x];
|
||||
|
||||
return col;
|
||||
}
|
||||
|
||||
Matrix_double *m_dot_m(Matrix_double *a, Matrix_double *b) {
|
||||
assert(a->cols == b->rows);
|
||||
|
||||
Matrix_double *prod = InitMatrixWithSize(double, a->rows, b->cols, 0.0);
|
||||
|
||||
Array_double *curr_col;
|
||||
for (size_t y = 0; y < a->rows; y++) {
|
||||
for (size_t x = 0; x < b->cols; x++) {
|
||||
curr_col = col_v(b, x);
|
||||
prod->data[y]->data[x] = v_dot_v(curr_col, a->data[y]);
|
||||
free_vector(curr_col);
|
||||
}
|
||||
}
|
||||
|
||||
return prod;
|
||||
}
|
||||
|
||||
Matrix_double *put_identity_diagonal(Matrix_double *m) {
|
||||
assert(m->rows == m->cols);
|
||||
Matrix_double *copy = copy_matrix(m);
|
||||
|
@ -94,3 +94,39 @@ UTEST(matrix, solve_matrix) {
|
||||
free_vector(b);
|
||||
free_vector(solution);
|
||||
}
|
||||
|
||||
UTEST(matrix, col_v) {
|
||||
Matrix_double *m = InitMatrixWithSize(double, 2, 3, 0.0);
|
||||
// set element to its column index
|
||||
for (size_t y = 0; y < m->rows; y++) {
|
||||
for (size_t x = 0; x < m->cols; x++) {
|
||||
m->data[y]->data[x] = x;
|
||||
}
|
||||
}
|
||||
|
||||
Array_double *col, *expected;
|
||||
for (size_t x = 0; x < m->cols; x++) {
|
||||
col = col_v(m, x);
|
||||
expected = InitArrayWithSize(double, m->rows, (double)x);
|
||||
EXPECT_TRUE(vector_equal(expected, col));
|
||||
free_vector(col);
|
||||
free_vector(expected);
|
||||
}
|
||||
|
||||
free_matrix(m);
|
||||
}
|
||||
|
||||
UTEST(matrix, m_dot_m) {
|
||||
Matrix_double *a = InitMatrixWithSize(double, 1, 3, 12.0);
|
||||
Matrix_double *b = InitMatrixWithSize(double, 3, 1, 10.0);
|
||||
|
||||
Matrix_double *prod = m_dot_m(a, b);
|
||||
|
||||
EXPECT_EQ(prod->cols, 1);
|
||||
EXPECT_EQ(prod->rows, 1);
|
||||
EXPECT_EQ(12.0 * 10.0 * 3, prod->data[0]->data[0]);
|
||||
|
||||
free_matrix(a);
|
||||
free_matrix(b);
|
||||
free_matrix(prod);
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user