稀疏矩阵乘法问题!
程序代码:
#define MAX 9 typedef struct { int row; int col; int value; }term; term a[MAX]={{6,6,8},{0,0,15},{0,3,22},{0,5,-15},{1,1,11},{1,2,3},{2,3,-6},{4,0,91},{5,2,28}}; term b[MAX]={{6,6,8},{1,0,34},{3,2,15},{1,5,10},{1,2,4},{4,2,3},{5,3,-6},{4,0,41},{5,2,30}}; void mmult(term a[],term b[],term d[])//d是相乘后的矩阵 { int i,j,column,totalb=b[0].value,totald=0; int rows_a=a[0].row,cols_a=a[0].col, totala=a[0].value; //printf("%d",totala); int cols_b=b[0].col; int row_begin=1,row=a[1].row,sum=0; term new_b[MAX];//书上是这样写的:int new_b[MAX][3],这不是明显的错误吗? if(cols_a!=b[0].row) { fprintf(stderr,"Incompatible matrices\n"); exit(1); } transpose(b,new_b); /*prints_matrix_test(b); printf("\n"); prints_matrix_test(new_b); print(new_b);*/ //这一段只是测试转置是否成功 a[totala+1].row=rows_a;// new_b[totalb+1].row=cols_b;// new_b[totalb+1].col=0;//很奇怪这三段代码不会报错?这三段代码有什么用处?看不出~~! for(i=1;i<=totala;)//下面的注释都是自己写上去的,貌似思想是对的,可是结果不对! { column=new_b[1].row;//获取b的列 for(j=1;j<=totalb+1;)// 遍历列 { if(a[i].row!=row)//如果遍历到的行与当前所在行不一致 { storesum(d,&totald,row,column,&sum);//将当前行列元素写入d中 i=row_begin;//重置当前行 for(;new_b[j].row==column;j++)//遍历至下一列 { ; } column=new_b[j].row;//将下一列作为当前列 } else if(new_b[j].row!=column)//如果遍历到的列与当前所在列不一致 { storesum(d,&totald,row,column,&sum);//将当前行列元素写入d中 i=row_begin;//重置当前行 column=new_b[j].row;//重置当前列 } else { switch(COMPARE(a[i].col,new_b[j].col))//比较a当前列的下标与b当前行的下标的大小 { case -1://列下标比行下标小,列下标移至下一列 i++;break; case 0: sum+=(a[i++].value*new_b[j++].value); break; case 1://行下标比列下标小,行下标移至下一行 j++;break; } } } for(;a[i].row==row;i++)//a矩阵跳到下一行 { ; } row_begin=i; row=a[i].row; } d[0].row=rows_a; d[0].col=cols_b; d[0].value=totald; } void storesum(term d[],int *totald,int row,int column,int *sum) { if(*sum) { if((*totald)<MAX) { d[++*totald].row=row; d[*totald].col=column; d[*totald].value=*sum; } else { fprintf(stderr,"Numbers of terms in product exceeds %d\n",MAX); exit(1); } } } int COMPARE(int a,int b) { if(a<b) return -1; else if(a==b) return 0; else return 1; }
下面附上测试程序:
程序代码:
void transpose(term a[],term b[])//转置函数 { int row_terms[7],startingpos[7]; int i,j,num_cols=a[0].col,num_terms=a[0].value; b[0].row=num_cols;b[0].col=a[0].row; b[0].value=num_terms; if(num_terms>0) { for(i=0;i<num_cols;i++) { row_terms[i]=0;//将行中元素个数置为0 } for(i=1;i<=num_terms;i++) { row_terms[a[i].col]++;//记录行中非零元素的个数 } startingpos[0]=1; for(i=1;i<num_cols;i++) { startingpos[i]=startingpos[i-1]+row_terms[i-1]; } for(i=1;i<=num_terms;i++) { j=startingpos[a[i].col]++; b[j].row=a[i].col; b[j].col=a[i].row; b[j].value=a[i].value; } printf("\n"); } } void prints_matrix_test(term *arry)//打印矩阵函数 { int i,j; int k=1; for(i=0;i<(*arry).row;++i) { for(j=0;j<(*arry).col;++j) { while(k<MAX)//通过循环遍历到符合位置的元素并打印 { if(i==(*(arry+k)).row&&j==(*(arry+k)).col) { printf("%4d",(*(arry+k)).value); break; } ++k; } if(k==MAX) { printf("%4d",0); } k=1;//k重置,为下一次循环做准备 } printf("\n"); } } int print(term arry[])//打印三元组函数 { int i; for(i=0;i<MAX;++i) { printf("%d %d %d \n",arry[i].row,arry[i].col,arry[i].value); } return 0; }