matrix multiplication!
This commit is contained in:
parent
8ba71912fa
commit
81979f09cf
@ -35,6 +35,8 @@ extern Array_double *bsubst(Matrix_double *u, Array_double *b);
|
|||||||
extern Array_double *fsubst(Matrix_double *l, Array_double *b);
|
extern Array_double *fsubst(Matrix_double *l, Array_double *b);
|
||||||
extern Array_double *solve_matrix(Matrix_double *m, Array_double *b);
|
extern Array_double *solve_matrix(Matrix_double *m, Array_double *b);
|
||||||
extern Array_double *m_dot_v(Matrix_double *m, Array_double *v);
|
extern Array_double *m_dot_v(Matrix_double *m, Array_double *v);
|
||||||
|
extern Matrix_double *m_dot_m(Matrix_double *a, Matrix_double *b);
|
||||||
|
extern Array_double *col_v(Matrix_double *m, size_t x);
|
||||||
extern Matrix_double *copy_matrix(Matrix_double *m);
|
extern Matrix_double *copy_matrix(Matrix_double *m);
|
||||||
extern void free_matrix(Matrix_double *m);
|
extern void free_matrix(Matrix_double *m);
|
||||||
extern void format_matrix_into(Matrix_double *m, char *s);
|
extern void format_matrix_into(Matrix_double *m, char *s);
|
||||||
|
27
src/matrix.c
27
src/matrix.c
@ -14,6 +14,33 @@ Array_double *m_dot_v(Matrix_double *m, Array_double *v) {
|
|||||||
return product;
|
return product;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Array_double *col_v(Matrix_double *m, size_t x) {
|
||||||
|
assert(x < m->cols);
|
||||||
|
|
||||||
|
Array_double *col = InitArrayWithSize(double, m->rows, 0.0);
|
||||||
|
for (size_t y = 0; y < m->rows; y++)
|
||||||
|
col->data[y] = m->data[y]->data[x];
|
||||||
|
|
||||||
|
return col;
|
||||||
|
}
|
||||||
|
|
||||||
|
Matrix_double *m_dot_m(Matrix_double *a, Matrix_double *b) {
|
||||||
|
assert(a->cols == b->rows);
|
||||||
|
|
||||||
|
Matrix_double *prod = InitMatrixWithSize(double, a->rows, b->cols, 0.0);
|
||||||
|
|
||||||
|
Array_double *curr_col;
|
||||||
|
for (size_t y = 0; y < a->rows; y++) {
|
||||||
|
for (size_t x = 0; x < b->cols; x++) {
|
||||||
|
curr_col = col_v(b, x);
|
||||||
|
prod->data[y]->data[x] = v_dot_v(curr_col, a->data[y]);
|
||||||
|
free_vector(curr_col);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return prod;
|
||||||
|
}
|
||||||
|
|
||||||
Matrix_double *put_identity_diagonal(Matrix_double *m) {
|
Matrix_double *put_identity_diagonal(Matrix_double *m) {
|
||||||
assert(m->rows == m->cols);
|
assert(m->rows == m->cols);
|
||||||
Matrix_double *copy = copy_matrix(m);
|
Matrix_double *copy = copy_matrix(m);
|
||||||
|
@ -94,3 +94,39 @@ UTEST(matrix, solve_matrix) {
|
|||||||
free_vector(b);
|
free_vector(b);
|
||||||
free_vector(solution);
|
free_vector(solution);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
UTEST(matrix, col_v) {
|
||||||
|
Matrix_double *m = InitMatrixWithSize(double, 2, 3, 0.0);
|
||||||
|
// set element to its column index
|
||||||
|
for (size_t y = 0; y < m->rows; y++) {
|
||||||
|
for (size_t x = 0; x < m->cols; x++) {
|
||||||
|
m->data[y]->data[x] = x;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Array_double *col, *expected;
|
||||||
|
for (size_t x = 0; x < m->cols; x++) {
|
||||||
|
col = col_v(m, x);
|
||||||
|
expected = InitArrayWithSize(double, m->rows, (double)x);
|
||||||
|
EXPECT_TRUE(vector_equal(expected, col));
|
||||||
|
free_vector(col);
|
||||||
|
free_vector(expected);
|
||||||
|
}
|
||||||
|
|
||||||
|
free_matrix(m);
|
||||||
|
}
|
||||||
|
|
||||||
|
UTEST(matrix, m_dot_m) {
|
||||||
|
Matrix_double *a = InitMatrixWithSize(double, 1, 3, 12.0);
|
||||||
|
Matrix_double *b = InitMatrixWithSize(double, 3, 1, 10.0);
|
||||||
|
|
||||||
|
Matrix_double *prod = m_dot_m(a, b);
|
||||||
|
|
||||||
|
EXPECT_EQ(prod->cols, 1);
|
||||||
|
EXPECT_EQ(prod->rows, 1);
|
||||||
|
EXPECT_EQ(12.0 * 10.0 * 3, prod->data[0]->data[0]);
|
||||||
|
|
||||||
|
free_matrix(a);
|
||||||
|
free_matrix(b);
|
||||||
|
free_matrix(prod);
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user