Linear algebra things
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

175 lines
5.0 KiB

  1. /*
  2. SSAD Assignment 3
  3. Implement least square approximation
  4. (matrix multiplication and inverse as well)
  5. TODO: Refactor the code so that it can be reused (make modular)
  6. */
  7. #include <fstream>
  8. #include <iostream>
  9. #include <stdexcept>
  10. #include <vector>
  11. class Matrix {
  12. std::vector<std::vector<double>> m_data;
  13. public:
  14. Matrix() {}
  15. // Copy constructor combined with transpose
  16. Matrix(const Matrix& m, bool transpose = 0) {
  17. const std::vector<std::vector<double>>& data = m.GetData();
  18. if (transpose) {
  19. if (data.size()) {
  20. int rows = data[0].size(), cols = data.size();
  21. m_data.resize(rows);
  22. for (auto& i : m_data) i.resize(cols, 0);
  23. for (int i = 0; i < rows; ++i)
  24. for (int j = 0; j < cols; ++j) m_data[i][j] = data[j][i];
  25. }
  26. } else {
  27. m_data = data;
  28. }
  29. }
  30. // Move constructor
  31. Matrix(Matrix&& m) { m_data = std::move(m.GetData()); }
  32. // Constructor with size
  33. Matrix(size_t rows, size_t cols) {
  34. m_data.resize(rows);
  35. for (auto& i : m_data) i.resize(cols, 0);
  36. }
  37. // reference-Getter of m_data
  38. std::vector<std::vector<double>>& GetData() { return m_data; }
  39. // Getter of m_data
  40. const std::vector<std::vector<double>>& GetData() const { return m_data; }
  41. size_t GetRows() const { return m_data.size(); }
  42. size_t GetCols() const {
  43. if (m_data.size() == 0) return 0;
  44. return m_data[0].size();
  45. }
  46. // Matrix multiplication operator
  47. friend Matrix operator*(const Matrix& m1, const Matrix& m2);
  48. // Inverse matrix using Gaussian elimination
  49. static Matrix Inverse(const Matrix& m) {
  50. int rows = m.GetRows(), cols = m.GetCols();
  51. if (rows != cols)
  52. throw std::invalid_argument("Inverse is only for square matrices!");
  53. Matrix from(m), res(rows, rows);
  54. std::vector<std::vector<double>>&datares = res.GetData(),
  55. data = from.GetData();
  56. for (int i = 0; i < rows; ++i) datares[i][i] = 1;
  57. for (int i = 0; i < rows; ++i) {
  58. if (data[i][i] == 0.0)
  59. throw std::invalid_argument("No inverse for this matrix!");
  60. for (int j = 0; j < rows; ++j) {
  61. if (i != j) {
  62. double ratio = data[j][i] / data[i][i];
  63. for (int k = 0; k < rows; ++k)
  64. data[j][k] -= ratio * data[i][k];
  65. for (int k = 0; k < rows; ++k)
  66. datares[j][k] -= ratio * datares[i][k];
  67. }
  68. }
  69. }
  70. for (int i = 0; i < rows; ++i)
  71. for (int j = 0; j < rows; ++j) datares[i][j] /= data[i][i];
  72. return res;
  73. }
  74. void print(std::ostream& out) const {
  75. for (const auto& i : m_data) {
  76. for (const auto& j : i) out << j << ' ';
  77. out << '\n';
  78. }
  79. }
  80. };
  81. Matrix operator*(const Matrix& m1, const Matrix& m2) {
  82. size_t rows = m1.GetRows(), cols = m2.GetCols(), mid = m1.GetCols();
  83. if (mid != m2.GetRows()) throw std::invalid_argument("Couldn't multiply!");
  84. Matrix res(rows, cols);
  85. const std::vector<std::vector<double>>&data1 = m1.GetData(),
  86. &data2 = m2.GetData();
  87. std::vector<std::vector<double>>& datares = res.GetData();
  88. for (size_t i = 0; i < rows; ++i) {
  89. for (size_t j = 0; j < cols; ++j) {
  90. datares[i][j] = 0;
  91. for (size_t k = 0; k < mid; ++k) {
  92. datares[i][j] += data1[i][k] * data2[k][j];
  93. }
  94. }
  95. }
  96. return res;
  97. }
  98. inline std::ostream& operator<<(std::ostream& out, const Matrix& m) {
  99. m.print(out);
  100. return out;
  101. }
  102. int main() {
  103. std::ifstream inputtxt("input.txt");
  104. std::ofstream outputtxt("output.txt");
  105. #ifdef DEBUG
  106. std::istream& in = inputtxt;
  107. std::ostream& out = std::cout;
  108. #else
  109. std::istream& in = inputtxt;
  110. std::ostream& out = outputtxt;
  111. #endif
  112. out.precision(2);
  113. int n, m;
  114. in >> n >> m;
  115. Matrix A(m, n + 1), Y(m, 1);
  116. // input
  117. std::vector<std::vector<double>>&a = A.GetData(), &y = Y.GetData();
  118. for (int i = 0; i < m; ++i) {
  119. for (int j = 0; j < n + 1; ++j) {
  120. if (j == 0) a[i][j] = 1;
  121. if (j < n)
  122. in >> a[i][j + 1];
  123. else
  124. in >> y[i][0];
  125. }
  126. }
  127. out << "A:\n" << std::fixed << A << '\n';
  128. out << "b:\n" << std::fixed << Y << '\n';
  129. Matrix A_T(A, true), // A transpose
  130. A_TA(std::move(A_T * A)), // A transpose multiplied by A
  131. A_TA_inv(
  132. std::move(Matrix::Inverse(A_TA))), // inverse of the previous one
  133. A_TA_invA_T(std::move(A_TA_inv * A_T)), // multiplied by A transpose
  134. ans(std::move(A_TA_invA_T * Y)); // answer x
  135. out << "A_T*A:\n" << std::fixed << A_TA << '\n';
  136. out << "(A_T*A)_-1:\n" << std::fixed << A_TA_inv << '\n';
  137. out << "(A_T*A)_-1*A_T:\n" << std::fixed << A_TA_invA_T << '\n';
  138. out << "x:\n" << std::fixed << ans << '\n';
  139. }