microsoft/qdk
Publicmirrored fromhttps://github.com/microsoft/qdkAvailable
source/paulimer/tests/bitmatrix_test.rs
423lines · modecode
| 1 | // Copyright (c) Microsoft Corporation. |
| 2 | // Licensed under the MIT License. |
| 3 | |
| 4 | use paulimer::bits::bitmatrix::{directly_summed, kernel_basis_matrix, rref_with_transforms}; |
| 5 | use paulimer::bits::tiny_matrix::{tiny_matrix_from_bitmatrix, tiny_matrix_rref}; |
| 6 | use paulimer::bits::{BitMatrix, BitVec, Bitwise, BitwiseBinaryOps, WORD_COUNT_DEFAULT}; |
| 7 | use proptest::prelude::*; |
| 8 | use rand::prelude::*; |
| 9 | use rand::Rng; |
| 10 | use rustc_hash::FxHashSet; |
| 11 | use sorted_iter::assume::AssumeSortedByItemExt; |
| 12 | use sorted_iter::SortedIterator; |
| 13 | use std::collections::BTreeMap; |
| 14 | use std::str::FromStr; |
| 15 | |
| 16 | proptest! { |
| 17 | #[test] |
| 18 | fn shape(rowcount in 0..100usize, columncount in 0..100usize) { |
| 19 | let matrix = BitMatrix::<WORD_COUNT_DEFAULT>::with_shape(rowcount, columncount); |
| 20 | assert_eq!(matrix.rowcount(), rowcount); |
| 21 | assert_eq!(matrix.columncount(), columncount); |
| 22 | assert_eq!(matrix.shape(), (rowcount, columncount)); |
| 23 | } |
| 24 | |
| 25 | #[test] |
| 26 | fn zeros(rowcount in 0..100usize, columncount in 0..100usize) { |
| 27 | let matrix = BitMatrix::<WORD_COUNT_DEFAULT>::zeros(rowcount, columncount); |
| 28 | for irow in 0..matrix.rowcount() { |
| 29 | for icol in 0..matrix.columncount() { |
| 30 | assert!(!matrix[(irow, icol)]); |
| 31 | } |
| 32 | } |
| 33 | } |
| 34 | |
| 35 | #[test] |
| 36 | fn indexing(matrix in arbitrary_bitmatrix(100)) { |
| 37 | for irow in 0..matrix.rowcount() { |
| 38 | for icol in 0..matrix.columncount() { |
| 39 | assert_eq!(matrix[(irow, icol)], matrix[[irow, icol]]); |
| 40 | } |
| 41 | } |
| 42 | } |
| 43 | |
| 44 | #[test] |
| 45 | fn clone(matrix in arbitrary_bitmatrix(100)) { |
| 46 | assert_eq!(matrix, matrix.clone()); |
| 47 | } |
| 48 | |
| 49 | #[test] |
| 50 | fn swap_rows(matrix in nonempty_bitmatrix(100), raw_row_indexes in (0..100usize, 0..100usize)) { |
| 51 | let row_indexes = [raw_row_indexes.0 % matrix.rowcount(), raw_row_indexes.1 % matrix.rowcount()]; |
| 52 | let mut swapped = matrix.clone(); |
| 53 | swapped.swap_rows(row_indexes[0], row_indexes[1]); |
| 54 | for column_index in 0..matrix.columncount() { |
| 55 | assert_eq!(matrix[[row_indexes[0], column_index]], swapped[[row_indexes[1], column_index]]); |
| 56 | } |
| 57 | let row_indexes = row_indexes.into_iter().collect::<rustc_hash::FxHashSet<usize>>(); |
| 58 | for row_index in (0..matrix.rowcount()).collect::<FxHashSet<usize>>().difference(&row_indexes) { |
| 59 | for column_index in 0..matrix.columncount() { |
| 60 | assert_eq!(matrix[[*row_index, column_index]], swapped[[*row_index, column_index]]); |
| 61 | } |
| 62 | } |
| 63 | } |
| 64 | |
| 65 | #[test] |
| 66 | fn swap_columns(matrix in nonempty_bitmatrix(100), raw_column_indexes in (0..100usize, 0..100usize)) { |
| 67 | let column_indexes = [raw_column_indexes.0 % matrix.columncount(), raw_column_indexes.1 % matrix.columncount()]; |
| 68 | let mut swapped = matrix.clone(); |
| 69 | swapped.swap_columns(column_indexes[0], column_indexes[1]); |
| 70 | for row_index in 0..matrix.rowcount() { |
| 71 | assert_eq!(matrix[[row_index, column_indexes[0]]], swapped[[row_index, column_indexes[1]]]); |
| 72 | } |
| 73 | let column_indexes = column_indexes.into_iter().collect::<rustc_hash::FxHashSet<usize>>(); |
| 74 | for column_index in (0..matrix.columncount()).collect::<FxHashSet<usize>>().difference(&column_indexes) { |
| 75 | for row_index in 0..matrix.rowcount() { |
| 76 | assert_eq!(matrix[[row_index, *column_index]], swapped[[row_index, *column_index]]); |
| 77 | } |
| 78 | } |
| 79 | } |
| 80 | |
| 81 | #[test] |
| 82 | fn addition((left, right) in equal_shape_bitmatrices(100)) { |
| 83 | let sum = &left + &right; |
| 84 | for irow in 0..left.rowcount() { |
| 85 | for icol in 0..right.columncount() { |
| 86 | let index = (irow, icol); |
| 87 | assert_eq!(sum[index], left[index] ^ right[index]); |
| 88 | } |
| 89 | } |
| 90 | assert_eq!(sum, &right + &left); |
| 91 | } |
| 92 | |
| 93 | #[test] |
| 94 | fn addition_inplace((mut left, right) in equal_shape_bitmatrices(100)) { |
| 95 | let sum = &left + &right; |
| 96 | left += &right; |
| 97 | assert_eq!(sum, left); |
| 98 | } |
| 99 | |
| 100 | #[test] |
| 101 | fn xor((left, right) in equal_shape_bitmatrices(100)) { |
| 102 | assert_eq!(&left ^ &right, &left + &right); |
| 103 | } |
| 104 | |
| 105 | #[test] |
| 106 | fn xor_inplace((mut left, right) in equal_shape_bitmatrices(100)) { |
| 107 | let xor = &left ^ &right; |
| 108 | left ^= &right; |
| 109 | assert_eq!(xor, left); |
| 110 | } |
| 111 | |
| 112 | #[test] |
| 113 | fn and((left, right) in equal_shape_bitmatrices(100)) { |
| 114 | let and = &left & &right; |
| 115 | for irow in 0..left.rowcount() { |
| 116 | for icol in 0..left.columncount() { |
| 117 | let index = (irow, icol); |
| 118 | assert_eq!(and[index], left[index] & right[index]); |
| 119 | } |
| 120 | } |
| 121 | assert_eq!(and, &right & &left); |
| 122 | } |
| 123 | |
| 124 | |
| 125 | #[test] |
| 126 | fn and_inplace((mut left, right) in equal_shape_bitmatrices(100)) { |
| 127 | let and = &left & &right; |
| 128 | left &= &right; |
| 129 | assert_eq!(and, left); |
| 130 | } |
| 131 | |
| 132 | #[test] |
| 133 | fn equality(left in arbitrary_bitmatrix(100), right in arbitrary_bitmatrix(100)) { |
| 134 | let mut are_equal = left.shape() == right.shape(); |
| 135 | if are_equal { |
| 136 | for irow in 0..left.rowcount() { |
| 137 | for icol in 0..right.columncount() { |
| 138 | let index = (irow, icol); |
| 139 | are_equal &= left[index] == right[index]; |
| 140 | } |
| 141 | } |
| 142 | } |
| 143 | assert_eq!(left == right, are_equal); |
| 144 | } |
| 145 | |
| 146 | #[test] |
| 147 | fn transpose(matrix in arbitrary_bitmatrix(100)) { |
| 148 | let transposed = matrix.transposed(); |
| 149 | for row in 0..matrix.rowcount() { |
| 150 | for column in 0..matrix.columncount() { |
| 151 | assert_eq!(matrix[(row, column)], transposed[(column, row)]); |
| 152 | } |
| 153 | } |
| 154 | } |
| 155 | |
| 156 | #[test] |
| 157 | fn inverse(matrix in invertible_bitmatrix(100)) { |
| 158 | let inverted = matrix.inverted(); |
| 159 | let identity = BitMatrix::identity(matrix.rowcount()); |
| 160 | assert_eq!(&matrix * &inverted, identity); |
| 161 | } |
| 162 | |
| 163 | #[test] |
| 164 | fn echelon_form(matrix in arbitrary_bitmatrix(100)) { |
| 165 | let mut echeloned = matrix.clone(); |
| 166 | let profile = echeloned.echelonize(); |
| 167 | assert!(is_rref(&echeloned, &profile)); |
| 168 | assert!(preserves_rowspan_of(&matrix, &echeloned)); |
| 169 | } |
| 170 | |
| 171 | #[test] |
| 172 | fn tiny_matrix_echelon_form(matrix in fixed_size_bitmatrix(32,60)) { |
| 173 | let mut echeloned = matrix.clone(); |
| 174 | let _ = echeloned.echelonize(); |
| 175 | let mut tiny1 = tiny_matrix_from_bitmatrix::<32>(&matrix); |
| 176 | tiny_matrix_rref::<32,60>(&mut tiny1); |
| 177 | let tiny2 = tiny_matrix_from_bitmatrix::<32>(&echeloned); |
| 178 | assert_eq!(tiny1,tiny2); |
| 179 | } |
| 180 | |
| 181 | #[test] |
| 182 | fn direct_sum(left in arbitrary_bitmatrix(100), right in arbitrary_bitmatrix(100)) { |
| 183 | let summed = directly_summed([&left, &right]); |
| 184 | let expected_shape = (left.rowcount() + right.rowcount(), left.columncount() + right.columncount()); |
| 185 | assert_eq!(expected_shape, summed.shape()); |
| 186 | for row_index in 0..left.rowcount() { |
| 187 | for column_index in 0..left.columncount() { |
| 188 | assert_eq!(left[(row_index, column_index)], summed[(row_index, column_index)]); |
| 189 | } |
| 190 | for column_index in left.columncount()..summed.columncount() { |
| 191 | assert!(!summed[(row_index, column_index)]); |
| 192 | } |
| 193 | } |
| 194 | for row_index in 0..right.rowcount() { |
| 195 | for column_index in 0..right.columncount() { |
| 196 | assert_eq!(right[(row_index, column_index)], summed[(left.rowcount() + row_index, left.columncount() + column_index)]); |
| 197 | } |
| 198 | for column_index in 0..left.columncount() { |
| 199 | assert!(!summed[(left.rowcount() + row_index, column_index)]); |
| 200 | } |
| 201 | } |
| 202 | } |
| 203 | |
| 204 | } |
| 205 | |
| 206 | macro_rules! bitmatrix{ |
| 207 | ($($t:tt)+) => { |
| 208 | $crate::BitMatrix::<{paulimer::bits::WORD_COUNT_DEFAULT}>::from_str(stringify!($($t)+)).unwrap() |
| 209 | }; |
| 210 | } |
| 211 | |
| 212 | prop_compose! { |
| 213 | fn arbitrary_bitmatrix(max_dimension: usize)(shape in (0..=max_dimension, 0..=max_dimension)) -> BitMatrix { |
| 214 | random_bitmatrix(shape.0, shape.1) |
| 215 | } |
| 216 | } |
| 217 | |
| 218 | prop_compose! { |
| 219 | fn fixed_size_bitmatrix(row_count: usize, column_count: usize)(_ in 0..column_count) -> BitMatrix { |
| 220 | random_bitmatrix(row_count, column_count) |
| 221 | } |
| 222 | } |
| 223 | |
| 224 | prop_compose! { |
| 225 | fn invertible_bitmatrix(max_dimension: usize)(dimension in 1..=max_dimension) -> BitMatrix { |
| 226 | let mut matrix = BitMatrix::identity(dimension); |
| 227 | for _ in 0..dimension^2 { |
| 228 | let from_index = thread_rng().gen_range(0..dimension); |
| 229 | let to_index = thread_rng().gen_range(0..dimension); |
| 230 | if from_index != to_index { |
| 231 | matrix.add_into_row(to_index, from_index); |
| 232 | } |
| 233 | } |
| 234 | for _ in 0..dimension.pow(2) { |
| 235 | let from_index = thread_rng().gen_range(0..dimension); |
| 236 | let to_index = thread_rng().gen_range(0..dimension); |
| 237 | matrix.swap_rows(from_index, to_index); |
| 238 | } |
| 239 | matrix |
| 240 | } |
| 241 | } |
| 242 | |
| 243 | prop_compose! { |
| 244 | fn nonempty_bitmatrix(max_dimension: usize)(shape in (1..=max_dimension, 1..=max_dimension)) -> BitMatrix { |
| 245 | random_bitmatrix(shape.0, shape.1) |
| 246 | } |
| 247 | } |
| 248 | |
| 249 | prop_compose! { |
| 250 | fn equal_shape_bitmatrices(max_dimension: usize)(shape in (1..=max_dimension, 1..=max_dimension)) -> (BitMatrix, BitMatrix) { |
| 251 | (random_bitmatrix(shape.0, shape.1), random_bitmatrix(shape.0, shape.1)) |
| 252 | } |
| 253 | } |
| 254 | |
| 255 | // #[test] |
| 256 | // fn reduce() { |
| 257 | // for _ in 0..100 { |
| 258 | // let array = random_bitmatrix(100, 100); |
| 259 | // let reduced = rref(array); |
| 260 | // assert!(is_rref(&reduced)); |
| 261 | // } |
| 262 | |
| 263 | // for _ in 0..100 { |
| 264 | // let array = random_bitmatrix(50, 100); |
| 265 | // let (reduced, profile) = rref_with_rank_profile(array); |
| 266 | // assert_eq!(profile.len(), reduced.rowcount()); |
| 267 | // assert!(is_rref(&reduced)); |
| 268 | // } |
| 269 | |
| 270 | // { |
| 271 | // let matrix = bitmatrix!( |
| 272 | // |10 011 01| |
| 273 | // |.. 111 01| |
| 274 | // |.. ... 10|); |
| 275 | // assert!(is_rref(&matrix)); |
| 276 | // let (reduced, profile) = rref_with_rank_profile(matrix); |
| 277 | // assert!(is_rref(&reduced)); |
| 278 | // assert_eq!(profile, vec![0, 2, 5]); |
| 279 | // } |
| 280 | // } |
| 281 | |
| 282 | #[test] |
| 283 | fn reduce_with_transforms() { |
| 284 | for _ in 0..100 { |
| 285 | check_rref_with_transforms_on_random_matrix(100, 100); |
| 286 | } |
| 287 | for _ in 0..100 { |
| 288 | check_rref_with_transforms_on_random_matrix(50, 100); |
| 289 | } |
| 290 | } |
| 291 | |
| 292 | fn check_rref_with_transforms_on_random_matrix(nrows: usize, ncols: usize) { |
| 293 | let array = random_bitmatrix(nrows, ncols); |
| 294 | let (reduced, t, t_inv_t, profile) = rref_with_transforms(array.clone()); |
| 295 | assert!(is_rref(&reduced, &profile)); |
| 296 | assert_eq!(t.dot(&array), reduced); |
| 297 | assert_eq!( |
| 298 | t.dot(&t_inv_t.transposed()), |
| 299 | BitMatrix::identity(array.rowcount()) |
| 300 | ); |
| 301 | } |
| 302 | |
| 303 | #[test] |
| 304 | fn test_dot() { |
| 305 | println!("0"); |
| 306 | let x = bitmatrix!( |
| 307 | |01| |
| 308 | |10|); |
| 309 | let id = bitmatrix!( |
| 310 | |10| |
| 311 | |01|); |
| 312 | println!("1"); |
| 313 | assert_eq!(x.dot(&x), id); |
| 314 | assert_eq!(x.dot(&id), x); |
| 315 | assert_eq!(id.dot(&x), x); |
| 316 | |
| 317 | // multiplication is associative |
| 318 | println!("2"); |
| 319 | for _ in 0..100 { |
| 320 | let a = random_bitmatrix(10, 10); |
| 321 | let b = random_bitmatrix(10, 10); |
| 322 | let c = random_bitmatrix(10, 10); |
| 323 | assert_eq!((a.dot(&b)).dot(&c), a.dot(&b.dot(&c))); |
| 324 | } |
| 325 | |
| 326 | println!("3"); |
| 327 | // multiplication by zero is zero |
| 328 | for _ in 0..100 { |
| 329 | let a = random_bitmatrix(10, 10); |
| 330 | let z = BitMatrix::zeros(10, 10); |
| 331 | assert_eq!(a.dot(&z), z); |
| 332 | } |
| 333 | |
| 334 | // multiplication by id |
| 335 | for _ in 0..100 { |
| 336 | let a = random_bitmatrix(3, 3); |
| 337 | let id = BitMatrix::identity(3); |
| 338 | assert_eq!(a.dot(&id), a); |
| 339 | } |
| 340 | } |
| 341 | |
| 342 | #[test] |
| 343 | fn test_kernel_basis() { |
| 344 | let num_cols = 100; |
| 345 | for _ in 0..100 { |
| 346 | let mut matrix = random_bitmatrix(50, 100); |
| 347 | let rrp = matrix.echelonize(); |
| 348 | let mut kernel_basis_matrix = kernel_basis_matrix(&matrix); |
| 349 | let prod = matrix.dot(&kernel_basis_matrix.transposed()); |
| 350 | assert!(prod.is_zero()); |
| 351 | let rrpc = kernel_basis_matrix.echelonize(); |
| 352 | assert_eq!(rrp.len() + rrpc.len(), num_cols); |
| 353 | } |
| 354 | } |
| 355 | |
| 356 | fn preserves_rowspan_of(matrix: &BitMatrix, rref_matrix: &BitMatrix) -> bool { |
| 357 | let profile = fast_profile_of(rref_matrix); |
| 358 | let mut profile_rows = BTreeMap::new(); |
| 359 | for (row_index, column_index) in profile.iter().enumerate() { |
| 360 | profile_rows.insert(column_index, row_index); |
| 361 | } |
| 362 | for row in matrix.rows() { |
| 363 | let mut reduced = BitVec::<WORD_COUNT_DEFAULT>::from_view(&row); |
| 364 | let support = row |
| 365 | .support() |
| 366 | .assume_sorted_by_item() |
| 367 | .intersection(profile.iter().copied().assume_sorted_by_item()); |
| 368 | |
| 369 | for column_index in support { |
| 370 | let row_index = profile_rows[&column_index]; |
| 371 | let rref_row = BitVec::<WORD_COUNT_DEFAULT>::from_view(&rref_matrix.row(row_index)); |
| 372 | reduced.bitxor_assign(&rref_row); |
| 373 | } |
| 374 | if reduced.weight() > 0 { |
| 375 | return false; |
| 376 | } |
| 377 | } |
| 378 | true |
| 379 | } |
| 380 | |
| 381 | fn is_rref(matrix: &BitMatrix, with_profile: &[usize]) -> bool { |
| 382 | let expected_profile = fast_profile_of(matrix); |
| 383 | (expected_profile == with_profile) && columns_are_pivots_of(matrix, with_profile) |
| 384 | } |
| 385 | |
| 386 | fn columns_are_pivots_of(matrix: &BitMatrix, column_indexes: &[usize]) -> bool { |
| 387 | for &column_index in column_indexes { |
| 388 | let column = matrix.column(column_index); |
| 389 | if column.weight() != 1 { |
| 390 | return false; |
| 391 | } |
| 392 | } |
| 393 | true |
| 394 | } |
| 395 | |
| 396 | fn fast_profile_of(matrix: &BitMatrix) -> Vec<usize> { |
| 397 | let mut profile = vec![]; |
| 398 | for row_index in 0..matrix.rowcount() { |
| 399 | let row = matrix.row(row_index); |
| 400 | let pivot = row.into_iter().position(|bit| bit); |
| 401 | if pivot.is_none() { |
| 402 | break; |
| 403 | } |
| 404 | profile.push(pivot.unwrap()); |
| 405 | } |
| 406 | profile |
| 407 | } |
| 408 | |
| 409 | fn random_bitmatrix(rowcount: usize, columncount: usize) -> BitMatrix { |
| 410 | let mut matrix = BitMatrix::with_shape(rowcount, columncount); |
| 411 | let mut bits = std::iter::from_fn(move || Some(thread_rng().gen::<bool>())); |
| 412 | for row_index in 0..rowcount { |
| 413 | for column_index in 0..columncount { |
| 414 | matrix.set((row_index, column_index), bits.next().expect("boom")); |
| 415 | } |
| 416 | } |
| 417 | for _ in 0..rowcount { |
| 418 | let from_index = thread_rng().gen_range(0..rowcount); |
| 419 | let to_index = thread_rng().gen_range(0..rowcount); |
| 420 | matrix.swap_rows(from_index, to_index); |
| 421 | } |
| 422 | matrix |
| 423 | } |
| 424 | |