So, I spent quite a while helping someone on stackoverflow prove the correctness of a matrix multiplication algorithm. Once we got to a point where they were happy with it and thanked me, then they deleted their question. I take this, and the fact that in hindsight their understanding did not seem to match up with the code they had, to mean that they were cheating on some university coursework. It took me quite a lot of effort to write the answer, and I wasn’t doing it primarily to help them specificaly, I was doing it to help Dafny become more popular. By which I mean: it was only worth the effort of writing such a detailed answer if the answer can help many other people.
Unfortunately, I don’t have a copy of the answer I wrote. But I do I have a copy of all the code. So here it is. If you are a lecturer and this is your coursework question then I can be contacted on the “about” page.
First the question:
http://rise4fun.com/Dafny/Bztr
method Main() { var m1: array2<int>, m2: array2<int>, m3: array2<int>; m1 := new int[2,3]; m2 := new int[3,1]; m1[0,0] := 1; m1[0,1] := 2; m1[0,2] := 3; m1[1,0] := 4; m1[1,1] := 5; m1[1,2] := 6; m2[0,0] := 7; m2[1,0] := 8; m2[2,0] := 9; m3 := Multiply'(m1, m2); PrintMatrix(m1); print "\n*\n"; PrintMatrix(m2); print "\n=\n"; PrintMatrix(m3); } method PrintMatrix(m: array2<int>) requires m != null { var i: nat := 0; while (i < m.Length0) { var j: nat := 0; print "\n"; while (j < m.Length1) { print m[i,j]; print "\t"; j := j + 1; } i := i + 1; } print "\n"; } predicate MM(m1: array2<int>, m2: array2<int>, m3: array2<int>) { // m3 is the result of multiplying the matrix m1 by the matrix m2 m1 != null && m2 != null && m3 != null && m1.Length1 == m2.Length0 && m3.Length0 == m1.Length0 && m3.Length1 == m2.Length1 && forall i,j :: 0 <= i < m3.Length0 && 0 <= j < m3.Length1 ==> m3[i,j] == RowColumnProduct(m1,m2,i,j) } function RowColumnProduct(m1: array2<int>, m2: array2<int>, row: nat, column: nat): int requires m1 != null && m2 != null && m1.Length1 == m2.Length0 requires row < m1.Length0 && column < m2.Length1 { RowColumnProductFrom(m1, m2, row, column, 0) } function RowColumnProductFrom(m1: array2<int>, m2: array2<int>, row: nat, column: nat, k: nat): int requires m1 != null && m2 != null && k <= m1.Length1 == m2.Length0 requires row < m1.Length0 && column < m2.Length1 decreases m1.Length1 - k { if k == m1.Length1 then 0 else m1[row,k]*m2[k,column] + RowColumnProductFrom(m1, m2, row, column, k+1) } function RowColumnProductTo(m1: array2<int>, m2: array2<int>, row: nat, column: nat, k: nat,i:nat): int requires m1 != null && m2 != null && k <= m1.Length1 == m2.Length0 requires row < m1.Length0 && column < m2.Length1 && i < m1.Length1 == m2.Length0 requires k<=i decreases i - k { if k == i then 0 else m1[row,k]*m2[k,column] + RowColumnProductTo(m1, m2, row, column, k+1,i) } predicate MMROW(m1: array2<int>, m2: array2<int>, m3: array2<int>,row:nat,col:nat) { // m3 is the result of multiplying the matrix m1 by the matrix m2 m1 != null && m2 != null && m3 != null && m1.Length1 == m2.Length0 && m3.Length0 == m1.Length0 && m3.Length1 == m2.Length1 && row <= m1.Length0 && col <= m2.Length1 && forall i,j :: 0 <= i < row && 0 <= j < col ==> m3[i,j] == RowColumnProduct(m1,m2,i,j) } predicate MMCOL(m1: array2<int>, m2: array2<int>, m3: array2<int>,row:nat,col:nat) { // m3 is the result of multiplying the matrix m1 by the matrix m2 m1 != null && m2 != null && m3 != null && m1.Length1 == m2.Length0 && m3.Length0 == m1.Length0 && m3.Length1 == m2.Length1 && row <= m1.Length0 && col <= m2.Length1 && forall i,j :: 0 <= i < row && 0 <= j < col ==> m3[i,j] == RowColumnProduct(m1,m2,i,j) } predicate MMI(m1: array2<int>, m2: array2<int>, m3: array2<int>,row:nat,col:nat,i:nat) { // m3 is the result of multiplying the matrix m1 by the matrix m2 m1 != null && m2 != null && m3 != null && m1.Length1 == m2.Length0 && m3.Length0 == m1.Length0 && m3.Length1 == m2.Length1 && row < m1.Length0 && col < m2.Length1 && 0<=i<m1.Length1 && forall n,j :: 0 <= n < row && 0 <= j < col ==> m3[n,j] == RowColumnProduct(m1,m2,n,j) && m3[row,col] == RowColumnProductTo(m1, m2, row, col ,0,i) } method Multiply'(m1: array2<int>, m2: array2<int>) returns (m3: array2<int>) requires m1 != null && m2 != null requires m1.Length1 > 0 && m2.Length0 > 0 requires m1.Length1 == m2.Length0 ensures MM(m1, m2, m3) { m3 := new int[m1.Length0, m2.Length1]; var row:nat := 0; var col:nat := 0; var i:nat := 0; while(row != m1.Length0) invariant MMROW(m1, m2, m3,row, col) invariant (0<=row<= m1.Length0) decreases m1.Length0 - row { while(col != m2.Length1) invariant MMCOL(m1, m2, m3,row, col) invariant (0<=col<= m2.Length1) decreases m2.Length1 - col { while(i != m1.Length1) invariant MMI(m1, m2, m3,row, col,i) invariant (i<= m1.Length1==m2.Length0)&&(0<=col<= m2.Length1)&&(0<=row<= m1.Length0) decreases m1.Length1 - i { m3[row,col]:= m3[row,col]+(m1[row,i]*m2[i,col]); i := i+1; } col := col+1; i := 0; } row := row+1; col:= 0; } } |
Now several different solutions.
The questioners insisted that I not change their definition of the main MM predicate, even though the direction of recursion in the predicate is opposite to the direction of recursion in the while loop. The proof strategy I showed them was to define a new predicate that did recurse in he same direction as the while loop, and then prove the equivalence of the two predicates.
http://rise4fun.com/Dafny/noVy
method Main() { var m1: array2<int>, m2: array2<int>, m3: array2<int>; m1 := new int[2,3]; m2 := new int[3,1]; m1[0,0] := 1; m1[0,1] := 2; m1[0,2] := 3; m1[1,0] := 4; m1[1,1] := 5; m1[1,2] := 6; m2[0,0] := 7; m2[1,0] := 8; m2[2,0] := 9; m3 := Multiply'(m1, m2); PrintMatrix(m1); print "\n*\n"; PrintMatrix(m2); print "\n=\n"; PrintMatrix(m3); } method PrintMatrix(m: array2<int>) requires m != null { var i: nat := 0; while (i < m.Length0) { var j: nat := 0; print "\n"; while (j < m.Length1) { print m[i,j]; print "\t"; j := j + 1; } i := i + 1; } print "\n"; } predicate AllowedToMultiply(m1: array2<int>, m2: array2<int>) { m1 != null && m2 != null && m1.Length1 == m2.Length0 } predicate AllowedToMultiplyInto(m1: array2<int>, m2: array2<int>, m3: array2<int>) { AllowedToMultiply(m1,m2) && m3 != null && m3.Length0 == m1.Length0 && m3.Length1 == m2.Length1 } predicate MM(m1: array2<int>, m2: array2<int>, m3: array2<int>) { // m3 is the result of multiplying the matrix m1 by the matrix m2 AllowedToMultiplyInto(m1,m2,m3) && forall i,j :: 0 <= i < m3.Length0 && 0 <= j < m3.Length1 ==> m3[i,j] == RowColumnProduct(m1,m2,i,j) } function RowColumnProduct(m1: array2<int>, m2: array2<int>, row: nat, column: nat): int requires AllowedToMultiply(m1,m2) requires row < m1.Length0 && column < m2.Length1 { RowColumnProductFrom(m1, m2, row, column, 0) } function RowColumnProductFrom(m1: array2<int>, m2: array2<int>, row: nat, column: nat, k: nat): int requires AllowedToMultiply(m1,m2) requires row < m1.Length0 && column < m2.Length1 requires k <= m1.Length1 decreases m1.Length1 - k { if k == m1.Length1 then 0 else m1[row,k]*m2[k,column] + RowColumnProductFrom(m1, m2, row, column, k+1) } function RowColumnProductTo(m1: array2<int>, m2: array2<int>, row: nat, column: nat, k: nat,i:nat): int requires AllowedToMultiply(m1,m2) requires row < m1.Length0 && column < m2.Length1 && i < m1.Length1 == m2.Length0 requires k<=i decreases i - k { if k == i then 0 else m1[row,k]*m2[k,column] + RowColumnProductTo(m1, m2, row, column, k+1,i) } function RowColumnProductForCount(m1: array2<int>, m2: array2<int>, row: nat, column: nat, n:nat): int requires AllowedToMultiply(m1, m2) requires row < m1.Length0 && column < m2.Length1 && n <= m1.Length1 { if n == 0 then 0 else RowColumnProductForCount(m1, m2, row, column, n-1) + m1[row,n-1]*m2[n-1,column] } predicate MMROW(m1: array2<int>, m2: array2<int>, m3: array2<int>, rown:nat) requires AllowedToMultiplyInto(m1, m2, m3) requires rown <= m1.Length0 { forall r:nat,c:nat :: r < rown && c < m2.Length1 ==> m3[r,c] == RowColumnProductForCount(m1,m2,r,c,m1.Length1) } predicate MMCOL(m1: array2<int>, m2: array2<int>, m3: array2<int>,row:nat,coln:nat) requires AllowedToMultiplyInto(m1, m2, m3) requires row < m1.Length0 && coln <= m2.Length1 { forall c:nat :: c < coln ==> m3[row,c] == RowColumnProductForCount(m1,m2,row,c,m1.Length1) } predicate MMI(m1: array2<int>, m2: array2<int>, m3: array2<int>,row:nat,col:nat,n:nat) requires AllowedToMultiplyInto(m1, m2, m3) requires row < m1.Length0 && col < m2.Length1 && n<=m1.Length1 { m3[row,col] == RowColumnProductForCount(m1, m2, row, col, n) } method Multiply'(m1: array2<int>, m2: array2<int>) returns (m3: array2<int>) requires AllowedToMultiply(m1, m2) ensures MM(m1, m2, m3) { m3 := new int[m1.Length0, m2.Length1]; var row:nat := 0; // loop over rows of m1 while(row < m1.Length0) invariant row <= m1.Length0 invariant forall rn:nat :: rn <= row ==> MMROW(m1, m2, m3, rn) modifies m3 { assert MMROW(m1, m2, m3, row); // loop over coloums of m2 var col:nat := 0; while(col < m2.Length1) invariant col <= m2.Length1 invariant forall rn:nat :: rn <= row ==> MMROW(m1, m2, m3, rn) invariant forall n:nat :: n <= col ==> MMCOL(m1, m2, m3,row, n) { assert MMCOL(m1, m2, m3, row, col); // // loop over elements of m1 row / m2 column var i:nat := 0; m3[row,col] := 0; while(i < m1.Length1) invariant i <= m1.Length1 invariant forall rn:nat :: rn < row ==> MMROW(m1, m2, m3, rn) invariant forall c:nat :: c < col ==> MMCOL(m1, m2, m3, row, c) invariant forall j:nat :: j <= i ==> MMI(m1, m2, m3, row, col, j) { assert MMI(m1, m2, m3, row, col, i); m3[row,col]:= m3[row,col]+(m1[row,i]*m2[i,col]); i := i+1; assert MMI(m1, m2, m3, row, col, i); } assert MMI(m1, m2, m3, row, col, m1.Length1); assert m3[row,col] == RowColumnProductForCount(m1,m2,row,col,m1.Length1); col := col+1; assert MMCOL(m1, m2, m3, row, col); } assert MMCOL(m1, m2, m3, row, m2.Length1); row := row+1; assert MMROW(m1, m2, m3, row); } assert MMROW(m1, m2, m3, m1.Length0); MMROWImpliesMM(m1, m2, m3); } lemma MMROWImpliesMM(m1: array2<int>, m2: array2<int>, m3: array2<int>) requires AllowedToMultiplyInto(m1,m2,m3) requires MMROW(m1, m2, m3, m1.Length0) ensures MM(m1, m2, m3) { assert forall r:nat,c:nat :: r < m1.Length0 && c < m2.Length1 ==> m3[r,c] == RowColumnProductForCount(m1,m2,r,c,m1.Length1); forall r:nat,c:nat | r < m1.Length0 && c < m2.Length1 ensures m3[r,c] == RowColumnProduct(m1,m2,r,c) { assert m3[r,c] == RowColumnProductForCount(m1,m2,r,c,m1.Length1); RowColumnProductForCountImpliesRowColumnProduct(m1, m2, m3, r, c); } assert forall r:nat,c:nat :: r < m3.Length0 && c < m3.Length1 ==> m3[r,c] == RowColumnProduct(m1,m2,r,c); } lemma RowColumnProductForCountImpliesRowColumnProduct(m1: array2<int>, m2: array2<int>, m3: array2<int>, r:nat, c:nat) requires AllowedToMultiplyInto(m1,m2,m3) requires r < m1.Length0 && c < m2.Length1; requires m3[r,c] == RowColumnProductForCount(m1,m2,r,c,m1.Length1) ensures m3[r,c] == RowColumnProduct(m1,m2,r,c) { assert RowColumnProduct(m1,m2,r,c) == RowColumnProductFrom(m1,m2,r,c,0); var i:nat := 0; var total := RowColumnProductForCount(m1,m2,r,c,m1.Length1); while i < m1.Length1 invariant i <= m1.Length1 invariant total == RowColumnProductForCount(m1,m2,r,c,m1.Length1-i) + RowColumnProductFrom(m1,m2,r,c,m1.Length1-i) { i := i+1; } } |
I suggested an alternative strategy which uses Dafny’s forall statement to implement the matrix multiplication. Since this is not a loop it does not require any invariants to be given in order to verify. This is the best solution if you are just trying to get a matrix multiply working.
http://rise4fun.com/Dafny/mgoeu
method Main() { var m1: array2<int>, m2: array2<int>, m3: array2<int>; m1 := new int[2,3]; m2 := new int[3,1]; m1[0,0] := 1; m1[0,1] := 2; m1[0,2] := 3; m1[1,0] := 4; m1[1,1] := 5; m1[1,2] := 6; m2[0,0] := 7; m2[1,0] := 8; m2[2,0] := 9; m3 := Multiply'(m1, m2); PrintMatrix(m1); print "\n*\n"; PrintMatrix(m2); print "\n=\n"; PrintMatrix(m3); } method PrintMatrix(m: array2<int>) requires m != null { var i: nat := 0; while (i < m.Length0) { var j: nat := 0; print "\n"; while (j < m.Length1) { print m[i,j]; print "\t"; j := j + 1; } i := i + 1; } print "\n"; } predicate AllowedToMultiply(m1: array2<int>, m2: array2<int>) { m1 != null && m2 != null && m1.Length1 == m2.Length0 } predicate MM(m1: array2<int>, m2: array2<int>, m3: array2<int>) requires AllowedToMultiply(m1, m2) { // m3 is the result of multiplying the matrix m1 by the matrix m2 m3 != null && m3.Length0 == m1.Length0 && m3.Length1 == m2.Length1 && forall r:nat,c:nat :: r < m3.Length0 && c < m3.Length1 ==> m3[r,c] == RowColumnProductTo(m1,m2,r,c,m1.Length1) } function method RowColumnProductTo(m1: array2<int>, m2: array2<int>, row: nat, column: nat, n:nat): int requires AllowedToMultiply(m1, m2) requires row < m1.Length0 && column < m2.Length1 && n <= m1.Length1 { if n == 0 then 0 else RowColumnProductTo(m1, m2, row, column, n-1) + m1[row,n-1]*m2[n-1,column] } method Multiply'(m1: array2<int>, m2: array2<int>) returns (m3: array2<int>) requires AllowedToMultiply(m1, m2) ensures MM(m1, m2, m3) { m3 := new int[m1.Length0, m2.Length1]; forall r:nat,c:nat | r < m3.Length0 && c < m3.Length1 { m3[r,c] := RowColumnProductTo(m1,m2,r,c,m1.Length1); } } |
In this version I take the approach of changing the definition of the MM predicate to work in the same direction as the iteration. If you don’t want to go down the route of using the forall statement, then in my opinion it is usually most productive if you get your recursive and iterative definitions to bracket the same way. For example, have them both do ((((a+b)+c)+d)+e)
, don’t have one of the do (a+(b+(c+(d+e))))
.
method Main() { var m1: array2<int>, m2: array2<int>, m3: array2<int>; m1 := new int[2,3]; m2 := new int[3,1]; m1[0,0] := 1; m1[0,1] := 2; m1[0,2] := 3; m1[1,0] := 4; m1[1,1] := 5; m1[1,2] := 6; m2[0,0] := 7; m2[1,0] := 8; m2[2,0] := 9; m3 := Multiply'(m1, m2); PrintMatrix(m1); print "\n*\n"; PrintMatrix(m2); print "\n=\n"; PrintMatrix(m3); } method PrintMatrix(m: array2<int>) requires m != null { var i: nat := 0; while (i < m.Length0) { var j: nat := 0; print "\n"; while (j < m.Length1) { print m[i,j]; print "\t"; j := j + 1; } i := i + 1; } print "\n"; } predicate AllowedToMultiply(m1: array2<int>, m2: array2<int>) { m1 != null && m2 != null && m1.Length1 == m2.Length0 } predicate AllowedToMultiplyInto(m1: array2<int>, m2: array2<int>, m3: array2<int>) { AllowedToMultiply(m1,m2) && m3 != null && m3.Length0 == m1.Length0 && m3.Length1 == m2.Length1 } predicate MM(m1: array2<int>, m2: array2<int>, m3: array2<int>) requires AllowedToMultiply(m1, m2) { // m3 is the result of multiplying the matrix m1 by the matrix m2 m3 != null && m3.Length0 == m1.Length0 && m3.Length1 == m2.Length1 && forall r:nat,c:nat :: r < m3.Length0 && c < m3.Length1 ==> m3[r,c] == RowColumnProductTo(m1,m2,r,c,m1.Length1) } function RowColumnProductTo(m1: array2<int>, m2: array2<int>, row: nat, column: nat, n:nat): int requires AllowedToMultiply(m1, m2) requires row < m1.Length0 && column < m2.Length1 && n <= m1.Length1 { if n == 0 then 0 else RowColumnProductTo(m1, m2, row, column, n-1) + m1[row,n-1]*m2[n-1,column] } predicate MMROW(m1: array2<int>, m2: array2<int>, m3: array2<int>, rown:nat) requires AllowedToMultiplyInto(m1, m2, m3) requires rown <= m1.Length0 { forall r:nat,c:nat :: r < rown && c < m2.Length1 ==> m3[r,c] == RowColumnProductTo(m1,m2,r,c,m1.Length1) } predicate MMCOL(m1: array2<int>, m2: array2<int>, m3: array2<int>,row:nat,coln:nat) requires AllowedToMultiplyInto(m1, m2, m3) requires row < m1.Length0 && coln <= m2.Length1 { forall c:nat :: c < coln ==> m3[row,c] == RowColumnProductTo(m1,m2,row,c,m1.Length1) } predicate MMI(m1: array2<int>, m2: array2<int>, m3: array2<int>,row:nat,col:nat,n:nat) requires AllowedToMultiplyInto(m1, m2, m3) requires row < m1.Length0 && col < m2.Length1 && n<=m1.Length1 { m3[row,col] == RowColumnProductTo(m1, m2, row, col, n) } method Multiply'(m1: array2<int>, m2: array2<int>) returns (m3: array2<int>) requires AllowedToMultiply(m1, m2) ensures MM(m1, m2, m3) { m3 := new int[m1.Length0, m2.Length1]; var row:nat := 0; // loop over rows of m1 while(row < m1.Length0) invariant row <= m1.Length0 invariant forall rn:nat :: rn <= row ==> MMROW(m1, m2, m3, rn) modifies m3 { assert MMROW(m1, m2, m3, row); // loop over coloums of m2 var col:nat := 0; while(col < m2.Length1) invariant col <= m2.Length1 invariant forall rn:nat :: rn <= row ==> MMROW(m1, m2, m3, rn) invariant forall n:nat :: n <= col ==> MMCOL(m1, m2, m3,row, n) { assert MMCOL(m1, m2, m3, row, col); // // loop over elements of m1 row / m2 column var i:nat := 0; m3[row,col] := 0; while(i < m1.Length1) invariant i <= m1.Length1 invariant forall rn:nat :: rn < row ==> MMROW(m1, m2, m3, rn) invariant forall c:nat :: c < col ==> MMCOL(m1, m2, m3, row, c) invariant forall j:nat :: j <= i ==> MMI(m1, m2, m3, row, col, j) { assert MMI(m1, m2, m3, row, col, i); m3[row,col]:= m3[row,col]+(m1[row,i]*m2[i,col]); i := i+1; assert MMI(m1, m2, m3, row, col, i); } assert MMI(m1, m2, m3, row, col, m1.Length1); assert m3[row,col] == RowColumnProductTo(m1,m2,row,col,m1.Length1); col := col+1; assert MMCOL(m1, m2, m3, row, col); } assert MMCOL(m1, m2, m3, row, m2.Length1); row := row+1; assert MMROW(m1, m2, m3, row); } assert MMROW(m1, m2, m3, m1.Length0); } |
Here are various other intermidiate stages and suggestions. We went through quite a few iterations, because their requirements on what I could and couldn’t change were not intuiative. I now presume that was because they didn’t want to say to me “here is the coursework specification, look it says don’t change that bit”.
In fact, looking in more detail, one of the attempts that they shared with me has this comment in it
// TODO: continue here, multiplying m1 by m2 placing the result in m3 such that MM(m1, m2, m3) will become true |
Which looks to me just like the kind of thing that I have seen written in many other courseworks. Hmm. Sucks.
http://rise4fun.com/Dafny/5F2R2
http://rise4fun.com/Dafny/6PNo
http://rise4fun.com/Dafny/mXi49
http://rise4fun.com/Dafny/VtXb
http://rise4fun.com/Dafny/mgoeu
http://rise4fun.com/Dafny/1Yslx
http://rise4fun.com/Dafny/WUop
http://rise4fun.com/Dafny/RPnU
http://rise4fun.com/Dafny/ji9A