在写 class CMatrix 之前,先将 main 代码写好。
这一点非常重要,
是先有需求后代码,而不是照着烂代码制定需求。
main大概是这个样子,不同的人可能写出不一样的代码,但大体差不多(绝不会想你写的那样):
(绝不会想你写的那样)
int main( void )
{
CMatrix a = { 2, 3, { 1, 2, 3
, 4, 5, 6 } };
CMatrix b = { 3, 2, { 1, 4
, 2, 5
, 3, 6 } };
CMatrix c = a * b;
cout << "a matrix:\n" << a << '\n';
cout << "b matrix:\n" << b << '\n';
cout << "a*b matrix:\n" << c << '\n';
cout << "Transpose matrix:\n" << c.Transpose() << '\n';
return 0;
}
有了用法的代码,class CMatrix 的大体结构就有了,顶多以后会补充一些其它的成员
class CMatrix
{
public:
CMatrix( size_t rownum, size_t colnum, std::initializer_list<int> data={} );
CMatrix( const CMatrix& mtx );
CMatrix( CMatrix&& mtx ) noexcept;
CMatrix& operator=( const CMatrix& mtx );
CMatrix& operator=( CMatrix&& mtx ) noexcept;
~CMatrix();
int at( size_t row, size_t col ) const;
int& at( size_t row, size_t col );
CMatrix operator*( const CMatrix& mtx );
CMatrix& Transpose();
protected:
size_t rownum_, colnum_;
int* pdata_;
friend std::ostream& operator<<( std::ostream& os, const CMatrix& mtx );
};
最后,就剩下枯燥的填代码了,没什么好说了的,结束。
随便填了一些,仅供参考:
程序代码:
#include <iostream>
class CMatrix
{
public:
CMatrix( size_t rownum, size_t colnum, std::initializer_list<int> data={} );
CMatrix( const CMatrix& mtx );
CMatrix( CMatrix&& mtx ) noexcept;
CMatrix& operator=( const CMatrix& mtx );
CMatrix& operator=( CMatrix&& mtx ) noexcept;
~CMatrix();
int at( size_t row, size_t col ) const;
int& at( size_t row, size_t col );
CMatrix operator*( const CMatrix& mtx );
CMatrix& Transpose();
protected:
size_t rownum_, colnum_;
int* pdata_;
friend std::ostream& operator<<( std::ostream& os, const CMatrix& mtx );
};
#include <new>
#include <exception>
using namespace std;
CMatrix::CMatrix( size_t rownum, size_t colnum, std::initializer_list<int> data ) try : rownum_(rownum), colnum_(colnum), pdata_(new int[rownum_*colnum_]())
{
if( rownum_==0 || colnum_==0 )
throw std::invalid_argument( "fuck" );
size_t i = 0;
for( const auto& val : data )
{
if( i == rownum_*colnum_ )
break;
pdata_[i++] = val;
}
}
catch( const std::bad_alloc& ) {
throw;
}
CMatrix::CMatrix( const CMatrix& mtx ) try : rownum_(mtx.rownum_), colnum_(mtx.colnum_), pdata_(new int[rownum_*colnum_])
{
std::copy( mtx.pdata_, mtx.pdata_+rownum_*colnum_, pdata_ );
}
catch( const std::bad_alloc& ) {
throw;
}
CMatrix::CMatrix( CMatrix&& mtx ) noexcept : rownum_(mtx.rownum_), colnum_(mtx.colnum_), pdata_(mtx.pdata_)
{
mtx.rownum_ = 0;
mtx.colnum_ = 0;
mtx.pdata_ = nullptr;
}
CMatrix& CMatrix::operator=( const CMatrix& mtx )
{
if( this != &mtx )
{
if( rownum_*colnum_ != mtx.rownum_*mtx.colnum_ )
{
delete[] pdata_;
pdata_ = nullptr;
try {
pdata_ = new int[rownum_*colnum_];
}
catch( const std::bad_alloc& ) {
throw;
}
rownum_ = mtx.rownum_;
colnum_ = mtx.colnum_;
}
std::copy( mtx.pdata_, mtx.pdata_+rownum_*colnum_, pdata_ );
}
return *this;
}
CMatrix& CMatrix::operator=( CMatrix&& mtx ) noexcept
{
if( this != &mtx )
{
rownum_ = mtx.rownum_;
colnum_ = mtx.colnum_;
pdata_ = mtx.pdata_;
mtx.rownum_ = 0;
mtx.colnum_ = 0;
mtx.pdata_ = nullptr;
}
return *this;
}
CMatrix::~CMatrix()
{
delete[] pdata_;
}
int CMatrix::at( size_t row, size_t col ) const
{
if( row>=rownum_ || col>=colnum_ )
throw std::invalid_argument( "fuck" );
return pdata_[ row*colnum_ + col ];
}
int& CMatrix::at( size_t row, size_t col )
{
if( row>=rownum_ || col>=colnum_ )
throw std::invalid_argument( "fuck" );
return pdata_[ row*colnum_ + col ];
}
CMatrix CMatrix::operator*( const CMatrix& mtx )
{
if( rownum_==0 || colnum_==0 || mtx.rownum_==0 || mtx.colnum_==0 || colnum_!=mtx.rownum_ )
throw std::invalid_argument( "fuck" );
CMatrix ret( rownum_, mtx.colnum_ );
for( size_t row=0; row!=ret.rownum_; ++row )
{
for( size_t col=0; col!=ret.colnum_; ++col )
{
for( size_t k=0; k!=colnum_; ++k )
ret.at(row,col) += at(row,k)*mtx.at(k,col);
}
}
return ret;
}
std::ostream& operator<<( std::ostream& os, const CMatrix& mtx )
{
for( size_t r=0; r!=mtx.rownum_; ++r )
for( size_t c=0; c!=mtx.colnum_; ++c )
os << mtx.at(r,c) << (c+1==mtx.colnum_ ? '\n' : ' ');
return os;
}
//CMatrix& CMatrix::Transpose()
//{
// // 未实现,矩阵的原位转置 自己写
//}
int main( void )
{
CMatrix a = { 2, 3, { 1, 2, 3
, 4, 5, 6 } };
CMatrix b = { 3, 2, { 1, 4
, 2, 5
, 3, 6 } };
CMatrix c = a * b;
cout << "a matrix:\n" << a << '\n';
cout << "b matrix:\n" << b << '\n';
cout << "a*b matrix:\n" << c << '\n';
//cout << "Transpose matrix:\n" << c.Transpose() << '\n';
return 0;
}