From 636cd2da16a111e1c9e846d5bc9b6223b3f367ad Mon Sep 17 00:00:00 2001 From: baoqiwen Date: Tue, 26 May 2026 19:38:03 +0800 Subject: [PATCH] Add headdim 256 md5 and Disable Python GC during timed iterations --- benchmark_flashmask.py | 8 + flashmask_bwd_gt.json | 387 ++++++++++++++++++++++++++++++++++++++- flashmask_fwd_gt.json | 79 +++++++- kernel_test_seq_info.txt | 7 + run.sh | 1 + run_aadiff.sh | 26 ++- test_bwd_md5sum.py | 1 + test_flashmask.py | 4 +- test_fwd_md5sum.py | 1 + 9 files changed, 500 insertions(+), 14 deletions(-) diff --git a/benchmark_flashmask.py b/benchmark_flashmask.py index 548a69d..9d6d3b7 100644 --- a/benchmark_flashmask.py +++ b/benchmark_flashmask.py @@ -11,6 +11,8 @@ from paddle.nn.functional.flash_attention import flashmask_attention import random import os +import gc + from datetime import datetime np.random.seed(0) @@ -89,7 +91,11 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flu # Warm-up for _ in range(n_warmup): fn() + # Benchmark + # Disable Python GC during timed iterations + gc.collect() + gc.disable() for i in range(n_repeat): # we don't want `fn` to accumulate gradient values # if it contains a backward pass. So we clear the @@ -103,6 +109,8 @@ def do_bench(fn, warmup=25, rep=100, grad_to_none=None, quantiles=None, fast_flu start_event[i].record() fn() end_event[i].record() + gc.enable() + # Record clocks paddle.device.synchronize() times = paddle.to_tensor([s.elapsed_time(e) for s, e in zip(start_event, end_event)], dtype=paddle.float32) diff --git a/flashmask_bwd_gt.json b/flashmask_bwd_gt.json index 9567771..12f9be0 100644 --- a/flashmask_bwd_gt.json +++ b/flashmask_bwd_gt.json @@ -219,6 +219,61 @@ "dk": "420e704788b21312a1d0b83d589bcd9d", "dv": "2eb37a48b132af94dd4fd937f5b508ae" }, + "full-1-8192-33792-2-1-1-256-256-4-dtype0": { + "dq": "f37c4e9e91b236a9926a9b1b31af23fc", + "dk": "a12b57b57fa1703f4e6ed046c9f649e5", + "dv": "7cd5f5173bfedc14260a7fe46c0394f5" + }, + "causal-1-8192-33792-2-1-1-256-256-4-dtype0": { + "dq": "9f9bf1144659f614bf0df8059c88cc15", + "dk": "ff2a54065eb7797723249ed4b87cf9bc", + "dv": "07efcc556381a21af4893fd4d1cea57f" + }, + "sliding_window-1-8192-33792-2-1-1-256-256-4-dtype0": { + "dq": "40e108fc467f71165e89d631bfaa74dc", + "dk": "79d223b18eb0b38168adce53a92227ea", + "dv": "f982186897e1a2b537cc5e912496651b" + }, + "causal_document-1-8192-33792-2-1-1-256-256-4-dtype0": { + "dq": "9f9bf1144659f614bf0df8059c88cc15", + "dk": "ff2a54065eb7797723249ed4b87cf9bc", + "dv": "07efcc556381a21af4893fd4d1cea57f" + }, + "document-1-8192-33792-2-1-1-256-256-4-dtype0": { + "dq": "e4be8e62e097de6be460f771043bf9fb", + "dk": "ce0d25700da3a7445ae86c2cf32ffe7d", + "dv": "9822883fddf4895532a01756ca99eca2" + }, + "share_question-1-8192-33792-2-1-1-256-256-4-dtype0": { + "dq": "9f9bf1144659f614bf0df8059c88cc15", + "dk": "ff2a54065eb7797723249ed4b87cf9bc", + "dv": "07efcc556381a21af4893fd4d1cea57f" + }, + "causal_blockwise-1-8192-33792-2-1-1-256-256-4-dtype0": { + "dq": "9f9bf1144659f614bf0df8059c88cc15", + "dk": "ff2a54065eb7797723249ed4b87cf9bc", + "dv": "07efcc556381a21af4893fd4d1cea57f" + }, + "prefix_lm_document_mask-1-8192-33792-2-1-1-256-256-4-dtype0": { + "dq": "ea01590c67a6d948324e41693b08a5a7", + "dk": "d21acdef6f97d1202902b89c9f9dbfd8", + "dv": "38a71d6b20e762d2040890b171672478" + }, + "prefix_lm_causal-1-8192-33792-2-1-1-256-256-4-dtype0": { + "dq": "ea01590c67a6d948324e41693b08a5a7", + "dk": "d21acdef6f97d1202902b89c9f9dbfd8", + "dv": "38a71d6b20e762d2040890b171672478" + }, + "qk_sparse-1-8192-33792-2-1-1-256-256-4-dtype0": { + "dq": "e7c653ae0d2be8b2f65027b34daf9eac", + "dk": "0140387a179f308be875d79b825a4d6c", + "dv": "aa2393fada4ed70ff4de68b6a397d7c7" + }, + "random_eviction-1-8192-33792-2-1-1-256-256-4-dtype0": { + "dq": "9f9bf1144659f614bf0df8059c88cc15", + "dk": "ff2a54065eb7797723249ed4b87cf9bc", + "dv": "07efcc556381a21af4893fd4d1cea57f" + }, "full-2840-32-32-16-4-1-64-64-4-dtype0": { "dq": "ad6cf01fed79595c11cd53eb4626309f", "dk": "b80c66a9fed4f85bd23ce2053e9f7388", @@ -439,6 +494,61 @@ "dk": "5a8e45b578aae3f2a45b9d7eeab8cd36", "dv": "fe269c5f60c63b2d02cd91bfc32e80c5" }, + "full-2840-32-32-16-4-1-256-256-4-dtype0": { + "dq": "7df6cfd40316e3199d8992b47619debf", + "dk": "22f5a451c040cf741abec3389221a144", + "dv": "36931880c76c98e8f727bc19be3e81ca" + }, + "causal-2840-32-32-16-4-1-256-256-4-dtype0": { + "dq": "b0ec6fcc9efb02ba402bb31b653725ab", + "dk": "12939996b0603342abc3630380c0348f", + "dv": "9492983ed08f7ff484a541f8c126df15" + }, + "sliding_window-2840-32-32-16-4-1-256-256-4-dtype0": { + "dq": "6c44504571983ee33add848126aa7273", + "dk": "6bcb14a26b403d6bc80e473b72815d5a", + "dv": "1010e2fdeef04e904524cb83bf431d9f" + }, + "causal_document-2840-32-32-16-4-1-256-256-4-dtype0": { + "dq": "9c0f305e5cc77fc30135b7da82c6ce56", + "dk": "3e1726f5265eea6c9aa871b0b4463d2c", + "dv": "3d3a2c9de8fea9938bc3b8bf318cc1fc" + }, + "document-2840-32-32-16-4-1-256-256-4-dtype0": { + "dq": "c98cd20dee78dddb248b310aa2bd5021", + "dk": "0f3c5fb7ca8af39e201d40b289439702", + "dv": "55f4b609a2d476a14ef13a45c831eb01" + }, + "share_question-2840-32-32-16-4-1-256-256-4-dtype0": { + "dq": "34285cb50850713e8ac471dd4fde775a", + "dk": "5dc0664dbdbf599a47a39174b8d89ef5", + "dv": "dffe9bc3bdf147af3f89184ea3d9d14c" + }, + "causal_blockwise-2840-32-32-16-4-1-256-256-4-dtype0": { + "dq": "680a05f7f3d2a8b803158da509bf15e9", + "dk": "fbd767bd34c9b70593308e7d88772484", + "dv": "656ca12d8bef491df17c2af5061f9aae" + }, + "prefix_lm_document_mask-2840-32-32-16-4-1-256-256-4-dtype0": { + "dq": "4656a73f0538afe153dbf04dbafdb7e2", + "dk": "6964c2df38af6044dd0c7ca8fa942d0e", + "dv": "b35ee05e3c9f681b8e6856f56950d6c0" + }, + "prefix_lm_causal-2840-32-32-16-4-1-256-256-4-dtype0": { + "dq": "1f0be8e1bd1b19f569861a0df535c10b", + "dk": "6de8348093217f9d4080d5bb918f38c1", + "dv": "9543802e7f87a2a2fd56bc4174bf378e" + }, + "qk_sparse-2840-32-32-16-4-1-256-256-4-dtype0": { + "dq": "7e484844581fe46bde9e098aa0b0fbbf", + "dk": "3f2ac3e713bcd7cf40db70b57f2840b6", + "dv": "b250bd21f8a7d6770ef776271aaae84b" + }, + "random_eviction-2840-32-32-16-4-1-256-256-4-dtype0": { + "dq": "d644d373c77db58b1715fb9693ddb52f", + "dk": "c3384712078b752863addfc00a2e008b", + "dv": "4ca5aa20c2c6ce073df961d7b8d032c2" + }, "full-2840-32-32-16-4-4-64-64-4-dtype0": { "dq": "ad6cf01fed79595c11cd53eb4626309f", "dk": "b80c66a9fed4f85bd23ce2053e9f7388", @@ -659,6 +769,61 @@ "dk": "1249f3d733a9175a9d9c69ddedc18109", "dv": "0c4b71c12fc119282d32dab0ab09a6a0" }, + "full-2840-32-32-16-4-4-256-256-4-dtype0": { + "dq": "7df6cfd40316e3199d8992b47619debf", + "dk": "22f5a451c040cf741abec3389221a144", + "dv": "36931880c76c98e8f727bc19be3e81ca" + }, + "causal-2840-32-32-16-4-4-256-256-4-dtype0": { + "dq": "b0ec6fcc9efb02ba402bb31b653725ab", + "dk": "12939996b0603342abc3630380c0348f", + "dv": "9492983ed08f7ff484a541f8c126df15" + }, + "sliding_window-2840-32-32-16-4-4-256-256-4-dtype0": { + "dq": "6c44504571983ee33add848126aa7273", + "dk": "6bcb14a26b403d6bc80e473b72815d5a", + "dv": "1010e2fdeef04e904524cb83bf431d9f" + }, + "causal_document-2840-32-32-16-4-4-256-256-4-dtype0": { + "dq": "9c0f305e5cc77fc30135b7da82c6ce56", + "dk": "3e1726f5265eea6c9aa871b0b4463d2c", + "dv": "3d3a2c9de8fea9938bc3b8bf318cc1fc" + }, + "document-2840-32-32-16-4-4-256-256-4-dtype0": { + "dq": "c98cd20dee78dddb248b310aa2bd5021", + "dk": "0f3c5fb7ca8af39e201d40b289439702", + "dv": "55f4b609a2d476a14ef13a45c831eb01" + }, + "share_question-2840-32-32-16-4-4-256-256-4-dtype0": { + "dq": "34285cb50850713e8ac471dd4fde775a", + "dk": "5dc0664dbdbf599a47a39174b8d89ef5", + "dv": "dffe9bc3bdf147af3f89184ea3d9d14c" + }, + "causal_blockwise-2840-32-32-16-4-4-256-256-4-dtype0": { + "dq": "680a05f7f3d2a8b803158da509bf15e9", + "dk": "fbd767bd34c9b70593308e7d88772484", + "dv": "656ca12d8bef491df17c2af5061f9aae" + }, + "prefix_lm_document_mask-2840-32-32-16-4-4-256-256-4-dtype0": { + "dq": "4656a73f0538afe153dbf04dbafdb7e2", + "dk": "6964c2df38af6044dd0c7ca8fa942d0e", + "dv": "b35ee05e3c9f681b8e6856f56950d6c0" + }, + "prefix_lm_causal-2840-32-32-16-4-4-256-256-4-dtype0": { + "dq": "1f0be8e1bd1b19f569861a0df535c10b", + "dk": "6de8348093217f9d4080d5bb918f38c1", + "dv": "9543802e7f87a2a2fd56bc4174bf378e" + }, + "qk_sparse-2840-32-32-16-4-4-256-256-4-dtype0": { + "dq": "7e484844581fe46bde9e098aa0b0fbbf", + "dk": "3f2ac3e713bcd7cf40db70b57f2840b6", + "dv": "b250bd21f8a7d6770ef776271aaae84b" + }, + "random_eviction-2840-32-32-16-4-4-256-256-4-dtype0": { + "dq": "2c16779037a920dee61581cb925f9b36", + "dk": "c4987d90482049babb972e61c10a6aab", + "dv": "37cba2e1f710d8cbea8255420609d8fc" + }, "full-1-300-300-16-16-1-64-64-4-dtype0": { "dq": "95ef08c4df18c656f47577883f7141e8", "dk": "760437ffc8679f279c2376d40d58e707", @@ -879,6 +1044,61 @@ "dk": "9aff28f1cede37590e865eb65a4d968e", "dv": "6e4a72323b5e061c351d88925631b17e" }, + "full-1-300-300-16-16-1-256-256-4-dtype0": { + "dq": "2d53d055dc0a56e2914ddbdd7f6674cf", + "dk": "74648f29e278e04f94a5a6676d97613b", + "dv": "48fbd5d5673deec8dc1828972e3799bd" + }, + "causal-1-300-300-16-16-1-256-256-4-dtype0": { + "dq": "1ed4f59218b0a69f6d709fc09cef05b3", + "dk": "8072a33327e8bfe8256e0f0f4f473d13", + "dv": "749729ea89ce5306d15f4a9fd210e978" + }, + "sliding_window-1-300-300-16-16-1-256-256-4-dtype0": { + "dq": "bbf11190a901cb3e78cf312a7ff72d10", + "dk": "b68fa3450715e120f6b78f4223d11eca", + "dv": "b92e364678efe25afba63d500a9b7900" + }, + "causal_document-1-300-300-16-16-1-256-256-4-dtype0": { + "dq": "5c729067d803742c7bb62b12d5994c90", + "dk": "3cf05422f1564ec6358667b70ac336b4", + "dv": "ac19c7491ce0de7757546bd0b510b529" + }, + "document-1-300-300-16-16-1-256-256-4-dtype0": { + "dq": "2a299fb677cedf17cf4f3d37f1f52692", + "dk": "f0997d442bd4116948f0a4870bfa2ce5", + "dv": "a669f4019a2a73ee186ec737169e48d0" + }, + "share_question-1-300-300-16-16-1-256-256-4-dtype0": { + "dq": "15bdad1af7ee9d9a6c4ec8026e1cfdd7", + "dk": "d81cc5edb81d1affecdb85972d1c29a8", + "dv": "6f51a5ecd87683aa20522420de9fe659" + }, + "causal_blockwise-1-300-300-16-16-1-256-256-4-dtype0": { + "dq": "540339b8af91dc7523d7a1d0d5ca43a0", + "dk": "2e0b2da35b43f37f776bcb7984d9c68c", + "dv": "9ecd2543feeee2d401a1909a03d2581e" + }, + "prefix_lm_document_mask-1-300-300-16-16-1-256-256-4-dtype0": { + "dq": "b23f679c8d4badce1f23a38577e01a28", + "dk": "9525a3d85e65dd26f1cb63ef9decc48d", + "dv": "9c73725cb01c2b9aea5c10c1c4b1fd43" + }, + "prefix_lm_causal-1-300-300-16-16-1-256-256-4-dtype0": { + "dq": "9960e8daa7448606db34c4c4c5b59d36", + "dk": "b10b50318ae247a56739e18f01faf6d2", + "dv": "efbca1839edc75eab5fe406afa242ce8" + }, + "qk_sparse-1-300-300-16-16-1-256-256-4-dtype0": { + "dq": "e3adb4b47a0170e364ca2b6642b5df47", + "dk": "5c2e49a5a911cee4ea3900f9de6a23b3", + "dv": "d53a96948a6809a843afb3d588f3120d" + }, + "random_eviction-1-300-300-16-16-1-256-256-4-dtype0": { + "dq": "5fdba436715aa3e4aa788bcbdc28ac3a", + "dk": "4a89af39357f87f1586007d99c12b03e", + "dv": "b62988fde36c2b7f7a366c8e2db7553b" + }, "full-1-300-300-16-16-16-64-64-4-dtype0": { "dq": "95ef08c4df18c656f47577883f7141e8", "dk": "760437ffc8679f279c2376d40d58e707", @@ -1099,6 +1319,61 @@ "dk": "379cf160f5143153d5299779660249ac", "dv": "e767841c37c15ece41f573b117ad367f" }, + "full-1-300-300-16-16-16-256-256-4-dtype0": { + "dq": "2d53d055dc0a56e2914ddbdd7f6674cf", + "dk": "74648f29e278e04f94a5a6676d97613b", + "dv": "48fbd5d5673deec8dc1828972e3799bd" + }, + "causal-1-300-300-16-16-16-256-256-4-dtype0": { + "dq": "1ed4f59218b0a69f6d709fc09cef05b3", + "dk": "8072a33327e8bfe8256e0f0f4f473d13", + "dv": "749729ea89ce5306d15f4a9fd210e978" + }, + "sliding_window-1-300-300-16-16-16-256-256-4-dtype0": { + "dq": "bbf11190a901cb3e78cf312a7ff72d10", + "dk": "b68fa3450715e120f6b78f4223d11eca", + "dv": "b92e364678efe25afba63d500a9b7900" + }, + "causal_document-1-300-300-16-16-16-256-256-4-dtype0": { + "dq": "5c729067d803742c7bb62b12d5994c90", + "dk": "3cf05422f1564ec6358667b70ac336b4", + "dv": "ac19c7491ce0de7757546bd0b510b529" + }, + "document-1-300-300-16-16-16-256-256-4-dtype0": { + "dq": "2a299fb677cedf17cf4f3d37f1f52692", + "dk": "f0997d442bd4116948f0a4870bfa2ce5", + "dv": "a669f4019a2a73ee186ec737169e48d0" + }, + "share_question-1-300-300-16-16-16-256-256-4-dtype0": { + "dq": "15bdad1af7ee9d9a6c4ec8026e1cfdd7", + "dk": "d81cc5edb81d1affecdb85972d1c29a8", + "dv": "6f51a5ecd87683aa20522420de9fe659" + }, + "causal_blockwise-1-300-300-16-16-16-256-256-4-dtype0": { + "dq": "540339b8af91dc7523d7a1d0d5ca43a0", + "dk": "2e0b2da35b43f37f776bcb7984d9c68c", + "dv": "9ecd2543feeee2d401a1909a03d2581e" + }, + "prefix_lm_document_mask-1-300-300-16-16-16-256-256-4-dtype0": { + "dq": "b23f679c8d4badce1f23a38577e01a28", + "dk": "9525a3d85e65dd26f1cb63ef9decc48d", + "dv": "9c73725cb01c2b9aea5c10c1c4b1fd43" + }, + "prefix_lm_causal-1-300-300-16-16-16-256-256-4-dtype0": { + "dq": "9960e8daa7448606db34c4c4c5b59d36", + "dk": "b10b50318ae247a56739e18f01faf6d2", + "dv": "efbca1839edc75eab5fe406afa242ce8" + }, + "qk_sparse-1-300-300-16-16-16-256-256-4-dtype0": { + "dq": "e3adb4b47a0170e364ca2b6642b5df47", + "dk": "5c2e49a5a911cee4ea3900f9de6a23b3", + "dv": "d53a96948a6809a843afb3d588f3120d" + }, + "random_eviction-1-300-300-16-16-16-256-256-4-dtype0": { + "dq": "9a9385c88e645fda63c9bf7cbe1d4592", + "dk": "bb80c74681136b63ebd8770f33411b38", + "dv": "657795a01a7c3fb0ee9a305b57c4f688" + }, "full-1-128-127-1-1-1-64-64-4-dtype0": { "dq": "a807005279a56317146d91c3027d387e", "dk": "ff139a0c811542436b5d65128e95707f", @@ -1319,6 +1594,61 @@ "dk": "9e07a3aa44fbc4e8232bc4b62ed9cc03", "dv": "28405fd0a3593340e49be52f99fb923c" }, + "full-1-128-127-1-1-1-256-256-4-dtype0": { + "dq": "e3adda315f37c5201965f1c4f472f277", + "dk": "395a8f5e7c3079ed0393b668562d65e7", + "dv": "b68a3fdd7a5a4a6ca38c8aa825e91e0a" + }, + "causal-1-128-127-1-1-1-256-256-4-dtype0": { + "dq": "3eb9b71d90c628b66aa9f6879ff2dcad", + "dk": "558fcd199316e9420b1a9d6a820bb13d", + "dv": "2d487c3ed089f92aab49122371d9e188" + }, + "sliding_window-1-128-127-1-1-1-256-256-4-dtype0": { + "dq": "de5b8ea1ea75c22469b8b1cbc3a93914", + "dk": "5a09909b80cf86c6884bf9e6f5c5978e", + "dv": "1b9ef991d2f90295f77f5954f2f8a69b" + }, + "causal_document-1-128-127-1-1-1-256-256-4-dtype0": { + "dq": "df18a0d1abfeb8446a9c95ebb6c5bb7a", + "dk": "91b8c277b97602d07bd3be63459e260f", + "dv": "0745dbae5ab1e5a6861914a8146c0c1f" + }, + "document-1-128-127-1-1-1-256-256-4-dtype0": { + "dq": "725102d50c45b7340fe51e18430e1969", + "dk": "27ad37ed335c2bff6f9ca04d9a392075", + "dv": "3b2aeb5818f1c3a66b5d6cc164f6eb10" + }, + "share_question-1-128-127-1-1-1-256-256-4-dtype0": { + "dq": "711bf8ab4c774dddb1ec7a92744cfb1a", + "dk": "28edbc0c20dafc20623defdf72a86e32", + "dv": "ec49941d5a0eae6030efbc70dff0aa26" + }, + "causal_blockwise-1-128-127-1-1-1-256-256-4-dtype0": { + "dq": "f17d378f010b51c35b1d4473368f67ee", + "dk": "26707d2058cda8c576213e0da9ad80d2", + "dv": "97df905e4bdb4f8ed81ef4fcf2627294" + }, + "prefix_lm_document_mask-1-128-127-1-1-1-256-256-4-dtype0": { + "dq": "31ea12e2253f38f27d351a2be20295fd", + "dk": "959ed2fe034e853aada72c40c686099e", + "dv": "bfdca160bfc216e55b9386d07a44d3ee" + }, + "prefix_lm_causal-1-128-127-1-1-1-256-256-4-dtype0": { + "dq": "2ca1e8bc4ddbe0d448b424e133641d50", + "dk": "f9a11a6bf508a2574abd31bbe35a1e17", + "dv": "7728355fb2d2dc8ecf46b7092a0a7034" + }, + "qk_sparse-1-128-127-1-1-1-256-256-4-dtype0": { + "dq": "ed5593574b09f2c51d38d9e39bfac94a", + "dk": "c7561e84f0cc74dfab75161f5b13b94e", + "dv": "5ae4250a2080af9438a2839924a73db7" + }, + "random_eviction-1-128-127-1-1-1-256-256-4-dtype0": { + "dq": "48d8d96bb1b5c2ed341062980e9d4e92", + "dk": "6bf887d178e1119b155291dbbf6e500e", + "dv": "ed77ba57b3bba6e5411fa3041907fc99" + }, "full-2-16384-16383-4-1-1-64-64-4-dtype0": { "dq": "9c036d6211d18b63ed7714aabffc94b3", "dk": "26794e040dd0751d70d917646232393b", @@ -1539,6 +1869,61 @@ "dk": "69697dbd1135b5ba0fe504a27e0d4f68", "dv": "a268e43dca9047eda452833c1df928fa" }, - "gt_commit_id": "7f507e048511557704cd270d58de360e07b536fa", + "full-2-16384-16383-4-1-1-256-256-4-dtype0": { + "dq": "17f61323d8ffac0a3be2d040cd44cc09", + "dk": "4ddb028b162b63267ea4d0f95f65318d", + "dv": "50b17b0ef0fb540f8abf3ffb32284ed0" + }, + "causal-2-16384-16383-4-1-1-256-256-4-dtype0": { + "dq": "0f19d0ef1e392ed837e5d384eb3493ef", + "dk": "1f8748f23336deb808a6c9560d10d5ab", + "dv": "4478212702917bce656e26eec1010a7e" + }, + "sliding_window-2-16384-16383-4-1-1-256-256-4-dtype0": { + "dq": "90d30161a97514341027b7dfa9e08c1e", + "dk": "1b5d91882771b430ed02336903cdf02b", + "dv": "adcb879f211dd35ad435c1297fc413db" + }, + "causal_document-2-16384-16383-4-1-1-256-256-4-dtype0": { + "dq": "f00a1429eba09c05302d4a38dc2ea486", + "dk": "114b240d2c4990634e4e6657420a5368", + "dv": "34bc535529eef334fdd69ff6efd5fa77" + }, + "document-2-16384-16383-4-1-1-256-256-4-dtype0": { + "dq": "035a23c0a0a21676e2f0d21ba88cafe8", + "dk": "195f2eb71d1143fa0bf3e97e2b40ec48", + "dv": "91979d12a63167fab2d375db4b1f90b5" + }, + "share_question-2-16384-16383-4-1-1-256-256-4-dtype0": { + "dq": "195ecada54db7f462beaf21bb682de88", + "dk": "efb9762ea789286b56c2e637dabe715e", + "dv": "5b67c9f6cb5718b4f23db7b4cbbd146f" + }, + "causal_blockwise-2-16384-16383-4-1-1-256-256-4-dtype0": { + "dq": "f33b726f801b37a37caff17f4c7386e5", + "dk": "e8e03cf115df7d663f9456e0ca55e3ad", + "dv": "3e6e19337fef2657759991f483dbc787" + }, + "prefix_lm_document_mask-2-16384-16383-4-1-1-256-256-4-dtype0": { + "dq": "5ea9cfc7ad449545435cf7d41fb423b6", + "dk": "a3c13cb7bd24d36f3570fb63c3d7011b", + "dv": "48014f28ef5c85b6fddeab7003424f57" + }, + "prefix_lm_causal-2-16384-16383-4-1-1-256-256-4-dtype0": { + "dq": "b9f8d9c8b1102e2e222a1d9570dc8ed1", + "dk": "6b4970fe997748d431d88a6fc5feed75", + "dv": "023d3126985e23d5438757c12f30d923" + }, + "qk_sparse-2-16384-16383-4-1-1-256-256-4-dtype0": { + "dq": "b8e8031f7b45382718d5ab35e87ca9aa", + "dk": "0ccc19c5cc1f8c5c8f0c124ce811c101", + "dv": "eec3ee32267d8fdf6bff096c89168117" + }, + "random_eviction-2-16384-16383-4-1-1-256-256-4-dtype0": { + "dq": "5785d85a3773004c626d27a0d1891131", + "dk": "8d6ad979ce374f229444839646b28940", + "dv": "c1ccdd1db12083e80641851499b5f457" + }, + "gt_commit_id": "e7805901d8a215494f6189fc0fa8e4c298b295a5", "gt_commit_msg": "refine whl version" } \ No newline at end of file diff --git a/flashmask_fwd_gt.json b/flashmask_fwd_gt.json index dee96e8..f2ede67 100644 --- a/flashmask_fwd_gt.json +++ b/flashmask_fwd_gt.json @@ -43,6 +43,17 @@ "prefix_lm_causal-1-8192-33792-2-1-1-192-128-4-dtype0": "8b3fce9d026d3d0f71f044634548d20b", "qk_sparse-1-8192-33792-2-1-1-192-128-4-dtype0": "f8ade49dbd32678c7c92a5539a2838b1", "random_eviction-1-8192-33792-2-1-1-192-128-4-dtype0": "4deb8ff9b3ee81e5c8528670b545268d", + "full-1-8192-33792-2-1-1-256-256-4-dtype0": "2512056c01dc2166825e1ce5896c6104", + "causal-1-8192-33792-2-1-1-256-256-4-dtype0": "d1fd9053e16318cbb6e619ac37640006", + "sliding_window-1-8192-33792-2-1-1-256-256-4-dtype0": "68fcfa2fb15d15c6ba811ad134320acf", + "causal_document-1-8192-33792-2-1-1-256-256-4-dtype0": "d1fd9053e16318cbb6e619ac37640006", + "document-1-8192-33792-2-1-1-256-256-4-dtype0": "bbc8db3ea7ec41e17642a2fc909f5b47", + "share_question-1-8192-33792-2-1-1-256-256-4-dtype0": "d1fd9053e16318cbb6e619ac37640006", + "causal_blockwise-1-8192-33792-2-1-1-256-256-4-dtype0": "d1fd9053e16318cbb6e619ac37640006", + "prefix_lm_document_mask-1-8192-33792-2-1-1-256-256-4-dtype0": "49d430bda1e37f0fbaf8c6987ec165b1", + "prefix_lm_causal-1-8192-33792-2-1-1-256-256-4-dtype0": "49d430bda1e37f0fbaf8c6987ec165b1", + "qk_sparse-1-8192-33792-2-1-1-256-256-4-dtype0": "ec49380afef526c3fc4ee6e7d91a6358", + "random_eviction-1-8192-33792-2-1-1-256-256-4-dtype0": "d1fd9053e16318cbb6e619ac37640006", "full-2840-32-32-16-4-1-64-64-4-dtype0": "1bb4902d0ccca7aead38c07e22b73fec", "causal-2840-32-32-16-4-1-64-64-4-dtype0": "01974b750ab35886a463d2d348c3ac00", "sliding_window-2840-32-32-16-4-1-64-64-4-dtype0": "e50620dd504429343ca63a21f3b9e277", @@ -87,6 +98,17 @@ "prefix_lm_causal-2840-32-32-16-4-1-192-128-4-dtype0": "296b8bc9f2e329cdacd03915060d4a78", "qk_sparse-2840-32-32-16-4-1-192-128-4-dtype0": "39b82ee44fe106cd6f854e0d70300ab5", "random_eviction-2840-32-32-16-4-1-192-128-4-dtype0": "86f7206a79fcb62347c0be7747cbbbbd", + "full-2840-32-32-16-4-1-256-256-4-dtype0": "db290790e3263adb1ad44c61ed5ee9ae", + "causal-2840-32-32-16-4-1-256-256-4-dtype0": "d22755d41575b3f9f6174ddee03fa6e4", + "sliding_window-2840-32-32-16-4-1-256-256-4-dtype0": "960464b113a63b4a9d756fd7d969b861", + "causal_document-2840-32-32-16-4-1-256-256-4-dtype0": "d34a288eb32d25752c9f86dc96e18a52", + "document-2840-32-32-16-4-1-256-256-4-dtype0": "acd0e3fd3612cd966415ba8bfe91822e", + "share_question-2840-32-32-16-4-1-256-256-4-dtype0": "726ef844a6e93fa713c1dd608a60f533", + "causal_blockwise-2840-32-32-16-4-1-256-256-4-dtype0": "d8d0d92d3fd631bbc49e38a7dabdd7c9", + "prefix_lm_document_mask-2840-32-32-16-4-1-256-256-4-dtype0": "dd45b95d387ae3525b9d3d0e1d5ea6cf", + "prefix_lm_causal-2840-32-32-16-4-1-256-256-4-dtype0": "c40bfd6dd27f4539405619e884c7258f", + "qk_sparse-2840-32-32-16-4-1-256-256-4-dtype0": "286fcb8839b0f962a0109b584eab48aa", + "random_eviction-2840-32-32-16-4-1-256-256-4-dtype0": "2f8e28db940670d444abab7b9752c64a", "full-2840-32-32-16-4-4-64-64-4-dtype0": "1bb4902d0ccca7aead38c07e22b73fec", "causal-2840-32-32-16-4-4-64-64-4-dtype0": "01974b750ab35886a463d2d348c3ac00", "sliding_window-2840-32-32-16-4-4-64-64-4-dtype0": "e50620dd504429343ca63a21f3b9e277", @@ -131,6 +153,17 @@ "prefix_lm_causal-2840-32-32-16-4-4-192-128-4-dtype0": "296b8bc9f2e329cdacd03915060d4a78", "qk_sparse-2840-32-32-16-4-4-192-128-4-dtype0": "39b82ee44fe106cd6f854e0d70300ab5", "random_eviction-2840-32-32-16-4-4-192-128-4-dtype0": "9495d566799894c24f3660a69c2e8663", + "full-2840-32-32-16-4-4-256-256-4-dtype0": "db290790e3263adb1ad44c61ed5ee9ae", + "causal-2840-32-32-16-4-4-256-256-4-dtype0": "d22755d41575b3f9f6174ddee03fa6e4", + "sliding_window-2840-32-32-16-4-4-256-256-4-dtype0": "960464b113a63b4a9d756fd7d969b861", + "causal_document-2840-32-32-16-4-4-256-256-4-dtype0": "d34a288eb32d25752c9f86dc96e18a52", + "document-2840-32-32-16-4-4-256-256-4-dtype0": "acd0e3fd3612cd966415ba8bfe91822e", + "share_question-2840-32-32-16-4-4-256-256-4-dtype0": "726ef844a6e93fa713c1dd608a60f533", + "causal_blockwise-2840-32-32-16-4-4-256-256-4-dtype0": "d8d0d92d3fd631bbc49e38a7dabdd7c9", + "prefix_lm_document_mask-2840-32-32-16-4-4-256-256-4-dtype0": "dd45b95d387ae3525b9d3d0e1d5ea6cf", + "prefix_lm_causal-2840-32-32-16-4-4-256-256-4-dtype0": "c40bfd6dd27f4539405619e884c7258f", + "qk_sparse-2840-32-32-16-4-4-256-256-4-dtype0": "286fcb8839b0f962a0109b584eab48aa", + "random_eviction-2840-32-32-16-4-4-256-256-4-dtype0": "a466d532b280bb2b2eb12b98e4ecbb32", "full-1-300-300-16-16-1-64-64-4-dtype0": "4ba18c71d08d97bc2e5a7f16051a12c6", "causal-1-300-300-16-16-1-64-64-4-dtype0": "121992fa6d459eb365186f4e409d62b9", "sliding_window-1-300-300-16-16-1-64-64-4-dtype0": "7a16cf529bb605d2ea78425d51616fc4", @@ -175,6 +208,17 @@ "prefix_lm_causal-1-300-300-16-16-1-192-128-4-dtype0": "6e9b93bd8e4d50253784bcf7c4cda7c0", "qk_sparse-1-300-300-16-16-1-192-128-4-dtype0": "5f0ac1efc9f23f21f368b691b7a5da68", "random_eviction-1-300-300-16-16-1-192-128-4-dtype0": "f218fb60e7525a71d031bcb8232f01d9", + "full-1-300-300-16-16-1-256-256-4-dtype0": "e0c1062f5217a2ba4ce4114f16c23609", + "causal-1-300-300-16-16-1-256-256-4-dtype0": "2568ddce35d354959dc196110720b021", + "sliding_window-1-300-300-16-16-1-256-256-4-dtype0": "b8123eb8ee62220a82d68592f305d879", + "causal_document-1-300-300-16-16-1-256-256-4-dtype0": "938070c9f49eced813a85d385e7a8f5e", + "document-1-300-300-16-16-1-256-256-4-dtype0": "24b57c40f82e1be8bc4f86de4298166c", + "share_question-1-300-300-16-16-1-256-256-4-dtype0": "43581148de9114c4e4bcfe9e67965468", + "causal_blockwise-1-300-300-16-16-1-256-256-4-dtype0": "2c71794ecb6150f2024694cfb4fe065c", + "prefix_lm_document_mask-1-300-300-16-16-1-256-256-4-dtype0": "1a6d83ca81d6db6ab6ab569e2cc0c102", + "prefix_lm_causal-1-300-300-16-16-1-256-256-4-dtype0": "b93f670d92b8b7dd2b155bc110cef0c4", + "qk_sparse-1-300-300-16-16-1-256-256-4-dtype0": "e1a308ffadcf2a66ef85240315fcdbba", + "random_eviction-1-300-300-16-16-1-256-256-4-dtype0": "66dec9ae4de7850b7331ce2614d518b6", "full-1-300-300-16-16-16-64-64-4-dtype0": "4ba18c71d08d97bc2e5a7f16051a12c6", "causal-1-300-300-16-16-16-64-64-4-dtype0": "121992fa6d459eb365186f4e409d62b9", "sliding_window-1-300-300-16-16-16-64-64-4-dtype0": "7a16cf529bb605d2ea78425d51616fc4", @@ -219,6 +263,17 @@ "prefix_lm_causal-1-300-300-16-16-16-192-128-4-dtype0": "6e9b93bd8e4d50253784bcf7c4cda7c0", "qk_sparse-1-300-300-16-16-16-192-128-4-dtype0": "5f0ac1efc9f23f21f368b691b7a5da68", "random_eviction-1-300-300-16-16-16-192-128-4-dtype0": "a91d14fe5c2b3fba9ebe5fe21ef45db9", + "full-1-300-300-16-16-16-256-256-4-dtype0": "e0c1062f5217a2ba4ce4114f16c23609", + "causal-1-300-300-16-16-16-256-256-4-dtype0": "2568ddce35d354959dc196110720b021", + "sliding_window-1-300-300-16-16-16-256-256-4-dtype0": "b8123eb8ee62220a82d68592f305d879", + "causal_document-1-300-300-16-16-16-256-256-4-dtype0": "938070c9f49eced813a85d385e7a8f5e", + "document-1-300-300-16-16-16-256-256-4-dtype0": "24b57c40f82e1be8bc4f86de4298166c", + "share_question-1-300-300-16-16-16-256-256-4-dtype0": "43581148de9114c4e4bcfe9e67965468", + "causal_blockwise-1-300-300-16-16-16-256-256-4-dtype0": "2c71794ecb6150f2024694cfb4fe065c", + "prefix_lm_document_mask-1-300-300-16-16-16-256-256-4-dtype0": "1a6d83ca81d6db6ab6ab569e2cc0c102", + "prefix_lm_causal-1-300-300-16-16-16-256-256-4-dtype0": "b93f670d92b8b7dd2b155bc110cef0c4", + "qk_sparse-1-300-300-16-16-16-256-256-4-dtype0": "e1a308ffadcf2a66ef85240315fcdbba", + "random_eviction-1-300-300-16-16-16-256-256-4-dtype0": "e47afa39ccf066f9fac6d81d711820fa", "full-1-128-127-1-1-1-64-64-4-dtype0": "a28ce01d7ba95891fba4d549b8b03ed3", "causal-1-128-127-1-1-1-64-64-4-dtype0": "1f4699648cb95f5a9b15d3bfc86f6a5f", "sliding_window-1-128-127-1-1-1-64-64-4-dtype0": "4031fbc9e9d39f6ad46ea19baf306522", @@ -263,6 +318,17 @@ "prefix_lm_causal-1-128-127-1-1-1-192-128-4-dtype0": "f1337e64cf2a201b95153c28ff6a519e", "qk_sparse-1-128-127-1-1-1-192-128-4-dtype0": "154ce297cb176261c902192975dc13f5", "random_eviction-1-128-127-1-1-1-192-128-4-dtype0": "ad172ed052a16668ee214e7c8fe715f3", + "full-1-128-127-1-1-1-256-256-4-dtype0": "0c9efc3a3dcc8d9155def6d2f32a811d", + "causal-1-128-127-1-1-1-256-256-4-dtype0": "ba93b2bec16f4c0c512098b7b997152c", + "sliding_window-1-128-127-1-1-1-256-256-4-dtype0": "d28d97dba639520c9cf0b98f98db25c8", + "causal_document-1-128-127-1-1-1-256-256-4-dtype0": "07204ce5572d8960a3018c44d01b61e5", + "document-1-128-127-1-1-1-256-256-4-dtype0": "4d12d442e05c8a0223027c7387f42180", + "share_question-1-128-127-1-1-1-256-256-4-dtype0": "3960ae3e447aea0ea82db0e486cf1419", + "causal_blockwise-1-128-127-1-1-1-256-256-4-dtype0": "2c97c3ac6d2156259b2c3317d9945586", + "prefix_lm_document_mask-1-128-127-1-1-1-256-256-4-dtype0": "3396a4d92548686199e698c1b15e8766", + "prefix_lm_causal-1-128-127-1-1-1-256-256-4-dtype0": "f9473b4d062903a538f29a8dd87bcefe", + "qk_sparse-1-128-127-1-1-1-256-256-4-dtype0": "2eca07e59e0b46a92e41ec975495926d", + "random_eviction-1-128-127-1-1-1-256-256-4-dtype0": "7d72c778008cca62e9e4a11b44ae0775", "full-2-16384-16383-4-1-1-64-64-4-dtype0": "7698afe4e07f559126e38a343b5a4b31", "causal-2-16384-16383-4-1-1-64-64-4-dtype0": "ce2aaabb573e07da055c844be4970e8b", "sliding_window-2-16384-16383-4-1-1-64-64-4-dtype0": "234f06ec8f7e655f6161dfa51b56e9f6", @@ -307,6 +373,17 @@ "prefix_lm_causal-2-16384-16383-4-1-1-192-128-4-dtype0": "8321c594cd4852d80b24fefa6ee40452", "qk_sparse-2-16384-16383-4-1-1-192-128-4-dtype0": "0c7d7d772ff6f578e285d5f6e1f24d92", "random_eviction-2-16384-16383-4-1-1-192-128-4-dtype0": "bfaeafd020eec761a14a96ab4d8ff1b9", - "gt_commit_id": "7f507e048511557704cd270d58de360e07b536fa", + "full-2-16384-16383-4-1-1-256-256-4-dtype0": "77043465011d68aabb26b59df891deff", + "causal-2-16384-16383-4-1-1-256-256-4-dtype0": "afc28e04f074ed6ae077868a108d4bef", + "sliding_window-2-16384-16383-4-1-1-256-256-4-dtype0": "0cbdd9065fdf761b3bf08165c531b73d", + "causal_document-2-16384-16383-4-1-1-256-256-4-dtype0": "5d708152a9a561b225ab921a80d0e9bf", + "document-2-16384-16383-4-1-1-256-256-4-dtype0": "e5ac1cd70ee7f2012878bf3c30f49a40", + "share_question-2-16384-16383-4-1-1-256-256-4-dtype0": "11e7c8a9f4f2eee9c46d3576a6eadf6f", + "causal_blockwise-2-16384-16383-4-1-1-256-256-4-dtype0": "1fddc698ed0f089ba7d0df2f563afce6", + "prefix_lm_document_mask-2-16384-16383-4-1-1-256-256-4-dtype0": "dba14a2580babc53aff3fd971d4ba0ff", + "prefix_lm_causal-2-16384-16383-4-1-1-256-256-4-dtype0": "92214a8df8b8b7396e37d728df473305", + "qk_sparse-2-16384-16383-4-1-1-256-256-4-dtype0": "29c00d2f4d6928d068a19bd66ad64646", + "random_eviction-2-16384-16383-4-1-1-256-256-4-dtype0": "ae53e82b44d17b365c560b0f1c9f5e46", + "gt_commit_id": "e7805901d8a215494f6189fc0fa8e4c298b295a5", "gt_commit_msg": "refine whl version" } \ No newline at end of file diff --git a/kernel_test_seq_info.txt b/kernel_test_seq_info.txt index 43aa47d..67141c3 100644 --- a/kernel_test_seq_info.txt +++ b/kernel_test_seq_info.txt @@ -1,3 +1,10 @@ +Total length: 4096, Document count range: (2, 6) +Sample 1, num_docs 3: [(443, 1573), (148, 1101), (264, 1422)]# [0, 1, 0] +Sample 2, num_docs 4: [(161, 1240), (123, 694), (214, 989), (179, 1173)]# [0, 0, 1, 0] +Sample 3, num_docs 5: [(123, 521), (369, 1359), (196, 710), (138, 596), (164, 910)]# [0, 1, 1, 0, 0] +Sample 4, num_docs 6: [(155, 929), (109, 436), (78, 265), (170, 1030), (69, 544), (208, 892)]# [0, 1, 0, 0, 1, 0] +Sample 5, num_docs 7: [(57, 572), (32, 204), (242, 854), (150, 594), (84, 569), (75, 458), (245, 845)]# [0, 0, 1, 0, 0, 1, 0] + Total length: 8192, Document count range: (2, 6) Sample 1, num_docs 3: [(885, 3145), (296, 2203), (528, 2844)]# [0, 1, 0] Sample 2, num_docs 4: [(321, 2481), (245, 1388), (428, 1978), (357, 2345)]# [0, 0, 1, 0] diff --git a/run.sh b/run.sh index b5b37b5..1ca7390 100644 --- a/run.sh +++ b/run.sh @@ -16,6 +16,7 @@ export FLAGS_alloc_fill_value=255 export FLAGS_use_system_allocator=1 export FLAGS_check_cuda_error=1 +# export FLAGS_cudnn_deterministic=1 # ========================= # 默认:单卡整文件测试 diff --git a/run_aadiff.sh b/run_aadiff.sh index a659bb5..06cec8b 100644 --- a/run_aadiff.sh +++ b/run_aadiff.sh @@ -1,19 +1,25 @@ export CUDA_VISIBLE_DEVICES=3 -export FLAGS_alloc_fill_value=255 -export FLAGS_use_system_allocator=1 -export FLAGS_check_cuda_error=1 -python -m pytest -v test_fwd_md5sum.py 2>&1 | tee test.log -# python -m pytest -v test_bwd_md5sum.py 2>&1 | tee test.log - -# export FLAGS_cudnn_deterministic=1 +export FLAGS_cudnn_deterministic=1 # run this if you want to update gt # python test_fwd_md5sum.py # python test_bwd_md5sum.py +export FLAGS_alloc_fill_value=255 +export FLAGS_use_system_allocator=1 +export FLAGS_check_cuda_error=1 + +python3 -m pytest \ + test_fwd_md5sum.py \ + test_bwd_md5sum.py \ + -v 2>&1 | tee test_md5.log + # if you update flash attention varlen -# python -m pytest -v test_fwd_varlen_md5sum.py 2>&1 | tee test.log -# python -m pytest -v test_bwd_varlen_md5sum.py 2>&1 | tee test.log -# run this if you want to update gt # python test_fwd_varlen_md5sum.py # python test_bwd_varlen_md5sum.py + +# python3 -m pytest \ +# test_fwd_varlen_md5sum.py \ +# test_bwd_varlen_md5sum.py \ +# -v 2>&1 | tee test_varlen_md5.log + diff --git a/test_bwd_md5sum.py b/test_bwd_md5sum.py index 79abfac..7e6b43d 100644 --- a/test_bwd_md5sum.py +++ b/test_bwd_md5sum.py @@ -52,6 +52,7 @@ (80, 80), (128, 128), (192, 128), + (256, 256), ] def record_gt(output_file="flashmask_bwd_gt.json"): diff --git a/test_flashmask.py b/test_flashmask.py index e27d413..02a7cce 100644 --- a/test_flashmask.py +++ b/test_flashmask.py @@ -136,8 +136,8 @@ def test_flashmask( startend_row_indices, causal = gen_startend_row_indices(batch_size, seqlen_q, seqlen_k, nheads_startend_row_indices) - if fa_version == 4 and seqlen_q != seqlen_k and causal and d > 128: - pytest.skip(f"Skipping because running fa4 and {d=} > 128 and {seqlen_q=} {seqlen_k} {causal=}") + if fa_version == 4 and seqlen_q != seqlen_k and causal and (d > 128 and d != 256): + pytest.skip(f"Skipping because running fa4 and {d=} > 128 (except 256) and {seqlen_q=} {seqlen_k} {causal=}") if fa_version == 2 and seqlen_q != seqlen_k and causal: pytest.skip(f"Skipping because running fa2 in causal when seqlen_q != seqlen_k") diff --git a/test_fwd_md5sum.py b/test_fwd_md5sum.py index db3f20d..09faad9 100644 --- a/test_fwd_md5sum.py +++ b/test_fwd_md5sum.py @@ -52,6 +52,7 @@ (80, 80), (128, 128), (192, 128), + (256, 256), ] def record_gt(output_file="flashmask_fwd_gt.json"):