Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 17 additions & 9 deletions crates/core_arch/src/x86/avx2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1813,14 +1813,20 @@ pub const fn _mm256_inserti128_si256<const IMM1: i32>(a: __m256i, b: __m128i) ->
#[target_feature(enable = "avx2")]
#[cfg_attr(test, assert_instr(vpmaddwd))]
#[stable(feature = "simd_x86", since = "1.27.0")]
#[rustc_const_unstable(feature = "stdarch_const_x86", issue = "149298")]
pub const fn _mm256_madd_epi16(a: __m256i, b: __m256i) -> __m256i {
unsafe {
let r: i32x16 = simd_mul(simd_cast(a.as_i16x16()), simd_cast(b.as_i16x16()));
let even: i32x8 = simd_shuffle!(r, r, [0, 2, 4, 6, 8, 10, 12, 14]);
let odd: i32x8 = simd_shuffle!(r, r, [1, 3, 5, 7, 9, 11, 13, 15]);
simd_add(even, odd).as_m256i()
}
pub fn _mm256_madd_epi16(a: __m256i, b: __m256i) -> __m256i {
// It's a trick used in the Adler-32 algorithm to perform a widening addition.
//
// ```rust
// #[target_feature(enable = "avx2")]
// unsafe fn widening_add(mad: __m256i) -> __m256i {
// _mm256_madd_epi16(mad, _mm256_set1_epi16(1))
// }
// ```
//
// If we implement this using generic vector intrinsics, the optimizer
// will eliminate this pattern, and `vpmaddwd` will no longer be emitted.
// For this reason, we use x86 intrinsics.
unsafe { transmute(pmaddwd(a.as_i16x16(), b.as_i16x16())) }
}

/// Vertically multiplies each unsigned 8-bit integer from `a` with the
Expand Down Expand Up @@ -3789,6 +3795,8 @@ unsafe extern "C" {
fn phaddsw(a: i16x16, b: i16x16) -> i16x16;
#[link_name = "llvm.x86.avx2.phsub.sw"]
fn phsubsw(a: i16x16, b: i16x16) -> i16x16;
#[link_name = "llvm.x86.avx2.pmadd.wd"]
fn pmaddwd(a: i16x16, b: i16x16) -> i32x8;
#[link_name = "llvm.x86.avx2.pmadd.ub.sw"]
fn pmaddubsw(a: u8x32, b: i8x32) -> i16x16;
#[link_name = "llvm.x86.avx2.mpsadbw"]
Expand Down Expand Up @@ -4637,7 +4645,7 @@ mod tests {
}

#[simd_test(enable = "avx2")]
const fn test_mm256_madd_epi16() {
fn test_mm256_madd_epi16() {
let a = _mm256_set1_epi16(2);
let b = _mm256_set1_epi16(4);
let r = _mm256_madd_epi16(a, b);
Expand Down
64 changes: 29 additions & 35 deletions crates/core_arch/src/x86/avx512bw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6321,22 +6321,20 @@ pub const unsafe fn _mm_mask_storeu_epi8(mem_addr: *mut i8, mask: __mmask16, a:
#[target_feature(enable = "avx512bw")]
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
#[cfg_attr(test, assert_instr(vpmaddwd))]
#[rustc_const_unstable(feature = "stdarch_const_x86", issue = "149298")]
pub const fn _mm512_madd_epi16(a: __m512i, b: __m512i) -> __m512i {
unsafe {
let r: i32x32 = simd_mul(simd_cast(a.as_i16x32()), simd_cast(b.as_i16x32()));
let even: i32x16 = simd_shuffle!(
r,
r,
[0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30]
);
let odd: i32x16 = simd_shuffle!(
r,
r,
[1, 3, 5, 7, 9, 11, 13, 15, 17, 19, 21, 23, 25, 27, 29, 31]
);
simd_add(even, odd).as_m512i()
}
pub fn _mm512_madd_epi16(a: __m512i, b: __m512i) -> __m512i {
// It's a trick used in the Adler-32 algorithm to perform a widening addition.
//
// ```rust
// #[target_feature(enable = "avx512bw")]
// unsafe fn widening_add(mad: __m512i) -> __m512i {
// _mm512_madd_epi16(mad, _mm512_set1_epi16(1))
// }
// ```
//
// If we implement this using generic vector intrinsics, the optimizer
// will eliminate this pattern, and `vpmaddwd` will no longer be emitted.
// For this reason, we use x86 intrinsics.
unsafe { transmute(vpmaddwd(a.as_i16x32(), b.as_i16x32())) }
}

/// Multiply packed signed 16-bit integers in a and b, producing intermediate signed 32-bit integers. Horizontally add adjacent pairs of intermediate 32-bit integers, and pack the results in dst using writemask k (elements are copied from src when the corresponding mask bit is not set).
Expand All @@ -6346,8 +6344,7 @@ pub const fn _mm512_madd_epi16(a: __m512i, b: __m512i) -> __m512i {
#[target_feature(enable = "avx512bw")]
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
#[cfg_attr(test, assert_instr(vpmaddwd))]
#[rustc_const_unstable(feature = "stdarch_const_x86", issue = "149298")]
pub const fn _mm512_mask_madd_epi16(src: __m512i, k: __mmask16, a: __m512i, b: __m512i) -> __m512i {
pub fn _mm512_mask_madd_epi16(src: __m512i, k: __mmask16, a: __m512i, b: __m512i) -> __m512i {
unsafe {
let madd = _mm512_madd_epi16(a, b).as_i32x16();
transmute(simd_select_bitmask(k, madd, src.as_i32x16()))
Expand All @@ -6361,8 +6358,7 @@ pub const fn _mm512_mask_madd_epi16(src: __m512i, k: __mmask16, a: __m512i, b: _
#[target_feature(enable = "avx512bw")]
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
#[cfg_attr(test, assert_instr(vpmaddwd))]
#[rustc_const_unstable(feature = "stdarch_const_x86", issue = "149298")]
pub const fn _mm512_maskz_madd_epi16(k: __mmask16, a: __m512i, b: __m512i) -> __m512i {
pub fn _mm512_maskz_madd_epi16(k: __mmask16, a: __m512i, b: __m512i) -> __m512i {
unsafe {
let madd = _mm512_madd_epi16(a, b).as_i32x16();
transmute(simd_select_bitmask(k, madd, i32x16::ZERO))
Expand All @@ -6376,8 +6372,7 @@ pub const fn _mm512_maskz_madd_epi16(k: __mmask16, a: __m512i, b: __m512i) -> __
#[target_feature(enable = "avx512bw,avx512vl")]
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
#[cfg_attr(test, assert_instr(vpmaddwd))]
#[rustc_const_unstable(feature = "stdarch_const_x86", issue = "149298")]
pub const fn _mm256_mask_madd_epi16(src: __m256i, k: __mmask8, a: __m256i, b: __m256i) -> __m256i {
pub fn _mm256_mask_madd_epi16(src: __m256i, k: __mmask8, a: __m256i, b: __m256i) -> __m256i {
unsafe {
let madd = _mm256_madd_epi16(a, b).as_i32x8();
transmute(simd_select_bitmask(k, madd, src.as_i32x8()))
Expand All @@ -6391,8 +6386,7 @@ pub const fn _mm256_mask_madd_epi16(src: __m256i, k: __mmask8, a: __m256i, b: __
#[target_feature(enable = "avx512bw,avx512vl")]
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
#[cfg_attr(test, assert_instr(vpmaddwd))]
#[rustc_const_unstable(feature = "stdarch_const_x86", issue = "149298")]
pub const fn _mm256_maskz_madd_epi16(k: __mmask8, a: __m256i, b: __m256i) -> __m256i {
pub fn _mm256_maskz_madd_epi16(k: __mmask8, a: __m256i, b: __m256i) -> __m256i {
unsafe {
let madd = _mm256_madd_epi16(a, b).as_i32x8();
transmute(simd_select_bitmask(k, madd, i32x8::ZERO))
Expand All @@ -6406,8 +6400,7 @@ pub const fn _mm256_maskz_madd_epi16(k: __mmask8, a: __m256i, b: __m256i) -> __m
#[target_feature(enable = "avx512bw,avx512vl")]
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
#[cfg_attr(test, assert_instr(vpmaddwd))]
#[rustc_const_unstable(feature = "stdarch_const_x86", issue = "149298")]
pub const fn _mm_mask_madd_epi16(src: __m128i, k: __mmask8, a: __m128i, b: __m128i) -> __m128i {
pub fn _mm_mask_madd_epi16(src: __m128i, k: __mmask8, a: __m128i, b: __m128i) -> __m128i {
unsafe {
let madd = _mm_madd_epi16(a, b).as_i32x4();
transmute(simd_select_bitmask(k, madd, src.as_i32x4()))
Expand All @@ -6421,8 +6414,7 @@ pub const fn _mm_mask_madd_epi16(src: __m128i, k: __mmask8, a: __m128i, b: __m12
#[target_feature(enable = "avx512bw,avx512vl")]
#[stable(feature = "stdarch_x86_avx512", since = "1.89")]
#[cfg_attr(test, assert_instr(vpmaddwd))]
#[rustc_const_unstable(feature = "stdarch_const_x86", issue = "149298")]
pub const fn _mm_maskz_madd_epi16(k: __mmask8, a: __m128i, b: __m128i) -> __m128i {
pub fn _mm_maskz_madd_epi16(k: __mmask8, a: __m128i, b: __m128i) -> __m128i {
unsafe {
let madd = _mm_madd_epi16(a, b).as_i32x4();
transmute(simd_select_bitmask(k, madd, i32x4::ZERO))
Expand Down Expand Up @@ -12582,6 +12574,8 @@ unsafe extern "C" {
#[link_name = "llvm.x86.avx512.pmul.hr.sw.512"]
fn vpmulhrsw(a: i16x32, b: i16x32) -> i16x32;

#[link_name = "llvm.x86.avx512.pmaddw.d.512"]
fn vpmaddwd(a: i16x32, b: i16x32) -> i32x16;
#[link_name = "llvm.x86.avx512.pmaddubs.w.512"]
fn vpmaddubsw(a: u8x64, b: i8x64) -> i16x32;

Expand Down Expand Up @@ -17486,7 +17480,7 @@ mod tests {
}

#[simd_test(enable = "avx512bw")]
const fn test_mm512_madd_epi16() {
fn test_mm512_madd_epi16() {
let a = _mm512_set1_epi16(1);
let b = _mm512_set1_epi16(1);
let r = _mm512_madd_epi16(a, b);
Expand All @@ -17495,7 +17489,7 @@ mod tests {
}

#[simd_test(enable = "avx512bw")]
const fn test_mm512_mask_madd_epi16() {
fn test_mm512_mask_madd_epi16() {
let a = _mm512_set1_epi16(1);
let b = _mm512_set1_epi16(1);
let r = _mm512_mask_madd_epi16(a, 0, a, b);
Expand Down Expand Up @@ -17523,7 +17517,7 @@ mod tests {
}

#[simd_test(enable = "avx512bw")]
const fn test_mm512_maskz_madd_epi16() {
fn test_mm512_maskz_madd_epi16() {
let a = _mm512_set1_epi16(1);
let b = _mm512_set1_epi16(1);
let r = _mm512_maskz_madd_epi16(0, a, b);
Expand All @@ -17534,7 +17528,7 @@ mod tests {
}

#[simd_test(enable = "avx512bw,avx512vl")]
const fn test_mm256_mask_madd_epi16() {
fn test_mm256_mask_madd_epi16() {
let a = _mm256_set1_epi16(1);
let b = _mm256_set1_epi16(1);
let r = _mm256_mask_madd_epi16(a, 0, a, b);
Expand All @@ -17554,7 +17548,7 @@ mod tests {
}

#[simd_test(enable = "avx512bw,avx512vl")]
const fn test_mm256_maskz_madd_epi16() {
fn test_mm256_maskz_madd_epi16() {
let a = _mm256_set1_epi16(1);
let b = _mm256_set1_epi16(1);
let r = _mm256_maskz_madd_epi16(0, a, b);
Expand All @@ -17565,7 +17559,7 @@ mod tests {
}

#[simd_test(enable = "avx512bw,avx512vl")]
const fn test_mm_mask_madd_epi16() {
fn test_mm_mask_madd_epi16() {
let a = _mm_set1_epi16(1);
let b = _mm_set1_epi16(1);
let r = _mm_mask_madd_epi16(a, 0, a, b);
Expand All @@ -17576,7 +17570,7 @@ mod tests {
}

#[simd_test(enable = "avx512bw,avx512vl")]
const fn test_mm_maskz_madd_epi16() {
fn test_mm_maskz_madd_epi16() {
let a = _mm_set1_epi16(1);
let b = _mm_set1_epi16(1);
let r = _mm_maskz_madd_epi16(0, a, b);
Expand Down
26 changes: 17 additions & 9 deletions crates/core_arch/src/x86/sse2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,14 +210,20 @@ pub const fn _mm_avg_epu16(a: __m128i, b: __m128i) -> __m128i {
#[target_feature(enable = "sse2")]
#[cfg_attr(test, assert_instr(pmaddwd))]
#[stable(feature = "simd_x86", since = "1.27.0")]
#[rustc_const_unstable(feature = "stdarch_const_x86", issue = "149298")]
pub const fn _mm_madd_epi16(a: __m128i, b: __m128i) -> __m128i {
unsafe {
let r: i32x8 = simd_mul(simd_cast(a.as_i16x8()), simd_cast(b.as_i16x8()));
let even: i32x4 = simd_shuffle!(r, r, [0, 2, 4, 6]);
let odd: i32x4 = simd_shuffle!(r, r, [1, 3, 5, 7]);
simd_add(even, odd).as_m128i()
}
pub fn _mm_madd_epi16(a: __m128i, b: __m128i) -> __m128i {
// It's a trick used in the Adler-32 algorithm to perform a widening addition.
//
// ```rust
// #[target_feature(enable = "sse2")]
// unsafe fn widening_add(mad: __m128i) -> __m128i {
// _mm_madd_epi16(mad, _mm_set1_epi16(1))
// }
// ```
//
// If we implement this using generic vector intrinsics, the optimizer
// will eliminate this pattern, and `pmaddwd` will no longer be emitted.
// For this reason, we use x86 intrinsics.
unsafe { transmute(pmaddwd(a.as_i16x8(), b.as_i16x8())) }
}
Comment on lines -213 to 227
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could consider using https://doc.rust-lang.org/std/intrinsics/fn.const_eval_select.html so we don't loose all of the const stuff. Up to @sayantn though, I don't have full context on what we'd like to be const fn currently.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#[target_feature] functions do not implement the Fn traits, while const_eval_select restricts FnOnce. So this does not seem feasible.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can still do it, the inner function which will be invoked doesn't need to have the correct target features if marked #[inline] (#[inline(always)] will probably be better) Godbolt. But I don't think it is that important to make this const right now. I think a better approach will be fixing the LLVM bug and make this truly const in future, but this should be enough for the patch


/// Compares packed 16-bit integers in `a` and `b`, and returns the packed
Expand Down Expand Up @@ -3187,6 +3193,8 @@ unsafe extern "C" {
fn lfence();
#[link_name = "llvm.x86.sse2.mfence"]
fn mfence();
#[link_name = "llvm.x86.sse2.pmadd.wd"]
fn pmaddwd(a: i16x8, b: i16x8) -> i32x4;
#[link_name = "llvm.x86.sse2.psad.bw"]
fn psadbw(a: u8x16, b: u8x16) -> u64x2;
#[link_name = "llvm.x86.sse2.psll.w"]
Expand Down Expand Up @@ -3467,7 +3475,7 @@ mod tests {
}

#[simd_test(enable = "sse2")]
const fn test_mm_madd_epi16() {
fn test_mm_madd_epi16() {
let a = _mm_setr_epi16(1, 2, 3, 4, 5, 6, 7, 8);
let b = _mm_setr_epi16(9, 10, 11, 12, 13, 14, 15, 16);
let r = _mm_madd_epi16(a, b);
Expand Down