歡迎來到Linux教程網
Linux教程網
Linux教程網
Linux教程網
您现在的位置: Linux教程網 >> UnixLinux >  >> Linux編程 >> Linux編程

C++實現的線性代數矩陣計算

/**
 *  線性代數矩陣計算
 *  實現功能:行列向量獲取,子矩陣獲取,轉置矩陣獲取,
 *          行列式計算,轉置伴隨矩陣獲取,逆矩陣計算
 *
 *  Copyright 2011 Shi Y.M. All Rights Reserved
 */

#ifndef MATICAL_TMATRIX_H_INCLUDED
#define MATICAL_TMATRIX_H_INCLUDED

#include <memory.h>

namespace matical
{

template<typename T>
class TMatrix
{
private:
    T * m_Elems;
    int m_Rows;
    int m_Cols;

public:
    typedef T ElemType;

public:

    TMatrix() : m_Elems(NULL), m_Rows(0), m_Cols(0) {}
    TMatrix(int rows, int cols) : m_Rows(rows), m_Cols(cols) {
        m_Elems = new ElemType[m_Rows * m_Cols];
        memset(m_Elems, 0, sizeof(ElemType) * m_Rows * m_Cols);
    }
    // 生成n階方陣
    TMatrix(int n) : m_Rows(n), m_Cols(n){
        m_Elems = new ElemType[m_Rows * m_Cols];
        memset(m_Elems, 0, sizeof(ElemType) * m_Rows * m_Cols);
    }

    TMatrix(const TMatrix& m) : m_Rows(m.Rows()), m_Cols(m.Cols()) {
        m_Elems = new ElemType[m_Rows * m_Cols];
        memcpy(m_Elems, m.m_Elems, sizeof(ElemType) * m_Rows * m_Cols);
    }
    virtual ~TMatrix() {
        if ( m_Elems ) delete[] m_Elems;
        m_Elems = NULL;
        m_Rows  = 0;
        m_Cols  = 0;
    }

public:

    int     Rows() const { return m_Rows;}
    int     Cols() const { return m_Cols;}

    // 是否為方陣
    bool    IsSquare()   const { return ( m_Rows > 0 && (m_Rows == m_Cols)); }

    // 取行向量
    TMatrix<T> Row(int row) const {
        TMatrix m(1, m_Cols);
        for(int i = 0; i < m_Cols; i++) {
            m(0, i) = (*this)(row, i);
        }
        return m;
    }

    // 取列向量
    TMatrix<T> Col(int col) const {
        TMatrix m(m_Rows, 1);
        for(int i = 0; i < m_Rows; i++) {
            m(i, 0) = (*this)(i, col);
        }
        return m;
    }

public:

    // 將當前矩陣設置為單位矩陣
    void Identity() {
        for(int r = 0; r < this->m_Rows; r++) {
            for(int c = 0; c < this->m_Cols; c++) {
                if ( r == c ) (*this)(r, c) = 1;
                else  (*this)(r, c) = 0;
            }
        }
        return ;
    }

    // 當前矩陣設置為零矩陣
    void Zero() {
        int sz = m_Rows * m_Cols;
        for (int i = 0; i < sz; i++ ) m_Elems[i] = 0;
        return ;
    }

    // 獲取指定范圍的子矩陣,rrow>=lrow, rcol>=lcol
    TMatrix<T> SubMatrix(int lrow, int lcol, int rrow, int rcol) const {
        TMatrix m(rrow - lrow + 1, rcol - lcol + 1);
        for ( int r = 0; r < m.m_Rows; r++) {
            for (int c = 0; c < m.m_Cols; c++) {
                m(r, c ) = (*this)(r + lrow, c + lcol);
            }
        }
        return m;
    }

    // 獲取當前矩陣元素(row, col)的子矩陣(去掉元素(row,col)所在的行和列)
    TMatrix<T> SubMatrixEx(int row, int col) const {
        int rr = 0;
        int cc = 0;
        TMatrix m(this->m_Rows - 1, this->m_Cols - 1);
        for ( int r = 0; r < m.m_Rows; r++, rr++) {
            if ( r == row) rr++;
            cc = 0;
            for (int c = 0; c < m.m_Cols; c++, cc++) {
                if ( c == col ) cc++;
                m(r, c) = (*this)(rr, cc);
            }
        }
        return m;
    }

