歡迎來到Linux教程網
Linux教程網
Linux教程網
Linux教程網
您现在的位置: Linux教程網 >> UnixLinux >  >> Linux編程 >> Linux編程

C++實現矩陣乘法

重載*運算符為友元函數。

#include <iostream>
#include <cmath>
using namespace std;

class Matrix{
    public:
        Matrix(){}
        Matrix(int,int);
        void setMatrix();
        void showMatrix();
        void showTransposedMatrix();
        friend Matrix operator *(Matrix m1,Matrix m2);
    protected:
        int m;
        int n;
        int mn;
        double* matrixPtr;
        double* transposedMPtr;
        void transpose();
};

class SquareMatrix:public Matrix{
    public:
        SquareMatrix(){}
        SquareMatrix(int);
        void setSquareMatrix();
        void setDet();
        void getDet();
    private:
        double det;
};

Matrix::Matrix(int mt,int nt){
    m=mt;
    n=nt;
    mn=m*n;
    matrixPtr=new double[mn];
}

void Matrix::setMatrix(){
    cout<<"輸入矩陣的行數和列數:"<<endl;
    cin>>m>>n;
    mn=m*n;
    matrixPtr=new double[mn];
    for(int i=0;i<mn;i++)
        cin>>matrixPtr[i];
}

void Matrix::transpose(){
    transposedMPtr=new double[mn];
    for(int i=0;i<n;i++)
        for(int j=0;j<m;j++)
            transposedMPtr[m*i+j]=matrixPtr[n*j+i];
}

void Matrix::showMatrix(){
    for(int i=0;i<m;i++){
        for(int j=0;j<n;j++)
            cout<<matrixPtr[n*i+j]<<' ';
        cout<<endl;
    }
}

void Matrix::showTransposedMatrix(){
    for(int i=0;i<n;i++){
        for(int j=0;j<m;j++)
            cout<<transposedMPtr[m*i+j]<<' ';
        cout<<endl;
    }
}

Matrix operator *(Matrix m1,Matrix m2){
    Matrix m3(m1.m,m2.n);
    for(int i=0;i<m3.m;i++)
        for(int j=0;j<m3.n;j++){
            double val=0;
            for(int k=0;k<m2.m;k++)
                val+=m1.matrixPtr[m1.n*i+k]*m2.matrixPtr[m2.n*k+j];
            m3.matrixPtr[m3.n*i+j]=val;
        }
    return m3;
}

SquareMatrix::SquareMatrix(int m){
    Matrix(m,m);                      //right?
}

void SquareMatrix::setSquareMatrix(){
    cout<<"輸入方陣的階數:"<<endl;
    cin>>m;
    n=m;
    mn=m*n;
    matrixPtr=new double[mn];
    for(int i=0;i<mn;i++)
        cin>>matrixPtr[i];
}

void SquareMatrix::setDet(){
    double valDet(double*,int);
    det=valDet(matrixPtr,m);
}

void SquareMatrix::getDet(){
    cout<<det<<endl;
}
double valDet( double *detPtr, int rank)
{
    double val=0;
    if(rank==1) return detPtr[0];
    for(int i=0;i<rank;i++)                  //計算余子式保存在nextDetPtr[]中
    {
        double *nextDetPtr=new double[(rank-1)*(rank-1)];
        for(int j=0;j<rank-1;j++)
            for(int k=0;k<i;k++)
                nextDetPtr[j*(rank-1)+k]=detPtr[(j+1)*rank+k];
        for(int j=0;j<rank-1;j++)
            for(int k=i;k<rank-1;k++)
                nextDetPtr[j*(rank-1)+k]=detPtr[(j+1)*rank+k+1];
        val+=detPtr[i]*valDet(nextDetPtr,rank-1)*pow(-1.0,i);
    }
    return val;
}

int main(){
    Matrix m1,m2,m3;
    m1.setMatrix();
    m2.setMatrix();
    m3=m1*m2;
    m3.showMatrix();
    return 0;
}

Copyright © Linux教程網 All Rights Reserved