    // 計算當前矩陣的轉置矩陣
    TMatrix<T> Transpose() const {
        printf("Transpose\n");
        TMatrix m(this->m_Cols, this->m_Rows);
        for (int r = 0; r < m.m_Rows; r++ ) {
            for ( int c = 0; c < m.m_Cols; c++ ) {
                m(r, c) = (*this)(c, r);
            }
        }
        return m;
    }

    // 計算當前矩陣的行列式(遞歸方式), 計算行列式的矩陣,須是n階方陣
    ElemType Det() const {
        ElemType det = 0;
        if ( m_Rows == 1) {
            det = (*this)(0, 0);
        } else if ( m_Rows == 2) {
            det = (*this)(0, 0) * (*this)(1, 1) - (*this)(0 ,1) * (*this)(1, 0);
        } else {
            for (int r = 0; r < m_Rows; r++) {
                TMatrix<ElemType> m = this->SubMatrixEx(r, 0);
                if ( r % 2 ) det += (-1) * (*this)(r, 0) * m.Det();
                else det += (*this)(r, 0) * m.Det();
            }
        }
        return det;
    }

    // 計算當前矩陣的轉置伴隨矩陣,當前矩陣須為n階方陣
    TMatrix<T> Adj() const {
        TMatrix<T> m(m_Rows, m_Cols);
        for (int r = 0; r < m_Rows; r++) {
            for (int c = 0; c < m_Cols; c++) {
                if ( (r + c) % 2 ) m(c, r) = -1 * this->SubMatrixEx(r, c).Det();
                else m(c, r) = this->SubMatrixEx(r, c).Det();
            }
        }
        return m;
    }

    // 計算當前矩陣的逆矩陣,當前矩陣須為n階方陣
    TMatrix<T> Inv() const {
        ElemType   det = this->Det();
        TMatrix<T> adj = this->Adj();
        return adj * (1 / det);
    }

public:

    ElemType operator()(int row, int col) const { return m_Elems[row * m_Cols + col];};
    ElemType& operator()(int row, int col) { return m_Elems[row * m_Cols + col]; }

public:

    // 兩矩陣相加,必須具有相同的行數以及列數
    TMatrix<T>  operator+(const TMatrix& m) const {
        TMatrix mr(this->m_Rows, this->m_Cols);
        for(int r = 0; r < this->m_Rows; r++) {
            for(int c = 0; c < this->m_Cols; c++) {
                mr(r, c) = (*this)(r, c) + m(r, c);
            }
        }
        return mr;
    } // operator+(const TMatrix& m) const

    // 兩矩陣相減,必須具有相同的行數以及列數
    TMatrix<T> operator-(const TMatrix& m) const {
        TMatrix mr(this->m_Rows, this->m_Cols);
        for(int r = 0; r < this->m_Rows; r++) {
            for(int c = 0; c < this->m_Cols; c++) {
                mr(r, c) = (*this)(r, c) - m(r, c);
            }
        }
        return mr;
    } // operator-(const TMatrix& m) const

    // 矩陣與常數相乘(數乘)
    TMatrix<T>  operator*(ElemType v) const {
        TMatrix mr(this->m_Rows, this->m_Cols);
        for(int r = 0; r < this->m_Rows; r++) {
            for(int c = 0; c < this->m_Cols; c++) {
                mr(r, c) = (*this)(r, c) * v;
            }
        }
        return mr;
    } // operator*(ElemType v) const

    // 矩陣相乘(當前矩陣列數須等於參數矩陣的列數),
    // 結果矩陣行數等於當前矩陣,列數等於參數矩陣的列數
    TMatrix<T>  operator*(const TMatrix& m) const {
        TMatrix mr(this->m_Rows, m.m_Cols);
        for (int r = 0; r < this->m_Rows; r++) {
            for (int c = 0; c < m.m_Cols; c++) {
                for (int i = 0; i < this->m_Cols; i++) {
                    mr(r, c) += ( (*this)(r, i) * m(i, c) );
                    printf("(%d, %d)\n", r, c);
                }
            }
        }
        return mr;
    } // operator*(const TMatrix& m) const
}; // class TMatrix

} // namespace matical

#endif // MATICAL_TMATRIX_H_INCLUDED

Copyright © Linux教程網 All Rights Reserved