Commit 859e386176 for aom
commit 859e3861762eccae630c440278b4eb124c03eeff
Author: Jerome Jiang <jianj@google.com>
Date: Mon May 11 12:55:17 2026 -0400
Update highway to adf91c746
Change-Id: I50303d4fde8ae1f06ca96be0926965dacc761f47
diff --git a/third_party/highway/README.libaom b/third_party/highway/README.libaom
index 7fc5db9998..b19ee19117 100644
--- a/third_party/highway/README.libaom
+++ b/third_party/highway/README.libaom
@@ -3,7 +3,7 @@ Short Name: highway
URL: https://github.com/google/highway
Version: N/A
Update Mechanism: Manual
-Revision: e92c12750d18c372867809b882dd3ec6874ecc73
+Revision: adf91c746e70c58389e5793e658ba4d04051b138
License: BSD-3-Clause
License File: LICENSE-BSD3
Shipped in Chromium: no
@@ -14,3 +14,6 @@ Highway is a C++ library that provides portable SIMD/vector intrinsics.
Local Changes:
Remove everything except hwy/ and LICENSE-BSD3
+Remove hwy/examples and hwy/tests
+Remove hwy/*test.cc
+Update include path by adding "third_party/highway"
diff --git a/third_party/highway/hwy/abort.cc b/third_party/highway/hwy/abort.cc
new file mode 100644
index 0000000000..3f535e01eb
--- /dev/null
+++ b/third_party/highway/hwy/abort.cc
@@ -0,0 +1,117 @@
+// Copyright 2019 Google LLC
+// Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
+// SPDX-License-Identifier: Apache-2.0
+// SPDX-License-Identifier: BSD-3-Clause
+
+#include "third_party/highway/hwy/abort.h"
+
+#include <stdarg.h>
+#include <stdio.h>
+#include <stdlib.h>
+
+#include <atomic>
+#include <string>
+
+#include "third_party/highway/hwy/base.h"
+
+#if HWY_IS_ASAN || HWY_IS_MSAN || HWY_IS_TSAN
+#include "sanitizer/common_interface_defs.h" // __sanitizer_print_stack_trace
+#endif
+
+namespace hwy {
+
+namespace {
+
+std::atomic<WarnFunc>& AtomicWarnFunc() {
+ static std::atomic<WarnFunc> func;
+ return func;
+}
+
+std::atomic<AbortFunc>& AtomicAbortFunc() {
+ static std::atomic<AbortFunc> func;
+ return func;
+}
+
+std::string GetBaseName(std::string const& file_name) {
+ auto last_slash = file_name.find_last_of("/\\");
+ return file_name.substr(last_slash + 1);
+}
+
+} // namespace
+
+// Returning a reference is unfortunately incompatible with `std::atomic`, which
+// is required to safely implement `SetWarnFunc`. As a workaround, we store a
+// copy here, update it when called, and return a reference to the copy. This
+// has the added benefit of protecting the actual pointer from modification.
+HWY_DLLEXPORT WarnFunc& GetWarnFunc() {
+ static WarnFunc func;
+ func = AtomicWarnFunc().load();
+ return func;
+}
+
+HWY_DLLEXPORT AbortFunc& GetAbortFunc() {
+ static AbortFunc func;
+ func = AtomicAbortFunc().load();
+ return func;
+}
+
+HWY_DLLEXPORT WarnFunc SetWarnFunc(WarnFunc func) {
+ return AtomicWarnFunc().exchange(func);
+}
+
+HWY_DLLEXPORT AbortFunc SetAbortFunc(AbortFunc func) {
+ return AtomicAbortFunc().exchange(func);
+}
+
+HWY_DLLEXPORT void HWY_FORMAT(3, 4)
+ Warn(const char* file, int line, const char* format, ...) {
+ char buf[800];
+ va_list args;
+ va_start(args, format);
+ vsnprintf(buf, sizeof(buf), format, args);
+ va_end(args);
+
+ WarnFunc handler = AtomicWarnFunc().load();
+ if (handler != nullptr) {
+ handler(file, line, buf);
+ } else {
+ fprintf(stderr, "Warn at %s:%d: %s\n", GetBaseName(file).data(), line, buf);
+ }
+}
+
+HWY_DLLEXPORT HWY_NORETURN void HWY_FORMAT(3, 4)
+ Abort(const char* file, int line, const char* format, ...) {
+ char buf[800];
+ va_list args;
+ va_start(args, format);
+ vsnprintf(buf, sizeof(buf), format, args);
+ va_end(args);
+
+ AbortFunc handler = AtomicAbortFunc().load();
+ if (handler != nullptr) {
+ handler(file, line, buf);
+ } else {
+ fprintf(stderr, "Abort at %s:%d: %s\n", GetBaseName(file).data(), line,
+ buf);
+ }
+
+// If compiled with any sanitizer, they can also print a stack trace.
+#if HWY_IS_ASAN || HWY_IS_MSAN || HWY_IS_TSAN
+ __sanitizer_print_stack_trace();
+#endif // HWY_IS_*
+ fflush(stderr);
+
+// Now terminate the program:
+#if HWY_ARCH_RISCV
+ exit(1); // trap/abort just freeze Spike.
+#elif HWY_IS_DEBUG_BUILD && !HWY_COMPILER_MSVC && !HWY_ARCH_ARM
+ // Facilitates breaking into a debugger, but don't use this in non-debug
+ // builds because it looks like "illegal instruction", which is misleading.
+ // Also does not work on Arm.
+ __builtin_trap();
+#else
+ abort(); // Compile error without this due to HWY_NORETURN.
+#endif
+}
+
+} // namespace hwy
diff --git a/third_party/highway/hwy/aligned_allocator.cc b/third_party/highway/hwy/aligned_allocator.cc
new file mode 100644
index 0000000000..a08f56a97c
--- /dev/null
+++ b/third_party/highway/hwy/aligned_allocator.cc
@@ -0,0 +1,156 @@
+// Copyright 2019 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "third_party/highway/hwy/aligned_allocator.h"
+
+#include <stdint.h>
+#include <stdio.h>
+#include <stdlib.h> // malloc
+
+#include <atomic>
+#include <limits>
+
+#include "third_party/highway/hwy/base.h"
+
+namespace hwy {
+namespace {
+
+#if HWY_ARCH_RISCV && defined(__riscv_v_intrinsic) && \
+ __riscv_v_intrinsic >= 11000
+// Not actually an upper bound on the size, but this value prevents crossing a
+// 4K boundary (relevant on Andes).
+constexpr size_t kAlignment = HWY_MAX(HWY_ALIGNMENT, 4096);
+#else
+constexpr size_t kAlignment = HWY_ALIGNMENT;
+#endif
+
+#if HWY_ARCH_X86
+// On x86, aliasing can only occur at multiples of 2K. To reduce the chance of
+// allocations being equal mod 2K, we round up to kAlias and add a cyclic
+// offset which is a multiple of kAlignment. Rounding up to only 1K decreases
+// the number of alias-free allocations, but also wastes less memory.
+constexpr size_t kAlias = HWY_MAX(kAlignment, 1024);
+#else
+constexpr size_t kAlias = kAlignment;
+#endif
+
+#pragma pack(push, 1)
+struct AllocationHeader {
+ void* allocated;
+ size_t payload_size;
+};
+#pragma pack(pop)
+
+// Returns a 'random' (cyclical) offset for AllocateAlignedBytes.
+size_t NextAlignedOffset() {
+ static std::atomic<size_t> next{0};
+ static_assert(kAlias % kAlignment == 0, "kAlias must be a multiple");
+ constexpr size_t kGroups = kAlias / kAlignment;
+ const size_t group = next.fetch_add(1, std::memory_order_relaxed) % kGroups;
+ const size_t offset = kAlignment * group;
+ HWY_DASSERT((offset % kAlignment == 0) && offset <= kAlias);
+ return offset;
+}
+
+} // namespace
+
+HWY_DLLEXPORT void* AllocateAlignedBytes(const size_t payload_size,
+ AllocPtr alloc_ptr, void* opaque_ptr) {
+ HWY_ASSERT(payload_size != 0); // likely a bug in caller
+ if (payload_size >= std::numeric_limits<size_t>::max() / 2) {
+ HWY_DASSERT(false && "payload_size too large");
+ return nullptr;
+ }
+
+ size_t offset = NextAlignedOffset();
+
+ // What: | misalign | unused | AllocationHeader |payload
+ // Size: |<= kAlias | offset |payload_size
+ // ^allocated.^aligned.^header............^payload
+ // The header must immediately precede payload, which must remain aligned.
+ // To avoid wasting space, the header resides at the end of `unused`,
+ // which therefore cannot be empty (offset == 0).
+ if (offset == 0) {
+ offset = RoundUpTo(sizeof(AllocationHeader), kAlignment);
+ }
+
+ const size_t allocated_size = kAlias + offset + payload_size;
+ void* allocated;
+ if (alloc_ptr == nullptr) {
+ allocated = malloc(allocated_size);
+ } else {
+ allocated = (*alloc_ptr)(opaque_ptr, allocated_size);
+ }
+ if (allocated == nullptr) return nullptr;
+ // Always round up even if already aligned - we already asked for kAlias
+ // extra bytes and there's no way to give them back.
+ uintptr_t aligned = reinterpret_cast<uintptr_t>(allocated) + kAlias;
+ static_assert((kAlias & (kAlias - 1)) == 0, "kAlias must be a power of 2");
+ static_assert(kAlias >= kAlignment, "Cannot align to more than kAlias");
+ aligned &= ~(kAlias - 1);
+
+ const uintptr_t payload = aligned + offset; // still aligned
+ HWY_DASSERT(payload % kAlignment == 0);
+
+ // Stash `allocated` and payload_size inside header for FreeAlignedBytes().
+ // The allocated_size can be reconstructed from the payload_size.
+ AllocationHeader* header = reinterpret_cast<AllocationHeader*>(payload) - 1;
+ HWY_DASSERT(reinterpret_cast<uintptr_t>(header) >= aligned);
+ header->allocated = allocated;
+ header->payload_size = payload_size;
+
+ return HWY_ASSUME_ALIGNED(reinterpret_cast<void*>(payload), kAlignment);
+}
+
+HWY_DLLEXPORT void FreeAlignedBytes(const void* aligned_pointer,
+ FreePtr free_ptr, void* opaque_ptr) {
+ if (aligned_pointer == nullptr) return;
+
+ const uintptr_t payload = reinterpret_cast<uintptr_t>(aligned_pointer);
+ HWY_DASSERT(payload % kAlignment == 0);
+ const AllocationHeader* header =
+ reinterpret_cast<const AllocationHeader*>(payload) - 1;
+
+ if (free_ptr == nullptr) {
+ free(header->allocated);
+ } else {
+ (*free_ptr)(opaque_ptr, header->allocated);
+ }
+}
+
+// static
+HWY_DLLEXPORT void AlignedDeleter::DeleteAlignedArray(void* aligned_pointer,
+ FreePtr free_ptr,
+ void* opaque_ptr,
+ ArrayDeleter deleter) {
+ if (aligned_pointer == nullptr) return;
+
+ const uintptr_t payload = reinterpret_cast<uintptr_t>(aligned_pointer);
+ HWY_DASSERT(payload % kAlignment == 0);
+ const AllocationHeader* header =
+ reinterpret_cast<const AllocationHeader*>(payload) - 1;
+
+ if (deleter) {
+ (*deleter)(aligned_pointer, header->payload_size);
+ }
+
+ if (free_ptr == nullptr) {
+ free(header->allocated);
+ } else {
+ (*free_ptr)(opaque_ptr, header->allocated);
+ }
+}
+
+} // namespace hwy
diff --git a/third_party/highway/hwy/aligned_allocator.h b/third_party/highway/hwy/aligned_allocator.h
index 149f18e65c..17d8b4c2e9 100644
--- a/third_party/highway/hwy/aligned_allocator.h
+++ b/third_party/highway/hwy/aligned_allocator.h
@@ -40,6 +40,7 @@ namespace hwy {
// access pairs of lines, and M1 L2 and POWER8 lines are also 128 bytes.
#define HWY_ALIGNMENT 128
+// `align` is in bytes.
template <typename T>
HWY_API constexpr bool IsAligned(T* ptr, size_t align = HWY_ALIGNMENT) {
return reinterpret_cast<uintptr_t>(ptr) % align == 0;
@@ -118,8 +119,21 @@ template <typename T, typename... Args>
AlignedUniquePtr<T> MakeUniqueAlignedWithAlloc(AllocPtr alloc, FreePtr free,
void* opaque, Args&&... args) {
T* ptr = static_cast<T*>(AllocateAlignedBytes(sizeof(T), alloc, opaque));
+ if (HWY_UNLIKELY(ptr == nullptr)) {
+ return AlignedUniquePtr<T>(nullptr, AlignedDeleter(free, opaque));
+ }
+#ifdef HWY_EXCEPTIONS_ENABLED
+ try {
+ return AlignedUniquePtr<T>(new (ptr) T(std::forward<Args>(args)...),
+ AlignedDeleter(free, opaque));
+ } catch (...) {
+ FreeAlignedBytes(ptr, free, opaque);
+ throw;
+ }
+#else
return AlignedUniquePtr<T>(new (ptr) T(std::forward<Args>(args)...),
AlignedDeleter(free, opaque));
+#endif
}
// Similar to MakeUniqueAlignedWithAlloc but using the default alloc/free
@@ -127,8 +141,21 @@ AlignedUniquePtr<T> MakeUniqueAlignedWithAlloc(AllocPtr alloc, FreePtr free,
template <typename T, typename... Args>
AlignedUniquePtr<T> MakeUniqueAligned(Args&&... args) {
T* ptr = static_cast<T*>(AllocateAlignedBytes(sizeof(T)));
+ if (HWY_UNLIKELY(ptr == nullptr)) {
+ return AlignedUniquePtr<T>(nullptr, AlignedDeleter());
+ }
+#ifdef HWY_EXCEPTIONS_ENABLED
+ try {
+ return AlignedUniquePtr<T>(new (ptr) T(std::forward<Args>(args)...),
+ AlignedDeleter());
+ } catch (...) {
+ FreeAlignedBytes(ptr, nullptr, nullptr);
+ throw;
+ }
+#else
return AlignedUniquePtr<T>(new (ptr) T(std::forward<Args>(args)...),
AlignedDeleter());
+#endif
}
template <class T>
@@ -146,8 +173,13 @@ struct AlignedAllocator {
"AlignedAllocator only supports integer types");
static_assert(sizeof(V) <= sizeof(std::size_t),
"V n must be smaller or equal size_t to avoid overflow");
+ const size_t count = static_cast<std::size_t>(n);
+ if (HWY_LIKELY(count != 0) && sizeof(value_type) > SIZE_MAX / count) {
+ HWY_ABORT("AlignedAllocator: allocation size overflow "
+ "(%zu * %zu exceeds size_t)", count, sizeof(value_type));
+ }
return static_cast<value_type*>(
- AllocateAlignedBytes(static_cast<std::size_t>(n) * sizeof(value_type)));
+ AllocateAlignedBytes(count * sizeof(value_type)));
}
template <class V>
@@ -207,7 +239,7 @@ AlignedUniquePtr<T[]> MakeUniqueAlignedArrayWithAlloc(
T* ptr = detail::AllocateAlignedItems<T>(items, alloc, opaque);
if (ptr != nullptr) {
for (size_t i = 0; i < items; i++) {
- new (ptr + i) T(std::forward<Args>(args)...);
+ new (ptr + i) T(args...);
}
}
return AlignedUniquePtr<T[]>(ptr, AlignedDeleter(free, opaque));
@@ -398,6 +430,10 @@ class AlignedNDArray {
size_t offset = 0;
size_t shape_index = 0;
for (const size_t axis_index : indices) {
+ if (HWY_UNLIKELY(axis_index >= shape_[shape_index])) {
+ HWY_ABORT("AlignedNDArray index %zu out of bounds (axis %zu, size %zu)",
+ axis_index, shape_index, shape_[shape_index]);
+ }
offset += memory_sizes_[shape_index + 1] * axis_index;
shape_index++;
}
@@ -416,6 +452,13 @@ class AlignedNDArray {
sizes[axis] = 1;
while (axis > 0) {
--axis;
+ // Check for integer overflow in dimension multiplication.
+ if (HWY_LIKELY(shape[axis] != 0) &&
+ sizes[axis + 1] > SIZE_MAX / shape[axis]) {
+ HWY_ABORT("AlignedNDArray: dimension overflow at axis %zu "
+ "(%zu * %zu exceeds size_t)",
+ axis, sizes[axis + 1], shape[axis]);
+ }
sizes[axis] = sizes[axis + 1] * shape[axis];
}
return sizes;
diff --git a/third_party/highway/hwy/auto_tune.h b/third_party/highway/hwy/auto_tune.h
index c94b5eb451..183487d909 100644
--- a/third_party/highway/hwy/auto_tune.h
+++ b/third_party/highway/hwy/auto_tune.h
@@ -25,7 +25,18 @@
#include "third_party/highway/hwy/aligned_allocator.h" // Span
#include "third_party/highway/hwy/base.h" // HWY_MIN
-#include "third_party/highway/hwy/contrib/sort/vqsort.h"
+
+// configuration to allow auto_tune to use std::sort instead of VQSort
+// (also enabled in header only mode).
+#if defined(HWY_HEADER_ONLY)
+#define HWY_AUTOTUNE_STDSORT
+#endif
+
+#ifdef HWY_AUTOTUNE_STDSORT
+#include <algorithm> // std::sort
+#else
+#include "third_party/highway/hwy/contrib/sort/vqsort.h" // VQSort
+#endif
// Infrastructure for auto-tuning (choosing optimal parameters at runtime).
@@ -104,6 +115,10 @@ class CostDistribution {
private:
static double Median(double* to_sort, size_t n) {
HWY_DASSERT(n >= 2);
+
+#ifdef HWY_AUTOTUNE_STDSORT
+ std::sort(to_sort, to_sort + n);
+#else
// F64 is supported everywhere except Armv7.
#if !HWY_ARCH_ARM_V7
VQSort(to_sort, n, SortAscending());
@@ -112,6 +127,8 @@ class CostDistribution {
// equivalent.
VQSort(reinterpret_cast<uint64_t*>(to_sort), n, SortAscending());
#endif
+#endif
+
if (n & 1) return to_sort[n / 2];
// Even length: average of two middle elements.
return (to_sort[n / 2] + to_sort[n / 2 - 1]) * 0.5;
@@ -405,10 +422,9 @@ class AutoTune {
const Config* Best() const { return best_; }
// If false, caller must call `SetCandidates` before `NextConfig`.
- bool HasCandidates() const {
- HWY_DASSERT(!Best());
- return !candidates_.empty();
- }
+ // NOTE: also called after Best() is non-null.
+ bool HasCandidates() const { return !candidates_.empty(); }
+
// WARNING: invalidates `Best()`, do not call if that is non-null.
void SetCandidates(std::vector<Config> candidates) {
HWY_DASSERT(!Best() && !HasCandidates());
@@ -429,7 +445,7 @@ class AutoTune {
// Returns the current `Config` to measure.
const Config& NextConfig() const {
- HWY_DASSERT(!Best() && HasCandidates());
+ HWY_DASSERT(HasCandidates());
return candidates_[config_idx_];
}
diff --git a/third_party/highway/hwy/base.h b/third_party/highway/hwy/base.h
index 54b71c7e12..db8e1d6872 100644
--- a/third_party/highway/hwy/base.h
+++ b/third_party/highway/hwy/base.h
@@ -22,8 +22,8 @@
#include <stddef.h>
#include <stdint.h>
#if defined(HWY_HEADER_ONLY)
-#include <cstdarg>
-#include <cstdio>
+#include <stdarg.h>
+#include <stdio.h>
#endif
#if !defined(HWY_NO_LIBCXX)
@@ -33,9 +33,10 @@
#include "third_party/highway/hwy/detect_compiler_arch.h"
#include "third_party/highway/hwy/highway_export.h"
-// API version (https://semver.org/); keep in sync with CMakeLists.txt.
+// API version (https://semver.org/); keep in sync with CMakeLists.txt and
+// meson.build.
#define HWY_MAJOR 1
-#define HWY_MINOR 2
+#define HWY_MINOR 4
#define HWY_PATCH 0
// True if the Highway version >= major.minor.0. Added in 1.2.0.
@@ -55,12 +56,12 @@
#include <inttypes.h>
#endif
-#if (HWY_ARCH_X86 && !defined(HWY_NO_LIBCXX)) || HWY_COMPILER_MSVC
+#endif // !HWY_IDE
+
+#if !defined(HWY_NO_LIBCXX) || HWY_COMPILER_MSVC
#include <atomic>
#endif
-#endif // !HWY_IDE
-
#ifndef HWY_HAVE_COMPARE_HEADER // allow override
#define HWY_HAVE_COMPARE_HEADER 0
#if defined(__has_include) // note: wrapper macro fails on Clang ~17
@@ -148,6 +149,21 @@
#endif // !HWY_COMPILER_MSVC
+#if (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1200) || \
+ (HWY_COMPILER_ICC && !HWY_COMPILER_ICX)
+// The use of __attribute__((unused)) in private class member variables triggers
+// a compiler warning with GCC 11 and earlier and ICC
+
+// GCC 11 and earlier and ICC also do not emit -Wunused-private-field warnings
+// for unused private class member variables
+#define HWY_MEMBER_VAR_MAYBE_UNUSED
+#else
+// Clang and ICX need __attribute__((unused)) in unused private class member
+// variables to suppress -Wunused-private-field warnings unless this warning is
+// ignored by using HWY_DIAGNOSTICS_OFF
+#define HWY_MEMBER_VAR_MAYBE_UNUSED HWY_MAYBE_UNUSED
+#endif
+
//------------------------------------------------------------------------------
// Builtin/attributes (no more #include after this point due to namespace!)
@@ -202,7 +218,9 @@ namespace hwy {
//------------------------------------------------------------------------------
// Macros
-#define HWY_API static HWY_INLINE HWY_FLATTEN HWY_MAYBE_UNUSED
+// Note: it is safe to remove `static` for users who want to use modules, but
+// that might be a breaking change for some users, hence we do not by default.
+#define HWY_API static HWY_INLINE HWY_FLATTEN
#define HWY_CONCAT_IMPL(a, b) a##b
#define HWY_CONCAT(a, b) HWY_CONCAT_IMPL(a, b)
@@ -248,13 +266,15 @@ namespace hwy {
#define HWY_ASSUME(expr) static_cast<void>(0)
#endif
-// Compile-time fence to prevent undesirable code reordering. On Clang x86, the
-// typical asm volatile("" : : : "memory") has no effect, whereas atomic fence
-// does, without generating code.
-#if HWY_ARCH_X86 && !defined(HWY_NO_LIBCXX)
-#define HWY_FENCE std::atomic_thread_fence(std::memory_order_acq_rel)
+// Compile-time fence to prevent undesirable code reordering. On Clang, the
+// typical `asm volatile("" : : : "memory")` seems to be ignored. Note that
+// `std::atomic_thread_fence` affects other threads, hence might generate a
+// barrier instruction, but this does not.
+#if !defined(HWY_NO_LIBCXX)
+#define HWY_FENCE std::atomic_signal_fence(std::memory_order_seq_cst)
+#elif HWY_COMPILER_GCC
+#define HWY_FENCE asm volatile("" : : : "memory")
#else
-// TODO(janwas): investigate alternatives. On Arm, the above generates barriers.
#define HWY_FENCE
#endif
@@ -347,60 +367,20 @@ HWY_DLLEXPORT HWY_NORETURN void HWY_FORMAT(3, 4)
} while (0)
#define HWY_ASSERT(condition) HWY_ASSERT_M(condition, "")
-#if HWY_HAS_FEATURE(memory_sanitizer) || defined(MEMORY_SANITIZER) || \
- defined(__SANITIZE_MEMORY__)
-#define HWY_IS_MSAN 1
-#else
-#define HWY_IS_MSAN 0
-#endif
-
-#if HWY_HAS_FEATURE(address_sanitizer) || defined(ADDRESS_SANITIZER) || \
- defined(__SANITIZE_ADDRESS__)
-#define HWY_IS_ASAN 1
-#else
-#define HWY_IS_ASAN 0
-#endif
-
-#if HWY_HAS_FEATURE(hwaddress_sanitizer) || defined(HWADDRESS_SANITIZER) || \
- defined(__SANITIZE_HWADDRESS__)
-#define HWY_IS_HWASAN 1
-#else
-#define HWY_IS_HWASAN 0
-#endif
-
-#if HWY_HAS_FEATURE(thread_sanitizer) || defined(THREAD_SANITIZER) || \
- defined(__SANITIZE_THREAD__)
-#define HWY_IS_TSAN 1
-#else
-#define HWY_IS_TSAN 0
-#endif
-
-#if HWY_HAS_FEATURE(undefined_behavior_sanitizer) || \
- defined(UNDEFINED_BEHAVIOR_SANITIZER)
-#define HWY_IS_UBSAN 1
-#else
-#define HWY_IS_UBSAN 0
-#endif
-
-// MSAN may cause lengthy build times or false positives e.g. in AVX3 DemoteTo.
-// You can disable MSAN by adding this attribute to the function that fails.
-#if HWY_IS_MSAN
-#define HWY_ATTR_NO_MSAN __attribute__((no_sanitize_memory))
-#else
-#define HWY_ATTR_NO_MSAN
-#endif
-
-#if HWY_IS_ASAN || HWY_IS_HWASAN || HWY_IS_MSAN || HWY_IS_TSAN || HWY_IS_UBSAN
-#define HWY_IS_SANITIZER 1
-#else
-#define HWY_IS_SANITIZER 0
-#endif
-
// For enabling HWY_DASSERT and shortening tests in slower debug builds
+//
+// Note: `HWY_IS_UBSAN` is specifically excluded from engaging debug
+// builds. This is in service of Chromium's `-fsanitize=array-bounds` by
+// default, where we don't want Highway to unconditionally build in
+// debug mode.
+//
+// See also:
+// https://docs.google.com/document/d/1eCtY4AZF-SiFHxhIYWzEytdIx3C24de7ccD6Y5Gn2H8/edit?tab=t.9zkn85hr82ms#heading=h.efcshvfql42c
#if !defined(HWY_IS_DEBUG_BUILD)
// Clang does not define NDEBUG, but it and GCC define __OPTIMIZE__, and recent
// MSVC defines NDEBUG (if not, could instead check _DEBUG).
-#if (!defined(__OPTIMIZE__) && !defined(NDEBUG)) || HWY_IS_SANITIZER || \
+#if (!defined(__OPTIMIZE__) && !defined(NDEBUG)) || \
+ (HWY_IS_ASAN || (HWY_IS_SANITIZER && !HWY_IS_UBSAN)) || \
defined(__clang_analyzer__)
#define HWY_IS_DEBUG_BUILD 1
#else
@@ -453,21 +433,22 @@ HWY_API void CopySameSize(const From* HWY_RESTRICT from, To* HWY_RESTRICT to) {
CopyBytes<sizeof(From)>(from, to);
}
-template <size_t kBytes, typename To>
-HWY_API void ZeroBytes(To* to) {
+// Sets each byte to `byte_value`, which must be between 0 and 255.
+HWY_API void FillBytes(void* to, int byte_value, size_t num_bytes) {
#if HWY_COMPILER_MSVC
- memset(to, 0, kBytes);
+ memset(to, byte_value, num_bytes);
#else
- __builtin_memset(to, 0, kBytes);
+ __builtin_memset(to, byte_value, num_bytes);
#endif
}
HWY_API void ZeroBytes(void* to, size_t num_bytes) {
-#if HWY_COMPILER_MSVC
- memset(to, 0, num_bytes);
-#else
- __builtin_memset(to, 0, num_bytes);
-#endif
+ FillBytes(to, 0, num_bytes);
+}
+
+template <size_t kBytes, typename To>
+HWY_API void ZeroBytes(To* to) {
+ ZeroBytes(to, kBytes);
}
//------------------------------------------------------------------------------
@@ -1168,6 +1149,7 @@ HWY_API HWY_BITCASTSCALAR_CONSTEXPR To BitCastScalar(const From& val) {
#pragma pack(push, 1)
+#ifndef HWY_NEON_HAVE_F16C // allow override
// Compiler supports __fp16 and load/store/conversion NEON intrinsics, which are
// included in Armv8 and VFPv4 (except with MSVC). On Armv7 Clang requires
// __ARM_FP & 2 whereas Armv7 GCC requires -mfp16-format=ieee.
@@ -1178,6 +1160,7 @@ HWY_API HWY_BITCASTSCALAR_CONSTEXPR To BitCastScalar(const From& val) {
#else
#define HWY_NEON_HAVE_F16C 0
#endif
+#endif // HWY_NEON_HAVE_F16C
// RVV with f16 extension supports _Float16 and f16 vector ops. If set, implies
// HWY_HAVE_FLOAT16.
@@ -1197,7 +1180,7 @@ HWY_API HWY_BITCASTSCALAR_CONSTEXPR To BitCastScalar(const From& val) {
#define HWY_SSE2_HAVE_F16_TYPE 0
#endif
-#ifndef HWY_HAVE_SCALAR_F16_TYPE
+#ifndef HWY_HAVE_SCALAR_F16_TYPE // allow override
// Compiler supports _Float16, not necessarily with operators.
#if HWY_NEON_HAVE_F16C || HWY_RVV_HAVE_F16_VEC || HWY_SSE2_HAVE_F16_TYPE || \
__SPIRV_DEVICE__
@@ -1695,24 +1678,14 @@ HWY_F16_CONSTEXPR inline std::partial_ordering operator<=>(
//------------------------------------------------------------------------------
// BF16 lane type
-// Compiler supports ACLE __bf16, not necessarily with operators.
-
-// Disable the __bf16 type on AArch64 with GCC 13 or earlier as there is a bug
-// in GCC 13 and earlier that sometimes causes BF16 constant values to be
-// incorrectly loaded on AArch64, and this GCC bug on AArch64 is
-// described at https://gcc.gnu.org/bugzilla/show_bug.cgi?id=111867.
-
-#if HWY_ARCH_ARM_A64 && \
- (HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400)
-#define HWY_ARM_HAVE_SCALAR_BF16_TYPE 1
-#else
-#define HWY_ARM_HAVE_SCALAR_BF16_TYPE 0
-#endif
-
// x86 compiler supports __bf16, not necessarily with operators.
+// Disable in debug builds due to clang miscompiles as of 2025-07-22: casting
+// bf16 <-> f32 in convert_test results in 0x2525 for 1.0 instead of 0x3f80.
+// Reported at https://github.com/llvm/llvm-project/issues/151692.
#ifndef HWY_SSE2_HAVE_SCALAR_BF16_TYPE
-#if HWY_ARCH_X86 && defined(__SSE2__) && \
- ((HWY_COMPILER_CLANG >= 1700 && !HWY_COMPILER_CLANGCL) || \
+#if HWY_ARCH_X86 && defined(__SSE2__) && \
+ ((HWY_COMPILER_CLANG >= 1700 && !HWY_COMPILER_CLANGCL && \
+ (!HWY_IS_DEBUG_BUILD || HWY_COMPILER3_CLANG >= 220101)) || \
HWY_COMPILER_GCC_ACTUAL >= 1300)
#define HWY_SSE2_HAVE_SCALAR_BF16_TYPE 1
#else
@@ -1730,7 +1703,11 @@ HWY_F16_CONSTEXPR inline std::partial_ordering operator<=>(
#ifndef HWY_HAVE_SCALAR_BF16_OPERATORS
// Recent enough compiler also has operators. aarch64 clang 18 hits internal
// compiler errors on bf16 ToString, hence only enable on GCC for now.
-#if HWY_HAVE_SCALAR_BF16_TYPE && (HWY_COMPILER_GCC_ACTUAL >= 1300)
+// GCC >= 13 will insert a function call to the __extendbfsf2 helper function
+// for scalar conversions from __bf16 to float. This is prohibitively expensive,
+// so refrain from using scalar BF16 operators on these compiler versions.
+// See https://gcc.gnu.org/bugzilla/show_bug.cgi?id=121853
+#if HWY_HAVE_SCALAR_BF16_TYPE && (HWY_COMPILER_GCC_ACTUAL >= 1700)
#define HWY_HAVE_SCALAR_BF16_OPERATORS 1
#else
#define HWY_HAVE_SCALAR_BF16_OPERATORS 0
@@ -1949,21 +1926,21 @@ static HWY_INLINE HWY_MAYBE_UNUSED constexpr uint32_t F32BitsToBF16RoundIncr(
: 0u);
}
+// If f32_bits is the bit representation of a NaN F32 value, make sure that
+// bit 6 of the BF16 result is set to convert SNaN F32 values to QNaN BF16
+// values and to prevent NaN F32 values from being converted to an infinite
+// BF16 value
+static HWY_INLINE constexpr uint32_t BF16BitsIfSNAN(uint32_t f32_bits) {
+ return ((f32_bits & 0x7FFFFFFFu) > 0x7F800000u) ? (uint32_t{1} << 6) : 0;
+}
+
// Converts f32_bits (which is the bits of a F32 value) to BF16 bits,
// rounded to the nearest F16 value
static HWY_INLINE HWY_MAYBE_UNUSED constexpr uint16_t F32BitsToBF16Bits(
const uint32_t f32_bits) {
- // Round f32_bits to the nearest BF16 by first adding
- // F32BitsToBF16RoundIncr(f32_bits) to f32_bits and then right shifting
- // f32_bits + F32BitsToBF16RoundIncr(f32_bits) by 16
-
- // If f32_bits is the bit representation of a NaN F32 value, make sure that
- // bit 6 of the BF16 result is set to convert SNaN F32 values to QNaN BF16
- // values and to prevent NaN F32 values from being converted to an infinite
- // BF16 value
return static_cast<uint16_t>(
- ((f32_bits + F32BitsToBF16RoundIncr(f32_bits)) >> 16) |
- (static_cast<uint32_t>((f32_bits & 0x7FFFFFFFu) > 0x7F800000u) << 6));
+ BF16BitsIfSNAN(f32_bits) |
+ ((f32_bits + F32BitsToBF16RoundIncr(f32_bits)) >> 16));
}
} // namespace detail
@@ -2303,6 +2280,11 @@ constexpr bool IsSigned<hwy::K32V32>() {
return false;
}
+template <typename T>
+HWY_API constexpr bool IsUnsigned() {
+ return IsInteger<T>() && !IsSigned<T>();
+}
+
template <typename T, bool = IsInteger<T>() && !IsIntegerLaneType<T>()>
struct MakeLaneTypeIfIntegerT {
using type = T;
@@ -2400,6 +2382,21 @@ HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR double Epsilon<double>() {
return 2.2204460492503131e-16;
}
+// Smallest positive normal value. Equal to 2^(1 - bias)
+template <typename T>
+HWY_API HWY_BITCASTSCALAR_CONSTEXPR T SmallestNormal() {
+ static_assert(sizeof(T) == 0, "Only instantiate the specializations");
+ return T{};
+}
+template <>
+HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR float SmallestNormal<float>() {
+ return 1.175494351e-38f; // 2^-126
+}
+template <>
+HWY_INLINE HWY_BITCASTSCALAR_CONSTEXPR double SmallestNormal<double>() {
+ return 2.2250738585072014e-308; // 2^-1022
+}
+
// Returns width in bits of the mantissa field in IEEE binary16/32/64.
template <typename T>
constexpr int MantissaBits() {
@@ -2636,53 +2633,83 @@ HWY_API float F32FromBF16Mem(const void* ptr) {
#define HWY_BF16_TO_F16_CONSTEXPR HWY_F16_CONSTEXPR
#endif
-// For casting from TFrom to TTo
-template <typename TTo, typename TFrom, HWY_IF_NOT_SPECIAL_FLOAT(TTo),
- HWY_IF_NOT_SPECIAL_FLOAT(TFrom), HWY_IF_NOT_SAME(TTo, TFrom)>
-HWY_API constexpr TTo ConvertScalarTo(const TFrom in) {
- return static_cast<TTo>(in);
-}
-template <typename TTo, typename TFrom, HWY_IF_F16(TTo),
- HWY_IF_NOT_SPECIAL_FLOAT(TFrom), HWY_IF_NOT_SAME(TFrom, double)>
-HWY_API constexpr TTo ConvertScalarTo(const TFrom in) {
- return F16FromF32(static_cast<float>(in));
+namespace detail {
+
+template <class TTo, class TFrom>
+static HWY_INLINE HWY_MAYBE_UNUSED constexpr TTo ConvertScalarToResult(
+ hwy::SizeTag<0> /*conv_to_tag*/, TFrom in) {
+ return static_cast<TTo>(static_cast<TFrom>(in));
}
-template <typename TTo, HWY_IF_F16(TTo)>
-HWY_API HWY_BF16_TO_F16_CONSTEXPR TTo
-ConvertScalarTo(const hwy::bfloat16_t in) {
- return F16FromF32(F32FromBF16(in));
+
+template <class TTo>
+static HWY_INLINE HWY_MAYBE_UNUSED HWY_F16_CONSTEXPR TTo
+ConvertScalarToResult(hwy::FloatTag /*conv_to_tag*/, float in) {
+ return F16FromF32(in);
}
-template <typename TTo, HWY_IF_F16(TTo)>
-HWY_API HWY_F16_CONSTEXPR TTo ConvertScalarTo(const double in) {
+
+template <class TTo>
+static HWY_INLINE HWY_MAYBE_UNUSED HWY_F16_CONSTEXPR TTo
+ConvertScalarToResult(hwy::FloatTag /*conv_to_tag*/, double in) {
return F16FromF64(in);
}
-template <typename TTo, typename TFrom, HWY_IF_BF16(TTo),
- HWY_IF_NOT_SPECIAL_FLOAT(TFrom), HWY_IF_NOT_SAME(TFrom, double)>
-HWY_API HWY_BF16_CONSTEXPR TTo ConvertScalarTo(const TFrom in) {
- return BF16FromF32(static_cast<float>(in));
-}
-template <typename TTo, HWY_IF_BF16(TTo)>
-HWY_API HWY_BF16_TO_F16_CONSTEXPR TTo ConvertScalarTo(const hwy::float16_t in) {
- return BF16FromF32(F32FromF16(in));
+
+template <class TTo>
+static HWY_INLINE HWY_MAYBE_UNUSED HWY_BF16_CONSTEXPR TTo
+ConvertScalarToResult(hwy::SpecialTag /*conv_to_tag*/, float in) {
+ return BF16FromF32(in);
}
-template <typename TTo, HWY_IF_BF16(TTo)>
-HWY_API HWY_BF16_CONSTEXPR TTo ConvertScalarTo(const double in) {
+
+template <class TTo>
+static HWY_INLINE HWY_MAYBE_UNUSED HWY_BF16_CONSTEXPR TTo
+ConvertScalarToResult(hwy::SpecialTag /*conv_to_tag*/, double in) {
return BF16FromF64(in);
}
-template <typename TTo, typename TFrom, HWY_IF_F16(TFrom),
- HWY_IF_NOT_SPECIAL_FLOAT(TTo)>
-HWY_API HWY_F16_CONSTEXPR TTo ConvertScalarTo(const TFrom in) {
- return static_cast<TTo>(F32FromF16(in));
+
+template <class TFrom, HWY_IF_BF16(TFrom)>
+static HWY_INLINE HWY_MAYBE_UNUSED HWY_BF16_CONSTEXPR float
+ConvertScalarSpecialFloatToF32(hwy::SpecialTag /*conv_from_tag*/, TFrom in) {
+ return F32FromBF16(in);
}
-template <typename TTo, typename TFrom, HWY_IF_BF16(TFrom),
- HWY_IF_NOT_SPECIAL_FLOAT(TTo)>
-HWY_API HWY_BF16_CONSTEXPR TTo ConvertScalarTo(TFrom in) {
- return static_cast<TTo>(F32FromBF16(in));
+
+template <class TFrom, HWY_IF_F16(TFrom)>
+static HWY_INLINE HWY_MAYBE_UNUSED HWY_F16_CONSTEXPR float
+ConvertScalarSpecialFloatToF32(hwy::SpecialTag /*conv_from_tag*/, TFrom in) {
+ return F32FromF16(in);
}
-// Same: return unchanged
-template <typename TTo>
-HWY_API constexpr TTo ConvertScalarTo(TTo in) {
- return in;
+
+template <class TFrom>
+static HWY_INLINE HWY_MAYBE_UNUSED constexpr auto
+ConvertScalarSpecialFloatToF32(hwy::FloatTag /*conv_from_tag*/, TFrom in)
+ -> hwy::If<hwy::IsSame<hwy::RemoveCvRef<TFrom>, double>(), double, float> {
+ return static_cast<
+ hwy::If<hwy::IsSame<hwy::RemoveCvRef<TFrom>, double>(), double, float>>(
+ in);
+}
+
+template <class TFrom>
+static HWY_INLINE HWY_MAYBE_UNUSED constexpr TFrom
+ConvertScalarSpecialFloatToF32(hwy::SizeTag<0> /*conv_from_tag*/, TFrom in) {
+ return static_cast<TFrom>(in);
+}
+
+} // namespace detail
+
+template <typename TTo, typename TFrom>
+HWY_API constexpr TTo ConvertScalarTo(TFrom in) {
+ return detail::ConvertScalarToResult<TTo>(
+ hwy::SizeTag<
+ (!hwy::IsSame<hwy::RemoveCvRef<TFrom>, hwy::RemoveCvRef<TTo>>() &&
+ hwy::IsSpecialFloat<TTo>())
+ ? (hwy::IsSame<RemoveCvRef<TTo>, hwy::bfloat16_t>() ? 0x300
+ : 0x200)
+ : 0>(),
+ detail::ConvertScalarSpecialFloatToF32(
+ hwy::SizeTag<
+ (!hwy::IsSame<hwy::RemoveCvRef<TFrom>, hwy::RemoveCvRef<TTo>>() &&
+ (hwy::IsSpecialFloat<TFrom>() || hwy::IsSpecialFloat<TTo>()))
+ ? (hwy::IsSpecialFloat<TFrom>() ? 0x300 : 0x200)
+ : 0>(),
+ static_cast<TFrom&&>(in)));
}
//------------------------------------------------------------------------------
@@ -2691,7 +2718,7 @@ HWY_API constexpr TTo ConvertScalarTo(TTo in) {
template <typename T1, typename T2>
constexpr inline T1 DivCeil(T1 a, T2 b) {
#if HWY_CXX_LANG >= 201703L
- HWY_DASSERT(b != 0);
+ HWY_DASSERT(b != T2{0});
#endif
return (a + b - 1) / b;
}
@@ -2914,9 +2941,10 @@ HWY_INLINE constexpr T AddWithWraparound(T t, T2 n) {
// 64 x 64 = 128 bit multiplication
HWY_API uint64_t Mul128(uint64_t a, uint64_t b, uint64_t* HWY_RESTRICT upper) {
#if defined(__SIZEOF_INT128__)
- __uint128_t product = (__uint128_t)a * (__uint128_t)b;
- *upper = (uint64_t)(product >> 64);
- return (uint64_t)(product & 0xFFFFFFFFFFFFFFFFULL);
+ __uint128_t product =
+ static_cast<__uint128_t>(a) * static_cast<__uint128_t>(b);
+ *upper = static_cast<uint64_t>(product >> 64);
+ return static_cast<uint64_t>(product & 0xFFFFFFFFFFFFFFFFULL);
#elif HWY_COMPILER_MSVC && HWY_ARCH_X86_64
return _umul128(a, b, upper);
#else
@@ -2933,9 +2961,9 @@ HWY_API uint64_t Mul128(uint64_t a, uint64_t b, uint64_t* HWY_RESTRICT upper) {
HWY_API int64_t Mul128(int64_t a, int64_t b, int64_t* HWY_RESTRICT upper) {
#if defined(__SIZEOF_INT128__)
- __int128_t product = (__int128_t)a * (__int128_t)b;
- *upper = (int64_t)(product >> 64);
- return (int64_t)(product & 0xFFFFFFFFFFFFFFFFULL);
+ __int128_t product = static_cast<__int128_t>(a) * static_cast<__int128_t>(b);
+ *upper = static_cast<int64_t>(product >> 64);
+ return static_cast<int64_t>(product & 0xFFFFFFFFFFFFFFFFULL);
#elif HWY_COMPILER_MSVC && HWY_ARCH_X86_64
return _mul128(a, b, upper);
#else
diff --git a/third_party/highway/hwy/bit_set.h b/third_party/highway/hwy/bit_set.h
index b747416527..65b068b721 100644
--- a/third_party/highway/hwy/bit_set.h
+++ b/third_party/highway/hwy/bit_set.h
@@ -16,17 +16,21 @@
#ifndef HIGHWAY_HWY_BIT_SET_H_
#define HIGHWAY_HWY_BIT_SET_H_
-// BitSet with fast Foreach for up to 64 and 4096 members.
+// Various BitSet for 64, up to 4096, or any number of bits.
#include <stddef.h>
+#include <atomic>
+
#include "third_party/highway/hwy/base.h"
namespace hwy {
-// 64-bit specialization of std::bitset, which lacks Foreach.
+// 64-bit specialization of `std::bitset`, which lacks `Foreach`.
class BitSet64 {
public:
+ constexpr size_t MaxSize() const { return 64; }
+
// No harm if `i` is already set.
void Set(size_t i) {
HWY_DASSERT(i < 64);
@@ -48,15 +52,24 @@ class BitSet64 {
return (bits_ & (1ULL << i)) != 0;
}
- // Returns true if any Get(i) would return true for i in [0, 64).
+ // Returns true if Get(i) would return true for any i in [0, 64).
bool Any() const { return bits_ != 0; }
- // Returns lowest i such that Get(i). Caller must ensure Any() beforehand!
+ // Returns true if Get(i) would return true for all i in [0, 64).
+ bool All() const { return bits_ == ~uint64_t{0}; }
+
+ // Returns lowest i such that `Get(i)`. Caller must first ensure `Any()`!
size_t First() const {
HWY_DASSERT(Any());
return Num0BitsBelowLS1Bit_Nonzero64(bits_);
}
+ // Returns lowest i such that `!Get(i)`. Caller must first ensure `!All()`!
+ size_t First0() const {
+ HWY_DASSERT(!All());
+ return Num0BitsBelowLS1Bit_Nonzero64(~bits_);
+ }
+
// Returns uint64_t(Get(i)) << i for i in [0, 64).
uint64_t Get64() const { return bits_; }
@@ -78,10 +91,226 @@ class BitSet64 {
uint64_t bits_ = 0;
};
-// Two-level bitset for up to kMaxSize <= 4096 values.
+// Any number of bits, flat array.
+template <size_t kMaxSize>
+class BitSet {
+ static_assert(kMaxSize != 0, "BitSet requires non-zero size");
+
+ public:
+ constexpr size_t MaxSize() const { return kMaxSize; }
+
+ // No harm if `i` is already set.
+ void Set(size_t i) {
+ HWY_DASSERT(i < kMaxSize);
+ const size_t idx = i / 64;
+ const size_t mod = i % 64;
+ bits_[idx].Set(mod);
+ }
+
+ void Clear(size_t i) {
+ HWY_DASSERT(i < kMaxSize);
+ const size_t idx = i / 64;
+ const size_t mod = i % 64;
+ bits_[idx].Clear(mod);
+ HWY_DASSERT(!Get(i));
+ }
+
+ bool Get(size_t i) const {
+ HWY_DASSERT(i < kMaxSize);
+ const size_t idx = i / 64;
+ const size_t mod = i % 64;
+ return bits_[idx].Get(mod);
+ }
+
+ // Returns true if Get(i) would return true for any i in [0, kMaxSize).
+ bool Any() const {
+ for (const BitSet64& bits : bits_) {
+ if (bits.Any()) return true;
+ }
+ return false;
+ }
+
+ // Returns true if Get(i) would return true for all i in [0, kMaxSize).
+ bool All() const {
+ for (size_t idx = 0; idx < kNum64 - 1; ++idx) {
+ if (!bits_[idx].All()) return false;
+ }
+
+ constexpr size_t kRemainder = kMaxSize % 64;
+ if (kRemainder == 0) {
+ return bits_[kNum64 - 1].All();
+ }
+ return bits_[kNum64 - 1].Count() == kRemainder;
+ }
+
+ // Returns lowest i such that `Get(i)`. Caller must first ensure `Any()`!
+ size_t First() const {
+ HWY_DASSERT(Any());
+ for (size_t idx = 0;; ++idx) {
+ HWY_DASSERT(idx < kNum64);
+ if (bits_[idx].Any()) return idx * 64 + bits_[idx].First();
+ }
+ }
+
+ // Returns lowest i such that `!Get(i)`. Caller must first ensure `All()`!
+ size_t First0() const {
+ HWY_DASSERT(!All());
+ for (size_t idx = 0;; ++idx) {
+ HWY_DASSERT(idx < kNum64);
+ if (!bits_[idx].All()) {
+ const size_t first0 = idx * 64 + bits_[idx].First0();
+ HWY_DASSERT(first0 < kMaxSize);
+ return first0;
+ }
+ }
+ }
+
+ // Calls `func(i)` for each `i` in the set. It is safe for `func` to modify
+ // the set, but the current Foreach call is only affected if changing one of
+ // the not yet visited BitSet64.
+ template <class Func>
+ void Foreach(const Func& func) const {
+ for (size_t idx = 0; idx < kNum64; ++idx) {
+ bits_[idx].Foreach([idx, &func](size_t mod) { func(idx * 64 + mod); });
+ }
+ }
+
+ size_t Count() const {
+ size_t total = 0;
+ for (const BitSet64& bits : bits_) {
+ total += bits.Count();
+ }
+ return total;
+ }
+
+ private:
+ static constexpr size_t kNum64 = DivCeil(kMaxSize, size_t{64});
+ BitSet64 bits_[kNum64];
+};
+
+// Any number of bits, flat array, atomic updates to the u64.
+template <size_t kMaxSize>
+class AtomicBitSet {
+ static_assert(kMaxSize != 0, "AtomicBitSet requires non-zero size");
+
+ // Bits may signal something to other threads, hence relaxed is insufficient.
+ // Acq/Rel ensures a happens-before relationship.
+ static constexpr auto kAcq = std::memory_order_acquire;
+ static constexpr auto kRel = std::memory_order_release;
+
+ public:
+ constexpr size_t MaxSize() const { return kMaxSize; }
+
+ // No harm if `i` is already set.
+ void Set(size_t i) {
+ HWY_DASSERT(i < kMaxSize);
+ const size_t idx = i / 64;
+ const size_t mod = i % 64;
+ bits_[idx].fetch_or(1ULL << mod, kRel);
+ }
+
+ void Clear(size_t i) {
+ HWY_DASSERT(i < kMaxSize);
+ const size_t idx = i / 64;
+ const size_t mod = i % 64;
+ bits_[idx].fetch_and(~(1ULL << mod), kRel);
+ HWY_DASSERT(!Get(i));
+ }
+
+ bool Get(size_t i) const {
+ HWY_DASSERT(i < kMaxSize);
+ const size_t idx = i / 64;
+ const size_t mod = i % 64;
+ return ((bits_[idx].load(kAcq) & (1ULL << mod))) != 0;
+ }
+
+ // Returns true if Get(i) would return true for any i in [0, kMaxSize).
+ bool Any() const {
+ for (const std::atomic<uint64_t>& bits : bits_) {
+ if (bits.load(kAcq)) return true;
+ }
+ return false;
+ }
+
+ // Returns true if Get(i) would return true for all i in [0, kMaxSize).
+ bool All() const {
+ for (size_t idx = 0; idx < kNum64 - 1; ++idx) {
+ if (bits_[idx].load(kAcq) != ~uint64_t{0}) return false;
+ }
+
+ constexpr size_t kRemainder = kMaxSize % 64;
+ const uint64_t last_bits = bits_[kNum64 - 1].load(kAcq);
+ if (kRemainder == 0) {
+ return last_bits == ~uint64_t{0};
+ }
+ return PopCount(last_bits) == kRemainder;
+ }
+
+ // Returns lowest i such that `Get(i)`. Caller must first ensure `Any()`!
+ size_t First() const {
+ HWY_DASSERT(Any());
+ for (size_t idx = 0;; ++idx) {
+ HWY_DASSERT(idx < kNum64);
+ const uint64_t bits = bits_[idx].load(kAcq);
+ if (bits != 0) {
+ return idx * 64 + Num0BitsBelowLS1Bit_Nonzero64(bits);
+ }
+ }
+ }
+
+ // Returns lowest i such that `!Get(i)`. Caller must first ensure `!All()`!
+ size_t First0() const {
+ HWY_DASSERT(!All());
+ for (size_t idx = 0;; ++idx) {
+ HWY_DASSERT(idx < kNum64);
+ const uint64_t inv_bits = ~bits_[idx].load(kAcq);
+ if (inv_bits != 0) {
+ const size_t first0 =
+ idx * 64 + Num0BitsBelowLS1Bit_Nonzero64(inv_bits);
+ HWY_DASSERT(first0 < kMaxSize);
+ return first0;
+ }
+ }
+ }
+
+ // Calls `func(i)` for each `i` in the set. It is safe for `func` to modify
+ // the set, but the current Foreach call is only affected if changing one of
+ // the not yet visited uint64_t.
+ template <class Func>
+ void Foreach(const Func& func) const {
+ for (size_t idx = 0; idx < kNum64; ++idx) {
+ uint64_t remaining_bits = bits_[idx].load(kAcq);
+ while (remaining_bits != 0) {
+ const size_t i = Num0BitsBelowLS1Bit_Nonzero64(remaining_bits);
+ remaining_bits &= remaining_bits - 1; // clear LSB
+ func(idx * 64 + i);
+ }
+ }
+ }
+
+ size_t Count() const {
+ size_t total = 0;
+ for (const std::atomic<uint64_t>& bits : bits_) {
+ total += PopCount(bits.load(kAcq));
+ }
+ return total;
+ }
+
+ private:
+ static constexpr size_t kNum64 = DivCeil(kMaxSize, size_t{64});
+ std::atomic<uint64_t> bits_[kNum64] = {};
+};
+
+// Two-level bitset for up to `kMaxSize` <= 4096 values. The iterators
+// (`Any/First/Foreach/Count`) are more efficient than `BitSet` for sparse sets.
+// This comes at the cost of slightly slower mutators (`Set/Clear`).
template <size_t kMaxSize = 4096>
class BitSet4096 {
+ static_assert(kMaxSize != 0, "BitSet4096 requires non-zero size");
+
public:
+ constexpr size_t MaxSize() const { return kMaxSize; }
+
// No harm if `i` is already set.
void Set(size_t i) {
HWY_DASSERT(i < kMaxSize);
@@ -117,16 +346,38 @@ class BitSet4096 {
return bits_[idx].Get(mod);
}
- // Returns true if any Get(i) would return true for i in [0, 64).
+ // Returns true if `Get(i)` would return true for any i in [0, kMaxSize).
bool Any() const { return nonzero_.Any(); }
- // Returns lowest i such that Get(i). Caller must ensure Any() beforehand!
+ // Returns true if `Get(i)` would return true for all i in [0, kMaxSize).
+ bool All() const {
+ // Do not check `nonzero_.All()` - that only works if `kMaxSize` is 4096.
+ if (nonzero_.Count() != kNum64) return false;
+ return Count() == kMaxSize;
+ }
+
+ // Returns lowest i such that `Get(i)`. Caller must first ensure `Any()`!
size_t First() const {
HWY_DASSERT(Any());
const size_t idx = nonzero_.First();
return idx * 64 + bits_[idx].First();
}
+ // Returns lowest i such that `!Get(i)`. Caller must first ensure `!All()`!
+ size_t First0() const {
+ HWY_DASSERT(!All());
+ // It is likely not worthwhile to have a separate `BitSet64` for `not_all_`,
+ // hence iterate over all u64.
+ for (size_t idx = 0;; ++idx) {
+ HWY_DASSERT(idx < kNum64);
+ if (!bits_[idx].All()) {
+ const size_t first0 = idx * 64 + bits_[idx].First0();
+ HWY_DASSERT(first0 < kMaxSize);
+ return first0;
+ }
+ }
+ }
+
// Returns uint64_t(Get(i)) << i for i in [0, 64).
uint64_t Get64() const { return bits_[0].Get64(); }
@@ -149,8 +400,9 @@ class BitSet4096 {
private:
static_assert(kMaxSize <= 64 * 64, "One BitSet64 insufficient");
+ static constexpr size_t kNum64 = DivCeil(kMaxSize, size_t{64});
BitSet64 nonzero_;
- BitSet64 bits_[kMaxSize / 64];
+ BitSet64 bits_[kNum64];
};
} // namespace hwy
diff --git a/third_party/highway/hwy/cache_control.h b/third_party/highway/hwy/cache_control.h
index b3bf5a8323..e93c30bbc3 100644
--- a/third_party/highway/hwy/cache_control.h
+++ b/third_party/highway/hwy/cache_control.h
@@ -16,6 +16,7 @@
#ifndef HIGHWAY_HWY_CACHE_CONTROL_H_
#define HIGHWAY_HWY_CACHE_CONTROL_H_
+#include "third_party/highway/hwy/aligned_allocator.h" // HWY_ALIGNMENT
#include "third_party/highway/hwy/base.h"
// Requires SSE2; fails to compile on 32-bit Clang 7 (see
@@ -66,6 +67,21 @@ HWY_INLINE HWY_ATTR_CACHE void LoadFence() {
// TODO(janwas): remove when this function is removed. (See above.)
#pragma pop_macro("LoadFence")
+// Overwrites "to" while attempting to bypass the cache (read-for-ownership).
+// Both pointers must be aligned.
+static HWY_INLINE void StreamCacheLine(const uint64_t* HWY_RESTRICT from,
+ uint64_t* HWY_RESTRICT to) {
+ HWY_DASSERT(IsAligned(from));
+ HWY_DASSERT(IsAligned(to));
+#if HWY_COMPILER_CLANG && !defined(HWY_DISABLE_CACHE_CONTROL)
+ for (size_t i = 0; i < HWY_ALIGNMENT / sizeof(uint64_t); ++i) {
+ __builtin_nontemporal_store(from[i], to + i);
+ }
+#else
+ hwy::CopyBytes(from, to, HWY_ALIGNMENT);
+#endif
+}
+
// Ensures values written by previous `Stream` calls are visible on the current
// core. This is NOT sufficient for synchronizing across cores; when `Stream`
// outputs are to be consumed by other core(s), the producer must publish
@@ -82,9 +98,10 @@ template <typename T>
HWY_INLINE HWY_ATTR_CACHE void Prefetch(const T* p) {
(void)p;
#ifndef HWY_DISABLE_CACHE_CONTROL
-#if HWY_ARCH_X86
+// Use _mm_prefetch on x86/x64, except when clang-cl is compiled with -mno-mmx.
+#if HWY_ARCH_X86 && !(HWY_COMPILER_CLANGCL && !defined(__MMX__))
_mm_prefetch(reinterpret_cast<const char*>(p), _MM_HINT_T0);
-#elif HWY_COMPILER_GCC // includes clang
+#elif HWY_COMPILER_GCC || HWY_COMPILER_CLANGCL // includes clang
// Hint=0 (NTA) behavior differs, but skipping outer caches is probably not
// desirable, so use the default 3 (keep in caches).
__builtin_prefetch(p, /*write=*/0, /*hint=*/3);
diff --git a/third_party/highway/hwy/contrib/algo/copy_test.cc b/third_party/highway/hwy/contrib/algo/copy_test.cc
new file mode 100644
index 0000000000..c21f4236c1
--- /dev/null
+++ b/third_party/highway/hwy/contrib/algo/copy_test.cc
@@ -0,0 +1,185 @@
+// Copyright 2022 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stddef.h>
+
+#include "third_party/highway/hwy/aligned_allocator.h"
+
+// clang-format off
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "hwy/contrib/algo/copy_test.cc"
+#include "third_party/highway/hwy/foreach_target.h" // IWYU pragma: keep
+#include "third_party/highway/hwy/highway.h"
+#include "third_party/highway/hwy/contrib/algo/copy-inl.h"
+#include "third_party/highway/hwy/tests/test_util-inl.h"
+// clang-format on
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+namespace {
+
+// Returns random integer in [0, 128), which fits in any lane type.
+template <typename T>
+T Random7Bit(RandomState& rng) {
+ return ConvertScalarTo<T>(Random32(&rng) & 127);
+}
+
+// Invokes Test (e.g. TestCopyIf) with all arg combinations. T comes from
+// ForFloatTypes.
+template <class Test>
+struct ForeachCountAndMisalign {
+ template <typename T, class D>
+ HWY_NOINLINE void operator()(T /*unused*/, D d) const {
+ RandomState rng;
+ const size_t N = Lanes(d);
+ const size_t misalignments[3] = {0, N / 4, 3 * N / 5};
+
+ for (size_t count = 0; count < 2 * N; ++count) {
+ for (size_t ma : misalignments) {
+ for (size_t mb : misalignments) {
+ Test()(d, count, ma, mb, rng);
+ }
+ }
+ }
+ }
+};
+
+struct TestFill {
+ template <class D>
+ void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
+ RandomState& rng) {
+ using T = TFromD<D>;
+ // HWY_MAX prevents error when misalign == count == 0.
+ AlignedFreeUniquePtr<T[]> pa =
+ AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
+ AlignedFreeUniquePtr<T[]> pb = AllocateAligned<T>(misalign_b + count + 1);
+ HWY_ASSERT(pa && pb);
+ T* expected = pa.get() + misalign_a;
+ const T value = Random7Bit<T>(rng);
+ for (size_t i = 0; i < count; ++i) {
+ expected[i] = value;
+ }
+ T* actual = pb.get() + misalign_b;
+
+ actual[count] = ConvertScalarTo<T>(0); // sentinel
+ Fill(d, value, count, actual);
+ HWY_ASSERT_EQ(ConvertScalarTo<T>(0), actual[count]); // no write past end
+
+ const auto info = hwy::detail::MakeTypeInfo<T>();
+ const char* target_name = hwy::TargetName(HWY_TARGET);
+ hwy::detail::AssertArrayEqual(info, expected, actual, count, target_name,
+ __FILE__, __LINE__);
+ }
+};
+
+void TestAllFill() {
+ ForAllTypes(ForPartialVectors<ForeachCountAndMisalign<TestFill>>());
+}
+
+struct TestCopy {
+ template <class D>
+ void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
+ RandomState& rng) {
+ using T = TFromD<D>;
+ // Prevents error if size to allocate is zero.
+ AlignedFreeUniquePtr<T[]> pa =
+ AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
+ AlignedFreeUniquePtr<T[]> pb =
+ AllocateAligned<T>(HWY_MAX(1, misalign_b + count));
+ HWY_ASSERT(pa && pb);
+ T* a = pa.get() + misalign_a;
+ for (size_t i = 0; i < count; ++i) {
+ a[i] = Random7Bit<T>(rng);
+ }
+ T* b = pb.get() + misalign_b;
+
+ Copy(d, a, count, b);
+
+ const auto info = hwy::detail::MakeTypeInfo<T>();
+ const char* target_name = hwy::TargetName(HWY_TARGET);
+ hwy::detail::AssertArrayEqual(info, a, b, count, target_name, __FILE__,
+ __LINE__);
+ }
+};
+
+void TestAllCopy() {
+ ForAllTypes(ForPartialVectors<ForeachCountAndMisalign<TestCopy>>());
+}
+
+struct TestCopyIf {
+ template <class D>
+ void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
+ RandomState& rng) {
+ using T = TFromD<D>;
+ const size_t padding = Lanes(ScalableTag<T>());
+
+ // Prevents error if size to allocate is zero.
+ AlignedFreeUniquePtr<T[]> pa =
+ AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
+ AlignedFreeUniquePtr<T[]> pb =
+ AllocateAligned<T>(HWY_MAX(1, misalign_b + count + padding));
+ AlignedFreeUniquePtr<T[]> expected = AllocateAligned<T>(HWY_MAX(1, count));
+ HWY_ASSERT(pa && pb && expected);
+
+ T* a = pa.get() + misalign_a;
+ for (size_t i = 0; i < count; ++i) {
+ a[i] = Random7Bit<T>(rng);
+ }
+ T* b = pb.get() + misalign_b;
+
+ size_t num_odd = 0;
+ for (size_t i = 0; i < count; ++i) {
+ if (a[i] & 1) {
+ expected[num_odd++] = a[i];
+ }
+ }
+
+ const auto is_odd = [](const auto d2, const auto v) HWY_ATTR {
+ return TestBit(v, Set(d2, TFromD<decltype(d2)>{1}));
+ };
+ T* end = CopyIf(d, a, count, b, is_odd);
+ const size_t num_written = static_cast<size_t>(end - b);
+ HWY_ASSERT_EQ(num_odd, num_written);
+
+ const auto info = hwy::detail::MakeTypeInfo<T>();
+ const char* target_name = hwy::TargetName(HWY_TARGET);
+ hwy::detail::AssertArrayEqual(info, expected.get(), b, num_odd, target_name,
+ __FILE__, __LINE__);
+ }
+};
+
+void TestAllCopyIf() {
+ ForUI163264(ForPartialVectors<ForeachCountAndMisalign<TestCopyIf>>());
+}
+
+} // namespace
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace hwy {
+namespace {
+HWY_BEFORE_TEST(CopyTest);
+HWY_EXPORT_AND_TEST_P(CopyTest, TestAllFill);
+HWY_EXPORT_AND_TEST_P(CopyTest, TestAllCopy);
+HWY_EXPORT_AND_TEST_P(CopyTest, TestAllCopyIf);
+HWY_AFTER_TEST();
+} // namespace
+} // namespace hwy
+HWY_TEST_MAIN();
+#endif // HWY_ONCE
diff --git a/third_party/highway/hwy/contrib/algo/count-inl.h b/third_party/highway/hwy/contrib/algo/count-inl.h
new file mode 100644
index 0000000000..85554c6518
--- /dev/null
+++ b/third_party/highway/hwy/contrib/algo/count-inl.h
@@ -0,0 +1,393 @@
+// Copyright 2026 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stddef.h>
+#include <stdint.h>
+
+// Per-target include guard
+#if defined(HIGHWAY_HWY_CONTRIB_ALGO_COUNT_INL_H_) == \
+ defined(HWY_TARGET_TOGGLE) // NOLINT
+#ifdef HIGHWAY_HWY_CONTRIB_ALGO_COUNT_INL_H_
+#undef HIGHWAY_HWY_CONTRIB_ALGO_COUNT_INL_H_
+#else
+#define HIGHWAY_HWY_CONTRIB_ALGO_COUNT_INL_H_
+#endif
+
+#include "third_party/highway/hwy/highway.h"
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+
+// Returns the number of elements in `in[0, count)` equal to `value`.
+template <class D, typename T = TFromD<D>>
+size_t Count(D d, T value, const T* HWY_RESTRICT in, size_t count) {
+ const size_t N = Lanes(d);
+ using V = Vec<D>;
+ const V broadcasted = Set(d, value);
+ const RebindToSigned<D> di;
+ using TI = TFromD<decltype(di)>;
+ using VI = Vec<decltype(di)>;
+
+ size_t total = 0;
+ size_t i = 0;
+
+ // Min 4 lanes needed for two pairwise widenings, 8->16->32
+ if constexpr (sizeof(T) == 1 && HWY_MAX_LANES_D(D) >= 4) {
+ const VI k1 = Set(di, TI{1});
+ const RebindToUnsigned<decltype(di)> du;
+ const RepartitionToWide<decltype(di)> di16;
+ const Repartition<int32_t, D> di32;
+ auto wide_sum = Zero(di32);
+
+ if (count >= 4 * N && N >= 4) {
+ while (i <= count - 4 * N) {
+ VI acc0 = Zero(di);
+ VI acc1 = Zero(di);
+ VI acc2 = Zero(di);
+ VI acc3 = Zero(di);
+ const size_t cap = HWY_MIN(i + 128 * 4 * N, count);
+
+ if constexpr (HWY_NATIVE_MASK) {
+ for (; i <= cap - 4 * N; i += 4 * N) {
+ const auto m0 = RebindMask(di, Eq(broadcasted, LoadU(d, in + i)));
+ const auto m1 =
+ RebindMask(di, Eq(broadcasted, LoadU(d, in + i + N)));
+ const auto m2 =
+ RebindMask(di, Eq(broadcasted, LoadU(d, in + i + 2 * N)));
+ const auto m3 =
+ RebindMask(di, Eq(broadcasted, LoadU(d, in + i + 3 * N)));
+ acc0 = MaskedAddOr(acc0, m0, acc0, k1);
+ acc1 = MaskedAddOr(acc1, m1, acc1, k1);
+ acc2 = MaskedAddOr(acc2, m2, acc2, k1);
+ acc3 = MaskedAddOr(acc3, m3, acc3, k1);
+ }
+ } else {
+ for (; i <= cap - 4 * N; i += 4 * N) {
+ const auto v0 = VecFromMask(d, Eq(broadcasted, LoadU(d, in + i)));
+ const auto v1 =
+ VecFromMask(d, Eq(broadcasted, LoadU(d, in + i + N)));
+ const auto v2 =
+ VecFromMask(d, Eq(broadcasted, LoadU(d, in + i + 2 * N)));
+ const auto v3 =
+ VecFromMask(d, Eq(broadcasted, LoadU(d, in + i + 3 * N)));
+ acc0 = Add(acc0, BitCast(di, v0));
+ acc1 = Add(acc1, BitCast(di, v1));
+ acc2 = Add(acc2, BitCast(di, v2));
+ acc3 = Add(acc3, BitCast(di, v3));
+ }
+
+ acc0 = Neg(acc0);
+ acc1 = Neg(acc1);
+ acc2 = Neg(acc2);
+ acc3 = Neg(acc3);
+ }
+
+ const auto w0 = SatWidenMulPairwiseAdd(di16, BitCast(du, acc0), k1);
+ const auto w1 = SatWidenMulPairwiseAdd(di16, BitCast(du, acc1), k1);
+ const auto w2 = SatWidenMulPairwiseAdd(di16, BitCast(du, acc2), k1);
+ const auto w3 = SatWidenMulPairwiseAdd(di16, BitCast(du, acc3), k1);
+ const auto sum16 = Add(Add(w0, w1), Add(w2, w3));
+ wide_sum = SatWidenMulPairwiseAccumulate(
+ di32, sum16, Set(di16, int16_t{1}), wide_sum);
+ }
+ }
+ total += static_cast<size_t>(ReduceSum(di32, wide_sum));
+ } else if constexpr (sizeof(T) == 2 && HWY_MAX_LANES_D(D) >= 2) {
+ // Min 2 lanes needed for pairwise widening, 16->32
+ const Repartition<int32_t, D> di32;
+ auto wide_sum = Zero(di32);
+
+ if (count >= 4 * N && N >= 2) {
+ while (i <= count - 4 * N) {
+ VI acc0 = Zero(di);
+ VI acc1 = Zero(di);
+ VI acc2 = Zero(di);
+ VI acc3 = Zero(di);
+ const size_t cap = HWY_MIN(i + 32768 * 4 * N, count);
+
+ if constexpr (HWY_NATIVE_MASK) {
+ const auto k1 = Set(di, TI{1});
+ for (; i <= cap - 4 * N; i += 4 * N) {
+ const auto m0 = RebindMask(di, Eq(broadcasted, LoadU(d, in + i)));
+ const auto m1 =
+ RebindMask(di, Eq(broadcasted, LoadU(d, in + i + N)));
+ const auto m2 =
+ RebindMask(di, Eq(broadcasted, LoadU(d, in + i + 2 * N)));
+ const auto m3 =
+ RebindMask(di, Eq(broadcasted, LoadU(d, in + i + 3 * N)));
+ acc0 = MaskedAddOr(acc0, m0, acc0, k1);
+ acc1 = MaskedAddOr(acc1, m1, acc1, k1);
+ acc2 = MaskedAddOr(acc2, m2, acc2, k1);
+ acc3 = MaskedAddOr(acc3, m3, acc3, k1);
+ }
+ } else {
+ for (; i <= cap - 4 * N; i += 4 * N) {
+ const auto v0 = VecFromMask(d, Eq(broadcasted, LoadU(d, in + i)));
+ const auto v1 =
+ VecFromMask(d, Eq(broadcasted, LoadU(d, in + i + N)));
+ const auto v2 =
+ VecFromMask(d, Eq(broadcasted, LoadU(d, in + i + 2 * N)));
+ const auto v3 =
+ VecFromMask(d, Eq(broadcasted, LoadU(d, in + i + 3 * N)));
+ acc0 = Add(acc0, BitCast(di, v0));
+ acc1 = Add(acc1, BitCast(di, v1));
+ acc2 = Add(acc2, BitCast(di, v2));
+ acc3 = Add(acc3, BitCast(di, v3));
+ }
+ }
+ const auto mul = Set(di, HWY_NATIVE_MASK ? TI{1} : TI{-1});
+ wide_sum = SatWidenMulPairwiseAccumulate(di32, acc0, mul, wide_sum);
+ wide_sum = SatWidenMulPairwiseAccumulate(di32, acc1, mul, wide_sum);
+ wide_sum = SatWidenMulPairwiseAccumulate(di32, acc2, mul, wide_sum);
+ wide_sum = SatWidenMulPairwiseAccumulate(di32, acc3, mul, wide_sum);
+ }
+ }
+ total += static_cast<size_t>(ReduceSum(di32, wide_sum));
+ } else {
+ // Lane type wide enough to accumulate directly
+ if (count >= 4 * N) {
+ VI acc0 = Zero(di);
+ VI acc1 = Zero(di);
+ VI acc2 = Zero(di);
+ VI acc3 = Zero(di);
+
+ if constexpr (HWY_NATIVE_MASK) {
+ const auto k1 = Set(di, TI{1});
+ for (; i <= count - 4 * N; i += 4 * N) {
+ const auto m0 = RebindMask(di, Eq(broadcasted, LoadU(d, in + i)));
+ const auto m1 = RebindMask(di, Eq(broadcasted, LoadU(d, in + i + N)));
+ const auto m2 =
+ RebindMask(di, Eq(broadcasted, LoadU(d, in + i + 2 * N)));
+ const auto m3 =
+ RebindMask(di, Eq(broadcasted, LoadU(d, in + i + 3 * N)));
+ acc0 = MaskedAddOr(acc0, m0, acc0, k1);
+ acc1 = MaskedAddOr(acc1, m1, acc1, k1);
+ acc2 = MaskedAddOr(acc2, m2, acc2, k1);
+ acc3 = MaskedAddOr(acc3, m3, acc3, k1);
+ }
+ acc0 = Add(Add(acc0, acc1), Add(acc2, acc3));
+ } else {
+ for (; i <= count - 4 * N; i += 4 * N) {
+ const auto v0 = VecFromMask(d, Eq(broadcasted, LoadU(d, in + i)));
+ const auto v1 = VecFromMask(d, Eq(broadcasted, LoadU(d, in + i + N)));
+ const auto v2 =
+ VecFromMask(d, Eq(broadcasted, LoadU(d, in + i + 2 * N)));
+ const auto v3 =
+ VecFromMask(d, Eq(broadcasted, LoadU(d, in + i + 3 * N)));
+ acc0 = Add(acc0, BitCast(di, v0));
+ acc1 = Add(acc1, BitCast(di, v1));
+ acc2 = Add(acc2, BitCast(di, v2));
+ acc3 = Add(acc3, BitCast(di, v3));
+ }
+ acc0 = Neg(Add(Add(acc0, acc1), Add(acc2, acc3)));
+ }
+ total += static_cast<size_t>(ReduceSum(di, acc0));
+ }
+ }
+
+ if (count >= N) {
+ for (; i <= count - N; i += N) {
+ total += CountTrue(d, Eq(broadcasted, LoadU(d, in + i)));
+ }
+ }
+
+ if (i != count) {
+ const size_t remaining = count - i;
+ HWY_DASSERT(0 != remaining && remaining < N);
+ const V v = LoadN(d, in + i, remaining);
+ total += CountTrue(d, And(Eq(broadcasted, v), FirstN(d, remaining)));
+ }
+
+ return total;
+}
+
+// Returns the number of elements in `in[0, count)` for which `func(d, vec)`
+// returns true.
+template <class D, class Func, typename T = TFromD<D>>
+size_t CountIf(D d, const T* HWY_RESTRICT in, size_t count, const Func& func) {
+ const size_t N = Lanes(d);
+ using V = Vec<D>;
+ const RebindToSigned<D> di;
+ using TI = TFromD<decltype(di)>;
+ using VI = Vec<decltype(di)>;
+ using MI = Mask<decltype(di)>;
+
+ size_t total = 0;
+ size_t i = 0;
+
+ // Min 4 lanes needed for two pairwise widenings, 8->16->32
+ if constexpr (sizeof(T) == 1 && HWY_MAX_LANES_D(D) >= 4) {
+ const VI k1 = Set(di, TI{1});
+ const RebindToUnsigned<decltype(di)> du;
+ const RepartitionToWide<decltype(di)> di16;
+ const Repartition<int32_t, D> di32;
+ using VI16 = Vec<decltype(di16)>;
+ using VI32 = Vec<decltype(di32)>;
+ VI32 wide_sum = Zero(di32);
+
+ if (count >= 4 * N && N >= 4) {
+ while (i <= count - 4 * N) {
+ VI acc0 = Zero(di);
+ VI acc1 = Zero(di);
+ VI acc2 = Zero(di);
+ VI acc3 = Zero(di);
+ const size_t cap = HWY_MIN(i + 128 * 4 * N, count);
+
+ if constexpr (HWY_NATIVE_MASK) {
+ for (; i <= cap - 4 * N; i += 4 * N) {
+ const MI m0 = RebindMask(di, func(d, LoadU(d, in + i)));
+ const MI m1 = RebindMask(di, func(d, LoadU(d, in + i + N)));
+ const MI m2 = RebindMask(di, func(d, LoadU(d, in + i + 2 * N)));
+ const MI m3 = RebindMask(di, func(d, LoadU(d, in + i + 3 * N)));
+ acc0 = MaskedAddOr(acc0, m0, acc0, k1);
+ acc1 = MaskedAddOr(acc1, m1, acc1, k1);
+ acc2 = MaskedAddOr(acc2, m2, acc2, k1);
+ acc3 = MaskedAddOr(acc3, m3, acc3, k1);
+ }
+ } else {
+ for (; i <= cap - 4 * N; i += 4 * N) {
+ const V v0 = VecFromMask(d, func(d, LoadU(d, in + i)));
+ const V v1 = VecFromMask(d, func(d, LoadU(d, in + i + N)));
+ const V v2 = VecFromMask(d, func(d, LoadU(d, in + i + 2 * N)));
+ const V v3 = VecFromMask(d, func(d, LoadU(d, in + i + 3 * N)));
+ acc0 = Add(acc0, BitCast(di, v0));
+ acc1 = Add(acc1, BitCast(di, v1));
+ acc2 = Add(acc2, BitCast(di, v2));
+ acc3 = Add(acc3, BitCast(di, v3));
+ }
+
+ acc0 = Neg(acc0);
+ acc1 = Neg(acc1);
+ acc2 = Neg(acc2);
+ acc3 = Neg(acc3);
+ }
+
+ const VI16 w0 = SatWidenMulPairwiseAdd(di16, BitCast(du, acc0), k1);
+ const VI16 w1 = SatWidenMulPairwiseAdd(di16, BitCast(du, acc1), k1);
+ const VI16 w2 = SatWidenMulPairwiseAdd(di16, BitCast(du, acc2), k1);
+ const VI16 w3 = SatWidenMulPairwiseAdd(di16, BitCast(du, acc3), k1);
+ const VI16 sum16 = Add(Add(w0, w1), Add(w2, w3));
+ wide_sum = SatWidenMulPairwiseAccumulate(
+ di32, sum16, Set(di16, int16_t{1}), wide_sum);
+ }
+ }
+ total += static_cast<size_t>(ReduceSum(di32, wide_sum));
+ } else if constexpr (sizeof(T) == 2 && HWY_MAX_LANES_D(D) >= 2) {
+ // Min 2 lanes needed for pairwise widening, 16->32
+ const Repartition<int32_t, D> di32;
+ using VI32 = Vec<decltype(di32)>;
+ VI32 wide_sum = Zero(di32);
+
+ if (count >= 4 * N && N >= 2) {
+ while (i <= count - 4 * N) {
+ VI acc0 = Zero(di);
+ VI acc1 = Zero(di);
+ VI acc2 = Zero(di);
+ VI acc3 = Zero(di);
+ const size_t cap = HWY_MIN(i + 32768 * 4 * N, count);
+
+ if constexpr (HWY_NATIVE_MASK) {
+ const VI k1 = Set(di, TI{1});
+ for (; i <= cap - 4 * N; i += 4 * N) {
+ const MI m0 = RebindMask(di, func(d, LoadU(d, in + i)));
+ const MI m1 = RebindMask(di, func(d, LoadU(d, in + i + N)));
+ const MI m2 = RebindMask(di, func(d, LoadU(d, in + i + 2 * N)));
+ const MI m3 = RebindMask(di, func(d, LoadU(d, in + i + 3 * N)));
+ acc0 = MaskedAddOr(acc0, m0, acc0, k1);
+ acc1 = MaskedAddOr(acc1, m1, acc1, k1);
+ acc2 = MaskedAddOr(acc2, m2, acc2, k1);
+ acc3 = MaskedAddOr(acc3, m3, acc3, k1);
+ }
+ } else {
+ for (; i <= cap - 4 * N; i += 4 * N) {
+ const V v0 = VecFromMask(d, func(d, LoadU(d, in + i)));
+ const V v1 = VecFromMask(d, func(d, LoadU(d, in + i + N)));
+ const V v2 = VecFromMask(d, func(d, LoadU(d, in + i + 2 * N)));
+ const V v3 = VecFromMask(d, func(d, LoadU(d, in + i + 3 * N)));
+ acc0 = Add(acc0, BitCast(di, v0));
+ acc1 = Add(acc1, BitCast(di, v1));
+ acc2 = Add(acc2, BitCast(di, v2));
+ acc3 = Add(acc3, BitCast(di, v3));
+ }
+ }
+ const VI mul = Set(di, HWY_NATIVE_MASK ? TI{1} : TI{-1});
+ wide_sum = SatWidenMulPairwiseAccumulate(di32, acc0, mul, wide_sum);
+ wide_sum = SatWidenMulPairwiseAccumulate(di32, acc1, mul, wide_sum);
+ wide_sum = SatWidenMulPairwiseAccumulate(di32, acc2, mul, wide_sum);
+ wide_sum = SatWidenMulPairwiseAccumulate(di32, acc3, mul, wide_sum);
+ }
+ }
+ total += static_cast<size_t>(ReduceSum(di32, wide_sum));
+ } else {
+ // Lane type wide enough to accumulate directly
+ if (count >= 4 * N) {
+ VI acc0 = Zero(di);
+ VI acc1 = Zero(di);
+ VI acc2 = Zero(di);
+ VI acc3 = Zero(di);
+
+ if constexpr (HWY_NATIVE_MASK) {
+ const VI k1 = Set(di, TI{1});
+ for (; i <= count - 4 * N; i += 4 * N) {
+ const MI m0 = RebindMask(di, func(d, LoadU(d, in + i)));
+ const MI m1 = RebindMask(di, func(d, LoadU(d, in + i + N)));
+ const MI m2 = RebindMask(di, func(d, LoadU(d, in + i + 2 * N)));
+ const MI m3 = RebindMask(di, func(d, LoadU(d, in + i + 3 * N)));
+ acc0 = MaskedAddOr(acc0, m0, acc0, k1);
+ acc1 = MaskedAddOr(acc1, m1, acc1, k1);
+ acc2 = MaskedAddOr(acc2, m2, acc2, k1);
+ acc3 = MaskedAddOr(acc3, m3, acc3, k1);
+ }
+ acc0 = Add(Add(acc0, acc1), Add(acc2, acc3));
+ } else {
+ for (; i <= count - 4 * N; i += 4 * N) {
+ const V v0 = VecFromMask(d, func(d, LoadU(d, in + i)));
+ const V v1 = VecFromMask(d, func(d, LoadU(d, in + i + N)));
+ const V v2 = VecFromMask(d, func(d, LoadU(d, in + i + 2 * N)));
+ const V v3 = VecFromMask(d, func(d, LoadU(d, in + i + 3 * N)));
+ acc0 = Add(acc0, BitCast(di, v0));
+ acc1 = Add(acc1, BitCast(di, v1));
+ acc2 = Add(acc2, BitCast(di, v2));
+ acc3 = Add(acc3, BitCast(di, v3));
+ }
+ acc0 = Neg(Add(Add(acc0, acc1), Add(acc2, acc3)));
+ }
+ total += static_cast<size_t>(ReduceSum(di, acc0));
+ }
+ }
+
+ if (count >= N) {
+ for (; i <= count - N; i += N) {
+ total += CountTrue(d, func(d, LoadU(d, in + i)));
+ }
+ }
+
+ if (i != count) {
+ const size_t remaining = count - i;
+ HWY_DASSERT(0 != remaining && remaining < N);
+ const V v = LoadN(d, in + i, remaining);
+ total += CountTrue(d, And(func(d, v), FirstN(d, remaining)));
+ }
+
+ return total;
+}
+
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#endif // HIGHWAY_HWY_CONTRIB_ALGO_COUNT_INL_H_
diff --git a/third_party/highway/hwy/contrib/algo/count_value_test.cc b/third_party/highway/hwy/contrib/algo/count_value_test.cc
new file mode 100644
index 0000000000..210d170c43
--- /dev/null
+++ b/third_party/highway/hwy/contrib/algo/count_value_test.cc
@@ -0,0 +1,165 @@
+// Copyright 2026 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stdio.h>
+
+#include <algorithm> // std::count, std::count_if
+#include <vector>
+
+#include "third_party/highway/hwy/aligned_allocator.h"
+#include "third_party/highway/hwy/base.h"
+
+// clang-format off
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE \
+ "hwy/contrib/algo/count_value_test.cc"
+#include "third_party/highway/hwy/foreach_target.h" // IWYU pragma: keep
+#include "third_party/highway/hwy/highway.h"
+#include "third_party/highway/hwy/contrib/algo/count-inl.h"
+#include "third_party/highway/hwy/tests/test_util-inl.h"
+// clang-format on
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+namespace {
+
+template <typename T>
+T Random(RandomState& rng) {
+ const int32_t bits = static_cast<int32_t>(Random32(&rng)) & 1023;
+ double val = (bits - 512) / 64.0;
+ if (!hwy::IsSigned<T>() && val < 0.0) {
+ val = -val;
+ }
+ return ConvertScalarTo<T>(val);
+}
+
+template <class Test>
+struct ForeachCountAndMisalign {
+ template <typename T, class D>
+ HWY_NOINLINE void operator()(T /*unused*/, D d) const {
+ RandomState rng;
+ const size_t N = Lanes(d);
+ const size_t misalignments[3] = {0, N / 4, 3 * N / 5};
+
+ std::vector<size_t> counts(AdjustedReps(512));
+ for (size_t& count : counts) {
+ count = static_cast<size_t>(rng()) % (16 * N + 1);
+ }
+ counts[0] = 0; // ensure we test count=0.
+
+ for (size_t count : counts) {
+ for (size_t m : misalignments) {
+ Test()(d, count, m, rng);
+ }
+ }
+ }
+};
+
+struct TestCount {
+ template <class D>
+ void operator()(D d, size_t count, size_t misalign, RandomState& rng) {
+ using T = TFromD<D>;
+ AlignedFreeUniquePtr<T[]> storage =
+ AllocateAligned<T>(HWY_MAX(1, misalign + count));
+ HWY_ASSERT(storage);
+ T* in = storage.get() + misalign;
+ for (size_t i = 0; i < count; ++i) {
+ in[i] = Random<T>(rng);
+ }
+
+ for (size_t pos = 0; pos < HWY_MIN(count, size_t{3}); ++pos) {
+ const T value = in[pos];
+ const size_t actual = Count(d, value, in, count);
+ const size_t expected =
+ static_cast<size_t>(std::count(in, in + count, value));
+
+ if (expected != actual) {
+ fprintf(stderr,
+ "%s count %d misalign %d: Count(%f) expected %d got %d\n",
+ hwy::TypeName(T(), Lanes(d)).c_str(), static_cast<int>(count),
+ static_cast<int>(misalign), ConvertScalarTo<double>(value),
+ static_cast<int>(expected), static_cast<int>(actual));
+ HWY_ASSERT(false);
+ }
+ }
+
+ HWY_ASSERT_EQ(size_t{0}, Count(d, ConvertScalarTo<T>(9), in, count));
+ }
+};
+
+void TestAllCount() {
+ // Widens to i32, hence require at least 4 i8 or 2 i16. We have an adapter for
+ // 128-bit and above which is stricter than required.
+ ForAllTypes(ForGE128Vectors<ForeachCountAndMisalign<TestCount>>());
+}
+
+struct TestCountIf {
+ template <class D>
+ void operator()(D d, size_t count, size_t misalign, RandomState& rng) {
+ using T = TFromD<D>;
+ AlignedFreeUniquePtr<T[]> storage =
+ AllocateAligned<T>(HWY_MAX(1, misalign + count));
+ HWY_ASSERT(storage);
+ T* in = storage.get() + misalign;
+ for (size_t i = 0; i < count; ++i) {
+ in[i] = Random<T>(rng);
+ }
+
+ const int min_val = IsSigned<T>() ? -9 : 0;
+ for (int val = min_val; val <= 9; ++val) {
+ const auto greater = [val](const auto d2, const auto v) HWY_ATTR {
+ return Gt(v, Set(d2, ConvertScalarTo<T>(val)));
+ };
+ const size_t actual = CountIf(d, in, count, greater);
+
+ const size_t expected = static_cast<size_t>(std::count_if(
+ in, in + count, [val](T x) { return x > ConvertScalarTo<T>(val); }));
+
+ if (expected != actual) {
+ fprintf(stderr,
+ "%s count %d misalign %d val %d: CountIf expected %d got %d\n",
+ hwy::TypeName(T(), Lanes(d)).c_str(), static_cast<int>(count),
+ static_cast<int>(misalign), val, static_cast<int>(expected),
+ static_cast<int>(actual));
+ HWY_ASSERT(false);
+ }
+ }
+ }
+};
+
+void TestAllCountIf() {
+ // Widens to i32, hence require at least 4 i8 or 2 i16. We have an adapter for
+ // 128-bit and above which is stricter than required.
+ ForAllTypes(ForGE128Vectors<ForeachCountAndMisalign<TestCountIf>>());
+}
+
+} // namespace
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace hwy {
+namespace {
+HWY_BEFORE_TEST(CountTest);
+HWY_EXPORT_AND_TEST_P(CountTest, TestAllCount);
+HWY_EXPORT_AND_TEST_P(CountTest, TestAllCountIf);
+HWY_AFTER_TEST();
+} // namespace
+} // namespace hwy
+HWY_TEST_MAIN();
+#endif // HWY_ONCE
diff --git a/third_party/highway/hwy/contrib/algo/find_test.cc b/third_party/highway/hwy/contrib/algo/find_test.cc
new file mode 100644
index 0000000000..ecbc6c8028
--- /dev/null
+++ b/third_party/highway/hwy/contrib/algo/find_test.cc
@@ -0,0 +1,200 @@
+// Copyright 2022 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stdio.h>
+
+#include <algorithm> // std::find_if
+#include <vector>
+
+#include "third_party/highway/hwy/aligned_allocator.h"
+#include "third_party/highway/hwy/base.h"
+#include "third_party/highway/hwy/print.h"
+
+// clang-format off
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "hwy/contrib/algo/find_test.cc"
+#include "third_party/highway/hwy/foreach_target.h" // IWYU pragma: keep
+#include "third_party/highway/hwy/highway.h"
+#include "third_party/highway/hwy/contrib/algo/find-inl.h"
+#include "third_party/highway/hwy/tests/test_util-inl.h"
+// clang-format on
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+namespace {
+
+// Returns random number in [-8, 8] - we use knowledge of the range to Find()
+// values we know are not present.
+template <typename T>
+T Random(RandomState& rng) {
+ const int32_t bits = static_cast<int32_t>(Random32(&rng)) & 1023;
+ double val = (bits - 512) / 64.0;
+ // Clamp negative to zero for unsigned types.
+ if (!hwy::IsSigned<T>() && val < 0.0) {
+ val = -val;
+ }
+ return ConvertScalarTo<T>(val);
+}
+
+// Invokes Test (e.g. TestFind) with all arg combinations.
+template <class Test>
+struct ForeachCountAndMisalign {
+ template <typename T, class D>
+ HWY_NOINLINE void operator()(T /*unused*/, D d) const {
+ RandomState rng;
+ const size_t N = Lanes(d);
+ const size_t misalignments[3] = {0, N / 4, 3 * N / 5};
+
+ // Find() checks 8 vectors at a time, so we want to cover a fairly large
+ // range without oversampling (checking every possible count).
+ std::vector<size_t> counts(AdjustedReps(512));
+ for (size_t& count : counts) {
+ count = static_cast<size_t>(rng()) % (16 * N + 1);
+ }
+ counts[0] = 0; // ensure we test count=0.
+
+ for (size_t count : counts) {
+ for (size_t m : misalignments) {
+ Test()(d, count, m, rng);
+ }
+ }
+ }
+};
+
+struct TestFind {
+ template <class D>
+ void operator()(D d, size_t count, size_t misalign, RandomState& rng) {
+ using T = TFromD<D>;
+ // Must allocate at least one even if count is zero.
+ AlignedFreeUniquePtr<T[]> storage =
+ AllocateAligned<T>(HWY_MAX(1, misalign + count));
+ HWY_ASSERT(storage);
+ T* in = storage.get() + misalign;
+ for (size_t i = 0; i < count; ++i) {
+ in[i] = Random<T>(rng);
+ }
+
+ // For each position, search for that element (which we know is there)
+ for (size_t pos = 0; pos < count; ++pos) {
+ const size_t actual = Find(d, in[pos], in, count);
+
+ // We may have found an earlier occurrence of the same value; ensure the
+ // value is the same, and that it is the first.
+ if (!IsEqual(in[pos], in[actual])) {
+ fprintf(stderr, "%s count %d, found %.15f at %d but wanted %.15f\n",
+ hwy::TypeName(T(), Lanes(d)).c_str(), static_cast<int>(count),
+ ConvertScalarTo<double>(in[actual]), static_cast<int>(actual),
+ ConvertScalarTo<double>(in[pos]));
+ HWY_ASSERT(false);
+ }
+ for (size_t i = 0; i < actual; ++i) {
+ if (IsEqual(in[i], in[pos])) {
+ fprintf(stderr, "%s count %d, found %f at %d but Find returned %d\n",
+ hwy::TypeName(T(), Lanes(d)).c_str(), static_cast<int>(count),
+ ConvertScalarTo<double>(in[i]), static_cast<int>(i),
+ static_cast<int>(actual));
+ HWY_ASSERT(false);
+ }
+ }
+ }
+
+ // Also search for values we know not to be present (out of range)
+ HWY_ASSERT_EQ(count, Find(d, ConvertScalarTo<T>(9), in, count));
+ HWY_ASSERT_EQ(count, Find(d, ConvertScalarTo<T>(-9), in, count));
+ }
+};
+
+void TestAllFind() {
+ ForAllTypes(ForPartialVectors<ForeachCountAndMisalign<TestFind>>());
+}
+
+struct TestFindIf {
+ template <class D>
+ void operator()(D d, size_t count, size_t misalign, RandomState& rng) {
+ using T = TFromD<D>;
+ using TI = MakeSigned<T>;
+ // Must allocate at least one even if count is zero.
+ AlignedFreeUniquePtr<T[]> storage =
+ AllocateAligned<T>(HWY_MAX(1, misalign + count));
+ HWY_ASSERT(storage);
+ T* in = storage.get() + misalign;
+ for (size_t i = 0; i < count; ++i) {
+ in[i] = Random<T>(rng);
+ HWY_ASSERT(ConvertScalarTo<TI>(in[i]) <= 8);
+ HWY_ASSERT(!hwy::IsSigned<T>() || ConvertScalarTo<TI>(in[i]) >= -8);
+ }
+
+ bool found_any = false;
+ bool not_found_any = false;
+
+ // unsigned T would be promoted to signed and compare greater than any
+ // negative val, whereas Set() would just cast to an unsigned value and the
+ // comparison remains unsigned, so avoid negative numbers there.
+ const int min_val = IsSigned<T>() ? -9 : 0;
+ // Includes out-of-range value 9 to test the not-found path.
+ for (int val = min_val; val <= 9; ++val) {
+ const auto greater = [val](const auto d2, const auto v) HWY_ATTR {
+ return Gt(v, Set(d2, ConvertScalarTo<T>(val)));
+ };
+ const size_t actual = FindIf(d, in, count, greater);
+ found_any |= actual < count;
+ not_found_any |= actual == count;
+
+ const auto pos = std::find_if(
+ in, in + count, [val](T x) { return x > ConvertScalarTo<T>(val); });
+ // Convert returned iterator to index.
+ const size_t expected = static_cast<size_t>(pos - in);
+ if (expected != actual) {
+ fprintf(stderr, "%s count %d val %d, expected %d actual %d\n",
+ hwy::TypeName(T(), Lanes(d)).c_str(), static_cast<int>(count),
+ val, static_cast<int>(expected), static_cast<int>(actual));
+ hwy::detail::PrintArray(hwy::detail::MakeTypeInfo<T>(), "in", in, count,
+ 0, count);
+ HWY_ASSERT(false);
+ }
+ }
+
+ // We will always not-find something due to val=9.
+ HWY_ASSERT(not_found_any);
+ // We'll find something unless the input is empty or {0} - because 0 > i
+ // is false for all i=[0,9].
+ if (count != 0 && in[0] != ConvertScalarTo<T>(0)) {
+ HWY_ASSERT(found_any);
+ }
+ }
+};
+
+void TestAllFindIf() {
+ ForAllTypes(ForPartialVectors<ForeachCountAndMisalign<TestFindIf>>());
+}
+
+} // namespace
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace hwy {
+namespace {
+HWY_BEFORE_TEST(FindTest);
+HWY_EXPORT_AND_TEST_P(FindTest, TestAllFind);
+HWY_EXPORT_AND_TEST_P(FindTest, TestAllFindIf);
+HWY_AFTER_TEST();
+} // namespace
+} // namespace hwy
+HWY_TEST_MAIN();
+#endif // HWY_ONCE
diff --git a/third_party/highway/hwy/contrib/algo/minmax-inl.h b/third_party/highway/hwy/contrib/algo/minmax-inl.h
new file mode 100644
index 0000000000..7815f655e5
--- /dev/null
+++ b/third_party/highway/hwy/contrib/algo/minmax-inl.h
@@ -0,0 +1,102 @@
+// Copyright 2026 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Per-target include guard
+#if defined(HIGHWAY_HWY_CONTRIB_ALGO_MINMAX_INL_H_) == \
+ defined(HWY_TARGET_TOGGLE) // NOLINT
+#ifdef HIGHWAY_HWY_CONTRIB_ALGO_MINMAX_INL_H_
+#undef HIGHWAY_HWY_CONTRIB_ALGO_MINMAX_INL_H_
+#else
+#define HIGHWAY_HWY_CONTRIB_ALGO_MINMAX_INL_H_
+#endif
+
+#include "third_party/highway/hwy/highway.h"
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+
+// Returns the minimum value in `in[0, count)` or PositiveInfOrHighestValue<T>() if count == 0.
+template <class D, typename T = TFromD<D>>
+T MinValue(D d, const T* HWY_RESTRICT in, size_t count) {
+ const size_t N = Lanes(d);
+ const T identity = hwy::PositiveInfOrHighestValue<T>();
+ const Vec<D> identity_vec = Set(d, identity);
+
+ Vec<D> acc0 = identity_vec;
+ Vec<D> acc1 = identity_vec;
+ Vec<D> acc2 = identity_vec;
+ Vec<D> acc3 = identity_vec;
+
+ size_t i = 0;
+ if (count >= 4 * N) {
+ for (; i <= count - 4 * N; i += 4 * N) {
+ acc0 = Min(acc0, LoadU(d, in + i));
+ acc1 = Min(acc1, LoadU(d, in + i + N));
+ acc2 = Min(acc2, LoadU(d, in + i + 2 * N));
+ acc3 = Min(acc3, LoadU(d, in + i + 3 * N));
+ }
+ }
+
+ acc0 = Min(Min(acc0, acc1), Min(acc2, acc3));
+
+ for (; i < count; i += N) {
+ const size_t remaining = count - i;
+ const size_t n = HWY_MIN(remaining, N);
+ acc0 = Min(acc0, LoadNOr(identity_vec, d, in + i, n));
+ }
+
+ return ReduceMin(d, acc0);
+}
+
+// Returns the maximum value in `in[0, count)` or NegativeInfOrLowestValue<T>() if count == 0.
+template <class D, typename T = TFromD<D>>
+T MaxValue(D d, const T* HWY_RESTRICT in, size_t count) {
+ const size_t N = Lanes(d);
+ const T identity = hwy::NegativeInfOrLowestValue<T>();
+ const Vec<D> identity_vec = Set(d, identity);
+
+ Vec<D> acc0 = identity_vec;
+ Vec<D> acc1 = identity_vec;
+ Vec<D> acc2 = identity_vec;
+ Vec<D> acc3 = identity_vec;
+
+ size_t i = 0;
+ if (count >= 4 * N) {
+ for (; i <= count - 4 * N; i += 4 * N) {
+ acc0 = Max(acc0, LoadU(d, in + i));
+ acc1 = Max(acc1, LoadU(d, in + i + N));
+ acc2 = Max(acc2, LoadU(d, in + i + 2 * N));
+ acc3 = Max(acc3, LoadU(d, in + i + 3 * N));
+ }
+ }
+
+ acc0 = Max(Max(acc0, acc1), Max(acc2, acc3));
+
+ for (; i < count; i += N) {
+ const size_t remaining = count - i;
+ const size_t n = HWY_MIN(remaining, N);
+ acc0 = Max(acc0, LoadNOr(identity_vec, d, in + i, n));
+ }
+
+ return ReduceMax(d, acc0);
+}
+
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#endif // HIGHWAY_HWY_CONTRIB_ALGO_MINMAX_INL_H_
diff --git a/third_party/highway/hwy/contrib/algo/minmax_value_test.cc b/third_party/highway/hwy/contrib/algo/minmax_value_test.cc
new file mode 100644
index 0000000000..fbbcc37382
--- /dev/null
+++ b/third_party/highway/hwy/contrib/algo/minmax_value_test.cc
@@ -0,0 +1,165 @@
+// Copyright 2026 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stdio.h>
+
+#include <vector>
+
+#include "third_party/highway/hwy/aligned_allocator.h"
+#include "third_party/highway/hwy/base.h"
+
+// clang-format off
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "hwy/contrib/algo/minmax_value_test.cc"
+#include "third_party/highway/hwy/foreach_target.h" // IWYU pragma: keep
+#include "third_party/highway/hwy/highway.h"
+#include "third_party/highway/hwy/contrib/algo/minmax-inl.h"
+#include "third_party/highway/hwy/tests/test_util-inl.h"
+// clang-format on
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+namespace {
+
+template <typename T>
+T Random(RandomState& rng) {
+ const int32_t bits = static_cast<int32_t>(Random32(&rng)) & 1023;
+ double val = (bits - 512) / 64.0;
+ if (!hwy::IsSigned<T>() && val < 0.0) {
+ val = -val;
+ }
+ return ConvertScalarTo<T>(val);
+}
+
+template <typename T>
+T ScalarMin(const T* in, size_t count) {
+ T result = hwy::PositiveInfOrHighestValue<T>();
+ for (size_t i = 0; i < count; ++i) {
+ if (in[i] < result) {
+ result = in[i];
+ }
+ }
+ return result;
+}
+
+template <typename T>
+T ScalarMax(const T* in, size_t count) {
+ T result = hwy::NegativeInfOrLowestValue<T>();
+ for (size_t i = 0; i < count; ++i) {
+ if (in[i] > result) {
+ result = in[i];
+ }
+ }
+ return result;
+}
+
+template <class Test>
+struct ForeachCountAndMisalign {
+ template <typename T, class D>
+ HWY_NOINLINE void operator()(T /*unused*/, D d) const {
+ RandomState rng;
+ const size_t N = Lanes(d);
+ const size_t misalignments[3] = {0, N / 4, 3 * N / 5};
+
+ std::vector<size_t> counts(AdjustedReps(512));
+ for (size_t& count : counts) {
+ count = static_cast<size_t>(rng()) % (16 * N + 1);
+ }
+ counts[0] = 0;
+
+ for (size_t count : counts) {
+ for (size_t m : misalignments) {
+ Test()(d, count, m, rng);
+ }
+ }
+ }
+};
+
+struct TestMinValue {
+ template <class D>
+ void operator()(D d, size_t count, size_t misalign, RandomState& rng) {
+ using T = TFromD<D>;
+ AlignedFreeUniquePtr<T[]> storage =
+ AllocateAligned<T>(HWY_MAX(1, misalign + count));
+ HWY_ASSERT(storage);
+ T* in = storage.get() + misalign;
+ for (size_t i = 0; i < count; ++i) {
+ in[i] = Random<T>(rng);
+ }
+
+ const T expected = ScalarMin(in, count);
+ const T actual = MinValue(d, in, count);
+
+ if (!IsEqual(expected, actual)) {
+ fprintf(stderr, "%s count %d misalign %d: MinValue expected %f got %f\n",
+ hwy::TypeName(T(), Lanes(d)).c_str(), static_cast<int>(count),
+ static_cast<int>(misalign), ConvertScalarTo<double>(expected),
+ ConvertScalarTo<double>(actual));
+ HWY_ASSERT(false);
+ }
+ }
+};
+
+struct TestMaxValue {
+ template <class D>
+ void operator()(D d, size_t count, size_t misalign, RandomState& rng) {
+ using T = TFromD<D>;
+ AlignedFreeUniquePtr<T[]> storage =
+ AllocateAligned<T>(HWY_MAX(1, misalign + count));
+ HWY_ASSERT(storage);
+ T* in = storage.get() + misalign;
+ for (size_t i = 0; i < count; ++i) {
+ in[i] = Random<T>(rng);
+ }
+
+ const T expected = ScalarMax(in, count);
+ const T actual = MaxValue(d, in, count);
+
+ if (!IsEqual(expected, actual)) {
+ fprintf(stderr, "%s count %d misalign %d: MaxValue expected %f got %f\n",
+ hwy::TypeName(T(), Lanes(d)).c_str(), static_cast<int>(count),
+ static_cast<int>(misalign), ConvertScalarTo<double>(expected),
+ ConvertScalarTo<double>(actual));
+ HWY_ASSERT(false);
+ }
+ }
+};
+
+void TestAllMinValue() {
+ ForAllTypes(ForPartialVectors<ForeachCountAndMisalign<TestMinValue>>());
+}
+
+void TestAllMaxValue() {
+ ForAllTypes(ForPartialVectors<ForeachCountAndMisalign<TestMaxValue>>());
+}
+
+} // namespace
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace hwy {
+namespace {
+HWY_BEFORE_TEST(MinMaxTest);
+HWY_EXPORT_AND_TEST_P(MinMaxTest, TestAllMinValue);
+HWY_EXPORT_AND_TEST_P(MinMaxTest, TestAllMaxValue);
+HWY_AFTER_TEST();
+} // namespace
+} // namespace hwy
+HWY_TEST_MAIN();
+#endif // HWY_ONCE
diff --git a/third_party/highway/hwy/contrib/algo/transform_test.cc b/third_party/highway/hwy/contrib/algo/transform_test.cc
new file mode 100644
index 0000000000..77a16fc9b1
--- /dev/null
+++ b/third_party/highway/hwy/contrib/algo/transform_test.cc
@@ -0,0 +1,403 @@
+// Copyright 2022 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <string.h> // memcpy
+
+#include <vector>
+
+#include "third_party/highway/hwy/aligned_allocator.h"
+#include "third_party/highway/hwy/base.h"
+
+// clang-format off
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "hwy/contrib/algo/transform_test.cc" //NOLINT
+#include "third_party/highway/hwy/foreach_target.h" // IWYU pragma: keep
+#include "third_party/highway/hwy/highway.h"
+#include "third_party/highway/hwy/contrib/algo/transform-inl.h"
+#include "third_party/highway/hwy/tests/test_util-inl.h"
+// clang-format on
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+namespace {
+
+constexpr double kAlpha = 1.5; // arbitrary scalar
+
+// Returns random floating-point number in [-8, 8) to ensure computations do
+// not exceed float32 precision.
+template <typename T>
+T Random(RandomState& rng) {
+ const int32_t bits = static_cast<int32_t>(Random32(&rng)) & 1023;
+ const double val = (bits - 512) / 64.0;
+ // Clamp negative to zero for unsigned types.
+ return ConvertScalarTo<T>(
+ HWY_MAX(ConvertScalarTo<double>(hwy::LowestValue<T>()), val));
+}
+
+// SCAL, AXPY names are from BLAS.
+template <typename T>
+HWY_NOINLINE void SimpleSCAL(const T* x, T* out, size_t count) {
+ for (size_t i = 0; i < count; ++i) {
+ out[i] = ConvertScalarTo<T>(ConvertScalarTo<T>(kAlpha) * x[i]);
+ }
+}
+
+template <typename T>
+HWY_NOINLINE void SimpleAXPY(const T* x, const T* y, T* out, size_t count) {
+ for (size_t i = 0; i < count; ++i) {
+ out[i] = ConvertScalarTo<T>(
+ ConvertScalarTo<T>(ConvertScalarTo<T>(kAlpha) * x[i]) + y[i]);
+ }
+}
+
+template <typename T>
+HWY_NOINLINE void SimpleFMA4(const T* x, const T* y, const T* z, T* out,
+ size_t count) {
+ for (size_t i = 0; i < count; ++i) {
+ out[i] = ConvertScalarTo<T>(x[i] * y[i] + z[i]);
+ }
+}
+
+// Invokes Test (e.g. TestTransform1) with all arg combinations. T comes from
+// ForFloatTypes.
+template <class Test>
+struct ForeachCountAndMisalign {
+ template <typename T, class D>
+ HWY_NOINLINE void operator()(T /*unused*/, D d) const {
+ RandomState rng;
+ const size_t N = Lanes(d);
+ const size_t misalignments[3] = {0, N / 4, 3 * N / 5};
+
+ for (size_t count = 0; count < 2 * N; ++count) {
+ for (size_t ma : misalignments) {
+ for (size_t mb : misalignments) {
+ Test()(d, count, ma, mb, rng);
+ }
+ }
+ }
+ }
+};
+
+// Fills an array with random values, placing a given sentinel value both before
+// (when misalignment space is available) and after. Requires an allocation of
+// at least count + misalign + 1 elements.
+template <typename T>
+T* FillRandom(AlignedFreeUniquePtr<T[]>& pa, size_t count, size_t misalign,
+ T sentinel, RandomState& rng) {
+ for (size_t i = 0; i < misalign; ++i) {
+ pa[i] = sentinel;
+ }
+
+ T* a = pa.get() + misalign;
+ for (size_t i = 0; i < count; ++i) {
+ a[i] = Random<T>(rng);
+ }
+ a[count] = sentinel;
+ return a;
+}
+
+// Output-only, no loads
+struct TestGenerate {
+ template <class D>
+ void operator()(D d, size_t count, size_t misalign_a, size_t /*misalign_b*/,
+ RandomState& /*rng*/) {
+ using T = TFromD<D>;
+ AlignedFreeUniquePtr<T[]> pa = AllocateAligned<T>(misalign_a + count + 1);
+ AlignedFreeUniquePtr<T[]> expected = AllocateAligned<T>(HWY_MAX(1, count));
+ HWY_ASSERT(pa && expected);
+
+ T* actual = pa.get() + misalign_a;
+
+ for (size_t i = 0; i < count; ++i) {
+ expected[i] = ConvertScalarTo<T>(2 * i);
+ }
+
+ // TODO(janwas): can we update the apply_to in HWY_PUSH_ATTRIBUTES so that
+ // the attribute also applies to lambdas? If so, remove HWY_ATTR.
+ const auto gen2 = [](const auto d2, const auto vidx)
+ HWY_ATTR { return BitCast(d2, Add(vidx, vidx)); };
+ actual[count] = ConvertScalarTo<T>(0); // sentinel
+ Generate(d, actual, count, gen2);
+ HWY_ASSERT_EQ(ConvertScalarTo<T>(0), actual[count]); // no write past end
+
+ const auto info = hwy::detail::MakeTypeInfo<T>();
+ const char* target_name = hwy::TargetName(HWY_TARGET);
+ hwy::detail::AssertArrayEqual(info, expected.get(), actual, count,
+ target_name, __FILE__, __LINE__);
+ }
+};
+
+// Input-only, no stores
+struct TestForeach {
+ template <class D>
+ void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
+ RandomState& /*rng*/) {
+ if (misalign_b != 0) return;
+ using T = TFromD<D>;
+ AlignedFreeUniquePtr<T[]> pa = AllocateAligned<T>(misalign_a + count + 1);
+ HWY_ASSERT(pa);
+
+ T* actual = pa.get() + misalign_a;
+ T max = hwy::LowestValue<T>();
+ for (size_t i = 0; i < count; ++i) {
+ actual[i] = hwy::ConvertScalarTo<T>(i <= count / 2 ? 2 * i : i);
+ max = HWY_MAX(max, actual[i]);
+ }
+
+ // Place sentinel values in the misalignment area and at the input's end.
+ for (size_t i = 0; i < misalign_a; ++i) {
+ pa[i] = ConvertScalarTo<T>(2 * count);
+ }
+ actual[count] = ConvertScalarTo<T>(2 * count);
+
+ const Vec<D> vmin = Set(d, hwy::LowestValue<T>());
+ // TODO(janwas): can we update the apply_to in HWY_PUSH_ATTRIBUTES so that
+ // the attribute also applies to lambdas? If so, remove HWY_ATTR.
+ Vec<D> vmax = vmin;
+ const auto func = [&vmax](const D, const Vec<D> v)
+ HWY_ATTR { vmax = Max(vmax, v); };
+ Foreach(d, actual, count, vmin, func);
+
+ const char* target_name = hwy::TargetName(HWY_TARGET);
+ AssertEqual(max, ReduceMax(d, vmax), target_name, __FILE__, __LINE__);
+ }
+};
+
+// Zero extra input arrays
+struct TestTransform {
+ template <class D>
+ void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
+ RandomState& rng) {
+ if (misalign_b != 0) return;
+ using T = TFromD<D>;
+ // Prevents error if size to allocate is zero.
+ AlignedFreeUniquePtr<T[]> pa =
+ AllocateAligned<T>(HWY_MAX(1, misalign_a + count + 1));
+ AlignedFreeUniquePtr<T[]> expected = AllocateAligned<T>(HWY_MAX(1, count));
+ HWY_ASSERT(pa && expected);
+
+ const T sentinel = ConvertScalarTo<T>(-42);
+ T* a = FillRandom(pa, count, misalign_a, sentinel, rng);
+ SimpleSCAL(a, expected.get(), count);
+
+ // TODO(janwas): can we update the apply_to in HWY_PUSH_ATTRIBUTES so that
+ // the attribute also applies to lambdas? If so, remove HWY_ATTR.
+ const auto scal = [](const auto d2, const auto v) HWY_ATTR {
+ return Mul(Set(d2, ConvertScalarTo<T>(kAlpha)), v);
+ };
+ Transform(d, a, count, scal);
+
+ const auto info = hwy::detail::MakeTypeInfo<T>();
+ const char* target_name = hwy::TargetName(HWY_TARGET);
+ hwy::detail::AssertArrayEqual(info, expected.get(), a, count, target_name,
+ __FILE__, __LINE__);
+
+ // Ensure no out-of-bound writes.
+ for (size_t i = 0; i < misalign_a; ++i) {
+ HWY_ASSERT_EQ(sentinel, pa[i]);
+ }
+ HWY_ASSERT_EQ(sentinel, a[count]);
+ }
+};
+
+// One extra input array
+struct TestTransform1 {
+ template <class D>
+ void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
+ RandomState& rng) {
+ using T = TFromD<D>;
+ // Prevents error if size to allocate is zero.
+ AlignedFreeUniquePtr<T[]> pa =
+ AllocateAligned<T>(HWY_MAX(1, misalign_a + count + 1));
+ AlignedFreeUniquePtr<T[]> pb =
+ AllocateAligned<T>(HWY_MAX(1, misalign_b + count));
+ AlignedFreeUniquePtr<T[]> expected = AllocateAligned<T>(HWY_MAX(1, count));
+ HWY_ASSERT(pa && pb && expected);
+
+ const T sentinel = ConvertScalarTo<T>(-42);
+ T* a = FillRandom(pa, count, misalign_a, sentinel, rng);
+ T* b = pb.get() + misalign_b;
+ for (size_t i = 0; i < count; ++i) {
+ b[i] = Random<T>(rng);
+ }
+
+ SimpleAXPY(a, b, expected.get(), count);
+
+ const auto axpy = [](const auto d2, const auto v, const auto v1) HWY_ATTR {
+ return MulAdd(Set(d2, ConvertScalarTo<T>(kAlpha)), v, v1);
+ };
+ Transform1(d, a, count, b, axpy);
+
+ AssertArraySimilar(expected.get(), a, count, hwy::TargetName(HWY_TARGET),
+ __FILE__, __LINE__);
+ // Ensure no out-of-bound writes.
+ for (size_t i = 0; i < misalign_a; ++i) {
+ HWY_ASSERT_EQ(sentinel, pa[i]);
+ }
+ HWY_ASSERT_EQ(sentinel, a[count]);
+ }
+};
+
+// Two extra input arrays
+struct TestTransform2 {
+ template <class D>
+ void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
+ RandomState& rng) {
+ using T = TFromD<D>;
+ // Prevents error if size to allocate is zero.
+ AlignedFreeUniquePtr<T[]> pa =
+ AllocateAligned<T>(HWY_MAX(1, misalign_a + count + 1));
+ AlignedFreeUniquePtr<T[]> pb =
+ AllocateAligned<T>(HWY_MAX(1, misalign_b + count));
+ AlignedFreeUniquePtr<T[]> pc =
+ AllocateAligned<T>(HWY_MAX(1, misalign_a + count));
+ AlignedFreeUniquePtr<T[]> expected = AllocateAligned<T>(HWY_MAX(1, count));
+ HWY_ASSERT(pa && pb && pc && expected);
+
+ const T sentinel = ConvertScalarTo<T>(-42);
+ T* a = FillRandom(pa, count, misalign_a, sentinel, rng);
+ T* b = pb.get() + misalign_b;
+ T* c = pc.get() + misalign_a;
+ for (size_t i = 0; i < count; ++i) {
+ b[i] = Random<T>(rng);
+ c[i] = Random<T>(rng);
+ }
+
+ SimpleFMA4(a, b, c, expected.get(), count);
+
+ const auto fma4 = [](auto /*d*/, auto v, auto v1, auto v2)
+ HWY_ATTR { return MulAdd(v, v1, v2); };
+ Transform2(d, a, count, b, c, fma4);
+
+ AssertArraySimilar(expected.get(), a, count, hwy::TargetName(HWY_TARGET),
+ __FILE__, __LINE__);
+ // Ensure no out-of-bound writes.
+ for (size_t i = 0; i < misalign_a; ++i) {
+ HWY_ASSERT_EQ(sentinel, pa[i]);
+ }
+ HWY_ASSERT_EQ(sentinel, a[count]);
+ }
+};
+
+template <typename T>
+class IfEq {
+ public:
+ IfEq(T val) : val_(val) {}
+
+ template <class D, class V>
+ Mask<D> operator()(D d, V v) const {
+ return Eq(v, Set(d, val_));
+ }
+
+ private:
+ T val_;
+};
+
+struct TestReplace {
+ template <class D>
+ void operator()(D d, size_t count, size_t misalign_a, size_t misalign_b,
+ RandomState& rng) {
+ if (misalign_b != 0) return;
+ if (count == 0) return;
+ using T = TFromD<D>;
+ AlignedFreeUniquePtr<T[]> pa = AllocateAligned<T>(misalign_a + count + 1);
+ AlignedFreeUniquePtr<T[]> pb = AllocateAligned<T>(count);
+ AlignedFreeUniquePtr<T[]> expected = AllocateAligned<T>(count);
+ HWY_ASSERT(pa && pb && expected);
+
+ const T sentinel = ConvertScalarTo<T>(-42);
+ T* a = FillRandom(pa, count, misalign_a, sentinel, rng);
+
+ std::vector<size_t> positions(AdjustedReps(count));
+ for (size_t& pos : positions) {
+ pos = static_cast<size_t>(rng()) % count;
+ }
+
+ for (size_t pos = 0; pos < count; ++pos) {
+ const T old_t = a[pos];
+ const T new_t = Random<T>(rng);
+ for (size_t i = 0; i < count; ++i) {
+ expected[i] = IsEqual(a[i], old_t) ? new_t : a[i];
+ }
+
+ // Copy so ReplaceIf gets the same input (and thus also outputs expected)
+ memcpy(pb.get(), a, count * sizeof(T));
+
+ Replace(d, a, count, new_t, old_t);
+ HWY_ASSERT_ARRAY_EQ(expected.get(), a, count);
+ // Ensure no out-of-bound writes.
+ for (size_t i = 0; i < misalign_a; ++i) {
+ HWY_ASSERT_EQ(sentinel, pa[i]);
+ }
+ HWY_ASSERT_EQ(sentinel, a[count]);
+
+ ReplaceIf(d, pb.get(), count, new_t, IfEq<T>(old_t));
+ HWY_ASSERT_ARRAY_EQ(expected.get(), pb.get(), count);
+ // Ensure no out-of-bound writes.
+ for (size_t i = 0; i < misalign_a; ++i) {
+ HWY_ASSERT_EQ(sentinel, pa[i]);
+ }
+ HWY_ASSERT_EQ(sentinel, a[count]);
+ }
+ }
+};
+
+void TestAllGenerate() {
+ // The test BitCast-s the indices, which does not work for floats.
+ ForIntegerTypes(ForPartialVectors<ForeachCountAndMisalign<TestGenerate>>());
+}
+
+void TestAllForeach() {
+ ForAllTypes(ForPartialVectors<ForeachCountAndMisalign<TestForeach>>());
+}
+
+void TestAllTransform() {
+ ForFloatTypes(ForPartialVectors<ForeachCountAndMisalign<TestTransform>>());
+}
+
+void TestAllTransform1() {
+ ForFloatTypes(ForPartialVectors<ForeachCountAndMisalign<TestTransform1>>());
+}
+
+void TestAllTransform2() {
+ ForFloatTypes(ForPartialVectors<ForeachCountAndMisalign<TestTransform2>>());
+}
+
+void TestAllReplace() {
+ ForFloatTypes(ForPartialVectors<ForeachCountAndMisalign<TestReplace>>());
+}
+
+} // namespace
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace hwy {
+namespace {
+HWY_BEFORE_TEST(TransformTest);
+HWY_EXPORT_AND_TEST_P(TransformTest, TestAllGenerate);
+HWY_EXPORT_AND_TEST_P(TransformTest, TestAllForeach);
+HWY_EXPORT_AND_TEST_P(TransformTest, TestAllTransform);
+HWY_EXPORT_AND_TEST_P(TransformTest, TestAllTransform1);
+HWY_EXPORT_AND_TEST_P(TransformTest, TestAllTransform2);
+HWY_EXPORT_AND_TEST_P(TransformTest, TestAllReplace);
+HWY_AFTER_TEST();
+} // namespace
+} // namespace hwy
+HWY_TEST_MAIN();
+#endif // HWY_ONCE
diff --git a/third_party/highway/hwy/contrib/bit_pack/bit_pack-inl.h b/third_party/highway/hwy/contrib/bit_pack/bit_pack-inl.h
index 0b3902e0ec..2f211608de 100644
--- a/third_party/highway/hwy/contrib/bit_pack/bit_pack-inl.h
+++ b/third_party/highway/hwy/contrib/bit_pack/bit_pack-inl.h
@@ -2623,7 +2623,7 @@ struct BitPackUnroller {
T* HWY_RESTRICT packed_out, const V& mask,
const V& frame_of_reference, V& in, V& out) {
// Avoid compilation errors and unnecessary template instantiation if
- // compiling in C++11 or C++14 mode
+ // compiling in C++14 mode
using NextUnroller = BitPackUnroller<
T, kBits, ((S <= B) ? (S + ((S < B) ? kBits : 0)) : (S % B)),
kLoadPos + static_cast<size_t>(S < B),
@@ -2669,7 +2669,7 @@ struct BitPackUnroller {
T* HWY_RESTRICT raw, const V& mask,
const V& frame_of_reference, V& in, V& out) {
// Avoid compilation errors and unnecessary template instantiation if
- // compiling in C++11 or C++14 mode
+ // compiling in C++14 mode
using NextUnroller = BitPackUnroller<
T, kBits, ((S <= B) ? (S + ((S < B) ? kBits : 0)) : (S % B)),
kLoadPos + static_cast<size_t>(S > B),
diff --git a/third_party/highway/hwy/contrib/bit_pack/bit_pack_test.cc b/third_party/highway/hwy/contrib/bit_pack/bit_pack_test.cc
new file mode 100644
index 0000000000..d59f24f0bd
--- /dev/null
+++ b/third_party/highway/hwy/contrib/bit_pack/bit_pack_test.cc
@@ -0,0 +1,244 @@
+// Copyright 2022 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stdio.h>
+
+#include <vector>
+
+#include "third_party/highway/hwy/aligned_allocator.h"
+#include "third_party/highway/hwy/base.h"
+#include "third_party/highway/hwy/nanobenchmark.h"
+
+// clang-format off
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "hwy/contrib/bit_pack/bit_pack_test.cc" // NOLINT
+#include "third_party/highway/hwy/foreach_target.h" // IWYU pragma: keep
+#include "third_party/highway/hwy/highway.h"
+#include "third_party/highway/hwy/timer.h"
+#include "third_party/highway/hwy/contrib/bit_pack/bit_pack-inl.h"
+#include "third_party/highway/hwy/tests/test_util-inl.h"
+// clang-format on
+
+#ifndef HWY_BIT_PACK_BENCHMARK
+#define HWY_BIT_PACK_BENCHMARK 0
+#endif
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+// Used to prevent running benchmark (slow) for partial vectors and targets
+// except the best available. Global, not per-target, hence must be outside
+// HWY_NAMESPACE. Declare first because HWY_ONCE is only true after some code
+// has been re-included.
+extern size_t last_bits;
+extern uint64_t best_target;
+#if HWY_ONCE
+size_t last_bits = 0;
+uint64_t best_target = ~0ull;
+#endif
+namespace HWY_NAMESPACE {
+namespace {
+
+template <size_t kBits, typename T>
+T Random(RandomState& rng) {
+ return ConvertScalarTo<T>(Random32(&rng) & kBits);
+}
+
+template <typename T>
+class Checker {
+ public:
+ explicit Checker(size_t num) { raw_.reserve(num); }
+ void NotifyRaw(T raw) { raw_.push_back(raw); }
+
+ void NotifyRawOutput(size_t bits, T raw) {
+ if (raw_[num_verified_] != raw) {
+ HWY_ABORT("%zu bits: pos %zu of %zu, expected %.0f actual %.0f\n", bits,
+ num_verified_, raw_.size(),
+ ConvertScalarTo<double>(raw_[num_verified_]),
+ ConvertScalarTo<double>(raw));
+ }
+ ++num_verified_;
+ }
+
+ private:
+ std::vector<T> raw_;
+ size_t num_verified_ = 0;
+};
+
+template <template <size_t> class PackT, size_t kVectors, size_t kBits>
+struct TestPack {
+ template <typename T, class D>
+ void operator()(T /* t */, D d) {
+ constexpr size_t kLoops = 16; // working set slightly larger than L1
+ const size_t N = Lanes(d);
+ RandomState rng(N * 129);
+ static_assert(kBits <= kVectors, "");
+ const size_t num_per_loop = N * kVectors;
+ const size_t num = num_per_loop * kLoops;
+ const size_t num_packed_per_loop = N * kBits;
+ const size_t num_packed = num_packed_per_loop * kLoops;
+ Checker<T> checker(num);
+ AlignedFreeUniquePtr<T[]> raw = hwy::AllocateAligned<T>(num);
+ AlignedFreeUniquePtr<T[]> raw2 = hwy::AllocateAligned<T>(num);
+ AlignedFreeUniquePtr<T[]> packed = hwy::AllocateAligned<T>(num_packed);
+ HWY_ASSERT(raw && raw2 && packed);
+
+ for (size_t i = 0; i < num; ++i) {
+ raw[i] = Random<kBits, T>(rng);
+ checker.NotifyRaw(raw[i]);
+ }
+
+ best_target = HWY_MIN(best_target, HWY_TARGET);
+ const bool run_bench = HWY_BIT_PACK_BENCHMARK && (kBits != last_bits) &&
+ (HWY_TARGET == best_target);
+ last_bits = kBits;
+
+ const PackT<kBits> func;
+
+ if (run_bench) {
+ const size_t kNumInputs = 1;
+ const size_t num_items = num * size_t(Unpredictable1());
+ const FuncInput inputs[kNumInputs] = {num_items};
+ Result results[kNumInputs];
+
+ Params p;
+ p.verbose = false;
+ p.max_evals = 7;
+ p.target_rel_mad = 0.002;
+ const size_t num_results = MeasureClosure(
+ [&](FuncInput) HWY_ATTR {
+ for (size_t i = 0, pi = 0; i < num;
+ i += num_per_loop, pi += num_packed_per_loop) {
+ func.Pack(d, raw.get() + i, packed.get() + pi);
+ }
+ T& val = packed.get()[Random32(&rng) % num_packed];
+ T zero = static_cast<T>(Unpredictable1() - 1);
+ val = static_cast<T>(val + zero);
+ for (size_t i = 0, pi = 0; i < num;
+ i += num_per_loop, pi += num_packed_per_loop) {
+ func.Unpack(d, packed.get() + pi, raw2.get() + i);
+ }
+ return raw2[Random32(&rng) % num];
+ },
+ inputs, kNumInputs, results, p);
+ if (num_results != kNumInputs) {
+ HWY_WARN("MeasureClosure failed.\n");
+ return;
+ }
+ // Print throughput for pack+unpack round trip
+ for (size_t i = 0; i < num_results; ++i) {
+ const size_t bytes_per_element = (kBits + 7) / 8;
+ const double bytes =
+ static_cast<double>(results[i].input * bytes_per_element);
+ const double seconds =
+ results[i].ticks / platform::InvariantTicksPerSecond();
+ printf("Bits:%2d elements:%3d GB/s:%4.1f (+/-%3.1f%%)\n",
+ static_cast<int>(kBits), static_cast<int>(results[i].input),
+ 1E-9 * bytes / seconds, results[i].variability * 100.0);
+ }
+ } else {
+ for (size_t i = 0, pi = 0; i < num;
+ i += num_per_loop, pi += num_packed_per_loop) {
+ func.Pack(d, raw.get() + i, packed.get() + pi);
+ }
+ T& val = packed.get()[Random32(&rng) % num_packed];
+ T zero = static_cast<T>(Unpredictable1() - 1);
+ val = static_cast<T>(val + zero);
+ for (size_t i = 0, pi = 0; i < num;
+ i += num_per_loop, pi += num_packed_per_loop) {
+ func.Unpack(d, packed.get() + pi, raw2.get() + i);
+ }
+ }
+
+ for (size_t i = 0; i < num; ++i) {
+ checker.NotifyRawOutput(kBits, raw2[i]);
+ }
+ }
+};
+
+void TestAllPack8() {
+ ForShrinkableVectors<TestPack<Pack8, 8, 1>>()(uint8_t());
+ ForShrinkableVectors<TestPack<Pack8, 8, 2>>()(uint8_t());
+ ForShrinkableVectors<TestPack<Pack8, 8, 3>>()(uint8_t());
+ ForShrinkableVectors<TestPack<Pack8, 8, 4>>()(uint8_t());
+ ForShrinkableVectors<TestPack<Pack8, 8, 5>>()(uint8_t());
+ ForShrinkableVectors<TestPack<Pack8, 8, 6>>()(uint8_t());
+ ForShrinkableVectors<TestPack<Pack8, 8, 7>>()(uint8_t());
+ ForShrinkableVectors<TestPack<Pack8, 8, 8>>()(uint8_t());
+}
+
+void TestAllPack16() {
+ ForShrinkableVectors<TestPack<Pack16, 16, 1>>()(uint16_t());
+ ForShrinkableVectors<TestPack<Pack16, 16, 2>>()(uint16_t());
+ ForShrinkableVectors<TestPack<Pack16, 16, 3>>()(uint16_t());
+ ForShrinkableVectors<TestPack<Pack16, 16, 4>>()(uint16_t());
+ ForShrinkableVectors<TestPack<Pack16, 16, 5>>()(uint16_t());
+ ForShrinkableVectors<TestPack<Pack16, 16, 6>>()(uint16_t());
+ ForShrinkableVectors<TestPack<Pack16, 16, 7>>()(uint16_t());
+ ForShrinkableVectors<TestPack<Pack16, 16, 8>>()(uint16_t());
+ ForShrinkableVectors<TestPack<Pack16, 16, 9>>()(uint16_t());
+ ForShrinkableVectors<TestPack<Pack16, 16, 10>>()(uint16_t());
+ ForShrinkableVectors<TestPack<Pack16, 16, 11>>()(uint16_t());
+ ForShrinkableVectors<TestPack<Pack16, 16, 12>>()(uint16_t());
+ ForShrinkableVectors<TestPack<Pack16, 16, 13>>()(uint16_t());
+ ForShrinkableVectors<TestPack<Pack16, 16, 14>>()(uint16_t());
+ ForShrinkableVectors<TestPack<Pack16, 16, 15>>()(uint16_t());
+ ForShrinkableVectors<TestPack<Pack16, 16, 16>>()(uint16_t());
+}
+
+void TestAllPack32() {
+ ForShrinkableVectors<TestPack<Pack32, 32, 1>>()(uint32_t());
+ ForShrinkableVectors<TestPack<Pack32, 32, 2>>()(uint32_t());
+ ForShrinkableVectors<TestPack<Pack32, 32, 6>>()(uint32_t());
+ ForShrinkableVectors<TestPack<Pack32, 32, 11>>()(uint32_t());
+ ForShrinkableVectors<TestPack<Pack32, 32, 16>>()(uint32_t());
+ ForShrinkableVectors<TestPack<Pack32, 32, 31>>()(uint32_t());
+ ForShrinkableVectors<TestPack<Pack32, 32, 32>>()(uint32_t());
+}
+
+void TestAllPack64() {
+ // Fails, but only on GCC 13.
+#if !(HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1400 && \
+ HWY_TARGET == HWY_RVV)
+ ForShrinkableVectors<TestPack<Pack64, 64, 1>>()(uint64_t());
+ ForShrinkableVectors<TestPack<Pack64, 64, 5>>()(uint64_t());
+ ForShrinkableVectors<TestPack<Pack64, 64, 12>>()(uint64_t());
+ ForShrinkableVectors<TestPack<Pack64, 64, 16>>()(uint64_t());
+ ForShrinkableVectors<TestPack<Pack64, 64, 27>>()(uint64_t());
+ ForShrinkableVectors<TestPack<Pack64, 64, 31>>()(uint64_t());
+ ForShrinkableVectors<TestPack<Pack64, 64, 33>>()(uint64_t());
+ ForShrinkableVectors<TestPack<Pack64, 64, 41>>()(uint64_t());
+ ForShrinkableVectors<TestPack<Pack64, 64, 61>>()(uint64_t());
+#endif
+}
+
+} // namespace
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace hwy {
+namespace {
+HWY_BEFORE_TEST(BitPackTest);
+HWY_EXPORT_AND_TEST_P(BitPackTest, TestAllPack8);
+HWY_EXPORT_AND_TEST_P(BitPackTest, TestAllPack16);
+HWY_EXPORT_AND_TEST_P(BitPackTest, TestAllPack32);
+HWY_EXPORT_AND_TEST_P(BitPackTest, TestAllPack64);
+HWY_AFTER_TEST();
+} // namespace
+} // namespace hwy
+HWY_TEST_MAIN();
+#endif // HWY_ONCE
diff --git a/third_party/highway/hwy/contrib/dot/dot_test.cc b/third_party/highway/hwy/contrib/dot/dot_test.cc
new file mode 100644
index 0000000000..f70eb35710
--- /dev/null
+++ b/third_party/highway/hwy/contrib/dot/dot_test.cc
@@ -0,0 +1,302 @@
+// Copyright 2021 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stdint.h>
+#include <stdio.h>
+#include <stdlib.h>
+
+#include "third_party/highway/hwy/aligned_allocator.h"
+#include "third_party/highway/hwy/base.h"
+
+// clang-format off
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "hwy/contrib/dot/dot_test.cc"
+#include "third_party/highway/hwy/foreach_target.h" // IWYU pragma: keep
+#include "third_party/highway/hwy/highway.h"
+#include "third_party/highway/hwy/contrib/dot/dot-inl.h"
+#include "third_party/highway/hwy/tests/test_util-inl.h"
+// clang-format on
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+namespace {
+
+template <typename T1, typename T2>
+HWY_NOINLINE T1 SimpleDot(const T1* pa, const T2* pb, size_t num) {
+ float sum = 0.0f;
+ for (size_t i = 0; i < num; ++i) {
+ sum += ConvertScalarTo<float>(pa[i]) * ConvertScalarTo<float>(pb[i]);
+ }
+ return ConvertScalarTo<T1>(sum);
+}
+
+HWY_MAYBE_UNUSED HWY_NOINLINE float SimpleDot(const float* pa,
+ const hwy::bfloat16_t* pb,
+ size_t num) {
+ float sum = 0.0f;
+ for (size_t i = 0; i < num; ++i) {
+ sum += pa[i] * F32FromBF16(pb[i]);
+ }
+ return sum;
+}
+
+// Overload is required because the generic template hits an internal compiler
+// error on aarch64 clang.
+HWY_MAYBE_UNUSED HWY_NOINLINE float SimpleDot(const bfloat16_t* pa,
+ const bfloat16_t* pb,
+ size_t num) {
+ float sum = 0.0f;
+ for (size_t i = 0; i < num; ++i) {
+ sum += F32FromBF16(pa[i]) * F32FromBF16(pb[i]);
+ }
+ return sum;
+}
+
+class TestDot {
+ // Computes/verifies one dot product.
+ template <int kAssumptions, class D>
+ void Test(D d, size_t num, size_t misalign_a, size_t misalign_b,
+ RandomState& rng) {
+ using T = TFromD<D>;
+ const size_t N = Lanes(d);
+ const auto random_t = [&rng]() {
+ const int32_t bits = static_cast<int32_t>(Random32(&rng)) & 1023;
+ return static_cast<float>(bits - 512) * (1.0f / 64);
+ };
+
+ const size_t padded =
+ (kAssumptions & Dot::kPaddedToVector) ? RoundUpTo(num, N) : num;
+ AlignedFreeUniquePtr<T[]> pa = AllocateAligned<T>(misalign_a + padded);
+ AlignedFreeUniquePtr<T[]> pb = AllocateAligned<T>(misalign_b + padded);
+ HWY_ASSERT(pa && pb);
+ T* a = pa.get() + misalign_a;
+ T* b = pb.get() + misalign_b;
+ size_t i = 0;
+ for (; i < num; ++i) {
+ a[i] = ConvertScalarTo<T>(random_t());
+ b[i] = ConvertScalarTo<T>(random_t());
+ }
+ // Fill padding - the values are not used, but avoids MSAN errors.
+ for (; i < padded; ++i) {
+ a[i] = ConvertScalarTo<T>(0);
+ b[i] = ConvertScalarTo<T>(0);
+ }
+
+ const double expected = SimpleDot(a, b, num);
+ const double magnitude = expected > 0.0 ? expected : -expected;
+ const double actual =
+ ConvertScalarTo<double>(Dot::Compute<kAssumptions>(d, a, b, num));
+ const double max = static_cast<double>(8 * 8 * num);
+ HWY_ASSERT(-max <= actual && actual <= max);
+ // Integer math is exact, so no tolerance.
+ const double tolerance =
+ IsFloat<T>() ? 96.0 * ConvertScalarTo<double>(Epsilon<T>()) *
+ HWY_MAX(magnitude, 1.0)
+ : 0;
+ HWY_ASSERT(expected - tolerance <= actual &&
+ actual <= expected + tolerance);
+ }
+
+ // Runs tests with various alignments.
+ template <int kAssumptions, class D>
+ void ForeachMisalign(D d, size_t num, RandomState& rng) {
+ const size_t N = Lanes(d);
+ const size_t misalignments[3] = {0, N / 4, 3 * N / 5};
+ for (size_t ma : misalignments) {
+ for (size_t mb : misalignments) {
+ Test<kAssumptions>(d, num, ma, mb, rng);
+ }
+ }
+ }
+
+ // Runs tests with various lengths compatible with the given assumptions.
+ template <int kAssumptions, class D>
+ void ForeachCount(D d, RandomState& rng) {
+ const size_t N = Lanes(d);
+ const size_t counts[] = {1,
+ 3,
+ 7,
+ 16,
+ HWY_MAX(N / 2, 1),
+ HWY_MAX(2 * N / 3, 1),
+ N,
+ N + 1,
+ 4 * N / 3,
+ 3 * N,
+ 8 * N,
+ 8 * N + 2};
+ for (size_t num : counts) {
+ if ((kAssumptions & Dot::kAtLeastOneVector) && num < N) continue;
+ if ((kAssumptions & Dot::kMultipleOfVector) && (num % N) != 0) continue;
+ ForeachMisalign<kAssumptions>(d, num, rng);
+ }
+ }
+
+ public:
+ // Must be inlined on aarch64 for bf16, else clang crashes.
+ template <class T, class D>
+ HWY_INLINE void operator()(T /*unused*/, D d) {
+ RandomState rng;
+
+ // All 8 combinations of the three length-related flags:
+ ForeachCount<0>(d, rng);
+ ForeachCount<Dot::kAtLeastOneVector>(d, rng);
+ ForeachCount<Dot::kMultipleOfVector>(d, rng);
+ ForeachCount<Dot::kMultipleOfVector | Dot::kAtLeastOneVector>(d, rng);
+ ForeachCount<Dot::kPaddedToVector>(d, rng);
+ ForeachCount<Dot::kPaddedToVector | Dot::kAtLeastOneVector>(d, rng);
+ ForeachCount<Dot::kPaddedToVector | Dot::kMultipleOfVector>(d, rng);
+ ForeachCount<Dot::kPaddedToVector | Dot::kMultipleOfVector |
+ Dot::kAtLeastOneVector>(d, rng);
+ }
+};
+
+class TestDotF32BF16 {
+ // Computes/verifies one dot product.
+ template <int kAssumptions, class D>
+ void Test(D d, size_t num, size_t misalign_a, size_t misalign_b,
+ RandomState& rng) {
+ using T = TFromD<D>;
+ using T2 = hwy::bfloat16_t;
+ const size_t N = Lanes(d);
+ const auto random_t = [&rng]() {
+ const int32_t bits = static_cast<int32_t>(Random32(&rng)) & 1023;
+ return static_cast<float>(bits - 512) * (1.0f / 64);
+ };
+
+ const size_t padded =
+ (kAssumptions & Dot::kPaddedToVector) ? RoundUpTo(num, N) : num;
+ AlignedFreeUniquePtr<T[]> pa = AllocateAligned<T>(misalign_a + padded);
+ AlignedFreeUniquePtr<T2[]> pb = AllocateAligned<T2>(misalign_b + padded);
+ HWY_ASSERT(pa && pb);
+ T* a = pa.get() + misalign_a;
+ T2* b = pb.get() + misalign_b;
+ size_t i = 0;
+ for (; i < num; ++i) {
+ a[i] = ConvertScalarTo<T>(random_t());
+ b[i] = ConvertScalarTo<T2>(random_t());
+ }
+ // Fill padding with NaN - the values are not used, but avoids MSAN errors.
+ for (; i < padded; ++i) {
+ ScalableTag<float> df1;
+ a[i] = ConvertScalarTo<T>(GetLane(NaN(df1)));
+ b[i] = ConvertScalarTo<T2>(GetLane(NaN(df1)));
+ }
+
+ double expected = SimpleDot(a, b, num);
+ const double magnitude = expected > 0.0 ? expected : -expected;
+ const double actual =
+ ConvertScalarTo<double>(Dot::Compute<kAssumptions>(d, a, b, num));
+ const double max = static_cast<double>(8 * 8 * num);
+ HWY_ASSERT(-max <= actual && actual <= max);
+ const double tolerance =
+ 64.0 * ConvertScalarTo<double>(Epsilon<T2>()) * HWY_MAX(magnitude, 1.0);
+ // Workaround for GCC 15 bug causing test failure. Without this, the
+ // `expected` value is way off.
+#if HWY_COMPILER_GCC_ACTUAL
+ __asm__ volatile("" : "+r"(expected));
+#endif
+ if (!(expected - tolerance <= actual && actual <= expected + tolerance)) {
+ HWY_ABORT("expected: %E actual: %E tolerance: %E", expected, actual,
+ tolerance);
+ }
+ }
+
+ // Runs tests with various alignments.
+ template <int kAssumptions, class D>
+ void ForeachMisalign(D d, size_t num, RandomState& rng) {
+ const size_t N = Lanes(d);
+ const size_t misalignments[3] = {0, N / 4, 3 * N / 5};
+ for (size_t ma : misalignments) {
+ for (size_t mb : misalignments) {
+ Test<kAssumptions>(d, num, ma, mb, rng);
+ }
+ }
+ }
+
+ // Runs tests with various lengths compatible with the given assumptions.
+ template <int kAssumptions, class D>
+ void ForeachCount(D d, RandomState& rng) {
+ const size_t N = Lanes(d);
+ const size_t counts[] = {1,
+ 3,
+ 7,
+ 16,
+ HWY_MAX(N / 2, 1),
+ HWY_MAX(2 * N / 3, 1),
+ N,
+ N + 1,
+ 4 * N / 3,
+ 3 * N,
+ 8 * N,
+ 8 * N + 2};
+ for (size_t num : counts) {
+ if ((kAssumptions & Dot::kAtLeastOneVector) && num < N) continue;
+ if ((kAssumptions & Dot::kMultipleOfVector) && (num % N) != 0) continue;
+ ForeachMisalign<kAssumptions>(d, num, rng);
+ }
+ }
+
+ public:
+ // Must be inlined on aarch64 for bf16, else clang crashes.
+ template <class T, class D>
+ HWY_INLINE void operator()(T /*unused*/, D d) {
+ RandomState rng;
+
+ // All 8 combinations of the three length-related flags:
+ ForeachCount<0>(d, rng);
+ ForeachCount<Dot::kAtLeastOneVector>(d, rng);
+ ForeachCount<Dot::kMultipleOfVector>(d, rng);
+ ForeachCount<Dot::kMultipleOfVector | Dot::kAtLeastOneVector>(d, rng);
+ ForeachCount<Dot::kPaddedToVector>(d, rng);
+ ForeachCount<Dot::kPaddedToVector | Dot::kAtLeastOneVector>(d, rng);
+ ForeachCount<Dot::kPaddedToVector | Dot::kMultipleOfVector>(d, rng);
+ ForeachCount<Dot::kPaddedToVector | Dot::kMultipleOfVector |
+ Dot::kAtLeastOneVector>(d, rng);
+ }
+};
+
+// All floating-point types, both arguments same.
+void TestAllDot() { ForFloatTypes(ForPartialVectors<TestDot>()); }
+
+// Mixed f32 and bf16 inputs.
+void TestAllDotF32BF16() { ForPartialVectors<TestDotF32BF16>()(float()); }
+
+// Both inputs bf16.
+void TestAllDotBF16() { ForShrinkableVectors<TestDot>()(bfloat16_t()); }
+
+// Both inputs i16.
+void TestAllDotI16() { ForShrinkableVectors<TestDot>()(int16_t()); }
+
+} // namespace
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace hwy {
+namespace {
+HWY_BEFORE_TEST(DotTest);
+HWY_EXPORT_AND_TEST_P(DotTest, TestAllDot);
+HWY_EXPORT_AND_TEST_P(DotTest, TestAllDotF32BF16);
+HWY_EXPORT_AND_TEST_P(DotTest, TestAllDotBF16);
+HWY_EXPORT_AND_TEST_P(DotTest, TestAllDotI16);
+HWY_AFTER_TEST();
+} // namespace
+} // namespace hwy
+HWY_TEST_MAIN();
+#endif // HWY_ONCE
diff --git a/third_party/highway/hwy/contrib/image/image.cc b/third_party/highway/hwy/contrib/image/image.cc
new file mode 100644
index 0000000000..11cf019ee6
--- /dev/null
+++ b/third_party/highway/hwy/contrib/image/image.cc
@@ -0,0 +1,144 @@
+// Copyright 2020 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "third_party/highway/hwy/contrib/image/image.h"
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <algorithm> // std::swap
+
+#include "third_party/highway/hwy/aligned_allocator.h"
+#include "third_party/highway/hwy/base.h"
+#include "third_party/highway/hwy/per_target.h"
+
+namespace hwy {
+
+size_t ImageBase::VectorSize() {
+ // Do not cache result - must return the current value, which may be greater
+ // than the first call if it was subject to DisableTargets!
+ return VectorBytes();
+}
+
+size_t ImageBase::BytesPerRow(const size_t xsize, const size_t sizeof_t) {
+ const size_t vec_size = VectorSize();
+ // Check for integer overflow in xsize * sizeof_t.
+ if (HWY_LIKELY(xsize != 0) && sizeof_t > SIZE_MAX / xsize) {
+ HWY_ABORT("ImageBase::BytesPerRow overflow: xsize=%zu, sizeof_t=%zu",
+ xsize, sizeof_t);
+ }
+ size_t valid_bytes = xsize * sizeof_t;
+
+ // Allow unaligned accesses starting at the last valid value - this may raise
+ // msan errors unless the user calls InitializePaddingForUnalignedAccesses.
+ // Skip for the scalar case because no extra lanes will be loaded.
+ if (vec_size != 1) {
+ HWY_DASSERT(vec_size >= sizeof_t);
+ valid_bytes += vec_size - sizeof_t;
+ }
+
+ // Round up to vector and cache line size.
+ const size_t align = HWY_MAX(vec_size, HWY_ALIGNMENT);
+ size_t bytes_per_row = RoundUpTo(valid_bytes, align);
+
+ // During the lengthy window before writes are committed to memory, CPUs
+ // guard against read after write hazards by checking the address, but
+ // only the lower 11 bits. We avoid a false dependency between writes to
+ // consecutive rows by ensuring their sizes are not multiples of 2 KiB.
+ // Avoid2K prevents the same problem for the planes of an Image3.
+ if (bytes_per_row % HWY_ALIGNMENT == 0) {
+ bytes_per_row += align;
+ }
+
+ HWY_DASSERT(bytes_per_row % align == 0);
+ return bytes_per_row;
+}
+
+ImageBase::ImageBase(const size_t xsize, const size_t ysize,
+ const size_t sizeof_t)
+ : xsize_(static_cast<uint32_t>(xsize)),
+ ysize_(static_cast<uint32_t>(ysize)),
+ bytes_(nullptr, AlignedFreer(&AlignedFreer::DoNothing, nullptr)) {
+ HWY_ASSERT(sizeof_t == 1 || sizeof_t == 2 || sizeof_t == 4 || sizeof_t == 8);
+ // Validate dimensions fit in uint32_t to prevent silent truncation.
+ HWY_ASSERT(xsize <= UINT32_MAX);
+ HWY_ASSERT(ysize <= UINT32_MAX);
+
+ bytes_per_row_ = 0;
+ // Dimensions can be zero, e.g. for lazily-allocated images. Only allocate
+ // if nonzero, because "zero" bytes still have padding/bookkeeping overhead.
+ if (xsize != 0 && ysize != 0) {
+ bytes_per_row_ = BytesPerRow(xsize, sizeof_t);
+ // Check for integer overflow in allocation size.
+ if (bytes_per_row_ > SIZE_MAX / ysize) {
+ HWY_ABORT("ImageBase: allocation overflow (%zu * %zu exceeds size_t)",
+ bytes_per_row_, ysize);
+ }
+ bytes_ = AllocateAligned<uint8_t>(bytes_per_row_ * ysize);
+ HWY_ASSERT(bytes_.get() != nullptr);
+ InitializePadding(sizeof_t, Padding::kRoundUp);
+ }
+}
+
+ImageBase::ImageBase(const size_t xsize, const size_t ysize,
+ const size_t bytes_per_row, void* const aligned)
+ : xsize_(static_cast<uint32_t>(xsize)),
+ ysize_(static_cast<uint32_t>(ysize)),
+ bytes_per_row_(bytes_per_row),
+ bytes_(static_cast<uint8_t*>(aligned),
+ AlignedFreer(&AlignedFreer::DoNothing, nullptr)) {
+ const size_t vec_size = VectorSize();
+ HWY_ASSERT(bytes_per_row % vec_size == 0);
+ HWY_ASSERT(reinterpret_cast<uintptr_t>(aligned) % vec_size == 0);
+}
+
+void ImageBase::InitializePadding(const size_t sizeof_t, Padding padding) {
+#if HWY_IS_MSAN || HWY_IDE
+ if (xsize_ == 0 || ysize_ == 0) return;
+
+ const size_t vec_size = VectorSize(); // Bytes, independent of sizeof_t!
+ if (vec_size == 1) return; // Scalar mode: no padding needed
+
+ const size_t valid_size = xsize_ * sizeof_t;
+ const size_t initialize_size = padding == Padding::kRoundUp
+ ? RoundUpTo(valid_size, vec_size)
+ : valid_size + vec_size - sizeof_t;
+ if (valid_size == initialize_size) return;
+
+ for (size_t y = 0; y < ysize_; ++y) {
+ uint8_t* HWY_RESTRICT row = static_cast<uint8_t*>(VoidRow(y));
+#if defined(__clang__) && (__clang_major__ <= 6)
+ // There's a bug in msan in clang-6 when handling AVX2 operations. This
+ // workaround allows tests to pass on msan, although it is slower and
+ // prevents msan warnings from uninitialized images.
+ hwy::ZeroBytes(row, initialize_size);
+#else
+ hwy::ZeroBytes(row + valid_size, initialize_size - valid_size);
+#endif // clang6
+ }
+#else
+ (void)sizeof_t;
+ (void)padding;
+#endif // HWY_IS_MSAN
+}
+
+void ImageBase::Swap(ImageBase& other) {
+ std::swap(xsize_, other.xsize_);
+ std::swap(ysize_, other.ysize_);
+ std::swap(bytes_per_row_, other.bytes_per_row_);
+ std::swap(bytes_, other.bytes_);
+}
+
+} // namespace hwy
diff --git a/third_party/highway/hwy/contrib/image/image_test.cc b/third_party/highway/hwy/contrib/image/image_test.cc
new file mode 100644
index 0000000000..5a216a328b
--- /dev/null
+++ b/third_party/highway/hwy/contrib/image/image_test.cc
@@ -0,0 +1,153 @@
+// Copyright (c) the JPEG XL Project
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "third_party/highway/hwy/contrib/image/image.h"
+
+#include <stddef.h>
+#include <stdio.h>
+#include <stdlib.h>
+
+#include <random>
+
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "hwy/contrib/image/image_test.cc"
+#include "third_party/highway/hwy/foreach_target.h" // IWYU pragma: keep
+#include "third_party/highway/hwy/highway.h"
+#include "third_party/highway/hwy/tests/test_util-inl.h"
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+namespace {
+
+// Ensure we can always write full aligned vectors.
+struct TestAlignedT {
+ template <typename T>
+ void operator()(T /*unused*/) const {
+ std::mt19937 rng(129);
+ std::uniform_int_distribution<int> dist(0, 16);
+ const ScalableTag<T> d;
+
+ for (size_t ysize = 1; ysize < 4; ++ysize) {
+ for (size_t xsize = 1; xsize < 64; ++xsize) {
+ Image<T> img(xsize, ysize);
+
+ for (size_t y = 0; y < ysize; ++y) {
+ T* HWY_RESTRICT row = img.MutableRow(y);
+ for (size_t x = 0; x < xsize; x += Lanes(d)) {
+ const auto values = Iota(d, dist(rng));
+ Store(values, d, row + x);
+ }
+ }
+
+ // Sanity check to prevent optimizing out the writes
+ const auto x = std::uniform_int_distribution<size_t>(0, xsize - 1)(rng);
+ const auto y = std::uniform_int_distribution<size_t>(0, ysize - 1)(rng);
+ HWY_ASSERT(img.ConstRow(y)[x] < 16 + Lanes(d));
+ }
+ }
+ }
+};
+
+void TestAligned() { ForUnsignedTypes(TestAlignedT()); }
+
+// Ensure we can write an unaligned vector starting at the last valid value.
+struct TestUnalignedT {
+ template <typename T>
+ void operator()(T /*unused*/) const {
+ std::mt19937 rng(129);
+ std::uniform_int_distribution<int> dist(0, 3);
+ const ScalableTag<T> d;
+
+ for (size_t ysize = 1; ysize < 4; ++ysize) {
+ for (size_t xsize = 1; xsize < 128; ++xsize) {
+ Image<T> img(xsize, ysize);
+ img.InitializePaddingForUnalignedAccesses();
+
+// This test reads padding, which only works if it was initialized,
+// which only happens in MSAN builds.
+#if HWY_IS_MSAN || HWY_IDE
+ // Initialize only the valid samples
+ for (size_t y = 0; y < ysize; ++y) {
+ T* HWY_RESTRICT row = img.MutableRow(y);
+ for (size_t x = 0; x < xsize; ++x) {
+ row[x] = ConvertScalarTo<T>(1u << dist(rng));
+ }
+ }
+
+ // Read padding bits
+ auto accum = Zero(d);
+ for (size_t y = 0; y < ysize; ++y) {
+ T* HWY_RESTRICT row = img.MutableRow(y);
+ for (size_t x = 0; x < xsize; ++x) {
+ accum = Or(accum, LoadU(d, row + x));
+ }
+ }
+
+ // Ensure padding was zero
+ const size_t N = Lanes(d);
+ auto lanes = AllocateAligned<T>(N);
+ HWY_ASSERT(lanes);
+ Store(accum, d, lanes.get());
+ for (size_t i = 0; i < N; ++i) {
+ HWY_ASSERT(lanes[i] < 16);
+ }
+#else // Check that writing padding does not overwrite valid samples
+ // Initialize only the valid samples
+ for (size_t y = 0; y < ysize; ++y) {
+ T* HWY_RESTRICT row = img.MutableRow(y);
+ for (size_t x = 0; x < xsize; ++x) {
+ row[x] = ConvertScalarTo<T>(x);
+ }
+ }
+
+ // Zero padding and rightmost sample
+ for (size_t y = 0; y < ysize; ++y) {
+ T* HWY_RESTRICT row = img.MutableRow(y);
+ StoreU(Zero(d), d, row + xsize - 1);
+ }
+
+ // Ensure no samples except the rightmost were overwritten
+ for (size_t y = 0; y < ysize; ++y) {
+ T* HWY_RESTRICT row = img.MutableRow(y);
+ for (size_t x = 0; x < xsize - 1; ++x) {
+ HWY_ASSERT_EQ(ConvertScalarTo<T>(x), row[x]);
+ }
+ }
+#endif
+ }
+ }
+ }
+};
+
+void TestUnaligned() { ForUnsignedTypes(TestUnalignedT()); }
+
+} // namespace
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace hwy {
+namespace {
+HWY_BEFORE_TEST(ImageTest);
+HWY_EXPORT_AND_TEST_P(ImageTest, TestAligned);
+HWY_EXPORT_AND_TEST_P(ImageTest, TestUnaligned);
+HWY_AFTER_TEST();
+} // namespace
+} // namespace hwy
+HWY_TEST_MAIN();
+#endif // HWY_ONCE
diff --git a/third_party/highway/hwy/contrib/math/fast_math-inl.h b/third_party/highway/hwy/contrib/math/fast_math-inl.h
new file mode 100644
index 0000000000..8d78403488
--- /dev/null
+++ b/third_party/highway/hwy/contrib/math/fast_math-inl.h
@@ -0,0 +1,1694 @@
+// Copyright 2026 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Include guard (still compiled once per target)
+#if defined(HIGHWAY_HWY_CONTRIB_MATH_FAST_MATH_INL_H_) == \
+ defined(HWY_TARGET_TOGGLE) // NOLINT
+#ifdef HIGHWAY_HWY_CONTRIB_MATH_FAST_MATH_INL_H_
+#undef HIGHWAY_HWY_CONTRIB_MATH_FAST_MATH_INL_H_
+#else
+#define HIGHWAY_HWY_CONTRIB_MATH_FAST_MATH_INL_H_
+#endif
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include "third_party/highway/hwy/highway.h"
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+
+namespace impl {
+
+// Port of reduce_angle_tan_SIMD
+template <class D, class V = VFromD<D>>
+HWY_INLINE void ReduceAngleTan(D d, V ang, V& x_red, V& sign) {
+ using T = TFromD<D>;
+ const auto pi = Set(d, static_cast<T>(3.14159265358979323846));
+ const auto zero = Set(d, static_cast<T>(0.0));
+ const auto one = Set(d, static_cast<T>(1.0));
+ const auto minus_one = Set(d, static_cast<T>(-1.0));
+
+ const auto inv_pi = Set(d, static_cast<T>(0.31830988618379067153777));
+
+ // Modulo pi
+ auto quotient = Mul(ang, inv_pi);
+ quotient = Round(quotient);
+ auto ang_mod = NegMulAdd(quotient, pi, ang);
+
+ // Determine sign
+ auto mask_neg = Lt(ang_mod, zero);
+ sign = IfThenElse(mask_neg, minus_one, one);
+
+ // Absolute value
+ x_red = Abs(ang_mod);
+}
+
+// Range reduction and exponent extraction for logarithm functions.
+// Normalizes x to y in [0.707, 1.414] and extracts the exponent as a float in
+// 'exp'. If kHandleSubnormals is true, scales subnormal inputs to prevent
+// underflow.
+template <bool kHandleSubnormals = true, class D, class V>
+HWY_INLINE void FastLogRangeReduction(D d, V x, V& y, V& exp) {
+ using T = TFromD<D>;
+ const RebindToSigned<D> di;
+ const RebindToUnsigned<D> du;
+ using TI = TFromD<decltype(di)>;
+ using VI = decltype(Zero(di));
+
+ constexpr bool kIsF32 = (sizeof(T) == 4);
+
+ const VI kExpMagicDiff = Set(
+ di, kIsF32
+ ? static_cast<TI>(0x3F800000L - 0x3F3504F3L)
+ : static_cast<TI>(0x3FF0000000000000LL - 0x3FE6A09E00000000LL));
+
+ MFromD<D> is_denormal;
+ if constexpr (kHandleSubnormals) {
+ const V kMinNormal =
+ Set(d, kIsF32 ? static_cast<T>(1.175494351e-38f)
+ : static_cast<T>(2.2250738585072014e-308));
+ const V kScale = Set(d, kIsF32 ? static_cast<T>(3.355443200e+7f)
+ : static_cast<T>(1.8014398509481984e+16));
+ is_denormal = Lt(x, kMinNormal);
+ x = MaskedMulOr(x, is_denormal, x, kScale);
+ } else {
+ (void)is_denormal;
+ }
+
+ auto exp_bits = Add(BitCast(di, x), kExpMagicDiff);
+
+ constexpr int kMantissaShift = kIsF32 ? 23 : 52;
+ const auto kBias = Set(di, kIsF32 ? 0x7F : 0x3FF);
+ const auto exp_int = Sub(
+ BitCast(di, ShiftRight<kMantissaShift>(BitCast(du, exp_bits))), kBias);
+ exp = ConvertTo(d, exp_int);
+
+ if constexpr (kHandleSubnormals) {
+ const V kExpScaleFloat =
+ Set(d, kIsF32 ? static_cast<T>(-25.0) : static_cast<T>(-54.0));
+ exp = MaskedAddOr(exp, is_denormal, exp, kExpScaleFloat);
+ }
+
+ const VI exp_int_shifted = ShiftLeft<kMantissaShift>(exp_int);
+ const VI y_bits = Sub(BitCast(di, x), exp_int_shifted);
+ y = BitCast(d, y_bits);
+}
+
+} // namespace impl
+
+namespace impl {
+
+template <class T>
+struct FastExpImpl {};
+
+template <>
+struct FastExpImpl<float> {
+ // Rounds float toward zero and returns as int32_t.
+ template <class D, class V = VFromD<D>, HWY_IF_F32_D(D)>
+ HWY_INLINE Vec<Rebind<int32_t, D>> ToInt32(D /*unused*/, V x) {
+ return ConvertInRangeTo(Rebind<int32_t, D>(), x);
+ }
+
+ // Computes 2^x, where x is an integer.
+ template <class D, class VI32 = Vec<Rebind<int32_t, D>>, HWY_IF_F32_D(D)>
+ HWY_INLINE Vec<D> Pow2I(D d, VI32 x) {
+ const Rebind<int32_t, D> di32;
+ const VI32 kOffset = Set(di32, 0x7F);
+ return BitCast(d, ShiftLeft<23>(Add(x, kOffset)));
+ }
+
+ // Sets the exponent of 'x' to 2^e.
+ template <class D, class V = VFromD<D>, class VI32 = Vec<Rebind<int32_t, D>>,
+ HWY_IF_F32_D(D)>
+ HWY_INLINE V LoadExpShortRange(D d, V x, VI32 e) {
+ const VI32 y = ShiftRight<1>(e);
+ return Mul(Mul(x, Pow2I(d, y)), Pow2I(d, Sub(e, y)));
+ }
+
+ template <class D, class V = VFromD<D>, class VI32 = Vec<Rebind<int32_t, D>>,
+ HWY_IF_F32_D(D)>
+ HWY_INLINE V ExpReduce(D d, V x, VI32 q) {
+ // kMinusLn2 ~= -ln(2)
+ const V kMinusLn2 = Set(d, -0.69314718056f);
+
+ // Extended precision modular arithmetic.
+ const V qf = ConvertTo(d, q);
+ return MulAdd(qf, kMinusLn2, x);
+ }
+
+ template <class D, class V = VFromD<D>, class VI32 = Vec<Rebind<int32_t, D>>,
+ HWY_IF_F32_D(D)>
+ HWY_INLINE V Exp2Reduce(D d, V x, VI32 q) {
+ const V qf = ConvertTo(d, q);
+ return Sub(x, qf);
+ }
+};
+
+#if HWY_HAVE_FLOAT64 && HWY_HAVE_INTEGER64
+template <>
+struct FastExpImpl<double> {
+ // Rounds double toward zero and returns as int32_t.
+ template <class D, class V = VFromD<D>, HWY_IF_F64_D(D)>
+ HWY_INLINE Vec<Rebind<int32_t, D>> ToInt32(D /*unused*/, V x) {
+ return DemoteInRangeTo(Rebind<int32_t, D>(), x);
+ }
+
+ // Computes 2^x, where x is an integer.
+ template <class D, class VI32 = Vec<Rebind<int32_t, D>>, HWY_IF_F64_D(D)>
+ HWY_INLINE Vec<D> Pow2I(D d, VI32 x) {
+ const Rebind<int32_t, D> di32;
+ const Rebind<int64_t, D> di64;
+ const VI32 kOffset = Set(di32, 0x3FF);
+ return BitCast(d, ShiftLeft<52>(PromoteTo(di64, Add(x, kOffset))));
+ }
+
+ // Sets the exponent of 'x' to 2^e.
+ template <class D, class V = VFromD<D>, class VI32 = Vec<Rebind<int32_t, D>>,
+ HWY_IF_F64_D(D)>
+ HWY_INLINE V LoadExpShortRange(D d, V x, VI32 e) {
+ const VI32 y = ShiftRight<1>(e);
+ return Mul(Mul(x, Pow2I(d, y)), Pow2I(d, Sub(e, y)));
+ }
+
+ template <class D, class V = VFromD<D>, class VI32 = Vec<Rebind<int32_t, D>>,
+ HWY_IF_F64_D(D)>
+ HWY_INLINE V ExpReduce(D d, V x, VI32 q) {
+ // kMinusLn2 ~= -ln(2)
+ const V kMinusLn2 = Set(d, -0.6931471805599453);
+
+ // Extended precision modular arithmetic.
+ const V qf = PromoteTo(d, q);
+ return MulAdd(qf, kMinusLn2, x);
+ }
+
+ template <class D, class V = VFromD<D>, class VI32 = Vec<Rebind<int32_t, D>>,
+ HWY_IF_F64_D(D)>
+ HWY_INLINE V Exp2Reduce(D d, V x, VI32 q) {
+ const V qf = PromoteTo(d, q);
+ return Sub(x, qf);
+ }
+};
+#endif
+
+} // namespace impl
+
+/**
+ * Fast approximation of tan(x).
+ *
+ * Valid Lane Types: float32, float64
+ * Max Relative Error: < 0.35% for angles equivalent to falling between [-89.99,
+ * +89.99] degrees (float32) and
+ * [-89.9999999, +89.9999999] degrees (float64).
+ * Valid Range: float32 : [-20, +20]rads
+ * float64 : [-39000, +39000]rads
+ *
+ * Note: Inputs extremely close to asymptotes may result in
+ * a sign flip due to precision limits.
+ *
+ * @return tangent of 'x'
+ */
+template <class D, class V>
+HWY_INLINE V FastTan(D d, V x) {
+ using T = TFromD<D>;
+
+ // Reduction
+ V x_red, sign;
+ impl::ReduceAngleTan(d, x, x_red, sign);
+
+ V b, c, d_val;
+
+ if constexpr (CanLookup8(d)) {
+ // --- Table Lookup ---
+ const auto scale = Set(d, static_cast<T>(3.8197186342));
+ auto idx_float = Floor(Mul(x_red, scale));
+
+ // Convert to Integer Vector (Signed)
+ auto idx_int = ConvertTo(RebindToSigned<D>(), idx_float);
+
+ HWY_ALIGN static constexpr T arr_b[8] = {
+ static_cast<T>(0),
+ static_cast<T>(0.0174532925199432955),
+ static_cast<T>(0.133808575986231942),
+ static_cast<T>(0.378736447682769484),
+ static_cast<T>(1.29590696960578966),
+ static_cast<T>(9.45968454580926554),
+ static_cast<T>(9.45968454580926554),
+ static_cast<T>(9.45968454580926554)};
+
+ HWY_ALIGN static constexpr T arr_c[8] = {
+ static_cast<T>(-0.0909090909092633431),
+ static_cast<T>(-0.400000000000000022),
+ static_cast<T>(-0.83333333333333337),
+ static_cast<T>(-1.29999999999999982),
+ static_cast<T>(-2.5),
+ static_cast<T>(-10.9999999999791349),
+ static_cast<T>(-10.9999999999791349),
+ static_cast<T>(-10.9999999999791349)};
+
+ HWY_ALIGN static constexpr T arr_d[8] = {
+ static_cast<T>(1.00277098842046231),
+ static_cast<T>(1.14668131856027444),
+ static_cast<T>(1.57370520888155374),
+ static_cast<T>(2.18515222349690053),
+ static_cast<T>(3.97062404828709958),
+ static_cast<T>(17.2787595947438639),
+ static_cast<T>(17.2787595947438639),
+ static_cast<T>(17.2787595947438639)};
+
+ // Since Lookup8 is available for HWY_MIN_BYTES / sizeof(T) >= 4, this
+ // condition covers all cases we encounter inside the top level if block
+ // inside FastTan
+ b = Lookup8(d, arr_b, idx_int);
+ c = Lookup8(d, arr_c, idx_int);
+ d_val = Lookup8(d, arr_d, idx_int);
+ } else {
+ // --- FALLBACK PATH: Blend Chain ---
+ if constexpr (HWY_REGISTERS >= 32) {
+ // Split into two parallel chains to reduce dependency latency.
+ const auto t0 = Set(d, static_cast<T>(0.2617993877995256));
+ const auto t1 = Set(d, static_cast<T>(0.5235987755990512));
+ const auto t2 = Set(d, static_cast<T>(0.7853981633985767));
+ const auto t3 = Set(d, static_cast<T>(1.0471975511981024));
+ const auto t4 = Set(d, static_cast<T>(1.3089969389976279));
+
+ // -- Chain 1: Indices 0 to 2 (Evaluated starting from t1 down to t0)
+ auto b_low = Set(d, static_cast<T>(0.133808575986231942)); // idx 2
+ auto c_low = Set(d, static_cast<T>(-0.83333333333333337));
+ auto d_low = Set(d, static_cast<T>(1.57370520888155374));
+
+ auto mask = Lt(x_red, t1);
+ b_low = IfThenElse(mask, Set(d, static_cast<T>(0.0174532925199432955)),
+ b_low);
+ c_low = IfThenElse(mask, Set(d, static_cast<T>(-0.400000000000000022)),
+ c_low);
+ d_low =
+ IfThenElse(mask, Set(d, static_cast<T>(1.14668131856027444)), d_low);
+
+ mask = Lt(x_red, t0);
+ b_low = IfThenZeroElse(mask, b_low);
+ c_low = IfThenElse(mask, Set(d, static_cast<T>(-0.0909090909092633431)),
+ c_low);
+ d_low =
+ IfThenElse(mask, Set(d, static_cast<T>(1.00277098842046231)), d_low);
+
+ // -- Chain 2: Indices 3 to 5 (Evaluated starting from t4 down to t3)
+ auto b_high = Set(d, static_cast<T>(9.45968454580926554)); // idx 5
+ auto c_high = Set(d, static_cast<T>(-10.9999999999791349));
+ auto d_high = Set(d, static_cast<T>(17.2787595947438639));
+
+ mask = Lt(x_red, t4);
+ b_high =
+ IfThenElse(mask, Set(d, static_cast<T>(1.29590696960578966)), b_high);
+ c_high = IfThenElse(mask, Set(d, static_cast<T>(-2.5)), c_high);
+ d_high =
+ IfThenElse(mask, Set(d, static_cast<T>(3.97062404828709958)), d_high);
+
+ mask = Lt(x_red, t3);
+ b_high = IfThenElse(mask, Set(d, static_cast<T>(0.378736447682769484)),
+ b_high);
+ c_high = IfThenElse(mask, Set(d, static_cast<T>(-1.29999999999999982)),
+ c_high);
+ d_high =
+ IfThenElse(mask, Set(d, static_cast<T>(2.18515222349690053)), d_high);
+
+ // -- Merge the two chains
+ auto merge_mask = Lt(x_red, t2);
+ b = IfThenElse(merge_mask, b_low, b_high);
+ c = IfThenElse(merge_mask, c_low, c_high);
+ d_val = IfThenElse(merge_mask, d_low, d_high);
+ } else {
+ b = Set(d, static_cast<T>(9.45968454580926554));
+ c = Set(d, static_cast<T>(-10.9999999999791349));
+ d_val = Set(d, static_cast<T>(17.2787595947438639));
+
+ auto mask = Lt(x_red, Set(d, static_cast<T>(1.3089969389976279)));
+ b = IfThenElse(mask, Set(d, static_cast<T>(1.29590696960578966)), b);
+ c = IfThenElse(mask, Set(d, static_cast<T>(-2.5)), c);
+ d_val =
+ IfThenElse(mask, Set(d, static_cast<T>(3.97062404828709958)), d_val);
+
+ mask = Lt(x_red, Set(d, static_cast<T>(1.0471975511981024)));
+ b = IfThenElse(mask, Set(d, static_cast<T>(0.378736447682769484)), b);
+ c = IfThenElse(mask, Set(d, static_cast<T>(-1.29999999999999982)), c);
+ d_val =
+ IfThenElse(mask, Set(d, static_cast<T>(2.18515222349690053)), d_val);
+
+ mask = Lt(x_red, Set(d, static_cast<T>(0.7853981633985767)));
+ b = IfThenElse(mask, Set(d, static_cast<T>(0.133808575986231942)), b);
+ c = IfThenElse(mask, Set(d, static_cast<T>(-0.83333333333333337)), c);
+ d_val =
+ IfThenElse(mask, Set(d, static_cast<T>(1.57370520888155374)), d_val);
+
+ mask = Lt(x_red, Set(d, static_cast<T>(0.5235987755990512)));
+ b = IfThenElse(mask, Set(d, static_cast<T>(0.0174532925199432955)), b);
+ c = IfThenElse(mask, Set(d, static_cast<T>(-0.400000000000000022)), c);
+ d_val =
+ IfThenElse(mask, Set(d, static_cast<T>(1.14668131856027444)), d_val);
+
+ mask = Lt(x_red, Set(d, static_cast<T>(0.2617993877995256)));
+ b = IfThenZeroElse(mask, b);
+ c = IfThenElse(mask, Set(d, static_cast<T>(-0.0909090909092633431)), c);
+ d_val =
+ IfThenElse(mask, Set(d, static_cast<T>(1.00277098842046231)), d_val);
+ }
+ }
+
+ // Math: y=(x + b)/(cx + d)
+ auto num = Add(x_red, b);
+ auto den = MulAdd(c, x_red, d_val);
+
+ // Guard against denominator underflow/sign-flip near singularities
+ T epsilon_val;
+ if constexpr (sizeof(T) == 8) {
+ epsilon_val = static_cast<T>(1e-15);
+ } else {
+ epsilon_val = static_cast<T>(1e-6);
+ }
+ const auto kMinDenom = Set(d, epsilon_val);
+ // We use Abs() because on the reduced interval [0, pi/2], the tangent
+ // magnitude must be positive. If the polynomial approximation calculates a
+ // negative denominator (overshoot), it is an error, and we force it to be
+ // positive.
+ den = Max(Abs(den), kMinDenom);
+
+ auto result = Div(num, den);
+
+ // Apply Sign
+ return CopySign(result, sign);
+}
+
+/**
+ * Fast approximation of atan(x).
+ *
+ * Valid Lane Types: float32, float64
+ * Max Relative Error: 0.0034%
+ * Average Relative Error : 0.0002% for float32
+ * 0.0002% for float64
+ * Valid Range: float32: [-1e35, +1e35]
+ * float64: [-1e305, +1e305]
+ *
+ * @return arctangent of 'x'
+ */
+// if kAssumePositive is true, we assume inputs are non-negative.
+template <bool kAssumePositive = false, class D, class V>
+HWY_INLINE V FastAtan(D d, V val) {
+ using T = TFromD<D>;
+
+ // Abs(val) and preserve sign for later (if needed)
+ V y;
+ if constexpr (kAssumePositive) {
+ y = val;
+ } else {
+ y = Abs(val);
+ }
+
+ const V kOne = Set(d, static_cast<T>(1.0));
+ const auto gt1_mask = Gt(y, kOne);
+ // Domain reduction: map [1, inf) to [0, 1]
+ const V mapped_y = MaskedDivOr(y, gt1_mask, kOne, y);
+
+ // Degree 4 polynomial for atan(x) / x over [0, 1]
+ const V c0 = Set(d, static_cast<T>(0.9999653683169244));
+ const V c1 = Set(d, static_cast<T>(-0.3315525587266785));
+ const V c2 = Set(d, static_cast<T>(0.1844770291758270));
+ const V c3 = Set(d, static_cast<T>(-0.0907475543745560));
+ const V c4 = Set(d, static_cast<T>(0.0232748721030191));
+
+ const V z = Mul(mapped_y, mapped_y);
+ const V z2 = Mul(z, z);
+
+ const V p01 = MulAdd(c1, z, c0);
+ const V p23 = MulAdd(c3, z, c2);
+ const V p234 = MulAdd(z2, c4, p23);
+ const V p = MulAdd(z2, p234, p01);
+
+ const V poly = Mul(mapped_y, p);
+
+ const V kPiOverTwo = Set(d, static_cast<T>(1.57079632679489661923));
+ auto result = MaskedSubOr(poly, gt1_mask, kPiOverTwo, poly);
+
+ if constexpr (kAssumePositive) {
+ return result;
+ } else {
+ return CopySign(result, val);
+ }
+}
+
+/**
+ * Fast approximation of atan2(y, x).
+ *
+ * Valid Lane Types: float32, float64
+ * Valid Range: As long as y/x is in Valid Range for FastAtan()
+ * Correctly handles negative zero, infinities, and NaN.
+ * @return atan2 of 'y', 'x'
+ */
+template <class D, class V>
+HWY_INLINE V FastAtan2(const D d, V y, V x) {
+ using T = TFromD<D>;
+ using M = MFromD<D>;
+
+ const V kPi = Set(d, static_cast<T>(3.14159265358979323846264));
+ const V kPiOverTwo = Set(d, static_cast<T>(1.57079632679489661923));
+ const V kOne = Set(d, static_cast<T>(1.0));
+ const V k0 = Zero(d);
+
+ const V ax = Abs(x);
+ const V ay = Abs(y);
+
+ const V num = Min(ax, ay);
+ const V den = Max(ax, ay);
+
+ const M is_inf = IsInf(num);
+ V mapped_y = MaskedDivOr(k0, Ne(den, k0), num, den);
+ mapped_y = IfThenElse(is_inf, kOne, mapped_y);
+
+ // Degree 4 polynomial for atan(x) / x over [0, 1]
+ const V c0 = Set(d, static_cast<T>(0.9999653683169244));
+ const V c1 = Set(d, static_cast<T>(-0.3315525587266785));
+ const V c2 = Set(d, static_cast<T>(0.1844770291758270));
+ const V c3 = Set(d, static_cast<T>(-0.0907475543745560));
+ const V c4 = Set(d, static_cast<T>(0.0232748721030191));
+
+ const V z = Mul(mapped_y, mapped_y);
+ const V z2 = Mul(z, z);
+
+ const V p01 = MulAdd(c1, z, c0);
+ const V p23 = MulAdd(c3, z, c2);
+ const V p234 = MulAdd(z2, c4, p23);
+ const V p = MulAdd(z2, p234, p01);
+
+ const V poly = Mul(mapped_y, p);
+
+ const M ay_gt_ax = Gt(ay, ax);
+ V angle = MaskedSubOr(poly, ay_gt_ax, kPiOverTwo, poly);
+
+ const M x_neg = Lt(x, k0);
+ angle = MaskedSubOr(angle, x_neg, kPi, angle);
+
+ const M is_nan = IsEitherNaN(y, x);
+ return IfThenElse(is_nan, NaN(d), CopySign(angle, y));
+}
+
+namespace impl {
+
+// Computes the index vector required for Lookup8 when the
+// intervals are uneven. Runs either an adder tree or a sequential add chain
+// depending on the number of registers.
+template <class D, class V>
+HWY_INLINE Vec<RebindToSigned<D>> ComputeIndices8Intervals(
+ D d, V y, const TFromD<D>* HWY_RESTRICT thresholds) {
+ using DI = RebindToSigned<D>;
+ auto idx_i = Zero(DI());
+ const auto one_i = Set(DI(), 1);
+
+ const auto t0 = Set(d, thresholds[0]);
+ const auto t1 = Set(d, thresholds[1]);
+ const auto t2 = Set(d, thresholds[2]);
+ const auto t3 = Set(d, thresholds[3]);
+ const auto t4 = Set(d, thresholds[4]);
+ const auto t5 = Set(d, thresholds[5]);
+ const auto t6 = Set(d, thresholds[6]);
+
+ const auto mask0 = RebindMask(DI(), Ge(y, t0));
+ const auto mask1 = RebindMask(DI(), Ge(y, t1));
+ const auto mask2 = RebindMask(DI(), Ge(y, t2));
+ const auto mask3 = RebindMask(DI(), Ge(y, t3));
+ const auto mask4 = RebindMask(DI(), Ge(y, t4));
+ const auto mask5 = RebindMask(DI(), Ge(y, t5));
+ const auto mask6 = RebindMask(DI(), Ge(y, t6));
+
+#ifdef HWY_NATIVE_MASK
+ if constexpr (HWY_REGISTERS >= 32) {
+ // Adder tree for native masks.
+ const auto sum0 = IfThenElseZero(mask0, one_i);
+ const auto sum01 = MaskedAddOr(sum0, mask1, sum0, one_i);
+
+ const auto sum2 = IfThenElseZero(mask2, one_i);
+ const auto sum23 = MaskedAddOr(sum2, mask3, sum2, one_i);
+
+ const auto sum4 = IfThenElseZero(mask4, one_i);
+ const auto sum45 = MaskedAddOr(sum4, mask5, sum4, one_i);
+
+ const auto sum6 = IfThenElseZero(mask6, one_i);
+
+ const auto sum03 = Add(sum01, sum23);
+ const auto sum46 = Add(sum45, sum6);
+
+ idx_i = Add(sum03, sum46);
+ } else {
+ // 2x unrolled sequential chain.
+ const auto sum0 = IfThenElseZero(mask0, one_i);
+ const auto sum02 = MaskedAddOr(sum0, mask2, sum0, one_i);
+ const auto sum024 = MaskedAddOr(sum02, mask4, sum02, one_i);
+ const auto sum0246 = MaskedAddOr(sum024, mask6, sum024, one_i);
+
+ const auto sum1 = IfThenElseZero(mask1, one_i);
+ const auto sum13 = MaskedAddOr(sum1, mask3, sum1, one_i);
+ const auto sum135 = MaskedAddOr(sum13, mask5, sum13, one_i);
+
+ idx_i = Add(sum0246, sum135);
+ }
+#else
+ (void)one_i;
+ if constexpr (HWY_REGISTERS >= 32) {
+ // Accummulate -1s in a tree to reduce latency
+ const auto m0 = VecFromMask(DI(), mask0);
+ const auto m1 = VecFromMask(DI(), mask1);
+ const auto m2 = VecFromMask(DI(), mask2);
+ const auto m3 = VecFromMask(DI(), mask3);
+ const auto m4 = VecFromMask(DI(), mask4);
+ const auto m5 = VecFromMask(DI(), mask5);
+ const auto m6 = VecFromMask(DI(), mask6);
+
+ const auto sum01 = Add(m0, m1);
+ const auto sum23 = Add(m2, m3);
+ const auto sum45 = Add(m4, m5);
+
+ const auto sum03 = Add(sum01, sum23);
+ const auto sum46 = Add(sum45, m6);
+
+ idx_i = Neg(Add(sum03, sum46));
+ } else {
+ // Subtract in a 2x unrolled chain
+ auto sum0246 = Sub(idx_i, VecFromMask(DI(), mask0));
+ sum0246 = Sub(sum0246, VecFromMask(DI(), mask2));
+ sum0246 = Sub(sum0246, VecFromMask(DI(), mask4));
+ sum0246 = Sub(sum0246, VecFromMask(DI(), mask6));
+
+ auto sum135 = Zero(DI());
+ sum135 = Sub(sum135, VecFromMask(DI(), mask1));
+ sum135 = Sub(sum135, VecFromMask(DI(), mask3));
+ sum135 = Sub(sum135, VecFromMask(DI(), mask5));
+
+ idx_i = Add(sum0246, sum135);
+ }
+#endif
+ return idx_i;
+}
+
+} // namespace impl
+
+/**
+ * Fast approximation of tanh(x).
+ *
+ * Valid Lane Types: float32, float64
+ * Max Relative Error : 0.0006% for float32, 0.0006% for float64
+ * Average Relative Error : 0.00002% for float32, 3e-6% for float64
+ * Max Relative Error for [-0.01, 0.01] : 0.00003%
+ * Average Relative Error for [-0.01, 0.01] : 3e-7%
+ * Valid Range: float32: [-1e35, +1e35]
+ * float64: [-1e305, +1e305]
+ *
+ * @return hyperbolic tangent of 'x'
+ */
+template <class D, class V>
+HWY_INLINE V FastTanh(D d, V val) {
+ using T = TFromD<D>;
+
+ // Abs(val) and preserve sign for later
+ auto y = Abs(val);
+
+ V a, b, c, d_val, e, f;
+
+ HWY_ALIGN static constexpr T thresholds[7] = {
+ static_cast<T>(0.168236118310606), static_cast<T>(0.365443754271396),
+ static_cast<T>(0.549306144334055), static_cast<T>(0.804718956217050),
+ static_cast<T>(1.203972804325936), static_cast<T>(2.969315202883957),
+ static_cast<T>(4.734657601441978)};
+
+ if constexpr (CanLookup8(d)) {
+ auto idx_i = impl::ComputeIndices8Intervals(d, y, thresholds);
+
+ HWY_ALIGN static constexpr T arr_a[8] = {
+ static_cast<T>(0.124683326807972),
+ static_cast<T>(0.0650303189120701),
+ static_cast<T>(-0.012865365312548),
+ static_cast<T>(-0.0600814996891072),
+ static_cast<T>(-0.0456234607880718),
+ static_cast<T>(0.00382424142943801),
+ static_cast<T>(0.000272471748022028),
+ static_cast<T>(8.15222218981581e-06)};
+
+ HWY_ALIGN static constexpr T arr_b[8] = {
+ static_cast<T>(0.00220499585798237),
+ static_cast<T>(0.0576751179885766),
+ static_cast<T>(0.197711481341899),
+ static_cast<T>(0.325772792901464),
+ static_cast<T>(0.256935827482807),
+ static_cast<T>(-0.0559644608292387),
+ static_cast<T>(-0.00594131995144914),
+ static_cast<T>(-0.000249501181979395)};
+
+ HWY_ALIGN static constexpr T arr_c[8] = {
+ static_cast<T>(-0.333553679129082), static_cast<T>(-0.355207242694923),
+ static_cast<T>(-0.457494048166093), static_cast<T>(-0.597530579050227),
+ static_cast<T>(-0.467928171646331), static_cast<T>(0.337223118234543),
+ static_cast<T>(0.0522965013918652), static_cast<T>(0.0030684314549725)};
+
+ HWY_ALIGN static constexpr T arr_d[8] = {
+ static_cast<T>(9.56515391952438e-06),
+ static_cast<T>(0.00439112349520911),
+ static_cast<T>(0.042321724042285),
+ static_cast<T>(0.119506151013929),
+ static_cast<T>(-0.00146222560953702),
+ static_cast<T>(-1.05430328521368),
+ static_cast<T>(-0.232902952602555),
+ static_cast<T>(-0.0189739119945283)};
+
+ HWY_ALIGN static constexpr T arr_e[8] = {
+ static_cast<T>(0.999999846647538), static_cast<T>(0.999545718850479),
+ static_cast<T>(0.992411248107215), static_cast<T>(0.970968585888465),
+ static_cast<T>(1.02705747414947), static_cast<T>(1.7260732085471),
+ static_cast<T>(0.526554911169526), static_cast<T>(0.0590636129554906)};
+
+ HWY_ALIGN static constexpr T arr_f[8] = {
+ static_cast<T>(4.72832130652986e-10),
+ static_cast<T>(1.91000262234958e-05),
+ static_cast<T>(0.000562934372196535),
+ static_cast<T>(0.00296459124064406),
+ static_cast<T>(-0.0073852515760983),
+ static_cast<T>(-0.195632701517054),
+ static_cast<T>(0.51454722991951),
+ static_cast<T>(0.92584756176511)};
+
+ a = Lookup8(d, arr_a, idx_i);
+ b = Lookup8(d, arr_b, idx_i);
+ c = Lookup8(d, arr_c, idx_i);
+ d_val = Lookup8(d, arr_d, idx_i);
+ e = Lookup8(d, arr_e, idx_i);
+ f = Lookup8(d, arr_f, idx_i);
+ } else {
+ const auto t0 = Set(d, thresholds[0]);
+ const auto t1 = Set(d, thresholds[1]);
+ const auto t2 = Set(d, thresholds[2]);
+ const auto t3 = Set(d, thresholds[3]);
+ const auto t4 = Set(d, thresholds[4]);
+ const auto t5 = Set(d, thresholds[5]);
+ const auto t6 = Set(d, thresholds[6]);
+ // --- FALLBACK PATH: Blend Chain ---
+ if constexpr (HWY_REGISTERS >= 32) {
+ // Split into two parallel chains to reduce dependency latency.
+
+ // -- Chain 1: Indices 0 to 3
+ auto a_low = Set(d, static_cast<T>(-0.0600814996891072)); // idx 3
+ auto b_low = Set(d, static_cast<T>(0.325772792901464));
+ auto c_low = Set(d, static_cast<T>(-0.597530579050227));
+ auto d_low = Set(d, static_cast<T>(0.119506151013929));
+ auto e_low = Set(d, static_cast<T>(0.970968585888465));
+ auto f_low = Set(d, static_cast<T>(0.00296459124064406));
+
+ auto mask = Lt(y, t2);
+ a_low =
+ IfThenElse(mask, Set(d, static_cast<T>(-0.012865365312548)), a_low);
+ b_low =
+ IfThenElse(mask, Set(d, static_cast<T>(0.197711481341899)), b_low);
+ c_low =
+ IfThenElse(mask, Set(d, static_cast<T>(-0.457494048166093)), c_low);
+ d_low =
+ IfThenElse(mask, Set(d, static_cast<T>(0.042321724042285)), d_low);
+ e_low =
+ IfThenElse(mask, Set(d, static_cast<T>(0.992411248107215)), e_low);
+ f_low =
+ IfThenElse(mask, Set(d, static_cast<T>(0.000562934372196535)), f_low);
+
+ mask = Lt(y, t1);
+ a_low =
+ IfThenElse(mask, Set(d, static_cast<T>(0.0650303189120701)), a_low);
+ b_low =
+ IfThenElse(mask, Set(d, static_cast<T>(0.0576751179885766)), b_low);
+ c_low =
+ IfThenElse(mask, Set(d, static_cast<T>(-0.355207242694923)), c_low);
+ d_low =
+ IfThenElse(mask, Set(d, static_cast<T>(0.00439112349520911)), d_low);
+ e_low =
+ IfThenElse(mask, Set(d, static_cast<T>(0.999545718850479)), e_low);
+ f_low =
+ IfThenElse(mask, Set(d, static_cast<T>(1.91000262234958e-05)), f_low);
+
+ mask = Lt(y, t0);
+ a_low =
+ IfThenElse(mask, Set(d, static_cast<T>(0.124683326807972)), a_low);
+ b_low =
+ IfThenElse(mask, Set(d, static_cast<T>(0.00220499585798237)), b_low);
+ c_low =
+ IfThenElse(mask, Set(d, static_cast<T>(-0.333553679129082)), c_low);
+ d_low =
+ IfThenElse(mask, Set(d, static_cast<T>(9.56515391952438e-06)), d_low);
+ e_low =
+ IfThenElse(mask, Set(d, static_cast<T>(0.999999846647538)), e_low);
+ f_low =
+ IfThenElse(mask, Set(d, static_cast<T>(4.72832130652986e-10)), f_low);
+
+ // -- Chain 2: Indices 4 to 7
+ auto a_high = Set(d, static_cast<T>(8.15222218981581e-06)); // idx 7
+ auto b_high = Set(d, static_cast<T>(-0.000249501181979395));
+ auto c_high = Set(d, static_cast<T>(0.0030684314549725));
+ auto d_high = Set(d, static_cast<T>(-0.0189739119945283));
+ auto e_high = Set(d, static_cast<T>(0.0590636129554906));
+ auto f_high = Set(d, static_cast<T>(0.92584756176511));
+
+ mask = Lt(y, t6);
+ a_high = IfThenElse(mask, Set(d, static_cast<T>(0.000272471748022028)),
+ a_high);
+ b_high = IfThenElse(mask, Set(d, static_cast<T>(-0.00594131995144914)),
+ b_high);
+ c_high =
+ IfThenElse(mask, Set(d, static_cast<T>(0.0522965013918652)), c_high);
+ d_high =
+ IfThenElse(mask, Set(d, static_cast<T>(-0.232902952602555)), d_high);
+ e_high =
+ IfThenElse(mask, Set(d, static_cast<T>(0.526554911169526)), e_high);
+ f_high =
+ IfThenElse(mask, Set(d, static_cast<T>(0.51454722991951)), f_high);
+
+ mask = Lt(y, t5);
+ a_high =
+ IfThenElse(mask, Set(d, static_cast<T>(0.00382424142943801)), a_high);
+ b_high =
+ IfThenElse(mask, Set(d, static_cast<T>(-0.0559644608292387)), b_high);
+ c_high =
+ IfThenElse(mask, Set(d, static_cast<T>(0.337223118234543)), c_high);
+ d_high =
+ IfThenElse(mask, Set(d, static_cast<T>(-1.05430328521368)), d_high);
+ e_high =
+ IfThenElse(mask, Set(d, static_cast<T>(1.7260732085471)), e_high);
+ f_high =
+ IfThenElse(mask, Set(d, static_cast<T>(-0.195632701517054)), f_high);
+
+ mask = Lt(y, t4);
+ a_high =
+ IfThenElse(mask, Set(d, static_cast<T>(-0.0456234607880718)), a_high);
+ b_high =
+ IfThenElse(mask, Set(d, static_cast<T>(0.256935827482807)), b_high);
+ c_high =
+ IfThenElse(mask, Set(d, static_cast<T>(-0.467928171646331)), c_high);
+ d_high = IfThenElse(mask, Set(d, static_cast<T>(-0.00146222560953702)),
+ d_high);
+ e_high =
+ IfThenElse(mask, Set(d, static_cast<T>(1.02705747414947)), e_high);
+ f_high =
+ IfThenElse(mask, Set(d, static_cast<T>(-0.0073852515760983)), f_high);
+
+ // Combine chains
+ mask = Lt(y, t3);
+ a = IfThenElse(mask, a_low, a_high);
+ b = IfThenElse(mask, b_low, b_high);
+ c = IfThenElse(mask, c_low, c_high);
+ d_val = IfThenElse(mask, d_low, d_high);
+ e = IfThenElse(mask, e_low, e_high);
+ f = IfThenElse(mask, f_low, f_high);
+ } else {
+ // Serial chain for lower register count
+ // Start with highest index (7)
+ a = Set(d, static_cast<T>(8.15222218981581e-06));
+ b = Set(d, static_cast<T>(-0.000249501181979395));
+ c = Set(d, static_cast<T>(0.0030684314549725));
+ d_val = Set(d, static_cast<T>(-0.0189739119945283));
+ e = Set(d, static_cast<T>(0.0590636129554906));
+ f = Set(d, static_cast<T>(0.92584756176511));
+
+ // If y < t6 (idx 6)
+ auto mask = Lt(y, t6);
+ a = IfThenElse(mask, Set(d, static_cast<T>(0.000272471748022028)), a);
+ b = IfThenElse(mask, Set(d, static_cast<T>(-0.00594131995144914)), b);
+ c = IfThenElse(mask, Set(d, static_cast<T>(0.0522965013918652)), c);
+ d_val =
+ IfThenElse(mask, Set(d, static_cast<T>(-0.232902952602555)), d_val);
+ e = IfThenElse(mask, Set(d, static_cast<T>(0.526554911169526)), e);
+ f = IfThenElse(mask, Set(d, static_cast<T>(0.51454722991951)), f);
+
+ // If y < t5 (idx 5)
+ mask = Lt(y, t5);
+ a = IfThenElse(mask, Set(d, static_cast<T>(0.00382424142943801)), a);
+ b = IfThenElse(mask, Set(d, static_cast<T>(-0.0559644608292387)), b);
+ c = IfThenElse(mask, Set(d, static_cast<T>(0.337223118234543)), c);
+ d_val =
+ IfThenElse(mask, Set(d, static_cast<T>(-1.05430328521368)), d_val);
+ e = IfThenElse(mask, Set(d, static_cast<T>(1.7260732085471)), e);
+ f = IfThenElse(mask, Set(d, static_cast<T>(-0.195632701517054)), f);
+
+ // If y < t4 (idx 4)
+ mask = Lt(y, t4);
+ a = IfThenElse(mask, Set(d, static_cast<T>(-0.0456234607880718)), a);
+ b = IfThenElse(mask, Set(d, static_cast<T>(0.256935827482807)), b);
+ c = IfThenElse(mask, Set(d, static_cast<T>(-0.467928171646331)), c);
+ d_val =
+ IfThenElse(mask, Set(d, static_cast<T>(-0.00146222560953702)), d_val);
+ e = IfThenElse(mask, Set(d, static_cast<T>(1.02705747414947)), e);
+ f = IfThenElse(mask, Set(d, static_cast<T>(-0.0073852515760983)), f);
+
+ // If y < t3 (idx 3)
+ mask = Lt(y, t3);
+ a = IfThenElse(mask, Set(d, static_cast<T>(-0.0600814996891072)), a);
+ b = IfThenElse(mask, Set(d, static_cast<T>(0.325772792901464)), b);
+ c = IfThenElse(mask, Set(d, static_cast<T>(-0.597530579050227)), c);
+ d_val =
+ IfThenElse(mask, Set(d, static_cast<T>(0.119506151013929)), d_val);
+ e = IfThenElse(mask, Set(d, static_cast<T>(0.970968585888465)), e);
+ f = IfThenElse(mask, Set(d, static_cast<T>(0.00296459124064406)), f);
+
+ // If y < t2 (idx 2)
+ mask = Lt(y, t2);
+ a = IfThenElse(mask, Set(d, static_cast<T>(-0.012865365312548)), a);
+ b = IfThenElse(mask, Set(d, static_cast<T>(0.197711481341899)), b);
+ c = IfThenElse(mask, Set(d, static_cast<T>(-0.457494048166093)), c);
+ d_val =
+ IfThenElse(mask, Set(d, static_cast<T>(0.042321724042285)), d_val);
+ e = IfThenElse(mask, Set(d, static_cast<T>(0.992411248107215)), e);
+ f = IfThenElse(mask, Set(d, static_cast<T>(0.000562934372196535)), f);
+
+ // If y < t1 (idx 1)
+ mask = Lt(y, t1);
+ a = IfThenElse(mask, Set(d, static_cast<T>(0.0650303189120701)), a);
+ b = IfThenElse(mask, Set(d, static_cast<T>(0.0576751179885766)), b);
+ c = IfThenElse(mask, Set(d, static_cast<T>(-0.355207242694923)), c);
+ d_val =
+ IfThenElse(mask, Set(d, static_cast<T>(0.00439112349520911)), d_val);
+ e = IfThenElse(mask, Set(d, static_cast<T>(0.999545718850479)), e);
+ f = IfThenElse(mask, Set(d, static_cast<T>(1.91000262234958e-05)), f);
+
+ // If y < t0 (idx 0)
+ mask = Lt(y, t0);
+ a = IfThenElse(mask, Set(d, static_cast<T>(0.124683326807972)), a);
+ b = IfThenElse(mask, Set(d, static_cast<T>(0.00220499585798237)), b);
+ c = IfThenElse(mask, Set(d, static_cast<T>(-0.333553679129082)), c);
+ d_val =
+ IfThenElse(mask, Set(d, static_cast<T>(9.56515391952438e-06)), d_val);
+ e = IfThenElse(mask, Set(d, static_cast<T>(0.999999846647538)), e);
+ f = IfThenElse(mask, Set(d, static_cast<T>(4.72832130652986e-10)), f);
+ }
+ }
+
+ // Math: f(y) = ay^5 + by^4 + cy^3 + dy^2 + ey + f
+ // Using Estrin's scheme
+ const auto y2 = Mul(y, y);
+ // term0 = e*y + f
+ const auto term0 = MulAdd(e, y, f);
+ // term1 = c*y + d
+ const auto term1 = MulAdd(c, y, d_val);
+ // term2 = a*y + b
+ const auto term2 = MulAdd(a, y, b);
+ // term3 = term2 * y2 + term1
+ const auto term3 = MulAdd(term2, y2, term1);
+ // result = term3 * y2 + term0
+ auto result = MulAdd(term3, y2, term0);
+
+ const auto kSmall = Set(d, static_cast<T>(0.001));
+ result = IfThenElse(Lt(y, kSmall), y, result);
+
+ const auto k1 = Set(d, static_cast<T>(1.0));
+ // We can take Min since the 5 degree polynomial approximation for index 7 is
+ // monotonically increasing, so for inputs >6.5 the polynomial approximation
+ // will output >1.0 allowing us to use Min() directly instead of IfThenElse()
+ result = Min(result, k1);
+
+ return CopySign(result, val); // Restore sign
+}
+
+namespace impl {
+
+// Fallback path used when Lookup8 cannot be used. Computes 4 final coefficient
+// vectors by running a blend chain (either serially or in parallel depending
+// on the number of registers)
+template <class D, class V>
+HWY_INLINE void FallbackBlendChain4Coeff(
+ D d, V y, const TFromD<D>* HWY_RESTRICT thresholds,
+ const TFromD<D>* HWY_RESTRICT arr_a, const TFromD<D>* HWY_RESTRICT arr_b,
+ const TFromD<D>* HWY_RESTRICT arr_c, const TFromD<D>* HWY_RESTRICT arr_d,
+ V& a, V& b, V& c, V& d_val) {
+ const auto t0 = Set(d, thresholds[0]);
+ const auto t1 = Set(d, thresholds[1]);
+ const auto t2 = Set(d, thresholds[2]);
+ const auto t3 = Set(d, thresholds[3]);
+ const auto t4 = Set(d, thresholds[4]);
+ const auto t5 = Set(d, thresholds[5]);
+ const auto t6 = Set(d, thresholds[6]);
+
+ if constexpr (HWY_REGISTERS >= 32) {
+ // Split into two parallel chains to reduce dependency latency.
+ // -- Chain 1: Indices 0 to 3 (Evaluated starting from t3 down to t0)
+ auto a_low = Set(d, arr_a[3]);
+ auto b_low = Set(d, arr_b[3]);
+ auto c_low = Set(d, arr_c[3]);
+ auto d_low = Set(d, arr_d[3]);
+
+ auto mask = Lt(y, t2);
+ a_low = IfThenElse(mask, Set(d, arr_a[2]), a_low);
+ b_low = IfThenElse(mask, Set(d, arr_b[2]), b_low);
+ c_low = IfThenElse(mask, Set(d, arr_c[2]), c_low);
+ d_low = IfThenElse(mask, Set(d, arr_d[2]), d_low);
+
+ mask = Lt(y, t1);
+ a_low = IfThenElse(mask, Set(d, arr_a[1]), a_low);
+ b_low = IfThenElse(mask, Set(d, arr_b[1]), b_low);
+ c_low = IfThenElse(mask, Set(d, arr_c[1]), c_low);
+ d_low = IfThenElse(mask, Set(d, arr_d[1]), d_low);
+
+ mask = Lt(y, t0);
+ a_low = IfThenElse(mask, Set(d, arr_a[0]), a_low);
+ b_low = IfThenElse(mask, Set(d, arr_b[0]), b_low);
+ c_low = IfThenElse(mask, Set(d, arr_c[0]), c_low);
+ d_low = IfThenElse(mask, Set(d, arr_d[0]), d_low);
+
+ // -- Chain 2: Indices 4 to 7 (Evaluated starting from t6 down to t4)
+ auto a_high = Set(d, arr_a[7]);
+ auto b_high = Set(d, arr_b[7]);
+ auto c_high = Set(d, arr_c[7]);
+ auto d_high = Set(d, arr_d[7]);
+
+ mask = Lt(y, t6);
+ a_high = IfThenElse(mask, Set(d, arr_a[6]), a_high);
+ b_high = IfThenElse(mask, Set(d, arr_b[6]), b_high);
+ c_high = IfThenElse(mask, Set(d, arr_c[6]), c_high);
+ d_high = IfThenElse(mask, Set(d, arr_d[6]), d_high);
+
+ mask = Lt(y, t5);
+ a_high = IfThenElse(mask, Set(d, arr_a[5]), a_high);
+ b_high = IfThenElse(mask, Set(d, arr_b[5]), b_high);
+ c_high = IfThenElse(mask, Set(d, arr_c[5]), c_high);
+ d_high = IfThenElse(mask, Set(d, arr_d[5]), d_high);
+
+ mask = Lt(y, t4);
+ a_high = IfThenElse(mask, Set(d, arr_a[4]), a_high);
+ b_high = IfThenElse(mask, Set(d, arr_b[4]), b_high);
+ c_high = IfThenElse(mask, Set(d, arr_c[4]), c_high);
+ d_high = IfThenElse(mask, Set(d, arr_d[4]), d_high);
+
+ // -- Merge the two chains
+ auto merge_mask = Lt(y, t3);
+ a = IfThenElse(merge_mask, a_low, a_high);
+ b = IfThenElse(merge_mask, b_low, b_high);
+ c = IfThenElse(merge_mask, c_low, c_high);
+ d_val = IfThenElse(merge_mask, d_low, d_high);
+ } else {
+ // Start with highest index (7)
+ a = Set(d, arr_a[7]);
+ b = Set(d, arr_b[7]);
+ c = Set(d, arr_c[7]);
+ d_val = Set(d, arr_d[7]);
+
+ // If y < t6 (idx 6)
+ auto mask = Lt(y, t6);
+ a = IfThenElse(mask, Set(d, arr_a[6]), a);
+ b = IfThenElse(mask, Set(d, arr_b[6]), b);
+ c = IfThenElse(mask, Set(d, arr_c[6]), c);
+ d_val = IfThenElse(mask, Set(d, arr_d[6]), d_val);
+
+ // If y < t5 (idx 5)
+ mask = Lt(y, t5);
+ a = IfThenElse(mask, Set(d, arr_a[5]), a);
+ b = IfThenElse(mask, Set(d, arr_b[5]), b);
+ c = IfThenElse(mask, Set(d, arr_c[5]), c);
+ d_val = IfThenElse(mask, Set(d, arr_d[5]), d_val);
+
+ // If y < t4 (idx 4)
+ mask = Lt(y, t4);
+ a = IfThenElse(mask, Set(d, arr_a[4]), a);
+ b = IfThenElse(mask, Set(d, arr_b[4]), b);
+ c = IfThenElse(mask, Set(d, arr_c[4]), c);
+ d_val = IfThenElse(mask, Set(d, arr_d[4]), d_val);
+
+ // If y < t3 (idx 3)
+ mask = Lt(y, t3);
+ a = IfThenElse(mask, Set(d, arr_a[3]), a);
+ b = IfThenElse(mask, Set(d, arr_b[3]), b);
+ c = IfThenElse(mask, Set(d, arr_c[3]), c);
+ d_val = IfThenElse(mask, Set(d, arr_d[3]), d_val);
+
+ // If y < t2 (idx 2)
+ mask = Lt(y, t2);
+ a = IfThenElse(mask, Set(d, arr_a[2]), a);
+ b = IfThenElse(mask, Set(d, arr_b[2]), b);
+ c = IfThenElse(mask, Set(d, arr_c[2]), c);
+ d_val = IfThenElse(mask, Set(d, arr_d[2]), d_val);
+
+ // If y < t1 (idx 1)
+ mask = Lt(y, t1);
+ a = IfThenElse(mask, Set(d, arr_a[1]), a);
+ b = IfThenElse(mask, Set(d, arr_b[1]), b);
+ c = IfThenElse(mask, Set(d, arr_c[1]), c);
+ d_val = IfThenElse(mask, Set(d, arr_d[1]), d_val);
+
+ // If y < t0 (idx 0)
+ mask = Lt(y, t0);
+ a = IfThenElse(mask, Set(d, arr_a[0]), a);
+ b = IfThenElse(mask, Set(d, arr_b[0]), b);
+ c = IfThenElse(mask, Set(d, arr_c[0]), c);
+ d_val = IfThenElse(mask, Set(d, arr_d[0]), d_val);
+ }
+}
+
+} // namespace impl
+
+/**
+ * Fast approximation of log(x).
+ *
+ * Valid Lane Types: float32, float64
+ * Max Relative Error: 0.0012%
+ * Average Relative Error: 3e-6% for float32, 1.8e-7% for float64
+ * Valid Range: float32: (0, +FLT_MAX]
+ * float64: (0, +DBL_MAX]
+ *
+ * @return natural logarithm of 'x'
+ */
+// If false, subnormals are treated as zero.
+template <bool kHandleSubnormals = true, class D, class V>
+HWY_INLINE V FastLog(D d, V x) {
+ using T = TFromD<D>;
+ const V kLn2 = Set(d, static_cast<T>(0.6931471805599453));
+ V y, exp;
+ impl::FastLogRangeReduction<kHandleSubnormals>(d, x, y, exp);
+
+ V approx;
+
+ V a, b, c, d_val;
+ // Centering the approximation around y=1.0 by using z = y - 1.0 significantly
+ // improves accuracy for low-degree polynomials compared to approximating
+ // log(y) directly.
+ const V z = Sub(y, Set(d, static_cast<T>(1.0)));
+
+ HWY_ALIGN static constexpr T arr_a[8] = {
+ static_cast<T>(0.78766119873962426), static_cast<T>(0.56395605885234767),
+ static_cast<T>(0.41755823888409732), static_cast<T>(0.31775546220809975),
+ static_cast<T>(0.24738922014476947), static_cast<T>(0.19635862241628779),
+ static_cast<T>(0.15845269802741027), static_cast<T>(0.12974944454997622)};
+
+ HWY_ALIGN static constexpr T arr_b[8] = {
+ static_cast<T>(-0.29967724727628686),
+ static_cast<T>(-0.43890059639104201),
+ static_cast<T>(-0.49106265092580692),
+ static_cast<T>(-0.50008637171949),
+ static_cast<T>(-0.48774751153444412),
+ static_cast<T>(-0.46524164863536055),
+ static_cast<T>(-0.43845622820596808),
+ static_cast<T>(-0.41055729878598496)};
+
+ HWY_ALIGN static constexpr T arr_c[8] = {
+ static_cast<T>(1.0358118335702087), static_cast<T>(1.0067153345685411),
+ static_cast<T>(1.000379283174812), static_cast<T>(1.0000110351938951),
+ static_cast<T>(0.99922180103707492), static_cast<T>(0.99586383428901692),
+ static_cast<T>(0.98951797571207256), static_cast<T>(0.98045123777070986)};
+
+ HWY_ALIGN static constexpr T arr_d[8] = {
+ static_cast<T>(0.0023082932745966296),
+ static_cast<T>(0.00026712584767189665),
+ static_cast<T>(5.4447452042709148e-06),
+ static_cast<T>(0),
+ static_cast<T>(1.8158320065986679e-05),
+ static_cast<T>(0.00018763480353754217),
+ static_cast<T>(0.0006917031865196592),
+ static_cast<T>(0.0016769113228540019)};
+
+ if constexpr (CanLookup8(d)) {
+ // --- Table Lookup ---
+ const auto scale = Set(d, static_cast<T>(11.3137085));
+ // Input is always non-negative, so Floor() + ConvertTo()
+ // can be replaced by direct ConvertTo() (truncation), which is faster.
+ // We use MulAdd(y, scale, -8.0) instead of Mul(Sub(y, lower_bound), scale)
+ // to save instructions. 0.70710678 * 11.3137085 ~= 8.0.
+ auto idx_i = ConvertInRangeTo(
+ RebindToSigned<D>(), MulAdd(y, scale, Set(d, static_cast<T>(-8.0))));
+
+ // Clamp index to 7 to handle overshoots
+ idx_i = Min(idx_i, Set(RebindToSigned<D>(), 7));
+
+ a = Lookup8(d, arr_a, idx_i);
+ b = Lookup8(d, arr_b, idx_i);
+ c = Lookup8(d, arr_c, idx_i);
+ d_val = Lookup8(d, arr_d, idx_i);
+ } else {
+ HWY_ALIGN static constexpr T thresholds[7] = {
+ static_cast<T>(0.7954951287634819), static_cast<T>(0.8838834764038688),
+ static_cast<T>(0.9722718240442556), static_cast<T>(1.0606601716846424),
+ static_cast<T>(1.1490485193250295), static_cast<T>(1.2374368669654163),
+ static_cast<T>(1.3258252146058032)};
+ impl::FallbackBlendChain4Coeff(d, y, thresholds, arr_a, arr_b, arr_c, arr_d,
+ a, b, c, d_val);
+ }
+ // Math: approx = (a*z + b)*z^2 + (c*z + d_val)
+ const auto z2 = Mul(z, z);
+ const auto pab = MulAdd(a, z, b);
+ const auto pcd = MulAdd(c, z, d_val);
+ approx = MulAdd(pab, z2, pcd);
+
+ return MulAdd(exp, kLn2, approx);
+}
+
+/**
+ * Fast approximation of exp(x).
+ *
+ * Valid Lane Types: float32, float64
+ * Max Relative Error: 0.0007% for float32 [-87, 88]
+ * Max Relative Error: 0.0007% for float64 [-708, 706]
+ * Average Relative Error: 0.00002% for float32 [-87, 88]
+ * Average Relative Error: 0.00001% for float64 [-708, 706]
+ * Max Relative Error for Subnormals: 2.4% for float32 [-FLT_MAX, -87]
+ * Max Relative Error for Subnormals: 0.006% for float64 [-DBL_MAX, -708]
+ * Valid Range: float32[-FLT_MAX, +88], float64[-DBL_MAX, +706]
+ *
+ * @return e^x
+ */
+template <bool kHandleSubnormals = true, class D, class V>
+HWY_INLINE V FastExp(D d, V x) {
+ using T = TFromD<D>;
+ impl::FastExpImpl<T> impl;
+
+ T lower_bound_val;
+ if constexpr (kHandleSubnormals) {
+ lower_bound_val = sizeof(T) == 4 ? -104.0 : -1000.0;
+ } else {
+ lower_bound_val = sizeof(T) == 4 ? -88.0 : -709.0;
+ }
+ const V kLowerBound = Set(d, static_cast<T>(lower_bound_val));
+
+ const V kHalf = Set(d, static_cast<T>(+0.5));
+ const V kNegZero = Set(d, static_cast<T>(-0.0));
+
+ const V kOneOverLog2 = Set(d, static_cast<T>(+1.442695040888963407359924681));
+
+ using TI = MakeSigned<T>;
+ const Rebind<TI, D> di;
+
+ V x_clamped = x;
+ if constexpr (!kHandleSubnormals) {
+ x_clamped = Max(x, kLowerBound);
+ }
+
+ const auto rounded_offs = BitCast(
+ d,
+ OrAnd(BitCast(di, kHalf), BitCast(di, x_clamped), BitCast(di, kNegZero)));
+
+ const auto q = impl.ToInt32(d, MulAdd(x_clamped, kOneOverLog2, rounded_offs));
+
+ const auto x_red = impl.ExpReduce(d, x_clamped, q);
+
+ // Degree 4 polynomial approximation of e^x on [-ln2/2, ln2/2]
+ // Generated via Caratheodory-Fejer approximation.
+ const auto c0 = Set(d, static_cast<T>(1.0000001510806224569));
+ const auto c1 = Set(d, static_cast<T>(0.99996228117046825901));
+ const auto c2 = Set(d, static_cast<T>(0.49998365704575670199));
+ const auto c3 = Set(d, static_cast<T>(0.16792157982876812494));
+ const auto c4 = Set(d, static_cast<T>(0.041959439862987071845));
+
+ // Estrin's scheme
+ const auto x2 = Mul(x_red, x_red);
+ // term0 = c1*x + c0
+ const auto term0 = MulAdd(c1, x_red, c0);
+ // term1 = c3*x + c2
+ const auto term1 = MulAdd(c3, x_red, c2);
+ // term2 = c4*x^2 + term1
+ const auto term2 = MulAdd(c4, x2, term1);
+ // approx = term2 * x^2 + term0
+ const auto approx = MulAdd(term2, x2, term0);
+
+ if constexpr (kHandleSubnormals) {
+ const V res = impl.LoadExpShortRange(d, approx, q);
+ // Handle underflow
+ return IfThenElseZero(Ge(x, kLowerBound), res);
+ } else {
+ // Optimization: avoid splitting the exponent since 'q' is guaranteed
+ // to fall within the normal floating-point ranges.
+ return Mul(approx, impl.Pow2I(d, q));
+ }
+}
+
+/**
+ * Fast approximation of exp2(x).
+ *
+ * Valid Lane Types: float32, float64
+ * Max Relative Error: 0.0007% for float32 [-150, 128]
+ * Max Relative Error: 0.0007% for float64 [-1075, 1024]
+ * Average Relative Error: 0.00002% for float32 [-150, 128]
+ * Average Relative Error: 0.00001% for float64 [-1075, 1024]
+ * Max Relative Error for Subnormals: 0.08% for float32 [-FLT_MAX, -150]
+ * Max Relative Error for Subnormals: 0.03% for float64 [-DBL_MAX, -1075]
+ * Valid Range: float32[-FLT_MAX, +128], float64[-DBL_MAX, +1024]
+ *
+ * @return 2^x
+ */
+template <bool kHandleSubnormals = true, class D, class V>
+HWY_INLINE V FastExp2(D d, V x) {
+ using T = TFromD<D>;
+ impl::FastExpImpl<T> impl;
+
+ T lower_bound_val;
+ if constexpr (kHandleSubnormals) {
+ // FastExp uses kLowerBound = -104.0 / -1000.0 since it operates on e^x. For
+ // FastExp2, we use lower limits correspondingly to -150.0 and -1075.0.
+ lower_bound_val = sizeof(T) == 4 ? -150.0 : -1075.0;
+ } else {
+ lower_bound_val = sizeof(T) == 4 ? -127.0 : -1023.0;
+ }
+ const V kLowerBound = Set(d, static_cast<T>(lower_bound_val));
+
+ const V kHalf = Set(d, static_cast<T>(+0.5));
+ const V kNegZero = Set(d, static_cast<T>(-0.0));
+
+ using TI = MakeSigned<T>;
+ const Rebind<TI, D> di;
+
+ V x_clamped = x;
+ if constexpr (!kHandleSubnormals) {
+ x_clamped = Max(x, kLowerBound);
+ }
+
+ const auto rounded_offs = BitCast(
+ d,
+ OrAnd(BitCast(di, kHalf), BitCast(di, x_clamped), BitCast(di, kNegZero)));
+
+ // FastExp calculates q = ToInt32(x * (1/ln(2)) + rounded_offs)
+ // FastExp2 does not need the (1/ln(2)) scaling factor since the input is
+ // already in base 2.
+ const auto q = impl.ToInt32(d, Add(x_clamped, rounded_offs));
+
+ const auto x_red = impl.Exp2Reduce(d, x_clamped, q);
+
+ // Degree 4 polynomial approximation of 2^x on [-1/2, 1/2]
+ // Derived from FastExp coefficients by pre-absorbing ln2:
+ // c_fast_exp2[i] = c_fast_exp[i] * (ln2)^i.
+ const auto c0 = Set(d, static_cast<T>(1.0000001510806224569));
+ const auto c1 = Set(d, static_cast<T>(0.69312104523363065471));
+ const auto c2 = Set(d, static_cast<T>(0.24021865239713606622));
+ const auto c3 = Set(d, static_cast<T>(0.05592203117565365516));
+ const auto c4 = Set(d, static_cast<T>(0.00968574163456345638));
+
+ // Estrin's scheme
+ const auto x2 = Mul(x_red, x_red);
+ // term0 = c1*x + c0
+ const auto term0 = MulAdd(c1, x_red, c0);
+ // term1 = c3*x + c2
+ const auto term1 = MulAdd(c3, x_red, c2);
+ // term2 = c4*x^2 + term1
+ const auto term2 = MulAdd(c4, x2, term1);
+ // approx = term2 * x^2 + term0
+ const auto approx = MulAdd(term2, x2, term0);
+
+ if constexpr (kHandleSubnormals) {
+ const V res = impl.LoadExpShortRange(d, approx, q);
+ // Handle underflow
+ return IfThenElseZero(Ge(x, kLowerBound), res);
+ } else {
+ // Optimization: avoid splitting the exponent since 'q' is guaranteed
+ // to fall within the normal floating-point ranges.
+ return Mul(approx, impl.Pow2I(d, q));
+ }
+}
+
+/**
+ * Fast approximation of exp(x) for x <= 0. Subnormals are flushed to zero.
+ *
+ * Valid Lane Types: float32, float64
+ * Max Relative Error: 0.0007% for float32 [-87, 0]
+ * Max Relative Error: 0.0007% for float64 [-708, 0]
+ * Average Relative Error: 0.00002% for float32 [-87, 0]
+ * Average Relative Error: 0.00001% for float64 [-708, 0]
+ * Valid Range: float32[-FLT_MAX, +0.0], float64[-DBL_MAX, +0.0]
+ *
+ * @return e^x
+ */
+template <class D, class V>
+HWY_INLINE V FastExpMinusOrZero(D d, V x) {
+ using T = TFromD<D>;
+ impl::FastExpImpl<T> impl;
+
+ const V kHalfMinus = Set(d, static_cast<T>(-0.5));
+ const V kLowerBound =
+ Set(d, static_cast<T>((sizeof(T) == 4 ? -88.0 : -709.0)));
+
+ const V kOneOverLog2 = Set(d, static_cast<T>(+1.442695040888963407359924681));
+
+ // Optimization for x <= 0:
+ // FastExp computes `rounded_offs = sign(x) ? -0.5 : 0.5` to round the
+ // multiplied argument towards zero. Since x <= 0, we avoid the dynamic
+ // calculation and simply use a constant -0.5 (kHalfMinus).
+ //
+ // We clamp x to be >= kLowerBound. For x < kLowerBound, the remapped
+ // exponent q becomes -127 (f32) or -1023 (f64), which Pow2I converts to
+ // exactly 0.0. This avoids subnormals and the need for a final mask.
+ const auto x_clamped = Max(x, kLowerBound);
+ const auto q = impl.ToInt32(d, MulAdd(x_clamped, kOneOverLog2, kHalfMinus));
+
+ const auto x_red = impl.ExpReduce(d, x_clamped, q);
+
+ // Degree 4 polynomial approximation of e^x on [-ln2/2, ln2/2]
+ // Generated via Caratheodory-Fejer approximation.
+ const auto c0 = Set(d, static_cast<T>(1.0000001510806224569));
+ const auto c1 = Set(d, static_cast<T>(0.99996228117046825901));
+ const auto c2 = Set(d, static_cast<T>(0.49998365704575670199));
+ const auto c3 = Set(d, static_cast<T>(0.16792157982876812494));
+ const auto c4 = Set(d, static_cast<T>(0.041959439862987071845));
+
+ // Estrin's scheme
+ const auto x2 = Mul(x_red, x_red);
+ // term0 = c1*x + c0
+ const auto term0 = MulAdd(c1, x_red, c0);
+ // term1 = c3*x + c2
+ const auto term1 = MulAdd(c3, x_red, c2);
+ // term2 = c4*x^2 + term1
+ const auto term2 = MulAdd(c4, x2, term1);
+ // approx = term2 * x^2 + term0
+ const auto approx = MulAdd(term2, x2, term0);
+
+ // Since inputs < -88.0 (f32) and < -709.0 (f64) are flushed to zero,
+ // we do not generate subnormals. Therefore, q is guaranteed to be >= -127
+ // and we can use Pow2I directly without splitting the exponent computation.
+ return Mul(approx, impl.Pow2I(d, q));
+}
+
+/**
+ * Fast approximation of log2(x).
+ *
+ * Valid Lane Types: float32, float64
+ * Max Relative Error: 0.0012%
+ * Average Relative Error: 1.2e-6% for float32, 1.8e-7% for float64
+ * Valid Range: float32: (0, +FLT_MAX]
+ * float64: (0, +DBL_MAX]
+ *
+ * @return base 2 logarithm of 'x'
+ */
+// If false, subnormals are treated as zero.
+template <bool kHandleSubnormals = true, class D, class V>
+HWY_INLINE V FastLog2(D d, V x) {
+ using T = TFromD<D>;
+ V y, exp;
+ impl::FastLogRangeReduction<kHandleSubnormals>(d, x, y, exp);
+
+ V approx;
+
+ V a, b, c, d_val;
+ // Centering the approximation around y=1.0 by using z = y - 1.0 significantly
+ // improves accuracy for low-degree polynomials compared to approximating
+ // log(y) directly.
+ const V z = Sub(y, Set(d, static_cast<T>(1.0)));
+
+ HWY_ALIGN static constexpr T arr_a[8] = {
+ static_cast<T>(1.136354905322312), static_cast<T>(0.8136166093855663),
+ static_cast<T>(0.6024092005204164), static_cast<T>(0.45842422954300593),
+ static_cast<T>(0.35690720107224694), static_cast<T>(0.28328561079576686),
+ static_cast<T>(0.22859892165962123), static_cast<T>(0.18718888021034824)};
+
+ HWY_ALIGN static constexpr T arr_b[8] = {static_cast<T>(-0.43234287851275466),
+ static_cast<T>(-0.6331997138565648),
+ static_cast<T>(-0.7084536512564498),
+ static_cast<T>(-0.721472128495863),
+ static_cast<T>(-0.703670916096675),
+ static_cast<T>(-0.6712018193012402),
+ static_cast<T>(-0.6325586260796298),
+ static_cast<T>(-0.5923089789593089)};
+
+ HWY_ALIGN static constexpr T arr_c[8] = {
+ static_cast<T>(1.4943605955858443), static_cast<T>(1.4523832207689078),
+ static_cast<T>(1.4432422308443573), static_cast<T>(1.4427109613084712),
+ static_cast<T>(1.4415723371043265), static_cast<T>(1.4367278151294332),
+ static_cast<T>(1.4275726764302927), static_cast<T>(1.4144921385652491)};
+
+ HWY_ALIGN static constexpr T arr_d[8] = {
+ static_cast<T>(0.0033301632601779037),
+ static_cast<T>(0.00038538113572950594),
+ static_cast<T>(7.855106905105614e-06),
+ static_cast<T>(0.0),
+ static_cast<T>(2.6196918310073536e-05),
+ static_cast<T>(0.000270699800561787),
+ static_cast<T>(0.000997916756959006),
+ static_cast<T>(0.00241927164949202)};
+
+ if constexpr (CanLookup8(d)) {
+ // --- Table Lookup ---
+ const auto scale = Set(d, static_cast<T>(11.3137085));
+ auto idx_i = ConvertInRangeTo(
+ RebindToSigned<D>(), MulAdd(y, scale, Set(d, static_cast<T>(-8.0))));
+
+ idx_i = Min(idx_i, Set(RebindToSigned<D>(), 7));
+
+ a = Lookup8(d, arr_a, idx_i);
+ b = Lookup8(d, arr_b, idx_i);
+ c = Lookup8(d, arr_c, idx_i);
+ d_val = Lookup8(d, arr_d, idx_i);
+ } else {
+ HWY_ALIGN static constexpr T thresholds[7] = {
+ static_cast<T>(0.7954951287634819), static_cast<T>(0.8838834764038688),
+ static_cast<T>(0.9722718240442556), static_cast<T>(1.0606601716846424),
+ static_cast<T>(1.1490485193250295), static_cast<T>(1.2374368669654163),
+ static_cast<T>(1.3258252146058032)};
+ impl::FallbackBlendChain4Coeff(d, y, thresholds, arr_a, arr_b, arr_c, arr_d,
+ a, b, c, d_val);
+ }
+ // Math: approx = (a*z + b)*z^2 + (c*z + d_val)
+ const auto z2 = Mul(z, z);
+ const auto pab = MulAdd(a, z, b);
+ const auto pcd = MulAdd(c, z, d_val);
+ approx = MulAdd(pab, z2, pcd);
+
+ return Add(exp, approx);
+}
+
+/**
+ * Fast approximation of log10(x).
+ *
+ * Valid Lane Types: float32, float64
+ * Max Relative Error: 0.0012%
+ * Average Relative Error: 5.4e-6% for float32, 1.8e-7% for float64
+ * Valid Range: float32: (0, +FLT_MAX]
+ * float64: (0, +DBL_MAX]
+ *
+ * @return base 10 logarithm of 'x'
+ */
+// If false, subnormals are treated as zero.
+template <bool kHandleSubnormals = true, class D, class V>
+HWY_INLINE V FastLog10(D d, V x) {
+ using T = TFromD<D>;
+ V y, exp;
+ impl::FastLogRangeReduction<kHandleSubnormals>(d, x, y, exp);
+
+ V approx;
+
+ V a, b, c, d_val;
+ // Centering the approximation around y=1.0 by using z = y - 1.0 significantly
+ // improves accuracy for low-degree polynomials compared to approximating
+ // log(y) directly.
+ const V z = Sub(y, Set(d, static_cast<T>(1.0)));
+
+ HWY_ALIGN static constexpr T arr_a[8] = {
+ static_cast<T>(0.3420769122219194), static_cast<T>(0.24492300439548012),
+ static_cast<T>(0.18134323902060331), static_cast<T>(0.137999443831595),
+ static_cast<T>(0.10743977319122217), static_cast<T>(0.08527746618951795),
+ static_cast<T>(0.06881513239598655), static_cast<T>(0.05634946779806663)};
+
+ HWY_ALIGN static constexpr T arr_b[8] = {
+ static_cast<T>(-0.1301481748440477),
+ static_cast<T>(-0.1906121071166758),
+ static_cast<T>(-0.21326579956586073),
+ static_cast<T>(-0.21718475171279292),
+ static_cast<T>(-0.21182605282145175),
+ static_cast<T>(-0.20205188075390862),
+ static_cast<T>(-0.19041912046596485),
+ static_cast<T>(-0.17830276936785788)};
+
+ HWY_ALIGN static constexpr T arr_c[8] = {
+ static_cast<T>(0.44984736360963107), static_cast<T>(0.4372109146505034),
+ static_cast<T>(0.4344592024931514), static_cast<T>(0.4342992744270672),
+ static_cast<T>(0.4339565143878306), static_cast<T>(0.4324981679587344),
+ static_cast<T>(0.4297421965958291), static_cast<T>(0.4258045623390324)};
+
+ HWY_ALIGN static constexpr T arr_d[8] = {
+ static_cast<T>(0.0010024790317717039),
+ static_cast<T>(0.00011601128161763332),
+ static_cast<T>(2.364622797584052e-06),
+ static_cast<T>(0.0),
+ static_cast<T>(7.886058205291105e-06),
+ static_cast<T>(8.148875978935532e-05),
+ static_cast<T>(0.00030040287702038374),
+ static_cast<T>(0.0007282733341565754)};
+
+ if constexpr (CanLookup8(d)) {
+ // --- Table Lookup ---
+ const auto scale = Set(d, static_cast<T>(11.3137085));
+ auto idx_i = ConvertInRangeTo(
+ RebindToSigned<D>(), MulAdd(y, scale, Set(d, static_cast<T>(-8.0))));
+
+ idx_i = Min(idx_i, Set(RebindToSigned<D>(), 7));
+
+ a = Lookup8(d, arr_a, idx_i);
+ b = Lookup8(d, arr_b, idx_i);
+ c = Lookup8(d, arr_c, idx_i);
+ d_val = Lookup8(d, arr_d, idx_i);
+ } else {
+ HWY_ALIGN static constexpr T thresholds[7] = {
+ static_cast<T>(0.7954951287634819), static_cast<T>(0.8838834764038688),
+ static_cast<T>(0.9722718240442556), static_cast<T>(1.0606601716846424),
+ static_cast<T>(1.1490485193250295), static_cast<T>(1.2374368669654163),
+ static_cast<T>(1.3258252146058032)};
+ impl::FallbackBlendChain4Coeff(d, y, thresholds, arr_a, arr_b, arr_c, arr_d,
+ a, b, c, d_val);
+ }
+ // Math: approx = (a*z + b)*z^2 + (c*z + d_val)
+ const auto z2 = Mul(z, z);
+ const auto pab = MulAdd(a, z, b);
+ const auto pcd = MulAdd(c, z, d_val);
+ approx = MulAdd(pab, z2, pcd);
+
+ const auto kLog10_2 = Set(d, static_cast<T>(0.3010299956639812)); // log10(2)
+ // Computes exp * log10(2) + approx. Since approx was scaled by 1/Ln(10)
+ // via the pre-scaled coefficients, this yields the correct log10 result
+ // using a single MulAdd instruction.
+ return MulAdd(exp, kLog10_2, approx);
+}
+
+/**
+ * Fast approximation of log(1 + x).
+ *
+ * Valid Lane Types: float32, float64
+ * Max Relative Error: 0.0012%
+ * Average Relative Error: 0.00013% for float32, 0.000039% for float64
+ * Valid Range: float32: [-1 + epsilon, +FLT_MAX]
+ * float64: [-1 + epsilon, +DBL_MAX]
+ *
+ * @return natural logarithm of '1 + x'
+ */
+// If false, subnormals are treated as zero.
+template <bool kHandleSubnormals = true, class D, class V>
+HWY_INLINE V FastLog1p(const D d, V x) {
+ using T = TFromD<D>;
+ const V kOne = Set(d, static_cast<T>(+1.0));
+
+ const V y = Add(x, kOne);
+ const Mask<D> not_pole = Ne(y, kOne);
+ // If y == 1, divisor becomes 1 (dummy), avoiding division by zero.
+ const V divisor = MaskedSubOr(y, not_pole, y, kOne);
+ // Ensure exactly 1.0 when x == divisor. This is necessary because some
+ // platforms (like Armv7) use Newton-Raphson for division, which can return
+ // 0.0, instead of 1.0 when the reciprocal calculation underflows
+ // for very large x.
+ const V div_res = MaskedDivOr(kOne, Ne(x, divisor), x, divisor);
+ const V non_pole = Mul(FastLog<kHandleSubnormals>(d, y), div_res);
+ return IfThenElse(not_pole, non_pole, x);
+}
+
+/**
+ * Fast approximation of base^exp.
+ *
+ * Valid Lane Types: float32, float64
+ * Valid Range: float32: base in (0, +FLT_MAX], exp * log(base) in [-25.0,
+ * +25.0] float64: base in (0, +DBL_MAX], exp * log(base) in [-25.0, +25.0] Max
+ * Relative Error for Valid Range: float32 : 0.03%, float64 : 0.03%
+ * @return base^exp
+ */
+// If false, subnormals are treated as zero.
+template <bool kHandleSubnormals = true, class D, class V>
+HWY_INLINE V FastPow(D d, V base, V exp) {
+ return FastExp<kHandleSubnormals>(
+ d, Mul(exp, FastLog<kHandleSubnormals>(d, base)));
+}
+
+template <class D, class V>
+HWY_NOINLINE V CallFastAtan(const D d, VecArg<V> x) {
+ return FastAtan(d, x);
+}
+
+template <class D, class V>
+HWY_NOINLINE V CallFastTan(const D d, VecArg<V> x) {
+ return FastTan(d, x);
+}
+
+template <class D, class V>
+HWY_NOINLINE V CallFastAtan2(const D d, VecArg<V> y, VecArg<V> x) {
+ return FastAtan2(d, y, x);
+}
+
+template <class D, class V>
+HWY_NOINLINE V CallFastTanh(const D d, VecArg<V> x) {
+ return FastTanh(d, x);
+}
+
+template <class D, class V>
+HWY_NOINLINE V CallFastLog(const D d, VecArg<V> x) {
+ return FastLog<>(d, x);
+}
+
+template <class D, class V>
+HWY_NOINLINE V CallFastExp(const D d, VecArg<V> x) {
+ return FastExp(d, x);
+}
+
+template <class D, class V>
+HWY_NOINLINE V CallFastExp2(const D d, VecArg<V> x) {
+ return FastExp2(d, x);
+}
+
+template <class D, class V>
+HWY_NOINLINE V CallFastExpMinusOrZero(const D d, VecArg<V> x) {
+ return FastExpMinusOrZero(d, x);
+}
+template <class D, class V>
+HWY_NOINLINE V CallFastLog2(const D d, VecArg<V> x) {
+ return FastLog2<>(d, x);
+}
+
+template <class D, class V>
+HWY_NOINLINE V CallFastLog10(const D d, VecArg<V> x) {
+ return FastLog10<>(d, x);
+}
+
+template <class D, class V>
+HWY_NOINLINE V CallFastLog1p(const D d, VecArg<V> x) {
+ return FastLog1p<>(d, x);
+}
+
+template <class D, class V>
+HWY_NOINLINE V CallFastPow(const D d, VecArg<V> base, VecArg<V> exp) {
+ return FastPow<>(d, base, exp);
+}
+
+template <class D, class V>
+HWY_NOINLINE V CallFastExpNormal(const D d, VecArg<V> x) {
+ return FastExp</*kHandleSubnormals=*/false>(d, x);
+}
+
+template <class D, class V>
+HWY_NOINLINE V CallFastExp2Normal(const D d, VecArg<V> x) {
+ return FastExp2</*kHandleSubnormals=*/false>(d, x);
+}
+
+template <class D, class V>
+HWY_NOINLINE V CallFastLogPositiveNormal(const D d, VecArg<V> x) {
+ return FastLog</*kHandleSubnormals=*/false>(d, x);
+}
+
+template <class D, class V>
+HWY_NOINLINE V CallFastLog2PositiveNormal(const D d, VecArg<V> x) {
+ return FastLog2</*kHandleSubnormals=*/false>(d, x);
+}
+
+template <class D, class V>
+HWY_NOINLINE V CallFastLog10PositiveNormal(const D d, VecArg<V> x) {
+ return FastLog10</*kHandleSubnormals=*/false>(d, x);
+}
+
+template <class D, class V>
+HWY_NOINLINE V CallFastLog1pPositiveNormal(const D d, VecArg<V> x) {
+ return FastLog1p</*kHandleSubnormals=*/false>(d, x);
+}
+
+template <class D, class V>
+HWY_NOINLINE V CallFastAtanPositive(const D d, VecArg<V> x) {
+ return FastAtan</*kAssumePositive=*/true>(d, x);
+}
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#endif // HIGHWAY_HWY_CONTRIB_MATH_FAST_MATH_INL_H_
diff --git a/third_party/highway/hwy/contrib/math/math-inl.h b/third_party/highway/hwy/contrib/math/math-inl.h
index 5bb536d9f8..dc5c90990e 100644
--- a/third_party/highway/hwy/contrib/math/math-inl.h
+++ b/third_party/highway/hwy/contrib/math/math-inl.h
@@ -134,41 +134,28 @@ HWY_NOINLINE V CallAtanh(const D d, VecArg<V> x) {
* Correctly handles negative zero, infinities, and NaN.
* @return atan2 of 'y', 'x'
*/
-template <class D, class V = VFromD<D>, class M = MFromD<D>,
- typename T = TFromD<D>>
-HWY_INLINE V Atan2(const D d, V y, V x) {
- const V kHalf = Set(d, static_cast<T>(+0.5));
- const V kPi = Set(d, static_cast<T>(+3.14159265358979323846264));
- const V kPi2 = Mul(kPi, kHalf);
-
- const V k0 = Zero(d);
- const M y_0 = Eq(y, k0);
- const M x_0 = Eq(x, k0);
- const M x_neg = Lt(x, k0);
- const M y_inf = IsInf(y);
- const M x_inf = IsInf(x);
- const M nan = Or(IsNaN(y), IsNaN(x));
-
- const V if_xneg_pi = IfThenElseZero(x_neg, kPi);
- // x= +inf: pi/4; -inf: 3*pi/4; else: pi/2
- const V if_yinf = Mul(kHalf, IfThenElse(x_inf, Add(kPi2, if_xneg_pi), kPi));
-
- V t = Atan(d, Div(y, x));
- // Disambiguate between quadrants 1/3 and 2/4 by adding (Q2: Pi; Q3: -Pi).
- t = Add(t, CopySignToAbs(if_xneg_pi, y));
- // Special cases for 0 and infinity:
- t = IfThenElse(x_inf, if_xneg_pi, t);
- t = IfThenElse(x_0, kPi2, t);
- t = IfThenElse(y_inf, if_yinf, t);
- t = IfThenElse(y_0, if_xneg_pi, t);
- // Any input NaN => NaN, otherwise fix sign.
- return IfThenElse(nan, NaN(d), CopySign(t, y));
-}
+template <class D, class V>
+HWY_INLINE V Atan2(D d, V y, V x);
template <class D, class V>
HWY_NOINLINE V CallAtan2(const D d, VecArg<V> y, VecArg<V> x) {
return Atan2(d, y, x);
}
+/**
+ * Highway SIMD version of std::cbrt(x).
+ *
+ * Valid Lane Types: float32, float64
+ * Max Error: ULP = 6
+ * Valid Range: float32[-FLT_MAX, +FLT_MAX], float64[-DBL_MAX, +DBL_MAX]
+ * @return cube root of 'x'
+ */
+template <bool kHandleSubnormals = true, class D, class V>
+HWY_INLINE V Cbrt(D d, V x);
+template <class D, class V>
+HWY_NOINLINE V CallCbrt(const D d, VecArg<V> x) {
+ return Cbrt<true>(d, x);
+}
+
/**
* Highway SIMD version of std::cos(x).
*
@@ -342,7 +329,6 @@ HWY_NOINLINE V CallTanh(const D d, VecArg<V> x) {
* Valid Lane Types: float32, float64
* Max Error: ULP = 1
* Valid Range: [-39000, +39000]
- * @return sine and cosine of 'x'
*/
template <class D, class V>
HWY_INLINE void SinCos(D d, V x, V& s, V& c);
@@ -575,7 +561,7 @@ struct SinCosImpl {};
template <>
struct AsinImpl<float> {
// Polynomial approximation for asin(x) over the range [0, 0.5).
- template <class D, class V>
+ template <class D, class V = VFromD<D>, HWY_IF_F32_D(D)>
HWY_INLINE V AsinPoly(D d, V x2, V /*x*/) {
const auto k0 = Set(d, +0.1666677296f);
const auto k1 = Set(d, +0.07495029271f);
@@ -592,7 +578,7 @@ struct AsinImpl<float> {
template <>
struct AsinImpl<double> {
// Polynomial approximation for asin(x) over the range [0, 0.5).
- template <class D, class V>
+ template <class D, class V = VFromD<D>, HWY_IF_F64_D(D)>
HWY_INLINE V AsinPoly(D d, V x2, V /*x*/) {
const auto k0 = Set(d, +0.1666666666666497543);
const auto k1 = Set(d, +0.07500000000378581611);
@@ -616,7 +602,7 @@ struct AsinImpl<double> {
template <>
struct AtanImpl<float> {
// Polynomial approximation for atan(x) over the range [0, 1.0).
- template <class D, class V>
+ template <class D, class V = VFromD<D>, HWY_IF_F32_D(D)>
HWY_INLINE V AtanPoly(D d, V x) {
const auto k0 = Set(d, -0.333331018686294555664062f);
const auto k1 = Set(d, +0.199926957488059997558594f);
@@ -637,7 +623,7 @@ struct AtanImpl<float> {
template <>
struct AtanImpl<double> {
// Polynomial approximation for atan(x) over the range [0, 1.0).
- template <class D, class V>
+ template <class D, class V = VFromD<D>, HWY_IF_F64_D(D)>
HWY_INLINE V AtanPoly(D d, V x) {
const auto k0 = Set(d, -0.333333333333311110369124);
const auto k1 = Set(d, +0.199999999996591265594148);
@@ -671,12 +657,12 @@ struct AtanImpl<double> {
template <>
struct CosSinImpl<float> {
// Rounds float toward zero and returns as int32_t.
- template <class D, class V>
+ template <class D, class V = VFromD<D>, HWY_IF_F32_D(D)>
HWY_INLINE Vec<Rebind<int32_t, D>> ToInt32(D /*unused*/, V x) {
return ConvertTo(Rebind<int32_t, D>(), x);
}
- template <class D, class V>
+ template <class D, class V = VFromD<D>, HWY_IF_F32_D(D)>
HWY_INLINE V Poly(D d, V x) {
const auto k0 = Set(d, -1.66666597127914428710938e-1f);
const auto k1 = Set(d, +8.33307858556509017944336e-3f);
@@ -687,7 +673,8 @@ struct CosSinImpl<float> {
return MulAdd(Estrin(y, k0, k1, k2, k3), Mul(y, x), x);
}
- template <class D, class V, class VI32>
+ template <class D, class V = VFromD<D>, class VI32 = Vec<Rebind<int32_t, D>>,
+ HWY_IF_F32_D(D)>
HWY_INLINE V CosReduce(D d, V x, VI32 q) {
// kHalfPiPart0f + kHalfPiPart1f + kHalfPiPart2f + kHalfPiPart3f ~= -pi/2
const V kHalfPiPart0f = Set(d, -0.5f * 3.140625f);
@@ -704,7 +691,8 @@ struct CosSinImpl<float> {
return x;
}
- template <class D, class V, class VI32>
+ template <class D, class V = VFromD<D>, class VI32 = Vec<Rebind<int32_t, D>>,
+ HWY_IF_F32_D(D)>
HWY_INLINE V SinReduce(D d, V x, VI32 q) {
// kPiPart0f + kPiPart1f + kPiPart2f + kPiPart3f ~= -pi
const V kPiPart0f = Set(d, -3.140625f);
@@ -722,14 +710,14 @@ struct CosSinImpl<float> {
}
// (q & 2) == 0 ? -0.0 : +0.0
- template <class D, class VI32>
+ template <class D, class VI32 = Vec<Rebind<int32_t, D>>, HWY_IF_F32_D(D)>
HWY_INLINE Vec<Rebind<float, D>> CosSignFromQuadrant(D d, VI32 q) {
const VI32 kTwo = Set(Rebind<int32_t, D>(), 2);
return BitCast(d, ShiftLeft<30>(AndNot(q, kTwo)));
}
// ((q & 1) ? -0.0 : +0.0)
- template <class D, class VI32>
+ template <class D, class VI32 = Vec<Rebind<int32_t, D>>, HWY_IF_F32_D(D)>
HWY_INLINE Vec<Rebind<float, D>> SinSignFromQuadrant(D d, VI32 q) {
const VI32 kOne = Set(Rebind<int32_t, D>(), 1);
return BitCast(d, ShiftLeft<31>(And(q, kOne)));
@@ -741,12 +729,12 @@ struct CosSinImpl<float> {
template <>
struct CosSinImpl<double> {
// Rounds double toward zero and returns as int32_t.
- template <class D, class V>
+ template <class D, class V = VFromD<D>, HWY_IF_F64_D(D)>
HWY_INLINE Vec<Rebind<int32_t, D>> ToInt32(D /*unused*/, V x) {
return DemoteTo(Rebind<int32_t, D>(), x);
}
- template <class D, class V>
+ template <class D, class V = VFromD<D>, HWY_IF_F64_D(D)>
HWY_INLINE V Poly(D d, V x) {
const auto k0 = Set(d, -0.166666666666666657414808);
const auto k1 = Set(d, +0.00833333333333332974823815);
@@ -762,7 +750,8 @@ struct CosSinImpl<double> {
return MulAdd(Estrin(y, k0, k1, k2, k3, k4, k5, k6, k7, k8), Mul(y, x), x);
}
- template <class D, class V, class VI32>
+ template <class D, class V = VFromD<D>, class VI32 = Vec<Rebind<int32_t, D>>,
+ HWY_IF_F64_D(D)>
HWY_INLINE V CosReduce(D d, V x, VI32 q) {
// kHalfPiPart0d + kHalfPiPart1d + kHalfPiPart2d + kHalfPiPart3d ~= -pi/2
const V kHalfPiPart0d = Set(d, -0.5 * 3.1415926218032836914);
@@ -779,7 +768,8 @@ struct CosSinImpl<double> {
return x;
}
- template <class D, class V, class VI32>
+ template <class D, class V = VFromD<D>, class VI32 = Vec<Rebind<int32_t, D>>,
+ HWY_IF_F64_D(D)>
HWY_INLINE V SinReduce(D d, V x, VI32 q) {
// kPiPart0d + kPiPart1d + kPiPart2d + kPiPart3d ~= -pi
const V kPiPart0d = Set(d, -3.1415926218032836914);
@@ -797,7 +787,7 @@ struct CosSinImpl<double> {
}
// (q & 2) == 0 ? -0.0 : +0.0
- template <class D, class VI32>
+ template <class D, class VI32 = Vec<Rebind<int32_t, D>>, HWY_IF_F64_D(D)>
HWY_INLINE Vec<Rebind<double, D>> CosSignFromQuadrant(D d, VI32 q) {
const VI32 kTwo = Set(Rebind<int32_t, D>(), 2);
return BitCast(
@@ -805,7 +795,7 @@ struct CosSinImpl<double> {
}
// ((q & 1) ? -0.0 : +0.0)
- template <class D, class VI32>
+ template <class D, class VI32 = Vec<Rebind<int32_t, D>>, HWY_IF_F64_D(D)>
HWY_INLINE Vec<Rebind<double, D>> SinSignFromQuadrant(D d, VI32 q) {
const VI32 kOne = Set(Rebind<int32_t, D>(), 1);
return BitCast(
@@ -818,18 +808,18 @@ struct CosSinImpl<double> {
template <>
struct ExpImpl<float> {
// Rounds float toward zero and returns as int32_t.
- template <class D, class V>
+ template <class D, class V = VFromD<D>, HWY_IF_F32_D(D)>
HWY_INLINE Vec<Rebind<int32_t, D>> ToInt32(D /*unused*/, V x) {
return ConvertTo(Rebind<int32_t, D>(), x);
}
// Rounds float to nearest int32_t
- template <class D, class V>
+ template <class D, class V = VFromD<D>, HWY_IF_F32_D(D)>
HWY_INLINE Vec<Rebind<int32_t, D>> ToNearestInt32(D /*unused*/, V x) {
return NearestInt(x);
}
- template <class D, class V>
+ template <class D, class V = VFromD<D>, HWY_IF_F32_D(D)>
HWY_INLINE V ExpPoly(D d, V x) {
const auto k0 = Set(d, +0.5f);
const auto k1 = Set(d, +0.166666671633720397949219f);
@@ -842,7 +832,7 @@ struct ExpImpl<float> {
}
// Computes 2^x, where x is an integer.
- template <class D, class VI32>
+ template <class D, class VI32 = Vec<Rebind<int32_t, D>>, HWY_IF_F32_D(D)>
HWY_INLINE Vec<D> Pow2I(D d, VI32 x) {
const Rebind<int32_t, D> di32;
const VI32 kOffset = Set(di32, 0x7F);
@@ -850,13 +840,15 @@ struct ExpImpl<float> {
}
// Sets the exponent of 'x' to 2^e.
- template <class D, class V, class VI32>
+ template <class D, class V = VFromD<D>, class VI32 = Vec<Rebind<int32_t, D>>,
+ HWY_IF_F32_D(D)>
HWY_INLINE V LoadExpShortRange(D d, V x, VI32 e) {
const VI32 y = ShiftRight<1>(e);
return Mul(Mul(x, Pow2I(d, y)), Pow2I(d, Sub(e, y)));
}
- template <class D, class V, class VI32>
+ template <class D, class V = VFromD<D>, class VI32 = Vec<Rebind<int32_t, D>>,
+ HWY_IF_F32_D(D)>
HWY_INLINE V ExpReduce(D d, V x, VI32 q) {
// kLn2Part0f + kLn2Part1f ~= -ln(2)
const V kLn2Part0f = Set(d, -0.693145751953125f);
@@ -869,7 +861,8 @@ struct ExpImpl<float> {
return x;
}
- template <class D, class V, class VI32>
+ template <class D, class V = VFromD<D>, class VI32 = Vec<Rebind<int32_t, D>>,
+ HWY_IF_F32_D(D)>
HWY_INLINE V Exp2Reduce(D d, V x, VI32 q) {
const V x_frac = Sub(x, ConvertTo(d, q));
return MulAdd(x_frac, Set(d, 0.193147182464599609375f),
@@ -879,8 +872,9 @@ struct ExpImpl<float> {
template <>
struct LogImpl<float> {
- template <class D, class V>
- HWY_INLINE Vec<Rebind<int32_t, D>> Log2p1NoSubnormal(D /*d*/, V x) {
+ template <class D, class V = VFromD<D>, HWY_IF_F32_D(D)>
+ HWY_INLINE Vec<Rebind<int32_t, D>> Log2p1NoSubnormal(
+ D /*d*/, Vec<Rebind<int32_t, D>> x) {
const Rebind<int32_t, D> di32;
const Rebind<uint32_t, D> du32;
const auto kBias = Set(di32, 0x7F);
@@ -888,7 +882,7 @@ struct LogImpl<float> {
}
// Approximates Log(x) over the range [sqrt(2) / 2, sqrt(2)].
- template <class D, class V>
+ template <class D, class V = VFromD<D>, HWY_IF_F32_D(D)>
HWY_INLINE V LogPoly(D d, V x) {
const V k0 = Set(d, 0.66666662693f);
const V k1 = Set(d, 0.40000972152f);
@@ -905,18 +899,18 @@ struct LogImpl<float> {
template <>
struct ExpImpl<double> {
// Rounds double toward zero and returns as int32_t.
- template <class D, class V>
+ template <class D, class V = VFromD<D>, HWY_IF_F64_D(D)>
HWY_INLINE Vec<Rebind<int32_t, D>> ToInt32(D /*unused*/, V x) {
return DemoteTo(Rebind<int32_t, D>(), x);
}
// Rounds double to nearest int32_t
- template <class D, class V>
+ template <class D, class V = VFromD<D>, HWY_IF_F64_D(D)>
HWY_INLINE Vec<Rebind<int32_t, D>> ToNearestInt32(D /*unused*/, V x) {
return DemoteToNearestInt(Rebind<int32_t, D>(), x);
}
- template <class D, class V>
+ template <class D, class V = VFromD<D>, HWY_IF_F64_D(D)>
HWY_INLINE V ExpPoly(D d, V x) {
const auto k0 = Set(d, +0.5);
const auto k1 = Set(d, +0.166666666666666851703837);
@@ -935,7 +929,7 @@ struct ExpImpl<double> {
}
// Computes 2^x, where x is an integer.
- template <class D, class VI32>
+ template <class D, class VI32 = Vec<Rebind<int32_t, D>>, HWY_IF_F64_D(D)>
HWY_INLINE Vec<D> Pow2I(D d, VI32 x) {
const Rebind<int32_t, D> di32;
const Rebind<int64_t, D> di64;
@@ -944,13 +938,15 @@ struct ExpImpl<double> {
}
// Sets the exponent of 'x' to 2^e.
- template <class D, class V, class VI32>
+ template <class D, class V = VFromD<D>, class VI32 = Vec<Rebind<int32_t, D>>,
+ HWY_IF_F64_D(D)>
HWY_INLINE V LoadExpShortRange(D d, V x, VI32 e) {
const VI32 y = ShiftRight<1>(e);
return Mul(Mul(x, Pow2I(d, y)), Pow2I(d, Sub(e, y)));
}
- template <class D, class V, class VI32>
+ template <class D, class V = VFromD<D>, class VI32 = Vec<Rebind<int32_t, D>>,
+ HWY_IF_F64_D(D)>
HWY_INLINE V ExpReduce(D d, V x, VI32 q) {
// kLn2Part0d + kLn2Part1d ~= -ln(2)
const V kLn2Part0d = Set(d, -0.6931471805596629565116018);
@@ -963,7 +959,8 @@ struct ExpImpl<double> {
return x;
}
- template <class D, class V, class VI32>
+ template <class D, class V = VFromD<D>, class VI32 = Vec<Rebind<int32_t, D>>,
+ HWY_IF_F64_D(D)>
HWY_INLINE V Exp2Reduce(D d, V x, VI32 q) {
const V x_frac = Sub(x, PromoteTo(d, q));
return MulAdd(x_frac, Set(d, 0.1931471805599453139823396),
@@ -973,8 +970,9 @@ struct ExpImpl<double> {
template <>
struct LogImpl<double> {
- template <class D, class V>
- HWY_INLINE Vec<Rebind<int64_t, D>> Log2p1NoSubnormal(D /*d*/, V x) {
+ template <class D, class V = VFromD<D>, HWY_IF_F64_D(D)>
+ HWY_INLINE Vec<Rebind<int64_t, D>> Log2p1NoSubnormal(
+ D /*d*/, Vec<Rebind<int64_t, D>> x) {
const Rebind<int64_t, D> di64;
const Rebind<uint64_t, D> du64;
return Sub(BitCast(di64, ShiftRight<52>(BitCast(du64, x))),
@@ -982,7 +980,7 @@ struct LogImpl<double> {
}
// Approximates Log(x) over the range [sqrt(2) / 2, sqrt(2)].
- template <class D, class V>
+ template <class D, class V = VFromD<D>, HWY_IF_F64_D(D)>
HWY_INLINE V LogPoly(D d, V x) {
const V k0 = Set(d, 0.6666666666666735130);
const V k1 = Set(d, 0.3999999999940941908);
@@ -1028,8 +1026,10 @@ HWY_INLINE V Log(const D d, V x) {
: static_cast<TI>(0xFFFFFFFFLL));
const VI kMagic = Set(di, kIsF32 ? static_cast<TI>(0x3F3504F3L)
: static_cast<TI>(0x3FE6A09E00000000LL));
- const VI kExpMask = Set(di, kIsF32 ? static_cast<TI>(0x3F800000L)
- : static_cast<TI>(0x3FF0000000000000LL));
+ const VI kExpMagicDiff = Set(
+ di, kIsF32
+ ? static_cast<TI>(0x3F800000L - 0x3F3504F3L)
+ : static_cast<TI>(0x3FF0000000000000LL - 0x3FE6A09E00000000LL));
const VI kExpScale =
Set(di, kIsF32 ? static_cast<TI>(-25) : static_cast<TI>(-54));
const VI kManMask = Set(di, kIsF32 ? static_cast<TI>(0x7FFFFFL)
@@ -1043,15 +1043,14 @@ HWY_INLINE V Log(const D d, V x) {
x = IfThenElse(is_denormal, Mul(x, kScale), x);
// Compute the new exponent.
- exp_bits = Add(BitCast(di, x), Sub(kExpMask, kMagic));
+ exp_bits = Add(BitCast(di, x), kExpMagicDiff);
const VI exp_scale =
BitCast(di, IfThenElseZero(is_denormal, BitCast(d, kExpScale)));
- exp = ConvertTo(
- d, Add(exp_scale, impl.Log2p1NoSubnormal(d, BitCast(d, exp_bits))));
+ exp = ConvertTo(d, Add(exp_scale, impl.Log2p1NoSubnormal(d, exp_bits)));
} else {
// Compute the new exponent.
- exp_bits = Add(BitCast(di, x), Sub(kExpMask, kMagic));
- exp = ConvertTo(d, impl.Log2p1NoSubnormal(d, BitCast(d, exp_bits)));
+ exp_bits = Add(BitCast(di, x), kExpMagicDiff);
+ exp = ConvertTo(d, impl.Log2p1NoSubnormal(d, exp_bits));
}
// Renormalize.
@@ -1072,7 +1071,7 @@ HWY_INLINE V Log(const D d, V x) {
// http://gruntthepeon.free.fr/ssemath/
// Third degree poly
-template <class D, class V>
+template <class D, class V = VFromD<D>>
HWY_INLINE void SinCos3(D d, TFromD<D> dp1, TFromD<D> dp2, TFromD<D> dp3, V x,
V& s, V& c) {
using T = TFromD<D>;
@@ -1171,7 +1170,7 @@ HWY_INLINE void SinCos3(D d, TFromD<D> dp1, TFromD<D> dp2, TFromD<D> dp3, V x,
}
// Sixth degree poly
-template <class D, class V>
+template <class D, class V = VFromD<D>>
HWY_INLINE void SinCos6(D d, TFromD<D> dp1, TFromD<D> dp2, TFromD<D> dp3, V x,
V& s, V& c) {
using T = TFromD<D>;
@@ -1283,7 +1282,7 @@ HWY_INLINE void SinCos6(D d, TFromD<D> dp1, TFromD<D> dp2, TFromD<D> dp3, V x,
template <>
struct SinCosImpl<float> {
- template <class D, class V>
+ template <class D, class V = VFromD<D>, HWY_IF_F32_D(D)>
HWY_INLINE void SinCos(D d, V x, V& s, V& c) {
SinCos3(d, -0.78515625f, -2.4187564849853515625e-4f,
-3.77489497744594108e-8f, x, s, c);
@@ -1293,7 +1292,7 @@ struct SinCosImpl<float> {
#if HWY_HAVE_FLOAT64 && HWY_HAVE_INTEGER64
template <>
struct SinCosImpl<double> {
- template <class D, class V>
+ template <class D, class V = VFromD<D>, HWY_IF_F64_D(D)>
HWY_INLINE void SinCos(D d, V x, V& s, V& c) {
SinCos6(d, -7.85398125648498535156E-1, -3.77489470793079817668E-8,
-2.69515142907905952645E-15, x, s, c);
@@ -1427,6 +1426,39 @@ HWY_INLINE V Atan(const D d, V x) {
return Or(IfThenElse(mask, Sub(kPiOverTwo, y), y), sign);
}
+template <class D, class V>
+HWY_INLINE V Atan2(const D d, V y, V x) {
+ using T = TFromD<D>;
+ using M = MFromD<D>;
+
+ const V kPi = Set(d, static_cast<T>(3.14159265358979323846264));
+ const V kPiOverTwo = Set(d, static_cast<T>(1.57079632679489661923132169));
+ const V kOne = Set(d, static_cast<T>(1.0));
+ const V k0 = Zero(d);
+
+ const V ax = Abs(x);
+ const V ay = Abs(y);
+
+ const V num = Min(ax, ay);
+ const V den = Max(ax, ay);
+
+ const M is_inf = IsInf(num);
+ V mapped_y = MaskedDivOr(k0, Ne(den, k0), num, den);
+ mapped_y = IfThenElse(is_inf, kOne, mapped_y);
+
+ impl::AtanImpl<T> impl;
+ const V poly = impl.AtanPoly(d, mapped_y);
+
+ const M ay_gt_ax = Gt(ay, ax);
+ V angle = MaskedSubOr(poly, ay_gt_ax, kPiOverTwo, poly);
+
+ const M x_neg = Lt(x, k0);
+ angle = MaskedSubOr(angle, x_neg, kPi, angle);
+
+ const M is_nan = IsEitherNaN(y, x);
+ return IfThenElse(is_nan, NaN(d), CopySign(angle, y));
+}
+
template <class D, class V>
HWY_INLINE V Atanh(const D d, V x) {
using T = TFromD<D>;
@@ -1440,6 +1472,151 @@ HWY_INLINE V Atanh(const D d, V x) {
Xor(kHalf, sign));
}
+namespace impl {
+
+// Barrett reduction (n/3) via MulHigh by 0x5556, repartitions to u16 lanes
+template <class DI, class VI = decltype(Zero(DI()))>
+HWY_INLINE void CbrtDivMod3(DI di, VI exp_shifted, VI& div, VI& mod) {
+ const Repartition<uint16_t, DI> du16;
+ using VU16 = decltype(Zero(du16));
+ const VU16 exp_u16 = BitCast(du16, exp_shifted);
+ const VU16 div_u16 =
+ MulHigh(exp_u16, Set(du16, static_cast<uint16_t>(0x5556)));
+ const VU16 mod_u16 =
+ Sub(exp_u16, Mul(div_u16, Set(du16, static_cast<uint16_t>(3))));
+ div = BitCast(di, div_u16);
+ mod = BitCast(di, mod_u16);
+}
+
+// Single-lane fallback, Barrett reduction on the int lane with no Repartition
+template <class DI, class VI = decltype(Zero(DI()))>
+HWY_INLINE void CbrtDivMod3Scalar(DI di, VI exp_shifted, VI& div, VI& mod) {
+ using TI = TFromD<DI>;
+ div = ShiftRight<16>(Mul(exp_shifted, Set(di, static_cast<TI>(0x5556))));
+ mod = Sub(exp_shifted, Mul(div, Set(di, static_cast<TI>(3))));
+}
+
+} // namespace impl
+
+// Modified from BSD-licensed code
+// Copyright (c) the JPEG XL Project Authors. All rights reserved.
+// See https://github.com/libjxl/libjxl/blob/main/LICENSE.
+template <bool kHandleSubnormals, class D, class V>
+HWY_INLINE V Cbrt(const D d, V x) {
+ using T = TFromD<D>;
+
+ const V sign = And(SignBit(d), x);
+ const V abs_x = Xor(x, sign);
+ x = abs_x;
+
+ constexpr bool kIsF32 = (sizeof(T) == 4);
+
+ MFromD<D> is_denormal;
+ if constexpr (kHandleSubnormals) {
+ const V kMinNormal = Set(d, SmallestNormal<T>());
+ // Exponent to scale subnormals that is divisible by 3, 2^24 or 2^54
+ const V kScale = Set(d, kIsF32 ? static_cast<T>(16777216.0f)
+ : static_cast<T>(18014398509481984.0));
+ is_denormal = Lt(x, kMinNormal);
+ x = MaskedMulOr(x, is_denormal, x, kScale);
+ } else {
+ (void)is_denormal;
+ }
+
+ V y;
+ const RebindToSigned<D> di;
+ using TI = TFromD<decltype(di)>;
+ using VI = decltype(Zero(di));
+
+ const VI x_int = BitCast(di, x);
+
+ // Extract exponent and shift (3*128 or 3*512) to keep non-negative for
+ // Barrett reduction
+ const VI exp_shifted =
+ Add(ShiftRight < kIsF32 ? 23 : 52 > (x_int),
+ Set(di, kIsF32 ? static_cast<TI>(257) : static_cast<TI>(513)));
+
+ VI exp_shifted_div_3;
+ VI exp_mod_3;
+ if constexpr ((HWY_MAX_LANES_D(D) > 1 && !HWY_HAVE_SCALABLE) ||
+ (HWY_HAVE_SCALABLE && detail::IsFull(d))) {
+ impl::CbrtDivMod3(di, exp_shifted, exp_shifted_div_3, exp_mod_3);
+ } else if constexpr (HWY_MAX_LANES_D(D) > 1) {
+ if (Lanes(d) > 1) {
+ impl::CbrtDivMod3(di, exp_shifted, exp_shifted_div_3, exp_mod_3);
+ } else {
+ impl::CbrtDivMod3Scalar(di, exp_shifted, exp_shifted_div_3, exp_mod_3);
+ }
+ } else {
+ impl::CbrtDivMod3Scalar(di, exp_shifted, exp_shifted_div_3, exp_mod_3);
+ }
+
+ // Undo constant shift to ensure non negative
+ const VI neg_exp_div_3 =
+ Sub(Set(di, kIsF32 ? static_cast<TI>(128) : static_cast<TI>(512)),
+ exp_shifted_div_3);
+ // Combine exp mod 3 index with the top mantissa bits
+ const VI top_mant =
+ And(ShiftRight < kIsF32 ? 22 : 50 > (x_int),
+ Set(di, kIsF32 ? static_cast<TI>(1) : static_cast<TI>(3)));
+ const VI idx = Add(ShiftLeft < kIsF32 ? 1 : 2 > (exp_mod_3), top_mant);
+
+ V r;
+ if constexpr (kIsF32) {
+ // (1/cbrt(lo) + 1/cbrt(hi))/2 over 6 bins of [1,8)
+ HWY_ALIGN static constexpr float initial_guess[8] = {
+ 0.92807984f, 0.81504166f, 0.73603648f, 0.65004617f,
+ 0.58375800f, 0.51406258f, 0.0f, 0.0f};
+ if constexpr ((HWY_MAX_LANES_D(D) >= 4 && !HWY_HAVE_SCALABLE) ||
+ (HWY_HAVE_SCALABLE && sizeof(T) == 4 && detail::IsFull(d))) {
+ r = Lookup8(d, initial_guess, idx);
+ } else {
+ r = GatherIndex(d, initial_guess, idx);
+ }
+ } else {
+ // (1/cbrt(lo) + 1/cbrt(hi))/2 over 12 bins of [1,8)
+ HWY_ALIGN static constexpr double initial_guess[12] = {
+ 9.6415888336127797e-01, 9.0094911572942737e-01, 8.5170349905127118e-01,
+ 8.1176352967517162e-01, 7.6525341285608861e-01, 7.1508378703935604e-01,
+ 6.7599751517949214e-01, 6.4429714047789299e-01, 6.0738203629500487e-01,
+ 5.6756237789583885e-01, 5.3653958336190732e-01, 5.1137897928735510e-01};
+ r = GatherIndex(d, initial_guess, idx);
+ }
+
+ // Apply 2^(-exp/3) to scale lookup result to 1/cbrt(x).
+ r = MulByPow2(r, neg_exp_div_3);
+
+ const V kOneThird = Set(d, static_cast<T>(1.0 / 3.0));
+ const V kFourThirds = Set(d, static_cast<T>(4.0 / 3.0));
+ const V x_div_3 = Mul(kOneThird, x);
+ constexpr size_t kIters = kIsF32 ? 2 : 3;
+ // Newton iteration for 1/cbrt(x): r = r * (4/3 - (x/3) * r^3).
+ for (size_t i = 0; i < kIters; ++i) {
+ const V r2 = Mul(r, r);
+ const V x_div_3_r = Mul(x_div_3, r);
+ r = Mul(r, NegMulAdd(x_div_3_r, r2, kFourThirds));
+ }
+
+ // Fused finalizer: y = r*r*x * (5/3 - (2/3) * r*r*r*x).
+ const V kFiveThirds = Set(d, static_cast<T>(5.0 / 3.0));
+ const V kTwoThirds = Set(d, static_cast<T>(2.0 / 3.0));
+ const V y0 = Mul(Mul(r, r), x);
+ const V h = Mul(y0, r);
+ y = Mul(y0, NegMulAdd(kTwoThirds, h, kFiveThirds));
+
+ if constexpr (kHandleSubnormals) {
+ // 1 / cbrt(kScale), 1 / 2^8 or 1 / 2^18
+ const auto kUnscale = Set(d, kIsF32 ? static_cast<T>(1.0 / 256.0)
+ : static_cast<T>(1.0 / 262144.0));
+ y = MaskedMulOr(y, is_denormal, y, kUnscale);
+ }
+
+ y = IfThenElse(Or(Eq(abs_x, Zero(d)), Not(IsFinite(abs_x))), abs_x, y);
+
+ y = Or(y, sign);
+ return y;
+}
+
template <class D, class V>
HWY_INLINE V Cos(const D d, V x) {
using T = TFromD<D>;
@@ -1548,11 +1725,17 @@ HWY_INLINE V Log1p(const D d, V x) {
const V kOne = Set(d, static_cast<T>(+1.0));
const V y = Add(x, kOne);
- const auto is_pole = Eq(y, kOne);
- const auto divisor = Sub(IfThenZeroElse(is_pole, y), kOne);
+ const Mask<D> not_pole = Ne(y, kOne);
+ // If y == 1, divisor becomes 1 (dummy), avoiding division by zero.
+ const V divisor = MaskedSubOr(y, not_pole, y, kOne);
+ // Ensure exactly 1.0 when x == divisor. This is necessary because some
+ // platforms (like Armv7) use Newton-Raphson for division, which can return
+ // 0.0, instead of 1.0 when the reciprocal calculation underflows
+ // for very large x.
+ const V div_res = MaskedDivOr(kOne, Ne(x, divisor), x, divisor);
const auto non_pole =
- Mul(impl::Log<D, V, /*kAllowSubnormals=*/false>(d, y), Div(x, divisor));
- return IfThenElse(is_pole, x, non_pole);
+ Mul(impl::Log<D, V, /*kAllowSubnormals=*/false>(d, y), div_res);
+ return IfThenElse(not_pole, non_pole, x);
}
template <class D, class V>
diff --git a/third_party/highway/hwy/contrib/math/math_benchmark.cc b/third_party/highway/hwy/contrib/math/math_benchmark.cc
new file mode 100644
index 0000000000..c5b48a6a0c
--- /dev/null
+++ b/third_party/highway/hwy/contrib/math/math_benchmark.cc
@@ -0,0 +1,326 @@
+// Copyright 2026 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// clang-format off
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "hwy/contrib/math/math_benchmark.cc"
+#include "third_party/highway/hwy/foreach_target.h" // IWYU pragma: keep
+#include "third_party/highway/hwy/highway.h"
+#include "third_party/highway/hwy/contrib/math/math-inl.h"
+#include "third_party/highway/hwy/contrib/math/fast_math-inl.h"
+#include "third_party/highway/hwy/nanobenchmark.h"
+#include "third_party/highway/hwy/tests/test_util-inl.h"
+// clang-format on
+
+#include <stdio.h>
+#include <string.h>
+
+// ============================================================================
+// You can dynamically select two highway math functions from the list to
+// benchmark by passing their names via command line flags `--fxn1` and
+// `--fxn2`. Defaults:
+// If no flags are passed, the benchmark automatically defaults to
+// benchmarking `FastExp` and `Exp`.
+//
+// Note:
+// - Either provide 2 function flags or none.
+// - Function names must be from the list of valid functions.
+// ============================================================================
+
+namespace hwy {
+extern const char* g_fxn1;
+extern const char* g_fxn2;
+} // namespace hwy
+
+#define HWY_MATH_BENCHMARKS(V) \
+ V(Exp) \
+ V(FastExp) \
+ V(FastExpNormal) \
+ V(Exp2) \
+ V(FastExp2) \
+ V(FastExp2Normal) \
+ V(FastExpMinusOrZero) \
+ V(Log) \
+ V(FastLog) \
+ V(FastLogPositiveNormal) \
+ V(Log2) \
+ V(FastLog2) \
+ V(FastLog2PositiveNormal) \
+ V(Log10) \
+ V(FastLog10) \
+ V(FastLog10PositiveNormal) \
+ V(Log1p) \
+ V(FastLog1p) \
+ V(FastLog1pPositiveNormal) \
+ V(Atan) \
+ V(FastAtan) \
+ V(FastAtanPositive) \
+ V(Tanh) \
+ V(FastTanh) \
+ V(Atan2) \
+ V(FastAtan2)
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+
+namespace hn = hwy::HWY_NAMESPACE;
+
+// Helper to safely convert float/double results to benchmark output
+// without triggering SIGILL on Infinity or compile errors on size mismatch.
+template <class Val>
+hwy::FuncOutput SafeCast(Val v) {
+ using BitsT = hwy::MakeUnsigned<Val>;
+ auto bits = hwy::BitCastScalar<BitsT>(v);
+ return static_cast<hwy::FuncOutput>(bits);
+}
+
+// Macro to define benchmark body
+#define DEFINE_MATH_BENCH(NAME, FUNC, MAP_EXPR) \
+ template <class D> \
+ void Bench##NAME(D d) { \
+ using T = hn::TFromD<D>; \
+ printf("Benchmarking " #NAME " for %s:\n", \
+ hwy::TypeName(T(), hn::Lanes(d)).c_str()); \
+ auto func = [d](const hwy::FuncInput in) -> hwy::FuncOutput { \
+ const double val = MAP_EXPR; \
+ const auto v = hn::Set(d, static_cast<T>(val)); \
+ const auto res = FUNC; \
+ return SafeCast(hn::GetLane(res)); \
+ }; \
+ const size_t kNumInputs = 16; \
+ hwy::FuncInput inputs[kNumInputs]; \
+ for (size_t i = 0; i < kNumInputs; ++i) inputs[i] = i; \
+ hwy::Result results[kNumInputs]; \
+ hwy::Params params = hwy::DefaultBenchmarkParams(); \
+ const size_t num_results = \
+ hwy::MeasureClosure(func, inputs, kNumInputs, results, params); \
+ double sum_ticks = 0; \
+ for (size_t i = 0; i < num_results; ++i) sum_ticks += results[i].ticks; \
+ if (num_results > 0) { \
+ printf(" Avg ticks: %f\n", sum_ticks / num_results); \
+ } \
+ }
+
+// Macro to define benchmark body for 2-element variants
+#define DEFINE_MATH_BENCH_2ARG(NAME, FUNC, MAP_EXPR_Y, MAP_EXPR_X) \
+ template <class D> \
+ void Bench##NAME(D d) { \
+ using T = hn::TFromD<D>; \
+ printf("Benchmarking " #NAME " for %s:\n", \
+ hwy::TypeName(T(), hn::Lanes(d)).c_str()); \
+ auto func = [d](const hwy::FuncInput in) -> hwy::FuncOutput { \
+ const double val_y = MAP_EXPR_Y; \
+ const double val_x = MAP_EXPR_X; \
+ const auto y = hn::Set(d, static_cast<T>(val_y)); \
+ const auto x = hn::Set(d, static_cast<T>(val_x)); \
+ const auto res = FUNC; \
+ return SafeCast(hn::GetLane(res)); \
+ }; \
+ const size_t kNumInputs = 16; \
+ hwy::FuncInput inputs[kNumInputs]; \
+ for (size_t i = 0; i < kNumInputs; ++i) inputs[i] = i; \
+ hwy::Result results[kNumInputs]; \
+ hwy::Params params = hwy::DefaultBenchmarkParams(); \
+ const size_t num_results = \
+ hwy::MeasureClosure(func, inputs, kNumInputs, results, params); \
+ double sum_ticks = 0; \
+ for (size_t i = 0; i < num_results; ++i) sum_ticks += results[i].ticks; \
+ if (num_results > 0) { \
+ printf(" Avg ticks: %f\n", sum_ticks / num_results); \
+ } \
+ }
+
+// Exp / FastExp
+DEFINE_MATH_BENCH(CallExp, hn::CallExp(d, v),
+ -10.0 + static_cast<double>(in) * (20.0 / 15.0))
+DEFINE_MATH_BENCH(CallFastExp, hn::CallFastExp(d, v),
+ -10.0 + static_cast<double>(in) * (20.0 / 15.0))
+DEFINE_MATH_BENCH(CallFastExpNormal, hn::CallFastExpNormal(d, v),
+ -10.0 + static_cast<double>(in) * (20.0 / 15.0))
+
+// Exp2 / FastExp2
+DEFINE_MATH_BENCH(CallExp2, hn::CallExp2(d, v),
+ -10.0 + static_cast<double>(in) * (20.0 / 15.0))
+DEFINE_MATH_BENCH(CallFastExp2, hn::CallFastExp2(d, v),
+ -10.0 + static_cast<double>(in) * (20.0 / 15.0))
+DEFINE_MATH_BENCH(CallFastExp2Normal, hn::CallFastExp2Normal(d, v),
+ -10.0 + static_cast<double>(in) * (20.0 / 15.0))
+
+// FastExpMinusOrZero
+DEFINE_MATH_BENCH(CallFastExpMinusOrZero, hn::CallFastExpMinusOrZero(d, v),
+ -10.0 + static_cast<double>(in) * (10.0 / 15.0))
+
+// Log / FastLog
+DEFINE_MATH_BENCH(CallLog, hn::CallLog(d, v),
+ 0.1 + static_cast<double>(in) * 1.0)
+DEFINE_MATH_BENCH(CallFastLog, hn::CallFastLog(d, v),
+ 0.1 + static_cast<double>(in) * 1.0)
+DEFINE_MATH_BENCH(CallFastLogPositiveNormal,
+ hn::CallFastLogPositiveNormal(d, v),
+ 0.1 + static_cast<double>(in) * 1.0)
+
+// Log2 / FastLog2
+DEFINE_MATH_BENCH(CallLog2, hn::CallLog2(d, v),
+ 0.1 + static_cast<double>(in) * 1.0)
+DEFINE_MATH_BENCH(CallFastLog2, hn::CallFastLog2(d, v),
+ 0.1 + static_cast<double>(in) * 1.0)
+DEFINE_MATH_BENCH(CallFastLog2PositiveNormal,
+ hn::CallFastLog2PositiveNormal(d, v),
+ 0.1 + static_cast<double>(in) * 1.0)
+
+// Log10 / FastLog10
+DEFINE_MATH_BENCH(CallLog10, hn::CallLog10(d, v),
+ 0.1 + static_cast<double>(in) * 1.0)
+DEFINE_MATH_BENCH(CallFastLog10, hn::CallFastLog10(d, v),
+ 0.1 + static_cast<double>(in) * 1.0)
+DEFINE_MATH_BENCH(CallFastLog10PositiveNormal,
+ hn::CallFastLog10PositiveNormal(d, v),
+ 0.1 + static_cast<double>(in) * 1.0)
+
+// Log1p / FastLog1p
+DEFINE_MATH_BENCH(CallLog1p, hn::CallLog1p(d, v),
+ 0.1 + static_cast<double>(in) * 1.0)
+DEFINE_MATH_BENCH(CallFastLog1p, hn::CallFastLog1p(d, v),
+ 0.1 + static_cast<double>(in) * 1.0)
+DEFINE_MATH_BENCH(CallFastLog1pPositiveNormal,
+ hn::CallFastLog1pPositiveNormal(d, v),
+ 0.1 + static_cast<double>(in) * 1.0)
+
+// Atan / FastAtan
+DEFINE_MATH_BENCH(CallAtan, hn::CallAtan(d, v),
+ -10.0 + static_cast<double>(in) * (20.0 / 15.0))
+DEFINE_MATH_BENCH(CallFastAtan, hn::CallFastAtan(d, v),
+ -10.0 + static_cast<double>(in) * (20.0 / 15.0))
+DEFINE_MATH_BENCH(CallFastAtanPositive, hn::CallFastAtanPositive(d, v),
+ 0.1 + static_cast<double>(in) * (20.0 / 15.0))
+
+// Tanh / FastTanh
+DEFINE_MATH_BENCH(CallTanh, hn::CallTanh(d, v),
+ -10.0 + static_cast<double>(in) * (20.0 / 15.0))
+DEFINE_MATH_BENCH(CallFastTanh, hn::CallFastTanh(d, v),
+ -10.0 + static_cast<double>(in) * (20.0 / 15.0))
+
+// Atan2 / FastAtan2
+DEFINE_MATH_BENCH_2ARG(CallAtan2, hn::CallAtan2(d, y, x),
+ -10.0 + static_cast<double>(in / 4) * (20.0 / 3.0),
+ -10.0 + static_cast<double>(in % 4) * (20.0 / 3.0))
+DEFINE_MATH_BENCH_2ARG(CallFastAtan2, hn::CallFastAtan2(d, y, x),
+ -10.0 + static_cast<double>(in / 4) * (20.0 / 3.0),
+ -10.0 + static_cast<double>(in % 4) * (20.0 / 3.0))
+
+struct RunBenchmarks {
+ template <class T, class D>
+ HWY_NOINLINE void operator()(T, D d) {
+ auto run_fxn = [&](const char* fxn) {
+ struct BenchmarkTable {
+ const char* name;
+ void (*func)(D);
+ };
+
+ static const BenchmarkTable table[] = {
+#define V(NAME) {#NAME, &BenchCall##NAME<D>},
+ HWY_MATH_BENCHMARKS(V)
+#undef V
+ };
+
+ for (const auto& entry : table) {
+ if (strcmp(fxn, entry.name) == 0) {
+ entry.func(d);
+ return;
+ }
+ }
+ };
+
+ if (hwy::g_fxn1) run_fxn(hwy::g_fxn1);
+ if (hwy::g_fxn2) run_fxn(hwy::g_fxn2);
+ }
+};
+
+HWY_NOINLINE void RunAllBenchmarks() {
+ hn::ForFloat3264Types(hn::ForPartialVectors<RunBenchmarks>());
+}
+
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace hwy {
+const char* g_fxn1 = nullptr;
+const char* g_fxn2 = nullptr;
+
+namespace {
+HWY_EXPORT(RunAllBenchmarks);
+} // namespace
+} // namespace hwy
+
+int main(int argc, char** argv) {
+ const char* fxn1 = nullptr;
+ const char* fxn2 = nullptr;
+
+ for (int i = 1; i < argc; ++i) {
+ if (strncmp(argv[i], "--fxn1=", 7) == 0) {
+ fxn1 = argv[i] + 7;
+ } else if (strncmp(argv[i], "--fxn2=", 7) == 0) {
+ fxn2 = argv[i] + 7;
+ }
+ }
+
+ if ((fxn1 == nullptr) != (fxn2 == nullptr)) {
+ fprintf(stderr,
+ "Error: Either pass both --fxn1 and --fxn2, or pass neither to use "
+ "defaults.\n");
+ return 1;
+ }
+
+ if (fxn1 == nullptr && fxn2 == nullptr) {
+ fxn1 = "FastExp";
+ fxn2 = "Exp";
+ }
+
+ const char* valid_fxns[] = {
+#define V(NAME) #NAME,
+ HWY_MATH_BENCHMARKS(V)
+#undef V
+ };
+
+ bool valid1 = false;
+ bool valid2 = false;
+ for (const char* v : valid_fxns) {
+ if (strcmp(fxn1, v) == 0) valid1 = true;
+ if (strcmp(fxn2, v) == 0) valid2 = true;
+ }
+
+ if (!valid1 || !valid2) {
+ fprintf(stderr,
+ "Error: One or both function names do not match any defined "
+ "benchmark functions.\n");
+ fprintf(stderr, "Valid functions are: ");
+ const size_t num_valid = sizeof(valid_fxns) / sizeof(valid_fxns[0]);
+ for (size_t i = 0; i < num_valid; ++i) {
+ fprintf(stderr, "%s%s", valid_fxns[i],
+ (i == num_valid - 1) ? ".\n" : ", ");
+ }
+ return 1;
+ }
+
+ hwy::g_fxn1 = fxn1;
+ hwy::g_fxn2 = fxn2;
+
+ HWY_DYNAMIC_DISPATCH(hwy::RunAllBenchmarks)();
+ return 0;
+}
+#endif // HWY_ONCE
diff --git a/third_party/highway/hwy/contrib/math/math_hyper_test.cc b/third_party/highway/hwy/contrib/math/math_hyper_test.cc
new file mode 100644
index 0000000000..b7669aa322
--- /dev/null
+++ b/third_party/highway/hwy/contrib/math/math_hyper_test.cc
@@ -0,0 +1,355 @@
+// Copyright 2020 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stdint.h>
+#include <stdio.h>
+
+#include <cfloat> // FLT_MAX
+#include <cmath> // std::abs
+
+#include "third_party/highway/hwy/base.h"
+
+// clang-format off
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "hwy/contrib/math/math_hyper_test.cc"
+#include "third_party/highway/hwy/foreach_target.h" // IWYU pragma: keep
+#include "third_party/highway/hwy/highway.h"
+#include "third_party/highway/hwy/contrib/math/math-inl.h"
+#include "third_party/highway/hwy/contrib/math/fast_math-inl.h"
+#include "third_party/highway/hwy/tests/test_util-inl.h"
+// clang-format on
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+namespace {
+
+// We have had test failures caused by excess precision due to keeping
+// intermediate results in 80-bit x87 registers. One such failure mode is that
+// Log1p computes a 1.0 which is not exactly equal to 1.0f, causing is_pole to
+// incorrectly evaluate to false.
+#undef HWY_MATH_TEST_EXCESS_PRECISION
+#if HWY_ARCH_X86_32 && HWY_COMPILER_GCC_ACTUAL && \
+ (HWY_TARGET == HWY_SCALAR || HWY_TARGET == HWY_EMU128)
+
+// GCC 13+: because CMAKE_CXX_EXTENSIONS is OFF, we build with -std= and hence
+// also -fexcess-precision=standard, so there is no problem. See #1708 and
+// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=323.
+#if HWY_COMPILER_GCC_ACTUAL >= 1300
+#define HWY_MATH_TEST_EXCESS_PRECISION 0
+
+#else // HWY_COMPILER_GCC_ACTUAL < 1300
+
+// The build system must enable SSE2, e.g. via HWY_CMAKE_SSE2 - see
+// https://stackoverflow.com/questions/20869904/c-handling-of-excess-precision .
+#if defined(__SSE2__) // correct flag given, no problem
+#define HWY_MATH_TEST_EXCESS_PRECISION 0
+#else
+#define HWY_MATH_TEST_EXCESS_PRECISION 1
+#pragma message( \
+ "Skipping scalar math_test on 32-bit x86 GCC <13 without HWY_CMAKE_SSE2")
+#endif // defined(__SSE2__)
+
+#endif // HWY_COMPILER_GCC_ACTUAL
+#else // not (x86-32, GCC, scalar target): running math_test normally
+#define HWY_MATH_TEST_EXCESS_PRECISION 0
+#endif // HWY_ARCH_X86_32 etc
+
+template <class T, class D>
+HWY_NOINLINE void TestMath(const char* name, T (*fx1)(T),
+ Vec<D> (*fxN)(D, VecArg<Vec<D>>), D d, T min, T max,
+ uint64_t max_error_ulp) {
+ if (HWY_MATH_TEST_EXCESS_PRECISION) {
+ static bool once = true;
+ if (once) {
+ once = false;
+ HWY_WARN("Skipping math_test due to GCC issue with excess precision.\n");
+ }
+ return;
+ }
+
+ using UintT = MakeUnsigned<T>;
+
+ const UintT min_bits = BitCastScalar<UintT>(min);
+ const UintT max_bits = BitCastScalar<UintT>(max);
+
+ // If min is negative and max is positive, the range needs to be broken into
+ // two pieces, [+0, max] and [-0, min], otherwise [min, max].
+ int range_count = 1;
+ UintT ranges[2][2] = {{min_bits, max_bits}, {0, 0}};
+ if ((min < 0.0) && (max > 0.0)) {
+ ranges[0][0] = BitCastScalar<UintT>(ConvertScalarTo<T>(+0.0));
+ ranges[0][1] = max_bits;
+ ranges[1][0] = BitCastScalar<UintT>(ConvertScalarTo<T>(-0.0));
+ ranges[1][1] = min_bits;
+ range_count = 2;
+ } else {
+ // If not splitting, ensure we iterate from smaller uint to larger uint.
+ // For negative numbers, min (e.g. -1000) has larger uint representation
+ // than max (e.g. -1).
+ if (ranges[0][0] > ranges[0][1]) {
+ auto tmp = ranges[0][0];
+ ranges[0][0] = ranges[0][1];
+ ranges[0][1] = tmp;
+ }
+ }
+
+ uint64_t max_ulp = 0;
+ // Emulation is slower, so cannot afford as many.
+ constexpr UintT kSamplesPerRange = static_cast<UintT>(AdjustedReps(4000));
+ for (int range_index = 0; range_index < range_count; ++range_index) {
+ const UintT start = ranges[range_index][0];
+ const UintT stop = ranges[range_index][1];
+ const UintT step = HWY_MAX(1, ((stop - start) / kSamplesPerRange));
+ for (UintT value_bits = start; value_bits <= stop; value_bits += step) {
+ // For reasons unknown, the HWY_MAX is necessary on RVV, otherwise
+ // value_bits can be less than start, and thus possibly NaN.
+ const T value =
+ BitCastScalar<T>(HWY_MIN(HWY_MAX(start, value_bits), stop));
+ const T actual = GetLane(fxN(d, Set(d, value)));
+ const T expected = fx1(value);
+
+ // Skip small inputs and outputs on armv7, it flushes subnormals to zero.
+#if HWY_TARGET <= HWY_NEON_WITHOUT_AES && HWY_ARCH_ARM_V7
+ if ((std::abs(value) < 1e-37f) || (std::abs(expected) < 1e-37f)) {
+ continue;
+ }
+#endif
+
+ const auto ulp = hwy::detail::ComputeUlpDelta(actual, expected);
+ max_ulp = HWY_MAX(max_ulp, ulp);
+ if (ulp > max_error_ulp) {
+ fprintf(stderr, "%s: %s(%f) expected %E actual %E ulp %g max ulp %u\n",
+ hwy::TypeName(T(), Lanes(d)).c_str(), name, value, expected,
+ actual, static_cast<double>(ulp),
+ static_cast<uint32_t>(max_error_ulp));
+ }
+ }
+ }
+ fprintf(stderr, "%s: %s max_ulp %g\n", hwy::TypeName(T(), Lanes(d)).c_str(),
+ name, static_cast<double>(max_ulp));
+ HWY_ASSERT(max_ulp <= max_error_ulp);
+}
+
+#define DEFINE_MATH_TEST_FUNC(NAME) \
+ HWY_NOINLINE void TestAll##NAME() { \
+ ForFloat3264Types(ForPartialVectors<Test##NAME>()); \
+ }
+
+#undef DEFINE_MATH_TEST
+#define DEFINE_MATH_TEST(NAME, F32x1, F32xN, F32_MIN, F32_MAX, F32_ERROR, \
+ F64x1, F64xN, F64_MIN, F64_MAX, F64_ERROR) \
+ struct Test##NAME { \
+ template <class T, class D> \
+ HWY_NOINLINE void operator()(T, D d) { \
+ if (sizeof(T) == 4) { \
+ TestMath<T, D>(HWY_STR(NAME), F32x1, F32xN, d, F32_MIN, F32_MAX, \
+ F32_ERROR); \
+ } else { \
+ TestMath<T, D>(HWY_STR(NAME), F64x1, F64xN, d, \
+ static_cast<T>(F64_MIN), static_cast<T>(F64_MAX), \
+ F64_ERROR); \
+ } \
+ } \
+ }; \
+ DEFINE_MATH_TEST_FUNC(NAME)
+
+// Floating point values closest to but less than 1.0. Avoid variables with
+// static initializers inside HWY_BEFORE_NAMESPACE/HWY_AFTER_NAMESPACE to
+// ensure target-specific code does not leak into startup code.
+float kNearOneF() { return BitCastScalar<float>(0x3F7FFFFF); }
+double kNearOneD() { return BitCastScalar<double>(0x3FEFFFFFFFFFFFFFULL); }
+
+constexpr uint64_t ACosh32ULP() {
+#if defined(__MINGW32__)
+ return 8;
+#else
+ return 3;
+#endif
+}
+
+// clang-format off
+DEFINE_MATH_TEST(Acosh,
+ std::acosh, CallAcosh, +1.0f, +FLT_MAX, ACosh32ULP(),
+ std::acosh, CallAcosh, +1.0, +DBL_MAX, 3)
+DEFINE_MATH_TEST(Asinh,
+ std::asinh, CallAsinh, -FLT_MAX, +FLT_MAX, 3,
+ std::asinh, CallAsinh, -DBL_MAX, +DBL_MAX, 3)
+// NEON has ULP 4 instead of 3
+DEFINE_MATH_TEST(Atanh,
+ std::atanh, CallAtanh, -kNearOneF(), +kNearOneF(), 4,
+ std::atanh, CallAtanh, -kNearOneD(), +kNearOneD(), 3)
+DEFINE_MATH_TEST(Sinh,
+ std::sinh, CallSinh, -80.0f, +80.0f, 4,
+ std::sinh, CallSinh, -709.0, +709.0, 4)
+DEFINE_MATH_TEST(Tanh,
+ std::tanh, CallTanh, -FLT_MAX, +FLT_MAX, 4,
+ std::tanh, CallTanh, -DBL_MAX, +DBL_MAX, 4)
+// clang-format on
+
+template <class T, class D>
+HWY_NOINLINE void TestMathRelative(const char* name, T (*fx1)(T),
+ Vec<D> (*fxN)(D, VecArg<Vec<D>>), D d, T min,
+ T max, double max_relative_error,
+ uint64_t samples = 4000) {
+ if (HWY_MATH_TEST_EXCESS_PRECISION) {
+ static bool once = true;
+ if (once) {
+ once = false;
+ HWY_WARN("Skipping math_test due to GCC issue with excess precision.\n");
+ }
+ return;
+ }
+
+ using UintT = MakeUnsigned<T>;
+
+ const UintT min_bits = BitCastScalar<UintT>(min);
+ const UintT max_bits = BitCastScalar<UintT>(max);
+
+ // If min is negative and max is positive, the range needs to be broken into
+ // two pieces, [+0, max] and [-0, min], otherwise [min, max].
+ int range_count = 1;
+ UintT ranges[2][2] = {{min_bits, max_bits}, {0, 0}};
+ if ((min < 0.0) && (max > 0.0)) {
+ ranges[0][0] = BitCastScalar<UintT>(ConvertScalarTo<T>(+0.0));
+ ranges[0][1] = max_bits;
+ ranges[1][0] = BitCastScalar<UintT>(ConvertScalarTo<T>(-0.0));
+ ranges[1][1] = min_bits;
+ range_count = 2;
+ } else {
+ // If not splitting, ensure we iterate from smaller uint to larger uint.
+ // For negative numbers, min (e.g. -1000) has larger uint representation
+ // than max (e.g. -1).
+ if (ranges[0][0] > ranges[0][1]) {
+ auto tmp = ranges[0][0];
+ ranges[0][0] = ranges[0][1];
+ ranges[0][1] = tmp;
+ }
+ }
+
+ double max_actual_rel_error = 0.0;
+ double max_rel_err_value = 0.0;
+ double max_rel_err_expected = 0.0;
+ double max_rel_err_actual = 0.0;
+ double sum_rel_error = 0.0;
+ uint64_t count = 0;
+ // Emulation is slower, so cannot afford as many.
+ const UintT kSamplesPerRange =
+ static_cast<UintT>(AdjustedReps(static_cast<size_t>(samples)));
+ for (int range_index = 0; range_index < range_count; ++range_index) {
+ const UintT start = ranges[range_index][0];
+ const UintT stop = ranges[range_index][1];
+ const UintT step = HWY_MAX(1, ((stop - start) / kSamplesPerRange));
+ for (UintT value_bits = start; value_bits <= stop; value_bits += step) {
+ // For reasons unknown, the HWY_MAX is necessary on RVV, otherwise
+ // value_bits can be less than start, and thus possibly NaN.
+ const T value =
+ BitCastScalar<T>(HWY_MIN(HWY_MAX(start, value_bits), stop));
+ const T actual = GetLane(fxN(d, Set(d, value)));
+ const T expected = fx1(value);
+
+ // Skip small inputs and outputs on armv7, it flushes subnormals to zero.
+#if HWY_TARGET <= HWY_NEON_WITHOUT_AES && HWY_ARCH_ARM_V7
+ if ((std::abs(value) < 1e-37f) || (std::abs(expected) < 1e-37f)) {
+ continue;
+ }
+#endif
+
+ if (std::abs(expected) > 0.0) {
+ double rel = std::abs(static_cast<double>(actual) -
+ static_cast<double>(expected)) /
+ std::abs(static_cast<double>(expected));
+ if (rel > max_actual_rel_error) {
+ max_actual_rel_error = rel;
+ max_rel_err_value = static_cast<double>(value);
+ max_rel_err_expected = static_cast<double>(expected);
+ max_rel_err_actual = static_cast<double>(actual);
+ }
+ sum_rel_error += rel;
+ count++;
+ if (rel > max_relative_error) {
+ static int print_count = 0;
+ if (print_count < 10) {
+ fprintf(stderr,
+ "%s: %s(%f) expected %E actual %E rel %E max rel %E\n",
+ hwy::TypeName(T(), Lanes(d)).c_str(), name,
+ static_cast<double>(value), static_cast<double>(expected),
+ static_cast<double>(actual), rel, max_relative_error);
+ print_count++;
+ }
+ }
+ }
+ }
+ }
+ fprintf(stderr, "%s: %s max_rel_error %E at input %E actual %E expected %E\n",
+ hwy::TypeName(T(), Lanes(d)).c_str(), name, max_actual_rel_error,
+ max_rel_err_value, max_rel_err_actual, max_rel_err_expected);
+ if (count > 0) {
+ fprintf(stderr, "%s: %s avg_rel_error %E\n",
+ hwy::TypeName(T(), Lanes(d)).c_str(), name,
+ sum_rel_error / static_cast<double>(count));
+ }
+ HWY_ASSERT(max_actual_rel_error <= max_relative_error);
+}
+
+struct TestFastTanh {
+ template <class T, class D>
+ HWY_NOINLINE void operator()(T, D d) {
+ const double max_relative_error_float = 0.000007;
+ const double max_relative_error_double = 0.000007;
+ const double max_relative_error_small = 0.0000004;
+ const uint64_t samples = 1000000;
+ const uint64_t samples_small = 10000;
+ TestMathRelative<T, D>("FastTanh Small", std::tanh, CallFastTanh, d,
+ static_cast<T>(-1e-2), static_cast<T>(1e-2),
+ max_relative_error_small, samples_small);
+ if (sizeof(T) == 4) {
+ TestMathRelative<T, D>("FastTanh Float", std::tanh, CallFastTanh, d,
+ static_cast<T>(-1e35), static_cast<T>(1e35),
+ max_relative_error_float, samples);
+ } else {
+ TestMathRelative<T, D>("FastTanh Double", std::tanh, CallFastTanh, d,
+ static_cast<T>(-1e305), static_cast<T>(1e305),
+ max_relative_error_double, samples);
+ }
+ }
+};
+
+HWY_NOINLINE void TestAllFastTanh() {
+ if (HWY_MATH_TEST_EXCESS_PRECISION) return;
+ ForFloat3264Types(ForPartialVectors<TestFastTanh>());
+}
+
+} // namespace
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace hwy {
+namespace {
+HWY_BEFORE_TEST(HwyMathHyperTest);
+HWY_EXPORT_AND_TEST_P(HwyMathHyperTest, TestAllAcosh);
+HWY_EXPORT_AND_TEST_P(HwyMathHyperTest, TestAllAsinh);
+HWY_EXPORT_AND_TEST_P(HwyMathHyperTest, TestAllAtanh);
+HWY_EXPORT_AND_TEST_P(HwyMathHyperTest, TestAllSinh);
+HWY_EXPORT_AND_TEST_P(HwyMathHyperTest, TestAllTanh);
+HWY_EXPORT_AND_TEST_P(HwyMathHyperTest, TestAllFastTanh);
+HWY_AFTER_TEST();
+} // namespace
+} // namespace hwy
+HWY_TEST_MAIN();
+#endif // HWY_ONCE
diff --git a/third_party/highway/hwy/contrib/math/math_tan_test.cc b/third_party/highway/hwy/contrib/math/math_tan_test.cc
new file mode 100644
index 0000000000..d6e0384a1a
--- /dev/null
+++ b/third_party/highway/hwy/contrib/math/math_tan_test.cc
@@ -0,0 +1,751 @@
+// Copyright 2020 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stdint.h>
+#include <stdio.h>
+
+#include <cfloat> // FLT_MAX
+#include <cmath> // std::abs
+
+#include "third_party/highway/hwy/base.h"
+#include "third_party/highway/hwy/nanobenchmark.h"
+
+// Clang build timeout on RVV as of 2025-09-19.
+#if !HWY_ARCH_RVV
+
+// clang-format off
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "hwy/contrib/math/math_tan_test.cc"
+#include "third_party/highway/hwy/foreach_target.h" // IWYU pragma: keep
+#include "third_party/highway/hwy/highway.h"
+#include "third_party/highway/hwy/contrib/math/math-inl.h"
+#include "third_party/highway/hwy/contrib/math/fast_math-inl.h"
+#include "third_party/highway/hwy/tests/test_util-inl.h"
+// clang-format on
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+namespace {
+
+// We have had test failures caused by excess precision due to keeping
+// intermediate results in 80-bit x87 registers. One such failure mode is that
+// Log1p computes a 1.0 which is not exactly equal to 1.0f, causing is_pole to
+// incorrectly evaluate to false.
+#undef HWY_MATH_TEST_EXCESS_PRECISION
+#if HWY_ARCH_X86_32 && HWY_COMPILER_GCC_ACTUAL && \
+ (HWY_TARGET == HWY_SCALAR || HWY_TARGET == HWY_EMU128)
+
+// GCC 13+: because CMAKE_CXX_EXTENSIONS is OFF, we build with -std= and hence
+// also -fexcess-precision=standard, so there is no problem. See #1708 and
+// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=323.
+#if HWY_COMPILER_GCC_ACTUAL >= 1300
+#define HWY_MATH_TEST_EXCESS_PRECISION 0
+
+#else // HWY_COMPILER_GCC_ACTUAL < 1300
+
+// The build system must enable SSE2, e.g. via HWY_CMAKE_SSE2 - see
+// https://stackoverflow.com/questions/20869904/c-handling-of-excess-precision .
+#if defined(__SSE2__) // correct flag given, no problem
+#define HWY_MATH_TEST_EXCESS_PRECISION 0
+#else
+#define HWY_MATH_TEST_EXCESS_PRECISION 1
+#pragma message( \
+ "Skipping scalar math_test on 32-bit x86 GCC <13 without HWY_CMAKE_SSE2")
+#endif // defined(__SSE2__)
+
+#endif // HWY_COMPILER_GCC_ACTUAL
+#else // not (x86-32, GCC, scalar target): running math_test normally
+#define HWY_MATH_TEST_EXCESS_PRECISION 0
+#endif // HWY_ARCH_X86_32 etc
+
+template <class T, class D>
+HWY_NOINLINE void TestMath(const char* name, T (*fx1)(T),
+ Vec<D> (*fxN)(D, VecArg<Vec<D>>), D d, T min, T max,
+ uint64_t max_error_ulp) {
+ if (HWY_MATH_TEST_EXCESS_PRECISION) {
+ static bool once = true;
+ if (once) {
+ once = false;
+ HWY_WARN("Skipping math_test due to GCC issue with excess precision.\n");
+ }
+ return;
+ }
+
+ using UintT = MakeUnsigned<T>;
+
+ const UintT min_bits = BitCastScalar<UintT>(min);
+ const UintT max_bits = BitCastScalar<UintT>(max);
+
+ // If min is negative and max is positive, the range needs to be broken into
+ // two pieces, [+0, max] and [-0, min], otherwise [min, max].
+ int range_count = 1;
+ UintT ranges[2][2] = {{min_bits, max_bits}, {0, 0}};
+ if ((min < 0.0) && (max > 0.0)) {
+ ranges[0][0] = BitCastScalar<UintT>(ConvertScalarTo<T>(+0.0));
+ ranges[0][1] = max_bits;
+ ranges[1][0] = BitCastScalar<UintT>(ConvertScalarTo<T>(-0.0));
+ ranges[1][1] = min_bits;
+ range_count = 2;
+ }
+
+ uint64_t max_ulp = 0;
+ // Emulation is slower, so cannot afford as many.
+ constexpr UintT kSamplesPerRange = static_cast<UintT>(AdjustedReps(4000));
+ for (int range_index = 0; range_index < range_count; ++range_index) {
+ const UintT start = ranges[range_index][0];
+ const UintT stop = ranges[range_index][1];
+ const UintT step = HWY_MAX(1, ((stop - start) / kSamplesPerRange));
+ for (UintT value_bits = start; value_bits <= stop; value_bits += step) {
+ // For reasons unknown, the HWY_MAX is necessary on RVV, otherwise
+ // value_bits can be less than start, and thus possibly NaN.
+ const T value =
+ BitCastScalar<T>(HWY_MIN(HWY_MAX(start, value_bits), stop));
+ const T actual = GetLane(fxN(d, Set(d, value)));
+ const T expected = fx1(value);
+
+ // Skip small inputs and outputs on armv7, it flushes subnormals to zero.
+#if HWY_TARGET <= HWY_NEON_WITHOUT_AES && HWY_ARCH_ARM_V7
+ if ((std::abs(value) < 1e-37f) || (std::abs(expected) < 1e-37f)) {
+ continue;
+ }
+#endif
+
+ const auto ulp = hwy::detail::ComputeUlpDelta(actual, expected);
+ max_ulp = HWY_MAX(max_ulp, ulp);
+ if (ulp > max_error_ulp) {
+ fprintf(stderr, "%s: %s(%f) expected %E actual %E ulp %g max ulp %u\n",
+ hwy::TypeName(T(), Lanes(d)).c_str(), name, value, expected,
+ actual, static_cast<double>(ulp),
+ static_cast<uint32_t>(max_error_ulp));
+ }
+ }
+ }
+ fprintf(stderr, "%s: %s max_ulp %g\n", hwy::TypeName(T(), Lanes(d)).c_str(),
+ name, static_cast<double>(max_ulp));
+ HWY_ASSERT(max_ulp <= max_error_ulp);
+}
+
+template <class T, class D>
+HWY_NOINLINE void TestMathRelative(const char* name, T (*fx1)(T),
+ Vec<D> (*fxN)(D, VecArg<Vec<D>>), D d, T min,
+ T max, double max_relative_error,
+ uint64_t samples = 4000) {
+ if (HWY_MATH_TEST_EXCESS_PRECISION) {
+ static bool once = true;
+ if (once) {
+ once = false;
+ HWY_WARN("Skipping math_test due to GCC issue with excess precision.\n");
+ }
+ return;
+ }
+
+ using UintT = MakeUnsigned<T>;
+
+ const UintT min_bits = BitCastScalar<UintT>(min);
+ const UintT max_bits = BitCastScalar<UintT>(max);
+
+ // If min is negative and max is positive, the range needs to be broken into
+ // two pieces, [+0, max] and [-0, min], otherwise [min, max].
+ int range_count = 1;
+ UintT ranges[2][2] = {{min_bits, max_bits}, {0, 0}};
+ if ((min < 0.0) && (max > 0.0)) {
+ ranges[0][0] = BitCastScalar<UintT>(ConvertScalarTo<T>(+0.0));
+ ranges[0][1] = max_bits;
+ ranges[1][0] = BitCastScalar<UintT>(ConvertScalarTo<T>(-0.0));
+ ranges[1][1] = min_bits;
+ range_count = 2;
+ }
+
+ double max_actual_rel_error = 0.0;
+ double max_error_value = 0.0;
+ double sum_rel_error = 0.0;
+ uint64_t count = 0;
+ // Emulation is slower, so cannot afford as many.
+ const UintT kSamplesPerRange =
+ static_cast<UintT>(AdjustedReps(static_cast<size_t>(samples)));
+ for (int range_index = 0; range_index < range_count; ++range_index) {
+ const UintT start = ranges[range_index][0];
+ const UintT stop = ranges[range_index][1];
+ const UintT step = HWY_MAX(1, ((stop - start) / kSamplesPerRange));
+ for (UintT value_bits = start; value_bits <= stop; value_bits += step) {
+ // For reasons unknown, the HWY_MAX is necessary on RVV, otherwise
+ // value_bits can be less than start, and thus possibly NaN.
+ const T value =
+ BitCastScalar<T>(HWY_MIN(HWY_MAX(start, value_bits), stop));
+ const T actual = GetLane(fxN(d, Set(d, value)));
+ const T expected = fx1(value);
+
+ // Skip small inputs and outputs on armv7, it flushes subnormals to zero.
+#if HWY_TARGET <= HWY_NEON_WITHOUT_AES && HWY_ARCH_ARM_V7
+ if ((std::abs(value) < 1e-37f) || (std::abs(expected) < 1e-37f)) {
+ continue;
+ }
+#endif
+
+ if (std::abs(expected) > 1e-20) {
+ double rel = std::abs(static_cast<double>(actual) -
+ static_cast<double>(expected)) /
+ std::abs(static_cast<double>(expected));
+ if (rel > max_actual_rel_error) {
+ max_actual_rel_error = rel;
+ max_error_value = static_cast<double>(value);
+ }
+ sum_rel_error += rel;
+ count++;
+ if (rel > max_relative_error) {
+ fprintf(stderr,
+ "%s: %s(%f) expected %E actual %E rel %E max rel %E\n",
+ hwy::TypeName(T(), Lanes(d)).c_str(), name,
+ static_cast<double>(value), static_cast<double>(expected),
+ static_cast<double>(actual), rel, max_relative_error);
+ }
+ }
+ }
+ }
+ fprintf(stderr, "%s: %s max_rel_error %E at %E\n",
+ hwy::TypeName(T(), Lanes(d)).c_str(), name, max_actual_rel_error,
+ max_error_value);
+ if (count > 0) {
+ fprintf(stderr, "%s: %s avg_rel_error %E\n",
+ hwy::TypeName(T(), Lanes(d)).c_str(), name,
+ sum_rel_error / static_cast<double>(count));
+ }
+ HWY_ASSERT(max_actual_rel_error <= max_relative_error);
+}
+
+#define DEFINE_MATH_TEST_FUNC(NAME) \
+ HWY_NOINLINE void TestAll##NAME() { \
+ ForFloat3264Types(ForPartialVectors<Test##NAME>()); \
+ }
+
+#undef DEFINE_MATH_TEST
+#define DEFINE_MATH_TEST(NAME, F32x1, F32xN, F32_MIN, F32_MAX, F32_ERROR, \
+ F64x1, F64xN, F64_MIN, F64_MAX, F64_ERROR) \
+ struct Test##NAME { \
+ template <class T, class D> \
+ HWY_NOINLINE void operator()(T, D d) { \
+ if (sizeof(T) == 4) { \
+ TestMath<T, D>(HWY_STR(NAME), F32x1, F32xN, d, F32_MIN, F32_MAX, \
+ F32_ERROR); \
+ } else { \
+ TestMath<T, D>(HWY_STR(NAME), F64x1, F64xN, d, \
+ static_cast<T>(F64_MIN), static_cast<T>(F64_MAX), \
+ F64_ERROR); \
+ } \
+ } \
+ }; \
+ DEFINE_MATH_TEST_FUNC(NAME)
+
+// clang-format off
+DEFINE_MATH_TEST(Atan,
+ std::atan, CallAtan, -FLT_MAX, +FLT_MAX, 3,
+ std::atan, CallAtan, -DBL_MAX, +DBL_MAX, 3)
+// clang-format on
+
+template <typename T, class D>
+void Atan2TestCases(T /*unused*/, D d, size_t& padded,
+ AlignedFreeUniquePtr<T[]>& out_y,
+ AlignedFreeUniquePtr<T[]>& out_x,
+ AlignedFreeUniquePtr<T[]>& out_expected) {
+ struct YX {
+ T y;
+ T x;
+ T expected;
+ };
+ const T pos = ConvertScalarTo<T>(1E5);
+ const T neg = ConvertScalarTo<T>(-1E7);
+ const T p0 = ConvertScalarTo<T>(0);
+ // -0 is not enough to get an actual negative zero.
+ const T n0 = ConvertScalarTo<T>(-0.0);
+ const T p1 = ConvertScalarTo<T>(1);
+ const T n1 = ConvertScalarTo<T>(-1);
+ const T p2 = ConvertScalarTo<T>(2);
+ const T n2 = ConvertScalarTo<T>(-2);
+ const T inf = GetLane(Inf(d));
+ const T nan = GetLane(NaN(d));
+
+ const T pi = ConvertScalarTo<T>(3.141592653589793238);
+ const YX test_cases[] = { // 45 degree steps:
+ {p0, p1, p0}, // E
+ {n1, p1, -pi / 4}, // SE
+ {n1, p0, -pi / 2}, // S
+ {n1, n1, -3 * pi / 4}, // SW
+ {p0, n1, pi}, // W
+ {p1, n1, 3 * pi / 4}, // NW
+ {p1, p0, pi / 2}, // N
+ {p1, p1, pi / 4}, // NE
+
+ // y = ±0, x < 0 or -0
+ {p0, n1, pi},
+ {n0, n2, -pi},
+ // y = ±0, x > 0 or +0
+ {p0, p2, p0},
+ {n0, p2, n0},
+ // y = ±∞, x finite
+ {inf, p2, pi / 2},
+ {-inf, p2, -pi / 2},
+ // y = ±∞, x = -∞
+ {inf, -inf, 3 * pi / 4},
+ {-inf, -inf, -3 * pi / 4},
+ // y = ±∞, x = +∞
+ {inf, inf, pi / 4},
+ {-inf, inf, -pi / 4},
+ // y < 0, x = ±0
+ {n2, p0, -pi / 2},
+ {n1, n0, -pi / 2},
+ // y > 0, x = ±0
+ {pos, p0, pi / 2},
+ {p2, n0, pi / 2},
+ // finite y > 0, x = -∞
+ {pos, -inf, pi},
+ // finite y < 0, x = -∞
+ {neg, -inf, -pi},
+ // finite y > 0, x = +∞
+ {pos, inf, p0},
+ // finite y < 0, x = +∞
+ {neg, inf, n0},
+ // y NaN xor x NaN
+ {nan, p0, nan},
+ {pos, nan, nan}};
+ const size_t kNumTestCases = sizeof(test_cases) / sizeof(test_cases[0]);
+ const size_t N = Lanes(d);
+ padded = RoundUpTo(kNumTestCases, N); // allow loading whole vectors
+ out_y = AllocateAligned<T>(padded);
+ out_x = AllocateAligned<T>(padded);
+ out_expected = AllocateAligned<T>(padded);
+ HWY_ASSERT(out_y && out_x && out_expected);
+ size_t i = 0;
+ for (; i < kNumTestCases; ++i) {
+ out_y[i] = test_cases[i].y;
+ out_x[i] = test_cases[i].x;
+ out_expected[i] = test_cases[i].expected;
+ }
+ for (; i < padded; ++i) {
+ out_y[i] = p0;
+ out_x[i] = p0;
+ out_expected[i] = p0;
+ }
+}
+
+struct TestAtan2 {
+ template <typename T, class D>
+ HWY_NOINLINE void operator()(T t, D d) {
+ const size_t N = Lanes(d);
+
+ size_t padded;
+ AlignedFreeUniquePtr<T[]> in_y, in_x, expected;
+ Atan2TestCases(t, d, padded, in_y, in_x, expected);
+
+ const Vec<D> tolerance = Set(d, ConvertScalarTo<T>(1E-5));
+
+ for (size_t i = 0; i < padded; ++i) {
+ const T actual = ConvertScalarTo<T>(atan2(in_y[i], in_x[i]));
+ // fprintf(stderr, "%zu: table %f atan2 %f\n", i, expected[i], actual);
+ HWY_ASSERT_EQ(expected[i], actual);
+ }
+ for (size_t i = 0; i < padded; i += N) {
+ const Vec<D> y = Load(d, &in_y[i]);
+ const Vec<D> x = Load(d, &in_x[i]);
+#if HWY_ARCH_ARM_A64
+ // TODO(b/287462770): inline to work around incorrect SVE codegen
+ const Vec<D> actual = Atan2(d, y, x);
+#else
+ const Vec<D> actual = CallAtan2(d, y, x);
+#endif
+ const Vec<D> vexpected = Load(d, &expected[i]);
+
+ const Mask<D> exp_nan = IsNaN(vexpected);
+ const Mask<D> act_nan = IsNaN(actual);
+ HWY_ASSERT_MASK_EQ(d, exp_nan, act_nan);
+
+ // If not NaN, then compare with tolerance
+ const Mask<D> ge = Ge(actual, Sub(vexpected, tolerance));
+ const Mask<D> le = Le(actual, Add(vexpected, tolerance));
+ const Mask<D> ok = Or(act_nan, And(le, ge));
+ if (!AllTrue(d, ok)) {
+ const size_t mismatch =
+ static_cast<size_t>(FindKnownFirstTrue(d, Not(ok)));
+ fprintf(stderr, "Mismatch for i=%d expected %E actual %E\n",
+ static_cast<int>(i + mismatch), expected[i + mismatch],
+ ExtractLane(actual, mismatch));
+ HWY_ASSERT(0);
+ }
+ }
+ }
+};
+
+HWY_NOINLINE void TestAllAtan2() {
+ if (HWY_MATH_TEST_EXCESS_PRECISION) return;
+
+ ForFloat3264Types(ForPartialVectors<TestAtan2>());
+}
+
+template <typename T, class D>
+void HypotTestCases(T /*unused*/, D d, size_t& padded,
+ AlignedFreeUniquePtr<T[]>& out_a,
+ AlignedFreeUniquePtr<T[]>& out_b,
+ AlignedFreeUniquePtr<T[]>& out_expected) {
+ using TU = MakeUnsigned<T>;
+
+ struct AB {
+ T a;
+ T b;
+ };
+
+ constexpr int kNumOfMantBits = MantissaBits<T>();
+ static_assert(kNumOfMantBits > 0, "kNumOfMantBits > 0 must be true");
+
+ // Ensures inputs are not constexpr.
+ const TU u1 = static_cast<TU>(hwy::Unpredictable1());
+ const double k1 = static_cast<double>(u1);
+
+ const T pos = ConvertScalarTo<T>(1E5 * k1);
+ const T neg = ConvertScalarTo<T>(-1E7 * k1);
+ const T p0 = ConvertScalarTo<T>(k1 - 1.0);
+ // -0 is not enough to get an actual negative zero.
+ const T n0 = ScalarCopySign<T>(p0, neg);
+ const T p1 = ConvertScalarTo<T>(k1);
+ const T n1 = ConvertScalarTo<T>(-k1);
+ const T p2 = ConvertScalarTo<T>(2 * k1);
+ const T n2 = ConvertScalarTo<T>(-2 * k1);
+ const T inf = BitCastScalar<T>(ExponentMask<T>() * u1);
+ const T neg_inf = ScalarCopySign(inf, n1);
+ const T nan = BitCastScalar<T>(
+ static_cast<TU>(ExponentMask<T>() | (u1 << (kNumOfMantBits - 1))));
+
+ const double max_as_f64 = ConvertScalarTo<double>(HighestValue<T>()) * k1;
+ const T max = ConvertScalarTo<T>(max_as_f64);
+
+ const T huge = ConvertScalarTo<T>(max_as_f64 * 0.25);
+ const T neg_huge = ScalarCopySign(huge, n1);
+
+ const T huge2 = ConvertScalarTo<T>(max_as_f64 * 0.039415044328304796);
+
+ const T large = ConvertScalarTo<T>(3.512227595593985E18 * k1);
+ const T neg_large = ScalarCopySign(large, n1);
+ const T large2 = ConvertScalarTo<T>(2.1190576943127544E16 * k1);
+
+ const T small = ConvertScalarTo<T>(1.067033284841808E-11 * k1);
+ const T neg_small = ScalarCopySign(small, n1);
+ const T small2 = ConvertScalarTo<T>(1.9401409532292856E-12 * k1);
+
+ const T tiny = BitCastScalar<T>(static_cast<TU>(u1 << kNumOfMantBits));
+ const T neg_tiny = ScalarCopySign(tiny, n1);
+
+ const T tiny2 =
+ ConvertScalarTo<T>(78.68466968859765 * ConvertScalarTo<double>(tiny));
+
+ const AB test_cases[] = {{p0, p0}, {p0, n0},
+ {n0, n0}, {p1, p1},
+ {p1, n1}, {n1, n1},
+ {p2, p2}, {p2, n2},
+ {p2, pos}, {p2, neg},
+ {n2, pos}, {n2, neg},
+ {n2, n2}, {p0, tiny},
+ {p0, neg_tiny}, {n0, tiny},
+ {n0, neg_tiny}, {p1, tiny},
+ {p1, neg_tiny}, {n1, tiny},
+ {n1, neg_tiny}, {tiny, p0},
+ {tiny2, p0}, {tiny, tiny2},
+ {neg_tiny, tiny2}, {huge, huge2},
+ {neg_huge, huge2}, {huge, p0},
+ {huge, tiny}, {huge2, tiny2},
+ {large, p0}, {large, large2},
+ {neg_large, p0}, {neg_large, large2},
+ {small, p0}, {small, small2},
+ {neg_small, p0}, {neg_small, small2},
+ {max, p0}, {max, huge},
+ {max, max}, {p0, inf},
+ {n0, inf}, {p1, inf},
+ {n1, inf}, {p2, inf},
+ {n2, inf}, {p0, neg_inf},
+ {n0, neg_inf}, {p1, neg_inf},
+ {n1, neg_inf}, {p2, neg_inf},
+ {n2, neg_inf}, {p0, nan},
+ {n0, nan}, {p1, nan},
+ {n1, nan}, {p2, nan},
+ {n2, nan}, {huge, inf},
+ {inf, nan}, {neg_inf, nan},
+ {nan, nan}};
+
+ const size_t kNumTestCases = sizeof(test_cases) / sizeof(test_cases[0]);
+ const size_t N = Lanes(d);
+ padded = RoundUpTo(kNumTestCases, N); // allow loading whole vectors
+ out_a = AllocateAligned<T>(padded);
+ out_b = AllocateAligned<T>(padded);
+ out_expected = AllocateAligned<T>(padded);
+ HWY_ASSERT(out_a && out_b && out_expected);
+
+ size_t i = 0;
+ for (; i < kNumTestCases; ++i) {
+ const T a =
+ test_cases[i].a * hwy::ConvertScalarTo<T>(hwy::Unpredictable1());
+ const T b = test_cases[i].b;
+
+#if HWY_TARGET <= HWY_NEON_WITHOUT_AES && HWY_ARCH_ARM_V7
+ // Ignore test cases that have infinite or NaN inputs on Armv7 NEON
+ if (!ScalarIsFinite(a) || !ScalarIsFinite(b)) {
+ out_a[i] = p0;
+ out_b[i] = p0;
+ out_expected[i] = p0;
+ continue;
+ }
+#endif
+
+ out_a[i] = a;
+ out_b[i] = b;
+
+ if (ScalarIsInf(a) || ScalarIsInf(b)) {
+ out_expected[i] = inf;
+ } else if (ScalarIsNaN(a) || ScalarIsNaN(b)) {
+ out_expected[i] = nan;
+ } else {
+ out_expected[i] = std::hypot(a, b);
+ }
+ }
+ for (; i < padded; ++i) {
+ out_a[i] = p0;
+ out_b[i] = p0;
+ out_expected[i] = p0;
+ }
+}
+
+struct TestHypot {
+ template <typename T, class D>
+ HWY_NOINLINE void operator()(T t, D d) {
+ if (HWY_MATH_TEST_EXCESS_PRECISION) {
+ return;
+ }
+
+ const size_t N = Lanes(d);
+
+ constexpr uint64_t kMaxErrorUlp = 4;
+
+ size_t padded;
+ AlignedFreeUniquePtr<T[]> in_a, in_b, expected;
+ HypotTestCases(t, d, padded, in_a, in_b, expected);
+
+ auto actual1_lanes = AllocateAligned<T>(N);
+ auto actual2_lanes = AllocateAligned<T>(N);
+ HWY_ASSERT(actual1_lanes && actual2_lanes);
+
+ uint64_t max_ulp = 0;
+ for (size_t i = 0; i < padded; i += N) {
+ const auto a = Load(d, in_a.get() + i);
+ const auto b = Load(d, in_b.get() + i);
+
+#if HWY_ARCH_ARM_A64
+ // TODO(b/287462770): inline to work around incorrect SVE codegen
+ const auto actual1 = Hypot(d, a, b);
+ const auto actual2 = Hypot(d, b, a);
+#else
+ const auto actual1 = CallHypot(d, a, b);
+ const auto actual2 = CallHypot(d, b, a);
+#endif
+
+ Store(actual1, d, actual1_lanes.get());
+ Store(actual2, d, actual2_lanes.get());
+
+ for (size_t j = 0; j < N; j++) {
+ const T val_a = in_a[i + j];
+ const T val_b = in_b[i + j];
+ const T expected_val = expected[i + j];
+ const T actual1_val = actual1_lanes[j];
+ const T actual2_val = actual2_lanes[j];
+
+ const auto ulp1 =
+ hwy::detail::ComputeUlpDelta(actual1_val, expected_val);
+ if (ulp1 > kMaxErrorUlp) {
+ fprintf(stderr,
+ "%s: Hypot(%e, %e) lane %d expected %E actual %E ulp %g max "
+ "ulp %u\n",
+ hwy::TypeName(T(), Lanes(d)).c_str(), val_a, val_b,
+ static_cast<int>(j), expected_val, actual1_val,
+ static_cast<double>(ulp1),
+ static_cast<uint32_t>(kMaxErrorUlp));
+ }
+
+ const auto ulp2 =
+ hwy::detail::ComputeUlpDelta(actual2_val, expected_val);
+ if (ulp2 > kMaxErrorUlp) {
+ fprintf(stderr,
+ "%s: Hypot(%e, %e) expected %E actual %E ulp %g max ulp %u\n",
+ hwy::TypeName(T(), Lanes(d)).c_str(), val_b, val_a,
+ expected_val, actual2_val, static_cast<double>(ulp2),
+ static_cast<uint32_t>(kMaxErrorUlp));
+ }
+
+ max_ulp = HWY_MAX(max_ulp, HWY_MAX(ulp1, ulp2));
+ }
+ }
+
+ if (max_ulp != 0) {
+ fprintf(stderr, "%s: Hypot max_ulp %g\n",
+ hwy::TypeName(T(), Lanes(d)).c_str(),
+ static_cast<double>(max_ulp));
+ HWY_ASSERT(max_ulp <= kMaxErrorUlp);
+ }
+ }
+};
+
+HWY_NOINLINE void TestAllHypot() {
+ if (HWY_MATH_TEST_EXCESS_PRECISION) return;
+
+ ForFloat3264Types(ForPartialVectors<TestHypot>());
+}
+
+struct TestFastTanRelative {
+ template <class T, class D>
+ HWY_NOINLINE void operator()(T, D d) {
+ if (sizeof(T) == 4) {
+ // Float: [-89.99, +89.99] deg
+ // 89.99 deg = 1.570621794 rad
+ TestMathRelative<T, D>("FastTan", std::tan, CallFastTan, d,
+ static_cast<T>(-1.570621794),
+ static_cast<T>(1.570621794), 0.0035);
+ } else {
+ // Double: [-89.9999999, +89.9999999] deg
+ // 89.9999999 deg = 1.570796325 rad
+ TestMathRelative<T, D>("FastTan", std::tan, CallFastTan, d,
+ static_cast<T>(-1.570796325),
+ static_cast<T>(1.570796325), 0.0035);
+ }
+ }
+};
+
+HWY_NOINLINE void TestAllFastTan() {
+ if (HWY_MATH_TEST_EXCESS_PRECISION) return;
+ ForFloat3264Types(ForPartialVectors<TestFastTanRelative>());
+}
+
+struct TestFastAtanRelative {
+ template <class T, class D>
+ HWY_NOINLINE void operator()(T, D d) {
+ if (sizeof(T) == 4) {
+ // Float: [-1e35, +1e35]
+ TestMathRelative<T, D>("FastAtan", std::atan, CallFastAtan, d,
+ static_cast<T>(-1e35), static_cast<T>(1e35),
+ 0.000035, 1000000);
+ // Float: [0, +1e35]
+ TestMathRelative<T, D>("FastAtanPositive", std::atan,
+ CallFastAtanPositive, d, static_cast<T>(0),
+ static_cast<T>(1e35), 0.000035, 1000000);
+ } else {
+ // Double: [-1e305, +1e305]
+ TestMathRelative<T, D>("FastAtan", std::atan, CallFastAtan, d,
+ static_cast<T>(-1e305), static_cast<T>(1e305),
+ 0.000035, 1000000);
+ // Double: [0, +1e305]
+ TestMathRelative<T, D>("FastAtanPositive", std::atan,
+ CallFastAtanPositive, d, static_cast<T>(0),
+ static_cast<T>(1e305), 0.000035, 1000000);
+ }
+ }
+};
+
+HWY_NOINLINE void TestAllFastAtan() {
+ if (HWY_MATH_TEST_EXCESS_PRECISION) return;
+ ForFloat3264Types(ForPartialVectors<TestFastAtanRelative>());
+}
+
+struct TestFastAtan2 {
+ template <typename T, class D>
+ HWY_NOINLINE void operator()(T t, D d) {
+ const size_t N = Lanes(d);
+
+ size_t padded;
+ AlignedFreeUniquePtr<T[]> in_y, in_x, expected;
+ Atan2TestCases(t, d, padded, in_y, in_x, expected);
+
+ // Constants for error checking
+ const T rel_limit = static_cast<T>(0.000035);
+ const T tiny_threshold = static_cast<T>(1e-20);
+ const Vec<D> v_rel_limit = Set(d, rel_limit);
+ const Vec<D> v_tiny_threshold = Set(d, tiny_threshold);
+
+ for (size_t i = 0; i < padded; i += N) {
+ const Vec<D> y = Load(d, &in_y[i]);
+ const Vec<D> x = Load(d, &in_x[i]);
+#if HWY_ARCH_ARM_A64
+ // TODO(b/287462770): inline to work around incorrect SVE codegen
+ const Vec<D> actual = FastAtan2(d, y, x);
+#else
+ const Vec<D> actual = CallFastAtan2(d, y, x);
+#endif
+ const Vec<D> vexpected = Load(d, &expected[i]);
+
+ // 1. Check NaNs match exactly
+ const Mask<D> exp_nan = IsNaN(vexpected);
+ const Mask<D> act_nan = IsNaN(actual);
+ HWY_ASSERT_MASK_EQ(d, exp_nan, act_nan);
+
+ // 2. Calculate Error
+ const Vec<D> abs_exp = Abs(vexpected);
+ const Vec<D> diff = Abs(Sub(actual, vexpected));
+
+ // 3. Determine Tolerance
+ // If abs_exp > 1e-20, tolerance = abs_exp * 8e-8.
+ // Else, tolerance = 8e-8 (effectively treating 'err' as the relative
+ // error metric).
+ const Mask<D> use_relative = Gt(abs_exp, v_tiny_threshold);
+ const Vec<D> tolerance =
+ IfThenElse(use_relative, Mul(abs_exp, v_rel_limit), v_rel_limit);
+
+ // 4. Verify
+ // Pass if it's NaN (checked above) OR if within tolerance
+ const Mask<D> ok = Or(act_nan, Le(diff, tolerance));
+
+ if (!AllTrue(d, ok)) {
+ const size_t mismatch =
+ static_cast<size_t>(FindKnownFirstTrue(d, Not(ok)));
+ fprintf(stderr, "Mismatch for i=%d expected %E actual %E\n",
+ static_cast<int>(i + mismatch), expected[i + mismatch],
+ ExtractLane(actual, mismatch));
+ HWY_ASSERT(0);
+ }
+ }
+ }
+};
+
+HWY_NOINLINE void TestAllFastAtan2() {
+ if (HWY_MATH_TEST_EXCESS_PRECISION) return;
+ ForFloat3264Types(ForPartialVectors<TestFastAtan2>());
+}
+
+} // namespace
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace hwy {
+namespace {
+HWY_BEFORE_TEST(HwyMathTanTest);
+HWY_EXPORT_AND_TEST_P(HwyMathTanTest, TestAllAtan);
+HWY_EXPORT_AND_TEST_P(HwyMathTanTest, TestAllAtan2);
+HWY_EXPORT_AND_TEST_P(HwyMathTanTest, TestAllHypot);
+HWY_EXPORT_AND_TEST_P(HwyMathTanTest, TestAllFastTan);
+HWY_EXPORT_AND_TEST_P(HwyMathTanTest, TestAllFastAtan);
+HWY_EXPORT_AND_TEST_P(HwyMathTanTest, TestAllFastAtan2);
+
+HWY_AFTER_TEST();
+} // namespace
+} // namespace hwy
+HWY_TEST_MAIN();
+#endif // HWY_ONCE
+
+#endif // HWY_ARCH_RVV
diff --git a/third_party/highway/hwy/contrib/math/math_test.cc b/third_party/highway/hwy/contrib/math/math_test.cc
new file mode 100644
index 0000000000..2049eba617
--- /dev/null
+++ b/third_party/highway/hwy/contrib/math/math_test.cc
@@ -0,0 +1,646 @@
+// Copyright 2020 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stdint.h>
+#include <stdio.h>
+
+#include <cfloat> // FLT_MAX
+#include <cmath> // std::abs
+
+// For faster tests. Not using AES, hence NEON_WITHOUT_AES is sufficient.
+// SVE is mostly superseded by SVE2.
+#ifndef HWY_DISABLED_TARGETS
+#define HWY_DISABLED_TARGETS (HWY_NEON | HWY_SVE)
+#endif // HWY_DISABLED_TARGETS
+
+#include "third_party/highway/hwy/base.h"
+
+// clang-format off
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "hwy/contrib/math/math_test.cc"
+#include "third_party/highway/hwy/foreach_target.h" // IWYU pragma: keep
+#include "third_party/highway/hwy/highway.h"
+#include "third_party/highway/hwy/contrib/math/fast_math-inl.h"
+#include "third_party/highway/hwy/contrib/math/math-inl.h"
+#include "third_party/highway/hwy/tests/test_util-inl.h"
+// clang-format on
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+namespace {
+
+// We have had test failures caused by excess precision due to keeping
+// intermediate results in 80-bit x87 registers. One such failure mode is that
+// Log1p computes a 1.0 which is not exactly equal to 1.0f, causing is_pole to
+// incorrectly evaluate to false.
+#undef HWY_MATH_TEST_EXCESS_PRECISION
+#if HWY_ARCH_X86_32 && HWY_COMPILER_GCC_ACTUAL && \
+ (HWY_TARGET == HWY_SCALAR || HWY_TARGET == HWY_EMU128)
+
+// GCC 13+: because CMAKE_CXX_EXTENSIONS is OFF, we build with -std= and hence
+// also -fexcess-precision=standard, so there is no problem. See #1708 and
+// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=323.
+#if HWY_COMPILER_GCC_ACTUAL >= 1300
+#define HWY_MATH_TEST_EXCESS_PRECISION 0
+
+#else // HWY_COMPILER_GCC_ACTUAL < 1300
+
+// The build system must enable SSE2, e.g. via HWY_CMAKE_SSE2 - see
+// https://stackoverflow.com/questions/20869904/c-handling-of-excess-precision .
+#if defined(__SSE2__) // correct flag given, no problem
+#define HWY_MATH_TEST_EXCESS_PRECISION 0
+#else
+#define HWY_MATH_TEST_EXCESS_PRECISION 1
+#pragma message( \
+ "Skipping scalar math_test on 32-bit x86 GCC <13 without HWY_CMAKE_SSE2")
+#endif // defined(__SSE2__)
+
+#endif // HWY_COMPILER_GCC_ACTUAL
+#else // not (x86-32, GCC, scalar target): running math_test normally
+#define HWY_MATH_TEST_EXCESS_PRECISION 0
+#endif // HWY_ARCH_X86_32 etc
+
+template <class T, class D>
+HWY_NOINLINE void TestMath(const char* name, T (*fx1)(T),
+ Vec<D> (*fxN)(D, VecArg<Vec<D>>), D d, T min, T max,
+ uint64_t max_error_ulp) {
+ if (HWY_MATH_TEST_EXCESS_PRECISION) {
+ static bool once = true;
+ if (once) {
+ once = false;
+ HWY_WARN("Skipping math_test due to GCC issue with excess precision.\n");
+ }
+ return;
+ }
+
+ using UintT = MakeUnsigned<T>;
+
+ const UintT min_bits = BitCastScalar<UintT>(min);
+ const UintT max_bits = BitCastScalar<UintT>(max);
+
+ // If min is negative and max is positive, the range needs to be broken into
+ // two pieces, [+0, max] and [-0, min], otherwise [min, max].
+ int range_count = 1;
+ UintT ranges[2][2] = {{min_bits, max_bits}, {0, 0}};
+ if ((min < T{0}) && (max > T{0})) {
+ ranges[0][0] = BitCastScalar<UintT>(ConvertScalarTo<T>(+0.0));
+ ranges[0][1] = max_bits;
+ ranges[1][0] = BitCastScalar<UintT>(ConvertScalarTo<T>(-0.0));
+ ranges[1][1] = min_bits;
+ range_count = 2;
+ } else {
+ // If not splitting, ensure we iterate from smaller uint to larger uint.
+ // For negative numbers, min (e.g. -1000) has larger uint representation
+ // than max (e.g. -1).
+ if (ranges[0][0] > ranges[0][1]) {
+ auto tmp = ranges[0][0];
+ ranges[0][0] = ranges[0][1];
+ ranges[0][1] = tmp;
+ }
+ }
+
+ uint64_t max_ulp = 0;
+ // Emulation is slower, so cannot afford as many.
+ constexpr UintT kSamplesPerRange =
+ static_cast<UintT>(AdjustedReps(static_cast<size_t>(1000)));
+ for (int range_index = 0; range_index < range_count; ++range_index) {
+ const UintT start = ranges[range_index][0];
+ const UintT stop = ranges[range_index][1];
+ const UintT step = HWY_MAX(1, ((stop - start) / kSamplesPerRange));
+ for (UintT value_bits = start; value_bits <= stop; value_bits += step) {
+ // For reasons unknown, the HWY_MAX is necessary on RVV, otherwise
+ // value_bits can be less than start, and thus possibly NaN.
+ const T value =
+ BitCastScalar<T>(HWY_MIN(HWY_MAX(start, value_bits), stop));
+ const T actual = GetLane(fxN(d, Set(d, value)));
+ const T expected = fx1(value);
+
+ // Skip small inputs and outputs on armv7, it flushes subnormals to zero.
+#if HWY_TARGET <= HWY_NEON_WITHOUT_AES && HWY_ARCH_ARM_V7
+ if ((std::abs(value) < 1e-37f) || (std::abs(expected) < 1e-37f)) {
+ continue;
+ }
+#endif
+
+ const auto ulp = hwy::detail::ComputeUlpDelta(actual, expected);
+ max_ulp = HWY_MAX(max_ulp, ulp);
+ if (ulp > max_error_ulp) {
+ fprintf(stderr, "%s: %s(%f) expected %E actual %E ulp %g max ulp %u\n",
+ hwy::TypeName(T(), Lanes(d)).c_str(), name, value,
+ static_cast<double>(expected), static_cast<double>(actual),
+ static_cast<double>(ulp), static_cast<uint32_t>(max_error_ulp));
+ }
+ }
+ }
+ fprintf(stderr, "%s: %s max_ulp %g\n", hwy::TypeName(T(), Lanes(d)).c_str(),
+ name, static_cast<double>(max_ulp));
+ HWY_ASSERT(max_ulp <= max_error_ulp);
+}
+
+#define DEFINE_MATH_TEST_FUNC(NAME) \
+ HWY_NOINLINE void TestAll##NAME() { \
+ ForFloat3264Types(ForPartialVectors<Test##NAME>()); \
+ }
+
+#undef DEFINE_MATH_TEST
+#define DEFINE_MATH_TEST(NAME, F32x1, F32xN, F32_MIN, F32_MAX, F32_ERROR, \
+ F64x1, F64xN, F64_MIN, F64_MAX, F64_ERROR) \
+ struct Test##NAME { \
+ template <class T, class D, HWY_IF_T_SIZE(T, 4)> \
+ HWY_NOINLINE void operator()(T, D d) { \
+ TestMath<T, D>(HWY_STR(NAME), F32x1, F32xN, d, F32_MIN, F32_MAX, \
+ F32_ERROR); \
+ } \
+ template <class T, class D, HWY_IF_T_SIZE(T, 8)> \
+ HWY_NOINLINE void operator()(T, D d) { \
+ TestMath<T, D>(HWY_STR(NAME), F64x1, F64xN, d, static_cast<T>(F64_MIN), \
+ static_cast<T>(F64_MAX), F64_ERROR); \
+ } \
+ }; \
+ DEFINE_MATH_TEST_FUNC(NAME)
+
+// clang-format off
+DEFINE_MATH_TEST(Exp,
+ std::exp, CallExp, -FLT_MAX, +104.0f, 1,
+ std::exp, CallExp, -DBL_MAX, +104.0, 1)
+DEFINE_MATH_TEST(Exp2,
+ std::exp2, CallExp2, -FLT_MAX, +128.0f, 2,
+ std::exp2, CallExp2, -DBL_MAX, +128.0, 2)
+DEFINE_MATH_TEST(Expm1,
+ std::expm1, CallExpm1, -FLT_MAX, +104.0f, 4,
+ std::expm1, CallExpm1, -DBL_MAX, +104.0, 4)
+DEFINE_MATH_TEST(Log,
+ std::log, CallLog, +FLT_MIN, +FLT_MAX, 1,
+ std::log, CallLog, +DBL_MIN, +DBL_MAX, 1)
+DEFINE_MATH_TEST(Log10,
+ std::log10, CallLog10, +FLT_MIN, +FLT_MAX, 2,
+ std::log10, CallLog10, +DBL_MIN, +DBL_MAX, 2)
+DEFINE_MATH_TEST(Log1p,
+ std::log1p, CallLog1p, +0.0f, +FLT_MAX, 3, // NEON is 3 instead of 2
+ std::log1p, CallLog1p, +0.0, +DBL_MAX, 2)
+DEFINE_MATH_TEST(Log2,
+ std::log2, CallLog2, +FLT_MIN, +FLT_MAX, 2,
+ std::log2, CallLog2, +DBL_MIN, +DBL_MAX, 2)
+DEFINE_MATH_TEST(Cbrt,
+ std::cbrt, CallCbrt, -FLT_MAX, +FLT_MAX, 6,
+ std::cbrt, CallCbrt, -DBL_MAX, +DBL_MAX, 6)
+
+// clang-format on
+
+template <class T, class D>
+HWY_NOINLINE void TestMathRelative(const char* name, T (*fx1)(T),
+ Vec<D> (*fxN)(D, VecArg<Vec<D>>), D d, T min,
+ T max, double max_relative_error,
+ uint64_t samples = 4000) {
+ if (HWY_MATH_TEST_EXCESS_PRECISION) {
+ static bool once = true;
+ if (once) {
+ once = false;
+ HWY_WARN("Skipping math_test due to GCC issue with excess precision.\n");
+ }
+ return;
+ }
+
+ using UintT = MakeUnsigned<T>;
+
+ const UintT min_bits = BitCastScalar<UintT>(min);
+ const UintT max_bits = BitCastScalar<UintT>(max);
+
+ // If min is negative and max is positive, the range needs to be broken into
+ // two pieces, [+0, max] and [-0, min], otherwise [min, max].
+ int range_count = 1;
+ UintT ranges[2][2] = {{min_bits, max_bits}, {0, 0}};
+ if ((min < 0.0) && (max > 0.0)) {
+ ranges[0][0] = BitCastScalar<UintT>(ConvertScalarTo<T>(+0.0));
+ ranges[0][1] = max_bits;
+ ranges[1][0] = BitCastScalar<UintT>(ConvertScalarTo<T>(-0.0));
+ ranges[1][1] = min_bits;
+ range_count = 2;
+ } else {
+ // If not splitting, ensure we iterate from smaller uint to larger uint.
+ // For negative numbers, min (e.g. -1000) has larger uint representation
+ // than max (e.g. -1).
+ if (ranges[0][0] > ranges[0][1]) {
+ auto tmp = ranges[0][0];
+ ranges[0][0] = ranges[0][1];
+ ranges[0][1] = tmp;
+ }
+ }
+
+ double max_actual_rel_error = 0.0;
+ double max_error_value = 0.0;
+ double sum_rel_error = 0.0;
+ uint64_t count = 0;
+ // Emulation is slower, so cannot afford as many.
+ const UintT kSamplesPerRange =
+ static_cast<UintT>(AdjustedReps(static_cast<size_t>(samples)));
+ for (int range_index = 0; range_index < range_count; ++range_index) {
+ const UintT start = ranges[range_index][0];
+ const UintT stop = ranges[range_index][1];
+ const UintT step = HWY_MAX(1, ((stop - start) / kSamplesPerRange));
+ for (UintT value_bits = start; value_bits <= stop; value_bits += step) {
+ // For reasons unknown, the HWY_MAX is necessary on RVV, otherwise
+ // value_bits can be less than start, and thus possibly NaN.
+ const T value =
+ BitCastScalar<T>(HWY_MIN(HWY_MAX(start, value_bits), stop));
+ const T actual = GetLane(fxN(d, Set(d, value)));
+ const T expected = fx1(value);
+
+ // Skip small inputs and outputs on armv7, it flushes subnormals to zero.
+#if HWY_TARGET <= HWY_NEON_WITHOUT_AES && HWY_ARCH_ARM_V7
+ if ((std::abs(value) < 1e-37f) || (std::abs(expected) < 1e-37f)) {
+ continue;
+ }
+#endif
+
+ if (std::abs(expected) > 0.0) {
+ double rel = std::abs(static_cast<double>(actual) -
+ static_cast<double>(expected)) /
+ std::abs(static_cast<double>(expected));
+ if (ScalarIsNaN(rel) || rel > max_actual_rel_error) {
+ max_actual_rel_error = rel;
+ max_error_value = static_cast<double>(value);
+ }
+ sum_rel_error += rel;
+ count++;
+ if (rel > max_relative_error) {
+ static int print_count = 0;
+ if (print_count < 10) {
+ fprintf(stderr,
+ "%s: %s(%f) expected %E actual %E rel %E max rel %E\n",
+ hwy::TypeName(T(), Lanes(d)).c_str(), name,
+ static_cast<double>(value), static_cast<double>(expected),
+ static_cast<double>(actual), rel, max_relative_error);
+ print_count++;
+ }
+ }
+ }
+ }
+ }
+ fprintf(stderr, "%s: %s max_rel_error %E at %E\n",
+ hwy::TypeName(T(), Lanes(d)).c_str(), name, max_actual_rel_error,
+ max_error_value);
+ if (count > 0) {
+ fprintf(stderr, "%s: %s avg_rel_error %E\n",
+ hwy::TypeName(T(), Lanes(d)).c_str(), name,
+ sum_rel_error / static_cast<double>(count));
+ }
+ HWY_ASSERT(max_actual_rel_error <= max_relative_error);
+}
+
+struct TestFastLog {
+ template <class T, class D>
+ HWY_NOINLINE void operator()(T, D d) {
+ const double max_relative_error = 1.15E-5;
+ const uint64_t samples = 1000000;
+ if (sizeof(T) == 4) {
+ TestMathRelative<T, D>("FastLog", std::log, CallFastLog, d,
+ static_cast<T>(FLT_MIN), static_cast<T>(FLT_MAX),
+ max_relative_error, samples);
+ TestMathRelative<T, D>("FastLogPositiveNormal", std::log,
+ CallFastLogPositiveNormal, d,
+ static_cast<T>(1.18e-38f), static_cast<T>(FLT_MAX),
+ max_relative_error, samples);
+
+ } else {
+ TestMathRelative<T, D>("FastLog", std::log, CallFastLog, d,
+ static_cast<T>(DBL_MIN), static_cast<T>(DBL_MAX),
+ max_relative_error, samples);
+ TestMathRelative<T, D>("FastLogPositiveNormal", std::log,
+ CallFastLogPositiveNormal, d,
+ static_cast<T>(2.23e-308), static_cast<T>(DBL_MAX),
+ max_relative_error, samples);
+ }
+ }
+};
+
+struct TestFastExp {
+ template <class T, class D>
+ HWY_NOINLINE void operator()(T, D d) {
+ if (sizeof(T) == 4) {
+ // Float Normal Range: [-87.0, +88.0]
+ // exp(-87) ~= 1.6e-38 (just above min normal 1.17e-38)
+ TestMathRelative<T, D>("FastExpNormal", std::exp, CallFastExp, d,
+ static_cast<T>(-87.0), static_cast<T>(88.0),
+ 0.000008, 10'000'000);
+
+ // Float Subnormal Range: [-104.0, -87.0]
+ // exp(-104) is very small. Quantization error is expected.
+ TestMathRelative<T, D>("FastExpSubnormal", std::exp, CallFastExp, d,
+ static_cast<T>(-104.0), static_cast<T>(-87.0),
+ 0.03);
+ } else {
+ // Double Normal Range: [-708.0, +706.0]
+ // exp(-708) ~= 2.2e-308 (min normal 2.22e-308)
+ TestMathRelative<T, D>("FastExpNormal", std::exp, CallFastExp, d,
+ static_cast<T>(-708.0), static_cast<T>(706.0),
+ 0.000008, 10'000'000);
+
+ // Double Subnormal Range: [-744.0, -708.0]
+ // exp(-744) is very small. Quantization error is expected.
+ TestMathRelative<T, D>("FastExpSubnormal", std::exp, CallFastExp, d,
+ static_cast<T>(-744.0), static_cast<T>(-708.0),
+ 1.4E-4);
+ }
+ }
+};
+
+struct TestFastExp2 {
+ template <class T, class D>
+ HWY_NOINLINE void operator()(T, D d) {
+ if (sizeof(T) == 4) {
+ // Float Normal Range: [-126.0, +127.0]
+ // exp2(-126) is min normal
+ TestMathRelative<T, D>("FastExp2Normal", std::exp2, CallFastExp2, d,
+ static_cast<T>(-126.0), static_cast<T>(127.0),
+ 0.000008, 10'000'000);
+
+ // Float Subnormal Range: [-150.0, -126.0]
+ TestMathRelative<T, D>("FastExp2Subnormal", std::exp2, CallFastExp2, d,
+ static_cast<T>(-150.0), static_cast<T>(-126.0),
+ 0.0009);
+ } else {
+ // Double Normal Range: [-1022.0, +1023.0]
+ TestMathRelative<T, D>("FastExp2Normal", std::exp2, CallFastExp2, d,
+ static_cast<T>(-1022.0), static_cast<T>(1023.0),
+ 0.000008, 10'000'000);
+
+ // Double Subnormal Range: [-1075.0, -1022.0]
+ TestMathRelative<T, D>("FastExp2Subnormal", std::exp2, CallFastExp2, d,
+ static_cast<T>(-1075.0), static_cast<T>(-1022.0),
+ 0.0004);
+ }
+ }
+};
+
+struct TestFastExpMinusOrZero {
+ template <class T, class D>
+ HWY_NOINLINE void operator()(T, D d) {
+ if (sizeof(T) == 4) {
+ // Float Normal Range: [-87.0, 0.0]
+ TestMathRelative<T, D>("FastExpMinusOrZeroNormal", std::exp,
+ CallFastExpMinusOrZero, d, static_cast<T>(-87.0),
+ static_cast<T>(-0.0), 0.000008, 10'000'000);
+ } else {
+ // Double Normal Range: [-708.0, 0.0]
+ TestMathRelative<T, D>("FastExpMinusOrZeroNormal", std::exp,
+ CallFastExpMinusOrZero, d, static_cast<T>(-708.0),
+ static_cast<T>(-0.0), 0.000008, 10'000'000);
+ }
+ }
+};
+
+struct TestFastLog2 {
+ template <class T, class D>
+ HWY_NOINLINE void operator()(T, D d) {
+ const double max_relative_error = 1.15E-5;
+ const uint64_t samples = 1000000;
+ if (sizeof(T) == 4) {
+ TestMathRelative<T, D>("FastLog2", std::log2, CallFastLog2, d,
+ static_cast<T>(FLT_MIN), static_cast<T>(FLT_MAX),
+ max_relative_error, samples);
+ TestMathRelative<T, D>("FastLog2PositiveNormal", std::log2,
+ CallFastLog2PositiveNormal, d,
+ static_cast<T>(1.18e-38f), static_cast<T>(FLT_MAX),
+ max_relative_error, samples);
+ } else {
+ TestMathRelative<T, D>("FastLog2", std::log2, CallFastLog2, d,
+ static_cast<T>(DBL_MIN), static_cast<T>(DBL_MAX),
+ max_relative_error, samples);
+ TestMathRelative<T, D>("FastLog2PositiveNormal", std::log2,
+ CallFastLog2PositiveNormal, d,
+ static_cast<T>(2.23e-308), static_cast<T>(DBL_MAX),
+ max_relative_error, samples);
+ }
+ }
+};
+
+struct TestFastLog10 {
+ template <class T, class D>
+ HWY_NOINLINE void operator()(T, D d) {
+ const double max_relative_error = 1.15E-5;
+ const uint64_t samples = 1000000;
+ if (sizeof(T) == 4) {
+ TestMathRelative<T, D>("FastLog10", std::log10, CallFastLog10, d,
+ static_cast<T>(FLT_MIN), static_cast<T>(FLT_MAX),
+ max_relative_error, samples);
+ TestMathRelative<T, D>("FastLog10PositiveNormal", std::log10,
+ CallFastLog10PositiveNormal, d,
+ static_cast<T>(1.18e-38f), static_cast<T>(FLT_MAX),
+ max_relative_error, samples);
+ } else {
+ TestMathRelative<T, D>("FastLog10", std::log10, CallFastLog10, d,
+ static_cast<T>(DBL_MIN), static_cast<T>(DBL_MAX),
+ max_relative_error, samples);
+ TestMathRelative<T, D>("FastLog10PositiveNormal", std::log10,
+ CallFastLog10PositiveNormal, d,
+ static_cast<T>(2.23e-308), static_cast<T>(DBL_MAX),
+ max_relative_error, samples);
+ }
+ }
+};
+
+struct TestFastLog1p {
+ template <class T, class D>
+ HWY_NOINLINE void operator()(T, D d) {
+ const double max_relative_error = 1.15E-5;
+ const uint64_t samples = 1000000;
+ if (sizeof(T) == 4) {
+ TestMathRelative<T, D>("FastLog1p", std::log1p, CallFastLog1p, d,
+ static_cast<T>(-0.9f), static_cast<T>(FLT_MAX),
+ max_relative_error, samples);
+ TestMathRelative<T, D>("FastLog1pPositiveNormal", std::log1p,
+ CallFastLog1pPositiveNormal, d,
+ static_cast<T>(0.0f), static_cast<T>(FLT_MAX),
+ max_relative_error, samples);
+ } else {
+ TestMathRelative<T, D>("FastLog1p", std::log1p, CallFastLog1p, d,
+ static_cast<T>(-0.9), static_cast<T>(DBL_MAX),
+ max_relative_error, samples);
+ TestMathRelative<T, D>("FastLog1pPositiveNormal", std::log1p,
+ CallFastLog1pPositiveNormal, d,
+ static_cast<T>(0.0), static_cast<T>(DBL_MAX),
+ max_relative_error, samples);
+ }
+ }
+};
+
+HWY_NOINLINE void TestAllFastExp() {
+ ForFloat3264Types(ForPartialVectors<TestFastExp>());
+}
+
+HWY_NOINLINE void TestAllFastExp2() {
+ ForFloat3264Types(ForPartialVectors<TestFastExp2>());
+}
+
+HWY_NOINLINE void TestAllFastExpMinusOrZero() {
+ ForFloat3264Types(ForPartialVectors<TestFastExpMinusOrZero>());
+}
+
+HWY_NOINLINE void TestAllFastLog() {
+ ForFloat3264Types(ForPartialVectors<TestFastLog>());
+}
+
+HWY_NOINLINE void TestAllFastLog2() {
+ ForFloat3264Types(ForPartialVectors<TestFastLog2>());
+}
+
+HWY_NOINLINE void TestAllFastLog10() {
+ ForFloat3264Types(ForPartialVectors<TestFastLog10>());
+}
+
+HWY_NOINLINE void TestAllFastLog1p() {
+ ForFloat3264Types(ForPartialVectors<TestFastLog1p>());
+}
+
+struct TestFastPow {
+ template <class T, class D>
+ HWY_NOINLINE void operator()(T, D d) {
+ if (HWY_MATH_TEST_EXCESS_PRECISION) {
+ return;
+ }
+
+ const T bases[] = {
+ static_cast<T>(0.1), static_cast<T>(0.5), static_cast<T>(0.99),
+ static_cast<T>(1.0), static_cast<T>(1.0001), static_cast<T>(1.5),
+ static_cast<T>(2.0), static_cast<T>(2.71828), static_cast<T>(10.0),
+ static_cast<T>(100.0), static_cast<T>(1.0e-10), static_cast<T>(1.0e10),
+ };
+
+ double max_actual_rel_error = 0.0;
+ double max_error_base = 0.0;
+ double max_error_exp = 0.0;
+
+ for (T base : bases) {
+ T logb = std::log(base);
+ T limit = (sizeof(T) == 8) ? static_cast<T>(25.0) : static_cast<T>(25.0);
+ T min_exp_val = -limit / logb;
+ T max_exp_val = limit / logb;
+
+ if (min_exp_val > max_exp_val) {
+ T tmp = min_exp_val;
+ min_exp_val = max_exp_val;
+ max_exp_val = tmp;
+ }
+
+ using UintT = MakeUnsigned<T>;
+ const UintT min_bits = BitCastScalar<UintT>(min_exp_val);
+ const UintT max_bits = BitCastScalar<UintT>(max_exp_val);
+
+ int range_count = 1;
+ UintT ranges[2][2] = {{min_bits, max_bits}, {0, 0}};
+ if ((min_exp_val < 0.0) && (max_exp_val > 0.0)) {
+ ranges[0][0] = BitCastScalar<UintT>(ConvertScalarTo<T>(+0.0));
+ ranges[0][1] = max_bits;
+ ranges[1][0] = BitCastScalar<UintT>(ConvertScalarTo<T>(-0.0));
+ ranges[1][1] = min_bits;
+ range_count = 2;
+ } else {
+ if (ranges[0][0] > ranges[0][1]) {
+ auto tmp = ranges[0][0];
+ ranges[0][0] = ranges[0][1];
+ ranges[0][1] = tmp;
+ }
+ }
+
+ const UintT kSamplesPerRange =
+ static_cast<UintT>(AdjustedReps(static_cast<size_t>(10000)));
+ for (int range_index = 0; range_index < range_count; ++range_index) {
+ const UintT start = ranges[range_index][0];
+ const UintT stop = ranges[range_index][1];
+ const UintT step = HWY_MAX(1, ((stop - start) / kSamplesPerRange));
+ for (UintT value_bits = start; value_bits <= stop; value_bits += step) {
+ const T exp_val =
+ BitCastScalar<T>(HWY_MIN(HWY_MAX(start, value_bits), stop));
+ const T actual =
+ GetLane(CallFastPow(d, Set(d, base), Set(d, exp_val)));
+ const T expected = std::pow(base, exp_val);
+
+#if HWY_TARGET <= HWY_NEON_WITHOUT_AES && HWY_ARCH_ARM_V7
+ if ((std::abs(exp_val) < 1e-37f) || (std::abs(expected) < 1e-37f)) {
+ continue;
+ }
+#endif
+
+ if (std::abs(expected) > 0.0) {
+ double rel = std::abs(static_cast<double>(actual) -
+ static_cast<double>(expected)) /
+ std::abs(static_cast<double>(expected));
+
+ if (ScalarIsNaN(rel) || rel > max_actual_rel_error) {
+ max_actual_rel_error = rel;
+ max_error_base = static_cast<double>(base);
+ max_error_exp = static_cast<double>(exp_val);
+ }
+ if (rel > 0.0003) {
+ static int print_count = 0;
+ if (print_count < 10) {
+ fprintf(stderr,
+ "%s: FastPow(%f, %f) expected %E actual %E rel %E max "
+ "rel %E\n",
+ hwy::TypeName(T(), Lanes(d)).c_str(),
+ static_cast<double>(base), static_cast<double>(exp_val),
+ static_cast<double>(expected),
+ static_cast<double>(actual), rel, 0.0003);
+ print_count++;
+ }
+ }
+ }
+ }
+ }
+ }
+ fprintf(stderr, "%s: FastPow max_rel_error %E at base=%E exp=%E\n",
+ hwy::TypeName(T(), Lanes(d)).c_str(), max_actual_rel_error,
+ max_error_base, max_error_exp);
+ HWY_ASSERT(max_actual_rel_error <= 0.0003);
+ }
+};
+
+HWY_NOINLINE void TestAllFastPow() {
+ ForFloat3264Types(ForPartialVectors<TestFastPow>());
+}
+
+} // namespace
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace hwy {
+namespace {
+HWY_BEFORE_TEST(HwyMathTest);
+HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllExp);
+HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllExp2);
+HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllExpm1);
+HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllLog);
+HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllLog10);
+HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllLog1p);
+HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllLog2);
+HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllCbrt);
+HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllFastLog);
+HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllFastExp);
+HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllFastExp2);
+HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllFastExpMinusOrZero);
+HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllFastLog2);
+HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllFastLog10);
+HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllFastLog1p);
+HWY_EXPORT_AND_TEST_P(HwyMathTest, TestAllFastPow);
+HWY_AFTER_TEST();
+} // namespace
+} // namespace hwy
+HWY_TEST_MAIN();
+#endif // HWY_ONCE
diff --git a/third_party/highway/hwy/contrib/math/math_trig_test.cc b/third_party/highway/hwy/contrib/math/math_trig_test.cc
new file mode 100644
index 0000000000..6e8450f2c5
--- /dev/null
+++ b/third_party/highway/hwy/contrib/math/math_trig_test.cc
@@ -0,0 +1,241 @@
+// Copyright 2020 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stdint.h>
+#include <stdio.h>
+
+#include <cmath> // std::abs
+
+#include "third_party/highway/hwy/base.h"
+
+// clang-format off
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "hwy/contrib/math/math_trig_test.cc"
+#include "third_party/highway/hwy/foreach_target.h" // IWYU pragma: keep
+#include "third_party/highway/hwy/highway.h"
+#include "third_party/highway/hwy/contrib/math/math-inl.h"
+#include "third_party/highway/hwy/tests/test_util-inl.h"
+// clang-format on
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+namespace {
+
+// We have had test failures caused by excess precision due to keeping
+// intermediate results in 80-bit x87 registers. One such failure mode is that
+// Log1p computes a 1.0 which is not exactly equal to 1.0f, causing is_pole to
+// incorrectly evaluate to false.
+#undef HWY_MATH_TEST_EXCESS_PRECISION
+#if HWY_ARCH_X86_32 && HWY_COMPILER_GCC_ACTUAL && \
+ (HWY_TARGET == HWY_SCALAR || HWY_TARGET == HWY_EMU128)
+
+// GCC 13+: because CMAKE_CXX_EXTENSIONS is OFF, we build with -std= and hence
+// also -fexcess-precision=standard, so there is no problem. See #1708 and
+// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=323.
+#if HWY_COMPILER_GCC_ACTUAL >= 1300
+#define HWY_MATH_TEST_EXCESS_PRECISION 0
+
+#else // HWY_COMPILER_GCC_ACTUAL < 1300
+
+// The build system must enable SSE2, e.g. via HWY_CMAKE_SSE2 - see
+// https://stackoverflow.com/questions/20869904/c-handling-of-excess-precision .
+#if defined(__SSE2__) // correct flag given, no problem
+#define HWY_MATH_TEST_EXCESS_PRECISION 0
+#else
+#define HWY_MATH_TEST_EXCESS_PRECISION 1
+#pragma message( \
+ "Skipping scalar math_test on 32-bit x86 GCC <13 without HWY_CMAKE_SSE2")
+#endif // defined(__SSE2__)
+
+#endif // HWY_COMPILER_GCC_ACTUAL
+#else // not (x86-32, GCC, scalar target): running math_test normally
+#define HWY_MATH_TEST_EXCESS_PRECISION 0
+#endif // HWY_ARCH_X86_32 etc
+
+template <class T, class D>
+HWY_NOINLINE void TestMath(const char* name, T (*fx1)(T),
+ Vec<D> (*fxN)(D, VecArg<Vec<D>>), D d, T min, T max,
+ uint64_t max_error_ulp) {
+ if (HWY_MATH_TEST_EXCESS_PRECISION) {
+ static bool once = true;
+ if (once) {
+ once = false;
+ HWY_WARN("Skipping math_test due to GCC issue with excess precision.\n");
+ }
+ return;
+ }
+
+ using UintT = MakeUnsigned<T>;
+
+ const UintT min_bits = BitCastScalar<UintT>(min);
+ const UintT max_bits = BitCastScalar<UintT>(max);
+
+ // If min is negative and max is positive, the range needs to be broken into
+ // two pieces, [+0, max] and [-0, min], otherwise [min, max].
+ int range_count = 1;
+ UintT ranges[2][2] = {{min_bits, max_bits}, {0, 0}};
+ if ((min < 0.0) && (max > 0.0)) {
+ ranges[0][0] = BitCastScalar<UintT>(ConvertScalarTo<T>(+0.0));
+ ranges[0][1] = max_bits;
+ ranges[1][0] = BitCastScalar<UintT>(ConvertScalarTo<T>(-0.0));
+ ranges[1][1] = min_bits;
+ range_count = 2;
+ }
+
+ uint64_t max_ulp = 0;
+ // Emulation is slower, so cannot afford as many.
+ constexpr UintT kSamplesPerRange = static_cast<UintT>(AdjustedReps(4000));
+ for (int range_index = 0; range_index < range_count; ++range_index) {
+ const UintT start = ranges[range_index][0];
+ const UintT stop = ranges[range_index][1];
+ const UintT step = HWY_MAX(1, ((stop - start) / kSamplesPerRange));
+ for (UintT value_bits = start; value_bits <= stop; value_bits += step) {
+ // For reasons unknown, the HWY_MAX is necessary on RVV, otherwise
+ // value_bits can be less than start, and thus possibly NaN.
+ const T value =
+ BitCastScalar<T>(HWY_MIN(HWY_MAX(start, value_bits), stop));
+ const T actual = GetLane(fxN(d, Set(d, value)));
+ const T expected = fx1(value);
+
+ // Skip small inputs and outputs on armv7, it flushes subnormals to zero.
+#if HWY_TARGET <= HWY_NEON_WITHOUT_AES && HWY_ARCH_ARM_V7
+ if ((std::abs(value) < 1e-37f) || (std::abs(expected) < 1e-37f)) {
+ continue;
+ }
+#endif
+
+ const auto ulp = hwy::detail::ComputeUlpDelta(actual, expected);
+ max_ulp = HWY_MAX(max_ulp, ulp);
+ if (ulp > max_error_ulp) {
+ fprintf(stderr, "%s: %s(%f) expected %E actual %E ulp %g max ulp %u\n",
+ hwy::TypeName(T(), Lanes(d)).c_str(), name, value, expected,
+ actual, static_cast<double>(ulp),
+ static_cast<uint32_t>(max_error_ulp));
+ }
+ }
+ }
+ fprintf(stderr, "%s: %s max_ulp %g\n", hwy::TypeName(T(), Lanes(d)).c_str(),
+ name, static_cast<double>(max_ulp));
+ HWY_ASSERT(max_ulp <= max_error_ulp);
+}
+
+#define DEFINE_MATH_TEST_FUNC(NAME) \
+ HWY_NOINLINE void TestAll##NAME() { \
+ ForFloat3264Types(ForPartialVectors<Test##NAME>()); \
+ }
+
+#undef DEFINE_MATH_TEST
+#define DEFINE_MATH_TEST(NAME, F32x1, F32xN, F32_MIN, F32_MAX, F32_ERROR, \
+ F64x1, F64xN, F64_MIN, F64_MAX, F64_ERROR) \
+ struct Test##NAME { \
+ template <class T, class D> \
+ HWY_NOINLINE void operator()(T, D d) { \
+ if (sizeof(T) == 4) { \
+ TestMath<T, D>(HWY_STR(NAME), F32x1, F32xN, d, F32_MIN, F32_MAX, \
+ F32_ERROR); \
+ } else { \
+ TestMath<T, D>(HWY_STR(NAME), F64x1, F64xN, d, \
+ static_cast<T>(F64_MIN), static_cast<T>(F64_MAX), \
+ F64_ERROR); \
+ } \
+ } \
+ }; \
+ DEFINE_MATH_TEST_FUNC(NAME)
+
+// The discrepancy is unacceptably large for MSYS2 (less accurate libm?), so
+// only increase the error tolerance there.
+constexpr uint64_t Cos64ULP() {
+#if defined(__MINGW32__)
+ return 23;
+#else
+ return 3;
+#endif
+}
+
+template <class D>
+static Vec<D> SinCosSin(const D d, VecArg<Vec<D>> x) {
+ Vec<D> s, c;
+ CallSinCos(d, x, s, c);
+ return s;
+}
+
+template <class D>
+static Vec<D> SinCosCos(const D d, VecArg<Vec<D>> x) {
+ Vec<D> s, c;
+ CallSinCos(d, x, s, c);
+ return c;
+}
+
+// on targets without FMA the result is less inaccurate
+constexpr uint64_t SinCosSin32ULP() {
+#if !(HWY_NATIVE_FMA)
+ return 256;
+#else
+ return 3;
+#endif
+}
+
+constexpr uint64_t SinCosCos32ULP() {
+#if !(HWY_NATIVE_FMA)
+ return 64;
+#else
+ return 3;
+#endif
+}
+
+// clang-format off
+DEFINE_MATH_TEST(Acos,
+ std::acos, CallAcos, -1.0f, +1.0f, 3, // NEON is 3 instead of 2
+ std::acos, CallAcos, -1.0, +1.0, 2)
+DEFINE_MATH_TEST(Asin,
+ std::asin, CallAsin, -1.0f, +1.0f, 4, // 4 ulp on Armv7, not 2
+ std::asin, CallAsin, -1.0, +1.0, 2)
+// NEON has ULP 4 instead of 3
+DEFINE_MATH_TEST(Cos,
+ std::cos, CallCos, -39000.0f, +39000.0f, 3,
+ std::cos, CallCos, -39000.0, +39000.0, Cos64ULP())
+DEFINE_MATH_TEST(Sin,
+ std::sin, CallSin, -39000.0f, +39000.0f, 3,
+ std::sin, CallSin, -39000.0, +39000.0, 4) // MSYS is 4 instead of 3
+DEFINE_MATH_TEST(SinCosSin,
+ std::sin, SinCosSin, -39000.0f, +39000.0f, SinCosSin32ULP(),
+ std::sin, SinCosSin, -39000.0, +39000.0, 1)
+DEFINE_MATH_TEST(SinCosCos,
+ std::cos, SinCosCos, -39000.0f, +39000.0f, SinCosCos32ULP(),
+ std::cos, SinCosCos, -39000.0, +39000.0, 1)
+// clang-format on
+
+} // namespace
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace hwy {
+namespace {
+HWY_BEFORE_TEST(HwyMathTrigTest);
+HWY_EXPORT_AND_TEST_P(HwyMathTrigTest, TestAllAcos);
+HWY_EXPORT_AND_TEST_P(HwyMathTrigTest, TestAllAsin);
+HWY_EXPORT_AND_TEST_P(HwyMathTrigTest, TestAllCos);
+HWY_EXPORT_AND_TEST_P(HwyMathTrigTest, TestAllSin);
+HWY_EXPORT_AND_TEST_P(HwyMathTrigTest, TestAllSinCosSin);
+HWY_EXPORT_AND_TEST_P(HwyMathTrigTest, TestAllSinCosCos);
+HWY_AFTER_TEST();
+} // namespace
+} // namespace hwy
+HWY_TEST_MAIN();
+#endif // HWY_ONCE
diff --git a/third_party/highway/hwy/contrib/matvec/matvec-inl.h b/third_party/highway/hwy/contrib/matvec/matvec-inl.h
index 32cd8df147..caa95cd48c 100644
--- a/third_party/highway/hwy/contrib/matvec/matvec-inl.h
+++ b/third_party/highway/hwy/contrib/matvec/matvec-inl.h
@@ -23,6 +23,7 @@
#endif
#include <stddef.h>
+#include <stdint.h>
#include "third_party/highway/hwy/cache_control.h"
#include "third_party/highway/hwy/contrib/thread_pool/thread_pool.h"
@@ -48,13 +49,13 @@ HWY_NOINLINE void MatVecAddImpl(const T* HWY_RESTRICT mat,
// Process multiple rows at a time so that we write multiples of a cache line
// to avoid false sharing (>= 64). 128 is better than 256. 512 has too little
// parallelization potential.
- constexpr size_t kChunkSize = 64 / sizeof(T);
- const uint64_t num_chunks = static_cast<uint64_t>(kOuter / kChunkSize);
+ constexpr size_t kChunkSize2 = 64 / sizeof(T);
+ const uint64_t num_chunks = static_cast<uint64_t>(kOuter / kChunkSize2);
const ScalableTag<T> d;
const size_t N = Lanes(d);
// Required for Stream loop, otherwise we might have partial vectors.
- HWY_DASSERT(kChunkSize >= N);
+ HWY_DASSERT(kChunkSize2 >= N);
pool.Run(0, num_chunks,
[&](const uint64_t chunk, size_t /*thread*/) HWY_ATTR {
// MSVC workaround: duplicate to ensure constexpr.
@@ -125,7 +126,7 @@ HWY_NOINLINE void MatVecAddImpl(const T* HWY_RESTRICT mat,
hwy::FlushStream();
// Handle remainder rows which are not a multiple of the chunk size.
- for (size_t r = num_chunks * kChunkSize; r < kOuter; ++r) {
+ for (size_t r = num_chunks * kChunkSize2; r < kOuter; ++r) {
auto sum0 = Zero(d);
const T* HWY_RESTRICT row = &mat[r * kInner];
@@ -193,8 +194,8 @@ HWY_NOINLINE void MatVecAddImpl(const hwy::bfloat16_t* HWY_RESTRICT mat,
// Process multiple rows at a time so that we write multiples of a cache line
// to avoid false sharing (>= 64). 128 is better than 256. 512 has too little
// parallelization potential.
- constexpr size_t kChunkSize = 64 / sizeof(float);
- const uint64_t num_chunks = static_cast<uint64_t>(kOuter / kChunkSize);
+ constexpr size_t kChunkSize2 = 64 / sizeof(float);
+ const uint64_t num_chunks = static_cast<uint64_t>(kOuter / kChunkSize2);
const ScalableTag<float> d;
const Repartition<hwy::bfloat16_t, decltype(d)> d16;
@@ -206,7 +207,7 @@ HWY_NOINLINE void MatVecAddImpl(const hwy::bfloat16_t* HWY_RESTRICT mat,
using V16H = Vec<decltype(d16h)>;
const size_t N = Lanes(d);
// Required for Stream loop, otherwise we might have partial vectors.
- HWY_DASSERT(kChunkSize >= N);
+ HWY_DASSERT(kChunkSize2 >= N);
pool.Run(0, num_chunks,
[&](const uint64_t chunk, size_t /*thread*/) HWY_ATTR {
// MSVC workaround: duplicate to ensure constexpr.
@@ -284,7 +285,7 @@ HWY_NOINLINE void MatVecAddImpl(const hwy::bfloat16_t* HWY_RESTRICT mat,
hwy::FlushStream();
// Handle remainder rows which are not a multiple of the chunk size.
- for (size_t r = num_chunks * kChunkSize; r < kOuter; ++r) {
+ for (size_t r = num_chunks * kChunkSize2; r < kOuter; ++r) {
auto sum0 = Zero(d);
const hwy::bfloat16_t* HWY_RESTRICT row = &mat[r * kInner];
@@ -333,15 +334,15 @@ HWY_NOINLINE void MatVecAddImpl(const hwy::bfloat16_t* HWY_RESTRICT mat,
// Process multiple rows at a time so that we write multiples of a cache line
// to avoid false sharing (>= 64). 128 is better than 256. 512 has too little
// parallelization potential.
- constexpr size_t kChunkSize = 64 / sizeof(bfloat16_t);
- const uint64_t num_chunks = static_cast<uint64_t>(kOuter / kChunkSize);
+ constexpr size_t kChunkSize2 = 64 / sizeof(bfloat16_t);
+ const uint64_t num_chunks = static_cast<uint64_t>(kOuter / kChunkSize2);
const ScalableTag<float> df;
const Repartition<hwy::bfloat16_t, decltype(df)> d16;
using V16 = Vec<decltype(d16)>;
const size_t N = Lanes(d16);
// Required for Stream loop, otherwise we might have partial vectors.
- HWY_DASSERT(kChunkSize >= N);
+ HWY_DASSERT(kChunkSize2 >= N);
pool.Run(0, num_chunks,
[&](const uint64_t chunk, size_t /*thread*/) HWY_ATTR {
// MSVC workaround: duplicate to ensure constexpr.
@@ -403,7 +404,7 @@ HWY_NOINLINE void MatVecAddImpl(const hwy::bfloat16_t* HWY_RESTRICT mat,
hwy::FlushStream();
// Handle remainder rows which are not a multiple of the chunk size.
- for (size_t r = num_chunks * kChunkSize; r < kOuter; ++r) {
+ for (size_t r = num_chunks * kChunkSize2; r < kOuter; ++r) {
auto sum0 = Zero(df);
auto sum1 = Zero(df);
diff --git a/third_party/highway/hwy/contrib/matvec/matvec_test.cc b/third_party/highway/hwy/contrib/matvec/matvec_test.cc
new file mode 100644
index 0000000000..b09f5b3a6d
--- /dev/null
+++ b/third_party/highway/hwy/contrib/matvec/matvec_test.cc
@@ -0,0 +1,312 @@
+// Copyright 2023 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "third_party/highway/hwy/base.h"
+
+// Reduce targets to avoid timeout under emulation.
+#ifndef HWY_DISABLED_TARGETS
+#define HWY_DISABLED_TARGETS (HWY_SVE2_128 | HWY_SVE2 | HWY_SVE_256 | HWY_NEON)
+#endif
+
+#include <stddef.h>
+#include <stdint.h>
+
+#include <cmath> // std::abs
+
+#include "third_party/highway/hwy/aligned_allocator.h"
+
+// clang-format off
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "hwy/contrib/matvec/matvec_test.cc" // NOLINT
+#include "third_party/highway/hwy/foreach_target.h" // IWYU pragma: keep
+// Must come after foreach_target.h
+#include "third_party/highway/hwy/contrib/algo/transform-inl.h"
+#include "third_party/highway/hwy/contrib/matvec/matvec-inl.h"
+#include "third_party/highway/hwy/highway.h"
+#include "third_party/highway/hwy/contrib/thread_pool/thread_pool.h"
+#include "third_party/highway/hwy/contrib/thread_pool/topology.h"
+#include "third_party/highway/hwy/tests/test_util-inl.h"
+// clang-format on
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+namespace {
+
+template <typename MatT, typename T>
+HWY_NOINLINE void SimpleMatVecAdd(const MatT* HWY_RESTRICT mat,
+ const T* HWY_RESTRICT vec,
+ const T* HWY_RESTRICT add, size_t rows,
+ size_t cols, T* HWY_RESTRICT out,
+ ThreadPool& pool) {
+ if (add) {
+ pool.Run(0, rows, [=](uint64_t r, size_t /*thread*/) {
+ double dot = 0.0;
+ for (size_t c = 0; c < cols; c++) {
+ // For reasons unknown, fp16 += does not compile on clang (Arm).
+ dot += ConvertScalarTo<double>(mat[r * cols + c]) *
+ ConvertScalarTo<double>(vec[c]);
+ }
+ out[r] = ConvertScalarTo<T>(dot + ConvertScalarTo<double>(add[r]));
+ });
+ } else {
+ pool.Run(0, rows, [=](uint64_t r, size_t /*thread*/) {
+ double dot = 0.0;
+ for (size_t c = 0; c < cols; c++) {
+ // For reasons unknown, fp16 += does not compile on clang (Arm).
+ dot += ConvertScalarTo<double>(mat[r * cols + c]) *
+ ConvertScalarTo<double>(vec[c]);
+ }
+ out[r] = ConvertScalarTo<T>(dot);
+ });
+ }
+}
+
+HWY_MAYBE_UNUSED HWY_NOINLINE void SimpleMatVecAdd(
+ const hwy::bfloat16_t* HWY_RESTRICT mat, const float* HWY_RESTRICT vec,
+ const float* add, size_t rows, size_t cols, float* HWY_RESTRICT out,
+ ThreadPool& pool) {
+ if (add) {
+ pool.Run(0, rows, [=](uint64_t r, size_t /*thread*/) {
+ float dot = 0.0f;
+ for (size_t c = 0; c < cols; c++) {
+ dot += F32FromBF16(mat[r * cols + c]) * vec[c];
+ }
+ out[r] = dot + add[r];
+ });
+ } else {
+ pool.Run(0, rows, [=](uint64_t r, size_t /*thread*/) {
+ float dot = 0.0f;
+ for (size_t c = 0; c < cols; c++) {
+ dot += F32FromBF16(mat[r * cols + c]) * vec[c];
+ }
+ out[r] = dot;
+ });
+ }
+}
+
+HWY_MAYBE_UNUSED HWY_NOINLINE void SimpleMatVecAdd(
+ const hwy::bfloat16_t* HWY_RESTRICT mat,
+ const hwy::bfloat16_t* HWY_RESTRICT vec,
+ const hwy::bfloat16_t* HWY_RESTRICT add, size_t rows, size_t cols,
+ float* HWY_RESTRICT out, ThreadPool& pool) {
+ if (add) {
+ pool.Run(0, rows, [=](uint64_t r, size_t /*thread*/) {
+ float dot = 0.0f;
+ for (size_t c = 0; c < cols; c++) {
+ dot += F32FromBF16(mat[r * cols + c]) * F32FromBF16(vec[c]);
+ }
+ out[r] = dot + F32FromBF16(add[r]);
+ });
+ } else {
+ pool.Run(0, rows, [=](uint64_t r, size_t /*thread*/) {
+ float dot = 0.0f;
+ for (size_t c = 0; c < cols; c++) {
+ dot += F32FromBF16(mat[r * cols + c]) * F32FromBF16(vec[c]);
+ }
+ out[r] = dot;
+ });
+ }
+}
+
+// Workaround for incorrect codegen on Arm, which results in values of `av`
+// >= 1E10. Can also be prevented by calling `Print(du, indices)`.
+#if HWY_ARCH_ARM && HWY_COMPILER_CLANG
+#define GENERATE_INLINE HWY_NOINLINE
+#else
+#define GENERATE_INLINE HWY_INLINE
+#endif
+
+struct GenerateMod {
+ template <class D, HWY_IF_NOT_BF16_D(D), HWY_IF_LANES_GT_D(D, 1)>
+ GENERATE_INLINE Vec<D> operator()(D d,
+ Vec<RebindToUnsigned<D>> indices) const {
+ const RebindToUnsigned<D> du;
+ return Reverse2(d, ConvertTo(d, And(indices, Set(du, 0xF))));
+ }
+
+ template <class D, HWY_IF_NOT_BF16_D(D), HWY_IF_LANES_LE_D(D, 1)>
+ GENERATE_INLINE Vec<D> operator()(D d,
+ Vec<RebindToUnsigned<D>> indices) const {
+ const RebindToUnsigned<D> du;
+ return ConvertTo(d, And(indices, Set(du, 0xF)));
+ }
+
+ // Requires >= 4 bf16 lanes for float32 Reverse2.
+ template <class D, HWY_IF_BF16_D(D), HWY_IF_LANES_GT_D(D, 2)>
+ GENERATE_INLINE Vec<D> operator()(D d,
+ Vec<RebindToUnsigned<D>> indices) const {
+ const RebindToUnsigned<D> du;
+ const RebindToSigned<D> di;
+ const RepartitionToWide<decltype(di)> dw;
+ const RebindToFloat<decltype(dw)> df;
+ indices = And(indices, Set(du, 0xF));
+ const Vec<decltype(df)> i0 = ConvertTo(df, PromoteLowerTo(dw, indices));
+ const Vec<decltype(df)> i1 = ConvertTo(df, PromoteUpperTo(dw, indices));
+ return OrderedDemote2To(d, Reverse2(df, i0), Reverse2(df, i1));
+ }
+
+ // For one or two lanes, we don't have OrderedDemote2To nor Reverse2.
+ template <class D, HWY_IF_BF16_D(D), HWY_IF_LANES_LE_D(D, 2)>
+ GENERATE_INLINE Vec<D> operator()(D d,
+ Vec<RebindToUnsigned<D>> indices) const {
+ const Rebind<float, D> df;
+ return DemoteTo(d, Set(df, static_cast<float>(GetLane(indices))));
+ }
+};
+
+// MatT is usually the same as T, but can also be bfloat16_t when T = float.
+template <typename MatT, typename VecT>
+class TestMatVecAdd {
+ template <size_t kRows, size_t kCols, class D, typename T = TFromD<D>>
+ HWY_NOINLINE void Test(D d, ThreadPool& pool) {
+// This target lacks too many ops required in our implementation, use
+// HWY_EMU128 instead.
+#if HWY_TARGET != HWY_SCALAR
+ const Repartition<MatT, D> dm;
+ const Repartition<VecT, D> dv;
+ const size_t misalign = 3 * Lanes(d) / 5;
+ // Fill matrix and vector with small integer values
+ const size_t area = kRows * kCols;
+ AlignedFreeUniquePtr<MatT[]> storage_m =
+ AllocateAligned<MatT>(misalign + area);
+ AlignedFreeUniquePtr<VecT[]> storage_v =
+ AllocateAligned<VecT>(misalign + kCols);
+ AlignedFreeUniquePtr<VecT[]> storage_a =
+ AllocateAligned<VecT>(misalign + kRows);
+ HWY_ASSERT(storage_m && storage_v && storage_a);
+ MatT* pm = storage_m.get() + misalign;
+ VecT* pv = storage_v.get() + misalign;
+ VecT* av = storage_a.get() + misalign;
+ Generate(dm, pm, area, GenerateMod());
+ Generate(dv, pv, kCols, GenerateMod());
+ Generate(dv, av, kRows, GenerateMod());
+
+ AlignedFreeUniquePtr<T[]> expected_without_add = AllocateAligned<T>(kRows);
+ HWY_ASSERT(expected_without_add);
+ SimpleMatVecAdd(pm, pv, static_cast<VecT*>(nullptr), kRows, kCols,
+ expected_without_add.get(), pool);
+
+ AlignedFreeUniquePtr<T[]> actual_without_add = AllocateAligned<T>(kRows);
+ HWY_ASSERT(actual_without_add);
+ MatVec<kRows, kCols>(pm, pv, actual_without_add.get(), pool);
+
+ const auto assert_close = [&](const AlignedFreeUniquePtr<T[]>& expected,
+ const AlignedFreeUniquePtr<T[]>& actual,
+ bool with_add) {
+ for (size_t i = 0; i < kRows; ++i) {
+ const double exp = ConvertScalarTo<double>(expected[i]);
+ const double act = ConvertScalarTo<double>(actual[i]);
+ const double epsilon =
+ 1.0 / (1ULL << HWY_MIN(MantissaBits<MatT>(), MantissaBits<VecT>()));
+ const double tolerance = exp * 20.0 / epsilon;
+ const double l1 = std::abs(exp - act);
+ const double rel = exp == 0.0 ? 0.0 : l1 / exp;
+
+ if (l1 > tolerance && rel > epsilon) {
+ fprintf(stderr,
+ "%s/%s %zu x %zu, %s: mismatch at %zu: %E != %E; "
+ "tol %f l1 %f rel %E\n",
+ TypeName(MatT(), 1).c_str(), TypeName(VecT(), 1).c_str(),
+ kRows, kCols, (with_add ? "with add" : "without add"), i, exp,
+ act, tolerance, l1, rel);
+ HWY_ASSERT(0);
+ }
+ }
+ };
+
+ assert_close(expected_without_add, actual_without_add, /*with_add=*/false);
+
+ AlignedFreeUniquePtr<T[]> expected_with_add = AllocateAligned<T>(kRows);
+ SimpleMatVecAdd(pm, pv, av, kRows, kCols, expected_with_add.get(), pool);
+
+ AlignedFreeUniquePtr<T[]> actual_with_add = AllocateAligned<T>(kRows);
+ MatVecAdd<kRows, kCols>(pm, pv, av, actual_with_add.get(), pool);
+
+ assert_close(expected_with_add, actual_with_add, /*with_add=*/true);
+
+#else
+ (void)d;
+ (void)pool;
+#endif // HWY_TARGET != HWY_SCALAR
+ }
+
+ template <class D>
+ HWY_NOINLINE void CreatePoolAndTest(D d, size_t num_threads) {
+ // Threads might not work on WASM; run only on main thread.
+ if (HaveThreadingSupport()) num_threads = 0;
+
+ ThreadPool pool(HWY_MIN(num_threads, ThreadPool::MaxThreads()));
+
+ Test<AdjustedReps(192), AdjustedReps(256)>(d, pool);
+// Fewer tests due to compiler OOM
+#if !HWY_ARCH_RISCV
+ Test<40, AdjustedReps(512)>(d, pool);
+ Test<AdjustedReps(1024), 50>(d, pool);
+
+ // Too large for low-precision vectors/accumulators.
+ if (sizeof(TFromD<D>) != 2 && sizeof(VecT) != 2) {
+ Test<AdjustedReps(1536), AdjustedReps(1536)>(d, pool);
+ }
+#endif // !HWY_ARCH_RISCV
+ }
+
+ public:
+ template <class T, class D>
+ HWY_NOINLINE void operator()(T /*unused*/, D d) {
+ CreatePoolAndTest(d, 13);
+// Fewer tests due to compiler OOM
+#if !HWY_ARCH_RISCV
+ CreatePoolAndTest(d, 16);
+#endif
+ }
+};
+
+void TestAllMatVecAdd() {
+#if HWY_HAVE_FLOAT16
+ ForPartialVectors<TestMatVecAdd<float16_t, float16_t>>()(float16_t());
+#endif
+ ForPartialVectors<TestMatVecAdd<float, float>>()(float());
+#if HWY_HAVE_FLOAT64
+ ForPartialVectors<TestMatVecAdd<double, double>>()(double());
+#endif
+}
+
+void TestAllMatVecBF16() {
+ ForGEVectors<32, TestMatVecAdd<bfloat16_t, float>>()(float());
+}
+
+void TestAllMatVecBF16Both() {
+ ForGEVectors<32, TestMatVecAdd<bfloat16_t, bfloat16_t>>()(float());
+}
+
+} // namespace
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace hwy {
+namespace {
+HWY_BEFORE_TEST(MatVecTest);
+HWY_EXPORT_AND_TEST_P(MatVecTest, TestAllMatVecAdd);
+HWY_EXPORT_AND_TEST_P(MatVecTest, TestAllMatVecBF16);
+HWY_EXPORT_AND_TEST_P(MatVecTest, TestAllMatVecBF16Both);
+HWY_AFTER_TEST();
+} // namespace
+} // namespace hwy
+HWY_TEST_MAIN();
+#endif // HWY_ONCE
diff --git a/third_party/highway/hwy/contrib/random/random-inl.h b/third_party/highway/hwy/contrib/random/random-inl.h
index b96ef8a4ec..ca79a62a90 100644
--- a/third_party/highway/hwy/contrib/random/random-inl.h
+++ b/third_party/highway/hwy/contrib/random/random-inl.h
@@ -21,12 +21,16 @@
#define HIGHWAY_HWY_CONTRIB_RANDOM_RANDOM_H_
#endif
+#include <stddef.h>
+
#include <array>
#include <cstdint>
#include <limits>
#include "third_party/highway/hwy/aligned_allocator.h"
+#include "third_party/highway/hwy/contrib/sort/vqsort.h" // Fill16BytesSecure
#include "third_party/highway/hwy/highway.h"
+#include "third_party/highway/hwy/timer.h"
HWY_BEFORE_NAMESPACE(); // required if not using HWY_ATTR
@@ -35,7 +39,6 @@ namespace hwy {
namespace HWY_NAMESPACE { // required: unique per target
namespace internal {
-namespace {
#if HWY_HAVE_FLOAT64
// C++ < 17 does not support hexfloat
#if __cpp_hex_float > 201603L
@@ -52,7 +55,6 @@ constexpr std::uint64_t kJump[] = {0x180ec6d33cfd0aba, 0xd5a61266f0c9392c,
constexpr std::uint64_t kLongJump[] = {0x76e15d3efefdcbbf, 0xc5004e441c522fb3,
0x77710069854ee241, 0x39109bb02acbe635};
-} // namespace
class SplitMix64 {
public:
@@ -177,6 +179,7 @@ class VectorXoshiro {
#if HWY_HAVE_FLOAT64
using VF64 = Vec<ScalableTag<double>>;
#endif
+
public:
explicit VectorXoshiro(const std::uint64_t seed,
const std::uint64_t threadNumber = 0)
@@ -200,7 +203,7 @@ class VectorXoshiro {
HWY_INLINE VU64 operator()() noexcept { return Next(); }
- AlignedVector<std::uint64_t> operator()(const std::size_t n) {
+ AlignedVector<std::uint64_t> operator()(const size_t n) {
AlignedVector<std::uint64_t> result(n);
const ScalableTag<std::uint64_t> tag{};
auto s0 = Load(tag, state_[{0}].data());
@@ -253,7 +256,7 @@ class VectorXoshiro {
return Mul(real, MUL_VALUE);
}
- AlignedVector<double> Uniform(const std::size_t n) {
+ AlignedVector<double> Uniform(const size_t n) {
AlignedVector<double> result(n);
const ScalableTag<std::uint64_t> tag{};
const ScalableTag<double> real_tag{};
@@ -370,15 +373,133 @@ class CachedXoshiro {
private:
VectorXoshiro generator_;
alignas(HWY_ALIGNMENT) std::array<result_type, size> cache_;
- std::size_t index_;
+ size_t index_;
static_assert((size & (size - 1)) == 0 && size != 0,
"only power of 2 are supported");
};
+// Non-cryptographic 64-bit pseudo-random number generator. Supports random or
+// deterministic seeding.
+//
+// Based on 5-round AES-CTR. Supports 2^64 streams, each with period 2^64. This
+// is useful for parallel sampling. Each thread can generate the stream for a
+// particular task, without caring about prior/subsequent generations.
+class alignas(16) AesCtrEngine {
+ // "Large-scale randomness study of security margins for 100+ cryptographic
+ // functions": at least four.
+ // "Parallel Random Numbers: As Easy as 1, 2, 3": four not Crush-resistant.
+ static constexpr size_t kRounds = 5;
+
+ public:
+ // If `deterministic` is true, uses a fixed seed; otherwise, attempts to
+ // grab entropy from the OS.
+ explicit AesCtrEngine(bool deterministic) {
+ // Pi-based nothing up my sleeve numbers from Randen.
+ key_[0] = 0x243F6A8885A308D3ull;
+ key_[1] = 0x13198A2E03707344ull;
+
+ if (!deterministic) { // want random seed
+ if (!hwy::Fill16BytesSecure(key_)) {
+ HWY_WARN("Failed to fill RNG key with secure random bits");
+ // Entropy not available. The test requires that we inject some
+ // differences relative to the deterministic seeds.
+ key_[0] ^= reinterpret_cast<uint64_t>(this);
+ key_[1] ^= hwy::timer::Start();
+ }
+ }
+
+ // Simple key schedule: swap and add constant (also from Randen).
+ for (size_t i = 0; i < kRounds; ++i) {
+ key_[2 + 2 * i + 0] = key_[2 * i + 1] + 0xA4093822299F31D0ull;
+ key_[2 + 2 * i + 1] = key_[2 * i + 0] + 0x082EFA98EC4E6C89ull;
+ }
+ }
+
+ // Pure and thread safe; typically called via `RngStream`, which increments
+ // `counter`. Throughput is about 100M/s on 3 GHz Skylake. It could be
+ // increased 4x via unrolling by the AES latency (4-7 cycles), but because
+ // users generally call once at a time, this requires buffering, which is not
+ // worth the complexity in this application.
+ uint64_t operator()(uint64_t stream, uint64_t counter) const {
+#if HWY_TARGET != HWY_SCALAR
+ using D = Full128<uint8_t>; // 128 bits for AES
+ using V = Vec<D>;
+ const Repartition<uint64_t, D> d64;
+
+ auto LoadKey = [](const uint64_t* ptr) HWY_ATTR -> V {
+ return Load(D(), reinterpret_cast<const uint8_t*>(ptr));
+ };
+
+ V state = BitCast(D(), Dup128VecFromValues(d64, counter, stream));
+ state = Xor(state, LoadKey(key_)); // initial whitening
+
+ static_assert(kRounds == 5 && sizeof(key_) == 12 * sizeof(uint64_t), "");
+ state = AESRound(state, LoadKey(key_ + 2));
+ state = AESRound(state, LoadKey(key_ + 4));
+ state = AESRound(state, LoadKey(key_ + 6));
+ state = AESRound(state, LoadKey(key_ + 8));
+ // Final round: fine to use another AESRound, including MixColumns.
+ state = AESRound(state, LoadKey(key_ + 10));
+
+ // Return lower 64 bits of the u8 vector.
+ return GetLane(BitCast(d64, state));
+#else
+ HWY_DASSERT(0); // Not supported.
+ (void)stream;
+ (void)counter;
+ return 0;
+#endif // HWY_TARGET != HWY_SCALAR
+ }
+
+ private:
+ uint64_t key_[2 * (1 + kRounds)];
+};
+
+// Flyweight per-thread adapter that maintains the counter. Conforms to C++
+// `UniformRandomBitGenerator`.
+class RngStream {
+ public:
+ RngStream() = default; // Allow C arrays with subsequent initialization.
+
+ // Binds to an engine, which holds the seed and must outlive this object.
+ // Sets the stream; any other `RngStream` with the same `counter_rng` and
+ // `stream` will return the same sequence. This is typically the task ID, so
+ // that threads can independently generate values for each task.
+ RngStream(const AesCtrEngine& counter_rng, uint64_t stream)
+ : engine_(&counter_rng), stream_(stream), counter_(0) {}
+
+ using result_type = uint64_t;
+ static constexpr result_type min() { return 0; }
+ static constexpr result_type max() { return ~result_type{0}; }
+ result_type operator()() { return (*engine_)(stream_, counter_++); }
+
+ private:
+ const AesCtrEngine* engine_ = nullptr;
+ uint64_t stream_ = 0; // immutable after ctor
+ uint64_t counter_ = 0;
+ // Prevent false sharing if used by multiple threads.
+ HWY_MEMBER_VAR_MAYBE_UNUSED uint8_t
+ padding_[HWY_ALIGNMENT - 16 - sizeof(engine_)];
+};
+
+// Returns normalized float in [-1, 1).
+HWY_INLINE float RandomNormalizedFloat(RngStream& rng) {
+ const uint32_t exp = hwy::BitCastScalar<uint32_t>(1.0f);
+ const uint32_t mantissa_mask = hwy::MantissaMask<float>();
+ const uint32_t mantissa =
+ static_cast<uint32_t>(rng() & static_cast<uint64_t>(mantissa_mask));
+ const uint32_t representation = exp | mantissa;
+ const float f12 = hwy::BitCastScalar<float>(representation);
+ HWY_DASSERT(1.0f <= f12 && f12 < 2.0f); // exponent is 2^0, only mantissa
+ const float f = (2.0f * (f12 - 1.0f)) - 1.0f;
+ HWY_DASSERT(-1.0f <= f && f < 1.0f);
+ return f;
+}
+
} // namespace HWY_NAMESPACE
} // namespace hwy
HWY_AFTER_NAMESPACE();
-#endif // HIGHWAY_HWY_CONTRIB_MATH_MATH_INL_H_
\ No newline at end of file
+#endif // HIGHWAY_HWY_CONTRIB_RANDOM_RANDOM_H_
diff --git a/third_party/highway/hwy/contrib/random/random_test.cc b/third_party/highway/hwy/contrib/random/random_test.cc
new file mode 100644
index 0000000000..8c19a28a03
--- /dev/null
+++ b/third_party/highway/hwy/contrib/random/random_test.cc
@@ -0,0 +1,457 @@
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <cstdint>
+#include <cstdio>
+#include <ctime>
+#include <iostream> // cerr
+#include <random>
+#include <vector>
+
+// clang-format off
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "hwy/contrib/random/random_test.cc" // NOLINT
+#include "third_party/highway/hwy/foreach_target.h" // IWYU pragma: keep
+#include "third_party/highway/hwy/highway.h"
+#include "third_party/highway/hwy/contrib/random/random-inl.h"
+#include "third_party/highway/hwy/tests/test_util-inl.h"
+// clang-format on
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE { // required: unique per target
+namespace {
+
+constexpr std::uint64_t tests = 1UL << 10;
+
+std::uint64_t GetSeed() { return static_cast<uint64_t>(std::time(nullptr)); }
+
+void RngLoop(const std::uint64_t seed, std::uint64_t* HWY_RESTRICT result,
+ const size_t size) {
+ const ScalableTag<std::uint64_t> d;
+ VectorXoshiro generator{seed};
+ for (size_t i = 0; i < size; i += Lanes(d)) {
+ Store(generator(), d, result + i);
+ }
+}
+
+#if HWY_HAVE_FLOAT64
+void UniformLoop(const std::uint64_t seed, double* HWY_RESTRICT result,
+ const size_t size) {
+ const ScalableTag<double> d;
+ VectorXoshiro generator{seed};
+ for (size_t i = 0; i < size; i += Lanes(d)) {
+ Store(generator.Uniform(), d, result + i);
+ }
+}
+#endif
+
+void TestSeeding() {
+ const std::uint64_t seed = GetSeed();
+ VectorXoshiro generator{seed};
+ internal::Xoshiro reference{seed};
+ const auto& state = generator.GetState();
+ const ScalableTag<std::uint64_t> d;
+ const std::size_t lanes = Lanes(d);
+ for (std::size_t i = 0UL; i < lanes; ++i) {
+ const auto& reference_state = reference.GetState();
+ for (std::size_t j = 0UL; j < reference_state.size(); ++j) {
+ if (state[{j}][i] != reference_state[j]) {
+ std::cerr << "SEED: " << seed << "\n";
+ std::cerr << "TEST SEEDING ERROR: ";
+ std::cerr << "state[" << j << "][" << i << "] -> " << state[{j}][i]
+ << " != " << reference_state[j] << "\n";
+ HWY_ASSERT(0);
+ }
+ }
+ reference.Jump();
+ }
+}
+
+void TestMultiThreadSeeding() {
+ const std::uint64_t seed = GetSeed();
+ const std::uint64_t threadId = GetSeed() % 1000;
+ VectorXoshiro generator{seed, threadId};
+ internal::Xoshiro reference{seed};
+
+ for (std::size_t i = 0UL; i < threadId; ++i) {
+ reference.LongJump();
+ }
+
+ const auto& state = generator.GetState();
+ const ScalableTag<std::uint64_t> d;
+ const std::size_t lanes = Lanes(d);
+ for (std::size_t i = 0UL; i < lanes; ++i) {
+ const auto& reference_state = reference.GetState();
+ for (std::size_t j = 0UL; j < reference_state.size(); ++j) {
+ if (state[{j}][i] != reference_state[j]) {
+ std::cerr << "SEED: " << seed << std::endl;
+ std::cerr << "TEST SEEDING ERROR: ";
+ std::cerr << "state[" << j << "][" << i << "] -> " << state[{j}][i]
+ << " != " << reference_state[j] << "\n";
+ HWY_ASSERT(0);
+ }
+ }
+ reference.Jump();
+ }
+}
+
+void TestRandomUint64() {
+ const std::uint64_t seed = GetSeed();
+ const auto result_array = hwy::MakeUniqueAlignedArray<std::uint64_t>(tests);
+ RngLoop(seed, result_array.get(), tests);
+ std::vector<internal::Xoshiro> reference;
+ reference.emplace_back(seed);
+ const ScalableTag<std::uint64_t> d;
+ const std::size_t lanes = Lanes(d);
+ for (std::size_t i = 1UL; i < lanes; ++i) {
+ auto rng = reference.back();
+ rng.Jump();
+ reference.emplace_back(rng);
+ }
+
+ for (std::size_t i = 0UL; i < tests; i += lanes) {
+ for (std::size_t lane = 0UL; lane < lanes; ++lane) {
+ const std::uint64_t result = reference[lane]();
+ if (result_array[i + lane] != result) {
+ std::cerr << "SEED: " << seed << std::endl;
+ std::cerr << "TEST UINT64 GENERATOR ERROR: result_array[" << i + lane
+ << "] -> " << result_array[i + lane] << " != " << result
+ << std::endl;
+ HWY_ASSERT(0);
+ }
+ }
+ }
+}
+void TestUniformDist() {
+#if HWY_HAVE_FLOAT64
+ const std::uint64_t seed = GetSeed();
+ const auto result_array = hwy::MakeUniqueAlignedArray<double>(tests);
+ UniformLoop(seed, result_array.get(), tests);
+ internal::Xoshiro reference{seed};
+ const ScalableTag<double> d;
+ const std::size_t lanes = Lanes(d);
+ for (std::size_t i = 0UL; i < tests; i += lanes) {
+ const double result = reference.Uniform();
+ if (result_array[i] != result) {
+ std::cerr << "SEED: " << seed << std::endl;
+ std::cerr << "TEST UNIFORM GENERATOR ERROR: result_array[" << i << "] -> "
+ << result_array[i] << " != " << result << std::endl;
+ HWY_ASSERT(0);
+ }
+ }
+#endif // HWY_HAVE_FLOAT64
+}
+
+void TestNextNRandomUint64() {
+ const std::uint64_t seed = GetSeed();
+ VectorXoshiro generator{seed};
+ const auto result_array = generator.operator()(tests);
+ std::vector<internal::Xoshiro> reference;
+ reference.emplace_back(seed);
+ const ScalableTag<std::uint64_t> d;
+ const std::size_t lanes = Lanes(d);
+ for (std::size_t i = 1UL; i < lanes; ++i) {
+ auto rng = reference.back();
+ rng.Jump();
+ reference.emplace_back(rng);
+ }
+
+ for (std::size_t i = 0UL; i < tests; i += lanes) {
+ for (std::size_t lane = 0UL; lane < lanes; ++lane) {
+ const std::uint64_t result = reference[lane]();
+ if (result_array[i + lane] != result) {
+ std::cerr << "SEED: " << seed << std::endl;
+ std::cerr << "TEST UINT64 GENERATOR ERROR: result_array[" << i + lane
+ << "] -> " << result_array[i + lane] << " != " << result
+ << std::endl;
+ HWY_ASSERT(0);
+ }
+ }
+ }
+}
+
+void TestNextFixedNRandomUint64() {
+ const std::uint64_t seed = GetSeed();
+ VectorXoshiro generator{seed};
+ const auto result_array = generator.operator()<tests>();
+ std::vector<internal::Xoshiro> reference;
+ reference.emplace_back(seed);
+ const ScalableTag<std::uint64_t> d;
+ const std::size_t lanes = Lanes(d);
+ for (std::size_t i = 1UL; i < lanes; ++i) {
+ auto rng = reference.back();
+ rng.Jump();
+ reference.emplace_back(rng);
+ }
+
+ for (std::size_t i = 0UL; i < tests; i += lanes) {
+ for (std::size_t lane = 0UL; lane < lanes; ++lane) {
+ const std::uint64_t result = reference[lane]();
+ if (result_array[i + lane] != result) {
+ std::cerr << "SEED: " << seed << std::endl;
+ std::cerr << "TEST UINT64 GENERATOR ERROR: result_array[" << i + lane
+ << "] -> " << result_array[i + lane] << " != " << result
+ << std::endl;
+
+ HWY_ASSERT(0);
+ }
+ }
+ }
+}
+void TestNextNUniformDist() {
+#if HWY_HAVE_FLOAT64
+ const std::uint64_t seed = GetSeed();
+ VectorXoshiro generator{seed};
+ const auto result_array = generator.Uniform(tests);
+ internal::Xoshiro reference{seed};
+ const ScalableTag<double> d;
+ const std::size_t lanes = Lanes(d);
+ for (std::size_t i = 0UL; i < tests; i += lanes) {
+ const double result = reference.Uniform();
+ if (result_array[i] != result) {
+ std::cerr << "SEED: " << seed << std::endl;
+ std::cerr << "TEST UNIFORM GENERATOR ERROR: result_array[" << i << "] -> "
+ << result_array[i] << " != " << result << std::endl;
+
+ HWY_ASSERT(0);
+ }
+ }
+#endif // HWY_HAVE_FLOAT64
+}
+
+void TestNextFixedNUniformDist() {
+#if HWY_HAVE_FLOAT64
+ const std::uint64_t seed = GetSeed();
+ VectorXoshiro generator{seed};
+ const auto result_array = generator.Uniform<tests>();
+ internal::Xoshiro reference{seed};
+ const ScalableTag<double> d;
+ const std::size_t lanes = Lanes(d);
+ for (std::size_t i = 0UL; i < tests; i += lanes) {
+ const double result = reference.Uniform();
+ if (result_array[i] != result) {
+ std::cerr << "SEED: " << seed << std::endl;
+ std::cerr << "TEST UNIFORM GENERATOR ERROR: result_array[" << i << "] -> "
+ << result_array[i] << " != " << result << std::endl;
+ HWY_ASSERT(0);
+ }
+ }
+#endif // HWY_HAVE_FLOAT64
+}
+
+void TestCachedXorshiro() {
+ const std::uint64_t seed = GetSeed();
+
+ CachedXoshiro<> generator{seed};
+ std::vector<internal::Xoshiro> reference;
+ reference.emplace_back(seed);
+ const ScalableTag<std::uint64_t> d;
+ const std::size_t lanes = Lanes(d);
+ for (std::size_t i = 1UL; i < lanes; ++i) {
+ auto rng = reference.back();
+ rng.Jump();
+ reference.emplace_back(rng);
+ }
+
+ for (std::size_t i = 0UL; i < tests; i += lanes) {
+ for (std::size_t lane = 0UL; lane < lanes; ++lane) {
+ const std::uint64_t result = reference[lane]();
+ const std::uint64_t got = generator();
+ if (got != result) {
+ std::cerr << "SEED: " << seed << std::endl;
+ std::cerr << "TEST CachedXoshiro GENERATOR ERROR: result_array["
+ << i + lane << "] -> " << got << " != " << result
+ << std::endl;
+
+ HWY_ASSERT(0);
+ }
+ }
+ }
+}
+void TestUniformCachedXorshiro() {
+#if HWY_HAVE_FLOAT64
+ const std::uint64_t seed = GetSeed();
+
+ CachedXoshiro<> generator{seed};
+ std::uniform_real_distribution<double> distribution{0., 1.};
+ for (std::size_t i = 0UL; i < tests; ++i) {
+ const double result = distribution(generator);
+
+ if (result < 0. || result >= 1.) {
+ std::cerr << "SEED: " << seed << std::endl;
+ std::cerr << "TEST CachedXoshiro GENERATOR ERROR: result_array[" << i
+ << "] -> " << result << " not in interval [0, 1)" << std::endl;
+ HWY_ASSERT(0);
+ }
+ }
+#endif // HWY_HAVE_FLOAT64
+}
+
+// ----- AesCtrEngine / RngStream / RandomNormalizedFloat tests -----
+
+#if HWY_TARGET != HWY_SCALAR
+
+void TestAesCtrDeterministic() {
+ const AesCtrEngine engine1(/*deterministic=*/true);
+ const AesCtrEngine engine2(/*deterministic=*/true);
+ RngStream rng1(engine1, 0);
+ RngStream rng2(engine2, 0);
+ // Remember for later testing after resetting the stream.
+ const uint64_t r0 = rng1();
+ const uint64_t r1 = rng1();
+ // Not consecutive values.
+ HWY_ASSERT(r0 != r1);
+ // Let rng2 catch up.
+ HWY_ASSERT(r0 == rng2());
+ HWY_ASSERT(r1 == rng2());
+
+ for (size_t i = 0; i < 1000; ++i) {
+ HWY_ASSERT(rng1() == rng2());
+ }
+
+ // Reset counter, ensure it matches the prior sequence.
+ rng1 = RngStream(engine1, 0);
+ HWY_ASSERT(r0 == rng1());
+ HWY_ASSERT(r1 == rng1());
+}
+
+void TestAesCtrSeeded() {
+ AesCtrEngine engine1(/*deterministic=*/true);
+ AesCtrEngine engine2(/*deterministic=*/false);
+ RngStream rng1(engine1, 0);
+ RngStream rng2(engine2, 0);
+ // It would be very unlucky to have even one 64-bit value match, and two are
+ // extremely unlikely.
+ const uint64_t a0 = rng1();
+ const uint64_t a1 = rng1();
+ const uint64_t b0 = rng2();
+ const uint64_t b1 = rng2();
+ HWY_ASSERT(a0 != b0 || a1 != b1);
+}
+
+void TestAesCtrStreamsDiffer() {
+ AesCtrEngine engine(/*deterministic=*/true);
+ // Compare random streams for more coverage than just the first N streams.
+ RngStream rng_for_stream(engine, 0);
+ for (size_t i = 0; i < 1000; ++i) {
+ RngStream rng1(engine, rng_for_stream());
+ RngStream rng2(engine, rng_for_stream());
+ // It would be very unlucky to have even one 64-bit value match, and two are
+ // extremely unlikely.
+ const uint64_t a0 = rng1();
+ const uint64_t a1 = rng1();
+ const uint64_t b0 = rng2();
+ const uint64_t b1 = rng2();
+ HWY_ASSERT(a0 != b0 || a1 != b1);
+ }
+}
+
+// If not close to 50% 1-bits, the RNG is quite broken.
+void TestAesCtrBitDistribution() {
+ AesCtrEngine engine(/*deterministic=*/true);
+ RngStream rng(engine, 0);
+ constexpr size_t kU64 = 2 * 1000 * 1000;
+ uint64_t one_bits = 0;
+ for (size_t i = 0; i < kU64; ++i) {
+ one_bits += hwy::PopCount(rng());
+ }
+ const uint64_t total_bits = kU64 * 64;
+ const double one_ratio = static_cast<double>(one_bits) / total_bits;
+ fprintf(stderr, "AesCtr 1-bit ratio %.5f\n", one_ratio);
+ HWY_ASSERT(0.4999 <= one_ratio && one_ratio <= 0.5001);
+}
+
+void TestAesCtrChiSquared() {
+ AesCtrEngine engine(/*deterministic=*/true);
+ RngStream rng(engine, 0);
+ constexpr size_t kU64 = 1 * 1000 * 1000;
+
+ // Test each byte separately.
+ for (size_t shift = 0; shift < 64; shift += 8) {
+ size_t counts[256] = {};
+ for (size_t i = 0; i < kU64; ++i) {
+ const size_t byte = (rng() >> shift) & 0xFF;
+ counts[byte]++;
+ }
+
+ double chi_squared = 0.0;
+ const double expected = static_cast<double>(kU64) / 256.0;
+ for (size_t i = 0; i < 256; ++i) {
+ const double diff = static_cast<double>(counts[i]) - expected;
+ chi_squared += diff * diff / expected;
+ }
+ // Should be within ~0.5% and 99.5% percentiles. See
+ // https://www.medcalc.org/manual/chi-square-table.php
+ if (chi_squared < 196.0 || chi_squared > 311.0) {
+ HWY_ABORT("Chi-squared byte %zu: %.5f \n", shift / 8, chi_squared);
+ }
+ }
+}
+
+void TestRandomNormalizedFloat() {
+ AesCtrEngine engine(/*deterministic=*/true);
+ RngStream rng(engine, 0);
+ constexpr size_t kCount = 100000;
+ double sum = 0.0;
+ for (size_t i = 0; i < kCount; ++i) {
+ const float f = RandomNormalizedFloat(rng);
+ HWY_ASSERT(-1.0f <= f && f < 1.0f);
+ sum += static_cast<double>(f);
+ }
+ // Mean should be near 0 for uniform [-1, 1).
+ const double mean = sum / kCount;
+ fprintf(stderr, "RandomNormalizedFloat mean: %.6f\n", mean);
+ HWY_ASSERT(-0.01 < mean && mean < 0.01);
+}
+
+#else
+
+void TestAesCtrDeterministic() {}
+
+void TestAesCtrSeeded() {}
+
+void TestAesCtrStreamsDiffer() {}
+
+void TestAesCtrBitDistribution() {}
+
+void TestAesCtrChiSquared() {}
+
+void TestRandomNormalizedFloat() {}
+
+#endif // HWY_TARGET != HWY_SCALAR
+
+} // namespace
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE(); // required if not using HWY_ATTR
+
+#if HWY_ONCE
+namespace hwy {
+namespace {
+HWY_BEFORE_TEST(HwyRandomTest);
+HWY_EXPORT_AND_TEST_P(HwyRandomTest, TestSeeding);
+HWY_EXPORT_AND_TEST_P(HwyRandomTest, TestMultiThreadSeeding);
+HWY_EXPORT_AND_TEST_P(HwyRandomTest, TestRandomUint64);
+HWY_EXPORT_AND_TEST_P(HwyRandomTest, TestNextNRandomUint64);
+HWY_EXPORT_AND_TEST_P(HwyRandomTest, TestNextFixedNRandomUint64);
+HWY_EXPORT_AND_TEST_P(HwyRandomTest, TestCachedXorshiro);
+HWY_EXPORT_AND_TEST_P(HwyRandomTest, TestUniformDist);
+HWY_EXPORT_AND_TEST_P(HwyRandomTest, TestNextNUniformDist);
+HWY_EXPORT_AND_TEST_P(HwyRandomTest, TestNextFixedNUniformDist);
+HWY_EXPORT_AND_TEST_P(HwyRandomTest, TestUniformCachedXorshiro);
+HWY_EXPORT_AND_TEST_P(HwyRandomTest, TestAesCtrDeterministic);
+HWY_EXPORT_AND_TEST_P(HwyRandomTest, TestAesCtrSeeded);
+HWY_EXPORT_AND_TEST_P(HwyRandomTest, TestAesCtrStreamsDiffer);
+HWY_EXPORT_AND_TEST_P(HwyRandomTest, TestAesCtrBitDistribution);
+HWY_EXPORT_AND_TEST_P(HwyRandomTest, TestAesCtrChiSquared);
+HWY_EXPORT_AND_TEST_P(HwyRandomTest, TestRandomNormalizedFloat);
+HWY_AFTER_TEST();
+} // namespace
+} // namespace hwy
+HWY_TEST_MAIN();
+#endif // HWY_ONCE
diff --git a/third_party/highway/hwy/contrib/sort/BUILD b/third_party/highway/hwy/contrib/sort/BUILD
index 9dd625f5a3..a146efbe3a 100644
--- a/third_party/highway/hwy/contrib/sort/BUILD
+++ b/third_party/highway/hwy/contrib/sort/BUILD
@@ -1,3 +1,7 @@
+load("@rules_cc//cc:cc_binary.bzl", "cc_binary")
+load("@rules_cc//cc:cc_library.bzl", "cc_library")
+load("@rules_cc//cc:cc_test.bzl", "cc_test")
+
package(
default_applicable_licenses = ["//:license"],
default_visibility = ["//visibility:public"],
@@ -82,7 +86,6 @@ cc_library(
)
VQSORT_SRCS = [
- "vqsort.cc",
# Split into separate files to reduce MSVC build time.
"vqsort_128a.cc",
"vqsort_128d.cc",
@@ -98,8 +101,7 @@ VQSORT_SRCS = [
"vqsort_i32d.cc",
"vqsort_i64a.cc",
"vqsort_i64d.cc",
- "vqsort_kv64a.cc",
- "vqsort_kv64d.cc",
+ # vqsort_kv64a.cc is in :vqsort_k32v32 and vqsort.cc is in :vqsort_shared.
"vqsort_kv128a.cc",
"vqsort_kv128d.cc",
"vqsort_u16a.cc",
@@ -119,6 +121,45 @@ VQSORT_TEXTUAL_HDRS = [
# Placeholder for internal instrumentation. Do not remove.
]
+# both :vqsort_k32v32 and :vqsort depend on this.
+cc_library(
+ name = "vqsort_shared",
+ srcs = [
+ "vqsort.cc",
+ ],
+ hdrs = [
+ "order.h", # part of public interface, included by vqsort.h
+ "vqsort.h", # public interface
+ ],
+ compatible_with = [],
+ local_defines = ["hwy_contrib_EXPORTS"],
+ textual_hdrs = VQSORT_TEXTUAL_HDRS,
+ deps = [
+ "//:algo",
+ "//:hwy",
+ ],
+)
+
+cc_library(
+ name = "vqsort_k32v32",
+ srcs = [
+ "vqsort_kv64a.cc",
+ "vqsort_kv64d.cc",
+ ],
+ hdrs = [
+ "order.h", # part of public interface, included by vqsort.h
+ "vqsort.h", # public interface
+ ],
+ compatible_with = [],
+ local_defines = ["hwy_contrib_EXPORTS"],
+ textual_hdrs = VQSORT_TEXTUAL_HDRS,
+ deps = [
+ ":vqsort_shared",
+ "//:algo",
+ "//:hwy",
+ ],
+)
+
cc_library(
name = "vqsort",
srcs = VQSORT_SRCS,
@@ -131,9 +172,12 @@ cc_library(
textual_hdrs = VQSORT_TEXTUAL_HDRS,
deps = [
":intel", # required if HAVE_INTEL
+ ":vqsort_k32v32",
+ ":vqsort_shared",
":vxsort", # required if HAVE_VXSORT
"//:algo",
"//:hwy",
+ "//:nanobenchmark",
],
)
@@ -160,6 +204,7 @@ cc_library(
deps = [
"//:algo",
"//:hwy",
+ "//:nanobenchmark",
],
)
diff --git a/third_party/highway/hwy/contrib/sort/README.md b/third_party/highway/hwy/contrib/sort/README.md
new file mode 100644
index 0000000000..800d2cfcda
--- /dev/null
+++ b/third_party/highway/hwy/contrib/sort/README.md
@@ -0,0 +1,361 @@
+# Vectorized and performance-portable Quicksort
+
+## Introduction
+
+As of 2022-06-07 this sorts large arrays of built-in types about ten times as
+fast as LLVM's `std::sort`. Note that other algorithms such as pdqsort can be
+about twice as fast as LLVM's std::sort as of 2023-06.
+
+See also our
+[blog post](https://opensource.googleblog.com/2022/06/Vectorized%20and%20performance%20portable%20Quicksort.html)
+and [paper](https://arxiv.org/abs/2205.05982).
+
+## Instructions
+
+Here are instructions for reproducing our results with cross-platform CMake,
+Linux, or AWS (SVE, NEON).
+
+### CMake, any platform
+
+Please first ensure that Clang (tested with 13.0.1 and 15.0.6) is installed, and
+if it is not the default compiler, point the CC and CXX environment variables to
+it, e.g.
+
+```
+export CC=clang-15
+export CXX=clang++-15
+```
+
+Then run the usual CMake workflow, also documented in the Highway README, e.g.:
+
+```
+mkdir -p build && cd build && cmake .. && make -j
+taskset -c 2 tests/bench_sort
+```
+
+The optional `taskset -c 2` part reduces the variability of measurements by
+preventing the OS from migrating the benchmark between cores.
+
+### Linux
+
+Please first ensure golang, and Clang (tested with 13.0.1) are installed via
+your system's package manager.
+
+```
+go install github.com/bazelbuild/bazelisk@latest
+git clone https://github.com/google/highway
+cd highway
+CC=clang CXX=clang++ ~/go/bin/bazelisk build -c opt hwy/contrib/sort:all
+bazel-bin/hwy/contrib/sort/sort_test
+bazel-bin/hwy/contrib/sort/bench_sort
+```
+
+### AWS Graviton3
+
+Instance config: amazon linux 5.10 arm64, c7g.8xlarge (largest allowed config is
+32 vCPU). Initial launch will fail. Wait a few minutes for an email saying the
+config is verified, then re-launch. See IPv4 hostname in list of instances.
+
+`ssh -i /path/key.pem ec2-user@hostname`
+
+Note that the AWS CMake package is too old for llvm, so we build it first:
+```
+wget https://cmake.org/files/v3.23/cmake-3.23.2.tar.gz
+tar -xvzf cmake-3.23.2.tar.gz && cd cmake-3.23.2/
+./bootstrap -- -DCMAKE_USE_OPENSSL=OFF
+make -j8 && sudo make install
+cd ..
+```
+
+AWS clang is at version 11.1, which generates unnecessary `AND` instructions
+which slow down the sort by 1.15x. We tested with clang trunk as of June 13
+(which reports Git hash 8f6512fea000c3a0d394864bb94e524bee375069). To build:
+
+```
+git clone --depth 1 https://github.com/llvm/llvm-project.git
+cd llvm-project
+mkdir -p build && cd build
+/usr/local/bin/cmake ../llvm -DLLVM_ENABLE_PROJECTS="clang" -DLLVM_ENABLE_RUNTIMES="libcxx;libcxxabi" -DCMAKE_BUILD_TYPE=Release
+make -j32 && sudo make install
+```
+
+```
+sudo yum install go
+go install github.com/bazelbuild/bazelisk@latest
+git clone https://github.com/google/highway
+cd highway
+CC=/usr/local/bin/clang CXX=/usr/local/bin/clang++ ~/go/bin/bazelisk build -c opt --copt=-march=armv8.2-a+sve hwy/contrib/sort:all
+bazel-bin/hwy/contrib/sort/sort_test
+bazel-bin/hwy/contrib/sort/bench_sort
+```
+
+The above command line enables SVE, which is currently only available on
+Graviton 3. You can also test NEON on the same processor, or other Arm CPUs, by
+changing the `-march=` option to `--copt=-march=armv8.2-a+crypto`. Note that
+such flags will be unnecessary once Clang supports `#pragma target` for NEON and
+SVE intrinsics, as it does for x86.
+
+## Results
+
+`bench_sort` outputs the instruction set (AVX3 refers to AVX-512), the sort
+algorithm (std for `std::sort`, vq for our vqsort), the type of keys being
+sorted (f32 is float), the distribution of keys (uniform32 for uniform random
+with range 0-2^32), the number of keys, then the throughput of sorted keys (i.e.
+number of key bytes output per second).
+
+Example excerpt from Xeon 6154 (Skylake-X) CPU clocked at 3 GHz:
+
+```
+[ RUN ] BenchSortGroup/BenchSort.BenchAllSort/AVX3
+ AVX3: std: f32: uniform32: 1.00E+06 54 MB/s ( 1 threads)
+ AVX3: vq: f32: uniform32: 1.00E+06 1143 MB/s ( 1 threads)
+```
+
+## Additional results
+
+Thanks to Lukas Bergdoll, who did a thorough [performance analysis](https://github.com/Voultapher/sort-research-rs/blob/main/writeup/intel_avx512/text.md)
+on various sort implementations. This helped us identify a performance bug,
+caused by obtaining entropy from the OS on each call. This was fixed in #1334
+and we look forward to the updated results.
+
+### Optimizations for small arrays
+
+Our initial focus was on large arrays. Since the VQSort paper was published,
+we have improved its performance for small arrays:
+
+- Previously, each call to VQSort obtained entropy from the OS. Unpredictable
+ seeding does help avoid worst-cases, and the cost is negligible when the
+ input size is at least 100K elements. However, the overhead is very costly
+ for arrays of just 100 or 1000, so we now obtain entropy only once per
+ thread and cache the seeds in TLS. This significantly improves the
+ performance on subsequent calls. Users can also explicitly initialize the
+ random generator.
+
+- We also improved the efficiency of our sorting network for inputs shorter
+ than half its size. Our approach avoids costly transposes by interpreting
+ inputs as a 2D matrix. Previously, we always used 16 rows, which means only
+ a single vector lane is active for up to 16 elements. We have added 8x2 and
+ 8x4 networks which use more lanes when available, and also 4x1 and 8x1
+ networks for very small inputs.
+
+- Previously we also loaded (overlapping) full vectors, with the offsets
+ determined by the number of columns. Now we use the minimum vector size
+ sufficient for the number of columns, which enables higher IPC on Skylake
+ and reduces the cost of unaligned loads.
+
+ Unfortunately this decreases code reuse; VQSort now consists of about 1500
+ instructions (https://gcc.godbolt.org/z/ojYKfjPe6). The size of sorting
+ networks has nearly doubled to 10.8 KiB, 70% of the total. Although large,
+ this still fits comfortably within 32 KiB instruction caches, and possibly
+ even in micro-op caches (DSB, 1500-2300 micro-ops), especially given that
+ not all instructions are guaranteed to execute.
+
+### Study of AVX-512 downclocking
+We study whether AVX-512 downclocking affects performance. Using the GHz
+reported by perf, we find an upper bound on the effects of downclocking, and
+observe that its effect is negligible when compared to scalar code.
+
+This issue has somehow attracted far more attention than seems warranted. An
+attempt by Daniel Lemire to measure the
+[worst-case](https://lemire.me/blog/2018/08/15/the-dangers-of-avx-512-throttling-a-3-impact/)
+only saw a **3% decrease**, and Intel CPUs since Icelake, as well as AMD Zen4,
+are much less impacted by throttling, if at all. By contrast, "Silver" and
+"Bronze" Intel Xeons have more severe throttling and would require a large(r)
+speedup from AVX-512 to outweigh the downclocking. However, these CPUs are
+marketed towards "entry compute, network and storage" and "small business and
+storage server solutions", and are thus less suitable for the high-performance
+workloads we consider.
+
+Our test workstation runs Linux (6.1.20-2rodete1-amd64) and has the same Xeon
+Gold 6154 CPU used in our paper because its Skylake microarchitecture is the
+most (potentially) affected. The compiler is a Clang similar to the LLVM trunk.
+
+We added a new 'cold' benchmark that initializes random seeds, fills an array
+with a constant except at one random index, calls VQSort, and then prints a
+random element to ensure the computations are not elided. To run it, we build
+bench_sort with `-DSORT_ONLY_COLD=1` and then invoke
+`taskset -c 6 setarch -R x86_64 perf stat -r 15 -d bench_sort`. The taskset and
+setarch serve to reduce variability by avoiding thread migration, and disabling
+address space randomization. `-r 15` requests 15 runs so that perf can display
+the variability of the measurements: < 1% for cycles, instructions, L1 dcache
+loads; LLC miss variability is much higher (> 10%) presumably due to the
+remaining background activity on this machine.
+
+For our measurements, we use the GHz value reported by `perf`. This does not
+include time spent in the kernel, and is thus noisy for short runtimes. Note
+that running `perf` under `sudo` is not an option because it results in
+"Workload failed: Cannot allocate memory". We see results between 2.6 - 2.9 GHz
+when running AVX-512 code. This is relative to 3.0 GHz nominal; we disabled
+Turbo Boost via MSR and ran `sudo cpupower frequency-set --governor performance`
+to prevent unnecessary frequency reductions. To the best of our knowledge, the
+remaining gap is explained by time spent in the kernel (in particular handling
+page faults) and downclocking. Thus an *upper-bound* for the latter is
+(3 - 2.9)/3 to (3 - 2.6)/3, or **1.03 - 1.13x**. Such a frequency reduction
+would already be negligible compared to the 2-4x increase in work per cycle from
+512-bit SIMD relative to 256 or 128-bit SIMD, which is typically less or not at
+all affected by downclocking.
+
+To further tighten this bound, we compare AVX-512 code vs. non-AVX-512 code, in
+the form of `std::sort`. Ensuring the remainder of the binary does not use
+AVX-512 is nontrivial. Library functions such as `memset` are known to use
+AVX-512, and they would not show up in a disassembly of our binary. Neither
+would they raise exceptions if run on a CPU lacking AVX-512 support, because
+software typically verifies CPU support before running AVX-512. As a first step,
+we take care to avoid calls to such library functions in our test, which is more
+feasible with a self-contained small binary. In particular, array
+zero-initialization typically compiles to `memset` (verified with clang-16), so
+we manually initialize the array to the return value of an `Unpredictable1`
+function whose implementation is not visible to the compiler. This indeed
+compiles to a scalar loop. To further increase confidence that the binary lacks
+AVX-512 instructions before VQSort, we replace the initialization loop with
+AVX-512 stores. This indeed raises the measured throughput from a fairly
+consistent 9 GB/s to 9-15 GB/s, likely because some of the AVX-512 startup now
+occurs outside of our timings. We examine this effect in the next section, but
+for now we can conclude that because adding AVX-512 makes a difference, the
+binary was otherwise not using it. Now we can revert to scalar initialization
+and compare the GHz reported for VQSort vs. `std::sort`. Across three runs, the
+ranges are 2.8-2.9 and 2.8-2.8 GHz. Thus we conclude: if there is any
+downclocking for a single core running AVX-512 on this Skylake-X CPU, the effect
+is **under the noise floor of our measurement**, and certainly far below any
+speedup one can reasonably predict from 512-bit SIMD. We expect this result to
+generalize to AMD Zen4 and any Gold/Platinum Intel Xeon.
+
+### Study of AVX-512 startup overhead
+
+In the previous section, we saw that downclocking is negligible on our system,
+but there is a noticeable benefit to warming up AVX-512 before the sort. To
+understand why, we refer to Travis Downs' excellent
+[measurements](https://travisdowns.github.io/blog/2020/01/17/avxfreq1.html#summary)
+of how Skylake reacts to an AVX-512 instruction: 8-20 us of reduced instruction
+throughput, an additional potential halt of 10 us, and then downclocking.
+Note that downclocking is negligible on a single core per the previous section.
+
+We choose the array length of 10K unsigned 64-bit keys such that VQSort
+completes in 7-10 us. Thus in this benchmark, VQSort (almost) finishes before
+AVX-512 is fully warmed up, and the speedup is reduced because the startup costs
+are amortized over relatively little data. Across five series of 15 runs, the
+average of average throughputs is 9.3 GB/s, implying a runtime of 8.6 us
+including startup costs.
+
+Note that the two-valued, almost all-equal input distribution is quite skewed.
+The above throughput does not reflect the performance attainable on other
+distributions, especially uniform random. However, this choice is deliberate
+because Quicksort can terminate early if all values in a partition are equal.
+When measuring such a 'best-case' input, we are more likely to observe the cost
+of startup overhead in surrounding code. Otherwise, this overhead might be
+hidden by the increase in sorting time.
+
+Now let us compare this throughput to the previously mentioned measurement with
+AVX-512 warmed up (via slow scatter instructions so that initialization takes
+about 100 us, well in excess of the warmup period): 15.2 GB/s, or 5.3 us without
+startup cost. It appears the 10 us halt is not happening, possibly because we do
+not use SIMD floating-point nor multiplication instructions. Thus we only
+experience reduced instruction throughput and/or increased latency. The ratio
+between cold and warmed-up time is only 1.6, which is plausible if the Skylake
+throttling is actually rounding latencies up to a multiple of four cycles, as
+Downs speculates. Indeed a large fraction of the SIMD instructions especially in
+the VQSort base case are cross-lane or 64-bit min/max operations with latencies
+of 3 cycles on Skylake, so their slowdown might only be 1.3x. The measured 1.6x
+could plausibly derive from 7/8 of 1.3x and 1/8 of 4x for single-cycle latency
+instructions.
+
+Assuming this understanding of AVX-512 startup cost is valid, how long does it
+remain active before the CPU reverts to the previous settings? The CPU cannot
+know what future instructions are coming, and to prevent unnecessary
+transitions, it has a hysteresis (delay after the last AVX-512 instruction
+before shutting down) which Downs measures as 680 us. Thus our benchmark
+subsequently sleeps for 100 ms to ensure the next run of the binary sees the
+original CPU state. Indeed we find for the five series that the slopes of the
+lines of best fit are negative in one case, positive in two, and flat in two,
+indicating there is no consistent pattern of benefit for earlier or later runs.
+
+What are the implications for users of VQSort? If the surrounding code executes
+an AVX-512 instruction at least every 500 us, then AVX-512 remains active and
+**any call to VQSort will benefit from it, no matter how small the input**.
+This is a reasonable expectation for modern systems whose designers were aware
+of data-oriented programming principles, because many (though not all) domains
+and operations can benefit from SIMD. By contrast, consider the case of dropping
+VQSort into an existing legacy system that does not yet use SIMD. In the case of
+10K input sizes, we still observe a 2.3x speedup vs. `std::sort`. However, the
+following code may have to deal with throttling for the remainder of the 20 us
+startup period. With VQSort we have 8.6 us runtime plus up to 11.4 us throttled
+code (potentially running at quarter speed) plus the remaining 3/4 of 11.4 for a
+total of 28.6. With `std::sort` we have 19.5 us runtime plus 20 us of normal
+subsequent code, or 39.5 us. Thus the overall speedup for the 20 us region plus
+VQSort **shrinks to 1.4x**, and it is possible to imagine an actual slowdown for
+sufficiently small inputs, when factoring in the throttling of subsequent code.
+This unfortunate 'beggar thy neighbor' effect cannot be solved at the level of
+individual building blocks such as a sort, and must instead be addressed at the
+system level. For example:
+
+- vectorizing more and more parts of the code to amortize startup cost;
+- relying on newer CPUs than Skylake (launched 2015!) which have little or no
+ AVX-512 startup overhead, such as Intel Icelake (2021) or AMD Zen4 (2022);
+- ensuring sorts (or anything else using AVX-512) process at least 100 KiB
+ of data, such that the expected speedup outweighs any startup cost.
+
+Any of these solutions are sufficient to render AVX-512 startup overhead a
+non-issue.
+
+### Comparison with Intel's x86-simd-sort and vxsort
+
+Our May 2022 paper compared performance with `ips4o` and `std::sort`. We now add
+results for Intel's [x86-simd-sort](https://github.com/intel/x86-simd-sort),
+released as open source around October 2022, and
+[vxsort](https://github.com/damageboy/vxsort-cpp/tree/master). We find that
+VQSort is generally about 1.4 times as fast as either, and in a few cases equal
+or up to 2% slower.
+
+Note that vxsort was open-sourced around May 2020; we were unaware of it at the
+time of writing because it had been published in the form of a blog series. We
+imported both from Github on 2023-06-06 at about 10:15 UTC. Both are integrated
+into our bench_sort, running on the same Linux OS and Xeon 6154 CPU mentioned
+above. We use uniform random inputs, because vxsort and x86-simd-sort appear to
+have much less robust handling of skewed input distributions. They choose the
+pivot as the median of three keys, or of 64 bytes, respectively. By contrast,
+VQSort draws a 384 byte sample and analyzes their distribution, which improves
+load balance and prevents recursing into all-equal partitions. Lacking this, the
+other algorithms are more vulnerable to worst-cases. Choosing uniform random
+thus prevents disadvantaging the other algorithms.
+
+We sample performance across a range of input sizes and types:
+
+- To isolate the performance of the sorting networks used by all three
+ algorithms, we start with powers of two up to 128. VQSort is generally the
+ fastest for 64-bit keys with the following exceptions: tie with vxsort at
+ N=2 (537 MB/s), slower than vxsort at N=16 (2114 vs. 2147), tie with
+ x86-simd-sort at N=32 (2643 MB/s). Note that VQSort is about 1.6 times as
+ fast as both others for N=128; possibly because its 2D structure enables
+ larger networks.
+
+- The `kPow10` mode in bench_sort measures power of ten input sizes between
+ 10 and 100K. Note that this covers non-power of two sizes, as well as the
+ crossover point between sorting networks and Quicksort recursion. The
+ speedups of VQSort relative to x86-simd-sort range from 1.33 to 1.81
+ (32-bit keys), and 1.25 to 1.68 (64-bit keys), with geomeans of 1.48 and
+ 1.44. The speedups of VQSort relative to vxsort range from 1.08 to 2.10
+ (32-bit keys), and 1.00 to 1.47 (64-bit keys), with geomeans of 1.41 and
+ 1.20. Note that vxsort matches VQSort at 10 64-bit elements; in all other
+ cases, VQSort is strictly faster.
+
+- Finally, we study the effect of key type at a fixed input size of 10K
+ elements. x86-simd-sort requires AVX512-VBMI2 for int16, which our CPU does
+ not support. Also, both other algorithms do not support 128-bit keys, thus
+ we only consider 32/64-bit integer and float types. The results in MB/s are:
+
+ |Type|VQSort|x86-simd-sort|vxsort|
+ |---|---|---|---|
+ |f32|**1551**| 798| 823|
+ |f64|**1773**|1147| 745|
+ |i32|**1509**|1042| 968|
+ |i64|**1365**|1043|1145|
+
+ VQSort is the fastest for each type, in some cases even about twice as fast.
+ Interestingly, vxsort performs at its best on i64, whereas the others are at
+ their best for f64. A potential explanation is that this CPU can execute two
+ f64 min/max per cycle, but only one i64.
+
+In conclusion, VQSort is generally more efficient than vxsort and x86-simd-sort
+across a range of input sizes and types. Occasionally, it is up to 2% slower,
+but the geomean of its speedup (32-bit keys and power-of-ten sizes) vs. vxsort
+is **1.41**, and **1.48** vs. x86-simd-sort.
diff --git a/third_party/highway/hwy/contrib/sort/algo-inl.h b/third_party/highway/hwy/contrib/sort/algo-inl.h
index 1530c5c57d..95cf2bb9ae 100644
--- a/third_party/highway/hwy/contrib/sort/algo-inl.h
+++ b/third_party/highway/hwy/contrib/sort/algo-inl.h
@@ -24,6 +24,7 @@
#include <functional> // std::less, std::greater
#include <vector>
+#include "third_party/highway/hwy/base.h"
#include "third_party/highway/hwy/contrib/sort/vqsort.h"
#include "third_party/highway/hwy/highway.h"
#include "third_party/highway/hwy/print.h"
@@ -207,36 +208,18 @@ enum class Algo {
};
static inline bool IsVQ(Algo algo) {
- switch (algo) {
- case Algo::kVQSort:
- case Algo::kVQPartialSort:
- case Algo::kVQSelect:
- return true;
- default:
- return false;
- }
+ return algo == Algo::kVQSort || algo == Algo::kVQPartialSort ||
+ algo == Algo::kVQSelect;
}
static inline bool IsSelect(Algo algo) {
- switch (algo) {
- case Algo::kStdSelect:
- case Algo::kVQSelect:
- case Algo::kHeapSelect:
- return true;
- default:
- return false;
- }
+ return algo == Algo::kStdSelect || algo == Algo::kVQSelect ||
+ algo == Algo::kHeapSelect;
}
static inline bool IsPartialSort(Algo algo) {
- switch (algo) {
- case Algo::kStdPartialSort:
- case Algo::kVQPartialSort:
- case Algo::kHeapPartialSort:
- return true;
- default:
- return false;
- }
+ return algo == Algo::kStdPartialSort || algo == Algo::kVQPartialSort ||
+ algo == Algo::kHeapPartialSort;
}
static inline Algo ReferenceAlgoFor(Algo algo) {
@@ -451,8 +434,8 @@ InputStats<T> GenerateInput(const Dist dist, T* v, size_t num_lanes) {
}
InputStats<T> input_stats;
- for (size_t i = 0; i < num_lanes; ++i) {
- input_stats.Notify(v[i]);
+ for (size_t j = 0; j < num_lanes; ++j) {
+ input_stats.Notify(v[j]);
}
return input_stats;
}
@@ -606,9 +589,6 @@ void Run(Algo algo, KeyType* inout, size_t num_keys, SharedState& shared,
return CallHeapPartialSort(inout, num_keys, k_keys, Order());
case Algo::kHeapSelect:
return CallHeapSelect(inout, num_keys, k_keys, Order());
-
- default:
- HWY_ABORT("Not implemented");
}
}
diff --git a/third_party/highway/hwy/contrib/sort/bench_parallel.cc b/third_party/highway/hwy/contrib/sort/bench_parallel.cc
new file mode 100644
index 0000000000..34054b5e12
--- /dev/null
+++ b/third_party/highway/hwy/contrib/sort/bench_parallel.cc
@@ -0,0 +1,242 @@
+// Copyright 2021 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Concurrent, independent sorts for generating more memory traffic and testing
+// scalability when bandwidth-limited. If you want to use multiple threads for
+// a single sort, you can use ips4o and integrate vqsort by calling it from
+// `baseCaseSort` and increasing `IPS4OML_BASE_CASE_SIZE` to say 8192.
+
+#include <stdint.h>
+#include <stdio.h>
+
+#include <condition_variable> //NOLINT
+#include <functional>
+#include <mutex> //NOLINT
+#include <thread> //NOLINT
+#include <vector>
+
+#include "third_party/highway/hwy/timer.h"
+
+// clang-format off
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "hwy/contrib/sort/bench_parallel.cc" //NOLINT
+#include "third_party/highway/hwy/foreach_target.h" // IWYU pragma: keep
+
+// After foreach_target
+#include "third_party/highway/hwy/contrib/sort/algo-inl.h"
+#include "third_party/highway/hwy/contrib/sort/result-inl.h"
+#include "third_party/highway/hwy/aligned_allocator.h"
+// Last
+#include "third_party/highway/hwy/tests/test_util-inl.h"
+// clang-format on
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+namespace {
+
+class ThreadPool {
+ public:
+ // Starts the given number of worker threads and blocks until they are ready.
+ explicit ThreadPool(
+ const size_t num_threads = std::thread::hardware_concurrency())
+ : num_threads_(num_threads) {
+ HWY_ASSERT(num_threads_ > 0);
+ threads_.reserve(num_threads_);
+ for (size_t i = 0; i < num_threads_; ++i) {
+ threads_.emplace_back(ThreadFunc, this, i);
+ }
+
+ WorkersReadyBarrier();
+ }
+
+ ThreadPool(const ThreadPool&) = delete;
+ ThreadPool& operator&(const ThreadPool&) = delete;
+
+ // Waits for all threads to exit.
+ ~ThreadPool() {
+ StartWorkers(kWorkerExit);
+
+ for (std::thread& thread : threads_) {
+ thread.join();
+ }
+ }
+
+ size_t NumThreads() const { return threads_.size(); }
+
+ template <class Func>
+ void RunOnThreads(size_t max_threads, const Func& func) {
+ task_ = &CallClosure<Func>;
+ data_ = &func;
+ StartWorkers(max_threads);
+ WorkersReadyBarrier();
+ }
+
+ private:
+ // After construction and between calls to Run, workers are "ready", i.e.
+ // waiting on worker_start_cv_. They are "started" by sending a "command"
+ // and notifying all worker_start_cv_ waiters. (That is why all workers
+ // must be ready/waiting - otherwise, the notification will not reach all of
+ // them and the main thread waits in vain for them to report readiness.)
+ using WorkerCommand = uint64_t;
+
+ static constexpr WorkerCommand kWorkerWait = ~1ULL;
+ static constexpr WorkerCommand kWorkerExit = ~2ULL;
+
+ // Calls a closure (lambda with captures).
+ template <class Closure>
+ static void CallClosure(const void* f, size_t thread) {
+ (*reinterpret_cast<const Closure*>(f))(thread);
+ }
+
+ void WorkersReadyBarrier() {
+ std::unique_lock<std::mutex> lock(mutex_);
+ // Typically only a single iteration.
+ while (workers_ready_ != threads_.size()) {
+ workers_ready_cv_.wait(lock);
+ }
+ workers_ready_ = 0;
+
+ // Safely handle spurious worker wakeups.
+ worker_start_command_ = kWorkerWait;
+ }
+
+ // Precondition: all workers are ready.
+ void StartWorkers(const WorkerCommand worker_command) {
+ std::unique_lock<std::mutex> lock(mutex_);
+ worker_start_command_ = worker_command;
+ // Workers will need this lock, so release it before they wake up.
+ lock.unlock();
+ worker_start_cv_.notify_all();
+ }
+
+ static void ThreadFunc(ThreadPool* self, size_t thread) {
+ // Until kWorkerExit command received:
+ for (;;) {
+ std::unique_lock<std::mutex> lock(self->mutex_);
+ // Notify main thread that this thread is ready.
+ if (++self->workers_ready_ == self->num_threads_) {
+ self->workers_ready_cv_.notify_one();
+ }
+ RESUME_WAIT:
+ // Wait for a command.
+ self->worker_start_cv_.wait(lock);
+ const WorkerCommand command = self->worker_start_command_;
+ switch (command) {
+ case kWorkerWait: // spurious wakeup:
+ goto RESUME_WAIT; // lock still held, avoid incrementing ready.
+ case kWorkerExit:
+ return; // exits thread
+ default:
+ break;
+ }
+
+ lock.unlock();
+ // Command is the maximum number of threads that should run the task.
+ HWY_ASSERT(command < self->NumThreads());
+ if (thread < command) {
+ self->task_(self->data_, thread);
+ }
+ }
+ }
+
+ const size_t num_threads_;
+
+ // Unmodified after ctor, but cannot be const because we call thread::join().
+ std::vector<std::thread> threads_;
+
+ std::mutex mutex_; // guards both cv and their variables.
+ std::condition_variable workers_ready_cv_;
+ size_t workers_ready_ = 0;
+ std::condition_variable worker_start_cv_;
+ WorkerCommand worker_start_command_;
+
+ // Written by main thread, read by workers (after mutex lock/unlock).
+ std::function<void(const void*, size_t)> task_; // points to CallClosure
+ const void* data_; // points to caller's Func
+};
+
+template <class Traits>
+void RunWithoutVerify(Traits st, const Dist dist, const size_t num_keys,
+ const Algo algo, SharedState& shared, size_t thread) {
+ using LaneType = typename Traits::LaneType;
+ using KeyType = typename Traits::KeyType;
+ using Order = typename Traits::Order;
+ const size_t num_lanes = num_keys * st.LanesPerKey();
+ auto aligned = hwy::AllocateAligned<LaneType>(num_lanes);
+
+ (void)GenerateInput(dist, aligned.get(), num_lanes);
+
+ const Timestamp t0;
+ Run(algo, reinterpret_cast<KeyType*>(aligned.get()), num_keys, shared, thread,
+ /*k_keys=*/0, Order());
+ HWY_ASSERT(aligned[0] < aligned[num_lanes - 1]);
+}
+
+void BenchParallel() {
+ // Not interested in benchmark results for other targets on x86
+ if (HWY_ARCH_X86 &&
+ (HWY_TARGET != HWY_AVX2 && HWY_TARGET != HWY_AVX3 &&
+ HWY_TARGET != HWY_AVX3_ZEN4 && HWY_TARGET != HWY_AVX3_SPR)) {
+ return;
+ }
+
+ ThreadPool pool;
+ const size_t NT = pool.NumThreads();
+
+ detail::SharedTraits<detail::TraitsLane<detail::OrderAscending<int64_t>>> st;
+ using KeyType = typename decltype(st)::KeyType;
+ const size_t num_keys = size_t{100} * 1000 * 1000;
+
+#if HAVE_IPS4O
+ const Algo algo = Algo::kIPS4O;
+#else
+ const Algo algo = Algo::kVQSort;
+#endif
+ const Dist dist = Dist::kUniform32;
+
+ SharedState shared;
+
+ std::vector<SortResult> results;
+ for (size_t nt = 1; nt < NT; nt += HWY_MAX(1, NT / 16)) {
+ Timestamp t0;
+ // Default capture because MSVC wants algo/dist but clang does not.
+ pool.RunOnThreads(nt, [=, &shared](size_t thread) {
+ RunWithoutVerify(st, dist, num_keys, algo, shared, thread);
+ });
+ const double sec = SecondsSince(t0);
+ results.emplace_back(algo, dist, num_keys, nt, sec, sizeof(KeyType),
+ st.KeyString());
+ results.back().Print();
+ }
+}
+
+} // namespace
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+
+namespace hwy {
+namespace {
+HWY_BEFORE_TEST(BenchParallel);
+HWY_EXPORT_AND_TEST_P(BenchParallel, BenchParallel);
+HWY_AFTER_TEST();
+} // namespace
+} // namespace hwy
+
+#endif // HWY_ONCE
diff --git a/third_party/highway/hwy/contrib/sort/bench_sort.cc b/third_party/highway/hwy/contrib/sort/bench_sort.cc
new file mode 100644
index 0000000000..fd03879414
--- /dev/null
+++ b/third_party/highway/hwy/contrib/sort/bench_sort.cc
@@ -0,0 +1,473 @@
+// Copyright 2021 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stdint.h>
+#include <stdio.h>
+
+#include <vector>
+
+// clang-format off
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "hwy/contrib/sort/bench_sort.cc"
+#include "third_party/highway/hwy/foreach_target.h" // IWYU pragma: keep
+
+// After foreach_target
+#include "third_party/highway/hwy/contrib/sort/algo-inl.h"
+#include "third_party/highway/hwy/contrib/sort/vqsort.h"
+#include "third_party/highway/hwy/contrib/sort/result-inl.h"
+#include "third_party/highway/hwy/contrib/sort/sorting_networks-inl.h" // SharedTraits
+#include "third_party/highway/hwy/contrib/sort/traits-inl.h"
+#include "third_party/highway/hwy/contrib/sort/traits128-inl.h"
+#include "third_party/highway/hwy/tests/test_util-inl.h"
+#include "third_party/highway/hwy/nanobenchmark.h"
+#include "third_party/highway/hwy/timer.h"
+#include "third_party/highway/hwy/contrib/thread_pool/futex.h" // NanoSleep
+// clang-format on
+
+// Mode for larger sorts because M1 is able to access more than the per-core
+// share of L2, so 1M elements might still be in cache.
+#define SORT_100M 0
+
+#ifndef SORT_ONLY_COLD
+#define SORT_ONLY_COLD 0
+#endif
+#ifndef SORT_BENCH_BASE_AND_PARTITION
+#define SORT_BENCH_BASE_AND_PARTITION (!SORT_ONLY_COLD && 0)
+#endif
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+// Defined within HWY_ONCE, used by BenchAllSort.
+extern int64_t first_sort_target;
+extern int64_t first_cold_target; // for BenchAllColdSort
+
+namespace HWY_NAMESPACE {
+namespace {
+using detail::OrderAscending;
+using detail::OrderDescending;
+using detail::SharedTraits;
+using detail::TraitsLane;
+
+#if HWY_TARGET != HWY_SCALAR
+using detail::OrderAscending128;
+using detail::OrderAscendingKV128;
+using detail::Traits128;
+#endif // HWY_TARGET != HWY_SCALAR
+
+HWY_NOINLINE void BenchAllColdSort() {
+ // Only run the best(first) enabled target
+ if (first_cold_target == 0) first_cold_target = HWY_TARGET;
+ if (HWY_TARGET != first_cold_target) {
+ return;
+ }
+
+ char cpu100[100];
+ if (!platform::HaveTimerStop(cpu100)) {
+ HWY_WARN("CPU '%s' does not support RDTSCP, skipping benchmark.\n", cpu100);
+ return;
+ }
+
+ // Initialize random seeds
+#if VQSORT_ENABLED
+ HWY_ASSERT(GetGeneratorState() != nullptr); // vqsort
+#endif
+ RandomState rng(static_cast<uint64_t>(Unpredictable1() * 129)); // this test
+
+ using T = uint64_t;
+ constexpr size_t kSize = 10 * 1000;
+ HWY_ALIGN T items[kSize];
+
+ // Initialize array
+#if 0 // optional: deliberate AVX-512 to verify VQSort performance improves
+ const ScalableTag<T> d;
+ const RebindToSigned<decltype(d)> di;
+ const size_t N = Lanes(d);
+ size_t i = 0;
+ for (; i + N <= kSize; i += N) {
+ // Super-slow scatter so that we spend enough time to warm up SKX.
+ const Vec<decltype(d)> val = Set(d, static_cast<T>(Unpredictable1()));
+ const Vec<decltype(di)> idx =
+ Iota(di, static_cast<T>(Unpredictable1() - 1));
+ ScatterIndex(val, d, items + i, idx);
+ }
+ for (; i < kSize; ++i) {
+ items[i] = static_cast<T>(Unpredictable1());
+ }
+#else // scalar-only, verified with clang-16
+ for (size_t i = 0; i < kSize; ++i) {
+ items[i] = static_cast<T>(Unpredictable1());
+ }
+#endif
+ items[Random32(&rng) % kSize] = static_cast<T>(Unpredictable1() + 1);
+
+ const timer::Ticks t0 = timer::Start();
+ const SortAscending order;
+#if VQSORT_ENABLED && 1 // change to && 0 to switch to std::sort.
+ VQSort(items, kSize, order);
+#else
+ SharedState shared;
+ Run(Algo::kStdSort, items, kSize, shared, /*thread=*/0, /*k_keys=*/0, order);
+#endif
+ const timer::Ticks t1 = timer::Stop();
+
+ const double ticks = static_cast<double>(t1 - t0);
+ const double elapsed = ticks / platform::InvariantTicksPerSecond();
+ const double GBps = kSize * sizeof(T) * 1E-9 / elapsed;
+
+ fprintf(stderr, "N=%zu GB/s=%.2f ns=%.1f random output: %g\n", kSize, GBps,
+ elapsed * 1E9, static_cast<double>(items[Random32(&rng) % kSize]));
+
+#if SORT_ONLY_COLD
+ // Long enough for the CPU to switch off AVX-512 mode before the next run.
+ NanoSleep(100 * 1000 * 1000);
+#endif
+}
+
+#if (VQSORT_ENABLED && SORT_BENCH_BASE_AND_PARTITION) || HWY_IDE
+
+template <class Traits>
+HWY_NOINLINE void BenchPartition() {
+ using LaneType = typename Traits::LaneType;
+ using KeyType = typename Traits::KeyType;
+ const SortTag<LaneType> d;
+ detail::SharedTraits<Traits> st;
+ const Dist dist = Dist::kUniform8;
+ double sum = 0.0;
+
+ constexpr size_t kLPK = st.LanesPerKey();
+ HWY_ALIGN LaneType
+ buf[SortConstants::BufBytes<LaneType, kLPK>(HWY_MAX_BYTES) /
+ sizeof(LaneType)];
+ uint64_t* HWY_RESTRICT state = GetGeneratorState();
+
+ const size_t max_log2 = AdjustedLog2Reps(20);
+ for (size_t log2 = max_log2; log2 < max_log2 + 1; ++log2) {
+ const size_t num_lanes = 1ull << log2;
+ const size_t num_keys = num_lanes / kLPK;
+ auto aligned = hwy::AllocateAligned<LaneType>(num_lanes);
+
+ std::vector<double> seconds;
+ const size_t num_reps = (1ull << (14 - log2 / 2)) * 30;
+ for (size_t rep = 0; rep < num_reps; ++rep) {
+ (void)GenerateInput(dist, aligned.get(), num_lanes);
+
+ // The pivot value can influence performance. Do exactly what vqsort will
+ // do so that the performance (influenced by prefetching and branch
+ // prediction) is likely to predict the actual performance inside vqsort.
+ detail::DrawSamples(d, st, aligned.get(), num_lanes, buf, state);
+ detail::SortSamples(d, st, buf);
+ auto pivot = detail::ChoosePivotByRank(d, st, buf);
+
+ const Timestamp t0;
+ detail::Partition(d, st, aligned.get(), num_lanes - 1, pivot, buf);
+ seconds.push_back(SecondsSince(t0));
+ // 'Use' the result to prevent optimizing out the partition.
+ sum += static_cast<double>(aligned.get()[num_lanes / 2]);
+ }
+
+ SortResult(Algo::kVQSort, dist, num_keys, 1, SummarizeMeasurements(seconds),
+ sizeof(KeyType), st.KeyString())
+ .Print();
+ }
+ HWY_ASSERT(sum != 999999); // Prevent optimizing out
+}
+
+HWY_NOINLINE void BenchAllPartition() {
+ // Not interested in benchmark results for these targets
+ if (HWY_TARGET == HWY_SSSE3) {
+ return;
+ }
+
+ BenchPartition<TraitsLane<OrderDescending<float>>>();
+ BenchPartition<TraitsLane<OrderDescending<int32_t>>>();
+ BenchPartition<TraitsLane<OrderDescending<int64_t>>>();
+ BenchPartition<Traits128<OrderAscending128>>();
+ // BenchPartition<Traits128<OrderDescending128>>();
+ BenchPartition<Traits128<OrderAscendingKV128>>();
+}
+
+template <class Traits>
+HWY_NOINLINE void BenchBase(std::vector<SortResult>& results) {
+ // Not interested in benchmark results for these targets
+ if (HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4) {
+ return;
+ }
+
+ using LaneType = typename Traits::LaneType;
+ using KeyType = typename Traits::KeyType;
+ const SortTag<LaneType> d;
+ detail::SharedTraits<Traits> st;
+ const Dist dist = Dist::kUniform32;
+ const Algo algo = Algo::kVQSort;
+
+ const size_t N = Lanes(d);
+ constexpr size_t kLPK = st.LanesPerKey();
+ const size_t num_lanes = SortConstants::BaseCaseNumLanes<kLPK>(N);
+ const size_t num_keys = num_lanes / kLPK;
+ auto keys = hwy::AllocateAligned<LaneType>(num_lanes);
+ auto buf = hwy::AllocateAligned<LaneType>(num_lanes + N);
+
+ std::vector<double> seconds;
+ double sum = 0; // prevents elision
+ constexpr size_t kMul = AdjustedReps(600); // ensures long enough to measure
+
+ for (size_t rep = 0; rep < 30; ++rep) {
+ InputStats<LaneType> input_stats =
+ GenerateInput(dist, keys.get(), num_lanes);
+
+ const Timestamp t0;
+ for (size_t i = 0; i < kMul; ++i) {
+ detail::BaseCase(d, st, keys.get(), num_lanes, buf.get());
+ sum += static_cast<double>(keys[0]);
+ }
+ seconds.push_back(SecondsSince(t0));
+ // printf("%f\n", seconds.back());
+
+ SortOrderVerifier<Traits>()(algo, input_stats, keys.get(), num_keys,
+ num_keys);
+ }
+ HWY_ASSERT(sum < 1E99);
+ results.emplace_back(algo, dist, num_keys * kMul, 1,
+ SummarizeMeasurements(seconds), sizeof(KeyType),
+ st.KeyString());
+}
+
+HWY_NOINLINE void BenchAllBase() {
+ // Not interested in benchmark results for these targets
+ if (HWY_TARGET == HWY_SSSE3) {
+ return;
+ }
+
+ std::vector<SortResult> results;
+ BenchBase<TraitsLane<OrderAscending<float>>>(results);
+ BenchBase<TraitsLane<OrderDescending<int64_t>>>(results);
+ BenchBase<Traits128<OrderAscending128>>(results);
+ for (const SortResult& r : results) {
+ r.Print();
+ }
+}
+
+#endif // VQSORT_ENABLED && SORT_BENCH_BASE_AND_PARTITION
+
+std::vector<Algo> AlgoForBench() {
+ return {
+#if HAVE_AVX2SORT
+ Algo::kSEA,
+#endif
+#if HAVE_PARALLEL_IPS4O
+ Algo::kParallelIPS4O,
+#elif HAVE_IPS4O
+ Algo::kIPS4O,
+#endif
+#if HAVE_PDQSORT
+ Algo::kPDQ,
+#endif
+#if HAVE_SORT512
+ Algo::kSort512,
+#endif
+// Only include if we're compiling for the target it supports.
+#if HAVE_VXSORT && ((VXSORT_AVX3 && HWY_TARGET == HWY_AVX3) || \
+ (!VXSORT_AVX3 && HWY_TARGET == HWY_AVX2))
+ Algo::kVXSort,
+#endif
+// Only include if we're compiling for the target it supports.
+#if HAVE_INTEL && HWY_TARGET <= HWY_AVX3
+ Algo::kIntel,
+#endif
+
+#if !HAVE_PARALLEL_IPS4O
+#if !SORT_100M
+ // 10-20x slower, but that's OK for the default size when we are not
+ // testing the parallel nor 100M modes.
+ // Algo::kStdSort,
+#endif
+
+#if VQSORT_ENABLED
+ Algo::kVQSort,
+#endif
+#endif // !HAVE_PARALLEL_IPS4O
+ };
+}
+
+template <class Traits>
+HWY_NOINLINE void BenchSort(size_t num_keys) {
+ if (first_sort_target == 0) first_sort_target = HWY_TARGET;
+
+ SharedState shared;
+ detail::SharedTraits<Traits> st;
+ using Order = typename Traits::Order;
+ using LaneType = typename Traits::LaneType;
+ using KeyType = typename Traits::KeyType;
+ const size_t num_lanes = num_keys * st.LanesPerKey();
+ auto aligned = hwy::AllocateAligned<LaneType>(num_lanes);
+
+ const size_t reps = num_keys > 1000 * 1000 ? 10 : 30;
+
+ for (Algo algo : AlgoForBench()) {
+ // Other algorithms don't depend on the vector instructions, so only run
+ // them for the first target.
+#if !HAVE_VXSORT
+ if (algo != Algo::kVQSort && HWY_TARGET != first_sort_target) {
+ continue;
+ }
+#endif
+
+ for (Dist dist : AllDist()) {
+ std::vector<double> seconds;
+ for (size_t rep = 0; rep < reps; ++rep) {
+ InputStats<LaneType> input_stats =
+ GenerateInput(dist, aligned.get(), num_lanes);
+
+ const Timestamp t0;
+ Run(algo, HWY_RCAST_ALIGNED(KeyType*, aligned.get()), num_keys, shared,
+ /*thread=*/0, /*k_keys=*/0, Order());
+ seconds.push_back(SecondsSince(t0));
+ // printf("%f\n", seconds.back());
+
+ SortOrderVerifier<Traits>()(algo, input_stats, aligned.get(), num_keys,
+ num_keys);
+ }
+ SortResult(algo, dist, num_keys, 1, SummarizeMeasurements(seconds),
+ sizeof(KeyType), st.KeyString())
+ .Print();
+ } // dist
+ } // algo
+}
+
+enum class BenchmarkModes {
+ kDefault,
+ k1M,
+ k10K,
+ kAllSmall,
+ kSmallPow2,
+ kSmallPow2Between, // includes padding
+ kPow4,
+ kPow10
+};
+
+std::vector<size_t> SizesToBenchmark(BenchmarkModes mode) {
+ std::vector<size_t> sizes;
+ switch (mode) {
+ case BenchmarkModes::kDefault:
+#if HAVE_PARALLEL_IPS4O || SORT_100M
+ sizes.push_back(100 * 1000 * size_t{1000});
+#else
+ sizes.push_back(100);
+ sizes.push_back(100 * 1000);
+#endif
+ break;
+ case BenchmarkModes::k1M:
+ sizes.push_back(1000 * 1000);
+ break;
+ case BenchmarkModes::k10K:
+ sizes.push_back(10 * 1000);
+ break;
+
+ case BenchmarkModes::kAllSmall:
+ sizes.reserve(128);
+ for (size_t i = 1; i <= 128; ++i) {
+ sizes.push_back(i);
+ }
+ break;
+ case BenchmarkModes::kSmallPow2:
+ for (size_t size = 2; size <= 128; size *= 2) {
+ sizes.push_back(size);
+ }
+ break;
+ case BenchmarkModes::kSmallPow2Between:
+ for (size_t size = 2; size <= 128; size *= 2) {
+ sizes.push_back(3 * size / 2);
+ }
+ break;
+
+ case BenchmarkModes::kPow4:
+ for (size_t size = 4; size <= 256 * 1024; size *= 4) {
+ sizes.push_back(size);
+ }
+ break;
+ case BenchmarkModes::kPow10:
+ for (size_t size = 10; size <= 100 * 1000; size *= 10) {
+ sizes.push_back(size);
+ }
+ break;
+ }
+ return sizes;
+}
+
+HWY_NOINLINE void BenchAllSort() {
+ // Not interested in benchmark results for these targets. Note that SSE4 is
+ // numerically less than SSE2, hence it is the lower bound.
+ if (HWY_SSE4 <= HWY_TARGET && HWY_TARGET <= HWY_SSE2 && Unpredictable1()) {
+ return;
+ }
+#if HAVE_INTEL
+ if (HWY_TARGET > HWY_AVX3) return;
+#endif
+
+ for (size_t num_keys : SizesToBenchmark(BenchmarkModes::kSmallPow2)) {
+#if !HAVE_INTEL
+#if HWY_HAVE_FLOAT16
+ if (hwy::HaveFloat16()) {
+ BenchSort<TraitsLane<OtherOrder<float16_t>>>(num_keys);
+ }
+#endif
+ BenchSort<TraitsLane<OrderAscending<float>>>(num_keys);
+#if HWY_HAVE_FLOAT64
+ if (hwy::HaveFloat64()) {
+ // BenchSort<TraitsLane<OtherOrder<double>>>(num_keys);
+ }
+#endif
+#endif // !HAVE_INTEL
+ // BenchSort<TraitsLane<OrderAscending<int16_t>>>(num_keys);
+ BenchSort<TraitsLane<OtherOrder<int32_t>>>(num_keys);
+ BenchSort<TraitsLane<OrderAscending<int64_t>>>(num_keys);
+ // BenchSort<TraitsLane<OtherOrder<uint16_t>>>(num_keys);
+ // BenchSort<TraitsLane<OtherOrder<uint32_t>>>(num_keys);
+ // BenchSort<TraitsLane<OrderAscending<uint64_t>>>(num_keys);
+
+#if !HAVE_VXSORT && !HAVE_INTEL && HWY_TARGET != HWY_SCALAR
+ BenchSort<Traits128<OrderAscending128>>(num_keys);
+ BenchSort<Traits128<OrderAscendingKV128>>(num_keys);
+#endif
+ }
+}
+
+} // namespace
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+
+namespace hwy {
+int64_t first_sort_target = 0; // none run yet
+int64_t first_cold_target = 0; // none run yet
+HWY_BEFORE_TEST(BenchSort);
+HWY_EXPORT_AND_TEST_P(BenchSort, BenchAllColdSort);
+#if SORT_BENCH_BASE_AND_PARTITION
+HWY_EXPORT_AND_TEST_P(BenchSort, BenchAllPartition);
+HWY_EXPORT_AND_TEST_P(BenchSort, BenchAllBase);
+#endif
+
+#if !SORT_ONLY_COLD // skip (warms up vector unit for next run)
+HWY_EXPORT_AND_TEST_P(BenchSort, BenchAllSort);
+#endif
+HWY_AFTER_TEST();
+} // namespace hwy
+
+HWY_TEST_MAIN();
+
+#endif // HWY_ONCE
diff --git a/third_party/highway/hwy/contrib/sort/print_network.cc b/third_party/highway/hwy/contrib/sort/print_network.cc
new file mode 100644
index 0000000000..35562c6e60
--- /dev/null
+++ b/third_party/highway/hwy/contrib/sort/print_network.cc
@@ -0,0 +1,90 @@
+// Copyright 2021 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stdio.h>
+
+#include <vector>
+
+#include "third_party/highway/hwy/base.h"
+
+// Based on A.7 in "Entwurf und Implementierung vektorisierter
+// Sortieralgorithmen" and code by Mark Blacher.
+static void PrintMergeNetwork(int rows, int cols) {
+ printf("\n%d x %d:\n", rows, cols);
+ // Powers of two
+ HWY_ASSERT(rows != 0 && (rows & (rows - 1)) == 0);
+ HWY_ASSERT(cols != 0 && (cols & (cols - 1)) == 0);
+ HWY_ASSERT(rows >= 4);
+ HWY_ASSERT(cols >= 2); // otherwise no cross-column merging required
+ HWY_ASSERT(cols <= 16); // SortTraits lacks Reverse32
+
+ // Log(rows) times: sort half of the vectors with reversed groups of the
+ // other half. Group size halves until we are sorting adjacent vectors.
+ int group_size = rows;
+ int num_groups = 1;
+ for (; group_size >= 2; group_size /= 2, num_groups *= 2) {
+ // All vectors except those being reversed. Allows us to group the
+ // ReverseKeys and Sort2 operations, which is easier to read and may help
+ // in-order machines with high-latency ReverseKeys.
+ std::vector<int> all_vi;
+ for (int group = 0; group < num_groups; ++group) {
+ for (int i = 0; i < group_size / 2; ++i) {
+ all_vi.push_back(group * group_size + i);
+ }
+ }
+ for (int vi : all_vi) {
+ const int vr = vi ^ (group_size - 1);
+ printf("v%x = st.ReverseKeys%d(d, v%x);\n", vr, cols, vr);
+ }
+ for (int vi : all_vi) {
+ const int vr = vi ^ (group_size - 1);
+ printf("st.Sort2(d, v%x, v%x);\n", vi, vr);
+ }
+ printf("\n");
+ }
+
+ // Now merge across columns in all vectors.
+ if (cols > 2) {
+ for (int i = 0; i < rows; ++i) {
+ printf("v%x = st.SortPairsReverse%d(d, v%x);\n", i, cols, i);
+ }
+ printf("\n");
+ }
+ if (cols >= 16) {
+ for (int i = 0; i < rows; ++i) {
+ printf("v%x = st.SortPairsDistance4(d, v%x);\n", i, i);
+ }
+ printf("\n");
+ }
+ if (cols >= 8) {
+ for (int i = 0; i < rows; ++i) {
+ printf("v%x = st.SortPairsDistance2(d, v%x);\n", i, i);
+ }
+ printf("\n");
+ }
+ for (int i = 0; i < rows; ++i) {
+ printf("v%x = st.SortPairsDistance1(d, v%x);\n", i, i);
+ }
+ printf("\n");
+}
+
+int main(int /*argc*/, char** /*argv*/) {
+ PrintMergeNetwork(8, 2);
+ PrintMergeNetwork(8, 4);
+ PrintMergeNetwork(16, 4);
+ PrintMergeNetwork(16, 8);
+ PrintMergeNetwork(16, 16);
+ return 0;
+}
diff --git a/third_party/highway/hwy/contrib/sort/result-inl.h b/third_party/highway/hwy/contrib/sort/result-inl.h
index 5f6d2ca8ec..47659cd399 100644
--- a/third_party/highway/hwy/contrib/sort/result-inl.h
+++ b/third_party/highway/hwy/contrib/sort/result-inl.h
@@ -29,7 +29,6 @@
#include "third_party/highway/hwy/aligned_allocator.h"
#include "third_party/highway/hwy/base.h"
-#include "third_party/highway/hwy/contrib/sort/order.h"
#include "third_party/highway/hwy/per_target.h" // DispatchedTarget
#include "third_party/highway/hwy/targets.h" // TargetName
@@ -51,16 +50,17 @@ static inline double SummarizeMeasurements(std::vector<double>& seconds) {
struct SortResult {
SortResult() {}
- SortResult(const Algo algo, Dist dist, size_t num_keys, size_t num_threads,
- double sec, size_t sizeof_key, const char* key_name)
+ SortResult(Algo algo_in, Dist dist_in, size_t num_keys_in,
+ size_t num_threads_in, double sec_in, size_t sizeof_key_in,
+ const char* key_name_in)
: target(DispatchedTarget()),
- algo(algo),
- dist(dist),
- num_keys(num_keys),
- num_threads(num_threads),
- sec(sec),
- sizeof_key(sizeof_key),
- key_name(key_name) {}
+ algo(algo_in),
+ dist(dist_in),
+ num_keys(num_keys_in),
+ num_threads(num_threads_in),
+ sec(sec_in),
+ sizeof_key(sizeof_key_in),
+ key_name(key_name_in) {}
void Print() const {
const double bytes = static_cast<double>(num_keys) *
diff --git a/third_party/highway/hwy/contrib/sort/shared-inl.h b/third_party/highway/hwy/contrib/sort/shared-inl.h
index e534d3ba92..161e7ed6de 100644
--- a/third_party/highway/hwy/contrib/sort/shared-inl.h
+++ b/third_party/highway/hwy/contrib/sort/shared-inl.h
@@ -141,21 +141,24 @@ static_assert(SortConstants::MaxBufBytes<2>(64) <= 1664, "Unexpectedly high");
// vqsort isn't available on HWY_SCALAR, and builds time out on MSVC opt and
// Armv7 debug, and Armv8 GCC 11 asan hits an internal compiler error likely
// due to https://gcc.gnu.org/bugzilla/show_bug.cgi?id=97696. Armv8 Clang
-// hwasan/msan/tsan/asan also fail to build SVE (b/335157772). RVV currently
-// has a compiler issue.
+// hwasan/msan/tsan/asan also fail to build SVE (b/335157772), and SVE2
+// hits a compiler crash, though SVE2_128 is fine.
#undef VQSORT_ENABLED
#undef VQSORT_COMPILER_COMPATIBLE
-#if (HWY_COMPILER_MSVC && !HWY_IS_DEBUG_BUILD) || \
- (HWY_ARCH_ARM_V7 && HWY_IS_DEBUG_BUILD) || \
- (HWY_ARCH_ARM_A64 && HWY_COMPILER_GCC_ACTUAL && HWY_IS_ASAN) || \
- (HWY_ARCH_RISCV)
+#if (HWY_COMPILER_MSVC && !HWY_IS_DEBUG_BUILD) || \
+ (HWY_ARCH_ARM_V7 && HWY_IS_DEBUG_BUILD) || \
+ (HWY_ARCH_ARM_A64 && HWY_IS_ASAN) || \
+ (HWY_ARCH_RISCV && HWY_COMPILER_GCC_ACTUAL < 1400)
#define VQSORT_COMPILER_COMPATIBLE 0
#else
#define VQSORT_COMPILER_COMPATIBLE 1
#endif
-#if (HWY_TARGET == HWY_SCALAR) || !VQSORT_COMPILER_COMPATIBLE
+#if (HWY_TARGET == HWY_SCALAR) || !VQSORT_COMPILER_COMPATIBLE || \
+ ((HWY_TARGET & HWY_ALL_SVE) && defined(SAFESTACK_SANITIZER)) || \
+ ((HWY_TARGET & HWY_ALL_SVE) && HWY_HAVE_SCALABLE) || \
+ defined(HWY_DISABLE_VQSORT)
#define VQSORT_ENABLED 0
#else
#define VQSORT_ENABLED 1
diff --git a/third_party/highway/hwy/contrib/sort/sort_test.cc b/third_party/highway/hwy/contrib/sort/sort_test.cc
new file mode 100644
index 0000000000..d1ec181ecd
--- /dev/null
+++ b/third_party/highway/hwy/contrib/sort/sort_test.cc
@@ -0,0 +1,277 @@
+// Copyright 2021 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stdint.h>
+#include <stdio.h>
+
+#include <numeric> // std::iota
+#include <random>
+#include <vector>
+
+#if !defined(HWY_DISABLED_TARGETS) && HWY_IS_DEBUG_BUILD
+#define HWY_DISABLED_TARGETS (HWY_SSE2 | HWY_SSSE3)
+#endif
+
+#include "third_party/highway/hwy/aligned_allocator.h" // IsAligned
+#include "third_party/highway/hwy/base.h"
+#include "third_party/highway/hwy/contrib/sort/vqsort.h"
+#include "third_party/highway/hwy/contrib/thread_pool/thread_pool.h"
+#include "third_party/highway/hwy/contrib/thread_pool/topology.h"
+#include "third_party/highway/hwy/per_target.h"
+
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "hwy/contrib/sort/sort_test.cc"
+#include "third_party/highway/hwy/foreach_target.h" // IWYU pragma: keep
+#include "third_party/highway/hwy/highway.h"
+// After highway.h
+#include "third_party/highway/hwy/contrib/sort/algo-inl.h"
+#include "third_party/highway/hwy/contrib/sort/result-inl.h"
+#include "third_party/highway/hwy/contrib/sort/vqsort-inl.h" // BaseCase
+#include "third_party/highway/hwy/print-inl.h"
+#include "third_party/highway/hwy/tests/test_util-inl.h"
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+namespace {
+
+using detail::OrderAscending;
+using detail::OrderAscendingKV64;
+using detail::OrderDescendingKV64;
+using detail::SharedTraits;
+using detail::TraitsLane;
+
+#if !HAVE_INTEL && HWY_TARGET != HWY_SCALAR
+using detail::OrderAscending128;
+using detail::OrderAscendingKV128;
+using detail::OrderDescending128;
+using detail::OrderDescendingKV128;
+using detail::Traits128;
+#endif // !HAVE_INTEL && HWY_TARGET != HWY_SCALAR
+
+template <typename Key>
+void TestSortIota(hwy::ThreadPool& pool) {
+ pool.Run(128, 300, [](uint64_t task, size_t /*thread*/) {
+ const size_t num = static_cast<size_t>(task);
+ Key keys[300];
+ std::iota(keys, keys + num, Key{0});
+ VQSort(keys, num, hwy::SortAscending());
+ for (size_t i = 0; i < num; ++i) {
+ if (keys[i] != static_cast<Key>(i)) {
+ HWY_ABORT("num %zu i %zu: not iota, got %.0f\n", num, i,
+ static_cast<double>(keys[i]));
+ }
+ }
+ });
+}
+
+void TestAllSortIota() {
+#if VQSORT_ENABLED
+ hwy::ThreadPool pool(hwy::HaveThreadingSupport() ? 4 : 0);
+ TestSortIota<uint32_t>(pool);
+ TestSortIota<int32_t>(pool);
+ if (hwy::HaveInteger64()) {
+ TestSortIota<int64_t>(pool);
+ TestSortIota<uint64_t>(pool);
+ }
+ TestSortIota<float>(pool);
+ if (hwy::HaveFloat64()) {
+ TestSortIota<double>(pool);
+ }
+#endif
+}
+
+// Supports full/partial sort and select.
+template <class Traits>
+void TestAnySort(const std::vector<Algo>& algos, size_t num_lanes) {
+// Workaround for stack overflow on clang-cl (/F 8388608 does not help).
+#if defined(_MSC_VER)
+ return;
+#endif
+ using Order = typename Traits::Order;
+ using LaneType = typename Traits::LaneType;
+ using KeyType = typename Traits::KeyType;
+ SharedState shared;
+ SharedTraits<Traits> st;
+
+ constexpr size_t kLPK = st.LanesPerKey();
+ num_lanes = hwy::RoundUpTo(num_lanes, kLPK);
+ const size_t num_keys = num_lanes / kLPK;
+
+ std::mt19937 rng(42);
+ std::uniform_int_distribution<size_t> k_dist(1, num_keys - 1);
+
+ constexpr size_t kMaxMisalign = 16;
+ auto aligned =
+ hwy::AllocateAligned<LaneType>(kMaxMisalign + num_lanes + kMaxMisalign);
+ HWY_ASSERT(aligned);
+
+ for (Algo algo : algos) {
+ if (IsVQ(algo) && !VQSORT_ENABLED) continue;
+
+ for (Dist dist : AllDist()) {
+ for (size_t misalign :
+ {size_t{0}, size_t{kLPK}, size_t{3 * kLPK}, kMaxMisalign / 2}) {
+ for (size_t k_rep = 0; k_rep < AdjustedReps(10); ++k_rep) {
+ // Skip reps for full sort because they do not use k.
+ if (!IsPartialSort(algo) && !IsSelect(algo) && k_rep > 0) break;
+
+ LaneType* lanes = aligned.get() + misalign;
+ HWY_ASSERT(hwy::IsAligned(lanes, sizeof(KeyType)));
+ KeyType* keys = HWY_RCAST_ALIGNED(KeyType*, lanes);
+
+ // Set up red zones before/after the keys to sort
+ for (size_t i = 0; i < misalign; ++i) {
+ aligned[i] = hwy::LowestValue<LaneType>();
+ }
+ for (size_t i = 0; i < kMaxMisalign; ++i) {
+ lanes[num_lanes + i] = hwy::HighestValue<LaneType>();
+ }
+ detail::MaybePoison(aligned.get(), misalign * sizeof(LaneType));
+ detail::MaybePoison(lanes + num_lanes,
+ kMaxMisalign * sizeof(LaneType));
+
+ InputStats<LaneType> input_stats =
+ GenerateInput(dist, lanes, num_lanes);
+ ReferenceSortVerifier<Traits> reference_verifier(lanes, num_lanes);
+ const size_t k_keys = k_dist(rng);
+ Run(algo, keys, num_keys, shared, /*thread=*/0, k_keys, Order());
+ reference_verifier(algo, lanes, k_keys);
+ SortOrderVerifier<Traits>()(algo, input_stats, lanes, num_keys,
+ k_keys);
+
+ // Check red zones
+ detail::MaybeUnpoison(aligned.get(), misalign);
+ detail::MaybeUnpoison(lanes + num_lanes, kMaxMisalign);
+ for (size_t i = 0; i < misalign; ++i) {
+ if (aligned[i] != hwy::LowestValue<LaneType>())
+ HWY_ABORT("Overrun left at %d\n", static_cast<int>(i));
+ }
+ for (size_t i = num_lanes; i < num_lanes + kMaxMisalign; ++i) {
+ if (lanes[i] != hwy::HighestValue<LaneType>())
+ HWY_ABORT("Overrun right at %d\n", static_cast<int>(i));
+ }
+ } // k_rep
+ } // misalign
+ } // dist
+ } // algo
+}
+
+// Calls TestAnySort with all traits.
+void CallAllSortTraits(const std::vector<Algo>& algos, size_t num_lanes) {
+#if !HAVE_INTEL
+ TestAnySort<TraitsLane<OrderAscending<int16_t>>>(algos, num_lanes);
+ TestAnySort<TraitsLane<OtherOrder<uint16_t>>>(algos, num_lanes);
+#endif
+
+ TestAnySort<TraitsLane<OtherOrder<int32_t>>>(algos, num_lanes);
+ TestAnySort<TraitsLane<OtherOrder<uint32_t>>>(algos, num_lanes);
+
+ TestAnySort<TraitsLane<OrderAscending<int64_t>>>(algos, num_lanes);
+ TestAnySort<TraitsLane<OrderAscending<uint64_t>>>(algos, num_lanes);
+
+ // WARNING: for float types, SIMD comparisons will flush denormals to
+ // zero, causing mismatches with scalar sorts. In this test, we avoid
+ // generating denormal inputs.
+#if HWY_HAVE_FLOAT16 // #if protects algo-inl.h's GenerateRandom
+ // Must also check whether the dynamic-dispatch target supports float16_t!
+ if (hwy::HaveFloat16()) {
+ TestAnySort<TraitsLane<OrderAscending<float16_t>>>(algos, num_lanes);
+ }
+#endif
+ TestAnySort<TraitsLane<OrderAscending<float>>>(algos, num_lanes);
+#if HWY_HAVE_FLOAT64 // #if protects algo-inl.h's GenerateRandom
+ // Must also check whether the dynamic-dispatch target supports float64!
+ if (hwy::HaveFloat64()) {
+ TestAnySort<TraitsLane<OtherOrder<double>>>(algos, num_lanes);
+ }
+#endif
+
+ // Other algorithms do not support 128-bit nor KV keys.
+#if !HAVE_VXSORT && !HAVE_INTEL
+ TestAnySort<TraitsLane<OrderAscendingKV64>>(algos, num_lanes);
+ TestAnySort<TraitsLane<OrderDescendingKV64>>(algos, num_lanes);
+
+// 128-bit keys require 128-bit SIMD.
+#if HWY_TARGET != HWY_SCALAR
+ TestAnySort<Traits128<OrderAscending128>>(algos, num_lanes);
+ TestAnySort<Traits128<OrderDescending128>>(algos, num_lanes);
+
+ TestAnySort<Traits128<OrderAscendingKV128>>(algos, num_lanes);
+ TestAnySort<Traits128<OrderDescendingKV128>>(algos, num_lanes);
+#endif // HWY_TARGET != HWY_SCALAR
+#endif // !HAVE_VXSORT && !HAVE_INTEL
+}
+
+void TestAllSort() {
+ const std::vector<Algo> algos{
+#if HAVE_AVX2SORT
+ Algo::kSEA,
+#endif
+#if HAVE_IPS4O
+ Algo::kIPS4O,
+#endif
+#if HAVE_PDQSORT
+ Algo::kPDQ,
+#endif
+#if HAVE_SORT512
+ Algo::kSort512,
+#endif
+ Algo::kVQSort, Algo::kHeapSort,
+ };
+
+ for (int num : {129, 504, 3 * 1000, 14567}) {
+ const size_t num_lanes = AdjustedReps(static_cast<size_t>(num));
+ CallAllSortTraits(algos, num_lanes);
+ }
+}
+
+void TestAllPartialSort() {
+ const std::vector<Algo> algos{Algo::kVQPartialSort, Algo::kHeapPartialSort};
+
+ for (int num : {129, 504, 3 * 1000, 14567}) {
+ const size_t num_lanes = AdjustedReps(static_cast<size_t>(num));
+ CallAllSortTraits(algos, num_lanes);
+ }
+}
+
+void TestAllSelect() {
+ const std::vector<Algo> algos{Algo::kVQSelect, Algo::kHeapSelect};
+
+ for (int num : {129, 504, 3 * 1000, 14567}) {
+ const size_t num_lanes = AdjustedReps(static_cast<size_t>(num));
+ CallAllSortTraits(algos, num_lanes);
+ }
+}
+
+} // namespace
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace hwy {
+namespace {
+HWY_BEFORE_TEST(SortTest);
+HWY_EXPORT_AND_TEST_P(SortTest, TestAllSortIota);
+HWY_EXPORT_AND_TEST_P(SortTest, TestAllSort);
+HWY_EXPORT_AND_TEST_P(SortTest, TestAllSelect);
+HWY_EXPORT_AND_TEST_P(SortTest, TestAllPartialSort);
+HWY_AFTER_TEST();
+} // namespace
+} // namespace hwy
+HWY_TEST_MAIN();
+#endif // HWY_ONCE
diff --git a/third_party/highway/hwy/contrib/sort/sort_unit_test.cc b/third_party/highway/hwy/contrib/sort/sort_unit_test.cc
new file mode 100644
index 0000000000..c5b39297d5
--- /dev/null
+++ b/third_party/highway/hwy/contrib/sort/sort_unit_test.cc
@@ -0,0 +1,567 @@
+// Copyright 2021 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stdio.h>
+
+#include <unordered_map>
+#include <vector>
+
+#include "third_party/highway/hwy/aligned_allocator.h" // IsAligned
+#include "third_party/highway/hwy/base.h"
+#include "third_party/highway/hwy/contrib/sort/vqsort.h"
+#include "third_party/highway/hwy/detect_compiler_arch.h"
+
+// clang-format off
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "hwy/contrib/sort/sort_unit_test.cc" // NOLINT
+// clang-format on
+#include "third_party/highway/hwy/foreach_target.h" // IWYU pragma: keep
+#include "third_party/highway/hwy/highway.h"
+// After highway.h
+#include "third_party/highway/hwy/contrib/sort/algo-inl.h"
+#include "third_party/highway/hwy/contrib/sort/result-inl.h"
+#include "third_party/highway/hwy/contrib/sort/traits128-inl.h"
+#include "third_party/highway/hwy/contrib/sort/vqsort-inl.h" // BaseCase
+#include "third_party/highway/hwy/print-inl.h"
+#include "third_party/highway/hwy/tests/test_util-inl.h"
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+namespace {
+
+using detail::OrderAscending;
+using detail::SharedTraits;
+using detail::TraitsLane;
+
+#if !HAVE_INTEL && HWY_TARGET != HWY_SCALAR
+using detail::OrderAscending128;
+using detail::OrderDescending128;
+using detail::Traits128;
+#endif // !HAVE_INTEL && HWY_TARGET != HWY_SCALAR
+
+#if VQSORT_ENABLED || HWY_IDE
+
+// Verify the corner cases of LargerSortValue/SmallerSortValue, used to
+// implement PrevValue/NextValue.
+struct TestFloatLargerSmaller {
+ template <typename T, class D>
+ HWY_NOINLINE void operator()(T, D d) {
+ const Vec<D> p0 = Zero(d);
+ const Vec<D> p1 = Set(d, ConvertScalarTo<T>(1));
+ const Vec<D> pinf = Inf(d);
+ const Vec<D> peps = Set(d, hwy::Epsilon<T>());
+ const Vec<D> pmax = Set(d, hwy::HighestValue<T>());
+
+ const Vec<D> n0 = Neg(p0);
+ const Vec<D> n1 = Neg(p1);
+ const Vec<D> ninf = Neg(pinf);
+ const Vec<D> neps = Neg(peps);
+ const Vec<D> nmax = Neg(pmax);
+
+ // Larger(0) is the smallest subnormal, typically eps * FLT_MIN.
+ const RebindToUnsigned<D> du;
+ const Vec<D> psub = BitCast(d, Set(du, 1));
+ const Vec<D> nsub = Neg(psub);
+ HWY_ASSERT(AllTrue(d, Lt(psub, peps)));
+ HWY_ASSERT(AllTrue(d, Gt(nsub, neps)));
+
+ // +/-0 moves to +/- smallest subnormal.
+ HWY_ASSERT_VEC_EQ(d, psub, detail::LargerSortValue(d, p0));
+ HWY_ASSERT_VEC_EQ(d, nsub, detail::SmallerSortValue(d, p0));
+ HWY_ASSERT_VEC_EQ(d, psub, detail::LargerSortValue(d, n0));
+ HWY_ASSERT_VEC_EQ(d, nsub, detail::SmallerSortValue(d, n0));
+
+ // The next magnitude larger than 1 is (1 + eps) by definition.
+ HWY_ASSERT_VEC_EQ(d, Add(p1, peps), detail::LargerSortValue(d, p1));
+ HWY_ASSERT_VEC_EQ(d, Add(n1, neps), detail::SmallerSortValue(d, n1));
+ // 1-eps and -1+eps are slightly different, but we can still ensure the
+ // next values are less than 1 / greater than -1.
+ HWY_ASSERT(AllTrue(d, Gt(p1, detail::SmallerSortValue(d, p1))));
+ HWY_ASSERT(AllTrue(d, Lt(n1, detail::LargerSortValue(d, n1))));
+
+ // Even for large (finite) values, we can move toward/away from infinity.
+ HWY_ASSERT_VEC_EQ(d, pinf, detail::LargerSortValue(d, pmax));
+ HWY_ASSERT_VEC_EQ(d, ninf, detail::SmallerSortValue(d, nmax));
+ HWY_ASSERT(AllTrue(d, Gt(pmax, detail::SmallerSortValue(d, pmax))));
+ HWY_ASSERT(AllTrue(d, Lt(nmax, detail::LargerSortValue(d, nmax))));
+
+ // For infinities, results are unchanged or the extremal finite value.
+ HWY_ASSERT_VEC_EQ(d, pinf, detail::LargerSortValue(d, pinf));
+ HWY_ASSERT_VEC_EQ(d, pmax, detail::SmallerSortValue(d, pinf));
+ HWY_ASSERT_VEC_EQ(d, nmax, detail::LargerSortValue(d, ninf));
+ HWY_ASSERT_VEC_EQ(d, ninf, detail::SmallerSortValue(d, ninf));
+ }
+};
+HWY_NOINLINE void TestAllFloatLargerSmaller() {
+ ForFloatTypesDynamic(ForPartialVectors<TestFloatLargerSmaller>());
+}
+
+// Previously, LastValue was the largest normal float, so we injected that
+// value into arrays containing only infinities. Ensure that does not happen.
+struct TestFloatInf {
+ template <typename T, class D>
+ HWY_NOINLINE void operator()(T, D d) {
+ const size_t N = Lanes(d);
+ const size_t num = N * 3;
+ auto in = hwy::AllocateAligned<T>(num);
+ HWY_ASSERT(in);
+ Fill(d, GetLane(Inf(d)), num, in.get());
+ VQSort(in.get(), num, SortAscending());
+ for (size_t i = 0; i < num; i += N) {
+ HWY_ASSERT(AllTrue(d, IsInf(LoadU(d, in.get() + i))));
+ }
+ }
+};
+
+HWY_NOINLINE void TestAllFloatInf() {
+ // TODO(janwas): bfloat16_t not yet supported.
+ ForFloatTypesDynamic(ForPartialVectors<TestFloatInf>());
+}
+
+template <class Traits>
+static HWY_NOINLINE void TestMedian3() {
+ using LaneType = typename Traits::LaneType;
+ using D = CappedTag<LaneType, 1>;
+ SharedTraits<Traits> st;
+ const D d;
+ using V = Vec<D>;
+ for (uint32_t bits = 0; bits < 8; ++bits) {
+ const V v0 = Set(d, LaneType{(bits & (1u << 0)) ? 1u : 0u});
+ const V v1 = Set(d, LaneType{(bits & (1u << 1)) ? 1u : 0u});
+ const V v2 = Set(d, LaneType{(bits & (1u << 2)) ? 1u : 0u});
+ const LaneType m = GetLane(detail::MedianOf3(st, v0, v1, v2));
+ // If at least half(rounded up) of bits are 1, so is the median.
+ const size_t count = PopCount(bits);
+ HWY_ASSERT_EQ((count >= 2) ? static_cast<LaneType>(1) : 0, m);
+ }
+}
+
+HWY_NOINLINE void TestAllMedian() {
+ TestMedian3<TraitsLane<OrderAscending<uint64_t> > >();
+}
+
+template <class Traits>
+static HWY_NOINLINE void TestBaseCaseAscDesc() {
+ using LaneType = typename Traits::LaneType;
+ SharedTraits<Traits> st;
+ const SortTag<LaneType> d;
+ const size_t N = Lanes(d);
+ constexpr size_t N1 = st.LanesPerKey();
+ const size_t base_case_num = SortConstants::BaseCaseNumLanes<N1>(N);
+
+ constexpr int kDebug = 0;
+ auto aligned_lanes = hwy::AllocateAligned<LaneType>(N + base_case_num + N);
+ auto buf = hwy::AllocateAligned<LaneType>(base_case_num + 2 * N);
+ HWY_ASSERT(aligned_lanes && buf);
+
+ std::vector<size_t> lengths;
+ lengths.push_back(HWY_MAX(1, N1));
+ lengths.push_back(3 * N1);
+ lengths.push_back(base_case_num / 2);
+ lengths.push_back(base_case_num / 2 + N1);
+ lengths.push_back(base_case_num - N1);
+ lengths.push_back(base_case_num);
+
+ std::vector<size_t> misalignments;
+ misalignments.push_back(0);
+ misalignments.push_back(1);
+ if (N >= 6) misalignments.push_back(N / 2 - 1);
+ misalignments.push_back(N / 2);
+ misalignments.push_back(N / 2 + 1);
+ misalignments.push_back(HWY_MIN(2 * N / 3 + 3, size_t{N - 1}));
+
+ for (bool asc : {false, true}) {
+ for (size_t len : lengths) {
+ for (size_t misalign : misalignments) {
+ LaneType* HWY_RESTRICT lanes = aligned_lanes.get() + misalign;
+ if (kDebug) {
+ printf("============%s asc %d N1 %d len %d misalign %d\n",
+ st.KeyString(), asc, static_cast<int>(N1),
+ static_cast<int>(len), static_cast<int>(misalign));
+ }
+
+ for (size_t i = 0; i < misalign; ++i) {
+ aligned_lanes[i] = hwy::LowestValue<LaneType>();
+ }
+ InputStats<LaneType> input_stats;
+ for (size_t i = 0; i < len; ++i) {
+ lanes[i] = asc ? static_cast<LaneType>(LaneType(i) + 1)
+ : static_cast<LaneType>(LaneType(len) - LaneType(i));
+ input_stats.Notify(lanes[i]);
+ if (kDebug >= 2) {
+ printf("%3zu: %f\n", i, static_cast<double>(lanes[i]));
+ }
+ }
+ for (size_t i = len; i < base_case_num + N; ++i) {
+ lanes[i] = hwy::LowestValue<LaneType>();
+ }
+
+ detail::BaseCase(d, st, lanes, len, buf.get());
+
+ if (kDebug >= 2) {
+ printf("out>>>>>>\n");
+ for (size_t i = 0; i < len; ++i) {
+ printf("%3zu: %f\n", i, static_cast<double>(lanes[i]));
+ }
+ }
+
+ SortOrderVerifier<Traits>()(Algo::kVQSort, input_stats, lanes, len / N1,
+ len / N1);
+ for (size_t i = 0; i < misalign; ++i) {
+ if (aligned_lanes[i] != hwy::LowestValue<LaneType>())
+ HWY_ABORT("Overrun misalign at %d\n", static_cast<int>(i));
+ }
+ for (size_t i = len; i < base_case_num + N; ++i) {
+ if (lanes[i] != hwy::LowestValue<LaneType>())
+ HWY_ABORT("Overrun right at %d\n", static_cast<int>(i));
+ }
+ } // misalign
+ } // len
+ } // asc
+}
+
+template <class Traits>
+static HWY_NOINLINE void TestBaseCase01() {
+ using LaneType = typename Traits::LaneType;
+ SharedTraits<Traits> st;
+ const SortTag<LaneType> d;
+ const size_t N = Lanes(d);
+ constexpr size_t N1 = st.LanesPerKey();
+ const size_t base_case_num = SortConstants::BaseCaseNumLanes<N1>(N);
+
+ constexpr int kDebug = 0;
+ auto lanes = hwy::AllocateAligned<LaneType>(base_case_num + N);
+ auto buf = hwy::AllocateAligned<LaneType>(base_case_num + 2 * N);
+ HWY_ASSERT(lanes && buf);
+
+ std::vector<size_t> lengths;
+ lengths.push_back(HWY_MAX(1, N1));
+ lengths.push_back(3 * N1);
+ lengths.push_back(base_case_num / 2);
+ lengths.push_back(base_case_num / 2 + N1);
+ lengths.push_back(base_case_num - N1);
+ lengths.push_back(base_case_num);
+
+ for (size_t len : lengths) {
+ if (kDebug) {
+ printf("============%s 01 N1 %d len %d\n", st.KeyString(),
+ static_cast<int>(N1), static_cast<int>(len));
+ }
+ const uint64_t kMaxBits = AdjustedLog2Reps(HWY_MIN(len, size_t{14}));
+ for (uint64_t bits = 0; bits < ((1ull << kMaxBits) - 1); ++bits) {
+ InputStats<LaneType> input_stats;
+ for (size_t i = 0; i < len; ++i) {
+ lanes[i] = (i < 64 && (bits & (1ull << i))) ? 1 : 0;
+ input_stats.Notify(lanes[i]);
+ if (kDebug >= 2) {
+ printf("%3zu: %f\n", i, static_cast<double>(lanes[i]));
+ }
+ }
+ for (size_t i = len; i < base_case_num + N; ++i) {
+ lanes[i] = hwy::LowestValue<LaneType>();
+ }
+
+ detail::BaseCase(d, st, lanes.get(), len, buf.get());
+
+ if (kDebug >= 2) {
+ printf("out>>>>>>\n");
+ for (size_t i = 0; i < len; ++i) {
+ printf("%3zu: %f\n", i, static_cast<double>(lanes[i]));
+ }
+ }
+
+ SortOrderVerifier<Traits>()(Algo::kVQSort, input_stats, lanes.get(),
+ len / N1, len / N1);
+ for (size_t i = len; i < base_case_num + N; ++i) {
+ if (lanes[i] != hwy::LowestValue<LaneType>())
+ HWY_ABORT("Overrun right at %d\n", static_cast<int>(i));
+ }
+ } // bits
+ } // len
+}
+
+template <class Traits>
+static HWY_NOINLINE void TestBaseCase() {
+ TestBaseCaseAscDesc<Traits>();
+ TestBaseCase01<Traits>();
+}
+
+HWY_NOINLINE void TestAllBaseCase() {
+ // Workaround for stack overflow on MSVC debug.
+#if defined(_MSC_VER)
+ return;
+#endif
+
+ TestBaseCase<TraitsLane<OrderAscending<int32_t> > >();
+ TestBaseCase<TraitsLane<OtherOrder<int64_t> > >();
+#if !HAVE_INTEL
+ TestBaseCase<Traits128<OrderAscending128> >();
+ TestBaseCase<Traits128<OrderDescending128> >();
+#endif
+}
+
+template <class Traits>
+static HWY_NOINLINE void VerifyPartition(
+ Traits st, typename Traits::LaneType* HWY_RESTRICT lanes, size_t left,
+ size_t border, size_t right, const size_t N1,
+ const typename Traits::LaneType* pivot) {
+ /* for (size_t i = left; i < right; ++i) {
+ if (i == border) printf("--\n");
+ printf("%4zu: %3d\n", i, lanes[i]);
+ }*/
+
+ HWY_ASSERT(left % N1 == 0);
+ HWY_ASSERT(border % N1 == 0);
+ HWY_ASSERT(right % N1 == 0);
+ constexpr bool kAscending = Traits::Order::IsAscending();
+ for (size_t i = left; i < border; i += N1) {
+ if (st.Compare1(pivot, lanes + i)) {
+ HWY_ABORT(
+ "%s: asc %d left[%d] piv %.0f %.0f compares before %.0f %.0f "
+ "border %d",
+ st.KeyString(), kAscending, static_cast<int>(i),
+ static_cast<double>(pivot[1]), static_cast<double>(pivot[0]),
+ static_cast<double>(lanes[i + 1]), static_cast<double>(lanes[i + 0]),
+ static_cast<int>(border));
+ }
+ }
+ for (size_t i = border; i < right; i += N1) {
+ if (!st.Compare1(pivot, lanes + i)) {
+ HWY_ABORT(
+ "%s: asc %d right[%d] piv %.0f %.0f compares after %.0f %.0f "
+ "border %d",
+ st.KeyString(), kAscending, static_cast<int>(i),
+ static_cast<double>(pivot[1]), static_cast<double>(pivot[0]),
+ static_cast<double>(lanes[i + 1]), static_cast<double>(lanes[i]),
+ static_cast<int>(border));
+ }
+ }
+}
+
+template <class Traits>
+static HWY_NOINLINE void TestPartition() {
+ using LaneType = typename Traits::LaneType;
+ // See HandleSpecialCases and HWY_ASSERT below.
+ const CappedTag<LaneType, 64 / sizeof(LaneType)> d;
+ SharedTraits<Traits> st;
+ constexpr bool kAscending = Traits::Order::IsAscending();
+ const size_t N = Lanes(d);
+ constexpr int kDebug = 0;
+ constexpr size_t N1 = st.LanesPerKey();
+ const size_t base_case_num = SortConstants::BaseCaseNumLanes<N1>(N);
+ HWY_ASSERT(2 * N <= base_case_num); // See HandleSpecialCases
+
+ // left + len + align
+ const size_t total = 32 + (base_case_num + 4 * HWY_MAX(N, 4)) + 2 * N;
+ auto aligned_lanes = hwy::AllocateAligned<LaneType>(total);
+ HWY_ASSERT(aligned_lanes);
+ HWY_ALIGN LaneType buf[SortConstants::BufBytes<LaneType, N1>(HWY_MAX_BYTES) /
+ sizeof(LaneType)];
+
+ for (bool in_asc : {false, true}) {
+ for (int left_i : {0, 1, 7, 8, 30, 31}) {
+ const size_t left = static_cast<size_t>(left_i) & ~(N1 - 1);
+ for (size_t ofs :
+ {N, N + 3, 2 * N, 2 * N + 2, 2 * N + 3, 3 * N - 1, 4 * N - 2}) {
+ const size_t len = (base_case_num + ofs) & ~(N1 - 1);
+ for (LaneType pivot1 : {LaneType(0), LaneType(len / 3),
+ LaneType(2 * len / 3), LaneType(len)}) {
+ const LaneType pivot2[2] = {pivot1, 0};
+ const auto pivot = st.SetKey(d, pivot2);
+ for (size_t misalign = 0; misalign < N; misalign += N1) {
+ LaneType* HWY_RESTRICT lanes = aligned_lanes.get() + misalign;
+ const size_t right = left + len;
+ if (kDebug) {
+ printf(
+ "=========%s asc %d left %d len %d right %d piv %.0f %.0f\n",
+ st.KeyString(), kAscending, static_cast<int>(left),
+ static_cast<int>(len), static_cast<int>(right),
+ static_cast<double>(pivot2[1]),
+ static_cast<double>(pivot2[0]));
+ }
+
+ for (size_t i = 0; i < misalign; ++i) {
+ aligned_lanes[i] = hwy::LowestValue<LaneType>();
+ }
+ for (size_t i = 0; i < left; ++i) {
+ lanes[i] = hwy::LowestValue<LaneType>();
+ }
+ std::unordered_map<LaneType, int> counts;
+ for (size_t i = left; i < right; ++i) {
+ lanes[i] = static_cast<LaneType>(
+ in_asc ? LaneType(i + 1) - static_cast<LaneType>(left)
+ : static_cast<LaneType>(right) - LaneType(i));
+ ++counts[lanes[i]];
+ if (kDebug >= 2) {
+ printf("%3zu: %f\n", i, static_cast<double>(lanes[i]));
+ }
+ }
+ for (size_t i = right; i < total - misalign; ++i) {
+ lanes[i] = hwy::LowestValue<LaneType>();
+ }
+
+ size_t border = left + detail::Partition(d, st, lanes + left,
+ right - left, pivot, buf);
+
+ if (kDebug >= 2) {
+ printf("out>>>>>>\n");
+ for (size_t i = left; i < right; ++i) {
+ printf("%3zu: %f\n", i, static_cast<double>(lanes[i]));
+ }
+ for (size_t i = right; i < total - misalign; ++i) {
+ printf("%3zu: sentinel %f\n", i, static_cast<double>(lanes[i]));
+ }
+ }
+ for (size_t i = left; i < right; ++i) {
+ --counts[lanes[i]];
+ }
+ for (auto kv : counts) {
+ if (kv.second != 0) {
+ PrintValue(kv.first);
+ HWY_ABORT("Incorrect count %d\n", kv.second);
+ }
+ }
+ VerifyPartition(st, lanes, left, border, right, N1, pivot2);
+ for (size_t i = 0; i < misalign; ++i) {
+ if (aligned_lanes[i] != hwy::LowestValue<LaneType>())
+ HWY_ABORT("Overrun misalign at %d\n", static_cast<int>(i));
+ }
+ for (size_t i = 0; i < left; ++i) {
+ if (lanes[i] != hwy::LowestValue<LaneType>())
+ HWY_ABORT("Overrun left at %d\n", static_cast<int>(i));
+ }
+ for (size_t i = right; i < total - misalign; ++i) {
+ if (lanes[i] != hwy::LowestValue<LaneType>())
+ HWY_ABORT("Overrun right at %d\n", static_cast<int>(i));
+ }
+ } // misalign
+ } // pivot
+ } // len
+ } // left
+ } // asc
+}
+
+#undef HWY_BROKEN_U128
+#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1400 && \
+ HWY_TARGET == HWY_RVV
+#define HWY_BROKEN_U128 1
+#else
+#define HWY_BROKEN_U128 0
+#endif
+
+HWY_NOINLINE void TestAllPartition() {
+ TestPartition<TraitsLane<OtherOrder<int32_t> > >();
+
+#if !HAVE_INTEL && !HWY_BROKEN_U128
+ TestPartition<Traits128<OrderAscending128> >();
+#endif
+
+#if !HWY_IS_DEBUG_BUILD
+ TestPartition<TraitsLane<OrderAscending<int16_t> > >();
+ TestPartition<TraitsLane<OrderAscending<int64_t> > >();
+ TestPartition<TraitsLane<OtherOrder<float> > >();
+ // OK to check current target, not using dynamic dispatch here.
+#if HWY_HAVE_FLOAT64
+ TestPartition<TraitsLane<OtherOrder<double> > >();
+#endif
+#if !HAVE_INTEL && !HWY_BROKEN_U128
+ TestPartition<Traits128<OrderDescending128> >();
+#endif
+#endif
+}
+
+// (used for sample selection for choosing a pivot)
+template <typename TU>
+static HWY_NOINLINE void TestRandomGenerator() {
+ static_assert(!hwy::IsSigned<TU>(), "");
+ SortTag<TU> du;
+ const size_t N = Lanes(du);
+
+ uint64_t* state = GetGeneratorState();
+
+ // Ensure lower and upper 32 bits are uniformly distributed.
+ uint64_t sum_lo = 0, sum_hi = 0;
+ for (size_t i = 0; i < 1000; ++i) {
+ const uint64_t bits = RandomBits(state);
+ sum_lo += bits & 0xFFFFFFFF;
+ sum_hi += bits >> 32;
+ }
+ {
+ const double expected = 1000 * (1ULL << 31);
+ HWY_ASSERT(0.9 * expected <= static_cast<double>(sum_lo) &&
+ static_cast<double>(sum_lo) <= 1.1 * expected);
+ HWY_ASSERT(0.9 * expected <= static_cast<double>(sum_hi) &&
+ static_cast<double>(sum_hi) <= 1.1 * expected);
+ }
+
+ const size_t lanes_per_block = HWY_MAX(64 / sizeof(TU), N); // power of two
+
+ for (uint32_t num_blocks = 2; num_blocks < 100000;
+ num_blocks = 3 * num_blocks / 2) {
+ // Generate some numbers and ensure all are in range
+ uint64_t sum = 0;
+ constexpr size_t kReps = 10000;
+ for (size_t rep = 0; rep < kReps; ++rep) {
+ const uint32_t bits = RandomBits(state) & 0xFFFFFFFF;
+ const size_t index = detail::RandomChunkIndex(num_blocks, bits);
+ HWY_ASSERT(((index + 1) * lanes_per_block) <=
+ num_blocks * lanes_per_block);
+
+ sum += index;
+ }
+
+ // Also ensure the mean is near the middle of the range
+ const double expected = (num_blocks - 1) / 2.0;
+ const double actual = static_cast<double>(sum) / kReps;
+ HWY_ASSERT(0.9 * expected <= actual && actual <= 1.1 * expected);
+ }
+}
+
+HWY_NOINLINE void TestAllGenerator() {
+ TestRandomGenerator<uint32_t>();
+ TestRandomGenerator<uint64_t>();
+}
+
+#else
+static void TestAllFloatLargerSmaller() {}
+static void TestAllFloatInf() {}
+static void TestAllMedian() {}
+static void TestAllBaseCase() {}
+static void TestAllPartition() {}
+static void TestAllGenerator() {}
+#endif // VQSORT_ENABLED
+
+} // namespace
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace hwy {
+namespace {
+HWY_BEFORE_TEST(SortTest);
+HWY_EXPORT_AND_TEST_P(SortTest, TestAllFloatLargerSmaller);
+HWY_EXPORT_AND_TEST_P(SortTest, TestAllFloatInf);
+HWY_EXPORT_AND_TEST_P(SortTest, TestAllMedian);
+HWY_EXPORT_AND_TEST_P(SortTest, TestAllBaseCase);
+HWY_EXPORT_AND_TEST_P(SortTest, TestAllPartition);
+HWY_EXPORT_AND_TEST_P(SortTest, TestAllGenerator);
+HWY_AFTER_TEST();
+} // namespace
+} // namespace hwy
+HWY_TEST_MAIN();
+#endif // HWY_ONCE
diff --git a/third_party/highway/hwy/contrib/sort/sorting_networks-inl.h b/third_party/highway/hwy/contrib/sort/sorting_networks-inl.h
index 2158e7ea99..8541d3a6b3 100644
--- a/third_party/highway/hwy/contrib/sort/sorting_networks-inl.h
+++ b/third_party/highway/hwy/contrib/sort/sorting_networks-inl.h
@@ -197,6 +197,7 @@ HWY_INLINE void Sort16(D d, Traits st, V& v0, V& v1, V& v2, V& v3, V& v4, V& v5,
// st.SortPairsDistance1 to compile. `if constexpr` in the caller would also
// work, but is not available in C++11. We write out the (unused) argument types
// rather than `...` because GCC 9 (but not 10) fails to compile with `...`.
+// TODO: use C++17.
template <size_t kKeysPerVector, class D, class Traits, class V,
HWY_IF_LANES_LE(kKeysPerVector, 1)>
@@ -891,6 +892,16 @@ HWY_NOINLINE void SortingNetwork(Traits st, T* HWY_RESTRICT buf, size_t cols) {
#else
template <class Base>
struct SharedTraits : public Base {};
+
+namespace detail {
+
+// Empty function to avoid a possible -Wpragma-clang-attribute warning if
+// compiling with Clang
+static HWY_INLINE HWY_MAYBE_UNUSED void HWY_CONCAT(UnusedSortingNetworksFunc,
+ __LINE__)() {}
+
+} // namespace detail
+
#endif // VQSORT_ENABLED
} // namespace detail
diff --git a/third_party/highway/hwy/contrib/sort/traits-inl.h b/third_party/highway/hwy/contrib/sort/traits-inl.h
index efa410c81d..82251a8473 100644
--- a/third_party/highway/hwy/contrib/sort/traits-inl.h
+++ b/third_party/highway/hwy/contrib/sort/traits-inl.h
@@ -89,7 +89,6 @@ Vec<D> LargerSortValue(D d, Vec<D> v) {
using T = TFromD<decltype(d)>;
const RebindToUnsigned<D> du;
using VU = Vec<decltype(du)>;
- using TU = TFromD<decltype(du)>;
const VU vu = BitCast(du, Abs(v));
@@ -97,7 +96,14 @@ Vec<D> LargerSortValue(D d, Vec<D> v) {
// than float comparison and treats -0 as 0 (so we return +epsilon).
const Mask<decltype(du)> was_pos = Le(BitCast(du, v), SignBit(du));
// If positive, add 1, else -1.
+#if HWY_ARCH_ARM_V7
+ // Workaround for incorrect codegen. ~x - x is equivalent to x? 1 : -1.
+ const VU was_pos_u = VecFromMask(du, was_pos);
+ const VU add = Not(was_pos_u) - was_pos_u;
+#else
+ using TU = TFromD<decltype(du)>;
const VU add = IfThenElse(was_pos, Set(du, 1u), Set(du, LimitsMax<TU>()));
+#endif
// Prev/next integer is the prev/next value, even if mantissa under/overflows.
v = BitCast(d, Add(vu, add));
// But we may have overflowed into inf or NaN; replace with inf if positive,
diff --git a/third_party/highway/hwy/contrib/sort/vqsort-inl.h b/third_party/highway/hwy/contrib/sort/vqsort-inl.h
index 5eaf4d56f8..0bb03aaeb3 100644
--- a/third_party/highway/hwy/contrib/sort/vqsort-inl.h
+++ b/third_party/highway/hwy/contrib/sort/vqsort-inl.h
@@ -691,9 +691,11 @@ HWY_INLINE size_t PartitionRightmost(D d, Traits st, T* const keys,
}
const size_t numWrittenR = bufR - max_buf;
- // MSan seems not to understand CompressStore.
+// Prior to 2022-10, Clang MSAN did not understand AVX-512 CompressStore.
+#if HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1600
detail::MaybeUnpoison(buf, bufL);
detail::MaybeUnpoison(buf + max_buf, numWrittenR);
+#endif
// Overwrite already-read end of keys with bufR.
writeR = num - numWrittenR;
diff --git a/third_party/highway/hwy/contrib/sort/vqsort.cc b/third_party/highway/hwy/contrib/sort/vqsort.cc
new file mode 100644
index 0000000000..42f7f154ab
--- /dev/null
+++ b/third_party/highway/hwy/contrib/sort/vqsort.cc
@@ -0,0 +1,121 @@
+// Copyright 2021 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "third_party/highway/hwy/contrib/sort/vqsort.h"
+
+#include "third_party/highway/hwy/base.h"
+#include "third_party/highway/hwy/contrib/sort/vqsort-inl.h"
+#include "third_party/highway/hwy/per_target.h"
+
+// Check if we have getrandom from <sys/random.h>. Because <features.h> is
+// unavailable on Android and non-Linux RVV, we assume that those systems lack
+// getrandom. Note that the only supported sources of entropy are getrandom or
+// Windows, thus VQSORT_SECURE_SEED=0 when this is 0 and we are not on Windows.
+#if defined(ANDROID) || defined(__ANDROID__) || \
+ (HWY_ARCH_RISCV && !HWY_OS_LINUX)
+#define VQSORT_GETRANDOM 0
+#endif
+
+#if !defined(VQSORT_GETRANDOM) && HWY_OS_LINUX
+#include <features.h>
+
+// ---- which libc
+#if defined(__UCLIBC__)
+#define VQSORT_GETRANDOM 1 // added Mar 2015, before uclibc-ng 1.0
+
+#elif defined(__GLIBC__) && defined(__GLIBC_PREREQ)
+#if __GLIBC_PREREQ(2, 25)
+#define VQSORT_GETRANDOM 1
+#else
+#define VQSORT_GETRANDOM 0
+#endif
+
+#else
+// Assume MUSL, which has getrandom since 2018. There is no macro to test, see
+// https://www.openwall.com/lists/musl/2013/03/29/13.
+#define VQSORT_GETRANDOM 1
+
+#endif // ---- which libc
+#endif // linux
+
+#if !defined(VQSORT_GETRANDOM)
+#define VQSORT_GETRANDOM 0
+#endif
+
+// Choose a seed source for SFC generator: 1=getrandom, 2=CryptGenRandom.
+// Allow user override - not all Android support the getrandom wrapper.
+#ifndef VQSORT_SECURE_SEED
+
+#if VQSORT_GETRANDOM
+#define VQSORT_SECURE_SEED 1
+#elif defined(_WIN32) || defined(_WIN64)
+#define VQSORT_SECURE_SEED 2
+#else
+#define VQSORT_SECURE_SEED 0
+#endif
+
+#endif // VQSORT_SECURE_SEED
+
+// Pull in dependencies of the chosen seed source.
+#if VQSORT_SECURE_SEED == 1
+#include <sys/random.h>
+#elif VQSORT_SECURE_SEED == 2
+#ifndef NOMINMAX
+#define NOMINMAX
+#endif // NOMINMAX
+#ifndef WIN32_LEAN_AND_MEAN
+#define WIN32_LEAN_AND_MEAN
+#endif // WIN32_LEAN_AND_MEAN
+#include <windows.h>
+#if HWY_COMPILER_MSVC || HWY_COMPILER_CLANGCL
+#pragma comment(lib, "advapi32.lib")
+#endif // HWY_COMPILER_MSVC || HWY_COMPILER_CLANGCL
+// Must come after windows.h.
+#include <wincrypt.h>
+#endif // VQSORT_SECURE_SEED
+
+namespace hwy {
+
+// Returns false or performs the equivalent of `memcpy(bytes, r, 16)`, where r
+// is high-quality (unpredictable, uniformly distributed) random bits.
+bool Fill16BytesSecure(void* bytes) {
+#if VQSORT_SECURE_SEED == 1
+ // May block if urandom is not yet initialized.
+ const ssize_t ret = getrandom(bytes, 16, /*flags=*/0);
+ if (ret == 16) return true;
+#elif VQSORT_SECURE_SEED == 2
+ HCRYPTPROV hProvider{};
+ if (CryptAcquireContextA(&hProvider, nullptr, nullptr, PROV_RSA_FULL,
+ CRYPT_VERIFYCONTEXT)) {
+ const BOOL ok =
+ CryptGenRandom(hProvider, 16, reinterpret_cast<BYTE*>(bytes));
+ CryptReleaseContext(hProvider, 0);
+ if (ok) return true;
+ }
+#else
+ (void)bytes;
+#endif
+
+ return false;
+}
+
+// Unused, only for ABI compatibility
+void Sorter::Fill24Bytes(const void*, size_t, void*) {}
+bool Sorter::HaveFloat64() { return hwy::HaveFloat64(); }
+Sorter::Sorter() {}
+void Sorter::Delete() {}
+uint64_t* GetGeneratorState() { return hwy::detail::GetGeneratorStateStatic(); }
+
+} // namespace hwy
diff --git a/third_party/highway/hwy/contrib/sort/vqsort_128a.cc b/third_party/highway/hwy/contrib/sort/vqsort_128a.cc
new file mode 100644
index 0000000000..c807de0bda
--- /dev/null
+++ b/third_party/highway/hwy/contrib/sort/vqsort_128a.cc
@@ -0,0 +1,98 @@
+// Copyright 2021 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "third_party/highway/hwy/contrib/sort/vqsort.h" // VQSort
+
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_128a.cc"
+#include "third_party/highway/hwy/foreach_target.h" // IWYU pragma: keep
+
+// After foreach_target
+#include "third_party/highway/hwy/contrib/sort/vqsort-inl.h"
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+namespace {
+
+void Sort128Asc(uint128_t* HWY_RESTRICT keys, const size_t num) {
+// 128-bit keys require 128-bit SIMD.
+#if HWY_TARGET != HWY_SCALAR
+ return VQSortStatic(keys, num, SortAscending());
+#else
+ (void)keys;
+ (void)num;
+#endif
+}
+
+void PartialSort128Asc(uint128_t* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+// 128-bit keys require 128-bit SIMD.
+#if HWY_TARGET != HWY_SCALAR
+ return VQPartialSortStatic(keys, num, k, SortAscending());
+#else
+ (void)keys;
+ (void)num;
+ (void)k;
+#endif
+}
+
+void Select128Asc(uint128_t* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+// 128-bit keys require 128-bit SIMD.
+#if HWY_TARGET != HWY_SCALAR
+ return VQSelectStatic(keys, num, k, SortAscending());
+#else
+ (void)keys;
+ (void)num;
+ (void)k;
+#endif
+}
+
+} // namespace
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace hwy {
+namespace {
+HWY_EXPORT(Sort128Asc);
+HWY_EXPORT(PartialSort128Asc);
+HWY_EXPORT(Select128Asc);
+} // namespace
+
+void VQSort(uint128_t* HWY_RESTRICT keys, const size_t n, SortAscending) {
+ HWY_DYNAMIC_DISPATCH(Sort128Asc)(keys, n);
+}
+
+void VQPartialSort(uint128_t* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortAscending) {
+ HWY_DYNAMIC_DISPATCH(PartialSort128Asc)(keys, n, k);
+}
+
+void VQSelect(uint128_t* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortAscending) {
+ HWY_DYNAMIC_DISPATCH(Select128Asc)(keys, n, k);
+}
+
+void Sorter::operator()(uint128_t* HWY_RESTRICT keys, size_t n,
+ SortAscending tag) const {
+ VQSort(keys, n, tag);
+}
+
+} // namespace hwy
+#endif // HWY_ONCE
diff --git a/third_party/highway/hwy/contrib/sort/vqsort_128d.cc b/third_party/highway/hwy/contrib/sort/vqsort_128d.cc
new file mode 100644
index 0000000000..42a367a459
--- /dev/null
+++ b/third_party/highway/hwy/contrib/sort/vqsort_128d.cc
@@ -0,0 +1,98 @@
+// Copyright 2021 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "third_party/highway/hwy/contrib/sort/vqsort.h" // VQSort
+
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_128d.cc"
+#include "third_party/highway/hwy/foreach_target.h" // IWYU pragma: keep
+
+// After foreach_target
+#include "third_party/highway/hwy/contrib/sort/vqsort-inl.h"
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+namespace {
+
+void Sort128Desc(uint128_t* HWY_RESTRICT keys, const size_t num) {
+ // 128-bit keys require 128-bit SIMD.
+#if HWY_TARGET != HWY_SCALAR
+ return VQSortStatic(keys, num, SortDescending());
+#else
+ (void)keys;
+ (void)num;
+#endif
+}
+
+void PartialSort128Desc(uint128_t* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+// 128-bit keys require 128-bit SIMD.
+#if HWY_TARGET != HWY_SCALAR
+ return VQPartialSortStatic(keys, num, k, SortDescending());
+#else
+ (void)keys;
+ (void)num;
+ (void)k;
+#endif
+}
+
+void Select128Desc(uint128_t* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+// 128-bit keys require 128-bit SIMD.
+#if HWY_TARGET != HWY_SCALAR
+ return VQSelectStatic(keys, num, k, SortDescending());
+#else
+ (void)keys;
+ (void)num;
+ (void)k;
+#endif
+}
+
+} // namespace
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace hwy {
+namespace {
+HWY_EXPORT(Sort128Desc);
+HWY_EXPORT(PartialSort128Desc);
+HWY_EXPORT(Select128Desc);
+} // namespace
+
+void VQSort(uint128_t* HWY_RESTRICT keys, const size_t n, SortDescending) {
+ HWY_DYNAMIC_DISPATCH(Sort128Desc)(keys, n);
+}
+
+void VQPartialSort(uint128_t* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortDescending) {
+ HWY_DYNAMIC_DISPATCH(PartialSort128Desc)(keys, n, k);
+}
+
+void VQSelect(uint128_t* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortDescending) {
+ HWY_DYNAMIC_DISPATCH(Select128Desc)(keys, n, k);
+}
+
+void Sorter::operator()(uint128_t* HWY_RESTRICT keys, size_t n,
+ SortDescending tag) const {
+ VQSort(keys, n, tag);
+}
+
+} // namespace hwy
+#endif // HWY_ONCE
diff --git a/third_party/highway/hwy/contrib/sort/vqsort_f16a.cc b/third_party/highway/hwy/contrib/sort/vqsort_f16a.cc
new file mode 100644
index 0000000000..841aaf4411
--- /dev/null
+++ b/third_party/highway/hwy/contrib/sort/vqsort_f16a.cc
@@ -0,0 +1,99 @@
+// Copyright 2021 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "third_party/highway/hwy/contrib/sort/vqsort.h" // VQSort
+#include "third_party/highway/hwy/nanobenchmark.h" // Unpredictable1
+
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_f16a.cc"
+#include "third_party/highway/hwy/foreach_target.h" // IWYU pragma: keep
+
+// After foreach_target
+#include "third_party/highway/hwy/contrib/sort/vqsort-inl.h"
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+namespace {
+
+void SortF16Asc(float16_t* HWY_RESTRICT keys, const size_t num) {
+#if HWY_HAVE_FLOAT16
+ return VQSortStatic(keys, num, SortAscending());
+#else
+ (void)keys;
+ (void)num;
+ if (Unpredictable1()) HWY_ASSERT(0);
+#endif
+}
+
+void PartialSortF16Asc(float16_t* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+#if HWY_HAVE_FLOAT16
+ return VQPartialSortStatic(keys, num, k, SortAscending());
+#else
+ (void)keys;
+ (void)num;
+ (void)k;
+ if (Unpredictable1()) HWY_ASSERT(0);
+#endif
+}
+
+void SelectF16Asc(float16_t* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+#if HWY_HAVE_FLOAT16
+ return VQSelectStatic(keys, num, k, SortAscending());
+#else
+ (void)keys;
+ (void)num;
+ (void)k;
+ if (Unpredictable1()) HWY_ASSERT(0);
+#endif
+}
+
+} // namespace
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace hwy {
+namespace {
+HWY_EXPORT(SortF16Asc);
+HWY_EXPORT(PartialSortF16Asc);
+HWY_EXPORT(SelectF16Asc);
+} // namespace
+
+void VQSort(float16_t* HWY_RESTRICT keys, const size_t n, SortAscending) {
+ HWY_DYNAMIC_DISPATCH(SortF16Asc)(keys, n);
+}
+
+void VQPartialSort(float16_t* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortAscending) {
+ HWY_DYNAMIC_DISPATCH(PartialSortF16Asc)(keys, n, k);
+}
+
+void VQSelect(float16_t* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortAscending) {
+ HWY_DYNAMIC_DISPATCH(SelectF16Asc)(keys, n, k);
+}
+
+void Sorter::operator()(float16_t* HWY_RESTRICT keys, size_t n,
+ SortAscending tag) const {
+ VQSort(keys, n, tag);
+}
+
+} // namespace hwy
+#endif // HWY_ONCE
diff --git a/third_party/highway/hwy/contrib/sort/vqsort_f16d.cc b/third_party/highway/hwy/contrib/sort/vqsort_f16d.cc
new file mode 100644
index 0000000000..d744efcc94
--- /dev/null
+++ b/third_party/highway/hwy/contrib/sort/vqsort_f16d.cc
@@ -0,0 +1,99 @@
+// Copyright 2021 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "third_party/highway/hwy/contrib/sort/vqsort.h" // VQSort
+#include "third_party/highway/hwy/nanobenchmark.h" //
+
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_f16d.cc"
+#include "third_party/highway/hwy/foreach_target.h" // IWYU pragma: keep
+
+// After foreach_target
+#include "third_party/highway/hwy/contrib/sort/vqsort-inl.h"
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+namespace {
+
+void SortF16Desc(float16_t* HWY_RESTRICT keys, const size_t num) {
+#if HWY_HAVE_FLOAT16
+ return VQSortStatic(keys, num, SortDescending());
+#else
+ (void)keys;
+ (void)num;
+ if (Unpredictable1()) HWY_ASSERT(0);
+#endif
+}
+
+void PartialSortF16Desc(float16_t* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+#if HWY_HAVE_FLOAT16
+ return VQPartialSortStatic(keys, num, k, SortDescending());
+#else
+ (void)keys;
+ (void)num;
+ (void)k;
+ if (Unpredictable1()) HWY_ASSERT(0);
+#endif
+}
+
+void SelectF16Desc(float16_t* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+#if HWY_HAVE_FLOAT16
+ return VQSelectStatic(keys, num, k, SortDescending());
+#else
+ (void)keys;
+ (void)num;
+ (void)k;
+ if (Unpredictable1()) HWY_ASSERT(0);
+#endif
+}
+
+} // namespace
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace hwy {
+namespace {
+HWY_EXPORT(SortF16Desc);
+HWY_EXPORT(PartialSortF16Desc);
+HWY_EXPORT(SelectF16Desc);
+} // namespace
+
+void VQSort(float16_t* HWY_RESTRICT keys, const size_t n, SortDescending) {
+ HWY_DYNAMIC_DISPATCH(SortF16Desc)(keys, n);
+}
+
+void VQPartialSort(float16_t* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortDescending) {
+ HWY_DYNAMIC_DISPATCH(PartialSortF16Desc)(keys, n, k);
+}
+
+void VQSelect(float16_t* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortDescending) {
+ HWY_DYNAMIC_DISPATCH(SelectF16Desc)(keys, n, k);
+}
+
+void Sorter::operator()(float16_t* HWY_RESTRICT keys, size_t n,
+ SortDescending tag) const {
+ VQSort(keys, n, tag);
+}
+
+} // namespace hwy
+#endif // HWY_ONCE
diff --git a/third_party/highway/hwy/contrib/sort/vqsort_f32a.cc b/third_party/highway/hwy/contrib/sort/vqsort_f32a.cc
new file mode 100644
index 0000000000..1d82a2503d
--- /dev/null
+++ b/third_party/highway/hwy/contrib/sort/vqsort_f32a.cc
@@ -0,0 +1,77 @@
+// Copyright 2021 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "third_party/highway/hwy/contrib/sort/vqsort.h" // VQSort
+
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_f32a.cc"
+#include "third_party/highway/hwy/foreach_target.h" // IWYU pragma: keep
+
+// After foreach_target
+#include "third_party/highway/hwy/contrib/sort/vqsort-inl.h"
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+namespace {
+
+void SortF32Asc(float* HWY_RESTRICT keys, const size_t num) {
+ return VQSortStatic(keys, num, SortAscending());
+}
+
+void PartialSortF32Asc(float* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+ return VQPartialSortStatic(keys, num, k, SortAscending());
+}
+
+void SelectF32Asc(float* HWY_RESTRICT keys, const size_t num, const size_t k) {
+ return VQSelectStatic(keys, num, k, SortAscending());
+}
+
+} // namespace
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace hwy {
+namespace {
+HWY_EXPORT(SortF32Asc);
+HWY_EXPORT(PartialSortF32Asc);
+HWY_EXPORT(SelectF32Asc);
+} // namespace
+
+void VQSort(float* HWY_RESTRICT keys, const size_t n, SortAscending) {
+ HWY_DYNAMIC_DISPATCH(SortF32Asc)(keys, n);
+}
+
+void VQPartialSort(float* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortAscending) {
+ HWY_DYNAMIC_DISPATCH(PartialSortF32Asc)(keys, n, k);
+}
+
+void VQSelect(float* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortAscending) {
+ HWY_DYNAMIC_DISPATCH(SelectF32Asc)(keys, n, k);
+}
+
+void Sorter::operator()(float* HWY_RESTRICT keys, size_t n,
+ SortAscending tag) const {
+ VQSort(keys, n, tag);
+}
+
+} // namespace hwy
+#endif // HWY_ONCE
diff --git a/third_party/highway/hwy/contrib/sort/vqsort_f32d.cc b/third_party/highway/hwy/contrib/sort/vqsort_f32d.cc
new file mode 100644
index 0000000000..24c6724156
--- /dev/null
+++ b/third_party/highway/hwy/contrib/sort/vqsort_f32d.cc
@@ -0,0 +1,77 @@
+// Copyright 2021 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "third_party/highway/hwy/contrib/sort/vqsort.h" // VQSort
+
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_f32d.cc"
+#include "third_party/highway/hwy/foreach_target.h" // IWYU pragma: keep
+
+// After foreach_target
+#include "third_party/highway/hwy/contrib/sort/vqsort-inl.h"
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+namespace {
+
+void SortF32Desc(float* HWY_RESTRICT keys, const size_t num) {
+ return VQSortStatic(keys, num, SortDescending());
+}
+
+void PartialSortF32Desc(float* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+ return VQPartialSortStatic(keys, num, k, SortDescending());
+}
+
+void SelectF32Desc(float* HWY_RESTRICT keys, const size_t num, const size_t k) {
+ return VQSelectStatic(keys, num, k, SortDescending());
+}
+
+} // namespace
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace hwy {
+namespace {
+HWY_EXPORT(SortF32Desc);
+HWY_EXPORT(PartialSortF32Desc);
+HWY_EXPORT(SelectF32Desc);
+} // namespace
+
+void VQSort(float* HWY_RESTRICT keys, const size_t n, SortDescending) {
+ HWY_DYNAMIC_DISPATCH(SortF32Desc)(keys, n);
+}
+
+void VQPartialSort(float* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortDescending) {
+ HWY_DYNAMIC_DISPATCH(PartialSortF32Desc)(keys, n, k);
+}
+
+void VQSelect(float* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortDescending) {
+ HWY_DYNAMIC_DISPATCH(SelectF32Desc)(keys, n, k);
+}
+
+void Sorter::operator()(float* HWY_RESTRICT keys, size_t n,
+ SortDescending tag) const {
+ VQSort(keys, n, tag);
+}
+
+} // namespace hwy
+#endif // HWY_ONCE
diff --git a/third_party/highway/hwy/contrib/sort/vqsort_f64a.cc b/third_party/highway/hwy/contrib/sort/vqsort_f64a.cc
new file mode 100644
index 0000000000..311788b46e
--- /dev/null
+++ b/third_party/highway/hwy/contrib/sort/vqsort_f64a.cc
@@ -0,0 +1,97 @@
+// Copyright 2021 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "third_party/highway/hwy/contrib/sort/vqsort.h" // VQSort
+
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_f64a.cc"
+#include "third_party/highway/hwy/foreach_target.h" // IWYU pragma: keep
+
+// After foreach_target
+#include "third_party/highway/hwy/contrib/sort/vqsort-inl.h"
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+namespace {
+
+void SortF64Asc(double* HWY_RESTRICT keys, const size_t num) {
+#if HWY_HAVE_FLOAT64
+ return VQSortStatic(keys, num, SortAscending());
+#else
+ (void)keys;
+ (void)num;
+ HWY_ASSERT(0);
+#endif
+}
+
+void PartialSortF64Asc(double* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+#if HWY_HAVE_FLOAT64
+ return VQPartialSortStatic(keys, num, k, SortAscending());
+#else
+ (void)keys;
+ (void)num;
+ (void)k;
+ HWY_ASSERT(0);
+#endif
+}
+
+void SelectF64Asc(double* HWY_RESTRICT keys, const size_t num, const size_t k) {
+#if HWY_HAVE_FLOAT64
+ return VQSelectStatic(keys, num, k, SortAscending());
+#else
+ (void)keys;
+ (void)num;
+ (void)k;
+ HWY_ASSERT(0);
+#endif
+}
+
+} // namespace
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace hwy {
+namespace {
+HWY_EXPORT(SortF64Asc);
+HWY_EXPORT(PartialSortF64Asc);
+HWY_EXPORT(SelectF64Asc);
+} // namespace
+
+void VQSort(double* HWY_RESTRICT keys, const size_t n, SortAscending) {
+ HWY_DYNAMIC_DISPATCH(SortF64Asc)(keys, n);
+}
+
+void VQPartialSort(double* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortAscending) {
+ HWY_DYNAMIC_DISPATCH(PartialSortF64Asc)(keys, n, k);
+}
+
+void VQSelect(double* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortAscending) {
+ HWY_DYNAMIC_DISPATCH(SelectF64Asc)(keys, n, k);
+}
+
+void Sorter::operator()(double* HWY_RESTRICT keys, size_t n,
+ SortAscending tag) const {
+ VQSort(keys, n, tag);
+}
+
+} // namespace hwy
+#endif // HWY_ONCE
diff --git a/third_party/highway/hwy/contrib/sort/vqsort_f64d.cc b/third_party/highway/hwy/contrib/sort/vqsort_f64d.cc
new file mode 100644
index 0000000000..9d8cd5ee08
--- /dev/null
+++ b/third_party/highway/hwy/contrib/sort/vqsort_f64d.cc
@@ -0,0 +1,98 @@
+// Copyright 2021 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "third_party/highway/hwy/contrib/sort/vqsort.h" // VQSort
+
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_f64d.cc"
+#include "third_party/highway/hwy/foreach_target.h" // IWYU pragma: keep
+
+// After foreach_target
+#include "third_party/highway/hwy/contrib/sort/vqsort-inl.h"
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+namespace {
+
+void SortF64Desc(double* HWY_RESTRICT keys, const size_t num) {
+#if HWY_HAVE_FLOAT64
+ return VQSortStatic(keys, num, SortDescending());
+#else
+ (void)keys;
+ (void)num;
+ HWY_ASSERT(0);
+#endif
+}
+
+void PartialSortF64Desc(double* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+#if HWY_HAVE_FLOAT64
+ return VQPartialSortStatic(keys, num, k, SortDescending());
+#else
+ (void)keys;
+ (void)num;
+ (void)k;
+ HWY_ASSERT(0);
+#endif
+}
+
+void SelectF64Desc(double* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+#if HWY_HAVE_FLOAT64
+ return VQSelectStatic(keys, num, k, SortDescending());
+#else
+ (void)keys;
+ (void)num;
+ (void)k;
+ HWY_ASSERT(0);
+#endif
+}
+
+} // namespace
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace hwy {
+namespace {
+HWY_EXPORT(SortF64Desc);
+HWY_EXPORT(PartialSortF64Desc);
+HWY_EXPORT(SelectF64Desc);
+} // namespace
+
+void VQSort(double* HWY_RESTRICT keys, const size_t n, SortDescending) {
+ HWY_DYNAMIC_DISPATCH(SortF64Desc)(keys, n);
+}
+
+void VQPartialSort(double* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortDescending) {
+ HWY_DYNAMIC_DISPATCH(PartialSortF64Desc)(keys, n, k);
+}
+
+void VQSelect(double* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortDescending) {
+ HWY_DYNAMIC_DISPATCH(SelectF64Desc)(keys, n, k);
+}
+
+void Sorter::operator()(double* HWY_RESTRICT keys, size_t n,
+ SortDescending tag) const {
+ VQSort(keys, n, tag);
+}
+
+} // namespace hwy
+#endif // HWY_ONCE
diff --git a/third_party/highway/hwy/contrib/sort/vqsort_i16a.cc b/third_party/highway/hwy/contrib/sort/vqsort_i16a.cc
new file mode 100644
index 0000000000..91bb9809c3
--- /dev/null
+++ b/third_party/highway/hwy/contrib/sort/vqsort_i16a.cc
@@ -0,0 +1,78 @@
+// Copyright 2021 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "third_party/highway/hwy/contrib/sort/vqsort.h" // VQSort
+
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_i16a.cc"
+#include "third_party/highway/hwy/foreach_target.h" // IWYU pragma: keep
+
+// After foreach_target
+#include "third_party/highway/hwy/contrib/sort/vqsort-inl.h"
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+namespace {
+
+void SortI16Asc(int16_t* HWY_RESTRICT keys, const size_t num) {
+ return VQSortStatic(keys, num, SortAscending());
+}
+
+void PartialSortI16Asc(int16_t* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+ return VQPartialSortStatic(keys, num, k, SortAscending());
+}
+
+void SelectI16Asc(int16_t* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+ return VQSelectStatic(keys, num, k, SortAscending());
+}
+
+} // namespace
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace hwy {
+namespace {
+HWY_EXPORT(SortI16Asc);
+HWY_EXPORT(PartialSortI16Asc);
+HWY_EXPORT(SelectI16Asc);
+} // namespace
+
+void VQSort(int16_t* HWY_RESTRICT keys, const size_t n, SortAscending) {
+ HWY_DYNAMIC_DISPATCH(SortI16Asc)(keys, n);
+}
+
+void VQPartialSort(int16_t* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortAscending) {
+ HWY_DYNAMIC_DISPATCH(PartialSortI16Asc)(keys, n, k);
+}
+
+void VQSelect(int16_t* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortAscending) {
+ HWY_DYNAMIC_DISPATCH(SelectI16Asc)(keys, n, k);
+}
+
+void Sorter::operator()(int16_t* HWY_RESTRICT keys, size_t n,
+ SortAscending tag) const {
+ VQSort(keys, n, tag);
+}
+
+} // namespace hwy
+#endif // HWY_ONCE
diff --git a/third_party/highway/hwy/contrib/sort/vqsort_i16d.cc b/third_party/highway/hwy/contrib/sort/vqsort_i16d.cc
new file mode 100644
index 0000000000..96674bbe4d
--- /dev/null
+++ b/third_party/highway/hwy/contrib/sort/vqsort_i16d.cc
@@ -0,0 +1,78 @@
+// Copyright 2021 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "third_party/highway/hwy/contrib/sort/vqsort.h" // VQSort
+
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_i16d.cc"
+#include "third_party/highway/hwy/foreach_target.h" // IWYU pragma: keep
+
+// After foreach_target
+#include "third_party/highway/hwy/contrib/sort/vqsort-inl.h"
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+namespace {
+
+void SortI16Desc(int16_t* HWY_RESTRICT keys, const size_t num) {
+ return VQSortStatic(keys, num, SortDescending());
+}
+
+void PartialSortI16Desc(int16_t* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+ return VQPartialSortStatic(keys, num, k, SortDescending());
+}
+
+void SelectI16Desc(int16_t* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+ return VQSelectStatic(keys, num, k, SortDescending());
+}
+
+} // namespace
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace hwy {
+namespace {
+HWY_EXPORT(SortI16Desc);
+HWY_EXPORT(PartialSortI16Desc);
+HWY_EXPORT(SelectI16Desc);
+} // namespace
+
+void VQSort(int16_t* HWY_RESTRICT keys, const size_t n, SortDescending) {
+ HWY_DYNAMIC_DISPATCH(SortI16Desc)(keys, n);
+}
+
+void VQPartialSort(int16_t* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortDescending) {
+ HWY_DYNAMIC_DISPATCH(PartialSortI16Desc)(keys, n, k);
+}
+
+void VQSelect(int16_t* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortDescending) {
+ HWY_DYNAMIC_DISPATCH(SelectI16Desc)(keys, n, k);
+}
+
+void Sorter::operator()(int16_t* HWY_RESTRICT keys, size_t n,
+ SortDescending tag) const {
+ VQSort(keys, n, tag);
+}
+
+} // namespace hwy
+#endif // HWY_ONCE
diff --git a/third_party/highway/hwy/contrib/sort/vqsort_i32a.cc b/third_party/highway/hwy/contrib/sort/vqsort_i32a.cc
new file mode 100644
index 0000000000..7bbfcfb232
--- /dev/null
+++ b/third_party/highway/hwy/contrib/sort/vqsort_i32a.cc
@@ -0,0 +1,78 @@
+// Copyright 2021 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "third_party/highway/hwy/contrib/sort/vqsort.h" // VQSort
+
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_i32a.cc"
+#include "third_party/highway/hwy/foreach_target.h" // IWYU pragma: keep
+
+// After foreach_target
+#include "third_party/highway/hwy/contrib/sort/vqsort-inl.h"
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+namespace {
+
+void SortI32Asc(int32_t* HWY_RESTRICT keys, const size_t num) {
+ return VQSortStatic(keys, num, SortAscending());
+}
+
+void PartialSortI32Asc(int32_t* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+ return VQPartialSortStatic(keys, num, k, SortAscending());
+}
+
+void SelectI32Asc(int32_t* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+ return VQSelectStatic(keys, num, k, SortAscending());
+}
+
+} // namespace
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace hwy {
+namespace {
+HWY_EXPORT(SortI32Asc);
+HWY_EXPORT(PartialSortI32Asc);
+HWY_EXPORT(SelectI32Asc);
+} // namespace
+
+void VQSort(int32_t* HWY_RESTRICT keys, const size_t n, SortAscending) {
+ HWY_DYNAMIC_DISPATCH(SortI32Asc)(keys, n);
+}
+
+void VQPartialSort(int32_t* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortAscending) {
+ HWY_DYNAMIC_DISPATCH(PartialSortI32Asc)(keys, n, k);
+}
+
+void VQSelect(int32_t* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortAscending) {
+ HWY_DYNAMIC_DISPATCH(SelectI32Asc)(keys, n, k);
+}
+
+void Sorter::operator()(int32_t* HWY_RESTRICT keys, size_t n,
+ SortAscending tag) const {
+ VQSort(keys, n, tag);
+}
+
+} // namespace hwy
+#endif // HWY_ONCE
diff --git a/third_party/highway/hwy/contrib/sort/vqsort_i32d.cc b/third_party/highway/hwy/contrib/sort/vqsort_i32d.cc
new file mode 100644
index 0000000000..cf7ab6657e
--- /dev/null
+++ b/third_party/highway/hwy/contrib/sort/vqsort_i32d.cc
@@ -0,0 +1,78 @@
+// Copyright 2021 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "third_party/highway/hwy/contrib/sort/vqsort.h" // VQSort
+
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_i32d.cc"
+#include "third_party/highway/hwy/foreach_target.h" // IWYU pragma: keep
+
+// After foreach_target
+#include "third_party/highway/hwy/contrib/sort/vqsort-inl.h"
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+namespace {
+
+void SortI32Desc(int32_t* HWY_RESTRICT keys, const size_t num) {
+ return VQSortStatic(keys, num, SortDescending());
+}
+
+void PartialSortI32Desc(int32_t* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+ return VQPartialSortStatic(keys, num, k, SortDescending());
+}
+
+void SelectI32Desc(int32_t* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+ return VQSelectStatic(keys, num, k, SortDescending());
+}
+
+} // namespace
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace hwy {
+namespace {
+HWY_EXPORT(SortI32Desc);
+HWY_EXPORT(PartialSortI32Desc);
+HWY_EXPORT(SelectI32Desc);
+} // namespace
+
+void VQSort(int32_t* HWY_RESTRICT keys, const size_t n, SortDescending) {
+ HWY_DYNAMIC_DISPATCH(SortI32Desc)(keys, n);
+}
+
+void VQPartialSort(int32_t* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortDescending) {
+ HWY_DYNAMIC_DISPATCH(PartialSortI32Desc)(keys, n, k);
+}
+
+void VQSelect(int32_t* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortDescending) {
+ HWY_DYNAMIC_DISPATCH(SelectI32Desc)(keys, n, k);
+}
+
+void Sorter::operator()(int32_t* HWY_RESTRICT keys, size_t n,
+ SortDescending tag) const {
+ VQSort(keys, n, tag);
+}
+
+} // namespace hwy
+#endif // HWY_ONCE
diff --git a/third_party/highway/hwy/contrib/sort/vqsort_i64a.cc b/third_party/highway/hwy/contrib/sort/vqsort_i64a.cc
new file mode 100644
index 0000000000..4d01a454ef
--- /dev/null
+++ b/third_party/highway/hwy/contrib/sort/vqsort_i64a.cc
@@ -0,0 +1,78 @@
+// Copyright 2021 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "third_party/highway/hwy/contrib/sort/vqsort.h" // VQSort
+
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_i64a.cc"
+#include "third_party/highway/hwy/foreach_target.h" // IWYU pragma: keep
+
+// After foreach_target
+#include "third_party/highway/hwy/contrib/sort/vqsort-inl.h"
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+namespace {
+
+void SortI64Asc(int64_t* HWY_RESTRICT keys, const size_t num) {
+ return VQSortStatic(keys, num, SortAscending());
+}
+
+void PartialSortI64Asc(int64_t* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+ return VQPartialSortStatic(keys, num, k, SortAscending());
+}
+
+void SelectI64Asc(int64_t* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+ return VQSelectStatic(keys, num, k, SortAscending());
+}
+
+} // namespace
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace hwy {
+namespace {
+HWY_EXPORT(SortI64Asc);
+HWY_EXPORT(PartialSortI64Asc);
+HWY_EXPORT(SelectI64Asc);
+} // namespace
+
+void VQSort(int64_t* HWY_RESTRICT keys, const size_t n, SortAscending) {
+ HWY_DYNAMIC_DISPATCH(SortI64Asc)(keys, n);
+}
+
+void VQPartialSort(int64_t* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortAscending) {
+ HWY_DYNAMIC_DISPATCH(PartialSortI64Asc)(keys, n, k);
+}
+
+void VQSelect(int64_t* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortAscending) {
+ HWY_DYNAMIC_DISPATCH(SelectI64Asc)(keys, n, k);
+}
+
+void Sorter::operator()(int64_t* HWY_RESTRICT keys, size_t n,
+ SortAscending tag) const {
+ VQSort(keys, n, tag);
+}
+
+} // namespace hwy
+#endif // HWY_ONCE
diff --git a/third_party/highway/hwy/contrib/sort/vqsort_i64d.cc b/third_party/highway/hwy/contrib/sort/vqsort_i64d.cc
new file mode 100644
index 0000000000..e63a7f8c6f
--- /dev/null
+++ b/third_party/highway/hwy/contrib/sort/vqsort_i64d.cc
@@ -0,0 +1,78 @@
+// Copyright 2021 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "third_party/highway/hwy/contrib/sort/vqsort.h" // VQSort
+
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_i64d.cc"
+#include "third_party/highway/hwy/foreach_target.h" // IWYU pragma: keep
+
+// After foreach_target
+#include "third_party/highway/hwy/contrib/sort/vqsort-inl.h"
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+namespace {
+
+void SortI64Desc(int64_t* HWY_RESTRICT keys, const size_t num) {
+ return VQSortStatic(keys, num, SortDescending());
+}
+
+void PartialSortI64Desc(int64_t* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+ return VQPartialSortStatic(keys, num, k, SortDescending());
+}
+
+void SelectI64Desc(int64_t* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+ return VQSelectStatic(keys, num, k, SortDescending());
+}
+
+} // namespace
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace hwy {
+namespace {
+HWY_EXPORT(SortI64Desc);
+HWY_EXPORT(PartialSortI64Desc);
+HWY_EXPORT(SelectI64Desc);
+} // namespace
+
+void VQSort(int64_t* HWY_RESTRICT keys, const size_t n, SortDescending) {
+ HWY_DYNAMIC_DISPATCH(SortI64Desc)(keys, n);
+}
+
+void VQPartialSort(int64_t* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortDescending) {
+ HWY_DYNAMIC_DISPATCH(PartialSortI64Desc)(keys, n, k);
+}
+
+void VQSelect(int64_t* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortDescending) {
+ HWY_DYNAMIC_DISPATCH(SelectI64Desc)(keys, n, k);
+}
+
+void Sorter::operator()(int64_t* HWY_RESTRICT keys, size_t n,
+ SortDescending tag) const {
+ VQSort(keys, n, tag);
+}
+
+} // namespace hwy
+#endif // HWY_ONCE
diff --git a/third_party/highway/hwy/contrib/sort/vqsort_kv128a.cc b/third_party/highway/hwy/contrib/sort/vqsort_kv128a.cc
new file mode 100644
index 0000000000..976d9906fb
--- /dev/null
+++ b/third_party/highway/hwy/contrib/sort/vqsort_kv128a.cc
@@ -0,0 +1,101 @@
+// Copyright 2021 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "third_party/highway/hwy/contrib/sort/vqsort.h" // VQSort
+
+#undef HWY_TARGET_INCLUDE
+// clang-format off
+// (avoid line break, which would prevent Copybara rules from matching)
+#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_kv128a.cc" //NOLINT
+// clang-format on
+#include "third_party/highway/hwy/foreach_target.h" // IWYU pragma: keep
+
+// After foreach_target
+#include "third_party/highway/hwy/contrib/sort/vqsort-inl.h"
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+namespace {
+
+void SortKV128Asc(K64V64* HWY_RESTRICT keys, const size_t num) {
+ // 128-bit keys require 128-bit SIMD.
+#if HWY_TARGET != HWY_SCALAR
+ return VQSortStatic(keys, num, SortAscending());
+#else
+ (void)keys;
+ (void)num;
+#endif
+}
+
+void PartialSortKV128Asc(K64V64* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+ // 128-bit keys require 128-bit SIMD.
+#if HWY_TARGET != HWY_SCALAR
+ return VQPartialSortStatic(keys, num, k, SortAscending());
+#else
+ (void)keys;
+ (void)num;
+ (void)k;
+#endif
+}
+
+void SelectKV128Asc(K64V64* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+ // 128-bit keys require 128-bit SIMD.
+#if HWY_TARGET != HWY_SCALAR
+ return VQSelectStatic(keys, num, k, SortAscending());
+#else
+ (void)keys;
+ (void)num;
+ (void)k;
+#endif
+}
+
+} // namespace
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace hwy {
+namespace {
+HWY_EXPORT(SortKV128Asc);
+HWY_EXPORT(PartialSortKV128Asc);
+HWY_EXPORT(SelectKV128Asc);
+} // namespace
+
+void VQSort(K64V64* HWY_RESTRICT keys, const size_t n, SortAscending) {
+ HWY_DYNAMIC_DISPATCH(SortKV128Asc)(keys, n);
+}
+
+void VQPartialSort(K64V64* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortAscending) {
+ HWY_DYNAMIC_DISPATCH(PartialSortKV128Asc)(keys, n, k);
+}
+
+void VQSelect(K64V64* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortAscending) {
+ HWY_DYNAMIC_DISPATCH(SelectKV128Asc)(keys, n, k);
+}
+
+void Sorter::operator()(K64V64* HWY_RESTRICT keys, size_t n,
+ SortAscending tag) const {
+ VQSort(keys, n, tag);
+}
+
+} // namespace hwy
+#endif // HWY_ONCE
diff --git a/third_party/highway/hwy/contrib/sort/vqsort_kv128d.cc b/third_party/highway/hwy/contrib/sort/vqsort_kv128d.cc
new file mode 100644
index 0000000000..a8f1680cd7
--- /dev/null
+++ b/third_party/highway/hwy/contrib/sort/vqsort_kv128d.cc
@@ -0,0 +1,101 @@
+// Copyright 2021 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "third_party/highway/hwy/contrib/sort/vqsort.h" // VQSort
+
+#undef HWY_TARGET_INCLUDE
+// clang-format off
+// (avoid line break, which would prevent Copybara rules from matching)
+#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_kv128d.cc" //NOLINT
+// clang-format on
+#include "third_party/highway/hwy/foreach_target.h" // IWYU pragma: keep
+
+// After foreach_target
+#include "third_party/highway/hwy/contrib/sort/vqsort-inl.h"
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+namespace {
+
+void SortKV128Desc(K64V64* HWY_RESTRICT keys, const size_t num) {
+ // 128-bit keys require 128-bit SIMD.
+#if HWY_TARGET != HWY_SCALAR
+ return VQSortStatic(keys, num, SortDescending());
+#else
+ (void)keys;
+ (void)num;
+#endif
+}
+
+void PartialSortKV128Desc(K64V64* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+ // 128-bit keys require 128-bit SIMD.
+#if HWY_TARGET != HWY_SCALAR
+ return VQPartialSortStatic(keys, num, k, SortDescending());
+#else
+ (void)keys;
+ (void)num;
+ (void)k;
+#endif
+}
+
+void SelectKV128Desc(K64V64* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+ // 128-bit keys require 128-bit SIMD.
+#if HWY_TARGET != HWY_SCALAR
+ return VQSelectStatic(keys, num, k, SortDescending());
+#else
+ (void)keys;
+ (void)num;
+ (void)k;
+#endif
+}
+
+} // namespace
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace hwy {
+namespace {
+HWY_EXPORT(SortKV128Desc);
+HWY_EXPORT(PartialSortKV128Desc);
+HWY_EXPORT(SelectKV128Desc);
+} // namespace
+
+void VQSort(K64V64* HWY_RESTRICT keys, const size_t n, SortDescending) {
+ HWY_DYNAMIC_DISPATCH(SortKV128Desc)(keys, n);
+}
+
+void VQPartialSort(K64V64* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortDescending) {
+ HWY_DYNAMIC_DISPATCH(PartialSortKV128Desc)(keys, n, k);
+}
+
+void VQSelect(K64V64* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortDescending) {
+ HWY_DYNAMIC_DISPATCH(SelectKV128Desc)(keys, n, k);
+}
+
+void Sorter::operator()(K64V64* HWY_RESTRICT keys, size_t n,
+ SortDescending tag) const {
+ VQSort(keys, n, tag);
+}
+
+} // namespace hwy
+#endif // HWY_ONCE
diff --git a/third_party/highway/hwy/contrib/sort/vqsort_kv64a.cc b/third_party/highway/hwy/contrib/sort/vqsort_kv64a.cc
new file mode 100644
index 0000000000..e740c97db4
--- /dev/null
+++ b/third_party/highway/hwy/contrib/sort/vqsort_kv64a.cc
@@ -0,0 +1,81 @@
+// Copyright 2022 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "third_party/highway/hwy/contrib/sort/vqsort.h" // VQSort
+
+#undef HWY_TARGET_INCLUDE
+// clang-format off
+// (avoid line break, which would prevent Copybara rules from matching)
+#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_kv64a.cc" //NOLINT
+// clang-format on
+#include "third_party/highway/hwy/foreach_target.h" // IWYU pragma: keep
+
+// After foreach_target
+#include "third_party/highway/hwy/contrib/sort/vqsort-inl.h"
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+namespace {
+
+void SortKV64Asc(K32V32* HWY_RESTRICT keys, const size_t num) {
+ return VQSortStatic(keys, num, SortAscending());
+}
+
+void PartialSortKV64Asc(K32V32* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+ return VQPartialSortStatic(keys, num, k, SortAscending());
+}
+
+void SelectKV64Asc(K32V32* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+ return VQSelectStatic(keys, num, k, SortAscending());
+}
+
+} // namespace
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace hwy {
+namespace {
+HWY_EXPORT(SortKV64Asc);
+HWY_EXPORT(PartialSortKV64Asc);
+HWY_EXPORT(SelectKV64Asc);
+} // namespace
+
+void VQSort(K32V32* HWY_RESTRICT keys, const size_t n, SortAscending) {
+ HWY_DYNAMIC_DISPATCH(SortKV64Asc)(keys, n);
+}
+
+void VQPartialSort(K32V32* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortAscending) {
+ HWY_DYNAMIC_DISPATCH(PartialSortKV64Asc)(keys, n, k);
+}
+
+void VQSelect(K32V32* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortAscending) {
+ HWY_DYNAMIC_DISPATCH(SelectKV64Asc)(keys, n, k);
+}
+
+void Sorter::operator()(K32V32* HWY_RESTRICT keys, size_t n,
+ SortAscending tag) const {
+ VQSort(keys, n, tag);
+}
+
+} // namespace hwy
+#endif // HWY_ONCE
diff --git a/third_party/highway/hwy/contrib/sort/vqsort_kv64d.cc b/third_party/highway/hwy/contrib/sort/vqsort_kv64d.cc
new file mode 100644
index 0000000000..c52f69037c
--- /dev/null
+++ b/third_party/highway/hwy/contrib/sort/vqsort_kv64d.cc
@@ -0,0 +1,81 @@
+// Copyright 2022 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "third_party/highway/hwy/contrib/sort/vqsort.h" // VQSort
+
+#undef HWY_TARGET_INCLUDE
+// clang-format off
+// (avoid line break, which would prevent Copybara rules from matching)
+#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_kv64d.cc" //NOLINT
+// clang-format on
+#include "third_party/highway/hwy/foreach_target.h" // IWYU pragma: keep
+
+// After foreach_target
+#include "third_party/highway/hwy/contrib/sort/vqsort-inl.h"
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+namespace {
+
+void SortKV64Desc(K32V32* HWY_RESTRICT keys, const size_t num) {
+ return VQSortStatic(keys, num, SortDescending());
+}
+
+void PartialSortKV64Desc(K32V32* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+ return VQPartialSortStatic(keys, num, k, SortDescending());
+}
+
+void SelectKV64Desc(K32V32* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+ return VQSelectStatic(keys, num, k, SortDescending());
+}
+
+} // namespace
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace hwy {
+namespace {
+HWY_EXPORT(SortKV64Desc);
+HWY_EXPORT(PartialSortKV64Desc);
+HWY_EXPORT(SelectKV64Desc);
+} // namespace
+
+void VQSort(K32V32* HWY_RESTRICT keys, const size_t n, SortDescending) {
+ HWY_DYNAMIC_DISPATCH(SortKV64Desc)(keys, n);
+}
+
+void VQPartialSort(K32V32* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortDescending) {
+ HWY_DYNAMIC_DISPATCH(PartialSortKV64Desc)(keys, n, k);
+}
+
+void VQSelect(K32V32* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortDescending) {
+ HWY_DYNAMIC_DISPATCH(SelectKV64Desc)(keys, n, k);
+}
+
+void Sorter::operator()(K32V32* HWY_RESTRICT keys, size_t n,
+ SortDescending tag) const {
+ VQSort(keys, n, tag);
+}
+
+} // namespace hwy
+#endif // HWY_ONCE
diff --git a/third_party/highway/hwy/contrib/sort/vqsort_u16a.cc b/third_party/highway/hwy/contrib/sort/vqsort_u16a.cc
new file mode 100644
index 0000000000..9f0a40c224
--- /dev/null
+++ b/third_party/highway/hwy/contrib/sort/vqsort_u16a.cc
@@ -0,0 +1,78 @@
+// Copyright 2021 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "third_party/highway/hwy/contrib/sort/vqsort.h" // VQSort
+
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_u16a.cc"
+#include "third_party/highway/hwy/foreach_target.h" // IWYU pragma: keep
+
+// After foreach_target
+#include "third_party/highway/hwy/contrib/sort/vqsort-inl.h"
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+namespace {
+
+void SortU16Asc(uint16_t* HWY_RESTRICT keys, const size_t num) {
+ return VQSortStatic(keys, num, SortAscending());
+}
+
+void PartialSortU16Asc(uint16_t* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+ return VQPartialSortStatic(keys, num, k, SortAscending());
+}
+
+void SelectU16Asc(uint16_t* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+ return VQSelectStatic(keys, num, k, SortAscending());
+}
+
+} // namespace
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace hwy {
+namespace {
+HWY_EXPORT(SortU16Asc);
+HWY_EXPORT(PartialSortU16Asc);
+HWY_EXPORT(SelectU16Asc);
+} // namespace
+
+void VQSort(uint16_t* HWY_RESTRICT keys, const size_t n, SortAscending) {
+ HWY_DYNAMIC_DISPATCH(SortU16Asc)(keys, n);
+}
+
+void VQPartialSort(uint16_t* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortAscending) {
+ HWY_DYNAMIC_DISPATCH(PartialSortU16Asc)(keys, n, k);
+}
+
+void VQSelect(uint16_t* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortAscending) {
+ HWY_DYNAMIC_DISPATCH(SelectU16Asc)(keys, n, k);
+}
+
+void Sorter::operator()(uint16_t* HWY_RESTRICT keys, size_t n,
+ SortAscending tag) const {
+ VQSort(keys, n, tag);
+}
+
+} // namespace hwy
+#endif // HWY_ONCE
diff --git a/third_party/highway/hwy/contrib/sort/vqsort_u16d.cc b/third_party/highway/hwy/contrib/sort/vqsort_u16d.cc
new file mode 100644
index 0000000000..9639c86c3b
--- /dev/null
+++ b/third_party/highway/hwy/contrib/sort/vqsort_u16d.cc
@@ -0,0 +1,78 @@
+// Copyright 2021 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "third_party/highway/hwy/contrib/sort/vqsort.h" // VQSort
+
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_u16d.cc"
+#include "third_party/highway/hwy/foreach_target.h" // IWYU pragma: keep
+
+// After foreach_target
+#include "third_party/highway/hwy/contrib/sort/vqsort-inl.h"
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+namespace {
+
+void SortU16Desc(uint16_t* HWY_RESTRICT keys, const size_t num) {
+ return VQSortStatic(keys, num, SortDescending());
+}
+
+void PartialSortU16Desc(uint16_t* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+ return VQPartialSortStatic(keys, num, k, SortDescending());
+}
+
+void SelectU16Desc(uint16_t* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+ return VQSelectStatic(keys, num, k, SortDescending());
+}
+
+} // namespace
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace hwy {
+namespace {
+HWY_EXPORT(SortU16Desc);
+HWY_EXPORT(PartialSortU16Desc);
+HWY_EXPORT(SelectU16Desc);
+} // namespace
+
+void VQSort(uint16_t* HWY_RESTRICT keys, const size_t n, SortDescending) {
+ HWY_DYNAMIC_DISPATCH(SortU16Desc)(keys, n);
+}
+
+void VQPartialSort(uint16_t* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortDescending) {
+ HWY_DYNAMIC_DISPATCH(PartialSortU16Desc)(keys, n, k);
+}
+
+void VQSelect(uint16_t* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortDescending) {
+ HWY_DYNAMIC_DISPATCH(SelectU16Desc)(keys, n, k);
+}
+
+void Sorter::operator()(uint16_t* HWY_RESTRICT keys, size_t n,
+ SortDescending tag) const {
+ VQSort(keys, n, tag);
+}
+
+} // namespace hwy
+#endif // HWY_ONCE
diff --git a/third_party/highway/hwy/contrib/sort/vqsort_u32a.cc b/third_party/highway/hwy/contrib/sort/vqsort_u32a.cc
new file mode 100644
index 0000000000..183b2ffb88
--- /dev/null
+++ b/third_party/highway/hwy/contrib/sort/vqsort_u32a.cc
@@ -0,0 +1,78 @@
+// Copyright 2021 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "third_party/highway/hwy/contrib/sort/vqsort.h" // VQSort
+
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_u32a.cc"
+#include "third_party/highway/hwy/foreach_target.h" // IWYU pragma: keep
+
+// After foreach_target
+#include "third_party/highway/hwy/contrib/sort/vqsort-inl.h"
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+namespace {
+
+void SortU32Asc(uint32_t* HWY_RESTRICT keys, const size_t num) {
+ return VQSortStatic(keys, num, SortAscending());
+}
+
+void PartialSortU32Asc(uint32_t* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+ return VQPartialSortStatic(keys, num, k, SortAscending());
+}
+
+void SelectU32Asc(uint32_t* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+ return VQSelectStatic(keys, num, k, SortAscending());
+}
+
+} // namespace
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace hwy {
+namespace {
+HWY_EXPORT(SortU32Asc);
+HWY_EXPORT(PartialSortU32Asc);
+HWY_EXPORT(SelectU32Asc);
+} // namespace
+
+void VQSort(uint32_t* HWY_RESTRICT keys, const size_t n, SortAscending) {
+ HWY_DYNAMIC_DISPATCH(SortU32Asc)(keys, n);
+}
+
+void VQPartialSort(uint32_t* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortAscending) {
+ HWY_DYNAMIC_DISPATCH(PartialSortU32Asc)(keys, n, k);
+}
+
+void VQSelect(uint32_t* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortAscending) {
+ HWY_DYNAMIC_DISPATCH(SelectU32Asc)(keys, n, k);
+}
+
+void Sorter::operator()(uint32_t* HWY_RESTRICT keys, size_t n,
+ SortAscending tag) const {
+ VQSort(keys, n, tag);
+}
+
+} // namespace hwy
+#endif // HWY_ONCE
diff --git a/third_party/highway/hwy/contrib/sort/vqsort_u32d.cc b/third_party/highway/hwy/contrib/sort/vqsort_u32d.cc
new file mode 100644
index 0000000000..c3f77704af
--- /dev/null
+++ b/third_party/highway/hwy/contrib/sort/vqsort_u32d.cc
@@ -0,0 +1,78 @@
+// Copyright 2021 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "third_party/highway/hwy/contrib/sort/vqsort.h" // VQSort
+
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_u32d.cc"
+#include "third_party/highway/hwy/foreach_target.h" // IWYU pragma: keep
+
+// After foreach_target
+#include "third_party/highway/hwy/contrib/sort/vqsort-inl.h"
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+namespace {
+
+void SortU32Desc(uint32_t* HWY_RESTRICT keys, const size_t num) {
+ return VQSortStatic(keys, num, SortDescending());
+}
+
+void PartialSortU32Desc(uint32_t* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+ return VQPartialSortStatic(keys, num, k, SortDescending());
+}
+
+void SelectU32Desc(uint32_t* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+ return VQSelectStatic(keys, num, k, SortDescending());
+}
+
+} // namespace
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace hwy {
+namespace {
+HWY_EXPORT(SortU32Desc);
+HWY_EXPORT(PartialSortU32Desc);
+HWY_EXPORT(SelectU32Desc);
+} // namespace
+
+void VQSort(uint32_t* HWY_RESTRICT keys, const size_t n, SortDescending) {
+ HWY_DYNAMIC_DISPATCH(SortU32Desc)(keys, n);
+}
+
+void VQPartialSort(uint32_t* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortDescending) {
+ HWY_DYNAMIC_DISPATCH(PartialSortU32Desc)(keys, n, k);
+}
+
+void VQSelect(uint32_t* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortDescending) {
+ HWY_DYNAMIC_DISPATCH(SelectU32Desc)(keys, n, k);
+}
+
+void Sorter::operator()(uint32_t* HWY_RESTRICT keys, size_t n,
+ SortDescending tag) const {
+ VQSort(keys, n, tag);
+}
+
+} // namespace hwy
+#endif // HWY_ONCE
diff --git a/third_party/highway/hwy/contrib/sort/vqsort_u64a.cc b/third_party/highway/hwy/contrib/sort/vqsort_u64a.cc
new file mode 100644
index 0000000000..e14cfe7a18
--- /dev/null
+++ b/third_party/highway/hwy/contrib/sort/vqsort_u64a.cc
@@ -0,0 +1,78 @@
+// Copyright 2021 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "third_party/highway/hwy/contrib/sort/vqsort.h" // VQSort
+
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_u64a.cc"
+#include "third_party/highway/hwy/foreach_target.h" // IWYU pragma: keep
+
+// After foreach_target
+#include "third_party/highway/hwy/contrib/sort/vqsort-inl.h"
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+namespace {
+
+void SortU64Asc(uint64_t* HWY_RESTRICT keys, const size_t num) {
+ return VQSortStatic(keys, num, SortAscending());
+}
+
+void PartialSortU64Asc(uint64_t* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+ return VQPartialSortStatic(keys, num, k, SortAscending());
+}
+
+void SelectU64Asc(uint64_t* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+ return VQSelectStatic(keys, num, k, SortAscending());
+}
+
+} // namespace
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace hwy {
+namespace {
+HWY_EXPORT(SortU64Asc);
+HWY_EXPORT(PartialSortU64Asc);
+HWY_EXPORT(SelectU64Asc);
+} // namespace
+
+void VQSort(uint64_t* HWY_RESTRICT keys, const size_t n, SortAscending) {
+ HWY_DYNAMIC_DISPATCH(SortU64Asc)(keys, n);
+}
+
+void VQPartialSort(uint64_t* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortAscending) {
+ HWY_DYNAMIC_DISPATCH(PartialSortU64Asc)(keys, n, k);
+}
+
+void VQSelect(uint64_t* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortAscending) {
+ HWY_DYNAMIC_DISPATCH(SelectU64Asc)(keys, n, k);
+}
+
+void Sorter::operator()(uint64_t* HWY_RESTRICT keys, size_t n,
+ SortAscending tag) const {
+ VQSort(keys, n, tag);
+}
+
+} // namespace hwy
+#endif // HWY_ONCE
diff --git a/third_party/highway/hwy/contrib/sort/vqsort_u64d.cc b/third_party/highway/hwy/contrib/sort/vqsort_u64d.cc
new file mode 100644
index 0000000000..a8dc6c9aa3
--- /dev/null
+++ b/third_party/highway/hwy/contrib/sort/vqsort_u64d.cc
@@ -0,0 +1,78 @@
+// Copyright 2021 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "third_party/highway/hwy/contrib/sort/vqsort.h" // VQSort
+
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "hwy/contrib/sort/vqsort_u64d.cc"
+#include "third_party/highway/hwy/foreach_target.h" // IWYU pragma: keep
+
+// After foreach_target
+#include "third_party/highway/hwy/contrib/sort/vqsort-inl.h"
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+namespace {
+
+void SortU64Desc(uint64_t* HWY_RESTRICT keys, const size_t num) {
+ return VQSortStatic(keys, num, SortDescending());
+}
+
+void PartialSortU64Desc(uint64_t* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+ return VQPartialSortStatic(keys, num, k, SortDescending());
+}
+
+void SelectU64Desc(uint64_t* HWY_RESTRICT keys, const size_t num,
+ const size_t k) {
+ return VQSelectStatic(keys, num, k, SortDescending());
+}
+
+} // namespace
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace hwy {
+namespace {
+HWY_EXPORT(SortU64Desc);
+HWY_EXPORT(PartialSortU64Desc);
+HWY_EXPORT(SelectU64Desc);
+} // namespace
+
+void VQSort(uint64_t* HWY_RESTRICT keys, const size_t n, SortDescending) {
+ HWY_DYNAMIC_DISPATCH(SortU64Desc)(keys, n);
+}
+
+void VQPartialSort(uint64_t* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortDescending) {
+ HWY_DYNAMIC_DISPATCH(PartialSortU64Desc)(keys, n, k);
+}
+
+void VQSelect(uint64_t* HWY_RESTRICT keys, const size_t n, const size_t k,
+ SortDescending) {
+ HWY_DYNAMIC_DISPATCH(SelectU64Desc)(keys, n, k);
+}
+
+void Sorter::operator()(uint64_t* HWY_RESTRICT keys, size_t n,
+ SortDescending tag) const {
+ VQSort(keys, n, tag);
+}
+
+} // namespace hwy
+#endif // HWY_ONCE
diff --git a/third_party/highway/hwy/contrib/thread_pool/futex.h b/third_party/highway/hwy/contrib/thread_pool/futex.h
index 740cbd23fc..d362d34f14 100644
--- a/third_party/highway/hwy/contrib/thread_pool/futex.h
+++ b/third_party/highway/hwy/contrib/thread_pool/futex.h
@@ -21,8 +21,8 @@
// use with shared-memory mappings).
//
// Futex equivalents: https://outerproduct.net/futex-dictionary.html; we
-// support Linux/Emscripten/Apple/Windows and C++20 std::atomic::wait, plus a
-// NanoSleep fallback.
+// support Linux/Emscripten/FreeBSD/Apple/Windows and C++20 std::atomic::wait,
+// plus a NanoSleep fallback.
#include <time.h>
@@ -31,6 +31,26 @@
#include "third_party/highway/hwy/base.h"
+#if HWY_OS_APPLE
+#include <AvailabilityMacros.h>
+// __ulock* were added in OS X 10.12 (Sierra, 2016).
+#if MAC_OS_X_VERSION_MAX_ALLOWED < 101200 && !defined(HWY_DISABLE_FUTEX)
+#define HWY_DISABLE_FUTEX
+#endif
+#endif // HWY_OS_APPLE
+
+#if HWY_OS_WIN
+// Need to include <windows.h> on Windows, even if HWY_DISABLE_FUTEX is defined,
+// since hwy::NanoSleep uses Windows API's that are defined in windows.h.
+#ifndef NOMINMAX
+#define NOMINMAX
+#endif // NOMINMAX
+#ifndef WIN32_LEAN_AND_MEAN
+#define WIN32_LEAN_AND_MEAN
+#endif // WIN32_LEAN_AND_MEAN
+#include <windows.h>
+#endif
+
#if HWY_ARCH_WASM
#include <emscripten/threading.h>
#include <math.h> // INFINITY
@@ -56,6 +76,16 @@
#define FUTEX_WAKE_PRIVATE (FUTEX_WAKE | 128)
#endif
+#elif HWY_OS_FREEBSD && !defined(HWY_DISABLE_FUTEX)
+#include <sys/param.h> // __FreeBSD_version
+#if __FreeBSD_version >= 600000
+#include <errno.h>
+#include <sys/types.h>
+#include <sys/umtx.h>
+#else
+#define HWY_DISABLE_FUTEX
+#endif
+
#elif HWY_OS_APPLE && !defined(HWY_DISABLE_FUTEX)
// These are private APIs, so add an opt-out.
extern "C" {
@@ -67,13 +97,6 @@ int __ulock_wake(uint32_t op, void* address, uint64_t zero);
#elif HWY_OS_WIN && !defined(HWY_DISABLE_FUTEX)
// WakeByAddressAll requires Windows 8, so add an opt-out.
-#ifndef NOMINMAX
-#define NOMINMAX
-#endif // NOMINMAX
-#ifndef WIN32_LEAN_AND_MEAN
-#define WIN32_LEAN_AND_MEAN
-#endif // WIN32_LEAN_AND_MEAN
-#include <windows.h>
#if HWY_COMPILER_MSVC || HWY_COMPILER_CLANGCL
#pragma comment(lib, "synchronization.lib")
#endif
@@ -159,6 +182,19 @@ static inline uint32_t BlockUntilDifferent(
}
}
+#elif HWY_OS_FREEBSD && !defined(HWY_DISABLE_FUTEX) // >= 6.0
+ // _umtx_op with UMTX_OP_WAIT_UINT_PRIVATE: process-private futex on FreeBSD.
+ void* address = const_cast<void*>(static_cast<const void*>(¤t));
+ for (;;) {
+ const uint32_t next = current.load(acq);
+ if (next != prev) return next;
+ const int ret = _umtx_op(address, UMTX_OP_WAIT_UINT_PRIVATE,
+ static_cast<u_long>(prev), nullptr, nullptr);
+ if (ret == -1) {
+ HWY_DASSERT(errno == EAGAIN || errno == EINTR);
+ }
+ }
+
#elif HWY_OS_WIN && !defined(HWY_DISABLE_FUTEX)
// It is always safe to cast to void.
volatile void* address =
@@ -221,6 +257,13 @@ static inline void WakeAll(std::atomic<uint32_t>& current) {
HWY_DASSERT(ret >= 0); // number woken
(void)ret;
+#elif HWY_OS_FREEBSD && !defined(HWY_DISABLE_FUTEX) // >= 6.0
+ void* address = static_cast<void*>(¤t);
+ const int ret = _umtx_op(address, UMTX_OP_WAKE_PRIVATE, INT_MAX, nullptr,
+ nullptr);
+ HWY_DASSERT(ret >= 0);
+ (void)ret;
+
#elif HWY_OS_WIN && !defined(HWY_DISABLE_FUTEX)
// It is always safe to cast to void.
void* address = static_cast<void*>(¤t);
diff --git a/third_party/highway/hwy/contrib/thread_pool/spin.h b/third_party/highway/hwy/contrib/thread_pool/spin.h
index 57973a7610..dd4a74df01 100644
--- a/third_party/highway/hwy/contrib/thread_pool/spin.h
+++ b/third_party/highway/hwy/contrib/thread_pool/spin.h
@@ -82,25 +82,32 @@ struct SpinResult {
// `HWY_TARGET` and its runtime dispatch mechanism. Returned by `Type()`, also
// used by callers to set the `disabled` argument for `DetectSpin`.
enum class SpinType : uint8_t {
+#if HWY_ENABLE_MONITORX
kMonitorX = 1, // AMD
- kUMonitor, // Intel
- kPause,
+#endif
+#if HWY_ENABLE_UMONITOR
+ kUMonitor = 2, // Intel
+#endif
+ kPause = 3,
kSentinel // for iterating over all enumerators. Must be last.
};
// For printing which is in use.
static inline const char* ToString(SpinType type) {
switch (type) {
+#if HWY_ENABLE_MONITORX
case SpinType::kMonitorX:
return "MonitorX_C1";
+#endif
+#if HWY_ENABLE_UMONITOR
case SpinType::kUMonitor:
return "UMonitor_C0.2";
+#endif
case SpinType::kPause:
return "Pause";
case SpinType::kSentinel:
- return nullptr;
default:
- HWY_UNREACHABLE;
+ return nullptr;
}
}
@@ -276,9 +283,10 @@ HWY_POP_ATTRIBUTES
// Ignores `disabled` for `kPause` if it is the only supported and enabled type.
// Somewhat expensive, typically called during initialization.
static inline SpinType DetectSpin(int disabled = 0) {
- const auto HWY_MAYBE_UNUSED enabled = [disabled](SpinType type) {
+ const auto enabled = [disabled](SpinType type) {
return (disabled & (1 << static_cast<int>(type))) == 0;
};
+ (void)enabled;
#if HWY_ENABLE_MONITORX
if (enabled(SpinType::kMonitorX) && x86::IsAMD()) {
@@ -302,23 +310,23 @@ static inline SpinType DetectSpin(int disabled = 0) {
return SpinType::kPause;
}
-// Calls `func(spin)` for the given `spin_type`.
-template <class Func>
-HWY_INLINE void CallWithSpin(SpinType spin_type, Func&& func) {
+// Calls `func(spin, args)` for the given `spin_type`.
+template <class Func, typename... Args>
+HWY_INLINE void CallWithSpin(SpinType spin_type, Func&& func, Args&&... args) {
switch (spin_type) {
#if HWY_ENABLE_MONITORX
case SpinType::kMonitorX:
- func(SpinMonitorX());
+ func(SpinMonitorX(), std::forward<Args>(args)...);
break;
#endif
#if HWY_ENABLE_UMONITOR
case SpinType::kUMonitor:
- func(SpinUMonitor());
+ func(SpinUMonitor(), std::forward<Args>(args)...);
break;
#endif
case SpinType::kPause:
default:
- func(SpinPause());
+ func(SpinPause(), std::forward<Args>(args)...);
break;
}
}
diff --git a/third_party/highway/hwy/contrib/thread_pool/spin_test.cc b/third_party/highway/hwy/contrib/thread_pool/spin_test.cc
new file mode 100644
index 0000000000..eace4dad61
--- /dev/null
+++ b/third_party/highway/hwy/contrib/thread_pool/spin_test.cc
@@ -0,0 +1,116 @@
+// Copyright 2025 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "third_party/highway/hwy/contrib/thread_pool/spin.h"
+
+#include <stddef.h>
+#include <stdint.h>
+#include <stdio.h>
+
+#include <atomic>
+
+#include "third_party/highway/hwy/aligned_allocator.h" // HWY_ALIGNMENT
+#include "third_party/highway/hwy/contrib/thread_pool/futex.h" // NanoSleep
+#include "third_party/highway/hwy/contrib/thread_pool/thread_pool.h"
+#include "third_party/highway/hwy/contrib/thread_pool/topology.h"
+#include "third_party/highway/hwy/tests/hwy_gtest.h"
+#include "third_party/highway/hwy/tests/test_util-inl.h"
+#include "third_party/highway/hwy/timer.h"
+
+namespace hwy {
+namespace {
+
+struct TestPingPongT {
+ template <class Spin>
+ void operator()(const Spin& spin) const {
+ constexpr size_t kU32PerLine = HWY_ALIGNMENT / 4;
+ constexpr size_t kF64PerLine = HWY_ALIGNMENT / 8;
+ alignas(HWY_ALIGNMENT) std::atomic<uint32_t> thread_active[kU32PerLine];
+ alignas(HWY_ALIGNMENT) std::atomic<uint32_t> thread_done[kU32PerLine];
+
+ thread_active[0].store(0, std::memory_order_release);
+ thread_done[0].store(0, std::memory_order_release);
+ hwy::ThreadPool pool(1);
+ HWY_ASSERT(pool.NumWorkers() == 2);
+
+ const double t0 = hwy::platform::Now();
+ std::atomic_flag error = ATOMIC_FLAG_INIT;
+
+ alignas(HWY_ALIGNMENT) std::atomic<size_t> reps1;
+ alignas(HWY_ALIGNMENT) std::atomic<size_t> reps2;
+
+ alignas(HWY_ALIGNMENT) std::atomic<double> before_thread_done[kF64PerLine];
+ alignas(HWY_ALIGNMENT) std::atomic<double> before_thread_go[kF64PerLine];
+ alignas(HWY_ALIGNMENT) std::atomic<double> ack_thread_done[kF64PerLine];
+ alignas(HWY_ALIGNMENT) std::atomic<double> ack_thread_release[kF64PerLine];
+
+ const auto kAcq = std::memory_order_acquire;
+ const auto kRel = std::memory_order_release;
+ pool.Run(0, 2, [&](uint64_t task, size_t thread) {
+ HWY_ASSERT(task == thread);
+ if (task == 0) { // new thread
+ SpinResult result = spin.UntilDifferent(0, thread_active[0]);
+ ack_thread_release[0].store(hwy::platform::Now(), kRel);
+ reps1.store(result.reps);
+ if (!NanoSleep(20 * 1000 * 1000)) {
+ error.test_and_set();
+ }
+ before_thread_done[0].store(hwy::platform::Now(), kRel);
+ thread_done[0].store(1, kRel);
+ } else { // main thread
+ if (!NanoSleep(30 * 1000 * 1000)) {
+ error.test_and_set();
+ }
+ // Release the thread.
+ before_thread_go[0].store(hwy::platform::Now(), kRel);
+ thread_active[0].store(1, kRel);
+ // Wait for it to finish.
+ const size_t reps = spin.UntilEqual(1, thread_done[0]);
+ ack_thread_done[0].store(hwy::platform::Now(), kRel);
+ reps2.store(reps);
+ }
+ });
+
+ const double t1 = hwy::platform::Now();
+ const double elapsed = t1 - t0;
+ const double latency1 =
+ ack_thread_release[0].load(kAcq) - before_thread_go[0].load(kAcq);
+ const double latency2 =
+ ack_thread_done[0].load(kAcq) - before_thread_done[0].load(kAcq);
+ fprintf(stderr,
+ "Elapsed time: %f us; reps1=%zu, reps2=%zu, latency=%f %f us\n",
+ elapsed * 1E6, reps1.load(), reps2.load(), latency1 * 1E6,
+ latency2 * 1E6);
+ // Unless NanoSleep failed to sleep, this should take 50ms+epsilon.
+ HWY_ASSERT(error.test_and_set() || elapsed > 25E-3);
+ }
+}; // namespace hwy
+
+// Simple mutex.
+TEST(SpinTest, TestPingPong) {
+ if (!HaveThreadingSupport()) {
+ HWY_WARN("Threads not supported, skipping test\n");
+ return;
+ }
+
+ const SpinType spin_type = DetectSpin();
+ fprintf(stderr, "Spin method : %s\n", ToString(spin_type));
+ CallWithSpin(spin_type, TestPingPongT());
+}
+
+} // namespace
+} // namespace hwy
+
+HWY_TEST_MAIN();
diff --git a/third_party/highway/hwy/contrib/thread_pool/thread_pool.cc b/third_party/highway/hwy/contrib/thread_pool/thread_pool.cc
new file mode 100644
index 0000000000..d22e1cfe8b
--- /dev/null
+++ b/third_party/highway/hwy/contrib/thread_pool/thread_pool.cc
@@ -0,0 +1,31 @@
+// Copyright 2025 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "third_party/highway/hwy/contrib/thread_pool/thread_pool.h"
+
+#include "third_party/highway/hwy/highway_export.h"
+
+namespace hwy {
+namespace pool {
+
+// TODO: move implementation here.
+
+HWY_CONTRIB_DLLEXPORT Shared& Shared::Get() {
+ static Shared* shared = new Shared();
+ return *shared;
+}
+
+} // namespace pool
+} // namespace hwy
diff --git a/third_party/highway/hwy/contrib/thread_pool/thread_pool.h b/third_party/highway/hwy/contrib/thread_pool/thread_pool.h
index d3516174ae..81599ba157 100644
--- a/third_party/highway/hwy/contrib/thread_pool/thread_pool.h
+++ b/third_party/highway/hwy/contrib/thread_pool/thread_pool.h
@@ -23,6 +23,7 @@
#include <stddef.h>
#include <stdint.h>
#include <stdio.h> // snprintf
+#include <string.h>
#include <array>
#include <atomic>
@@ -30,11 +31,6 @@
#include <thread> // NOLINT
#include <vector>
-#include "third_party/highway/hwy/detect_compiler_arch.h"
-#if HWY_OS_FREEBSD
-#include <pthread_np.h>
-#endif
-
#include "third_party/highway/hwy/aligned_allocator.h" // HWY_ALIGNMENT
#include "third_party/highway/hwy/auto_tune.h"
#include "third_party/highway/hwy/base.h"
@@ -42,11 +38,19 @@
#include "third_party/highway/hwy/contrib/thread_pool/futex.h"
#include "third_party/highway/hwy/contrib/thread_pool/spin.h"
#include "third_party/highway/hwy/contrib/thread_pool/topology.h"
+#include "third_party/highway/hwy/profiler.h"
#include "third_party/highway/hwy/stats.h"
#include "third_party/highway/hwy/timer.h"
-// Define to HWY_NOINLINE to see profiles of `WorkerRun*` and waits.
-#define HWY_POOL_PROFILE
+#if HWY_OS_APPLE
+#include <AvailabilityMacros.h>
+#endif
+
+#if PROFILER_ENABLED
+#include <algorithm> // std::sort
+
+#include "third_party/highway/hwy/bit_set.h"
+#endif
namespace hwy {
@@ -59,19 +63,67 @@ static inline void SetThreadName(const char* format, int thread) {
HWY_ASSERT(0 < chars_written &&
chars_written <= static_cast<int>(sizeof(buf) - 1));
-#if HWY_OS_LINUX && (!defined(__ANDROID__) || __ANDROID_API__ >= 19)
+#if (HWY_OS_LINUX && (!defined(__ANDROID__) || __ANDROID_API__ >= 19)) || \
+ HWY_OS_FREEBSD
+ // Note that FreeBSD pthread_set_name_np does not return a value (#2669).
HWY_ASSERT(0 == pthread_setname_np(pthread_self(), buf));
-#elif HWY_OS_FREEBSD
- HWY_ASSERT(0 == pthread_set_name_np(pthread_self(), buf));
-#elif HWY_OS_APPLE
+#elif HWY_OS_APPLE && (MAC_OS_X_VERSION_MIN_REQUIRED >= 1060)
// Different interface: single argument, current thread only.
HWY_ASSERT(0 == pthread_setname_np(buf));
+#elif defined(__EMSCRIPTEN__)
+ emscripten_set_thread_name(pthread_self(), buf);
+#else
+ (void)format;
+ (void)thread;
#endif
}
// Whether workers should block or spin.
enum class PoolWaitMode : uint8_t { kBlock = 1, kSpin };
+enum class Exit : uint32_t { kNone, kLoop, kThread };
+
+// Upper bound on non-empty `ThreadPool` (single-worker pools do not count).
+// Turin has 16 clusters. Add one for the across-cluster pool.
+HWY_INLINE_VAR constexpr size_t kMaxClusters = 32 + 1;
+
+// Use the last slot so that `PoolWorkerMapping` does not have to know the
+// total number of clusters.
+HWY_INLINE_VAR constexpr size_t kAllClusters = kMaxClusters - 1;
+
+// Argument to `ThreadPool`: how to map local worker_idx to global.
+class PoolWorkerMapping {
+ public:
+ // Backward-compatible mode: returns local worker index.
+ PoolWorkerMapping() : cluster_idx_(0), max_cluster_workers_(0) {}
+ PoolWorkerMapping(size_t cluster_idx, size_t max_cluster_workers)
+ : cluster_idx_(cluster_idx), max_cluster_workers_(max_cluster_workers) {
+ HWY_DASSERT(cluster_idx <= kAllClusters);
+ // Only use this ctor for the new global worker index mode. If this were
+ // zero, we would still return local indices.
+ HWY_DASSERT(max_cluster_workers != 0);
+ }
+
+ size_t ClusterIdx() const { return cluster_idx_; }
+ size_t MaxClusterWorkers() const { return max_cluster_workers_; }
+
+ // Returns global_idx, or unchanged local worker_idx if default-constructed.
+ size_t operator()(size_t worker_idx) const {
+ if (cluster_idx_ == kAllClusters) {
+ const size_t cluster_idx = worker_idx;
+ HWY_DASSERT(cluster_idx < kAllClusters);
+ // First index within the N-th cluster. The main thread is the first.
+ return cluster_idx * max_cluster_workers_;
+ }
+ HWY_DASSERT(max_cluster_workers_ == 0 || worker_idx < max_cluster_workers_);
+ return cluster_idx_ * max_cluster_workers_ + worker_idx;
+ }
+
+ private:
+ size_t cluster_idx_;
+ size_t max_cluster_workers_;
+};
+
namespace pool {
#ifndef HWY_POOL_VERBOSITY
@@ -84,7 +136,7 @@ static constexpr int kVerbosity = HWY_POOL_VERBOSITY;
// large pool, we assume applications create multiple pools, ideally per
// cluster (cores sharing a cache), because this improves locality and barrier
// latency. In that case, this is a generous upper bound.
-static constexpr size_t kMaxThreads = 63;
+static constexpr size_t kMaxThreads = 127;
// Generates a random permutation of [0, size). O(1) storage.
class ShuffledIota {
@@ -150,7 +202,7 @@ class ShuffledIota {
};
// 'Policies' suitable for various worker counts and locality. To define a
-// new class, add an enum and update `ToString` plus `FunctorAddWait`. The
+// new class, add an enum and update `ToString` plus `CallWithConfig`. The
// enumerators must be contiguous so we can iterate over them.
enum class WaitType : uint8_t {
kBlock,
@@ -158,15 +210,6 @@ enum class WaitType : uint8_t {
kSpinSeparate,
kSentinel // Must be last.
};
-enum class BarrierType : uint8_t {
- kOrdered,
- kCounter1,
- kCounter2,
- kCounter4,
- kGroup2,
- kGroup4,
- kSentinel // Must be last.
-};
// For printing which is in use.
static inline const char* ToString(WaitType type) {
@@ -179,107 +222,558 @@ static inline const char* ToString(WaitType type) {
return "Separate";
case WaitType::kSentinel:
return nullptr;
- default:
- HWY_UNREACHABLE;
}
}
-static inline const char* ToString(BarrierType type) {
- switch (type) {
- case BarrierType::kOrdered:
- return "Ordered";
- case BarrierType::kCounter1:
- return "Counter1";
- case BarrierType::kCounter2:
- return "Counter2";
- case BarrierType::kCounter4:
- return "Counter4";
- case BarrierType::kGroup2:
- return "Group2";
- case BarrierType::kGroup4:
- return "Group4";
- case BarrierType::kSentinel:
- return nullptr;
- default:
- HWY_UNREACHABLE;
- }
-}
-
-// We want predictable struct/class sizes so we can reason about cache lines.
-#pragma pack(push, 1)
-
// Parameters governing the main and worker thread behavior. Can be updated at
-// runtime via `SetWaitMode`. Both have copies which are carefully synchronized
-// (two-phase barrier). 64-bit allows adding fields (e.g. for load-balancing)
-// without having to bit-pack members, and is fine because this is only moved
-// with relaxed stores, hence we do not have to fit it in the 32 futex bits.
-class Config { // 8 bytes
- public:
- static std::vector<Config> AllCandidates(PoolWaitMode wait_mode,
- size_t num_threads) {
- std::vector<SpinType> spin_types(size_t{1}, DetectSpin());
- // Monitor-based spin may be slower, so also try Pause.
- if (spin_types[0] != SpinType::kPause) {
- spin_types.push_back(SpinType::kPause);
- }
+// runtime via `SetWaitMode`, which calls `SendConfig`. Both have copies which
+// are carefully synchronized. 32 bits leave room for two future fields.
+// 64 bits would also be fine because this does not go through futex.
+struct Config { // 4 bytes
+ static std::vector<Config> AllCandidates(PoolWaitMode wait_mode) {
+ std::vector<Config> candidates;
- std::vector<WaitType> wait_types;
if (wait_mode == PoolWaitMode::kSpin) {
+ std::vector<SpinType> spin_types;
+ spin_types.reserve(2);
+ spin_types.push_back(DetectSpin());
+ // Monitor-based spin may be slower, so also try Pause.
+ if (spin_types[0] != SpinType::kPause) {
+ spin_types.push_back(SpinType::kPause);
+ }
+
// All except `kBlock`.
+ std::vector<WaitType> wait_types;
for (size_t wait = 0;; ++wait) {
const WaitType wait_type = static_cast<WaitType>(wait);
if (wait_type == WaitType::kSentinel) break;
if (wait_type != WaitType::kBlock) wait_types.push_back(wait_type);
}
- } else {
- wait_types.push_back(WaitType::kBlock);
- }
- std::vector<BarrierType> barrier_types;
- // Note that casting an integer is UB if there is no matching enumerator,
- // but we define a sentinel to prevent this.
- for (size_t barrier = 0;; ++barrier) {
- const BarrierType barrier_type = static_cast<BarrierType>(barrier);
- if (barrier_type == BarrierType::kSentinel) break;
- // If <= 2 workers, group size of 4 is the same as 2.
- if (num_threads <= 1 && barrier_type == BarrierType::kCounter4) continue;
- if (num_threads <= 1 && barrier_type == BarrierType::kGroup4) continue;
- barrier_types.push_back(barrier_type);
- }
-
- std::vector<Config> candidates;
- candidates.reserve(50);
- for (const SpinType spin_type : spin_types) {
- for (const WaitType wait_type : wait_types) {
- for (const BarrierType barrier_type : barrier_types) {
- candidates.emplace_back(spin_type, wait_type, barrier_type);
+ candidates.reserve(spin_types.size() * wait_types.size());
+ for (const SpinType spin_type : spin_types) {
+ for (const WaitType wait_type : wait_types) {
+ candidates.emplace_back(spin_type, wait_type);
}
}
+ } else {
+ // kBlock does not use spin, so there is only one candidate.
+ candidates.emplace_back(SpinType::kPause, WaitType::kBlock);
}
+
return candidates;
}
std::string ToString() const {
char buf[128];
- snprintf(buf, sizeof(buf), "%14s %9s %9s", hwy::ToString(spin_type),
- pool::ToString(wait_type), pool::ToString(barrier_type));
+ snprintf(buf, sizeof(buf), "%-14s %-9s", hwy::ToString(spin_type),
+ pool::ToString(wait_type));
return buf;
}
- Config() {}
- Config(SpinType spin_type, WaitType wait_type, BarrierType barrier_type)
- : spin_type(spin_type),
- wait_type(wait_type),
- barrier_type(barrier_type),
- exit(false) {}
+ Config(SpinType spin_type_in, WaitType wait_type_in)
+ : spin_type(spin_type_in), wait_type(wait_type_in) {}
+ // Workers initially spin until ThreadPool sends them their actual config.
+ Config() : Config(SpinType::kPause, WaitType::kSpinSeparate) {}
SpinType spin_type;
WaitType wait_type;
- BarrierType barrier_type;
- bool exit;
- uint32_t reserved = 0;
+ HWY_MEMBER_VAR_MAYBE_UNUSED uint8_t reserved[2];
+};
+static_assert(sizeof(Config) == 4, "");
+
+#if PROFILER_ENABLED
+
+// Accumulates timings and stats from main thread and workers.
+class Stats {
+ // Up to `HWY_ALIGNMENT / 8` slots/offsets, passed to `PerThread`.
+ static constexpr size_t kDWait = 0;
+ static constexpr size_t kWaitReps = 1;
+ static constexpr size_t kTBeforeRun = 2;
+ static constexpr size_t kDRun = 3;
+ static constexpr size_t kTasksStatic = 4;
+ static constexpr size_t kTasksDynamic = 5;
+ static constexpr size_t kTasksStolen = 6;
+ static constexpr size_t kDFuncStatic = 7;
+ static constexpr size_t kDFuncDynamic = 8;
+ static constexpr size_t kSentinel = 9;
+
+ public:
+ Stats() {
+ for (size_t thread_idx = 0; thread_idx < kMaxThreads; ++thread_idx) {
+ for (size_t offset = 0; offset < kSentinel; ++offset) {
+ PerThread(thread_idx, offset) = 0;
+ }
+ }
+ Reset();
+ }
+
+ // Called by the N lowest-indexed workers that got one of the N tasks, which
+ // includes the main thread because its index is 0.
+ // `d_*` denotes "difference" (of timestamps) and thus also duration.
+ void NotifyRunStatic(size_t worker_idx, timer::Ticks d_func) {
+ if (worker_idx == 0) { // main thread
+ num_run_static_++;
+ sum_tasks_static_++;
+ sum_d_func_static_ += d_func;
+ } else {
+ const size_t thread_idx = worker_idx - 1;
+ // Defer the sums until `NotifyMainRun` to avoid atomic RMW.
+ PerThread(thread_idx, kTasksStatic)++;
+ PerThread(thread_idx, kDFuncStatic) += d_func;
+ }
+ }
+
+ // Called by all workers, including the main thread, regardless of whether
+ // they actually stole or even ran a task.
+ void NotifyRunDynamic(size_t worker_idx, size_t tasks, size_t stolen,
+ timer::Ticks d_func) {
+ if (worker_idx == 0) { // main thread
+ num_run_dynamic_++;
+ sum_tasks_dynamic_ += tasks;
+ sum_tasks_stolen_ += stolen;
+ sum_d_func_dynamic_ += d_func;
+ } else {
+ const size_t thread_idx = worker_idx - 1;
+ // Defer the sums until `NotifyMainRun` to avoid atomic RMW.
+ PerThread(thread_idx, kTasksDynamic) += tasks;
+ PerThread(thread_idx, kTasksStolen) += stolen;
+ PerThread(thread_idx, kDFuncDynamic) += d_func;
+ }
+ }
+
+ // Called concurrently by non-main worker threads after their `WorkerRun` and
+ // before the barrier.
+ void NotifyThreadRun(size_t worker_idx, timer::Ticks d_wait, size_t wait_reps,
+ timer::Ticks t_before_run, timer::Ticks d_run) {
+ HWY_DASSERT(worker_idx != 0); // Not called by main thread.
+ const size_t thread_idx = worker_idx - 1;
+ HWY_DASSERT(PerThread(thread_idx, kDWait) == 0);
+ HWY_DASSERT(PerThread(thread_idx, kWaitReps) == 0);
+ HWY_DASSERT(PerThread(thread_idx, kTBeforeRun) == 0);
+ HWY_DASSERT(PerThread(thread_idx, kDRun) == 0);
+ PerThread(thread_idx, kDWait) = d_wait;
+ PerThread(thread_idx, kWaitReps) = wait_reps;
+ PerThread(thread_idx, kTBeforeRun) = t_before_run; // For wake latency.
+ PerThread(thread_idx, kDRun) = d_run;
+ }
+
+ // Called by the main thread after the barrier, whose store-release and
+ // load-acquire publishes all prior writes. Note: only the main thread can
+ // store `after_barrier`. If workers did, which by definition happens after
+ // the barrier, then they would race with this function's reads.
+ void NotifyMainRun(size_t num_threads, timer::Ticks t_before_wake,
+ timer::Ticks d_wake, timer::Ticks d_main_run,
+ timer::Ticks d_barrier) {
+ HWY_DASSERT(num_threads <= kMaxThreads);
+
+ timer::Ticks min_d_run = ~timer::Ticks{0};
+ timer::Ticks max_d_run = 0;
+ timer::Ticks sum_d_run = 0;
+ for (size_t thread_idx = 0; thread_idx < num_threads; ++thread_idx) {
+ sum_tasks_static_ += PerThread(thread_idx, kTasksStatic);
+ sum_tasks_dynamic_ += PerThread(thread_idx, kTasksDynamic);
+ sum_tasks_stolen_ += PerThread(thread_idx, kTasksStolen);
+ sum_d_func_static_ += PerThread(thread_idx, kDFuncStatic);
+ sum_d_func_dynamic_ += PerThread(thread_idx, kDFuncDynamic);
+ sum_d_wait_ += PerThread(thread_idx, kDWait);
+ sum_wait_reps_ += PerThread(thread_idx, kWaitReps);
+ const timer::Ticks d_thread_run = PerThread(thread_idx, kDRun);
+ min_d_run = HWY_MIN(min_d_run, d_thread_run);
+ max_d_run = HWY_MAX(max_d_run, d_thread_run);
+ sum_d_run += d_thread_run;
+ const timer::Ticks t_before_run = PerThread(thread_idx, kTBeforeRun);
+
+ for (size_t offset = 0; offset < kSentinel; ++offset) {
+ PerThread(thread_idx, offset) = 0;
+ }
+
+ HWY_DASSERT(t_before_run != 0);
+ const timer::Ticks d_latency = t_before_run - t_before_wake;
+ sum_wake_latency_ += d_latency;
+ max_wake_latency_ = HWY_MAX(max_wake_latency_, d_latency);
+ }
+ const double inv_avg_d_run =
+ static_cast<double>(num_threads) / static_cast<double>(sum_d_run);
+ // Ratios of min and max run times to the average, for this pool.Run.
+ const double r_min = static_cast<double>(min_d_run) * inv_avg_d_run;
+ const double r_max = static_cast<double>(max_d_run) * inv_avg_d_run;
+
+ num_run_++; // `num_run_*` are incremented by `NotifyRun*`.
+ sum_d_run_ += sum_d_run;
+ sum_r_min_ += r_min; // For average across all pool.Run.
+ sum_r_max_ += r_max;
+
+ sum_d_wake_ += d_wake; // `*wake_latency_` are updated above.
+ sum_d_barrier_ += d_barrier;
+
+ sum_d_run_ += d_main_run;
+ sum_d_run_main_ += d_main_run;
+ }
+
+ void PrintAndReset(size_t num_threads, timer::Ticks d_thread_lifetime_ticks) {
+ // This is unconditionally called via `ProfilerFunc`. If the pool was unused
+ // in this invocation, skip it.
+ if (num_run_ == 0) return;
+ HWY_ASSERT(num_run_ == num_run_static_ + num_run_dynamic_);
+
+ const double d_func_static = Seconds(sum_d_func_static_);
+ const double d_func_dynamic = Seconds(sum_d_func_dynamic_);
+ const double sum_d_run = Seconds(sum_d_run_);
+ const double func_div_run = (d_func_static + d_func_dynamic) / sum_d_run;
+ if (!(0.95 <= func_div_run && func_div_run <= 1.0)) {
+ HWY_WARN("Func time %f should be similar to total run %f.",
+ d_func_static + d_func_dynamic, sum_d_run);
+ }
+ const double sum_d_run_main = Seconds(sum_d_run_main_);
+ const double max_wake_latency = Seconds(max_wake_latency_);
+ const double sum_d_wait = Seconds(sum_d_wait_);
+ const double d_thread_lifetime = Seconds(d_thread_lifetime_ticks);
+
+ const double inv_run = 1.0 / static_cast<double>(num_run_);
+ const auto per_run = [inv_run](double sum) { return sum * inv_run; };
+ const double avg_d_wake = per_run(Seconds(sum_d_wake_));
+ const double avg_wake_latency = per_run(Seconds(sum_wake_latency_));
+ const double avg_d_wait = per_run(sum_d_wait);
+ const double avg_wait_reps = per_run(static_cast<double>(sum_wait_reps_));
+ const double avg_d_barrier = per_run(Seconds(sum_d_barrier_));
+ const double avg_r_min = per_run(sum_r_min_);
+ const double avg_r_max = per_run(sum_r_max_);
+
+ const size_t num_workers = 1 + num_threads;
+ const double avg_tasks_static =
+ Avg(sum_tasks_static_, num_run_static_ * num_workers);
+ const double avg_tasks_dynamic =
+ Avg(sum_tasks_dynamic_, num_run_dynamic_ * num_workers);
+ const double avg_steals =
+ Avg(sum_tasks_stolen_, num_run_dynamic_ * num_workers);
+ const double avg_d_run = sum_d_run / num_workers;
+
+ const double pc_wait = sum_d_wait / d_thread_lifetime * 100.0;
+ const double pc_run = sum_d_run / d_thread_lifetime * 100.0;
+ const double pc_main = sum_d_run_main / avg_d_run * 100.0;
+
+ const auto us = [](double sec) { return sec * 1E6; };
+ const auto ns = [](double sec) { return sec * 1E9; };
+ printf(
+ "%3zu: %5d x %.2f/%5d x %4.1f tasks, %.2f steals; "
+ "wake %7.3f ns, latency %6.3f < %7.3f us, barrier %7.3f us; "
+ "wait %.1f us (%6.0f reps, %4.1f%%), balance %4.1f%%-%5.1f%%, "
+ "func: %6.3f + %7.3f, "
+ "%.1f%% of thread time %7.3f s; main:worker %5.1f%%\n",
+ num_threads, num_run_static_, avg_tasks_static, num_run_dynamic_,
+ avg_tasks_dynamic, avg_steals, ns(avg_d_wake), us(avg_wake_latency),
+ us(max_wake_latency), us(avg_d_barrier), us(avg_d_wait), avg_wait_reps,
+ pc_wait, avg_r_min * 100.0, avg_r_max * 100.0, d_func_static,
+ d_func_dynamic, pc_run, d_thread_lifetime, pc_main);
+
+ Reset(num_threads);
+ }
+
+ void Reset(size_t num_threads = kMaxThreads) {
+ num_run_ = 0;
+ num_run_static_ = 0;
+ num_run_dynamic_ = 0;
+
+ sum_tasks_stolen_ = 0;
+ sum_tasks_static_ = 0;
+ sum_tasks_dynamic_ = 0;
+
+ sum_d_wake_ = 0;
+ sum_wake_latency_ = 0;
+ max_wake_latency_ = 0;
+ sum_d_wait_ = 0;
+ sum_wait_reps_ = 0;
+ sum_d_barrier_ = 0;
+
+ sum_d_func_static_ = 0;
+ sum_d_func_dynamic_ = 0;
+ sum_r_min_ = 0.0;
+ sum_r_max_ = 0.0;
+ sum_d_run_ = 0;
+ sum_d_run_main_ = 0;
+ // ctor and `NotifyMainRun` already reset `PerThread`.
+ }
+
+ private:
+ template <typename T>
+ static double Avg(T sum, size_t div) {
+ return div == 0 ? 0.0 : static_cast<double>(sum) / static_cast<double>(div);
+ }
+
+ static constexpr size_t kU64PerLine = HWY_ALIGNMENT / sizeof(uint64_t);
+
+ uint64_t& PerThread(size_t thread_idx, size_t offset) {
+ HWY_DASSERT(thread_idx < kMaxThreads);
+ HWY_DASSERT(offset < kSentinel);
+ return per_thread_[thread_idx * kU64PerLine + offset];
+ }
+
+ int32_t num_run_;
+ int32_t num_run_static_;
+ int32_t num_run_dynamic_;
+
+ int32_t sum_tasks_stolen_;
+ int64_t sum_tasks_static_;
+ int64_t sum_tasks_dynamic_;
+
+ timer::Ticks sum_d_wake_;
+ timer::Ticks sum_wake_latency_;
+ timer::Ticks max_wake_latency_;
+ timer::Ticks sum_d_wait_;
+ uint64_t sum_wait_reps_;
+ timer::Ticks sum_d_barrier_;
+
+ timer::Ticks sum_d_func_static_;
+ timer::Ticks sum_d_func_dynamic_;
+ double sum_r_min_;
+ double sum_r_max_;
+ timer::Ticks sum_d_run_;
+ timer::Ticks sum_d_run_main_;
+
+ // One cache line per pool thread to avoid false sharing.
+ uint64_t per_thread_[kMaxThreads * kU64PerLine];
+};
+// Enables shift rather than multiplication.
+static_assert(sizeof(Stats) == (kMaxThreads + 1) * HWY_ALIGNMENT, "Wrong size");
+
+// Non-power of two to avoid 2K aliasing.
+HWY_INLINE_VAR constexpr size_t kMaxCallers = 60;
+
+// Per-caller stats, stored in `PerCluster`.
+class CallerAccumulator {
+ public:
+ bool Any() const { return calls_ != 0; }
+
+ void Add(size_t tasks, size_t workers, bool is_root, timer::Ticks wait_before,
+ timer::Ticks elapsed) {
+ calls_++;
+ root_ += is_root;
+ workers_ += workers;
+ min_tasks_ = HWY_MIN(min_tasks_, tasks);
+ max_tasks_ = HWY_MAX(max_tasks_, tasks);
+ tasks_ += tasks;
+ wait_before_ += wait_before;
+ elapsed_ += elapsed;
+ }
+
+ void AddFrom(const CallerAccumulator& other) {
+ calls_ += other.calls_;
+ root_ += other.root_;
+ workers_ += other.workers_;
+ min_tasks_ = HWY_MIN(min_tasks_, other.min_tasks_);
+ max_tasks_ = HWY_MAX(max_tasks_, other.max_tasks_);
+ tasks_ += other.tasks_;
+ wait_before_ += other.wait_before_;
+ elapsed_ += other.elapsed_;
+ }
+
+ bool operator>(const CallerAccumulator& other) const {
+ return elapsed_ > other.elapsed_;
+ }
+
+ void PrintAndReset(const char* caller, size_t active_clusters) {
+ if (!Any()) return;
+ HWY_ASSERT(root_ <= calls_);
+ const double inv_calls = 1.0 / static_cast<double>(calls_);
+ const double pc_root = static_cast<double>(root_) * inv_calls * 100.0;
+ const double avg_workers = static_cast<double>(workers_) * inv_calls;
+ const double avg_tasks = static_cast<double>(tasks_) * inv_calls;
+ const double avg_tasks_per_worker = avg_tasks / avg_workers;
+ const double inv_freq = 1.0 / platform::InvariantTicksPerSecond();
+ const double sum_wait_before = static_cast<double>(wait_before_) * inv_freq;
+ const double avg_wait_before =
+ root_ ? sum_wait_before / static_cast<double>(root_) : 0.0;
+ const double elapsed = static_cast<double>(elapsed_) * inv_freq;
+ const double avg_elapsed = elapsed * inv_calls;
+ const double task_len = avg_elapsed / avg_tasks_per_worker;
+ printf(
+ "%40s: %7.0f x (%3.0f%%) %2zu clusters, %4.1f workers @ "
+ "%5.1f tasks (%5u-%5u), "
+ "%5.0f us wait, %6.1E us run (task len %6.1E us), total %6.2f s\n",
+ caller, static_cast<double>(calls_), pc_root, active_clusters,
+ avg_workers, avg_tasks_per_worker, static_cast<uint32_t>(min_tasks_),
+ static_cast<uint32_t>(max_tasks_), avg_wait_before * 1E6,
+ avg_elapsed * 1E6, task_len * 1E6, elapsed);
+ *this = CallerAccumulator();
+ }
+
+ // For the grand total, only print calls and elapsed because averaging the
+ // the other stats is not very useful. No need to reset because this is called
+ // on a temporary.
+ void PrintTotal() {
+ if (!Any()) return;
+ HWY_ASSERT(root_ <= calls_);
+ const double elapsed =
+ static_cast<double>(elapsed_) / platform::InvariantTicksPerSecond();
+ printf("TOTAL: %7.0f x run %6.2f s\n", static_cast<double>(calls_),
+ elapsed);
+ }
+
+ private:
+ int64_t calls_ = 0;
+ int64_t root_ = 0;
+ uint64_t workers_ = 0;
+ uint64_t min_tasks_ = ~uint64_t{0};
+ uint64_t max_tasks_ = 0;
+ uint64_t tasks_ = 0;
+ // both are wall time for root Run, otherwise CPU time.
+ timer::Ticks wait_before_ = 0;
+ timer::Ticks elapsed_ = 0;
+};
+static_assert(sizeof(CallerAccumulator) == 64, "");
+
+class PerCluster {
+ public:
+ CallerAccumulator& Get(size_t caller_idx) {
+ HWY_DASSERT(caller_idx < kMaxCallers);
+ callers_.Set(caller_idx);
+ return accumulators_[caller_idx];
+ }
+
+ template <class Func>
+ void ForeachCaller(Func&& func) {
+ callers_.Foreach([&](size_t caller_idx) {
+ func(caller_idx, accumulators_[caller_idx]);
+ });
+ }
+
+ // Returns indices (required for `StringTable::Name`) in descending order of
+ // elapsed time.
+ std::vector<size_t> Sorted() {
+ std::vector<size_t> vec;
+ vec.reserve(kMaxCallers);
+ ForeachCaller([&](size_t caller_idx, CallerAccumulator&) {
+ vec.push_back(caller_idx);
+ });
+ std::sort(vec.begin(), vec.end(), [&](size_t a, size_t b) {
+ return accumulators_[a] > accumulators_[b];
+ });
+ return vec;
+ }
+
+ // Caller takes care of resetting `accumulators_`.
+ void ResetBits() { callers_ = hwy::BitSet<kMaxCallers>(); }
+
+ private:
+ CallerAccumulator accumulators_[kMaxCallers];
+ hwy::BitSet<kMaxCallers> callers_;
+};
+
+// Type-safe wrapper.
+class Caller {
+ public:
+ Caller() : idx_(0) {} // `AddCaller` never returns 0.
+ explicit Caller(size_t idx) : idx_(idx) { HWY_DASSERT(idx < kMaxCallers); }
+ size_t Idx() const { return idx_; }
+
+ private:
+ size_t idx_;
};
-static_assert(sizeof(Config) == 8, "");
+
+// Singleton, shared by all ThreadPool.
+class Shared {
+ public:
+ static HWY_CONTRIB_DLLEXPORT Shared& Get(); // Thread-safe.
+
+ Stopwatch MakeStopwatch() const { return Stopwatch(timer_); }
+ Stopwatch& LastRootEnd() { return last_root_end_; }
+
+ // Thread-safe. Calls with the same `name` return the same `Caller`.
+ Caller AddCaller(const char* name) { return Caller(callers_.Add(name)); }
+
+ PerCluster& Cluster(size_t cluster_idx) {
+ HWY_DASSERT(cluster_idx < kMaxClusters);
+ return per_cluster_[cluster_idx];
+ }
+
+ // Called from the main thread via `Profiler::PrintResults`.
+ void PrintAndReset() {
+ // Start counting pools (= one per cluster) invoked by each caller.
+ size_t active_clusters[kMaxCallers] = {};
+ per_cluster_[0].ForeachCaller(
+ [&](size_t caller_idx, CallerAccumulator& acc) {
+ active_clusters[caller_idx] = acc.Any();
+ });
+ // Reduce per-cluster accumulators into the first cluster.
+ for (size_t cluster_idx = 1; cluster_idx < kMaxClusters; ++cluster_idx) {
+ per_cluster_[cluster_idx].ForeachCaller(
+ [&](size_t caller_idx, CallerAccumulator& acc) {
+ active_clusters[caller_idx] += acc.Any();
+ per_cluster_[0].Get(caller_idx).AddFrom(acc);
+ acc = CallerAccumulator();
+ });
+ per_cluster_[cluster_idx].ResetBits();
+ }
+
+ CallerAccumulator total;
+ for (size_t caller_idx : per_cluster_[0].Sorted()) {
+ CallerAccumulator& acc = per_cluster_[0].Get(caller_idx);
+ total.AddFrom(acc); // must be before PrintAndReset.
+ acc.PrintAndReset(callers_.Name(caller_idx), active_clusters[caller_idx]);
+ }
+ total.PrintTotal();
+ per_cluster_[0].ResetBits();
+ }
+
+ private:
+ Shared() // called via Get().
+ : last_root_end_(timer_),
+ send_config(callers_.Add("SendConfig")),
+ dtor(callers_.Add("PoolDtor")),
+ print_stats(callers_.Add("PrintStats")) {
+ Profiler::Get().AddFunc(this, [this]() { PrintAndReset(); });
+ // Can skip `RemoveFunc` because the singleton never dies.
+ }
+
+ const Timer timer_;
+ Stopwatch last_root_end_;
+
+ PerCluster per_cluster_[kMaxClusters];
+ StringTable<kMaxCallers> callers_;
+
+ public:
+ // Returned from `callers_.Add`:
+ Caller send_config;
+ Caller dtor;
+ Caller print_stats;
+};
+
+#else
+
+struct Stats {
+ void NotifyRunStatic(size_t, timer::Ticks) {}
+ void NotifyRunDynamic(size_t, size_t, size_t, timer::Ticks) {}
+ void NotifyThreadRun(size_t, timer::Ticks, size_t, timer::Ticks,
+ timer::Ticks) {}
+ void NotifyMainRun(size_t, timer::Ticks, timer::Ticks, timer::Ticks,
+ timer::Ticks) {}
+ void PrintAndReset(size_t, timer::Ticks) {}
+ void Reset(size_t = kMaxThreads) {}
+};
+
+struct Caller {};
+
+class Shared {
+ public:
+ static HWY_CONTRIB_DLLEXPORT Shared& Get(); // Thread-safe.
+
+ Stopwatch MakeStopwatch() const { return Stopwatch(timer_); }
+
+ Caller AddCaller(const char*) { return Caller(); }
+
+ private:
+ Shared() {}
+
+ const Timer timer_;
+
+ public:
+ Caller send_config;
+ Caller dtor;
+ Caller print_stats;
+};
+
+#endif // PROFILER_ENABLED
// Per-worker state used by both main and worker threads. `ThreadFunc`
// (threads) and `ThreadPool` (main) have a few additional members of their own.
@@ -289,12 +783,33 @@ class alignas(HWY_ALIGNMENT) Worker { // HWY_ALIGNMENT bytes
static constexpr auto kAcq = std::memory_order_acquire;
static constexpr auto kRel = std::memory_order_release;
+ bool OwnsGlobalIdx() const {
+#if PROFILER_ENABLED
+ if (global_idx_ >= profiler::kMaxWorkers) {
+ HWY_WARN("Windows-only bug? global_idx %zu >= %zu.", global_idx_,
+ profiler::kMaxWorkers);
+ }
+#endif // PROFILER_ENABLED
+ // Across-cluster pool owns all except the main thread, which is reserved by
+ // profiler.cc.
+ if (cluster_idx_ == kAllClusters) return global_idx_ != 0;
+ // Within-cluster pool owns all except *its* main thread, because that is
+ // owned by the across-cluster pool.
+ return worker_ != 0;
+ }
+
public:
Worker(const size_t worker, const size_t num_threads,
- const Divisor64& div_workers)
- : worker_(worker), num_threads_(num_threads), workers_(this - worker) {
- (void)padding_;
-
+ const PoolWorkerMapping mapping, const Divisor64& div_workers,
+ const Stopwatch& stopwatch)
+ : workers_(this - worker),
+ worker_(worker),
+ num_threads_(num_threads),
+ stopwatch_(stopwatch),
+ // If `num_threads == 0`, we might be in an inner pool and must use
+ // the `global_idx` we are currently running on.
+ global_idx_(num_threads == 0 ? Profiler::GlobalIdx() : mapping(worker)),
+ cluster_idx_(mapping.ClusterIdx()) {
HWY_DASSERT(IsAligned(this, HWY_ALIGNMENT));
HWY_DASSERT(worker <= num_threads);
const size_t num_workers = static_cast<size_t>(div_workers.GetDivisor());
@@ -312,6 +827,20 @@ class alignas(HWY_ALIGNMENT) Worker { // HWY_ALIGNMENT bytes
victims_[i] = shuffled_iota.Next(victims_[i - 1], div_workers);
HWY_DASSERT(victims_[i] != worker);
}
+
+ HWY_IF_CONSTEXPR(PROFILER_ENABLED) {
+ if (HWY_LIKELY(OwnsGlobalIdx())) {
+ Profiler::Get().ReserveWorker(global_idx_);
+ }
+ }
+ }
+
+ ~Worker() {
+ HWY_IF_CONSTEXPR(PROFILER_ENABLED) {
+ if (HWY_LIKELY(OwnsGlobalIdx())) {
+ Profiler::Get().FreeWorker(global_idx_);
+ }
+ }
}
// Placement-newed by `WorkerLifecycle`, we do not expect any copying.
@@ -319,15 +848,29 @@ class alignas(HWY_ALIGNMENT) Worker { // HWY_ALIGNMENT bytes
Worker& operator=(const Worker&) = delete;
size_t Index() const { return worker_; }
+ // For work stealing.
Worker* AllWorkers() { return workers_; }
const Worker* AllWorkers() const { return workers_; }
size_t NumThreads() const { return num_threads_; }
+ size_t GlobalIdx() const { return global_idx_; }
+ size_t ClusterIdx() const { return cluster_idx_; }
+
+ void SetStartTime() { stopwatch_.Reset(); }
+ timer::Ticks ElapsedTime() { return stopwatch_.Elapsed(); }
+
// ------------------------ Per-worker storage for `SendConfig`
- Config LatchedConfig() const { return latched_; }
- // For workers, but no harm if also called by main thread.
- void LatchConfig(Config copy) { latched_ = copy; }
+ Config NextConfig() const { return next_config_; }
+ // Called during `SendConfig` by workers and now also the main thread. This
+ // avoids a separate `ThreadPool` member which risks going out of sync.
+ void SetNextConfig(Config copy) { next_config_ = copy; }
+
+ Exit GetExit() const { return exit_; }
+ void SetExit(Exit exit) { exit_ = exit; }
+
+ uint32_t WorkerEpoch() const { return worker_epoch_; }
+ uint32_t AdvanceWorkerEpoch() { return ++worker_epoch_; }
// ------------------------ Task assignment
@@ -364,36 +907,43 @@ class alignas(HWY_ALIGNMENT) Worker { // HWY_ALIGNMENT bytes
// ------------------------ Barrier: Main thread waits for workers
+ // For use by `HasReached` and `UntilReached`.
const std::atomic<uint32_t>& Barrier() const { return barrier_epoch_; }
- std::atomic<uint32_t>& MutableBarrier() { return barrier_epoch_; }
+ // Setting to `epoch` signals that the worker has reached the barrier.
void StoreBarrier(uint32_t epoch) { barrier_epoch_.store(epoch, kRel); }
private:
- // Atomics first because arm7 clang otherwise makes them unaligned.
-
// Set by `SetRange`:
- alignas(8) std::atomic<uint64_t> my_begin_;
- alignas(8) std::atomic<uint64_t> my_end_;
+ std::atomic<uint64_t> my_begin_;
+ std::atomic<uint64_t> my_end_;
+
+ Worker* const workers_;
+ const size_t worker_;
+ const size_t num_threads_;
- // Use u32 to match futex.h.
- alignas(4) std::atomic<uint32_t> wait_epoch_{0};
- alignas(4) std::atomic<uint32_t> barrier_epoch_{0}; // is reset
+ Stopwatch stopwatch_; // Reset by `SetStartTime`.
+ const size_t global_idx_;
+ const size_t cluster_idx_;
+
+ // Use u32 to match futex.h. These must start at the initial value of
+ // `worker_epoch_`.
+ std::atomic<uint32_t> wait_epoch_{1};
+ std::atomic<uint32_t> barrier_epoch_{1};
uint32_t num_victims_; // <= kPoolMaxVictims
std::array<uint32_t, kMaxVictims> victims_;
- Config latched_;
-
- const size_t worker_;
- const size_t num_threads_;
- Worker* const workers_;
+ // Written and read by the same thread, hence not atomic.
+ Config next_config_;
+ Exit exit_ = Exit::kNone;
+ // thread_pool_test requires nonzero epoch.
+ uint32_t worker_epoch_ = 1;
- uint8_t padding_[HWY_ALIGNMENT - 64 - sizeof(victims_)];
+ HWY_MEMBER_VAR_MAYBE_UNUSED uint8_t
+ padding_[HWY_ALIGNMENT - 56 - 6 * sizeof(void*) - sizeof(victims_)];
};
static_assert(sizeof(Worker) == HWY_ALIGNMENT, "");
-#pragma pack(pop)
-
// Creates/destroys `Worker` using preallocated storage. See comment at
// `ThreadPool::worker_bytes_` for why we do not dynamically allocate.
class WorkerLifecycle { // 0 bytes
@@ -401,10 +951,13 @@ class WorkerLifecycle { // 0 bytes
// Placement new for `Worker` into `storage` because its ctor requires
// the worker index. Returns array of all workers.
static Worker* Init(uint8_t* storage, size_t num_threads,
- const Divisor64& div_workers) {
- Worker* workers = new (storage) Worker(0, num_threads, div_workers);
+ PoolWorkerMapping mapping, const Divisor64& div_workers,
+ Shared& shared) {
+ Worker* workers = new (storage)
+ Worker(0, num_threads, mapping, div_workers, shared.MakeStopwatch());
for (size_t worker = 1; worker <= num_threads; ++worker) {
- new (Addr(storage, worker)) Worker(worker, num_threads, div_workers);
+ new (Addr(storage, worker)) Worker(worker, num_threads, mapping,
+ div_workers, shared.MakeStopwatch());
// Ensure pointer arithmetic is the same (will be used in Destroy).
HWY_DASSERT(reinterpret_cast<uintptr_t>(workers + worker) ==
reinterpret_cast<uintptr_t>(Addr(storage, worker)));
@@ -428,10 +981,9 @@ class WorkerLifecycle { // 0 bytes
}
};
-#pragma pack(push, 1)
// Stores arguments to `Run`: the function and range of task indices. Set by
// the main thread, read by workers including the main thread.
-class alignas(8) Tasks {
+class Tasks {
static constexpr auto kAcq = std::memory_order_acquire;
// Signature of the (internal) function called from workers(s) for each
@@ -454,7 +1006,8 @@ class alignas(8) Tasks {
}
// Assigns workers their share of `[begin, end)`. Called from the main
- // thread; workers are initializing or spinning for a command.
+ // thread; workers are initializing or waiting for a command.
+ // Negligible CPU time.
static void DivideRangeAmongWorkers(const uint64_t begin, const uint64_t end,
const Divisor64& div_workers,
Worker* workers) {
@@ -480,27 +1033,31 @@ class alignas(8) Tasks {
}
// Runs the worker's assigned range of tasks, plus work stealing if needed.
- HWY_POOL_PROFILE void WorkerRun(Worker* worker) const {
+ void WorkerRun(Worker* worker, const Shared& shared, Stats& stats) const {
if (NumTasks() > worker->NumThreads() + 1) {
- WorkerRunWithStealing(worker);
+ WorkerRunDynamic(worker, shared, stats);
} else {
- WorkerRunSingle(worker->Index());
+ WorkerRunStatic(worker, shared, stats);
}
}
private:
// Special case for <= 1 task per worker, where stealing is unnecessary.
- void WorkerRunSingle(size_t worker) const {
+ void WorkerRunStatic(Worker* worker, const Shared& shared,
+ Stats& stats) const {
const uint64_t begin = begin_.load(kAcq);
const uint64_t end = end_.load(kAcq);
HWY_DASSERT(begin <= end);
+ const size_t index = worker->Index();
- const uint64_t task = begin + worker;
+ const uint64_t task = begin + index;
// We might still have more workers than tasks, so check first.
if (HWY_LIKELY(task < end)) {
const void* opaque = Opaque();
const RunFunc func = Func();
- func(opaque, task, worker);
+ Stopwatch stopwatch = shared.MakeStopwatch();
+ func(opaque, task, index);
+ stats.NotifyRunStatic(index, stopwatch.Elapsed());
}
}
@@ -516,12 +1073,16 @@ class alignas(8) Tasks {
// and perform work from others, as if they were that worker. This deals with
// imbalances as they arise, but care is required to reduce contention. We
// randomize the order in which threads choose victims to steal from.
- HWY_POOL_PROFILE void WorkerRunWithStealing(Worker* worker) const {
+ void WorkerRunDynamic(Worker* worker, const Shared& shared,
+ Stats& stats) const {
Worker* workers = worker->AllWorkers();
const size_t index = worker->Index();
const RunFunc func = Func();
const void* opaque = Opaque();
+ size_t sum_tasks = 0;
+ size_t sum_stolen = 0;
+ timer::Ticks sum_d_func = 0;
// For each worker in random order, starting with our own, attempt to do
// all their work.
for (uint32_t victim : worker->Victims()) {
@@ -538,11 +1099,16 @@ class alignas(8) Tasks {
hwy::Pause(); // Reduce coherency traffic while stealing.
break;
}
+ Stopwatch stopwatch = shared.MakeStopwatch();
// Pass the index we are actually running on; this is important
// because it is the TLS index for user code.
func(opaque, task, index);
+ sum_tasks++;
+ sum_stolen += worker != other_worker;
+ sum_d_func += stopwatch.Elapsed();
}
}
+ stats.NotifyRunDynamic(index, sum_tasks, sum_stolen, sum_d_func);
}
size_t NumTasks() const {
@@ -564,7 +1130,6 @@ class alignas(8) Tasks {
std::atomic<const void*> opaque_;
};
static_assert(sizeof(Tasks) == 16 + 2 * sizeof(void*), "");
-#pragma pack(pop)
// ------------------------------ Threads wait, main wakes them
@@ -590,20 +1155,21 @@ static_assert(sizeof(Tasks) == 16 + 2 * sizeof(void*), "");
// Futex: blocking reduces apparent CPU usage, but has higher wake latency.
struct WaitBlock {
- WaitType Type() const { return WaitType::kBlock; }
-
// Wakes all workers by storing the current `epoch`.
void WakeWorkers(Worker* workers, const uint32_t epoch) const {
HWY_DASSERT(epoch != 0);
- workers[0].StoreWaiter(epoch);
- WakeAll(workers[0].MutableWaiter()); // futex: expensive syscall
+ workers[1].StoreWaiter(epoch);
+ WakeAll(workers[1].MutableWaiter()); // futex: expensive syscall
}
// Waits until `WakeWorkers(_, epoch)` has been called.
template <class Spin>
- void UntilWoken(const Worker* worker, const Spin& /*spin*/,
- const uint32_t epoch) const {
- BlockUntilDifferent(epoch - 1, worker->AllWorkers()->Waiter());
+ size_t UntilWoken(const Worker& worker, const Spin& /*spin*/) const {
+ HWY_DASSERT(worker.Index() != 0); // main is 0
+ const uint32_t epoch = worker.WorkerEpoch();
+ const Worker* workers = worker.AllWorkers();
+ BlockUntilDifferent(epoch - 1, workers[1].Waiter());
+ return 1; // iterations
}
};
@@ -611,376 +1177,95 @@ struct WaitBlock {
// one cache line and thus have it in a shared state, which means the store
// will invalidate each of them, leading to more transactions than SpinSeparate.
struct WaitSpin1 {
- WaitType Type() const { return WaitType::kSpin1; }
-
void WakeWorkers(Worker* workers, const uint32_t epoch) const {
- workers[0].StoreWaiter(epoch);
+ workers[1].StoreWaiter(epoch);
}
+ // Returns the number of spin-wait iterations.
template <class Spin>
- void UntilWoken(const Worker* worker, const Spin& spin,
- const uint32_t epoch) const {
- (void)spin.UntilEqual(epoch, worker->AllWorkers()->Waiter());
- // TODO: store reps in stats.
+ size_t UntilWoken(const Worker& worker, const Spin& spin) const {
+ HWY_DASSERT(worker.Index() != 0); // main is 0
+ const Worker* workers = worker.AllWorkers();
+ const uint32_t epoch = worker.WorkerEpoch();
+ return spin.UntilEqual(epoch, workers[1].Waiter());
}
};
// Separate u32 per thread: more stores for the main thread, but each worker
// only polls its own cache line, leading to fewer cache-coherency transactions.
struct WaitSpinSeparate {
- WaitType Type() const { return WaitType::kSpinSeparate; }
-
void WakeWorkers(Worker* workers, const uint32_t epoch) const {
for (size_t thread = 0; thread < workers->NumThreads(); ++thread) {
- workers[thread].StoreWaiter(epoch);
+ workers[1 + thread].StoreWaiter(epoch);
}
}
template <class Spin>
- void UntilWoken(const Worker* worker, const Spin& spin,
- const uint32_t epoch) const {
- (void)spin.UntilEqual(epoch, worker->Waiter());
- // TODO: store reps in stats.
+ size_t UntilWoken(const Worker& worker, const Spin& spin) const {
+ HWY_DASSERT(worker.Index() != 0); // main is 0
+ const uint32_t epoch = worker.WorkerEpoch();
+ return spin.UntilEqual(epoch, worker.Waiter());
}
};
-// ------------------------------ Barrier: Main thread waits for workers
-
-// Single atomic counter. TODO: remove if not competitive?
-template <size_t kShards>
-class BarrierCounter {
- static_assert(kShards == 1 || kShards == 2 || kShards == 4, ""); // pow2
-
- public:
- BarrierType Type() const {
- return kShards == 1 ? BarrierType::kCounter1
- : kShards == 2 ? BarrierType::kCounter2
- : BarrierType::kCounter4;
- }
-
- void Reset(Worker* workers) const {
- for (size_t i = 0; i < kShards; ++i) {
- // Use last worker(s) to avoid contention with other stores to the Worker.
- // Note that there are kMaxThreads + 1 workers, hence i == 0 is the last.
- workers[kMaxThreads - i].StoreBarrier(0);
- }
- }
-
- template <class Spin>
- void WorkerReached(Worker* worker, const Spin& /*spin*/,
- uint32_t /*epoch*/) const {
- const size_t shard = worker->Index() & (kShards - 1);
- const auto kAcqRel = std::memory_order_acq_rel;
- worker->AllWorkers()[kMaxThreads - shard].MutableBarrier().fetch_add(
- 1, kAcqRel);
- }
-
- template <class Spin>
- void UntilReached(size_t num_threads, const Worker* workers, const Spin& spin,
- uint32_t /*epoch*/) const {
- HWY_IF_CONSTEXPR(kShards == 1) {
- (void)spin.UntilEqual(static_cast<uint32_t>(num_threads),
- workers[kMaxThreads].Barrier());
- }
- HWY_IF_CONSTEXPR(kShards == 2) {
- const auto kAcq = std::memory_order_acquire;
- for (;;) {
- hwy::Pause();
- const uint64_t sum = workers[kMaxThreads - 0].Barrier().load(kAcq) +
- workers[kMaxThreads - 1].Barrier().load(kAcq);
- if (sum == num_threads) break;
- }
- }
- HWY_IF_CONSTEXPR(kShards == 4) {
- const auto kAcq = std::memory_order_acquire;
- for (;;) {
- hwy::Pause();
- const uint64_t sum = workers[kMaxThreads - 0].Barrier().load(kAcq) +
- workers[kMaxThreads - 1].Barrier().load(kAcq) +
- workers[kMaxThreads - 2].Barrier().load(kAcq) +
- workers[kMaxThreads - 3].Barrier().load(kAcq);
- if (sum == num_threads) break;
- }
- }
+// Calls unrolled code selected by all config enums.
+template <class Func, typename... Args>
+HWY_INLINE void CallWithConfig(const Config& config, Func&& func,
+ Args&&... args) {
+ switch (config.wait_type) {
+ case WaitType::kBlock:
+ return func(SpinPause(), WaitBlock(), std::forward<Args>(args)...);
+ case WaitType::kSpin1:
+ return CallWithSpin(config.spin_type, func, WaitSpin1(),
+ std::forward<Args>(args)...);
+ case WaitType::kSpinSeparate:
+ return CallWithSpin(config.spin_type, func, WaitSpinSeparate(),
+ std::forward<Args>(args)...);
+ case WaitType::kSentinel:
+ HWY_UNREACHABLE;
}
-};
+}
-// As with the wait, a store-release of the same local epoch counter serves as a
-// "have arrived" flag that does not require resetting.
+// ------------------------------ Barrier: Main thread waits for workers
-// Main thread loops over each worker.
-class BarrierOrdered {
+// Similar to `WaitSpinSeparate`, a store-release of the same local epoch
+// counter serves as a "have arrived" flag that does not require resetting.
+class Barrier {
public:
- BarrierType Type() const { return BarrierType::kOrdered; }
-
- void Reset(Worker* /*workers*/) const {}
-
- template <class Spin>
- void WorkerReached(Worker* worker, const Spin&, uint32_t epoch) const {
- worker->StoreBarrier(epoch);
- }
-
- template <class Spin>
- void UntilReached(size_t num_threads, const Worker* workers, const Spin& spin,
- uint32_t epoch) const {
- for (size_t i = 0; i < num_threads; ++i) {
- (void)spin.UntilEqual(epoch, workers[i].Barrier());
- }
+ void WorkerReached(Worker& worker, uint32_t epoch) const {
+ HWY_DASSERT(worker.Index() != 0); // main is 0
+ worker.StoreBarrier(epoch);
}
-};
-// Leader threads wait for others in the group, main thread loops over leaders.
-template <size_t kGroupSize>
-class BarrierGroup {
- public:
- BarrierType Type() const {
- return kGroupSize == 2 ? BarrierType::kGroup2 : BarrierType::kGroup4;
+ // Returns true if `worker` (can be the main thread) reached the barrier.
+ bool HasReached(const Worker* worker, uint32_t epoch) const {
+ const uint32_t barrier = worker->Barrier().load(std::memory_order_acquire);
+ HWY_DASSERT(barrier <= epoch);
+ return barrier == epoch;
}
- void Reset(Worker* /*workers*/) const {}
-
+ // Main thread loops over each worker. A "group of 2 or 4" barrier was not
+ // competitive on Skylake, Granite Rapids and Zen5.
template <class Spin>
- void WorkerReached(Worker* worker, const Spin& spin, uint32_t epoch) const {
- const size_t thread = worker->Index();
- // Leaders wait for all others in their group before marking themselves.
- if (thread % kGroupSize == 0) {
- for (size_t i = thread + 1;
- i < HWY_MIN(thread + kGroupSize, worker->NumThreads()); ++i) {
- (void)spin.UntilEqual(epoch, worker->AllWorkers()[i].Barrier());
- }
- }
- worker->StoreBarrier(epoch);
- }
-
- template <class Spin>
- void UntilReached(size_t num_threads, const Worker* workers, const Spin& spin,
+ void UntilReached(size_t num_threads, Worker* workers, const Spin& spin,
uint32_t epoch) const {
- for (size_t i = 0; i < num_threads; i += kGroupSize) {
- (void)spin.UntilEqual(epoch, workers[i].Barrier());
- }
- }
-};
+ workers[0].StoreBarrier(epoch); // for main thread HasReached.
-// ------------------------------ Inlining policy classes
-
-// We want to inline the various spin/wait/barrier policy classes into larger
-// code sections because both the main and worker threads use two or three of
-// them at a time, and we do not want separate branches around each.
-//
-// We generate code for three combinations of the enums, hence implement
-// composable adapters that 'add' `Wait` and `Barrier` arguments. `spin.h`
-// provides a `CallWithSpin`, hence it is the outermost. C++11 lacks generic
-// lambdas, so we implement these as classes.
-template <class Func>
-class FunctorAddWait {
- public:
- FunctorAddWait(WaitType wait_type, Func&& func)
- : func_(std::forward<Func>(func)), wait_type_(wait_type) {}
-
- template <class Spin>
- HWY_INLINE void operator()(const Spin& spin) {
- switch (wait_type_) {
- case WaitType::kBlock:
- return func_(spin, WaitBlock());
- case WaitType::kSpin1:
- return func_(spin, WaitSpin1());
- case WaitType::kSpinSeparate:
- return func_(spin, WaitSpinSeparate());
- default:
- HWY_UNREACHABLE;
+ for (size_t i = 0; i < num_threads; ++i) {
+ // TODO: log number of spin-wait iterations.
+ (void)spin.UntilEqual(epoch, workers[1 + i].Barrier());
}
}
-
- private:
- Func&& func_;
- WaitType wait_type_;
};
-template <class Func>
-class FunctorAddBarrier {
+// In debug builds, detects when functions are re-entered.
+class BusyFlag {
public:
- FunctorAddBarrier(BarrierType barrier_type, Func&& func)
- : func_(std::forward<Func>(func)), barrier_type_(barrier_type) {}
-
- template <class Wait>
- HWY_INLINE void operator()(const Wait& wait) {
- switch (barrier_type_) {
- case BarrierType::kOrdered:
- return func_(wait, BarrierOrdered());
- case BarrierType::kCounter1:
- return func_(wait, BarrierCounter<1>());
- case BarrierType::kCounter2:
- return func_(wait, BarrierCounter<2>());
- case BarrierType::kCounter4:
- return func_(wait, BarrierCounter<4>());
- case BarrierType::kGroup2:
- return func_(wait, BarrierGroup<2>());
- case BarrierType::kGroup4:
- return func_(wait, BarrierGroup<4>());
- default:
- HWY_UNREACHABLE;
- }
- }
- template <class Spin, class Wait>
- HWY_INLINE void operator()(const Spin& spin, const Wait& wait) {
- switch (barrier_type_) {
- case BarrierType::kOrdered:
- return func_(spin, wait, BarrierOrdered());
- case BarrierType::kCounter1:
- return func_(spin, wait, BarrierCounter<1>());
- case BarrierType::kCounter2:
- return func_(spin, wait, BarrierCounter<2>());
- case BarrierType::kCounter4:
- return func_(spin, wait, BarrierCounter<4>());
- case BarrierType::kGroup2:
- return func_(spin, wait, BarrierGroup<2>());
- case BarrierType::kGroup4:
- return func_(spin, wait, BarrierGroup<4>());
- default:
- HWY_UNREACHABLE;
- }
- }
+ void Set() { HWY_DASSERT(!busy_.test_and_set()); }
+ void Clear() { HWY_IF_CONSTEXPR(HWY_IS_DEBUG_BUILD) busy_.clear(); }
private:
- Func&& func_;
- BarrierType barrier_type_;
-};
-
-// Calls unrolled code selected by all 3 enums.
-template <class Func>
-HWY_INLINE void CallWithConfig(const Config& config, Func&& func) {
- CallWithSpin(
- config.spin_type,
- FunctorAddWait<FunctorAddBarrier<Func>>(
- config.wait_type, FunctorAddBarrier<Func>(config.barrier_type,
- std::forward<Func>(func))));
-}
-
-// For `WorkerAdapter`, `Spin` and `Wait`.
-template <class Func>
-HWY_INLINE void CallWithSpinWait(const Config& config, Func&& func) {
- CallWithSpin(
- config.spin_type,
- FunctorAddWait<Func>(config.wait_type, std::forward<Func>(func)));
-}
-
-// For `WorkerAdapter`, only `Spin` and `Barrier`.
-template <class Func>
-HWY_INLINE void CallWithSpinBarrier(const Config& config, Func&& func) {
- CallWithSpin(
- config.spin_type,
- FunctorAddBarrier<Func>(config.barrier_type, std::forward<Func>(func)));
-}
-
-// ------------------------------ Adapters
-
-// Logic of the main and worker threads, again packaged as classes because
-// C++11 lacks generic lambdas, called by `CallWith*`.
-
-class MainAdapter {
- public:
- MainAdapter(Worker* main, const Tasks* tasks) : main_(main), tasks_(tasks) {}
-
- void SetEpoch(uint32_t epoch) { epoch_ = epoch; }
-
- template <class Spin, class Wait, class Barrier>
- HWY_POOL_PROFILE void operator()(const Spin& spin, const Wait& wait,
- const Barrier& barrier) const {
- Worker* workers = main_->AllWorkers();
- const size_t num_threads = main_->NumThreads();
- barrier.Reset(workers);
-
- wait.WakeWorkers(workers, epoch_);
- // Threads might still be starting up and wake up late, but we wait for
- // them at the barrier below.
-
- // Also perform work on the main thread before the barrier.
- tasks_->WorkerRun(main_);
-
- // Waits until all *threads* (not the main thread, because it already knows
- // it is here) called `WorkerReached`. All `barrier` types use spinning.
-
- barrier.UntilReached(num_threads, workers, spin, epoch_);
-
- // Threads may already be waiting `UntilWoken`, which serves as the
- // 'release' phase of the barrier.
- }
-
- private:
- Worker* const main_;
- const Tasks* const tasks_;
- uint32_t epoch_;
-};
-
-class WorkerAdapter {
- public:
- explicit WorkerAdapter(Worker* worker) : worker_(worker) {}
-
- void SetEpoch(uint32_t epoch) { epoch_ = epoch; }
-
- // Split into separate wait/barrier functions because `ThreadFunc` latches
- // the config in between them.
- template <class Spin, class Wait,
- HWY_IF_SAME(decltype(Wait().Type()), WaitType)>
- void operator()(const Spin& spin, const Wait& wait) const {
- wait.UntilWoken(worker_, spin, epoch_);
- }
-
- template <class Spin, class Barrier,
- HWY_IF_SAME(decltype(Barrier().Type()), BarrierType)>
- void operator()(const Spin& spin, const Barrier& barrier) const {
- barrier.WorkerReached(worker_, spin, epoch_);
- }
-
- private:
- Worker* const worker_;
- uint32_t epoch_;
-};
-
-// Could also be a lambda in ThreadPool ctor, but this allows annotating with
-// `HWY_POOL_PROFILE` so we can more easily inspect the generated code.
-class ThreadFunc {
- public:
- ThreadFunc(Worker* worker, Tasks* tasks, Config config)
- : worker_(worker),
- tasks_(tasks),
- config_(config),
- worker_adapter_(worker_) {
- worker->LatchConfig(config);
- }
-
- HWY_POOL_PROFILE void operator()() {
- SetThreadName("worker%03zu", static_cast<int>(worker_->Index()));
-
- // Ensure main thread's writes are visible (synchronizes with fence in
- // `WorkerLifecycle::Init`).
- std::atomic_thread_fence(std::memory_order_acquire);
-
- // Initialization must match pre-increment in `MainAdapter::SetEpoch`.
- // Loop termination is triggered by `~ThreadPool`.
- for (uint32_t epoch = 1;; ++epoch) {
- worker_adapter_.SetEpoch(epoch);
- CallWithSpinWait(config_, worker_adapter_);
-
- // Must happen before `WorkerRun` because `SendConfig` writes it there.
- config_ = worker_->LatchedConfig();
-
- tasks_->WorkerRun(worker_);
-
- // Notify barrier after `WorkerRun`.
- CallWithSpinBarrier(config_, worker_adapter_);
-
- // Check after notifying the barrier, otherwise the main thread deadlocks.
- if (HWY_UNLIKELY(config_.exit)) break;
- }
- }
-
- private:
- Worker* const worker_;
- Tasks* const tasks_;
-
- Config config_;
- WorkerAdapter worker_adapter_;
+ std::atomic_flag busy_ = ATOMIC_FLAG_INIT;
};
} // namespace pool
@@ -993,11 +1278,11 @@ class ThreadFunc {
// that threads do not schedule new work themselves. This allows us to avoid
// queues and only store a counter plus the current task. The latter is a
// pointer to a lambda function, without the allocation/indirection required for
-// std::function.
+// `std::function`.
//
// To reduce fork/join latency, we choose an efficient barrier, optionally
-// enable spin-waits via SetWaitMode, and avoid any mutex/lock. We largely even
-// avoid atomic RMW operations (LOCK prefix): currently for the wait and
+// enable spin-waits via `SetWaitMode`, and avoid any mutex/lock. We largely
+// even avoid atomic RMW operations (LOCK prefix): currently for the wait and
// barrier, in future hopefully also for work stealing.
//
// To eliminate false sharing and enable reasoning about cache line traffic, the
@@ -1005,6 +1290,17 @@ class ThreadFunc {
//
// For load-balancing, we use work stealing in random order.
class alignas(HWY_ALIGNMENT) ThreadPool {
+ // Used to initialize `num_threads_` from the ctor argument.
+ static size_t ClampedNumThreads(size_t num_threads) {
+ // Upper bound is required for `worker_bytes_`.
+ if (HWY_UNLIKELY(num_threads > pool::kMaxThreads)) {
+ HWY_WARN("ThreadPool: clamping num_threads %zu to %zu.", num_threads,
+ pool::kMaxThreads);
+ num_threads = pool::kMaxThreads;
+ }
+ return num_threads;
+ }
+
public:
// This typically includes hyperthreads, hence it is a loose upper bound.
// -1 because these are in addition to the main thread.
@@ -1020,48 +1316,61 @@ class alignas(HWY_ALIGNMENT) ThreadPool {
// `num_threads` is the number of *additional* threads to spawn, which should
// not exceed `MaxThreads()`. Note that the main thread also performs work.
- explicit ThreadPool(size_t num_threads)
- : have_timer_stop_(platform::HaveTimerStop(cpu100_)),
- num_threads_(ClampedNumThreads(num_threads)),
- div_workers_(num_threads_ + 1),
+ // `mapping` indicates how to map local worker_idx to global.
+ ThreadPool(size_t num_threads,
+ PoolWorkerMapping mapping = PoolWorkerMapping())
+ : num_threads_(ClampedNumThreads(num_threads)),
+ div_workers_(1 + num_threads_),
+ shared_(pool::Shared::Get()), // on first call, calls ReserveWorker(0)!
workers_(pool::WorkerLifecycle::Init(worker_bytes_, num_threads_,
- div_workers_)),
- main_adapter_(workers_ + num_threads_, &tasks_) {
+ mapping, div_workers_, shared_)) {
// Leaves the default wait mode as `kBlock`, which means futex, because
// spinning only makes sense when threads are pinned and wake latency is
// important, so it must explicitly be requested by calling `SetWaitMode`.
for (PoolWaitMode mode : {PoolWaitMode::kSpin, PoolWaitMode::kBlock}) {
wait_mode_ = mode; // for AutoTuner
AutoTuner().SetCandidates(
- pool::Config::AllCandidates(mode, num_threads_));
+ pool::Config::AllCandidates(mode));
+ }
+
+ // Skip empty pools because they do not update stats anyway.
+ if (num_threads_ > 0) {
+ Profiler::Get().AddFunc(this, [this]() { PrintStats(); });
}
- config_ = AutoTuner().Candidates()[0];
threads_.reserve(num_threads_);
for (size_t thread = 0; thread < num_threads_; ++thread) {
threads_.emplace_back(
- pool::ThreadFunc(workers_ + thread, &tasks_, config_));
+ ThreadFunc(workers_[1 + thread], tasks_, shared_, stats_));
}
- // No barrier is required here because wakeup works regardless of the
- // relative order of wake and wait.
+ // Threads' `Config` defaults to spinning. Change to `kBlock` (see above).
+ // This also ensures all threads have started before we return, so that
+ // startup latency is billed to the ctor, not the first `Run`.
+ SendConfig(AutoTuner().Candidates()[0]);
}
- // Waits for all threads to exit.
+ // If we created threads, waits for them all to exit.
~ThreadPool() {
// There is no portable way to request threads to exit like `ExitThread` on
// Windows, otherwise we could call that from `Run`. Instead, we must cause
- // the thread to wake up and exit. We can use the same `SendConfig`
- // mechanism as `SetWaitMode`.
- pool::Config copy = config_;
- copy.exit = true;
- SendConfig(copy);
+ // the thread to wake up and exit. We can just use `Run`.
+ (void)RunWithoutAutotune(
+ 0, NumWorkers(), shared_.dtor,
+ [this](HWY_MAYBE_UNUSED uint64_t task, size_t worker) {
+ HWY_DASSERT(task == worker);
+ workers_[worker].SetExit(Exit::kThread);
+ });
for (std::thread& thread : threads_) {
HWY_DASSERT(thread.joinable());
thread.join();
}
+ if (num_threads_ > 0) {
+ Profiler::Get().RemoveFunc(this);
+ }
+
pool::WorkerLifecycle::Destroy(workers_, num_threads_);
}
@@ -1084,12 +1393,16 @@ class alignas(HWY_ALIGNMENT) ThreadPool {
: AutoTuner().NextConfig());
}
- // For printing which are in use.
- pool::Config config() const { return config_; }
+ // For printing which is in use.
+ pool::Config config() const { return workers_[0].NextConfig(); }
bool AutoTuneComplete() const { return AutoTuner().Best(); }
Span<CostDistribution> AutoTuneCosts() { return AutoTuner().Costs(); }
+ static pool::Caller AddCaller(const char* name) {
+ return pool::Shared::Get().AddCaller(name);
+ }
+
// parallel-for: Runs `closure(task, worker)` on workers for every `task` in
// `[begin, end)`. Note that the unit of work should be large enough to
// amortize the function call overhead, but small enough that each worker
@@ -1098,7 +1411,204 @@ class alignas(HWY_ALIGNMENT) ThreadPool {
// Not thread-safe - concurrent parallel-for in the same `ThreadPool` are
// forbidden unless `NumWorkers() == 1` or `end <= begin + 1`.
template <class Closure>
+ void Run(uint64_t begin, uint64_t end, pool::Caller caller,
+ const Closure& closure) {
+ AutoTuneT& auto_tuner = AutoTuner();
+ // Already finished tuning: run without time measurement.
+ if (HWY_LIKELY(auto_tuner.Best())) {
+ // Don't care whether threads ran, we are done either way.
+ (void)RunWithoutAutotune(begin, end, caller, closure);
+ return;
+ }
+
+ // Not yet finished: measure time and notify autotuner.
+ Stopwatch stopwatch(shared_.MakeStopwatch());
+ // Skip update if threads didn't actually run.
+ if (!RunWithoutAutotune(begin, end, caller, closure)) return;
+ auto_tuner.NotifyCost(stopwatch.Elapsed());
+
+ pool::Config next = auto_tuner.NextConfig(); // may be overwritten below
+ if (auto_tuner.Best()) { // just finished
+ next = *auto_tuner.Best();
+ HWY_IF_CONSTEXPR(pool::kVerbosity >= 1) {
+ const size_t idx_best = static_cast<size_t>(
+ auto_tuner.Best() - auto_tuner.Candidates().data());
+ HWY_DASSERT(idx_best < auto_tuner.Costs().size());
+ auto& AT = auto_tuner.Costs()[idx_best];
+ const double best_cost = AT.EstimateCost();
+ HWY_DASSERT(best_cost > 0.0); // will divide by this below
+
+ Stats s_ratio;
+ for (size_t i = 0; i < auto_tuner.Costs().size(); ++i) {
+ if (i == idx_best) continue;
+ const double cost = auto_tuner.Costs()[i].EstimateCost();
+ s_ratio.Notify(static_cast<float>(cost / best_cost));
+ }
+
+ fprintf(stderr,
+ "Pool %3zu: %s %8.0f +/- %6.0f. Gain %.2fx [%.2fx, %.2fx]\n",
+ NumWorkers(), auto_tuner.Best()->ToString().c_str(), best_cost,
+ AT.Stddev(), s_ratio.GeometricMean(),
+ static_cast<double>(s_ratio.Min()),
+ static_cast<double>(s_ratio.Max()));
+ }
+ }
+ SendConfig(next);
+ }
+
+ // Backward-compatible version without Caller.
+ template <class Closure>
void Run(uint64_t begin, uint64_t end, const Closure& closure) {
+ Run(begin, end, pool::Caller(), closure);
+ }
+
+ private:
+ // Called via `CallWithConfig`.
+ struct MainWakeAndBarrier {
+ template <class Spin, class Wait>
+ void operator()(const Spin& spin, const Wait& wait, pool::Worker& main,
+ const pool::Tasks& tasks, const pool::Shared& shared,
+ pool::Stats& stats) const {
+ const pool::Barrier barrier;
+ pool::Worker* workers = main.AllWorkers();
+ HWY_DASSERT(&main == main.AllWorkers()); // main is first.
+ const size_t num_threads = main.NumThreads();
+ const uint32_t epoch = main.AdvanceWorkerEpoch();
+
+ HWY_IF_CONSTEXPR(HWY_IS_DEBUG_BUILD) {
+ for (size_t i = 0; i < 1 + num_threads; ++i) {
+ HWY_DASSERT(!barrier.HasReached(workers + i, epoch));
+ }
+ }
+
+ Stopwatch stopwatch(shared.MakeStopwatch());
+ const timer::Ticks t_before_wake = stopwatch.Origin();
+ wait.WakeWorkers(workers, epoch);
+ const timer::Ticks d_wake = stopwatch.Elapsed();
+
+ // Also perform work on the main thread before the barrier.
+ tasks.WorkerRun(&main, shared, stats);
+ const timer::Ticks d_run = stopwatch.Elapsed();
+
+ // Spin-waits until all worker *threads* (not `main`, because it already
+ // knows it is here) called `WorkerReached`.
+ barrier.UntilReached(num_threads, workers, spin, epoch);
+ const timer::Ticks d_barrier = stopwatch.Elapsed();
+ stats.NotifyMainRun(main.NumThreads(), t_before_wake, d_wake, d_run,
+ d_barrier);
+
+ HWY_IF_CONSTEXPR(HWY_IS_DEBUG_BUILD) {
+ for (size_t i = 0; i < 1 + num_threads; ++i) {
+ HWY_DASSERT(barrier.HasReached(workers + i, epoch));
+ }
+ }
+
+ // Threads are or will soon be waiting `UntilWoken`, which serves as the
+ // 'release' phase of the barrier.
+ }
+ };
+
+ // Called by `std::thread`. Could also be a lambda.
+ class ThreadFunc {
+ // Functor called by `CallWithConfig`. Loops until `SendConfig` changes the
+ // Spin or Wait policy or the pool is destroyed.
+ struct WorkerLoop {
+ template <class Spin, class Wait>
+ void operator()(const Spin& spin, const Wait& wait, pool::Worker& worker,
+ pool::Tasks& tasks, const pool::Shared& shared,
+ pool::Stats& stats) const {
+ do {
+ // Main worker also calls this, so their epochs match.
+ const uint32_t epoch = worker.AdvanceWorkerEpoch();
+
+ Stopwatch stopwatch(shared.MakeStopwatch());
+
+ const size_t wait_reps = wait.UntilWoken(worker, spin);
+ const timer::Ticks d_wait = stopwatch.Elapsed();
+ const timer::Ticks t_before_run = stopwatch.Origin();
+
+ tasks.WorkerRun(&worker, shared, stats);
+ const timer::Ticks d_run = stopwatch.Elapsed();
+ stats.NotifyThreadRun(worker.Index(), d_wait, wait_reps, t_before_run,
+ d_run);
+
+ // Notify barrier after `WorkerRun`. Note that we cannot send an
+ // after-barrier timestamp, see above.
+ pool::Barrier().WorkerReached(worker, epoch);
+ // Check after `WorkerReached`, otherwise the main thread deadlocks.
+ } while (worker.GetExit() == Exit::kNone);
+ }
+ };
+
+ public:
+ ThreadFunc(pool::Worker& worker, pool::Tasks& tasks,
+ const pool::Shared& shared, pool::Stats& stats)
+ : worker_(worker), tasks_(tasks), shared_(shared), stats_(stats) {}
+
+ void operator()() {
+ // Ensure main thread's writes are visible (synchronizes with fence in
+ // `WorkerLifecycle::Init`).
+ std::atomic_thread_fence(std::memory_order_acquire);
+
+ HWY_DASSERT(worker_.Index() != 0); // main is 0
+ SetThreadName("worker%03zu", static_cast<int>(worker_.Index() - 1));
+
+ worker_.SetStartTime();
+ Profiler& profiler = Profiler::Get();
+ profiler.SetGlobalIdx(worker_.GlobalIdx());
+ // No Zone here because it would only exit after `GetExit`, which may be
+ // after the main thread's `PROFILER_END_ROOT_RUN`, and thus too late to
+ // be counted. Instead, `ProfilerFunc` records the elapsed time.
+
+ // Loop termination via `GetExit` is triggered by `~ThreadPool`.
+ for (;;) {
+ // Uses the initial config, or the last one set during WorkerRun.
+ CallWithConfig(worker_.NextConfig(), WorkerLoop(), worker_, tasks_,
+ shared_, stats_);
+
+ // Exit or reset the flag and return to WorkerLoop with a new config.
+ if (worker_.GetExit() == Exit::kThread) break;
+ worker_.SetExit(Exit::kNone);
+ }
+
+ profiler.SetGlobalIdx(~size_t{0});
+
+ // Defer `FreeWorker` until workers are destroyed to ensure the profiler
+ // is not still using the worker.
+ }
+
+ private:
+ pool::Worker& worker_;
+ pool::Tasks& tasks_;
+ const pool::Shared& shared_;
+ pool::Stats& stats_;
+ };
+
+ void PrintStats() {
+ // Total run time from all non-main threads.
+ std::atomic<timer::Ticks> sum_thread_elapsed{0};
+ (void)RunWithoutAutotune(
+ 0, NumWorkers(), shared_.print_stats,
+ [this, &sum_thread_elapsed](HWY_MAYBE_UNUSED uint64_t task,
+ size_t worker) {
+ HWY_DASSERT(task == worker);
+ // Skip any main thread(s) because they did not init the stopwatch.
+ if (worker != 0) {
+ sum_thread_elapsed.fetch_add(workers_[worker].ElapsedTime());
+ }
+ });
+ const timer::Ticks thread_total =
+ sum_thread_elapsed.load(std::memory_order_acquire);
+ stats_.PrintAndReset(num_threads_, thread_total);
+ }
+
+ // Returns whether threads were used. If not, there is no need to update
+ // the autotuner config.
+ template <class Closure>
+ bool RunWithoutAutotune(uint64_t begin, uint64_t end, pool::Caller caller,
+ const Closure& closure) {
+ pool::Worker& main = workers_[0];
+
const size_t num_tasks = static_cast<size_t>(end - begin);
const size_t num_workers = NumWorkers();
@@ -1108,10 +1618,18 @@ class alignas(HWY_ALIGNMENT) ThreadPool {
for (uint64_t task = begin; task < end; ++task) {
closure(task, /*worker=*/0);
}
- return;
+ return false;
}
- SetBusy();
+ busy_.Set();
+
+#if PROFILER_ENABLED
+ const bool is_root = PROFILER_IS_ROOT_RUN();
+ Stopwatch stopwatch(shared_.MakeStopwatch());
+ const timer::Ticks wait_before =
+ is_root ? shared_.LastRootEnd().Elapsed() : 0;
+#endif
+
tasks_.Set(begin, end, closure);
// More than one task per worker: use work stealing.
@@ -1119,129 +1637,43 @@ class alignas(HWY_ALIGNMENT) ThreadPool {
pool::Tasks::DivideRangeAmongWorkers(begin, end, div_workers_, workers_);
}
- main_adapter_.SetEpoch(++epoch_);
-
- AutoTuneT& auto_tuner = AutoTuner();
- if (HWY_LIKELY(auto_tuner.Best())) {
- CallWithConfig(config_, main_adapter_);
- ClearBusy();
- } else {
- const uint64_t t0 = timer::Start();
- CallWithConfig(config_, main_adapter_);
- const uint64_t t1 = have_timer_stop_ ? timer::Stop() : timer::Start();
- auto_tuner.NotifyCost(t1 - t0);
- ClearBusy(); // before `SendConfig`
- if (auto_tuner.Best()) { // just finished
- HWY_IF_CONSTEXPR(pool::kVerbosity >= 1) {
- const size_t idx_best = static_cast<size_t>(
- auto_tuner.Best() - auto_tuner.Candidates().data());
- HWY_DASSERT(idx_best < auto_tuner.Costs().size());
- auto& AT = auto_tuner.Costs()[idx_best];
- const double best_cost = AT.EstimateCost();
- HWY_DASSERT(best_cost > 0.0); // will divide by this below
-
- Stats s_ratio;
- for (size_t i = 0; i < auto_tuner.Costs().size(); ++i) {
- if (i == idx_best) continue;
- const double cost = auto_tuner.Costs()[i].EstimateCost();
- s_ratio.Notify(static_cast<float>(cost / best_cost));
- }
-
- fprintf(stderr, " %s %5.0f +/- %4.0f. Gain %.2fx [%.2fx, %.2fx]\n",
- auto_tuner.Best()->ToString().c_str(), best_cost, AT.Stddev(),
- s_ratio.GeometricMean(), s_ratio.Min(), s_ratio.Max());
- }
- SendConfig(*auto_tuner.Best());
- } else {
- HWY_IF_CONSTEXPR(pool::kVerbosity >= 2) {
- fprintf(stderr, " %s %5lu\n", config_.ToString().c_str(), t1 - t0);
- }
- SendConfig(auto_tuner.NextConfig());
- }
+ // Runs `MainWakeAndBarrier` with the first worker slot.
+ CallWithConfig(config(), MainWakeAndBarrier(), main, tasks_, shared_,
+ stats_);
+
+#if PROFILER_ENABLED
+ pool::CallerAccumulator& acc =
+ shared_.Cluster(main.ClusterIdx()).Get(caller.Idx());
+ acc.Add(num_tasks, num_workers, is_root, wait_before, stopwatch.Elapsed());
+ if (is_root) {
+ PROFILER_END_ROOT_RUN();
+ shared_.LastRootEnd().Reset();
}
- }
-
- // Can pass this as init_closure when no initialization is needed.
- // DEPRECATED, better to call the Run() overload without the init_closure arg.
- static bool NoInit(size_t /*num_threads*/) { return true; } // DEPRECATED
-
- // DEPRECATED equivalent of NumWorkers. Note that this is not the same as the
- // ctor argument because num_threads = 0 has the same effect as 1.
- size_t NumThreads() const { return NumWorkers(); } // DEPRECATED
+#else
+ (void)caller;
+#endif
- // DEPRECATED prior interface with 32-bit tasks and first calling
- // `init_closure(num_threads)`. Instead, perform any init before this, calling
- // NumWorkers() for an upper bound on the worker index, then call the other
- // overload of Run().
- template <class InitClosure, class RunClosure>
- bool Run(uint64_t begin, uint64_t end, const InitClosure& init_closure,
- const RunClosure& run_closure) {
- if (!init_closure(NumThreads())) return false;
- Run(begin, end, run_closure);
+ busy_.Clear();
return true;
}
- private:
- // Used to initialize ThreadPool::num_threads_ from its ctor argument.
- static size_t ClampedNumThreads(size_t num_threads) {
- // Upper bound is required for `worker_bytes_`.
- if (HWY_UNLIKELY(num_threads > pool::kMaxThreads)) {
- HWY_WARN("ThreadPool: clamping num_threads %zu to %zu.", num_threads,
- pool::kMaxThreads);
- num_threads = pool::kMaxThreads;
- }
- return num_threads;
- }
-
- // Debug-only re-entrancy detection.
- void SetBusy() { HWY_DASSERT(!busy_.test_and_set()); }
- void ClearBusy() { HWY_IF_CONSTEXPR(HWY_IS_DEBUG_BUILD) busy_.clear(); }
-
- // Two-phase barrier protocol for sending `copy` to workers, similar to the
- // 'quiescent state' used in RCU.
- //
- // Phase 1:
- // - Main wakes threads using the old config.
- // - Threads latch `copy` during `WorkerRun`.
- // - Threads notify a barrier and wait for the next wake using the old config.
- //
- // Phase 2:
- // - Main wakes threads still using the old config.
- // - Threads switch their config to their latched `copy`.
- // - Threads notify a barrier and wait, BOTH with the new config.
- // - Main thread switches to `copy` for the next wake.
- HWY_NOINLINE void SendConfig(pool::Config copy) {
- if (NumWorkers() == 1) {
- config_ = copy;
- return;
- }
-
- SetBusy();
-
- const auto closure = [this, copy](uint64_t task, size_t worker) {
- (void)task;
- HWY_DASSERT(task == worker); // one task per worker
- workers_[worker].LatchConfig(copy);
- };
- tasks_.Set(0, NumWorkers(), closure);
- // Same config as workers are *currently* using.
- main_adapter_.SetEpoch(++epoch_);
- CallWithConfig(config_, main_adapter_);
- // All workers have latched `copy` and are waiting with the old config.
-
- // No-op task; will not be called because begin == end.
- tasks_.Set(0, 0, [](uint64_t /*task*/, size_t /*worker*/) {});
- // Threads are waiting using the old config, but will switch after waking,
- // which means we must already use the new barrier.
- pool::Config new_barrier = config_;
- new_barrier.barrier_type = copy.barrier_type;
- main_adapter_.SetEpoch(++epoch_);
- CallWithConfig(new_barrier, main_adapter_);
- // All have woken and are, or will be, waiting per the *new* config. Now we
+ // Sends `next_config` to workers:
+ // - Main wakes threads using the current config.
+ // - Threads copy `next_config` into their `Worker` during `WorkerRun`.
+ // - Threads notify the (same) barrier and already wait for the next wake
+ // using `next_config`.
+ HWY_NOINLINE void SendConfig(pool::Config next_config) {
+ (void)RunWithoutAutotune(
+ 0, NumWorkers(), shared_.send_config,
+ [this, next_config](HWY_MAYBE_UNUSED uint64_t task, size_t worker) {
+ HWY_DASSERT(task == worker); // one task per worker
+ workers_[worker].SetNextConfig(next_config);
+ workers_[worker].SetExit(Exit::kLoop);
+ });
+
+ // All have woken and are, or will be, waiting per `next_config`. Now we
// can entirely switch the main thread's config for the next wake.
- config_ = copy;
-
- ClearBusy();
+ workers_[0].SetNextConfig(next_config);
}
using AutoTuneT = AutoTune<pool::Config, 30>;
@@ -1253,21 +1685,21 @@ class alignas(HWY_ALIGNMENT) ThreadPool {
return auto_tune_[static_cast<size_t>(wait_mode_) - 1];
}
- char cpu100_[100];
- const bool have_timer_stop_;
const size_t num_threads_; // not including main thread
const Divisor64 div_workers_;
+ pool::Shared& shared_;
pool::Worker* const workers_; // points into `worker_bytes_`
- pool::MainAdapter main_adapter_;
+ alignas(HWY_ALIGNMENT) pool::Stats stats_;
- // The only mutable state:
- pool::Tasks tasks_; // written by `Run` and read by workers.
- pool::Config config_; // for use by the next `Run`. Updated via `SendConfig`.
- uint32_t epoch_ = 0; // passed to `MainAdapter`.
+ // This is written by the main thread and read by workers, via reference
+ // passed to `ThreadFunc`. Padding ensures that the workers' cache lines are
+ // not unnecessarily invalidated when the main thread writes other members.
+ alignas(HWY_ALIGNMENT) pool::Tasks tasks_;
+ HWY_MEMBER_VAR_MAYBE_UNUSED char
+ padding_[HWY_ALIGNMENT - sizeof(pool::Tasks)];
- // In debug builds, detects if functions are re-entered.
- std::atomic_flag busy_ = ATOMIC_FLAG_INIT;
+ pool::BusyFlag busy_;
// Unmodified after ctor, but cannot be const because we call thread::join().
std::vector<std::thread> threads_;
diff --git a/third_party/highway/hwy/contrib/thread_pool/thread_pool_test.cc b/third_party/highway/hwy/contrib/thread_pool/thread_pool_test.cc
new file mode 100644
index 0000000000..a2bafb30a4
--- /dev/null
+++ b/third_party/highway/hwy/contrib/thread_pool/thread_pool_test.cc
@@ -0,0 +1,485 @@
+// Copyright 2023 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Modified from BSD-licensed code
+// Copyright (c) the JPEG XL Project Authors. All rights reserved.
+// See https://github.com/libjxl/libjxl/blob/main/LICENSE.
+
+#include "third_party/highway/hwy/contrib/thread_pool/thread_pool.h"
+
+#include <math.h> // sqrtf
+#include <stddef.h>
+#include <stdint.h>
+#include <stdio.h>
+
+#include <atomic>
+#include <thread> // NOLINT
+#include <vector>
+
+#include "third_party/highway/hwy/base.h" // PopCount
+#include "third_party/highway/hwy/contrib/thread_pool/spin.h"
+#include "third_party/highway/hwy/contrib/thread_pool/topology.h"
+#include "third_party/highway/hwy/tests/hwy_gtest.h"
+#include "third_party/highway/hwy/tests/test_util-inl.h" // AdjustedReps
+
+namespace hwy {
+namespace pool {
+namespace {
+
+TEST(ThreadPoolTest, TestCoprime) {
+ // 1 is coprime with anything
+ for (uint32_t i = 1; i < 500; ++i) {
+ HWY_ASSERT(ShuffledIota::CoprimeNonzero(1, i));
+ HWY_ASSERT(ShuffledIota::CoprimeNonzero(i, 1));
+ }
+
+ // Powers of two >= 2 are not coprime
+ for (size_t i = 1; i < 20; ++i) {
+ for (size_t j = 1; j < 20; ++j) {
+ HWY_ASSERT(!ShuffledIota::CoprimeNonzero(1u << i, 1u << j));
+ }
+ }
+
+ // 2^x and 2^x +/- 1 are coprime
+ for (size_t i = 1; i < 30; ++i) {
+ const uint32_t pow2 = 1u << i;
+ HWY_ASSERT(ShuffledIota::CoprimeNonzero(pow2, pow2 + 1));
+ HWY_ASSERT(ShuffledIota::CoprimeNonzero(pow2, pow2 - 1));
+ HWY_ASSERT(ShuffledIota::CoprimeNonzero(pow2 + 1, pow2));
+ HWY_ASSERT(ShuffledIota::CoprimeNonzero(pow2 - 1, pow2));
+ }
+
+ // Random number x * random y (both >= 2) is not co-prime with x nor y.
+ RandomState rng;
+ for (size_t i = 1; i < 5000; ++i) {
+ const uint32_t x = (Random32(&rng) & 0xFFF7) + 2;
+ const uint32_t y = (Random32(&rng) & 0xFFF7) + 2;
+ HWY_ASSERT(!ShuffledIota::CoprimeNonzero(x * y, x));
+ HWY_ASSERT(!ShuffledIota::CoprimeNonzero(x * y, y));
+ HWY_ASSERT(!ShuffledIota::CoprimeNonzero(x, x * y));
+ HWY_ASSERT(!ShuffledIota::CoprimeNonzero(y, x * y));
+ }
+
+ // Primes are all coprime (list from https://oeis.org/A000040)
+ static constexpr uint32_t primes[] = {
+ 2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47,
+ 53, 59, 61, 67, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109, 113,
+ 127, 131, 137, 139, 149, 151, 157, 163, 167, 173, 179, 181, 191, 193, 197,
+ 199, 211, 223, 227, 229, 233, 239, 241, 251, 257, 263, 269, 271};
+ for (size_t i = 0; i < sizeof(primes) / sizeof(primes[0]); ++i) {
+ for (size_t j = i + 1; j < sizeof(primes) / sizeof(primes[0]); ++j) {
+ HWY_ASSERT(ShuffledIota::CoprimeNonzero(primes[i], primes[j]));
+ HWY_ASSERT(ShuffledIota::CoprimeNonzero(primes[j], primes[i]));
+ }
+ }
+}
+
+// Ensures `shuffled` visits [0, size) exactly once starting from `current`.
+void VerifyPermutation(uint32_t size, const Divisor64& divisor,
+ const ShuffledIota& shuffled, uint32_t current,
+ uint32_t* visited) {
+ for (size_t i = 0; i < size; i++) {
+ visited[i] = 0;
+ }
+
+ for (size_t i = 0; i < size; i++) {
+ ++visited[current];
+ current = shuffled.Next(current, divisor);
+ }
+
+ for (size_t i = 0; i < size; i++) {
+ HWY_ASSERT(visited[i] == 1);
+ }
+}
+
+// Verifies ShuffledIota generates a permutation of [0, size).
+TEST(ThreadPoolTest, TestRandomPermutation) {
+ constexpr size_t kMaxSize = 40;
+ uint32_t visited[kMaxSize];
+
+ // Exhaustive enumeration of size and starting point.
+ for (uint32_t size = 1; size < kMaxSize; ++size) {
+ const Divisor64 divisor(size);
+
+ const uint32_t coprime = ShuffledIota::FindAnotherCoprime(size, 1);
+ const ShuffledIota shuffled(coprime);
+
+ for (uint32_t start = 0; start < size; ++start) {
+ VerifyPermutation(size, divisor, shuffled, start, visited);
+ }
+ }
+}
+
+// Verifies multiple ShuffledIota are relatively independent.
+TEST(ThreadPoolTest, TestMultiplePermutations) {
+ constexpr size_t kMaxSize = 40;
+ uint32_t coprimes[kMaxSize];
+ // One per ShuffledIota; initially the starting value, then its Next().
+ uint32_t current[kMaxSize];
+
+ for (uint32_t size = 1; size < kMaxSize; ++size) {
+ const Divisor64 divisor(size);
+
+ // Create `size` ShuffledIota instances with unique coprimes.
+ std::vector<ShuffledIota> shuffled;
+ for (size_t i = 0; i < size; ++i) {
+ coprimes[i] = ShuffledIota::FindAnotherCoprime(
+ size, static_cast<uint32_t>((i + 1) * 257 + i * 13));
+ shuffled.emplace_back(coprimes[i]);
+ }
+
+ // ShuffledIota[i] starts at i to match the worker thread use case.
+ for (uint32_t i = 0; i < size; ++i) {
+ current[i] = i;
+ }
+
+ size_t num_bad = 0;
+ uint32_t all_visited[kMaxSize] = {0};
+
+ // For each step, ensure there are few non-unique current[].
+ for (size_t step = 0; step < size; ++step) {
+ // How many times is each number visited?
+ uint32_t visited[kMaxSize] = {0};
+ for (size_t i = 0; i < size; ++i) {
+ visited[current[i]] += 1;
+ all_visited[current[i]] = 1; // visited at all across all steps?
+ }
+
+ // How many numbers are visited multiple times?
+ size_t num_contended = 0;
+ uint32_t max_contention = 0;
+ for (size_t i = 0; i < size; ++i) {
+ num_contended += visited[i] > 1;
+ max_contention = HWY_MAX(max_contention, visited[i]);
+ }
+
+ // Count/print if excessive collisions.
+ const size_t expected =
+ static_cast<size_t>(sqrtf(static_cast<float>(size)) * 2.0f);
+ if ((num_contended > expected) && (max_contention > 3)) {
+ ++num_bad;
+ if (true) {
+ fprintf(stderr, "size %u step %zu contended %zu max contention %u\n",
+ size, step, num_contended, max_contention);
+ for (size_t i = 0; i < size; ++i) {
+ fprintf(stderr, " %u\n", current[i]);
+ }
+ fprintf(stderr, "coprimes\n");
+ for (size_t i = 0; i < size; ++i) {
+ fprintf(stderr, " %u\n", coprimes[i]);
+ }
+ }
+ }
+
+ // Advance all ShuffledIota generators.
+ for (size_t i = 0; i < size; ++i) {
+ current[i] = shuffled[i].Next(current[i], divisor);
+ }
+ } // step
+
+ // Ensure each task was visited during at least one step.
+ for (size_t i = 0; i < size; ++i) {
+ HWY_ASSERT(all_visited[i] != 0);
+ }
+
+ if (num_bad != 0) {
+ fprintf(stderr, "size %u total bad: %zu\n", size, num_bad);
+ }
+ HWY_ASSERT(num_bad < kMaxSize / 10);
+ } // size
+}
+
+class DoWait {
+ public:
+ explicit DoWait(Worker& worker) : worker_(worker) {}
+
+ template <class Spin, class Wait>
+ void operator()(const Spin& spin, const Wait& wait) const {
+ wait.UntilWoken(worker_, spin);
+ }
+
+ private:
+ Worker& worker_;
+};
+
+class DoWakeWorkers {
+ public:
+ explicit DoWakeWorkers(Worker* workers) : workers_(workers) {}
+
+ template <class Spin, class Wait>
+ void operator()(const Spin&, const Wait& wait) const {
+ wait.WakeWorkers(workers_, workers_[0].WorkerEpoch());
+ }
+
+ private:
+ Worker* const workers_;
+};
+
+// Verifies that waiter(s) can be woken by another thread.
+TEST(ThreadPoolTest, TestWaiter) {
+ if (!hwy::HaveThreadingSupport()) return;
+
+ // Not actual threads, but we allocate and loop over this many workers.
+ for (size_t num_threads = 1; num_threads < 6; ++num_threads) {
+ const size_t num_workers = 1 + num_threads;
+ auto storage = hwy::AllocateAligned<uint8_t>(num_workers * sizeof(Worker));
+ HWY_ASSERT(storage);
+ const Divisor64 div_workers(num_workers);
+ Shared& shared = Shared::Get(); // already calls ReserveWorker(0).
+
+ for (WaitType wait_type :
+ {WaitType::kBlock, WaitType::kSpin1, WaitType::kSpinSeparate}) {
+ Worker* workers = pool::WorkerLifecycle::Init(
+ storage.get(), num_threads, PoolWorkerMapping(), div_workers, shared);
+
+ alignas(8) const Config config(SpinType::kPause, wait_type);
+
+ // This thread acts as the "main thread", which will wake the actual main
+ // and all its worker instances.
+ std::thread thread(
+ [&]() { CallWithConfig(config, DoWakeWorkers(workers)); });
+
+ // main is 0
+ for (size_t worker = 1; worker < num_workers; ++worker) {
+ CallWithConfig(config, DoWait(workers[1]));
+ }
+ thread.join();
+
+ pool::WorkerLifecycle::Destroy(workers, num_workers);
+ }
+ }
+}
+
+// Ensures all tasks are run. Similar to TestPool below but without threads.
+TEST(ThreadPoolTest, TestTasks) {
+ for (size_t num_threads = 1; num_threads <= 8; ++num_threads) {
+ const size_t num_workers = num_threads + 1;
+ auto storage = hwy::AllocateAligned<uint8_t>(num_workers * sizeof(Worker));
+ HWY_ASSERT(storage);
+ const Divisor64 div_workers(num_workers);
+ Shared& shared = Shared::Get();
+ Stats stats;
+ Worker* workers = WorkerLifecycle::Init(
+ storage.get(), num_threads, PoolWorkerMapping(), div_workers, shared);
+
+ constexpr uint64_t kMaxTasks = 20;
+ uint64_t mementos[kMaxTasks]; // non-atomic, no threads involved.
+ for (uint64_t num_tasks = 0; num_tasks < 20; ++num_tasks) {
+ for (uint64_t begin = 0; begin < AdjustedReps(32); ++begin) {
+ const uint64_t end = begin + num_tasks;
+
+ ZeroBytes(mementos, sizeof(mementos));
+ const auto func = [begin, end, &mementos](uint64_t task,
+ size_t /*worker*/) {
+ HWY_ASSERT(begin <= task && task < end);
+
+ // Store mementos ensure we visited each task.
+ mementos[task - begin] = 1000 + task;
+ };
+ Tasks tasks;
+ tasks.Set(begin, end, func);
+
+ Tasks::DivideRangeAmongWorkers(begin, end, div_workers, workers);
+ // The `tasks < workers` special case requires running by all workers.
+ for (size_t worker = 0; worker < num_workers; ++worker) {
+ tasks.WorkerRun(workers + worker, shared, stats);
+ }
+
+ // Ensure all tasks were run.
+ for (uint64_t task = begin; task < end; ++task) {
+ HWY_ASSERT_EQ(1000 + task, mementos[task - begin]);
+ }
+ }
+ }
+
+ WorkerLifecycle::Destroy(workers, num_workers);
+ }
+}
+
+// Ensures task parameter is in bounds, every parameter is reached,
+// pool can be reused (multiple consecutive Run calls), pool can be destroyed
+// (joining with its threads), num_threads=0 works (runs on current thread).
+TEST(ThreadPoolTest, TestPool) {
+ if (!hwy::HaveThreadingSupport()) return;
+
+ constexpr uint64_t kMaxTasks = 20;
+ static std::atomic<uint64_t> mementos[kMaxTasks];
+ static std::atomic<uint64_t> a_begin;
+ static std::atomic<uint64_t> a_end;
+ static std::atomic<uint64_t> a_num_workers;
+
+ // Called by pool; sets mementos and runs a nested but serial Run.
+ const auto func = [](uint64_t task, size_t worker) {
+ HWY_ASSERT(worker < a_num_workers.load());
+ const uint64_t begin = a_begin.load(std::memory_order_acquire);
+ const uint64_t end = a_end.load(std::memory_order_acquire);
+
+ if (!(begin <= task && task < end)) {
+ HWY_ABORT("Task %d not in [%d, %d]", static_cast<int>(task),
+ static_cast<int>(begin), static_cast<int>(end));
+ }
+
+ // Store mementos ensure we visited each task.
+ mementos[task - begin].store(1000 + task);
+
+ // Re-entering Run is fine on a 0-worker pool. Note that this must be
+ // per-thread so that it gets the `global_idx` it is running on.
+ hwy::ThreadPool inner(0);
+ inner.Run(begin, end,
+ [begin, end](uint64_t inner_task, size_t inner_worker) {
+ HWY_ASSERT(inner_worker == 0);
+ HWY_ASSERT(begin <= inner_task && inner_task < end);
+ });
+ };
+
+ for (size_t num_threads = 0; num_threads <= 6; num_threads += 3) {
+ hwy::ThreadPool pool(HWY_MIN(ThreadPool::MaxThreads(), num_threads));
+ a_num_workers.store(pool.NumWorkers());
+ for (bool spin : {true, false}) {
+ pool.SetWaitMode(spin ? PoolWaitMode::kSpin : PoolWaitMode::kBlock);
+
+ for (uint64_t num_tasks = 0; num_tasks < kMaxTasks; ++num_tasks) {
+ for (uint64_t all_begin = 0; all_begin < AdjustedReps(32);
+ ++all_begin) {
+ const uint64_t all_end = all_begin + num_tasks;
+ a_begin.store(all_begin, std::memory_order_release);
+ a_end.store(all_end, std::memory_order_release);
+
+ for (size_t i = 0; i < kMaxTasks; ++i) {
+ mementos[i].store(0);
+ }
+
+ pool.Run(all_begin, all_end, func);
+
+ for (uint64_t task = all_begin; task < all_end; ++task) {
+ const uint64_t expected = 1000 + task;
+ const uint64_t actual = mementos[task - all_begin].load();
+ if (expected != actual) {
+ HWY_ABORT(
+ "threads %zu, tasks %d: task not run, expected %d, got %d\n",
+ num_threads, static_cast<int>(num_tasks),
+ static_cast<int>(expected), static_cast<int>(actual));
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
+// Debug tsan builds seem to generate incorrect codegen for [&] of atomics, so
+// use a pointer to a state object instead.
+struct SmallAssignmentState {
+ // (Avoid mutex because it may perturb the worker thread scheduling)
+ std::atomic<uint64_t> num_tasks{0};
+ std::atomic<uint64_t> num_workers{0};
+ std::atomic<uint64_t> id_bits{0};
+ std::atomic<uint64_t> num_calls{0};
+};
+
+// Verify "thread" parameter when processing few tasks.
+TEST(ThreadPoolTest, TestSmallAssignments) {
+ if (!hwy::HaveThreadingSupport()) return;
+
+ static SmallAssignmentState state;
+
+ for (size_t num_threads :
+ {size_t{0}, size_t{1}, size_t{3}, size_t{5}, size_t{8}}) {
+ ThreadPool pool(HWY_MIN(ThreadPool::MaxThreads(), num_threads));
+ state.num_workers.store(pool.NumWorkers());
+
+ for (size_t mul = 1; mul <= 2; ++mul) {
+ const size_t num_tasks = pool.NumWorkers() * mul;
+ state.num_tasks.store(num_tasks);
+ state.id_bits.store(0);
+ state.num_calls.store(0);
+
+ pool.Run(0, num_tasks, [](uint64_t task, size_t worker) {
+ HWY_ASSERT(task < state.num_tasks.load());
+ HWY_ASSERT(worker < state.num_workers.load());
+
+ state.num_calls.fetch_add(1);
+
+ uint64_t bits = state.id_bits.load();
+ while (!state.id_bits.compare_exchange_weak(bits,
+ bits | (1ULL << worker))) {
+ }
+ });
+
+ // Correct number of tasks.
+ const uint64_t actual_calls = state.num_calls.load();
+ HWY_ASSERT(num_tasks == actual_calls);
+
+ const size_t num_participants = PopCount(state.id_bits.load());
+ // <= because some workers may not manage to run any tasks.
+ HWY_ASSERT(num_participants <= pool.NumWorkers());
+ }
+ }
+}
+
+struct Counter {
+ Counter() {
+ // Suppress "unused-field" warning.
+ (void)padding;
+ }
+ void Assimilate(const Counter& victim) { counter += victim.counter; }
+ std::atomic<uint64_t> counter{0};
+ uint64_t padding[15];
+};
+
+// Can switch between any wait mode, and multiple times.
+TEST(ThreadPoolTest, TestWaitMode) {
+ if (!hwy::HaveThreadingSupport()) return;
+
+ ThreadPool pool(9);
+ RandomState rng;
+ for (size_t i = 0; i < 100; ++i) {
+ pool.SetWaitMode((Random32(&rng) & 1u) ? PoolWaitMode::kSpin
+ : PoolWaitMode::kBlock);
+ }
+}
+
+TEST(ThreadPoolTest, TestCounter) {
+ if (!hwy::HaveThreadingSupport()) return;
+
+ const size_t kNumThreads = 12;
+ ThreadPool pool(kNumThreads);
+ for (PoolWaitMode mode : {PoolWaitMode::kSpin, PoolWaitMode::kBlock}) {
+ pool.SetWaitMode(mode);
+ alignas(128) Counter counters[1+kNumThreads];
+
+ const uint64_t kNumTasks = kNumThreads * 19;
+ pool.Run(0, kNumTasks,
+ [&counters](const uint64_t task, const size_t worker) {
+ counters[worker].counter.fetch_add(task);
+ });
+
+ uint64_t expected = 0;
+ for (uint64_t i = 0; i < kNumTasks; ++i) {
+ expected += i;
+ }
+
+ for (size_t i = 1; i < pool.NumWorkers(); ++i) {
+ counters[0].Assimilate(counters[i]);
+ }
+ HWY_ASSERT_EQ(expected, counters[0].counter.load());
+ }
+}
+
+} // namespace
+} // namespace pool
+} // namespace hwy
+
+HWY_TEST_MAIN();
diff --git a/third_party/highway/hwy/contrib/thread_pool/topology.cc b/third_party/highway/hwy/contrib/thread_pool/topology.cc
new file mode 100644
index 0000000000..02b927e4ae
--- /dev/null
+++ b/third_party/highway/hwy/contrib/thread_pool/topology.cc
@@ -0,0 +1,1280 @@
+// Copyright 2024 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "third_party/highway/hwy/contrib/thread_pool/topology.h"
+
+#include <ctype.h> // isspace
+#include <stddef.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <string.h> // strchr
+
+#include <array>
+#include <map>
+#include <string>
+#include <vector>
+
+#include "third_party/highway/hwy/base.h" // HWY_OS_WIN, HWY_WARN
+
+#if HWY_OS_APPLE
+#include <sys/sysctl.h>
+
+#include "third_party/highway/hwy/aligned_allocator.h" // HWY_ALIGNMENT
+#endif
+
+#if HWY_OS_WIN
+#ifndef NOMINMAX
+#define NOMINMAX
+#endif
+#ifndef WIN32_LEAN_AND_MEAN
+#define WIN32_LEAN_AND_MEAN
+#endif
+#ifndef _WIN32_WINNT
+#define _WIN32_WINNT 0x0601 // Windows 7 / Server 2008
+#endif
+#include <windows.h>
+#endif // HWY_OS_WIN
+
+#if HWY_OS_LINUX || HWY_OS_FREEBSD
+#ifndef _GNU_SOURCE
+#define _GNU_SOURCE
+#endif
+#include <errno.h>
+#include <fcntl.h>
+#include <pthread.h>
+#include <sched.h>
+#include <sys/stat.h>
+#include <sys/syscall.h>
+#include <sys/types.h>
+#include <unistd.h> // sysconf
+#endif // HWY_OS_LINUX || HWY_OS_FREEBSD
+
+#if HWY_OS_FREEBSD
+#include <sys/param.h>
+// After param.h / types.h.
+#include <sys/cpuset.h>
+#endif // HWY_OS_FREEBSD
+
+#if HWY_ARCH_WASM
+#include <emscripten/threading.h>
+#endif
+
+namespace hwy {
+
+HWY_CONTRIB_DLLEXPORT bool HaveThreadingSupport() {
+#if HWY_ARCH_WASM
+ return emscripten_has_threading_support() != 0;
+#else
+ return true;
+#endif
+}
+
+namespace {
+
+// Returns `whole / part`, with a check that `part` evenly divides `whole`,
+// which implies the result is exact.
+HWY_MAYBE_UNUSED size_t DivByFactor(size_t whole, size_t part) {
+ HWY_ASSERT(part != 0);
+ const size_t div = whole / part;
+ const size_t mul = div * part;
+ if (mul != whole) {
+ HWY_ABORT("%zu / %zu = %zu; *%zu = %zu\n", whole, part, div, part, mul);
+ }
+ return div;
+}
+
+#if HWY_OS_WIN
+
+using SLPI = SYSTEM_LOGICAL_PROCESSOR_INFORMATION_EX;
+
+template <class Func>
+bool ForEachSLPI(LOGICAL_PROCESSOR_RELATIONSHIP rel, Func&& func) {
+ // Get required buffer size.
+ DWORD buf_bytes = 0;
+ HWY_ASSERT(!GetLogicalProcessorInformationEx(rel, nullptr, &buf_bytes));
+ // Observed when `rel` is not supported:
+ if (HWY_UNLIKELY(buf_bytes == 0 && GetLastError() == ERROR_GEN_FAILURE)) {
+ if (rel != RelationNumaNodeEx && rel != RelationProcessorDie) {
+ HWY_WARN("Unexpected err %lx for GLPI relationship %d\n", GetLastError(),
+ static_cast<int>(rel));
+ }
+ return false;
+ }
+ HWY_ASSERT(GetLastError() == ERROR_INSUFFICIENT_BUFFER);
+ // Note: `buf_bytes` may be less than `sizeof(SLPI)`, which has padding.
+ // `calloc` zero-initializes the `Reserved` field, part of which has been
+ // repurposed into `GroupCount` in SDKs, 10.0.22000.0 or possibly earlier.
+ uint8_t* buf = static_cast<uint8_t*>(calloc(1, buf_bytes));
+ HWY_ASSERT(buf);
+
+ // Fill the buffer.
+ SLPI* info = reinterpret_cast<SLPI*>(buf);
+ if (HWY_UNLIKELY(!GetLogicalProcessorInformationEx(rel, info, &buf_bytes))) {
+ free(buf);
+ return false;
+ }
+
+ // Iterate over each SLPI. `sizeof(SLPI)` is unreliable, see above.
+ uint8_t* pos = buf;
+ while (pos < buf + buf_bytes) {
+ info = reinterpret_cast<SLPI*>(pos);
+ HWY_ASSERT(rel == RelationAll || info->Relationship == rel);
+ func(*info);
+ pos += info->Size;
+ }
+ if (pos != buf + buf_bytes) {
+ HWY_WARN("unexpected pos %p, end %p, buf_bytes %lu, sizeof(SLPI) %zu\n",
+ pos, buf + buf_bytes, buf_bytes, sizeof(SLPI));
+ }
+
+ free(buf);
+ return true;
+}
+
+size_t NumBits(size_t num_groups, const GROUP_AFFINITY* affinity) {
+ size_t total_bits = 0;
+ for (size_t i = 0; i < num_groups; ++i) {
+ size_t bits = 0;
+ hwy::CopyBytes<sizeof(bits)>(&affinity[i].Mask, &bits);
+ total_bits += hwy::PopCount(bits);
+ }
+ return total_bits;
+}
+
+// Calls `func(lp, lps)` for each index `lp` in the set, after ensuring that
+// `lp < lps.size()`. `line` is for debugging via Warn().
+template <class Func>
+void ForeachBit(size_t num_groups, const GROUP_AFFINITY* affinity,
+ std::vector<Topology::LP>& lps, int line, const Func& func) {
+ for (size_t group = 0; group < num_groups; ++group) {
+ size_t bits = 0;
+ hwy::CopyBytes<sizeof(bits)>(&affinity[group].Mask, &bits);
+ while (bits != 0) {
+ size_t lp = group * 64 + Num0BitsBelowLS1Bit_Nonzero64(bits);
+ bits &= bits - 1; // clear LSB
+ if (HWY_UNLIKELY(lp >= lps.size())) {
+ Warn(__FILE__, line, "Clamping lp %zu to lps.size() %zu, groups %zu\n",
+ lp, lps.size(), num_groups);
+ lp = lps.size() - 1;
+ }
+ func(lp, lps);
+ }
+ }
+}
+
+#elif HWY_OS_APPLE
+
+// Returns whether sysctlbyname() succeeded; if so, writes `val / div` to
+// `out`, otherwise sets `err`.
+template <typename T>
+bool Sysctl(const char* name, size_t div, int& err, T* out) {
+ size_t val = 0;
+ size_t size = sizeof(val);
+ // Last two arguments are for updating the value, which we do not want.
+ const int ret = sysctlbyname(name, &val, &size, nullptr, 0);
+ if (HWY_UNLIKELY(ret != 0)) {
+ // Do not print warnings because some `name` are expected to fail.
+ err = ret;
+ return false;
+ }
+ *out = static_cast<T>(DivByFactor(val, div));
+ return true;
+}
+
+#endif // HWY_OS_*
+
+} // namespace
+
+HWY_CONTRIB_DLLEXPORT size_t TotalLogicalProcessors() {
+ size_t total_lps = 0;
+#if HWY_ARCH_WASM
+ const int num_cores = emscripten_num_logical_cores();
+ if (num_cores > 0) total_lps = static_cast<size_t>(num_cores);
+#elif HWY_OS_WIN
+ // If there are multiple groups, this should return them all, rather than
+ // just the first 64, but VMs report less.
+ (void)ForEachSLPI(RelationProcessorCore, [&total_lps](const SLPI& info) {
+ const PROCESSOR_RELATIONSHIP& p = info.Processor;
+ total_lps += NumBits(p.GroupCount, p.GroupMask);
+ });
+#elif HWY_OS_LINUX
+ // Only check "online" because sysfs entries such as topology are missing for
+ // offline CPUs, which will cause `DetectPackages` to fail.
+ const long ret = sysconf(_SC_NPROCESSORS_ONLN); // NOLINT(runtime/int)
+ if (ret < 0) {
+ HWY_WARN("Unexpected _SC_NPROCESSORS_CONF = %d\n", static_cast<int>(ret));
+ } else {
+ total_lps = static_cast<size_t>(ret);
+ }
+#elif HWY_OS_APPLE
+ int err;
+ // Only report P processors.
+ if (!Sysctl("hw.perflevel0.logicalcpu", 1, err, &total_lps)) {
+ total_lps = 0;
+ }
+#endif
+
+ if (HWY_UNLIKELY(total_lps == 0)) { // Failed to detect.
+ HWY_WARN(
+ "Unknown TotalLogicalProcessors, assuming 1. "
+ "HWY_OS_: WIN=%d LINUX=%d APPLE=%d;\n"
+ "HWY_ARCH_: WASM=%d X86=%d PPC=%d ARM=%d RISCV=%d S390X=%d\n",
+ HWY_OS_WIN, HWY_OS_LINUX, HWY_OS_APPLE, HWY_ARCH_WASM, HWY_ARCH_X86,
+ HWY_ARCH_PPC, HWY_ARCH_ARM, HWY_ARCH_RISCV, HWY_ARCH_S390X);
+ return 1;
+ }
+
+ // Warn that we are clamping.
+ if (HWY_UNLIKELY(total_lps > kMaxLogicalProcessors)) {
+ HWY_WARN("OS reports %zu processors but clamping to %zu\n", total_lps,
+ kMaxLogicalProcessors);
+ total_lps = kMaxLogicalProcessors;
+ }
+
+ return total_lps;
+}
+
+// ------------------------------ Affinity
+
+#if HWY_OS_LINUX || HWY_OS_FREEBSD
+namespace {
+
+#if HWY_OS_LINUX
+using CpuSet = cpu_set_t;
+#else
+using CpuSet = cpuset_t;
+#endif
+
+// Helper functions reduce the number of #if in GetThreadAffinity.
+int GetAffinity(CpuSet* set) {
+ // To specify the current thread, pass 0 on Linux/Android and -1 on FreeBSD.
+#if defined(__ANDROID__) && __ANDROID_API__ < 12
+ return syscall(__NR_sched_getaffinity, 0, sizeof(CpuSet), set);
+#elif HWY_OS_FREEBSD
+ return cpuset_getaffinity(CPU_LEVEL_WHICH, CPU_WHICH_TID, -1, sizeof(CpuSet),
+ set);
+#else // normal Linux
+ return sched_getaffinity(0, sizeof(CpuSet), set);
+#endif
+}
+
+int SetAffinity(CpuSet* set) {
+ // To specify the current thread, pass 0 on Linux/Android and -1 on FreeBSD.
+#if defined(__ANDROID__) && __ANDROID_API__ < 12
+ return syscall(__NR_sched_setaffinity, 0, sizeof(CpuSet), set);
+#elif HWY_OS_FREEBSD
+ return cpuset_setaffinity(CPU_LEVEL_WHICH, CPU_WHICH_TID, -1, sizeof(CpuSet),
+ set);
+#else // normal Linux
+ return sched_setaffinity(0, sizeof(CpuSet), set);
+#endif
+}
+
+bool IsSet(size_t lp, const CpuSet* set) {
+#if HWY_COMPILER_GCC_ACTUAL
+ // Workaround for GCC compiler warning with CPU_ISSET macro
+ HWY_DIAGNOSTICS(push)
+ HWY_DIAGNOSTICS_OFF(disable : 4305 4309, ignored "-Wsign-conversion")
+#endif
+ const int is_set = CPU_ISSET(static_cast<int>(lp), set);
+#if HWY_COMPILER_GCC_ACTUAL
+ HWY_DIAGNOSTICS(pop)
+#endif
+ return is_set != 0;
+}
+
+void Set(size_t lp, CpuSet* set) {
+#if HWY_COMPILER_GCC_ACTUAL
+ // Workaround for GCC compiler warning with CPU_SET macro
+ HWY_DIAGNOSTICS(push)
+ HWY_DIAGNOSTICS_OFF(disable : 4305 4309, ignored "-Wsign-conversion")
+#endif
+ CPU_SET(static_cast<int>(lp), set);
+#if HWY_COMPILER_GCC_ACTUAL
+ HWY_DIAGNOSTICS(pop)
+#endif
+}
+
+} // namespace
+#endif // HWY_OS_LINUX || HWY_OS_FREEBSD
+
+HWY_CONTRIB_DLLEXPORT bool GetThreadAffinity(LogicalProcessorSet& lps) {
+#if HWY_OS_WIN
+ // Only support the first 64 because WINE does not support processor groups.
+ const HANDLE hThread = GetCurrentThread();
+ const DWORD_PTR prev = SetThreadAffinityMask(hThread, ~DWORD_PTR(0));
+ if (!prev) return false;
+ (void)SetThreadAffinityMask(hThread, prev);
+ lps = LogicalProcessorSet(); // clear all
+ lps.SetNonzeroBitsFrom64(prev);
+ return true;
+#elif HWY_OS_LINUX || HWY_OS_FREEBSD
+ CpuSet set;
+ CPU_ZERO(&set);
+ const int err = GetAffinity(&set);
+ if (err != 0) return false;
+ for (size_t lp = 0; lp < kMaxLogicalProcessors; ++lp) {
+ if (IsSet(lp, &set)) lps.Set(lp);
+ }
+ return true;
+#else
+ // For HWY_OS_APPLE, affinity is not supported. Do not even set lp=0 to force
+ // callers to handle this case.
+ (void)lps;
+ return false;
+#endif
+}
+
+HWY_CONTRIB_DLLEXPORT bool SetThreadAffinity(const LogicalProcessorSet& lps) {
+#if HWY_OS_WIN
+ const HANDLE hThread = GetCurrentThread();
+ const DWORD_PTR prev = SetThreadAffinityMask(hThread, lps.Get64());
+ return prev != 0;
+#elif HWY_OS_LINUX || HWY_OS_FREEBSD
+ CpuSet set;
+ CPU_ZERO(&set);
+ lps.Foreach([&set](size_t lp) { Set(lp, &set); });
+ const int err = SetAffinity(&set);
+ if (err != 0) return false;
+ return true;
+#else
+ // Apple THREAD_AFFINITY_POLICY is only an (often ignored) hint.
+ (void)lps;
+ return false;
+#endif
+}
+
+namespace {
+
+struct PackageSizes {
+ size_t num_clusters;
+ size_t num_cores;
+};
+
+#if HWY_OS_LINUX
+
+class File {
+ public:
+ explicit File(const char* path) {
+ for (;;) {
+ fd_ = open(path, O_RDONLY);
+ if (fd_ > 0) return; // success
+ if (errno == EINTR) continue; // signal: retry
+ if (errno == ENOENT) return; // not found, give up
+ HWY_WARN("Unexpected error opening %s: %d\n", path, errno);
+ return; // unknown error, give up
+ }
+ }
+
+ ~File() {
+ if (fd_ > 0) {
+ for (;;) {
+ const int ret = close(fd_);
+ if (ret == 0) break; // success
+ if (errno == EINTR) continue; // signal: retry
+ HWY_WARN("Unexpected error closing file: %d\n", errno);
+ return; // unknown error, ignore
+ }
+ }
+ }
+
+ // Returns number of bytes read or 0 on failure.
+ size_t Read(char* buf200) const {
+ if (fd_ < 0) return 0;
+ size_t pos = 0;
+ for (;;) {
+ // read instead of `pread`, which might not work for sysfs.
+ const auto bytes_read = read(fd_, buf200 + pos, 200 - pos);
+ if (bytes_read == 0) { // EOF: done
+ buf200[pos++] = '\0';
+ return pos;
+ }
+ if (bytes_read == -1) {
+ if (errno == EINTR) continue; // signal: retry
+ HWY_WARN("Unexpected error reading file: %d\n", errno);
+ return 0;
+ }
+ pos += static_cast<size_t>(bytes_read);
+ HWY_ASSERT(pos <= 200);
+ }
+ }
+
+ private:
+ int fd_;
+};
+
+// Returns bytes read, or 0 on failure.
+size_t ReadSysfs(const char* format, size_t lp, char* buf200) {
+ char path[200];
+ const int bytes_written = snprintf(path, sizeof(path), format, lp);
+ HWY_ASSERT(0 < bytes_written &&
+ bytes_written < static_cast<int>(sizeof(path) - 1));
+
+ const File file(path);
+ return file.Read(buf200);
+}
+
+// Interprets [str + pos, str + end) as base-10 ASCII. Stops when any non-digit
+// is found, or at end. Returns false if no digits found.
+bool ParseDigits(const char* str, const size_t end, size_t& pos, size_t* out) {
+ HWY_ASSERT(pos <= end);
+ // 9 digits cannot overflow even 32-bit size_t.
+ const size_t stop = pos + 9;
+ *out = 0;
+ for (; pos < HWY_MIN(end, stop); ++pos) {
+ const int c = str[pos];
+ if (c < '0' || c > '9') break;
+ *out *= 10;
+ *out += static_cast<size_t>(c - '0');
+ }
+ if (pos == 0) { // No digits found
+ *out = 0;
+ return false;
+ }
+ return true;
+}
+
+// Number, plus optional K or M suffix, plus terminator.
+bool ParseNumberWithOptionalSuffix(const char* str, size_t len, size_t* out) {
+ size_t pos = 0;
+ if (!ParseDigits(str, len, pos, out)) return false;
+ if (str[pos] == 'K') {
+ *out <<= 10;
+ ++pos;
+ }
+ if (str[pos] == 'M') {
+ *out <<= 20;
+ ++pos;
+ }
+ if (str[pos] != '\0' && str[pos] != '\n') {
+ HWY_ABORT("Expected [suffix] terminator at %zu %s\n", pos, str);
+ }
+ return true;
+}
+
+bool ReadNumberWithOptionalSuffix(const char* format, size_t lp, size_t* out) {
+ char buf200[200];
+ const size_t pos = ReadSysfs(format, lp, buf200);
+ if (pos == 0) return false;
+ return ParseNumberWithOptionalSuffix(buf200, pos, out);
+}
+
+const char* kPackage =
+ "/sys/devices/system/cpu/cpu%zu/topology/physical_package_id";
+const char* kCluster = "/sys/devices/system/cpu/cpu%zu/cache/index3/id";
+const char* kCore = "/sys/devices/system/cpu/cpu%zu/topology/core_id";
+const char* kL2Size = "/sys/devices/system/cpu/cpu%zu/cache/index2/size";
+const char* kL3Size = "/sys/devices/system/cpu/cpu%zu/cache/index3/size";
+const char* kNode = "/sys/devices/system/node/node%zu/cpulist";
+
+// sysfs values can be arbitrarily large, so store in a map and replace with
+// indices in order of appearance.
+class Remapper {
+ public:
+ // Returns false on error, or sets `out_index` to the index of the sysfs
+ // value selected by `format` and `lp`.
+ template <typename T>
+ bool operator()(const char* format, size_t lp, T* HWY_RESTRICT out_index) {
+ size_t opaque;
+ if (!ReadNumberWithOptionalSuffix(format, lp, &opaque)) return false;
+
+ const auto ib = indices_.insert({opaque, num_});
+ num_ += ib.second; // increment if inserted
+ const size_t index = ib.first->second; // new or existing
+ HWY_ASSERT(index < num_);
+ HWY_ASSERT(index < hwy::LimitsMax<T>());
+ *out_index = static_cast<T>(index);
+ return true;
+ }
+
+ size_t Num() const { return num_; }
+
+ private:
+ std::map<size_t, size_t> indices_;
+ size_t num_ = 0;
+};
+
+// For internal use by `DetectPackages`.
+struct PerPackage {
+ Remapper clusters;
+ Remapper cores;
+ // We rely on this zero-init and increment it below.
+ uint8_t smt_per_core[kMaxLogicalProcessors] = {0};
+};
+
+// Initializes `lps` and returns a PackageSizes vector (empty on failure)
+// indicating the number of clusters and cores per package.
+std::vector<PackageSizes> DetectPackages(std::vector<Topology::LP>& lps) {
+ std::vector<PackageSizes> empty;
+
+ Remapper packages;
+ for (size_t lp = 0; lp < lps.size(); ++lp) {
+ if (!packages(kPackage, lp, &lps[lp].package)) {
+ HWY_WARN("Failed to read sysfs package for LP %zu\n", lp);
+ return empty;
+ }
+ }
+ std::vector<PerPackage> per_package(packages.Num());
+ HWY_ASSERT(!per_package.empty());
+
+ for (size_t lp = 0; lp < lps.size(); ++lp) {
+ PerPackage& pp = per_package[lps[lp].package];
+ // Not a failure: some CPUs lack a (shared) L3 cache.
+ if (!pp.clusters(kCluster, lp, &lps[lp].cluster)) {
+ lps[lp].cluster = 0;
+ }
+
+ if (!pp.cores(kCore, lp, &lps[lp].core)) {
+ HWY_WARN("Failed to read sysfs core for LP %zu\n", lp);
+ return empty;
+ }
+
+ // SMT ID is how many LP we have already seen assigned to the same core.
+ HWY_ASSERT(lps[lp].core < kMaxLogicalProcessors);
+ lps[lp].smt = pp.smt_per_core[lps[lp].core]++;
+ HWY_ASSERT(lps[lp].smt < 16);
+ }
+
+ std::vector<PackageSizes> package_sizes(per_package.size());
+ for (size_t p = 0; p < package_sizes.size(); ++p) {
+ // Was zero if the package has no shared L3, see above.
+ package_sizes[p].num_clusters = HWY_MAX(1, per_package[p].clusters.Num());
+ package_sizes[p].num_cores = per_package[p].cores.Num();
+ HWY_ASSERT(package_sizes[p].num_cores != 0);
+ }
+ return package_sizes;
+}
+
+std::vector<size_t> ExpandList(const char* list, size_t list_end,
+ size_t max_lp) {
+ std::vector<size_t> expanded;
+ constexpr size_t kNotFound = ~size_t{0};
+ size_t pos = 0;
+
+ // Gracefully handle empty lists, happens on GH200 systems (#2668).
+ if (isspace(list[0]) && list_end <= 2) return expanded;
+
+ // Returns first `found_pos >= pos` where `list[found_pos] == c`, or
+ // `kNotFound`.
+ const auto find = [list, list_end, &pos](char c) -> size_t {
+ const char* found_ptr = strchr(list + pos, c);
+ if (found_ptr == nullptr) return kNotFound;
+ const size_t found_pos = static_cast<size_t>(found_ptr - list);
+ HWY_ASSERT(found_pos < list_end && list[found_pos] == c);
+ return found_pos;
+ };
+
+ // Reads LP number and advances `pos`. `end` is for verifying we did not
+ // read past a known terminator, or the end of string.
+ const auto parse_lp = [list, list_end, &pos, max_lp](size_t end) -> size_t {
+ end = HWY_MIN(end, list_end);
+ size_t lp;
+ HWY_ASSERT(ParseDigits(list, end, pos, &lp));
+ HWY_IF_CONSTEXPR(HWY_ARCH_RISCV) {
+ // On RISC-V, both TotalLogicalProcessors and GetThreadAffinity may
+ // under-report the count, hence clamp.
+ lp = HWY_MIN(lp, max_lp);
+ }
+ HWY_ASSERT(lp <= max_lp);
+ HWY_ASSERT(pos <= end);
+ return lp;
+ };
+
+ // Parse all [first-]last separated by commas.
+ for (;;) {
+ // Single number or first of range: ends with dash, comma, or end.
+ const size_t lp_range_first = parse_lp(HWY_MIN(find('-'), find(',')));
+
+ if (list[pos] == '-') { // range
+ ++pos; // skip dash
+ // Last of range ends with comma or end.
+ const size_t lp_range_last = parse_lp(find(','));
+
+ expanded.reserve(expanded.size() + lp_range_last - lp_range_first + 1);
+ for (size_t lp = lp_range_first; lp <= lp_range_last; ++lp) {
+ expanded.push_back(lp);
+ }
+ } else { // single number
+ expanded.push_back(lp_range_first);
+ }
+
+ // Done if reached end of string.
+ if (pos == list_end || list[pos] == '\0' || list[pos] == '\n') {
+ break;
+ }
+ // Comma means at least one more term is coming.
+ if (list[pos] == ',') {
+ ++pos;
+ continue;
+ }
+ HWY_ABORT("Unexpected character at %zu in %s\n", pos, list);
+ } // for pos
+
+ return expanded;
+}
+
+// Sets LP.node for all `lps`.
+void SetNodes(std::vector<Topology::LP>& lps) {
+ // For each NUMA node found via sysfs:
+ for (size_t node = 0;; node++) {
+ // Read its cpulist so we can scatter `node` to all its `lps`.
+ char buf200[200];
+ const size_t bytes_read = ReadSysfs(kNode, node, buf200);
+ if (bytes_read == 0) break;
+ const std::vector<size_t> list =
+ ExpandList(buf200, bytes_read, lps.size() - 1);
+ for (size_t lp : list) {
+ lps[lp].node = static_cast<uint8_t>(node);
+ }
+ }
+}
+
+void SetClusterCacheSizes(std::vector<Topology::Package>& packages) {
+ for (size_t ip = 0; ip < packages.size(); ++ip) {
+ Topology::Package& p = packages[ip];
+ for (size_t ic = 0; ic < p.clusters.size(); ++ic) {
+ Topology::Cluster& c = p.clusters[ic];
+ const size_t lp = c.lps.First();
+ size_t bytes;
+ if (ReadNumberWithOptionalSuffix(kL2Size, lp, &bytes)) {
+ c.private_kib = bytes >> 10;
+ }
+ if (ReadNumberWithOptionalSuffix(kL3Size, lp, &bytes)) {
+ c.shared_kib = bytes >> 10;
+ }
+ }
+ }
+}
+
+#elif HWY_OS_WIN
+
+// See #2734. GroupCount was added around Windows 10, but SDK docs do not
+// mention the actual version required. It is known to be absent in 8.1 and
+// MinGW 5.0.1, and present in the 10.0.22000.0 SDK. However, the OS must also
+// know about the field. Thus we zero-initialize the reserved field, assume it
+// remains zero, and return 1 if zero (old style single GroupMask), otherwise
+// the number of groups. There are two such structures, but note that
+// `PROCESSOR_RELATIONSHIP` already had this field.
+static size_t GroupCount(const CACHE_RELATIONSHIP& cr) {
+ // Added as the last u16 in the reserved area before GroupMask. We only read
+ // one byte because 256*64 processor bits are plenty.
+ const uint8_t* pcount =
+ reinterpret_cast<const uint8_t*>(&cr.GroupMask) - sizeof(uint16_t);
+ return HWY_MAX(pcount[HWY_IS_BIG_ENDIAN], 1);
+}
+
+static size_t GroupCount(const NUMA_NODE_RELATIONSHIP& nn) {
+ const uint8_t* pcount =
+ reinterpret_cast<const uint8_t*>(&nn.GroupMask) - sizeof(uint16_t);
+ return HWY_MAX(pcount[HWY_IS_BIG_ENDIAN], 1);
+}
+
+struct PerPackage {
+ size_t clusters = 0;
+ size_t cores = 0;
+ size_t max_lps_per_core = 0;
+};
+
+// Returns per-package vector and assigns LP.package to an index within it.
+std::vector<PerPackage> AssignPackageIndices(std::vector<Topology::LP>& lps) {
+ size_t package_idx = 0;
+ (void)ForEachSLPI(
+ RelationProcessorPackage, [&lps, &package_idx](const SLPI& info) {
+ const PROCESSOR_RELATIONSHIP& p = info.Processor;
+ ForeachBit(p.GroupCount, p.GroupMask, lps, __LINE__,
+ [&package_idx](size_t lp, std::vector<Topology::LP>& lps) {
+ lps[lp].package = static_cast<uint8_t>(package_idx);
+ });
+ ++package_idx;
+ });
+ return std::vector<PerPackage>(package_idx);
+}
+
+// Sets LP.core and LP.smt and updates `PerPackage.cores/max_lps_per_core`.
+void AssignCoreSmtIndices(std::vector<Topology::LP>& lps,
+ std::vector<PerPackage>& per_package) {
+ (void)ForEachSLPI(
+ RelationProcessorCore, [&lps, &per_package](const SLPI& info) {
+ const PROCESSOR_RELATIONSHIP& p = info.Processor;
+ PerPackage* pp = nullptr;
+ // Foreach LP in this core: assign its core and smt.
+ size_t smt = 0;
+ ForeachBit(p.GroupCount, p.GroupMask, lps, __LINE__,
+ [&](size_t lp, std::vector<Topology::LP>& lps) {
+ pp = &per_package[lps[lp].package];
+ lps[lp].core = static_cast<uint16_t>(pp->cores);
+ lps[lp].smt = static_cast<uint8_t>(smt++);
+ });
+ HWY_ASSERT(pp != nullptr);
+ ++pp->cores;
+ pp->max_lps_per_core =
+ HWY_MAX(pp->max_lps_per_core, NumBits(p.GroupCount, p.GroupMask));
+ HWY_ASSERT(pp->max_lps_per_core != 0);
+ });
+}
+
+// Interprets cluster (typically a shared L3 cache) as a "processor die". Sets
+// LP.cluster and updates `PerPackage.clusters`.
+void AssignClusterIndices(std::vector<Topology::LP>& lps,
+ std::vector<PerPackage>& per_package) {
+ // Shared between `foreach_die` and `foreach_l3`. Assigns all LPs to this
+ // cluster and increments the cluster index.
+ const auto foreach_cluster = [&](size_t num_groups,
+ const GROUP_AFFINITY* groups) {
+ PerPackage* pp = nullptr;
+ ForeachBit(num_groups, groups, lps, __LINE__,
+ [&per_package, &pp](size_t lp, std::vector<Topology::LP>& lps) {
+ pp = &per_package[lps[lp].package];
+ lps[lp].cluster = static_cast<uint16_t>(pp->clusters);
+ });
+ if (pp != nullptr) {
+ ++pp->clusters;
+ }
+ };
+
+ // Passes group bits to `foreach_cluster`, depending on relationship type.
+ const auto foreach_die = [&foreach_cluster](const SLPI& info) {
+ const PROCESSOR_RELATIONSHIP& p = info.Processor;
+ foreach_cluster(p.GroupCount, p.GroupMask);
+ };
+ const auto foreach_l3 = [&foreach_cluster](const SLPI& info) {
+ const CACHE_RELATIONSHIP& cr = info.Cache;
+ if (cr.Type != CacheUnified && cr.Type != CacheData) return;
+ if (cr.Level != 3) return;
+ foreach_cluster(GroupCount(cr), cr.GroupMasks);
+ };
+
+ if (!ForEachSLPI(RelationProcessorDie, foreach_die)) {
+ // Has been observed to fail; also check for shared L3 caches.
+ (void)ForEachSLPI(RelationCache, foreach_l3);
+ }
+
+ // All packages should have the same number of clusters.
+ for (size_t package_idx = 1; package_idx < per_package.size();
+ ++package_idx) {
+ if (per_package[package_idx].clusters != per_package[0].clusters) {
+ HWY_ABORT("pkg %zu has %zu clusters, expected %zu\n", package_idx,
+ per_package[package_idx].clusters, per_package[0].clusters);
+ }
+ }
+
+ if (per_package[0].clusters == 0) {
+ HWY_WARN("No clusters found, assuming 1 cluster\n");
+ for (PerPackage& pp : per_package) {
+ pp.clusters = 1;
+ }
+ for (Topology::LP& lp : lps) {
+ lp.cluster = 0;
+ }
+ }
+}
+
+// Initializes `lps` and returns a `PackageSizes` vector (empty on failure)
+// indicating the number of clusters and cores per package.
+std::vector<PackageSizes> DetectPackages(std::vector<Topology::LP>& lps) {
+ std::vector<PerPackage> per_package = AssignPackageIndices(lps);
+ if (per_package.empty()) return {};
+ AssignCoreSmtIndices(lps, per_package);
+ AssignClusterIndices(lps, per_package);
+
+ std::vector<PackageSizes> packages(per_package.size());
+ for (size_t package_idx = 0; package_idx < per_package.size();
+ ++package_idx) {
+ packages[package_idx].num_clusters = per_package[package_idx].clusters;
+ packages[package_idx].num_cores = per_package[package_idx].cores;
+ }
+ return packages;
+}
+
+// Sets LP.node for all `lps`.
+void SetNodes(std::vector<Topology::LP>& lps) {
+ // Zero-initialize all nodes in case the below fails.
+ for (size_t lp = 0; lp < lps.size(); ++lp) {
+ lps[lp].node = 0;
+ }
+
+ // We want the full NUMA nodes, but Windows Server 2022 truncates the results
+ // of `RelationNumaNode` to a single 64-LP group. To get the old, unlimited
+ // behavior without using the new `RelationNumaNodeEx` symbol, use the old
+ // `RelationAll` and filter the SLPI we want.
+ (void)ForEachSLPI(RelationAll, [&](const SLPI& info) {
+ if (info.Relationship != RelationNumaNode) return;
+ const NUMA_NODE_RELATIONSHIP& nn = info.NumaNode;
+ // This field was previously reserved/zero. There is at least one group.
+ const size_t num_groups = HWY_MAX(1, GroupCount(nn));
+ const uint8_t node = static_cast<uint8_t>(nn.NodeNumber);
+ ForeachBit(num_groups, nn.GroupMasks, lps, __LINE__,
+ [node](size_t lp, std::vector<Topology::LP>& lps) {
+ lps[lp].node = node;
+ });
+ });
+}
+
+#elif HWY_OS_APPLE
+
+// Initializes `lps` and returns a `PackageSizes` vector (empty on failure)
+// indicating the number of clusters and cores per package.
+std::vector<PackageSizes> DetectPackages(std::vector<Topology::LP>& lps) {
+ int err;
+
+ size_t total_cores = 0;
+ if (!Sysctl("hw.perflevel0.physicalcpu", 1, err, &total_cores)) {
+ HWY_WARN("Error %d detecting total_cores, assuming one per LP\n", err);
+ total_cores = lps.size();
+ }
+
+ if (lps.size() % total_cores != 0) {
+ HWY_WARN("LPs %zu not a multiple of total_cores %zu\n", lps.size(),
+ total_cores);
+ }
+ const size_t lp_per_core = DivCeil(lps.size(), total_cores);
+
+ size_t cores_per_cluster = 0;
+ if (!Sysctl("hw.perflevel0.cpusperl2", 1, err, &cores_per_cluster)) {
+ HWY_WARN("Error %d detecting cores_per_cluster\n", err);
+ cores_per_cluster = HWY_MIN(4, total_cores);
+ }
+
+ if (total_cores % cores_per_cluster != 0) {
+ HWY_WARN("total_cores %zu not a multiple of cores_per_cluster %zu\n",
+ total_cores, cores_per_cluster);
+ }
+
+ for (size_t lp = 0; lp < lps.size(); ++lp) {
+ lps[lp].package = 0; // single package
+ lps[lp].core = static_cast<uint16_t>(lp / lp_per_core);
+ lps[lp].smt = static_cast<uint8_t>(lp % lp_per_core);
+ lps[lp].cluster = static_cast<uint16_t>(lps[lp].core / cores_per_cluster);
+ }
+
+ PackageSizes ps;
+ ps.num_clusters = DivCeil(total_cores, cores_per_cluster);
+ ps.num_cores = total_cores;
+ return std::vector<PackageSizes>{ps};
+}
+
+// Sets LP.node for all `lps`.
+void SetNodes(std::vector<Topology::LP>& lps) {
+ for (size_t lp = 0; lp < lps.size(); ++lp) {
+ lps[lp].node = 0; // no NUMA
+ }
+}
+
+#endif // HWY_OS_*
+
+#if HWY_OS_WIN || HWY_OS_APPLE
+
+void SetClusterCacheSizes(std::vector<Topology::Package>& packages) {
+ // Assumes clusters are homogeneous. Otherwise, we would have to scan
+ // `RelationCache` again and find the corresponding package_idx.
+ const Cache* caches = DataCaches();
+ const size_t private_kib = caches ? caches[2].size_kib : 0;
+ const size_t shared_kib = caches ? caches[3].size_kib : 0;
+
+ for (size_t ip = 0; ip < packages.size(); ++ip) {
+ Topology::Package& p = packages[ip];
+ for (size_t ic = 0; ic < p.clusters.size(); ++ic) {
+ Topology::Cluster& c = p.clusters[ic];
+ c.private_kib = private_kib;
+ c.shared_kib = shared_kib;
+ }
+ }
+}
+
+#endif // HWY_OS_WIN || HWY_OS_APPLE
+
+} // namespace
+
+HWY_CONTRIB_DLLEXPORT Topology::Topology() {
+#if HWY_OS_LINUX || HWY_OS_WIN || HWY_OS_APPLE
+ lps.resize(TotalLogicalProcessors());
+ const std::vector<PackageSizes>& package_sizes = DetectPackages(lps);
+ if (package_sizes.empty()) return;
+ SetNodes(lps);
+
+ // Allocate per-package/cluster/core vectors. This indicates to callers that
+ // detection succeeded.
+ packages.resize(package_sizes.size());
+ for (size_t p = 0; p < packages.size(); ++p) {
+ packages[p].clusters.resize(package_sizes[p].num_clusters);
+ packages[p].cores.resize(package_sizes[p].num_cores);
+ }
+
+ // Populate the per-cluster/core sets of LP.
+ for (size_t lp = 0; lp < lps.size(); ++lp) {
+ Package& p = packages[lps[lp].package];
+ p.clusters[lps[lp].cluster].lps.Set(lp);
+ p.cores[lps[lp].core].lps.Set(lp);
+ }
+
+ SetClusterCacheSizes(packages);
+#endif // HWY_OS_*
+}
+
+// ------------------------------ Cache detection
+
+namespace {
+
+using Caches = std::array<Cache, 4>;
+
+// We assume homogeneous caches across all clusters because some OS APIs return
+// a single value for a class of CPUs.
+
+#if HWY_OS_LINUX
+std::string ReadString(const char* name, size_t index) {
+ // First CPU is usually a P core.
+ const std::string path("/sys/devices/system/cpu/cpu0/cache/index%zu/");
+ char buf200[200];
+ size_t end = ReadSysfs((path + name).c_str(), index, buf200);
+ // Remove trailing newline/null to simplify string comparison.
+ for (; end != 0; --end) {
+ if (buf200[end - 1] != '\0' && buf200[end - 1] != '\n') break;
+ }
+ return std::string(buf200, buf200 + end);
+}
+
+template <typename T>
+bool WriteSysfs(const char* name, size_t index, T* out) {
+ const std::string str = ReadString(name, index);
+ // Do not call `ParseNumberWithOptionalSuffix` because it acts on the
+ // K suffix in "size", but we actually want KiB.
+ size_t pos = 0;
+ size_t val;
+ if (!ParseDigits(str.c_str(), str.length(), pos, &val)) return false;
+ HWY_ASSERT(pos <= str.length());
+ *out = static_cast<T>(val);
+ return true;
+}
+
+// Reading from sysfs is preferred because sysconf returns L3 associativity = 0
+// on some CPUs, and does not indicate sharing across cores.
+// https://www.kernel.org/doc/Documentation/ABI/testing/sysfs-devices-system-cpu
+bool InitCachesSysfs(Caches& caches) {
+ // For computing shared cache sizes.
+ std::vector<hwy::Topology::LP> lps(TotalLogicalProcessors());
+ const std::vector<PackageSizes> package_sizes = DetectPackages(lps);
+ // `package_sizes` is only used to check that `lps` were filled.
+ if (package_sizes.empty()) {
+ HWY_WARN("no packages, shared cache sizes may be incorrect\n");
+ return false;
+ }
+
+ for (size_t i = 0;; ++i) {
+ const std::string type = ReadString("type", i);
+ if (type.empty()) break; // done, no more entries
+ if (type != "Data" && type != "Unified") continue;
+ uint32_t level;
+ if (!WriteSysfs("level", i, &level)) continue;
+ if (level != 1 && level != 2 && level != 3) continue;
+ Cache& c = caches[level];
+
+ // Check before overwriting any fields.
+ if (c.size_kib != 0) {
+ HWY_WARN("ignoring another L%u, first size %u\n", level, c.size_kib);
+ continue;
+ }
+
+ const bool ok = WriteSysfs("size", i, &c.size_kib) &&
+ WriteSysfs("ways_of_associativity", i, &c.associativity) &&
+ WriteSysfs("number_of_sets", i, &c.sets);
+ if (HWY_UNLIKELY(!ok)) {
+ HWY_WARN("skipping partially-detected L%u, error %d\n", level, errno);
+ c = Cache();
+ continue;
+ }
+
+ // Compute line size *before* adjusting the size for sharing. Note that
+ // `coherency_line_size` exists, but we are not sure that is the line size.
+ const size_t bytes = static_cast<size_t>(c.size_kib) * 1024;
+ const size_t lines = c.associativity * c.sets;
+ c.bytes_per_line = static_cast<uint16_t>(DivByFactor(bytes, lines));
+
+ // Divide by number of *cores* sharing the cache.
+ const std::string shared_str = ReadString("shared_cpu_list", i);
+ if (HWY_UNLIKELY(shared_str.empty())) {
+ HWY_WARN("no shared_cpu_list for L%u %s\n", level, type.c_str());
+ c.cores_sharing = 1;
+ } else {
+ const std::vector<size_t> shared_lps =
+ ExpandList(shared_str.c_str(), shared_str.length(), lps.size() - 1);
+ size_t num_cores = 0;
+ for (size_t lp : shared_lps) {
+ if (HWY_LIKELY(lp < lps.size())) {
+ num_cores += lps[lp].smt == 0;
+ } else {
+ HWY_WARN("out of bounds lp %zu of %zu from %s\n", lp, lps.size(),
+ shared_str.c_str());
+ }
+ }
+ if (num_cores == 0) {
+ HWY_WARN("no cores sharing L%u %s, setting to 1\n", level,
+ type.c_str());
+ num_cores = 1;
+ }
+ c.cores_sharing = static_cast<uint16_t>(num_cores);
+ // There exist CPUs for which L3 is not evenly divisible by `num_cores`,
+ // hence do not use `DivByFactor`. It is safer to round down.
+ c.size_kib = static_cast<uint32_t>(c.size_kib / num_cores);
+ c.sets = static_cast<uint32_t>(c.sets / num_cores);
+ }
+ }
+
+ // Require L1 and L2 cache.
+ if (HWY_UNLIKELY(caches[1].size_kib == 0 || caches[2].size_kib == 0)) {
+// Don't complain on Android because this is known to happen there. We are
+// unaware of good alternatives: `getauxval(AT_L1D_CACHEGEOMETRY)` and
+// `sysconf(_SC_LEVEL1_DCACHE_SIZE)` are unreliable, detecting via timing seems
+// difficult to do reliably, and we do not want to maintain lists of known CPUs
+// and their properties. It's OK to return false; callers are responsible for
+// assuming reasonable defaults.
+#ifndef __ANDROID__
+ HWY_WARN("sysfs detected L1=%u L2=%u, err %d\n", caches[1].size_kib,
+ caches[2].size_kib, errno);
+#endif
+ return false;
+ }
+
+ // L3 is optional; if not found, its size is already zero from static init.
+ return true;
+}
+
+#elif HWY_OS_WIN
+
+bool InitCachesWin(Caches& caches) {
+ std::vector<hwy::Topology::LP> lps(TotalLogicalProcessors());
+ std::vector<PerPackage> per_package = AssignPackageIndices(lps);
+ if (per_package.empty()) return false;
+ AssignCoreSmtIndices(lps, per_package);
+
+ (void)ForEachSLPI(RelationCache, [&per_package, &caches](const SLPI& info) {
+ const CACHE_RELATIONSHIP& cr = info.Cache;
+ if (cr.Type != CacheUnified && cr.Type != CacheData) return;
+ if (1 <= cr.Level && cr.Level <= 3) {
+ Cache& c = caches[cr.Level];
+ // If the size is non-zero then we (probably) have already detected this
+ // cache and can skip the CR.
+ if (c.size_kib > 0) return;
+ c.size_kib = static_cast<uint32_t>(DivByFactor(cr.CacheSize, 1024));
+ c.bytes_per_line = static_cast<uint16_t>(cr.LineSize);
+ c.associativity = (cr.Associativity == CACHE_FULLY_ASSOCIATIVE)
+ ? Cache::kMaxAssociativity
+ : cr.Associativity;
+
+ // How many cores share this cache?
+ size_t shared_with = NumBits(GroupCount(cr), cr.GroupMasks);
+ // Divide out hyperthreads. This core may have fewer than
+ // `max_lps_per_core`, hence round up.
+ shared_with = DivCeil(shared_with, per_package[0].max_lps_per_core);
+ if (shared_with == 0) {
+ HWY_WARN("no cores sharing L%u, setting to 1\n", cr.Level);
+ shared_with = 1;
+ }
+
+ // Update `size_kib` to *per-core* portion.
+ // There exist CPUs for which L3 is not evenly divisible by `shared_with`,
+ // hence do not use `DivByFactor`. It is safer to round down.
+ c.size_kib = static_cast<uint32_t>(c.size_kib / shared_with);
+ c.cores_sharing = static_cast<uint16_t>(shared_with);
+ }
+ });
+
+ // Require L1 and L2 cache.
+ if (HWY_UNLIKELY(caches[1].size_kib == 0 || caches[2].size_kib == 0)) {
+ HWY_WARN("Windows detected L1=%u, L2=%u, err %lx\n", caches[1].size_kib,
+ caches[2].size_kib, GetLastError());
+ return false;
+ }
+
+ // L3 is optional; if not found, its size is already zero from static init.
+ return true;
+}
+
+#elif HWY_OS_APPLE
+
+bool InitCachesApple(Caches& caches) {
+ int err = 0;
+ Cache& L1 = caches[1];
+ Cache& L2 = caches[2];
+ Cache& L3 = caches[3];
+
+ // Total L1 and L2 size can be reliably queried, but prefer perflevel0
+ // (P-cores) because hw.l1dcachesize etc. are documented to describe the
+ // "least performant core".
+ bool ok = Sysctl("hw.perflevel0.l1dcachesize", 1024, err, &L1.size_kib) ||
+ Sysctl("hw.l1dcachesize", 1024, err, &L1.size_kib);
+ ok &= Sysctl("hw.perflevel0.l2cachesize", 1024, err, &L2.size_kib) ||
+ Sysctl("hw.l2cachesize", 1024, err, &L2.size_kib);
+ if (HWY_UNLIKELY(!ok)) {
+ HWY_WARN("Apple cache detection failed, error %d\n", err);
+ return false;
+ }
+ L1.cores_sharing = 1;
+ if (Sysctl("hw.perflevel0.cpusperl2", 1, err, &L2.cores_sharing)) {
+ // There exist CPUs for which L2 is not evenly divisible by `cores_sharing`,
+ // hence do not use `DivByFactor`. It is safer to round down.
+ L2.size_kib /= L2.cores_sharing;
+ } else {
+ L2.cores_sharing = 1;
+ }
+
+ // Other properties are not always reported. Set `associativity` and
+ // `bytes_per_line` based on known models.
+ char brand[128] = {0};
+ size_t size = sizeof(brand);
+ if (!sysctlbyname("machdep.cpu.brand_string", brand, &size, nullptr, 0)) {
+ if (strncmp(brand, "Apple ", 6) != 0) {
+ // Unexpected, but we will continue check the string suffixes.
+ HWY_WARN("unexpected Apple brand %s\n", brand);
+ }
+
+ if (brand[6] == 'M') {
+ // https://dougallj.github.io/applecpu/firestorm.html,
+ // https://www.7-cpu.com/cpu/Apple_M1.html:
+ L1.bytes_per_line = 64;
+ L1.associativity = 8;
+ L2.bytes_per_line = 128;
+ if (brand[7] == '1') { // M1
+ L2.associativity = 12;
+ } else if ('2' <= brand[7] && brand[7] <= '4') { // M2/M3, maybe also M4
+ L2.associativity = 16;
+ } else {
+ L2.associativity = 0; // Unknown, set below via sysctl.
+ }
+
+ // Although Wikipedia lists SLC sizes per model, we do not know how it is
+ // partitioned/allocated, so do not treat it as a reliable L3.
+ } // M*
+ } // brand string
+
+ // This sysctl does not distinguish between L1 and L2 line sizes, so only use
+ // it if we have not already set `bytes_per_line` above.
+ uint16_t bytes_per_line;
+ if (!Sysctl("hw.cachelinesize", 1, err, &bytes_per_line)) {
+ bytes_per_line = static_cast<uint16_t>(HWY_ALIGNMENT); // guess
+ }
+ for (size_t level = 1; level <= 3; ++level) {
+ if (caches[level].bytes_per_line == 0) {
+ caches[level].bytes_per_line = bytes_per_line;
+ }
+ }
+
+ // Fill in associativity if not already set. Unfortunately this is only
+ // reported on x86, not on M*.
+ if (L1.associativity == 0 && !Sysctl("machdep.cpu.cache.L1_associativity", 1,
+ err, &L1.associativity)) {
+ L1.associativity = 8; // guess
+ }
+ if (L2.associativity == 0 && !Sysctl("machdep.cpu.cache.L2_associativity", 1,
+ err, &L2.associativity)) {
+ L2.associativity = 12; // guess
+ }
+ // There is no L3_associativity.
+ if (L3.associativity == 0) {
+ L3.associativity = 12; // guess
+ }
+
+ // Now attempt to query L3. Although this sysctl is documented, M3 does not
+ // report an L3 cache.
+ if (L3.size_kib == 0 &&
+ (Sysctl("hw.perflevel0.l3cachesize", 1024, err, &L3.size_kib) ||
+ Sysctl("hw.l3cachesize", 1024, err, &L3.size_kib))) {
+ // There exist CPUs for which L3 is not evenly divisible by `cores_sharing`,
+ // hence do not use `DivByFactor`. It is safer to round down.
+ if (Sysctl("hw.perflevel0.cpusperl3", 1, err, &L3.cores_sharing)) {
+ L3.size_kib /= L3.cores_sharing;
+ } else {
+ L3.cores_sharing = 1;
+ }
+ }
+ // If no L3 cache, reset all fields for consistency.
+ if (L3.size_kib == 0) {
+ L3 = Cache();
+ }
+
+ // Are there other useful sysctls? hw.cacheconfig appears to be how many
+ // cores share the memory and caches, though this is not documented, and
+ // duplicates information in hw.perflevel0.cpusperl*.
+
+ return true;
+}
+
+#endif // HWY_OS_*
+
+// Most APIs do not set the `sets` field, so compute it from the size and
+// associativity, and if a value is already set, ensure it matches.
+HWY_MAYBE_UNUSED void ComputeSets(Cache& c) {
+ // If there is no such cache, avoid division by zero.
+ if (HWY_UNLIKELY(c.size_kib == 0)) {
+ c.sets = 0;
+ return;
+ }
+ const size_t bytes = static_cast<size_t>(c.size_kib) * 1024;
+ // `size_kib` may have been rounded down, hence `lines` and `sets` are not
+ // necessarily evenly divisible, so round down instead of `DivByFactor`.
+ const size_t lines = bytes / c.bytes_per_line;
+ const size_t sets = lines / c.associativity;
+
+ if (c.sets == 0) {
+ c.sets = static_cast<uint32_t>(sets);
+ } else {
+ const size_t diff = c.sets - sets;
+ if (diff > 1) {
+ HWY_ABORT("Inconsistent cache sets %u != %zu\n", c.sets, sets);
+ }
+ }
+}
+
+const Cache* InitDataCaches() {
+ alignas(64) static Caches caches;
+
+ // On failure, return immediately because InitCaches*() already warn.
+#if HWY_OS_LINUX
+ if (HWY_UNLIKELY(!InitCachesSysfs(caches))) return nullptr;
+#elif HWY_OS_WIN
+ if (HWY_UNLIKELY(!InitCachesWin(caches))) return nullptr;
+#elif HWY_OS_APPLE
+ if (HWY_UNLIKELY(!InitCachesApple(caches))) return nullptr;
+#else
+ HWY_WARN("Cache detection not implemented for this platform.\n");
+ (void)caches;
+ return nullptr;
+#define HWY_NO_CACHE_DETECTION
+#endif
+
+ // Prevents "code not reached" warnings on WASM.
+#ifndef HWY_NO_CACHE_DETECTION
+ for (size_t level = 1; level <= 3; ++level) {
+ ComputeSets(caches[level]);
+ }
+
+ // Heuristic to ignore SLCs such as on Ampere Altra, which should not be
+ // treated as a reliable L3 because of their cache inclusion policy.
+ // On Apple M*, these are not even reported as an L3.
+ if (caches[3].cores_sharing >= 16 && caches[3].size_kib <= 512) {
+ caches[3] = Cache();
+ }
+
+ return &caches[0];
+#endif // HWY_NO_CACHE_DETECTION
+}
+
+} // namespace
+
+HWY_CONTRIB_DLLEXPORT const Cache* DataCaches() {
+ static const Cache* caches = InitDataCaches();
+ return caches;
+}
+
+} // namespace hwy
diff --git a/third_party/highway/hwy/contrib/thread_pool/topology_test.cc b/third_party/highway/hwy/contrib/thread_pool/topology_test.cc
new file mode 100644
index 0000000000..447f8f0f22
--- /dev/null
+++ b/third_party/highway/hwy/contrib/thread_pool/topology_test.cc
@@ -0,0 +1,135 @@
+// Copyright 2024 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "third_party/highway/hwy/contrib/thread_pool/topology.h"
+
+#include <stddef.h>
+#include <stdio.h>
+
+#include <vector>
+
+#include "third_party/highway/hwy/base.h"
+#include "third_party/highway/hwy/tests/hwy_gtest.h"
+#include "third_party/highway/hwy/tests/test_util-inl.h"
+#include "third_party/highway/hwy/timer.h"
+
+namespace hwy {
+namespace {
+
+TEST(TopologyTest, TestNum) {
+ const size_t total = TotalLogicalProcessors();
+ fprintf(stderr, "TotalLogical %zu\n", total);
+
+ LogicalProcessorSet lps;
+ if (GetThreadAffinity(lps)) {
+ fprintf(stderr, "Active %zu\n", lps.Count());
+ HWY_ASSERT(lps.Count() <= total);
+ }
+}
+
+TEST(TopologyTest, TestTopology) {
+ char cpu100[100];
+ if (hwy::platform::GetCpuString(cpu100)) {
+ fprintf(stderr, "%s\n", cpu100);
+ }
+
+ Topology topology;
+ if (topology.packages.empty()) return;
+
+ fprintf(stderr, "Topology: %zuP %zuX %zuC\n", topology.packages.size(),
+ topology.packages[0].clusters.size(),
+ topology.packages[0].clusters[0].lps.Count());
+
+ HWY_ASSERT(!topology.lps.empty());
+ LogicalProcessorSet nodes;
+ for (size_t lp = 0; lp < topology.lps.size(); ++lp) {
+ const size_t node = static_cast<size_t>(topology.lps[lp].node);
+ if (!nodes.Get(node)) {
+ fprintf(stderr, "Found NUMA node %zu, LP %zu\n", node, lp);
+ nodes.Set(node);
+ }
+ }
+
+ size_t lps_by_cluster = 0;
+ size_t lps_by_core = 0;
+ LogicalProcessorSet all_lps;
+ for (const Topology::Package& pkg : topology.packages) {
+ HWY_ASSERT(!pkg.clusters.empty());
+ HWY_ASSERT(!pkg.cores.empty());
+ HWY_ASSERT(pkg.clusters.size() <= pkg.cores.size());
+
+ for (const Topology::Cluster& c : pkg.clusters) {
+ lps_by_cluster += c.lps.Count();
+ c.lps.Foreach([&all_lps](size_t lp) { all_lps.Set(lp); });
+ }
+ for (const Topology::Core& c : pkg.cores) {
+ lps_by_core += c.lps.Count();
+ c.lps.Foreach([&all_lps](size_t lp) { all_lps.Set(lp); });
+ }
+ }
+ // Ensure the per-cluster and per-core sets sum to the total.
+ HWY_ASSERT(lps_by_cluster == topology.lps.size());
+ HWY_ASSERT(lps_by_core == topology.lps.size());
+ // .. and are a partition of unity (all LPs are covered)
+ HWY_ASSERT(all_lps.Count() == topology.lps.size());
+}
+
+void PrintCache(const Cache& c, size_t level) {
+ fprintf(stderr,
+ "L%zu: size %u KiB, line size %u, assoc %u, sets %u, cores %u\n",
+ level, c.size_kib, c.bytes_per_line, c.associativity, c.sets,
+ c.cores_sharing);
+}
+
+static void CheckCache(const Cache& c, size_t level) {
+ // L1-L2 must exist, L3 is not guaranteed.
+ if (level == 3 && c.size_kib == 0) {
+ HWY_ASSERT(c.associativity == 0 && c.bytes_per_line == 0 && c.sets == 0);
+ return;
+ }
+
+ // size and thus sets are not necessarily powers of two.
+ HWY_ASSERT(c.size_kib != 0);
+ HWY_ASSERT(c.sets != 0);
+
+ // Intel Skylake has non-pow2 L3 associativity, and Apple L2 also, so we can
+ // only check loose bounds.
+ HWY_ASSERT(c.associativity >= 2);
+ HWY_ASSERT(c.associativity <= Cache::kMaxAssociativity);
+
+ // line sizes are always powers of two because CPUs partition addresses into
+ // line offsets (the lower bits), set, and tag.
+ const auto is_pow2 = [](uint32_t x) { return x != 0 && (x & (x - 1)) == 0; };
+ HWY_ASSERT(is_pow2(c.bytes_per_line));
+ HWY_ASSERT(32 <= c.bytes_per_line && c.bytes_per_line <= 1024);
+
+ HWY_ASSERT(c.cores_sharing != 0);
+ // +1 observed on RISC-V.
+ HWY_ASSERT(c.cores_sharing <= TotalLogicalProcessors() + 1);
+}
+
+TEST(TopologyTest, TestCaches) {
+ const Cache* caches = DataCaches();
+ if (!caches) return;
+ for (size_t level = 1; level <= 3; ++level) {
+ PrintCache(caches[level], level);
+ CheckCache(caches[level], level);
+ }
+}
+
+} // namespace
+} // namespace hwy
+
+HWY_TEST_MAIN();
diff --git a/third_party/highway/hwy/contrib/unroller/unroller-inl.h b/third_party/highway/hwy/contrib/unroller/unroller-inl.h
deleted file mode 100644
index 7008e7ef41..0000000000
--- a/third_party/highway/hwy/contrib/unroller/unroller-inl.h
+++ /dev/null
@@ -1,473 +0,0 @@
-// Copyright 2023 Matthew Kolbe
-// SPDX-License-Identifier: Apache-2.0
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-#if defined(HIGHWAY_HWY_CONTRIB_UNROLLER_UNROLLER_INL_H_) == \
- defined(HWY_TARGET_TOGGLE)
-#ifdef HIGHWAY_HWY_CONTRIB_UNROLLER_UNROLLER_INL_H_
-#undef HIGHWAY_HWY_CONTRIB_UNROLLER_UNROLLER_INL_H_
-#else
-#define HIGHWAY_HWY_CONTRIB_UNROLLER_UNROLLER_INL_H_
-#endif
-
-#include <cstdlib> // std::abs
-
-#include "third_party/highway/hwy/highway.h"
-
-HWY_BEFORE_NAMESPACE();
-namespace hwy {
-namespace HWY_NAMESPACE {
-
-namespace hn = hwy::HWY_NAMESPACE;
-
-template <class DERIVED, typename IN_T, typename OUT_T>
-struct UnrollerUnit {
- static constexpr size_t kMaxTSize = HWY_MAX(sizeof(IN_T), sizeof(OUT_T));
- using LargerT = SignedFromSize<kMaxTSize>; // only the size matters.
-
- DERIVED* me() { return static_cast<DERIVED*>(this); }
-
- static constexpr size_t MaxUnitLanes() {
- return HWY_MAX_LANES_D(hn::ScalableTag<LargerT>);
- }
- static size_t ActualLanes() { return Lanes(hn::ScalableTag<LargerT>()); }
-
- using LargerD = hn::CappedTag<LargerT, MaxUnitLanes()>;
- using IT = hn::Rebind<IN_T, LargerD>;
- using OT = hn::Rebind<OUT_T, LargerD>;
- IT d_in;
- OT d_out;
- using Y_VEC = hn::Vec<OT>;
- using X_VEC = hn::Vec<IT>;
-
- Y_VEC Func(const ptrdiff_t idx, const X_VEC x, const Y_VEC y) {
- return me()->Func(idx, x, y);
- }
-
- X_VEC X0Init() { return me()->X0InitImpl(); }
-
- X_VEC X0InitImpl() { return hn::Zero(d_in); }
-
- Y_VEC YInit() { return me()->YInitImpl(); }
-
- Y_VEC YInitImpl() { return hn::Zero(d_out); }
-
- X_VEC Load(const ptrdiff_t idx, const IN_T* from) {
- return me()->LoadImpl(idx, from);
- }
-
- X_VEC LoadImpl(const ptrdiff_t idx, const IN_T* from) {
- return hn::LoadU(d_in, from + idx);
- }
-
- // MaskLoad can take in either a positive or negative number for `places`. if
- // the number is positive, then it loads the top `places` values, and if it's
- // negative, it loads the bottom |places| values. example: places = 3
- // | o | o | o | x | x | x | x | x |
- // example places = -3
- // | x | x | x | x | x | o | o | o |
- X_VEC MaskLoad(const ptrdiff_t idx, const IN_T* from,
- const ptrdiff_t places) {
- return me()->MaskLoadImpl(idx, from, places);
- }
-
- X_VEC MaskLoadImpl(const ptrdiff_t idx, const IN_T* from,
- const ptrdiff_t places) {
- auto mask = hn::FirstN(d_in, static_cast<size_t>(places));
- auto maskneg = hn::Not(hn::FirstN(
- d_in,
- static_cast<size_t>(places + static_cast<ptrdiff_t>(ActualLanes()))));
- if (places < 0) mask = maskneg;
-
- return hn::MaskedLoad(mask, d_in, from + idx);
- }
-
- bool StoreAndShortCircuit(const ptrdiff_t idx, OUT_T* to, const Y_VEC x) {
- return me()->StoreAndShortCircuitImpl(idx, to, x);
- }
-
- bool StoreAndShortCircuitImpl(const ptrdiff_t idx, OUT_T* to, const Y_VEC x) {
- hn::StoreU(x, d_out, to + idx);
- return true;
- }
-
- ptrdiff_t MaskStore(const ptrdiff_t idx, OUT_T* to, const Y_VEC x,
- ptrdiff_t const places) {
- return me()->MaskStoreImpl(idx, to, x, places);
- }
-
- ptrdiff_t MaskStoreImpl(const ptrdiff_t idx, OUT_T* to, const Y_VEC x,
- const ptrdiff_t places) {
- auto mask = hn::FirstN(d_out, static_cast<size_t>(places));
- auto maskneg = hn::Not(hn::FirstN(
- d_out,
- static_cast<size_t>(places + static_cast<ptrdiff_t>(ActualLanes()))));
- if (places < 0) mask = maskneg;
-
- hn::BlendedStore(x, mask, d_out, to + idx);
- return std::abs(places);
- }
-
- ptrdiff_t Reduce(const Y_VEC x, OUT_T* to) { return me()->ReduceImpl(x, to); }
-
- ptrdiff_t ReduceImpl(const Y_VEC x, OUT_T* to) {
- // default does nothing
- (void)x;
- (void)to;
- return 0;
- }
-
- void Reduce(const Y_VEC x0, const Y_VEC x1, const Y_VEC x2, Y_VEC* y) {
- me()->ReduceImpl(x0, x1, x2, y);
- }
-
- void ReduceImpl(const Y_VEC x0, const Y_VEC x1, const Y_VEC x2, Y_VEC* y) {
- // default does nothing
- (void)x0;
- (void)x1;
- (void)x2;
- (void)y;
- }
-};
-
-template <class DERIVED, typename IN0_T, typename IN1_T, typename OUT_T>
-struct UnrollerUnit2D {
- DERIVED* me() { return static_cast<DERIVED*>(this); }
-
- static constexpr size_t kMaxTSize =
- HWY_MAX(sizeof(IN0_T), HWY_MAX(sizeof(IN1_T), sizeof(OUT_T)));
- using LargerT = SignedFromSize<kMaxTSize>; // only the size matters.
-
- static constexpr size_t MaxUnitLanes() {
- return HWY_MAX_LANES_D(hn::ScalableTag<LargerT>);
- }
- static size_t ActualLanes() { return Lanes(hn::ScalableTag<LargerT>()); }
-
- using LargerD = hn::CappedTag<LargerT, MaxUnitLanes()>;
-
- using I0T = hn::Rebind<IN0_T, LargerD>;
- using I1T = hn::Rebind<IN1_T, LargerD>;
- using OT = hn::Rebind<OUT_T, LargerD>;
- I0T d_in0;
- I1T d_in1;
- OT d_out;
- using Y_VEC = hn::Vec<OT>;
- using X0_VEC = hn::Vec<I0T>;
- using X1_VEC = hn::Vec<I1T>;
-
- hn::Vec<OT> Func(const ptrdiff_t idx, const hn::Vec<I0T> x0,
- const hn::Vec<I1T> x1, const Y_VEC y) {
- return me()->Func(idx, x0, x1, y);
- }
-
- X0_VEC X0Init() { return me()->X0InitImpl(); }
-
- X0_VEC X0InitImpl() { return hn::Zero(d_in0); }
-
- X1_VEC X1Init() { return me()->X1InitImpl(); }
-
- X1_VEC X1InitImpl() { return hn::Zero(d_in1); }
-
- Y_VEC YInit() { return me()->YInitImpl(); }
-
- Y_VEC YInitImpl() { return hn::Zero(d_out); }
-
- X0_VEC Load0(const ptrdiff_t idx, const IN0_T* from) {
- return me()->Load0Impl(idx, from);
- }
-
- X0_VEC Load0Impl(const ptrdiff_t idx, const IN0_T* from) {
- return hn::LoadU(d_in0, from + idx);
- }
-
- X1_VEC Load1(const ptrdiff_t idx, const IN1_T* from) {
- return me()->Load1Impl(idx, from);
- }
-
- X1_VEC Load1Impl(const ptrdiff_t idx, const IN1_T* from) {
- return hn::LoadU(d_in1, from + idx);
- }
-
- // maskload can take in either a positive or negative number for `places`. if
- // the number is positive, then it loads the top `places` values, and if it's
- // negative, it loads the bottom |places| values. example: places = 3
- // | o | o | o | x | x | x | x | x |
- // example places = -3
- // | x | x | x | x | x | o | o | o |
- X0_VEC MaskLoad0(const ptrdiff_t idx, const IN0_T* from,
- const ptrdiff_t places) {
- return me()->MaskLoad0Impl(idx, from, places);
- }
-
- X0_VEC MaskLoad0Impl(const ptrdiff_t idx, const IN0_T* from,
- const ptrdiff_t places) {
- auto mask = hn::FirstN(d_in0, static_cast<size_t>(places));
- auto maskneg = hn::Not(hn::FirstN(
- d_in0,
- static_cast<size_t>(places + static_cast<ptrdiff_t>(ActualLanes()))));
- if (places < 0) mask = maskneg;
-
- return hn::MaskedLoad(mask, d_in0, from + idx);
- }
-
- hn::Vec<I1T> MaskLoad1(const ptrdiff_t idx, const IN1_T* from,
- const ptrdiff_t places) {
- return me()->MaskLoad1Impl(idx, from, places);
- }
-
- hn::Vec<I1T> MaskLoad1Impl(const ptrdiff_t idx, const IN1_T* from,
- const ptrdiff_t places) {
- auto mask = hn::FirstN(d_in1, static_cast<size_t>(places));
- auto maskneg = hn::Not(hn::FirstN(
- d_in1,
- static_cast<size_t>(places + static_cast<ptrdiff_t>(ActualLanes()))));
- if (places < 0) mask = maskneg;
-
- return hn::MaskedLoad(mask, d_in1, from + idx);
- }
-
- // store returns a bool that is `false` when
- bool StoreAndShortCircuit(const ptrdiff_t idx, OUT_T* to, const Y_VEC x) {
- return me()->StoreAndShortCircuitImpl(idx, to, x);
- }
-
- bool StoreAndShortCircuitImpl(const ptrdiff_t idx, OUT_T* to, const Y_VEC x) {
- hn::StoreU(x, d_out, to + idx);
- return true;
- }
-
- ptrdiff_t MaskStore(const ptrdiff_t idx, OUT_T* to, const Y_VEC x,
- const ptrdiff_t places) {
- return me()->MaskStoreImpl(idx, to, x, places);
- }
-
- ptrdiff_t MaskStoreImpl(const ptrdiff_t idx, OUT_T* to, const Y_VEC x,
- const ptrdiff_t places) {
- auto mask = hn::FirstN(d_out, static_cast<size_t>(places));
- auto maskneg = hn::Not(hn::FirstN(
- d_out,
- static_cast<size_t>(places + static_cast<ptrdiff_t>(ActualLanes()))));
- if (places < 0) mask = maskneg;
-
- hn::BlendedStore(x, mask, d_out, to + idx);
- return std::abs(places);
- }
-
- ptrdiff_t Reduce(const Y_VEC x, OUT_T* to) { return me()->ReduceImpl(x, to); }
-
- ptrdiff_t ReduceImpl(const Y_VEC x, OUT_T* to) {
- // default does nothing
- (void)x;
- (void)to;
- return 0;
- }
-
- void Reduce(const Y_VEC x0, const Y_VEC x1, const Y_VEC x2, Y_VEC* y) {
- me()->ReduceImpl(x0, x1, x2, y);
- }
-
- void ReduceImpl(const Y_VEC x0, const Y_VEC x1, const Y_VEC x2, Y_VEC* y) {
- // default does nothing
- (void)x0;
- (void)x1;
- (void)x2;
- (void)y;
- }
-};
-
-template <class FUNC, typename IN_T, typename OUT_T>
-inline void Unroller(FUNC& f, const IN_T* HWY_RESTRICT x, OUT_T* HWY_RESTRICT y,
- const ptrdiff_t n) {
- auto xx = f.X0Init();
- auto yy = f.YInit();
- ptrdiff_t i = 0;
-
-#if HWY_MEM_OPS_MIGHT_FAULT
- constexpr auto lane_sz =
- static_cast<ptrdiff_t>(RemoveRef<FUNC>::MaxUnitLanes());
- if (n < lane_sz) {
- const DFromV<decltype(yy)> d;
- // this may not fit on the stack for HWY_RVV, but we do not reach this code
- // there
- HWY_ALIGN IN_T xtmp[static_cast<size_t>(lane_sz)];
- HWY_ALIGN OUT_T ytmp[static_cast<size_t>(lane_sz)];
-
- CopyBytes(x, xtmp, static_cast<size_t>(n) * sizeof(IN_T));
- xx = f.MaskLoad(0, xtmp, n);
- yy = f.Func(0, xx, yy);
- Store(Zero(d), d, ytmp);
- i += f.MaskStore(0, ytmp, yy, n);
- i += f.Reduce(yy, ytmp);
- CopyBytes(ytmp, y, static_cast<size_t>(i) * sizeof(OUT_T));
- return;
- }
-#endif
-
- const ptrdiff_t actual_lanes =
- static_cast<ptrdiff_t>(RemoveRef<FUNC>::ActualLanes());
- if (n > 4 * actual_lanes) {
- auto xx1 = f.X0Init();
- auto yy1 = f.YInit();
- auto xx2 = f.X0Init();
- auto yy2 = f.YInit();
- auto xx3 = f.X0Init();
- auto yy3 = f.YInit();
-
- while (i + 4 * actual_lanes - 1 < n) {
- xx = f.Load(i, x);
- i += actual_lanes;
- xx1 = f.Load(i, x);
- i += actual_lanes;
- xx2 = f.Load(i, x);
- i += actual_lanes;
- xx3 = f.Load(i, x);
- i -= 3 * actual_lanes;
-
- yy = f.Func(i, xx, yy);
- yy1 = f.Func(i + actual_lanes, xx1, yy1);
- yy2 = f.Func(i + 2 * actual_lanes, xx2, yy2);
- yy3 = f.Func(i + 3 * actual_lanes, xx3, yy3);
-
- if (!f.StoreAndShortCircuit(i, y, yy)) return;
- i += actual_lanes;
- if (!f.StoreAndShortCircuit(i, y, yy1)) return;
- i += actual_lanes;
- if (!f.StoreAndShortCircuit(i, y, yy2)) return;
- i += actual_lanes;
- if (!f.StoreAndShortCircuit(i, y, yy3)) return;
- i += actual_lanes;
- }
-
- f.Reduce(yy3, yy2, yy1, &yy);
- }
-
- while (i + actual_lanes - 1 < n) {
- xx = f.Load(i, x);
- yy = f.Func(i, xx, yy);
- if (!f.StoreAndShortCircuit(i, y, yy)) return;
- i += actual_lanes;
- }
-
- if (i != n) {
- xx = f.MaskLoad(n - actual_lanes, x, i - n);
- yy = f.Func(n - actual_lanes, xx, yy);
- f.MaskStore(n - actual_lanes, y, yy, i - n);
- }
-
- f.Reduce(yy, y);
-}
-
-template <class FUNC, typename IN0_T, typename IN1_T, typename OUT_T>
-inline void Unroller(FUNC& HWY_RESTRICT f, IN0_T* HWY_RESTRICT x0,
- IN1_T* HWY_RESTRICT x1, OUT_T* HWY_RESTRICT y,
- const ptrdiff_t n) {
- const ptrdiff_t lane_sz =
- static_cast<ptrdiff_t>(RemoveRef<FUNC>::ActualLanes());
-
- auto xx00 = f.X0Init();
- auto xx10 = f.X1Init();
- auto yy = f.YInit();
-
- ptrdiff_t i = 0;
-
-#if HWY_MEM_OPS_MIGHT_FAULT
- if (n < lane_sz) {
- const DFromV<decltype(yy)> d;
- // this may not fit on the stack for HWY_RVV, but we do not reach this code
- // there
- constexpr auto max_lane_sz =
- static_cast<ptrdiff_t>(RemoveRef<FUNC>::MaxUnitLanes());
- HWY_ALIGN IN0_T xtmp0[static_cast<size_t>(max_lane_sz)];
- HWY_ALIGN IN1_T xtmp1[static_cast<size_t>(max_lane_sz)];
- HWY_ALIGN OUT_T ytmp[static_cast<size_t>(max_lane_sz)];
-
- CopyBytes(x0, xtmp0, static_cast<size_t>(n) * sizeof(IN0_T));
- CopyBytes(x1, xtmp1, static_cast<size_t>(n) * sizeof(IN1_T));
- xx00 = f.MaskLoad0(0, xtmp0, n);
- xx10 = f.MaskLoad1(0, xtmp1, n);
- yy = f.Func(0, xx00, xx10, yy);
- Store(Zero(d), d, ytmp);
- i += f.MaskStore(0, ytmp, yy, n);
- i += f.Reduce(yy, ytmp);
- CopyBytes(ytmp, y, static_cast<size_t>(i) * sizeof(OUT_T));
- return;
- }
-#endif
-
- if (n > 4 * lane_sz) {
- auto xx01 = f.X0Init();
- auto xx11 = f.X1Init();
- auto yy1 = f.YInit();
- auto xx02 = f.X0Init();
- auto xx12 = f.X1Init();
- auto yy2 = f.YInit();
- auto xx03 = f.X0Init();
- auto xx13 = f.X1Init();
- auto yy3 = f.YInit();
-
- while (i + 4 * lane_sz - 1 < n) {
- xx00 = f.Load0(i, x0);
- xx10 = f.Load1(i, x1);
- i += lane_sz;
- xx01 = f.Load0(i, x0);
- xx11 = f.Load1(i, x1);
- i += lane_sz;
- xx02 = f.Load0(i, x0);
- xx12 = f.Load1(i, x1);
- i += lane_sz;
- xx03 = f.Load0(i, x0);
- xx13 = f.Load1(i, x1);
- i -= 3 * lane_sz;
-
- yy = f.Func(i, xx00, xx10, yy);
- yy1 = f.Func(i + lane_sz, xx01, xx11, yy1);
- yy2 = f.Func(i + 2 * lane_sz, xx02, xx12, yy2);
- yy3 = f.Func(i + 3 * lane_sz, xx03, xx13, yy3);
-
- if (!f.StoreAndShortCircuit(i, y, yy)) return;
- i += lane_sz;
- if (!f.StoreAndShortCircuit(i, y, yy1)) return;
- i += lane_sz;
- if (!f.StoreAndShortCircuit(i, y, yy2)) return;
- i += lane_sz;
- if (!f.StoreAndShortCircuit(i, y, yy3)) return;
- i += lane_sz;
- }
-
- f.Reduce(yy3, yy2, yy1, &yy);
- }
-
- while (i + lane_sz - 1 < n) {
- xx00 = f.Load0(i, x0);
- xx10 = f.Load1(i, x1);
- yy = f.Func(i, xx00, xx10, yy);
- if (!f.StoreAndShortCircuit(i, y, yy)) return;
- i += lane_sz;
- }
-
- if (i != n) {
- xx00 = f.MaskLoad0(n - lane_sz, x0, i - n);
- xx10 = f.MaskLoad1(n - lane_sz, x1, i - n);
- yy = f.Func(n - lane_sz, xx00, xx10, yy);
- f.MaskStore(n - lane_sz, y, yy, i - n);
- }
-
- f.Reduce(yy, y);
-}
-
-} // namespace HWY_NAMESPACE
-} // namespace hwy
-HWY_AFTER_NAMESPACE();
-
-#endif // HIGHWAY_HWY_CONTRIB_UNROLLER_UNROLLER_INL_H_
diff --git a/third_party/highway/hwy/detect_compiler_arch.h b/third_party/highway/hwy/detect_compiler_arch.h
index 9d4d56b0a0..1810a5434d 100644
--- a/third_party/highway/hwy/detect_compiler_arch.h
+++ b/third_party/highway/hwy/detect_compiler_arch.h
@@ -20,9 +20,14 @@
// inclusion by foreach_target.h.
// Add to #if conditions to prevent IDE from graying out code.
+// Note for clangd users: There is no predefined macro in clangd, so you must
+// manually add these two lines (without the preceding '// ') to your project's
+// `.clangd` file:
+// CompileFlags:
+// Add: [-D__CLANGD__]
#if (defined __CDT_PARSER__) || (defined __INTELLISENSE__) || \
(defined Q_CREATOR_RUN) || (defined __CLANGD__) || \
- (defined GROK_ELLIPSIS_BUILD)
+ (defined GROK_ELLIPSIS_BUILD) || (defined __JETBRAINS_IDE__)
#define HWY_IDE 1
#else
#define HWY_IDE 0
@@ -65,15 +70,23 @@
#define HWY_COMPILER_GCC 0
#endif
-// Clang or clang-cl, not GCC.
-#ifdef __clang__
+#ifndef HWY_COMPILER_CLANG // Allow user override.
+#ifdef __clang__ // Clang or clang-cl, not GCC.
// In case of Apple LLVM (whose version number is unrelated to that of LLVM) or
// an invalid version number, deduce it from the presence of warnings.
// Originally based on
// https://github.com/simd-everywhere/simde/blob/47d6e603de9d04ee05cdfbc57cf282a02be1bf2a/simde/simde-detect-clang.h#L59.
// Please send updates below to them as well, thanks!
#if defined(__apple_build_version__) || __clang_major__ >= 999
-#if __has_warning("-Woverriding-option")
+#if __has_builtin(__builtin_elementwise_fshl)
+#define HWY_COMPILER_CLANG 2201
+#elif __has_builtin(__builtin_structured_binding_size)
+#define HWY_COMPILER_CLANG 2101
+#elif __has_builtin(__builtin_common_type)
+#define HWY_COMPILER_CLANG 2001
+#elif __has_warning("-Wreturn-mismatch")
+#define HWY_COMPILER_CLANG 1901
+#elif __has_warning("-Woverriding-option")
#define HWY_COMPILER_CLANG 1801
// No new warnings in 17.0, and Apple LLVM 15.3, which should be 1600, already
// has the unsafe_buffer_usage attribute, so we instead check for new builtins.
@@ -108,7 +121,6 @@
#else // Anything older than 7.0 is not recommended for Highway.
#define HWY_COMPILER_CLANG 600
#endif // __has_warning chain
-#define HWY_COMPILER3_CLANG (HWY_COMPILER_CLANG * 100)
#else // use normal version
#define HWY_COMPILER_CLANG (__clang_major__ * 100 + __clang_minor__)
#define HWY_COMPILER3_CLANG \
@@ -117,6 +129,12 @@
#else // Not clang
#define HWY_COMPILER_CLANG 0
#define HWY_COMPILER3_CLANG 0
+#endif // __clang__
+#endif // HWY_COMPILER_CLANG
+
+// User-defined or deduced HWY_COMPILER_CLANG: derive HWY_COMPILER3_CLANG.
+#ifndef HWY_COMPILER3_CLANG
+#define HWY_COMPILER3_CLANG (HWY_COMPILER_CLANG * 100)
#endif
#if HWY_COMPILER_GCC && !HWY_COMPILER_CLANG && !HWY_COMPILER_ICC && \
@@ -174,22 +192,97 @@
#define HWY_CXX_LANG __cplusplus
#endif
+// Use instead of constexpr to avoid compiler errors for older compilers. This
+// macro is for when a constexpr function involves multiple statements and
+// loops, which is allowed in C++14 but not before. If the compiler does not
+// support C++14 constexpr, this evaluates to nothing.
+#if defined(__cpp_constexpr) && __cpp_constexpr >= 201304L
+#define HWY_CXX14_CONSTEXPR constexpr
+#else
+#define HWY_CXX14_CONSTEXPR
+#endif
+
+// Same as above, but for C++17 constexpr, which adds support for lambdas, the
+// standard library, and capturing *this.
+// Note that C++17 constexpr still disallows allocating and virtual functions,
+// which are allowed in C++20, but we do not have a use case yet.
#if defined(__cpp_constexpr) && __cpp_constexpr >= 201603L
#define HWY_CXX17_CONSTEXPR constexpr
#else
#define HWY_CXX17_CONSTEXPR
#endif
-#if defined(__cpp_constexpr) && __cpp_constexpr >= 201304L
-#define HWY_CXX14_CONSTEXPR constexpr
+// Use instead of `if constexpr` to avoid compiler errors for older compilers.
+// When compilers lack C++17 support, this evaluates to a normal if statement.
+#if HWY_CXX_LANG >= 201703L || \
+ (defined(__cpp_if_constexpr) && __cpp_if_constexpr >= 201606L)
+#define HWY_IF_CONSTEXPR if constexpr
#else
-#define HWY_CXX14_CONSTEXPR
+#define HWY_IF_CONSTEXPR if
#endif
-#if HWY_CXX_LANG >= 201703L
-#define HWY_IF_CONSTEXPR if constexpr
+// Use for constexpr variables at namespace scope in headers. Constexpr is
+// separate to allow using `HWY_CXX14_CONSTEXPR` if required.
+#ifndef HWY_INLINE_VAR
+#if __cplusplus > 201402L
+// C++17: mark as COMDAT to ensure linkers de-duplicate it. See
+// https://quuxplusone.github.io/blog/2022/07/08/inline-constexpr/
+#define HWY_INLINE_VAR inline
#else
-#define HWY_IF_CONSTEXPR if
+#define HWY_INLINE_VAR
+#endif
+#endif
+
+//------------------------------------------------------------------------------
+// Sanitizers
+
+#if HWY_HAS_FEATURE(memory_sanitizer) || defined(MEMORY_SANITIZER) || \
+ defined(__SANITIZE_MEMORY__)
+#define HWY_IS_MSAN 1
+#else
+#define HWY_IS_MSAN 0
+#endif
+
+#if HWY_HAS_FEATURE(address_sanitizer) || defined(ADDRESS_SANITIZER) || \
+ defined(__SANITIZE_ADDRESS__)
+#define HWY_IS_ASAN 1
+#else
+#define HWY_IS_ASAN 0
+#endif
+
+#if HWY_HAS_FEATURE(hwaddress_sanitizer) || defined(HWADDRESS_SANITIZER) || \
+ defined(__SANITIZE_HWADDRESS__)
+#define HWY_IS_HWASAN 1
+#else
+#define HWY_IS_HWASAN 0
+#endif
+
+#if HWY_HAS_FEATURE(thread_sanitizer) || defined(THREAD_SANITIZER) || \
+ defined(__SANITIZE_THREAD__)
+#define HWY_IS_TSAN 1
+#else
+#define HWY_IS_TSAN 0
+#endif
+
+#if HWY_HAS_FEATURE(undefined_behavior_sanitizer) || \
+ defined(UNDEFINED_BEHAVIOR_SANITIZER)
+#define HWY_IS_UBSAN 1
+#else
+#define HWY_IS_UBSAN 0
+#endif
+
+// MSAN may cause lengthy build times or false positives e.g. in AVX3 DemoteTo.
+// You can disable MSAN by adding this attribute to the function that fails.
+#if HWY_IS_MSAN
+#define HWY_ATTR_NO_MSAN __attribute__((no_sanitize_memory))
+#else
+#define HWY_ATTR_NO_MSAN
+#endif
+
+#if HWY_IS_ASAN || HWY_IS_HWASAN || HWY_IS_MSAN || HWY_IS_TSAN || HWY_IS_UBSAN
+#define HWY_IS_SANITIZER 1
+#else
+#define HWY_IS_SANITIZER 0
#endif
//------------------------------------------------------------------------------
@@ -217,13 +310,16 @@
#define HWY_ARCH_X86 0
#endif
-#if defined(__powerpc64__) || defined(_M_PPC) || defined(__powerpc__)
+// Apple uses __ppc__, MSVC uses _M_PPC.
+#if defined(__powerpc64__) || defined(_M_PPC) || defined(__powerpc__) || \
+ defined(__PPC__) || defined(__ppc__) || defined(__POWERPC__)
#define HWY_ARCH_PPC 1
#else
#define HWY_ARCH_PPC 0
#endif
-#if defined(__powerpc64__) || (HWY_ARCH_PPC && defined(__64BIT__))
+#if defined(__powerpc64__) || defined(__PPC64__) || defined(__ppc64__) || \
+ (HWY_ARCH_PPC && defined(__64BIT__))
#define HWY_ARCH_PPC_64 1
#else
#define HWY_ARCH_PPC_64 0
@@ -322,13 +418,36 @@
#define HWY_ARCH_LOONGARCH 0
#endif
+#if defined(__hexagon__) || defined(__HEXAGON_ARCH__)
+#define HWY_ARCH_HEXAGON 1
+#else
+#define HWY_ARCH_HEXAGON 0
+#endif
+
// It is an error to detect multiple architectures at the same time, but OK to
// detect none of the above.
-#if (HWY_ARCH_X86 + HWY_ARCH_PPC + HWY_ARCH_ARM + HWY_ARCH_ARM_OLD + \
- HWY_ARCH_WASM + HWY_ARCH_RISCV + HWY_ARCH_S390X + HWY_ARCH_LOONGARCH) > 1
+#if (HWY_ARCH_X86 + HWY_ARCH_PPC + HWY_ARCH_ARM + HWY_ARCH_ARM_OLD + \
+ HWY_ARCH_WASM + HWY_ARCH_RISCV + HWY_ARCH_S390X + HWY_ARCH_LOONGARCH + \
+ HWY_ARCH_HEXAGON) > 1
#error "Must not detect more than one architecture"
#endif
+#if HWY_ARCH_RISCV
+#define HWY_ARCH_MAX_BYTES 65536
+#elif HWY_ARCH_ARM_A64
+#define HWY_ARCH_MAX_BYTES 256
+#elif HWY_ARCH_HEXAGON
+#define HWY_ARCH_MAX_BYTES 128
+#elif HWY_ARCH_X86
+#define HWY_ARCH_MAX_BYTES 64
+#elif HWY_ARCH_WASM || HWY_ARCH_LOONGARCH
+#define HWY_ARCH_MAX_BYTES 32
+#elif HWY_ARCH_PPC || HWY_ARCH_S390X || HWY_ARCH_ARM_V7 || HWY_ARCH_ARM_OLD
+#define HWY_ARCH_MAX_BYTES 16
+#else
+#error "Missing case for HWY_ARCH_*"
+#endif
+
//------------------------------------------------------------------------------
// Operating system
@@ -392,4 +511,20 @@
#error "Must only detect one byte order"
#endif
+//------------------------------------------------------------------------------
+// Features checked in set_macros-inl.h
+
+// Compiler supports ACLE __bf16, not necessarily with operators.
+//
+// Disable the __bf16 type on AArch64 with GCC 13 or earlier as there is a bug
+// in GCC 13 and earlier that sometimes causes BF16 constant values to be
+// incorrectly loaded on AArch64, and this GCC bug on AArch64 is
+// described at https://gcc.gnu.org/bugzilla/show_bug.cgi?id=111867.
+#if HWY_ARCH_ARM_A64 && \
+ (HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400)
+#define HWY_ARM_HAVE_SCALAR_BF16_TYPE 1
+#else
+#define HWY_ARM_HAVE_SCALAR_BF16_TYPE 0
+#endif
+
#endif // HIGHWAY_HWY_DETECT_COMPILER_ARCH_H_
diff --git a/third_party/highway/hwy/detect_targets.h b/third_party/highway/hwy/detect_targets.h
index 491f3ee8f9..7bff06d2ba 100644
--- a/third_party/highway/hwy/detect_targets.h
+++ b/third_party/highway/hwy/detect_targets.h
@@ -1,4 +1,5 @@
// Copyright 2021 Google LLC
+// Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
@@ -60,9 +61,9 @@
// --------------------------- x86: 15 targets (+ one fallback)
// Bits 0..2 reserved (3 targets)
-#define HWY_AVX10_2_512 (1LL << 3) // AVX10.2 with 512-bit vectors
+#define HWY_AVX10_2 (1LL << 3) // AVX10.2 with 512-bit vectors
#define HWY_AVX3_SPR (1LL << 4)
-#define HWY_AVX10_2 (1LL << 5) // AVX10.2 with 256-bit vectors
+// Bit 5: reserved (1 target)
// Currently `HWY_AVX3_DL` plus `AVX512BF16` and a special case for
// `CompressStore` (10x as fast, still useful on Zen5). We may later also use
// `VPCONFLICT`. Note that `VP2INTERSECT` is available in Zen5.
@@ -71,8 +72,8 @@
// Currently satisfiable by Ice Lake (`VNNI`, `VPCLMULQDQ`, `VPOPCNTDQ`,
// `VBMI`, `VBMI2`, `VAES`, `BITALG`, `GFNI`).
#define HWY_AVX3_DL (1LL << 7)
-#define HWY_AVX3 (1LL << 8) // HWY_AVX2 plus AVX-512F/BW/CD/DQ/VL
-#define HWY_AVX2 (1LL << 9) // HWY_SSE4 plus BMI2 + F16 + FMA
+#define HWY_AVX3 (1LL << 8) // HWY_AVX2 plus AVX-512F/BW/CD/DQ/VL
+#define HWY_AVX2 (1LL << 9) // HWY_SSE4 plus BMI2 + F16 + FMA
// Bit 10: reserved
#define HWY_SSE4 (1LL << 11) // SSE4.2 plus AES + CLMUL
#define HWY_SSSE3 (1LL << 12) // S-SSE3
@@ -92,7 +93,7 @@
#define HWY_SVE2 (1LL << 23)
#define HWY_SVE (1LL << 24)
// Bit 25 reserved for NEON
-#define HWY_NEON_BF16 (1LL << 26) // fp16/dot/bf16 (e.g. Neoverse V2/N2/N3)
+#define HWY_NEON_BF16 (1LL << 26) // fp16/dot/bf16 (e.g. Neoverse V2/N2)
// Bit 27 reserved for NEON
#define HWY_NEON (1LL << 28) // Implies support for AES
#define HWY_NEON_WITHOUT_AES (1LL << 29)
@@ -194,10 +195,21 @@
#endif
#endif // HWY_BROKEN_MSVC
+#ifndef HWY_BROKEN_AVX10_2 // allow override
+// AVX10_2 requires clang >= 20.1 (postpone to 23 due to "avx10.2-512" remnant,
+// only removed in https://github.com/llvm/llvm-project/pull/157034) or
+// gcc >= 15.2 with binutils 2.44.
+#if (HWY_COMPILER_CLANG < 2300) && (HWY_COMPILER_GCC_ACTUAL < 1502)
+#define HWY_BROKEN_AVX10_2 HWY_AVX10_2
+#else
+#define HWY_BROKEN_AVX10_2 0
+#endif
+#endif // HWY_BROKEN_AVX10_2
+
#ifndef HWY_BROKEN_AVX3_DL_ZEN4 // allow override
-// AVX3_DL and AVX3_ZEN4 require clang >= 7 (ensured above), gcc >= 8.1 or ICC
+// AVX3_DL and AVX3_ZEN4 require clang >= 7 (ensured above), gcc >= 10.1 or ICC
// 2021.
-#if (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 801) || \
+#if (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1001) || \
(HWY_COMPILER_ICC && HWY_COMPILER_ICC < 2021)
#define HWY_BROKEN_AVX3_DL_ZEN4 (HWY_AVX3_DL | HWY_AVX3_ZEN4)
#else
@@ -245,9 +257,10 @@
#endif // HWY_BROKEN_ARM7_WITHOUT_VFP4
#ifndef HWY_BROKEN_NEON_BF16 // allow override
-// HWY_NEON_BF16 requires recent compilers.
+// Broken on older compilers:
#if (HWY_COMPILER_CLANG != 0 && HWY_COMPILER_CLANG < 1700) || \
- (HWY_COMPILER_GCC_ACTUAL != 0 && HWY_COMPILER_GCC_ACTUAL < 1302)
+ (HWY_COMPILER_GCC_ACTUAL != 0 && HWY_COMPILER_GCC_ACTUAL < 1302) || \
+ (defined(__apple_build_version__) && __apple_build_version__ <= 17000000)
#define HWY_BROKEN_NEON_BF16 (HWY_NEON_BF16)
#else
#define HWY_BROKEN_NEON_BF16 0
@@ -257,11 +270,11 @@
// SVE[2] require recent clang or gcc versions.
#ifndef HWY_BROKEN_SVE // allow override
-// GCC 10+. Clang 19 still has many test failures for SVE. No Apple CPU (at
-// least up to and including M4 and A18) has SVE.
-#if (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 2000) || \
+// Clang 22+, GCC 10+, except MSAN does not yet support SVE.
+// No Apple CPU (at least up to and including M4 and A18) has SVE.
+#if (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 2200) || \
(HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1000) || \
- HWY_OS_APPLE
+ HWY_OS_APPLE || HWY_IS_MSAN
#define HWY_BROKEN_SVE (HWY_SVE | HWY_SVE_256)
#else
#define HWY_BROKEN_SVE 0
@@ -269,16 +282,28 @@
#endif // HWY_BROKEN_SVE
#ifndef HWY_BROKEN_SVE2 // allow override
-// Clang 19 still has many test failures for SVE2.
-#if (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 2000) || \
+// Clang 22+, GCC 10+, except MSAN does not yet support SVE2.
+// No Apple CPU (at least up to and including M4 and A18) has SVE2.
+#if (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 2200) || \
(HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1000) || \
- HWY_OS_APPLE
-#define HWY_BROKEN_SVE2 (HWY_SVE2 | HWY_SVE2_128)
+ HWY_OS_APPLE || HWY_IS_MSAN
+#define HWY_BROKEN_SVE2 (HWY_SVE2)
#else
#define HWY_BROKEN_SVE2 0
#endif
#endif // HWY_BROKEN_SVE2
+#ifndef HWY_BROKEN_SVE2_128 // allow override
+// GCC 10+. Clang 21 works for SVE2_128, but not for SVE2 nor MSAN.
+#if (HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 2100) || \
+ (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1000) || \
+ HWY_OS_APPLE || HWY_IS_MSAN
+#define HWY_BROKEN_SVE2_128 (HWY_SVE2_128)
+#else
+#define HWY_BROKEN_SVE2_128 0
+#endif
+#endif // HWY_BROKEN_SVE2_128
+
#ifndef HWY_BROKEN_PPC10 // allow override
#if (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1100)
// GCC 10 supports the -mcpu=power10 option but does not support the PPC10
@@ -330,11 +355,15 @@
#endif // HWY_BROKEN_RVV
#ifndef HWY_BROKEN_LOONGARCH // allow override
-// HWY_LSX/HWY_LASX require GCC 14 or Clang 18.
-#if HWY_ARCH_LOONGARCH && \
- ((HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 1800) || \
- (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1400))
+// Using __loongarch_sx and __loongarch_asx macros to
+// check whether LSX/LASX targets are available.
+// GCC does not work yet, see https://gcc.gnu.org/PR121875.
+#if !defined(__loongarch_sx) && \
+ !(HWY_COMPILER_CLANG && HWY_COMPILER_CLANG >= 1800)
#define HWY_BROKEN_LOONGARCH (HWY_LSX | HWY_LASX)
+#elif !defined(__loongarch_asx) && \
+ !(HWY_COMPILER_CLANG && HWY_COMPILER_CLANG >= 1800)
+#define HWY_BROKEN_LOONGARCH (HWY_LASX)
#else
#define HWY_BROKEN_LOONGARCH 0
#endif
@@ -359,13 +388,13 @@
// Allow the user to override this without any guarantee of success.
#ifndef HWY_BROKEN_TARGETS
-#define HWY_BROKEN_TARGETS \
- (HWY_BROKEN_CLANG6 | HWY_BROKEN_32BIT | HWY_BROKEN_MSVC | \
- HWY_BROKEN_AVX3_DL_ZEN4 | HWY_BROKEN_AVX3_SPR | \
- HWY_BROKEN_ARM7_BIG_ENDIAN | HWY_BROKEN_ARM7_WITHOUT_VFP4 | \
- HWY_BROKEN_NEON_BF16 | HWY_BROKEN_SVE | HWY_BROKEN_SVE2 | \
- HWY_BROKEN_PPC10 | HWY_BROKEN_PPC_32BIT | HWY_BROKEN_RVV | \
- HWY_BROKEN_LOONGARCH | HWY_BROKEN_Z14)
+#define HWY_BROKEN_TARGETS \
+ (HWY_BROKEN_CLANG6 | HWY_BROKEN_32BIT | HWY_BROKEN_MSVC | \
+ HWY_BROKEN_AVX10_2 | HWY_BROKEN_AVX3_DL_ZEN4 | HWY_BROKEN_AVX3_SPR | \
+ HWY_BROKEN_ARM7_BIG_ENDIAN | HWY_BROKEN_ARM7_WITHOUT_VFP4 | \
+ HWY_BROKEN_NEON_BF16 | HWY_BROKEN_SVE | HWY_BROKEN_SVE2 | \
+ HWY_BROKEN_SVE2_128 | HWY_BROKEN_PPC10 | HWY_BROKEN_PPC_32BIT | \
+ HWY_BROKEN_RVV | HWY_BROKEN_LOONGARCH | HWY_BROKEN_Z14)
#endif // HWY_BROKEN_TARGETS
@@ -379,7 +408,7 @@
// because it affects the fallback target, which must always be enabled. If 1,
// we instead choose HWY_SCALAR even without HWY_COMPILE_ONLY_SCALAR being set.
#if !defined(HWY_BROKEN_EMU128) // allow overriding
-#if (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1400) || \
+#if (HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1600) || \
defined(HWY_NO_LIBCXX)
#define HWY_BROKEN_EMU128 1
#else
@@ -488,7 +517,8 @@
#if defined(__ARM_FEATURE_AES) && \
defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) && \
defined(__ARM_FEATURE_DOTPROD) && \
- defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC)
+ defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) && \
+ defined(__ARM_FEATURE_MATMUL_INT8)
#define HWY_BASELINE_NEON HWY_ALL_NEON
#elif defined(__ARM_FEATURE_AES)
#define HWY_BASELINE_NEON (HWY_NEON_WITHOUT_AES | HWY_NEON)
@@ -600,10 +630,12 @@
#endif
// Require everything in AVX2 plus AVX-512 flags (also set by MSVC)
-#if HWY_BASELINE_AVX2 != 0 && defined(__AVX512F__) && defined(__AVX512BW__) && \
- defined(__AVX512DQ__) && defined(__AVX512VL__) && \
- ((!HWY_COMPILER_GCC_ACTUAL && !HWY_COMPILER_CLANG) || \
- HWY_COMPILER_GCC_ACTUAL < 1400 || HWY_COMPILER_CLANG < 1800 || \
+#if HWY_BASELINE_AVX2 != 0 && \
+ ((defined(__AVX512F__) && defined(__AVX512BW__) && \
+ defined(__AVX512DQ__) && defined(__AVX512VL__)) || \
+ defined(__AVX10_2__)) && \
+ ((!HWY_COMPILER_GCC_ACTUAL && !HWY_COMPILER_CLANG) || \
+ HWY_COMPILER_GCC_ACTUAL < 1400 || HWY_COMPILER_CLANG < 1800 || \
defined(__EVEX512__))
#define HWY_BASELINE_AVX3 HWY_AVX3
#else
@@ -611,10 +643,12 @@
#endif
// TODO(janwas): not yet known whether these will be set by MSVC
-#if HWY_BASELINE_AVX3 != 0 && defined(__AVX512VNNI__) && defined(__VAES__) && \
- defined(__VPCLMULQDQ__) && defined(__AVX512VBMI__) && \
- defined(__AVX512VBMI2__) && defined(__AVX512VPOPCNTDQ__) && \
- defined(__AVX512BITALG__)
+#if HWY_BASELINE_AVX3 != 0 && \
+ ((defined(__AVX512VNNI__) && defined(__VAES__) && \
+ defined(__VPCLMULQDQ__) && defined(__AVX512VBMI__) && \
+ defined(__AVX512VBMI2__) && defined(__AVX512VPOPCNTDQ__) && \
+ defined(__AVX512BITALG__)) || \
+ defined(__AVX10_2__))
#define HWY_BASELINE_AVX3_DL HWY_AVX3_DL
#else
#define HWY_BASELINE_AVX3_DL 0
@@ -629,27 +663,26 @@
#define HWY_BASELINE_AVX3_ZEN4 0
#endif
-#if HWY_BASELINE_AVX2 != 0 && defined(__AVX10_2__)
-#define HWY_BASELINE_AVX10_2 HWY_AVX10_2
-#else
-#define HWY_BASELINE_AVX10_2 0
-#endif
-
-#if HWY_BASELINE_AVX3_DL != 0 && defined(__AVX512BF16__) && \
- defined(__AVX512FP16__)
+#if HWY_BASELINE_AVX3_DL != 0 && \
+ ((defined(__AVX512BF16__) && defined(__AVX512FP16__)) || \
+ defined(__AVX10_2__))
#define HWY_BASELINE_AVX3_SPR HWY_AVX3_SPR
#else
#define HWY_BASELINE_AVX3_SPR 0
#endif
-#if HWY_BASELINE_AVX3_SPR != 0 && defined(__AVX10_2_512__)
-#define HWY_BASELINE_AVX10_2_512 HWY_AVX10_2_512
+#if HWY_BASELINE_AVX3_SPR != 0 && defined(__AVX10_2__)
+#define HWY_BASELINE_AVX10_2 HWY_AVX10_2
#else
-#define HWY_BASELINE_AVX10_2_512 0
+#define HWY_BASELINE_AVX10_2 0
#endif
// RVV requires intrinsics 0.11 or later, see #1156.
-#if HWY_ARCH_RISCV && defined(__riscv_v_intrinsic) && \
+
+// Also check that the __riscv_v macro is defined as GCC or Clang will define
+// the __risc_v macro if the RISC-V "V" extension is enabled.
+
+#if HWY_ARCH_RISCV && defined(__riscv_v) && defined(__riscv_v_intrinsic) && \
__riscv_v_intrinsic >= 11000
#define HWY_BASELINE_RVV HWY_RVV
#else
@@ -664,17 +697,29 @@
#define HWY_BASELINE_LOONGARCH 0
#endif
-// Allow the user to override this without any guarantee of success.
+// Workaround for libaom, which unconditionally defines HWY_BASELINE_TARGETS
+// even when that would be disabled/broken. If so, at least use AVX2.
+#if defined(HWY_BASELINE_TARGETS)
+#if HWY_BASELINE_TARGETS == HWY_AVX3_DL && \
+ ((HWY_BROKEN_TARGETS | HWY_DISABLED_TARGETS) & HWY_AVX3_DL)
+#undef HWY_BASELINE_TARGETS
+#define HWY_BASELINE_TARGETS HWY_AVX2
+#endif
+#endif // HWY_BASELINE_TARGETS
+
+// Allow the user to override this without any guarantee of success. If the
+// compiler invocation considers that target to be broken/disabled, then
+// `HWY_ENABLED_BASELINE` will be 0 and users will have to check for that and
+// skip their code.
#ifndef HWY_BASELINE_TARGETS
-#define HWY_BASELINE_TARGETS \
- (HWY_BASELINE_SCALAR | HWY_BASELINE_WASM | HWY_BASELINE_PPC8 | \
- HWY_BASELINE_PPC9 | HWY_BASELINE_PPC10 | HWY_BASELINE_Z14 | \
- HWY_BASELINE_Z15 | HWY_BASELINE_SVE2 | HWY_BASELINE_SVE | \
- HWY_BASELINE_NEON | HWY_BASELINE_SSE2 | HWY_BASELINE_SSSE3 | \
- HWY_BASELINE_SSE4 | HWY_BASELINE_AVX2 | HWY_BASELINE_AVX3 | \
- HWY_BASELINE_AVX3_DL | HWY_BASELINE_AVX3_ZEN4 | HWY_BASELINE_AVX10_2 | \
- HWY_BASELINE_AVX3_SPR | HWY_BASELINE_AVX10_2_512 | HWY_BASELINE_RVV | \
- HWY_BASELINE_LOONGARCH)
+#define HWY_BASELINE_TARGETS \
+ (HWY_BASELINE_SCALAR | HWY_BASELINE_WASM | HWY_BASELINE_PPC8 | \
+ HWY_BASELINE_PPC9 | HWY_BASELINE_PPC10 | HWY_BASELINE_Z14 | \
+ HWY_BASELINE_Z15 | HWY_BASELINE_SVE2 | HWY_BASELINE_SVE | \
+ HWY_BASELINE_NEON | HWY_BASELINE_SSE2 | HWY_BASELINE_SSSE3 | \
+ HWY_BASELINE_SSE4 | HWY_BASELINE_AVX2 | HWY_BASELINE_AVX3 | \
+ HWY_BASELINE_AVX3_DL | HWY_BASELINE_AVX3_ZEN4 | HWY_BASELINE_AVX3_SPR | \
+ HWY_BASELINE_AVX10_2 | HWY_BASELINE_RVV | HWY_BASELINE_LOONGARCH)
#endif // HWY_BASELINE_TARGETS
//------------------------------------------------------------------------------
@@ -682,7 +727,11 @@
#define HWY_ENABLED_BASELINE HWY_ENABLED(HWY_BASELINE_TARGETS)
#if HWY_ENABLED_BASELINE == 0
-#error "At least one baseline target must be defined and enabled"
+#pragma message \
+ "All baseline targets are disabled or considered broken." \
+ "This is typically due to very restrictive HWY_BASELINE_TARGETS, or " \
+ "too expansive HWY_BROKEN_TARGETS or HWY_DISABLED_TAREGTS. User code " \
+ "must also check for this and skip any usage of SIMD."
#endif
// Best baseline, used for static dispatch. This is the least-significant 1-bit
@@ -698,13 +747,6 @@
//------------------------------------------------------------------------------
// Choose targets for dynamic dispatch according to one of four policies
-// TODO: remove once HWY_LSX is actually supported
-#if HWY_ARCH_LOONGARCH && !defined(HWY_COMPILE_ONLY_SCALAR) && \
- !defined(HWY_COMPILE_ONLY_EMU128)
-#undef HWY_COMPILE_ONLY_STATIC
-#define HWY_COMPILE_ONLY_EMU128
-#endif
-
#if 1 < (defined(HWY_COMPILE_ONLY_SCALAR) + defined(HWY_COMPILE_ONLY_EMU128) + \
defined(HWY_COMPILE_ONLY_STATIC))
#error "Can only define one of HWY_COMPILE_ONLY_{SCALAR|EMU128|STATIC} - bug?"
@@ -747,12 +789,10 @@
#endif // HWY_HAVE_AUXV
#ifndef HWY_HAVE_RUNTIME_DISPATCH_RVV // allow override
-// The riscv_vector.h in Clang 16-18 requires compiler flags, and 19 still has
-// some missing intrinsics, see
-// https://github.com/llvm/llvm-project/issues/56592. GCC 13.3 also has an
-// #error check, whereas 14.1 fails with "argument type 'vuint16m8_t' requires
-// the V ISA extension": https://gcc.gnu.org/bugzilla/show_bug.cgi?id=115325.
-#if HWY_ARCH_RISCV && HWY_COMPILER_CLANG >= 1900 && 0
+// Clang 19+ supports target attributes for RVV intrinsics (resolved in
+// https://github.com/llvm/llvm-project/issues/56592 and
+// https://gcc.gnu.org/bugzilla/show_bug.cgi?id=115325).
+#if HWY_ARCH_RISCV && HWY_COMPILER_CLANG >= 1900
#define HWY_HAVE_RUNTIME_DISPATCH_RVV 1
#else
#define HWY_HAVE_RUNTIME_DISPATCH_RVV 0
@@ -768,6 +808,15 @@
#endif
#endif // HWY_HAVE_RUNTIME_DISPATCH_APPLE
+#ifndef HWY_HAVE_RUNTIME_DISPATCH_LOONGARCH // allow override
+#if HWY_ARCH_LOONGARCH && HWY_HAVE_AUXV && !defined(__loongarch_asx) && \
+ HWY_COMPILER_CLANG && HWY_COMPILER_CLANG >= 1800
+#define HWY_HAVE_RUNTIME_DISPATCH_LOONGARCH 1
+#else
+#define HWY_HAVE_RUNTIME_DISPATCH_LOONGARCH 0
+#endif
+#endif // HWY_HAVE_RUNTIME_DISPATCH_LOONGARCH
+
#ifndef HWY_HAVE_RUNTIME_DISPATCH_LINUX // allow override
#if (HWY_ARCH_ARM || HWY_ARCH_PPC || HWY_ARCH_S390X) && HWY_OS_LINUX && \
(HWY_COMPILER_GCC_ACTUAL || HWY_COMPILER_CLANG >= 1700) && HWY_HAVE_AUXV
@@ -780,8 +829,12 @@
// Allow opting out, and without a guarantee of success, opting-in.
#ifndef HWY_HAVE_RUNTIME_DISPATCH
// Clang, GCC and MSVC allow OS-independent runtime dispatch on x86.
-#if HWY_ARCH_X86 || HWY_HAVE_RUNTIME_DISPATCH_RVV || \
- HWY_HAVE_RUNTIME_DISPATCH_APPLE || HWY_HAVE_RUNTIME_DISPATCH_LINUX
+// Wasm does not, because browsers reject a binary containing any SIMD
+// instructions when the browser does not support them. Typical practice there
+// is to build two binaries, one with the -msimd128 flag.
+#if HWY_ARCH_X86 || HWY_HAVE_RUNTIME_DISPATCH_RVV || \
+ HWY_HAVE_RUNTIME_DISPATCH_APPLE || HWY_HAVE_RUNTIME_DISPATCH_LOONGARCH || \
+ HWY_HAVE_RUNTIME_DISPATCH_LINUX
#define HWY_HAVE_RUNTIME_DISPATCH 1
#else
#define HWY_HAVE_RUNTIME_DISPATCH 0
@@ -859,7 +912,7 @@
#define HWY_ATTAINABLE_TARGETS_X86 \
HWY_ENABLED(HWY_BASELINE_SCALAR | HWY_SSE2 | HWY_SSSE3 | HWY_SSE4 | \
HWY_AVX2 | HWY_AVX3 | HWY_AVX3_DL | HWY_AVX3_ZEN4 | \
- HWY_AVX3_SPR)
+ HWY_AVX3_SPR | HWY_AVX10_2)
#endif // !HWY_COMPILER_MSVC
#endif // HWY_ATTAINABLE_TARGETS_X86
@@ -923,7 +976,7 @@
// HWY_ONCE and the multiple-inclusion mechanism rely on HWY_STATIC_TARGET being
// one of the dynamic targets. This also implies HWY_TARGETS != 0 and
// (HWY_TARGETS & HWY_ENABLED_BASELINE) != 0.
-#if (HWY_TARGETS & HWY_STATIC_TARGET) == 0
+#if (HWY_TARGETS & HWY_STATIC_TARGET) == 0 && HWY_ENABLED_BASELINE != 0
#error "Logic error: best baseline should be included in dynamic targets"
#endif
diff --git a/third_party/highway/hwy/examples/skeleton-inl.h b/third_party/highway/hwy/examples/skeleton-inl.h
deleted file mode 100644
index 227ef462e5..0000000000
--- a/third_party/highway/hwy/examples/skeleton-inl.h
+++ /dev/null
@@ -1,64 +0,0 @@
-// Copyright 2020 Google LLC
-// SPDX-License-Identifier: Apache-2.0
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Demo of functions that might be called from multiple SIMD modules (either
-// other -inl.h files, or a .cc file between begin/end_target-inl). This is
-// optional - all SIMD code can reside in .cc files. However, this allows
-// splitting code into different files while still inlining instead of requiring
-// calling through function pointers.
-
-// Per-target include guard. This is only required when using dynamic dispatch,
-// i.e. including foreach_target.h. For static dispatch, a normal include
-// guard would be fine because the header is only compiled once.
-#if defined(HIGHWAY_HWY_EXAMPLES_SKELETON_INL_H_) == defined(HWY_TARGET_TOGGLE)
-#ifdef HIGHWAY_HWY_EXAMPLES_SKELETON_INL_H_
-#undef HIGHWAY_HWY_EXAMPLES_SKELETON_INL_H_
-#else
-#define HIGHWAY_HWY_EXAMPLES_SKELETON_INL_H_
-#endif
-
-// It is fine to #include normal or *-inl headers.
-#include "third_party/highway/hwy/highway.h"
-
-HWY_BEFORE_NAMESPACE();
-namespace skeleton {
-namespace HWY_NAMESPACE {
-
-// Highway ops reside here; ADL does not find templates nor builtins.
-namespace hn = hwy::HWY_NAMESPACE;
-
-// Example of a type-agnostic (caller-specified lane type) and width-agnostic
-// (uses best available instruction set) function in a header.
-//
-// Computes x[i] = mul_array[i] * x_array[i] + add_array[i] for i < size.
-template <class D, typename T>
-HWY_MAYBE_UNUSED void MulAddLoop(const D d, const T* HWY_RESTRICT mul_array,
- const T* HWY_RESTRICT add_array,
- const size_t size, T* HWY_RESTRICT x_array) {
- for (size_t i = 0; i < size; i += hn::Lanes(d)) {
- const auto mul = hn::Load(d, mul_array + i);
- const auto add = hn::Load(d, add_array + i);
- auto x = hn::Load(d, x_array + i);
- x = hn::MulAdd(mul, x, add);
- hn::Store(x, d, x_array + i);
- }
-}
-
-// NOLINTNEXTLINE(google-readability-namespace-comments)
-} // namespace HWY_NAMESPACE
-} // namespace skeleton
-HWY_AFTER_NAMESPACE();
-
-#endif // include guard
diff --git a/third_party/highway/hwy/examples/skeleton.h b/third_party/highway/hwy/examples/skeleton.h
deleted file mode 100644
index 55e15a49dc..0000000000
--- a/third_party/highway/hwy/examples/skeleton.h
+++ /dev/null
@@ -1,38 +0,0 @@
-// Copyright 2020 Google LLC
-// SPDX-License-Identifier: Apache-2.0
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Demo interface to target-specific code in skeleton.cc
-
-// Normal header with include guard and namespace.
-#ifndef HIGHWAY_HWY_EXAMPLES_SKELETON_H_
-#define HIGHWAY_HWY_EXAMPLES_SKELETON_H_
-
-// Platform-specific definitions used for declaring an interface, independent of
-// the SIMD instruction set.
-#include "third_party/highway/hwy/base.h" // HWY_RESTRICT
-
-namespace skeleton {
-
-// Computes base-2 logarithm by converting to float. Supports dynamic dispatch.
-HWY_DLLEXPORT void CallFloorLog2(const uint8_t* HWY_RESTRICT in, size_t count,
- uint8_t* HWY_RESTRICT out);
-
-// Same, but uses HWY_DYNAMIC_POINTER to save a function pointer and call it.
-HWY_DLLEXPORT void SavedCallFloorLog2(const uint8_t* HWY_RESTRICT in,
- size_t count, uint8_t* HWY_RESTRICT out);
-
-} // namespace skeleton
-
-#endif // HIGHWAY_HWY_EXAMPLES_SKELETON_H_
diff --git a/third_party/highway/hwy/foreach_target.h b/third_party/highway/hwy/foreach_target.h
index 33faf8507d..11c7ed5979 100644
--- a/third_party/highway/hwy/foreach_target.h
+++ b/third_party/highway/hwy/foreach_target.h
@@ -154,17 +154,6 @@
#endif
#endif
-#if (HWY_TARGETS & HWY_AVX10_2_512) && (HWY_STATIC_TARGET != HWY_AVX10_2_512)
-#undef HWY_TARGET
-#define HWY_TARGET HWY_AVX10_2_512
-#include HWY_TARGET_INCLUDE
-#ifdef HWY_TARGET_TOGGLE
-#undef HWY_TARGET_TOGGLE
-#else
-#define HWY_TARGET_TOGGLE
-#endif
-#endif
-
// ------------------------------ HWY_ARCH_ARM
#if (HWY_TARGETS & HWY_NEON_WITHOUT_AES) && \
diff --git a/third_party/highway/hwy/highway.h b/third_party/highway/hwy/highway.h
index a50d9a271f..f797dd8519 100644
--- a/third_party/highway/hwy/highway.h
+++ b/third_party/highway/hwy/highway.h
@@ -61,249 +61,305 @@ namespace hwy {
//------------------------------------------------------------------------------
// Export user functions for static/dynamic dispatch
+// The static target is the best baseline. When using foreach_target.h, this is
+// the last target compiled. Otherwise, it is the only target.
+
// Evaluates to 0 inside a translation unit if it is generating anything but the
-// static target (the last one if multiple targets are enabled). Used to prevent
-// redefinitions of HWY_EXPORT. Unless foreach_target.h is included, we only
-// compile once anyway, so this is 1 unless it is or has been included.
+// static target. Used to prevent redefinitions of HWY_EXPORT. Unless
+// foreach_target.h is included, we only compile once anyway, so this is 1
+// unless it is or has been included.
#ifndef HWY_ONCE
#define HWY_ONCE 1
#endif
-// HWY_STATIC_DISPATCH(FUNC_NAME) is the namespace-qualified FUNC_NAME for
-// HWY_STATIC_TARGET (the only defined namespace unless HWY_TARGET_INCLUDE is
-// defined), and can be used to deduce the return type of Choose*.
+// `HWY_STATIC_NAMESPACE` expands to its namespace name, e.g. `N_AVX2`.
#if HWY_STATIC_TARGET == HWY_SCALAR
-#define HWY_STATIC_DISPATCH(FUNC_NAME) N_SCALAR::FUNC_NAME
+#define HWY_STATIC_NAMESPACE N_SCALAR
#elif HWY_STATIC_TARGET == HWY_EMU128
-#define HWY_STATIC_DISPATCH(FUNC_NAME) N_EMU128::FUNC_NAME
+#define HWY_STATIC_NAMESPACE N_EMU128
#elif HWY_STATIC_TARGET == HWY_WASM
-#define HWY_STATIC_DISPATCH(FUNC_NAME) N_WASM::FUNC_NAME
+#define HWY_STATIC_NAMESPACE N_WASM
#elif HWY_STATIC_TARGET == HWY_WASM_EMU256
-#define HWY_STATIC_DISPATCH(FUNC_NAME) N_WASM_EMU256::FUNC_NAME
+#define HWY_STATIC_NAMESPACE N_WASM_EMU256
#elif HWY_STATIC_TARGET == HWY_Z14
-#define HWY_STATIC_DISPATCH(FUNC_NAME) N_Z14::FUNC_NAME
+#define HWY_STATIC_NAMESPACE N_Z14
#elif HWY_STATIC_TARGET == HWY_Z15
-#define HWY_STATIC_DISPATCH(FUNC_NAME) N_Z15::FUNC_NAME
+#define HWY_STATIC_NAMESPACE N_Z15
#elif HWY_STATIC_TARGET == HWY_PPC8
-#define HWY_STATIC_DISPATCH(FUNC_NAME) N_PPC8::FUNC_NAME
+#define HWY_STATIC_NAMESPACE N_PPC8
#elif HWY_STATIC_TARGET == HWY_PPC9
-#define HWY_STATIC_DISPATCH(FUNC_NAME) N_PPC9::FUNC_NAME
+#define HWY_STATIC_NAMESPACE N_PPC9
#elif HWY_STATIC_TARGET == HWY_PPC10
-#define HWY_STATIC_DISPATCH(FUNC_NAME) N_PPC10::FUNC_NAME
+#define HWY_STATIC_NAMESPACE N_PPC10
#elif HWY_STATIC_TARGET == HWY_LSX
-#define HWY_STATIC_DISPATCH(FUNC_NAME) N_LSX::FUNC_NAME
+#define HWY_STATIC_NAMESPACE N_LSX
#elif HWY_STATIC_TARGET == HWY_LASX
-#define HWY_STATIC_DISPATCH(FUNC_NAME) N_LASX::FUNC_NAME
+#define HWY_STATIC_NAMESPACE N_LASX
#elif HWY_STATIC_TARGET == HWY_RVV
-#define HWY_STATIC_DISPATCH(FUNC_NAME) N_RVV::FUNC_NAME
+#define HWY_STATIC_NAMESPACE N_RVV
#elif HWY_STATIC_TARGET == HWY_NEON_WITHOUT_AES
-#define HWY_STATIC_DISPATCH(FUNC_NAME) N_NEON_WITHOUT_AES::FUNC_NAME
+#define HWY_STATIC_NAMESPACE N_NEON_WITHOUT_AES
#elif HWY_STATIC_TARGET == HWY_NEON
-#define HWY_STATIC_DISPATCH(FUNC_NAME) N_NEON::FUNC_NAME
+#define HWY_STATIC_NAMESPACE N_NEON
#elif HWY_STATIC_TARGET == HWY_NEON_BF16
-#define HWY_STATIC_DISPATCH(FUNC_NAME) N_NEON_BF16::FUNC_NAME
+#define HWY_STATIC_NAMESPACE N_NEON_BF16
#elif HWY_STATIC_TARGET == HWY_SVE
-#define HWY_STATIC_DISPATCH(FUNC_NAME) N_SVE::FUNC_NAME
+#define HWY_STATIC_NAMESPACE N_SVE
#elif HWY_STATIC_TARGET == HWY_SVE2
-#define HWY_STATIC_DISPATCH(FUNC_NAME) N_SVE2::FUNC_NAME
+#define HWY_STATIC_NAMESPACE N_SVE2
#elif HWY_STATIC_TARGET == HWY_SVE_256
-#define HWY_STATIC_DISPATCH(FUNC_NAME) N_SVE_256::FUNC_NAME
+#define HWY_STATIC_NAMESPACE N_SVE_256
#elif HWY_STATIC_TARGET == HWY_SVE2_128
-#define HWY_STATIC_DISPATCH(FUNC_NAME) N_SVE2_128::FUNC_NAME
+#define HWY_STATIC_NAMESPACE N_SVE2_128
#elif HWY_STATIC_TARGET == HWY_SSE2
-#define HWY_STATIC_DISPATCH(FUNC_NAME) N_SSE2::FUNC_NAME
+#define HWY_STATIC_NAMESPACE N_SSE2
#elif HWY_STATIC_TARGET == HWY_SSSE3
-#define HWY_STATIC_DISPATCH(FUNC_NAME) N_SSSE3::FUNC_NAME
+#define HWY_STATIC_NAMESPACE N_SSSE3
#elif HWY_STATIC_TARGET == HWY_SSE4
-#define HWY_STATIC_DISPATCH(FUNC_NAME) N_SSE4::FUNC_NAME
+#define HWY_STATIC_NAMESPACE N_SSE4
#elif HWY_STATIC_TARGET == HWY_AVX2
-#define HWY_STATIC_DISPATCH(FUNC_NAME) N_AVX2::FUNC_NAME
+#define HWY_STATIC_NAMESPACE N_AVX2
#elif HWY_STATIC_TARGET == HWY_AVX3
-#define HWY_STATIC_DISPATCH(FUNC_NAME) N_AVX3::FUNC_NAME
+#define HWY_STATIC_NAMESPACE N_AVX3
#elif HWY_STATIC_TARGET == HWY_AVX3_DL
-#define HWY_STATIC_DISPATCH(FUNC_NAME) N_AVX3_DL::FUNC_NAME
+#define HWY_STATIC_NAMESPACE N_AVX3_DL
#elif HWY_STATIC_TARGET == HWY_AVX3_ZEN4
-#define HWY_STATIC_DISPATCH(FUNC_NAME) N_AVX3_ZEN4::FUNC_NAME
-#elif HWY_STATIC_TARGET == HWY_AVX10_2
-#define HWY_STATIC_DISPATCH(FUNC_NAME) N_AVX10_2::FUNC_NAME
+#define HWY_STATIC_NAMESPACE N_AVX3_ZEN4
#elif HWY_STATIC_TARGET == HWY_AVX3_SPR
-#define HWY_STATIC_DISPATCH(FUNC_NAME) N_AVX3_SPR::FUNC_NAME
-#elif HWY_STATIC_TARGET == HWY_AVX10_2_512
-#define HWY_STATIC_DISPATCH(FUNC_NAME) N_AVX10_2_512::FUNC_NAME
+#define HWY_STATIC_NAMESPACE N_AVX3_SPR
+#elif HWY_STATIC_TARGET == HWY_AVX10_2
+#define HWY_STATIC_NAMESPACE N_AVX10_2
#endif
-// HWY_CHOOSE_*(FUNC_NAME) expands to the function pointer for that target or
-// nullptr is that target was not compiled.
+// `HWY_STATIC_DISPATCH(FUNC_NAME)` is the namespace-qualified FUNC_NAME for
+// `HWY_STATIC_TARGET`, and can be used to deduce the return type of Choose*.
+#define HWY_STATIC_DISPATCH(FUNC_NAME) HWY_STATIC_NAMESPACE::FUNC_NAME
+
+// `HWY_CHOOSE_*(FUNC_NAME)` expands to the function pointer for that target or
+// nullptr if that target was not compiled.
+// `HWY_VISIT_*(VISITOR)` expands to `VISITOR(TARGET, NAMESPACE)` or nothing if
+// that target was not compiled.
#if HWY_TARGETS & HWY_EMU128
#define HWY_CHOOSE_FALLBACK(FUNC_NAME) &N_EMU128::FUNC_NAME
+#define HWY_VISIT_FALLBACK(VISITOR) VISITOR(HWY_EMU128, N_EMU128)
#elif HWY_TARGETS & HWY_SCALAR
#define HWY_CHOOSE_FALLBACK(FUNC_NAME) &N_SCALAR::FUNC_NAME
+#define HWY_VISIT_FALLBACK(VISITOR) VISITOR(HWY_SCALAR, N_SCALAR)
#else
// When HWY_SCALAR/HWY_EMU128 are not present and other targets were disabled at
// runtime, fall back to the baseline with HWY_STATIC_DISPATCH().
#define HWY_CHOOSE_FALLBACK(FUNC_NAME) &HWY_STATIC_DISPATCH(FUNC_NAME)
+#define HWY_VISIT_FALLBACK(VISITOR) \
+ VISITOR(HWY_STATIC_TARGET, HWY_STATIC_NAMESPACE)
#endif
#if HWY_TARGETS & HWY_WASM
#define HWY_CHOOSE_WASM(FUNC_NAME) &N_WASM::FUNC_NAME
+#define HWY_VISIT_WASM(VISITOR) VISITOR(HWY_WASM, N_WASM)
#else
#define HWY_CHOOSE_WASM(FUNC_NAME) nullptr
+#define HWY_VISIT_WASM(VISITOR)
#endif
#if HWY_TARGETS & HWY_WASM_EMU256
#define HWY_CHOOSE_WASM_EMU256(FUNC_NAME) &N_WASM_EMU256::FUNC_NAME
+#define HWY_VISIT_WASM_EMU256(VISITOR) VISITOR(HWY_WASM_EMU256, N_WASM_EMU256)
#else
#define HWY_CHOOSE_WASM_EMU256(FUNC_NAME) nullptr
+#define HWY_VISIT_WASM_EMU256(VISITOR)
#endif
#if HWY_TARGETS & HWY_Z14
#define HWY_CHOOSE_Z14(FUNC_NAME) &N_Z14::FUNC_NAME
+#define HWY_VISIT_Z14(VISITOR) VISITOR(HWY_Z14, N_Z14)
#else
#define HWY_CHOOSE_Z14(FUNC_NAME) nullptr
+#define HWY_VISIT_Z14(VISITOR)
#endif
#if HWY_TARGETS & HWY_Z15
#define HWY_CHOOSE_Z15(FUNC_NAME) &N_Z15::FUNC_NAME
+#define HWY_VISIT_Z15(VISITOR) VISITOR(HWY_Z15, N_Z15)
#else
#define HWY_CHOOSE_Z15(FUNC_NAME) nullptr
+#define HWY_VISIT_Z15(VISITOR)
#endif
#if HWY_TARGETS & HWY_PPC8
#define HWY_CHOOSE_PPC8(FUNC_NAME) &N_PPC8::FUNC_NAME
+#define HWY_VISIT_PPC8(VISITOR) VISITOR(HWY_PPC8, N_PPC8)
#else
#define HWY_CHOOSE_PPC8(FUNC_NAME) nullptr
+#define HWY_VISIT_PPC8(VISITOR)
#endif
#if HWY_TARGETS & HWY_PPC9
#define HWY_CHOOSE_PPC9(FUNC_NAME) &N_PPC9::FUNC_NAME
+#define HWY_VISIT_PPC9(VISITOR) VISITOR(HWY_PPC9, N_PPC9)
#else
#define HWY_CHOOSE_PPC9(FUNC_NAME) nullptr
+#define HWY_VISIT_PPC9(VISITOR)
#endif
#if HWY_TARGETS & HWY_LSX
#define HWY_CHOOSE_LSX(FUNC_NAME) &N_LSX::FUNC_NAME
+#define HWY_VISIT_LSX(VISITOR) VISITOR(HWY_LSX, N_LSX)
#else
#define HWY_CHOOSE_LSX(FUNC_NAME) nullptr
+#define HWY_VISIT_LSX(VISITOR)
#endif
#if HWY_TARGETS & HWY_LASX
#define HWY_CHOOSE_LASX(FUNC_NAME) &N_LASX::FUNC_NAME
+#define HWY_VISIT_LASX(VISITOR) VISITOR(HWY_LASX, N_LASX)
#else
#define HWY_CHOOSE_LASX(FUNC_NAME) nullptr
+#define HWY_VISIT_LASX(VISITOR)
#endif
#if HWY_TARGETS & HWY_PPC10
#define HWY_CHOOSE_PPC10(FUNC_NAME) &N_PPC10::FUNC_NAME
+#define HWY_VISIT_PPC10(VISITOR) VISITOR(HWY_PPC10, N_PPC10)
#else
#define HWY_CHOOSE_PPC10(FUNC_NAME) nullptr
+#define HWY_VISIT_PPC10(VISITOR)
#endif
#if HWY_TARGETS & HWY_RVV
#define HWY_CHOOSE_RVV(FUNC_NAME) &N_RVV::FUNC_NAME
+#define HWY_VISIT_RVV(VISITOR) VISITOR(HWY_RVV, N_RVV)
#else
#define HWY_CHOOSE_RVV(FUNC_NAME) nullptr
+#define HWY_VISIT_RVV(VISITOR)
#endif
#if HWY_TARGETS & HWY_NEON_WITHOUT_AES
#define HWY_CHOOSE_NEON_WITHOUT_AES(FUNC_NAME) &N_NEON_WITHOUT_AES::FUNC_NAME
+#define HWY_VISIT_NEON_WITHOUT_AES(VISITOR) \
+ VISITOR(HWY_NEON_WITHOUT_AES, N_NEON_WITHOUT_AES)
#else
#define HWY_CHOOSE_NEON_WITHOUT_AES(FUNC_NAME) nullptr
+#define HWY_VISIT_NEON_WITHOUT_AES(VISITOR)
#endif
#if HWY_TARGETS & HWY_NEON
#define HWY_CHOOSE_NEON(FUNC_NAME) &N_NEON::FUNC_NAME
+#define HWY_VISIT_NEON(VISITOR) VISITOR(HWY_NEON, N_NEON)
#else
#define HWY_CHOOSE_NEON(FUNC_NAME) nullptr
+#define HWY_VISIT_NEON(VISITOR)
#endif
#if HWY_TARGETS & HWY_NEON_BF16
#define HWY_CHOOSE_NEON_BF16(FUNC_NAME) &N_NEON_BF16::FUNC_NAME
+#define HWY_VISIT_NEON_BF16(VISITOR) VISITOR(HWY_NEON_BF16, N_NEON_BF16)
#else
#define HWY_CHOOSE_NEON_BF16(FUNC_NAME) nullptr
+#define HWY_VISIT_NEON_BF16(VISITOR)
#endif
#if HWY_TARGETS & HWY_SVE
#define HWY_CHOOSE_SVE(FUNC_NAME) &N_SVE::FUNC_NAME
+#define HWY_VISIT_SVE(VISITOR) VISITOR(HWY_SVE, N_SVE)
#else
#define HWY_CHOOSE_SVE(FUNC_NAME) nullptr
+#define HWY_VISIT_SVE(VISITOR)
#endif
#if HWY_TARGETS & HWY_SVE2
#define HWY_CHOOSE_SVE2(FUNC_NAME) &N_SVE2::FUNC_NAME
+#define HWY_VISIT_SVE2(VISITOR) VISITOR(HWY_SVE2, N_SVE2)
#else
#define HWY_CHOOSE_SVE2(FUNC_NAME) nullptr
+#define HWY_VISIT_SVE2(VISITOR)
#endif
#if HWY_TARGETS & HWY_SVE_256
#define HWY_CHOOSE_SVE_256(FUNC_NAME) &N_SVE_256::FUNC_NAME
+#define HWY_VISIT_SVE_256(VISITOR) VISITOR(HWY_SVE_256, N_SVE_256)
#else
#define HWY_CHOOSE_SVE_256(FUNC_NAME) nullptr
+#define HWY_VISIT_SVE_256(VISITOR)
#endif
#if HWY_TARGETS & HWY_SVE2_128
#define HWY_CHOOSE_SVE2_128(FUNC_NAME) &N_SVE2_128::FUNC_NAME
+#define HWY_VISIT_SVE2_128(VISITOR) VISITOR(HWY_SVE2_128, N_SVE2_128)
#else
#define HWY_CHOOSE_SVE2_128(FUNC_NAME) nullptr
+#define HWY_VISIT_SVE2_128(VISITOR)
#endif
#if HWY_TARGETS & HWY_SSE2
#define HWY_CHOOSE_SSE2(FUNC_NAME) &N_SSE2::FUNC_NAME
+#define HWY_VISIT_SSE2(VISITOR) VISITOR(HWY_SSE2, N_SSE2)
#else
#define HWY_CHOOSE_SSE2(FUNC_NAME) nullptr
+#define HWY_VISIT_SSE2(VISITOR)
#endif
#if HWY_TARGETS & HWY_SSSE3
#define HWY_CHOOSE_SSSE3(FUNC_NAME) &N_SSSE3::FUNC_NAME
+#define HWY_VISIT_SSSE3(VISITOR) VISITOR(HWY_SSSE3, N_SSSE3)
#else
#define HWY_CHOOSE_SSSE3(FUNC_NAME) nullptr
+#define HWY_VISIT_SSSE3(VISITOR)
#endif
#if HWY_TARGETS & HWY_SSE4
#define HWY_CHOOSE_SSE4(FUNC_NAME) &N_SSE4::FUNC_NAME
+#define HWY_VISIT_SSE4(VISITOR) VISITOR(HWY_SSE4, N_SSE4)
#else
#define HWY_CHOOSE_SSE4(FUNC_NAME) nullptr
+#define HWY_VISIT_SSE4(VISITOR)
#endif
#if HWY_TARGETS & HWY_AVX2
#define HWY_CHOOSE_AVX2(FUNC_NAME) &N_AVX2::FUNC_NAME
+#define HWY_VISIT_AVX2(VISITOR) VISITOR(HWY_AVX2, N_AVX2)
#else
#define HWY_CHOOSE_AVX2(FUNC_NAME) nullptr
+#define HWY_VISIT_AVX2(VISITOR)
#endif
#if HWY_TARGETS & HWY_AVX3
#define HWY_CHOOSE_AVX3(FUNC_NAME) &N_AVX3::FUNC_NAME
+#define HWY_VISIT_AVX3(VISITOR) VISITOR(HWY_AVX3, N_AVX3)
#else
#define HWY_CHOOSE_AVX3(FUNC_NAME) nullptr
+#define HWY_VISIT_AVX3(VISITOR)
#endif
#if HWY_TARGETS & HWY_AVX3_DL
#define HWY_CHOOSE_AVX3_DL(FUNC_NAME) &N_AVX3_DL::FUNC_NAME
+#define HWY_VISIT_AVX3_DL(VISITOR) VISITOR(HWY_AVX3_DL, N_AVX3_DL)
#else
#define HWY_CHOOSE_AVX3_DL(FUNC_NAME) nullptr
+#define HWY_VISIT_AVX3_DL(VISITOR)
#endif
#if HWY_TARGETS & HWY_AVX3_ZEN4
#define HWY_CHOOSE_AVX3_ZEN4(FUNC_NAME) &N_AVX3_ZEN4::FUNC_NAME
+#define HWY_VISIT_AVX3_ZEN4(VISITOR) VISITOR(HWY_AVX3_ZEN4, N_AVX3_ZEN4)
#else
#define HWY_CHOOSE_AVX3_ZEN4(FUNC_NAME) nullptr
-#endif
-
-#if HWY_TARGETS & HWY_AVX10_2
-#define HWY_CHOOSE_AVX10_2(FUNC_NAME) &N_AVX10_2::FUNC_NAME
-#else
-#define HWY_CHOOSE_AVX10_2(FUNC_NAME) nullptr
+#define HWY_VISIT_AVX3_ZEN4(VISITOR)
#endif
#if HWY_TARGETS & HWY_AVX3_SPR
#define HWY_CHOOSE_AVX3_SPR(FUNC_NAME) &N_AVX3_SPR::FUNC_NAME
+#define HWY_VISIT_AVX3_SPR(VISITOR) VISITOR(HWY_AVX3_SPR, N_AVX3_SPR)
#else
#define HWY_CHOOSE_AVX3_SPR(FUNC_NAME) nullptr
+#define HWY_VISIT_AVX3_SPR(VISITOR)
#endif
-#if HWY_TARGETS & HWY_AVX10_2_512
-#define HWY_CHOOSE_AVX10_2_512(FUNC_NAME) &N_AVX10_2_512::FUNC_NAME
+#if HWY_TARGETS & HWY_AVX10_2
+#define HWY_CHOOSE_AVX10_2(FUNC_NAME) &N_AVX10_2::FUNC_NAME
+#define HWY_VISIT_AVX10_2(VISITOR) VISITOR(HWY_AVX10_2, N_AVX10_2)
#else
-#define HWY_CHOOSE_AVX10_2_512(FUNC_NAME) nullptr
+#define HWY_CHOOSE_AVX10_2(FUNC_NAME) nullptr
+#define HWY_VISIT_AVX10_2(VISITOR)
#endif
// MSVC 2017 workaround: the non-type template parameter to ChooseAndCall
@@ -567,8 +623,24 @@ struct AddExport {
(HWY_DISPATCH_TABLE(FUNC_NAME)[hwy::GetChosenTarget().GetIndex()])
// Calls the function pointer for the chosen target.
+#if HWY_COMPILER_GCC || HWY_COMPILER_CLANG
+
+// On GCC or Clang, we call hwy::PreventElision(...) to work around a compiler
+// crash where the LLVM inliner crashes due to inlining incompatible intrinsics.
+
+#define HWY_DYNAMIC_DISPATCH(FUNC_NAME) \
+ __extension__({ \
+ auto HWY_CONCAT(hwy_tmp_, __LINE__) = *(HWY_DYNAMIC_POINTER(FUNC_NAME)); \
+ hwy::PreventElision(HWY_CONCAT(hwy_tmp_, __LINE__)); \
+ HWY_CONCAT(hwy_tmp_, __LINE__); \
+ })
+
+#else // !(HWY_COMPILER_GCC || HWY_COMPILER_CLANG)
+
#define HWY_DYNAMIC_DISPATCH(FUNC_NAME) (*(HWY_DYNAMIC_POINTER(FUNC_NAME)))
+#endif // HWY_COMPILER_GCC || HWY_COMPILER_CLANG
+
// Same as DISPATCH, but provide a different arg name to clarify usage.
#define HWY_DYNAMIC_DISPATCH_T(TABLE_NAME) HWY_DYNAMIC_DISPATCH(TABLE_NAME)
#define HWY_DYNAMIC_POINTER_T(TABLE_NAME) HWY_DYNAMIC_POINTER(TABLE_NAME)
@@ -605,14 +677,27 @@ struct AddExport {
#define HWY_HIGHWAY_PER_TARGET
#endif
+// No SIMD target enabled, skip header inclusion.
+#if HWY_ENABLED_BASELINE == 0
+
+// We would expect that HWY_TARGET and HWY_STATIC_TARGET are now both 0.
+#if HWY_TARGET != 0
+#error "Why is HWY_TARGET not 0 when HWY_ENABLED_BASELINE == 0?"
+#endif
+#if HWY_STATIC_TARGET != 0
+#error "Why is HWY_STATIC_TARGET not 0 when HWY_ENABLED_BASELINE == 0?"
+#endif
+
+#else
+
// These define ops inside namespace hwy::HWY_NAMESPACE.
#if HWY_TARGET == HWY_SSE2 || HWY_TARGET == HWY_SSSE3 || HWY_TARGET == HWY_SSE4
#include "third_party/highway/hwy/ops/x86_128-inl.h"
#elif HWY_TARGET == HWY_AVX2
#include "third_party/highway/hwy/ops/x86_256-inl.h"
-#elif HWY_TARGET == HWY_AVX3 || HWY_TARGET == HWY_AVX3_DL || \
- HWY_TARGET == HWY_AVX3_ZEN4 || HWY_TARGET == HWY_AVX10_2 || \
- HWY_TARGET == HWY_AVX3_SPR || HWY_TARGET == HWY_AVX10_2_512
+#elif HWY_TARGET == HWY_AVX3 || HWY_TARGET == HWY_AVX3_DL || \
+ HWY_TARGET == HWY_AVX3_ZEN4 || HWY_TARGET == HWY_AVX3_SPR || \
+ HWY_TARGET == HWY_AVX10_2
#include "third_party/highway/hwy/ops/x86_avx3-inl.h"
#elif HWY_TARGET == HWY_Z14 || HWY_TARGET == HWY_Z15 || \
(HWY_TARGET & HWY_ALL_PPC)
@@ -627,16 +712,20 @@ struct AddExport {
#include "third_party/highway/hwy/ops/wasm_128-inl.h"
#elif HWY_TARGET == HWY_RVV
#include "third_party/highway/hwy/ops/rvv-inl.h"
+#elif HWY_TARGET == HWY_LSX
+#include "third_party/highway/hwy/ops/loongarch_lsx-inl.h"
+#elif HWY_TARGET == HWY_LASX
+#include "third_party/highway/hwy/ops/loongarch_lasx-inl.h"
#elif HWY_TARGET == HWY_EMU128
#include "third_party/highway/hwy/ops/emu128-inl.h"
#elif HWY_TARGET == HWY_SCALAR
#include "third_party/highway/hwy/ops/scalar-inl.h"
-#elif HWY_TARGET == HWY_LSX || HWY_TARGET == HWY_LASX
-#include "third_party/highway/hwy/ops/loongarch_lsx-inl.h"
#else
#pragma message("HWY_TARGET does not match any known target")
#endif // HWY_TARGET
#include "third_party/highway/hwy/ops/generic_ops-inl.h"
+#endif // HWY_ENABLED_BASELINE
+
#endif // HWY_HIGHWAY_PER_TARGET
diff --git a/third_party/highway/hwy/nanobenchmark.cc b/third_party/highway/hwy/nanobenchmark.cc
new file mode 100644
index 0000000000..32cb3905bb
--- /dev/null
+++ b/third_party/highway/hwy/nanobenchmark.cc
@@ -0,0 +1,302 @@
+// Copyright 2019 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "third_party/highway/hwy/nanobenchmark.h"
+
+#include <stdio.h>
+#include <stdlib.h>
+#include <time.h> // clock_gettime
+
+#include <algorithm> // std::sort, std::find_if
+#include <numeric> // std::iota
+#include <random>
+#include <vector>
+
+#include "third_party/highway/hwy/base.h"
+#include "third_party/highway/hwy/robust_statistics.h"
+#include "third_party/highway/hwy/timer.h"
+
+namespace hwy {
+namespace {
+const timer::Ticks& GetTimerResolution() {
+ static const timer::Ticks timer_resolution = platform::TimerResolution();
+ return timer_resolution;
+}
+
+// Estimates the expected value of "lambda" values with a variable number of
+// samples until the variability "rel_mad" is less than "max_rel_mad".
+template <class Lambda>
+timer::Ticks SampleUntilStable(const double max_rel_mad, double* rel_mad,
+ const Params& p, const Lambda& lambda) {
+ // Choose initial samples_per_eval based on a single estimated duration.
+ timer::Ticks t0 = timer::Start();
+ lambda();
+ timer::Ticks t1 = timer::Stop(); // Caller checks HaveTimerStop
+ timer::Ticks est = t1 - t0;
+ static const double ticks_per_second = platform::InvariantTicksPerSecond();
+ const size_t ticks_per_eval =
+ static_cast<size_t>(ticks_per_second * p.seconds_per_eval);
+ size_t samples_per_eval = est == 0
+ ? p.min_samples_per_eval
+ : static_cast<size_t>(ticks_per_eval / est);
+ samples_per_eval = HWY_MAX(samples_per_eval, p.min_samples_per_eval);
+
+ std::vector<timer::Ticks> samples;
+ samples.reserve(1 + samples_per_eval);
+ samples.push_back(est);
+
+ // Percentage is too strict for tiny differences, so also allow a small
+ // absolute "median absolute deviation".
+ const timer::Ticks max_abs_mad = (GetTimerResolution() + 99) / 100;
+ *rel_mad = 0.0; // ensure initialized
+
+ for (size_t eval = 0; eval < p.max_evals; ++eval, samples_per_eval *= 2) {
+ samples.reserve(samples.size() + samples_per_eval);
+ for (size_t i = 0; i < samples_per_eval; ++i) {
+ t0 = timer::Start();
+ lambda();
+ t1 = timer::Stop(); // Caller checks HaveTimerStop
+ samples.push_back(t1 - t0);
+ }
+
+ if (samples.size() >= p.min_mode_samples) {
+ est = robust_statistics::Mode(samples.data(), samples.size());
+ } else {
+ // For "few" (depends also on the variance) samples, Median is safer.
+ est = robust_statistics::Median(samples.data(), samples.size());
+ }
+ if (est == 0) {
+ HWY_WARN("estimated duration is 0\n");
+ }
+
+ // Median absolute deviation (mad) is a robust measure of 'variability'.
+ const timer::Ticks abs_mad = robust_statistics::MedianAbsoluteDeviation(
+ samples.data(), samples.size(), est);
+ *rel_mad = static_cast<double>(abs_mad) / static_cast<double>(est);
+
+ if (*rel_mad <= max_rel_mad || abs_mad <= max_abs_mad) {
+ if (p.verbose) {
+ printf("%6d samples => %5d (abs_mad=%4d, rel_mad=%4.2f%%)\n",
+ static_cast<int>(samples.size()), static_cast<int>(est),
+ static_cast<int>(abs_mad), *rel_mad * 100.0);
+ }
+ return est;
+ }
+ }
+
+ if (p.verbose) {
+ printf("WARNING: rel_mad=%4.2f%% still exceeds %4.2f%% after %6d samples\n",
+ *rel_mad * 100.0, max_rel_mad * 100.0,
+ static_cast<int>(samples.size()));
+ }
+ return est;
+}
+
+using InputVec = std::vector<FuncInput>;
+
+// Returns vector of unique input values.
+InputVec UniqueInputs(const FuncInput* inputs, const size_t num_inputs) {
+ InputVec unique(inputs, inputs + num_inputs);
+ std::sort(unique.begin(), unique.end());
+ unique.erase(std::unique(unique.begin(), unique.end()), unique.end());
+ return unique;
+}
+
+// Returns how often we need to call func for sufficient precision.
+size_t NumSkip(const Func func, const uint8_t* arg, const InputVec& unique,
+ const Params& p) {
+ // Min elapsed ticks for any input.
+ timer::Ticks min_duration = ~timer::Ticks(0);
+
+ for (const FuncInput input : unique) {
+ double rel_mad;
+ const timer::Ticks total = SampleUntilStable(
+ p.target_rel_mad, &rel_mad, p,
+ [func, arg, input]() { PreventElision(func(arg, input)); });
+ min_duration = HWY_MIN(min_duration, total - GetTimerResolution());
+ }
+
+ // Number of repetitions required to reach the target resolution.
+ const size_t max_skip = p.precision_divisor;
+ // Number of repetitions given the estimated duration.
+ const size_t num_skip =
+ min_duration == 0
+ ? 0
+ : static_cast<size_t>((max_skip + min_duration - 1) / min_duration);
+ if (p.verbose) {
+ printf("res=%d max_skip=%d min_dur=%d num_skip=%d\n",
+ static_cast<int>(GetTimerResolution()), static_cast<int>(max_skip),
+ static_cast<int>(min_duration), static_cast<int>(num_skip));
+ }
+ return num_skip;
+}
+
+// Replicates inputs until we can omit "num_skip" occurrences of an input.
+InputVec ReplicateInputs(const FuncInput* inputs, const size_t num_inputs,
+ const size_t num_unique, const size_t num_skip,
+ const Params& p) {
+ InputVec full;
+ if (num_unique == 1) {
+ full.assign(p.subset_ratio * num_skip, inputs[0]);
+ return full;
+ }
+
+ full.reserve(p.subset_ratio * num_skip * num_inputs);
+ for (size_t i = 0; i < p.subset_ratio * num_skip; ++i) {
+ full.insert(full.end(), inputs, inputs + num_inputs);
+ }
+ std::mt19937 rng;
+ std::shuffle(full.begin(), full.end(), rng);
+ return full;
+}
+
+// Copies the "full" to "subset" in the same order, but with "num_skip"
+// randomly selected occurrences of "input_to_skip" removed.
+void FillSubset(const InputVec& full, const FuncInput input_to_skip,
+ const size_t num_skip, InputVec* subset) {
+ const size_t count =
+ static_cast<size_t>(std::count(full.begin(), full.end(), input_to_skip));
+ // Generate num_skip random indices: which occurrence to skip.
+ std::vector<uint32_t> omit(count);
+ std::iota(omit.begin(), omit.end(), 0);
+ // omit[] is the same on every call, but that's OK because they identify the
+ // Nth instance of input_to_skip, so the position within full[] differs.
+ std::mt19937 rng;
+ std::shuffle(omit.begin(), omit.end(), rng);
+ omit.resize(num_skip);
+ std::sort(omit.begin(), omit.end());
+
+ uint32_t occurrence = ~0u; // 0 after preincrement
+ size_t idx_omit = 0; // cursor within omit[]
+ size_t idx_subset = 0; // cursor within *subset
+ for (const FuncInput next : full) {
+ if (next == input_to_skip) {
+ ++occurrence;
+ // Haven't removed enough already
+ if (idx_omit < num_skip) {
+ // This one is up for removal
+ if (occurrence == omit[idx_omit]) {
+ ++idx_omit;
+ continue;
+ }
+ }
+ }
+ if (idx_subset < subset->size()) {
+ (*subset)[idx_subset++] = next;
+ }
+ }
+ HWY_DASSERT(idx_subset == subset->size());
+ HWY_DASSERT(idx_omit == omit.size());
+ HWY_DASSERT(occurrence == count - 1);
+}
+
+// Returns total ticks elapsed for all inputs.
+timer::Ticks TotalDuration(const Func func, const uint8_t* arg,
+ const InputVec* inputs, const Params& p,
+ double* max_rel_mad) {
+ double rel_mad;
+ const timer::Ticks duration =
+ SampleUntilStable(p.target_rel_mad, &rel_mad, p, [func, arg, inputs]() {
+ for (const FuncInput input : *inputs) {
+ PreventElision(func(arg, input));
+ }
+ });
+ *max_rel_mad = HWY_MAX(*max_rel_mad, rel_mad);
+ return duration;
+}
+
+// (Nearly) empty Func for measuring timer overhead/resolution.
+HWY_NOINLINE FuncOutput EmptyFunc(const void* /*arg*/, const FuncInput input) {
+ return input;
+}
+
+// Returns overhead of accessing inputs[] and calling a function; this will
+// be deducted from future TotalDuration return values.
+timer::Ticks Overhead(const uint8_t* arg, const InputVec* inputs,
+ const Params& p) {
+ double rel_mad;
+ // Zero tolerance because repeatability is crucial and EmptyFunc is fast.
+ return SampleUntilStable(0.0, &rel_mad, p, [arg, inputs]() {
+ for (const FuncInput input : *inputs) {
+ PreventElision(EmptyFunc(arg, input));
+ }
+ });
+}
+
+} // namespace
+
+HWY_DLLEXPORT int Unpredictable1() { return timer::Start() != ~0ULL; }
+
+HWY_DLLEXPORT size_t Measure(const Func func, const uint8_t* arg,
+ const FuncInput* inputs, const size_t num_inputs,
+ Result* results, const Params& p) {
+ HWY_DASSERT(num_inputs != 0);
+
+ char cpu100[100];
+ if (!platform::HaveTimerStop(cpu100)) {
+ HWY_WARN("CPU '%s' does not support RDTSCP, skipping benchmark.\n", cpu100);
+ return 0;
+ }
+
+ const InputVec& unique = UniqueInputs(inputs, num_inputs);
+
+ const size_t num_skip = NumSkip(func, arg, unique, p); // never 0
+ if (num_skip == 0) return 0; // NumSkip already printed error message
+ // (slightly less work on x86 to cast from signed integer)
+ const float mul = 1.0f / static_cast<float>(static_cast<int>(num_skip));
+
+ const InputVec& full =
+ ReplicateInputs(inputs, num_inputs, unique.size(), num_skip, p);
+ InputVec subset(full.size() - num_skip);
+
+ const timer::Ticks overhead = Overhead(arg, &full, p);
+ const timer::Ticks overhead_skip = Overhead(arg, &subset, p);
+ if (overhead < overhead_skip) {
+ HWY_WARN("Measurement failed: overhead %d < %d\n",
+ static_cast<int>(overhead), static_cast<int>(overhead_skip));
+ return 0;
+ }
+
+ if (p.verbose) {
+ printf("#inputs=%5d,%5d overhead=%5d,%5d\n", static_cast<int>(full.size()),
+ static_cast<int>(subset.size()), static_cast<int>(overhead),
+ static_cast<int>(overhead_skip));
+ }
+
+ double max_rel_mad = 0.0;
+ const timer::Ticks total = TotalDuration(func, arg, &full, p, &max_rel_mad);
+
+ for (size_t i = 0; i < unique.size(); ++i) {
+ FillSubset(full, unique[i], num_skip, &subset);
+ const timer::Ticks total_skip =
+ TotalDuration(func, arg, &subset, p, &max_rel_mad);
+
+ if (total < total_skip) {
+ HWY_WARN("Measurement failed: total %f < %f\n",
+ static_cast<double>(total), static_cast<double>(total_skip));
+ return 0;
+ }
+
+ const timer::Ticks duration =
+ (total - overhead) - (total_skip - overhead_skip);
+ results[i].input = unique[i];
+ results[i].ticks = static_cast<float>(duration) * mul;
+ results[i].variability = static_cast<float>(max_rel_mad);
+ }
+
+ return unique.size();
+}
+
+} // namespace hwy
diff --git a/third_party/highway/hwy/nanobenchmark.h b/third_party/highway/hwy/nanobenchmark.h
index 9001767a51..4b1a0b6b17 100644
--- a/third_party/highway/hwy/nanobenchmark.h
+++ b/third_party/highway/hwy/nanobenchmark.h
@@ -50,6 +50,7 @@
#include "third_party/highway/hwy/highway_export.h"
#include "third_party/highway/hwy/timer.h" // IWYU pragma: export
+#include "third_party/highway/hwy/base.h"
namespace hwy {
@@ -114,6 +115,14 @@ struct Result {
float variability;
};
+// Returns a Params struct with customized configuration for benchmarks.
+// Specifically limits `max_evals` to prevent timeout in tests.
+static inline Params DefaultBenchmarkParams() {
+ Params p;
+ p.max_evals = HWY_IS_DEBUG_BUILD ? 3 : 4;
+ return p;
+}
+
// Precisely measures the number of ticks elapsed when calling "func" with the
// given inputs, shuffled to ensure realistic branch prediction hit rates.
//
@@ -132,8 +141,8 @@ HWY_DLLEXPORT size_t Measure(Func func, const uint8_t* arg,
// Calls operator() of the given closure (lambda function).
template <class Closure>
-static FuncOutput CallClosure(const Closure* f, const FuncInput input) {
- return (*f)(input);
+static FuncOutput CallClosure(const void* f, const FuncInput input) {
+ return (*reinterpret_cast<const Closure*>(f))(input);
}
// Same as Measure, except "closure" is typically a lambda function of
@@ -143,7 +152,7 @@ static inline size_t MeasureClosure(const Closure& closure,
const FuncInput* inputs,
const size_t num_inputs, Result* results,
const Params& p = Params()) {
- return Measure(reinterpret_cast<Func>(&CallClosure<Closure>),
+ return Measure(static_cast<Func>(&CallClosure<Closure>),
reinterpret_cast<const uint8_t*>(&closure), inputs, num_inputs,
results, p);
}
diff --git a/third_party/highway/hwy/ops/arm_neon-inl.h b/third_party/highway/hwy/ops/arm_neon-inl.h
index f7e587eb3f..176f681223 100644
--- a/third_party/highway/hwy/ops/arm_neon-inl.h
+++ b/third_party/highway/hwy/ops/arm_neon-inl.h
@@ -1,5 +1,5 @@
// Copyright 2019 Google LLC
-// Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
+// Copyright 2024-2026 Arm Limited and/or its affiliates <open-source-office@arm.com>
// SPDX-License-Identifier: Apache-2.0
// SPDX-License-Identifier: BSD-3-Clause
//
@@ -142,29 +142,6 @@ namespace detail { // for code folding and Raw128
HWY_NEON_DEF_FUNCTION(int64, 2, name, prefix##q, infix, s64, args) \
HWY_NEON_DEF_FUNCTION(int64, 1, name, prefix, infix, s64, args)
-// Clang 17 crashes with bf16, see github.com/llvm/llvm-project/issues/64179.
-#undef HWY_NEON_HAVE_BFLOAT16
-#if HWY_HAVE_SCALAR_BF16_TYPE && \
- ((HWY_TARGET == HWY_NEON_BF16 && \
- (!HWY_COMPILER_CLANG || HWY_COMPILER_CLANG >= 1800)) || \
- defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC))
-#define HWY_NEON_HAVE_BFLOAT16 1
-#else
-#define HWY_NEON_HAVE_BFLOAT16 0
-#endif
-
-// HWY_NEON_HAVE_F32_TO_BF16C is defined if NEON vcvt_bf16_f32 and
-// vbfdot_f32 are available, even if the __bf16 type is disabled due to
-// GCC/Clang bugs.
-#undef HWY_NEON_HAVE_F32_TO_BF16C
-#if HWY_NEON_HAVE_BFLOAT16 || HWY_TARGET == HWY_NEON_BF16 || \
- (defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) && \
- (HWY_COMPILER_GCC_ACTUAL >= 1000 || HWY_COMPILER_CLANG >= 1100))
-#define HWY_NEON_HAVE_F32_TO_BF16C 1
-#else
-#define HWY_NEON_HAVE_F32_TO_BF16C 0
-#endif
-
// bfloat16_t
#if HWY_NEON_HAVE_BFLOAT16
#define HWY_NEON_DEF_FUNCTION_BFLOAT_16(name, prefix, infix, args) \
@@ -907,6 +884,249 @@ using DFromM = Simd<typename M::PrivateT, M::kPrivateN, 0>;
template <class V>
using TFromV = typename V::PrivateT;
+// TODO(janwas): ForDemoteVectors, in convert_test and demote_test, appear to
+// instantiate this with D = double x 4. The cause is unknown. Previously,
+// defining this in terms of Set rejected that via SFINAE because only
+// V_SIZE = 16 and V_SIZE <= 8 overloads were defined. As a workaround,
+// truncate the lane count to 128 bits.
+template <class D>
+using VFromD =
+ Vec128<TFromD<D>, HWY_MIN(16 / sizeof(TFromD<D>), MaxLanes(D()))>;
+
+// ------------------------------ BitCast
+
+namespace detail {
+
+// Converts from Vec128<T, N> to Vec128<uint8_t, N * sizeof(T)> using the
+// vreinterpret*_u8_*() set of functions.
+#define HWY_NEON_BUILD_TPL_HWY_CAST_TO_U8
+#define HWY_NEON_BUILD_RET_HWY_CAST_TO_U8(type, size) \
+ Vec128<uint8_t, size * sizeof(type##_t)>
+#define HWY_NEON_BUILD_PARAM_HWY_CAST_TO_U8(type, size) Vec128<type##_t, size> v
+#define HWY_NEON_BUILD_ARG_HWY_CAST_TO_U8 v.raw
+
+// Special case of u8 to u8 since vreinterpret*_u8_u8 is obviously not defined.
+template <size_t N>
+HWY_INLINE Vec128<uint8_t, N> BitCastToByte(Vec128<uint8_t, N> v) {
+ return v;
+}
+
+HWY_NEON_DEF_FUNCTION_ALL_FLOATS(BitCastToByte, vreinterpret, _u8_,
+ HWY_CAST_TO_U8)
+HWY_NEON_DEF_FUNCTION_BFLOAT_16(BitCastToByte, vreinterpret, _u8_,
+ HWY_CAST_TO_U8)
+
+HWY_NEON_DEF_FUNCTION_INTS(BitCastToByte, vreinterpret, _u8_, HWY_CAST_TO_U8)
+HWY_NEON_DEF_FUNCTION_UINT_16(BitCastToByte, vreinterpret, _u8_, HWY_CAST_TO_U8)
+HWY_NEON_DEF_FUNCTION_UINT_32(BitCastToByte, vreinterpret, _u8_, HWY_CAST_TO_U8)
+HWY_NEON_DEF_FUNCTION_UINT_64(BitCastToByte, vreinterpret, _u8_, HWY_CAST_TO_U8)
+
+#if !HWY_HAVE_FLOAT16
+#if HWY_NEON_HAVE_F16C
+HWY_NEON_DEF_FUNCTION_FLOAT_16_UNCONDITIONAL(BitCastToByte, vreinterpret, _u8_,
+ HWY_CAST_TO_U8)
+#else
+template <size_t N>
+HWY_INLINE Vec128<uint8_t, N * 2> BitCastToByte(Vec128<float16_t, N> v) {
+ return BitCastToByte(Vec128<uint16_t, N>(v.raw));
+}
+#endif // HWY_NEON_HAVE_F16C
+#endif // !HWY_HAVE_FLOAT16
+
+#if !HWY_NEON_HAVE_BFLOAT16
+template <size_t N>
+HWY_INLINE Vec128<uint8_t, N * 2> BitCastToByte(Vec128<bfloat16_t, N> v) {
+ return BitCastToByte(Vec128<uint16_t, N>(v.raw));
+}
+#endif // !HWY_NEON_HAVE_BFLOAT16
+
+#undef HWY_NEON_BUILD_TPL_HWY_CAST_TO_U8
+#undef HWY_NEON_BUILD_RET_HWY_CAST_TO_U8
+#undef HWY_NEON_BUILD_PARAM_HWY_CAST_TO_U8
+#undef HWY_NEON_BUILD_ARG_HWY_CAST_TO_U8
+
+template <class D, HWY_IF_U8_D(D)>
+HWY_INLINE VFromD<D> BitCastFromByte(D /* tag */, VFromD<D> v) {
+ return v;
+}
+
+// 64-bit or less:
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 8), HWY_IF_I8_D(D)>
+HWY_INLINE VFromD<D> BitCastFromByte(D /* tag */,
+ VFromD<RebindToUnsigned<D>> v) {
+ return VFromD<D>(vreinterpret_s8_u8(v.raw));
+}
+template <class D, HWY_IF_V_SIZE_LE_D(D, 8), HWY_IF_U16_D(D)>
+HWY_INLINE VFromD<D> BitCastFromByte(D /* tag */,
+ VFromD<Repartition<uint8_t, D>> v) {
+ return VFromD<D>(vreinterpret_u16_u8(v.raw));
+}
+template <class D, HWY_IF_V_SIZE_LE_D(D, 8), HWY_IF_I16_D(D)>
+HWY_INLINE VFromD<D> BitCastFromByte(D /* tag */,
+ VFromD<Repartition<uint8_t, D>> v) {
+ return VFromD<D>(vreinterpret_s16_u8(v.raw));
+}
+template <class D, HWY_IF_V_SIZE_LE_D(D, 8), HWY_IF_U32_D(D)>
+HWY_INLINE VFromD<D> BitCastFromByte(D /* tag */,
+ VFromD<Repartition<uint8_t, D>> v) {
+ return VFromD<D>(vreinterpret_u32_u8(v.raw));
+}
+template <class D, HWY_IF_V_SIZE_LE_D(D, 8), HWY_IF_I32_D(D)>
+HWY_INLINE VFromD<D> BitCastFromByte(D /* tag */,
+ VFromD<Repartition<uint8_t, D>> v) {
+ return VFromD<D>(vreinterpret_s32_u8(v.raw));
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 8), HWY_IF_U64_D(D)>
+HWY_INLINE Vec64<uint64_t> BitCastFromByte(D /* tag */, Vec64<uint8_t> v) {
+ return Vec64<uint64_t>(vreinterpret_u64_u8(v.raw));
+}
+template <class D, HWY_IF_V_SIZE_LE_D(D, 8), HWY_IF_I64_D(D)>
+HWY_INLINE Vec64<int64_t> BitCastFromByte(D /* tag */, Vec64<uint8_t> v) {
+ return Vec64<int64_t>(vreinterpret_s64_u8(v.raw));
+}
+
+// Cannot use HWY_NEON_IF_EMULATED_D due to the extra HWY_NEON_HAVE_F16C.
+template <class D, HWY_IF_V_SIZE_LE_D(D, 8), HWY_IF_F16_D(D)>
+HWY_INLINE VFromD<D> BitCastFromByte(D, VFromD<Repartition<uint8_t, D>> v) {
+#if HWY_HAVE_FLOAT16 || HWY_NEON_HAVE_F16C
+ return VFromD<D>(vreinterpret_f16_u8(v.raw));
+#else
+ const RebindToUnsigned<D> du;
+ return VFromD<D>(BitCastFromByte(du, v).raw);
+#endif
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 8), HWY_IF_BF16_D(D)>
+HWY_INLINE VFromD<D> BitCastFromByte(D, VFromD<Repartition<uint8_t, D>> v) {
+#if HWY_NEON_HAVE_BFLOAT16
+ return VFromD<D>(vreinterpret_bf16_u8(v.raw));
+#else
+ const RebindToUnsigned<D> du;
+ return VFromD<D>(BitCastFromByte(du, v).raw);
+#endif
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 8), HWY_IF_F32_D(D)>
+HWY_INLINE VFromD<D> BitCastFromByte(D /* tag */,
+ VFromD<Repartition<uint8_t, D>> v) {
+ return VFromD<D>(vreinterpret_f32_u8(v.raw));
+}
+
+#if HWY_HAVE_FLOAT64
+template <class D, HWY_IF_V_SIZE_LE_D(D, 8), HWY_IF_F64_D(D)>
+HWY_INLINE Vec64<double> BitCastFromByte(D /* tag */, Vec64<uint8_t> v) {
+ return Vec64<double>(vreinterpret_f64_u8(v.raw));
+}
+#endif // HWY_HAVE_FLOAT64
+
+// 128-bit full:
+
+template <class D, HWY_IF_I8_D(D)>
+HWY_INLINE Vec128<int8_t> BitCastFromByte(D /* tag */, Vec128<uint8_t> v) {
+ return Vec128<int8_t>(vreinterpretq_s8_u8(v.raw));
+}
+template <class D, HWY_IF_U16_D(D)>
+HWY_INLINE Vec128<uint16_t> BitCastFromByte(D /* tag */, Vec128<uint8_t> v) {
+ return Vec128<uint16_t>(vreinterpretq_u16_u8(v.raw));
+}
+template <class D, HWY_IF_I16_D(D)>
+HWY_INLINE Vec128<int16_t> BitCastFromByte(D /* tag */, Vec128<uint8_t> v) {
+ return Vec128<int16_t>(vreinterpretq_s16_u8(v.raw));
+}
+template <class D, HWY_IF_U32_D(D)>
+HWY_INLINE Vec128<uint32_t> BitCastFromByte(D /* tag */, Vec128<uint8_t> v) {
+ return Vec128<uint32_t>(vreinterpretq_u32_u8(v.raw));
+}
+template <class D, HWY_IF_I32_D(D)>
+HWY_INLINE Vec128<int32_t> BitCastFromByte(D /* tag */, Vec128<uint8_t> v) {
+ return Vec128<int32_t>(vreinterpretq_s32_u8(v.raw));
+}
+template <class D, HWY_IF_U64_D(D)>
+HWY_INLINE Vec128<uint64_t> BitCastFromByte(D /* tag */, Vec128<uint8_t> v) {
+ return Vec128<uint64_t>(vreinterpretq_u64_u8(v.raw));
+}
+template <class D, HWY_IF_I64_D(D)>
+HWY_INLINE Vec128<int64_t> BitCastFromByte(D /* tag */, Vec128<uint8_t> v) {
+ return Vec128<int64_t>(vreinterpretq_s64_u8(v.raw));
+}
+
+template <class D, HWY_IF_F32_D(D)>
+HWY_INLINE Vec128<float> BitCastFromByte(D /* tag */, Vec128<uint8_t> v) {
+ return Vec128<float>(vreinterpretq_f32_u8(v.raw));
+}
+
+#if HWY_HAVE_FLOAT64
+template <class D, HWY_IF_F64_D(D)>
+HWY_INLINE Vec128<double> BitCastFromByte(D /* tag */, Vec128<uint8_t> v) {
+ return Vec128<double>(vreinterpretq_f64_u8(v.raw));
+}
+#endif // HWY_HAVE_FLOAT64
+
+// Cannot use HWY_NEON_IF_EMULATED_D due to the extra HWY_NEON_HAVE_F16C.
+template <class D, HWY_IF_F16_D(D)>
+HWY_INLINE VFromD<D> BitCastFromByte(D, Vec128<uint8_t> v) {
+#if HWY_HAVE_FLOAT16 || HWY_NEON_HAVE_F16C
+ return VFromD<D>(vreinterpretq_f16_u8(v.raw));
+#else
+ return VFromD<D>(BitCastFromByte(RebindToUnsigned<D>(), v).raw);
+#endif
+}
+
+template <class D, HWY_IF_BF16_D(D)>
+HWY_INLINE VFromD<D> BitCastFromByte(D, Vec128<uint8_t> v) {
+#if HWY_NEON_HAVE_BFLOAT16
+ return VFromD<D>(vreinterpretq_bf16_u8(v.raw));
+#else
+ return VFromD<D>(BitCastFromByte(RebindToUnsigned<D>(), v).raw);
+#endif
+}
+
+} // namespace detail
+
+template <class D, class FromT>
+HWY_API VFromD<D> BitCast(D d,
+ Vec128<FromT, Repartition<FromT, D>().MaxLanes()> v) {
+ return detail::BitCastFromByte(d, detail::BitCastToByte(v));
+}
+
+// ------------------------------ ResizeBitCast
+
+// <= 8 byte vector to <= 8 byte vector
+template <class D, class FromV, HWY_IF_V_SIZE_LE_V(FromV, 8),
+ HWY_IF_V_SIZE_LE_D(D, 8)>
+HWY_API VFromD<D> ResizeBitCast(D d, FromV v) {
+ const Repartition<uint8_t, decltype(d)> du8;
+ return BitCast(d, VFromD<decltype(du8)>{detail::BitCastToByte(v).raw});
+}
+
+// 16-byte vector to 16-byte vector: same as BitCast
+template <class D, class FromV, HWY_IF_V_SIZE_V(FromV, 16),
+ HWY_IF_V_SIZE_D(D, 16)>
+HWY_API VFromD<D> ResizeBitCast(D d, FromV v) {
+ return BitCast(d, v);
+}
+
+// 16-byte vector to <= 8-byte vector
+template <class D, class FromV, HWY_IF_V_SIZE_V(FromV, 16),
+ HWY_IF_V_SIZE_LE_D(D, 8)>
+HWY_API VFromD<D> ResizeBitCast(D d, FromV v) {
+ const DFromV<decltype(v)> d_from;
+ const Half<decltype(d_from)> dh_from;
+ return ResizeBitCast(d, LowerHalf(dh_from, v));
+}
+
+// <= 8-bit vector to 16-byte vector
+template <class D, class FromV, HWY_IF_V_SIZE_LE_V(FromV, 8),
+ HWY_IF_V_SIZE_D(D, 16)>
+HWY_API VFromD<D> ResizeBitCast(D d, FromV v) {
+ const Full64<TFromV<FromV>> d_full64_from;
+ const Full128<TFromV<FromV>> d_full128_from;
+ return BitCast(d, Combine(d_full128_from, Zero(d_full64_from),
+ ResizeBitCast(d_full64_from, v)));
+}
+
// ------------------------------ Set
namespace detail {
@@ -923,16 +1143,26 @@ namespace detail {
#define HWY_NEON_BUILD_ARG_HWY_SET t
HWY_NEON_DEF_FUNCTION_ALL_TYPES(NativeSet, vdup, _n_, HWY_SET)
-#if !HWY_HAVE_FLOAT16 && HWY_NEON_HAVE_F16C
+#if !HWY_HAVE_FLOAT16 && HWY_NEON_HAVE_F16C && HWY_HAVE_SCALAR_F16_TYPE
HWY_NEON_DEF_FUNCTION_FLOAT_16_UNCONDITIONAL(NativeSet, vdup, _n_, HWY_SET)
#endif
HWY_NEON_DEF_FUNCTION_BFLOAT_16(NativeSet, vdup, _n_, HWY_SET)
-template <class D, HWY_NEON_IF_EMULATED_D(D)>
-HWY_API Vec128<TFromD<D>, MaxLanes(D())> NativeSet(D d, TFromD<D> t) {
+#if !HWY_NEON_HAVE_F16C || !HWY_HAVE_SCALAR_F16_TYPE
+template <class D, HWY_IF_F16_D(D)>
+HWY_API VFromD<D> NativeSet(D d, TFromD<D> t) {
+ const uint16_t tu = BitCastScalar<uint16_t>(t);
+ return BitCast(d, Set(RebindToUnsigned<D>(), tu));
+}
+#endif
+
+#if !HWY_NEON_HAVE_BFLOAT16
+template <class D, HWY_IF_BF16_D(D)>
+HWY_API VFromD<D> NativeSet(D d, TFromD<D> t) {
const uint16_t tu = BitCastScalar<uint16_t>(t);
- return Vec128<TFromD<D>, d.MaxLanes()>(Set(RebindToUnsigned<D>(), tu).raw);
+ return BitCast(d, Set(RebindToUnsigned<D>(), tu));
}
+#endif
#undef HWY_NEON_BUILD_TPL_HWY_SET
#undef HWY_NEON_BUILD_RET_HWY_SET
@@ -941,25 +1171,21 @@ HWY_API Vec128<TFromD<D>, MaxLanes(D())> NativeSet(D d, TFromD<D> t) {
} // namespace detail
-// Full vector. Cannot yet use VFromD because that is defined in terms of Set.
+// Full vector.
// Do not use a typename T = TFromD<D> argument because T will be deduced from
// the actual argument type, which can differ from TFromD<D>.
template <class D, HWY_IF_V_SIZE_D(D, 16), typename T>
-HWY_INLINE Vec128<TFromD<D>> Set(D /* tag */, T t) {
+HWY_INLINE VFromD<D> Set(D /* tag */, T t) {
return detail::NativeSet(Full128<TFromD<D>>(), static_cast<TFromD<D>>(t));
}
// Partial vector: create 64-bit and return wrapper.
template <class D, HWY_IF_V_SIZE_LE_D(D, 8), typename T>
-HWY_API Vec128<TFromD<D>, MaxLanes(D())> Set(D /* tag */, T t) {
+HWY_API VFromD<D> Set(D /* tag */, T t) {
const Full64<TFromD<D>> dfull;
- return Vec128<TFromD<D>, MaxLanes(D())>(
- detail::NativeSet(dfull, static_cast<TFromD<D>>(t)).raw);
+ return VFromD<D>(detail::NativeSet(dfull, static_cast<TFromD<D>>(t)).raw);
}
-template <class D>
-using VFromD = decltype(Set(D(), TFromD<D>()));
-
template <class D>
HWY_API VFromD<D> Zero(D d) {
// Default ctor also works for bfloat16_t and float16_t.
@@ -1211,7 +1437,8 @@ HWY_API VFromD<D> Dup128VecFromValues(D d, TFromD<D> t0, TFromD<D> t1,
BitCastScalar<int16_t>(t6), BitCastScalar<int16_t>(t7)));
}
-#if (HWY_COMPILER_GCC || HWY_COMPILER_CLANGCL) && HWY_NEON_HAVE_F16C
+#if (HWY_COMPILER_GCC || HWY_COMPILER_CLANGCL) && HWY_NEON_HAVE_F16C && \
+ HWY_HAVE_SCALAR_F16_TYPE
template <class D, HWY_IF_F16_D(D), HWY_IF_V_SIZE_LE_D(D, 8)>
HWY_API VFromD<D> Dup128VecFromValues(D d, TFromD<D> t0, TFromD<D> t1,
TFromD<D> t2, TFromD<D> t3,
@@ -1359,283 +1586,49 @@ HWY_API Vec128<int16_t> Combine(D /* tag */, Vec64<int16_t> hi,
Vec64<int16_t> lo) {
return Vec128<int16_t>(vcombine_s16(lo.raw, hi.raw));
}
-template <class D, HWY_IF_I32_D(D)>
-HWY_API Vec128<int32_t> Combine(D /* tag */, Vec64<int32_t> hi,
- Vec64<int32_t> lo) {
- return Vec128<int32_t>(vcombine_s32(lo.raw, hi.raw));
-}
-template <class D, HWY_IF_I64_D(D)>
-HWY_API Vec128<int64_t> Combine(D /* tag */, Vec64<int64_t> hi,
- Vec64<int64_t> lo) {
- return Vec128<int64_t>(vcombine_s64(lo.raw, hi.raw));
-}
-
-#if HWY_HAVE_FLOAT16
-template <class D, HWY_IF_F16_D(D)>
-HWY_API Vec128<float16_t> Combine(D, Vec64<float16_t> hi, Vec64<float16_t> lo) {
- return Vec128<float16_t>(vcombine_f16(lo.raw, hi.raw));
-}
-#endif // HWY_HAVE_FLOAT16
-
-#if HWY_NEON_HAVE_BFLOAT16
-template <class D, HWY_IF_BF16_D(D)>
-HWY_API VFromD<D> Combine(D, Vec64<bfloat16_t> hi, Vec64<bfloat16_t> lo) {
- return VFromD<D>(vcombine_bf16(lo.raw, hi.raw));
-}
-#endif // HWY_NEON_HAVE_BFLOAT16
-
-template <class D, class DH = Half<D>, HWY_NEON_IF_EMULATED_D(D)>
-HWY_API VFromD<D> Combine(D d, VFromD<DH> hi, VFromD<DH> lo) {
- const RebindToUnsigned<D> du;
- const Half<decltype(du)> duh;
- return BitCast(d, Combine(du, BitCast(duh, hi), BitCast(duh, lo)));
-}
-
-template <class D, HWY_IF_F32_D(D)>
-HWY_API Vec128<float> Combine(D /* tag */, Vec64<float> hi, Vec64<float> lo) {
- return Vec128<float>(vcombine_f32(lo.raw, hi.raw));
-}
-#if HWY_HAVE_FLOAT64
-template <class D, HWY_IF_F64_D(D)>
-HWY_API Vec128<double> Combine(D /* tag */, Vec64<double> hi,
- Vec64<double> lo) {
- return Vec128<double>(vcombine_f64(lo.raw, hi.raw));
-}
-#endif // HWY_HAVE_FLOAT64
-
-// ------------------------------ BitCast
-
-namespace detail {
-
-// Converts from Vec128<T, N> to Vec128<uint8_t, N * sizeof(T)> using the
-// vreinterpret*_u8_*() set of functions.
-#define HWY_NEON_BUILD_TPL_HWY_CAST_TO_U8
-#define HWY_NEON_BUILD_RET_HWY_CAST_TO_U8(type, size) \
- Vec128<uint8_t, size * sizeof(type##_t)>
-#define HWY_NEON_BUILD_PARAM_HWY_CAST_TO_U8(type, size) Vec128<type##_t, size> v
-#define HWY_NEON_BUILD_ARG_HWY_CAST_TO_U8 v.raw
-
-// Special case of u8 to u8 since vreinterpret*_u8_u8 is obviously not defined.
-template <size_t N>
-HWY_INLINE Vec128<uint8_t, N> BitCastToByte(Vec128<uint8_t, N> v) {
- return v;
-}
-
-HWY_NEON_DEF_FUNCTION_ALL_FLOATS(BitCastToByte, vreinterpret, _u8_,
- HWY_CAST_TO_U8)
-HWY_NEON_DEF_FUNCTION_BFLOAT_16(BitCastToByte, vreinterpret, _u8_,
- HWY_CAST_TO_U8)
-
-HWY_NEON_DEF_FUNCTION_INTS(BitCastToByte, vreinterpret, _u8_, HWY_CAST_TO_U8)
-HWY_NEON_DEF_FUNCTION_UINT_16(BitCastToByte, vreinterpret, _u8_, HWY_CAST_TO_U8)
-HWY_NEON_DEF_FUNCTION_UINT_32(BitCastToByte, vreinterpret, _u8_, HWY_CAST_TO_U8)
-HWY_NEON_DEF_FUNCTION_UINT_64(BitCastToByte, vreinterpret, _u8_, HWY_CAST_TO_U8)
-
-#if !HWY_HAVE_FLOAT16
-#if HWY_NEON_HAVE_F16C
-HWY_NEON_DEF_FUNCTION_FLOAT_16_UNCONDITIONAL(BitCastToByte, vreinterpret, _u8_,
- HWY_CAST_TO_U8)
-#else
-template <size_t N>
-HWY_INLINE Vec128<uint8_t, N * 2> BitCastToByte(Vec128<float16_t, N> v) {
- return BitCastToByte(Vec128<uint16_t, N>(v.raw));
-}
-#endif // HWY_NEON_HAVE_F16C
-#endif // !HWY_HAVE_FLOAT16
-
-#if !HWY_NEON_HAVE_BFLOAT16
-template <size_t N>
-HWY_INLINE Vec128<uint8_t, N * 2> BitCastToByte(Vec128<bfloat16_t, N> v) {
- return BitCastToByte(Vec128<uint16_t, N>(v.raw));
-}
-#endif // !HWY_NEON_HAVE_BFLOAT16
-
-#undef HWY_NEON_BUILD_TPL_HWY_CAST_TO_U8
-#undef HWY_NEON_BUILD_RET_HWY_CAST_TO_U8
-#undef HWY_NEON_BUILD_PARAM_HWY_CAST_TO_U8
-#undef HWY_NEON_BUILD_ARG_HWY_CAST_TO_U8
-
-template <class D, HWY_IF_U8_D(D)>
-HWY_INLINE VFromD<D> BitCastFromByte(D /* tag */, VFromD<D> v) {
- return v;
-}
-
-// 64-bit or less:
-
-template <class D, HWY_IF_V_SIZE_LE_D(D, 8), HWY_IF_I8_D(D)>
-HWY_INLINE VFromD<D> BitCastFromByte(D /* tag */,
- VFromD<RebindToUnsigned<D>> v) {
- return VFromD<D>(vreinterpret_s8_u8(v.raw));
-}
-template <class D, HWY_IF_V_SIZE_LE_D(D, 8), HWY_IF_U16_D(D)>
-HWY_INLINE VFromD<D> BitCastFromByte(D /* tag */,
- VFromD<Repartition<uint8_t, D>> v) {
- return VFromD<D>(vreinterpret_u16_u8(v.raw));
-}
-template <class D, HWY_IF_V_SIZE_LE_D(D, 8), HWY_IF_I16_D(D)>
-HWY_INLINE VFromD<D> BitCastFromByte(D /* tag */,
- VFromD<Repartition<uint8_t, D>> v) {
- return VFromD<D>(vreinterpret_s16_u8(v.raw));
-}
-template <class D, HWY_IF_V_SIZE_LE_D(D, 8), HWY_IF_U32_D(D)>
-HWY_INLINE VFromD<D> BitCastFromByte(D /* tag */,
- VFromD<Repartition<uint8_t, D>> v) {
- return VFromD<D>(vreinterpret_u32_u8(v.raw));
-}
-template <class D, HWY_IF_V_SIZE_LE_D(D, 8), HWY_IF_I32_D(D)>
-HWY_INLINE VFromD<D> BitCastFromByte(D /* tag */,
- VFromD<Repartition<uint8_t, D>> v) {
- return VFromD<D>(vreinterpret_s32_u8(v.raw));
-}
-
-template <class D, HWY_IF_V_SIZE_LE_D(D, 8), HWY_IF_U64_D(D)>
-HWY_INLINE Vec64<uint64_t> BitCastFromByte(D /* tag */, Vec64<uint8_t> v) {
- return Vec64<uint64_t>(vreinterpret_u64_u8(v.raw));
-}
-template <class D, HWY_IF_V_SIZE_LE_D(D, 8), HWY_IF_I64_D(D)>
-HWY_INLINE Vec64<int64_t> BitCastFromByte(D /* tag */, Vec64<uint8_t> v) {
- return Vec64<int64_t>(vreinterpret_s64_u8(v.raw));
-}
-
-// Cannot use HWY_NEON_IF_EMULATED_D due to the extra HWY_NEON_HAVE_F16C.
-template <class D, HWY_IF_V_SIZE_LE_D(D, 8), HWY_IF_F16_D(D)>
-HWY_INLINE VFromD<D> BitCastFromByte(D, VFromD<Repartition<uint8_t, D>> v) {
-#if HWY_HAVE_FLOAT16 || HWY_NEON_HAVE_F16C
- return VFromD<D>(vreinterpret_f16_u8(v.raw));
-#else
- const RebindToUnsigned<D> du;
- return VFromD<D>(BitCastFromByte(du, v).raw);
-#endif
-}
-
-template <class D, HWY_IF_V_SIZE_LE_D(D, 8), HWY_IF_BF16_D(D)>
-HWY_INLINE VFromD<D> BitCastFromByte(D, VFromD<Repartition<uint8_t, D>> v) {
-#if HWY_NEON_HAVE_BFLOAT16
- return VFromD<D>(vreinterpret_bf16_u8(v.raw));
-#else
- const RebindToUnsigned<D> du;
- return VFromD<D>(BitCastFromByte(du, v).raw);
-#endif
-}
-
-template <class D, HWY_IF_V_SIZE_LE_D(D, 8), HWY_IF_F32_D(D)>
-HWY_INLINE VFromD<D> BitCastFromByte(D /* tag */,
- VFromD<Repartition<uint8_t, D>> v) {
- return VFromD<D>(vreinterpret_f32_u8(v.raw));
-}
-
-#if HWY_HAVE_FLOAT64
-template <class D, HWY_IF_V_SIZE_LE_D(D, 8), HWY_IF_F64_D(D)>
-HWY_INLINE Vec64<double> BitCastFromByte(D /* tag */, Vec64<uint8_t> v) {
- return Vec64<double>(vreinterpret_f64_u8(v.raw));
-}
-#endif // HWY_HAVE_FLOAT64
-
-// 128-bit full:
-
-template <class D, HWY_IF_I8_D(D)>
-HWY_INLINE Vec128<int8_t> BitCastFromByte(D /* tag */, Vec128<uint8_t> v) {
- return Vec128<int8_t>(vreinterpretq_s8_u8(v.raw));
-}
-template <class D, HWY_IF_U16_D(D)>
-HWY_INLINE Vec128<uint16_t> BitCastFromByte(D /* tag */, Vec128<uint8_t> v) {
- return Vec128<uint16_t>(vreinterpretq_u16_u8(v.raw));
-}
-template <class D, HWY_IF_I16_D(D)>
-HWY_INLINE Vec128<int16_t> BitCastFromByte(D /* tag */, Vec128<uint8_t> v) {
- return Vec128<int16_t>(vreinterpretq_s16_u8(v.raw));
-}
-template <class D, HWY_IF_U32_D(D)>
-HWY_INLINE Vec128<uint32_t> BitCastFromByte(D /* tag */, Vec128<uint8_t> v) {
- return Vec128<uint32_t>(vreinterpretq_u32_u8(v.raw));
-}
-template <class D, HWY_IF_I32_D(D)>
-HWY_INLINE Vec128<int32_t> BitCastFromByte(D /* tag */, Vec128<uint8_t> v) {
- return Vec128<int32_t>(vreinterpretq_s32_u8(v.raw));
-}
-template <class D, HWY_IF_U64_D(D)>
-HWY_INLINE Vec128<uint64_t> BitCastFromByte(D /* tag */, Vec128<uint8_t> v) {
- return Vec128<uint64_t>(vreinterpretq_u64_u8(v.raw));
-}
-template <class D, HWY_IF_I64_D(D)>
-HWY_INLINE Vec128<int64_t> BitCastFromByte(D /* tag */, Vec128<uint8_t> v) {
- return Vec128<int64_t>(vreinterpretq_s64_u8(v.raw));
-}
-
-template <class D, HWY_IF_F32_D(D)>
-HWY_INLINE Vec128<float> BitCastFromByte(D /* tag */, Vec128<uint8_t> v) {
- return Vec128<float>(vreinterpretq_f32_u8(v.raw));
-}
-
-#if HWY_HAVE_FLOAT64
-template <class D, HWY_IF_F64_D(D)>
-HWY_INLINE Vec128<double> BitCastFromByte(D /* tag */, Vec128<uint8_t> v) {
- return Vec128<double>(vreinterpretq_f64_u8(v.raw));
-}
-#endif // HWY_HAVE_FLOAT64
-
-// Cannot use HWY_NEON_IF_EMULATED_D due to the extra HWY_NEON_HAVE_F16C.
-template <class D, HWY_IF_F16_D(D)>
-HWY_INLINE VFromD<D> BitCastFromByte(D, Vec128<uint8_t> v) {
-#if HWY_HAVE_FLOAT16 || HWY_NEON_HAVE_F16C
- return VFromD<D>(vreinterpretq_f16_u8(v.raw));
-#else
- return VFromD<D>(BitCastFromByte(RebindToUnsigned<D>(), v).raw);
-#endif
-}
-
-template <class D, HWY_IF_BF16_D(D)>
-HWY_INLINE VFromD<D> BitCastFromByte(D, Vec128<uint8_t> v) {
-#if HWY_NEON_HAVE_BFLOAT16
- return VFromD<D>(vreinterpretq_bf16_u8(v.raw));
-#else
- return VFromD<D>(BitCastFromByte(RebindToUnsigned<D>(), v).raw);
-#endif
+template <class D, HWY_IF_I32_D(D)>
+HWY_API Vec128<int32_t> Combine(D /* tag */, Vec64<int32_t> hi,
+ Vec64<int32_t> lo) {
+ return Vec128<int32_t>(vcombine_s32(lo.raw, hi.raw));
}
-
-} // namespace detail
-
-template <class D, class FromT>
-HWY_API VFromD<D> BitCast(D d,
- Vec128<FromT, Repartition<FromT, D>().MaxLanes()> v) {
- return detail::BitCastFromByte(d, detail::BitCastToByte(v));
+template <class D, HWY_IF_I64_D(D)>
+HWY_API Vec128<int64_t> Combine(D /* tag */, Vec64<int64_t> hi,
+ Vec64<int64_t> lo) {
+ return Vec128<int64_t>(vcombine_s64(lo.raw, hi.raw));
}
-// ------------------------------ ResizeBitCast
-
-// <= 8 byte vector to <= 8 byte vector
-template <class D, class FromV, HWY_IF_V_SIZE_LE_V(FromV, 8),
- HWY_IF_V_SIZE_LE_D(D, 8)>
-HWY_API VFromD<D> ResizeBitCast(D d, FromV v) {
- const Repartition<uint8_t, decltype(d)> du8;
- return BitCast(d, VFromD<decltype(du8)>{detail::BitCastToByte(v).raw});
+#if HWY_HAVE_FLOAT16
+template <class D, HWY_IF_F16_D(D)>
+HWY_API Vec128<float16_t> Combine(D, Vec64<float16_t> hi, Vec64<float16_t> lo) {
+ return Vec128<float16_t>(vcombine_f16(lo.raw, hi.raw));
}
+#endif // HWY_HAVE_FLOAT16
-// 16-byte vector to 16-byte vector: same as BitCast
-template <class D, class FromV, HWY_IF_V_SIZE_V(FromV, 16),
- HWY_IF_V_SIZE_D(D, 16)>
-HWY_API VFromD<D> ResizeBitCast(D d, FromV v) {
- return BitCast(d, v);
+#if HWY_NEON_HAVE_BFLOAT16
+template <class D, HWY_IF_BF16_D(D)>
+HWY_API VFromD<D> Combine(D, Vec64<bfloat16_t> hi, Vec64<bfloat16_t> lo) {
+ return VFromD<D>(vcombine_bf16(lo.raw, hi.raw));
}
+#endif // HWY_NEON_HAVE_BFLOAT16
-// 16-byte vector to <= 8-byte vector
-template <class D, class FromV, HWY_IF_V_SIZE_V(FromV, 16),
- HWY_IF_V_SIZE_LE_D(D, 8)>
-HWY_API VFromD<D> ResizeBitCast(D d, FromV v) {
- const DFromV<decltype(v)> d_from;
- const Half<decltype(d_from)> dh_from;
- return ResizeBitCast(d, LowerHalf(dh_from, v));
+template <class D, class DH = Half<D>, HWY_NEON_IF_EMULATED_D(D)>
+HWY_API VFromD<D> Combine(D d, VFromD<DH> hi, VFromD<DH> lo) {
+ const RebindToUnsigned<D> du;
+ const Half<decltype(du)> duh;
+ return BitCast(d, Combine(du, BitCast(duh, hi), BitCast(duh, lo)));
}
-// <= 8-bit vector to 16-byte vector
-template <class D, class FromV, HWY_IF_V_SIZE_LE_V(FromV, 8),
- HWY_IF_V_SIZE_D(D, 16)>
-HWY_API VFromD<D> ResizeBitCast(D d, FromV v) {
- const Full64<TFromV<FromV>> d_full64_from;
- const Full128<TFromV<FromV>> d_full128_from;
- return BitCast(d, Combine(d_full128_from, Zero(d_full64_from),
- ResizeBitCast(d_full64_from, v)));
+template <class D, HWY_IF_F32_D(D)>
+HWY_API Vec128<float> Combine(D /* tag */, Vec64<float> hi, Vec64<float> lo) {
+ return Vec128<float>(vcombine_f32(lo.raw, hi.raw));
+}
+#if HWY_HAVE_FLOAT64
+template <class D, HWY_IF_F64_D(D)>
+HWY_API Vec128<double> Combine(D /* tag */, Vec64<double> hi,
+ Vec64<double> lo) {
+ return Vec128<double>(vcombine_f64(lo.raw, hi.raw));
}
+#endif // HWY_HAVE_FLOAT64
// ------------------------------ GetLane
@@ -1950,10 +1943,74 @@ HWY_API Vec128<T, 16> InsertLane(const Vec128<T, 16> v, size_t i, T t) {
// ================================================== ARITHMETIC
// ------------------------------ Addition
-HWY_NEON_DEF_FUNCTION_ALL_TYPES(operator+, vadd, _, 2)
+HWY_NEON_DEF_FUNCTION_UINTS(operator+, vadd, _, 2)
+HWY_NEON_DEF_FUNCTION_ALL_FLOATS(operator+, vadd, _, 2)
+
+template <size_t N>
+HWY_API Vec128<int8_t, N> operator+(Vec128<int8_t, N> a, Vec128<int8_t, N> b) {
+ const DFromV<decltype(a)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(d, BitCast(du, a) + BitCast(du, b));
+}
+
+template <size_t N>
+HWY_API Vec128<int16_t, N> operator+(Vec128<int16_t, N> a,
+ Vec128<int16_t, N> b) {
+ const DFromV<decltype(a)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(d, BitCast(du, a) + BitCast(du, b));
+}
+
+template <size_t N>
+HWY_API Vec128<int32_t, N> operator+(Vec128<int32_t, N> a,
+ Vec128<int32_t, N> b) {
+ const DFromV<decltype(a)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(d, BitCast(du, a) + BitCast(du, b));
+}
+
+template <size_t N>
+HWY_API Vec128<int64_t, N> operator+(Vec128<int64_t, N> a,
+ Vec128<int64_t, N> b) {
+ const DFromV<decltype(a)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(d, BitCast(du, a) + BitCast(du, b));
+}
// ------------------------------ Subtraction
-HWY_NEON_DEF_FUNCTION_ALL_TYPES(operator-, vsub, _, 2)
+HWY_NEON_DEF_FUNCTION_UINTS(operator-, vsub, _, 2)
+HWY_NEON_DEF_FUNCTION_ALL_FLOATS(operator-, vsub, _, 2)
+
+template <size_t N>
+HWY_API Vec128<int8_t, N> operator-(Vec128<int8_t, N> a, Vec128<int8_t, N> b) {
+ const DFromV<decltype(a)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(d, BitCast(du, a) - BitCast(du, b));
+}
+
+template <size_t N>
+HWY_API Vec128<int16_t, N> operator-(Vec128<int16_t, N> a,
+ Vec128<int16_t, N> b) {
+ const DFromV<decltype(a)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(d, BitCast(du, a) - BitCast(du, b));
+}
+
+template <size_t N>
+HWY_API Vec128<int32_t, N> operator-(Vec128<int32_t, N> a,
+ Vec128<int32_t, N> b) {
+ const DFromV<decltype(a)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(d, BitCast(du, a) - BitCast(du, b));
+}
+
+template <size_t N>
+HWY_API Vec128<int64_t, N> operator-(Vec128<int64_t, N> a,
+ Vec128<int64_t, N> b) {
+ const DFromV<decltype(a)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(d, BitCast(du, a) - BitCast(du, b));
+}
// ------------------------------ SumsOf8
@@ -2476,9 +2533,31 @@ HWY_API Vec128<T, N> RoundingShiftRightSame(const Vec128<T, N> v, int bits) {
// All except ui64
HWY_NEON_DEF_FUNCTION_UINT_8_16_32(operator*, vmul, _, 2)
-HWY_NEON_DEF_FUNCTION_INT_8_16_32(operator*, vmul, _, 2)
HWY_NEON_DEF_FUNCTION_ALL_FLOATS(operator*, vmul, _, 2)
+template <size_t N>
+HWY_API Vec128<int8_t, N> operator*(Vec128<int8_t, N> a, Vec128<int8_t, N> b) {
+ const DFromV<decltype(a)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(d, BitCast(du, a) * BitCast(du, b));
+}
+
+template <size_t N>
+HWY_API Vec128<int16_t, N> operator*(Vec128<int16_t, N> a,
+ Vec128<int16_t, N> b) {
+ const DFromV<decltype(a)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(d, BitCast(du, a) * BitCast(du, b));
+}
+
+template <size_t N>
+HWY_API Vec128<int32_t, N> operator*(Vec128<int32_t, N> a,
+ Vec128<int32_t, N> b) {
+ const DFromV<decltype(a)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(d, BitCast(du, a) * BitCast(du, b));
+}
+
// ------------------------------ Integer multiplication
// Returns the upper sizeof(T)*8 bits of a * b in each lane.
@@ -2824,15 +2903,15 @@ HWY_API Vec128<T, N> And(const Vec128<T, N> a, const Vec128<T, N> b) {
// ------------------------------ AndNot
namespace detail {
-// reversed_andnot returns a & ~b.
-HWY_NEON_DEF_FUNCTION_INTS_UINTS(reversed_andnot, vbic, _, 2)
+// AndNotSwap returns a & ~b, whereas AndNot is defined as ~a & b.
+HWY_NEON_DEF_FUNCTION_INTS_UINTS(AndNotSwap, vbic, _, 2)
} // namespace detail
// Returns ~not_mask & mask.
template <typename T, size_t N, HWY_IF_NOT_FLOAT(T)>
HWY_API Vec128<T, N> AndNot(const Vec128<T, N> not_mask,
const Vec128<T, N> mask) {
- return detail::reversed_andnot(mask, not_mask);
+ return detail::AndNotSwap(mask, not_mask);
}
// Uses the u32/64 defined above.
@@ -2842,7 +2921,7 @@ HWY_API Vec128<T, N> AndNot(const Vec128<T, N> not_mask,
const DFromV<decltype(mask)> d;
const RebindToUnsigned<decltype(d)> du;
VFromD<decltype(du)> ret =
- detail::reversed_andnot(BitCast(du, mask), BitCast(du, not_mask));
+ detail::AndNotSwap(BitCast(du, mask), BitCast(du, not_mask));
return BitCast(d, ret);
}
@@ -2872,6 +2951,13 @@ HWY_API Vec128<T, N> Xor(const Vec128<T, N> a, const Vec128<T, N> b) {
// ------------------------------ Xor3
#if HWY_ARCH_ARM_A64 && defined(__ARM_FEATURE_SHA3)
+
+#ifdef HWY_NATIVE_XOR3
+#undef HWY_NATIVE_XOR3
+#else
+#define HWY_NATIVE_XOR3
+#endif
+
HWY_NEON_DEF_FUNCTION_FULL_UI(Xor3, veor3, _, 3)
// Half vectors are not natively supported. Two Xor are likely more efficient
@@ -2889,11 +2975,6 @@ HWY_API Vec128<T, N> Xor3(const Vec128<T, N> x1, const Vec128<T, N> x2,
return BitCast(d, Xor3(BitCast(du, x1), BitCast(du, x2), BitCast(du, x3)));
}
-#else
-template <typename T, size_t N>
-HWY_API Vec128<T, N> Xor3(Vec128<T, N> x1, Vec128<T, N> x2, Vec128<T, N> x3) {
- return Xor(x1, Xor(x2, x3));
-}
#endif
// ------------------------------ Or3
@@ -2908,26 +2989,45 @@ HWY_API Vec128<T, N> OrAnd(Vec128<T, N> o, Vec128<T, N> a1, Vec128<T, N> a2) {
return Or(o, And(a1, a2));
}
-// ------------------------------ IfVecThenElse
-template <typename T, size_t N>
-HWY_API Vec128<T, N> IfVecThenElse(Vec128<T, N> mask, Vec128<T, N> yes,
- Vec128<T, N> no) {
- return IfThenElse(MaskFromVec(mask), yes, no);
-}
-
-// ------------------------------ BitwiseIfThenElse
+// ------------------------------ XorAndNot
+#if HWY_ARCH_ARM_A64 && defined(__ARM_FEATURE_SHA3)
-#ifdef HWY_NATIVE_BITWISE_IF_THEN_ELSE
-#undef HWY_NATIVE_BITWISE_IF_THEN_ELSE
+#ifdef HWY_NATIVE_BCAX
+#undef HWY_NATIVE_BCAX
#else
-#define HWY_NATIVE_BITWISE_IF_THEN_ELSE
+#define HWY_NATIVE_BCAX
#endif
-template <class V>
-HWY_API V BitwiseIfThenElse(V mask, V yes, V no) {
- return IfVecThenElse(mask, yes, no);
+namespace detail {
+HWY_NEON_DEF_FUNCTION_FULL_UI(XorAndNotSwap, vbcax, _, 3)
+} // namespace detail
+
+// As with AndNot, swap the last two arguments because our "negated first"
+// convention mismatches the intrinsics, which have the negated arg last.
+template <class V, HWY_IF_V_SIZE_V(V, 16), HWY_IF_NOT_FLOAT_V(V)>
+HWY_API V XorAndNot(V x, V a1, V a2) {
+ return detail::XorAndNotSwap(x, a2, a1);
+}
+
+// Half vectors are not natively supported. Two ops are likely more efficient
+// than Combine to 128-bit.
+template <typename T, size_t N, HWY_IF_V_SIZE_LE(T, N, 8), HWY_IF_NOT_FLOAT(T)>
+HWY_API Vec128<T, N> XorAndNot(Vec128<T, N> x, Vec128<T, N> a1,
+ Vec128<T, N> a2) {
+ return Xor(x, AndNot(a1, a2));
+}
+
+template <typename T, size_t N, HWY_IF_FLOAT(T)>
+HWY_API Vec128<T, N> XorAndNot(const Vec128<T, N> x, const Vec128<T, N> a1,
+ const Vec128<T, N> a2) {
+ const DFromV<decltype(x)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(d,
+ XorAndNot(BitCast(du, x), BitCast(du, a1), BitCast(du, a2)));
}
+#endif
+
// ------------------------------ Operator overloads (internal-only if float)
template <typename T, size_t N>
@@ -3047,14 +3147,6 @@ HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Abs, vabs, _, 1)
HWY_NEON_DEF_FUNCTION_INT_8_16_32(SaturatedAbs, vqabs, _, 1)
-// ------------------------------ CopySign
-template <typename T, size_t N>
-HWY_API Vec128<T, N> CopySign(Vec128<T, N> magn, Vec128<T, N> sign) {
- static_assert(IsFloat<T>(), "Only makes sense for floating-point");
- const DFromV<decltype(magn)> d;
- return BitwiseIfThenElse(SignBit(d), sign, magn);
-}
-
// ------------------------------ CopySignToAbs
template <typename T, size_t N>
HWY_API Vec128<T, N> CopySignToAbs(Vec128<T, N> abs, Vec128<T, N> sign) {
@@ -3101,6 +3193,21 @@ HWY_API MFromD<DTo> RebindMask(DTo /* tag */, Mask128<TFrom, NFrom> m) {
// ------------------------------ IfThenElse
+// Workaround for incorrect codegen.
+#if HWY_ARCH_ARM_V7
+
+template <class V, class D = DFromV<V>>
+HWY_API V IfThenElse(MFromD<D> mask, V yes, V no) {
+ const RebindToUnsigned<D> du;
+ using VU = VFromD<decltype(du)>;
+ const VU no_u = BitCast(du, no);
+ const VU diff_u = BitCast(du, yes) ^ no_u;
+ const VU mask_u = BitCast(du, VecFromMask(D(), mask));
+ return BitCast(D(), no_u ^ (diff_u & mask_u));
+}
+
+#else // normal VBSL instruction
+
#define HWY_NEON_BUILD_TPL_HWY_IF
#define HWY_NEON_BUILD_RET_HWY_IF(type, size) Vec128<type##_t, size>
#define HWY_NEON_BUILD_PARAM_HWY_IF(type, size) \
@@ -3110,6 +3217,8 @@ HWY_API MFromD<DTo> RebindMask(DTo /* tag */, Mask128<TFrom, NFrom> m) {
HWY_NEON_DEF_FUNCTION_ALL_TYPES(IfThenElse, vbsl, _, HWY_IF)
+#endif // HWY_ARCH_ARM_V7
+
#if HWY_HAVE_FLOAT16
#define HWY_NEON_IF_EMULATED_IF_THEN_ELSE(V) HWY_IF_BF16(TFromV<V>)
#else
@@ -3165,6 +3274,33 @@ HWY_API Vec128<T, N> IfNegativeThenElse(Vec128<T, N> v, Vec128<T, N> yes,
return IfThenElse(m, yes, no);
}
+template <typename T, size_t N>
+HWY_API Vec128<T, N> IfVecThenElse(Vec128<T, N> mask, Vec128<T, N> yes,
+ Vec128<T, N> no) {
+ return IfThenElse(MaskFromVec(mask), yes, no);
+}
+
+// ------------------------------ BitwiseIfThenElse
+
+#ifdef HWY_NATIVE_BITWISE_IF_THEN_ELSE
+#undef HWY_NATIVE_BITWISE_IF_THEN_ELSE
+#else
+#define HWY_NATIVE_BITWISE_IF_THEN_ELSE
+#endif
+
+template <class V>
+HWY_API V BitwiseIfThenElse(V mask, V yes, V no) {
+ return IfVecThenElse(mask, yes, no);
+}
+
+// ------------------------------ CopySign (BitwiseIfThenElse)
+template <typename T, size_t N>
+HWY_API Vec128<T, N> CopySign(Vec128<T, N> magn, Vec128<T, N> sign) {
+ static_assert(IsFloat<T>(), "Only makes sense for floating-point");
+ const DFromV<decltype(magn)> d;
+ return BitwiseIfThenElse(SignBit(d), sign, magn);
+}
+
// ------------------------------ Mask logical
template <typename T, size_t N>
@@ -3395,21 +3531,19 @@ HWY_API Mask128<int64_t, N> TestBit(Vec128<int64_t, N> v,
#undef HWY_NEON_BUILD_PARAM_HWY_TESTBIT
#undef HWY_NEON_BUILD_ARG_HWY_TESTBIT
-// ------------------------------ Abs i64 (IfThenElse, BroadcastSignBit)
+// ------------------------------ Abs i64 (IfNegativeThenElse, Neg)
HWY_API Vec128<int64_t> Abs(const Vec128<int64_t> v) {
#if HWY_ARCH_ARM_A64
return Vec128<int64_t>(vabsq_s64(v.raw));
#else
- const auto zero = Zero(DFromV<decltype(v)>());
- return IfThenElse(MaskFromVec(BroadcastSignBit(v)), zero - v, v);
+ return IfNegativeThenElse(v, Neg(v), v);
#endif
}
HWY_API Vec64<int64_t> Abs(const Vec64<int64_t> v) {
#if HWY_ARCH_ARM_A64
return Vec64<int64_t>(vabs_s64(v.raw));
#else
- const auto zero = Zero(DFromV<decltype(v)>());
- return IfThenElse(MaskFromVec(BroadcastSignBit(v)), zero - v, v);
+ return IfNegativeThenElse(v, Neg(v), v);
#endif
}
@@ -3418,7 +3552,7 @@ HWY_API Vec128<int64_t> SaturatedAbs(const Vec128<int64_t> v) {
return Vec128<int64_t>(vqabsq_s64(v.raw));
#else
const auto zero = Zero(DFromV<decltype(v)>());
- return IfThenElse(MaskFromVec(BroadcastSignBit(v)), SaturatedSub(zero, v), v);
+ return IfNegativeThenElse(v, SaturatedSub(zero, v), v);
#endif
}
HWY_API Vec64<int64_t> SaturatedAbs(const Vec64<int64_t> v) {
@@ -3426,7 +3560,7 @@ HWY_API Vec64<int64_t> SaturatedAbs(const Vec64<int64_t> v) {
return Vec64<int64_t>(vqabs_s64(v.raw));
#else
const auto zero = Zero(DFromV<decltype(v)>());
- return IfThenElse(MaskFromVec(BroadcastSignBit(v)), SaturatedSub(zero, v), v);
+ return IfNegativeThenElse(v, SaturatedSub(zero, v), v);
#endif
}
@@ -3562,6 +3696,28 @@ HWY_API Vec128<double> Max(Vec128<double> a, Vec128<double> b) {
HWY_NEON_DEF_FUNCTION_ALL_FLOATS(Max, vmax, _, 2)
#endif // HWY_ARCH_ARM_A64
+// ------------------------------ MinNumber and MaxNumber
+
+#if !HWY_ARCH_ARM_A64
+
+#ifdef HWY_NATIVE_FLOAT_MIN_MAX_NUMBER
+#undef HWY_NATIVE_FLOAT_MIN_MAX_NUMBER
+#else
+#define HWY_NATIVE_FLOAT_MIN_MAX_NUMBER
+#endif
+
+template <class V, HWY_IF_FLOAT_OR_SPECIAL_V(V)>
+HWY_API V MinNumber(V a, V b) {
+ return Min(IfThenElse(IsNaN(a), b, a), IfThenElse(IsNaN(b), a, b));
+}
+
+template <class V, HWY_IF_FLOAT_OR_SPECIAL_V(V)>
+HWY_API V MaxNumber(V a, V b) {
+ return Max(IfThenElse(IsNaN(a), b, a), IfThenElse(IsNaN(b), a, b));
+}
+
+#endif
+
// ================================================== MEMORY
// ------------------------------ Load 128
@@ -7237,6 +7393,9 @@ static HWY_INLINE bfloat16x8_t BitCastToRawNeonBF16(bfloat16x8_t raw) {
// The uint16x4_t or uint16x8_t vector neets to be bitcasted to a bfloat16x4_t
// or a bfloat16x8_t vector for the vbfdot_f32 and vbfdotq_f32 intrinsics if
// HWY_NEON_HAVE_F32_TO_BF16C && !HWY_NEON_HAVE_BFLOAT16 is true
+
+// NOTE: vbfdot uses round to odd unless the additional FEAT_EBF16 feature is
+// available and enabled.
static HWY_INLINE bfloat16x4_t BitCastToRawNeonBF16(uint16x4_t raw) {
return vreinterpret_bf16_u16(raw);
}
@@ -7302,6 +7461,12 @@ HWY_API VFromD<D> ReorderWidenMulAccumulate(
detail::BitCastToRawNeonBF16(b.raw)));
}
+template <size_t N>
+HWY_API Vec128<float, N> RearrangeToOddPlusEven(Vec128<float, N> sum0,
+ Vec128<float, N>) {
+ return sum0;
+}
+
#endif // HWY_NEON_HAVE_F32_TO_BF16C
template <class D, HWY_IF_I32_D(D)>
@@ -7403,16 +7568,30 @@ HWY_API VFromD<D> Combine(D d, VFromD<Half<D>> hi, VFromD<Half<D>> lo) {
// ------------------------------ RearrangeToOddPlusEven (Combine)
-template <size_t N>
-HWY_API Vec128<float, N> RearrangeToOddPlusEven(Vec128<float, N> sum0,
- Vec128<float, N> sum1) {
-#if HWY_NEON_HAVE_BFLOAT16
- (void)sum1; // unused by bf16 ReorderWidenMulAccumulate
- return sum0;
-#else
- return Add(sum0, sum1);
-#endif
-}
+namespace detail {
+// Armv7 only provides 64-bit (half-vector) pairwise operations.
+#define HWY_NEON_DEF_PAIRWISE_OP(T, name, prefix, suffix) \
+ HWY_INLINE Vec64<T> Pairwise##name(Vec64<T> a, Vec64<T> b) { \
+ return Vec64<T>(prefix##_##suffix(a.raw, b.raw)); \
+ }
+
+// Note that Armv7 also lacks [u]int64 instructions, which are handled by
+// generic_ops-inl.h SumOfLanes etc., hence no 64-bit overloads here.
+#define HWY_NEON_DEF_PAIRWISE_OPS(name, prefix) \
+ HWY_NEON_DEF_PAIRWISE_OP(uint32_t, name, prefix, u32) \
+ HWY_NEON_DEF_PAIRWISE_OP(uint16_t, name, prefix, u16) \
+ HWY_NEON_DEF_PAIRWISE_OP(uint8_t, name, prefix, u8) \
+ HWY_NEON_DEF_PAIRWISE_OP(int32_t, name, prefix, s32) \
+ HWY_NEON_DEF_PAIRWISE_OP(int16_t, name, prefix, s16) \
+ HWY_NEON_DEF_PAIRWISE_OP(int8_t, name, prefix, s8) \
+ HWY_NEON_DEF_PAIRWISE_OP(float32_t, name, prefix, f32)
+
+HWY_NEON_DEF_PAIRWISE_OPS(Sum, vpadd)
+HWY_NEON_DEF_PAIRWISE_OPS(Min, vpmin)
+HWY_NEON_DEF_PAIRWISE_OPS(Max, vpmax)
+#undef HWY_NEON_DEF_PAIRWISE_OPS
+#undef HWY_NEON_DEF_PAIRWISE_OP
+} // namespace detail
HWY_API Vec128<int32_t> RearrangeToOddPlusEven(Vec128<int32_t> sum0,
Vec128<int32_t> sum1) {
@@ -7422,18 +7601,18 @@ HWY_API Vec128<int32_t> RearrangeToOddPlusEven(Vec128<int32_t> sum0,
#else
const Full128<int32_t> d;
const Half<decltype(d)> d64;
- const Vec64<int32_t> hi(
- vpadd_s32(LowerHalf(d64, sum1).raw, UpperHalf(d64, sum1).raw));
+ const Vec64<int32_t> hi =
+ detail::PairwiseSum(LowerHalf(d64, sum1), UpperHalf(d64, sum1));
const Vec64<int32_t> lo(
- vpadd_s32(LowerHalf(d64, sum0).raw, UpperHalf(d64, sum0).raw));
- return Combine(Full128<int32_t>(), hi, lo);
+ detail::PairwiseSum(LowerHalf(d64, sum0), UpperHalf(d64, sum0)));
+ return Combine(d, hi, lo);
#endif
}
HWY_API Vec64<int32_t> RearrangeToOddPlusEven(Vec64<int32_t> sum0,
Vec64<int32_t> sum1) {
// vmlal_s16 multiplied the lower half into sum0 and upper into sum1.
- return Vec64<int32_t>(vpadd_s32(sum0.raw, sum1.raw));
+ return detail::PairwiseSum(sum0, sum1);
}
HWY_API Vec32<int32_t> RearrangeToOddPlusEven(Vec32<int32_t> sum0,
@@ -7450,18 +7629,18 @@ HWY_API Vec128<uint32_t> RearrangeToOddPlusEven(Vec128<uint32_t> sum0,
#else
const Full128<uint32_t> d;
const Half<decltype(d)> d64;
- const Vec64<uint32_t> hi(
- vpadd_u32(LowerHalf(d64, sum1).raw, UpperHalf(d64, sum1).raw));
- const Vec64<uint32_t> lo(
- vpadd_u32(LowerHalf(d64, sum0).raw, UpperHalf(d64, sum0).raw));
- return Combine(Full128<uint32_t>(), hi, lo);
+ const Vec64<uint32_t> hi =
+ detail::PairwiseSum(LowerHalf(d64, sum1), UpperHalf(d64, sum1));
+ const Vec64<uint32_t> lo =
+ detail::PairwiseSum(LowerHalf(d64, sum0), UpperHalf(d64, sum0));
+ return Combine(d, hi, lo);
#endif
}
HWY_API Vec64<uint32_t> RearrangeToOddPlusEven(Vec64<uint32_t> sum0,
Vec64<uint32_t> sum1) {
// vmlal_u16 multiplied the lower half into sum0 and upper into sum1.
- return Vec64<uint32_t>(vpadd_u32(sum0.raw, sum1.raw));
+ return detail::PairwiseSum(sum0, sum1);
}
HWY_API Vec32<uint32_t> RearrangeToOddPlusEven(Vec32<uint32_t> sum0,
@@ -7472,7 +7651,8 @@ HWY_API Vec32<uint32_t> RearrangeToOddPlusEven(Vec32<uint32_t> sum0,
// ------------------------------ SumOfMulQuadAccumulate
-#if HWY_TARGET == HWY_NEON_BF16
+
+#if HWY_TARGET == HWY_NEON_BF16 || defined(__ARM_FEATURE_DOTPROD)
#ifdef HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE
#undef HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE
@@ -7516,18 +7696,38 @@ HWY_API VFromD<DU32> SumOfMulQuadAccumulate(
return VFromD<DU32>(vdotq_u32(sum.raw, a.raw, b.raw));
}
+#endif //__ARM_FEATURE_DOTPROD || HWY_TARGET == HWY_NEON_BF16
+
#ifdef HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE
#undef HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE
#else
#define HWY_NATIVE_U8_I8_SUMOFMULQUADACCUMULATE
#endif
+#if defined(__ARM_FEATURE_MATMUL_INT8) || \
+ (HWY_TARGET == HWY_NEON_BF16 && HWY_OS_APPLE && HWY_ARCH_ARM_A64 && \
+ HWY_HAVE_RUNTIME_DISPATCH)
+
+template <class DI32, HWY_IF_I32_D(DI32), HWY_IF_V_SIZE_LE_D(DI32, 8)>
+HWY_API VFromD<DI32> SumOfMulQuadAccumulate(
+ DI32 /*di32*/, VFromD<Repartition<uint8_t, DI32>> a_u,
+ VFromD<Repartition<int8_t, DI32>> b_i, VFromD<DI32> sum) {
+ return VFromD<DI32>(vusdot_s32(sum.raw, a_u.raw, b_i.raw));
+}
+
+template <class DI32, HWY_IF_I32_D(DI32), HWY_IF_V_SIZE_D(DI32, 16)>
+HWY_API VFromD<DI32> SumOfMulQuadAccumulate(
+ DI32 /*di32*/, VFromD<Repartition<uint8_t, DI32>> a_u,
+ VFromD<Repartition<int8_t, DI32>> b_i, VFromD<DI32> sum) {
+ return VFromD<DI32>(vusdotq_s32(sum.raw, a_u.raw, b_i.raw));
+}
+
+#else
+
template <class DI32, HWY_IF_I32_D(DI32)>
HWY_API VFromD<DI32> SumOfMulQuadAccumulate(
DI32 di32, VFromD<Repartition<uint8_t, DI32>> a_u,
VFromD<Repartition<int8_t, DI32>> b_i, VFromD<DI32> sum) {
- // TODO: use vusdot[q]_s32 on NEON targets that require support for NEON I8MM
-
const RebindToUnsigned<decltype(di32)> du32;
const Repartition<uint8_t, decltype(di32)> du8;
@@ -7540,7 +7740,7 @@ HWY_API VFromD<DI32> SumOfMulQuadAccumulate(
return BitCast(di32, Sub(result_sum0, result_sum1));
}
-#endif // HWY_TARGET == HWY_NEON_BF16
+#endif // __ARM_FEATURE_MATMUL_INT8
// ------------------------------ WidenMulPairwiseAdd
@@ -7959,6 +8159,17 @@ HWY_API V InterleaveOddBlocks(D, V a, V /*b*/) {
return a;
}
+// ------------------------------ InterleaveLowerBlocks
+template <class D, class V = VFromD<D>>
+HWY_API V InterleaveLowerBlocks(D, V a, V /*b*/) {
+ return a;
+}
+// ------------------------------ InterleaveUpperBlocks
+template <class D, class V = VFromD<D>>
+HWY_API V InterleaveUpperBlocks(D, V a, V /*b*/) {
+ return a;
+}
+
// ------------------------------ ReverseBlocks
// Single block: no change
template <class D, HWY_IF_V_SIZE_LE_D(D, 16)>
@@ -8745,71 +8956,47 @@ HWY_API VFromD<D> MaxOfLanes(D d, VFromD<D> v) {
// On Armv7 we define SumOfLanes and generic_ops defines ReduceSum via GetLane.
#else // !HWY_ARCH_ARM_A64
-// Armv7 lacks N=2 and 8-bit x4, so enable generic versions of those.
+// Armv7 lacks N=2 (except 32-bit) and 8-bit x4, so enable them in generic_ops.
#undef HWY_IF_SUM_OF_LANES_D
#define HWY_IF_SUM_OF_LANES_D(D) \
- hwy::EnableIf<(HWY_MAX_LANES_D(D) == 2) || \
+ hwy::EnableIf<(sizeof(TFromD<D>) != 4 && HWY_MAX_LANES_D(D) == 2) || \
(sizeof(TFromD<D>) == 1 && HWY_MAX_LANES_D(D) == 4)>* = \
nullptr
#undef HWY_IF_MINMAX_OF_LANES_D
#define HWY_IF_MINMAX_OF_LANES_D(D) \
- hwy::EnableIf<(HWY_MAX_LANES_D(D) == 2) || \
+ hwy::EnableIf<(sizeof(TFromD<D>) != 4 && HWY_MAX_LANES_D(D) == 2) || \
(sizeof(TFromD<D>) == 1 && HWY_MAX_LANES_D(D) == 4)>* = \
nullptr
// For arm7, we implement reductions using a series of pairwise operations. This
// produces the full vector result, so we express Reduce* in terms of *OfLanes.
-#define HWY_NEON_BUILD_TYPE_T(type, size) type##x##size##_t
-#define HWY_NEON_DEF_PAIRWISE_REDUCTION(type, size, name, prefix, suffix) \
- template <class D, HWY_IF_LANES_D(D, size)> \
- HWY_API Vec128<type##_t, size> name##OfLanes(D /* d */, \
- Vec128<type##_t, size> v) { \
- HWY_NEON_BUILD_TYPE_T(type, size) tmp = prefix##_##suffix(v.raw, v.raw); \
- if ((size / 2) > 1) tmp = prefix##_##suffix(tmp, tmp); \
- if ((size / 4) > 1) tmp = prefix##_##suffix(tmp, tmp); \
- return Vec128<type##_t, size>(tmp); \
- }
-// For the wide versions, the pairwise operations produce a half-length vector.
-// We produce that `tmp` and then Combine.
-#define HWY_NEON_DEF_WIDE_PAIRWISE_REDUCTION(type, size, half, name, prefix, \
- suffix) \
- template <class D, HWY_IF_LANES_D(D, size)> \
- HWY_API Vec128<type##_t, size> name##OfLanes(D /* d */, \
- Vec128<type##_t, size> v) { \
- HWY_NEON_BUILD_TYPE_T(type, half) tmp; \
- tmp = prefix##_##suffix(vget_high_##suffix(v.raw), \
- vget_low_##suffix(v.raw)); \
- if ((size / 2) > 1) tmp = prefix##_##suffix(tmp, tmp); \
- if ((size / 4) > 1) tmp = prefix##_##suffix(tmp, tmp); \
- if ((size / 8) > 1) tmp = prefix##_##suffix(tmp, tmp); \
- return Vec128<type##_t, size>(vcombine_##suffix(tmp, tmp)); \
+#define HWY_NEON_DEF_PAIRWISE_REDUCTION(name) \
+ /* generic_ops-inl.h handles 64-bit types. */ \
+ template <class D, HWY_IF_V_SIZE_D(D, 8), HWY_IF_NOT_T_SIZE_D(D, 8)> \
+ HWY_API VFromD<D> name##OfLanes(D d, VFromD<D> v) { \
+ HWY_LANES_CONSTEXPR size_t N = Lanes(d); \
+ VFromD<D> tmp = detail::Pairwise##name(v, v); \
+ if ((N / 2) > 1) tmp = detail::Pairwise##name(tmp, tmp); \
+ if ((N / 4) > 1) tmp = detail::Pairwise##name(tmp, tmp); \
+ return tmp; \
+ } \
+ /* Armv7 lacks q (full-vector) instructions, so first reduce 128-bit v */ \
+ /* into a half-vector, then reduce that. */ \
+ template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_NOT_T_SIZE_D(D, 8)> \
+ HWY_API VFromD<D> name##OfLanes(D d, VFromD<D> v) { \
+ const Half<D> dh; \
+ VFromD<decltype(dh)> upper = UpperHalf(dh, v); \
+ VFromD<decltype(dh)> lower = LowerHalf(dh, v); \
+ VFromD<decltype(dh)> half = detail::Pairwise##name(upper, lower); \
+ half = name##OfLanes(dh, half); \
+ return Combine(d, half, half); \
}
-#define HWY_NEON_DEF_PAIRWISE_REDUCTIONS(name, prefix) \
- HWY_NEON_DEF_PAIRWISE_REDUCTION(uint32, 2, name, prefix, u32) \
- HWY_NEON_DEF_PAIRWISE_REDUCTION(uint16, 4, name, prefix, u16) \
- HWY_NEON_DEF_PAIRWISE_REDUCTION(uint8, 8, name, prefix, u8) \
- HWY_NEON_DEF_PAIRWISE_REDUCTION(int32, 2, name, prefix, s32) \
- HWY_NEON_DEF_PAIRWISE_REDUCTION(int16, 4, name, prefix, s16) \
- HWY_NEON_DEF_PAIRWISE_REDUCTION(int8, 8, name, prefix, s8) \
- HWY_NEON_DEF_PAIRWISE_REDUCTION(float32, 2, name, prefix, f32) \
- HWY_NEON_DEF_WIDE_PAIRWISE_REDUCTION(uint32, 4, 2, name, prefix, u32) \
- HWY_NEON_DEF_WIDE_PAIRWISE_REDUCTION(uint16, 8, 4, name, prefix, u16) \
- HWY_NEON_DEF_WIDE_PAIRWISE_REDUCTION(uint8, 16, 8, name, prefix, u8) \
- HWY_NEON_DEF_WIDE_PAIRWISE_REDUCTION(int32, 4, 2, name, prefix, s32) \
- HWY_NEON_DEF_WIDE_PAIRWISE_REDUCTION(int16, 8, 4, name, prefix, s16) \
- HWY_NEON_DEF_WIDE_PAIRWISE_REDUCTION(int8, 16, 8, name, prefix, s8) \
- HWY_NEON_DEF_WIDE_PAIRWISE_REDUCTION(float32, 4, 2, name, prefix, f32)
-
-HWY_NEON_DEF_PAIRWISE_REDUCTIONS(Sum, vpadd)
-HWY_NEON_DEF_PAIRWISE_REDUCTIONS(Min, vpmin)
-HWY_NEON_DEF_PAIRWISE_REDUCTIONS(Max, vpmax)
-
-#undef HWY_NEON_DEF_PAIRWISE_REDUCTIONS
-#undef HWY_NEON_DEF_WIDE_PAIRWISE_REDUCTION
+HWY_NEON_DEF_PAIRWISE_REDUCTION(Sum)
+HWY_NEON_DEF_PAIRWISE_REDUCTION(Min)
+HWY_NEON_DEF_PAIRWISE_REDUCTION(Max)
#undef HWY_NEON_DEF_PAIRWISE_REDUCTION
-#undef HWY_NEON_BUILD_TYPE_T
// GetLane(SumsOf4(v)) is more efficient on ArmV7 NEON than the default
// N=4 I8/U8 ReduceSum implementation in generic_ops-inl.h
diff --git a/third_party/highway/hwy/ops/arm_sve-inl.h b/third_party/highway/hwy/ops/arm_sve-inl.h
index 87f0e49996..ed6529f96f 100644
--- a/third_party/highway/hwy/ops/arm_sve-inl.h
+++ b/third_party/highway/hwy/ops/arm_sve-inl.h
@@ -1,4 +1,5 @@
// Copyright 2021 Google LLC
+// Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
// SPDX-License-Identifier: Apache-2.0
//
// Licensed under the Apache License, Version 2.0 (the "License");
@@ -33,14 +34,6 @@
#define HWY_SVE_HAVE_2 0
#endif
-// If 1, both __bf16 and a limited set of *_bf16 SVE intrinsics are available:
-// create/get/set/dup, ld/st, sel, rev, trn, uzp, zip.
-#if HWY_ARM_HAVE_SCALAR_BF16_TYPE && defined(__ARM_FEATURE_SVE_BF16)
-#define HWY_SVE_HAVE_BF16_FEATURE 1
-#else
-#define HWY_SVE_HAVE_BF16_FEATURE 0
-#endif
-
// HWY_SVE_HAVE_BF16_VEC is defined to 1 if the SVE svbfloat16_t vector type
// is supported, even if HWY_SVE_HAVE_BF16_FEATURE (= intrinsics) is 0.
#if HWY_SVE_HAVE_BF16_FEATURE || \
@@ -389,9 +382,18 @@ HWY_API svbool_t PFalse() { return svpfalse_b(); }
//
// This is used in functions that load/store memory; other functions (e.g.
// arithmetic) can ignore d and use PTrue instead.
+//
+// Always use FirstN(N) for HWY_TARGET == HWY_SVE2_128 to avoid vector length
+// information loss when using PTrue(d) predicates in memory intrinsics.
+//
+// SVE2_256 is untested due to unavailable hardware and cannot assume
+// equal minimum and maximum vector lengths as SVE2_128 can.
template <class D>
svbool_t MakeMask(D d) {
- return IsFull(d) ? PTrue(d) : FirstN(d, Lanes(d));
+#if HWY_TARGET != HWY_SVE2_128
+ HWY_IF_CONSTEXPR(IsFull(d)) { return PTrue(d); }
+#endif
+ return FirstN(d, Lanes(d));
}
} // namespace detail
@@ -407,6 +409,20 @@ HWY_API svbool_t MaskFalse(const D /*d*/) {
return detail::PFalse();
}
+#ifdef HWY_NATIVE_SET_MASK
+#undef HWY_NATIVE_SET_MASK
+#else
+#define HWY_NATIVE_SET_MASK
+#endif
+
+template <class D>
+HWY_API svbool_t SetMask(D d, bool val) {
+ // The SVE svdup_n_b* intrinsics are equivalent to the FirstN op below if
+ // detail::IsFull(d) is true since svdup_n_b* is simply a wrapper around the
+ // SVE whilelo instruction.
+ return FirstN(d, size_t{0} - static_cast<size_t>(val));
+}
+
// ================================================== INIT
// ------------------------------ Set
@@ -848,8 +864,31 @@ HWY_API V AndNot(const V a, const V b) {
#if HWY_SVE_HAVE_2
+#ifdef HWY_NATIVE_XOR3
+#undef HWY_NATIVE_XOR3
+#else
+#define HWY_NATIVE_XOR3
+#endif
+
+#ifdef HWY_NATIVE_BCAX
+#undef HWY_NATIVE_BCAX
+#else
+#define HWY_NATIVE_BCAX
+#endif
+
HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGVVV, Xor3, eor3)
+// As with AndNot, we follow the x86 convention where the first argument is
+// negated, whereas the intrinsic has it last.
+#define HWY_SVE_RETV_ARGVVV_SWAP(BASE, CHAR, BITS, HALF, NAME, OP) \
+ HWY_API HWY_SVE_V(BASE, BITS) \
+ NAME(HWY_SVE_V(BASE, BITS) a, HWY_SVE_V(BASE, BITS) b, \
+ HWY_SVE_V(BASE, BITS) c) { \
+ return sv##OP##_##CHAR##BITS(a, c, b); \
+ }
+HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGVVV_SWAP, XorAndNot, bcax)
+#undef HWY_SVE_RETV_ARGVVV_SWAP
+
template <class V, HWY_IF_FLOAT_V(V)>
HWY_API V Xor3(const V x1, const V x2, const V x3) {
const DFromV<V> df;
@@ -857,12 +896,15 @@ HWY_API V Xor3(const V x1, const V x2, const V x3) {
return BitCast(df, Xor3(BitCast(du, x1), BitCast(du, x2), BitCast(du, x3)));
}
-#else
-template <class V>
-HWY_API V Xor3(V x1, V x2, V x3) {
- return Xor(x1, Xor(x2, x3));
+template <class V, HWY_IF_FLOAT_V(V)>
+HWY_API V XorAndNot(const V x, const V a1, const V a2) {
+ const DFromV<V> df;
+ const RebindToUnsigned<decltype(df)> du;
+ return BitCast(df,
+ XorAndNot(BitCast(du, x), BitCast(du, a1), BitCast(du, a2)));
}
-#endif
+
+#endif // HWY_SVE_HAVE_2
// ------------------------------ Or3
template <class V>
@@ -1645,8 +1687,10 @@ HWY_API svbool_t LowerHalfOfMask(D /*d*/, svbool_t m) {
#endif
namespace detail {
-HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedMin, min)
-HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedMax, max)
+HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGMVV, MaskedMin, minnm)
+HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGMVV, MaskedMax, maxnm)
+HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGMVV, MaskedMin, min)
+HWY_SVE_FOREACH_UI(HWY_SVE_RETV_ARGMVV, MaskedMax, max)
HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedAdd, add)
HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedSub, sub)
HWY_SVE_FOREACH(HWY_SVE_RETV_ARGMVV, MaskedMul, mul)
@@ -1971,9 +2015,8 @@ HWY_API V Min(V a, V b) {
}
template <class V, HWY_IF_FLOAT_OR_SPECIAL_V(V)>
HWY_API V Min(V a, V b) {
- return IfThenElse(Lt(a, b), a, b);
+ return IfThenElse(Or(Lt(a, b), Ne(b, b)), a, b);
}
-
#else
HWY_SVE_FOREACH_I(HWY_SVE_RETV_ARGPVV, Min, min)
HWY_SVE_FOREACH_F(HWY_SVE_RETV_ARGPVV, Min, minnm)
@@ -2469,12 +2512,13 @@ HWY_API VFromD<D> LoadDup128(D d, const TFromD<D>* HWY_RESTRICT p) {
#define HWY_NATIVE_STORE_TRUNCATED
#endif
-#define HWY_SVE_STORE_TRUNCATED(BASE, CHAR, BITS, HALF, NAME, OP, TO_BITS) \
- template <size_t N, int kPow2> \
- HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, \
- const HWY_SVE_D(BASE, BITS, N, kPow2) d, \
- HWY_SVE_T(BASE, TO_BITS) * HWY_RESTRICT p) { \
- sv##OP##_##CHAR##BITS(detail::PTrue(d), detail::NativeLanePointer(p), v); \
+#define HWY_SVE_STORE_TRUNCATED(BASE, CHAR, BITS, HALF, NAME, OP, TO_BITS) \
+ template <size_t N, int kPow2> \
+ HWY_API void NAME(HWY_SVE_V(BASE, BITS) v, \
+ const HWY_SVE_D(BASE, BITS, N, kPow2) d, \
+ HWY_SVE_T(BASE, TO_BITS) * HWY_RESTRICT p) { \
+ sv##OP##_##CHAR##BITS(detail::MakeMask(d), detail::NativeLanePointer(p), \
+ v); \
}
#define HWY_SVE_STORE_TRUNCATED_BYTE(BASE, CHAR, BITS, HALF, NAME, OP) \
@@ -3540,10 +3584,10 @@ HWY_API V InterleaveUpper(D d, const V a, const V b) {
}
// ------------------------------ InterleaveWholeLower
-#ifdef HWY_NATIVE_INTERLEAVE_WHOLE
-#undef HWY_NATIVE_INTERLEAVE_WHOLE
+#ifdef HWY_TOGGLE_INTERLEAVE_WHOLE
+#undef HWY_TOGGLE_INTERLEAVE_WHOLE
#else
-#define HWY_NATIVE_INTERLEAVE_WHOLE
+#define HWY_TOGGLE_INTERLEAVE_WHOLE
#endif
template <class D>
@@ -4077,6 +4121,56 @@ HWY_API V InterleaveOddBlocks(D d, V a, V b) {
#endif
}
+// ------------------------------ InterleaveLowerBlocks
+// (InterleaveEvenBlocks)
+
+template <class D, class V = VFromD<D>>
+HWY_API V InterleaveLowerBlocks(D d, V a, V b) {
+#if HWY_TARGET == HWY_SVE_256
+ return InterleaveEvenBlocks(d, a, b);
+#elif HWY_TARGET == HWY_SVE2_128
+ (void)d;
+ (void)b;
+ return a;
+#else
+ const Repartition<uint64_t, decltype(d)> du64;
+ const svuint64_t a64 = BitCast(du64, a);
+ const svuint64_t b64 = BitCast(du64, b);
+ svuint64_t even = detail::InterleaveEven(a64, b64); // a0 b0 a2 b2
+ svuint64_t odd = detail::InterleaveOdd(a64, b64); // a1 b1 a3 b3
+ return BitCast(d, detail::ZipLowerSame(even, odd)); // a10 b10
+#endif
+}
+
+// ------------------------------ InterleaveUpperBlocks
+// (ConcatUpperUpper, SlideDownLanes, OddEvenBlocks)
+
+template <class D, class V = VFromD<D>>
+HWY_API V InterleaveUpperBlocks(D d, V a, V b) {
+#if HWY_TARGET == HWY_SVE_256
+ return InterleaveOddBlocks(d, a, b);
+#elif HWY_TARGET == HWY_SVE2_128
+ (void)d;
+ (void)b;
+ return a;
+#else
+ const Repartition<uint64_t, decltype(d)> du64;
+ const svuint64_t a64 = BitCast(du64, a);
+ const svuint64_t b64 = BitCast(du64, b);
+ svuint64_t even = detail::InterleaveEven(a64, b64); // a0 b0 a2 b2
+ svuint64_t odd = detail::InterleaveOdd(a64, b64); // a1 b1 a3 b3
+ HWY_IF_CONSTEXPR(detail::IsFull(d)) {
+ return BitCast(d, detail::ZipUpperSame(even, odd)); // a32 b32
+ }
+ // ZipUpperSame assumes full vectors; instead use UpperHalf to honor the
+ // capped/fractional tag.
+ const Half<decltype(du64)> du64h;
+ even = ResizeBitCast(du64, UpperHalf(du64h, even));
+ odd = ResizeBitCast(du64, UpperHalf(du64h, odd));
+ return BitCast(d, detail::ZipLowerSame(even, odd));
+#endif // HWY_TARGET
+}
+
// ------------------------------ Reverse
namespace detail {
@@ -5302,7 +5396,7 @@ HWY_API V AverageRound(const V a, const V b) {
// `p` points to at least 8 readable bytes, not all of which need be valid.
template <class D, HWY_IF_T_SIZE_D(D, 1)>
HWY_INLINE svbool_t LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) {
-#if HWY_COMPILER_CLANG >= 1901 || HWY_COMPILER_GCC_ACTUAL >= 1200
+#if HWY_COMPILER_CLANG >= 2200 || HWY_COMPILER_GCC_ACTUAL >= 1200
typedef svbool_t UnalignedSveMaskT
__attribute__((__aligned__(1), __may_alias__));
(void)d;
@@ -5465,12 +5559,11 @@ HWY_API size_t StoreMaskBits(D d, svbool_t m, uint8_t* bits) {
HWY_IF_CONSTEXPR(N < 8) {
// BitsFromMask guarantees upper bits are zero, hence no masking.
bits[0] = static_cast<uint8_t>(bits64);
+ return 1;
}
- else {
- static_assert(N % 8 == 0, "N is pow2 >= 8, hence divisible");
- static_assert(HWY_IS_LITTLE_ENDIAN, "");
- hwy::CopyBytes<N / 8>(&bits64, bits);
- }
+ static_assert(N < 8 || N % 8 == 0, "N is pow2 >= 8, hence divisible");
+ static_assert(HWY_IS_LITTLE_ENDIAN, "");
+ hwy::CopyBytes<N / 8>(&bits64, bits);
constexpr size_t num_bytes = hwy::DivCeil(N, size_t{8});
return num_bytes;
#else
@@ -6344,42 +6437,126 @@ HWY_API V PairwiseAdd(D d, V a, V b) {
#endif // HWY_SVE_HAVE_2
#endif // HWY_TARGET != HWY_SCALAR
-// ------------------------------ WidenMulPairwiseAdd
+// ------------------------------ MulEvenAdd/MulOddAdd (PromoteEvenTo)
+
+// Always implemented here because this op is used in WidenMulEven.
+#ifdef HWY_NATIVE_MUL_EVEN_BF16
+#undef HWY_NATIVE_MUL_EVEN_BF16
+#else
+#define HWY_NATIVE_MUL_EVEN_BF16
+#endif
template <size_t N, int kPow2>
-HWY_API svfloat32_t WidenMulPairwiseAdd(Simd<float, N, kPow2> df, VBF16 a,
- VBF16 b) {
-#if HWY_SVE_HAVE_F32_TO_BF16C
- const svfloat32_t even = svbfmlalb_f32(Zero(df), a, b);
- return svbfmlalt_f32(even, a, b);
+HWY_API svfloat32_t MulEvenAdd(Simd<float, N, kPow2> dw, VBF16 a, VBF16 b,
+ const svfloat32_t c) {
+#if HWY_SVE_HAVE_BF16_FEATURE
+ return svbfmlalb_f32(c, a, b);
#else
- return MulAdd(PromoteEvenTo(df, a), PromoteEvenTo(df, b),
- Mul(PromoteOddTo(df, a), PromoteOddTo(df, b)));
+ return MulAdd(PromoteEvenTo(dw, a), PromoteEvenTo(dw, b), c);
#endif // HWY_SVE_HAVE_BF16_FEATURE
}
template <size_t N, int kPow2>
-HWY_API svint32_t WidenMulPairwiseAdd(Simd<int32_t, N, kPow2> d32, svint16_t a,
- svint16_t b) {
+HWY_API svfloat32_t MulOddAdd(Simd<float, N, kPow2> dw, VBF16 a, VBF16 b,
+ const svfloat32_t c) {
+#if HWY_SVE_HAVE_BF16_FEATURE
+ (void)dw;
+ return svbfmlalt_f32(c, a, b);
+#else
+ return MulAdd(PromoteOddTo(dw, a), PromoteOddTo(dw, b), c);
+#endif // HWY_SVE_HAVE_BF16_FEATURE
+}
+
+// Highway API only guarantees support for bf16*bf16+f32; also implement UI32
+// to allow reusing this op for integer WidenMulPairwiseAdd:
+
+template <size_t N, int kPow2>
+HWY_API svint32_t MulEvenAdd(Simd<int32_t, N, kPow2> dw, svint16_t a,
+ svint16_t b, const svint32_t c) {
#if HWY_SVE_HAVE_2
- (void)d32;
- return svmlalt_s32(svmullb_s32(a, b), a, b);
+ (void)dw;
+ return svmlalb_s32(c, a, b);
#else
- return MulAdd(PromoteEvenTo(d32, a), PromoteEvenTo(d32, b),
- Mul(PromoteOddTo(d32, a), PromoteOddTo(d32, b)));
-#endif
+ return MulAdd(PromoteEvenTo(dw, a), PromoteEvenTo(dw, b), c);
+#endif // HWY_SVE_HAVE_2
}
template <size_t N, int kPow2>
-HWY_API svuint32_t WidenMulPairwiseAdd(Simd<uint32_t, N, kPow2> d32,
- svuint16_t a, svuint16_t b) {
+HWY_API svuint32_t MulEvenAdd(Simd<uint32_t, N, kPow2> dw, svuint16_t a,
+ svuint16_t b, const svuint32_t c) {
#if HWY_SVE_HAVE_2
- (void)d32;
- return svmlalt_u32(svmullb_u32(a, b), a, b);
+ (void)dw;
+ return svmlalb_u32(c, a, b);
#else
- return MulAdd(PromoteEvenTo(d32, a), PromoteEvenTo(d32, b),
- Mul(PromoteOddTo(d32, a), PromoteOddTo(d32, b)));
-#endif
+ return MulAdd(PromoteEvenTo(dw, a), PromoteEvenTo(dw, b), c);
+#endif // HWY_SVE_HAVE_2
+}
+
+template <size_t N, int kPow2>
+HWY_API svint32_t MulOddAdd(Simd<int32_t, N, kPow2> dw, svint16_t a,
+ svint16_t b, const svint32_t c) {
+#if HWY_SVE_HAVE_2
+ (void)dw;
+ return svmlalt_s32(c, a, b);
+#else
+ return MulAdd(PromoteOddTo(dw, a), PromoteOddTo(dw, b), c);
+#endif // HWY_SVE_HAVE_2
+}
+
+template <size_t N, int kPow2>
+HWY_API svuint32_t MulOddAdd(Simd<uint32_t, N, kPow2> dw, svuint16_t a,
+ svuint16_t b, const svuint32_t c) {
+#if HWY_SVE_HAVE_2
+ (void)dw;
+ return svmlalt_u32(c, a, b);
+#else
+ return MulAdd(PromoteOddTo(dw, a), PromoteOddTo(dw, b), c);
+#endif // HWY_SVE_HAVE_2
+}
+
+// ------------------------------ WidenMulEven (MulEvenAdd, PromoteEvenTo)
+
+template <size_t N, int kPow2>
+HWY_API svfloat32_t WidenMulEven(Simd<float, N, kPow2> dw, VBF16 a, VBF16 b) {
+#if HWY_SVE_HAVE_BF16_FEATURE
+ (void)dw;
+ return MulEvenAdd(dw, Zero(dw), a, b);
+#else
+ // Same as MulEvenAdd, but without generating a zero argument.
+ return Mul(PromoteEvenTo(dw, a), PromoteEvenTo(dw, b));
+#endif // HWY_SVE_HAVE_BF16_FEATURE
+}
+
+template <size_t N, int kPow2>
+HWY_API svint32_t WidenMulEven(Simd<int32_t, N, kPow2> dw, svint16_t a,
+ svint16_t b) {
+#if HWY_SVE_HAVE_2
+ (void)dw;
+ return svmullb_s32(a, b);
+#else
+ // Same as MulEvenAdd, but without generating a zero argument.
+ return Mul(PromoteEvenTo(dw, a), PromoteEvenTo(dw, b));
+#endif // HWY_SVE_HAVE_2
+}
+
+template <size_t N, int kPow2>
+HWY_API svuint32_t WidenMulEven(Simd<uint32_t, N, kPow2> dw, svuint16_t a,
+ svuint16_t b) {
+#if HWY_SVE_HAVE_2
+ (void)dw;
+ return svmullb_u32(a, b);
+#else
+ // Same as MulEvenAdd, but without generating a zero argument.
+ return Mul(PromoteEvenTo(dw, a), PromoteEvenTo(dw, b));
+#endif // HWY_SVE_HAVE_2
+}
+
+// ------------------------------ WidenMulPairwiseAdd (WidenMulEven, MulOddAdd)
+// Deduce from VN: RepartitionToNarrow could be either F16 or BF16 for float.
+template <class VN, class DN = DFromV<VN>, class DW = RepartitionToWide<DN>,
+ class VW = VFromD<DW>>
+HWY_API VW WidenMulPairwiseAdd(DW dw, VN a, VN b) {
+ return MulOddAdd(dw, a, b, WidenMulEven(dw, a, b));
}
// ------------------------------ SatWidenMulPairwiseAccumulate
@@ -6426,69 +6603,47 @@ HWY_API VFromD<DI32> SatWidenMulAccumFixedPoint(DI32 /*di32*/,
#endif // HWY_SVE_HAVE_2
-// ------------------------------ ReorderWidenMulAccumulate (MulAdd, ZipLower)
+// ------------------------------ ReorderWidenMulAccumulate (MulOddEven)
#if HWY_SVE_HAVE_BF16_FEATURE
-// NOTE: we currently do not use SVE BFDOT for bf16 ReorderWidenMulAccumulate
-// because, apparently unlike NEON, it uses round to odd unless the additional
-// FEAT_EBF16 feature is available and enabled.
-#ifdef HWY_NATIVE_MUL_EVEN_BF16
-#undef HWY_NATIVE_MUL_EVEN_BF16
+#ifdef HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16
+#undef HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16
#else
-#define HWY_NATIVE_MUL_EVEN_BF16
+#define HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16
#endif
+// NOTE: svbfdot uses round to odd unless the additional FEAT_EBF16 feature is
+// available and enabled.
template <size_t N, int kPow2>
-HWY_API svfloat32_t MulEvenAdd(Simd<float, N, kPow2> /* d */, VBF16 a, VBF16 b,
- const svfloat32_t c) {
- return svbfmlalb_f32(c, a, b);
+HWY_API svfloat32_t ReorderWidenMulAccumulate(Simd<float, N, kPow2> d32,
+ VBF16 a, VBF16 b,
+ const svfloat32_t sum0,
+ svfloat32_t& sum1) {
+ (void)d32;
+ (void)sum1;
+ return svbfdot_f32(sum0, a, b);
}
-template <size_t N, int kPow2>
-HWY_API svfloat32_t MulOddAdd(Simd<float, N, kPow2> /* d */, VBF16 a, VBF16 b,
- const svfloat32_t c) {
- return svbfmlalt_f32(c, a, b);
+template <class VW, HWY_IF_FLOAT_V(VW)>
+HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW) {
+ // sum1 is unused and the invariant already holds.
+ return sum0;
}
#endif // HWY_SVE_HAVE_BF16_FEATURE
-template <size_t N, int kPow2>
-HWY_API svint32_t ReorderWidenMulAccumulate(Simd<int32_t, N, kPow2> d32,
- svint16_t a, svint16_t b,
- const svint32_t sum0,
- svint32_t& sum1) {
-#if HWY_SVE_HAVE_2
- (void)d32;
- sum1 = svmlalt_s32(sum1, a, b);
- return svmlalb_s32(sum0, a, b);
-#else
- // Lane order within sum0/1 is undefined, hence we can avoid the
- // longer-latency lane-crossing PromoteTo by using PromoteEvenTo.
- sum1 = MulAdd(PromoteOddTo(d32, a), PromoteOddTo(d32, b), sum1);
- return MulAdd(PromoteEvenTo(d32, a), PromoteEvenTo(d32, b), sum0);
-#endif
-}
-
-template <size_t N, int kPow2>
-HWY_API svuint32_t ReorderWidenMulAccumulate(Simd<uint32_t, N, kPow2> d32,
- svuint16_t a, svuint16_t b,
- const svuint32_t sum0,
- svuint32_t& sum1) {
-#if HWY_SVE_HAVE_2
- (void)d32;
- sum1 = svmlalt_u32(sum1, a, b);
- return svmlalb_u32(sum0, a, b);
-#else
- // Lane order within sum0/1 is undefined, hence we can avoid the
- // longer-latency lane-crossing PromoteTo by using PromoteEvenTo.
- sum1 = MulAdd(PromoteOddTo(d32, a), PromoteOddTo(d32, b), sum1);
- return MulAdd(PromoteEvenTo(d32, a), PromoteEvenTo(d32, b), sum0);
-#endif
+// F32 version is implemented above, or in generic_ops-inl.h.
+template <class VN, class DN = DFromV<VN>, class DW = RepartitionToWide<DN>,
+ class VW = VFromD<DW>, HWY_IF_NOT_FLOAT_D(DW)>
+HWY_API VW ReorderWidenMulAccumulate(DW dw, VN a, VN b, const VW sum0,
+ VW& sum1) {
+ sum1 = MulOddAdd(dw, a, b, sum1);
+ return MulEvenAdd(dw, a, b, sum0);
}
// ------------------------------ RearrangeToOddPlusEven
-template <class VW>
+template <class VW, HWY_IF_NOT_FLOAT_V(VW)>
HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW sum1) {
// sum0 is the sum of bottom/even lanes and sum1 of top/odd lanes.
return Add(sum0, sum1);
@@ -6529,9 +6684,10 @@ HWY_API VFromD<DU32> SumOfMulQuadAccumulate(DU32 /*du32*/, svuint8_t a,
template <class DI32, HWY_IF_I32_D(DI32)>
HWY_API VFromD<DI32> SumOfMulQuadAccumulate(DI32 di32, svuint8_t a_u,
svint8_t b_i, svint32_t sum) {
- // TODO: use svusdot_u32 on SVE targets that require support for both SVE2
- // and SVE I8MM.
-
+#if HWY_SVE_HAVE_2 && __ARM_FEATURE_MATMUL_INT8
+ (void)di32;
+ return svusdot_s32(sum, a_u, b_i);
+#else
const RebindToUnsigned<decltype(di32)> du32;
const Repartition<uint8_t, decltype(di32)> du8;
@@ -6541,6 +6697,7 @@ HWY_API VFromD<DI32> SumOfMulQuadAccumulate(DI32 di32, svuint8_t a_u,
ShiftLeft<8>(svdot_u32(Zero(du32), a_u, ShiftRight<7>(b_u)));
return BitCast(di32, Sub(result_sum0, result_sum1));
+#endif
}
#ifdef HWY_NATIVE_I16_I16_SUMOFMULQUADACCUMULATE
diff --git a/third_party/highway/hwy/ops/emu128-inl.h b/third_party/highway/hwy/ops/emu128-inl.h
index 7d54b79f57..229d6d544e 100644
--- a/third_party/highway/hwy/ops/emu128-inl.h
+++ b/third_party/highway/hwy/ops/emu128-inl.h
@@ -316,12 +316,6 @@ HWY_API Vec128<T, N> operator^(Vec128<T, N> a, Vec128<T, N> b) {
return Xor(a, b);
}
-// ------------------------------ Xor3
-template <typename T, size_t N>
-HWY_API Vec128<T, N> Xor3(Vec128<T, N> x1, Vec128<T, N> x2, Vec128<T, N> x3) {
- return Xor(x1, Xor(x2, x3));
-}
-
// ------------------------------ Or3
template <typename T, size_t N>
HWY_API Vec128<T, N> Or3(Vec128<T, N> o1, Vec128<T, N> o2, Vec128<T, N> o3) {
@@ -2323,6 +2317,17 @@ HWY_API V InterleaveOddBlocks(D, V a, V /*b*/) {
return a;
}
+// ------------------------------ InterleaveLowerBlocks
+template <class D, class V = VFromD<D>>
+HWY_API V InterleaveLowerBlocks(D, V a, V /*b*/) {
+ return a;
+}
+// ------------------------------ InterleaveUpperBlocks
+template <class D, class V = VFromD<D>>
+HWY_API V InterleaveUpperBlocks(D, V a, V /*b*/) {
+ return a;
+}
+
// ------------------------------ TableLookupLanes
// Returned by SetTableIndices for use by TableLookupLanes.
@@ -2905,7 +2910,7 @@ HWY_API VFromD<D> ReorderWidenMulAccumulate(D d32, V16 a, V16 b,
}
// ------------------------------ RearrangeToOddPlusEven
-template <class VW>
+template <class VW, HWY_IF_NOT_FLOAT_V(VW)>
HWY_API VW RearrangeToOddPlusEven(VW sum0, VW sum1) {
return Add(sum0, sum1);
}
diff --git a/third_party/highway/hwy/ops/generic_ops-inl.h b/third_party/highway/hwy/ops/generic_ops-inl.h
index d8bc111e3c..644f03d561 100644
--- a/third_party/highway/hwy/ops/generic_ops-inl.h
+++ b/third_party/highway/hwy/ops/generic_ops-inl.h
@@ -22,6 +22,7 @@
// the generic implementation here if native ops are already defined.
#include "third_party/highway/hwy/base.h"
+#include "third_party/highway/hwy/detect_compiler_arch.h"
// Define detail::Shuffle1230 etc, but only when viewing the current header;
// normally this is included via highway.h, which includes ops/*.h.
@@ -245,6 +246,22 @@ HWY_API Mask<D> MaskFalse(D d) {
#endif // HWY_NATIVE_MASK_FALSE
+// ------------------------------ SetMask
+#if (defined(HWY_NATIVE_SET_MASK) == defined(HWY_TARGET_TOGGLE))
+#ifdef HWY_NATIVE_SET_MASK
+#undef HWY_NATIVE_SET_MASK
+#else
+#define HWY_NATIVE_SET_MASK
+#endif
+
+template <class D>
+HWY_API Mask<D> SetMask(D d, bool val) {
+ const Repartition<int32_t, decltype(d)> di32;
+ return MaskFromVec(ResizeBitCast(d, Set(di32, -static_cast<int32_t>(val))));
+}
+
+#endif // HWY_NATIVE_SET_MASK
+
// ------------------------------ IfNegativeThenElseZero
#if (defined(HWY_NATIVE_IF_NEG_THEN_ELSE_ZERO) == defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_IF_NEG_THEN_ELSE_ZERO
@@ -466,11 +483,11 @@ HWY_API V RotateLeft(V v) {
}
// ------------------------------ InterleaveWholeLower/InterleaveWholeUpper
-#if (defined(HWY_NATIVE_INTERLEAVE_WHOLE) == defined(HWY_TARGET_TOGGLE))
-#ifdef HWY_NATIVE_INTERLEAVE_WHOLE
-#undef HWY_NATIVE_INTERLEAVE_WHOLE
+#if (defined(HWY_TOGGLE_INTERLEAVE_WHOLE) == defined(HWY_TARGET_TOGGLE))
+#ifdef HWY_TOGGLE_INTERLEAVE_WHOLE
+#undef HWY_TOGGLE_INTERLEAVE_WHOLE
#else
-#define HWY_NATIVE_INTERLEAVE_WHOLE
+#define HWY_TOGGLE_INTERLEAVE_WHOLE
#endif
#if HWY_TARGET != HWY_SCALAR || HWY_IDE
@@ -497,7 +514,7 @@ HWY_API VFromD<D> InterleaveWholeUpper(D d, VFromD<D> a, VFromD<D> b) {
// is implemented in wasm_256-inl.h.
#endif // HWY_TARGET != HWY_SCALAR
-#endif // HWY_NATIVE_INTERLEAVE_WHOLE
+#endif // HWY_TOGGLE_INTERLEAVE_WHOLE
#if HWY_TARGET != HWY_SCALAR || HWY_IDE
// The InterleaveWholeLower without the optional D parameter is generic for all
@@ -519,6 +536,37 @@ HWY_API V InterleaveEven(V a, V b) {
}
#endif
+// ------------------------------ MinNumber/MaxNumber
+
+#if (defined(HWY_NATIVE_FLOAT_MIN_MAX_NUMBER) == defined(HWY_TARGET_TOGGLE))
+#ifdef HWY_NATIVE_FLOAT_MIN_MAX_NUMBER
+#undef HWY_NATIVE_FLOAT_MIN_MAX_NUMBER
+#else
+#define HWY_NATIVE_FLOAT_MIN_MAX_NUMBER
+#endif
+
+template <class V, HWY_IF_FLOAT_OR_SPECIAL_V(V)>
+HWY_API V MinNumber(V a, V b) {
+ return Min(a, b);
+}
+
+template <class V, HWY_IF_FLOAT_OR_SPECIAL_V(V)>
+HWY_API V MaxNumber(V a, V b) {
+ return Max(a, b);
+}
+
+#endif
+
+template <class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V)>
+HWY_API V MinNumber(V a, V b) {
+ return Min(a, b);
+}
+
+template <class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V)>
+HWY_API V MaxNumber(V a, V b) {
+ return Max(a, b);
+}
+
// ------------------------------ MinMagnitude/MaxMagnitude
#if (defined(HWY_NATIVE_FLOAT_MIN_MAX_MAGNITUDE) == defined(HWY_TARGET_TOGGLE))
@@ -644,12 +692,18 @@ HWY_API V MaskedMulOr(V no, M m, V a, V b) {
template <class V, class M>
HWY_API V MaskedDivOr(V no, M m, V a, V b) {
- return IfThenElse(m, Div(a, b), no);
+ const DFromV<V> d;
+ // Avoid division by zero for masked-out lanes.
+ const V nonzero = Set(d, TFromD<decltype(d)>{1});
+ return IfThenElse(m, Div(a, IfThenElse(m, b, nonzero)), no);
}
template <class V, class M>
HWY_API V MaskedModOr(V no, M m, V a, V b) {
- return IfThenElse(m, Mod(a, b), no);
+ const DFromV<V> d;
+ // Avoid division by zero for masked-out lanes.
+ const V nonzero = Set(d, TFromD<decltype(d)>{1});
+ return IfThenElse(m, Mod(a, IfThenElse(m, b, nonzero)), no);
}
template <class V, class M>
@@ -797,6 +851,54 @@ HWY_API MFromD<D> MaskedIsNaN(const M m, const V v) {
}
#endif // HWY_NATIVE_MASKED_COMP
+// ------------------------------ Xor3
+
+#if (defined(HWY_NATIVE_XOR3) == defined(HWY_TARGET_TOGGLE))
+#ifdef HWY_NATIVE_XOR3
+#undef HWY_NATIVE_XOR3
+#else
+#define HWY_NATIVE_XOR3
+#endif
+
+template <class V>
+HWY_API V Xor3(V x1, V x2, V x3) {
+ return Xor(x1, Xor(x2, x3));
+}
+
+#endif // HWY_NATIVE_XOR3
+
+// ------------------------------ XorAndNot
+
+#if (defined(HWY_NATIVE_BCAX) == defined(HWY_TARGET_TOGGLE))
+#ifdef HWY_NATIVE_BCAX
+#undef HWY_NATIVE_BCAX
+#else
+#define HWY_NATIVE_BCAX
+#endif
+
+template <class V>
+HWY_API V XorAndNot(const V x, const V a1, const V a2) {
+ return Xor(x, AndNot(a1, a2));
+}
+
+#endif // HWY_NATIVE_BCAX
+
+// ------------------------------ AndXor
+
+#if (defined(HWY_NATIVE_TERNLOG) == defined(HWY_TARGET_TOGGLE))
+#ifdef HWY_NATIVE_TERNLOG
+#undef HWY_NATIVE_TERNLOG
+#else
+#define HWY_NATIVE_TERNLOG
+#endif
+
+template <class V>
+HWY_API V AndXor(const V a, const V x1, const V x2) {
+ return And(a, Xor(x1, x2));
+}
+
+#endif // HWY_NATIVE_TERNLOG
+
// ------------------------------ IfNegativeThenNegOrUndefIfZero
#if (defined(HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG) == \
@@ -1090,6 +1192,7 @@ HWY_API VFromD<D> MaxOfLanes(D /* tag */, VFromD<D> v) {
#else
#define HWY_NATIVE_REDUCE_SUM_4_UI8
#endif
+
template <class D, HWY_IF_V_SIZE_D(D, 4), HWY_IF_UI8_D(D)>
HWY_API TFromD<D> ReduceSum(D d, VFromD<D> v) {
const Twice<RepartitionToWide<decltype(d)>> dw;
@@ -1241,12 +1344,10 @@ HWY_API VFromD<RebindToSigned<DFromV<V>>> FloorInt(V v) {
template <class V, HWY_IF_FLOAT_V(V)>
HWY_API V MulByPow2(V v, VFromD<RebindToSigned<DFromV<V>>> exp) {
const DFromV<decltype(v)> df;
- const RebindToUnsigned<decltype(df)> du;
const RebindToSigned<decltype(df)> di;
using TF = TFromD<decltype(df)>;
using TI = TFromD<decltype(di)>;
- using TU = TFromD<decltype(du)>;
using VF = VFromD<decltype(df)>;
using VI = VFromD<decltype(di)>;
@@ -1268,85 +1369,43 @@ HWY_API V MulByPow2(V v, VFromD<RebindToSigned<DFromV<V>>> exp) {
using TExpMinMax = TI;
#endif
-#if HWY_TARGET == HWY_EMU128 || HWY_TARGET == HWY_SCALAR
- using TExpSatSub = TU;
-#elif HWY_TARGET <= HWY_SSE2 || HWY_TARGET == HWY_WASM || \
- HWY_TARGET == HWY_WASM_EMU256
- using TExpSatSub = If<(sizeof(TF) == 4), uint8_t, uint16_t>;
-#elif HWY_TARGET_IS_PPC
- using TExpSatSub = If<(sizeof(TF) >= 4), uint32_t, TU>;
-#else
- using TExpSatSub = If<(sizeof(TF) == 4), uint8_t, TU>;
-#endif
-
static_assert(kExpBias <= static_cast<TI>(LimitsMax<TExpMinMax>() / 3),
"kExpBias <= LimitsMax<TExpMinMax>() / 3 must be true");
const Repartition<TExpMinMax, decltype(df)> d_exp_min_max;
- const Repartition<TExpSatSub, decltype(df)> d_sat_exp_sub;
- constexpr int kNumOfExpBits = ExponentBits<TF>();
constexpr int kNumOfMantBits = MantissaBits<TF>();
- // The sign bit of BitCastScalar<TU>(a[i]) >> kNumOfMantBits can be zeroed out
- // using SaturatedSub if kZeroOutSignUsingSatSub is true.
-
- // If kZeroOutSignUsingSatSub is true, then val_for_exp_sub will be bitcasted
- // to a vector that has a smaller lane size than TU for the SaturatedSub
- // operation below.
- constexpr bool kZeroOutSignUsingSatSub =
- ((sizeof(TExpSatSub) * 8) == static_cast<size_t>(kNumOfExpBits));
-
- // If kZeroOutSignUsingSatSub is true, then the upper
- // (sizeof(TU) - sizeof(TExpSatSub)) * 8 bits of kExpDecrBy1Bits will be all
- // ones and the lower sizeof(TExpSatSub) * 8 bits of kExpDecrBy1Bits will be
- // equal to 1.
-
- // Otherwise, if kZeroOutSignUsingSatSub is false, kExpDecrBy1Bits will be
- // equal to 1.
- constexpr TU kExpDecrBy1Bits = static_cast<TU>(
- TU{1} - (static_cast<TU>(kZeroOutSignUsingSatSub) << kNumOfExpBits));
-
- VF val_for_exp_sub = v;
- HWY_IF_CONSTEXPR(!kZeroOutSignUsingSatSub) {
- // If kZeroOutSignUsingSatSub is not true, zero out the sign bit of
- // val_for_exp_sub[i] using Abs
- val_for_exp_sub = Abs(val_for_exp_sub);
- }
-
- // min_exp1_plus_min_exp2[i] is the smallest exponent such that
- // min_exp1_plus_min_exp2[i] >= 2 - kExpBias * 2 and
- // std::ldexp(v[i], min_exp1_plus_min_exp2[i]) is a normal floating-point
- // number if v[i] is a normal number
- const VI min_exp1_plus_min_exp2 = BitCast(
- di,
- Max(BitCast(
- d_exp_min_max,
- Neg(BitCast(
- di,
- SaturatedSub(
- BitCast(d_sat_exp_sub, ShiftRight<kNumOfMantBits>(
- BitCast(du, val_for_exp_sub))),
- BitCast(d_sat_exp_sub, Set(du, kExpDecrBy1Bits)))))),
- BitCast(d_exp_min_max,
- Set(di, static_cast<TI>(2 - kExpBias - kExpBias)))));
+ const VI exp_bias = Set(di, kExpBias);
const VI clamped_exp =
- Max(Min(exp, Set(di, static_cast<TI>(kExpBias * 3))),
- Add(min_exp1_plus_min_exp2, Set(di, static_cast<TI>(1 - kExpBias))));
+ Clamp(exp, Set(di, 3 - 3 * kExpBias), Set(di, 3 * kExpBias));
- const VI exp1_plus_exp2 = BitCast(
- di, Max(Min(BitCast(d_exp_min_max,
- Sub(clamped_exp, ShiftRight<2>(clamped_exp))),
- BitCast(d_exp_min_max,
- Set(di, static_cast<TI>(kExpBias + kExpBias)))),
- BitCast(d_exp_min_max, min_exp1_plus_min_exp2)));
+ const auto min_scale_factor_exp =
+ BitCast(d_exp_min_max, Set(di, 1 - kExpBias));
+ const auto max_scale_factor_exp = BitCast(d_exp_min_max, exp_bias);
- const VI exp1 = ShiftRight<1>(exp1_plus_exp2);
- const VI exp2 = Sub(exp1_plus_exp2, exp1);
- const VI exp3 = Sub(clamped_exp, exp1_plus_exp2);
+ // If clamped_exp[i] < 0, ensure that 1 - kExpBias <= exp1[i] <= 0,
+ // 1 - kExpBias <= exp2[i] <= 0, and 1 - kExpBias <= exp3[i] <= 0 are
+ // true.
- const VI exp_bias = Set(di, kExpBias);
+ // In addition, if clamped_exp[i] < 1 - kExpBias, ensure that
+ // exp3[i] == 1 - kExpBias to ensure results are correctly rounded if the
+ // exact value of |x[i] * factor1[i] * factor2[i] * factor3[i]| is less than
+ // the smallest positive normal value.
+
+ // Otherwise, if clamped_exp[i] >= 0, ensure that 0 <= exp1[i] <= kExpBias,
+ // 0 <= exp2[i] <= kExpBias, and 0 <= exp3[i] <= kExpBias are all true.
+
+ const VI exp3 =
+ BitCast(di, Clamp(BitCast(d_exp_min_max, clamped_exp),
+ min_scale_factor_exp, max_scale_factor_exp));
+
+ const VI clamped_exp_minus_exp3 = Sub(clamped_exp, exp3);
+ const VI exp2 =
+ BitCast(di, Clamp(BitCast(d_exp_min_max, clamped_exp_minus_exp3),
+ min_scale_factor_exp, max_scale_factor_exp));
+ const VI exp1 = Sub(clamped_exp_minus_exp3, exp2);
const VF factor1 =
BitCast(df, ShiftLeft<kNumOfMantBits>(Add(exp1, exp_bias)));
@@ -1355,6 +1414,37 @@ HWY_API V MulByPow2(V v, VFromD<RebindToSigned<DFromV<V>>> exp) {
const VF factor3 =
BitCast(df, ShiftLeft<kNumOfMantBits>(Add(exp3, exp_bias)));
+ // If exp2[i] < 0, then clamped_exp[i] < 1 - kExpBias and
+ // exp3[i] == 1 - kExpBias will both be true. factor3[i] will be equal to the
+ // smallest positive normal value if exp3[i] == 1 - kExpBias.
+
+ // If exp2[i] >= 0, then exp1[i] >= 0 and factor1[i] * factor2[i] >= 1.
+
+ // If exp2[i] < 0 and the exact value of |v[i] * factor1[i] * factor2[i]| is
+ // less than the smallest positive normal value, then the exact value of
+ // |v[i] * factor1[i] * factor2[i] * factor3[i]| will be much smaller than
+ // half of the smallest positive denormal value (since factor3[i] will be
+ // equal to the smallest positive normal value in this case), resulting in a
+ // correctly rounded result in this case.
+
+ // If kExpBias >= kNumOfMantBits + 3 and exp3[i] == 1 - kExpBias are both
+ // true, then factor3[i] will be small enough such that
+ // v[i] * factor1[i] * factor2[i] * factor3[i] will be correctly rounded,
+ // even if the exact value of |v[i] * factor1[i] * factor2[i]| is smaller than
+ // the smallest positive normal value.
+
+ // kExpBias >= kNumOfMantBits + 3 is true for the F16, F32, and F64
+ // floating-point types.
+
+ // Otherwise, either exp2[i] >= 0, the exact value of
+ // |v[i] * factor1[i] * factor2[i]| is greater than or equal to the smallest
+ // positive normal value, or v[i] is NaN. In these cases,
+ // v[i] * factor1[i] * factor2[i] will either be exact or overflow to
+ // infinity (if clamped_exp[i] > 0 and v[i] is a non-zero finite value),
+ // resulting in a correctly rounded result if the exact value of
+ // |v[i] * factor1[i] * factor2[i] * factor3[i]| is less than the smallest
+ // positive normal value.
+
return Mul(Mul(Mul(v, factor1), factor2), factor3);
}
@@ -3118,8 +3208,8 @@ HWY_API VFromD<D> GatherIndexN(D d, const T* HWY_RESTRICT base,
template <class D, typename T = TFromD<D>>
HWY_API VFromD<D> GatherIndexNOr(VFromD<D> no, D d, const T* HWY_RESTRICT base,
- VFromD<RebindToSigned<D>> index,
- const size_t max_lanes_to_load) {
+ VFromD<RebindToSigned<D>> index,
+ const size_t max_lanes_to_load) {
const RebindToSigned<D> di;
using TI = TFromD<decltype(di)>;
static_assert(sizeof(T) == sizeof(TI), "Index/lane size must match");
@@ -3140,8 +3230,8 @@ HWY_API VFromD<D> GatherIndexN(D d, const T* HWY_RESTRICT base,
}
template <class D, typename T = TFromD<D>>
HWY_API VFromD<D> GatherIndexNOr(VFromD<D> no, D d, const T* HWY_RESTRICT base,
- VFromD<RebindToSigned<D>> index,
- const size_t max_lanes_to_load) {
+ VFromD<RebindToSigned<D>> index,
+ const size_t max_lanes_to_load) {
return MaskedGatherIndexOr(no, FirstN(d, max_lanes_to_load), d, base, index);
}
#endif // (defined(HWY_NATIVE_GATHER) == defined(HWY_TARGET_TOGGLE))
@@ -4487,43 +4577,46 @@ HWY_API V CLMulUpper(V a, V b) {
#define HWY_NATIVE_POPCNT
#endif
-// This overload requires vectors to be at least 16 bytes, which is the case
-// for LMUL >= 2.
-#undef HWY_IF_POPCNT
-#if HWY_TARGET == HWY_RVV
-#define HWY_IF_POPCNT(D) \
- hwy::EnableIf<D().Pow2() >= 1 && D().MaxLanes() >= 16>* = nullptr
-#else
-// Other targets only have these two overloads which are mutually exclusive, so
-// no further conditions are required.
-#define HWY_IF_POPCNT(D) void* = nullptr
-#endif // HWY_TARGET == HWY_RVV
-
-template <class V, class D = DFromV<V>, HWY_IF_U8_D(D),
- HWY_IF_V_SIZE_GT_D(D, 8), HWY_IF_POPCNT(D)>
+template <class V, class D = DFromV<V>, HWY_IF_U8_D(D)>
HWY_API V PopulationCount(V v) {
const D d;
- const V lookup =
- Dup128VecFromValues(d, 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4);
- const auto lo = And(v, Set(d, uint8_t{0xF}));
- const auto hi = ShiftRight<4>(v);
- return Add(TableLookupBytes(lookup, hi), TableLookupBytes(lookup, lo));
-}
-// RVV has a specialization that avoids the Set().
-#if HWY_TARGET != HWY_RVV
-// Slower fallback for capped vectors.
-template <class V, class D = DFromV<V>, HWY_IF_U8_D(D),
- HWY_IF_V_SIZE_LE_D(D, 8)>
-HWY_API V PopulationCount(V v) {
- const D d;
+#if HWY_TARGET == HWY_SSE2
+ // TableLookupBytes is slow on SSE2
+
// See https://arxiv.org/pdf/1611.07612.pdf, Figure 3
const V k33 = Set(d, uint8_t{0x33});
v = Sub(v, And(ShiftRight<1>(v), Set(d, uint8_t{0x55})));
v = Add(And(ShiftRight<2>(v), k33), And(v, k33));
return And(Add(v, ShiftRight<4>(v)), Set(d, uint8_t{0x0F}));
+#else // HWY_TARGET != HWY_SSE2
+
+#if HWY_TARGET == HWY_RVV
+ // Need at least LMUL=1 on RVV to ensure that Lanes(d_tbl) is at least 16
+ const ScalableTag<uint8_t, HWY_MAX(HWY_POW2_D(D), 0)> d_tbl;
+#else
+ const FixedTag<uint8_t, HWY_MAX(HWY_MAX_LANES_D(D), 16)> d_tbl;
+#endif
+
+ const auto lookup = Dup128VecFromValues(d_tbl, 0, 1, 1, 2, 1, 2, 2, 3, 1, 2,
+ 2, 3, 2, 3, 3, 4);
+ const auto lo = And(v, Set(d, uint8_t{0xF}));
+ const auto hi = ShiftRight<4>(v);
+
+#if HWY_TARGET == HWY_RVV
+ // On RVV, use TableLookupLanes to avoid unnecessary overhead
+ const auto hi_popcnt =
+ ResizeBitCast(d, TableLookupLanes(lookup, ResizeBitCast(d_tbl, hi)));
+ const auto lo_popcnt =
+ ResizeBitCast(d, TableLookupLanes(lookup, ResizeBitCast(d_tbl, lo)));
+#else // HWY_TARGET != HWY_RVV
+ const auto hi_popcnt = TableLookupBytes(lookup, hi);
+ const auto lo_popcnt = TableLookupBytes(lookup, lo);
+#endif // HWY_TARGET == HWY_RVV
+
+ return Add(hi_popcnt, lo_popcnt);
+#endif // HWY_TARGET == HWY_SSE2
}
-#endif // HWY_TARGET != HWY_RVV
template <class V, class D = DFromV<V>, HWY_IF_U16_D(D)>
HWY_API V PopulationCount(V v) {
@@ -5285,18 +5378,20 @@ HWY_INLINE V IntDiv(V a, V b) {
#endif // HWY_HAVE_FLOAT64
template <size_t kOrigLaneSize, class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V),
- HWY_IF_T_SIZE_ONE_OF_V(V, ((HWY_TARGET <= HWY_SSE2 ||
- HWY_TARGET == HWY_WASM ||
- HWY_TARGET == HWY_WASM_EMU256)
- ? 0
- : (1 << 1)) |
- (1 << 2) | (1 << 4) | (1 << 8))>
+ HWY_IF_T_SIZE_ONE_OF_V(
+ V, ((HWY_TARGET <= HWY_SSE2 || HWY_TARGET == HWY_WASM ||
+ HWY_TARGET == HWY_WASM_EMU256 || HWY_TARGET == HWY_LSX ||
+ HWY_TARGET == HWY_LASX)
+ ? 0
+ : (1 << 1)) |
+ (1 << 2) | (1 << 4) | (1 << 8))>
HWY_INLINE V IntMod(V a, V b) {
return hwy::HWY_NAMESPACE::NegMulAdd(IntDiv<kOrigLaneSize>(a, b), b, a);
}
-#if HWY_TARGET <= HWY_SSE2 || HWY_TARGET == HWY_WASM || \
- HWY_TARGET == HWY_WASM_EMU256
+#if HWY_TARGET <= HWY_SSE2 || HWY_TARGET == HWY_WASM || \
+ HWY_TARGET == HWY_WASM_EMU256 || HWY_TARGET == HWY_LSX || \
+ HWY_TARGET == HWY_LASX
template <size_t kOrigLaneSize, class V, HWY_IF_UI8(TFromV<V>),
HWY_IF_V_SIZE_LE_V(V, HWY_MAX_BYTES / 2)>
HWY_INLINE V IntMod(V a, V b) {
@@ -5315,7 +5410,7 @@ HWY_INLINE V IntMod(V a, V b) {
IntMod<kOrigLaneSize>(PromoteUpperTo(dw, a), PromoteUpperTo(dw, b)));
}
#endif // HWY_TARGET <= HWY_SSE2 || HWY_TARGET == HWY_WASM || HWY_TARGET ==
- // HWY_WASM_EMU256
+ // HWY_WASM_EMU256 || HWY_TARGET == HWY_LSX || HWY_TARGET == HWY_LASX
} // namespace detail
@@ -5433,17 +5528,15 @@ HWY_API V RoundingShiftRight(V v) {
template <class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V)>
HWY_API V RoundingShiftRightSame(V v, int shift_amt) {
const DFromV<V> d;
- using T = TFromD<decltype(d)>;
-
- const int shift_amt_is_zero_mask = -static_cast<int>(shift_amt == 0);
+ const bool shift_amt_is_zero = (shift_amt == 0);
const auto scaled_down_v = ShiftRightSame(
v, static_cast<int>(static_cast<unsigned>(shift_amt) +
- static_cast<unsigned>(~shift_amt_is_zero_mask)));
+ static_cast<unsigned>(shift_amt_is_zero) - 1u));
return AverageRound(
scaled_down_v,
- And(scaled_down_v, Set(d, static_cast<T>(shift_amt_is_zero_mask))));
+ IfThenElseZero(SetMask(d, shift_amt_is_zero), scaled_down_v));
}
template <class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V)>
@@ -5490,7 +5583,7 @@ HWY_API VFromD<DF> MulOddAdd(DF df, VBF a, VBF b, VFromD<DF> c) {
// ------------------------------ ReorderWidenMulAccumulate (MulEvenAdd)
-// AVX3_SPR/ZEN4, and NEON with bf16 but not(!) SVE override this.
+// AVX3_SPR/ZEN4, NEON with bf16 and SVE override this.
#if (defined(HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16) == \
defined(HWY_TARGET_TOGGLE))
#ifdef HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16
@@ -5510,6 +5603,13 @@ HWY_API VFromD<DF> ReorderWidenMulAccumulate(DF df, VBF a, VBF b,
return MulEvenAdd(df, a, b, sum0);
}
+template <class VW, HWY_IF_FLOAT_V(VW)>
+HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW sum1) {
+ // sum1 contains the odd lanes and sum0 the even, hence their sum is the
+ // desired pairwise sum.
+ return Add(sum0, sum1);
+}
+
#endif // HWY_NATIVE_REORDER_WIDEN_MUL_ACC_BF16
// ------------------------------ WidenMulAccumulate
@@ -5521,8 +5621,7 @@ HWY_API VFromD<DF> ReorderWidenMulAccumulate(DF df, VBF a, VBF b,
#define HWY_NATIVE_WIDEN_MUL_ACCUMULATE
#endif
-template<class D, HWY_IF_INTEGER(TFromD<D>),
- class DN = RepartitionToNarrow<D>>
+template <class D, HWY_IF_INTEGER(TFromD<D>), class DN = RepartitionToNarrow<D>>
HWY_API VFromD<D> WidenMulAccumulate(D d, VFromD<DN> mul, VFromD<DN> x,
VFromD<D> low, VFromD<D>& high) {
high = MulAdd(PromoteUpperTo(d, mul), PromoteUpperTo(d, x), high);
@@ -5901,7 +6000,7 @@ HWY_API size_t CompressBitsStore(V v, const uint8_t* HWY_RESTRICT bits, D d,
Store(v, d, lanes);
const Simd<T, HWY_MIN(MaxLanes(d), 8), 0> d8;
- T* HWY_RESTRICT pos = unaligned;
+ T* pos = unaligned;
HWY_ALIGN constexpr T table[2048] = {
0, 1, 2, 3, 4, 5, 6, 7, /**/ 0, 1, 2, 3, 4, 5, 6, 7, //
@@ -6033,15 +6132,40 @@ HWY_API size_t CompressBitsStore(V v, const uint8_t* HWY_RESTRICT bits, D d,
2, 3, 4, 5, 6, 7, 0, 1, /**/ 0, 2, 3, 4, 5, 6, 7, 1, //
1, 2, 3, 4, 5, 6, 7, 0, /**/ 0, 1, 2, 3, 4, 5, 6, 7};
- for (size_t i = 0; i < Lanes(d); i += 8) {
- // Each byte worth of bits is the index of one of 256 8-byte ranges, and its
- // population count determines how far to advance the write position.
- const size_t bits8 = bits[i / 8];
- const auto indices = Load(d8, table + bits8 * 8);
- const auto compressed = TableLookupBytes(LoadU(d8, lanes + i), indices);
- StoreU(compressed, d8, pos);
- pos += PopCount(bits8);
+ size_t i = 0;
+ HWY_LANES_CONSTEXPR size_t N = Lanes(d);
+ constexpr bool kMaybeLt128 =
+ (HWY_TARGET == HWY_SCALAR) || !detail::IsFull(D());
+ // If less than 128 bit, we may not enter the main loop below, and even
+ // the remainder loop might not write anything if bits are not set.
+ // Ensure the output is initialized. GCC seems not to understand this is only
+ // necessary if kMaybeLt128.
+ HWY_IF_CONSTEXPR(kMaybeLt128 || HWY_COMPILER_GCC_ACTUAL) {
+ StoreU(v, d, unaligned);
+ }
+ HWY_ASSUME(N >= 8 || kMaybeLt128);
+ if (N >= 8) {
+ for (; i <= N - 8; i += 8) {
+ // Each byte worth of bits is the index of one of 256 8-byte ranges, and
+ // its population count determines how far to advance the write position.
+ const size_t bits8 = bits[i / 8];
+ const auto indices = Load(d8, table + bits8 * 8);
+ const auto compressed = TableLookupBytes(LoadU(d8, lanes + i), indices);
+ StoreU(compressed, d8, pos);
+ pos += PopCount(bits8);
+ }
+ }
+ // Not required if we have full vectors of >= 128 bits, because they are
+ // multiples of 8 bytes. Inefficient loop is mainly required for safely
+ // handling compress_test).
+ HWY_IF_CONSTEXPR(kMaybeLt128) {
+ for (; i < N; ++i) {
+ if (bits[i / 8] & (1u << (i % 8))) {
+ *pos++ = lanes[i];
+ }
+ }
}
+
return static_cast<size_t>(pos - unaligned);
}
@@ -6734,12 +6858,19 @@ HWY_API Vec128<T, 1> Expand(Vec128<T, 1> v, Mask128<T, 1> mask) {
}
// ------------------------------ LoadExpand
+
+// #2957: clangd warning because x86_128-inl.h defines an overload with this
+// condition, so negate it here.
+#if !(HWY_TARGET <= HWY_AVX3 || HWY_IDE)
+
template <class D, HWY_IF_V_SIZE_LE_D(D, 16)>
HWY_API VFromD<D> LoadExpand(MFromD<D> mask, D d,
const TFromD<D>* HWY_RESTRICT unaligned) {
return Expand(LoadU(d, unaligned), mask);
}
+#endif // !(HWY_TARGET <= HWY_AVX3 || HWY_IDE)
+
#endif // HWY_NATIVE_EXPAND
// ------------------------------ TwoTablesLookupLanes
@@ -6757,6 +6888,138 @@ HWY_API VFromD<D> TwoTablesLookupLanes(D /*d*/, VFromD<D> a, VFromD<D> b,
}
#endif
+// ------------------------------ Lookup8
+
+template <class D, typename T = TFromD<D>, class VI>
+HWY_INLINE Vec<D> Lookup8(D d, const T* HWY_RESTRICT table, VI indices) {
+ // `di` describes the indices given - same bits per lane, but `d` determines
+ // the actual lane count of the result and also of the table vectors, which
+ // is relevant for adjusting the index values, see below.
+ DFromV<VI> di;
+ static_assert(sizeof(T) == sizeof(TFromD<decltype(di)>),
+ "Index/vector must have same lane size");
+ HWY_IF_CONSTEXPR(HWY_IS_DEBUG_BUILD) {
+ // Asserting Lanes(di) >= 4 not needed since both d and di have the same
+ // number of Lanes()
+ HWY_DASSERT(Lanes(d) >= 4);
+ HWY_DASSERT(AllTrue(di, Lt(indices, Set(di, 8))));
+ }
+
+ HWY_IF_CONSTEXPR(!HWY_HAVE_SCALABLE) {
+ // Fixed-size vectors: we know they are >= 128 bit, so either one or two
+ // tables are sufficient.
+ HWY_IF_CONSTEXPR(MaxLanes(d) >= 8) {
+ const CappedTag<T, 8> d8;
+ // We want to perform one lookup per index, hence cast. This has no
+ // runtime cost; the upper lanes are unused.
+ const Vec<D> t0 = ResizeBitCast(d, Load(d8, table));
+ return TableLookupLanes(t0, IndicesFromVec(d, indices));
+ }
+ HWY_IF_CONSTEXPR(MaxLanes(d) < 8) {
+ // Exactly 4 lanes per vector, because we ensured >= 4 above.
+ const Vec<D> t0 = Load(d, table);
+ const Vec<D> t1 = Load(d, table + 4);
+ return TwoTablesLookupLanes(d, t0, t1, IndicesFromVec(d, indices));
+ }
+ }
+
+ HWY_IF_CONSTEXPR(HWY_HAVE_SCALABLE) {
+ // Scalable: first we must load two halves of the table into two vectors,
+ // regardless of vector size. We always use two-vector lookups to avoid
+ // runtime branching. Note that RVV can have U64x8 even with 128-bit
+ // vectors (LMUL=4), hence we must use the given LMUL, not FixedTag, but we
+ // still want to cap at 4 lanes to avoid overrunning the table.
+ const CappedTag<T, 4, d.Pow2()> d4;
+
+ // We want to use native lookup instructions (more efficient on SVE than two
+ // lookups plus a blend), hence cast. This has no runtime cost. No LoadU
+ // required because + 4 is still aligned relative to `d4`.
+ const Vec<D> t0 = ResizeBitCast(d, Load(d4, table));
+ const Vec<D> t1 = ResizeBitCast(d, Load(d4, table + 4));
+
+ // Now ensure indices for the second half of the table point to the second
+ // vector. Note that SVE2_128 and SVE_256 are handled by the fixed-size case
+ // above. The adjustment factor is 0 for 128-bit SIMD, which can happen with
+ // 128-bit SVE1 hardware, but we do not know that at compile time.
+ using TI = TFromD<decltype(di)>;
+ const VI adjust = Set(di, static_cast<TI>(Lanes(d) - 4));
+#if HWY_TARGET_IS_SVE
+ const Mask<decltype(di)> ge_4 = detail::GeN(indices, 4);
+#else
+ const Mask<decltype(di)> ge_4 = Ge(indices, Set(di, 4));
+#endif
+ indices = MaskedAddOr(indices, ge_4, indices, adjust);
+
+ return TwoTablesLookupLanes(d, t0, t1, IndicesFromVec(d, indices));
+ }
+}
+
+// ------------------------------ Lookup16
+
+template <class D, typename T = TFromD<D>, class VI>
+HWY_INLINE Vec<D> Lookup16(D d, const T* HWY_RESTRICT table, VI indices) {
+ // `di` describes the indices given - same bits per lane, but `d` determines
+ // the actual lane count of the result and also of the table vectors, which
+ // is relevant for adjusting the index values, see below.
+ DFromV<VI> di;
+ static_assert(sizeof(T) == sizeof(TFromD<decltype(di)>),
+ "Index/vector must have same lane size");
+ HWY_IF_CONSTEXPR(HWY_IS_DEBUG_BUILD) {
+ // Asserting Lanes(di) >= 8 not needed since both d and di have the same
+ // number of Lanes()
+ HWY_DASSERT(Lanes(d) >= 8);
+ HWY_DASSERT(AllTrue(di, Lt(indices, Set(di, 16))));
+ }
+
+ HWY_IF_CONSTEXPR(!HWY_HAVE_SCALABLE) {
+ // Fixed-size vectors: we know they are >= 128 bit, so either one or two
+ // tables are sufficient.
+ HWY_IF_CONSTEXPR(MaxLanes(d) >= 16) {
+ const CappedTag<T, 16> d16;
+ // We want to perform one lookup per index, hence cast. This has no
+ // runtime cost; the upper lanes are unused.
+ const Vec<D> t0 = ResizeBitCast(d, Load(d16, table));
+ return TableLookupLanes(t0, IndicesFromVec(d, indices));
+ }
+ HWY_IF_CONSTEXPR(MaxLanes(d) < 16) {
+ // Exactly 8 lanes per vector, because we ensured >= 8 above.
+ const Vec<D> t0 = Load(d, table);
+ const Vec<D> t1 = Load(d, table + 8);
+ return TwoTablesLookupLanes(d, t0, t1, IndicesFromVec(d, indices));
+ }
+ }
+
+ HWY_IF_CONSTEXPR(HWY_HAVE_SCALABLE) {
+ // Scalable: first we must load two halves of the table into two vectors,
+ // regardless of vector size. We always use two-vector lookups to avoid
+ // runtime branching. Note that RVV can have U32x16 even with 128-bit
+ // vectors (LMUL=4), hence we must use the given LMUL, not FixedTag, but we
+ // still want to cap at 8 lanes to avoid overrunning the table.
+ const CappedTag<T, 8, d.Pow2()> d8;
+
+ // We want to use native lookup instructions (more efficient on SVE than two
+ // lookups plus a blend), hence cast. This has no runtime cost. No LoadU
+ // required because + 8 is still aligned relative to `d8`.
+ const Vec<D> t0 = ResizeBitCast(d, Load(d8, table));
+ const Vec<D> t1 = ResizeBitCast(d, Load(d8, table + 8));
+
+ // Now ensure indices for the second half of the table point to the second
+ // vector. Note that SVE2_128 and SVE_256 are handled by the fixed-size case
+ // above. The adjustment factor is 0 for 128-bit SIMD, which can happen with
+ // 128-bit SVE1 hardware, but we do not know that at compile time.
+ using TI = TFromD<decltype(di)>;
+ const VI adjust = Set(di, static_cast<TI>(Lanes(d) - 8));
+#if HWY_TARGET_IS_SVE
+ const Mask<decltype(di)> ge_8 = detail::GeN(indices, 8);
+#else
+ const Mask<decltype(di)> ge_8 = Ge(indices, Set(di, 8));
+#endif
+ indices = MaskedAddOr(indices, ge_8, indices, adjust);
+
+ return TwoTablesLookupLanes(d, t0, t1, IndicesFromVec(d, indices));
+ }
+}
+
// ------------------------------ Reverse2, Reverse4, Reverse8 (8-bit)
#if (defined(HWY_NATIVE_REVERSE2_8) == defined(HWY_TARGET_TOGGLE)) || HWY_IDE
@@ -7398,7 +7661,8 @@ namespace detail {
// detail::BlockwiseConcatOddEven(d, v) returns the even lanes of each block of
// v followed by the odd lanes of v
-#if HWY_TARGET_IS_NEON || HWY_TARGET_IS_SVE || HWY_TARGET == HWY_RVV
+#if HWY_TARGET_IS_NEON || HWY_TARGET_IS_SVE || HWY_TARGET == HWY_RVV || \
+ HWY_TARGET == HWY_LSX || HWY_TARGET == HWY_LASX
template <class D, HWY_IF_T_SIZE_ONE_OF_D(D, (1 << 1) | (1 << 2) | (1 << 4)),
HWY_IF_V_SIZE_GT_D(D, 8)>
static HWY_INLINE HWY_MAYBE_UNUSED Vec<D> BlockwiseConcatOddEven(D d,
diff --git a/third_party/highway/hwy/ops/loongarch_lasx-inl.h b/third_party/highway/hwy/ops/loongarch_lasx-inl.h
new file mode 100644
index 0000000000..7de2ed6421
--- /dev/null
+++ b/third_party/highway/hwy/ops/loongarch_lasx-inl.h
@@ -0,0 +1,4686 @@
+// Copyright 2019 Google LLC
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// 256-bit LASX vectors and operations.
+// External include guard in highway.h - see comment there.
+
+#include "third_party/highway/hwy/ops/loongarch_lsx-inl.h"
+#include "third_party/highway/hwy/ops/shared-inl.h"
+
+#ifndef __loongarch_asx
+// If LASX is to be runtime dispatched (instead of in baseline), we need
+// to enable it *and* define __loongarch_asx or the intrinsic header will
+// fail to compile.
+//
+// For consistency, the same pattern as the lsxintrin.h handling in
+// loongarch_lsx-inl.h is used (instead of moving lasxintrin.h after
+// HWY_BEFORE_NAMESPACE).
+HWY_PUSH_ATTRIBUTES("lsx,lasx")
+#define __loongarch_asx
+#include <lasxintrin.h>
+#undef __loongarch_asx
+// Prevent "unused push_attribute" warning from Clang.
+HWY_MAYBE_UNUSED static void HWY_CONCAT(hwy_lasx_dummy, __COUNTER__) () {}
+HWY_POP_ATTRIBUTES
+#else
+#include <lasxintrin.h>
+#endif
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+namespace detail {
+
+template <typename T>
+struct Raw256 {
+ using type = __m256i;
+};
+template <>
+struct Raw256<float> {
+ using type = __m256;
+};
+template <>
+struct Raw256<double> {
+ using type = __m256d;
+};
+
+} // namespace detail
+
+template <typename T>
+class Vec256 {
+ using Raw = typename detail::Raw256<T>::type;
+
+ public:
+ using PrivateT = T; // only for DFromV
+ static constexpr size_t kPrivateN = 32 / sizeof(T); // only for DFromV
+
+ // Compound assignment. Only usable if there is a corresponding non-member
+ // binary operator overload. For example, only f32 and f64 support division.
+ HWY_INLINE Vec256& operator*=(const Vec256 other) {
+ return *this = (*this * other);
+ }
+ HWY_INLINE Vec256& operator/=(const Vec256 other) {
+ return *this = (*this / other);
+ }
+ HWY_INLINE Vec256& operator+=(const Vec256 other) {
+ return *this = (*this + other);
+ }
+ HWY_INLINE Vec256& operator-=(const Vec256 other) {
+ return *this = (*this - other);
+ }
+ HWY_INLINE Vec256& operator%=(const Vec256 other) {
+ return *this = (*this % other);
+ }
+ HWY_INLINE Vec256& operator&=(const Vec256 other) {
+ return *this = (*this & other);
+ }
+ HWY_INLINE Vec256& operator|=(const Vec256 other) {
+ return *this = (*this | other);
+ }
+ HWY_INLINE Vec256& operator^=(const Vec256 other) {
+ return *this = (*this ^ other);
+ }
+
+ Raw raw;
+};
+
+namespace detail {
+
+template <typename T>
+using RawMask256 = typename Raw256<T>::type;
+
+} // namespace detail
+
+template <typename T>
+struct Mask256 {
+ using Raw = typename detail::RawMask256<T>;
+
+ using PrivateT = T; // only for DFromM
+ static constexpr size_t kPrivateN = 32 / sizeof(T); // only for DFromM
+
+ Raw raw;
+};
+
+template <typename T>
+using Full256 = Simd<T, 32 / sizeof(T), 0>;
+
+// ------------------------------ Zero
+
+// Cannot use VFromD here because it is defined in terms of Zero.
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(D)>
+HWY_API Vec256<TFromD<D>> Zero(D /* tag */) {
+ return Vec256<TFromD<D>>{__lasx_xvreplgr2vr_d(0)};
+}
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_BF16_D(D)>
+HWY_API Vec256<bfloat16_t> Zero(D /* tag */) {
+ return Vec256<bfloat16_t>{__lasx_xvreplgr2vr_d(0)};
+}
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F16_D(D)>
+HWY_API Vec256<float16_t> Zero(D /* tag */) {
+ return Vec256<float16_t>{__lasx_xvreplgr2vr_d(0)};
+}
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)>
+HWY_API Vec256<float> Zero(D /* tag */) {
+ return Vec256<float>{reinterpret_cast<__m256>(__lasx_xvreplgr2vr_d(0))};
+}
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)>
+HWY_API Vec256<double> Zero(D /* tag */) {
+ return Vec256<double>{reinterpret_cast<__m256d>(__lasx_xvreplgr2vr_d(0))};
+}
+
+// ------------------------------ BitCast
+
+namespace detail {
+
+HWY_INLINE __m256i BitCastToInteger(__m256i v) { return v; }
+HWY_INLINE __m256i BitCastToInteger(__m256 v) {
+ return reinterpret_cast<__m256i>(v);
+}
+HWY_INLINE __m256i BitCastToInteger(__m256d v) {
+ return reinterpret_cast<__m256i>(v);
+}
+
+template <typename T>
+HWY_INLINE Vec256<uint8_t> BitCastToByte(Vec256<T> v) {
+ return Vec256<uint8_t>{BitCastToInteger(v.raw)};
+}
+
+// Cannot rely on function overloading because return types differ.
+template <typename T>
+struct BitCastFromInteger256 {
+ HWY_INLINE __m256i operator()(__m256i v) { return v; }
+};
+template <>
+struct BitCastFromInteger256<float> {
+ HWY_INLINE __m256 operator()(__m256i v) {
+ return reinterpret_cast<__m256>(v);
+ }
+};
+template <>
+struct BitCastFromInteger256<double> {
+ HWY_INLINE __m256d operator()(__m256i v) {
+ return reinterpret_cast<__m256d>(v);
+ }
+};
+
+template <class D, HWY_IF_V_SIZE_D(D, 32)>
+HWY_INLINE VFromD<D> BitCastFromByte(D /* tag */, Vec256<uint8_t> v) {
+ return VFromD<D>{BitCastFromInteger256<TFromD<D>>()(v.raw)};
+}
+
+} // namespace detail
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), typename FromT>
+HWY_API VFromD<D> BitCast(D d, Vec256<FromT> v) {
+ return detail::BitCastFromByte(d, detail::BitCastToByte(v));
+}
+
+// ------------------------------ Set
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 1)>
+HWY_API VFromD<D> Set(D /* tag */, TFromD<D> t) {
+ return VFromD<D>{__lasx_xvreplgr2vr_b(static_cast<char>(t))}; // NOLINT
+}
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI16_D(D)>
+HWY_API VFromD<D> Set(D /* tag */, TFromD<D> t) {
+ return VFromD<D>{__lasx_xvreplgr2vr_h(static_cast<short>(t))}; // NOLINT
+}
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI32_D(D)>
+HWY_API VFromD<D> Set(D /* tag */, TFromD<D> t) {
+ return VFromD<D>{__lasx_xvreplgr2vr_w(static_cast<int>(t))};
+}
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI64_D(D)>
+HWY_API VFromD<D> Set(D /* tag */, TFromD<D> t) {
+ return VFromD<D>{__lasx_xvreplgr2vr_d(static_cast<long long>(t))}; // NOLINT
+}
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)>
+HWY_API Vec256<float> Set(D /* tag */, float t) {
+ return BitCast(D(), Vec256<int32_t>{__lasx_xvldrepl_w(&t, 0)});
+}
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)>
+HWY_API Vec256<double> Set(D /* tag */, double t) {
+ return BitCast(D(), Vec256<int64_t>{__lasx_xvldrepl_d(&t, 0)});
+}
+
+// ------------------------------ ResizeBitCast
+
+// 32-byte vector to 32-byte vector
+template <class D, class FromV, HWY_IF_V_SIZE_GT_V(FromV, 16),
+ HWY_IF_V_SIZE_D(D, HWY_MAX_LANES_V(FromV) * sizeof(TFromV<FromV>))>
+HWY_API VFromD<D> ResizeBitCast(D d, FromV v) {
+ return BitCast(d, v);
+}
+
+// 32-byte vector to 16-byte vector
+template <class D, class FromV, HWY_IF_V_SIZE_GT_V(FromV, 16),
+ HWY_IF_V_SIZE_D(D, 16)>
+HWY_API VFromD<D> ResizeBitCast(D d, FromV v) {
+ const DFromV<decltype(v)> d_from;
+ const Half<decltype(d_from)> dh_from;
+ return BitCast(d, LowerHalf(dh_from, v));
+}
+
+// 32-byte vector to <= 8-byte vector
+template <class D, class FromV, HWY_IF_V_SIZE_GT_V(FromV, 16),
+ HWY_IF_V_SIZE_LE_D(D, 8)>
+HWY_API VFromD<D> ResizeBitCast(D /*d*/, FromV v) {
+ return VFromD<D>{ResizeBitCast(Full128<TFromD<D>>(), v).raw};
+}
+
+// <= 16-byte vector to 32-byte vector
+template <class D, class FromV, HWY_IF_V_SIZE_LE_V(FromV, 16),
+ HWY_IF_V_SIZE_D(D, 32)>
+HWY_API VFromD<D> ResizeBitCast(D d, FromV v) {
+ typedef uint64_t GccRawU64M128Vec __attribute__((__vector_size__(16)));
+
+ const GccRawU64M128Vec raw_v0 = reinterpret_cast<GccRawU64M128Vec>(v.raw);
+#if HWY_COMPILER_CLANG && HWY_HAS_BUILTIN(__builtin_nondeterministic_value)
+ const GccRawU64M128Vec raw_v1 = __builtin_nondeterministic_value(raw_v0);
+#else
+ const GccRawU64M128Vec raw_v1 = raw_v0;
+#endif
+
+ const Repartition<uint64_t, decltype(d)> du64;
+ const Half<decltype(du64)> dh_u64;
+ return BitCast(
+ d,
+ Combine(du64, VFromD<decltype(dh_u64)>{reinterpret_cast<__m128i>(raw_v1)},
+ VFromD<decltype(dh_u64)>{reinterpret_cast<__m128i>(raw_v0)}));
+}
+
+// ------------------------------ Dup128VecFromValues
+
+template <class D, HWY_IF_UI8_D(D), HWY_IF_V_SIZE_D(D, 32)>
+HWY_API VFromD<D> Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1,
+ TFromD<D> t2, TFromD<D> t3, TFromD<D> t4,
+ TFromD<D> t5, TFromD<D> t6, TFromD<D> t7,
+ TFromD<D> t8, TFromD<D> t9, TFromD<D> t10,
+ TFromD<D> t11, TFromD<D> t12,
+ TFromD<D> t13, TFromD<D> t14,
+ TFromD<D> t15) {
+ typedef int8_t GccI8RawVectType __attribute__((__vector_size__(32)));
+ GccI8RawVectType raw_i8_vec = {
+ static_cast<char>(t0), static_cast<char>(t1), static_cast<char>(t2),
+ static_cast<char>(t3), static_cast<char>(t4), static_cast<char>(t5),
+ static_cast<char>(t6), static_cast<char>(t7), static_cast<char>(t8),
+ static_cast<char>(t9), static_cast<char>(t10), static_cast<char>(t11),
+ static_cast<char>(t12), static_cast<char>(t13), static_cast<char>(t14),
+ static_cast<char>(t15), static_cast<char>(t0), static_cast<char>(t1),
+ static_cast<char>(t2), static_cast<char>(t3), static_cast<char>(t4),
+ static_cast<char>(t5), static_cast<char>(t6), static_cast<char>(t7),
+ static_cast<char>(t8), static_cast<char>(t9), static_cast<char>(t10),
+ static_cast<char>(t11), static_cast<char>(t12), static_cast<char>(t13),
+ static_cast<char>(t14), static_cast<char>(t15)};
+ return VFromD<D>{reinterpret_cast<__m256i>(raw_i8_vec)};
+}
+
+template <class D, HWY_IF_UI16_D(D), HWY_IF_V_SIZE_D(D, 32)>
+HWY_API VFromD<D> Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1,
+ TFromD<D> t2, TFromD<D> t3, TFromD<D> t4,
+ TFromD<D> t5, TFromD<D> t6,
+ TFromD<D> t7) {
+ typedef int16_t GccI16RawVectType __attribute__((__vector_size__(32)));
+ GccI16RawVectType raw_i16_vec = {
+ static_cast<int16_t>(t0), static_cast<int16_t>(t1),
+ static_cast<int16_t>(t2), static_cast<int16_t>(t3),
+ static_cast<int16_t>(t4), static_cast<int16_t>(t5),
+ static_cast<int16_t>(t6), static_cast<int16_t>(t7),
+ static_cast<int16_t>(t0), static_cast<int16_t>(t1),
+ static_cast<int16_t>(t2), static_cast<int16_t>(t3),
+ static_cast<int16_t>(t4), static_cast<int16_t>(t5),
+ static_cast<int16_t>(t6), static_cast<int16_t>(t7)};
+ return VFromD<D>{reinterpret_cast<__m256i>(raw_i16_vec)};
+}
+
+template <class D, HWY_IF_UI32_D(D), HWY_IF_V_SIZE_D(D, 32)>
+HWY_API VFromD<D> Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1,
+ TFromD<D> t2, TFromD<D> t3) {
+ typedef int32_t GccI32RawVectType __attribute__((__vector_size__(32)));
+ GccI32RawVectType raw_i32_vec = {
+ static_cast<int32_t>(t0), static_cast<int32_t>(t1),
+ static_cast<int32_t>(t2), static_cast<int32_t>(t3),
+ static_cast<int32_t>(t0), static_cast<int32_t>(t1),
+ static_cast<int32_t>(t2), static_cast<int32_t>(t3)};
+ return VFromD<D>{reinterpret_cast<__m256i>(raw_i32_vec)};
+}
+
+template <class D, HWY_IF_F32_D(D), HWY_IF_V_SIZE_D(D, 32)>
+HWY_API VFromD<D> Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1,
+ TFromD<D> t2, TFromD<D> t3) {
+ typedef float GccF32RawVectType __attribute__((__vector_size__(32)));
+ GccF32RawVectType raw_f32_vec = {t0, t1, t2, t3, t0, t1, t2, t3};
+ return Vec256<float>{reinterpret_cast<__m256>(raw_f32_vec)};
+}
+
+template <class D, HWY_IF_UI64_D(D), HWY_IF_V_SIZE_D(D, 32)>
+HWY_API VFromD<D> Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1) {
+ typedef int64_t GccI64RawVectType __attribute__((__vector_size__(32)));
+ const GccI64RawVectType raw_i64_vec = {
+ static_cast<int64_t>(t0), static_cast<int64_t>(t1),
+ static_cast<int64_t>(t0), static_cast<int64_t>(t1)};
+ return VFromD<D>{reinterpret_cast<__m256i>(raw_i64_vec)};
+}
+
+template <class D, HWY_IF_F64_D(D), HWY_IF_V_SIZE_D(D, 32)>
+HWY_API VFromD<D> Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1) {
+ typedef double GccF64RawVectType __attribute__((__vector_size__(32)));
+ const GccF64RawVectType raw_f64_vec = {t0, t1, t0, t1};
+ return VFromD<D>{reinterpret_cast<__m256d>(raw_f64_vec)};
+}
+
+// ------------------------------ And
+
+template <typename T>
+HWY_API Vec256<T> And(Vec256<T> a, Vec256<T> b) {
+ const DFromV<decltype(a)> d; // for float16_t
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(d, VFromD<decltype(du)>{__lasx_xvand_v(BitCast(du, a).raw,
+ BitCast(du, b).raw)});
+}
+
+// ------------------------------ AndNot
+
+// Returns ~not_mask & mask.
+template <typename T>
+HWY_API Vec256<T> AndNot(Vec256<T> not_mask, Vec256<T> mask) {
+ const DFromV<decltype(mask)> d; // for float16_t
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(d, VFromD<decltype(du)>{__lasx_xvandn_v(
+ BitCast(du, not_mask).raw, BitCast(du, mask).raw)});
+}
+
+// ------------------------------ Or
+
+template <typename T>
+HWY_API Vec256<T> Or(Vec256<T> a, Vec256<T> b) {
+ const DFromV<decltype(a)> d; // for float16_t
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(d, VFromD<decltype(du)>{
+ __lasx_xvor_v(BitCast(du, a).raw, BitCast(du, b).raw)});
+}
+
+// ------------------------------ Xor
+
+template <typename T>
+HWY_API Vec256<T> Xor(Vec256<T> a, Vec256<T> b) {
+ const DFromV<decltype(a)> d; // for float16_t
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(d, VFromD<decltype(du)>{__lasx_xvxor_v(BitCast(du, a).raw,
+ BitCast(du, b).raw)});
+}
+
+// ------------------------------ Not
+template <typename T>
+HWY_API Vec256<T> Not(const Vec256<T> v) {
+ const DFromV<decltype(v)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(d, VFromD<decltype(du)>{__lasx_xvnor_v(BitCast(du, v).raw,
+ BitCast(du, v).raw)});
+}
+
+// ------------------------------ Or3
+template <typename T>
+HWY_API Vec256<T> Or3(Vec256<T> o1, Vec256<T> o2, Vec256<T> o3) {
+ return Or(o1, Or(o2, o3));
+}
+
+// ------------------------------ OrAnd
+template <typename T>
+HWY_API Vec256<T> OrAnd(Vec256<T> o, Vec256<T> a1, Vec256<T> a2) {
+ return Or(o, And(a1, a2));
+}
+
+// ------------------------------ IfVecThenElse
+template <typename T>
+HWY_API Vec256<T> IfVecThenElse(Vec256<T> mask, Vec256<T> yes, Vec256<T> no) {
+ return IfThenElse(MaskFromVec(mask), yes, no);
+}
+
+// ------------------------------ Operator overloads (internal-only if float)
+
+template <typename T>
+HWY_API Vec256<T> operator&(const Vec256<T> a, const Vec256<T> b) {
+ return And(a, b);
+}
+
+template <typename T>
+HWY_API Vec256<T> operator|(const Vec256<T> a, const Vec256<T> b) {
+ return Or(a, b);
+}
+
+template <typename T>
+HWY_API Vec256<T> operator^(const Vec256<T> a, const Vec256<T> b) {
+ return Xor(a, b);
+}
+
+// ------------------------------ PopulationCount
+
+namespace detail {
+
+template <typename T>
+HWY_INLINE Vec256<T> PopulationCount(hwy::SizeTag<1> /* tag */, Vec256<T> v) {
+ return Vec256<T>{__lasx_xvpcnt_b(v.raw)};
+}
+template <typename T>
+HWY_INLINE Vec256<T> PopulationCount(hwy::SizeTag<2> /* tag */, Vec256<T> v) {
+ return Vec256<T>{__lasx_xvpcnt_h(v.raw)};
+}
+template <typename T>
+HWY_INLINE Vec256<T> PopulationCount(hwy::SizeTag<4> /* tag */, Vec256<T> v) {
+ return Vec256<T>{__lasx_xvpcnt_w(v.raw)};
+}
+template <typename T>
+HWY_INLINE Vec256<T> PopulationCount(hwy::SizeTag<8> /* tag */, Vec256<T> v) {
+ return Vec256<T>{__lasx_xvpcnt_d(v.raw)};
+}
+
+} // namespace detail
+
+template <typename T>
+HWY_API Vec256<T> PopulationCount(Vec256<T> v) {
+ return detail::PopulationCount(hwy::SizeTag<sizeof(T)>(), v);
+}
+
+// ------------------------------ Mask
+
+// Mask and Vec are the same (true = FF..FF).
+template <typename T>
+HWY_API Mask256<T> MaskFromVec(const Vec256<T> v) {
+ return Mask256<T>{v.raw};
+}
+
+template <typename T>
+HWY_API Vec256<T> VecFromMask(const Mask256<T> v) {
+ return Vec256<T>{v.raw};
+}
+
+// ------------------------------ IfThenElse
+
+// mask ? yes : no
+template <typename T>
+HWY_API Vec256<T> IfThenElse(Mask256<T> mask, Vec256<T> yes, Vec256<T> no) {
+ const DFromV<decltype(yes)> d;
+ RebindToSigned<decltype(d)> di;
+ return BitCast(d, VFromD<decltype(di)>{__lasx_xvbitsel_v(
+ BitCast(di, no).raw, BitCast(di, yes).raw,
+ RebindMask(di, mask).raw)});
+}
+
+// mask ? yes : 0
+template <typename T>
+HWY_API Vec256<T> IfThenElseZero(Mask256<T> mask, Vec256<T> yes) {
+ return yes & VecFromMask(mask);
+}
+
+// mask ? 0 : no
+template <typename T>
+HWY_API Vec256<T> IfThenZeroElse(Mask256<T> mask, Vec256<T> no) {
+ return AndNot(VecFromMask(mask), no);
+}
+
+template <typename T>
+HWY_API Vec256<T> ZeroIfNegative(Vec256<T> v) {
+ static_assert(IsSigned<T>(), "Only for float");
+ const DFromV<decltype(v)> d;
+ const auto zero = Zero(d);
+ return IfThenElse(v < zero, zero, v);
+}
+
+// ------------------------------ Mask logical
+
+template <typename T>
+HWY_API Mask256<T> Not(const Mask256<T> m) {
+ const Full256<T> d;
+ return MaskFromVec(Not(VecFromMask(d, m)));
+}
+
+template <typename T>
+HWY_API Mask256<T> And(const Mask256<T> a, Mask256<T> b) {
+ const Full256<T> d;
+ return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b)));
+}
+
+template <typename T>
+HWY_API Mask256<T> AndNot(const Mask256<T> a, Mask256<T> b) {
+ const Full256<T> d;
+ return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b)));
+}
+
+template <typename T>
+HWY_API Mask256<T> Or(const Mask256<T> a, Mask256<T> b) {
+ const Full256<T> d;
+ return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b)));
+}
+
+template <typename T>
+HWY_API Mask256<T> Xor(const Mask256<T> a, Mask256<T> b) {
+ const Full256<T> d;
+ return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b)));
+}
+
+template <typename T>
+HWY_API Mask256<T> ExclusiveNeither(const Mask256<T> a, Mask256<T> b) {
+ const Full256<T> d;
+ return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b))));
+}
+
+// ================================================== COMPARE
+
+// Comparisons fill a lane with 1-bits if the condition is true, else 0.
+
+template <class DTo, HWY_IF_V_SIZE_D(DTo, 32), typename TFrom>
+HWY_API MFromD<DTo> RebindMask(DTo d_to, Mask256<TFrom> m) {
+ static_assert(sizeof(TFrom) == sizeof(TFromD<DTo>), "Must have same size");
+ const Full256<TFrom> dfrom;
+ return MaskFromVec(BitCast(d_to, VecFromMask(dfrom, m)));
+}
+
+template <typename T>
+HWY_API Mask256<T> TestBit(const Vec256<T> v, const Vec256<T> bit) {
+ static_assert(!hwy::IsFloat<T>(), "Only integer vectors supported");
+ return (v & bit) == bit;
+}
+
+// ------------------------------ Equality
+
+template <typename T, HWY_IF_T_SIZE(T, 1)>
+HWY_API Mask256<T> operator==(Vec256<T> a, Vec256<T> b) {
+ return Mask256<T>{__lasx_xvseq_b(a.raw, b.raw)};
+}
+
+template <typename T, HWY_IF_UI16(T)>
+HWY_API Mask256<T> operator==(Vec256<T> a, Vec256<T> b) {
+ return Mask256<T>{__lasx_xvseq_h(a.raw, b.raw)};
+}
+
+template <typename T, HWY_IF_UI32(T)>
+HWY_API Mask256<T> operator==(Vec256<T> a, Vec256<T> b) {
+ return Mask256<T>{__lasx_xvseq_w(a.raw, b.raw)};
+}
+
+template <typename T, HWY_IF_UI64(T)>
+HWY_API Mask256<T> operator==(Vec256<T> a, Vec256<T> b) {
+ return Mask256<T>{__lasx_xvseq_d(a.raw, b.raw)};
+}
+
+HWY_API Mask256<float> operator==(Vec256<float> a, Vec256<float> b) {
+ const DFromV<decltype(a)> d;
+ const RebindToSigned<decltype(d)> di;
+ return RebindMask(d, MFromD<decltype(di)>{__lasx_xvfcmp_ceq_s(a.raw, b.raw)});
+}
+
+HWY_API Mask256<double> operator==(Vec256<double> a, Vec256<double> b) {
+ const DFromV<decltype(a)> d;
+ const RebindToSigned<decltype(d)> di;
+ return RebindMask(d, MFromD<decltype(di)>{__lasx_xvfcmp_ceq_d(a.raw, b.raw)});
+}
+
+// ------------------------------ Inequality
+
+template <typename T, HWY_IF_NOT_FLOAT3264(T)>
+HWY_API Mask256<T> operator!=(Vec256<T> a, Vec256<T> b) {
+ return Not(a == b);
+}
+HWY_API Mask256<float> operator!=(Vec256<float> a, Vec256<float> b) {
+ const DFromV<decltype(a)> d;
+ const RebindToSigned<decltype(d)> di;
+ return RebindMask(d, MFromD<decltype(di)>{__lasx_xvfcmp_cne_s(a.raw, b.raw)});
+}
+HWY_API Mask256<double> operator!=(Vec256<double> a, Vec256<double> b) {
+ const DFromV<decltype(a)> d;
+ const RebindToSigned<decltype(d)> di;
+ return RebindMask(d, MFromD<decltype(di)>{__lasx_xvfcmp_cne_d(a.raw, b.raw)});
+}
+
+// ------------------------------ Strict inequality
+
+namespace detail {
+
+HWY_API Mask256<int8_t> Gt(hwy::SignedTag /*tag*/, Vec256<int8_t> a,
+ Vec256<int8_t> b) {
+ return Mask256<int8_t>{__lasx_xvslt_b(b.raw, a.raw)};
+}
+HWY_API Mask256<int16_t> Gt(hwy::SignedTag /*tag*/, Vec256<int16_t> a,
+ Vec256<int16_t> b) {
+ return Mask256<int16_t>{__lasx_xvslt_h(b.raw, a.raw)};
+}
+HWY_API Mask256<int32_t> Gt(hwy::SignedTag /*tag*/, Vec256<int32_t> a,
+ Vec256<int32_t> b) {
+ return Mask256<int32_t>{__lasx_xvslt_w(b.raw, a.raw)};
+}
+HWY_API Mask256<int64_t> Gt(hwy::SignedTag /*tag*/, Vec256<int64_t> a,
+ Vec256<int64_t> b) {
+ return Mask256<int64_t>{__lasx_xvslt_d(b.raw, a.raw)};
+}
+
+HWY_API Mask256<uint8_t> Gt(hwy::UnsignedTag /*tag*/, Vec256<uint8_t> a,
+ Vec256<uint8_t> b) {
+ return Mask256<uint8_t>{__lasx_xvslt_bu(b.raw, a.raw)};
+}
+HWY_API Mask256<uint16_t> Gt(hwy::UnsignedTag /*tag*/, Vec256<uint16_t> a,
+ Vec256<uint16_t> b) {
+ return Mask256<uint16_t>{__lasx_xvslt_hu(b.raw, a.raw)};
+}
+HWY_API Mask256<uint32_t> Gt(hwy::UnsignedTag /*tag*/, Vec256<uint32_t> a,
+ Vec256<uint32_t> b) {
+ return Mask256<uint32_t>{__lasx_xvslt_wu(b.raw, a.raw)};
+}
+HWY_API Mask256<uint64_t> Gt(hwy::UnsignedTag /*tag*/, Vec256<uint64_t> a,
+ Vec256<uint64_t> b) {
+ return Mask256<uint64_t>{__lasx_xvslt_du(b.raw, a.raw)};
+}
+
+HWY_API Mask256<float> Gt(hwy::FloatTag /*tag*/, Vec256<float> a,
+ Vec256<float> b) {
+ const DFromV<decltype(a)> d;
+ const RebindToSigned<decltype(d)> di;
+ return RebindMask(d, MFromD<decltype(di)>{__lasx_xvfcmp_clt_s(b.raw, a.raw)});
+}
+HWY_API Mask256<double> Gt(hwy::FloatTag /*tag*/, Vec256<double> a,
+ Vec256<double> b) {
+ const DFromV<decltype(a)> d;
+ const RebindToSigned<decltype(d)> di;
+ return RebindMask(d, MFromD<decltype(di)>{__lasx_xvfcmp_clt_d(b.raw, a.raw)});
+}
+
+} // namespace detail
+
+template <typename T>
+HWY_API Mask256<T> operator>(Vec256<T> a, Vec256<T> b) {
+ return detail::Gt(hwy::TypeTag<T>(), a, b);
+}
+
+// ------------------------------ Weak inequality
+
+namespace detail {
+
+template <typename T>
+HWY_INLINE Mask256<T> Ge(hwy::SignedTag /*tag*/, Vec256<T> a, Vec256<T> b) {
+ return Not(b > a);
+}
+
+template <typename T>
+HWY_INLINE Mask256<T> Ge(hwy::UnsignedTag /*tag*/, Vec256<T> a, Vec256<T> b) {
+ return Not(b > a);
+}
+
+HWY_INLINE Mask256<float> Ge(hwy::FloatTag /*tag*/, Vec256<float> a,
+ Vec256<float> b) {
+ const DFromV<decltype(a)> d;
+ const RebindToSigned<decltype(d)> di;
+ return RebindMask(d, MFromD<decltype(di)>{__lasx_xvfcmp_cle_s(b.raw, a.raw)});
+}
+HWY_INLINE Mask256<double> Ge(hwy::FloatTag /*tag*/, Vec256<double> a,
+ Vec256<double> b) {
+ const DFromV<decltype(a)> d;
+ const RebindToSigned<decltype(d)> di;
+ return RebindMask(d, MFromD<decltype(di)>{__lasx_xvfcmp_cle_d(b.raw, a.raw)});
+}
+
+} // namespace detail
+
+template <typename T>
+HWY_API Mask256<T> operator>=(Vec256<T> a, Vec256<T> b) {
+ return detail::Ge(hwy::TypeTag<T>(), a, b);
+}
+
+// ------------------------------ Reversed comparisons
+
+template <typename T>
+HWY_API Mask256<T> operator<(const Vec256<T> a, const Vec256<T> b) {
+ return b > a;
+}
+
+template <typename T>
+HWY_API Mask256<T> operator<=(const Vec256<T> a, const Vec256<T> b) {
+ return b >= a;
+}
+
+// ------------------------------ Min (Gt, IfThenElse)
+
+// Unsigned
+HWY_API Vec256<uint8_t> Min(const Vec256<uint8_t> a, const Vec256<uint8_t> b) {
+ return Vec256<uint8_t>{__lasx_xvmin_bu(a.raw, b.raw)};
+}
+HWY_API Vec256<uint16_t> Min(const Vec256<uint16_t> a,
+ const Vec256<uint16_t> b) {
+ return Vec256<uint16_t>{__lasx_xvmin_hu(a.raw, b.raw)};
+}
+HWY_API Vec256<uint32_t> Min(const Vec256<uint32_t> a,
+ const Vec256<uint32_t> b) {
+ return Vec256<uint32_t>{__lasx_xvmin_wu(a.raw, b.raw)};
+}
+HWY_API Vec256<uint64_t> Min(const Vec256<uint64_t> a,
+ const Vec256<uint64_t> b) {
+ return Vec256<uint64_t>{__lasx_xvmin_du(a.raw, b.raw)};
+}
+
+// Signed
+HWY_API Vec256<int8_t> Min(const Vec256<int8_t> a, const Vec256<int8_t> b) {
+ return Vec256<int8_t>{__lasx_xvmin_b(a.raw, b.raw)};
+}
+HWY_API Vec256<int16_t> Min(const Vec256<int16_t> a, const Vec256<int16_t> b) {
+ return Vec256<int16_t>{__lasx_xvmin_h(a.raw, b.raw)};
+}
+HWY_API Vec256<int32_t> Min(const Vec256<int32_t> a, const Vec256<int32_t> b) {
+ return Vec256<int32_t>{__lasx_xvmin_w(a.raw, b.raw)};
+}
+HWY_API Vec256<int64_t> Min(const Vec256<int64_t> a, const Vec256<int64_t> b) {
+ return Vec256<int64_t>{__lasx_xvmin_d(a.raw, b.raw)};
+}
+
+// Float
+HWY_API Vec256<float> Min(const Vec256<float> a, const Vec256<float> b) {
+ return Vec256<float>{__lasx_xvfmin_s(a.raw, b.raw)};
+}
+HWY_API Vec256<double> Min(const Vec256<double> a, const Vec256<double> b) {
+ return Vec256<double>{__lasx_xvfmin_d(a.raw, b.raw)};
+}
+
+// ------------------------------ Max (Gt, IfThenElse)
+
+// Unsigned
+HWY_API Vec256<uint8_t> Max(const Vec256<uint8_t> a, const Vec256<uint8_t> b) {
+ return Vec256<uint8_t>{__lasx_xvmax_bu(a.raw, b.raw)};
+}
+HWY_API Vec256<uint16_t> Max(const Vec256<uint16_t> a,
+ const Vec256<uint16_t> b) {
+ return Vec256<uint16_t>{__lasx_xvmax_hu(a.raw, b.raw)};
+}
+HWY_API Vec256<uint32_t> Max(const Vec256<uint32_t> a,
+ const Vec256<uint32_t> b) {
+ return Vec256<uint32_t>{__lasx_xvmax_wu(a.raw, b.raw)};
+}
+HWY_API Vec256<uint64_t> Max(const Vec256<uint64_t> a,
+ const Vec256<uint64_t> b) {
+ return Vec256<uint64_t>{__lasx_xvmax_du(a.raw, b.raw)};
+}
+
+// Signed
+HWY_API Vec256<int8_t> Max(const Vec256<int8_t> a, const Vec256<int8_t> b) {
+ return Vec256<int8_t>{__lasx_xvmax_b(a.raw, b.raw)};
+}
+HWY_API Vec256<int16_t> Max(const Vec256<int16_t> a, const Vec256<int16_t> b) {
+ return Vec256<int16_t>{__lasx_xvmax_h(a.raw, b.raw)};
+}
+HWY_API Vec256<int32_t> Max(const Vec256<int32_t> a, const Vec256<int32_t> b) {
+ return Vec256<int32_t>{__lasx_xvmax_w(a.raw, b.raw)};
+}
+HWY_API Vec256<int64_t> Max(const Vec256<int64_t> a, const Vec256<int64_t> b) {
+ return Vec256<int64_t>{__lasx_xvmax_d(a.raw, b.raw)};
+}
+
+// Float
+HWY_API Vec256<float> Max(const Vec256<float> a, const Vec256<float> b) {
+ return Vec256<float>{__lasx_xvfmax_s(a.raw, b.raw)};
+}
+HWY_API Vec256<double> Max(const Vec256<double> a, const Vec256<double> b) {
+ return Vec256<double>{__lasx_xvfmax_d(a.raw, b.raw)};
+}
+
+// ------------------------------ MinMagnitude and MaxMagnitude
+
+HWY_API Vec256<float> MinMagnitude(Vec256<float> a, Vec256<float> b) {
+ return Vec256<float>{__lasx_xvfmina_s(a.raw, b.raw)};
+}
+HWY_API Vec256<double> MinMagnitude(Vec256<double> a, Vec256<double> b) {
+ return Vec256<double>{__lasx_xvfmina_d(a.raw, b.raw)};
+}
+
+HWY_API Vec256<float> MaxMagnitude(Vec256<float> a, Vec256<float> b) {
+ return Vec256<float>{__lasx_xvfmaxa_s(a.raw, b.raw)};
+}
+HWY_API Vec256<double> MaxMagnitude(Vec256<double> a, Vec256<double> b) {
+ return Vec256<double>{__lasx_xvfmaxa_d(a.raw, b.raw)};
+}
+
+// ------------------------------ Iota
+
+namespace detail {
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 1)>
+HWY_INLINE VFromD<D> Iota0(D /*d*/) {
+ typedef int8_t GccI8RawVectType __attribute__((__vector_size__(32)));
+ const GccI8RawVectType raw_i8_vec = {
+ static_cast<char>(0), static_cast<char>(1), static_cast<char>(2),
+ static_cast<char>(3), static_cast<char>(4), static_cast<char>(5),
+ static_cast<char>(6), static_cast<char>(7), static_cast<char>(8),
+ static_cast<char>(9), static_cast<char>(10), static_cast<char>(11),
+ static_cast<char>(12), static_cast<char>(13), static_cast<char>(14),
+ static_cast<char>(15), static_cast<char>(16), static_cast<char>(17),
+ static_cast<char>(18), static_cast<char>(19), static_cast<char>(20),
+ static_cast<char>(21), static_cast<char>(22), static_cast<char>(23),
+ static_cast<char>(24), static_cast<char>(25), static_cast<char>(26),
+ static_cast<char>(27), static_cast<char>(28), static_cast<char>(29),
+ static_cast<char>(30), static_cast<char>(31)};
+ return VFromD<D>{reinterpret_cast<__m256i>(raw_i8_vec)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI16_D(D)>
+HWY_INLINE VFromD<D> Iota0(D /*d*/) {
+ typedef int16_t GccI16RawVectType __attribute__((__vector_size__(32)));
+ const GccI16RawVectType raw_i16_vec = {
+ static_cast<int16_t>(0), static_cast<int16_t>(1),
+ static_cast<int16_t>(2), static_cast<int16_t>(3),
+ static_cast<int16_t>(4), static_cast<int16_t>(5),
+ static_cast<int16_t>(6), static_cast<int16_t>(7),
+ static_cast<int16_t>(8), static_cast<int16_t>(9),
+ static_cast<int16_t>(10), static_cast<int16_t>(11),
+ static_cast<int16_t>(12), static_cast<int16_t>(13),
+ static_cast<int16_t>(14), static_cast<int16_t>(15)};
+ return VFromD<D>{reinterpret_cast<__m256i>(raw_i16_vec)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI32_D(D)>
+HWY_INLINE VFromD<D> Iota0(D /*d*/) {
+ typedef int32_t GccI32RawVectType __attribute__((__vector_size__(32)));
+ const GccI32RawVectType raw_i32_vec = {
+ static_cast<int32_t>(0), static_cast<int32_t>(1), static_cast<int32_t>(2),
+ static_cast<int32_t>(3), static_cast<int32_t>(4), static_cast<int32_t>(5),
+ static_cast<int32_t>(6), static_cast<int32_t>(7)};
+ return VFromD<D>{reinterpret_cast<__m256i>(raw_i32_vec)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI64_D(D)>
+HWY_INLINE VFromD<D> Iota0(D /*d*/) {
+ typedef int64_t GccI64RawVectType __attribute__((__vector_size__(32)));
+ const GccI64RawVectType raw_i64_vec = {
+ static_cast<int64_t>(0), static_cast<int64_t>(1), static_cast<int64_t>(2),
+ static_cast<int64_t>(3)};
+ return VFromD<D>{reinterpret_cast<__m256i>(raw_i64_vec)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)>
+HWY_INLINE VFromD<D> Iota0(D /*d*/) {
+ typedef float GccF32RawVectType __attribute__((__vector_size__(32)));
+ const GccF32RawVectType raw_f32_vec = {0.0f, 1.0f, 2.0f, 3.0f,
+ 4.0f, 5.0f, 6.0f, 7.0f};
+ return VFromD<D>{reinterpret_cast<__m256>(raw_f32_vec)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)>
+HWY_INLINE VFromD<D> Iota0(D /*d*/) {
+ typedef double GccF64RawVectType __attribute__((__vector_size__(32)));
+ const GccF64RawVectType raw_f64_vec = {0.0, 1.0, 2.0, 3.0};
+ return VFromD<D>{reinterpret_cast<__m256d>(raw_f64_vec)};
+}
+
+} // namespace detail
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), typename T2>
+HWY_API VFromD<D> Iota(D d, const T2 first) {
+ return detail::Iota0(d) + Set(d, ConvertScalarTo<TFromD<D>>(first));
+}
+
+// ------------------------------ FirstN (Iota, Lt)
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), class M = MFromD<D>>
+HWY_API M FirstN(const D d, size_t n) {
+ constexpr size_t kN = MaxLanes(d);
+ n = HWY_MIN(n, kN);
+ const RebindToSigned<decltype(d)> di; // Signed comparisons are cheaper.
+ using TI = TFromD<decltype(di)>;
+ return RebindMask(d, detail::Iota0(di) < Set(di, static_cast<TI>(n)));
+}
+
+// ================================================== ARITHMETIC
+
+// ------------------------------ Addition
+
+// Unsigned
+HWY_API Vec256<uint8_t> operator+(Vec256<uint8_t> a, Vec256<uint8_t> b) {
+ return Vec256<uint8_t>{__lasx_xvadd_b(a.raw, b.raw)};
+}
+HWY_API Vec256<uint16_t> operator+(Vec256<uint16_t> a, Vec256<uint16_t> b) {
+ return Vec256<uint16_t>{__lasx_xvadd_h(a.raw, b.raw)};
+}
+HWY_API Vec256<uint32_t> operator+(Vec256<uint32_t> a, Vec256<uint32_t> b) {
+ return Vec256<uint32_t>{__lasx_xvadd_w(a.raw, b.raw)};
+}
+HWY_API Vec256<uint64_t> operator+(Vec256<uint64_t> a, Vec256<uint64_t> b) {
+ return Vec256<uint64_t>{__lasx_xvadd_d(a.raw, b.raw)};
+}
+
+// Signed
+HWY_API Vec256<int8_t> operator+(Vec256<int8_t> a, Vec256<int8_t> b) {
+ return Vec256<int8_t>{__lasx_xvadd_b(a.raw, b.raw)};
+}
+HWY_API Vec256<int16_t> operator+(Vec256<int16_t> a, Vec256<int16_t> b) {
+ return Vec256<int16_t>{__lasx_xvadd_h(a.raw, b.raw)};
+}
+HWY_API Vec256<int32_t> operator+(Vec256<int32_t> a, Vec256<int32_t> b) {
+ return Vec256<int32_t>{__lasx_xvadd_w(a.raw, b.raw)};
+}
+HWY_API Vec256<int64_t> operator+(Vec256<int64_t> a, Vec256<int64_t> b) {
+ return Vec256<int64_t>{__lasx_xvadd_d(a.raw, b.raw)};
+}
+
+HWY_API Vec256<float> operator+(Vec256<float> a, Vec256<float> b) {
+ return Vec256<float>{__lasx_xvfadd_s(a.raw, b.raw)};
+}
+HWY_API Vec256<double> operator+(Vec256<double> a, Vec256<double> b) {
+ return Vec256<double>{__lasx_xvfadd_d(a.raw, b.raw)};
+}
+
+template <typename T>
+HWY_API Vec256<T> Add(Vec256<T> a, Vec256<T> b) {
+ return a + b;
+}
+
+// ------------------------------ Subtraction
+
+// Unsigne
+HWY_API Vec256<uint8_t> operator-(Vec256<uint8_t> a, Vec256<uint8_t> b) {
+ return Vec256<uint8_t>{__lasx_xvsub_b(a.raw, b.raw)};
+}
+HWY_API Vec256<uint16_t> operator-(Vec256<uint16_t> a, Vec256<uint16_t> b) {
+ return Vec256<uint16_t>{__lasx_xvsub_h(a.raw, b.raw)};
+}
+HWY_API Vec256<uint32_t> operator-(Vec256<uint32_t> a, Vec256<uint32_t> b) {
+ return Vec256<uint32_t>{__lasx_xvsub_w(a.raw, b.raw)};
+}
+HWY_API Vec256<uint64_t> operator-(Vec256<uint64_t> a, Vec256<uint64_t> b) {
+ return Vec256<uint64_t>{__lasx_xvsub_d(a.raw, b.raw)};
+}
+
+// Signed
+HWY_API Vec256<int8_t> operator-(Vec256<int8_t> a, Vec256<int8_t> b) {
+ return Vec256<int8_t>{__lasx_xvsub_b(a.raw, b.raw)};
+}
+HWY_API Vec256<int16_t> operator-(Vec256<int16_t> a, Vec256<int16_t> b) {
+ return Vec256<int16_t>{__lasx_xvsub_h(a.raw, b.raw)};
+}
+HWY_API Vec256<int32_t> operator-(Vec256<int32_t> a, Vec256<int32_t> b) {
+ return Vec256<int32_t>{__lasx_xvsub_w(a.raw, b.raw)};
+}
+HWY_API Vec256<int64_t> operator-(Vec256<int64_t> a, Vec256<int64_t> b) {
+ return Vec256<int64_t>{__lasx_xvsub_d(a.raw, b.raw)};
+}
+
+HWY_API Vec256<float> operator-(Vec256<float> a, Vec256<float> b) {
+ return Vec256<float>{__lasx_xvfsub_s(a.raw, b.raw)};
+}
+HWY_API Vec256<double> operator-(Vec256<double> a, Vec256<double> b) {
+ return Vec256<double>{__lasx_xvfsub_d(a.raw, b.raw)};
+}
+
+// ------------------------------ SumsOf8
+HWY_API Vec256<uint64_t> SumsOf8(Vec256<uint8_t> v) {
+ v.raw = __lasx_xvhaddw_hu_bu(v.raw, v.raw);
+ v.raw = __lasx_xvhaddw_wu_hu(v.raw, v.raw);
+ return Vec256<uint64_t>{__lasx_xvhaddw_du_wu(v.raw, v.raw)};
+}
+HWY_API Vec256<int64_t> SumsOf8(Vec256<int8_t> v) {
+ v.raw = __lasx_xvhaddw_h_b(v.raw, v.raw);
+ v.raw = __lasx_xvhaddw_w_h(v.raw, v.raw);
+ return Vec256<int64_t>{__lasx_xvhaddw_d_w(v.raw, v.raw)};
+}
+
+// ------------------------------ SaturatedAdd
+
+// Returns a + b clamped to the destination range.
+
+// Unsigned
+HWY_API Vec256<uint8_t> SaturatedAdd(Vec256<uint8_t> a, Vec256<uint8_t> b) {
+ return Vec256<uint8_t>{__lasx_xvsadd_bu(a.raw, b.raw)};
+}
+HWY_API Vec256<uint16_t> SaturatedAdd(Vec256<uint16_t> a, Vec256<uint16_t> b) {
+ return Vec256<uint16_t>{__lasx_xvsadd_hu(a.raw, b.raw)};
+}
+HWY_API Vec256<uint32_t> SaturatedAdd(Vec256<uint32_t> a, Vec256<uint32_t> b) {
+ return Vec256<uint32_t>{__lasx_xvsadd_wu(a.raw, b.raw)};
+}
+HWY_API Vec256<uint64_t> SaturatedAdd(Vec256<uint64_t> a, Vec256<uint64_t> b) {
+ return Vec256<uint64_t>{__lasx_xvsadd_du(a.raw, b.raw)};
+}
+
+// Signed
+HWY_API Vec256<int8_t> SaturatedAdd(Vec256<int8_t> a, Vec256<int8_t> b) {
+ return Vec256<int8_t>{__lasx_xvsadd_b(a.raw, b.raw)};
+}
+HWY_API Vec256<int16_t> SaturatedAdd(Vec256<int16_t> a, Vec256<int16_t> b) {
+ return Vec256<int16_t>{__lasx_xvsadd_h(a.raw, b.raw)};
+}
+HWY_API Vec256<int32_t> SaturatedAdd(Vec256<int32_t> a, Vec256<int32_t> b) {
+ return Vec256<int32_t>{__lasx_xvsadd_w(a.raw, b.raw)};
+}
+HWY_API Vec256<int64_t> SaturatedAdd(Vec256<int64_t> a, Vec256<int64_t> b) {
+ return Vec256<int64_t>{__lasx_xvsadd_d(a.raw, b.raw)};
+}
+
+// ------------------------------ SaturatedSub
+
+// Returns a - b clamped to the destination range.
+
+// Unsigned
+HWY_API Vec256<uint8_t> SaturatedSub(Vec256<uint8_t> a, Vec256<uint8_t> b) {
+ return Vec256<uint8_t>{__lasx_xvssub_bu(a.raw, b.raw)};
+}
+HWY_API Vec256<uint16_t> SaturatedSub(Vec256<uint16_t> a, Vec256<uint16_t> b) {
+ return Vec256<uint16_t>{__lasx_xvssub_hu(a.raw, b.raw)};
+}
+HWY_API Vec256<uint32_t> SaturatedSub(Vec256<uint32_t> a, Vec256<uint32_t> b) {
+ return Vec256<uint32_t>{__lasx_xvssub_wu(a.raw, b.raw)};
+}
+HWY_API Vec256<uint64_t> SaturatedSub(Vec256<uint64_t> a, Vec256<uint64_t> b) {
+ return Vec256<uint64_t>{__lasx_xvssub_du(a.raw, b.raw)};
+}
+
+// Signed
+HWY_API Vec256<int8_t> SaturatedSub(Vec256<int8_t> a, Vec256<int8_t> b) {
+ return Vec256<int8_t>{__lasx_xvssub_b(a.raw, b.raw)};
+}
+HWY_API Vec256<int16_t> SaturatedSub(Vec256<int16_t> a, Vec256<int16_t> b) {
+ return Vec256<int16_t>{__lasx_xvssub_h(a.raw, b.raw)};
+}
+HWY_API Vec256<int32_t> SaturatedSub(Vec256<int32_t> a, Vec256<int32_t> b) {
+ return Vec256<int32_t>{__lasx_xvssub_w(a.raw, b.raw)};
+}
+HWY_API Vec256<int64_t> SaturatedSub(Vec256<int64_t> a, Vec256<int64_t> b) {
+ return Vec256<int64_t>{__lasx_xvssub_d(a.raw, b.raw)};
+}
+
+// ------------------------------ Average
+
+// Returns (a + b + 1) / 2
+
+// Unsigned
+HWY_API Vec256<int8_t> AverageRound(Vec256<int8_t> a, Vec256<int8_t> b) {
+ return Vec256<int8_t>{__lasx_xvavgr_b(a.raw, b.raw)};
+}
+HWY_API Vec256<uint8_t> AverageRound(Vec256<uint8_t> a, Vec256<uint8_t> b) {
+ return Vec256<uint8_t>{__lasx_xvavgr_bu(a.raw, b.raw)};
+}
+HWY_API Vec256<int16_t> AverageRound(Vec256<int16_t> a, Vec256<int16_t> b) {
+ return Vec256<int16_t>{__lasx_xvavgr_h(a.raw, b.raw)};
+}
+HWY_API Vec256<uint16_t> AverageRound(Vec256<uint16_t> a, Vec256<uint16_t> b) {
+ return Vec256<uint16_t>{__lasx_xvavgr_hu(a.raw, b.raw)};
+}
+HWY_API Vec256<int32_t> AverageRound(Vec256<int32_t> a, Vec256<int32_t> b) {
+ return Vec256<int32_t>{__lasx_xvavgr_w(a.raw, b.raw)};
+}
+HWY_API Vec256<uint32_t> AverageRound(Vec256<uint32_t> a, Vec256<uint32_t> b) {
+ return Vec256<uint32_t>{__lasx_xvavgr_wu(a.raw, b.raw)};
+}
+HWY_API Vec256<int64_t> AverageRound(Vec256<int64_t> a, Vec256<int64_t> b) {
+ return Vec256<int64_t>{__lasx_xvavgr_d(a.raw, b.raw)};
+}
+HWY_API Vec256<uint64_t> AverageRound(Vec256<uint64_t> a, Vec256<uint64_t> b) {
+ return Vec256<uint64_t>{__lasx_xvavgr_du(a.raw, b.raw)};
+}
+
+// ------------------------------ Abs (Sub)
+
+// Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1.
+HWY_API Vec256<int8_t> Abs(Vec256<int8_t> v) {
+ return Vec256<int8_t>{__lasx_xvabsd_b(v.raw, __lasx_xvreplgr2vr_b(0))};
+}
+HWY_API Vec256<int16_t> Abs(const Vec256<int16_t> v) {
+ return Vec256<int16_t>{__lasx_xvabsd_h(v.raw, __lasx_xvreplgr2vr_h(0))};
+}
+HWY_API Vec256<int32_t> Abs(const Vec256<int32_t> v) {
+ return Vec256<int32_t>{__lasx_xvabsd_w(v.raw, __lasx_xvreplgr2vr_w(0))};
+}
+HWY_API Vec256<int64_t> Abs(const Vec256<int64_t> v) {
+ return Vec256<int64_t>{__lasx_xvabsd_d(v.raw, __lasx_xvreplgr2vr_d(0))};
+}
+
+// ------------------------------ Integer AbsDiff
+HWY_API Vec256<int8_t> AbsDiff(const Vec256<int8_t> a, Vec256<int8_t> b) {
+ return Vec256<int8_t>{__lasx_xvabsd_b(a.raw, b.raw)};
+}
+HWY_API Vec256<int16_t> AbsDiff(const Vec256<int16_t> a, Vec256<int16_t> b) {
+ return Vec256<int16_t>{__lasx_xvabsd_h(a.raw, b.raw)};
+}
+HWY_API Vec256<int32_t> AbsDiff(const Vec256<int32_t> a, Vec256<int32_t> b) {
+ return Vec256<int32_t>{__lasx_xvabsd_w(a.raw, b.raw)};
+}
+HWY_API Vec256<int64_t> AbsDiff(const Vec256<int64_t> a, Vec256<int64_t> b) {
+ return Vec256<int64_t>{__lasx_xvabsd_d(a.raw, b.raw)};
+}
+
+HWY_API Vec256<uint8_t> AbsDiff(const Vec256<uint8_t> a, Vec256<uint8_t> b) {
+ return Vec256<uint8_t>{__lasx_xvabsd_bu(a.raw, b.raw)};
+}
+HWY_API Vec256<uint16_t> AbsDiff(const Vec256<uint16_t> a, Vec256<uint16_t> b) {
+ return Vec256<uint16_t>{__lasx_xvabsd_hu(a.raw, b.raw)};
+}
+HWY_API Vec256<uint32_t> AbsDiff(const Vec256<uint32_t> a, Vec256<uint32_t> b) {
+ return Vec256<uint32_t>{__lasx_xvabsd_wu(a.raw, b.raw)};
+}
+HWY_API Vec256<uint64_t> AbsDiff(const Vec256<uint64_t> a, Vec256<uint64_t> b) {
+ return Vec256<uint64_t>{__lasx_xvabsd_du(a.raw, b.raw)};
+}
+
+// ------------------------------ Integer multiplication
+
+// Unsigned
+HWY_API Vec256<uint8_t> operator*(Vec256<uint8_t> a, Vec256<uint8_t> b) {
+ return Vec256<uint8_t>{__lasx_xvmul_b(a.raw, b.raw)};
+}
+HWY_API Vec256<uint16_t> operator*(Vec256<uint16_t> a, Vec256<uint16_t> b) {
+ return Vec256<uint16_t>{__lasx_xvmul_h(a.raw, b.raw)};
+}
+HWY_API Vec256<uint32_t> operator*(Vec256<uint32_t> a, Vec256<uint32_t> b) {
+ return Vec256<uint32_t>{__lasx_xvmul_w(a.raw, b.raw)};
+}
+HWY_API Vec256<uint64_t> operator*(Vec256<uint64_t> a, Vec256<uint64_t> b) {
+ return Vec256<uint64_t>{__lasx_xvmul_d(a.raw, b.raw)};
+}
+
+// Signed
+HWY_API Vec256<int8_t> operator*(Vec256<int8_t> a, Vec256<int8_t> b) {
+ return Vec256<int8_t>{__lasx_xvmul_b(a.raw, b.raw)};
+}
+HWY_API Vec256<int16_t> operator*(Vec256<int16_t> a, Vec256<int16_t> b) {
+ return Vec256<int16_t>{__lasx_xvmul_h(a.raw, b.raw)};
+}
+HWY_API Vec256<int32_t> operator*(Vec256<int32_t> a, Vec256<int32_t> b) {
+ return Vec256<int32_t>{__lasx_xvmul_w(a.raw, b.raw)};
+}
+HWY_API Vec256<int64_t> operator*(Vec256<int64_t> a, Vec256<int64_t> b) {
+ return Vec256<int64_t>{__lasx_xvmul_d(a.raw, b.raw)};
+}
+
+HWY_API Vec256<uint8_t> MulHigh(Vec256<uint8_t> a, Vec256<uint8_t> b) {
+ return Vec256<uint8_t>{__lasx_xvmuh_bu(a.raw, b.raw)};
+}
+HWY_API Vec256<int8_t> MulHigh(Vec256<int8_t> a, Vec256<int8_t> b) {
+ return Vec256<int8_t>{__lasx_xvmuh_b(a.raw, b.raw)};
+}
+HWY_API Vec256<uint16_t> MulHigh(Vec256<uint16_t> a, Vec256<uint16_t> b) {
+ return Vec256<uint16_t>{__lasx_xvmuh_hu(a.raw, b.raw)};
+}
+HWY_API Vec256<int16_t> MulHigh(Vec256<int16_t> a, Vec256<int16_t> b) {
+ return Vec256<int16_t>{__lasx_xvmuh_h(a.raw, b.raw)};
+}
+HWY_API Vec256<uint32_t> MulHigh(Vec256<uint32_t> a, Vec256<uint32_t> b) {
+ return Vec256<uint32_t>{__lasx_xvmuh_wu(a.raw, b.raw)};
+}
+HWY_API Vec256<int32_t> MulHigh(Vec256<int32_t> a, Vec256<int32_t> b) {
+ return Vec256<int32_t>{__lasx_xvmuh_w(a.raw, b.raw)};
+}
+HWY_API Vec256<uint64_t> MulHigh(Vec256<uint64_t> a, Vec256<uint64_t> b) {
+ return Vec256<uint64_t>{__lasx_xvmuh_du(a.raw, b.raw)};
+}
+HWY_API Vec256<int64_t> MulHigh(Vec256<int64_t> a, Vec256<int64_t> b) {
+ return Vec256<int64_t>{__lasx_xvmuh_d(a.raw, b.raw)};
+}
+
+// Multiplies even lanes (0, 2 ..) and places the double-wide result into
+// even and the upper half into its odd neighbor lane.
+HWY_API Vec256<int16_t> MulEven(Vec256<int8_t> a, Vec256<int8_t> b) {
+ return Vec256<int16_t>{__lasx_xvmulwev_h_b(a.raw, b.raw)};
+}
+HWY_API Vec256<uint16_t> MulEven(Vec256<uint8_t> a, Vec256<uint8_t> b) {
+ return Vec256<uint16_t>{__lasx_xvmulwev_h_bu(a.raw, b.raw)};
+}
+HWY_API Vec256<int32_t> MulEven(Vec256<int16_t> a, Vec256<int16_t> b) {
+ return Vec256<int32_t>{__lasx_xvmulwev_w_h(a.raw, b.raw)};
+}
+HWY_API Vec256<uint32_t> MulEven(Vec256<uint16_t> a, Vec256<uint16_t> b) {
+ return Vec256<uint32_t>{__lasx_xvmulwev_w_hu(a.raw, b.raw)};
+}
+HWY_API Vec256<int64_t> MulEven(Vec256<int32_t> a, Vec256<int32_t> b) {
+ return Vec256<int64_t>{__lasx_xvmulwev_d_w(a.raw, b.raw)};
+}
+HWY_API Vec256<uint64_t> MulEven(Vec256<uint32_t> a, Vec256<uint32_t> b) {
+ return Vec256<uint64_t>{__lasx_xvmulwev_d_wu(a.raw, b.raw)};
+}
+template <typename T, HWY_IF_I64(T)>
+HWY_API Vec256<T> MulEven(Vec256<T> a, Vec256<T> b) {
+ return Vec256<T>{__lasx_xvmulwev_q_d(a.raw, b.raw)};
+}
+template <typename T, HWY_IF_U64(T)>
+HWY_API Vec256<T> MulEven(Vec256<T> a, Vec256<T> b) {
+ return Vec256<T>{__lasx_xvmulwev_q_du(a.raw, b.raw)};
+}
+
+HWY_API Vec256<int16_t> MulOdd(Vec256<int8_t> a, Vec256<int8_t> b) {
+ return Vec256<int16_t>{__lasx_xvmulwod_h_b(a.raw, b.raw)};
+}
+HWY_API Vec256<uint16_t> MulOdd(Vec256<uint8_t> a, Vec256<uint8_t> b) {
+ return Vec256<uint16_t>{__lasx_xvmulwod_h_bu(a.raw, b.raw)};
+}
+HWY_API Vec256<int32_t> MulOdd(Vec256<int16_t> a, Vec256<int16_t> b) {
+ return Vec256<int32_t>{__lasx_xvmulwod_w_h(a.raw, b.raw)};
+}
+HWY_API Vec256<uint32_t> MulOdd(Vec256<uint16_t> a, Vec256<uint16_t> b) {
+ return Vec256<uint32_t>{__lasx_xvmulwod_w_hu(a.raw, b.raw)};
+}
+HWY_API Vec256<int64_t> MulOdd(Vec256<int32_t> a, Vec256<int32_t> b) {
+ return Vec256<int64_t>{__lasx_xvmulwod_d_w(a.raw, b.raw)};
+}
+HWY_API Vec256<uint64_t> MulOdd(Vec256<uint32_t> a, Vec256<uint32_t> b) {
+ return Vec256<uint64_t>{__lasx_xvmulwod_d_wu(a.raw, b.raw)};
+}
+template <typename T, HWY_IF_I64(T)>
+HWY_API Vec256<T> MulOdd(Vec256<T> a, Vec256<T> b) {
+ return Vec256<T>{__lasx_xvmulwod_q_d(a.raw, b.raw)};
+}
+template <typename T, HWY_IF_U64(T)>
+HWY_API Vec256<T> MulOdd(Vec256<T> a, Vec256<T> b) {
+ return Vec256<T>{__lasx_xvmulwod_q_du(a.raw, b.raw)};
+}
+
+template <typename T, HWY_IF_I16(T)>
+HWY_API Vec256<T> MulFixedPoint15(Vec256<T> a, Vec256<T> b) {
+ const auto i32_ev = MulEven(a, b);
+ const auto i32_od = MulOdd(a, b);
+ const auto i64_lo = InterleaveLower(i32_ev, i32_od);
+ const auto i64_hi = InterleaveUpper(Full256<int32_t>(), i32_ev, i32_od);
+ return Vec256<T>{__lasx_xvssrarni_h_w(i64_hi.raw, i64_lo.raw, 15)};
+}
+
+// ------------------------------ Integer division
+
+HWY_API Vec256<int8_t> operator/(const Vec256<int8_t> a,
+ const Vec256<int8_t> b) {
+ // Use inline assembly to avoid undefined behavior if any lanes of b are zero
+ // or a[i] == LimitsMin<int8_t>() && b[i] == -1
+ __m256i raw_result;
+ __asm__("xvdiv.b %u0,%u1,%u2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :);
+ return Vec256<int8_t>{raw_result};
+}
+
+HWY_API Vec256<uint8_t> operator/(const Vec256<uint8_t> a,
+ const Vec256<uint8_t> b) {
+ // Use inline assembly to avoid undefined behavior if any lanes of b are zero
+ __m256i raw_result;
+ __asm__("xvdiv.bu %u0,%u1,%u2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :);
+ return Vec256<uint8_t>{raw_result};
+}
+
+HWY_API Vec256<int16_t> operator/(const Vec256<int16_t> a,
+ const Vec256<int16_t> b) {
+ // Use inline assembly to avoid undefined behavior if any lanes of b are zero
+ // or a[i] == LimitsMin<int16_t>() && b[i] == -1
+ __m256i raw_result;
+ __asm__("xvdiv.h %u0,%u1,%u2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :);
+ return Vec256<int16_t>{raw_result};
+}
+
+HWY_API Vec256<uint16_t> operator/(const Vec256<uint16_t> a,
+ const Vec256<uint16_t> b) {
+ // Use inline assembly to avoid undefined behavior if any lanes of b are zero
+ __m256i raw_result;
+ __asm__("xvdiv.hu %u0,%u1,%u2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :);
+ return Vec256<uint16_t>{raw_result};
+}
+
+HWY_API Vec256<int32_t> operator/(const Vec256<int32_t> a,
+ const Vec256<int32_t> b) {
+ // Use inline assembly to avoid undefined behavior if any lanes of b are zero
+ // or a[i] == LimitsMin<int32_t>() && b[i] == -1
+ __m256i raw_result;
+ __asm__("xvdiv.w %u0,%u1,%u2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :);
+ return Vec256<int32_t>{raw_result};
+}
+
+HWY_API Vec256<uint32_t> operator/(const Vec256<uint32_t> a,
+ const Vec256<uint32_t> b) {
+ // Use inline assembly to avoid undefined behavior if any lanes of b are zero
+ __m256i raw_result;
+ __asm__("xvdiv.wu %u0,%u1,%u2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :);
+ return Vec256<uint32_t>{raw_result};
+}
+
+HWY_API Vec256<int64_t> operator/(const Vec256<int64_t> a,
+ const Vec256<int64_t> b) {
+ // Use inline assembly to avoid undefined behavior if any lanes of b are zero
+ // or a[i] == LimitsMin<int64_t>() && b[i] == -1
+ __m256i raw_result;
+ __asm__("xvdiv.d %u0,%u1,%u2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :);
+ return Vec256<int64_t>{raw_result};
+}
+
+HWY_API Vec256<uint64_t> operator/(const Vec256<uint64_t> a,
+ const Vec256<uint64_t> b) {
+ // Use inline assembly to avoid undefined behavior if any lanes of b are zero
+ __m256i raw_result;
+ __asm__("xvdiv.du %u0,%u1,%u2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :);
+ return Vec256<uint64_t>{raw_result};
+}
+
+// ------------------------------ Integer modulo
+
+HWY_API Vec256<int8_t> operator%(const Vec256<int8_t> a,
+ const Vec256<int8_t> b) {
+ // Use inline assembly to avoid undefined behavior if any lanes of b are zero
+ // or a[i] == LimitsMin<int8_t>() && b[i] == -1
+ __m256i raw_result;
+ __asm__("xvmod.b %u0,%u1,%u2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :);
+ return Vec256<int8_t>{raw_result};
+}
+
+HWY_API Vec256<uint8_t> operator%(const Vec256<uint8_t> a,
+ const Vec256<uint8_t> b) {
+ // Use inline assembly to avoid undefined behavior if any lanes of b are zero
+ __m256i raw_result;
+ __asm__("xvmod.bu %u0,%u1,%u2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :);
+ return Vec256<uint8_t>{raw_result};
+}
+
+HWY_API Vec256<int16_t> operator%(const Vec256<int16_t> a,
+ const Vec256<int16_t> b) {
+ // Use inline assembly to avoid undefined behavior if any lanes of b are zero
+ // or a[i] == LimitsMin<int16_t>() && b[i] == -1
+ __m256i raw_result;
+ __asm__("xvmod.h %u0,%u1,%u2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :);
+ return Vec256<int16_t>{raw_result};
+}
+
+HWY_API Vec256<uint16_t> operator%(const Vec256<uint16_t> a,
+ const Vec256<uint16_t> b) {
+ // Use inline assembly to avoid undefined behavior if any lanes of b are zero
+ __m256i raw_result;
+ __asm__("xvmod.hu %u0,%u1,%u2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :);
+ return Vec256<uint16_t>{raw_result};
+}
+
+HWY_API Vec256<int32_t> operator%(const Vec256<int32_t> a,
+ const Vec256<int32_t> b) {
+ // Use inline assembly to avoid undefined behavior if any lanes of b are zero
+ // or a[i] == LimitsMin<int32_t>() && b[i] == -1
+ __m256i raw_result;
+ __asm__("xvmod.w %u0,%u1,%u2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :);
+ return Vec256<int32_t>{raw_result};
+}
+
+HWY_API Vec256<uint32_t> operator%(const Vec256<uint32_t> a,
+ const Vec256<uint32_t> b) {
+ // Use inline assembly to avoid undefined behavior if any lanes of b are zero
+ __m256i raw_result;
+ __asm__("xvmod.wu %u0,%u1,%u2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :);
+ return Vec256<uint32_t>{raw_result};
+}
+
+HWY_API Vec256<int64_t> operator%(const Vec256<int64_t> a,
+ const Vec256<int64_t> b) {
+ // Use inline assembly to avoid undefined behavior if any lanes of b are zero
+ // or a[i] == LimitsMin<int64_t>() && b[i] == -1
+ __m256i raw_result;
+ __asm__("xvmod.d %u0,%u1,%u2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :);
+ return Vec256<int64_t>{raw_result};
+}
+
+HWY_API Vec256<uint64_t> operator%(const Vec256<uint64_t> a,
+ const Vec256<uint64_t> b) {
+ // Use inline assembly to avoid undefined behavior if any lanes of b are zero
+ __m256i raw_result;
+ __asm__("xvmod.du %u0,%u1,%u2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :);
+ return Vec256<uint64_t>{raw_result};
+}
+
+// ------------------------------ ShiftLeft (Compile-time constant shifts)
+
+template <int kBits, typename T, HWY_IF_UI8(T)>
+HWY_API Vec256<T> ShiftLeft(Vec256<T> v) {
+ return Vec256<T>{__lasx_xvslli_b(v.raw, kBits)};
+}
+
+template <int kBits, typename T, HWY_IF_UI16(T)>
+HWY_API Vec256<T> ShiftLeft(Vec256<T> v) {
+ return Vec256<T>{__lasx_xvslli_h(v.raw, kBits)};
+}
+
+template <int kBits, typename T, HWY_IF_UI32(T)>
+HWY_API Vec256<T> ShiftLeft(Vec256<T> v) {
+ return Vec256<T>{__lasx_xvslli_w(v.raw, kBits)};
+}
+
+template <int kBits, typename T, HWY_IF_UI64(T)>
+HWY_API Vec256<T> ShiftLeft(Vec256<T> v) {
+ return Vec256<T>{__lasx_xvslli_d(v.raw, kBits)};
+}
+
+// ------------------------------ ShiftRight (Compile-time constant shifts)
+
+template <int kBits>
+HWY_API Vec256<uint8_t> ShiftRight(Vec256<uint8_t> v) {
+ return Vec256<uint8_t>{__lasx_xvsrli_b(v.raw, kBits)};
+}
+
+template <int kBits>
+HWY_API Vec256<uint16_t> ShiftRight(Vec256<uint16_t> v) {
+ return Vec256<uint16_t>{__lasx_xvsrli_h(v.raw, kBits)};
+}
+
+template <int kBits>
+HWY_API Vec256<uint32_t> ShiftRight(Vec256<uint32_t> v) {
+ return Vec256<uint32_t>{__lasx_xvsrli_w(v.raw, kBits)};
+}
+
+template <int kBits>
+HWY_API Vec256<uint64_t> ShiftRight(Vec256<uint64_t> v) {
+ return Vec256<uint64_t>{__lasx_xvsrli_d(v.raw, kBits)};
+}
+
+template <int kBits>
+HWY_API Vec256<int8_t> ShiftRight(Vec256<int8_t> v) {
+ return Vec256<int8_t>{__lasx_xvsrai_b(v.raw, kBits)};
+}
+
+template <int kBits>
+HWY_API Vec256<int16_t> ShiftRight(Vec256<int16_t> v) {
+ return Vec256<int16_t>{__lasx_xvsrai_h(v.raw, kBits)};
+}
+
+template <int kBits>
+HWY_API Vec256<int32_t> ShiftRight(Vec256<int32_t> v) {
+ return Vec256<int32_t>{__lasx_xvsrai_w(v.raw, kBits)};
+}
+
+template <int kBits>
+HWY_API Vec256<int64_t> ShiftRight(Vec256<int64_t> v) {
+ return Vec256<int64_t>{__lasx_xvsrai_d(v.raw, kBits)};
+}
+
+// ------------------------------ RoundingShiftRight
+
+template <int kBits>
+HWY_API Vec256<int8_t> RoundingShiftRight(Vec256<int8_t> v) {
+ return Vec256<int8_t>{__lasx_xvsrari_b(v.raw, kBits)};
+}
+template <int kBits>
+HWY_API Vec256<int16_t> RoundingShiftRight(Vec256<int16_t> v) {
+ return Vec256<int16_t>{__lasx_xvsrari_h(v.raw, kBits)};
+}
+template <int kBits>
+HWY_API Vec256<int32_t> RoundingShiftRight(Vec256<int32_t> v) {
+ return Vec256<int32_t>{__lasx_xvsrari_w(v.raw, kBits)};
+}
+template <int kBits>
+HWY_API Vec256<int64_t> RoundingShiftRight(Vec256<int64_t> v) {
+ return Vec256<int64_t>{__lasx_xvsrari_d(v.raw, kBits)};
+}
+
+template <int kBits>
+HWY_API Vec256<uint8_t> RoundingShiftRight(Vec256<uint8_t> v) {
+ return Vec256<uint8_t>{__lasx_xvsrlri_b(v.raw, kBits)};
+}
+template <int kBits>
+HWY_API Vec256<uint16_t> RoundingShiftRight(Vec256<uint16_t> v) {
+ return Vec256<uint16_t>{__lasx_xvsrlri_h(v.raw, kBits)};
+}
+template <int kBits>
+HWY_API Vec256<uint32_t> RoundingShiftRight(Vec256<uint32_t> v) {
+ return Vec256<uint32_t>{__lasx_xvsrlri_w(v.raw, kBits)};
+}
+template <int kBits>
+HWY_API Vec256<uint64_t> RoundingShiftRight(Vec256<uint64_t> v) {
+ return Vec256<uint64_t>{__lasx_xvsrlri_d(v.raw, kBits)};
+}
+// ------------------------------ RoundingShr
+
+HWY_API Vec256<uint8_t> RoundingShr(Vec256<uint8_t> v, Vec256<uint8_t> bits) {
+ return Vec256<uint8_t>{__lasx_xvsrlr_b(v.raw, bits.raw)};
+}
+HWY_API Vec256<uint16_t> RoundingShr(Vec256<uint16_t> v,
+ Vec256<uint16_t> bits) {
+ return Vec256<uint16_t>{__lasx_xvsrlr_h(v.raw, bits.raw)};
+}
+HWY_API Vec256<uint32_t> RoundingShr(Vec256<uint32_t> v,
+ Vec256<uint32_t> bits) {
+ return Vec256<uint32_t>{__lasx_xvsrlr_w(v.raw, bits.raw)};
+}
+HWY_API Vec256<uint64_t> RoundingShr(Vec256<uint64_t> v,
+ Vec256<uint64_t> bits) {
+ return Vec256<uint64_t>{__lasx_xvsrlr_d(v.raw, bits.raw)};
+}
+
+HWY_API Vec256<int8_t> RoundingShr(Vec256<int8_t> v, Vec256<int8_t> bits) {
+ return Vec256<int8_t>{__lasx_xvsrar_b(v.raw, bits.raw)};
+}
+HWY_API Vec256<int16_t> RoundingShr(Vec256<int16_t> v, Vec256<int16_t> bits) {
+ return Vec256<int16_t>{__lasx_xvsrar_h(v.raw, bits.raw)};
+}
+HWY_API Vec256<int32_t> RoundingShr(Vec256<int32_t> v, Vec256<int32_t> bits) {
+ return Vec256<int32_t>{__lasx_xvsrar_w(v.raw, bits.raw)};
+}
+HWY_API Vec256<int64_t> RoundingShr(Vec256<int64_t> v, Vec256<int64_t> bits) {
+ return Vec256<int64_t>{__lasx_xvsrar_d(v.raw, bits.raw)};
+}
+
+// ------------------------------ RoundingShiftRightSame (RoundingShr)
+
+template <typename T>
+HWY_API Vec256<T> RoundingShiftRightSame(const Vec256<T> v, int bits) {
+ return RoundingShr(v, Set(DFromV<decltype(v)>(), static_cast<T>(bits)));
+}
+
+// ------------------------------ RotateRight (Compile-time constant shifts)
+
+template <int kBits, typename T, HWY_IF_UI8(T)>
+HWY_API Vec256<T> RotateRight(const Vec256<T> v) {
+ static_assert(0 <= kBits && kBits < 8, "Invalid shift count");
+ if (kBits == 0) return v;
+ return Vec256<T>{__lasx_xvrotri_b(v.raw, kBits)};
+}
+
+template <int kBits, typename T, HWY_IF_UI16(T)>
+HWY_API Vec256<T> RotateRight(const Vec256<T> v) {
+ static_assert(0 <= kBits && kBits < 16, "Invalid shift count");
+ if (kBits == 0) return v;
+ return Vec256<T>{__lasx_xvrotri_h(v.raw, kBits)};
+}
+
+template <int kBits, typename T, HWY_IF_UI32(T)>
+HWY_API Vec256<T> RotateRight(const Vec256<T> v) {
+ static_assert(0 <= kBits && kBits < 32, "Invalid shift count");
+ if (kBits == 0) return v;
+ return Vec256<T>{__lasx_xvrotri_w(v.raw, kBits)};
+}
+
+template <int kBits, typename T, HWY_IF_UI64(T)>
+HWY_API Vec256<T> RotateRight(const Vec256<T> v) {
+ static_assert(0 <= kBits && kBits < 64, "Invalid shift count");
+ if (kBits == 0) return v;
+ return Vec256<T>{__lasx_xvrotri_d(v.raw, kBits)};
+}
+
+// ------------------------------ Rol/Ror
+template <class T, HWY_IF_UI8(T)>
+HWY_API Vec256<T> Ror(Vec256<T> a, Vec256<T> b) {
+ return Vec256<T>{__lasx_xvrotr_b(a.raw, b.raw)};
+}
+
+template <class T, HWY_IF_UI16(T)>
+HWY_API Vec256<T> Ror(Vec256<T> a, Vec256<T> b) {
+ return Vec256<T>{__lasx_xvrotr_h(a.raw, b.raw)};
+}
+
+template <class T, HWY_IF_UI32(T)>
+HWY_API Vec256<T> Ror(Vec256<T> a, Vec256<T> b) {
+ return Vec256<T>{__lasx_xvrotr_w(a.raw, b.raw)};
+}
+
+template <class T, HWY_IF_UI64(T)>
+HWY_API Vec256<T> Ror(Vec256<T> a, Vec256<T> b) {
+ return Vec256<T>{__lasx_xvrotr_d(a.raw, b.raw)};
+}
+
+// ------------------------------ BroadcastSignBit (ShiftRight, compare, mask)
+
+HWY_API Vec256<int8_t> BroadcastSignBit(const Vec256<int8_t> v) {
+ return Vec256<int8_t>{__lasx_xvsrai_b(v.raw, 7)};
+}
+
+HWY_API Vec256<int16_t> BroadcastSignBit(const Vec256<int16_t> v) {
+ return Vec256<int16_t>{__lasx_xvsrai_h(v.raw, 15)};
+}
+
+HWY_API Vec256<int32_t> BroadcastSignBit(const Vec256<int32_t> v) {
+ return Vec256<int32_t>{__lasx_xvsrai_w(v.raw, 31)};
+}
+
+HWY_API Vec256<int64_t> BroadcastSignBit(const Vec256<int64_t> v) {
+ return Vec256<int64_t>{__lasx_xvsrai_d(v.raw, 63)};
+}
+
+// ------------------------------ IfNegativeThenElse (BroadcastSignBit)
+template <typename T>
+HWY_API Vec256<T> IfNegativeThenElse(Vec256<T> v, Vec256<T> yes, Vec256<T> no) {
+ static_assert(IsSigned<T>(), "Only works for signed/float");
+ const DFromV<decltype(v)> d;
+ const RebindToSigned<decltype(d)> di;
+ const auto mask = MaskFromVec(BitCast(d, BroadcastSignBit(BitCast(di, v))));
+ return IfThenElse(mask, yes, no);
+}
+
+// ------------------------------ IfNegativeThenNegOrUndefIfZero
+
+HWY_API Vec256<int8_t> IfNegativeThenNegOrUndefIfZero(Vec256<int8_t> mask,
+ Vec256<int8_t> v) {
+ return Vec256<int8_t>{__lasx_xvsigncov_b(mask.raw, v.raw)};
+}
+
+HWY_API Vec256<int16_t> IfNegativeThenNegOrUndefIfZero(Vec256<int16_t> mask,
+ Vec256<int16_t> v) {
+ return Vec256<int16_t>{__lasx_xvsigncov_h(mask.raw, v.raw)};
+}
+
+HWY_API Vec256<int32_t> IfNegativeThenNegOrUndefIfZero(Vec256<int32_t> mask,
+ Vec256<int32_t> v) {
+ return Vec256<int32_t>{__lasx_xvsigncov_w(mask.raw, v.raw)};
+}
+
+HWY_API Vec256<int64_t> IfNegativeThenNegOrUndefIfZero(Vec256<int64_t> mask,
+ Vec256<int64_t> v) {
+ return Vec256<int64_t>{__lasx_xvsigncov_d(mask.raw, v.raw)};
+}
+
+// ------------------------------ ShiftLeftSame
+
+template <typename T>
+HWY_API Vec256<T> ShiftLeftSame(const Vec256<T> v, const int bits) {
+ return Shl(v, Set(DFromV<decltype(v)>(), static_cast<T>(bits)));
+}
+
+// ------------------------------ ShiftRightSame (BroadcastSignBit)
+
+HWY_API Vec256<uint8_t> ShiftRightSame(const Vec256<uint8_t> v,
+ const int bits) {
+ return Vec256<uint8_t>{__lasx_xvsrl_b(v.raw, __lasx_xvreplgr2vr_b(bits))};
+}
+
+HWY_API Vec256<uint16_t> ShiftRightSame(const Vec256<uint16_t> v,
+ const int bits) {
+ return Vec256<uint16_t>{__lasx_xvsrl_h(v.raw, __lasx_xvreplgr2vr_h(bits))};
+}
+
+HWY_API Vec256<uint32_t> ShiftRightSame(const Vec256<uint32_t> v,
+ const int bits) {
+ return Vec256<uint32_t>{__lasx_xvsrl_w(v.raw, __lasx_xvreplgr2vr_w(bits))};
+}
+
+HWY_API Vec256<uint64_t> ShiftRightSame(const Vec256<uint64_t> v,
+ const int bits) {
+ return Vec256<uint64_t>{__lasx_xvsrl_d(v.raw, __lasx_xvreplgr2vr_d(bits))};
+}
+
+HWY_API Vec256<int8_t> ShiftRightSame(const Vec256<int8_t> v, const int bits) {
+ return Vec256<int8_t>{__lasx_xvsra_b(v.raw, __lasx_xvreplgr2vr_b(bits))};
+}
+
+HWY_API Vec256<int16_t> ShiftRightSame(const Vec256<int16_t> v,
+ const int bits) {
+ return Vec256<int16_t>{__lasx_xvsra_h(v.raw, __lasx_xvreplgr2vr_h(bits))};
+}
+
+HWY_API Vec256<int32_t> ShiftRightSame(const Vec256<int32_t> v,
+ const int bits) {
+ return Vec256<int32_t>{__lasx_xvsra_w(v.raw, __lasx_xvreplgr2vr_w(bits))};
+}
+
+HWY_API Vec256<int64_t> ShiftRightSame(const Vec256<int64_t> v,
+ const int bits) {
+ return Vec256<int64_t>{__lasx_xvsra_d(v.raw, __lasx_xvreplgr2vr_d(bits))};
+}
+
+// ------------------------------ Neg (Xor, Sub)
+
+namespace detail {
+
+template <typename T>
+HWY_INLINE Vec256<T> Neg(hwy::FloatTag /*tag*/, const Vec256<T> v) {
+ const DFromV<decltype(v)> d;
+ return Xor(v, SignBit(d));
+}
+
+template <typename T>
+HWY_INLINE Vec256<T> Neg(hwy::SpecialTag /*tag*/, const Vec256<T> v) {
+ const DFromV<decltype(v)> d;
+ return Xor(v, SignBit(d));
+}
+
+// Not floating-point
+template <typename T, HWY_IF_UI8(T)>
+HWY_INLINE Vec256<T> Neg(hwy::SignedTag /*tag*/, const Vec256<T> v) {
+ return Vec256<T>{__lasx_xvneg_b(v.raw)};
+}
+
+template <typename T, HWY_IF_UI16(T)>
+HWY_INLINE Vec256<T> Neg(hwy::SignedTag /*tag*/, const Vec256<T> v) {
+ return Vec256<T>{__lasx_xvneg_h(v.raw)};
+}
+
+template <typename T, HWY_IF_UI32(T)>
+HWY_INLINE Vec256<T> Neg(hwy::SignedTag /*tag*/, const Vec256<T> v) {
+ return Vec256<T>{__lasx_xvneg_w(v.raw)};
+}
+
+template <typename T, HWY_IF_UI64(T)>
+HWY_INLINE Vec256<T> Neg(hwy::SignedTag /*tag*/, const Vec256<T> v) {
+ return Vec256<T>{__lasx_xvneg_d(v.raw)};
+}
+
+} // namespace detail
+
+template <typename T>
+HWY_API Vec256<T> Neg(const Vec256<T> v) {
+ return detail::Neg(hwy::TypeTag<T>(), v);
+}
+
+// ------------------------------ Floating-point mul / div
+
+HWY_API Vec256<float> operator*(Vec256<float> a, Vec256<float> b) {
+ return Vec256<float>{__lasx_xvfmul_s(a.raw, b.raw)};
+}
+HWY_API Vec256<double> operator*(Vec256<double> a, Vec256<double> b) {
+ return Vec256<double>{__lasx_xvfmul_d(a.raw, b.raw)};
+}
+
+HWY_API Vec256<float> operator/(Vec256<float> a, Vec256<float> b) {
+ return Vec256<float>{__lasx_xvfdiv_s(a.raw, b.raw)};
+}
+HWY_API Vec256<double> operator/(Vec256<double> a, Vec256<double> b) {
+ return Vec256<double>{__lasx_xvfdiv_d(a.raw, b.raw)};
+}
+
+// Approximate reciprocal
+
+HWY_API Vec256<float> ApproximateReciprocal(Vec256<float> v) {
+ return Vec256<float>{__lasx_xvfrecip_s(v.raw)};
+}
+
+HWY_API Vec256<double> ApproximateReciprocal(Vec256<double> v) {
+ return Vec256<double>{__lasx_xvfrecip_d(v.raw)};
+}
+
+// Integer multiply-add variants
+
+// signed
+HWY_API Vec256<int8_t> MulAdd(Vec256<int8_t> mul, Vec256<int8_t> x,
+ Vec256<int8_t> add) {
+ return Vec256<int8_t>{__lasx_xvmadd_b(add.raw, mul.raw, x.raw)};
+}
+HWY_API Vec256<int16_t> MulAdd(Vec256<int16_t> mul, Vec256<int16_t> x,
+ Vec256<int16_t> add) {
+ return Vec256<int16_t>{__lasx_xvmadd_h(add.raw, mul.raw, x.raw)};
+}
+HWY_API Vec256<int32_t> MulAdd(Vec256<int32_t> mul, Vec256<int32_t> x,
+ Vec256<int32_t> add) {
+ return Vec256<int32_t>{__lasx_xvmadd_w(add.raw, mul.raw, x.raw)};
+}
+HWY_API Vec256<int64_t> MulAdd(Vec256<int64_t> mul, Vec256<int64_t> x,
+ Vec256<int64_t> add) {
+ return Vec256<int64_t>{__lasx_xvmadd_d(add.raw, mul.raw, x.raw)};
+}
+
+// unsigend
+HWY_API Vec256<uint8_t> MulAdd(Vec256<uint8_t> mul, Vec256<uint8_t> x,
+ Vec256<uint8_t> add) {
+ return Vec256<uint8_t>{__lasx_xvmadd_b(add.raw, mul.raw, x.raw)};
+}
+HWY_API Vec256<uint16_t> MulAdd(Vec256<uint16_t> mul, Vec256<uint16_t> x,
+ Vec256<uint16_t> add) {
+ return Vec256<uint16_t>{__lasx_xvmadd_h(add.raw, mul.raw, x.raw)};
+}
+HWY_API Vec256<uint32_t> MulAdd(Vec256<uint32_t> mul, Vec256<uint32_t> x,
+ Vec256<uint32_t> add) {
+ return Vec256<uint32_t>{__lasx_xvmadd_w(add.raw, mul.raw, x.raw)};
+}
+HWY_API Vec256<uint64_t> MulAdd(Vec256<uint64_t> mul, Vec256<uint64_t> x,
+ Vec256<uint64_t> add) {
+ return Vec256<uint64_t>{__lasx_xvmadd_d(add.raw, mul.raw, x.raw)};
+}
+
+// signed
+HWY_API Vec256<int8_t> NegMulAdd(Vec256<int8_t> mul, Vec256<int8_t> x,
+ Vec256<int8_t> add) {
+ return Vec256<int8_t>{__lasx_xvmsub_b(add.raw, mul.raw, x.raw)};
+}
+HWY_API Vec256<int16_t> NegMulAdd(Vec256<int16_t> mul, Vec256<int16_t> x,
+ Vec256<int16_t> add) {
+ return Vec256<int16_t>{__lasx_xvmsub_h(add.raw, mul.raw, x.raw)};
+}
+HWY_API Vec256<int32_t> NegMulAdd(Vec256<int32_t> mul, Vec256<int32_t> x,
+ Vec256<int32_t> add) {
+ return Vec256<int32_t>{__lasx_xvmsub_w(add.raw, mul.raw, x.raw)};
+}
+HWY_API Vec256<int64_t> NegMulAdd(Vec256<int64_t> mul, Vec256<int64_t> x,
+ Vec256<int64_t> add) {
+ return Vec256<int64_t>{__lasx_xvmsub_d(add.raw, mul.raw, x.raw)};
+}
+
+// unsigned
+HWY_API Vec256<uint8_t> NegMulAdd(Vec256<uint8_t> mul, Vec256<uint8_t> x,
+ Vec256<uint8_t> add) {
+ return Vec256<uint8_t>{__lasx_xvmsub_b(add.raw, mul.raw, x.raw)};
+}
+HWY_API Vec256<uint16_t> NegMulAdd(Vec256<uint16_t> mul, Vec256<uint16_t> x,
+ Vec256<uint16_t> add) {
+ return Vec256<uint16_t>{__lasx_xvmsub_h(add.raw, mul.raw, x.raw)};
+}
+HWY_API Vec256<uint32_t> NegMulAdd(Vec256<uint32_t> mul, Vec256<uint32_t> x,
+ Vec256<uint32_t> add) {
+ return Vec256<uint32_t>{__lasx_xvmsub_w(add.raw, mul.raw, x.raw)};
+}
+HWY_API Vec256<uint64_t> NegMulAdd(Vec256<uint64_t> mul, Vec256<uint64_t> x,
+ Vec256<uint64_t> add) {
+ return Vec256<uint64_t>{__lasx_xvmsub_d(add.raw, mul.raw, x.raw)};
+}
+
+// ------------------------------ Floating-point multiply-add variants
+
+HWY_API Vec256<float> MulAdd(Vec256<float> mul, Vec256<float> x,
+ Vec256<float> add) {
+ return Vec256<float>{__lasx_xvfmadd_s(mul.raw, x.raw, add.raw)};
+}
+HWY_API Vec256<double> MulAdd(Vec256<double> mul, Vec256<double> x,
+ Vec256<double> add) {
+ return Vec256<double>{__lasx_xvfmadd_d(mul.raw, x.raw, add.raw)};
+}
+
+HWY_API Vec256<float> NegMulAdd(Vec256<float> mul, Vec256<float> x,
+ Vec256<float> add) {
+ return add - mul * x;
+}
+HWY_API Vec256<double> NegMulAdd(Vec256<double> mul, Vec256<double> x,
+ Vec256<double> add) {
+ return add - mul * x;
+}
+
+HWY_API Vec256<float> MulSub(Vec256<float> mul, Vec256<float> x,
+ Vec256<float> sub) {
+ return Vec256<float>{__lasx_xvfmsub_s(mul.raw, x.raw, sub.raw)};
+}
+HWY_API Vec256<double> MulSub(Vec256<double> mul, Vec256<double> x,
+ Vec256<double> sub) {
+ return Vec256<double>{__lasx_xvfmsub_d(mul.raw, x.raw, sub.raw)};
+}
+
+HWY_API Vec256<float> NegMulSub(Vec256<float> mul, Vec256<float> x,
+ Vec256<float> sub) {
+ return Vec256<float>{__lasx_xvfnmadd_s(mul.raw, x.raw, sub.raw)};
+}
+HWY_API Vec256<double> NegMulSub(Vec256<double> mul, Vec256<double> x,
+ Vec256<double> sub) {
+ return Vec256<double>{__lasx_xvfnmadd_d(mul.raw, x.raw, sub.raw)};
+}
+
+// ------------------------------ MulAddSub(Float)
+
+template <typename T, HWY_IF_FLOAT3264(T)>
+HWY_API Vec256<T> MulAddSub(Vec256<T> mul, Vec256<T> x, Vec256<T> sub_or_add) {
+ return OddEven(MulAdd(mul, x, sub_or_add), MulSub(mul, x, sub_or_add));
+}
+
+// ------------------------------ Floating-point square root
+
+// Full precision square root
+HWY_API Vec256<float> Sqrt(Vec256<float> v) {
+ return Vec256<float>{__lasx_xvfsqrt_s(v.raw)};
+}
+
+HWY_API Vec256<double> Sqrt(Vec256<double> v) {
+ return Vec256<double>{__lasx_xvfsqrt_d(v.raw)};
+}
+
+// Approximate reciprocal square root
+HWY_API Vec256<float> ApproximateReciprocalSqrt(Vec256<float> v) {
+ return Vec256<float>{__lasx_xvfrsqrt_s(v.raw)};
+}
+
+HWY_API Vec256<double> ApproximateReciprocalSqrt(Vec256<double> v) {
+ return Vec256<double>{__lasx_xvfrsqrt_d(v.raw)};
+}
+
+// ------------------------------ Floating-point rounding
+
+// Toward nearest integer, tie to even
+HWY_API Vec256<float> Round(Vec256<float> v) {
+ return Vec256<float>{__lasx_xvfrintrne_s(v.raw)};
+}
+
+HWY_API Vec256<double> Round(Vec256<double> v) {
+ return Vec256<double>{__lasx_xvfrintrne_d(v.raw)};
+}
+
+// Toward zero, aka truncate
+HWY_API Vec256<float> Trunc(Vec256<float> v) {
+ return Vec256<float>{__lasx_xvfrintrz_s(v.raw)};
+}
+
+HWY_API Vec256<double> Trunc(Vec256<double> v) {
+ return Vec256<double>{__lasx_xvfrintrz_d(v.raw)};
+}
+
+// Toward +infinity, aka ceiling
+HWY_API Vec256<float> Ceil(Vec256<float> v) {
+ return Vec256<float>{__lasx_xvfrintrp_s(v.raw)};
+}
+
+HWY_API Vec256<double> Ceil(Vec256<double> v) {
+ return Vec256<double>{__lasx_xvfrintrp_d(v.raw)};
+}
+
+// Toward -infinity, aka floor
+HWY_API Vec256<float> Floor(Vec256<float> v) {
+ return Vec256<float>{__lasx_xvfrintrm_s(v.raw)};
+}
+
+HWY_API Vec256<double> Floor(Vec256<double> v) {
+ return Vec256<double>{__lasx_xvfrintrm_d(v.raw)};
+}
+
+// ------------------------------ Floating-point classification
+
+// FIXME: disable gcc-14 tree-based loop optimizations to prevent
+// 'HighwayTestGroup/HighwayTest.TestAllIsNaN/LASX' failures
+#if HWY_COMPILER_GCC && !HWY_COMPILER_CLANG
+#pragma GCC push_options
+#pragma GCC optimize("-fno-tree-loop-optimize")
+#endif
+
+HWY_API Mask256<float> IsNaN(Vec256<float> v) {
+ const DFromV<decltype(v)> d;
+ const RebindToSigned<decltype(d)> di;
+ return RebindMask(d,
+ MFromD<decltype(di)>{__lasx_xvfcmp_cune_s(v.raw, v.raw)});
+}
+
+HWY_API Mask256<double> IsNaN(Vec256<double> v) {
+ const DFromV<decltype(v)> d;
+ const RebindToSigned<decltype(d)> di;
+ return RebindMask(d,
+ MFromD<decltype(di)>{__lasx_xvfcmp_cune_d(v.raw, v.raw)});
+}
+
+#if HWY_COMPILER_GCC && !HWY_COMPILER_CLANG
+#pragma GCC pop_options
+#endif
+
+HWY_API Mask256<float> IsEitherNaN(Vec256<float> a, Vec256<float> b) {
+ const DFromV<decltype(a)> d;
+ const RebindToSigned<decltype(d)> di;
+ return RebindMask(d, MFromD<decltype(di)>{__lasx_xvfcmp_cun_s(a.raw, b.raw)});
+}
+
+HWY_API Mask256<double> IsEitherNaN(Vec256<double> a, Vec256<double> b) {
+ const DFromV<decltype(a)> d;
+ const RebindToSigned<decltype(d)> di;
+ return RebindMask(d, MFromD<decltype(di)>{__lasx_xvfcmp_cun_d(a.raw, b.raw)});
+}
+
+// ================================================== MEMORY
+
+// ------------------------------ Load
+
+template <class D, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API VFromD<D> Load(D /* tag */, const TFromD<D>* HWY_RESTRICT aligned) {
+ const RebindToSigned<D> di;
+ return BitCast(D(), VFromD<decltype(di)>{__lasx_xvld(aligned, 0)});
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API VFromD<D> LoadU(D /* tag */, const TFromD<D>* HWY_RESTRICT p) {
+ const RebindToSigned<D> di;
+ return BitCast(D(), VFromD<decltype(di)>{__lasx_xvld(p, 0)});
+}
+
+// ------------------------------ MaskedLoad
+
+template <class D, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API VFromD<D> MaskedLoad(MFromD<D> m, D d,
+ const TFromD<D>* HWY_RESTRICT p) {
+ return IfThenElseZero(m, LoadU(d, p));
+}
+
+// ------------------------------ LoadDup128
+
+template <class D, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API VFromD<D> LoadDup128(D d, const TFromD<D>* HWY_RESTRICT p) {
+ VFromD<Half<D>> vec_tmp;
+ vec_tmp = Load(Half<D>(), p);
+ return Combine(d, vec_tmp, vec_tmp);
+}
+
+// ------------------------------ Store
+
+template <class D, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API void Store(VFromD<D> v, D /* tag */, TFromD<D>* HWY_RESTRICT aligned) {
+ __lasx_xvst(v.raw, aligned, 0);
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API void StoreU(VFromD<D> v, D /* tag */, TFromD<D>* HWY_RESTRICT p) {
+ __lasx_xvst(v.raw, p, 0);
+}
+
+// ------------------------------ BlendedStore
+template <class D, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API void BlendedStore(VFromD<D> v, MFromD<D> m, D d,
+ TFromD<D>* HWY_RESTRICT p) {
+ const RebindToUnsigned<decltype(d)> du;
+ const auto blended =
+ IfThenElse(RebindMask(du, m), BitCast(du, v), BitCast(du, LoadU(d, p)));
+ StoreU(BitCast(d, blended), d, p);
+}
+
+// ================================================== SWIZZLE
+// ------------------------------ LowerHalf
+
+template <class D, HWY_IF_V_SIZE_D(D, 16)>
+HWY_API VFromD<D> LowerHalf(D /* tag */, VFromD<Twice<D>> v) {
+#if HWY_HAS_BUILTIN(__builtin_shufflevector)
+ typedef uint32_t U32RawVectType __attribute__((__vector_size__(32)));
+ return VFromD<D>{reinterpret_cast<typename detail::Raw128<TFromD<D>>::type>(
+ __builtin_shufflevector(reinterpret_cast<U32RawVectType>(v.raw),
+ reinterpret_cast<U32RawVectType>(v.raw), 0, 1, 2,
+ 3))};
+#else
+ const RebindToUnsigned<D> du;
+ const Twice<decltype(du)> dut;
+ alignas(32) __m128i vec_tmp[2];
+ __m256i vec_result = BitCast(dut, v).raw;
+ CopyBytes<32>(&vec_result, vec_tmp);
+ return BitCast(D(), VFromD<decltype(du)>{vec_tmp[0]});
+#endif
+}
+
+template <typename T>
+HWY_API Vec128<T> LowerHalf(Vec256<T> v) {
+ const Full128<T> dh;
+ return LowerHalf(dh, v);
+}
+
+// ------------------------------ UpperHalf
+
+template <class D, HWY_IF_V_SIZE_D(D, 16)>
+HWY_API VFromD<D> UpperHalf(D d, VFromD<Twice<D>> v) {
+#if HWY_HAS_BUILTIN(__builtin_shufflevector)
+ (void)d;
+ typedef uint32_t U32RawVectType __attribute__((__vector_size__(32)));
+ return VFromD<D>{reinterpret_cast<typename detail::Raw128<TFromD<D>>::type>(
+ __builtin_shufflevector(reinterpret_cast<U32RawVectType>(v.raw),
+ reinterpret_cast<U32RawVectType>(v.raw), 4, 5, 6,
+ 7))};
+#else
+ const RebindToUnsigned<decltype(d)> du;
+ const Twice<decltype(du)> dut;
+ alignas(32) __m128i vec_tmp[2];
+ __m256i vec_result = BitCast(dut, v).raw;
+ CopyBytes<32>(&vec_result, vec_tmp);
+ return BitCast(d, VFromD<decltype(du)>{vec_tmp[1]});
+#endif
+}
+
+// ------------------------------ ExtractLane (Store)
+template <typename T>
+HWY_API T ExtractLane(const Vec256<T> v, size_t i) {
+ const DFromV<decltype(v)> d;
+ HWY_DASSERT(i < Lanes(d));
+
+#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang
+ constexpr size_t kLanesPerBlock = 16 / sizeof(T);
+ if (__builtin_constant_p(i < kLanesPerBlock) && (i < kLanesPerBlock)) {
+ return ExtractLane(LowerHalf(Half<decltype(d)>(), v), i);
+ }
+#endif
+
+ alignas(32) T lanes[32 / sizeof(T)];
+ Store(v, d, lanes);
+ return lanes[i];
+}
+
+// ------------------------------ InsertLane (Store)
+template <typename T>
+HWY_API Vec256<T> InsertLane(const Vec256<T> v, size_t i, T t) {
+ return detail::InsertLaneUsingBroadcastAndBlend(v, i, t);
+}
+
+// ------------------------------ GetLane (LowerHalf)
+template <typename T>
+HWY_API T GetLane(const Vec256<T> v) {
+ return GetLane(LowerHalf(v));
+}
+
+// ------------------------------ ExtractBlock (LowerHalf, UpperHalf)
+
+template <int kBlockIdx, class T>
+HWY_API Vec128<T> ExtractBlock(Vec256<T> v) {
+ static_assert(kBlockIdx == 0 || kBlockIdx == 1, "Invalid block index");
+ const Half<DFromV<decltype(v)>> dh;
+ return (kBlockIdx == 0) ? LowerHalf(dh, v) : UpperHalf(dh, v);
+}
+
+// ------------------------------ ZeroExtendVector
+
+template <class D, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API VFromD<D> ZeroExtendVector(D /* tag */, VFromD<Half<D>> lo) {
+#if HWY_HAS_BUILTIN(__builtin_shufflevector)
+ typedef uint32_t U32RawVectType __attribute__((__vector_size__(16)));
+ U32RawVectType zero = {0, 0, 0, 0};
+ return VFromD<D>{reinterpret_cast<typename detail::Raw256<TFromD<D>>::type>(
+ __builtin_shufflevector(reinterpret_cast<U32RawVectType>(lo.raw), zero, 0,
+ 1, 2, 3, 4, 5, 6, 7))};
+#else
+ return Combine(D(), Zero(Half<D>()), lo);
+#endif
+}
+
+// ------------------------------ ZeroExtendResizeBitCast
+
+namespace detail {
+
+template <class DTo, class DFrom>
+HWY_INLINE VFromD<DTo> ZeroExtendResizeBitCast(
+ hwy::SizeTag<8> /* from_size_tag */, hwy::SizeTag<32> /* to_size_tag */,
+ DTo d_to, DFrom d_from, VFromD<DFrom> v) {
+ const Twice<decltype(d_from)> dt_from;
+ const Twice<decltype(dt_from)> dq_from;
+ return BitCast(d_to, ZeroExtendVector(dq_from, ZeroExtendVector(dt_from, v)));
+}
+
+} // namespace detail
+
+// ------------------------------ Combine
+
+template <class D, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API VFromD<D> Combine(D d, VFromD<Half<D>> hi, VFromD<Half<D>> lo) {
+#if HWY_HAS_BUILTIN(__builtin_shufflevector)
+ (void)d;
+ typedef uint32_t U32RawVectType __attribute__((__vector_size__(16)));
+ return VFromD<D>{reinterpret_cast<typename detail::Raw256<TFromD<D>>::type>(
+ __builtin_shufflevector(reinterpret_cast<U32RawVectType>(lo.raw),
+ reinterpret_cast<U32RawVectType>(hi.raw), 0, 1, 2,
+ 3, 4, 5, 6, 7))};
+#else
+ const RebindToUnsigned<decltype(d)> du;
+ const Half<decltype(du)> du128;
+ alignas(32) __m128i vec_tmp[2];
+ __m256i vec_result;
+ vec_tmp[0] = BitCast(du128, lo).raw;
+ vec_tmp[1] = BitCast(du128, hi).raw;
+ CopyBytes<32>(vec_tmp, &vec_result);
+ return BitCast(d, VFromD<decltype(du)>{vec_result});
+#endif
+}
+
+// ------------------------------ ShiftLeftBytes
+template <int kBytes, class D, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API VFromD<D> ShiftLeftBytes(D d, VFromD<D> v) {
+ static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes");
+ if (kBytes == 0) return v;
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(
+ d, VFromD<decltype(du)>{__lasx_xvbsll_v(BitCast(du, v).raw, kBytes)});
+}
+
+// ------------------------------ ShiftRightBytes
+template <int kBytes, class D, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API VFromD<D> ShiftRightBytes(D d, VFromD<D> v) {
+ static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes");
+ if (kBytes == 0) return v;
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(
+ d, VFromD<decltype(du)>{__lasx_xvbsrl_v(BitCast(du, v).raw, kBytes)});
+}
+
+// ------------------------------ CombineShiftRightBytes
+template <int kBytes, class D, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API VFromD<D> CombineShiftRightBytes(D d, VFromD<D> hi, VFromD<D> lo) {
+ return Or(ShiftRightBytes<kBytes>(d, lo), ShiftLeftBytes<16 - kBytes>(d, hi));
+}
+
+// ------------------------------ Broadcast
+
+template <int kLane, class T, HWY_IF_T_SIZE(T, 1)>
+HWY_API Vec256<T> Broadcast(const Vec256<T> v) {
+ static_assert(0 <= kLane && kLane < 16, "Invalid lane");
+ return Vec256<T>{__lasx_xvreplve_b(v.raw, kLane)};
+}
+
+template <int kLane, typename T, HWY_IF_T_SIZE(T, 2)>
+HWY_API Vec256<T> Broadcast(const Vec256<T> v) {
+ static_assert(0 <= kLane && kLane < 8, "Invalid lane");
+ const DFromV<decltype(v)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(
+ d, VFromD<decltype(du)>{__lasx_xvreplve_h(BitCast(du, v).raw, kLane)});
+}
+
+template <int kLane, typename T, HWY_IF_UI32(T)>
+HWY_API Vec256<T> Broadcast(const Vec256<T> v) {
+ static_assert(0 <= kLane && kLane < 4, "Invalid lane");
+ return Vec256<T>{__lasx_xvreplve_w(v.raw, kLane)};
+}
+
+template <int kLane, typename T, HWY_IF_UI64(T)>
+HWY_API Vec256<T> Broadcast(const Vec256<T> v) {
+ static_assert(0 <= kLane && kLane < 2, "Invalid lane");
+ return Vec256<T>{__lasx_xvreplve_d(v.raw, kLane)};
+}
+
+template <int kLane>
+HWY_API Vec256<float> Broadcast(Vec256<float> v) {
+ static_assert(0 <= kLane && kLane < 4, "Invalid lane");
+ const DFromV<decltype(v)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(
+ d, VFromD<decltype(du)>{__lasx_xvreplve_w(BitCast(du, v).raw, kLane)});
+}
+
+template <int kLane>
+HWY_API Vec256<double> Broadcast(const Vec256<double> v) {
+ static_assert(0 <= kLane && kLane < 2, "Invalid lane");
+ const DFromV<decltype(v)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(
+ d, VFromD<decltype(du)>{__lasx_xvreplve_d(BitCast(du, v).raw, kLane)});
+}
+
+// ------------------------------ BroadcastBlock
+
+template <int kBlockIdx, class T>
+HWY_API Vec256<T> BroadcastBlock(Vec256<T> v) {
+ static_assert(kBlockIdx == 0 || kBlockIdx == 1, "Invalid block index");
+ const DFromV<decltype(v)> d;
+ return (kBlockIdx == 0) ? ConcatLowerLower(d, v, v)
+ : ConcatUpperUpper(d, v, v);
+}
+
+// ------------------------------ BroadcastLane
+
+namespace detail {
+
+template <class T, HWY_IF_T_SIZE(T, 1)>
+HWY_INLINE Vec256<T> BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */,
+ Vec256<T> v) {
+ return Vec256<T>{__lasx_xvreplve0_b(v.raw)};
+}
+
+template <class T, HWY_IF_T_SIZE(T, 2)>
+HWY_INLINE Vec256<T> BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */,
+ Vec256<T> v) {
+ const DFromV<decltype(v)> d;
+ const RebindToUnsigned<decltype(d)> du; // for float16_t
+ return BitCast(d,
+ VFromD<decltype(du)>{__lasx_xvreplve0_h(BitCast(du, v).raw)});
+}
+
+template <class T, HWY_IF_UI32(T)>
+HWY_INLINE Vec256<T> BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */,
+ Vec256<T> v) {
+ return Vec256<T>{__lasx_xvreplve0_w(v.raw)};
+}
+
+template <class T, HWY_IF_UI64(T)>
+HWY_INLINE Vec256<T> BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */,
+ Vec256<T> v) {
+ return Vec256<T>{__lasx_xvreplve0_d(v.raw)};
+}
+
+HWY_INLINE Vec256<float> BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */,
+ Vec256<float> v) {
+ const DFromV<decltype(v)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(d,
+ VFromD<decltype(du)>{__lasx_xvreplve0_w(BitCast(du, v).raw)});
+}
+
+HWY_INLINE Vec256<double> BroadcastLane(hwy::SizeTag<0> /* lane_idx_tag */,
+ Vec256<double> v) {
+ const DFromV<decltype(v)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(d,
+ VFromD<decltype(du)>{__lasx_xvreplve0_d(BitCast(du, v).raw)});
+}
+
+template <size_t kLaneIdx, class T, hwy::EnableIf<kLaneIdx != 0>* = nullptr>
+HWY_INLINE Vec256<T> BroadcastLane(hwy::SizeTag<kLaneIdx> /* lane_idx_tag */,
+ Vec256<T> v) {
+ constexpr size_t kLanesPerBlock = 16 / sizeof(T);
+ constexpr int kBlockIdx = static_cast<int>(kLaneIdx / kLanesPerBlock);
+ constexpr int kLaneInBlkIdx =
+ static_cast<int>(kLaneIdx) & (kLanesPerBlock - 1);
+ return Broadcast<kLaneInBlkIdx>(BroadcastBlock<kBlockIdx>(v));
+}
+} // namespace detail
+
+template <int kLaneIdx, class T>
+HWY_API Vec256<T> BroadcastLane(Vec256<T> v) {
+ static_assert(kLaneIdx >= 0, "Invalid lane");
+ return detail::BroadcastLane(hwy::SizeTag<static_cast<size_t>(kLaneIdx)>(),
+ v);
+}
+
+// ------------------------------ Hard-coded shuffles
+
+// Notation: let Vec256<int32_t> have lanes 7,6,5,4,3,2,1,0 (0 is
+// least-significant). Shuffle0321 rotates four-lane blocks one lane to the
+// right (the previous least-significant lane is now most-significant =>
+// 47650321). These could also be implemented via CombineShiftRightBytes but
+// the shuffle_abcd notation is more convenient.
+
+// Swap 32-bit halves in 64-bit halves.
+template <typename T, HWY_IF_UI32(T)>
+HWY_API Vec256<T> Shuffle2301(const Vec256<T> v) {
+ return Vec256<T>{__lasx_xvshuf4i_w(v.raw, 0xb1)};
+}
+HWY_API Vec256<float> Shuffle2301(const Vec256<float> v) {
+ const DFromV<decltype(v)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(
+ d, VFromD<decltype(du)>{__lasx_xvshuf4i_w(BitCast(du, v).raw, 0xb1)});
+}
+
+// Used by generic_ops-inl.h
+namespace detail {
+
+template <typename T, HWY_IF_T_SIZE(T, 4)>
+HWY_API Vec256<T> ShuffleTwo2301(const Vec256<T> a, const Vec256<T> b) {
+ const DFromV<decltype(a)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(d, VFromD<decltype(du)>{__lasx_xvpermi_w(
+ BitCast(du, b).raw, BitCast(du, a).raw, 0xb1)});
+}
+template <typename T, HWY_IF_T_SIZE(T, 4)>
+HWY_API Vec256<T> ShuffleTwo1230(const Vec256<T> a, const Vec256<T> b) {
+ const DFromV<decltype(a)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(d, VFromD<decltype(du)>{__lasx_xvpermi_w(
+ BitCast(du, b).raw, BitCast(du, a).raw, 0x6c)});
+}
+template <typename T, HWY_IF_T_SIZE(T, 4)>
+HWY_API Vec256<T> ShuffleTwo3012(const Vec256<T> a, const Vec256<T> b) {
+ const DFromV<decltype(a)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(d, VFromD<decltype(du)>{__lasx_xvpermi_w(
+ BitCast(du, b).raw, BitCast(du, a).raw, 0xc6)});
+}
+
+} // namespace detail
+
+// Swap 64-bit halves
+HWY_API Vec256<uint32_t> Shuffle1032(const Vec256<uint32_t> v) {
+ return Vec256<uint32_t>{__lasx_xvshuf4i_w(v.raw, 0x4e)};
+}
+HWY_API Vec256<int32_t> Shuffle1032(const Vec256<int32_t> v) {
+ return Vec256<int32_t>{__lasx_xvshuf4i_w(v.raw, 0x4e)};
+}
+HWY_API Vec256<float> Shuffle1032(const Vec256<float> v) {
+ const DFromV<decltype(v)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(
+ d, VFromD<decltype(du)>{__lasx_xvshuf4i_w(BitCast(du, v).raw, 0x4e)});
+}
+HWY_API Vec256<uint64_t> Shuffle01(const Vec256<uint64_t> v) {
+ return Vec256<uint64_t>{__lasx_xvshuf4i_w(v.raw, 0x4e)};
+}
+HWY_API Vec256<int64_t> Shuffle01(const Vec256<int64_t> v) {
+ return Vec256<int64_t>{__lasx_xvshuf4i_w(v.raw, 0x4e)};
+}
+HWY_API Vec256<double> Shuffle01(const Vec256<double> v) {
+ const DFromV<decltype(v)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(
+ d, VFromD<decltype(du)>{__lasx_xvshuf4i_w(BitCast(du, v).raw, 0x4e)});
+}
+
+// Rotate right 32 bits
+HWY_API Vec256<uint32_t> Shuffle0321(const Vec256<uint32_t> v) {
+ return Vec256<uint32_t>{__lasx_xvshuf4i_w(v.raw, 0x39)};
+}
+HWY_API Vec256<int32_t> Shuffle0321(const Vec256<int32_t> v) {
+ return Vec256<int32_t>{__lasx_xvshuf4i_w(v.raw, 0x39)};
+}
+HWY_API Vec256<float> Shuffle0321(const Vec256<float> v) {
+ const DFromV<decltype(v)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(
+ d, VFromD<decltype(du)>{__lasx_xvshuf4i_w(BitCast(du, v).raw, 0x39)});
+}
+// Rotate left 32 bits
+HWY_API Vec256<uint32_t> Shuffle2103(const Vec256<uint32_t> v) {
+ return Vec256<uint32_t>{__lasx_xvshuf4i_w(v.raw, 0x93)};
+}
+HWY_API Vec256<int32_t> Shuffle2103(const Vec256<int32_t> v) {
+ return Vec256<int32_t>{__lasx_xvshuf4i_w(v.raw, 0x93)};
+}
+HWY_API Vec256<float> Shuffle2103(const Vec256<float> v) {
+ const DFromV<decltype(v)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(
+ d, VFromD<decltype(du)>{__lasx_xvshuf4i_w(BitCast(du, v).raw, 0x93)});
+}
+
+// Reverse
+HWY_API Vec256<uint32_t> Shuffle0123(const Vec256<uint32_t> v) {
+ return Vec256<uint32_t>{__lasx_xvshuf4i_w(v.raw, 0x1B)};
+}
+HWY_API Vec256<int32_t> Shuffle0123(const Vec256<int32_t> v) {
+ return Vec256<int32_t>{__lasx_xvshuf4i_w(v.raw, 0x1B)};
+}
+HWY_API Vec256<float> Shuffle0123(const Vec256<float> v) {
+ const DFromV<decltype(v)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(
+ d, VFromD<decltype(du)>{__lasx_xvshuf4i_w(BitCast(du, v).raw, 0x1b)});
+}
+
+// ------------------------------ TableLookupLanes
+
+// Returned by SetTableIndices/IndicesFromVec for use by TableLookupLanes.
+template <typename T>
+struct Indices256 {
+ __m256i raw;
+};
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), typename TI>
+HWY_API Indices256<TFromD<D>> IndicesFromVec(D /* tag */, Vec256<TI> vec) {
+ static_assert(sizeof(TFromD<D>) == sizeof(TI), "Index size must match lane");
+#if HWY_IS_DEBUG_BUILD
+ const Full256<TI> di;
+ HWY_DASSERT(AllFalse(di, Lt(vec, Zero(di))) &&
+ AllTrue(di, Lt(vec, Set(di, static_cast<TI>(2 * Lanes(di))))));
+#endif
+ return Indices256<TFromD<D>>{vec.raw};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), typename TI>
+HWY_API Indices256<TFromD<D>> SetTableIndices(D d, const TI* idx) {
+ const Rebind<TI, decltype(d)> di;
+ return IndicesFromVec(d, LoadU(di, idx));
+}
+
+template <typename T, HWY_IF_T_SIZE(T, 1)>
+HWY_API Vec256<T> TableLookupLanes(Vec256<T> v, Indices256<T> idx) {
+ const DFromV<decltype(v)> d;
+ const auto a = ConcatLowerLower(d, v, v);
+ const auto b = ConcatUpperUpper(d, v, v);
+ return Vec256<T>{__lasx_xvshuf_b(b.raw, a.raw, idx.raw)};
+}
+
+template <typename T, HWY_IF_T_SIZE(T, 2)>
+HWY_API Vec256<T> TableLookupLanes(Vec256<T> v, Indices256<T> idx) {
+ const DFromV<decltype(v)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ const auto a = ConcatLowerLower(d, v, v);
+ const auto b = ConcatUpperUpper(d, v, v);
+ return BitCast(d, VFromD<decltype(du)>{__lasx_xvshuf_h(
+ idx.raw, BitCast(du, b).raw, BitCast(du, a).raw)});
+}
+
+template <typename T, HWY_IF_T_SIZE(T, 4)>
+HWY_API Vec256<T> TableLookupLanes(Vec256<T> v, Indices256<T> idx) {
+ const DFromV<decltype(v)> d;
+ const RebindToSigned<decltype(d)> di;
+ return BitCast(d,
+ Vec256<int32_t>{__lasx_xvperm_w(BitCast(di, v).raw, idx.raw)});
+}
+
+template <typename T, HWY_IF_T_SIZE(T, 8)>
+HWY_API Vec256<T> TableLookupLanes(Vec256<T> v, Indices256<T> idx) {
+ using TI = MakeSigned<T>;
+ const DFromV<decltype(v)> d;
+ const RebindToSigned<decltype(d)> di64;
+ const Repartition<int32_t, decltype(d)> di32;
+ // Replicate 64-bit index into upper 32 bits
+ const Vec256<TI> dup{__lasx_xvpackev_w(idx.raw, idx.raw)};
+ // For each idx64 i, idx32 are 2*i and 2*i+1.
+ const Vec256<TI> idx32 = dup + dup + Set(di64, int64_t{1} << 32);
+ return BitCast(
+ d, TableLookupLanes(BitCast(di32, v), Indices256<int32_t>{idx32.raw}));
+}
+
+template <typename T, HWY_IF_T_SIZE(T, 1)>
+HWY_API Vec256<T> TwoTablesLookupLanes(Vec256<T> a, Vec256<T> b,
+ Indices256<T> idx) {
+ const auto idx2 = Indices256<T>{__lasx_xvandi_b(idx.raw, 31)};
+ const Vec256<T> idx_vec{idx.raw};
+ const auto sel_hi_mask = ShiftLeft<2>(idx_vec);
+ const auto mask0or1 = __lasx_xvslti_b(sel_hi_mask.raw, 0);
+ const auto lo_lookup_result = TableLookupLanes(a, idx);
+ const auto hi_lookup_result = TableLookupLanes(b, idx2);
+ return IfThenElse(Mask256<T>{mask0or1}, hi_lookup_result, lo_lookup_result);
+}
+
+template <typename T, HWY_IF_NOT_T_SIZE(T, 1)>
+HWY_API Vec256<T> TwoTablesLookupLanes(Vec256<T> a, Vec256<T> b,
+ Indices256<T> idx) {
+ const DFromV<decltype(a)> d;
+ const RebindToSigned<decltype(d)> di;
+ const Vec256<TFromD<decltype(di)>> idx_vec{idx.raw};
+ constexpr int shift_count = 8 * sizeof(T) - 6 + CeilLog2(sizeof(T));
+ const auto sel_hi_mask = BitCast(di, ShiftLeft<shift_count>(idx_vec));
+ const auto lo_lookup_result = BitCast(di, TableLookupLanes(a, idx));
+ const auto hi_lookup_result = BitCast(di, TableLookupLanes(b, idx));
+ return BitCast(
+ d, IfNegativeThenElse(sel_hi_mask, hi_lookup_result, lo_lookup_result));
+}
+
+// ------------------------------ SwapAdjacentBlocks
+
+template <typename T>
+HWY_API Vec256<T> SwapAdjacentBlocks(Vec256<T> v) {
+ const DFromV<decltype(v)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(d, Vec256<uint8_t>{__lasx_xvpermi_q(
+ BitCast(du, v).raw, BitCast(du, v).raw, 0x01)});
+}
+
+// ------------------------------ InterleaveEvenBlocks (ConcatLowerLower)
+template <class D, class V = VFromD<D>, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API V InterleaveEvenBlocks(D d, V a, V b) {
+ return ConcatLowerLower(d, b, a);
+}
+
+// ------------------------------ InterleaveOddBlocks (ConcatUpperUpper)
+template <class D, class V = VFromD<D>, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API V InterleaveOddBlocks(D d, V a, V b) {
+ return ConcatUpperUpper(d, b, a);
+}
+
+// ------------------------------ InterleaveLowerBlocks
+template <class D, class V = VFromD<D>, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API V InterleaveLowerBlocks(D d, V a, V b) {
+ return InterleaveEvenBlocks(d, a, b);
+}
+// ------------------------------ InterleaveUpperBlocks
+template <class D, class V = VFromD<D>, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API V InterleaveUpperBlocks(D d, V a, V b) {
+ return InterleaveOddBlocks(d, a, b);
+}
+
+// ------------------------------ Reverse (RotateRight)
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 4)>
+HWY_API VFromD<D> Reverse(D d, const VFromD<D> v) {
+ alignas(32) static constexpr int32_t kReverse[8] = {7, 6, 5, 4, 3, 2, 1, 0};
+ return TableLookupLanes(v, SetTableIndices(d, kReverse));
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 8)>
+HWY_API VFromD<D> Reverse(D d, const VFromD<D> v) {
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(
+ d, VFromD<decltype(du)>{__lasx_xvpermi_d(BitCast(du, v).raw, 0x1b)});
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 2)>
+HWY_API VFromD<D> Reverse(D d, const VFromD<D> v) {
+ alignas(32) static constexpr int16_t kReverse[16] = {
+ 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0};
+ return TableLookupLanes(v, SetTableIndices(d, kReverse));
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 1)>
+HWY_API VFromD<D> Reverse(D d, const VFromD<D> v) {
+ alignas(32) static constexpr TFromD<D> kReverse[32] = {
+ 31, 30, 29, 28, 27, 26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 16,
+ 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0};
+ return TableLookupLanes(v, SetTableIndices(d, kReverse));
+}
+
+// ------------------------------ Reverse4 (SwapAdjacentBlocks)
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 2)>
+HWY_API VFromD<D> Reverse4(D d, const VFromD<D> v) {
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(
+ d, VFromD<decltype(du)>{__lasx_xvshuf4i_h(BitCast(du, v).raw, 0x1b)});
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 8)>
+HWY_API VFromD<D> Reverse4(D /* tag */, const VFromD<D> v) {
+ const RebindToUnsigned<D> du;
+ return BitCast(
+ D(), VFromD<decltype(du)>{__lasx_xvpermi_d(BitCast(du, v).raw, 0x1b)});
+}
+
+// ------------------------------ Reverse8
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 2)>
+HWY_API VFromD<D> Reverse8(D d, const VFromD<D> v) {
+ const RebindToSigned<decltype(d)> di;
+ const VFromD<decltype(di)> shuffle = Dup128VecFromValues(
+ di, 0x0F0E, 0x0D0C, 0x0B0A, 0x0908, 0x0706, 0x0504, 0x0302, 0x0100);
+ return BitCast(d, TableLookupBytes(v, shuffle));
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 4)>
+HWY_API VFromD<D> Reverse8(D d, const VFromD<D> v) {
+ return Reverse(d, v);
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 8)>
+HWY_API VFromD<D> Reverse8(D /* tag */, const VFromD<D> /* v */) {
+ HWY_ASSERT(0);
+}
+
+// ------------------------------ InterleaveLower
+
+// Interleaves lanes from halves of the 128-bit blocks of "a" (which provides
+// the least-significant lane) and "b". To concatenate two half-width integers
+// into one, use ZipLower/Upper instead (also works with scalar).
+
+template <typename T, HWY_IF_T_SIZE(T, 1)>
+HWY_API Vec256<T> InterleaveLower(Vec256<T> a, Vec256<T> b) {
+ return Vec256<T>{__lasx_xvilvl_b(b.raw, a.raw)};
+}
+template <typename T, HWY_IF_T_SIZE(T, 2)>
+HWY_API Vec256<T> InterleaveLower(Vec256<T> a, Vec256<T> b) {
+ const DFromV<decltype(a)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ using VU = VFromD<decltype(du)>; // for float16_t
+ return BitCast(d,
+ VU{__lasx_xvilvl_h(BitCast(du, b).raw, BitCast(du, a).raw)});
+}
+template <typename T, HWY_IF_UI32(T)>
+HWY_API Vec256<T> InterleaveLower(Vec256<T> a, Vec256<T> b) {
+ return Vec256<T>{__lasx_xvilvl_w(b.raw, a.raw)};
+}
+template <typename T, HWY_IF_UI64(T)>
+HWY_API Vec256<T> InterleaveLower(Vec256<T> a, Vec256<T> b) {
+ return Vec256<T>{__lasx_xvilvl_d(b.raw, a.raw)};
+}
+
+HWY_API Vec256<float> InterleaveLower(Vec256<float> a, Vec256<float> b) {
+ const Full256<uint32_t> du;
+ const Full256<float> df;
+ return BitCast(df, Vec256<uint32_t>{__lasx_xvilvl_w(BitCast(du, b).raw,
+ BitCast(du, a).raw)});
+}
+HWY_API Vec256<double> InterleaveLower(Vec256<double> a, Vec256<double> b) {
+ const Full256<uint64_t> du;
+ const Full256<double> df;
+ return BitCast(df, Vec256<uint64_t>{__lasx_xvilvl_d(BitCast(du, b).raw,
+ BitCast(du, a).raw)});
+}
+
+// ------------------------------ InterleaveUpper
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 1)>
+HWY_API VFromD<D> InterleaveUpper(D /* tag */, VFromD<D> a, VFromD<D> b) {
+ return VFromD<D>{__lasx_xvilvh_b(b.raw, a.raw)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 2)>
+HWY_API VFromD<D> InterleaveUpper(D d, VFromD<D> a, VFromD<D> b) {
+ const RebindToUnsigned<decltype(d)> du;
+ using VU = VFromD<decltype(du)>; // for float16_t
+ return BitCast(d,
+ VU{__lasx_xvilvh_h(BitCast(du, b).raw, BitCast(du, a).raw)});
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI32_D(D)>
+HWY_API VFromD<D> InterleaveUpper(D /* tag */, VFromD<D> a, VFromD<D> b) {
+ return VFromD<D>{__lasx_xvilvh_w(b.raw, a.raw)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI64_D(D)>
+HWY_API VFromD<D> InterleaveUpper(D /* tag */, VFromD<D> a, VFromD<D> b) {
+ return VFromD<D>{__lasx_xvilvh_d(b.raw, a.raw)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)>
+HWY_API VFromD<D> InterleaveUpper(D /* tag */, VFromD<D> a, VFromD<D> b) {
+ const RebindToUnsigned<D> du;
+ return BitCast(D(), VFromD<decltype(du)>{__lasx_xvilvh_w(
+ BitCast(du, b).raw, BitCast(du, a).raw)});
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)>
+HWY_API VFromD<D> InterleaveUpper(D /* tag */, VFromD<D> a, VFromD<D> b) {
+ const RebindToUnsigned<D> du;
+ return BitCast(D(), VFromD<decltype(du)>{__lasx_xvilvh_d(
+ BitCast(du, b).raw, BitCast(du, a).raw)});
+}
+
+// ------------------------------ Blocks (LowerHalf, ZeroExtendVector)
+
+// hiH,hiL loH,loL |-> hiL,loL (= lower halves)
+template <class D, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API VFromD<D> ConcatLowerLower(D d, VFromD<D> hi, VFromD<D> lo) {
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(d, VFromD<decltype(du)>{__lasx_xvpermi_q(
+ BitCast(du, hi).raw, BitCast(du, lo).raw, 0x20)});
+}
+
+// hiH,hiL loH,loL |-> hiL,loH (= inner halves / swap blocks)
+template <class D, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API VFromD<D> ConcatLowerUpper(D d, VFromD<D> hi, VFromD<D> lo) {
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(d, VFromD<decltype(du)>{__lasx_xvpermi_q(
+ BitCast(du, hi).raw, BitCast(du, lo).raw, 0x21)});
+}
+
+// hiH,hiL loH,loL |-> hiH,loL (= outer halves)
+template <class D, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API VFromD<D> ConcatUpperLower(D d, VFromD<D> hi, VFromD<D> lo) {
+ const RebindToUnsigned<decltype(d)> du; // for float16_t
+ return BitCast(d, VFromD<decltype(du)>{__lasx_xvpermi_q(
+ BitCast(du, hi).raw, BitCast(du, lo).raw, 0x30)});
+}
+
+// hiH,hiL loH,loL |-> hiH,loH (= upper halves)
+template <class D, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API VFromD<D> ConcatUpperUpper(D d, VFromD<D> hi, VFromD<D> lo) {
+ const RebindToUnsigned<decltype(d)> du; // for float16_t
+ return BitCast(d, VFromD<decltype(du)>{__lasx_xvpermi_q(
+ BitCast(du, hi).raw, BitCast(du, lo).raw, 0x31)});
+}
+
+// ---------------------------- InsertBlock (ConcatLowerLower, ConcatUpperLower)
+template <int kBlockIdx, class T>
+HWY_API Vec256<T> InsertBlock(Vec256<T> v, Vec128<T> blk_to_insert) {
+ static_assert(kBlockIdx == 0 || kBlockIdx == 1, "Invalid block index");
+
+ const DFromV<decltype(v)> d;
+ const auto vec_to_insert = ResizeBitCast(d, blk_to_insert);
+ return (kBlockIdx == 0) ? ConcatUpperLower(d, v, vec_to_insert)
+ : ConcatLowerLower(d, vec_to_insert, v);
+}
+
+// ------------------------------ ConcatOdd
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 1)>
+HWY_API VFromD<D> ConcatOdd(D /* tag */, VFromD<D> hi, VFromD<D> lo) {
+ __m256i od = __lasx_xvpickod_b(hi.raw, lo.raw);
+ return VFromD<D>{__lasx_xvpermi_d(od, 0xd8)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 2)>
+HWY_API VFromD<D> ConcatOdd(D d, VFromD<D> hi, VFromD<D> lo) {
+ const RebindToUnsigned<decltype(d)> du;
+ __m256i od = __lasx_xvpickod_h(BitCast(du, hi).raw, BitCast(du, lo).raw);
+ return BitCast(d, VFromD<decltype(du)>{__lasx_xvpermi_d(od, 0xd8)});
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI32_D(D)>
+HWY_API VFromD<D> ConcatOdd(D /* tag */, VFromD<D> hi, VFromD<D> lo) {
+ __m256i od = __lasx_xvpickod_w(hi.raw, lo.raw);
+ return VFromD<D>{__lasx_xvpermi_d(od, 0xd8)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)>
+HWY_API VFromD<D> ConcatOdd(D d, VFromD<D> hi, VFromD<D> lo) {
+ const RebindToUnsigned<decltype(d)> du;
+ __m256i od = __lasx_xvpickod_w(BitCast(du, hi).raw, BitCast(du, lo).raw);
+ return BitCast(d, VFromD<decltype(du)>{__lasx_xvpermi_d(od, 0xd8)});
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI64_D(D)>
+HWY_API VFromD<D> ConcatOdd(D /* tag */, VFromD<D> hi, VFromD<D> lo) {
+ __m256i od = __lasx_xvpickod_d(hi.raw, lo.raw);
+ return VFromD<D>{__lasx_xvpermi_d(od, 0xd8)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)>
+HWY_API Vec256<double> ConcatOdd(D d, Vec256<double> hi, Vec256<double> lo) {
+ const RebindToUnsigned<decltype(d)> du;
+ __m256i od = __lasx_xvpickod_d(BitCast(du, hi).raw, BitCast(du, lo).raw);
+ return BitCast(d, VFromD<decltype(du)>{__lasx_xvpermi_d(od, 0xd8)});
+}
+
+// ------------------------------ ConcatEven
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 1)>
+HWY_API VFromD<D> ConcatEven(D /* tag */, VFromD<D> hi, VFromD<D> lo) {
+ __m256i ev = __lasx_xvpickev_b(hi.raw, lo.raw);
+ return VFromD<D>{__lasx_xvpermi_d(ev, 0xd8)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 2)>
+HWY_API VFromD<D> ConcatEven(D d, VFromD<D> hi, VFromD<D> lo) {
+ const RebindToUnsigned<decltype(d)> du;
+ __m256i ev = __lasx_xvpickev_h(BitCast(du, hi).raw, BitCast(du, lo).raw);
+ return BitCast(d, VFromD<decltype(du)>{__lasx_xvpermi_d(ev, 0xd8)});
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI32_D(D)>
+HWY_API VFromD<D> ConcatEven(D /* tag */, VFromD<D> hi, VFromD<D> lo) {
+ __m256i ev = __lasx_xvpickev_w(hi.raw, lo.raw);
+ return VFromD<D>{__lasx_xvpermi_d(ev, 0xd8)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)>
+HWY_API VFromD<D> ConcatEven(D d, VFromD<D> hi, VFromD<D> lo) {
+ const RebindToUnsigned<decltype(d)> du;
+ __m256i ev = __lasx_xvpickev_w(BitCast(du, hi).raw, BitCast(du, lo).raw);
+ return BitCast(d, VFromD<decltype(du)>{__lasx_xvpermi_d(ev, 0xd8)});
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_UI64_D(D)>
+HWY_API VFromD<D> ConcatEven(D /* tag */, VFromD<D> hi, VFromD<D> lo) {
+ __m256i ev = __lasx_xvpickev_d(hi.raw, lo.raw);
+ return VFromD<D>{__lasx_xvpermi_d(ev, 0xd8)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)>
+HWY_API Vec256<double> ConcatEven(D d, Vec256<double> hi, Vec256<double> lo) {
+ const RebindToUnsigned<decltype(d)> du;
+ __m256i ev = __lasx_xvpickev_d(BitCast(du, hi).raw, BitCast(du, lo).raw);
+ return BitCast(d, VFromD<decltype(du)>{__lasx_xvpermi_d(ev, 0xd8)});
+}
+
+// ------------------------------ InterleaveWholeLower
+
+template <class D, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API VFromD<D> InterleaveWholeLower(D d, VFromD<D> a, VFromD<D> b) {
+ return ConcatLowerLower(d, InterleaveUpper(d, a, b), InterleaveLower(a, b));
+}
+
+// ------------------------------ InterleaveWholeUpper
+
+template <class D, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API VFromD<D> InterleaveWholeUpper(D d, VFromD<D> a, VFromD<D> b) {
+ return ConcatUpperUpper(d, InterleaveUpper(d, a, b), InterleaveLower(a, b));
+}
+
+// ------------------------------ DupEven (InterleaveLower)
+
+template <typename T, HWY_IF_UI8(T)>
+HWY_API Vec256<T> DupEven(Vec256<T> v) {
+ return Vec256<T>{__lasx_xvpackev_b(v.raw, v.raw)};
+}
+
+template <typename T, HWY_IF_UI16(T)>
+HWY_API Vec256<T> DupEven(Vec256<T> v) {
+ return Vec256<T>{__lasx_xvpackev_h(v.raw, v.raw)};
+}
+
+template <typename T, HWY_IF_UI32(T)>
+HWY_API Vec256<T> DupEven(Vec256<T> v) {
+ return Vec256<T>{__lasx_xvpackev_w(v.raw, v.raw)};
+}
+
+HWY_API Vec256<float> DupEven(Vec256<float> v) {
+ const Full256<uint32_t> du;
+ const DFromV<decltype(v)> d;
+ return BitCast(d, Vec256<uint32_t>{__lasx_xvpackev_w(BitCast(du, v).raw,
+ BitCast(du, v).raw)});
+}
+
+template <typename T, HWY_IF_T_SIZE(T, 8)>
+HWY_API Vec256<T> DupEven(const Vec256<T> v) {
+ const DFromV<decltype(v)> d;
+ return InterleaveLower(d, v, v);
+}
+
+// ------------------------------ DupOdd (InterleaveUpper)
+
+template <typename T, HWY_IF_UI8(T)>
+HWY_API Vec256<T> DupOdd(Vec256<T> v) {
+ return Vec256<T>{__lasx_xvpackod_b(v.raw, v.raw)};
+}
+
+template <typename T, HWY_IF_UI16(T)>
+HWY_API Vec256<T> DupOdd(Vec256<T> v) {
+ return Vec256<T>{__lasx_xvpackod_h(v.raw, v.raw)};
+}
+
+template <typename T, HWY_IF_UI32(T)>
+HWY_API Vec256<T> DupOdd(Vec256<T> v) {
+ return Vec256<T>{__lasx_xvpackod_w(v.raw, v.raw)};
+}
+
+HWY_API Vec256<float> DupOdd(Vec256<float> v) {
+ const Full256<uint32_t> du;
+ const DFromV<decltype(v)> d;
+ return BitCast(d, Vec256<uint32_t>{__lasx_xvpackod_w(BitCast(du, v).raw,
+ BitCast(du, v).raw)});
+}
+
+template <typename T, HWY_IF_T_SIZE(T, 8)>
+HWY_API Vec256<T> DupOdd(const Vec256<T> v) {
+ const DFromV<decltype(v)> d;
+ return InterleaveUpper(d, v, v);
+}
+
+// ------------------------------ OddEven
+
+template <typename T, HWY_IF_T_SIZE(T, 1)>
+HWY_INLINE Vec256<T> OddEven(Vec256<T> a, Vec256<T> b) {
+ __m256i c = __lasx_xvpackod_b(a.raw, a.raw);
+ return Vec256<T>{__lasx_xvpackev_b(c, b.raw)};
+}
+
+template <typename T, HWY_IF_UI16(T)>
+HWY_INLINE Vec256<T> OddEven(Vec256<T> a, Vec256<T> b) {
+ __m256i c = __lasx_xvpackod_h(a.raw, a.raw);
+ return Vec256<T>{__lasx_xvpackev_h(c, b.raw)};
+}
+
+template <typename T, HWY_IF_UI32(T)>
+HWY_INLINE Vec256<T> OddEven(Vec256<T> a, Vec256<T> b) {
+ __m256i c = __lasx_xvpackod_w(a.raw, a.raw);
+ return Vec256<T>{__lasx_xvpackev_w(c, b.raw)};
+}
+
+template <typename T, HWY_IF_UI64(T)>
+HWY_INLINE Vec256<T> OddEven(Vec256<T> a, Vec256<T> b) {
+ return Vec256<T>{__lasx_xvextrins_d(b.raw, a.raw, 0x11)};
+}
+
+HWY_API Vec256<float> OddEven(Vec256<float> a, Vec256<float> b) {
+ const DFromV<decltype(a)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ __m256i c = __lasx_xvpackod_w(BitCast(du, a).raw, BitCast(du, a).raw);
+ return BitCast(
+ d, VFromD<decltype(du)>{__lasx_xvpackev_w(c, BitCast(du, b).raw)});
+}
+
+HWY_API Vec256<double> OddEven(Vec256<double> a, Vec256<double> b) {
+ const DFromV<decltype(a)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(d, VFromD<decltype(du)>{__lasx_xvextrins_d(
+ BitCast(du, b).raw, BitCast(du, a).raw, 0x11)});
+}
+
+// -------------------------- InterleaveEven
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 1)>
+HWY_API VFromD<D> InterleaveEven(D /*d*/, VFromD<D> a, VFromD<D> b) {
+ return VFromD<D>{__lasx_xvpackev_b(b.raw, a.raw)};
+}
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 2)>
+HWY_API VFromD<D> InterleaveEven(D /*d*/, VFromD<D> a, VFromD<D> b) {
+ return VFromD<D>{__lasx_xvpackev_h(b.raw, a.raw)};
+}
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 4)>
+HWY_API VFromD<D> InterleaveEven(D d, VFromD<D> a, VFromD<D> b) {
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(d, VFromD<decltype(du)>{__lasx_xvpackev_w(
+ BitCast(du, b).raw, BitCast(du, a).raw)});
+}
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 8)>
+HWY_API VFromD<D> InterleaveEven(D /*d*/, VFromD<D> a, VFromD<D> b) {
+ return InterleaveLower(a, b);
+}
+
+// -------------------------- InterleaveOdd
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 1)>
+HWY_API VFromD<D> InterleaveOdd(D /*d*/, VFromD<D> a, VFromD<D> b) {
+ return VFromD<D>{__lasx_xvpackod_b(b.raw, a.raw)};
+}
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 2)>
+HWY_API VFromD<D> InterleaveOdd(D /*d*/, VFromD<D> a, VFromD<D> b) {
+ return VFromD<D>{__lasx_xvpackod_h(b.raw, a.raw)};
+}
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 4)>
+HWY_API VFromD<D> InterleaveOdd(D d, VFromD<D> a, VFromD<D> b) {
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(d, VFromD<decltype(du)>{__lasx_xvpackod_w(
+ BitCast(du, b).raw, BitCast(du, a).raw)});
+}
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 8)>
+HWY_API VFromD<D> InterleaveOdd(D d, VFromD<D> a, VFromD<D> b) {
+ return InterleaveUpper(d, a, b);
+}
+
+// ------------------------------ OddEvenBlocks
+
+template <typename T>
+Vec256<T> OddEvenBlocks(Vec256<T> odd, Vec256<T> even) {
+ const DFromV<decltype(odd)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(d, VFromD<decltype(du)>{__lasx_xvpermi_q(
+ BitCast(du, odd).raw, BitCast(du, even).raw, 0x30)});
+}
+
+// ------------------------------ ReverseBlocks (SwapAdjacentBlocks)
+
+template <class D, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API VFromD<D> ReverseBlocks(D /*d*/, VFromD<D> v) {
+ return SwapAdjacentBlocks(v);
+}
+
+// ------------------------------ TableLookupBytes (ZeroExtendVector)
+
+// Both full
+template <typename T, typename TI>
+HWY_API Vec256<TI> TableLookupBytes(Vec256<T> bytes, Vec256<TI> from) {
+ const DFromV<decltype(from)> d;
+ return BitCast(d, Vec256<uint8_t>{__lasx_xvshuf_b(
+ BitCast(Full256<uint8_t>(), bytes).raw,
+ BitCast(Full256<uint8_t>(), bytes).raw,
+ BitCast(Full256<uint8_t>(), from).raw)});
+}
+
+// Partial index vector
+template <typename T, typename TI, size_t NI>
+HWY_API Vec128<TI, NI> TableLookupBytes(Vec256<T> bytes, Vec128<TI, NI> from) {
+ const Full256<TI> di;
+ const Half<decltype(di)> dih;
+ // First expand to full 128, then 256.
+ const auto from_256 = ZeroExtendVector(di, Vec128<TI>{from.raw});
+ const auto tbl_full = TableLookupBytes(bytes, from_256);
+ // Shrink to 128, then partial.
+ return Vec128<TI, NI>{LowerHalf(dih, tbl_full).raw};
+}
+
+// Partial table vector
+template <typename T, size_t N, typename TI>
+HWY_API Vec256<TI> TableLookupBytes(Vec128<T, N> bytes, Vec256<TI> from) {
+ const Full256<T> d;
+ // First expand to full 128, then 256.
+ const auto bytes_256 = ZeroExtendVector(d, Vec128<T>{bytes.raw});
+ return TableLookupBytes(bytes_256, from);
+}
+
+// ------------------------------ Per4LaneBlockShuffle
+
+namespace detail {
+
+template <class D, HWY_IF_V_SIZE_D(D, 32)>
+HWY_INLINE VFromD<D> Per4LaneBlkShufDupSet4xU32(D d, const uint32_t x3,
+ const uint32_t x2,
+ const uint32_t x1,
+ const uint32_t x0) {
+ alignas(32) uint32_t rawU32[8] = {x0, x1, x2, x3, x0, x1, x2, x3};
+ return BitCast(d, Vec256<uint32_t>{__lasx_xvld(rawU32, 0)});
+}
+
+template <size_t kIdx3210, class V, HWY_IF_NOT_FLOAT(TFromV<V>)>
+HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<kIdx3210> /*idx_3210_tag*/,
+ hwy::SizeTag<4> /*lane_size_tag*/,
+ hwy::SizeTag<32> /*vect_size_tag*/, V v) {
+ const DFromV<decltype(v)> d;
+ V idx =
+ Per4LaneBlkShufDupSet4xU32(d, (kIdx3210 >> 6) & 3, (kIdx3210 >> 4) & 3,
+ (kIdx3210 >> 2) & 3, kIdx3210 & 3);
+ return V{__lasx_xvshuf_w(idx.raw, v.raw, v.raw)};
+}
+
+template <size_t kIdx3210, class V, HWY_IF_FLOAT(TFromV<V>)>
+HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<kIdx3210> /*idx_3210_tag*/,
+ hwy::SizeTag<4> /*lane_size_tag*/,
+ hwy::SizeTag<32> /*vect_size_tag*/, V v) {
+ const DFromV<decltype(v)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ const auto idx =
+ Per4LaneBlkShufDupSet4xU32(du, (kIdx3210 >> 6) & 3, (kIdx3210 >> 4) & 3,
+ (kIdx3210 >> 2) & 3, kIdx3210 & 3);
+ return BitCast(d, VFromD<decltype(du)>{__lasx_xvshuf_w(
+ idx.raw, BitCast(du, v).raw, BitCast(du, v).raw)});
+}
+
+template <class V>
+HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0x44> /*idx_3210_tag*/,
+ hwy::SizeTag<8> /*lane_size_tag*/,
+ hwy::SizeTag<32> /*vect_size_tag*/, V v) {
+ const DFromV<decltype(v)> d;
+ return ConcatLowerLower(d, v, v);
+}
+
+template <class V>
+HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<0xEE> /*idx_3210_tag*/,
+ hwy::SizeTag<8> /*lane_size_tag*/,
+ hwy::SizeTag<32> /*vect_size_tag*/, V v) {
+ const DFromV<decltype(v)> d;
+ return ConcatUpperUpper(d, v, v);
+}
+
+template <size_t kIdx3210, class V, HWY_IF_NOT_FLOAT(TFromV<V>)>
+HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<kIdx3210> /*idx_3210_tag*/,
+ hwy::SizeTag<8> /*lane_size_tag*/,
+ hwy::SizeTag<32> /*vect_size_tag*/, V v) {
+ const DFromV<decltype(v)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ using VU = VFromD<decltype(du)>;
+
+ const VU vu = BitCast(du, v);
+ return BitCast(
+ d, VU{__lasx_xvpermi_d(vu.raw, static_cast<int>(kIdx3210 & 0xFF))});
+}
+
+} // namespace detail
+
+// ------------------------------ SlideUpLanes
+
+namespace detail {
+
+template <class D, HWY_IF_V_SIZE_D(D, 32)>
+HWY_INLINE VFromD<D> TableLookupSlideUpLanes(D d, VFromD<D> v, size_t amt) {
+ const RebindToUnsigned<D> du;
+ using TU = TFromD<decltype(du)>;
+ const auto idx = Iota(du, static_cast<TU>(size_t{0} - amt));
+ const auto masked_idx = And(idx, Set(du, static_cast<TU>(MaxLanes(d) - 1)));
+ return BitCast(
+ d, IfThenElseZero(
+ idx == masked_idx,
+ TableLookupLanes(BitCast(du, v), IndicesFromVec(du, masked_idx))));
+}
+
+} // namespace detail
+
+template <int kBlocks, class D, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API VFromD<D> SlideUpBlocks(D d, VFromD<D> v) {
+ static_assert(0 <= kBlocks && kBlocks <= 1,
+ "kBlocks must be between 0 and 1");
+ return (kBlocks == 1) ? ConcatLowerLower(d, v, Zero(d)) : v;
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API VFromD<D> SlideUpLanes(D d, VFromD<D> v, size_t amt) {
+#if !HWY_IS_DEBUG_BUILD
+ constexpr size_t kLanesPerBlock = 16 / sizeof(TFromD<D>);
+ if (__builtin_constant_p(amt)) {
+ const auto v_lo = ConcatLowerLower(d, v, Zero(d));
+ switch (amt * sizeof(TFromD<D>)) {
+ case 0:
+ return v;
+ case 1:
+ return CombineShiftRightBytes<15>(d, v, v_lo);
+ case 2:
+ return CombineShiftRightBytes<14>(d, v, v_lo);
+ case 3:
+ return CombineShiftRightBytes<13>(d, v, v_lo);
+ case 4:
+ return CombineShiftRightBytes<12>(d, v, v_lo);
+ case 5:
+ return CombineShiftRightBytes<11>(d, v, v_lo);
+ case 6:
+ return CombineShiftRightBytes<10>(d, v, v_lo);
+ case 7:
+ return CombineShiftRightBytes<9>(d, v, v_lo);
+ case 8:
+ return CombineShiftRightBytes<8>(d, v, v_lo);
+ case 9:
+ return CombineShiftRightBytes<7>(d, v, v_lo);
+ case 10:
+ return CombineShiftRightBytes<6>(d, v, v_lo);
+ case 11:
+ return CombineShiftRightBytes<5>(d, v, v_lo);
+ case 12:
+ return CombineShiftRightBytes<4>(d, v, v_lo);
+ case 13:
+ return CombineShiftRightBytes<3>(d, v, v_lo);
+ case 14:
+ return CombineShiftRightBytes<2>(d, v, v_lo);
+ case 15:
+ return CombineShiftRightBytes<1>(d, v, v_lo);
+ }
+ }
+
+ if (__builtin_constant_p(amt >= kLanesPerBlock) && amt >= kLanesPerBlock) {
+ const Half<decltype(d)> dh;
+ return Combine(d, SlideUpLanes(dh, LowerHalf(dh, v), amt - kLanesPerBlock),
+ Zero(dh));
+ }
+#endif
+
+ return detail::TableLookupSlideUpLanes(d, v, amt);
+}
+
+// ------------------------------ Slide1Up
+
+template <typename D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 1)>
+HWY_API VFromD<D> Slide1Up(D d, VFromD<D> v) {
+ const auto v_lo = ConcatLowerLower(d, v, Zero(d));
+ return CombineShiftRightBytes<15>(d, v, v_lo);
+}
+
+template <typename D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 2)>
+HWY_API VFromD<D> Slide1Up(D d, VFromD<D> v) {
+ const auto v_lo = ConcatLowerLower(d, v, Zero(d));
+ return CombineShiftRightBytes<14>(d, v, v_lo);
+}
+
+template <typename D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 4)>
+HWY_API VFromD<D> Slide1Up(D d, VFromD<D> v) {
+ const auto v_lo = ConcatLowerLower(d, v, Zero(d));
+ return CombineShiftRightBytes<12>(d, v, v_lo);
+}
+
+template <typename D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 8)>
+HWY_API VFromD<D> Slide1Up(D d, VFromD<D> v) {
+ const auto v_lo = ConcatLowerLower(d, v, Zero(d));
+ return CombineShiftRightBytes<8>(d, v, v_lo);
+}
+
+// ------------------------------ SlideDownLanes
+
+namespace detail {
+
+template <class D, HWY_IF_V_SIZE_D(D, 32)>
+HWY_INLINE VFromD<D> TableLookupSlideDownLanes(D d, VFromD<D> v, size_t amt) {
+ const RebindToUnsigned<decltype(d)> du;
+ using TU = TFromD<decltype(du)>;
+ const auto idx = Iota(du, static_cast<TU>(amt));
+ const auto masked_idx = And(idx, Set(du, static_cast<TU>(MaxLanes(d) - 1)));
+ return IfThenElseZero(RebindMask(d, idx == masked_idx),
+ TableLookupLanes(v, IndicesFromVec(d, masked_idx)));
+}
+
+} // namespace detail
+
+template <int kBlocks, class D, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API VFromD<D> SlideDownBlocks(D d, VFromD<D> v) {
+ static_assert(0 <= kBlocks && kBlocks <= 1,
+ "kBlocks must be between 0 and 1");
+ const Half<decltype(d)> dh;
+ return (kBlocks == 1) ? ZeroExtendVector(d, UpperHalf(dh, v)) : v;
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API VFromD<D> SlideDownLanes(D d, VFromD<D> v, size_t amt) {
+#if !HWY_IS_DEBUG_BUILD
+ constexpr size_t kLanesPerBlock = 16 / sizeof(TFromD<D>);
+ const Half<decltype(d)> dh;
+ if (__builtin_constant_p(amt)) {
+ const auto v_hi = ZeroExtendVector(d, UpperHalf(dh, v));
+ switch (amt * sizeof(TFromD<D>)) {
+ case 0:
+ return v;
+ case 1:
+ return CombineShiftRightBytes<1>(d, v_hi, v);
+ case 2:
+ return CombineShiftRightBytes<2>(d, v_hi, v);
+ case 3:
+ return CombineShiftRightBytes<3>(d, v_hi, v);
+ case 4:
+ return CombineShiftRightBytes<4>(d, v_hi, v);
+ case 5:
+ return CombineShiftRightBytes<5>(d, v_hi, v);
+ case 6:
+ return CombineShiftRightBytes<6>(d, v_hi, v);
+ case 7:
+ return CombineShiftRightBytes<7>(d, v_hi, v);
+ case 8:
+ return CombineShiftRightBytes<8>(d, v_hi, v);
+ case 9:
+ return CombineShiftRightBytes<9>(d, v_hi, v);
+ case 10:
+ return CombineShiftRightBytes<10>(d, v_hi, v);
+ case 11:
+ return CombineShiftRightBytes<11>(d, v_hi, v);
+ case 12:
+ return CombineShiftRightBytes<12>(d, v_hi, v);
+ case 13:
+ return CombineShiftRightBytes<13>(d, v_hi, v);
+ case 14:
+ return CombineShiftRightBytes<14>(d, v_hi, v);
+ case 15:
+ return CombineShiftRightBytes<15>(d, v_hi, v);
+ }
+ }
+
+ if (__builtin_constant_p(amt >= kLanesPerBlock) && amt >= kLanesPerBlock) {
+ return ZeroExtendVector(
+ d, SlideDownLanes(dh, UpperHalf(dh, v), amt - kLanesPerBlock));
+ }
+#endif
+
+ return detail::TableLookupSlideDownLanes(d, v, amt);
+}
+
+// ------------------------------ Slide1Down
+
+template <typename D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 1)>
+HWY_API VFromD<D> Slide1Down(D d, VFromD<D> v) {
+ const Half<decltype(d)> dh;
+ const auto v_hi = ZeroExtendVector(d, UpperHalf(dh, v));
+ return CombineShiftRightBytes<1>(d, v_hi, v);
+}
+
+template <typename D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 2)>
+HWY_API VFromD<D> Slide1Down(D d, VFromD<D> v) {
+ const Half<decltype(d)> dh;
+ const auto v_hi = ZeroExtendVector(d, UpperHalf(dh, v));
+ return CombineShiftRightBytes<2>(d, v_hi, v);
+}
+
+template <typename D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 4)>
+HWY_API VFromD<D> Slide1Down(D d, VFromD<D> v) {
+ const Half<decltype(d)> dh;
+ const auto v_hi = ZeroExtendVector(d, UpperHalf(dh, v));
+ return CombineShiftRightBytes<4>(d, v_hi, v);
+}
+
+template <typename D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 8)>
+HWY_API VFromD<D> Slide1Down(D d, VFromD<D> v) {
+ const Half<decltype(d)> dh;
+ const auto v_hi = ZeroExtendVector(d, UpperHalf(dh, v));
+ return CombineShiftRightBytes<8>(d, v_hi, v);
+}
+
+// ------------------------------ Shl (Mul, ZipLower)
+namespace detail {
+
+HWY_INLINE Vec256<uint8_t> Shl(hwy::UnsignedTag /*tag*/, Vec256<uint8_t> v,
+ Vec256<uint8_t> bits) {
+ return Vec256<uint8_t>{__lasx_xvsll_b(v.raw, bits.raw)};
+}
+
+HWY_INLINE Vec256<uint16_t> Shl(hwy::UnsignedTag /*tag*/, Vec256<uint16_t> v,
+ Vec256<uint16_t> bits) {
+ return Vec256<uint16_t>{__lasx_xvsll_h(v.raw, bits.raw)};
+}
+
+HWY_INLINE Vec256<uint32_t> Shl(hwy::UnsignedTag /*tag*/, Vec256<uint32_t> v,
+ Vec256<uint32_t> bits) {
+ return Vec256<uint32_t>{__lasx_xvsll_w(v.raw, bits.raw)};
+}
+
+HWY_INLINE Vec256<uint64_t> Shl(hwy::UnsignedTag /*tag*/, Vec256<uint64_t> v,
+ Vec256<uint64_t> bits) {
+ return Vec256<uint64_t>{__lasx_xvsll_d(v.raw, bits.raw)};
+}
+
+template <typename T>
+HWY_INLINE Vec256<T> Shl(hwy::SignedTag /*tag*/, Vec256<T> v, Vec256<T> bits) {
+ // Signed left shifts are the same as unsigned.
+ const Full256<T> di;
+ const Full256<MakeUnsigned<T>> du;
+ return BitCast(di,
+ Shl(hwy::UnsignedTag(), BitCast(du, v), BitCast(du, bits)));
+}
+
+} // namespace detail
+
+template <typename T>
+HWY_API Vec256<T> operator<<(Vec256<T> v, Vec256<T> bits) {
+ return detail::Shl(hwy::TypeTag<T>(), v, bits);
+}
+
+// ------------------------------ Shr (MulHigh, IfThenElse, Not)
+
+HWY_API Vec256<uint8_t> operator>>(Vec256<uint8_t> v, Vec256<uint8_t> bits) {
+ return Vec256<uint8_t>{__lasx_xvsrl_b(v.raw, bits.raw)};
+}
+
+HWY_API Vec256<uint16_t> operator>>(Vec256<uint16_t> v, Vec256<uint16_t> bits) {
+ return Vec256<uint16_t>{__lasx_xvsrl_h(v.raw, bits.raw)};
+}
+
+HWY_API Vec256<uint32_t> operator>>(Vec256<uint32_t> v, Vec256<uint32_t> bits) {
+ return Vec256<uint32_t>{__lasx_xvsrl_w(v.raw, bits.raw)};
+}
+
+HWY_API Vec256<uint64_t> operator>>(Vec256<uint64_t> v, Vec256<uint64_t> bits) {
+ return Vec256<uint64_t>{__lasx_xvsrl_d(v.raw, bits.raw)};
+}
+
+HWY_API Vec256<int8_t> operator>>(Vec256<int8_t> v, Vec256<int8_t> bits) {
+ return Vec256<int8_t>{__lasx_xvsra_b(v.raw, bits.raw)};
+}
+
+HWY_API Vec256<int16_t> operator>>(Vec256<int16_t> v, Vec256<int16_t> bits) {
+ return Vec256<int16_t>{__lasx_xvsra_h(v.raw, bits.raw)};
+}
+
+HWY_API Vec256<int32_t> operator>>(Vec256<int32_t> v, Vec256<int32_t> bits) {
+ return Vec256<int32_t>{__lasx_xvsra_w(v.raw, bits.raw)};
+}
+
+HWY_API Vec256<int64_t> operator>>(Vec256<int64_t> v, Vec256<int64_t> bits) {
+ return Vec256<int64_t>{__lasx_xvsra_d(v.raw, bits.raw)};
+}
+
+// ------------------------------ WidenMulPairwiseAdd
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I32_D(D)>
+HWY_API VFromD<D> WidenMulPairwiseAdd(D /*d32*/, Vec256<int16_t> a,
+ Vec256<int16_t> b) {
+ __m256i ev = __lasx_xvmulwev_w_h(b.raw, a.raw);
+ return VFromD<D>{__lasx_xvmaddwod_w_h(ev, b.raw, a.raw)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U32_D(D)>
+HWY_API VFromD<D> WidenMulPairwiseAdd(D /*d32*/, Vec256<uint16_t> a,
+ Vec256<uint16_t> b) {
+ __m256i ev = __lasx_xvmulwev_w_hu(b.raw, a.raw);
+ return VFromD<D>{__lasx_xvmaddwod_w_hu(ev, b.raw, a.raw)};
+}
+
+// ------------------------------ ReorderWidenMulAccumulate
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I32_D(D)>
+HWY_API VFromD<D> ReorderWidenMulAccumulate(D /*tag*/, Vec256<int16_t> a,
+ Vec256<int16_t> b,
+ const VFromD<D> sum0,
+ VFromD<D>& /*sum1*/) {
+ return VFromD<D>{__lasx_xvmaddwev_w_h(
+ __lasx_xvmaddwod_w_h(sum0.raw, a.raw, b.raw), a.raw, b.raw)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U32_D(D)>
+HWY_API VFromD<D> ReorderWidenMulAccumulate(D /*tag*/, Vec256<uint16_t> a,
+ Vec256<uint16_t> b,
+ const VFromD<D> sum0,
+ VFromD<D>& /*sum1*/) {
+ return VFromD<D>{__lasx_xvmaddwev_w_hu(
+ __lasx_xvmaddwod_w_hu(sum0.raw, a.raw, b.raw), a.raw, b.raw)};
+}
+
+// ------------------------------ RearrangeToOddPlusEven
+HWY_API Vec256<int32_t> RearrangeToOddPlusEven(const Vec256<int32_t> sum0,
+ Vec256<int32_t> /*sum1*/) {
+ return sum0; // invariant already holds
+}
+
+HWY_API Vec256<uint32_t> RearrangeToOddPlusEven(const Vec256<uint32_t> sum0,
+ Vec256<uint32_t> /*sum1*/) {
+ return sum0; // invariant already holds
+}
+
+// ================================================== CONVERT
+
+// ------------------------------ Promotions (part w/ narrow lanes -> full)
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)>
+HWY_API VFromD<D> PromoteTo(D /* tag */, Vec128<hwy::float16_t> v) {
+ const Repartition<hwy::float16_t, D> df16;
+ const auto from_128 = ZeroExtendVector(df16, v);
+ const VFromD<decltype(df16)> f16_concat{__lasx_xvpermi_d(from_128.raw, 0xd8)};
+ return VFromD<D>{__lasx_xvfcvtl_s_h(f16_concat.raw)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)>
+HWY_API VFromD<D> PromoteTo(D /* tag */, Vec128<float> v) {
+ const Repartition<float, D> df;
+ const RebindToSigned<decltype(df)> di;
+ const auto from_128 = ZeroExtendVector(df, v);
+ const auto f32_concat = BitCast(
+ df, Vec256<uint32_t>{__lasx_xvpermi_d(BitCast(di, from_128).raw, 0xd8)});
+ return VFromD<D>{__lasx_xvfcvtl_d_s(f32_concat.raw)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I64_D(D)>
+HWY_API VFromD<D> PromoteTo(D /*di64*/, Vec128<float> v) {
+ const Repartition<float, D> df;
+ const RebindToSigned<decltype(df)> di;
+ const auto from_128 = ZeroExtendVector(df, v);
+ const auto f32_concat = BitCast(
+ df, Vec256<uint32_t>{__lasx_xvpermi_d(BitCast(di, from_128).raw, 0xd8)});
+ return VFromD<D>{__lasx_xvftintrzl_l_s(f32_concat.raw)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)>
+HWY_API VFromD<D> PromoteTo(D /* tag */, Vec128<int32_t> v) {
+ alignas(32) __m128i vec_tmp[2];
+ __m256i vec_temp;
+ vec_tmp[0] = v.raw;
+ CopyBytes<32>(vec_tmp, &vec_temp);
+ vec_temp = __lasx_xvpermi_d(vec_temp, 0xd8);
+ vec_temp = __lasx_xvsllwil_d_w(vec_temp, 0);
+ return VFromD<D>{__lasx_xvffint_d_l(vec_temp)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)>
+HWY_API Vec256<double> PromoteTo(D /* tag */, Vec128<uint32_t> v) {
+ alignas(32) __m128i vec_tmp[2];
+ __m256i vec_temp;
+ vec_tmp[0] = v.raw;
+ CopyBytes<32>(vec_tmp, &vec_temp);
+ vec_temp = __lasx_xvpermi_d(vec_temp, 0xd8);
+ vec_temp = __lasx_xvsllwil_du_wu(vec_temp, 0);
+ return VFromD<D>{__lasx_xvffint_d_lu(vec_temp)};
+}
+
+// Unsigned: zero-extend.
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U16_D(D)>
+HWY_API VFromD<D> PromoteTo(D /* tag */, Vec128<uint8_t> v) {
+ alignas(32) __m128i vec_tmp[2];
+ __m256i vec_temp;
+ vec_tmp[0] = v.raw;
+ CopyBytes<32>(vec_tmp, &vec_temp);
+ vec_temp = __lasx_xvpermi_d(vec_temp, 0xd8);
+ return VFromD<D>{__lasx_xvsllwil_hu_bu(vec_temp, 0)};
+}
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U32_D(D)>
+HWY_API VFromD<D> PromoteTo(D /* tag */, Vec128<uint8_t, 8> v) {
+ alignas(32) __m128i vec_tmp[2];
+ __m256i vec_temp;
+ vec_tmp[0] = v.raw;
+ CopyBytes<32>(vec_tmp, &vec_temp);
+ vec_temp = __lasx_xvsllwil_hu_bu(vec_temp, 0);
+ vec_temp = __lasx_xvpermi_d(vec_temp, 0xd8);
+ return VFromD<D>{__lasx_xvsllwil_wu_hu(vec_temp, 0)};
+}
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U32_D(D)>
+HWY_API VFromD<D> PromoteTo(D /* tag */, Vec128<uint16_t> v) {
+ alignas(32) __m128i vec_tmp[2];
+ __m256i vec_temp;
+ vec_tmp[0] = v.raw;
+ CopyBytes<32>(vec_tmp, &vec_temp);
+ vec_temp = __lasx_xvpermi_d(vec_temp, 0xd8);
+ return VFromD<D>{__lasx_xvsllwil_wu_hu(vec_temp, 0)};
+}
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U64_D(D)>
+HWY_API VFromD<D> PromoteTo(D /* tag */, Vec128<uint32_t> v) {
+ alignas(32) __m128i vec_tmp[2];
+ __m256i vec_temp;
+ vec_tmp[0] = v.raw;
+ CopyBytes<32>(vec_tmp, &vec_temp);
+ vec_temp = __lasx_xvpermi_d(vec_temp, 0xd8);
+ return VFromD<D>{__lasx_xvsllwil_du_wu(vec_temp, 0)};
+}
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U64_D(D)>
+HWY_API VFromD<D> PromoteTo(D /* tag */, Vec64<uint16_t> v) {
+ alignas(32) __m128i vec_tmp[2];
+ __m256i vec_temp;
+ vec_tmp[0] = v.raw;
+ CopyBytes<32>(vec_tmp, &vec_temp);
+ vec_temp = __lasx_xvsllwil_wu_hu(vec_temp, 0);
+ vec_temp = __lasx_xvpermi_d(vec_temp, 0xd8);
+ return VFromD<D>{__lasx_xvsllwil_du_wu(vec_temp, 0)};
+}
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U64_D(D)>
+HWY_API VFromD<D> PromoteTo(D /* tag */, Vec32<uint8_t> v) {
+ alignas(32) __m128i vec_tmp[2];
+ __m256i vec_temp;
+ vec_tmp[0] = v.raw;
+ CopyBytes<32>(vec_tmp, &vec_temp);
+ vec_temp = __lasx_xvsllwil_hu_bu(vec_temp, 0);
+ vec_temp = __lasx_xvsllwil_wu_hu(vec_temp, 0);
+ vec_temp = __lasx_xvpermi_d(vec_temp, 0xd8);
+ return VFromD<D>{__lasx_xvsllwil_du_wu(vec_temp, 0)};
+}
+
+// Signed: replicate sign bit.
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I16_D(D)>
+HWY_API VFromD<D> PromoteTo(D /* tag */, Vec128<int8_t> v) {
+ alignas(32) __m128i vec_tmp[2];
+ __m256i vec_temp;
+ vec_tmp[0] = v.raw;
+ CopyBytes<32>(vec_tmp, &vec_temp);
+ vec_temp = __lasx_xvpermi_d(vec_temp, 0xd8);
+ return VFromD<D>{__lasx_xvsllwil_h_b(vec_temp, 0)};
+}
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I32_D(D)>
+HWY_API VFromD<D> PromoteTo(D /* tag */, Vec128<int8_t, 8> v) {
+ alignas(32) __m128i vec_tmp[2];
+ __m256i vec_temp;
+ vec_tmp[0] = v.raw;
+ CopyBytes<32>(vec_tmp, &vec_temp);
+ vec_temp = __lasx_xvsllwil_h_b(vec_temp, 0);
+ vec_temp = __lasx_xvpermi_d(vec_temp, 0xd8);
+ return VFromD<D>{__lasx_xvsllwil_w_h(vec_temp, 0)};
+}
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I32_D(D)>
+HWY_API VFromD<D> PromoteTo(D /* tag */, Vec128<int16_t> v) {
+ alignas(32) __m128i vec_tmp[2];
+ __m256i vec_temp;
+ vec_tmp[0] = v.raw;
+ CopyBytes<32>(vec_tmp, &vec_temp);
+ vec_temp = __lasx_xvpermi_d(vec_temp, 0xd8);
+ return VFromD<D>{__lasx_xvsllwil_w_h(vec_temp, 0)};
+}
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I64_D(D)>
+HWY_API VFromD<D> PromoteTo(D /* tag */, Vec128<int32_t> v) {
+ alignas(32) __m128i vec_tmp[2];
+ __m256i vec_temp;
+ vec_tmp[0] = v.raw;
+ CopyBytes<32>(vec_tmp, &vec_temp);
+ vec_temp = __lasx_xvpermi_d(vec_temp, 0xd8);
+ return VFromD<D>{__lasx_xvsllwil_d_w(vec_temp, 0)};
+}
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I64_D(D)>
+HWY_API VFromD<D> PromoteTo(D /* tag */, Vec64<int16_t> v) {
+ alignas(32) __m128i vec_tmp[2];
+ __m256i vec_temp;
+ vec_tmp[0] = v.raw;
+ CopyBytes<32>(vec_tmp, &vec_temp);
+ vec_temp = __lasx_xvsllwil_w_h(vec_temp, 0);
+ vec_temp = __lasx_xvpermi_d(vec_temp, 0xd8);
+ return VFromD<D>{__lasx_xvsllwil_d_w(vec_temp, 0)};
+}
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I64_D(D)>
+HWY_API VFromD<D> PromoteTo(D /* tag */, Vec32<int8_t> v) {
+ alignas(32) __m128i vec_tmp[2];
+ __m256i vec_temp;
+ vec_tmp[0] = v.raw;
+ CopyBytes<32>(vec_tmp, &vec_temp);
+ vec_temp = __lasx_xvsllwil_h_b(vec_temp, 0);
+ vec_temp = __lasx_xvsllwil_w_h(vec_temp, 0);
+ vec_temp = __lasx_xvpermi_d(vec_temp, 0xd8);
+ return VFromD<D>{__lasx_xvsllwil_d_w(vec_temp, 0)};
+}
+
+// ------------------------------ PromoteEvenTo/PromoteOddTo
+namespace detail {
+
+// I32->I64 PromoteEvenTo/PromoteOddTo
+
+template <class D, HWY_IF_LANES_D(D, 4)>
+HWY_INLINE VFromD<D> PromoteEvenTo(hwy::SignedTag /*to_type_tag*/,
+ hwy::SizeTag<8> /*to_lane_size_tag*/,
+ hwy::SignedTag /*from_type_tag*/, D d_to,
+ Vec256<int32_t> v) {
+ return BitCast(d_to, OddEven(DupEven(BroadcastSignBit(v)), v));
+}
+
+template <class D, HWY_IF_LANES_D(D, 4)>
+HWY_INLINE VFromD<D> PromoteOddTo(hwy::SignedTag /*to_type_tag*/,
+ hwy::SizeTag<8> /*to_lane_size_tag*/,
+ hwy::SignedTag /*from_type_tag*/, D d_to,
+ Vec256<int32_t> v) {
+ return BitCast(d_to, OddEven(BroadcastSignBit(v), DupOdd(v)));
+}
+
+} // namespace detail
+
+// ------------------------------ Demotions (full -> part w/ narrow lanes)
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I8_D(D)>
+HWY_API VFromD<D> ReorderDemote2To(D /* tag */, Vec256<int16_t> a,
+ Vec256<int16_t> b) {
+ return VFromD<D>{__lasx_xvssrani_b_h(b.raw, a.raw, 0)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U8_D(D)>
+HWY_API VFromD<D> ReorderDemote2To(D /* tag */, Vec256<int16_t> a,
+ Vec256<int16_t> b) {
+ return VFromD<D>{__lasx_xvssrani_bu_h(b.raw, a.raw, 0)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I8_D(D)>
+HWY_API VFromD<D> ReorderDemote2To(D /* tag */, Vec256<uint16_t> a,
+ Vec256<uint16_t> b) {
+ return VFromD<D>{__lasx_xvssrlni_b_h(b.raw, a.raw, 0)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U8_D(D)>
+HWY_API VFromD<D> ReorderDemote2To(D /* tag */, Vec256<uint16_t> a,
+ Vec256<uint16_t> b) {
+ return VFromD<D>{__lasx_xvssrlni_bu_h(b.raw, a.raw, 0)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I16_D(D)>
+HWY_API VFromD<D> ReorderDemote2To(D /* tag */, Vec256<int32_t> a,
+ Vec256<int32_t> b) {
+ return VFromD<D>{__lasx_xvssrani_h_w(b.raw, a.raw, 0)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U16_D(D)>
+HWY_API VFromD<D> ReorderDemote2To(D /* tag */, Vec256<int32_t> a,
+ Vec256<int32_t> b) {
+ return VFromD<D>{__lasx_xvssrani_hu_w(b.raw, a.raw, 0)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I16_D(D)>
+HWY_API VFromD<D> ReorderDemote2To(D /* tag */, Vec256<uint32_t> a,
+ Vec256<uint32_t> b) {
+ return VFromD<D>{__lasx_xvssrlni_h_w(b.raw, a.raw, 0)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U16_D(D)>
+HWY_API VFromD<D> ReorderDemote2To(D /* tag */, Vec256<uint32_t> a,
+ Vec256<uint32_t> b) {
+ return VFromD<D>{__lasx_xvssrlni_hu_w(b.raw, a.raw, 0)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I32_D(D)>
+HWY_API VFromD<D> ReorderDemote2To(D /* tag */, Vec256<int64_t> a,
+ Vec256<int64_t> b) {
+ return VFromD<D>{__lasx_xvssrani_w_d(b.raw, a.raw, 0)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U32_D(D)>
+HWY_API VFromD<D> ReorderDemote2To(D /* tag */, Vec256<int64_t> a,
+ Vec256<int64_t> b) {
+ return VFromD<D>{__lasx_xvssrani_wu_d(b.raw, a.raw, 0)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I32_D(D)>
+HWY_API VFromD<D> ReorderDemote2To(D /* tag */, Vec256<uint64_t> a,
+ Vec256<uint64_t> b) {
+ return VFromD<D>{__lasx_xvssrlni_w_d(b.raw, a.raw, 0)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U32_D(D)>
+HWY_API VFromD<D> ReorderDemote2To(D /* tag */, Vec256<uint64_t> a,
+ Vec256<uint64_t> b) {
+ return VFromD<D>{__lasx_xvssrlni_wu_d(b.raw, a.raw, 0)};
+}
+
+template <class D, class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL(TFromD<D>),
+ HWY_IF_V_SIZE_D(D, 32), HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V),
+ HWY_IF_T_SIZE_V(V, sizeof(TFromD<D>) * 2),
+ HWY_IF_LANES_D(D, HWY_MAX_LANES_D(DFromV<V>) * 2)>
+HWY_API VFromD<D> OrderedDemote2To(D d, V a, V b) {
+ return VFromD<D>{__lasx_xvpermi_d(ReorderDemote2To(d, a, b).raw, 0xd8)};
+}
+
+template <class D, class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(D),
+ HWY_IF_V_SIZE_D(D, 16), HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V),
+ HWY_IF_T_SIZE_V(V, sizeof(TFromD<D>) * 2),
+ HWY_IF_LANES_D(D, HWY_MAX_LANES_D(DFromV<V>))>
+HWY_API VFromD<D> DemoteTo(D d, V v) {
+ return LowerHalf(OrderedDemote2To(Twice<decltype(d)>(), v, v));
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_F16_D(D)>
+HWY_API VFromD<D> DemoteTo(D /* tag */, Vec256<float> v) {
+ const Full256<int16_t> di;
+ const Vec256<hwy::float16_t> f16_blocks{__lasx_xvfcvt_h_s(v.raw, v.raw)};
+ const auto f16_concat =
+ BitCast(Twice<D>(), VFromD<decltype(di)>{__lasx_xvpermi_d(
+ BitCast(di, f16_blocks).raw, 0xd8)});
+ return LowerHalf(f16_concat);
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_F32_D(D)>
+HWY_API VFromD<D> DemoteTo(D /* tag */, Vec256<double> v) {
+ const Full256<int32_t> di;
+ const Vec256<float> f32_blocks{__lasx_xvfcvt_s_d(v.raw, v.raw)};
+ const auto f32_concat =
+ BitCast(Twice<D>(), VFromD<decltype(di)>{__lasx_xvpermi_d(
+ BitCast(di, f32_blocks).raw, 0xd8)});
+ return LowerHalf(f32_concat);
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_I32_D(D)>
+HWY_API VFromD<D> DemoteTo(D dn, Vec256<double> v) {
+ const __m256i i32_blocks = __lasx_xvftintrz_w_d(v.raw, v.raw);
+ return LowerHalf(dn, VFromD<Twice<D>>{__lasx_xvpermi_d(i32_blocks, 0xd8)});
+}
+
+// For already range-limited input [0, 255].
+HWY_API Vec128<uint8_t, 8> U8FromU32(const Vec256<uint32_t> v) {
+ const Full256<uint32_t> d32;
+ const Full64<uint8_t> d8;
+ alignas(32) static constexpr uint32_t k8From32[8] = {
+ 0x0C080400u, 0x13121110u, 0, 0, 0x13121110u, 0x0C080400u, 0, 0};
+ // Place first four bytes in lo[0], remaining 4 in hi[1].
+ const auto quad = VFromD<decltype(d32)>{
+ __lasx_xvshuf_b(Zero(d32).raw, v.raw, Load(d32, k8From32).raw)};
+ // Interleave both quadruplets - OR instead of unpack reduces port5 pressure.
+ const auto lo = LowerHalf(quad);
+ const auto hi = UpperHalf(Half<decltype(d32)>(), quad);
+ return BitCast(d8, LowerHalf(lo | hi));
+}
+
+// ------------------------------ Truncations
+
+template <class D, HWY_IF_V_SIZE_D(D, 4), HWY_IF_U8_D(D)>
+HWY_API VFromD<D> TruncateTo(D /* tag */, Vec256<uint64_t> v) {
+ const Full256<uint8_t> d8;
+ alignas(32) static constexpr uint8_t kMap[32] = {0, 8, 16, 24};
+ const auto i8 = TableLookupLanes(BitCast(d8, v), SetTableIndices(d8, kMap));
+ return LowerHalf(LowerHalf(LowerHalf(Vec256<uint8_t>{i8.raw})));
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 8), HWY_IF_U16_D(D)>
+HWY_API VFromD<D> TruncateTo(D /* tag */, Vec256<uint64_t> v) {
+ const __m256i i32_blocks = __lasx_xvpickev_w(v.raw, v.raw);
+ const __m256i i32_concat = __lasx_xvpermi_d(i32_blocks, 0xd8);
+ const __m256i i16 = __lasx_xvpickev_h(i32_concat, i32_concat);
+ return LowerHalf(LowerHalf(Vec256<uint16_t>{i16}));
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_U32_D(D)>
+HWY_API VFromD<D> TruncateTo(D /* tag */, Vec256<uint64_t> v) {
+ const Full256<uint32_t> d32;
+ alignas(32) static constexpr uint32_t kEven[8] = {0, 2, 4, 6, 0, 2, 4, 6};
+ const auto v32 =
+ TableLookupLanes(BitCast(d32, v), SetTableIndices(d32, kEven));
+ return LowerHalf(Vec256<uint32_t>{v32.raw});
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 8), HWY_IF_U8_D(D)>
+HWY_API VFromD<D> TruncateTo(D /* tag */, Vec256<uint32_t> v) {
+ const Full256<uint8_t> d8;
+ alignas(32) static constexpr uint8_t kEven[32] = {0, 4, 8, 12,
+ 16, 20, 24, 28};
+ const auto i8 = TableLookupLanes(BitCast(d8, v), SetTableIndices(d8, kEven));
+ return LowerHalf(LowerHalf(Vec256<uint8_t>{i8.raw}));
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_U16_D(D)>
+HWY_API VFromD<D> TruncateTo(D /* tag */, Vec256<uint32_t> v) {
+ const __m256i i16_blocks = __lasx_xvpickev_h(v.raw, v.raw);
+ const __m256i i16_concat = __lasx_xvpermi_d(i16_blocks, 0xd8);
+ return LowerHalf(Vec256<uint16_t>{i16_concat});
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_U8_D(D)>
+HWY_API VFromD<D> TruncateTo(D /* tag */, Vec256<uint16_t> v) {
+ const __m256i i8_blocks = __lasx_xvpickev_b(v.raw, v.raw);
+ const __m256i i8_concat = __lasx_xvpermi_d(i8_blocks, 0xd8);
+ return LowerHalf(Vec256<uint8_t>{i8_concat});
+}
+
+// ------------------------------ Integer <=> fp (ShiftRight, OddEven)
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)>
+HWY_API VFromD<D> ConvertTo(D /* tag */, Vec256<int32_t> v) {
+ return VFromD<D>{__lasx_xvffint_s_w(v.raw)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F32_D(D)>
+HWY_API VFromD<D> ConvertTo(D /*df*/, Vec256<uint32_t> v) {
+ return VFromD<D>{__lasx_xvffint_s_wu(v.raw)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)>
+HWY_API VFromD<D> ConvertTo(D /*dd*/, Vec256<int64_t> v) {
+ return VFromD<D>{__lasx_xvffint_d_l(v.raw)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_F64_D(D)>
+HWY_API VFromD<D> ConvertTo(D /*dd*/, Vec256<uint64_t> v) {
+ return VFromD<D>{__lasx_xvffint_d_lu(v.raw)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I32_D(D)>
+HWY_API VFromD<D> ConvertTo(D /*d*/, Vec256<float> v) {
+ return VFromD<D>{__lasx_xvftintrz_w_s(v.raw)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I64_D(D)>
+HWY_API VFromD<D> ConvertTo(D /*di*/, Vec256<double> v) {
+ return VFromD<D>{__lasx_xvftintrz_l_d(v.raw)};
+}
+
+template <class DU, HWY_IF_V_SIZE_D(DU, 32), HWY_IF_U32_D(DU)>
+HWY_API VFromD<DU> ConvertTo(DU /*du*/, VFromD<RebindToFloat<DU>> v) {
+ return VFromD<DU>{__lasx_xvftintrz_wu_s(v.raw)};
+}
+
+template <class DU, HWY_IF_V_SIZE_D(DU, 32), HWY_IF_U64_D(DU)>
+HWY_API VFromD<DU> ConvertTo(DU /*du*/, VFromD<RebindToFloat<DU>> v) {
+ return VFromD<DU>{__lasx_xvftintrz_lu_d(v.raw)};
+}
+
+template <typename T, HWY_IF_FLOAT3264(T)>
+HWY_API Vec256<MakeSigned<T>> NearestInt(const Vec256<T> v) {
+ return ConvertTo(Full256<MakeSigned<T>>(), Round(v));
+}
+
+// ------------------------------ LoadMaskBits (TestBit)
+
+namespace detail {
+
+template <typename T, HWY_IF_T_SIZE(T, 1)>
+HWY_INLINE Mask256<T> LoadMaskBits256(uint64_t mask_bits) {
+ const Full256<T> d;
+ const RebindToUnsigned<decltype(d)> du;
+ const Repartition<uint32_t, decltype(d)> du32;
+ const auto vbits = BitCast(du, Set(du32, static_cast<uint32_t>(mask_bits)));
+
+ // Replicate bytes 8x such that each byte contains the bit that governs it.
+ const Repartition<uint64_t, decltype(d)> du64;
+ alignas(32) static constexpr uint64_t kRep8[4] = {
+ 0x0000000000000000ull, 0x0101010101010101ull, 0x0202020202020202ull,
+ 0x0303030303030303ull};
+ const auto rep8 = TableLookupBytes(vbits, BitCast(du, Load(du64, kRep8)));
+
+ const VFromD<decltype(du)> bit = Dup128VecFromValues(
+ du, 1, 2, 4, 8, 16, 32, 64, 128, 1, 2, 4, 8, 16, 32, 64, 128);
+ return RebindMask(d, TestBit(rep8, bit));
+}
+
+template <typename T, HWY_IF_T_SIZE(T, 2)>
+HWY_INLINE Mask256<T> LoadMaskBits256(uint64_t mask_bits) {
+ const Full256<T> d;
+ const RebindToUnsigned<decltype(d)> du;
+ alignas(32) static constexpr uint16_t kBit[16] = {
+ 1, 2, 4, 8, 16, 32, 64, 128,
+ 0x100, 0x200, 0x400, 0x800, 0x1000, 0x2000, 0x4000, 0x8000};
+ const auto vmask_bits = Set(du, static_cast<uint16_t>(mask_bits));
+ return RebindMask(d, TestBit(vmask_bits, Load(du, kBit)));
+}
+
+template <typename T, HWY_IF_T_SIZE(T, 4)>
+HWY_INLINE Mask256<T> LoadMaskBits256(uint64_t mask_bits) {
+ const Full256<T> d;
+ const RebindToUnsigned<decltype(d)> du;
+ alignas(32) static constexpr uint32_t kBit[8] = {1, 2, 4, 8, 16, 32, 64, 128};
+ const auto vmask_bits = Set(du, static_cast<uint32_t>(mask_bits));
+ return RebindMask(d, TestBit(vmask_bits, Load(du, kBit)));
+}
+
+template <typename T, HWY_IF_T_SIZE(T, 8)>
+HWY_INLINE Mask256<T> LoadMaskBits256(uint64_t mask_bits) {
+ const Full256<T> d;
+ const RebindToUnsigned<decltype(d)> du;
+ alignas(32) static constexpr uint64_t kBit[8] = {1, 2, 4, 8};
+ return RebindMask(d, TestBit(Set(du, mask_bits), Load(du, kBit)));
+}
+
+} // namespace detail
+
+// `p` points to at least 8 readable bytes, not all of which need be valid.
+template <class D, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API MFromD<D> LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) {
+ constexpr size_t kN = MaxLanes(d);
+ constexpr size_t kNumBytes = (kN + 7) / 8;
+
+ uint64_t mask_bits = 0;
+ CopyBytes<kNumBytes>(bits, &mask_bits);
+
+ if (kN < 8) {
+ mask_bits &= (1ull << kN) - 1;
+ }
+
+ return detail::LoadMaskBits256<TFromD<D>>(mask_bits);
+}
+
+// ------------------------------ BitsFromMask
+
+template <class D, HWY_IF_T_SIZE_D(D, 1), HWY_IF_V_SIZE_D(D, 32)>
+HWY_API uint64_t BitsFromMask(D /*tag*/, MFromD<D> mask) {
+ const auto sign_bits = __lasx_xvmskltz_b(mask.raw);
+ return static_cast<uint32_t>(__lasx_xvpickve2gr_w(sign_bits, 0) |
+ (__lasx_xvpickve2gr_w(sign_bits, 4) << 16));
+}
+
+template <class D, HWY_IF_T_SIZE_D(D, 2), HWY_IF_V_SIZE_D(D, 32)>
+HWY_API uint64_t BitsFromMask(D d, MFromD<D> mask) {
+ const RebindToSigned<decltype(d)> di;
+ const auto vec_mask = VecFromMask(mask);
+ const auto sign_bits =
+ __lasx_xvpickod_b(BitCast(di, vec_mask).raw, BitCast(di, vec_mask).raw);
+ const auto sign_shuf = __lasx_xvpermi_d(sign_bits, 0xd8);
+ const auto sign_last = __lasx_xvmskltz_b(sign_shuf);
+ return static_cast<unsigned>(__lasx_xvpickve2gr_w(sign_last, 0));
+}
+
+template <class D, HWY_IF_T_SIZE_D(D, 4), HWY_IF_V_SIZE_D(D, 32)>
+HWY_API uint64_t BitsFromMask(D d, MFromD<D> mask) {
+ const RebindToSigned<decltype(d)> di;
+ const auto vec_mask = VecFromMask(mask);
+ const auto sign_bits =
+ __lasx_xvpickod_h(BitCast(di, vec_mask).raw, BitCast(di, vec_mask).raw);
+ const auto sign_shuf = __lasx_xvpermi_d(sign_bits, 0xd8);
+ const auto sign_last = __lasx_xvmskltz_h(sign_shuf);
+ return static_cast<unsigned>(__lasx_xvpickve2gr_w(sign_last, 0));
+}
+
+template <class D, HWY_IF_T_SIZE_D(D, 8), HWY_IF_V_SIZE_D(D, 32)>
+HWY_API uint64_t BitsFromMask(D d, MFromD<D> mask) {
+ const RebindToSigned<decltype(d)> di;
+ const auto vec_mask = VecFromMask(mask);
+ const auto sign_bits =
+ __lasx_xvpickod_w(BitCast(di, vec_mask).raw, BitCast(di, vec_mask).raw);
+ const auto sign_shuf = __lasx_xvpermi_d(sign_bits, 0xd8);
+ const auto sign_last = __lasx_xvmskltz_w(sign_shuf);
+ return static_cast<unsigned>(__lasx_xvpickve2gr_w(sign_last, 0));
+}
+
+// ------------------------------ StoreMaskBits
+// `p` points to at least 8 writable bytes.
+template <class D, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API size_t StoreMaskBits(D d, MFromD<D> mask, uint8_t* bits) {
+ constexpr size_t N = MaxLanes(d);
+ constexpr size_t kNumBytes = (N + 7) / 8;
+
+ const uint64_t mask_bits = BitsFromMask(d, mask);
+ CopyBytes<kNumBytes>(&mask_bits, bits);
+ return kNumBytes;
+}
+
+// ------------------------------ Mask testing
+
+template <class D, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API bool AllFalse(D d, MFromD<D> mask) {
+ return BitsFromMask(d, mask) == 0;
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API bool AllTrue(D d, MFromD<D> mask) {
+ constexpr size_t kN = MaxLanes(d);
+ constexpr uint64_t kAllBits = (1ull << kN) - 1;
+ return BitsFromMask(d, mask) == kAllBits;
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API size_t CountTrue(D d, MFromD<D> mask) {
+ return PopCount(BitsFromMask(d, mask));
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API size_t FindKnownFirstTrue(D d, MFromD<D> mask) {
+ const uint32_t mask_bits = static_cast<uint32_t>(BitsFromMask(d, mask));
+ return Num0BitsBelowLS1Bit_Nonzero32(mask_bits);
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API intptr_t FindFirstTrue(D d, MFromD<D> mask) {
+ const uint32_t mask_bits = static_cast<uint32_t>(BitsFromMask(d, mask));
+ return mask_bits ? intptr_t(Num0BitsBelowLS1Bit_Nonzero32(mask_bits)) : -1;
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API size_t FindKnownLastTrue(D d, MFromD<D> mask) {
+ const uint32_t mask_bits = static_cast<uint32_t>(BitsFromMask(d, mask));
+ return 31 - Num0BitsAboveMS1Bit_Nonzero32(mask_bits);
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API intptr_t FindLastTrue(D d, MFromD<D> mask) {
+ const uint32_t mask_bits = static_cast<uint32_t>(BitsFromMask(d, mask));
+ return mask_bits ? intptr_t(31 - Num0BitsAboveMS1Bit_Nonzero32(mask_bits))
+ : -1;
+}
+
+// ------------------------------ Compress, CompressBits
+
+namespace detail {
+
+template <typename T, HWY_IF_T_SIZE(T, 4)>
+HWY_INLINE Vec256<uint32_t> IndicesFromBits256(uint64_t mask_bits) {
+ const Full256<uint32_t> d32;
+ // We need a masked Iota(). With 8 lanes, there are 256 combinations and a LUT
+ // of SetTableIndices would require 8 KiB, a large part of L1D. We instead
+ // compress each index into 4 bits, for a total of 1 KiB.
+ alignas(16) static constexpr uint32_t packed_array[256] = {
+ // PrintCompress32x8Tables
+ 0x76543210, 0x76543218, 0x76543209, 0x76543298, 0x7654310a, 0x765431a8,
+ 0x765430a9, 0x76543a98, 0x7654210b, 0x765421b8, 0x765420b9, 0x76542b98,
+ 0x765410ba, 0x76541ba8, 0x76540ba9, 0x7654ba98, 0x7653210c, 0x765321c8,
+ 0x765320c9, 0x76532c98, 0x765310ca, 0x76531ca8, 0x76530ca9, 0x7653ca98,
+ 0x765210cb, 0x76521cb8, 0x76520cb9, 0x7652cb98, 0x76510cba, 0x7651cba8,
+ 0x7650cba9, 0x765cba98, 0x7643210d, 0x764321d8, 0x764320d9, 0x76432d98,
+ 0x764310da, 0x76431da8, 0x76430da9, 0x7643da98, 0x764210db, 0x76421db8,
+ 0x76420db9, 0x7642db98, 0x76410dba, 0x7641dba8, 0x7640dba9, 0x764dba98,
+ 0x763210dc, 0x76321dc8, 0x76320dc9, 0x7632dc98, 0x76310dca, 0x7631dca8,
+ 0x7630dca9, 0x763dca98, 0x76210dcb, 0x7621dcb8, 0x7620dcb9, 0x762dcb98,
+ 0x7610dcba, 0x761dcba8, 0x760dcba9, 0x76dcba98, 0x7543210e, 0x754321e8,
+ 0x754320e9, 0x75432e98, 0x754310ea, 0x75431ea8, 0x75430ea9, 0x7543ea98,
+ 0x754210eb, 0x75421eb8, 0x75420eb9, 0x7542eb98, 0x75410eba, 0x7541eba8,
+ 0x7540eba9, 0x754eba98, 0x753210ec, 0x75321ec8, 0x75320ec9, 0x7532ec98,
+ 0x75310eca, 0x7531eca8, 0x7530eca9, 0x753eca98, 0x75210ecb, 0x7521ecb8,
+ 0x7520ecb9, 0x752ecb98, 0x7510ecba, 0x751ecba8, 0x750ecba9, 0x75ecba98,
+ 0x743210ed, 0x74321ed8, 0x74320ed9, 0x7432ed98, 0x74310eda, 0x7431eda8,
+ 0x7430eda9, 0x743eda98, 0x74210edb, 0x7421edb8, 0x7420edb9, 0x742edb98,
+ 0x7410edba, 0x741edba8, 0x740edba9, 0x74edba98, 0x73210edc, 0x7321edc8,
+ 0x7320edc9, 0x732edc98, 0x7310edca, 0x731edca8, 0x730edca9, 0x73edca98,
+ 0x7210edcb, 0x721edcb8, 0x720edcb9, 0x72edcb98, 0x710edcba, 0x71edcba8,
+ 0x70edcba9, 0x7edcba98, 0x6543210f, 0x654321f8, 0x654320f9, 0x65432f98,
+ 0x654310fa, 0x65431fa8, 0x65430fa9, 0x6543fa98, 0x654210fb, 0x65421fb8,
+ 0x65420fb9, 0x6542fb98, 0x65410fba, 0x6541fba8, 0x6540fba9, 0x654fba98,
+ 0x653210fc, 0x65321fc8, 0x65320fc9, 0x6532fc98, 0x65310fca, 0x6531fca8,
+ 0x6530fca9, 0x653fca98, 0x65210fcb, 0x6521fcb8, 0x6520fcb9, 0x652fcb98,
+ 0x6510fcba, 0x651fcba8, 0x650fcba9, 0x65fcba98, 0x643210fd, 0x64321fd8,
+ 0x64320fd9, 0x6432fd98, 0x64310fda, 0x6431fda8, 0x6430fda9, 0x643fda98,
+ 0x64210fdb, 0x6421fdb8, 0x6420fdb9, 0x642fdb98, 0x6410fdba, 0x641fdba8,
+ 0x640fdba9, 0x64fdba98, 0x63210fdc, 0x6321fdc8, 0x6320fdc9, 0x632fdc98,
+ 0x6310fdca, 0x631fdca8, 0x630fdca9, 0x63fdca98, 0x6210fdcb, 0x621fdcb8,
+ 0x620fdcb9, 0x62fdcb98, 0x610fdcba, 0x61fdcba8, 0x60fdcba9, 0x6fdcba98,
+ 0x543210fe, 0x54321fe8, 0x54320fe9, 0x5432fe98, 0x54310fea, 0x5431fea8,
+ 0x5430fea9, 0x543fea98, 0x54210feb, 0x5421feb8, 0x5420feb9, 0x542feb98,
+ 0x5410feba, 0x541feba8, 0x540feba9, 0x54feba98, 0x53210fec, 0x5321fec8,
+ 0x5320fec9, 0x532fec98, 0x5310feca, 0x531feca8, 0x530feca9, 0x53feca98,
+ 0x5210fecb, 0x521fecb8, 0x520fecb9, 0x52fecb98, 0x510fecba, 0x51fecba8,
+ 0x50fecba9, 0x5fecba98, 0x43210fed, 0x4321fed8, 0x4320fed9, 0x432fed98,
+ 0x4310feda, 0x431feda8, 0x430feda9, 0x43feda98, 0x4210fedb, 0x421fedb8,
+ 0x420fedb9, 0x42fedb98, 0x410fedba, 0x41fedba8, 0x40fedba9, 0x4fedba98,
+ 0x3210fedc, 0x321fedc8, 0x320fedc9, 0x32fedc98, 0x310fedca, 0x31fedca8,
+ 0x30fedca9, 0x3fedca98, 0x210fedcb, 0x21fedcb8, 0x20fedcb9, 0x2fedcb98,
+ 0x10fedcba, 0x1fedcba8, 0x0fedcba9, 0xfedcba98};
+
+ // No need to mask because __lasx_xvperm_w ignores bits 3..31.
+ // Just shift each copy of the 32 bit LUT to extract its 4-bit fields.
+ const auto packed = Set(d32, packed_array[mask_bits]);
+ alignas(32) static constexpr uint32_t shifts[8] = {0, 4, 8, 12,
+ 16, 20, 24, 28};
+ return packed >> Load(d32, shifts);
+}
+
+template <typename T, HWY_IF_T_SIZE(T, 8)>
+HWY_INLINE Vec256<uint64_t> IndicesFromBits256(uint64_t mask_bits) {
+ const Full256<uint64_t> d64;
+
+ // For 64-bit, there are only 4 lanes, so we can afford to load the
+ // entire index vector directly.
+ alignas(32) static constexpr uint64_t u64_indices[64] = {
+ // PrintCompress64x4PairTables
+ 0, 1, 2, 3, 8, 1, 2, 3, 9, 0, 2, 3, 8, 9, 2, 3,
+ 10, 0, 1, 3, 8, 10, 1, 3, 9, 10, 0, 3, 8, 9, 10, 3,
+ 11, 0, 1, 2, 8, 11, 1, 2, 9, 11, 0, 2, 8, 9, 11, 2,
+ 10, 11, 0, 1, 8, 10, 11, 1, 9, 10, 11, 0, 8, 9, 10, 11};
+ return Load(d64, u64_indices + 4 * mask_bits);
+}
+
+template <typename T, HWY_IF_T_SIZE(T, 4)>
+HWY_INLINE Vec256<uint32_t> IndicesFromNotBits256(uint64_t mask_bits) {
+ const Full256<uint32_t> d32;
+ // We need a masked Iota(). With 8 lanes, there are 256 combinations and a LUT
+ // of SetTableIndices would require 8 KiB, a large part of L1D. We instead
+ // compress each index into 4 bits, for a total of 1 KiB.
+ alignas(16) static constexpr uint32_t packed_array[256] = {
+ // PrintCompressNot32x8Tables
+ 0xfedcba98, 0x8fedcba9, 0x9fedcba8, 0x98fedcba, 0xafedcb98, 0xa8fedcb9,
+ 0xa9fedcb8, 0xa98fedcb, 0xbfedca98, 0xb8fedca9, 0xb9fedca8, 0xb98fedca,
+ 0xbafedc98, 0xba8fedc9, 0xba9fedc8, 0xba98fedc, 0xcfedba98, 0xc8fedba9,
+ 0xc9fedba8, 0xc98fedba, 0xcafedb98, 0xca8fedb9, 0xca9fedb8, 0xca98fedb,
+ 0xcbfeda98, 0xcb8feda9, 0xcb9feda8, 0xcb98feda, 0xcbafed98, 0xcba8fed9,
+ 0xcba9fed8, 0xcba98fed, 0xdfecba98, 0xd8fecba9, 0xd9fecba8, 0xd98fecba,
+ 0xdafecb98, 0xda8fecb9, 0xda9fecb8, 0xda98fecb, 0xdbfeca98, 0xdb8feca9,
+ 0xdb9feca8, 0xdb98feca, 0xdbafec98, 0xdba8fec9, 0xdba9fec8, 0xdba98fec,
+ 0xdcfeba98, 0xdc8feba9, 0xdc9feba8, 0xdc98feba, 0xdcafeb98, 0xdca8feb9,
+ 0xdca9feb8, 0xdca98feb, 0xdcbfea98, 0xdcb8fea9, 0xdcb9fea8, 0xdcb98fea,
+ 0xdcbafe98, 0xdcba8fe9, 0xdcba9fe8, 0xdcba98fe, 0xefdcba98, 0xe8fdcba9,
+ 0xe9fdcba8, 0xe98fdcba, 0xeafdcb98, 0xea8fdcb9, 0xea9fdcb8, 0xea98fdcb,
+ 0xebfdca98, 0xeb8fdca9, 0xeb9fdca8, 0xeb98fdca, 0xebafdc98, 0xeba8fdc9,
+ 0xeba9fdc8, 0xeba98fdc, 0xecfdba98, 0xec8fdba9, 0xec9fdba8, 0xec98fdba,
+ 0xecafdb98, 0xeca8fdb9, 0xeca9fdb8, 0xeca98fdb, 0xecbfda98, 0xecb8fda9,
+ 0xecb9fda8, 0xecb98fda, 0xecbafd98, 0xecba8fd9, 0xecba9fd8, 0xecba98fd,
+ 0xedfcba98, 0xed8fcba9, 0xed9fcba8, 0xed98fcba, 0xedafcb98, 0xeda8fcb9,
+ 0xeda9fcb8, 0xeda98fcb, 0xedbfca98, 0xedb8fca9, 0xedb9fca8, 0xedb98fca,
+ 0xedbafc98, 0xedba8fc9, 0xedba9fc8, 0xedba98fc, 0xedcfba98, 0xedc8fba9,
+ 0xedc9fba8, 0xedc98fba, 0xedcafb98, 0xedca8fb9, 0xedca9fb8, 0xedca98fb,
+ 0xedcbfa98, 0xedcb8fa9, 0xedcb9fa8, 0xedcb98fa, 0xedcbaf98, 0xedcba8f9,
+ 0xedcba9f8, 0xedcba98f, 0xfedcba98, 0xf8edcba9, 0xf9edcba8, 0xf98edcba,
+ 0xfaedcb98, 0xfa8edcb9, 0xfa9edcb8, 0xfa98edcb, 0xfbedca98, 0xfb8edca9,
+ 0xfb9edca8, 0xfb98edca, 0xfbaedc98, 0xfba8edc9, 0xfba9edc8, 0xfba98edc,
+ 0xfcedba98, 0xfc8edba9, 0xfc9edba8, 0xfc98edba, 0xfcaedb98, 0xfca8edb9,
+ 0xfca9edb8, 0xfca98edb, 0xfcbeda98, 0xfcb8eda9, 0xfcb9eda8, 0xfcb98eda,
+ 0xfcbaed98, 0xfcba8ed9, 0xfcba9ed8, 0xfcba98ed, 0xfdecba98, 0xfd8ecba9,
+ 0xfd9ecba8, 0xfd98ecba, 0xfdaecb98, 0xfda8ecb9, 0xfda9ecb8, 0xfda98ecb,
+ 0xfdbeca98, 0xfdb8eca9, 0xfdb9eca8, 0xfdb98eca, 0xfdbaec98, 0xfdba8ec9,
+ 0xfdba9ec8, 0xfdba98ec, 0xfdceba98, 0xfdc8eba9, 0xfdc9eba8, 0xfdc98eba,
+ 0xfdcaeb98, 0xfdca8eb9, 0xfdca9eb8, 0xfdca98eb, 0xfdcbea98, 0xfdcb8ea9,
+ 0xfdcb9ea8, 0xfdcb98ea, 0xfdcbae98, 0xfdcba8e9, 0xfdcba9e8, 0xfdcba98e,
+ 0xfedcba98, 0xfe8dcba9, 0xfe9dcba8, 0xfe98dcba, 0xfeadcb98, 0xfea8dcb9,
+ 0xfea9dcb8, 0xfea98dcb, 0xfebdca98, 0xfeb8dca9, 0xfeb9dca8, 0xfeb98dca,
+ 0xfebadc98, 0xfeba8dc9, 0xfeba9dc8, 0xfeba98dc, 0xfecdba98, 0xfec8dba9,
+ 0xfec9dba8, 0xfec98dba, 0xfecadb98, 0xfeca8db9, 0xfeca9db8, 0xfeca98db,
+ 0xfecbda98, 0xfecb8da9, 0xfecb9da8, 0xfecb98da, 0xfecbad98, 0xfecba8d9,
+ 0xfecba9d8, 0xfecba98d, 0xfedcba98, 0xfed8cba9, 0xfed9cba8, 0xfed98cba,
+ 0xfedacb98, 0xfeda8cb9, 0xfeda9cb8, 0xfeda98cb, 0xfedbca98, 0xfedb8ca9,
+ 0xfedb9ca8, 0xfedb98ca, 0xfedbac98, 0xfedba8c9, 0xfedba9c8, 0xfedba98c,
+ 0xfedcba98, 0xfedc8ba9, 0xfedc9ba8, 0xfedc98ba, 0xfedcab98, 0xfedca8b9,
+ 0xfedca9b8, 0xfedca98b, 0xfedcba98, 0xfedcb8a9, 0xfedcb9a8, 0xfedcb98a,
+ 0xfedcba98, 0xfedcba89, 0xfedcba98, 0xfedcba98};
+
+ // No need to mask because <__lasx_xvperm_w> ignores bits 3..31.
+ // Just shift each copy of the 32 bit LUT to extract its 4-bit fields.
+ const Vec256<uint32_t> packed = Set(d32, packed_array[mask_bits]);
+ alignas(32) static constexpr uint32_t shifts[8] = {0, 4, 8, 12,
+ 16, 20, 24, 28};
+ return packed >> Load(d32, shifts);
+}
+
+template <typename T, HWY_IF_T_SIZE(T, 8)>
+HWY_INLINE Vec256<uint64_t> IndicesFromNotBits256(uint64_t mask_bits) {
+ const Full256<uint64_t> d64;
+
+ // For 64-bit, there are only 4 lanes, so we can afford to load
+ // the entire index vector directly.
+ alignas(32) static constexpr uint64_t u64_indices[64] = {
+ // PrintCompressNot64x4PairTables
+ 8, 9, 10, 11, 9, 10, 11, 0, 8, 10, 11, 1, 10, 11, 0, 1,
+ 8, 9, 11, 2, 9, 11, 0, 2, 8, 11, 1, 2, 11, 0, 1, 2,
+ 8, 9, 10, 3, 9, 10, 0, 3, 8, 10, 1, 3, 10, 0, 1, 3,
+ 8, 9, 2, 3, 9, 0, 2, 3, 8, 1, 2, 3, 0, 1, 2, 3};
+ return Load(d64, u64_indices + 4 * mask_bits);
+}
+
+template <typename T, HWY_IF_NOT_T_SIZE(T, 2)>
+HWY_INLINE Vec256<T> Compress(Vec256<T> v, const uint64_t mask_bits) {
+ const DFromV<decltype(v)> d;
+ const RebindToSigned<decltype(d)> di;
+
+ HWY_DASSERT(mask_bits < (1ull << Lanes(d)));
+ const Indices256<TFromD<decltype(di)>> indices{
+ IndicesFromBits256<T>(mask_bits).raw};
+ return BitCast(d, TableLookupLanes(BitCast(di, v), indices));
+}
+
+// LUTs are infeasible for 2^16 possible masks, so splice together two
+// half-vector Compress.
+template <typename T, HWY_IF_T_SIZE(T, 2)>
+HWY_INLINE Vec256<T> Compress(Vec256<T> v, const uint64_t mask_bits) {
+ const DFromV<decltype(v)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ const auto vu16 = BitCast(du, v); // (required for float16_t inputs)
+ const Half<decltype(du)> duh;
+ const auto half0 = LowerHalf(duh, vu16);
+ const auto half1 = UpperHalf(duh, vu16);
+
+ const uint64_t mask_bits0 = mask_bits & 0xFF;
+ const uint64_t mask_bits1 = mask_bits >> 8;
+ const auto compressed0 = detail::CompressBits(half0, mask_bits0);
+ const auto compressed1 = detail::CompressBits(half1, mask_bits1);
+
+ alignas(32) uint16_t all_true[16] = {};
+ // Store mask=true lanes, left to right.
+ const size_t num_true0 = PopCount(mask_bits0);
+ Store(compressed0, duh, all_true);
+ StoreU(compressed1, duh, all_true + num_true0);
+
+ if (hwy::HWY_NAMESPACE::CompressIsPartition<T>::value) {
+ // Store mask=false lanes, right to left. The second vector fills the upper
+ // half with right-aligned false lanes. The first vector is shifted
+ // rightwards to overwrite the true lanes of the second.
+ alignas(32) uint16_t all_false[16] = {};
+ const size_t num_true1 = PopCount(mask_bits1);
+ Store(compressed1, duh, all_false + 8);
+ StoreU(compressed0, duh, all_false + num_true1);
+
+ const auto mask = FirstN(du, num_true0 + num_true1);
+ return BitCast(d,
+ IfThenElse(mask, Load(du, all_true), Load(du, all_false)));
+ } else {
+ // Only care about the mask=true lanes.
+ return BitCast(d, Load(du, all_true));
+ }
+}
+
+template <typename T, HWY_IF_T_SIZE_ONE_OF(T, (1 << 4) | (1 << 8))>
+HWY_INLINE Vec256<T> CompressNot(Vec256<T> v, const uint64_t mask_bits) {
+ const DFromV<decltype(v)> d;
+ const RebindToSigned<decltype(d)> di;
+
+ HWY_DASSERT(mask_bits < (1ull << Lanes(d)));
+ const Indices256<TFromD<decltype(di)>> indices{
+ IndicesFromNotBits256<T>(mask_bits).raw};
+ return BitCast(d, TableLookupLanes(BitCast(di, v), indices));
+}
+
+// LUTs are infeasible for 2^16 possible masks, so splice together two
+// half-vector Compress.
+template <typename T, HWY_IF_T_SIZE(T, 2)>
+HWY_INLINE Vec256<T> CompressNot(Vec256<T> v, const uint64_t mask_bits) {
+ // Compress ensures only the lower 16 bits are set, so flip those.
+ return Compress(v, mask_bits ^ 0xFFFF);
+}
+
+} // namespace detail
+
+template <typename T, HWY_IF_NOT_T_SIZE(T, 1)>
+HWY_API Vec256<T> Compress(Vec256<T> v, Mask256<T> m) {
+ const DFromV<decltype(v)> d;
+ return detail::Compress(v, BitsFromMask(d, m));
+}
+
+template <typename T, HWY_IF_NOT_T_SIZE(T, 1)>
+HWY_API Vec256<T> CompressNot(Vec256<T> v, Mask256<T> m) {
+ const DFromV<decltype(v)> d;
+ return detail::CompressNot(v, BitsFromMask(d, m));
+}
+
+HWY_API Vec256<uint64_t> CompressBlocksNot(Vec256<uint64_t> v,
+ Mask256<uint64_t> mask) {
+ return CompressNot(v, mask);
+}
+
+template <typename T, HWY_IF_NOT_T_SIZE(T, 1)>
+HWY_API Vec256<T> CompressBits(Vec256<T> v, const uint8_t* HWY_RESTRICT bits) {
+ constexpr size_t N = 32 / sizeof(T);
+ constexpr size_t kNumBytes = (N + 7) / 8;
+
+ uint64_t mask_bits = 0;
+ CopyBytes<kNumBytes>(bits, &mask_bits);
+
+ if (N < 8) {
+ mask_bits &= (1ull << N) - 1;
+ }
+
+ return detail::Compress(v, mask_bits);
+}
+
+// ------------------------------ CompressStore, CompressBitsStore
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_NOT_T_SIZE_D(D, 1)>
+HWY_API size_t CompressStore(VFromD<D> v, MFromD<D> m, D d,
+ TFromD<D>* HWY_RESTRICT unaligned) {
+ const uint64_t mask_bits = BitsFromMask(d, m);
+ const size_t count = PopCount(mask_bits);
+ StoreU(detail::Compress(v, mask_bits), d, unaligned);
+ detail::MaybeUnpoison(unaligned, count);
+ return count;
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32),
+ HWY_IF_T_SIZE_ONE_OF_D(D, (1 << 4) | (1 << 8))>
+HWY_API size_t CompressBlendedStore(VFromD<D> v, MFromD<D> m, D d,
+ TFromD<D>* HWY_RESTRICT unaligned) {
+ const uint64_t mask_bits = BitsFromMask(d, m);
+ const size_t count = PopCount(mask_bits);
+ using TU = MakeUnsigned<TFromD<D>>;
+
+ const RebindToUnsigned<decltype(d)> du;
+ HWY_DASSERT(mask_bits < (1ull << Lanes(d)));
+ const Vec256<TU> idx_mask = detail::IndicesFromBits256<TFromD<D>>(mask_bits);
+ // Shift nibble MSB into MSB
+ const auto shiftVal = sizeof(TU) == 4 ? 28 : 60;
+ const Mask256<TU> mask32or64 = MaskFromVec(ShiftLeft<shiftVal>(idx_mask));
+ const Mask256<TU> masku{sizeof(TU) == 4 ? __lasx_xvslti_w(mask32or64.raw, 0)
+ : __lasx_xvslti_d(mask32or64.raw, 0)};
+ const MFromD<D> mask = RebindMask(d, masku);
+ const VFromD<D> compressed = BitCast(
+ d, TableLookupLanes(BitCast(du, v), Indices256<TU>{idx_mask.raw}));
+
+ BlendedStore(compressed, mask, d, unaligned);
+ detail::MaybeUnpoison(unaligned, count);
+ return count;
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 2)>
+HWY_API size_t CompressBlendedStore(VFromD<D> v, MFromD<D> m, D d,
+ TFromD<D>* HWY_RESTRICT unaligned) {
+ const uint64_t mask_bits = BitsFromMask(d, m);
+ const size_t count = PopCount(mask_bits);
+ const VFromD<D> compressed = detail::Compress(v, mask_bits);
+ BlendedStore(compressed, FirstN(d, count), d, unaligned);
+ return count;
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_NOT_T_SIZE_D(D, 1)>
+HWY_API size_t CompressBitsStore(VFromD<D> v, const uint8_t* HWY_RESTRICT bits,
+ D d, TFromD<D>* HWY_RESTRICT unaligned) {
+ constexpr size_t N = MaxLanes(d);
+ constexpr size_t kNumBytes = (N + 7) / 8;
+
+ uint64_t mask_bits = 0;
+ CopyBytes<kNumBytes>(bits, &mask_bits);
+
+ if (N < 8) {
+ mask_bits &= (1ull << N) - 1;
+ }
+ const size_t count = PopCount(mask_bits);
+
+ StoreU(detail::Compress(v, mask_bits), d, unaligned);
+ detail::MaybeUnpoison(unaligned, count);
+ return count;
+}
+
+// ------------------------------ Dup128MaskFromMaskBits
+
+// Generic for all vector lengths >= 32 bytes
+template <class D, HWY_IF_V_SIZE_GT_D(D, 16)>
+HWY_API MFromD<D> Dup128MaskFromMaskBits(D d, unsigned mask_bits) {
+ const Half<decltype(d)> dh;
+ const auto mh = Dup128MaskFromMaskBits(dh, mask_bits);
+ return CombineMasks(d, mh, mh);
+}
+
+// ------------------------------ Expand
+
+template <typename T, HWY_IF_T_SIZE(T, 1)>
+HWY_API Vec256<T> Expand(Vec256<T> v, Mask256<T> mask) {
+ const DFromV<decltype(v)> d;
+ // LUTs are infeasible for so many mask combinations, so Combine two
+ // half-vector Expand.
+ const Half<decltype(d)> dh;
+ const uint64_t mask_bits = BitsFromMask(d, mask);
+ constexpr size_t N = 32 / sizeof(T);
+ const size_t countL = PopCount(mask_bits & ((1 << (N / 2)) - 1));
+ const Mask128<T> maskL = MaskFromVec(LowerHalf(VecFromMask(d, mask)));
+ const Vec128<T> expandL = Expand(LowerHalf(v), maskL);
+
+ alignas(32) T lanes[N];
+ Store(v, d, lanes);
+ const Mask128<T> maskH = MaskFromVec(UpperHalf(dh, VecFromMask(d, mask)));
+ const Vec128<T> expandH = Expand(LoadU(dh, lanes + countL), maskH);
+ return Combine(d, expandH, expandL);
+}
+
+template <typename T, HWY_IF_T_SIZE(T, 2)>
+HWY_API Vec256<T> Expand(Vec256<T> v, Mask256<T> mask) {
+ const Full256<T> d;
+ // LUTs are infeasible for 2^16 possible masks, so splice together two
+ // half-vector Expand.
+ const Half<decltype(d)> dh;
+ const Mask128<T> maskL = MaskFromVec(LowerHalf(VecFromMask(d, mask)));
+ const Vec128<T> expandL = Expand(LowerHalf(v), maskL);
+
+ alignas(32) T lanes[32 / sizeof(T)];
+ Store(v, d, lanes);
+ const Vec128<T> vH = LoadU(dh, lanes + CountTrue(dh, maskL));
+ const Mask128<T> maskH = MaskFromVec(UpperHalf(dh, VecFromMask(d, mask)));
+ const Vec128<T> expandH = Expand(vH, maskH);
+ return Combine(d, expandH, expandL);
+}
+
+template <typename T, HWY_IF_T_SIZE(T, 4)>
+HWY_API Vec256<T> Expand(Vec256<T> v, Mask256<T> mask) {
+ const Full256<T> d;
+ const RebindToUnsigned<decltype(d)> du;
+ const uint64_t mask_bits = BitsFromMask(d, mask);
+ alignas(16) constexpr uint32_t packed_array[256] = {
+ // PrintExpand32x8Nibble.
+ 0xffffffff, 0xfffffff0, 0xffffff0f, 0xffffff10, 0xfffff0ff, 0xfffff1f0,
+ 0xfffff10f, 0xfffff210, 0xffff0fff, 0xffff1ff0, 0xffff1f0f, 0xffff2f10,
+ 0xffff10ff, 0xffff21f0, 0xffff210f, 0xffff3210, 0xfff0ffff, 0xfff1fff0,
+ 0xfff1ff0f, 0xfff2ff10, 0xfff1f0ff, 0xfff2f1f0, 0xfff2f10f, 0xfff3f210,
+ 0xfff10fff, 0xfff21ff0, 0xfff21f0f, 0xfff32f10, 0xfff210ff, 0xfff321f0,
+ 0xfff3210f, 0xfff43210, 0xff0fffff, 0xff1ffff0, 0xff1fff0f, 0xff2fff10,
+ 0xff1ff0ff, 0xff2ff1f0, 0xff2ff10f, 0xff3ff210, 0xff1f0fff, 0xff2f1ff0,
+ 0xff2f1f0f, 0xff3f2f10, 0xff2f10ff, 0xff3f21f0, 0xff3f210f, 0xff4f3210,
+ 0xff10ffff, 0xff21fff0, 0xff21ff0f, 0xff32ff10, 0xff21f0ff, 0xff32f1f0,
+ 0xff32f10f, 0xff43f210, 0xff210fff, 0xff321ff0, 0xff321f0f, 0xff432f10,
+ 0xff3210ff, 0xff4321f0, 0xff43210f, 0xff543210, 0xf0ffffff, 0xf1fffff0,
+ 0xf1ffff0f, 0xf2ffff10, 0xf1fff0ff, 0xf2fff1f0, 0xf2fff10f, 0xf3fff210,
+ 0xf1ff0fff, 0xf2ff1ff0, 0xf2ff1f0f, 0xf3ff2f10, 0xf2ff10ff, 0xf3ff21f0,
+ 0xf3ff210f, 0xf4ff3210, 0xf1f0ffff, 0xf2f1fff0, 0xf2f1ff0f, 0xf3f2ff10,
+ 0xf2f1f0ff, 0xf3f2f1f0, 0xf3f2f10f, 0xf4f3f210, 0xf2f10fff, 0xf3f21ff0,
+ 0xf3f21f0f, 0xf4f32f10, 0xf3f210ff, 0xf4f321f0, 0xf4f3210f, 0xf5f43210,
+ 0xf10fffff, 0xf21ffff0, 0xf21fff0f, 0xf32fff10, 0xf21ff0ff, 0xf32ff1f0,
+ 0xf32ff10f, 0xf43ff210, 0xf21f0fff, 0xf32f1ff0, 0xf32f1f0f, 0xf43f2f10,
+ 0xf32f10ff, 0xf43f21f0, 0xf43f210f, 0xf54f3210, 0xf210ffff, 0xf321fff0,
+ 0xf321ff0f, 0xf432ff10, 0xf321f0ff, 0xf432f1f0, 0xf432f10f, 0xf543f210,
+ 0xf3210fff, 0xf4321ff0, 0xf4321f0f, 0xf5432f10, 0xf43210ff, 0xf54321f0,
+ 0xf543210f, 0xf6543210, 0x0fffffff, 0x1ffffff0, 0x1fffff0f, 0x2fffff10,
+ 0x1ffff0ff, 0x2ffff1f0, 0x2ffff10f, 0x3ffff210, 0x1fff0fff, 0x2fff1ff0,
+ 0x2fff1f0f, 0x3fff2f10, 0x2fff10ff, 0x3fff21f0, 0x3fff210f, 0x4fff3210,
+ 0x1ff0ffff, 0x2ff1fff0, 0x2ff1ff0f, 0x3ff2ff10, 0x2ff1f0ff, 0x3ff2f1f0,
+ 0x3ff2f10f, 0x4ff3f210, 0x2ff10fff, 0x3ff21ff0, 0x3ff21f0f, 0x4ff32f10,
+ 0x3ff210ff, 0x4ff321f0, 0x4ff3210f, 0x5ff43210, 0x1f0fffff, 0x2f1ffff0,
+ 0x2f1fff0f, 0x3f2fff10, 0x2f1ff0ff, 0x3f2ff1f0, 0x3f2ff10f, 0x4f3ff210,
+ 0x2f1f0fff, 0x3f2f1ff0, 0x3f2f1f0f, 0x4f3f2f10, 0x3f2f10ff, 0x4f3f21f0,
+ 0x4f3f210f, 0x5f4f3210, 0x2f10ffff, 0x3f21fff0, 0x3f21ff0f, 0x4f32ff10,
+ 0x3f21f0ff, 0x4f32f1f0, 0x4f32f10f, 0x5f43f210, 0x3f210fff, 0x4f321ff0,
+ 0x4f321f0f, 0x5f432f10, 0x4f3210ff, 0x5f4321f0, 0x5f43210f, 0x6f543210,
+ 0x10ffffff, 0x21fffff0, 0x21ffff0f, 0x32ffff10, 0x21fff0ff, 0x32fff1f0,
+ 0x32fff10f, 0x43fff210, 0x21ff0fff, 0x32ff1ff0, 0x32ff1f0f, 0x43ff2f10,
+ 0x32ff10ff, 0x43ff21f0, 0x43ff210f, 0x54ff3210, 0x21f0ffff, 0x32f1fff0,
+ 0x32f1ff0f, 0x43f2ff10, 0x32f1f0ff, 0x43f2f1f0, 0x43f2f10f, 0x54f3f210,
+ 0x32f10fff, 0x43f21ff0, 0x43f21f0f, 0x54f32f10, 0x43f210ff, 0x54f321f0,
+ 0x54f3210f, 0x65f43210, 0x210fffff, 0x321ffff0, 0x321fff0f, 0x432fff10,
+ 0x321ff0ff, 0x432ff1f0, 0x432ff10f, 0x543ff210, 0x321f0fff, 0x432f1ff0,
+ 0x432f1f0f, 0x543f2f10, 0x432f10ff, 0x543f21f0, 0x543f210f, 0x654f3210,
+ 0x3210ffff, 0x4321fff0, 0x4321ff0f, 0x5432ff10, 0x4321f0ff, 0x5432f1f0,
+ 0x5432f10f, 0x6543f210, 0x43210fff, 0x54321ff0, 0x54321f0f, 0x65432f10,
+ 0x543210ff, 0x654321f0, 0x6543210f, 0x76543210,
+ };
+
+ // For lane i, shift the i-th 4-bit index down to bits [0, 3).
+ const Vec256<uint32_t> packed = Set(du, packed_array[mask_bits]);
+ alignas(32) constexpr uint32_t shifts[8] = {0, 4, 8, 12, 16, 20, 24, 28};
+ // TableLookupLanes ignores upper bits; avoid bounds-check in IndicesFromVec.
+ const Indices256<uint32_t> indices{(packed >> Load(du, shifts)).raw};
+ const Vec256<uint32_t> expand = TableLookupLanes(BitCast(du, v), indices);
+ // TableLookupLanes cannot also zero masked-off lanes, so do that now.
+ return IfThenElseZero(mask, BitCast(d, expand));
+}
+
+template <typename T, HWY_IF_T_SIZE(T, 8)>
+HWY_API Vec256<T> Expand(Vec256<T> v, Mask256<T> mask) {
+ const Full256<T> d;
+ const RebindToUnsigned<decltype(d)> du;
+ const uint64_t mask_bits = BitsFromMask(d, mask);
+
+ alignas(16) constexpr uint64_t packed_array[16] = {
+ // PrintExpand64x4Nibble.
+ 0x0000ffff, 0x0000fff0, 0x0000ff0f, 0x0000ff10, 0x0000f0ff, 0x0000f1f0,
+ 0x0000f10f, 0x0000f210, 0x00000fff, 0x00001ff0, 0x00001f0f, 0x00002f10,
+ 0x000010ff, 0x000021f0, 0x0000210f, 0x00003210};
+
+ // For lane i, shift the i-th 4-bit index down to bits [0, 2).
+ const Vec256<uint64_t> packed = Set(du, packed_array[mask_bits]);
+ alignas(32) constexpr uint64_t shifts[8] = {0, 4, 8, 12, 16, 20, 24, 28};
+ // 64-bit TableLookupLanes on LASX requires IndicesFromVec, which checks
+ // bounds, so clear the upper bits.
+ const Vec256<uint64_t> masked = And(packed >> Load(du, shifts), Set(du, 3));
+ const Indices256<uint64_t> indices = IndicesFromVec(du, masked);
+ const Vec256<uint64_t> expand = TableLookupLanes(BitCast(du, v), indices);
+ // TableLookupLanes cannot also zero masked-off lanes, so do that now.
+ return IfThenElseZero(mask, BitCast(d, expand));
+}
+
+// ------------------------------ LoadExpand
+
+template <class D, HWY_IF_V_SIZE_D(D, 32),
+ HWY_IF_T_SIZE_ONE_OF_D(D, (1 << 1) | (1 << 2))>
+HWY_API VFromD<D> LoadExpand(MFromD<D> mask, D d,
+ const TFromD<D>* HWY_RESTRICT unaligned) {
+ return Expand(LoadU(d, unaligned), mask);
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 32),
+ HWY_IF_T_SIZE_ONE_OF_D(D, (1 << 4) | (1 << 8))>
+HWY_API VFromD<D> LoadExpand(MFromD<D> mask, D d,
+ const TFromD<D>* HWY_RESTRICT unaligned) {
+ return Expand(LoadU(d, unaligned), mask);
+}
+
+// ------------------------------ LoadInterleaved3/4
+
+// Implemented in generic_ops, we just overload LoadTransposedBlocks3/4.
+
+namespace detail {
+// Input:
+// 1 0 (<- first block of unaligned)
+// 3 2
+// 5 4
+// Output:
+// 3 0
+// 4 1
+// 5 2
+template <class D, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API void LoadTransposedBlocks3(D d, const TFromD<D>* HWY_RESTRICT unaligned,
+ VFromD<D>& A, VFromD<D>& B, VFromD<D>& C) {
+ constexpr size_t N = MaxLanes(d);
+ const VFromD<D> v10 = LoadU(d, unaligned + 0 * N); // 1 0
+ const VFromD<D> v32 = LoadU(d, unaligned + 1 * N);
+ const VFromD<D> v54 = LoadU(d, unaligned + 2 * N);
+
+ A = ConcatUpperLower(d, v32, v10);
+ B = ConcatLowerUpper(d, v54, v10);
+ C = ConcatUpperLower(d, v54, v32);
+}
+
+// Input (128-bit blocks):
+// 1 0 (first block of unaligned)
+// 3 2
+// 5 4
+// 7 6
+// Output:
+// 4 0 (LSB of vA)
+// 5 1
+// 6 2
+// 7 3
+template <class D, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API void LoadTransposedBlocks4(D d, const TFromD<D>* HWY_RESTRICT unaligned,
+ VFromD<D>& vA, VFromD<D>& vB, VFromD<D>& vC,
+ VFromD<D>& vD) {
+ constexpr size_t N = MaxLanes(d);
+ const VFromD<D> v10 = LoadU(d, unaligned + 0 * N);
+ const VFromD<D> v32 = LoadU(d, unaligned + 1 * N);
+ const VFromD<D> v54 = LoadU(d, unaligned + 2 * N);
+ const VFromD<D> v76 = LoadU(d, unaligned + 3 * N);
+
+ vA = ConcatLowerLower(d, v54, v10);
+ vB = ConcatUpperUpper(d, v54, v10);
+ vC = ConcatLowerLower(d, v76, v32);
+ vD = ConcatUpperUpper(d, v76, v32);
+}
+} // namespace detail
+
+// ------------------------------ StoreInterleaved2/3/4 (ConcatUpperLower)
+
+// Implemented in generic_ops, we just overload StoreTransposedBlocks2/3/4.
+
+namespace detail {
+// Input (128-bit blocks):
+// 2 0 (LSB of i)
+// 3 1
+// Output:
+// 1 0
+// 3 2
+template <class D, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API void StoreTransposedBlocks2(VFromD<D> i, VFromD<D> j, D d,
+ TFromD<D>* HWY_RESTRICT unaligned) {
+ constexpr size_t N = MaxLanes(d);
+ const auto out0 = ConcatLowerLower(d, j, i);
+ const auto out1 = ConcatUpperUpper(d, j, i);
+ StoreU(out0, d, unaligned + 0 * N);
+ StoreU(out1, d, unaligned + 1 * N);
+}
+
+// Input (128-bit blocks):
+// 3 0 (LSB of i)
+// 4 1
+// 5 2
+// Output:
+// 1 0
+// 3 2
+// 5 4
+template <class D, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API void StoreTransposedBlocks3(VFromD<D> i, VFromD<D> j, VFromD<D> k, D d,
+ TFromD<D>* HWY_RESTRICT unaligned) {
+ constexpr size_t N = MaxLanes(d);
+ const auto out0 = ConcatLowerLower(d, j, i);
+ const auto out1 = ConcatUpperLower(d, i, k);
+ const auto out2 = ConcatUpperUpper(d, k, j);
+ StoreU(out0, d, unaligned + 0 * N);
+ StoreU(out1, d, unaligned + 1 * N);
+ StoreU(out2, d, unaligned + 2 * N);
+}
+
+// Input (128-bit blocks):
+// 4 0 (LSB of i)
+// 5 1
+// 6 2
+// 7 3
+// Output:
+// 1 0
+// 3 2
+// 5 4
+// 7 6
+template <class D, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API void StoreTransposedBlocks4(VFromD<D> i, VFromD<D> j, VFromD<D> k,
+ VFromD<D> l, D d,
+ TFromD<D>* HWY_RESTRICT unaligned) {
+ constexpr size_t N = MaxLanes(d);
+ // Write lower halves, then upper.
+ const auto out0 = ConcatLowerLower(d, j, i);
+ const auto out1 = ConcatLowerLower(d, l, k);
+ StoreU(out0, d, unaligned + 0 * N);
+ StoreU(out1, d, unaligned + 1 * N);
+ const auto out2 = ConcatUpperUpper(d, j, i);
+ const auto out3 = ConcatUpperUpper(d, l, k);
+ StoreU(out2, d, unaligned + 2 * N);
+ StoreU(out3, d, unaligned + 3 * N);
+}
+} // namespace detail
+
+// ------------------------------ Additional mask logical operations
+
+namespace detail {
+
+template <class T>
+static HWY_INLINE HWY_MAYBE_UNUSED Vec256<T> LasxI256Neg(Vec256<T> v) {
+ const Full256<T> d;
+ const Repartition<uint64_t, decltype(d)> du64;
+
+ const auto vu64 = BitCast(du64, v);
+ const auto vu64_zero = Zero(du64);
+ const auto i128_ne_zero = VecFromMask(du64, Ne128(du64, vu64, vu64_zero));
+ const VFromD<decltype(du64)> i128_neg_result{
+ __lasx_xvsub_q(vu64_zero.raw, vu64.raw)};
+ const VFromD<decltype(du64)> i256_neg_result_as_u64{
+ __lasx_xvadd_q(i128_neg_result.raw,
+ ConcatLowerLower(du64, i128_ne_zero, vu64_zero).raw)};
+
+ return BitCast(d, i256_neg_result_as_u64);
+}
+
+} // namespace detail
+
+template <class T>
+HWY_API Mask256<T> SetAtOrAfterFirst(Mask256<T> mask) {
+ const Full256<T> d;
+ return Or(mask, MaskFromVec(detail::LasxI256Neg(VecFromMask(d, mask))));
+}
+
+template <class T>
+HWY_API Mask256<T> SetBeforeFirst(Mask256<T> mask) {
+ return Not(SetAtOrAfterFirst(mask));
+}
+
+template <class T>
+HWY_API Mask256<T> SetOnlyFirst(Mask256<T> mask) {
+ const Full256<T> d;
+ const RebindToSigned<decltype(d)> di;
+
+ const auto vmask = BitCast(di, VecFromMask(d, mask));
+ const auto neg_vmask = detail::LasxI256Neg(vmask);
+
+ return MaskFromVec(BitCast(d, Neg(And(vmask, neg_vmask))));
+}
+
+template <class T>
+HWY_API Mask256<T> SetAtOrBeforeFirst(Mask256<T> mask) {
+ const Full256<T> d;
+ constexpr size_t kLanesPerBlock = MaxLanes(d) / 2;
+
+ const auto vmask = VecFromMask(d, mask);
+ const auto vmask_lo = ConcatLowerLower(d, vmask, Zero(d));
+ return SetBeforeFirst(
+ MaskFromVec(CombineShiftRightBytes<(kLanesPerBlock - 1) * sizeof(T)>(
+ d, vmask, vmask_lo)));
+}
+
+// ------------------------------ LeadingZeroCount
+
+template <class V, HWY_IF_UI8(TFromV<V>), HWY_IF_V_SIZE_V(V, 32)>
+HWY_API V LeadingZeroCount(V v) {
+ return V{__lasx_xvclz_b(v.raw)};
+}
+template <class V, HWY_IF_UI16(TFromV<V>), HWY_IF_V_SIZE_V(V, 32)>
+HWY_API V LeadingZeroCount(V v) {
+ return V{__lasx_xvclz_h(v.raw)};
+}
+template <class V, HWY_IF_UI32(TFromV<V>), HWY_IF_V_SIZE_V(V, 32)>
+HWY_API V LeadingZeroCount(V v) {
+ return V{__lasx_xvclz_w(v.raw)};
+}
+template <class V, HWY_IF_UI64(TFromV<V>), HWY_IF_V_SIZE_V(V, 32)>
+HWY_API V LeadingZeroCount(V v) {
+ return V{__lasx_xvclz_d(v.raw)};
+}
+
+template <class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V), HWY_IF_V_SIZE_V(V, 32)>
+HWY_API V HighestSetBitIndex(V v) {
+ const DFromV<decltype(v)> d;
+ using T = TFromD<decltype(d)>;
+ return BitCast(d, Set(d, T{sizeof(T) * 8 - 1}) - LeadingZeroCount(v));
+}
+
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
diff --git a/third_party/highway/hwy/ops/loongarch_lsx-inl.h b/third_party/highway/hwy/ops/loongarch_lsx-inl.h
index 035e38b978..8f8ba3c31b 100644
--- a/third_party/highway/hwy/ops/loongarch_lsx-inl.h
+++ b/third_party/highway/hwy/ops/loongarch_lsx-inl.h
@@ -13,4 +13,5942 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// TODO: fill
\ No newline at end of file
+#include <stdio.h>
+
+#ifndef __loongarch_sx
+// If LSX is to be runtime dispatched (instead of in baseline), we need
+// to enable it *and* define __loongarch_sx or the intrinsic header will
+// fail to compile.
+//
+// We cannot simply move lsxintrin.h after HWY_BEFORE_NAMESPACE because
+// doing so may cause the first (the only effective) inclusion of
+// lsxintrin.h to be compiled with both LSX and LASX enabled. Then when
+// we call the inline functions in the header with only LSX enabled,
+// we'll get an "always_inline function requires lasx but would be inlined
+// into a function that is compiled without suport for lasx" error.
+HWY_PUSH_ATTRIBUTES("lsx")
+#define __loongarch_sx
+#include <lsxintrin.h>
+#undef __loongarch_sx
+// Prevent "unused push_attribute" warning from Clang.
+HWY_MAYBE_UNUSED static void HWY_CONCAT(hwy_lsx_dummy, __COUNTER__) () {}
+HWY_POP_ATTRIBUTES
+#else
+#include <lsxintrin.h>
+#endif
+
+#include "third_party/highway/hwy/base.h"
+#include "third_party/highway/hwy/ops/shared-inl.h"
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+namespace detail {
+
+// Enable generic functions for whichever of (f16, bf16) are not supported.
+#define HWY_LSX_IF_EMULATED_D(D) HWY_IF_SPECIAL_FLOAT_D(D)
+
+template <typename T>
+struct Raw128 {
+ using type = __m128i;
+};
+template <>
+struct Raw128<float> {
+ using type = __m128;
+};
+template <>
+struct Raw128<double> {
+ using type = __m128d;
+};
+
+} // namespace detail
+
+template <typename T, size_t N = 16 / sizeof(T)>
+class Vec128 {
+ using Raw = typename detail::Raw128<T>::type;
+
+ public:
+ using PrivateT = T; // only for DFromV
+ static constexpr size_t kPrivateN = N; // only for DFromV
+
+ // Compound assignment. Only usable if there is a corresponding non-member
+ // binary operator overload. For example, only f32 and f64 support division.
+ HWY_INLINE Vec128& operator*=(const Vec128 other) {
+ return *this = (*this * other);
+ }
+ HWY_INLINE Vec128& operator/=(const Vec128 other) {
+ return *this = (*this / other);
+ }
+ HWY_INLINE Vec128& operator+=(const Vec128 other) {
+ return *this = (*this + other);
+ }
+ HWY_INLINE Vec128& operator-=(const Vec128 other) {
+ return *this = (*this - other);
+ }
+ HWY_INLINE Vec128& operator%=(const Vec128 other) {
+ return *this = (*this % other);
+ }
+ HWY_INLINE Vec128& operator&=(const Vec128 other) {
+ return *this = (*this & other);
+ }
+ HWY_INLINE Vec128& operator|=(const Vec128 other) {
+ return *this = (*this | other);
+ }
+ HWY_INLINE Vec128& operator^=(const Vec128 other) {
+ return *this = (*this ^ other);
+ }
+
+ Raw raw;
+};
+
+template <typename T>
+using Vec64 = Vec128<T, 8 / sizeof(T)>;
+
+template <typename T>
+using Vec32 = Vec128<T, 4 / sizeof(T)>;
+
+template <typename T>
+using Vec16 = Vec128<T, 2 / sizeof(T)>;
+
+namespace detail {
+
+template <typename T>
+using RawMask128 = typename Raw128<T>::type;
+
+} // namespace detail
+
+template <typename T, size_t N = 16 / sizeof(T)>
+struct Mask128 {
+ using Raw = typename detail::RawMask128<T>;
+
+ using PrivateT = T; // only for DFromM
+ static constexpr size_t kPrivateN = N; // only for DFromM
+
+ Raw raw;
+};
+
+template <class V>
+using DFromV = Simd<typename V::PrivateT, V::kPrivateN, 0>;
+
+template <class M>
+using DFromM = Simd<typename M::PrivateT, M::kPrivateN, 0>;
+
+template <class V>
+using TFromV = typename V::PrivateT;
+
+// ------------------------------ BitCast
+
+namespace detail {
+
+HWY_INLINE __m128i BitCastToInteger(__m128i v) { return v; }
+HWY_INLINE __m128i BitCastToInteger(__m128 v) {
+ return reinterpret_cast<__m128i>(v);
+}
+HWY_INLINE __m128i BitCastToInteger(__m128d v) {
+ return reinterpret_cast<__m128i>(v);
+}
+
+template <typename T, size_t N>
+HWY_INLINE Vec128<uint8_t, N * sizeof(T)> BitCastToByte(Vec128<T, N> v) {
+ return Vec128<uint8_t, N * sizeof(T)>{BitCastToInteger(v.raw)};
+}
+
+// Cannot rely on function overloading because return types differ.
+template <typename T>
+struct BitCastFromInteger128 {
+ HWY_INLINE __m128i operator()(__m128i v) { return v; }
+};
+template <>
+struct BitCastFromInteger128<float> {
+ HWY_INLINE __m128 operator()(__m128i v) {
+ return reinterpret_cast<__m128>(v);
+ }
+};
+template <>
+struct BitCastFromInteger128<double> {
+ HWY_INLINE __m128d operator()(__m128i v) {
+ return reinterpret_cast<__m128d>(v);
+ }
+};
+
+} // namespace detail
+
+// ------------------------------ Zero
+
+// Use HWY_MAX_LANES_D here because VFromD is defined in terms of Zero.
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_NOT_FLOAT3264_D(D)>
+HWY_API Vec128<TFromD<D>, HWY_MAX_LANES_D(D)> Zero(D /* tag */) {
+ return Vec128<TFromD<D>, HWY_MAX_LANES_D(D)>{(__lsx_vreplgr2vr_w(0))};
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_FLOAT3264_D(D)>
+HWY_API Vec128<TFromD<D>, HWY_MAX_LANES_D(D)> Zero(D /* tag */) {
+ return Vec128<TFromD<D>, HWY_MAX_LANES_D(D)>{
+ detail::BitCastFromInteger128<TFromD<D>>()(__lsx_vreplgr2vr_w(0))};
+}
+
+template <class D>
+using VFromD = decltype(Zero(D()));
+
+namespace detail {
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16)>
+HWY_INLINE VFromD<D> BitCastFromByte(D /* tag */,
+ Vec128<uint8_t, D().MaxBytes()> v) {
+ return VFromD<D>{BitCastFromInteger128<TFromD<D>>()(v.raw)};
+}
+
+} // namespace detail
+
+template <class D, typename FromT, HWY_IF_V_SIZE_LE_D(D, 16)>
+HWY_API VFromD<D> BitCast(D d,
+ Vec128<FromT, Repartition<FromT, D>().MaxLanes()> v) {
+ return detail::BitCastFromByte(d, detail::BitCastToByte(v));
+}
+
+// ------------------------------ Set
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_T_SIZE_D(D, 1)>
+HWY_API VFromD<D> Set(D /* tag */, TFromD<D> t) {
+ return VFromD<D>{__lsx_vreplgr2vr_b(static_cast<int>(t))};
+}
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_UI16_D(D)>
+HWY_API VFromD<D> Set(D /* tag */, TFromD<D> t) {
+ return VFromD<D>{__lsx_vreplgr2vr_h(static_cast<int>(t))};
+}
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_UI32_D(D)>
+HWY_API VFromD<D> Set(D /* tag */, TFromD<D> t) {
+ return VFromD<D>{__lsx_vreplgr2vr_w(static_cast<int>(t))};
+}
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_UI64_D(D)>
+HWY_API VFromD<D> Set(D /* tag */, TFromD<D> t) {
+ return VFromD<D>{__lsx_vreplgr2vr_d(static_cast<long int>(t))};
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_F32_D(D)>
+HWY_API VFromD<D> Set(D d, float t) {
+ const RebindToSigned<decltype(d)> di;
+ return BitCast(d, VFromD<decltype(di)>{__lsx_vldrepl_w(&t, 0)});
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_F64_D(D)>
+HWY_API VFromD<D> Set(D d, double t) {
+ const RebindToSigned<decltype(d)> di;
+ return BitCast(d, VFromD<decltype(di)>{__lsx_vldrepl_d(&t, 0)});
+}
+
+// Generic for all vector lengths.
+template <class D, HWY_LSX_IF_EMULATED_D(D)>
+HWY_API VFromD<D> Set(D df, TFromD<D> t) {
+ const RebindToUnsigned<decltype(df)> du;
+ static_assert(sizeof(TFromD<D>) == 2, "Expecting [b]f16");
+ uint16_t bits;
+ CopyBytes<2>(&t, &bits);
+ return BitCast(df, Set(du, bits));
+}
+
+// ------------------------------ Undefined
+
+HWY_DIAGNOSTICS(push)
+HWY_DIAGNOSTICS_OFF(disable : 4700, ignored "-Wuninitialized")
+
+// Returns a vector with uninitialized elements.
+template <class D>
+HWY_API VFromD<D> Undefined(D /* tag */) {
+ VFromD<D> v;
+ return v;
+}
+
+HWY_DIAGNOSTICS(pop)
+
+// ------------------------------ GetLane
+
+template <typename T, size_t N, HWY_IF_T_SIZE(T, 1)>
+HWY_API T GetLane(const Vec128<T, N> v) {
+ return static_cast<T>(__lsx_vpickve2gr_b(v.raw, 0));
+}
+template <typename T, size_t N, HWY_IF_T_SIZE(T, 2)>
+HWY_API T GetLane(const Vec128<T, N> v) {
+ return static_cast<T>(__lsx_vpickve2gr_h(v.raw, 0));
+}
+template <typename T, size_t N, HWY_IF_T_SIZE(T, 4)>
+HWY_API T GetLane(const Vec128<T, N> v) {
+ return static_cast<T>(__lsx_vpickve2gr_w(v.raw, 0));
+}
+template <typename T, size_t N, HWY_IF_T_SIZE(T, 8)>
+HWY_API T GetLane(const Vec128<T, N> v) {
+ return static_cast<T>(__lsx_vpickve2gr_d(v.raw, 0));
+}
+template <size_t N>
+HWY_API float GetLane(const Vec128<float, N> v) {
+ float f32;
+ int32_t i32 = __lsx_vpickve2gr_w(reinterpret_cast<__m128i>(v.raw), 0);
+ CopyBytes<4>(&i32, &f32);
+ return f32;
+}
+template <size_t N>
+HWY_API double GetLane(const Vec128<double, N> v) {
+ double f64;
+ int64_t i64 = __lsx_vpickve2gr_d(reinterpret_cast<__m128i>(v.raw), 0);
+ CopyBytes<8>(&i64, &f64);
+ return f64;
+}
+
+// ------------------------------ ResizeBitCast
+
+template <class D, class FromV, HWY_IF_V_SIZE_LE_V(FromV, 16),
+ HWY_IF_V_SIZE_LE_D(D, 16)>
+HWY_API VFromD<D> ResizeBitCast(D d, FromV v) {
+ const Repartition<uint8_t, decltype(d)> du8;
+ return BitCast(d, VFromD<decltype(du8)>{detail::BitCastToInteger(v.raw)});
+}
+
+// ------------------------------ Dup128VecFromValues
+
+template <class D, HWY_IF_UI8_D(D), HWY_IF_V_SIZE_LE_D(D, 16)>
+HWY_API VFromD<D> Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1,
+ TFromD<D> t2, TFromD<D> t3, TFromD<D> t4,
+ TFromD<D> t5, TFromD<D> t6, TFromD<D> t7,
+ TFromD<D> t8, TFromD<D> t9, TFromD<D> t10,
+ TFromD<D> t11, TFromD<D> t12,
+ TFromD<D> t13, TFromD<D> t14,
+ TFromD<D> t15) {
+ typedef int8_t GccI8RawVectType __attribute__((__vector_size__(16)));
+ const GccI8RawVectType raw = {
+ static_cast<int8_t>(t0), static_cast<int8_t>(t1),
+ static_cast<int8_t>(t2), static_cast<int8_t>(t3),
+ static_cast<int8_t>(t4), static_cast<int8_t>(t5),
+ static_cast<int8_t>(t6), static_cast<int8_t>(t7),
+ static_cast<int8_t>(t8), static_cast<int8_t>(t9),
+ static_cast<int8_t>(t10), static_cast<int8_t>(t11),
+ static_cast<int8_t>(t12), static_cast<int8_t>(t13),
+ static_cast<int8_t>(t14), static_cast<int8_t>(t15)};
+ return VFromD<D>{reinterpret_cast<__m128i>(raw)};
+}
+
+template <class D, HWY_IF_UI16_D(D), HWY_IF_V_SIZE_LE_D(D, 16)>
+HWY_API VFromD<D> Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1,
+ TFromD<D> t2, TFromD<D> t3, TFromD<D> t4,
+ TFromD<D> t5, TFromD<D> t6,
+ TFromD<D> t7) {
+ typedef int16_t GccI16RawVectType __attribute__((__vector_size__(16)));
+ const GccI16RawVectType raw = {
+ static_cast<int16_t>(t0), static_cast<int16_t>(t1),
+ static_cast<int16_t>(t2), static_cast<int16_t>(t3),
+ static_cast<int16_t>(t4), static_cast<int16_t>(t5),
+ static_cast<int16_t>(t6), static_cast<int16_t>(t7)};
+ return VFromD<D>{reinterpret_cast<__m128i>(raw)};
+}
+
+template <class D, HWY_IF_SPECIAL_FLOAT_D(D)>
+HWY_API VFromD<D> Dup128VecFromValues(D d, TFromD<D> t0, TFromD<D> t1,
+ TFromD<D> t2, TFromD<D> t3, TFromD<D> t4,
+ TFromD<D> t5, TFromD<D> t6,
+ TFromD<D> t7) {
+ const RebindToSigned<decltype(d)> di;
+ return BitCast(d,
+ Dup128VecFromValues(
+ di, BitCastScalar<int16_t>(t0), BitCastScalar<int16_t>(t1),
+ BitCastScalar<int16_t>(t2), BitCastScalar<int16_t>(t3),
+ BitCastScalar<int16_t>(t4), BitCastScalar<int16_t>(t5),
+ BitCastScalar<int16_t>(t6), BitCastScalar<int16_t>(t7)));
+}
+
+template <class D, HWY_IF_UI32_D(D), HWY_IF_V_SIZE_LE_D(D, 16)>
+HWY_API VFromD<D> Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1,
+ TFromD<D> t2, TFromD<D> t3) {
+ typedef int32_t GccI32RawVectType __attribute__((__vector_size__(16)));
+ const GccI32RawVectType raw = {
+ static_cast<int32_t>(t0), static_cast<int32_t>(t1),
+ static_cast<int32_t>(t2), static_cast<int32_t>(t3)};
+ return VFromD<D>{reinterpret_cast<__m128i>(raw)};
+}
+template <class D, HWY_IF_UI64_D(D), HWY_IF_V_SIZE_LE_D(D, 16)>
+HWY_API VFromD<D> Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1) {
+ typedef int64_t GccI64RawVectType __attribute__((__vector_size__(16)));
+ const GccI64RawVectType raw = {static_cast<int64_t>(t0),
+ static_cast<int64_t>(t1)};
+ return VFromD<D>{reinterpret_cast<__m128i>(raw)};
+}
+template <class D, HWY_IF_F32_D(D), HWY_IF_V_SIZE_LE_D(D, 16)>
+HWY_API VFromD<D> Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1,
+ TFromD<D> t2, TFromD<D> t3) {
+ typedef float GccF32RawVectType __attribute__((__vector_size__(16)));
+ const GccF32RawVectType raw = {t0, t1, t2, t3};
+ return VFromD<D>{reinterpret_cast<__m128>(raw)};
+}
+template <class D, HWY_IF_F64_D(D), HWY_IF_V_SIZE_LE_D(D, 16)>
+HWY_API VFromD<D> Dup128VecFromValues(D /*d*/, TFromD<D> t0, TFromD<D> t1) {
+ typedef double GccF64RawVectType __attribute__((__vector_size__(16)));
+ const GccF64RawVectType raw = {t0, t1};
+ return VFromD<D>{reinterpret_cast<__m128d>(raw)};
+}
+
+// ================================================== LOGICAL
+
+// ------------------------------ And
+
+template <typename T, size_t N>
+HWY_API Vec128<T, N> And(Vec128<T, N> a, Vec128<T, N> b) {
+ const DFromV<decltype(a)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(d, VFromD<decltype(du)>{
+ __lsx_vand_v(BitCast(du, a).raw, BitCast(du, b).raw)});
+}
+
+// ------------------------------ AndNot
+
+// Returns ~not_mask & mask.
+template <typename T, size_t N>
+HWY_API Vec128<T, N> AndNot(Vec128<T, N> not_mask, Vec128<T, N> mask) {
+ const DFromV<decltype(mask)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(d, VFromD<decltype(du)>{__lsx_vandn_v(
+ BitCast(du, not_mask).raw, BitCast(du, mask).raw)});
+}
+
+// ------------------------------ Or
+
+template <typename T, size_t N>
+HWY_API Vec128<T, N> Or(Vec128<T, N> a, Vec128<T, N> b) {
+ const DFromV<decltype(a)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(d, VFromD<decltype(du)>{
+ __lsx_vor_v(BitCast(du, a).raw, BitCast(du, b).raw)});
+}
+
+// ------------------------------ Xor
+
+template <typename T, size_t N>
+HWY_API Vec128<T, N> Xor(Vec128<T, N> a, Vec128<T, N> b) {
+ const DFromV<decltype(a)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(d, VFromD<decltype(du)>{
+ __lsx_vxor_v(BitCast(du, a).raw, BitCast(du, b).raw)});
+}
+
+// ------------------------------ Not
+template <typename T, size_t N>
+HWY_API Vec128<T, N> Not(const Vec128<T, N> v) {
+ const DFromV<decltype(v)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(d, VFromD<decltype(du)>{
+ __lsx_vnor_v(BitCast(du, v).raw, BitCast(du, v).raw)});
+}
+
+// ------------------------------ Or3
+template <typename T, size_t N>
+HWY_API Vec128<T, N> Or3(Vec128<T, N> o1, Vec128<T, N> o2, Vec128<T, N> o3) {
+ return Or(o1, Or(o2, o3));
+}
+
+// ------------------------------ OrAnd
+template <typename T, size_t N>
+HWY_API Vec128<T, N> OrAnd(Vec128<T, N> o, Vec128<T, N> a1, Vec128<T, N> a2) {
+ return Or(o, And(a1, a2));
+}
+
+// ------------------------------ Mask
+
+// Mask and Vec are the same (true = FF..FF).
+template <typename T, size_t N>
+HWY_API Mask128<T, N> MaskFromVec(const Vec128<T, N> v) {
+ return Mask128<T, N>{v.raw};
+}
+
+template <class D>
+using MFromD = decltype(MaskFromVec(VFromD<D>()));
+
+template <typename T, size_t N>
+HWY_API Vec128<T, N> VecFromMask(const Mask128<T, N> v) {
+ return Vec128<T, N>{v.raw};
+}
+
+// Generic for all vector lengths.
+template <class D>
+HWY_API VFromD<D> VecFromMask(D /* tag */, MFromD<D> v) {
+ return VecFromMask(v);
+}
+
+template <typename T, size_t N>
+HWY_API Vec128<T, N> IfThenElse(Mask128<T, N> mask, Vec128<T, N> yes,
+ Vec128<T, N> no) {
+ const DFromV<decltype(yes)> d;
+ RebindToSigned<decltype(d)> di;
+ return BitCast(d, VFromD<decltype(di)>{__lsx_vbitsel_v(
+ BitCast(di, no).raw, BitCast(di, yes).raw,
+ RebindMask(di, mask).raw)});
+}
+
+// ------------------------------ IfVecThenElse
+template <typename T, size_t N>
+HWY_API Vec128<T, N> IfVecThenElse(Vec128<T, N> mask, Vec128<T, N> yes,
+ Vec128<T, N> no) {
+ return IfThenElse(MaskFromVec(mask), yes, no);
+}
+
+// ------------------------------ BitwiseIfThenElse
+
+#ifdef HWY_NATIVE_BITWISE_IF_THEN_ELSE
+#undef HWY_NATIVE_BITWISE_IF_THEN_ELSE
+#else
+#define HWY_NATIVE_BITWISE_IF_THEN_ELSE
+#endif
+
+template <class V>
+HWY_API V BitwiseIfThenElse(V mask, V yes, V no) {
+ return IfVecThenElse(mask, yes, no);
+}
+
+// ------------------------------ Operator overloads (internal-only if float)
+
+template <typename T, size_t N>
+HWY_API Vec128<T, N> operator&(const Vec128<T, N> a, const Vec128<T, N> b) {
+ return And(a, b);
+}
+
+template <typename T, size_t N>
+HWY_API Vec128<T, N> operator|(const Vec128<T, N> a, const Vec128<T, N> b) {
+ return Or(a, b);
+}
+
+template <typename T, size_t N>
+HWY_API Vec128<T, N> operator^(const Vec128<T, N> a, const Vec128<T, N> b) {
+ return Xor(a, b);
+}
+
+// ------------------------------ PopulationCount
+
+#ifdef HWY_NATIVE_POPCNT
+#undef HWY_NATIVE_POPCNT
+#else
+#define HWY_NATIVE_POPCNT
+#endif
+
+namespace detail {
+
+template <typename T, size_t N>
+HWY_INLINE Vec128<T, N> PopulationCount(hwy::SizeTag<1> /* tag */,
+ Vec128<T, N> v) {
+ return Vec128<T, N>{__lsx_vpcnt_b(v.raw)};
+}
+template <typename T, size_t N>
+HWY_INLINE Vec128<T, N> PopulationCount(hwy::SizeTag<2> /* tag */,
+ Vec128<T, N> v) {
+ return Vec128<T, N>{__lsx_vpcnt_h(v.raw)};
+}
+template <typename T, size_t N>
+HWY_INLINE Vec128<T, N> PopulationCount(hwy::SizeTag<4> /* tag */,
+ Vec128<T, N> v) {
+ return Vec128<T, N>{__lsx_vpcnt_w(v.raw)};
+}
+template <typename T, size_t N>
+HWY_INLINE Vec128<T, N> PopulationCount(hwy::SizeTag<8> /* tag */,
+ Vec128<T, N> v) {
+ return Vec128<T, N>{__lsx_vpcnt_d(v.raw)};
+}
+
+} // namespace detail
+
+template <typename T, size_t N>
+HWY_API Vec128<T, N> PopulationCount(Vec128<T, N> v) {
+ return detail::PopulationCount(hwy::SizeTag<sizeof(T)>(), v);
+}
+
+// ================================================== SIGN
+
+// ------------------------------ Neg
+
+template <typename T, size_t N, HWY_IF_FLOAT_OR_SPECIAL(T)>
+HWY_API Vec128<T, N> Neg(const Vec128<T, N> v) {
+ return Xor(v, SignBit(DFromV<decltype(v)>()));
+}
+
+template <typename T, size_t N, HWY_IF_UI8(T)>
+HWY_API Vec128<T, N> Neg(const Vec128<T, N> v) {
+ return Vec128<T, N>{__lsx_vneg_b(v.raw)};
+}
+
+template <typename T, size_t N, HWY_IF_UI16(T)>
+HWY_API Vec128<T, N> Neg(const Vec128<T, N> v) {
+ return Vec128<T, N>{__lsx_vneg_h(v.raw)};
+}
+
+template <typename T, size_t N, HWY_IF_UI32(T)>
+HWY_API Vec128<T, N> Neg(const Vec128<T, N> v) {
+ return Vec128<T, N>{__lsx_vneg_w(v.raw)};
+}
+
+template <typename T, size_t N, HWY_IF_UI64(T)>
+HWY_API Vec128<T, N> Neg(const Vec128<T, N> v) {
+ return Vec128<T, N>{__lsx_vneg_d(v.raw)};
+}
+
+// ------------------------------ Floating-point Abs
+// Generic for all vector lengths
+template <class V, HWY_IF_FLOAT(TFromV<V>)>
+HWY_API V Abs(V v) {
+ const DFromV<decltype(v)> d;
+ const RebindToSigned<decltype(d)> di;
+ using TI = TFromD<decltype(di)>;
+ return v & BitCast(d, Set(di, static_cast<TI>(~SignMask<TI>())));
+}
+
+// ------------------------------ CopySign
+// Generic for all vector lengths.
+template <class V>
+HWY_API V CopySign(const V magn, const V sign) {
+ static_assert(IsFloat<TFromV<V>>(), "Only makes sense for floating-point");
+
+ const DFromV<decltype(magn)> d;
+ const auto msb = SignBit(d);
+ return BitwiseIfThenElse(msb, sign, magn);
+}
+
+// ------------------------------ CopySignToAbs
+// Generic for all vector lengths.
+template <class V>
+HWY_API V CopySignToAbs(const V abs, const V sign) {
+ const DFromV<decltype(abs)> d;
+ return OrAnd(abs, SignBit(d), sign);
+}
+
+// ------------------------------ IfThenElseZero
+
+template <typename T, size_t N>
+HWY_API Vec128<T, N> IfThenElseZero(Mask128<T, N> mask, Vec128<T, N> yes) {
+ return yes & VecFromMask(DFromV<decltype(yes)>(), mask);
+}
+
+template <typename T, size_t N>
+HWY_API Vec128<T, N> IfThenZeroElse(Mask128<T, N> mask, Vec128<T, N> no) {
+ return AndNot(VecFromMask(DFromV<decltype(no)>(), mask), no);
+}
+
+// ------------------------------ Mask logical
+
+template <typename T, size_t N>
+HWY_API Mask128<T, N> Not(const Mask128<T, N> m) {
+ const Simd<T, N, 0> d;
+ return MaskFromVec(Not(VecFromMask(d, m)));
+}
+
+template <typename T, size_t N>
+HWY_API Mask128<T, N> And(const Mask128<T, N> a, Mask128<T, N> b) {
+ const Simd<T, N, 0> d;
+ return MaskFromVec(And(VecFromMask(d, a), VecFromMask(d, b)));
+}
+
+template <typename T, size_t N>
+HWY_API Mask128<T, N> AndNot(const Mask128<T, N> a, Mask128<T, N> b) {
+ const Simd<T, N, 0> d;
+ return MaskFromVec(AndNot(VecFromMask(d, a), VecFromMask(d, b)));
+}
+
+template <typename T, size_t N>
+HWY_API Mask128<T, N> Or(const Mask128<T, N> a, Mask128<T, N> b) {
+ const Simd<T, N, 0> d;
+ return MaskFromVec(Or(VecFromMask(d, a), VecFromMask(d, b)));
+}
+
+template <typename T, size_t N>
+HWY_API Mask128<T, N> Xor(const Mask128<T, N> a, Mask128<T, N> b) {
+ const Simd<T, N, 0> d;
+ return MaskFromVec(Xor(VecFromMask(d, a), VecFromMask(d, b)));
+}
+
+// ------------------------------ ExclusiveNeither
+
+template <typename T, size_t N>
+HWY_API Mask128<T, N> ExclusiveNeither(const Mask128<T, N> a, Mask128<T, N> b) {
+ const Simd<T, N, 0> d;
+ return MaskFromVec(AndNot(VecFromMask(d, a), Not(VecFromMask(d, b))));
+}
+
+// ------------------------------ ShiftLeft
+
+template <int kBits, size_t N>
+HWY_API Vec128<uint8_t, N> ShiftLeft(const Vec128<uint8_t, N> v) {
+ return Vec128<uint8_t, N>{__lsx_vslli_b(v.raw, kBits)};
+}
+template <int kBits, size_t N>
+HWY_API Vec128<uint16_t, N> ShiftLeft(const Vec128<uint16_t, N> v) {
+ return Vec128<uint16_t, N>{__lsx_vslli_h(v.raw, kBits)};
+}
+template <int kBits, size_t N>
+HWY_API Vec128<uint32_t, N> ShiftLeft(const Vec128<uint32_t, N> v) {
+ return Vec128<uint32_t, N>{__lsx_vslli_w(v.raw, kBits)};
+}
+template <int kBits, size_t N>
+HWY_API Vec128<uint64_t, N> ShiftLeft(const Vec128<uint64_t, N> v) {
+ return Vec128<uint64_t, N>{__lsx_vslli_d(v.raw, kBits)};
+}
+
+template <int kBits, size_t N>
+HWY_API Vec128<int8_t, N> ShiftLeft(const Vec128<int8_t, N> v) {
+ return Vec128<int8_t, N>{__lsx_vslli_b(v.raw, kBits)};
+}
+template <int kBits, size_t N>
+HWY_API Vec128<int16_t, N> ShiftLeft(const Vec128<int16_t, N> v) {
+ return Vec128<int16_t, N>{__lsx_vslli_h(v.raw, kBits)};
+}
+template <int kBits, size_t N>
+HWY_API Vec128<int32_t, N> ShiftLeft(const Vec128<int32_t, N> v) {
+ return Vec128<int32_t, N>{__lsx_vslli_w(v.raw, kBits)};
+}
+template <int kBits, size_t N>
+HWY_API Vec128<int64_t, N> ShiftLeft(const Vec128<int64_t, N> v) {
+ return Vec128<int64_t, N>{__lsx_vslli_d(v.raw, kBits)};
+}
+
+// ------------------------------ ShiftRight
+
+template <int kBits, size_t N>
+HWY_API Vec128<uint8_t, N> ShiftRight(Vec128<uint8_t, N> v) {
+ return Vec128<uint8_t, N>{__lsx_vsrli_b(v.raw, kBits)};
+}
+template <int kBits, size_t N>
+HWY_API Vec128<uint16_t, N> ShiftRight(Vec128<uint16_t, N> v) {
+ return Vec128<uint16_t, N>{__lsx_vsrli_h(v.raw, kBits)};
+}
+template <int kBits, size_t N>
+HWY_API Vec128<uint32_t, N> ShiftRight(Vec128<uint32_t, N> v) {
+ return Vec128<uint32_t, N>{__lsx_vsrli_w(v.raw, kBits)};
+}
+template <int kBits, size_t N>
+HWY_API Vec128<uint64_t, N> ShiftRight(Vec128<uint64_t, N> v) {
+ return Vec128<uint64_t, N>{__lsx_vsrli_d(v.raw, kBits)};
+}
+
+template <int kBits, size_t N>
+HWY_API Vec128<int8_t, N> ShiftRight(Vec128<int8_t, N> v) {
+ return Vec128<int8_t, N>{__lsx_vsrai_b(v.raw, kBits)};
+}
+template <int kBits, size_t N>
+HWY_API Vec128<int16_t, N> ShiftRight(Vec128<int16_t, N> v) {
+ return Vec128<int16_t, N>{__lsx_vsrai_h(v.raw, kBits)};
+}
+template <int kBits, size_t N>
+HWY_API Vec128<int32_t, N> ShiftRight(Vec128<int32_t, N> v) {
+ return Vec128<int32_t, N>{__lsx_vsrai_w(v.raw, kBits)};
+}
+template <int kBits, size_t N>
+HWY_API Vec128<int64_t, N> ShiftRight(Vec128<int64_t, N> v) {
+ return Vec128<int64_t, N>{__lsx_vsrai_d(v.raw, kBits)};
+}
+
+// ------------------------------ RoundingShiftRight
+
+#ifdef HWY_NATIVE_ROUNDING_SHR
+#undef HWY_NATIVE_ROUNDING_SHR
+#else
+#define HWY_NATIVE_ROUNDING_SHR
+#endif
+
+template <int kBits, size_t N>
+HWY_API Vec128<int8_t, N> RoundingShiftRight(Vec128<int8_t, N> v) {
+ return Vec128<int8_t, N>{__lsx_vsrari_b(v.raw, kBits)};
+}
+template <int kBits, size_t N>
+HWY_API Vec128<int16_t, N> RoundingShiftRight(Vec128<int16_t, N> v) {
+ return Vec128<int16_t, N>{__lsx_vsrari_h(v.raw, kBits)};
+}
+template <int kBits, size_t N>
+HWY_API Vec128<int32_t, N> RoundingShiftRight(Vec128<int32_t, N> v) {
+ return Vec128<int32_t, N>{__lsx_vsrari_w(v.raw, kBits)};
+}
+template <int kBits, size_t N>
+HWY_API Vec128<int64_t, N> RoundingShiftRight(Vec128<int64_t, N> v) {
+ return Vec128<int64_t, N>{__lsx_vsrari_d(v.raw, kBits)};
+}
+
+template <int kBits, size_t N>
+HWY_API Vec128<uint8_t, N> RoundingShiftRight(Vec128<uint8_t, N> v) {
+ return Vec128<uint8_t, N>{__lsx_vsrlri_b(v.raw, kBits)};
+}
+template <int kBits, size_t N>
+HWY_API Vec128<uint16_t, N> RoundingShiftRight(Vec128<uint16_t, N> v) {
+ return Vec128<uint16_t, N>{__lsx_vsrlri_h(v.raw, kBits)};
+}
+template <int kBits, size_t N>
+HWY_API Vec128<uint32_t, N> RoundingShiftRight(Vec128<uint32_t, N> v) {
+ return Vec128<uint32_t, N>{__lsx_vsrlri_w(v.raw, kBits)};
+}
+template <int kBits, size_t N>
+HWY_API Vec128<uint64_t, N> RoundingShiftRight(Vec128<uint64_t, N> v) {
+ return Vec128<uint64_t, N>{__lsx_vsrlri_d(v.raw, kBits)};
+}
+
+// ------------------------------ RoundingShr
+
+template <size_t N>
+HWY_API Vec128<int8_t, N> RoundingShr(Vec128<int8_t, N> v,
+ Vec128<int8_t, N> bits) {
+ return Vec128<int8_t, N>{__lsx_vsrar_b(v.raw, bits.raw)};
+}
+template <size_t N>
+HWY_API Vec128<int16_t, N> RoundingShr(Vec128<int16_t, N> v,
+ Vec128<int16_t, N> bits) {
+ return Vec128<int16_t, N>{__lsx_vsrar_h(v.raw, bits.raw)};
+}
+template <size_t N>
+HWY_API Vec128<int32_t, N> RoundingShr(Vec128<int32_t, N> v,
+ Vec128<int32_t, N> bits) {
+ return Vec128<int32_t, N>{__lsx_vsrar_w(v.raw, bits.raw)};
+}
+template <size_t N>
+HWY_API Vec128<int64_t, N> RoundingShr(Vec128<int64_t, N> v,
+ Vec128<int64_t, N> bits) {
+ return Vec128<int64_t, N>{__lsx_vsrar_d(v.raw, bits.raw)};
+}
+
+template <size_t N>
+HWY_API Vec128<uint8_t, N> RoundingShr(Vec128<uint8_t, N> v,
+ Vec128<uint8_t, N> bits) {
+ return Vec128<uint8_t, N>{__lsx_vsrlr_b(v.raw, bits.raw)};
+}
+template <size_t N>
+HWY_API Vec128<uint16_t, N> RoundingShr(Vec128<uint16_t, N> v,
+ Vec128<uint16_t, N> bits) {
+ return Vec128<uint16_t, N>{__lsx_vsrlr_h(v.raw, bits.raw)};
+}
+template <size_t N>
+HWY_API Vec128<uint32_t, N> RoundingShr(Vec128<uint32_t, N> v,
+ Vec128<uint32_t, N> bits) {
+ return Vec128<uint32_t, N>{__lsx_vsrlr_w(v.raw, bits.raw)};
+}
+template <size_t N>
+HWY_API Vec128<uint64_t, N> RoundingShr(Vec128<uint64_t, N> v,
+ Vec128<uint64_t, N> bits) {
+ return Vec128<uint64_t, N>{__lsx_vsrlr_d(v.raw, bits.raw)};
+}
+
+// ------------------------------ RoundingShiftRightSame (RoundingShr)
+
+template <typename T, size_t N>
+HWY_API Vec128<T, N> RoundingShiftRightSame(const Vec128<T, N> v, int bits) {
+ return RoundingShr(v, Set(DFromV<decltype(v)>(), static_cast<T>(bits)));
+}
+
+// ================================================== MEMORY (1)
+
+// ------------------------------ Load 128
+
+template <class D, HWY_IF_V_SIZE_D(D, 16), typename T = TFromD<D>>
+HWY_API Vec128<T> Load(D d, const T* HWY_RESTRICT aligned) {
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(d, VFromD<decltype(du)>{__lsx_vld(aligned, 0)});
+}
+
+// Partial
+template <class D, HWY_IF_V_SIZE_LE_D(D, 8)>
+HWY_API VFromD<D> Load(D d, const TFromD<D>* HWY_RESTRICT p) {
+ VFromD<D> v;
+ CopyBytes<d.MaxBytes()>(p, &v);
+ return v;
+}
+
+// LoadU == Load
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16)>
+HWY_API VFromD<D> LoadU(D d, const TFromD<D>* HWY_RESTRICT p) {
+ return Load(d, p);
+}
+
+// ------------------------------ MaskedLoad
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16)>
+HWY_API VFromD<D> MaskedLoad(MFromD<D> m, D d,
+ const TFromD<D>* HWY_RESTRICT p) {
+ return IfThenElseZero(m, LoadU(d, p));
+}
+
+// ------------------------------ MaskedLoadOr
+
+template <class D>
+HWY_API VFromD<D> MaskedLoadOr(VFromD<D> v, MFromD<D> m, D d,
+ const TFromD<D>* HWY_RESTRICT p) {
+ return IfThenElse(m, LoadU(d, p), v);
+}
+
+// 128-bit SIMD => nothing to duplicate, same as an unaligned load.
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16)>
+HWY_API VFromD<D> LoadDup128(D d, const TFromD<D>* HWY_RESTRICT p) {
+ return Load(d, p);
+}
+
+// ------------------------------ Store 128
+
+template <class D, HWY_IF_V_SIZE_D(D, 16)>
+HWY_API void Store(VFromD<D> v, D /* tag */, void* HWY_RESTRICT aligned) {
+ __lsx_vst(v.raw, aligned, 0);
+}
+
+// ------------------------------ Store 64
+
+template <class D, HWY_IF_V_SIZE_D(D, 8)>
+HWY_API void Store(VFromD<D> v, D /* tag */, void* HWY_RESTRICT aligned) {
+ __lsx_vstelm_d(v.raw, aligned, 0, 0);
+}
+
+// ------------------------------ Store 32
+
+template <class D, HWY_IF_V_SIZE_D(D, 4)>
+HWY_API void Store(VFromD<D> v, D /* tag */, void* HWY_RESTRICT aligned) {
+ __lsx_vstelm_w(v.raw, aligned, 0, 0);
+}
+
+// ------------------------------ Store 16
+
+template <class D, HWY_IF_V_SIZE_D(D, 2)>
+HWY_API void Store(VFromD<D> v, D /* tag */, void* HWY_RESTRICT aligned) {
+ __lsx_vstelm_h(v.raw, aligned, 0, 0);
+}
+
+// ------------------------------ Store 8
+
+template <class D, HWY_IF_V_SIZE_D(D, 1)>
+HWY_API void Store(VFromD<D> v, D /* tag */, void* HWY_RESTRICT aligned) {
+ __lsx_vstelm_b(v.raw, aligned, 0, 0);
+}
+
+template <class D>
+HWY_API void StoreU(VFromD<D> v, D d, void* HWY_RESTRICT p) {
+ Store(v, d, p);
+}
+
+// ================================================== SWIZZLE (1)
+
+// ------------------------------ TableLookupBytes
+template <typename T, size_t N, typename TI, size_t NI>
+HWY_API Vec128<TI, NI> TableLookupBytes(const Vec128<T, N> bytes,
+ const Vec128<TI, NI> from) {
+ const DFromV<decltype(from)> d;
+ const Repartition<uint8_t, decltype(d)> du8;
+ const DFromV<decltype(bytes)> d_bytes;
+ const Repartition<uint8_t, decltype(d_bytes)> du8_bytes;
+ return BitCast(
+ d, VFromD<decltype(du8)>{__lsx_vshuf_b(BitCast(du8_bytes, bytes).raw,
+ BitCast(du8_bytes, bytes).raw,
+ (BitCast(du8, from).raw))});
+}
+
+// ------------------------------ TableLookupBytesOr0
+template <class V, class VI>
+HWY_API VI TableLookupBytesOr0(const V bytes, const VI from) {
+ const DFromV<VI> d;
+ const Repartition<int8_t, decltype(d)> di8;
+ return BitCast(d,
+ IfThenZeroElse(Lt(BitCast(di8, from), Zero(di8)),
+ BitCast(di8, TableLookupBytes(bytes, from))));
+}
+
+// ------------------------------ Shuffles (ShiftRight, TableLookupBytes)
+
+// Notation: let Vec128<int32_t> have lanes 3,2,1,0 (0 is least-significant).
+// Shuffle0321 rotates one lane to the right (the previous least-significant
+// lane is now most-significant). These could also be implemented via
+// CombineShiftRightBytes but the shuffle_abcd notation is more convenient.
+
+// Swap 32-bit halves in 64-bit halves.
+template <typename T, size_t N>
+HWY_API Vec128<T, N> Shuffle2301(const Vec128<T, N> v) {
+ static_assert(sizeof(T) == 4, "Only for 32-bit lanes");
+ static_assert(N == 2 || N == 4, "Does not make sense for N=1");
+ const DFromV<decltype(v)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(d, VFromD<decltype(du)>{__lsx_vshuf4i_w(
+ detail::BitCastToInteger(v.raw), 0xB1)});
+}
+
+namespace detail {
+
+template <typename T, HWY_IF_T_SIZE(T, 1)>
+HWY_API Vec32<T> ShuffleTwo2301(const Vec32<T> a, const Vec32<T> b) {
+ const int8_t _data_idx[] = {1, 0, 19, 18};
+ __m128i shuffle_idx = __lsx_vld(_data_idx, 0);
+ return Vec32<T>{__lsx_vshuf_b(b.raw, a.raw, shuffle_idx)};
+}
+template <typename T, HWY_IF_T_SIZE(T, 2)>
+HWY_API Vec64<T> ShuffleTwo2301(const Vec64<T> a, const Vec64<T> b) {
+ const int16_t _data_idx[] = {9, 8, 3, 2};
+ __m128i shuffle_idx = __lsx_vld(_data_idx, 0);
+ return Vec64<T>{__lsx_vshuf_h(shuffle_idx, a.raw, b.raw)};
+}
+template <typename T, HWY_IF_T_SIZE(T, 4)>
+HWY_API Vec128<T> ShuffleTwo2301(const Vec128<T> a, const Vec128<T> b) {
+ const DFromV<decltype(a)> d;
+ const RebindToSigned<decltype(d)> di;
+ return BitCast(d, Vec128<int32_t>{__lsx_vpermi_w(BitCast(di, b).raw,
+ BitCast(di, a).raw, 0xB1)});
+}
+
+template <typename T, HWY_IF_T_SIZE(T, 1)>
+HWY_API Vec32<T> ShuffleTwo1230(const Vec32<T> a, const Vec32<T> b) {
+ const int8_t _data_idx[] = {0, 3, 18, 17};
+ __m128i shuffle_idx = __lsx_vld(_data_idx, 0);
+ return Vec32<T>{__lsx_vshuf_b(b.raw, a.raw, shuffle_idx)};
+}
+template <typename T, HWY_IF_T_SIZE(T, 2)>
+HWY_API Vec64<T> ShuffleTwo1230(const Vec64<T> a, const Vec64<T> b) {
+ const int16_t _data_idx[] = {10, 11, 2, 1};
+ __m128i shuffle_idx = __lsx_vld(_data_idx, 0);
+ auto t0 = __lsx_vshuf_h(shuffle_idx, a.raw, b.raw);
+ return Vec64<T>{t0};
+}
+template <typename T, HWY_IF_T_SIZE(T, 4)>
+HWY_API Vec128<T> ShuffleTwo1230(const Vec128<T> a, const Vec128<T> b) {
+ const DFromV<decltype(a)> d;
+ const RebindToSigned<decltype(d)> di;
+ return BitCast(d, Vec128<int32_t>{__lsx_vpermi_w(BitCast(di, b).raw,
+ BitCast(di, a).raw, 0x6C)});
+}
+
+template <typename T, HWY_IF_T_SIZE(T, 1)>
+HWY_API Vec32<T> ShuffleTwo3012(const Vec32<T> a, const Vec32<T> b) {
+ const int8_t _data_idx[] = {2, 1, 16, 19};
+ __m128i shuffle_idx = __lsx_vld(_data_idx, 0);
+ return Vec32<T>{__lsx_vshuf_b(b.raw, a.raw, shuffle_idx)};
+}
+template <typename T, HWY_IF_T_SIZE(T, 2)>
+HWY_API Vec64<T> ShuffleTwo3012(const Vec64<T> a, const Vec64<T> b) {
+ const int16_t _data_idx[] = {8, 9, 0, 3};
+ __m128i shuffle_idx = __lsx_vld(_data_idx, 0);
+ return Vec64<T>{__lsx_vshuf_h(shuffle_idx, a.raw, b.raw)};
+}
+template <typename T, HWY_IF_T_SIZE(T, 4)>
+HWY_API Vec128<T> ShuffleTwo3012(const Vec128<T> a, const Vec128<T> b) {
+ const DFromV<decltype(a)> d;
+ const RebindToSigned<decltype(d)> di;
+ return BitCast(d, Vec128<int32_t>{__lsx_vpermi_w(BitCast(di, b).raw,
+ BitCast(di, a).raw, 0xC6)});
+}
+
+} // namespace detail
+
+// Swap 64-bit halves
+template <typename T, HWY_IF_T_SIZE(T, 4)>
+HWY_API Vec128<T> Shuffle1032(const Vec128<T> v) {
+ const DFromV<decltype(v)> d;
+ return BitCast(d, Vec128<uint32_t>{__lsx_vshuf4i_w(
+ reinterpret_cast<__m128i>(v.raw), 0x4E)});
+}
+HWY_API Vec128<uint64_t> Shuffle01(const Vec128<uint64_t> v) {
+ return Vec128<uint64_t>{__lsx_vshuf4i_w(v.raw, 0x4E)};
+}
+HWY_API Vec128<int64_t> Shuffle01(const Vec128<int64_t> v) {
+ return Vec128<int64_t>{__lsx_vshuf4i_w(v.raw, 0x4E)};
+}
+HWY_API Vec128<double> Shuffle01(const Vec128<double> v) {
+ const DFromV<decltype(v)> d;
+ return BitCast(d, Vec128<uint64_t>{__lsx_vshuf4i_d(
+ reinterpret_cast<__m128i>(v.raw),
+ reinterpret_cast<__m128i>(v.raw), 0x1)});
+}
+
+// Rotate right 32 bits
+template <typename T, HWY_IF_T_SIZE(T, 4)>
+HWY_API Vec128<T> Shuffle0321(const Vec128<T> v) {
+ const DFromV<decltype(v)> d;
+ return BitCast(d, Vec128<uint32_t>{__lsx_vshuf4i_w(
+ reinterpret_cast<__m128i>(v.raw), 0x39)});
+}
+// Rotate left 32 bits
+template <typename T, HWY_IF_T_SIZE(T, 4)>
+HWY_API Vec128<T> Shuffle2103(const Vec128<T> v) {
+ const DFromV<decltype(v)> d;
+ return BitCast(d, Vec128<uint32_t>{__lsx_vshuf4i_w(
+ reinterpret_cast<__m128i>(v.raw), 0x93)});
+}
+// Reverse
+template <typename T, HWY_IF_T_SIZE(T, 4)>
+HWY_API Vec128<T> Shuffle0123(const Vec128<T> v) {
+ const DFromV<decltype(v)> d;
+ return BitCast(d, Vec128<uint32_t>{__lsx_vshuf4i_w(
+ reinterpret_cast<__m128i>(v.raw), 0x1B)});
+}
+
+// Comparisons fill a lane with 1-bits if the condition is true, else 0.
+
+template <class DTo, typename TFrom, size_t NFrom, HWY_IF_V_SIZE_LE_D(DTo, 16)>
+HWY_API MFromD<DTo> RebindMask(DTo dto, Mask128<TFrom, NFrom> m) {
+ static_assert(sizeof(TFrom) == sizeof(TFromD<DTo>), "Must have same size");
+ const Simd<TFrom, NFrom, 0> d;
+ return MaskFromVec(BitCast(dto, VecFromMask(d, m)));
+}
+
+// ================================================== COMPARE
+
+template <typename T, size_t N>
+HWY_API Mask128<T, N> TestBit(Vec128<T, N> v, Vec128<T, N> bit) {
+ static_assert(!hwy::IsFloat<T>(), "Only integer vectors supported");
+ return (v & bit) == bit;
+}
+
+// ------------------------------ Equality
+
+// Unsigned
+template <size_t N>
+HWY_API Mask128<uint8_t, N> operator==(Vec128<uint8_t, N> a,
+ Vec128<uint8_t, N> b) {
+ return Mask128<uint8_t, N>{__lsx_vseq_b(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Mask128<uint16_t, N> operator==(Vec128<uint16_t, N> a,
+ Vec128<uint16_t, N> b) {
+ return Mask128<uint16_t, N>{__lsx_vseq_h(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Mask128<uint32_t, N> operator==(Vec128<uint32_t, N> a,
+ Vec128<uint32_t, N> b) {
+ return Mask128<uint32_t, N>{__lsx_vseq_w(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Mask128<uint64_t, N> operator==(const Vec128<uint64_t, N> a,
+ const Vec128<uint64_t, N> b) {
+ return Mask128<uint64_t, N>{__lsx_vseq_d(a.raw, b.raw)};
+}
+
+// Signed
+template <size_t N>
+HWY_API Mask128<int8_t, N> operator==(Vec128<int8_t, N> a,
+ Vec128<int8_t, N> b) {
+ return Mask128<int8_t, N>{__lsx_vseq_b(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Mask128<int16_t, N> operator==(Vec128<int16_t, N> a,
+ Vec128<int16_t, N> b) {
+ return Mask128<int16_t, N>{__lsx_vseq_h(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Mask128<int32_t, N> operator==(Vec128<int32_t, N> a,
+ Vec128<int32_t, N> b) {
+ return Mask128<int32_t, N>{__lsx_vseq_w(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Mask128<int64_t, N> operator==(const Vec128<int64_t, N> a,
+ const Vec128<int64_t, N> b) {
+ return Mask128<int64_t, N>{__lsx_vseq_d(a.raw, b.raw)};
+}
+
+// Float
+template <size_t N>
+HWY_API Mask128<float, N> operator==(Vec128<float, N> a, Vec128<float, N> b) {
+ return Mask128<float, N>{
+ reinterpret_cast<__m128>(__lsx_vfcmp_ceq_s(a.raw, b.raw))};
+}
+template <size_t N>
+HWY_API Mask128<double, N> operator==(Vec128<double, N> a,
+ Vec128<double, N> b) {
+ return Mask128<double, N>{
+ reinterpret_cast<__m128d>(__lsx_vfcmp_ceq_d(a.raw, b.raw))};
+}
+
+// ------------------------------ Inequality
+
+// This cannot have T as a template argument, otherwise it is not more
+// specialized than rewritten operator== in C++20, leading to compile
+// errors: https://gcc.godbolt.org/z/xsrPhPvPT.
+template <size_t N>
+HWY_API Mask128<uint8_t, N> operator!=(Vec128<uint8_t, N> a,
+ Vec128<uint8_t, N> b) {
+ return Not(a == b);
+}
+template <size_t N>
+HWY_API Mask128<uint16_t, N> operator!=(Vec128<uint16_t, N> a,
+ Vec128<uint16_t, N> b) {
+ return Not(a == b);
+}
+template <size_t N>
+HWY_API Mask128<uint32_t, N> operator!=(Vec128<uint32_t, N> a,
+ Vec128<uint32_t, N> b) {
+ return Not(a == b);
+}
+template <size_t N>
+HWY_API Mask128<uint64_t, N> operator!=(Vec128<uint64_t, N> a,
+ Vec128<uint64_t, N> b) {
+ return Not(a == b);
+}
+template <size_t N>
+HWY_API Mask128<int8_t, N> operator!=(Vec128<int8_t, N> a,
+ Vec128<int8_t, N> b) {
+ return Not(a == b);
+}
+template <size_t N>
+HWY_API Mask128<int16_t, N> operator!=(Vec128<int16_t, N> a,
+ Vec128<int16_t, N> b) {
+ return Not(a == b);
+}
+template <size_t N>
+HWY_API Mask128<int32_t, N> operator!=(Vec128<int32_t, N> a,
+ Vec128<int32_t, N> b) {
+ return Not(a == b);
+}
+template <size_t N>
+HWY_API Mask128<int64_t, N> operator!=(Vec128<int64_t, N> a,
+ Vec128<int64_t, N> b) {
+ return Not(a == b);
+}
+
+template <size_t N>
+HWY_API Mask128<float, N> operator!=(Vec128<float, N> a, Vec128<float, N> b) {
+ return Mask128<float, N>{
+ reinterpret_cast<__m128>(__lsx_vfcmp_cune_s(a.raw, b.raw))};
+}
+template <size_t N>
+HWY_API Mask128<double, N> operator!=(Vec128<double, N> a,
+ Vec128<double, N> b) {
+ return Mask128<double, N>{
+ reinterpret_cast<__m128d>(__lsx_vfcmp_cune_d(a.raw, b.raw))};
+}
+
+// ------------------------------ Strict inequality
+
+namespace detail {
+
+template <size_t N>
+HWY_INLINE Mask128<int8_t, N> Gt(hwy::SignedTag /*tag*/, Vec128<int8_t, N> a,
+ Vec128<int8_t, N> b) {
+ return Mask128<int8_t, N>{__lsx_vslt_b(b.raw, a.raw)};
+}
+template <size_t N>
+HWY_INLINE Mask128<int16_t, N> Gt(hwy::SignedTag /*tag*/, Vec128<int16_t, N> a,
+ Vec128<int16_t, N> b) {
+ return Mask128<int16_t, N>{__lsx_vslt_h(b.raw, a.raw)};
+}
+template <size_t N>
+HWY_INLINE Mask128<int32_t, N> Gt(hwy::SignedTag /*tag*/, Vec128<int32_t, N> a,
+ Vec128<int32_t, N> b) {
+ return Mask128<int32_t, N>{__lsx_vslt_w(b.raw, a.raw)};
+}
+template <size_t N>
+HWY_INLINE Mask128<int64_t, N> Gt(hwy::SignedTag /*tag*/,
+ const Vec128<int64_t, N> a,
+ const Vec128<int64_t, N> b) {
+ return Mask128<int64_t, N>{__lsx_vslt_d(b.raw, a.raw)};
+}
+
+template <size_t N>
+HWY_INLINE Mask128<uint8_t, N> Gt(hwy::SignedTag /*tag*/, Vec128<uint8_t, N> a,
+ Vec128<uint8_t, N> b) {
+ return Mask128<uint8_t, N>{__lsx_vslt_b(b.raw, a.raw)};
+}
+template <size_t N>
+HWY_INLINE Mask128<uint16_t, N> Gt(hwy::SignedTag /*tag*/,
+ Vec128<uint16_t, N> a,
+ Vec128<uint16_t, N> b) {
+ return Mask128<uint16_t, N>{__lsx_vslt_h(b.raw, a.raw)};
+}
+template <size_t N>
+HWY_INLINE Mask128<uint32_t, N> Gt(hwy::SignedTag /*tag*/,
+ Vec128<uint32_t, N> a,
+ Vec128<uint32_t, N> b) {
+ return Mask128<uint32_t, N>{__lsx_vslt_w(b.raw, a.raw)};
+}
+template <size_t N>
+HWY_INLINE Mask128<uint64_t, N> Gt(hwy::SignedTag /*tag*/,
+ const Vec128<uint64_t, N> a,
+ const Vec128<uint64_t, N> b) {
+ return Mask128<uint64_t, N>{__lsx_vslt_d(b.raw, a.raw)};
+}
+
+template <typename T, size_t N>
+HWY_INLINE Mask128<T, N> Gt(hwy::UnsignedTag /*tag*/, Vec128<T, N> a,
+ Vec128<T, N> b) {
+ const DFromV<decltype(a)> du;
+ const RebindToSigned<decltype(du)> di;
+ const Vec128<T, N> msb = Set(du, (LimitsMax<T>() >> 1) + 1);
+ const auto sa = BitCast(di, Xor(a, msb));
+ const auto sb = BitCast(di, Xor(b, msb));
+ return RebindMask(du, Gt(hwy::SignedTag(), sa, sb));
+}
+
+template <size_t N>
+HWY_INLINE Mask128<float, N> Gt(hwy::FloatTag /*tag*/, Vec128<float, N> a,
+ Vec128<float, N> b) {
+ return Mask128<float, N>{
+ reinterpret_cast<__m128>(__lsx_vfcmp_clt_s(b.raw, a.raw))};
+}
+template <size_t N>
+HWY_INLINE Mask128<double, N> Gt(hwy::FloatTag /*tag*/, Vec128<double, N> a,
+ Vec128<double, N> b) {
+ return Mask128<double, N>{
+ reinterpret_cast<__m128d>(__lsx_vfcmp_clt_d(b.raw, a.raw))};
+}
+
+} // namespace detail
+
+template <typename T, size_t N>
+HWY_INLINE Mask128<T, N> operator>(Vec128<T, N> a, Vec128<T, N> b) {
+ return detail::Gt(hwy::TypeTag<T>(), a, b);
+}
+
+// ------------------------------ Weak inequality
+
+namespace detail {
+template <typename T, size_t N>
+HWY_INLINE Mask128<T, N> Ge(hwy::SignedTag tag, Vec128<T, N> a,
+ Vec128<T, N> b) {
+ return Not(Gt(tag, b, a));
+}
+
+template <typename T, size_t N>
+HWY_INLINE Mask128<T, N> Ge(hwy::UnsignedTag tag, Vec128<T, N> a,
+ Vec128<T, N> b) {
+ return Not(Gt(tag, b, a));
+}
+
+template <size_t N>
+HWY_INLINE Mask128<float, N> Ge(hwy::FloatTag /*tag*/, Vec128<float, N> a,
+ Vec128<float, N> b) {
+ return Mask128<float, N>{
+ reinterpret_cast<__m128>(__lsx_vfcmp_cle_s(b.raw, a.raw))};
+}
+template <size_t N>
+HWY_INLINE Mask128<double, N> Ge(hwy::FloatTag /*tag*/, Vec128<double, N> a,
+ Vec128<double, N> b) {
+ return Mask128<double, N>{
+ reinterpret_cast<__m128d>(__lsx_vfcmp_cle_d(b.raw, a.raw))};
+}
+
+} // namespace detail
+
+template <typename T, size_t N>
+HWY_API Mask128<T, N> operator>=(Vec128<T, N> a, Vec128<T, N> b) {
+ return detail::Ge(hwy::TypeTag<T>(), a, b);
+}
+
+// ------------------------------ Reversed comparisons
+
+template <typename T, size_t N>
+HWY_API Mask128<T, N> operator<(Vec128<T, N> a, Vec128<T, N> b) {
+ return b > a;
+}
+
+template <typename T, size_t N>
+HWY_API Mask128<T, N> operator<=(Vec128<T, N> a, Vec128<T, N> b) {
+ return b >= a;
+}
+
+// ------------------------------ Iota (Load)
+
+namespace detail {
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_T_SIZE_D(D, 1)>
+HWY_INLINE VFromD<D> Iota0(D d) {
+ return Dup128VecFromValues(
+ d, TFromD<D>{0}, TFromD<D>{1}, TFromD<D>{2}, TFromD<D>{3}, TFromD<D>{4},
+ TFromD<D>{5}, TFromD<D>{6}, TFromD<D>{7}, TFromD<D>{8}, TFromD<D>{9},
+ TFromD<D>{10}, TFromD<D>{11}, TFromD<D>{12}, TFromD<D>{13}, TFromD<D>{14},
+ TFromD<D>{15});
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_UI16_D(D)>
+HWY_INLINE VFromD<D> Iota0(D d) {
+ return Dup128VecFromValues(d, TFromD<D>{0}, TFromD<D>{1}, TFromD<D>{2},
+ TFromD<D>{3}, TFromD<D>{4}, TFromD<D>{5},
+ TFromD<D>{6}, TFromD<D>{7});
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_T_SIZE_D(D, 4)>
+HWY_INLINE VFromD<D> Iota0(D d) {
+ return Dup128VecFromValues(
+ d, static_cast<TFromD<D>>(0), static_cast<TFromD<D>>(1),
+ static_cast<TFromD<D>>(2), static_cast<TFromD<D>>(3));
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_T_SIZE_D(D, 8)>
+HWY_INLINE VFromD<D> Iota0(D d) {
+ return Dup128VecFromValues(d, static_cast<TFromD<D>>(0),
+ static_cast<TFromD<D>>(1));
+}
+
+} // namespace detail
+
+template <class D, typename T2, HWY_IF_V_SIZE_LE_D(D, 16)>
+HWY_API VFromD<D> Iota(D d, const T2 first) {
+ const auto result_iota =
+ detail::Iota0(d) + Set(d, static_cast<TFromD<D>>(first));
+ return result_iota;
+}
+
+// ------------------------------ FirstN (Iota, Lt)
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16)>
+HWY_API MFromD<D> FirstN(D d, size_t num) {
+ const RebindToSigned<decltype(d)> di; // Signed comparisons are cheaper.
+ using TI = TFromD<decltype(di)>;
+ return RebindMask(d, detail::Iota0(di) < Set(di, static_cast<TI>(num)));
+}
+
+// ------------------------------ InterleaveLower
+
+// Interleaves lanes from halves of the 128-bit blocks of "a" (which provides
+// the least-significant lane) and "b". To concatenate two half-width integers
+// into one, use ZipLower/Upper instead (also works with scalar).
+
+template <typename T, size_t N, HWY_IF_T_SIZE(T, 1)>
+HWY_API Vec128<T, N> InterleaveLower(Vec128<T, N> a, Vec128<T, N> b) {
+ return Vec128<T, N>{__lsx_vilvl_b(b.raw, a.raw)};
+}
+template <typename T, size_t N, HWY_IF_T_SIZE(T, 2)>
+HWY_API Vec128<T, N> InterleaveLower(Vec128<T, N> a, Vec128<T, N> b) {
+ return Vec128<T, N>{__lsx_vilvl_h(b.raw, a.raw)};
+}
+template <typename T, size_t N, HWY_IF_UI32(T)>
+HWY_API Vec128<T, N> InterleaveLower(Vec128<T, N> a, Vec128<T, N> b) {
+ return Vec128<T, N>{__lsx_vilvl_w(b.raw, a.raw)};
+}
+template <typename T, size_t N, HWY_IF_UI64(T)>
+HWY_API Vec128<T, N> InterleaveLower(Vec128<T, N> a, Vec128<T, N> b) {
+ return Vec128<T, N>{__lsx_vilvl_d(b.raw, a.raw)};
+}
+
+template <size_t N>
+HWY_API Vec128<float, N> InterleaveLower(Vec128<float, N> a,
+ Vec128<float, N> b) {
+ return Vec128<float, N>{reinterpret_cast<__m128>(__lsx_vilvl_w(
+ reinterpret_cast<__m128i>(b.raw), reinterpret_cast<__m128i>(a.raw)))};
+}
+template <size_t N>
+HWY_API Vec128<double, N> InterleaveLower(Vec128<double, N> a,
+ Vec128<double, N> b) {
+ return Vec128<double, N>{reinterpret_cast<__m128d>(__lsx_vilvl_d(
+ reinterpret_cast<__m128i>(b.raw), reinterpret_cast<__m128i>(a.raw)))};
+}
+
+// Generic for all vector lengths.
+template <class D>
+HWY_API VFromD<D> InterleaveLower(D /* tag */, VFromD<D> a, VFromD<D> b) {
+ return InterleaveLower(a, b);
+}
+
+// ------------------------------ BlendedStore
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16)>
+HWY_API void BlendedStore(VFromD<D> v, MFromD<D> m, D d,
+ TFromD<D>* HWY_RESTRICT p) {
+ StoreU(IfThenElse(m, v, LoadU(d, p)), d, p);
+}
+
+// ================================================== ARITHMETIC
+
+// ------------------------------ Addition
+
+// Unsigned
+template <size_t N>
+HWY_API Vec128<uint8_t, N> operator+(const Vec128<uint8_t, N> a,
+ const Vec128<uint8_t, N> b) {
+ return Vec128<uint8_t, N>{__lsx_vadd_b(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<uint16_t, N> operator+(const Vec128<uint16_t, N> a,
+ const Vec128<uint16_t, N> b) {
+ return Vec128<uint16_t, N>{__lsx_vadd_h(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<uint32_t, N> operator+(const Vec128<uint32_t, N> a,
+ const Vec128<uint32_t, N> b) {
+ return Vec128<uint32_t, N>{__lsx_vadd_w(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<uint64_t, N> operator+(const Vec128<uint64_t, N> a,
+ const Vec128<uint64_t, N> b) {
+ return Vec128<uint64_t, N>{__lsx_vadd_d(a.raw, b.raw)};
+}
+
+// Signed
+template <size_t N>
+HWY_API Vec128<int8_t, N> operator+(const Vec128<int8_t, N> a,
+ const Vec128<int8_t, N> b) {
+ return Vec128<int8_t, N>{__lsx_vadd_b(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<int16_t, N> operator+(const Vec128<int16_t, N> a,
+ const Vec128<int16_t, N> b) {
+ return Vec128<int16_t, N>{__lsx_vadd_h(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<int32_t, N> operator+(const Vec128<int32_t, N> a,
+ const Vec128<int32_t, N> b) {
+ return Vec128<int32_t, N>{__lsx_vadd_w(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<int64_t, N> operator+(const Vec128<int64_t, N> a,
+ const Vec128<int64_t, N> b) {
+ return Vec128<int64_t, N>{__lsx_vadd_d(a.raw, b.raw)};
+}
+
+template <size_t N>
+HWY_API Vec128<float, N> operator+(const Vec128<float, N> a,
+ const Vec128<float, N> b) {
+ return Vec128<float, N>{__lsx_vfadd_s(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<double, N> operator+(const Vec128<double, N> a,
+ const Vec128<double, N> b) {
+ return Vec128<double, N>{__lsx_vfadd_d(a.raw, b.raw)};
+}
+
+// ------------------------------ Subtraction
+
+// Unsigned
+template <size_t N>
+HWY_API Vec128<uint8_t, N> operator-(const Vec128<uint8_t, N> a,
+ const Vec128<uint8_t, N> b) {
+ return Vec128<uint8_t, N>{__lsx_vsub_b(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<uint16_t, N> operator-(Vec128<uint16_t, N> a,
+ Vec128<uint16_t, N> b) {
+ return Vec128<uint16_t, N>{__lsx_vsub_h(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<uint32_t, N> operator-(const Vec128<uint32_t, N> a,
+ const Vec128<uint32_t, N> b) {
+ return Vec128<uint32_t, N>{__lsx_vsub_w(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<uint64_t, N> operator-(const Vec128<uint64_t, N> a,
+ const Vec128<uint64_t, N> b) {
+ return Vec128<uint64_t, N>{__lsx_vsub_d(a.raw, b.raw)};
+}
+
+// Signed
+template <size_t N>
+HWY_API Vec128<int8_t, N> operator-(const Vec128<int8_t, N> a,
+ const Vec128<int8_t, N> b) {
+ return Vec128<int8_t, N>{__lsx_vsub_b(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<int16_t, N> operator-(const Vec128<int16_t, N> a,
+ const Vec128<int16_t, N> b) {
+ return Vec128<int16_t, N>{__lsx_vsub_h(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<int32_t, N> operator-(const Vec128<int32_t, N> a,
+ const Vec128<int32_t, N> b) {
+ return Vec128<int32_t, N>{__lsx_vsub_w(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<int64_t, N> operator-(const Vec128<int64_t, N> a,
+ const Vec128<int64_t, N> b) {
+ return Vec128<int64_t, N>{__lsx_vsub_d(a.raw, b.raw)};
+}
+
+template <size_t N>
+HWY_API Vec128<float, N> operator-(const Vec128<float, N> a,
+ const Vec128<float, N> b) {
+ return Vec128<float, N>{__lsx_vfsub_s(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<double, N> operator-(const Vec128<double, N> a,
+ const Vec128<double, N> b) {
+ return Vec128<double, N>{__lsx_vfsub_d(a.raw, b.raw)};
+}
+
+// ------------------------------ SumsOf2
+namespace detail {
+
+template <class V, HWY_IF_V_SIZE_LE_V(V, 16)>
+HWY_INLINE VFromD<RepartitionToWide<DFromV<V>>> SumsOf2(
+ hwy::SignedTag, hwy::SizeTag<1> /*lane_size_tag*/, V v) {
+ return VFromD<RepartitionToWide<DFromV<V>>>{__lsx_vhaddw_h_b(v.raw, v.raw)};
+}
+template <class V, HWY_IF_V_SIZE_LE_V(V, 16)>
+HWY_INLINE VFromD<RepartitionToWide<DFromV<V>>> SumsOf2(
+ hwy::UnsignedTag, hwy::SizeTag<1> /*lane_size_tag*/, V v) {
+ return VFromD<RepartitionToWide<DFromV<V>>>{__lsx_vhaddw_hu_bu(v.raw, v.raw)};
+}
+template <class V, HWY_IF_V_SIZE_LE_V(V, 16)>
+HWY_INLINE VFromD<RepartitionToWide<DFromV<V>>> SumsOf2(
+ hwy::SignedTag, hwy::SizeTag<2> /*lane_size_tag*/, V v) {
+ return VFromD<RepartitionToWide<DFromV<V>>>{__lsx_vhaddw_w_h(v.raw, v.raw)};
+}
+template <class V, HWY_IF_V_SIZE_LE_V(V, 16)>
+HWY_INLINE VFromD<RepartitionToWide<DFromV<V>>> SumsOf2(
+ hwy::UnsignedTag, hwy::SizeTag<2> /*lane_size_tag*/, V v) {
+ return VFromD<RepartitionToWide<DFromV<V>>>{__lsx_vhaddw_wu_hu(v.raw, v.raw)};
+}
+template <class V, HWY_IF_V_SIZE_LE_V(V, 16)>
+HWY_INLINE VFromD<RepartitionToWide<DFromV<V>>> SumsOf2(
+ hwy::SignedTag, hwy::SizeTag<4> /*lane_size_tag*/, V v) {
+ return VFromD<RepartitionToWide<DFromV<V>>>{__lsx_vhaddw_d_w(v.raw, v.raw)};
+}
+template <class V, HWY_IF_V_SIZE_LE_V(V, 16)>
+HWY_INLINE VFromD<RepartitionToWide<DFromV<V>>> SumsOf2(
+ hwy::UnsignedTag, hwy::SizeTag<4> /*lane_size_tag*/, V v) {
+ return VFromD<RepartitionToWide<DFromV<V>>>{__lsx_vhaddw_du_wu(v.raw, v.raw)};
+}
+
+} // namespace detail
+
+// ------------------------------ SumsOf8
+template <size_t N>
+HWY_API Vec128<uint64_t, N / 8> SumsOf8(const Vec128<uint8_t, N> v) {
+ __m128i temp = __lsx_vhaddw_hu_bu(v.raw, v.raw);
+ temp = __lsx_vhaddw_wu_hu(temp, temp);
+ return Vec128<uint64_t, N / 8>{__lsx_vhaddw_du_wu(temp, temp)};
+}
+template <size_t N>
+HWY_API Vec128<int64_t, N / 8> SumsOf8(const Vec128<int8_t, N> v) {
+ __m128i temp = __lsx_vhaddw_h_b(v.raw, v.raw);
+ temp = __lsx_vhaddw_w_h(temp, temp);
+ return Vec128<int64_t, N / 8>{__lsx_vhaddw_d_w(temp, temp)};
+}
+
+// ------------------------------ SaturatedAdd
+
+// Returns a + b clamped to the destination range.
+
+#ifdef HWY_NATIVE_I32_SATURATED_ADDSUB
+#undef HWY_NATIVE_I32_SATURATED_ADDSUB
+#else
+#define HWY_NATIVE_I32_SATURATED_ADDSUB
+#endif
+
+#ifdef HWY_NATIVE_I64_SATURATED_ADDSUB
+#undef HWY_NATIVE_I64_SATURATED_ADDSUB
+#else
+#define HWY_NATIVE_I64_SATURATED_ADDSUB
+#endif
+
+#ifdef HWY_NATIVE_U32_SATURATED_ADDSUB
+#undef HWY_NATIVE_U32_SATURATED_ADDSUB
+#else
+#define HWY_NATIVE_U32_SATURATED_ADDSUB
+#endif
+
+#ifdef HWY_NATIVE_U64_SATURATED_ADDSUB
+#undef HWY_NATIVE_U64_SATURATED_ADDSUB
+#else
+#define HWY_NATIVE_U64_SATURATED_ADDSUB
+#endif
+
+// Unsigned
+template <size_t N>
+HWY_API Vec128<uint8_t, N> SaturatedAdd(const Vec128<uint8_t, N> a,
+ const Vec128<uint8_t, N> b) {
+ return Vec128<uint8_t, N>{__lsx_vsadd_bu(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<uint16_t, N> SaturatedAdd(const Vec128<uint16_t, N> a,
+ const Vec128<uint16_t, N> b) {
+ return Vec128<uint16_t, N>{__lsx_vsadd_hu(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<uint32_t, N> SaturatedAdd(const Vec128<uint32_t, N> a,
+ const Vec128<uint32_t, N> b) {
+ return Vec128<uint32_t, N>{__lsx_vsadd_wu(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<uint64_t, N> SaturatedAdd(const Vec128<uint64_t, N> a,
+ const Vec128<uint64_t, N> b) {
+ return Vec128<uint64_t, N>{__lsx_vsadd_du(a.raw, b.raw)};
+}
+
+// signed
+template <size_t N>
+HWY_API Vec128<int8_t, N> SaturatedAdd(const Vec128<int8_t, N> a,
+ const Vec128<int8_t, N> b) {
+ return Vec128<int8_t, N>{__lsx_vsadd_b(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<int16_t, N> SaturatedAdd(const Vec128<int16_t, N> a,
+ const Vec128<int16_t, N> b) {
+ return Vec128<int16_t, N>{__lsx_vsadd_h(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<int32_t, N> SaturatedAdd(const Vec128<int32_t, N> a,
+ const Vec128<int32_t, N> b) {
+ return Vec128<int32_t, N>{__lsx_vsadd_w(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<int64_t, N> SaturatedAdd(const Vec128<int64_t, N> a,
+ const Vec128<int64_t, N> b) {
+ return Vec128<int64_t, N>{__lsx_vsadd_d(a.raw, b.raw)};
+}
+
+// ------------------------------ SaturatedSub
+
+// Returns a - b clamped to the destination range.
+
+// Unsigned
+template <size_t N>
+HWY_API Vec128<uint8_t, N> SaturatedSub(const Vec128<uint8_t, N> a,
+ const Vec128<uint8_t, N> b) {
+ return Vec128<uint8_t, N>{__lsx_vssub_bu(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<uint16_t, N> SaturatedSub(const Vec128<uint16_t, N> a,
+ const Vec128<uint16_t, N> b) {
+ return Vec128<uint16_t, N>{__lsx_vssub_hu(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<uint32_t, N> SaturatedSub(const Vec128<uint32_t, N> a,
+ const Vec128<uint32_t, N> b) {
+ return Vec128<uint32_t, N>{__lsx_vssub_wu(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<uint64_t, N> SaturatedSub(const Vec128<uint64_t, N> a,
+ const Vec128<uint64_t, N> b) {
+ return Vec128<uint64_t, N>{__lsx_vssub_du(a.raw, b.raw)};
+}
+
+// signed
+template <size_t N>
+HWY_API Vec128<int8_t, N> SaturatedSub(const Vec128<int8_t, N> a,
+ const Vec128<int8_t, N> b) {
+ return Vec128<int8_t, N>{__lsx_vssub_b(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<int16_t, N> SaturatedSub(const Vec128<int16_t, N> a,
+ const Vec128<int16_t, N> b) {
+ return Vec128<int16_t, N>{__lsx_vssub_h(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<int32_t, N> SaturatedSub(const Vec128<int32_t, N> a,
+ const Vec128<int32_t, N> b) {
+ return Vec128<int32_t, N>{__lsx_vssub_w(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<int64_t, N> SaturatedSub(const Vec128<int64_t, N> a,
+ const Vec128<int64_t, N> b) {
+ return Vec128<int64_t, N>{__lsx_vssub_d(a.raw, b.raw)};
+}
+
+// ------------------------------ AverageRound
+
+// Returns (a + b + 1) / 2
+
+#ifdef HWY_NATIVE_AVERAGE_ROUND_UI32
+#undef HWY_NATIVE_AVERAGE_ROUND_UI32
+#else
+#define HWY_NATIVE_AVERAGE_ROUND_UI32
+#endif
+
+#ifdef HWY_NATIVE_AVERAGE_ROUND_UI64
+#undef HWY_NATIVE_AVERAGE_ROUND_UI64
+#else
+#define HWY_NATIVE_AVERAGE_ROUND_UI64
+#endif
+
+// Unsigned
+template <size_t N>
+HWY_API Vec128<uint8_t, N> AverageRound(const Vec128<uint8_t, N> a,
+ const Vec128<uint8_t, N> b) {
+ return Vec128<uint8_t, N>{__lsx_vavgr_bu(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<uint16_t, N> AverageRound(const Vec128<uint16_t, N> a,
+ const Vec128<uint16_t, N> b) {
+ return Vec128<uint16_t, N>{__lsx_vavgr_hu(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<uint32_t, N> AverageRound(const Vec128<uint32_t, N> a,
+ const Vec128<uint32_t, N> b) {
+ return Vec128<uint32_t, N>{__lsx_vavgr_wu(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<uint64_t, N> AverageRound(const Vec128<uint64_t, N> a,
+ const Vec128<uint64_t, N> b) {
+ return Vec128<uint64_t, N>{__lsx_vavgr_du(a.raw, b.raw)};
+}
+
+// signed
+template <size_t N>
+HWY_API Vec128<int8_t, N> AverageRound(const Vec128<int8_t, N> a,
+ const Vec128<int8_t, N> b) {
+ return Vec128<int8_t, N>{__lsx_vavgr_b(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<int16_t, N> AverageRound(const Vec128<int16_t, N> a,
+ const Vec128<int16_t, N> b) {
+ return Vec128<int16_t, N>{__lsx_vavgr_h(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<int32_t, N> AverageRound(const Vec128<int32_t, N> a,
+ const Vec128<int32_t, N> b) {
+ return Vec128<int32_t, N>{__lsx_vavgr_w(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<int64_t, N> AverageRound(const Vec128<int64_t, N> a,
+ const Vec128<int64_t, N> b) {
+ return Vec128<int64_t, N>{__lsx_vavgr_d(a.raw, b.raw)};
+}
+
+// ------------------------------ Integer/Float multiplication
+
+// Per-target flags to prevent generic_ops-inl.h defining 8/64-bit operator*.
+#ifdef HWY_NATIVE_MUL_8
+#undef HWY_NATIVE_MUL_8
+#else
+#define HWY_NATIVE_MUL_8
+#endif
+#ifdef HWY_NATIVE_MUL_64
+#undef HWY_NATIVE_MUL_64
+#else
+#define HWY_NATIVE_MUL_64
+#endif
+
+template <typename T, size_t N, HWY_IF_UI8(T)>
+HWY_API Vec128<T, N> operator*(const Vec128<T, N> a, const Vec128<T, N> b) {
+ return Vec128<T, N>{__lsx_vmul_b(a.raw, b.raw)};
+}
+template <typename T, size_t N, HWY_IF_UI16(T)>
+HWY_API Vec128<T, N> operator*(const Vec128<T, N> a, const Vec128<T, N> b) {
+ return Vec128<T, N>{__lsx_vmul_h(a.raw, b.raw)};
+}
+template <typename T, size_t N, HWY_IF_UI32(T)>
+HWY_API Vec128<T, N> operator*(const Vec128<T, N> a, const Vec128<T, N> b) {
+ return Vec128<T, N>{__lsx_vmul_w(a.raw, b.raw)};
+}
+template <typename T, size_t N, HWY_IF_UI64(T)>
+HWY_API Vec128<T, N> operator*(const Vec128<T, N> a, const Vec128<T, N> b) {
+ return Vec128<T, N>{__lsx_vmul_d(a.raw, b.raw)};
+}
+
+template <size_t N>
+HWY_API Vec128<float, N> operator*(const Vec128<float, N> a,
+ const Vec128<float, N> b) {
+ return Vec128<float, N>{__lsx_vfmul_s(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<double, N> operator*(const Vec128<double, N> a,
+ const Vec128<double, N> b) {
+ return Vec128<double, N>{__lsx_vfmul_d(a.raw, b.raw)};
+}
+
+// ------------------------------ MulHigh
+
+// Usigned
+template <size_t N>
+HWY_API Vec128<uint8_t, N> MulHigh(const Vec128<uint8_t, N> a,
+ const Vec128<uint8_t, N> b) {
+ return Vec128<uint8_t, N>{__lsx_vmuh_bu(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<uint16_t, N> MulHigh(const Vec128<uint16_t, N> a,
+ const Vec128<uint16_t, N> b) {
+ return Vec128<uint16_t, N>{__lsx_vmuh_hu(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<uint32_t, N> MulHigh(const Vec128<uint32_t, N> a,
+ const Vec128<uint32_t, N> b) {
+ return Vec128<uint32_t, N>{__lsx_vmuh_wu(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<uint64_t, N> MulHigh(const Vec128<uint64_t, N> a,
+ const Vec128<uint64_t, N> b) {
+ return Vec128<uint64_t, N>{__lsx_vmuh_du(a.raw, b.raw)};
+}
+
+// signed
+template <size_t N>
+HWY_API Vec128<int8_t, N> MulHigh(const Vec128<int8_t, N> a,
+ const Vec128<int8_t, N> b) {
+ return Vec128<int8_t, N>{__lsx_vmuh_b(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<int16_t, N> MulHigh(const Vec128<int16_t, N> a,
+ const Vec128<int16_t, N> b) {
+ return Vec128<int16_t, N>{__lsx_vmuh_h(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<int32_t, N> MulHigh(const Vec128<int32_t, N> a,
+ const Vec128<int32_t, N> b) {
+ return Vec128<int32_t, N>{__lsx_vmuh_w(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<int64_t, N> MulHigh(const Vec128<int64_t, N> a,
+ const Vec128<int64_t, N> b) {
+ return Vec128<int64_t, N>{__lsx_vmuh_d(a.raw, b.raw)};
+}
+
+// ------------------------------ MulEven
+
+template <size_t N>
+HWY_API Vec128<int16_t, (N + 1) / 2> MulEven(Vec128<int8_t, N> a,
+ Vec128<int8_t, N> b) {
+ return Vec128<int16_t, (N + 1) / 2>{__lsx_vmulwev_h_b(a.raw, b.raw)};
+}
+
+template <size_t N>
+HWY_API Vec128<uint16_t, (N + 1) / 2> MulEven(Vec128<uint8_t, N> a,
+ Vec128<uint8_t, N> b) {
+ return Vec128<uint16_t, (N + 1) / 2>{__lsx_vmulwev_h_bu(a.raw, b.raw)};
+}
+
+template <size_t N>
+HWY_API Vec128<int32_t, (N + 1) / 2> MulEven(Vec128<int16_t, N> a,
+ Vec128<int16_t, N> b) {
+ return Vec128<int32_t, (N + 1) / 2>{__lsx_vmulwev_w_h(a.raw, b.raw)};
+}
+
+template <size_t N>
+HWY_API Vec128<uint32_t, (N + 1) / 2> MulEven(Vec128<uint16_t, N> a,
+ Vec128<uint16_t, N> b) {
+ return Vec128<uint32_t, (N + 1) / 2>{__lsx_vmulwev_w_hu(a.raw, b.raw)};
+}
+
+template <size_t N>
+HWY_API Vec128<int64_t, (N + 1) / 2> MulEven(Vec128<int32_t, N> a,
+ Vec128<int32_t, N> b) {
+ return Vec128<int64_t, (N + 1) / 2>{__lsx_vmulwev_d_w(a.raw, b.raw)};
+}
+
+template <size_t N>
+HWY_API Vec128<uint64_t, (N + 1) / 2> MulEven(Vec128<uint32_t, N> a,
+ Vec128<uint32_t, N> b) {
+ return Vec128<uint64_t, (N + 1) / 2>{__lsx_vmulwev_d_wu(a.raw, b.raw)};
+}
+
+template <typename T, HWY_IF_I64(T)>
+HWY_API Vec128<T> MulEven(Vec128<T> a, Vec128<T> b) {
+ return Vec128<T>{__lsx_vmulwev_q_d(a.raw, b.raw)};
+}
+
+template <typename T, HWY_IF_U64(T)>
+HWY_API Vec128<T> MulEven(Vec128<T> a, Vec128<T> b) {
+ return Vec128<T>{__lsx_vmulwev_q_du(a.raw, b.raw)};
+}
+
+// ------------------------------ MulOdd
+
+template <size_t N>
+HWY_API Vec128<int16_t, (N + 1) / 2> MulOdd(Vec128<int8_t, N> a,
+ Vec128<int8_t, N> b) {
+ return Vec128<int16_t, (N + 1) / 2>{__lsx_vmulwod_h_b(a.raw, b.raw)};
+}
+
+template <size_t N>
+HWY_API Vec128<uint16_t, (N + 1) / 2> MulOdd(Vec128<uint8_t, N> a,
+ Vec128<uint8_t, N> b) {
+ return Vec128<uint16_t, (N + 1) / 2>{__lsx_vmulwod_h_bu(a.raw, b.raw)};
+}
+
+template <size_t N>
+HWY_API Vec128<int32_t, (N + 1) / 2> MulOdd(Vec128<int16_t, N> a,
+ Vec128<int16_t, N> b) {
+ return Vec128<int32_t, (N + 1) / 2>{__lsx_vmulwod_w_h(a.raw, b.raw)};
+}
+
+template <size_t N>
+HWY_API Vec128<uint32_t, (N + 1) / 2> MulOdd(Vec128<uint16_t, N> a,
+ Vec128<uint16_t, N> b) {
+ return Vec128<uint32_t, (N + 1) / 2>{__lsx_vmulwod_w_hu(a.raw, b.raw)};
+}
+
+template <size_t N>
+HWY_API Vec128<int64_t, (N + 1) / 2> MulOdd(Vec128<int32_t, N> a,
+ Vec128<int32_t, N> b) {
+ return Vec128<int64_t, (N + 1) / 2>{__lsx_vmulwod_d_w(a.raw, b.raw)};
+}
+
+template <size_t N>
+HWY_API Vec128<uint64_t, (N + 1) / 2> MulOdd(Vec128<uint32_t, N> a,
+ Vec128<uint32_t, N> b) {
+ return Vec128<uint64_t, (N + 1) / 2>{__lsx_vmulwod_d_wu(a.raw, b.raw)};
+}
+
+template <typename T, HWY_IF_I64(T)>
+HWY_API Vec128<T> MulOdd(Vec128<T> a, Vec128<T> b) {
+ return Vec128<T>{__lsx_vmulwod_q_d(a.raw, b.raw)};
+}
+
+template <typename T, HWY_IF_U64(T)>
+HWY_API Vec128<T> MulOdd(Vec128<T> a, Vec128<T> b) {
+ return Vec128<T>{__lsx_vmulwod_q_du(a.raw, b.raw)};
+}
+
+// ------------------------------ RotateRight (ShiftRight, Or)
+
+template <int kBits, typename T, size_t N, HWY_IF_UI8(T)>
+HWY_API Vec128<T, N> RotateRight(const Vec128<T, N> v) {
+ return Vec128<T, N>{__lsx_vrotri_b(v.raw, kBits)};
+}
+template <int kBits, typename T, size_t N, HWY_IF_UI16(T)>
+HWY_API Vec128<T, N> RotateRight(const Vec128<T, N> v) {
+ return Vec128<T, N>{__lsx_vrotri_h(v.raw, kBits)};
+}
+template <int kBits, typename T, size_t N, HWY_IF_UI32(T)>
+HWY_API Vec128<T, N> RotateRight(const Vec128<T, N> v) {
+ return Vec128<T, N>{__lsx_vrotri_w(v.raw, kBits)};
+}
+template <int kBits, typename T, size_t N, HWY_IF_UI64(T)>
+HWY_API Vec128<T, N> RotateRight(const Vec128<T, N> v) {
+ return Vec128<T, N>{__lsx_vrotri_d(v.raw, kBits)};
+}
+
+// ------------------------------ Ror
+#ifdef HWY_NATIVE_ROL_ROR_8
+#undef HWY_NATIVE_ROL_ROR_8
+#else
+#define HWY_NATIVE_ROL_ROR_8
+#endif
+
+#ifdef HWY_NATIVE_ROL_ROR_16
+#undef HWY_NATIVE_ROL_ROR_16
+#else
+#define HWY_NATIVE_ROL_ROR_16
+#endif
+
+#ifdef HWY_NATIVE_ROL_ROR_32_64
+#undef HWY_NATIVE_ROL_ROR_32_64
+#else
+#define HWY_NATIVE_ROL_ROR_32_64
+#endif
+
+template <class T, size_t N, HWY_IF_UI8(T)>
+HWY_API Vec128<T, N> Ror(Vec128<T, N> a, Vec128<T, N> b) {
+ return Vec128<T, N>{__lsx_vrotr_b(a.raw, b.raw)};
+}
+
+template <class T, size_t N, HWY_IF_UI16(T)>
+HWY_API Vec128<T, N> Ror(Vec128<T, N> a, Vec128<T, N> b) {
+ return Vec128<T, N>{__lsx_vrotr_h(a.raw, b.raw)};
+}
+
+template <class T, size_t N, HWY_IF_UI32(T)>
+HWY_API Vec128<T, N> Ror(Vec128<T, N> a, Vec128<T, N> b) {
+ return Vec128<T, N>{__lsx_vrotr_w(a.raw, b.raw)};
+}
+
+template <class T, size_t N, HWY_IF_UI64(T)>
+HWY_API Vec128<T, N> Ror(Vec128<T, N> a, Vec128<T, N> b) {
+ return Vec128<T, N>{__lsx_vrotr_d(a.raw, b.raw)};
+}
+
+// Rol is generic for all vector lengths
+template <class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V)>
+HWY_API V Rol(V a, V b) {
+ const DFromV<decltype(a)> d;
+ const RebindToSigned<decltype(d)> di;
+
+ return Ror(a, BitCast(d, Neg(BitCast(di, b))));
+}
+
+// ------------------------------ RotateLeftSame/RotateRightSame
+
+#ifdef HWY_NATIVE_ROL_ROR_SAME_8
+#undef HWY_NATIVE_ROL_ROR_SAME_8
+#else
+#define HWY_NATIVE_ROL_ROR_SAME_8
+#endif
+
+#ifdef HWY_NATIVE_ROL_ROR_SAME_16
+#undef HWY_NATIVE_ROL_ROR_SAME_16
+#else
+#define HWY_NATIVE_ROL_ROR_SAME_16
+#endif
+
+#ifdef HWY_NATIVE_ROL_ROR_SAME_32_64
+#undef HWY_NATIVE_ROL_ROR_SAME_32_64
+#else
+#define HWY_NATIVE_ROL_ROR_SAME_32_64
+#endif
+
+// RotateLeftSame/RotateRightSame are generic for all vector lengths
+template <class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V)>
+HWY_API V RotateLeftSame(V v, int bits) {
+ using T = TFromV<V>;
+ const DFromV<decltype(v)> d;
+ return Rol(v, Set(d, static_cast<T>(bits)));
+}
+
+template <class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V)>
+HWY_API V RotateRightSame(V v, int bits) {
+ using T = TFromV<V>;
+ const DFromV<decltype(v)> d;
+ return Ror(v, Set(d, static_cast<T>(bits)));
+}
+
+// ------------------------------ BroadcastSignBit
+
+template <typename T, size_t N, HWY_IF_SIGNED(T)>
+HWY_API Vec128<T, N> BroadcastSignBit(const Vec128<T, N> v) {
+ return ShiftRight<sizeof(T) * 8 - 1>(v);
+}
+
+// ------------------------------ Integer Abs
+
+// Returns absolute value, except that LimitsMin() maps to LimitsMax() + 1.
+template <size_t N>
+HWY_API Vec128<int8_t, N> Abs(const Vec128<int8_t, N> v) {
+ return Vec128<int8_t, N>{__lsx_vabsd_b(v.raw, __lsx_vreplgr2vr_b(0))};
+}
+template <size_t N>
+HWY_API Vec128<int16_t, N> Abs(const Vec128<int16_t, N> v) {
+ return Vec128<int16_t, N>{__lsx_vabsd_h(v.raw, __lsx_vreplgr2vr_b(0))};
+}
+template <size_t N>
+HWY_API Vec128<int32_t, N> Abs(const Vec128<int32_t, N> v) {
+ return Vec128<int32_t, N>{__lsx_vabsd_w(v.raw, __lsx_vreplgr2vr_b(0))};
+}
+template <size_t N>
+HWY_API Vec128<int64_t, N> Abs(const Vec128<int64_t, N> v) {
+ return Vec128<int64_t, N>{__lsx_vabsd_d(v.raw, __lsx_vreplgr2vr_b(0))};
+}
+
+// ------------------------------ SaturatedAbs
+
+#ifdef HWY_NATIVE_SATURATED_ABS
+#undef HWY_NATIVE_SATURATED_ABS
+#else
+#define HWY_NATIVE_SATURATED_ABS
+#endif
+
+template <class V, HWY_IF_I8(TFromV<V>)>
+HWY_API V SaturatedAbs(V v) {
+ const DFromV<decltype(v)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(d, Min(BitCast(du, v), BitCast(du, SaturatedSub(Zero(d), v))));
+}
+template <class V, HWY_IF_I16(TFromV<V>)>
+HWY_API V SaturatedAbs(V v) {
+ return Max(v, SaturatedSub(Zero(DFromV<V>()), v));
+}
+template <class V, HWY_IF_I32(TFromV<V>)>
+HWY_API V SaturatedAbs(V v) {
+ const auto abs_v = Abs(v);
+ const DFromV<decltype(v)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(d, Min(BitCast(du, abs_v),
+ Set(du, static_cast<uint32_t>(LimitsMax<int32_t>()))));
+}
+template <class V, HWY_IF_I64(TFromV<V>)>
+HWY_API V SaturatedAbs(V v) {
+ const auto abs_v = Abs(v);
+ return Add(abs_v, BroadcastSignBit(abs_v));
+}
+
+// ------------------------------ IfNegativeThenElse
+template <typename T, size_t N>
+HWY_API Vec128<T, N> IfNegativeThenElse(Vec128<T, N> v, Vec128<T, N> yes,
+ Vec128<T, N> no) {
+ static_assert(IsSigned<T>(), "Only works for signed/float");
+ const DFromV<decltype(no)> d;
+ const RebindToSigned<decltype(d)> di;
+
+ Mask128<T, N> m = MaskFromVec(BitCast(d, BroadcastSignBit(BitCast(di, v))));
+ return IfThenElse(m, yes, no);
+}
+
+// ------------------------------ IfNegativeThenNegOrUndefIfZero
+
+#ifdef HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG
+#undef HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG
+#else
+#define HWY_NATIVE_INTEGER_IF_NEGATIVE_THEN_NEG
+#endif
+
+template <size_t N>
+HWY_API Vec128<int8_t, N> IfNegativeThenNegOrUndefIfZero(Vec128<int8_t, N> mask,
+ Vec128<int8_t, N> v) {
+ return Vec128<int8_t, N>{__lsx_vsigncov_b(mask.raw, v.raw)};
+}
+
+template <size_t N>
+HWY_API Vec128<int16_t, N> IfNegativeThenNegOrUndefIfZero(
+ Vec128<int16_t, N> mask, Vec128<int16_t, N> v) {
+ return Vec128<int16_t, N>{__lsx_vsigncov_h(mask.raw, v.raw)};
+}
+
+template <size_t N>
+HWY_API Vec128<int32_t, N> IfNegativeThenNegOrUndefIfZero(
+ Vec128<int32_t, N> mask, Vec128<int32_t, N> v) {
+ return Vec128<int32_t, N>{__lsx_vsigncov_w(mask.raw, v.raw)};
+}
+
+template <size_t N>
+HWY_API Vec128<int64_t, N> IfNegativeThenNegOrUndefIfZero(
+ Vec128<int64_t, N> mask, Vec128<int64_t, N> v) {
+ return Vec128<int64_t, N>{__lsx_vsigncov_d(mask.raw, v.raw)};
+}
+
+// ------------------------------ ShiftLeftSame/ShiftRightSame
+
+template <typename T, size_t N>
+HWY_API Vec128<T, N> ShiftLeftSame(const Vec128<T, N> v, int bits) {
+ return v << Set(DFromV<decltype(v)>(), static_cast<T>(bits));
+}
+template <typename T, size_t N>
+HWY_API Vec128<T, N> ShiftRightSame(const Vec128<T, N> v, int bits) {
+ return v >> Set(DFromV<decltype(v)>(), static_cast<T>(bits));
+}
+
+// ------------------------------ Integer/Float Div
+
+#ifdef HWY_NATIVE_INT_DIV
+#undef HWY_NATIVE_INT_DIV
+#else
+#define HWY_NATIVE_INT_DIV
+#endif
+
+template <size_t N>
+HWY_API Vec128<int8_t, N> operator/(const Vec128<int8_t, N> a,
+ const Vec128<int8_t, N> b) {
+ // Use inline assembly to avoid undefined behavior if any lanes of b are zero
+ // or a[i] == LimitsMin<int8_t>() && b[i] == -1
+ __m128i raw_result;
+ __asm__("vdiv.b %w0,%w1,%w2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :);
+ return Vec128<int8_t, N>{raw_result};
+}
+
+template <size_t N>
+HWY_API Vec128<uint8_t, N> operator/(const Vec128<uint8_t, N> a,
+ const Vec128<uint8_t, N> b) {
+ // Use inline assembly to avoid undefined behavior if any lanes of b are zero
+ __m128i raw_result;
+ __asm__("vdiv.bu %w0,%w1,%w2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :);
+ return Vec128<uint8_t, N>{raw_result};
+}
+
+template <size_t N>
+HWY_API Vec128<int16_t, N> operator/(const Vec128<int16_t, N> a,
+ const Vec128<int16_t, N> b) {
+ // Use inline assembly to avoid undefined behavior if any lanes of b are zero
+ // or a[i] == LimitsMin<int16_t>() && b[i] == -1
+ __m128i raw_result;
+ __asm__("vdiv.h %w0,%w1,%w2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :);
+ return Vec128<int16_t, N>{raw_result};
+}
+
+template <size_t N>
+HWY_API Vec128<uint16_t, N> operator/(const Vec128<uint16_t, N> a,
+ const Vec128<uint16_t, N> b) {
+ // Use inline assembly to avoid undefined behavior if any lanes of b are zero
+ __m128i raw_result;
+ __asm__("vdiv.hu %w0,%w1,%w2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :);
+ return Vec128<uint16_t, N>{raw_result};
+}
+
+template <size_t N>
+HWY_API Vec128<int32_t, N> operator/(const Vec128<int32_t, N> a,
+ const Vec128<int32_t, N> b) {
+ // Use inline assembly to avoid undefined behavior if any lanes of b are zero
+ // or a[i] == LimitsMin<int32_t>() && b[i] == -1
+ __m128i raw_result;
+ __asm__("vdiv.w %w0,%w1,%w2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :);
+ return Vec128<int32_t, N>{raw_result};
+}
+
+template <size_t N>
+HWY_API Vec128<uint32_t, N> operator/(const Vec128<uint32_t, N> a,
+ const Vec128<uint32_t, N> b) {
+ // Use inline assembly to avoid undefined behavior if any lanes of b are zero
+ __m128i raw_result;
+ __asm__("vdiv.wu %w0,%w1,%w2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :);
+ return Vec128<uint32_t, N>{raw_result};
+}
+
+template <size_t N>
+HWY_API Vec128<int64_t, N> operator/(const Vec128<int64_t, N> a,
+ const Vec128<int64_t, N> b) {
+ // Use inline assembly to avoid undefined behavior if any lanes of b are zero
+ // or a[i] == LimitsMin<int64_t>() && b[i] == -1
+ __m128i raw_result;
+ __asm__("vdiv.d %w0,%w1,%w2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :);
+ return Vec128<int64_t, N>{raw_result};
+}
+
+template <size_t N>
+HWY_API Vec128<uint64_t, N> operator/(const Vec128<uint64_t, N> a,
+ const Vec128<uint64_t, N> b) {
+ // Use inline assembly to avoid undefined behavior if any lanes of b are zero
+ __m128i raw_result;
+ __asm__("vdiv.du %w0,%w1,%w2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :);
+ return Vec128<uint64_t, N>{raw_result};
+}
+
+template <size_t N>
+HWY_API Vec128<float, N> operator/(const Vec128<float, N> a,
+ const Vec128<float, N> b) {
+ return Vec128<float, N>{__lsx_vfdiv_s(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<double, N> operator/(const Vec128<double, N> a,
+ const Vec128<double, N> b) {
+ return Vec128<double, N>{__lsx_vfdiv_d(a.raw, b.raw)};
+}
+
+// ------------------------------ Integer Mod
+
+template <size_t N>
+HWY_API Vec128<int8_t, N> operator%(const Vec128<int8_t, N> a,
+ const Vec128<int8_t, N> b) {
+ // Use inline assembly to avoid undefined behavior if any lanes of b are zero
+ // or a[i] == LimitsMin<int8_t>() && b[i] == -1
+ __m128i raw_result;
+ __asm__("vmod.b %w0,%w1,%w2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :);
+ return Vec128<int8_t, N>{raw_result};
+}
+
+template <size_t N>
+HWY_API Vec128<uint8_t, N> operator%(const Vec128<uint8_t, N> a,
+ const Vec128<uint8_t, N> b) {
+ // Use inline assembly to avoid undefined behavior if any lanes of b are zero
+ __m128i raw_result;
+ __asm__("vmod.bu %w0,%w1,%w2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :);
+ return Vec128<uint8_t, N>{raw_result};
+}
+
+template <size_t N>
+HWY_API Vec128<int16_t, N> operator%(const Vec128<int16_t, N> a,
+ const Vec128<int16_t, N> b) {
+ // Use inline assembly to avoid undefined behavior if any lanes of b are zero
+ // or a[i] == LimitsMin<int16_t>() && b[i] == -1
+ __m128i raw_result;
+ __asm__("vmod.h %w0,%w1,%w2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :);
+ return Vec128<int16_t, N>{raw_result};
+}
+
+template <size_t N>
+HWY_API Vec128<uint16_t, N> operator%(const Vec128<uint16_t, N> a,
+ const Vec128<uint16_t, N> b) {
+ // Use inline assembly to avoid undefined behavior if any lanes of b are zero
+ __m128i raw_result;
+ __asm__("vmod.hu %w0,%w1,%w2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :);
+ return Vec128<uint16_t, N>{raw_result};
+}
+
+template <size_t N>
+HWY_API Vec128<int32_t, N> operator%(const Vec128<int32_t, N> a,
+ const Vec128<int32_t, N> b) {
+ // Use inline assembly to avoid undefined behavior if any lanes of b are zero
+ // or a[i] == LimitsMin<int32_t>() && b[i] == -1
+ __m128i raw_result;
+ __asm__("vmod.w %w0,%w1,%w2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :);
+ return Vec128<int32_t, N>{raw_result};
+}
+
+template <size_t N>
+HWY_API Vec128<uint32_t, N> operator%(const Vec128<uint32_t, N> a,
+ const Vec128<uint32_t, N> b) {
+ // Use inline assembly to avoid undefined behavior if any lanes of b are zero
+ __m128i raw_result;
+ __asm__("vmod.wu %w0,%w1,%w2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :);
+ return Vec128<uint32_t, N>{raw_result};
+}
+
+template <size_t N>
+HWY_API Vec128<int64_t, N> operator%(const Vec128<int64_t, N> a,
+ const Vec128<int64_t, N> b) {
+ // Use inline assembly to avoid undefined behavior if any lanes of b are zero
+ // or a[i] == LimitsMin<int64_t>() && b[i] == -1
+ __m128i raw_result;
+ __asm__("vmod.d %w0,%w1,%w2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :);
+ return Vec128<int64_t, N>{raw_result};
+}
+
+template <size_t N>
+HWY_API Vec128<uint64_t, N> operator%(const Vec128<uint64_t, N> a,
+ const Vec128<uint64_t, N> b) {
+ // Use inline assembly to avoid undefined behavior if any lanes of b are zero
+ __m128i raw_result;
+ __asm__("vmod.du %w0,%w1,%w2" : "=f"(raw_result) : "f"(a.raw), "f"(b.raw) :);
+ return Vec128<uint64_t, N>{raw_result};
+}
+
+// ------------------------------ ApproximateReciprocal
+
+#ifdef HWY_NATIVE_F64_APPROX_RECIP
+#undef HWY_NATIVE_F64_APPROX_RECIP
+#else
+#define HWY_NATIVE_F64_APPROX_RECIP
+#endif
+
+template <size_t N>
+HWY_API Vec128<float, N> ApproximateReciprocal(const Vec128<float, N> v) {
+ return Vec128<float, N>{__lsx_vfrecip_s(v.raw)};
+}
+template <size_t N>
+HWY_API Vec128<double, N> ApproximateReciprocal(const Vec128<double, N> v) {
+ return Vec128<double, N>{__lsx_vfrecip_d(v.raw)};
+}
+
+// ------------------------------ Absolute value of difference
+
+#ifdef HWY_NATIVE_INTEGER_ABS_DIFF
+#undef HWY_NATIVE_INTEGER_ABS_DIFF
+#else
+#define HWY_NATIVE_INTEGER_ABS_DIFF
+#endif
+
+template <size_t N>
+HWY_API Vec128<int8_t, N> AbsDiff(const Vec128<int8_t, N> a,
+ Vec128<int8_t, N> b) {
+ return Vec128<int8_t, N>{__lsx_vabsd_b(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<int16_t, N> AbsDiff(const Vec128<int16_t, N> a,
+ Vec128<int16_t, N> b) {
+ return Vec128<int16_t, N>{__lsx_vabsd_h(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<int32_t, N> AbsDiff(const Vec128<int32_t, N> a,
+ Vec128<int32_t, N> b) {
+ return Vec128<int32_t, N>{__lsx_vabsd_w(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<int64_t, N> AbsDiff(const Vec128<int64_t, N> a,
+ Vec128<int64_t, N> b) {
+ return Vec128<int64_t, N>{__lsx_vabsd_d(a.raw, b.raw)};
+}
+
+template <size_t N>
+HWY_API Vec128<uint8_t, N> AbsDiff(const Vec128<uint8_t, N> a,
+ Vec128<uint8_t, N> b) {
+ return Vec128<uint8_t, N>{__lsx_vabsd_bu(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<uint16_t, N> AbsDiff(const Vec128<uint16_t, N> a,
+ Vec128<uint16_t, N> b) {
+ return Vec128<uint16_t, N>{__lsx_vabsd_hu(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<uint32_t, N> AbsDiff(const Vec128<uint32_t, N> a,
+ Vec128<uint32_t, N> b) {
+ return Vec128<uint32_t, N>{__lsx_vabsd_wu(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<uint64_t, N> AbsDiff(const Vec128<uint64_t, N> a,
+ Vec128<uint64_t, N> b) {
+ return Vec128<uint64_t, N>{__lsx_vabsd_du(a.raw, b.raw)};
+}
+
+// Generic for all vector lengths.
+template <class V, HWY_IF_FLOAT_V(V)>
+HWY_API V AbsDiff(V a, V b) {
+ return Abs(a - b);
+}
+
+// ------------------------------ Integer/Float multiply-add
+
+#ifdef HWY_NATIVE_INT_FMA
+#undef HWY_NATIVE_INT_FMA
+#else
+#define HWY_NATIVE_INT_FMA
+#endif
+
+template <size_t N>
+HWY_API Vec128<int8_t, N> MulAdd(Vec128<int8_t, N> mul, Vec128<int8_t, N> x,
+ Vec128<int8_t, N> add) {
+ return Vec128<int8_t, N>{__lsx_vmadd_b(add.raw, mul.raw, x.raw)};
+}
+template <size_t N>
+HWY_API Vec128<int16_t, N> MulAdd(Vec128<int16_t, N> mul, Vec128<int16_t, N> x,
+ Vec128<int16_t, N> add) {
+ return Vec128<int16_t, N>{__lsx_vmadd_h(add.raw, mul.raw, x.raw)};
+}
+template <size_t N>
+HWY_API Vec128<int32_t, N> MulAdd(Vec128<int32_t, N> mul, Vec128<int32_t, N> x,
+ Vec128<int32_t, N> add) {
+ return Vec128<int32_t, N>{__lsx_vmadd_w(add.raw, mul.raw, x.raw)};
+}
+template <size_t N>
+HWY_API Vec128<int64_t, N> MulAdd(Vec128<int64_t, N> mul, Vec128<int64_t, N> x,
+ Vec128<int64_t, N> add) {
+ return Vec128<int64_t, N>{__lsx_vmadd_d(add.raw, mul.raw, x.raw)};
+}
+
+template <size_t N>
+HWY_API Vec128<float, N> MulAdd(Vec128<float, N> mul, Vec128<float, N> x,
+ Vec128<float, N> add) {
+ return Vec128<float, N>{__lsx_vfmadd_s(mul.raw, x.raw, add.raw)};
+}
+template <size_t N>
+HWY_API Vec128<double, N> MulAdd(Vec128<double, N> mul, Vec128<double, N> x,
+ Vec128<double, N> add) {
+ return Vec128<double, N>{__lsx_vfmadd_d(mul.raw, x.raw, add.raw)};
+}
+
+// Unsinged
+template <typename T, size_t N, HWY_IF_UNSIGNED(T)>
+HWY_API Vec128<T, N> MulAdd(Vec128<T, N> mul, Vec128<T, N> x,
+ Vec128<T, N> add) {
+ return mul * x + add;
+}
+
+// ------------------------------ Integer/Float NegMulAdd
+
+template <size_t N>
+HWY_API Vec128<int8_t, N> NegMulAdd(Vec128<int8_t, N> mul, Vec128<int8_t, N> x,
+ Vec128<int8_t, N> add) {
+ return Vec128<int8_t, N>{__lsx_vmsub_b(add.raw, mul.raw, x.raw)};
+}
+template <size_t N>
+HWY_API Vec128<int16_t, N> NegMulAdd(Vec128<int16_t, N> mul,
+ Vec128<int16_t, N> x,
+ Vec128<int16_t, N> add) {
+ return Vec128<int16_t, N>{__lsx_vmsub_h(add.raw, mul.raw, x.raw)};
+}
+template <size_t N>
+HWY_API Vec128<int32_t, N> NegMulAdd(Vec128<int32_t, N> mul,
+ Vec128<int32_t, N> x,
+ Vec128<int32_t, N> sub) {
+ return Vec128<int32_t, N>{__lsx_vmsub_w(sub.raw, mul.raw, x.raw)};
+}
+template <size_t N>
+HWY_API Vec128<int64_t, N> NegMulAdd(Vec128<int64_t, N> mul,
+ Vec128<int64_t, N> x,
+ Vec128<int64_t, N> sub) {
+ return Vec128<int64_t, N>{__lsx_vmsub_d(sub.raw, mul.raw, x.raw)};
+}
+
+// Float/unsigned
+template <typename T, size_t N, HWY_IF_NOT_SPECIAL_FLOAT(T)>
+HWY_API Vec128<T, N> NegMulAdd(Vec128<T, N> mul, Vec128<T, N> x,
+ Vec128<T, N> add) {
+ return add - mul * x;
+}
+
+// ------------------------------ Float MulSub
+
+// float
+template <size_t N>
+HWY_API Vec128<float, N> MulSub(Vec128<float, N> mul, Vec128<float, N> x,
+ Vec128<float, N> sub) {
+ return Vec128<float, N>{__lsx_vfmsub_s(x.raw, mul.raw, sub.raw)};
+}
+template <size_t N>
+HWY_API Vec128<double, N> MulSub(Vec128<double, N> mul, Vec128<double, N> x,
+ Vec128<double, N> sub) {
+ return Vec128<double, N>{__lsx_vfmsub_d(x.raw, mul.raw, sub.raw)};
+}
+
+// unsigned
+template <typename T, size_t N, HWY_IF_NOT_SPECIAL_FLOAT(T)>
+HWY_API Vec128<T, N> MulSub(Vec128<T, N> mul, Vec128<T, N> x,
+ Vec128<T, N> sub) {
+ return mul * x - sub;
+}
+
+// ------------------------------ Float NegMulSub
+
+// float/unsigned
+template <typename T, size_t N, HWY_IF_NOT_SPECIAL_FLOAT(T)>
+HWY_API Vec128<T, N> NegMulSub(Vec128<T, N> mul, Vec128<T, N> x,
+ Vec128<T, N> sub) {
+ return Neg(mul) * x - sub;
+}
+
+// ------------------------------ Floating-point square root
+
+template <size_t N>
+HWY_API Vec128<float, N> Sqrt(Vec128<float, N> v) {
+ return Vec128<float, N>{__lsx_vfsqrt_s(v.raw)};
+}
+template <size_t N>
+HWY_API Vec128<double, N> Sqrt(Vec128<double, N> v) {
+ return Vec128<double, N>{__lsx_vfsqrt_d(v.raw)};
+}
+
+// ------------------------------ ApproximateReciprocalSqrt
+#ifdef HWY_NATIVE_F64_APPROX_RSQRT
+#undef HWY_NATIVE_F64_APPROX_RSQRT
+#else
+#define HWY_NATIVE_F64_APPROX_RSQRT
+#endif
+
+template <size_t N>
+HWY_API Vec128<float, N> ApproximateReciprocalSqrt(Vec128<float, N> v) {
+ return Vec128<float, N>{__lsx_vfrsqrt_s(v.raw)};
+}
+template <size_t N>
+HWY_API Vec128<double, N> ApproximateReciprocalSqrt(Vec128<double, N> v) {
+ return Vec128<double, N>{__lsx_vfrsqrt_d(v.raw)};
+}
+
+// ------------------------------ Min
+
+template <size_t N>
+HWY_API Vec128<uint8_t, N> Min(Vec128<uint8_t, N> a, Vec128<uint8_t, N> b) {
+ return Vec128<uint8_t, N>{__lsx_vmin_bu(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<uint16_t, N> Min(Vec128<uint16_t, N> a, Vec128<uint16_t, N> b) {
+ return Vec128<uint16_t, N>{__lsx_vmin_hu(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<uint32_t, N> Min(Vec128<uint32_t, N> a, Vec128<uint32_t, N> b) {
+ return Vec128<uint32_t, N>{__lsx_vmin_wu(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<uint64_t, N> Min(Vec128<uint64_t, N> a, Vec128<uint64_t, N> b) {
+ return Vec128<uint64_t, N>{__lsx_vmin_du(a.raw, b.raw)};
+}
+
+template <size_t N>
+HWY_API Vec128<int8_t, N> Min(Vec128<int8_t, N> a, Vec128<int8_t, N> b) {
+ return Vec128<int8_t, N>{__lsx_vmin_b(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<int16_t, N> Min(Vec128<int16_t, N> a, Vec128<int16_t, N> b) {
+ return Vec128<int16_t, N>{__lsx_vmin_h(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<int32_t, N> Min(Vec128<int32_t, N> a, Vec128<int32_t, N> b) {
+ return Vec128<int32_t, N>{__lsx_vmin_w(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<int64_t, N> Min(Vec128<int64_t, N> a, Vec128<int64_t, N> b) {
+ return Vec128<int64_t, N>{__lsx_vmin_d(a.raw, b.raw)};
+}
+
+template <size_t N>
+HWY_API Vec128<float, N> Min(Vec128<float, N> a, Vec128<float, N> b) {
+ return Vec128<float, N>{__lsx_vfmin_s(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<double, N> Min(Vec128<double, N> a, Vec128<double, N> b) {
+ return Vec128<double, N>{__lsx_vfmin_d(a.raw, b.raw)};
+}
+
+// ------------------------------ Max
+
+template <size_t N>
+HWY_API Vec128<uint8_t, N> Max(Vec128<uint8_t, N> a, Vec128<uint8_t, N> b) {
+ return Vec128<uint8_t, N>{__lsx_vmax_bu(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<uint16_t, N> Max(Vec128<uint16_t, N> a, Vec128<uint16_t, N> b) {
+ return Vec128<uint16_t, N>{__lsx_vmax_hu(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<uint32_t, N> Max(Vec128<uint32_t, N> a, Vec128<uint32_t, N> b) {
+ return Vec128<uint32_t, N>{__lsx_vmax_wu(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<uint64_t, N> Max(Vec128<uint64_t, N> a, Vec128<uint64_t, N> b) {
+ return Vec128<uint64_t, N>{__lsx_vmax_du(a.raw, b.raw)};
+}
+
+template <size_t N>
+HWY_API Vec128<int8_t, N> Max(Vec128<int8_t, N> a, Vec128<int8_t, N> b) {
+ return Vec128<int8_t, N>{__lsx_vmax_b(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<int16_t, N> Max(Vec128<int16_t, N> a, Vec128<int16_t, N> b) {
+ return Vec128<int16_t, N>{__lsx_vmax_h(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<int32_t, N> Max(Vec128<int32_t, N> a, Vec128<int32_t, N> b) {
+ return Vec128<int32_t, N>{__lsx_vmax_w(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<int64_t, N> Max(Vec128<int64_t, N> a, Vec128<int64_t, N> b) {
+ return Vec128<int64_t, N>{__lsx_vmax_d(a.raw, b.raw)};
+}
+
+template <size_t N>
+HWY_API Vec128<float, N> Max(Vec128<float, N> a, Vec128<float, N> b) {
+ return Vec128<float, N>{__lsx_vfmax_s(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<double, N> Max(Vec128<double, N> a, Vec128<double, N> b) {
+ return Vec128<double, N>{__lsx_vfmax_d(a.raw, b.raw)};
+}
+
+// ------------------------------ MinMagnitude and MaxMagnitude
+
+#ifdef HWY_NATIVE_FLOAT_MIN_MAX_MAGNITUDE
+#undef HWY_NATIVE_FLOAT_MIN_MAX_MAGNITUDE
+#else
+#define HWY_NATIVE_FLOAT_MIN_MAX_MAGNITUDE
+#endif
+
+template <size_t N>
+HWY_API Vec128<float, N> MinMagnitude(Vec128<float, N> a, Vec128<float, N> b) {
+ return Vec128<float, N>{__lsx_vfmina_s(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<double, N> MinMagnitude(Vec128<double, N> a,
+ Vec128<double, N> b) {
+ return Vec128<double, N>{__lsx_vfmina_d(a.raw, b.raw)};
+}
+
+template <size_t N>
+HWY_API Vec128<float, N> MaxMagnitude(Vec128<float, N> a, Vec128<float, N> b) {
+ return Vec128<float, N>{__lsx_vfmaxa_s(a.raw, b.raw)};
+}
+template <size_t N>
+HWY_API Vec128<double, N> MaxMagnitude(Vec128<double, N> a,
+ Vec128<double, N> b) {
+ return Vec128<double, N>{__lsx_vfmaxa_d(a.raw, b.raw)};
+}
+
+// ------------------------------ Non-temporal stores
+
+// Same as aligned stores on non-x86.
+
+template <class D>
+HWY_API void Stream(const VFromD<D> v, D d, TFromD<D>* HWY_RESTRICT aligned) {
+ __builtin_prefetch(aligned, 1, 0);
+ Store(v, d, aligned);
+}
+
+// ------------------------------ Scatter in generic_ops-inl.h
+// ------------------------------ Gather in generic_ops-inl.h
+
+// ================================================== SWIZZLE (2)
+
+// ------------------------------ LowerHalf
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 8)>
+HWY_API VFromD<D> LowerHalf(D /* tag */, VFromD<Twice<D>> v) {
+ return VFromD<D>{v.raw};
+}
+template <typename T, size_t N>
+HWY_API Vec128<T, N / 2> LowerHalf(Vec128<T, N> v) {
+ return Vec128<T, N / 2>{v.raw};
+}
+
+// ------------------------------ ShiftLeftBytes
+
+template <int kBytes, class D, HWY_IF_V_SIZE_LE_D(D, 16)>
+HWY_API VFromD<D> ShiftLeftBytes(D d, VFromD<D> v) {
+ static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes");
+ if (kBytes == 0) return v;
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(
+ d, VFromD<decltype(du)>{__lsx_vbsll_v(BitCast(du, v).raw, kBytes)});
+}
+
+// Generic for all vector lengths.
+template <int kBytes, class V>
+HWY_API V ShiftLeftBytes(const V v) {
+ return ShiftLeftBytes<kBytes>(DFromV<decltype(v)>(), v);
+}
+
+// ------------------------------ ShiftLeftLanes
+
+// Generic for all vector lengths.
+template <int kLanes, class D>
+HWY_API VFromD<D> ShiftLeftLanes(D d, const VFromD<D> v) {
+ const Repartition<uint8_t, decltype(d)> d8;
+ return BitCast(d, ShiftLeftBytes<kLanes * sizeof(TFromD<D>)>(BitCast(d8, v)));
+}
+
+// Generic for all vector lengths.
+template <int kLanes, class V>
+HWY_API V ShiftLeftLanes(const V v) {
+ return ShiftLeftLanes<kLanes>(DFromV<decltype(v)>(), v);
+}
+
+// ------------------------------ ShiftRightBytes
+template <int kBytes, class D, HWY_IF_V_SIZE_LE_D(D, 16)>
+HWY_API VFromD<D> ShiftRightBytes(D d, VFromD<D> v) {
+ static_assert(0 <= kBytes && kBytes <= 16, "Invalid kBytes");
+ if (kBytes == 0) return v;
+ const RebindToUnsigned<decltype(d)> du;
+ // For partial vectors, clear upper lanes so we shift in zeros.
+ if (d.MaxBytes() != 16) {
+ const Full128<TFromD<D>> dfull;
+ const VFromD<decltype(dfull)> vfull{v.raw};
+ v = VFromD<D>{IfThenElseZero(FirstN(dfull, MaxLanes(d)), vfull).raw};
+ }
+ return BitCast(
+ d, VFromD<decltype(du)>{__lsx_vbsrl_v(BitCast(du, v).raw, kBytes)});
+}
+
+// ------------------------------ ShiftRightLanes
+// Generic for all vector lengths.
+template <int kLanes, class D>
+HWY_API VFromD<D> ShiftRightLanes(D d, const VFromD<D> v) {
+ const Repartition<uint8_t, decltype(d)> d8;
+ constexpr size_t kBytes = kLanes * sizeof(TFromD<D>);
+ return BitCast(d, ShiftRightBytes<kBytes>(d8, BitCast(d8, v)));
+}
+
+// ------------------------------ UpperHalf (ShiftRightBytes)
+
+template <class D, HWY_IF_V_SIZE_D(D, 8)>
+HWY_API VFromD<D> UpperHalf(D d, VFromD<Twice<D>> v) {
+ const Twice<RebindToUnsigned<decltype(d)>> dut;
+ using VUT = VFromD<decltype(dut)>; // for float16_t
+ const VUT vut = BitCast(dut, v);
+ return BitCast(d, LowerHalf(VUT{__lsx_vilvh_d(vut.raw, vut.raw)}));
+}
+
+// Partial
+template <class D, HWY_IF_V_SIZE_LE_D(D, 4)>
+HWY_API VFromD<D> UpperHalf(D d, VFromD<Twice<D>> v) {
+ return LowerHalf(d, ShiftRightBytes<d.MaxBytes()>(Twice<D>(), v));
+}
+
+// ------------------------------ ExtractLane (UpperHalf)
+
+namespace detail {
+
+template <size_t kLane, typename T, size_t N, HWY_IF_T_SIZE(T, 1)>
+HWY_INLINE T ExtractLane(const Vec128<T, N> v) {
+ static_assert(kLane < N, "Lane index out of bounds");
+ return static_cast<T>(__lsx_vpickve2gr_b(v.raw, kLane) & 0xFF);
+}
+
+template <size_t kLane, typename T, size_t N, HWY_IF_T_SIZE(T, 2)>
+HWY_INLINE T ExtractLane(const Vec128<T, N> v) {
+ static_assert(kLane < N, "Lane index out of bounds");
+ const DFromV<decltype(v)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ const uint16_t lane = static_cast<uint16_t>(
+ __lsx_vpickve2gr_hu(BitCast(du, v).raw, kLane) & 0xFFFF);
+ return BitCastScalar<T>(lane);
+}
+
+template <size_t kLane, typename T, size_t N, HWY_IF_T_SIZE(T, 4)>
+HWY_INLINE T ExtractLane(const Vec128<T, N> v) {
+ static_assert(kLane < N, "Lane index out of bounds");
+ return static_cast<T>(__lsx_vpickve2gr_w(v.raw, kLane));
+}
+
+template <size_t kLane, typename T, size_t N, HWY_IF_T_SIZE(T, 8)>
+HWY_INLINE T ExtractLane(const Vec128<T, N> v) {
+ static_assert(kLane < N, "Lane index out of bounds");
+ return static_cast<T>(__lsx_vpickve2gr_d(v.raw, kLane));
+}
+
+template <size_t kLane, size_t N>
+HWY_INLINE float ExtractLane(const Vec128<float, N> v) {
+ float f32;
+ int32_t i32 = __lsx_vpickve2gr_w(reinterpret_cast<__m128i>(v.raw), kLane);
+ CopyBytes<4>(&i32, &f32);
+ return f32;
+}
+template <size_t kLane, size_t N>
+HWY_INLINE double ExtractLane(const Vec128<double, N> v) {
+ double f64;
+ int64_t i64 = __lsx_vpickve2gr_d(reinterpret_cast<__m128i>(v.raw), kLane);
+ CopyBytes<8>(&i64, &f64);
+ return f64;
+}
+
+} // namespace detail
+
+template <typename T>
+HWY_API T ExtractLane(const Vec128<T, 1> v, size_t i) {
+ HWY_DASSERT(i == 0);
+ (void)i;
+ return GetLane(v);
+}
+
+template <typename T>
+HWY_API T ExtractLane(const Vec128<T, 2> v, size_t i) {
+#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang
+ if (__builtin_constant_p(i)) {
+ switch (i) {
+ case 0:
+ return detail::ExtractLane<0>(v);
+ case 1:
+ return detail::ExtractLane<1>(v);
+ }
+ }
+#endif
+ alignas(16) T lanes[2];
+ Store(v, DFromV<decltype(v)>(), lanes);
+ return lanes[i];
+}
+
+template <typename T>
+HWY_API T ExtractLane(const Vec128<T, 4> v, size_t i) {
+#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang
+ if (__builtin_constant_p(i)) {
+ switch (i) {
+ case 0:
+ return detail::ExtractLane<0>(v);
+ case 1:
+ return detail::ExtractLane<1>(v);
+ case 2:
+ return detail::ExtractLane<2>(v);
+ case 3:
+ return detail::ExtractLane<3>(v);
+ }
+ }
+#endif
+ alignas(16) T lanes[4];
+ Store(v, DFromV<decltype(v)>(), lanes);
+ return lanes[i];
+}
+
+template <typename T>
+HWY_API T ExtractLane(const Vec128<T, 8> v, size_t i) {
+#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang
+ if (__builtin_constant_p(i)) {
+ switch (i) {
+ case 0:
+ return detail::ExtractLane<0>(v);
+ case 1:
+ return detail::ExtractLane<1>(v);
+ case 2:
+ return detail::ExtractLane<2>(v);
+ case 3:
+ return detail::ExtractLane<3>(v);
+ case 4:
+ return detail::ExtractLane<4>(v);
+ case 5:
+ return detail::ExtractLane<5>(v);
+ case 6:
+ return detail::ExtractLane<6>(v);
+ case 7:
+ return detail::ExtractLane<7>(v);
+ }
+ }
+#endif
+ alignas(16) T lanes[8];
+ Store(v, DFromV<decltype(v)>(), lanes);
+ return lanes[i];
+}
+
+template <typename T>
+HWY_API T ExtractLane(const Vec128<T, 16> v, size_t i) {
+#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang
+ if (__builtin_constant_p(i)) {
+ switch (i) {
+ case 0:
+ return detail::ExtractLane<0>(v);
+ case 1:
+ return detail::ExtractLane<1>(v);
+ case 2:
+ return detail::ExtractLane<2>(v);
+ case 3:
+ return detail::ExtractLane<3>(v);
+ case 4:
+ return detail::ExtractLane<4>(v);
+ case 5:
+ return detail::ExtractLane<5>(v);
+ case 6:
+ return detail::ExtractLane<6>(v);
+ case 7:
+ return detail::ExtractLane<7>(v);
+ case 8:
+ return detail::ExtractLane<8>(v);
+ case 9:
+ return detail::ExtractLane<9>(v);
+ case 10:
+ return detail::ExtractLane<10>(v);
+ case 11:
+ return detail::ExtractLane<11>(v);
+ case 12:
+ return detail::ExtractLane<12>(v);
+ case 13:
+ return detail::ExtractLane<13>(v);
+ case 14:
+ return detail::ExtractLane<14>(v);
+ case 15:
+ return detail::ExtractLane<15>(v);
+ }
+ }
+#endif
+ alignas(16) T lanes[16];
+ Store(v, DFromV<decltype(v)>(), lanes);
+ return lanes[i];
+}
+
+// ------------------------------ InsertLane (UpperHalf)
+
+namespace detail {
+
+template <class V>
+HWY_INLINE V InsertLaneUsingBroadcastAndBlend(V v, size_t i, TFromV<V> t) {
+ const DFromV<decltype(v)> d;
+
+#if HWY_TARGET <= HWY_AVX3
+ using RawMask = decltype(MaskFromVec(VFromD<decltype(d)>()).raw);
+ const auto mask = MFromD<decltype(d)>{static_cast<RawMask>(uint64_t{1} << i)};
+#else
+ const RebindToUnsigned<decltype(d)> du;
+ using TU = TFromD<decltype(du)>;
+ const auto mask = RebindMask(d, Iota(du, 0) == Set(du, static_cast<TU>(i)));
+#endif
+
+ return IfThenElse(mask, Set(d, t), v);
+}
+
+template <size_t kLane, typename T, size_t N, HWY_IF_T_SIZE(T, 1)>
+HWY_INLINE Vec128<T, N> InsertLane(const Vec128<T, N> v, T t) {
+ static_assert(kLane < N, "Lane index out of bounds");
+ return Vec128<T, N>{__lsx_vinsgr2vr_b(v.raw, t, kLane)};
+}
+
+template <size_t kLane, typename T, size_t N, HWY_IF_T_SIZE(T, 2)>
+HWY_INLINE Vec128<T, N> InsertLane(const Vec128<T, N> v, T t) {
+ static_assert(kLane < N, "Lane index out of bounds");
+ const DFromV<decltype(v)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ const uint16_t bits = BitCastScalar<uint16_t>(t);
+ return BitCast(d, VFromD<decltype(du)>{
+ __lsx_vinsgr2vr_h(BitCast(du, v).raw, bits, kLane)});
+}
+template <size_t kLane, typename T, size_t N, HWY_IF_UI32(T)>
+HWY_INLINE Vec128<T, N> InsertLane(const Vec128<T, N> v, T t) {
+ static_assert(kLane < N, "Lane index out of bounds");
+ return Vec128<T, N>{__lsx_vinsgr2vr_w(v.raw, t, kLane)};
+}
+template <size_t kLane, typename T, size_t N, HWY_IF_UI64(T)>
+HWY_INLINE Vec128<T, N> InsertLane(const Vec128<T, N> v, T t) {
+ static_assert(kLane < N, "Lane index out of bounds");
+ return Vec128<T, N>{__lsx_vinsgr2vr_d(v.raw, t, kLane)};
+}
+
+template <size_t kLane, size_t N>
+HWY_INLINE Vec128<float, N> InsertLane(const Vec128<float, N> v, float t) {
+ static_assert(kLane < N, "Lane index out of bounds");
+ const DFromV<decltype(v)> d;
+ int ti = BitCastScalar<int>(t);
+ RebindToUnsigned<decltype(d)> du;
+ return BitCast(d, VFromD<decltype(du)>{__lsx_vinsgr2vr_w(
+ reinterpret_cast<__m128i>(v.raw), ti, kLane)});
+}
+
+template <size_t kLane>
+HWY_INLINE Vec128<double> InsertLane(const Vec128<double> v, double t) {
+ static_assert(kLane < 2, "Lane index out of bounds");
+ const DFromV<decltype(v)> d;
+ long int ti = BitCastScalar<long int>(t);
+ RebindToUnsigned<decltype(d)> du;
+ return BitCast(d, VFromD<decltype(du)>{__lsx_vinsgr2vr_d(
+ reinterpret_cast<__m128i>(v.raw), ti, kLane)});
+}
+
+} // namespace detail
+
+template <typename T>
+HWY_API Vec128<T, 1> InsertLane(const Vec128<T, 1> v, size_t i, T t) {
+ HWY_DASSERT(i == 0);
+ (void)i;
+ return Set(DFromV<decltype(v)>(), t);
+}
+
+template <typename T>
+HWY_API Vec128<T, 2> InsertLane(const Vec128<T, 2> v, size_t i, T t) {
+#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang
+ if (__builtin_constant_p(i)) {
+ switch (i) {
+ case 0:
+ return detail::InsertLane<0>(v, t);
+ case 1:
+ return detail::InsertLane<1>(v, t);
+ }
+ }
+#endif
+ return detail::InsertLaneUsingBroadcastAndBlend(v, i, t);
+}
+
+template <typename T>
+HWY_API Vec128<T, 4> InsertLane(const Vec128<T, 4> v, size_t i, T t) {
+#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang
+ if (__builtin_constant_p(i)) {
+ switch (i) {
+ case 0:
+ return detail::InsertLane<0>(v, t);
+ case 1:
+ return detail::InsertLane<1>(v, t);
+ case 2:
+ return detail::InsertLane<2>(v, t);
+ case 3:
+ return detail::InsertLane<3>(v, t);
+ }
+ }
+#endif
+ return detail::InsertLaneUsingBroadcastAndBlend(v, i, t);
+}
+
+template <typename T>
+HWY_API Vec128<T, 8> InsertLane(const Vec128<T, 8> v, size_t i, T t) {
+#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang
+ if (__builtin_constant_p(i)) {
+ switch (i) {
+ case 0:
+ return detail::InsertLane<0>(v, t);
+ case 1:
+ return detail::InsertLane<1>(v, t);
+ case 2:
+ return detail::InsertLane<2>(v, t);
+ case 3:
+ return detail::InsertLane<3>(v, t);
+ case 4:
+ return detail::InsertLane<4>(v, t);
+ case 5:
+ return detail::InsertLane<5>(v, t);
+ case 6:
+ return detail::InsertLane<6>(v, t);
+ case 7:
+ return detail::InsertLane<7>(v, t);
+ }
+ }
+#endif
+ return detail::InsertLaneUsingBroadcastAndBlend(v, i, t);
+}
+
+template <typename T>
+HWY_API Vec128<T, 16> InsertLane(const Vec128<T, 16> v, size_t i, T t) {
+#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang
+ if (__builtin_constant_p(i)) {
+ switch (i) {
+ case 0:
+ return detail::InsertLane<0>(v, t);
+ case 1:
+ return detail::InsertLane<1>(v, t);
+ case 2:
+ return detail::InsertLane<2>(v, t);
+ case 3:
+ return detail::InsertLane<3>(v, t);
+ case 4:
+ return detail::InsertLane<4>(v, t);
+ case 5:
+ return detail::InsertLane<5>(v, t);
+ case 6:
+ return detail::InsertLane<6>(v, t);
+ case 7:
+ return detail::InsertLane<7>(v, t);
+ case 8:
+ return detail::InsertLane<8>(v, t);
+ case 9:
+ return detail::InsertLane<9>(v, t);
+ case 10:
+ return detail::InsertLane<10>(v, t);
+ case 11:
+ return detail::InsertLane<11>(v, t);
+ case 12:
+ return detail::InsertLane<12>(v, t);
+ case 13:
+ return detail::InsertLane<13>(v, t);
+ case 14:
+ return detail::InsertLane<14>(v, t);
+ case 15:
+ return detail::InsertLane<15>(v, t);
+ }
+ }
+#endif
+ return detail::InsertLaneUsingBroadcastAndBlend(v, i, t);
+}
+
+// ------------------------------ CombineShiftRightBytes
+template <int kBytes, class D, HWY_IF_V_SIZE_D(D, 16)>
+HWY_API VFromD<D> CombineShiftRightBytes(D d, VFromD<D> hi, VFromD<D> lo) {
+ static_assert(0 < kBytes && kBytes < 16, "kBytes invalid");
+ return Or(ShiftRightBytes<kBytes>(d, lo), ShiftLeftBytes<16 - kBytes>(d, hi));
+}
+template <int kBytes, class D, HWY_IF_V_SIZE_LE_D(D, 8)>
+HWY_API VFromD<D> CombineShiftRightBytes(D d, VFromD<D> hi, VFromD<D> lo) {
+ constexpr size_t kSize = d.MaxBytes();
+ static_assert(0 < kBytes && kBytes < kSize, "kBytes invalid");
+
+ const Twice<decltype(d)> dt;
+ return VFromD<D>{ShiftRightBytes<kBytes>(dt, Combine(dt, hi, lo)).raw};
+}
+
+// ------------------------------ Broadcast/splat any lane
+
+template <int kLane, typename T, size_t N, HWY_IF_UI8(T)>
+HWY_API Vec128<T, N> Broadcast(Vec128<T, N> v) {
+ static_assert(0 <= kLane && kLane < N, "Invalid lane");
+ return Vec128<T, N>{__lsx_vreplvei_b(v.raw, kLane)};
+}
+template <int kLane, typename T, size_t N, HWY_IF_UI16(T)>
+HWY_API Vec128<T, N> Broadcast(Vec128<T, N> v) {
+ static_assert(0 <= kLane && kLane < N, "Invalid lane");
+ return Vec128<T, N>{__lsx_vreplvei_h(v.raw, kLane)};
+}
+template <int kLane, typename T, size_t N, HWY_IF_T_SIZE(T, 4)>
+HWY_API Vec128<T, N> Broadcast(Vec128<T, N> v) {
+ static_assert(0 <= kLane && kLane < N, "Invalid lane");
+ const DFromV<decltype(v)> d;
+ return BitCast(d, Vec128<int32_t, N>{__lsx_vreplvei_w(
+ reinterpret_cast<__m128i>(v.raw), kLane)});
+}
+template <int kLane, typename T, size_t N, HWY_IF_T_SIZE(T, 8)>
+HWY_API Vec128<T, N> Broadcast(Vec128<T, N> v) {
+ static_assert(0 <= kLane && kLane < N, "Invalid lane");
+ const DFromV<decltype(v)> d;
+ return BitCast(d, Vec128<int64_t, N>{__lsx_vreplvei_d(
+ reinterpret_cast<__m128i>(v.raw), kLane)});
+}
+
+// ------------------------------ TableLookupLanes (Shuffle01)
+
+// Returned by SetTableIndices/IndicesFromVec for use by TableLookupLanes.
+template <typename T, size_t N = 16 / sizeof(T)>
+struct Indices128 {
+ __m128i raw;
+};
+
+namespace detail {
+
+template <class D, HWY_IF_T_SIZE_D(D, 1)>
+HWY_INLINE VFromD<Repartition<uint8_t, D>> IndicesFromVecBroadcastLaneBytes(
+ D d) {
+ const Repartition<uint8_t, decltype(d)> d8;
+ return Iota(d8, 0);
+}
+
+template <class D, HWY_IF_T_SIZE_D(D, 2)>
+HWY_INLINE VFromD<Repartition<uint8_t, D>> IndicesFromVecBroadcastLaneBytes(
+ D d) {
+ const Repartition<uint8_t, decltype(d)> d8;
+ alignas(16) static constexpr uint8_t kBroadcastLaneBytes[16] = {
+ 0, 0, 2, 2, 4, 4, 6, 6, 8, 8, 10, 10, 12, 12, 14, 14};
+ return Load(d8, kBroadcastLaneBytes);
+}
+
+template <class D, HWY_IF_T_SIZE_D(D, 4)>
+HWY_INLINE VFromD<Repartition<uint8_t, D>> IndicesFromVecBroadcastLaneBytes(
+ D d) {
+ const Repartition<uint8_t, decltype(d)> d8;
+ alignas(16) static constexpr uint8_t kBroadcastLaneBytes[16] = {
+ 0, 0, 0, 0, 4, 4, 4, 4, 8, 8, 8, 8, 12, 12, 12, 12};
+ return Load(d8, kBroadcastLaneBytes);
+}
+
+template <class D, HWY_IF_T_SIZE_D(D, 8)>
+HWY_INLINE VFromD<Repartition<uint8_t, D>> IndicesFromVecBroadcastLaneBytes(
+ D d) {
+ const Repartition<uint8_t, decltype(d)> d8;
+ alignas(16) static constexpr uint8_t kBroadcastLaneBytes[16] = {
+ 0, 0, 0, 0, 0, 0, 0, 0, 8, 8, 8, 8, 8, 8, 8, 8};
+ return Load(d8, kBroadcastLaneBytes);
+}
+
+template <class D, HWY_IF_T_SIZE_D(D, 1)>
+HWY_INLINE VFromD<Repartition<uint8_t, D>> IndicesFromVecByteOffsets(D d) {
+ const Repartition<uint8_t, decltype(d)> d8;
+ return Zero(d8);
+}
+
+template <class D, HWY_IF_T_SIZE_D(D, 2)>
+HWY_INLINE VFromD<Repartition<uint8_t, D>> IndicesFromVecByteOffsets(D d) {
+ const Repartition<uint8_t, decltype(d)> d8;
+ alignas(16) static constexpr uint8_t kByteOffsets[16] = {
+ 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1};
+ return Load(d8, kByteOffsets);
+}
+
+template <class D, HWY_IF_T_SIZE_D(D, 4)>
+HWY_INLINE VFromD<Repartition<uint8_t, D>> IndicesFromVecByteOffsets(D d) {
+ const Repartition<uint8_t, decltype(d)> d8;
+ alignas(16) static constexpr uint8_t kByteOffsets[16] = {
+ 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3};
+ return Load(d8, kByteOffsets);
+}
+
+template <class D, HWY_IF_T_SIZE_D(D, 8)>
+HWY_INLINE VFromD<Repartition<uint8_t, D>> IndicesFromVecByteOffsets(D d) {
+ const Repartition<uint8_t, decltype(d)> d8;
+ alignas(16) static constexpr uint8_t kByteOffsets[16] = {
+ 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7};
+ return Load(d8, kByteOffsets);
+}
+
+} // namespace detail
+
+template <class D, typename TI, HWY_IF_T_SIZE_D(D, 1)>
+HWY_API Indices128<TFromD<D>, MaxLanes(D())> IndicesFromVec(
+ D d, Vec128<TI, MaxLanes(D())> vec) {
+ using T = TFromD<D>;
+ static_assert(sizeof(T) == sizeof(TI), "Index size must match lane");
+#if HWY_IS_DEBUG_BUILD
+ const RebindToUnsigned<decltype(d)> du;
+ using TU = TFromD<decltype(du)>;
+ HWY_DASSERT(AllTrue(
+ du, Lt(BitCast(du, vec), Set(du, static_cast<TU>(MaxLanes(d) * 2)))));
+#endif
+
+ (void)d;
+ return Indices128<TFromD<D>, MaxLanes(D())>{BitCast(d, vec).raw};
+}
+
+template <class D, typename TI,
+ HWY_IF_T_SIZE_ONE_OF_D(D, (1 << 2) | (1 << 4) | (1 << 8))>
+HWY_API Indices128<TFromD<D>, MaxLanes(D())> IndicesFromVec(
+ D d, Vec128<TI, MaxLanes(D())> vec) {
+ using T = TFromD<D>;
+ static_assert(sizeof(T) == sizeof(TI), "Index size must match lane");
+#if HWY_IS_DEBUG_BUILD
+ const RebindToUnsigned<decltype(d)> du;
+ using TU = TFromD<decltype(du)>;
+ HWY_DASSERT(AllTrue(
+ du, Lt(BitCast(du, vec), Set(du, static_cast<TU>(MaxLanes(d) * 2)))));
+#endif
+
+ const Repartition<uint8_t, decltype(d)> d8;
+ using V8 = VFromD<decltype(d8)>;
+
+ // Broadcast each lane index to all bytes of T and shift to bytes
+ const V8 lane_indices = TableLookupBytes(
+ BitCast(d8, vec), detail::IndicesFromVecBroadcastLaneBytes(d));
+ constexpr int kIndexShiftAmt = static_cast<int>(FloorLog2(sizeof(T)));
+ const V8 byte_indices = ShiftLeft<kIndexShiftAmt>(lane_indices);
+ const V8 sum = Add(byte_indices, detail::IndicesFromVecByteOffsets(d));
+ return Indices128<TFromD<D>, MaxLanes(D())>{sum.raw};
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), typename TI>
+HWY_API Indices128<TFromD<D>, MaxLanes(D())> SetTableIndices(D d,
+ const TI* idx) {
+ const Rebind<TI, decltype(d)> di;
+ return IndicesFromVec(d, LoadU(di, idx));
+}
+
+template <typename T, size_t N, HWY_IF_NOT_SPECIAL_FLOAT(T)>
+HWY_API Vec128<T, N> TableLookupLanes(Vec128<T, N> v, Indices128<T, N> idx) {
+ using TI = MakeSigned<T>;
+ const DFromV<decltype(v)> d;
+ const Rebind<TI, decltype(d)> di;
+ auto t1 = TableLookupBytes(BitCast(di, v), Vec128<TI, N>{idx.raw});
+ return BitCast(d, t1);
+}
+
+// Single lane: no change
+template <typename T>
+HWY_API Vec128<T, 1> TableLookupLanes(Vec128<T, 1> v,
+ Indices128<T, 1> /* idx */) {
+ return v;
+}
+
+// ------------------------------ ReverseBlocks
+
+// Single block: no change
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16)>
+HWY_API VFromD<D> ReverseBlocks(D /* tag */, VFromD<D> v) {
+ return v;
+}
+
+// ------------------------------ Reverse (Shuffle0123, Shuffle2301)
+
+// Single lane: no change
+template <class D, HWY_IF_LANES_D(D, 1)>
+HWY_API VFromD<D> Reverse(D /* tag */, VFromD<D> v) {
+ return v;
+}
+// 32-bit x2: shuffle
+template <class D, HWY_IF_V_SIZE_D(D, 8), HWY_IF_T_SIZE_D(D, 4)>
+HWY_API VFromD<D> Reverse(D /* tag */, const VFromD<D> v) {
+ return VFromD<D>{Shuffle2301(Vec128<TFromD<D>>{v.raw}).raw};
+}
+// 64-bit x2: shuffle
+template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_T_SIZE_D(D, 8)>
+HWY_API VFromD<D> Reverse(D /* tag */, const VFromD<D> v) {
+ return Shuffle01(v);
+}
+// 32-bit x4: shuffle
+template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_T_SIZE_D(D, 4)>
+HWY_API VFromD<D> Reverse(D /* tag */, const VFromD<D> v) {
+ return Shuffle0123(v);
+}
+
+// 16-bit
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_T_SIZE_D(D, 2),
+ HWY_IF_LANES_GT_D(D, 1)>
+HWY_API VFromD<D> Reverse(D d, const VFromD<D> v) {
+ const RebindToUnsigned<decltype(d)> du;
+ using VU = VFromD<decltype(du)>;
+ const VU vu = BitCast(du, v);
+ constexpr size_t kN = MaxLanes(d);
+ if (kN == 1) return v;
+ if (kN == 2) {
+ return BitCast(d, VU{__lsx_vshuf4i_h(vu.raw, 0x11)});
+ }
+ if (kN == 4) {
+ return BitCast(d, VU{__lsx_vshuf4i_h(vu.raw, 0x1B)});
+ }
+ const RebindToSigned<decltype(d)> di;
+ const VFromD<decltype(di)> shuffle = Dup128VecFromValues(
+ di, 0x0F0E, 0x0D0C, 0x0B0A, 0x0908, 0x0706, 0x0504, 0x0302, 0x0100);
+ return BitCast(d, TableLookupBytes(v, shuffle));
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_T_SIZE_D(D, 1),
+ HWY_IF_LANES_GT_D(D, 1)>
+HWY_API VFromD<D> Reverse(D d, const VFromD<D> v) {
+ static constexpr int kN = static_cast<int>(MaxLanes(d));
+ if (kN == 1) return v;
+ alignas(16) static constexpr int8_t _tmp_data[] = {
+ kN - 1, kN - 2, kN - 3, kN - 4, kN - 5, kN - 6, kN - 7, kN - 8,
+ kN - 9, kN - 10, kN - 11, kN - 12, kN - 13, kN - 14, kN - 15, kN - 16};
+ return VFromD<D>{__lsx_vshuf_b(v.raw, v.raw, __lsx_vld(_tmp_data, 0))};
+}
+
+// ------------------------------ Reverse2
+
+// Single lane: no change
+template <class D, HWY_IF_LANES_D(D, 1)>
+HWY_API VFromD<D> Reverse2(D /* tag */, VFromD<D> v) {
+ return v;
+}
+
+template <class D, HWY_IF_T_SIZE_D(D, 2)>
+HWY_API VFromD<D> Reverse2(D d, const VFromD<D> v) {
+ const RepartitionToWide<RebindToUnsigned<decltype(d)>> dw;
+ return BitCast(d, RotateRight<16>(BitCast(dw, v)));
+}
+
+// Generic for all vector lengths.
+template <class D, HWY_IF_T_SIZE_D(D, 4), HWY_IF_LANES_GT_D(D, 1)>
+HWY_API VFromD<D> Reverse2(D /* tag */, VFromD<D> v) {
+ return Shuffle2301(v);
+}
+
+// Generic for all vector lengths.
+template <class D, HWY_IF_T_SIZE_D(D, 8), HWY_IF_LANES_GT_D(D, 1)>
+HWY_API VFromD<D> Reverse2(D /* tag */, VFromD<D> v) {
+ return Shuffle01(v);
+}
+
+// ------------------------------ Reverse4
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_T_SIZE_D(D, 2)>
+HWY_API VFromD<D> Reverse4(D /* tag */, VFromD<D> v) {
+ return VFromD<D>{__lsx_vshuf4i_h(v.raw, 0x1B)};
+}
+
+// Generic for all vector lengths.
+template <class D, HWY_IF_T_SIZE_D(D, 4)>
+HWY_API VFromD<D> Reverse4(D /* tag */, const VFromD<D> v) {
+ return Shuffle0123(v);
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_T_SIZE_D(D, 8)>
+HWY_API VFromD<D> Reverse4(D /* tag */, VFromD<D> /* v */) {
+ HWY_ASSERT(0); // don't have 4 u64 lanes
+}
+
+// ------------------------------ Reverse8
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_T_SIZE_D(D, 2)>
+HWY_API VFromD<D> Reverse8(D d, const VFromD<D> v) {
+ const RepartitionToWide<decltype(d)> dw;
+ return Reverse2(d, BitCast(d, Shuffle0123(BitCast(dw, v))));
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16),
+ HWY_IF_T_SIZE_ONE_OF_D(D, (1 << 4) | (1 << 8))>
+HWY_API VFromD<D> Reverse8(D /* tag */, VFromD<D> /* v */) {
+ HWY_ASSERT(0); // don't have 8 lanes if larger than 16-bit
+}
+
+// ------------------------------ InterleaveUpper (UpperHalf)
+
+// Full
+template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_T_SIZE_D(D, 1)>
+HWY_API VFromD<D> InterleaveUpper(D /* tag */, VFromD<D> a, VFromD<D> b) {
+ return VFromD<D>{__lsx_vilvh_b(b.raw, a.raw)};
+}
+template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_T_SIZE_D(D, 2)>
+HWY_API VFromD<D> InterleaveUpper(D /* tag */, VFromD<D> a, VFromD<D> b) {
+ return VFromD<D>{__lsx_vilvh_h(b.raw, a.raw)};
+}
+template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_T_SIZE_D(D, 4)>
+HWY_API VFromD<D> InterleaveUpper(D d, VFromD<D> a, VFromD<D> b) {
+ const RebindToSigned<decltype(d)> df;
+ return BitCast(d, VFromD<decltype(df)>{
+ __lsx_vilvh_w(BitCast(df, b).raw, BitCast(df, a).raw)});
+}
+template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_T_SIZE_D(D, 8)>
+HWY_API VFromD<D> InterleaveUpper(D d, VFromD<D> a, VFromD<D> b) {
+ const RebindToSigned<decltype(d)> dd;
+ return BitCast(d, VFromD<decltype(dd)>{
+ __lsx_vilvh_d(BitCast(dd, b).raw, BitCast(dd, a).raw)});
+}
+
+// Partial
+template <class D, HWY_IF_V_SIZE_LE_D(D, 8)>
+HWY_API VFromD<D> InterleaveUpper(D d, VFromD<D> a, VFromD<D> b) {
+ const Half<decltype(d)> d2;
+ return InterleaveLower(d, VFromD<D>{UpperHalf(d2, a).raw},
+ VFromD<D>{UpperHalf(d2, b).raw});
+}
+
+// ------------------------------ ZipLower/ZipUpper (InterleaveLower)
+
+// Same as Interleave*, except that the return lanes are double-width integers;
+// this is necessary because the single-lane scalar cannot return two values.
+template <class V, class DW = RepartitionToWide<DFromV<V>>>
+HWY_API VFromD<DW> ZipLower(V a, V b) {
+ return BitCast(DW(), InterleaveLower(a, b));
+}
+template <class V, class D = DFromV<V>, class DW = RepartitionToWide<D>>
+HWY_API VFromD<DW> ZipLower(DW dw, V a, V b) {
+ return BitCast(dw, InterleaveLower(D(), a, b));
+}
+
+template <class V, class D = DFromV<V>, class DW = RepartitionToWide<D>>
+HWY_API VFromD<DW> ZipUpper(DW dw, V a, V b) {
+ return BitCast(dw, InterleaveUpper(D(), a, b));
+}
+
+// ================================================== CONVERT (1)
+
+// ------------------------------ PromoteTo unsigned
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_U16_D(D)>
+HWY_API VFromD<D> PromoteTo(D /* tag */, VFromD<Rebind<uint8_t, D>> v) {
+ return VFromD<D>{__lsx_vsllwil_hu_bu(v.raw, 0)};
+}
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_U32_D(D)>
+HWY_API VFromD<D> PromoteTo(D /* tag */, VFromD<Rebind<uint16_t, D>> v) {
+ return VFromD<D>{__lsx_vsllwil_wu_hu(v.raw, 0)};
+}
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_U64_D(D)>
+HWY_API VFromD<D> PromoteTo(D /* tag */, VFromD<Rebind<uint32_t, D>> v) {
+ return VFromD<D>{__lsx_vsllwil_du_wu(v.raw, 0)};
+}
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_U32_D(D)>
+HWY_API VFromD<D> PromoteTo(D /* tag */, VFromD<Rebind<uint8_t, D>> v) {
+ const __m128i u16 = __lsx_vsllwil_hu_bu(v.raw, 0);
+ return VFromD<D>{__lsx_vsllwil_wu_hu(u16, 0)};
+}
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_U64_D(D)>
+HWY_API VFromD<D> PromoteTo(D d, VFromD<Rebind<uint8_t, D>> v) {
+ const Rebind<uint32_t, decltype(d)> du32;
+ return PromoteTo(d, PromoteTo(du32, v));
+}
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_U64_D(D)>
+HWY_API VFromD<D> PromoteTo(D /*tag*/, VFromD<Rebind<uint16_t, D>> v) {
+ const __m128i u32 = __lsx_vsllwil_wu_hu(v.raw, 0);
+ return VFromD<D>{__lsx_vsllwil_du_wu(u32, 0)};
+}
+
+// Unsigned to signed: same plus cast.
+template <class D, class V, HWY_IF_SIGNED_D(D), HWY_IF_UNSIGNED_V(V),
+ HWY_IF_LANES_GT(sizeof(TFromD<D>), sizeof(TFromV<V>)),
+ HWY_IF_LANES_D(D, HWY_MAX_LANES_V(V))>
+HWY_API VFromD<D> PromoteTo(D di, V v) {
+ const RebindToUnsigned<decltype(di)> du;
+ return BitCast(di, PromoteTo(du, v));
+}
+
+// signed
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_I16_D(D)>
+HWY_API VFromD<D> PromoteTo(D /* tag */, VFromD<Rebind<int8_t, D>> v) {
+ return VFromD<D>{__lsx_vsllwil_h_b(v.raw, 0)};
+}
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_I32_D(D)>
+HWY_API VFromD<D> PromoteTo(D /* tag */, VFromD<Rebind<int16_t, D>> v) {
+ return VFromD<D>{__lsx_vsllwil_w_h(v.raw, 0)};
+}
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_I64_D(D)>
+HWY_API VFromD<D> PromoteTo(D /* tag */, VFromD<Rebind<int32_t, D>> v) {
+ return VFromD<D>{__lsx_vsllwil_d_w(v.raw, 0)};
+}
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_I32_D(D)>
+HWY_API VFromD<D> PromoteTo(D /* tag */, VFromD<Rebind<int8_t, D>> v) {
+ const __m128i i16 = __lsx_vsllwil_h_b(v.raw, 0);
+ return VFromD<D>{__lsx_vsllwil_w_h(i16, 0)};
+}
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_I64_D(D)>
+HWY_API VFromD<D> PromoteTo(D d, VFromD<Rebind<int8_t, D>> v) {
+ const Rebind<int32_t, decltype(d)> di32;
+ return PromoteTo(d, PromoteTo(di32, v));
+}
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_I64_D(D)>
+HWY_API VFromD<D> PromoteTo(D /*tag*/, VFromD<Rebind<int16_t, D>> v) {
+ const __m128i i32 = __lsx_vsllwil_w_h(v.raw, 0);
+ return VFromD<D>{__lsx_vsllwil_d_w(i32, 0)};
+}
+
+// -------------------- PromoteTo float
+
+#ifdef HWY_NATIVE_F16C
+#undef HWY_NATIVE_F16C
+#else
+#define HWY_NATIVE_F16C
+#endif
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_F32_D(D)>
+HWY_API VFromD<D> PromoteTo(D /* tag */, VFromD<Rebind<hwy::float16_t, D>> v) {
+ return VFromD<D>{__lsx_vfcvtl_s_h(v.raw)};
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_F64_D(D)>
+HWY_API VFromD<D> PromoteTo(D /* tag */, VFromD<Rebind<float, D>> v) {
+ return VFromD<D>{__lsx_vfcvtl_d_s(v.raw)};
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_F64_D(D)>
+HWY_API VFromD<D> PromoteTo(D /* tag */, VFromD<Rebind<int32_t, D>> v) {
+ return VFromD<D>{__lsx_vffintl_d_w(v.raw)};
+}
+
+template <class D, HWY_IF_F64_D(D), HWY_IF_V_SIZE_LE_D(D, 16)>
+HWY_API VFromD<D> PromoteTo(D df64, VFromD<Rebind<uint32_t, D>> v) {
+ const Rebind<int32_t, decltype(df64)> di32;
+ const auto i32_to_f64_result = PromoteTo(df64, BitCast(di32, v));
+ return i32_to_f64_result + IfNegativeThenElse(i32_to_f64_result,
+ Set(df64, 4294967296.0),
+ Zero(df64));
+}
+
+template <class D, HWY_IF_F32_D(D)>
+HWY_API VFromD<D> PromoteTo(D d, VFromD<Rebind<hwy::bfloat16_t, D>> v) {
+ const RebindToSigned<decltype(d)> di32;
+ const Rebind<uint16_t, decltype(d)> du16;
+ return BitCast(d, ShiftLeft<16>(PromoteTo(di32, BitCast(du16, v))));
+}
+
+// ------------------------------ Per4LaneBlockShuffle
+
+namespace detail {
+
+#ifdef HWY_NATIVE_PER4LANEBLKSHUF_DUP32
+#undef HWY_NATIVE_PER4LANEBLKSHUF_DUP32
+#else
+#define HWY_NATIVE_PER4LANEBLKSHUF_DUP32
+#endif
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16)>
+HWY_INLINE VFromD<D> Per4LaneBlkShufDupSet4xU32(D d, const uint32_t x3,
+ const uint32_t x2,
+ const uint32_t x1,
+ const uint32_t x0) {
+ typedef uint32_t GccU32RawVectType __attribute__((__vector_size__(16)));
+ const GccU32RawVectType raw = {x0, x1, x2, x3};
+ return ResizeBitCast(d, Vec128<uint32_t>{reinterpret_cast<__m128i>(raw)});
+}
+
+template <size_t kIdx3210, size_t kVectSize, class V,
+ HWY_IF_LANES_LE(kVectSize, 16)>
+HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<kIdx3210> /*idx_3210_tag*/,
+ hwy::SizeTag<1> /*lane_size_tag*/,
+ hwy::SizeTag<kVectSize> /*vect_size_tag*/,
+ V v) {
+ constexpr int kShuffle = static_cast<int>(kIdx3210 & 0xFF);
+ return V{__lsx_vshuf4i_b(v.raw, kShuffle)};
+}
+
+template <size_t kIdx3210, size_t kVectSize, class V,
+ HWY_IF_LANES_LE(kVectSize, 16)>
+HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<kIdx3210> /*idx_3210_tag*/,
+ hwy::SizeTag<2> /*lane_size_tag*/,
+ hwy::SizeTag<kVectSize> /*vect_size_tag*/,
+ V v) {
+ const DFromV<decltype(v)> d;
+ const RebindToUnsigned<decltype(d)> du; // for float16_t
+ constexpr int kShuffle = static_cast<int>(kIdx3210 & 0xFF);
+ return BitCast(
+ d, VFromD<decltype(du)>{__lsx_vshuf4i_h(BitCast(du, v).raw, kShuffle)});
+}
+
+template <size_t kIdx3210, class V>
+HWY_INLINE V Per4LaneBlockShuffle(hwy::SizeTag<kIdx3210> /*idx_3210_tag*/,
+ hwy::SizeTag<4> /*lane_size_tag*/,
+ hwy::SizeTag<16> /*vect_size_tag*/, V v) {
+ const DFromV<decltype(v)> d;
+ constexpr int kShuffle = static_cast<int>(kIdx3210 & 0xFF);
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(d, VFromD<decltype(du)>{__lsx_vshuf4i_w(
+ reinterpret_cast<__m128i>(v.raw), kShuffle)});
+}
+
+} // namespace detail
+
+// ------------------------------ SlideUpLanes
+
+namespace detail {
+
+template <class V, HWY_IF_V_SIZE_LE_V(V, 8)>
+HWY_INLINE V SlideUpLanes(V v, size_t amt) {
+ const DFromV<decltype(v)> d;
+ const Full64<uint64_t> du64;
+ const auto vu64 = ResizeBitCast(du64, v);
+ return ResizeBitCast(
+ d, ShiftLeftSame(vu64, static_cast<int>(amt * sizeof(TFromV<V>) * 8)));
+}
+
+template <class V, HWY_IF_V_SIZE_V(V, 16)>
+HWY_INLINE V SlideUpLanes(V v, size_t amt) {
+ const DFromV<decltype(v)> d;
+ const Repartition<uint8_t, decltype(d)> du8;
+ const auto idx =
+ Iota(du8, static_cast<uint8_t>(size_t{0} - amt * sizeof(TFromV<V>)));
+ return BitCast(d, TableLookupBytesOr0(BitCast(du8, v), idx));
+}
+
+} // namespace detail
+
+template <class D, HWY_IF_LANES_D(D, 1)>
+HWY_API VFromD<D> SlideUpLanes(D /*d*/, VFromD<D> v, size_t /*amt*/) {
+ return v;
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_LANES_D(D, 2)>
+HWY_API VFromD<D> SlideUpLanes(D d, VFromD<D> v, size_t amt) {
+#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang
+ if (__builtin_constant_p(amt)) {
+ switch (amt) {
+ case 0:
+ return v;
+ case 1:
+ return ShiftLeftLanes<1>(d, v);
+ }
+ }
+#else
+ (void)d;
+#endif
+
+ return detail::SlideUpLanes(v, amt);
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_LANES_D(D, 4)>
+HWY_API VFromD<D> SlideUpLanes(D d, VFromD<D> v, size_t amt) {
+#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang
+ if (__builtin_constant_p(amt)) {
+ switch (amt) {
+ case 0:
+ return v;
+ case 1:
+ return ShiftLeftLanes<1>(d, v);
+ case 2:
+ return ShiftLeftLanes<2>(d, v);
+ case 3:
+ return ShiftLeftLanes<3>(d, v);
+ }
+ }
+#else
+ (void)d;
+#endif
+
+ return detail::SlideUpLanes(v, amt);
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_LANES_D(D, 8)>
+HWY_API VFromD<D> SlideUpLanes(D d, VFromD<D> v, size_t amt) {
+#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang
+ if (__builtin_constant_p(amt)) {
+ switch (amt) {
+ case 0:
+ return v;
+ case 1:
+ return ShiftLeftLanes<1>(d, v);
+ case 2:
+ return ShiftLeftLanes<2>(d, v);
+ case 3:
+ return ShiftLeftLanes<3>(d, v);
+ case 4:
+ return ShiftLeftLanes<4>(d, v);
+ case 5:
+ return ShiftLeftLanes<5>(d, v);
+ case 6:
+ return ShiftLeftLanes<6>(d, v);
+ case 7:
+ return ShiftLeftLanes<7>(d, v);
+ }
+ }
+#else
+ (void)d;
+#endif
+
+ return detail::SlideUpLanes(v, amt);
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_LANES_D(D, 16)>
+HWY_API VFromD<D> SlideUpLanes(D d, VFromD<D> v, size_t amt) {
+#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang
+ if (__builtin_constant_p(amt)) {
+ switch (amt) {
+ case 0:
+ return v;
+ case 1:
+ return ShiftLeftLanes<1>(d, v);
+ case 2:
+ return ShiftLeftLanes<2>(d, v);
+ case 3:
+ return ShiftLeftLanes<3>(d, v);
+ case 4:
+ return ShiftLeftLanes<4>(d, v);
+ case 5:
+ return ShiftLeftLanes<5>(d, v);
+ case 6:
+ return ShiftLeftLanes<6>(d, v);
+ case 7:
+ return ShiftLeftLanes<7>(d, v);
+ case 8:
+ return ShiftLeftLanes<8>(d, v);
+ case 9:
+ return ShiftLeftLanes<9>(d, v);
+ case 10:
+ return ShiftLeftLanes<10>(d, v);
+ case 11:
+ return ShiftLeftLanes<11>(d, v);
+ case 12:
+ return ShiftLeftLanes<12>(d, v);
+ case 13:
+ return ShiftLeftLanes<13>(d, v);
+ case 14:
+ return ShiftLeftLanes<14>(d, v);
+ case 15:
+ return ShiftLeftLanes<15>(d, v);
+ }
+ }
+#else
+ (void)d;
+#endif
+
+ return detail::SlideUpLanes(v, amt);
+}
+
+// ------------------------------ SlideDownLanes
+
+namespace detail {
+
+template <class V, HWY_IF_V_SIZE_LE_V(V, 8)>
+HWY_INLINE V SlideDownLanes(V v, size_t amt) {
+ const DFromV<decltype(v)> d;
+ const Repartition<UnsignedFromSize<d.MaxBytes()>, decltype(d)> dv;
+ return BitCast(d,
+ ShiftRightSame(BitCast(dv, v),
+ static_cast<int>(amt * sizeof(TFromV<V>) * 8)));
+}
+
+template <class V, HWY_IF_V_SIZE_V(V, 16)>
+HWY_INLINE V SlideDownLanes(V v, size_t amt) {
+ const DFromV<decltype(v)> d;
+ const Repartition<int8_t, decltype(d)> di8;
+ auto idx = Iota(di8, static_cast<int8_t>(amt * sizeof(TFromV<V>)));
+ idx = Or(idx, VecFromMask(di8, idx > Set(di8, int8_t{15})));
+ return BitCast(d, TableLookupBytesOr0(BitCast(di8, v), idx));
+}
+
+} // namespace detail
+
+template <class D, HWY_IF_LANES_D(D, 1)>
+HWY_API VFromD<D> SlideDownLanes(D /*d*/, VFromD<D> v, size_t /*amt*/) {
+ return v;
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_LANES_D(D, 2)>
+HWY_API VFromD<D> SlideDownLanes(D d, VFromD<D> v, size_t amt) {
+#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang
+ if (__builtin_constant_p(amt)) {
+ switch (amt) {
+ case 0:
+ return v;
+ case 1:
+ return ShiftRightLanes<1>(d, v);
+ }
+ }
+#else
+ (void)d;
+#endif
+
+ return detail::SlideDownLanes(v, amt);
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_LANES_D(D, 4)>
+HWY_API VFromD<D> SlideDownLanes(D d, VFromD<D> v, size_t amt) {
+#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang
+ if (__builtin_constant_p(amt)) {
+ switch (amt) {
+ case 0:
+ return v;
+ case 1:
+ return ShiftRightLanes<1>(d, v);
+ case 2:
+ return ShiftRightLanes<2>(d, v);
+ case 3:
+ return ShiftRightLanes<3>(d, v);
+ }
+ }
+#else
+ (void)d;
+#endif
+
+ return detail::SlideDownLanes(v, amt);
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_LANES_D(D, 8)>
+HWY_API VFromD<D> SlideDownLanes(D d, VFromD<D> v, size_t amt) {
+#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang
+ if (__builtin_constant_p(amt)) {
+ switch (amt) {
+ case 0:
+ return v;
+ case 1:
+ return ShiftRightLanes<1>(d, v);
+ case 2:
+ return ShiftRightLanes<2>(d, v);
+ case 3:
+ return ShiftRightLanes<3>(d, v);
+ case 4:
+ return ShiftRightLanes<4>(d, v);
+ case 5:
+ return ShiftRightLanes<5>(d, v);
+ case 6:
+ return ShiftRightLanes<6>(d, v);
+ case 7:
+ return ShiftRightLanes<7>(d, v);
+ }
+ }
+#else
+ (void)d;
+#endif
+
+ return detail::SlideDownLanes(v, amt);
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_LANES_D(D, 16)>
+HWY_API VFromD<D> SlideDownLanes(D d, VFromD<D> v, size_t amt) {
+#if !HWY_IS_DEBUG_BUILD && HWY_COMPILER_GCC // includes clang
+ if (__builtin_constant_p(amt)) {
+ switch (amt) {
+ case 0:
+ return v;
+ case 1:
+ return ShiftRightLanes<1>(d, v);
+ case 2:
+ return ShiftRightLanes<2>(d, v);
+ case 3:
+ return ShiftRightLanes<3>(d, v);
+ case 4:
+ return ShiftRightLanes<4>(d, v);
+ case 5:
+ return ShiftRightLanes<5>(d, v);
+ case 6:
+ return ShiftRightLanes<6>(d, v);
+ case 7:
+ return ShiftRightLanes<7>(d, v);
+ case 8:
+ return ShiftRightLanes<8>(d, v);
+ case 9:
+ return ShiftRightLanes<9>(d, v);
+ case 10:
+ return ShiftRightLanes<10>(d, v);
+ case 11:
+ return ShiftRightLanes<11>(d, v);
+ case 12:
+ return ShiftRightLanes<12>(d, v);
+ case 13:
+ return ShiftRightLanes<13>(d, v);
+ case 14:
+ return ShiftRightLanes<14>(d, v);
+ case 15:
+ return ShiftRightLanes<15>(d, v);
+ }
+ }
+#else
+ (void)d;
+#endif
+
+ return detail::SlideDownLanes(v, amt);
+}
+
+// ================================================== COMBINE
+
+// ------------------------------ Combine (InterleaveLower)
+
+// N = N/2 + N/2 (upper half undefined)
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), class VH = VFromD<Half<D>>>
+HWY_API VFromD<D> Combine(D d, VH hi_half, VH lo_half) {
+ const Half<decltype(d)> dh;
+ const RebindToUnsigned<decltype(dh)> duh;
+ // Treat half-width input as one lane, and expand to two lanes.
+ using VU = Vec128<UnsignedFromSize<dh.MaxBytes()>, 2>;
+ const VU lo{BitCast(duh, lo_half).raw};
+ const VU hi{BitCast(duh, hi_half).raw};
+ return BitCast(d, InterleaveLower(lo, hi));
+}
+
+// ------------------------------ ZeroExtendVector (Combine)
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16)>
+HWY_API VFromD<D> ZeroExtendVector(D d, VFromD<Half<D>> lo) {
+ return Combine(d, Zero(Half<decltype(d)>()), lo);
+}
+
+// ------------------------------ Concat full (InterleaveLower)
+
+// hiH,hiL loH,loL |-> hiL,loL (= lower halves)
+template <class D, HWY_IF_V_SIZE_D(D, 16)>
+HWY_API VFromD<D> ConcatLowerLower(D d, VFromD<D> hi, VFromD<D> lo) {
+ const Repartition<uint64_t, decltype(d)> d64;
+ return BitCast(d, InterleaveLower(BitCast(d64, lo), BitCast(d64, hi)));
+}
+
+// hiH,hiL loH,loL |-> hiH,loH (= upper halves)
+template <class D, HWY_IF_V_SIZE_D(D, 16)>
+HWY_API VFromD<D> ConcatUpperUpper(D d, VFromD<D> hi, VFromD<D> lo) {
+ const Repartition<uint64_t, decltype(d)> d64;
+ return BitCast(d, InterleaveUpper(d64, BitCast(d64, lo), BitCast(d64, hi)));
+}
+
+// hiH,hiL loH,loL |-> hiL,loH (= inner halves)
+template <class D, HWY_IF_V_SIZE_D(D, 16)>
+HWY_API VFromD<D> ConcatLowerUpper(D d, VFromD<D> hi, VFromD<D> lo) {
+ return CombineShiftRightBytes<8>(d, hi, lo);
+}
+
+// hiH,hiL loH,loL |-> hiH,loL (= outer halves)
+template <class D, HWY_IF_V_SIZE_D(D, 16)>
+HWY_API VFromD<D> ConcatUpperLower(D d, VFromD<D> hi, VFromD<D> lo) {
+ return BitCast(d, Vec128<uint8_t>{__lsx_vshuf4i_d(
+ reinterpret_cast<__m128i>(lo.raw),
+ reinterpret_cast<__m128i>(hi.raw), 0xC)});
+}
+
+// ------------------------------ Concat partial (Combine, LowerHalf)
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 8)>
+HWY_API VFromD<D> ConcatLowerLower(D d, VFromD<D> hi, VFromD<D> lo) {
+ const Half<decltype(d)> d2;
+ return Combine(d, LowerHalf(d2, hi), LowerHalf(d2, lo));
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 8)>
+HWY_API VFromD<D> ConcatUpperUpper(D d, VFromD<D> hi, VFromD<D> lo) {
+ const Half<decltype(d)> d2;
+ return Combine(d, UpperHalf(d2, hi), UpperHalf(d2, lo));
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 8)>
+HWY_API VFromD<D> ConcatLowerUpper(D d, const VFromD<D> hi,
+ const VFromD<D> lo) {
+ const Half<decltype(d)> d2;
+ return Combine(d, LowerHalf(d2, hi), UpperHalf(d2, lo));
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 8)>
+HWY_API VFromD<D> ConcatUpperLower(D d, VFromD<D> hi, VFromD<D> lo) {
+ const Half<decltype(d)> d2;
+ return Combine(d, UpperHalf(d2, hi), LowerHalf(d2, lo));
+}
+
+// ------------------------------ ConcatOdd
+
+// 8-bit full
+template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_T_SIZE_D(D, 1)>
+HWY_API VFromD<D> ConcatOdd(D /* tag */, VFromD<D> hi, VFromD<D> lo) {
+ return VFromD<D>{__lsx_vpickod_b(hi.raw, lo.raw)};
+}
+// 8-bit x8
+template <class D, HWY_IF_V_SIZE_D(D, 8), HWY_IF_T_SIZE_D(D, 1)>
+HWY_API VFromD<D> ConcatOdd(D /* tag */, VFromD<D> hi, VFromD<D> lo) {
+ __m128i _tmp = __lsx_vpickod_b(hi.raw, lo.raw);
+ return VFromD<D>{__lsx_vextrins_w(_tmp, _tmp, 0x12)};
+}
+// 8-bit x4
+template <class D, HWY_IF_V_SIZE_D(D, 4), HWY_IF_T_SIZE_D(D, 1)>
+HWY_API VFromD<D> ConcatOdd(D /* tag */, VFromD<D> hi, VFromD<D> lo) {
+ __m128i _tmp = __lsx_vpickod_b(hi.raw, lo.raw);
+ return VFromD<D>{__lsx_vextrins_h(_tmp, _tmp, 0x14)};
+}
+
+// 16-bit full
+template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_T_SIZE_D(D, 2)>
+HWY_API VFromD<D> ConcatOdd(D /* tag */, VFromD<D> hi, VFromD<D> lo) {
+ return VFromD<D>{__lsx_vpickod_h(hi.raw, lo.raw)};
+}
+// 16-bit x4
+template <class D, HWY_IF_V_SIZE_D(D, 8), HWY_IF_T_SIZE_D(D, 2)>
+HWY_API VFromD<D> ConcatOdd(D /* tag */, VFromD<D> hi, VFromD<D> lo) {
+ __m128i _tmp = __lsx_vpickod_h(hi.raw, lo.raw);
+ return VFromD<D>{__lsx_vextrins_w(_tmp, _tmp, 0x12)};
+}
+
+// 32-bit full
+template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_T_SIZE_D(D, 4)>
+HWY_API VFromD<D> ConcatOdd(D d, VFromD<D> hi, VFromD<D> lo) {
+ return BitCast(
+ d, Vec128<uint8_t>{__lsx_vpickod_w(reinterpret_cast<__m128i>(hi.raw),
+ reinterpret_cast<__m128i>(lo.raw))});
+}
+
+// Any T x2
+template <class D, HWY_IF_LANES_D(D, 2)>
+HWY_API VFromD<D> ConcatOdd(D d, VFromD<D> hi, VFromD<D> lo) {
+ return InterleaveUpper(d, lo, hi);
+}
+
+// ------------------------------ ConcatEven
+
+// 8-bit full
+template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_T_SIZE_D(D, 1)>
+HWY_API VFromD<D> ConcatEven(D /* tag */, VFromD<D> hi, VFromD<D> lo) {
+ return VFromD<D>{__lsx_vpickev_b(hi.raw, lo.raw)};
+}
+// 8-bit x8
+template <class D, HWY_IF_V_SIZE_D(D, 8), HWY_IF_T_SIZE_D(D, 1)>
+HWY_API VFromD<D> ConcatEven(D /* tag */, VFromD<D> hi, VFromD<D> lo) {
+ __m128i _tmp = __lsx_vpickev_b(hi.raw, lo.raw);
+ return VFromD<D>{__lsx_vextrins_w(_tmp, _tmp, 0x12)};
+}
+// 8-bit x4
+template <class D, HWY_IF_V_SIZE_D(D, 4), HWY_IF_T_SIZE_D(D, 1)>
+HWY_API VFromD<D> ConcatEven(D /* tag */, VFromD<D> hi, VFromD<D> lo) {
+ __m128i _tmp = __lsx_vpickev_b(hi.raw, lo.raw);
+ return VFromD<D>{__lsx_vextrins_h(_tmp, _tmp, 0x14)};
+}
+
+// 16-bit full
+template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_T_SIZE_D(D, 2)>
+HWY_API VFromD<D> ConcatEven(D /* tag */, VFromD<D> hi, VFromD<D> lo) {
+ return VFromD<D>{__lsx_vpickev_h(hi.raw, lo.raw)};
+}
+// 16-bit x4
+template <class D, HWY_IF_V_SIZE_D(D, 8), HWY_IF_T_SIZE_D(D, 2)>
+HWY_API VFromD<D> ConcatEven(D /* tag */, VFromD<D> hi, VFromD<D> lo) {
+ __m128i _tmp = __lsx_vpickev_h(hi.raw, lo.raw);
+ return VFromD<D>{__lsx_vextrins_w(_tmp, _tmp, 0x12)};
+}
+
+// 32-bit full
+template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_T_SIZE_D(D, 4)>
+HWY_API VFromD<D> ConcatEven(D d, VFromD<D> hi, VFromD<D> lo) {
+ return BitCast(
+ d, Vec128<uint8_t>{__lsx_vpickev_w(reinterpret_cast<__m128i>(hi.raw),
+ reinterpret_cast<__m128i>(lo.raw))});
+}
+
+// Any T x2
+template <class D, HWY_IF_LANES_D(D, 2)>
+HWY_API VFromD<D> ConcatEven(D d, VFromD<D> hi, VFromD<D> lo) {
+ return InterleaveLower(d, lo, hi);
+}
+
+template <size_t N>
+HWY_INLINE Vec128<float16_t, N> ConcatEven(Vec128<float16_t, N> hi,
+ Vec128<float16_t, N> lo) {
+ const DFromV<decltype(hi)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(d, ConcatEven(BitCast(du, hi), BitCast(du, lo)));
+}
+// ------------------------------ DupEven (InterleaveLower)
+
+template <typename T>
+HWY_API Vec128<T, 1> DupEven(const Vec128<T, 1> v) {
+ return v;
+}
+
+template <typename T, size_t N, HWY_IF_T_SIZE(T, 1)>
+HWY_API Vec128<T, N> DupEven(const Vec128<T, N> v) {
+ __m128i _tmp = __lsx_vpickev_b(v.raw, v.raw);
+ return Vec128<T, N>{__lsx_vilvl_b(_tmp, _tmp)};
+}
+
+template <typename T, size_t N, HWY_IF_T_SIZE(T, 2)>
+HWY_API Vec128<T, N> DupEven(const Vec128<T, N> v) {
+ const DFromV<decltype(v)> d;
+ const RebindToUnsigned<decltype(d)> du; // for float16_t
+ __m128i _tmp = __lsx_vpickev_h(BitCast(du, v).raw, BitCast(du, v).raw);
+ return BitCast(d, VFromD<decltype(du)>{__lsx_vilvl_h(_tmp, _tmp)});
+}
+
+template <typename T, size_t N, HWY_IF_T_SIZE(T, 4)>
+HWY_API Vec128<T, N> DupEven(const Vec128<T, N> v) {
+ const DFromV<decltype(v)> d;
+ __m128i _tmp = detail::BitCastToInteger(v.raw);
+ __m128i _tmp1 = __lsx_vpickev_w(_tmp, _tmp);
+ return BitCast(d, Vec128<uint32_t, N>{__lsx_vilvl_w(_tmp1, _tmp1)});
+}
+
+template <typename T, size_t N, HWY_IF_T_SIZE(T, 8)>
+HWY_API Vec128<T, N> DupEven(Vec128<T, N> v) {
+ return InterleaveLower(DFromV<decltype(v)>(), v, v);
+}
+
+// ------------------------------ DupOdd (InterleaveUpper)
+
+template <typename T, HWY_IF_T_SIZE(T, 1)>
+HWY_API Vec128<T, 1> DupOdd(Vec128<T, 1> v) {
+ return v;
+}
+
+template <typename T, size_t N, HWY_IF_T_SIZE(T, 1)>
+HWY_API Vec128<T, N> DupOdd(const Vec128<T, N> v) {
+ __m128i _tmp = __lsx_vpickod_b(v.raw, v.raw);
+ return Vec128<T, N>{__lsx_vilvl_b(_tmp, _tmp)};
+}
+
+template <typename T, size_t N, HWY_IF_T_SIZE(T, 2)>
+HWY_API Vec128<T, N> DupOdd(const Vec128<T, N> v) {
+ __m128i _tmp = __lsx_vpickod_h(v.raw, v.raw);
+ return Vec128<T, N>{__lsx_vilvl_h(_tmp, _tmp)};
+}
+
+template <typename T, size_t N, HWY_IF_T_SIZE(T, 4)>
+HWY_API Vec128<T, N> DupOdd(const Vec128<T, N> v) {
+ const DFromV<decltype(v)> d;
+ __m128i _tmp = detail::BitCastToInteger(v.raw);
+ __m128i _tmp1 = __lsx_vpickod_w(_tmp, _tmp);
+ return BitCast(d, Vec128<uint32_t, N>{__lsx_vilvl_w(_tmp1, _tmp1)});
+}
+
+template <typename T, size_t N, HWY_IF_T_SIZE(T, 8)>
+HWY_API Vec128<T, N> DupOdd(Vec128<T, N> v) {
+ return InterleaveUpper(DFromV<decltype(v)>(), v, v);
+}
+
+// ------------------------------ TwoTablesLookupLanes (DupEven)
+
+template <typename T, size_t N, HWY_IF_V_SIZE_LE(T, N, 8)>
+HWY_API Vec128<T, N> TwoTablesLookupLanes(Vec128<T, N> a, Vec128<T, N> b,
+ Indices128<T, N> idx) {
+ const DFromV<decltype(a)> d;
+ const Twice<decltype(d)> dt;
+ const Repartition<uint8_t, decltype(dt)> dt_u8;
+// TableLookupLanes currently requires table and index vectors to be the same
+// size, though a half-length index vector would be sufficient here.
+#if HWY_IS_MSAN
+ const Vec128<T, N> idx_vec{idx.raw};
+ const Indices128<T, N * 2> idx2{Combine(dt, idx_vec, idx_vec).raw};
+#else
+ // We only keep LowerHalf of the result, which is valid in idx.
+ const Indices128<T, N * 2> idx2{idx.raw};
+#endif
+ return LowerHalf(
+ d, TableLookupBytes(Combine(dt, b, a),
+ BitCast(dt, VFromD<decltype(dt_u8)>{idx2.raw})));
+}
+
+template <typename T, HWY_IF_UI8(T)>
+HWY_API Vec128<T> TwoTablesLookupLanes(Vec128<T> a, Vec128<T> b,
+ Indices128<T> idx) {
+ return Vec128<T>{__lsx_vshuf_b(b.raw, a.raw, idx.raw)};
+}
+
+template <typename T, HWY_IF_T_SIZE_ONE_OF(T, ((1 << 2) | (1 << 4) | (1 << 8)))>
+HWY_API Vec128<T> TwoTablesLookupLanes(Vec128<T> a, Vec128<T> b,
+ Indices128<T> idx) {
+ const DFromV<decltype(a)> d;
+ const Repartition<uint8_t, decltype(d)> du8;
+ return BitCast(d, TwoTablesLookupLanes(BitCast(du8, a), BitCast(du8, b),
+ Indices128<uint8_t>{idx.raw}));
+}
+
+// ------------------------------ OddEven
+
+template <typename T, size_t N, HWY_IF_UI8(T)>
+HWY_INLINE Vec128<T, N> OddEven(const Vec128<T, N> a, const Vec128<T, N> b) {
+ __m128i t0 = __lsx_vpackod_b(a.raw, a.raw);
+ return Vec128<T, N>{__lsx_vpackev_b(t0, b.raw)};
+}
+template <typename T, size_t N, HWY_IF_UI16(T)>
+HWY_INLINE Vec128<T, N> OddEven(const Vec128<T, N> a, const Vec128<T, N> b) {
+ __m128i t0 = __lsx_vpackod_h(a.raw, a.raw);
+ return Vec128<T, N>{__lsx_vpackev_h(t0, b.raw)};
+}
+template <typename T, size_t N, HWY_IF_T_SIZE(T, 4)>
+HWY_INLINE Vec128<T, N> OddEven(const Vec128<T, N> a, const Vec128<T, N> b) {
+ const DFromV<decltype(a)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ __m128i t0 = __lsx_vpackod_w(BitCast(du, a).raw, BitCast(du, a).raw);
+ return BitCast(d,
+ VFromD<decltype(du)>{__lsx_vpackev_w(t0, BitCast(du, b).raw)});
+}
+template <typename T, size_t N, HWY_IF_T_SIZE(T, 8)>
+HWY_INLINE Vec128<T, N> OddEven(const Vec128<T, N> a, const Vec128<T, N> b) {
+ const DFromV<decltype(a)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ return BitCast(d, VFromD<decltype(du)>{__lsx_vextrins_d(
+ BitCast(du, b).raw, BitCast(du, a).raw, 0x11)});
+}
+
+// -------------------------- InterleaveEven
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_T_SIZE_D(D, 1)>
+HWY_API VFromD<D> InterleaveEven(D /*d*/, VFromD<D> a, VFromD<D> b) {
+ return VFromD<D>{__lsx_vpackev_b(b.raw, a.raw)};
+}
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_T_SIZE_D(D, 2)>
+HWY_API VFromD<D> InterleaveEven(D /*d*/, VFromD<D> a, VFromD<D> b) {
+ return VFromD<D>{__lsx_vpackev_h(b.raw, a.raw)};
+}
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_T_SIZE_D(D, 4)>
+HWY_API VFromD<D> InterleaveEven(D d, VFromD<D> a, VFromD<D> b) {
+ const RebindToSigned<D> di;
+ return BitCast(d, VFromD<decltype(di)>{__lsx_vpackev_w(BitCast(di, b).raw,
+ BitCast(di, a).raw)});
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_T_SIZE_D(D, 8)>
+HWY_API VFromD<D> InterleaveEven(D d, VFromD<D> a, VFromD<D> b) {
+ const RebindToSigned<D> di;
+ return BitCast(d, VFromD<decltype(di)>{__lsx_vpackev_d(BitCast(di, b).raw,
+ BitCast(di, a).raw)});
+}
+
+// -------------------------- InterleaveOdd
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_T_SIZE_D(D, 1)>
+HWY_API VFromD<D> InterleaveOdd(D /*d*/, VFromD<D> a, VFromD<D> b) {
+ return VFromD<D>{__lsx_vpackod_b(b.raw, a.raw)};
+}
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_T_SIZE_D(D, 2)>
+HWY_API VFromD<D> InterleaveOdd(D /*d*/, VFromD<D> a, VFromD<D> b) {
+ return VFromD<D>{__lsx_vpackod_h(b.raw, a.raw)};
+}
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_T_SIZE_D(D, 4)>
+HWY_API VFromD<D> InterleaveOdd(D d, VFromD<D> a, VFromD<D> b) {
+ const RebindToSigned<D> di;
+ return BitCast(d, VFromD<decltype(di)>{__lsx_vpackod_w(BitCast(di, b).raw,
+ BitCast(di, a).raw)});
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_T_SIZE_D(D, 8)>
+HWY_API VFromD<D> InterleaveOdd(D d, VFromD<D> a, VFromD<D> b) {
+ const RebindToSigned<D> di;
+ return BitCast(d, VFromD<decltype(di)>{__lsx_vpackod_d(BitCast(di, b).raw,
+ BitCast(di, a).raw)});
+}
+
+// ------------------------------ OddEvenBlocks
+
+template <typename T, size_t N>
+HWY_API Vec128<T, N> OddEvenBlocks(Vec128<T, N> /* odd */, Vec128<T, N> even) {
+ return even;
+}
+
+// ------------------------------ SwapAdjacentBlocks
+
+template <typename T, size_t N>
+HWY_API Vec128<T, N> SwapAdjacentBlocks(Vec128<T, N> v) {
+ return v;
+}
+
+// ------------------------------ InterleaveEvenBlocks
+template <class D, class V = VFromD<D>, HWY_IF_V_SIZE_LE_D(D, 16)>
+HWY_API V InterleaveEvenBlocks(D, V a, V /*b*/) {
+ return a;
+}
+// ------------------------------ InterleaveOddBlocks
+template <class D, class V = VFromD<D>, HWY_IF_V_SIZE_LE_D(D, 16)>
+HWY_API V InterleaveOddBlocks(D, V a, V /*b*/) {
+ return a;
+}
+
+// ------------------------------ InterleaveLowerBlocks
+template <class D, class V = VFromD<D>, HWY_IF_V_SIZE_LE_D(D, 16)>
+HWY_API V InterleaveLowerBlocks(D, V a, V /*b*/) {
+ return a;
+}
+// ------------------------------ InterleaveUpperBlocks
+template <class D, class V = VFromD<D>, HWY_IF_V_SIZE_LE_D(D, 16)>
+HWY_API V InterleaveUpperBlocks(D, V a, V /*b*/) {
+ return a;
+}
+
+// ------------------------------ Shl
+
+template <typename T, size_t N, HWY_IF_UI8(T)>
+HWY_API Vec128<T, N> operator<<(Vec128<T, N> v, Vec128<T, N> bits) {
+ return Vec128<T, N>{__lsx_vsll_b(v.raw, bits.raw)};
+}
+
+template <typename T, size_t N, HWY_IF_UI16(T)>
+HWY_API Vec128<T, N> operator<<(Vec128<T, N> v, Vec128<T, N> bits) {
+ return Vec128<T, N>{__lsx_vsll_h(v.raw, bits.raw)};
+}
+
+template <typename T, size_t N, HWY_IF_UI32(T)>
+HWY_API Vec128<T, N> operator<<(Vec128<T, N> v, Vec128<T, N> bits) {
+ return Vec128<T, N>{__lsx_vsll_w(v.raw, bits.raw)};
+}
+
+template <typename T, size_t N, HWY_IF_UI64(T)>
+HWY_API Vec128<T, N> operator<<(Vec128<T, N> v, Vec128<T, N> bits) {
+ return Vec128<T, N>{__lsx_vsll_d(v.raw, bits.raw)};
+}
+
+// ------------------------------ Shr
+
+namespace detail {
+
+template <size_t N>
+HWY_API Vec128<uint8_t, N> Shr(Vec128<uint8_t, N> v, Vec128<uint8_t, N> bits) {
+ return Vec128<uint8_t, N>{__lsx_vsrl_b(v.raw, bits.raw)};
+}
+template <size_t N>
+HWY_API Vec128<uint16_t, N> Shr(Vec128<uint16_t, N> v,
+ Vec128<uint16_t, N> bits) {
+ return Vec128<uint16_t, N>{__lsx_vsrl_h(v.raw, bits.raw)};
+}
+template <size_t N>
+HWY_API Vec128<uint32_t, N> Shr(Vec128<uint32_t, N> v,
+ Vec128<uint32_t, N> bits) {
+ return Vec128<uint32_t, N>{__lsx_vsrl_w(v.raw, bits.raw)};
+}
+template <size_t N>
+HWY_API Vec128<uint64_t, N> Shr(Vec128<uint64_t, N> v,
+ Vec128<uint64_t, N> bits) {
+ return Vec128<uint64_t, N>{__lsx_vsrl_d(v.raw, bits.raw)};
+}
+
+template <size_t N>
+HWY_API Vec128<int8_t, N> Shr(Vec128<int8_t, N> v, Vec128<int8_t, N> bits) {
+ return Vec128<int8_t, N>{__lsx_vsra_b(v.raw, bits.raw)};
+}
+template <size_t N>
+HWY_API Vec128<int16_t, N> Shr(Vec128<int16_t, N> v, Vec128<int16_t, N> bits) {
+ return Vec128<int16_t, N>{__lsx_vsra_h(v.raw, bits.raw)};
+}
+template <size_t N>
+HWY_API Vec128<int32_t, N> Shr(Vec128<int32_t, N> v, Vec128<int32_t, N> bits) {
+ return Vec128<int32_t, N>{__lsx_vsra_w(v.raw, bits.raw)};
+}
+template <size_t N>
+HWY_API Vec128<int64_t, N> Shr(Vec128<int64_t, N> v, Vec128<int64_t, N> bits) {
+ return Vec128<int64_t, N>{__lsx_vsra_d(v.raw, bits.raw)};
+}
+
+} // namespace detail
+
+template <typename T, size_t N>
+HWY_API Vec128<T, N> operator>>(Vec128<T, N> v, Vec128<T, N> bits) {
+ return detail::Shr(v, bits);
+}
+
+// ================================================== CONVERT (2)
+
+// ------------------------------ PromoteEvenTo/PromoteOddTo
+#include "third_party/highway/hwy/ops/inside-inl.h"
+
+// Generic for all vector lengths.
+template <class DF, HWY_IF_F32_D(DF),
+ class VBF = VFromD<Repartition<bfloat16_t, DF>>>
+HWY_API VFromD<DF> WidenMulPairwiseAdd(DF df, VBF a, VBF b) {
+ return MulAdd(PromoteEvenTo(df, a), PromoteEvenTo(df, b),
+ Mul(PromoteOddTo(df, a), PromoteOddTo(df, b)));
+}
+
+template <class D32, HWY_IF_I32_D(D32), HWY_IF_V_SIZE_LE_D(D32, 16),
+ class V16 = VFromD<RepartitionToNarrow<D32>>>
+HWY_API VFromD<D32> WidenMulPairwiseAdd(D32 /* tag */, V16 a, V16 b) {
+ __m128i _tmp = __lsx_vmulwev_w_h(a.raw, b.raw);
+ return VFromD<D32>{__lsx_vmaddwod_w_h(_tmp, a.raw, b.raw)};
+}
+
+template <class DU32, HWY_IF_U32_D(DU32), HWY_IF_V_SIZE_LE_D(DU32, 16),
+ class VU16 = VFromD<RepartitionToNarrow<DU32>>>
+HWY_API VFromD<DU32> WidenMulPairwiseAdd(DU32 /* tag */, VU16 a, VU16 b) {
+ __m128i _tmp = __lsx_vmulwev_w_hu(a.raw, b.raw);
+ return VFromD<DU32>{__lsx_vmaddwod_w_hu(_tmp, a.raw, b.raw)};
+}
+
+// ------------------------------ ReorderWidenMulAccumulate
+
+template <class D32, HWY_IF_I32_D(D32), HWY_IF_V_SIZE_LE_D(D32, 16),
+ class V16 = VFromD<RepartitionToNarrow<D32>>>
+HWY_API VFromD<D32> ReorderWidenMulAccumulate(D32 /* tag */, V16 a, V16 b,
+ const VFromD<D32> sum0,
+ VFromD<D32>& /* sum1 */) {
+ return VFromD<D32>{__lsx_vmaddwev_w_h(
+ __lsx_vmaddwod_w_h(sum0.raw, a.raw, b.raw), a.raw, b.raw)};
+}
+
+template <class DU32, HWY_IF_U32_D(DU32),
+ class VU16 = VFromD<RepartitionToNarrow<DU32>>>
+HWY_API VFromD<DU32> ReorderWidenMulAccumulate(DU32 /* tag */, VU16 a, VU16 b,
+ const VFromD<DU32> sum0,
+ VFromD<DU32>& /* sum1 */) {
+ return VFromD<DU32>{__lsx_vmaddwev_w_hu(
+ __lsx_vmaddwod_w_hu(sum0.raw, a.raw, b.raw), a.raw, b.raw)};
+}
+
+// ------------------------------ RearrangeToOddPlusEven
+template <size_t N>
+HWY_API Vec128<int32_t, N> RearrangeToOddPlusEven(const Vec128<int32_t, N> sum0,
+ Vec128<int32_t, N> /*sum1*/) {
+ return sum0; // invariant already holds
+}
+
+template <size_t N>
+HWY_API Vec128<uint32_t, N> RearrangeToOddPlusEven(
+ const Vec128<uint32_t, N> sum0, Vec128<uint32_t, N> /*sum1*/) {
+ return sum0; // invariant already holds
+}
+
+// ------------------------------ Demotions
+
+// NOTE: hwy::EnableIf<!hwy::IsSame<V, V>()>* = nullptr is used instead of
+// hwy::EnableIf<false>* = nullptr to avoid compiler errors since
+// !hwy::IsSame<V, V>() is always false and as !hwy::IsSame<V, V>() will cause
+// SFINAE to occur instead of a hard error due to a dependency on the V template
+// argument
+#undef HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V
+#define HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V(V) \
+ hwy::EnableIf<!hwy::IsSame<V, V>()>* = nullptr
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 8), HWY_IF_I8_D(D)>
+HWY_API VFromD<D> DemoteTo(D /* tag */, VFromD<Rebind<int16_t, D>> v) {
+ return VFromD<D>{__lsx_vssrani_b_h(v.raw, v.raw, 0)};
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 8), HWY_IF_U8_D(D)>
+HWY_API VFromD<D> DemoteTo(D /* tag */, VFromD<Rebind<int16_t, D>> v) {
+ return VFromD<D>{__lsx_vssrani_bu_h(v.raw, v.raw, 0)};
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 8), HWY_IF_I8_D(D)>
+HWY_API VFromD<D> DemoteTo(D /* tag */, VFromD<Rebind<uint16_t, D>> v) {
+ return VFromD<D>{__lsx_vssrlni_b_h(v.raw, v.raw, 0)};
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 8), HWY_IF_U8_D(D)>
+HWY_API VFromD<D> DemoteTo(D /* tag */, VFromD<Rebind<uint16_t, D>> v) {
+ return VFromD<D>{__lsx_vssrlni_bu_h(v.raw, v.raw, 0)};
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 8), HWY_IF_I16_D(D)>
+HWY_API VFromD<D> DemoteTo(D /* tag */, VFromD<Rebind<int32_t, D>> v) {
+ return VFromD<D>{__lsx_vssrani_h_w(v.raw, v.raw, 0)};
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 8), HWY_IF_U16_D(D)>
+HWY_API VFromD<D> DemoteTo(D /* tag */, VFromD<Rebind<int32_t, D>> v) {
+ return VFromD<D>{__lsx_vssrani_hu_w(v.raw, v.raw, 0)};
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 8), HWY_IF_I16_D(D)>
+HWY_API VFromD<D> DemoteTo(D /* tag */, VFromD<Rebind<uint32_t, D>> v) {
+ return VFromD<D>{__lsx_vssrlni_h_w(v.raw, v.raw, 0)};
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 8), HWY_IF_U16_D(D)>
+HWY_API VFromD<D> DemoteTo(D /* tag */, VFromD<Rebind<uint32_t, D>> v) {
+ return VFromD<D>{__lsx_vssrlni_hu_w(v.raw, v.raw, 0)};
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 8), HWY_IF_I32_D(D)>
+HWY_API VFromD<D> DemoteTo(D /* tag */, VFromD<Rebind<int64_t, D>> v) {
+ return VFromD<D>{__lsx_vssrani_w_d(v.raw, v.raw, 0)};
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 8), HWY_IF_U32_D(D)>
+HWY_API VFromD<D> DemoteTo(D /* tag */, VFromD<Rebind<int64_t, D>> v) {
+ return VFromD<D>{__lsx_vssrani_wu_d(v.raw, v.raw, 0)};
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 8), HWY_IF_I32_D(D)>
+HWY_API VFromD<D> DemoteTo(D /* tag */, VFromD<Rebind<uint64_t, D>> v) {
+ return VFromD<D>{__lsx_vssrlni_w_d(v.raw, v.raw, 0)};
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 8), HWY_IF_U32_D(D)>
+HWY_API VFromD<D> DemoteTo(D /* tag */, VFromD<Rebind<uint64_t, D>> v) {
+ return VFromD<D>{__lsx_vssrlni_wu_d(v.raw, v.raw, 0)};
+}
+
+// UI->UI DemoteTo for the case where
+// sizeof(TFromD<D>) <= sizeof(TFromV<V>) / 4 is generic for all vector lengths
+template <class DN, class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(DN),
+ HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V),
+ HWY_IF_T_SIZE_LE_D(DN, sizeof(TFromV<V>) / 4)>
+HWY_API VFromD<DN> DemoteTo(DN dn, V v) {
+ using T = TFromV<V>;
+ using TN = TFromD<DN>;
+
+ using TDemoteTo =
+ MakeNarrow<If<IsSigned<T>() && IsSigned<TN>(), T, MakeUnsigned<T>>>;
+ return DemoteTo(dn, DemoteTo(Rebind<TDemoteTo, DN>(), v));
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 8), HWY_IF_F16_D(D)>
+HWY_API VFromD<D> DemoteTo(D /* tag */, VFromD<Rebind<float, D>> v) {
+ return VFromD<D>{__lsx_vfcvt_h_s(v.raw, v.raw)};
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 8), HWY_IF_F32_D(D)>
+HWY_API VFromD<D> DemoteTo(D /* tag */, VFromD<Rebind<double, D>> v) {
+ return VFromD<D>{__lsx_vfcvt_s_d(v.raw, v.raw)};
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 8), HWY_IF_I32_D(D)>
+HWY_API VFromD<D> DemoteTo(D /* tag */, VFromD<Rebind<double, D>> v) {
+ return VFromD<D>{__lsx_vftintrz_w_d(
+ reinterpret_cast<__m128d>(__lsx_vreplgr2vr_w(0)), v.raw)};
+}
+
+template <class D, HWY_IF_U32_D(D)>
+HWY_API VFromD<D> DemoteTo(D du32, VFromD<Rebind<double, D>> v) {
+ const Rebind<uint64_t, decltype(du32)> du64;
+ return DemoteTo(du32, ConvertTo(du64, v));
+}
+
+template <class D, HWY_IF_F32_D(D)>
+HWY_API VFromD<D> DemoteTo(D df32, VFromD<Rebind<int64_t, D>> v) {
+ const Rebind<double, decltype(df32)> df64;
+ const RebindToUnsigned<decltype(df64)> du64;
+ const RebindToSigned<decltype(df32)> di32;
+ const RebindToUnsigned<decltype(df32)> du32;
+
+ const auto k2p64_63 = Set(df64, 27670116110564327424.0);
+ const auto f64_hi52 =
+ Xor(BitCast(df64, ShiftRight<12>(BitCast(du64, v))), k2p64_63) - k2p64_63;
+ const auto f64_lo12 =
+ PromoteTo(df64, BitCast(di32, And(TruncateTo(du32, BitCast(du64, v)),
+ Set(du32, uint32_t{0x00000FFF}))));
+
+ const auto f64_sum = f64_hi52 + f64_lo12;
+ const auto f64_carry = (f64_hi52 - f64_sum) + f64_lo12;
+
+ const auto f64_sum_is_inexact =
+ ShiftRight<63>(BitCast(du64, VecFromMask(df64, f64_carry != Zero(df64))));
+ const auto f64_bits_decrement =
+ And(ShiftRight<63>(BitCast(du64, Xor(f64_sum, f64_carry))),
+ f64_sum_is_inexact);
+
+ const auto adj_f64_val = BitCast(
+ df64,
+ Or(BitCast(du64, f64_sum) - f64_bits_decrement, f64_sum_is_inexact));
+
+ return DemoteTo(df32, adj_f64_val);
+}
+
+template <class D, HWY_IF_F32_D(D)>
+HWY_API VFromD<D> DemoteTo(D df32, VFromD<Rebind<uint64_t, D>> v) {
+ const Rebind<double, decltype(df32)> df64;
+ const RebindToUnsigned<decltype(df64)> du64;
+ const RebindToSigned<decltype(df32)> di32;
+ const RebindToUnsigned<decltype(df32)> du32;
+
+ const auto k2p64 = Set(df64, 18446744073709551616.0);
+ const auto f64_hi52 = Or(BitCast(df64, ShiftRight<12>(v)), k2p64) - k2p64;
+ const auto f64_lo12 =
+ PromoteTo(df64, BitCast(di32, And(TruncateTo(du32, BitCast(du64, v)),
+ Set(du32, uint32_t{0x00000FFF}))));
+
+ const auto f64_sum = f64_hi52 + f64_lo12;
+ const auto f64_carry = (f64_hi52 - f64_sum) + f64_lo12;
+ const auto f64_sum_is_inexact =
+ ShiftRight<63>(BitCast(du64, VecFromMask(df64, f64_carry != Zero(df64))));
+
+ const auto adj_f64_val = BitCast(
+ df64,
+ Or(BitCast(du64, f64_sum) - ShiftRight<63>(BitCast(du64, f64_carry)),
+ f64_sum_is_inexact));
+
+ return DemoteTo(df32, adj_f64_val);
+}
+
+// ------------------------------ ReorderDemote2To
+
+// ReorderDemote2To for 8-byte UI64->UI32, <= 4-byte UI32->UI16,
+// and <= 4-byte UI16->UI8
+template <class DN, class V,
+ HWY_IF_V_SIZE_LE_D(DN, ((sizeof(TFromD<DN>) <= 2 ? 4 : 8))),
+ HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(DN), HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V),
+ HWY_IF_T_SIZE_V(V, sizeof(TFromD<DN>) * 2),
+ HWY_IF_LANES_D(DN, HWY_MAX_LANES_D(DFromV<V>) * 2)>
+HWY_API VFromD<DN> ReorderDemote2To(DN dn, V a, V b) {
+ const DFromV<decltype(a)> d;
+ const Twice<decltype(d)> dt;
+ return DemoteTo(dn, Combine(dt, b, a));
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_I8_D(D)>
+HWY_API VFromD<D> ReorderDemote2To(D /* tag */, Vec128<int16_t> a,
+ Vec128<int16_t> b) {
+ return VFromD<D>{__lsx_vssrani_b_h(b.raw, a.raw, 0)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_U8_D(D)>
+HWY_API VFromD<D> ReorderDemote2To(D /* tag */, Vec128<int16_t> a,
+ Vec128<int16_t> b) {
+ return VFromD<D>{__lsx_vssrani_bu_h(b.raw, a.raw, 0)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_I8_D(D)>
+HWY_API VFromD<D> ReorderDemote2To(D /* tag */, Vec128<uint16_t> a,
+ Vec128<uint16_t> b) {
+ return VFromD<D>{__lsx_vssrlni_b_h(b.raw, a.raw, 0)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_U8_D(D)>
+HWY_API VFromD<D> ReorderDemote2To(D /* tag */, Vec128<uint16_t> a,
+ Vec128<uint16_t> b) {
+ return VFromD<D>{__lsx_vssrlni_bu_h(b.raw, a.raw, 0)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_I16_D(D)>
+HWY_API VFromD<D> ReorderDemote2To(D /* tag */, Vec128<int32_t> a,
+ Vec128<int32_t> b) {
+ return VFromD<D>{__lsx_vssrani_h_w(b.raw, a.raw, 0)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_U16_D(D)>
+HWY_API VFromD<D> ReorderDemote2To(D /* tag */, Vec128<int32_t> a,
+ Vec128<int32_t> b) {
+ return VFromD<D>{__lsx_vssrani_hu_w(b.raw, a.raw, 0)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_I16_D(D)>
+HWY_API VFromD<D> ReorderDemote2To(D /* tag */, Vec128<uint32_t> a,
+ Vec128<uint32_t> b) {
+ return VFromD<D>{__lsx_vssrlni_h_w(b.raw, a.raw, 0)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_U16_D(D)>
+HWY_API VFromD<D> ReorderDemote2To(D /* tag */, Vec128<uint32_t> a,
+ Vec128<uint32_t> b) {
+ return VFromD<D>{__lsx_vssrlni_hu_w(b.raw, a.raw, 0)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_I32_D(D)>
+HWY_API VFromD<D> ReorderDemote2To(D /* tag */, Vec128<int64_t> a,
+ Vec128<int64_t> b) {
+ return VFromD<D>{__lsx_vssrani_w_d(b.raw, a.raw, 0)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_U32_D(D)>
+HWY_API VFromD<D> ReorderDemote2To(D /* tag */, Vec128<int64_t> a,
+ Vec128<int64_t> b) {
+ return VFromD<D>{__lsx_vssrani_wu_d(b.raw, a.raw, 0)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_I32_D(D)>
+HWY_API VFromD<D> ReorderDemote2To(D /* tag */, Vec128<uint64_t> a,
+ Vec128<uint64_t> b) {
+ return VFromD<D>{__lsx_vssrlni_w_d(b.raw, a.raw, 0)};
+}
+
+template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_U32_D(D)>
+HWY_API VFromD<D> ReorderDemote2To(D /* tag */, Vec128<uint64_t> a,
+ Vec128<uint64_t> b) {
+ return VFromD<D>{__lsx_vssrlni_wu_d(b.raw, a.raw, 0)};
+}
+
+// 8-byte UI32->UI16 and UI16->UI8 ReorderDemote2To
+template <class DN, class V, HWY_IF_V_SIZE_D(DN, 8),
+ HWY_IF_NOT_FLOAT_NOR_SPECIAL_D(DN), HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V),
+ HWY_IF_T_SIZE_LE_D(DN, 2), HWY_IF_T_SIZE_V(V, sizeof(TFromD<DN>) * 2),
+ HWY_IF_LANES_D(DN, HWY_MAX_LANES_D(DFromV<V>) * 2)>
+HWY_API VFromD<DN> ReorderDemote2To(DN dn, V a, V b) {
+ const Twice<DFromV<V>> dt;
+ const Twice<decltype(dn)> dt_n;
+
+ const auto demote2_result =
+ ReorderDemote2To(dt_n, ResizeBitCast(dt, a), ResizeBitCast(dt, b));
+ return VFromD<DN>{__lsx_vshuf4i_w(demote2_result.raw, 0x88)};
+}
+
+template <class D, class V, HWY_IF_V_SIZE_LE_D(D, 16),
+ HWY_IF_NOT_FLOAT_NOR_SPECIAL(TFromD<D>),
+ HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V),
+ HWY_IF_T_SIZE_V(V, sizeof(TFromD<D>) * 2),
+ HWY_IF_LANES_D(D, HWY_MAX_LANES_D(DFromV<V>) * 2)>
+HWY_API VFromD<D> OrderedDemote2To(D d, V a, V b) {
+ return ReorderDemote2To(d, a, b);
+}
+
+template <size_t N>
+HWY_API Vec128<uint8_t, N> U8FromU32(const Vec128<uint32_t, N> v) {
+ const DFromV<decltype(v)> du32;
+ const Rebind<uint8_t, decltype(du32)> du8;
+ return DemoteTo(du8, BitCast(du32, v));
+}
+
+// ------------------------------ F32->UI64 PromoteTo
+
+// f32 ->i64
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_I64_D(D)>
+HWY_API VFromD<D> PromoteTo(D /*di64*/, VFromD<Rebind<float, D>> v) {
+ return VFromD<D>{__lsx_vftintrzl_l_s(v.raw)};
+}
+
+// F32->U64 PromoteTo generic for all vector lengths
+template <class D, HWY_IF_U64_D(D)>
+HWY_API VFromD<D> PromoteTo(D du64, VFromD<Rebind<float, D>> v) {
+ const RebindToFloat<decltype(du64)> df64;
+ return ConvertTo(du64, PromoteTo(df64, v));
+}
+
+// ------------------------------ MulFixedPoint15
+
+template <size_t N>
+HWY_API Vec128<int16_t, N> MulFixedPoint15(const Vec128<int16_t, N> a,
+ const Vec128<int16_t, N> b) {
+ __m128i temp_ev = __lsx_vmulwev_w_h(a.raw, b.raw);
+ __m128i temp_od = __lsx_vmulwod_w_h(a.raw, b.raw);
+ __m128i temp1 = __lsx_vilvl_w(temp_od, temp_ev);
+ __m128i temp2 = __lsx_vilvh_w(temp_od, temp_ev);
+ return Vec128<int16_t, N>{__lsx_vssrarni_h_w(temp2, temp1, 15)};
+}
+
+// ------------------------------ Truncations
+
+template <typename From, class DTo, HWY_IF_LANES_D(DTo, 1)>
+HWY_API VFromD<DTo> TruncateTo(DTo /* tag */, Vec128<From, 1> v) {
+ const Repartition<TFromD<DTo>, DFromV<decltype(v)>> dto;
+ return VFromD<DTo>{BitCast(dto, v).raw};
+}
+
+template <class D, HWY_IF_U8_D(D)>
+HWY_API Vec16<uint8_t> TruncateTo(D /* tag */, Vec128<uint64_t> v) {
+ return Vec16<uint8_t>{__lsx_vextrins_b(v.raw, v.raw, 0x18)};
+}
+
+template <class D, HWY_IF_U16_D(D)>
+HWY_API Vec32<uint16_t> TruncateTo(D /* tag */, Vec128<uint64_t> v) {
+ return Vec32<uint16_t>{__lsx_vextrins_h(v.raw, v.raw, 0x14)};
+}
+
+template <class D, HWY_IF_U32_D(D)>
+HWY_API Vec64<uint32_t> TruncateTo(D /* tag */, Vec128<uint64_t> v) {
+ return Vec64<uint32_t>{__lsx_vpickev_w(v.raw, v.raw)};
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 4), HWY_IF_U8_D(D)>
+HWY_API VFromD<D> TruncateTo(D /* tag */, VFromD<Rebind<uint32_t, D>> v) {
+ __m128i v_ev = __lsx_vpickev_b(v.raw, v.raw);
+ return VFromD<D>{__lsx_vpickev_b(v_ev, v_ev)};
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 8), HWY_IF_U16_D(D)>
+HWY_API VFromD<D> TruncateTo(D /* tag */, VFromD<Rebind<uint32_t, D>> v) {
+ return VFromD<D>{__lsx_vpickev_h(v.raw, v.raw)};
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 8), HWY_IF_U8_D(D)>
+HWY_API VFromD<D> TruncateTo(D /* tag */, VFromD<Rebind<uint16_t, D>> v) {
+ return VFromD<D>{__lsx_vpickev_b(v.raw, v.raw)};
+}
+
+// ------------------------------ int -> float ConvertTo
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_F32_D(D)>
+HWY_API VFromD<D> ConvertTo(D /* tag */, VFromD<Rebind<int32_t, D>> v) {
+ return VFromD<D>{__lsx_vffint_s_w(v.raw)};
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_F32_D(D)>
+HWY_API VFromD<D> ConvertTo(D /* tag */, VFromD<Rebind<uint32_t, D>> v) {
+ return VFromD<D>{__lsx_vffint_s_wu(v.raw)};
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_F64_D(D)>
+HWY_API VFromD<D> ConvertTo(D /* tag */, VFromD<Rebind<int64_t, D>> v) {
+ return VFromD<D>{__lsx_vffint_d_l(v.raw)};
+}
+
+// ------------------------------ float -> int ConvertTo
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_F64_D(D)>
+HWY_API VFromD<D> ConvertTo(D /* tag */, VFromD<Rebind<uint64_t, D>> v) {
+ return VFromD<D>{__lsx_vffint_d_lu(v.raw)};
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_I32_D(D)>
+HWY_API VFromD<D> ConvertTo(D /* tag */, VFromD<Rebind<float, D>> v) {
+ return VFromD<D>{__lsx_vftintrz_w_s(v.raw)};
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_U32_D(D)>
+HWY_API VFromD<D> ConvertTo(D /* tag */, VFromD<Rebind<float, D>> v) {
+ return VFromD<D>{__lsx_vftintrz_wu_s(v.raw)};
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_I64_D(D)>
+HWY_API VFromD<D> ConvertTo(D /* tag */, VFromD<Rebind<double, D>> v) {
+ return VFromD<D>{__lsx_vftintrz_l_d(v.raw)};
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_U64_D(D)>
+HWY_API VFromD<D> ConvertTo(D /* tag */, VFromD<Rebind<double, D>> v) {
+ return VFromD<D>{__lsx_vftintrz_lu_d(v.raw)};
+}
+
+// ------------------------------ NearestInt (Round)
+
+template <size_t N>
+HWY_API Vec128<int32_t, N> NearestInt(const Vec128<float, N> v) {
+ return Vec128<int32_t, N>{__lsx_vftintrne_w_s(v.raw)};
+}
+
+template <size_t N>
+HWY_API Vec128<int64_t, N> NearestInt(const Vec128<double, N> v) {
+ return Vec128<int64_t, N>{__lsx_vftintrne_l_d(v.raw)};
+}
+
+template <class DI32, HWY_IF_I32_D(DI32)>
+HWY_API VFromD<DI32> DemoteToNearestInt(DI32 di32,
+ VFromD<Rebind<double, DI32>> v) {
+ return DemoteTo(di32, NearestInt(v));
+}
+
+// ------------------------------ Floating-point rounding
+
+template <size_t N>
+HWY_API Vec128<float, N> Round(const Vec128<float, N> v) {
+ return Vec128<float, N>{__lsx_vfrintrne_s(v.raw)};
+}
+template <size_t N>
+HWY_API Vec128<double, N> Round(const Vec128<double, N> v) {
+ return Vec128<double, N>{__lsx_vfrintrne_d(v.raw)};
+}
+template <size_t N>
+HWY_API Vec128<float, N> Trunc(const Vec128<float, N> v) {
+ return Vec128<float, N>{__lsx_vfrintrz_s(v.raw)};
+}
+template <size_t N>
+HWY_API Vec128<double, N> Trunc(const Vec128<double, N> v) {
+ return Vec128<double, N>{__lsx_vfrintrz_d(v.raw)};
+}
+template <size_t N>
+HWY_API Vec128<float, N> Ceil(const Vec128<float, N> v) {
+ return Vec128<float, N>{__lsx_vfrintrp_s(v.raw)};
+}
+template <size_t N>
+HWY_API Vec128<double, N> Ceil(const Vec128<double, N> v) {
+ return Vec128<double, N>{__lsx_vfrintrp_d(v.raw)};
+}
+// Toward -infinity, aka floor
+template <size_t N>
+HWY_API Vec128<float, N> Floor(const Vec128<float, N> v) {
+ return Vec128<float, N>{__lsx_vfrintrm_s(v.raw)};
+}
+template <size_t N>
+HWY_API Vec128<double, N> Floor(const Vec128<double, N> v) {
+ return Vec128<double, N>{__lsx_vfrintrm_d(v.raw)};
+}
+
+// ------------------------------ Floating-point classification
+
+// FIXME: disable gcc-14 tree-based loop optimizations to prevent
+// 'HighwayTestGroup/HighwayTest.TestAllIsNaN/LSX' failures
+#if HWY_COMPILER_GCC && !HWY_COMPILER_CLANG
+#pragma GCC push_options
+#pragma GCC optimize("-fno-tree-loop-optimize")
+#endif
+
+template <size_t N>
+HWY_API Mask128<float, N> IsNaN(const Vec128<float, N> v) {
+ return Mask128<float, N>{
+ reinterpret_cast<__m128>(__lsx_vfcmp_cune_s(v.raw, v.raw))};
+}
+
+template <size_t N>
+HWY_API Mask128<double, N> IsNaN(const Vec128<double, N> v) {
+ return Mask128<double, N>{
+ reinterpret_cast<__m128d>(__lsx_vfcmp_cune_d(v.raw, v.raw))};
+}
+
+#if HWY_COMPILER_GCC && !HWY_COMPILER_CLANG
+#pragma GCC pop_options
+#endif
+
+#ifdef HWY_NATIVE_IS_EITHER_NAN
+#undef HWY_NATIVE_IS_EITHER_NAN
+#else
+#define HWY_NATIVE_IS_EITHER_NAN
+#endif
+
+template <size_t N>
+HWY_API Mask128<float, N> IsEitherNaN(Vec128<float, N> a, Vec128<float, N> b) {
+ return Mask128<float, N>{
+ reinterpret_cast<__m128>(__lsx_vfcmp_cun_s(a.raw, b.raw))};
+}
+
+template <size_t N>
+HWY_API Mask128<double, N> IsEitherNaN(Vec128<double, N> a,
+ Vec128<double, N> b) {
+ __m128i _tmp = __lsx_vor_v(__lsx_vfcmp_cune_d(a.raw, a.raw),
+ __lsx_vfcmp_cune_d(b.raw, b.raw));
+ return Mask128<double, N>{reinterpret_cast<__m128d>(_tmp)};
+}
+
+#ifdef HWY_NATIVE_ISINF
+#undef HWY_NATIVE_ISINF
+#else
+#define HWY_NATIVE_ISINF
+#endif
+
+template <class V>
+HWY_API MFromD<DFromV<V>> IsInf(V v) {
+ using T = TFromV<V>;
+
+ static_assert(IsFloat<T>(), "Only for float");
+ using TU = MakeUnsigned<T>;
+ const DFromV<decltype(v)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ const VFromD<decltype(du)> vu = BitCast(du, v);
+ // 'Shift left' to clear the sign bit, check for exponent=max and
+ // mantissa=0.
+ return RebindMask(
+ d,
+ Eq(Add(vu, vu), Set(du, static_cast<TU>(hwy::MaxExponentTimes2<T>()))));
+}
+
+// Returns whether normal/subnormal/zero.
+template <class V>
+HWY_API MFromD<DFromV<V>> IsFinite(V v) {
+ using T = TFromV<V>;
+
+ static_assert(IsFloat<T>(), "Only for float");
+ using TU = MakeUnsigned<T>;
+ const DFromV<decltype(v)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ const VFromD<decltype(du)> vu = BitCast(du, v);
+ // 'Shift left' to clear the sign bit, check for exponent<max.
+ return RebindMask(
+ d,
+ Lt(Add(vu, vu), Set(du, static_cast<TU>(hwy::MaxExponentTimes2<T>()))));
+}
+
+// ================================================== MISC
+
+// ------------------------------ LoadMaskBits (TestBit)
+
+namespace detail {
+
+template <class D, HWY_IF_T_SIZE_D(D, 1)>
+HWY_INLINE MFromD<D> LoadMaskBits(D d, uint64_t bits) {
+ const RebindToUnsigned<decltype(d)> du;
+ // Easier than Set(), which would require an >8-bit type, which would not
+ // compile for T=uint8_t, N=1.
+ const VFromD<D> vbits{__lsx_vreplgr2vr_w(static_cast<int32_t>(bits))};
+
+ // Replicate bytes 8x such that each byte contains the bit that governs it.
+ alignas(16) static constexpr uint8_t kRep8[16] = {0, 0, 0, 0, 0, 0, 0, 0,
+ 1, 1, 1, 1, 1, 1, 1, 1};
+ const auto rep8 = TableLookupBytes(vbits, Load(du, kRep8));
+
+ alignas(16) static constexpr uint8_t kBit[16] = {1, 2, 4, 8, 16, 32, 64, 128,
+ 1, 2, 4, 8, 16, 32, 64, 128};
+ return RebindMask(d, TestBit(rep8, LoadDup128(du, kBit)));
+}
+
+template <class D, HWY_IF_T_SIZE_D(D, 2)>
+HWY_INLINE MFromD<D> LoadMaskBits(D d, uint64_t bits) {
+ const RebindToUnsigned<decltype(d)> du;
+ alignas(16) static constexpr uint16_t kBit[8] = {1, 2, 4, 8, 16, 32, 64, 128};
+ return RebindMask(
+ d, TestBit(Set(du, static_cast<uint16_t>(bits)), Load(du, kBit)));
+}
+
+template <class D, HWY_IF_T_SIZE_D(D, 4)>
+HWY_INLINE MFromD<D> LoadMaskBits(D d, uint64_t bits) {
+ const RebindToUnsigned<decltype(d)> du;
+ alignas(16) static constexpr uint32_t kBit[8] = {1, 2, 4, 8};
+ return RebindMask(
+ d, TestBit(Set(du, static_cast<uint32_t>(bits)), Load(du, kBit)));
+}
+
+template <class D, HWY_IF_T_SIZE_D(D, 8)>
+HWY_INLINE MFromD<D> LoadMaskBits(D d, uint64_t bits) {
+ const RebindToUnsigned<decltype(d)> du;
+ alignas(16) static constexpr uint64_t kBit[8] = {1, 2};
+ return RebindMask(d, TestBit(Set(du, bits), Load(du, kBit)));
+}
+
+} // namespace detail
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16)>
+HWY_API MFromD<D> LoadMaskBits(D d, const uint8_t* HWY_RESTRICT bits) {
+ uint64_t mask_bits = 0;
+ CopyBytes<(d.MaxLanes() + 7) / 8>(bits, &mask_bits);
+ return detail::LoadMaskBits(d, mask_bits);
+}
+
+// ------------------------------ Dup128MaskFromMaskBits
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16)>
+HWY_API MFromD<D> Dup128MaskFromMaskBits(D d, unsigned mask_bits) {
+ constexpr size_t kN = MaxLanes(d);
+ if (kN < 8) mask_bits &= (1u << kN) - 1;
+ return detail::LoadMaskBits(d, mask_bits);
+}
+
+template <typename T>
+struct CompressIsPartition {
+ enum { value = (sizeof(T) != 1) };
+};
+
+// ------------------------------ BitsFromMask
+
+namespace detail {
+
+template <class D>
+constexpr uint64_t OnlyActive(D d, uint64_t mask_bits) {
+ return (d.MaxBytes() >= 16) ? mask_bits
+ : mask_bits & ((1ull << d.MaxLanes()) - 1);
+}
+
+constexpr HWY_INLINE uint64_t U64FromInt(int mask_bits) {
+ return static_cast<uint64_t>(static_cast<unsigned>(mask_bits));
+}
+
+} // namespace detail
+
+template <class D, HWY_IF_T_SIZE_D(D, 1), HWY_IF_V_SIZE_LE_D(D, 16)>
+HWY_API uint64_t BitsFromMask(D d, MFromD<D> mask) {
+ return detail::OnlyActive(
+ d, detail::U64FromInt(__lsx_vpickve2gr_w(__lsx_vmskltz_b(mask.raw), 0)));
+}
+
+template <class D, HWY_IF_T_SIZE_D(D, 2), HWY_IF_V_SIZE_LE_D(D, 16)>
+HWY_API uint64_t BitsFromMask(D d, MFromD<D> mask) {
+ return detail::OnlyActive(
+ d, detail::U64FromInt(__lsx_vpickve2gr_w(__lsx_vmskltz_h(mask.raw), 0)));
+}
+
+template <class D, HWY_IF_T_SIZE_D(D, 4), HWY_IF_V_SIZE_LE_D(D, 16)>
+HWY_API uint64_t BitsFromMask(D d, MFromD<D> mask) {
+ return detail::OnlyActive(
+ d, detail::U64FromInt(__lsx_vpickve2gr_w(
+ __lsx_vmskltz_w(reinterpret_cast<__m128i>(mask.raw)), 0)));
+}
+
+template <class D, HWY_IF_T_SIZE_D(D, 8), HWY_IF_V_SIZE_LE_D(D, 16)>
+HWY_API uint64_t BitsFromMask(D d, MFromD<D> mask) {
+ return detail::OnlyActive(
+ d, detail::U64FromInt(__lsx_vpickve2gr_w(
+ __lsx_vmskltz_d(reinterpret_cast<__m128i>(mask.raw)), 0)));
+}
+
+// ------------------------------ StoreMaskBits
+// `p` points to at least 8 writable bytes.
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16)>
+HWY_API size_t StoreMaskBits(D d, MFromD<D> mask, uint8_t* bits) {
+ constexpr size_t kNumBytes = (MaxLanes(d) + 7) / 8;
+ const uint64_t mask_bits = BitsFromMask(d, mask);
+ CopyBytes<kNumBytes>(&mask_bits, bits);
+ return kNumBytes;
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16)>
+HWY_API bool AllFalse(D d, MFromD<D> mask) {
+ return BitsFromMask(d, mask) == 0;
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16)>
+HWY_API bool AllTrue(D d, MFromD<D> mask) {
+ constexpr size_t kN = MaxLanes(d);
+ constexpr uint64_t kAllBits = (1ull << kN) - 1;
+ return BitsFromMask(d, mask) == kAllBits;
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16)>
+HWY_API size_t CountTrue(D d, MFromD<D> mask) {
+ return PopCount(BitsFromMask(d, mask));
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16)>
+HWY_API size_t FindKnownFirstTrue(D d, MFromD<D> mask) {
+ return Num0BitsBelowLS1Bit_Nonzero64(BitsFromMask(d, mask));
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16)>
+HWY_API intptr_t FindFirstTrue(D d, MFromD<D> mask) {
+ const uint64_t mask_bits = BitsFromMask(d, mask);
+ return mask_bits ? intptr_t(Num0BitsBelowLS1Bit_Nonzero64(mask_bits)) : -1;
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16)>
+HWY_API size_t FindKnownLastTrue(D d, MFromD<D> mask) {
+ return 31 - Num0BitsAboveMS1Bit_Nonzero32(
+ static_cast<uint32_t>(BitsFromMask(d, mask)));
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16)>
+HWY_API intptr_t FindLastTrue(D d, MFromD<D> mask) {
+ const uint32_t mask_bits = static_cast<uint32_t>(BitsFromMask(d, mask));
+ return mask_bits ? intptr_t(31 - Num0BitsAboveMS1Bit_Nonzero32(mask_bits))
+ : -1;
+}
+
+// ------------------------------ Compress, CompressBits
+
+namespace detail {
+
+// Also works for N < 8 because the first 16 4-tuples only reference bytes 0-6.
+template <class D, HWY_IF_T_SIZE_D(D, 2)>
+HWY_INLINE VFromD<D> IndicesFromBits128(D d, uint64_t mask_bits) {
+ HWY_DASSERT(mask_bits < 256);
+ const Rebind<uint8_t, decltype(d)> d8;
+ const Twice<decltype(d8)> d8t;
+ const RebindToUnsigned<decltype(d)> du;
+
+ alignas(16) static constexpr uint8_t table[2048] = {
+ // PrintCompress16x8Tables
+ 0, 2, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, //
+ 2, 0, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, //
+ 4, 0, 2, 6, 8, 10, 12, 14, /**/ 0, 4, 2, 6, 8, 10, 12, 14, //
+ 2, 4, 0, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, //
+ 6, 0, 2, 4, 8, 10, 12, 14, /**/ 0, 6, 2, 4, 8, 10, 12, 14, //
+ 2, 6, 0, 4, 8, 10, 12, 14, /**/ 0, 2, 6, 4, 8, 10, 12, 14, //
+ 4, 6, 0, 2, 8, 10, 12, 14, /**/ 0, 4, 6, 2, 8, 10, 12, 14, //
+ 2, 4, 6, 0, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, //
+ 8, 0, 2, 4, 6, 10, 12, 14, /**/ 0, 8, 2, 4, 6, 10, 12, 14, //
+ 2, 8, 0, 4, 6, 10, 12, 14, /**/ 0, 2, 8, 4, 6, 10, 12, 14, //
+ 4, 8, 0, 2, 6, 10, 12, 14, /**/ 0, 4, 8, 2, 6, 10, 12, 14, //
+ 2, 4, 8, 0, 6, 10, 12, 14, /**/ 0, 2, 4, 8, 6, 10, 12, 14, //
+ 6, 8, 0, 2, 4, 10, 12, 14, /**/ 0, 6, 8, 2, 4, 10, 12, 14, //
+ 2, 6, 8, 0, 4, 10, 12, 14, /**/ 0, 2, 6, 8, 4, 10, 12, 14, //
+ 4, 6, 8, 0, 2, 10, 12, 14, /**/ 0, 4, 6, 8, 2, 10, 12, 14, //
+ 2, 4, 6, 8, 0, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, //
+ 10, 0, 2, 4, 6, 8, 12, 14, /**/ 0, 10, 2, 4, 6, 8, 12, 14, //
+ 2, 10, 0, 4, 6, 8, 12, 14, /**/ 0, 2, 10, 4, 6, 8, 12, 14, //
+ 4, 10, 0, 2, 6, 8, 12, 14, /**/ 0, 4, 10, 2, 6, 8, 12, 14, //
+ 2, 4, 10, 0, 6, 8, 12, 14, /**/ 0, 2, 4, 10, 6, 8, 12, 14, //
+ 6, 10, 0, 2, 4, 8, 12, 14, /**/ 0, 6, 10, 2, 4, 8, 12, 14, //
+ 2, 6, 10, 0, 4, 8, 12, 14, /**/ 0, 2, 6, 10, 4, 8, 12, 14, //
+ 4, 6, 10, 0, 2, 8, 12, 14, /**/ 0, 4, 6, 10, 2, 8, 12, 14, //
+ 2, 4, 6, 10, 0, 8, 12, 14, /**/ 0, 2, 4, 6, 10, 8, 12, 14, //
+ 8, 10, 0, 2, 4, 6, 12, 14, /**/ 0, 8, 10, 2, 4, 6, 12, 14, //
+ 2, 8, 10, 0, 4, 6, 12, 14, /**/ 0, 2, 8, 10, 4, 6, 12, 14, //
+ 4, 8, 10, 0, 2, 6, 12, 14, /**/ 0, 4, 8, 10, 2, 6, 12, 14, //
+ 2, 4, 8, 10, 0, 6, 12, 14, /**/ 0, 2, 4, 8, 10, 6, 12, 14, //
+ 6, 8, 10, 0, 2, 4, 12, 14, /**/ 0, 6, 8, 10, 2, 4, 12, 14, //
+ 2, 6, 8, 10, 0, 4, 12, 14, /**/ 0, 2, 6, 8, 10, 4, 12, 14, //
+ 4, 6, 8, 10, 0, 2, 12, 14, /**/ 0, 4, 6, 8, 10, 2, 12, 14, //
+ 2, 4, 6, 8, 10, 0, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, //
+ 12, 0, 2, 4, 6, 8, 10, 14, /**/ 0, 12, 2, 4, 6, 8, 10, 14, //
+ 2, 12, 0, 4, 6, 8, 10, 14, /**/ 0, 2, 12, 4, 6, 8, 10, 14, //
+ 4, 12, 0, 2, 6, 8, 10, 14, /**/ 0, 4, 12, 2, 6, 8, 10, 14, //
+ 2, 4, 12, 0, 6, 8, 10, 14, /**/ 0, 2, 4, 12, 6, 8, 10, 14, //
+ 6, 12, 0, 2, 4, 8, 10, 14, /**/ 0, 6, 12, 2, 4, 8, 10, 14, //
+ 2, 6, 12, 0, 4, 8, 10, 14, /**/ 0, 2, 6, 12, 4, 8, 10, 14, //
+ 4, 6, 12, 0, 2, 8, 10, 14, /**/ 0, 4, 6, 12, 2, 8, 10, 14, //
+ 2, 4, 6, 12, 0, 8, 10, 14, /**/ 0, 2, 4, 6, 12, 8, 10, 14, //
+ 8, 12, 0, 2, 4, 6, 10, 14, /**/ 0, 8, 12, 2, 4, 6, 10, 14, //
+ 2, 8, 12, 0, 4, 6, 10, 14, /**/ 0, 2, 8, 12, 4, 6, 10, 14, //
+ 4, 8, 12, 0, 2, 6, 10, 14, /**/ 0, 4, 8, 12, 2, 6, 10, 14, //
+ 2, 4, 8, 12, 0, 6, 10, 14, /**/ 0, 2, 4, 8, 12, 6, 10, 14, //
+ 6, 8, 12, 0, 2, 4, 10, 14, /**/ 0, 6, 8, 12, 2, 4, 10, 14, //
+ 2, 6, 8, 12, 0, 4, 10, 14, /**/ 0, 2, 6, 8, 12, 4, 10, 14, //
+ 4, 6, 8, 12, 0, 2, 10, 14, /**/ 0, 4, 6, 8, 12, 2, 10, 14, //
+ 2, 4, 6, 8, 12, 0, 10, 14, /**/ 0, 2, 4, 6, 8, 12, 10, 14, //
+ 10, 12, 0, 2, 4, 6, 8, 14, /**/ 0, 10, 12, 2, 4, 6, 8, 14, //
+ 2, 10, 12, 0, 4, 6, 8, 14, /**/ 0, 2, 10, 12, 4, 6, 8, 14, //
+ 4, 10, 12, 0, 2, 6, 8, 14, /**/ 0, 4, 10, 12, 2, 6, 8, 14, //
+ 2, 4, 10, 12, 0, 6, 8, 14, /**/ 0, 2, 4, 10, 12, 6, 8, 14, //
+ 6, 10, 12, 0, 2, 4, 8, 14, /**/ 0, 6, 10, 12, 2, 4, 8, 14, //
+ 2, 6, 10, 12, 0, 4, 8, 14, /**/ 0, 2, 6, 10, 12, 4, 8, 14, //
+ 4, 6, 10, 12, 0, 2, 8, 14, /**/ 0, 4, 6, 10, 12, 2, 8, 14, //
+ 2, 4, 6, 10, 12, 0, 8, 14, /**/ 0, 2, 4, 6, 10, 12, 8, 14, //
+ 8, 10, 12, 0, 2, 4, 6, 14, /**/ 0, 8, 10, 12, 2, 4, 6, 14, //
+ 2, 8, 10, 12, 0, 4, 6, 14, /**/ 0, 2, 8, 10, 12, 4, 6, 14, //
+ 4, 8, 10, 12, 0, 2, 6, 14, /**/ 0, 4, 8, 10, 12, 2, 6, 14, //
+ 2, 4, 8, 10, 12, 0, 6, 14, /**/ 0, 2, 4, 8, 10, 12, 6, 14, //
+ 6, 8, 10, 12, 0, 2, 4, 14, /**/ 0, 6, 8, 10, 12, 2, 4, 14, //
+ 2, 6, 8, 10, 12, 0, 4, 14, /**/ 0, 2, 6, 8, 10, 12, 4, 14, //
+ 4, 6, 8, 10, 12, 0, 2, 14, /**/ 0, 4, 6, 8, 10, 12, 2, 14, //
+ 2, 4, 6, 8, 10, 12, 0, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14, //
+ 14, 0, 2, 4, 6, 8, 10, 12, /**/ 0, 14, 2, 4, 6, 8, 10, 12, //
+ 2, 14, 0, 4, 6, 8, 10, 12, /**/ 0, 2, 14, 4, 6, 8, 10, 12, //
+ 4, 14, 0, 2, 6, 8, 10, 12, /**/ 0, 4, 14, 2, 6, 8, 10, 12, //
+ 2, 4, 14, 0, 6, 8, 10, 12, /**/ 0, 2, 4, 14, 6, 8, 10, 12, //
+ 6, 14, 0, 2, 4, 8, 10, 12, /**/ 0, 6, 14, 2, 4, 8, 10, 12, //
+ 2, 6, 14, 0, 4, 8, 10, 12, /**/ 0, 2, 6, 14, 4, 8, 10, 12, //
+ 4, 6, 14, 0, 2, 8, 10, 12, /**/ 0, 4, 6, 14, 2, 8, 10, 12, //
+ 2, 4, 6, 14, 0, 8, 10, 12, /**/ 0, 2, 4, 6, 14, 8, 10, 12, //
+ 8, 14, 0, 2, 4, 6, 10, 12, /**/ 0, 8, 14, 2, 4, 6, 10, 12, //
+ 2, 8, 14, 0, 4, 6, 10, 12, /**/ 0, 2, 8, 14, 4, 6, 10, 12, //
+ 4, 8, 14, 0, 2, 6, 10, 12, /**/ 0, 4, 8, 14, 2, 6, 10, 12, //
+ 2, 4, 8, 14, 0, 6, 10, 12, /**/ 0, 2, 4, 8, 14, 6, 10, 12, //
+ 6, 8, 14, 0, 2, 4, 10, 12, /**/ 0, 6, 8, 14, 2, 4, 10, 12, //
+ 2, 6, 8, 14, 0, 4, 10, 12, /**/ 0, 2, 6, 8, 14, 4, 10, 12, //
+ 4, 6, 8, 14, 0, 2, 10, 12, /**/ 0, 4, 6, 8, 14, 2, 10, 12, //
+ 2, 4, 6, 8, 14, 0, 10, 12, /**/ 0, 2, 4, 6, 8, 14, 10, 12, //
+ 10, 14, 0, 2, 4, 6, 8, 12, /**/ 0, 10, 14, 2, 4, 6, 8, 12, //
+ 2, 10, 14, 0, 4, 6, 8, 12, /**/ 0, 2, 10, 14, 4, 6, 8, 12, //
+ 4, 10, 14, 0, 2, 6, 8, 12, /**/ 0, 4, 10, 14, 2, 6, 8, 12, //
+ 2, 4, 10, 14, 0, 6, 8, 12, /**/ 0, 2, 4, 10, 14, 6, 8, 12, //
+ 6, 10, 14, 0, 2, 4, 8, 12, /**/ 0, 6, 10, 14, 2, 4, 8, 12, //
+ 2, 6, 10, 14, 0, 4, 8, 12, /**/ 0, 2, 6, 10, 14, 4, 8, 12, //
+ 4, 6, 10, 14, 0, 2, 8, 12, /**/ 0, 4, 6, 10, 14, 2, 8, 12, //
+ 2, 4, 6, 10, 14, 0, 8, 12, /**/ 0, 2, 4, 6, 10, 14, 8, 12, //
+ 8, 10, 14, 0, 2, 4, 6, 12, /**/ 0, 8, 10, 14, 2, 4, 6, 12, //
+ 2, 8, 10, 14, 0, 4, 6, 12, /**/ 0, 2, 8, 10, 14, 4, 6, 12, //
+ 4, 8, 10, 14, 0, 2, 6, 12, /**/ 0, 4, 8, 10, 14, 2, 6, 12, //
+ 2, 4, 8, 10, 14, 0, 6, 12, /**/ 0, 2, 4, 8, 10, 14, 6, 12, //
+ 6, 8, 10, 14, 0, 2, 4, 12, /**/ 0, 6, 8, 10, 14, 2, 4, 12, //
+ 2, 6, 8, 10, 14, 0, 4, 12, /**/ 0, 2, 6, 8, 10, 14, 4, 12, //
+ 4, 6, 8, 10, 14, 0, 2, 12, /**/ 0, 4, 6, 8, 10, 14, 2, 12, //
+ 2, 4, 6, 8, 10, 14, 0, 12, /**/ 0, 2, 4, 6, 8, 10, 14, 12, //
+ 12, 14, 0, 2, 4, 6, 8, 10, /**/ 0, 12, 14, 2, 4, 6, 8, 10, //
+ 2, 12, 14, 0, 4, 6, 8, 10, /**/ 0, 2, 12, 14, 4, 6, 8, 10, //
+ 4, 12, 14, 0, 2, 6, 8, 10, /**/ 0, 4, 12, 14, 2, 6, 8, 10, //
+ 2, 4, 12, 14, 0, 6, 8, 10, /**/ 0, 2, 4, 12, 14, 6, 8, 10, //
+ 6, 12, 14, 0, 2, 4, 8, 10, /**/ 0, 6, 12, 14, 2, 4, 8, 10, //
+ 2, 6, 12, 14, 0, 4, 8, 10, /**/ 0, 2, 6, 12, 14, 4, 8, 10, //
+ 4, 6, 12, 14, 0, 2, 8, 10, /**/ 0, 4, 6, 12, 14, 2, 8, 10, //
+ 2, 4, 6, 12, 14, 0, 8, 10, /**/ 0, 2, 4, 6, 12, 14, 8, 10, //
+ 8, 12, 14, 0, 2, 4, 6, 10, /**/ 0, 8, 12, 14, 2, 4, 6, 10, //
+ 2, 8, 12, 14, 0, 4, 6, 10, /**/ 0, 2, 8, 12, 14, 4, 6, 10, //
+ 4, 8, 12, 14, 0, 2, 6, 10, /**/ 0, 4, 8, 12, 14, 2, 6, 10, //
+ 2, 4, 8, 12, 14, 0, 6, 10, /**/ 0, 2, 4, 8, 12, 14, 6, 10, //
+ 6, 8, 12, 14, 0, 2, 4, 10, /**/ 0, 6, 8, 12, 14, 2, 4, 10, //
+ 2, 6, 8, 12, 14, 0, 4, 10, /**/ 0, 2, 6, 8, 12, 14, 4, 10, //
+ 4, 6, 8, 12, 14, 0, 2, 10, /**/ 0, 4, 6, 8, 12, 14, 2, 10, //
+ 2, 4, 6, 8, 12, 14, 0, 10, /**/ 0, 2, 4, 6, 8, 12, 14, 10, //
+ 10, 12, 14, 0, 2, 4, 6, 8, /**/ 0, 10, 12, 14, 2, 4, 6, 8, //
+ 2, 10, 12, 14, 0, 4, 6, 8, /**/ 0, 2, 10, 12, 14, 4, 6, 8, //
+ 4, 10, 12, 14, 0, 2, 6, 8, /**/ 0, 4, 10, 12, 14, 2, 6, 8, //
+ 2, 4, 10, 12, 14, 0, 6, 8, /**/ 0, 2, 4, 10, 12, 14, 6, 8, //
+ 6, 10, 12, 14, 0, 2, 4, 8, /**/ 0, 6, 10, 12, 14, 2, 4, 8, //
+ 2, 6, 10, 12, 14, 0, 4, 8, /**/ 0, 2, 6, 10, 12, 14, 4, 8, //
+ 4, 6, 10, 12, 14, 0, 2, 8, /**/ 0, 4, 6, 10, 12, 14, 2, 8, //
+ 2, 4, 6, 10, 12, 14, 0, 8, /**/ 0, 2, 4, 6, 10, 12, 14, 8, //
+ 8, 10, 12, 14, 0, 2, 4, 6, /**/ 0, 8, 10, 12, 14, 2, 4, 6, //
+ 2, 8, 10, 12, 14, 0, 4, 6, /**/ 0, 2, 8, 10, 12, 14, 4, 6, //
+ 4, 8, 10, 12, 14, 0, 2, 6, /**/ 0, 4, 8, 10, 12, 14, 2, 6, //
+ 2, 4, 8, 10, 12, 14, 0, 6, /**/ 0, 2, 4, 8, 10, 12, 14, 6, //
+ 6, 8, 10, 12, 14, 0, 2, 4, /**/ 0, 6, 8, 10, 12, 14, 2, 4, //
+ 2, 6, 8, 10, 12, 14, 0, 4, /**/ 0, 2, 6, 8, 10, 12, 14, 4, //
+ 4, 6, 8, 10, 12, 14, 0, 2, /**/ 0, 4, 6, 8, 10, 12, 14, 2, //
+ 2, 4, 6, 8, 10, 12, 14, 0, /**/ 0, 2, 4, 6, 8, 10, 12, 14};
+
+ const VFromD<decltype(d8t)> byte_idx{Load(d8, table + mask_bits * 8).raw};
+ const VFromD<decltype(du)> pairs = ZipLower(byte_idx, byte_idx);
+ return BitCast(d, pairs + Set(du, 0x0100));
+}
+
+template <class D, HWY_IF_T_SIZE_D(D, 2)>
+HWY_INLINE VFromD<D> IndicesFromNotBits128(D d, uint64_t mask_bits) {
+ HWY_DASSERT(mask_bits < 256);
+ const Rebind<uint8_t, decltype(d)> d8;
+ const Twice<decltype(d8)> d8t;
+ const RebindToUnsigned<decltype(d)> du;
+
+ alignas(16) static constexpr uint8_t table[2048] = {
+ // PrintCompressNot16x8Tables
+ 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 12, 14, 0, //
+ 0, 4, 6, 8, 10, 12, 14, 2, /**/ 4, 6, 8, 10, 12, 14, 0, 2, //
+ 0, 2, 6, 8, 10, 12, 14, 4, /**/ 2, 6, 8, 10, 12, 14, 0, 4, //
+ 0, 6, 8, 10, 12, 14, 2, 4, /**/ 6, 8, 10, 12, 14, 0, 2, 4, //
+ 0, 2, 4, 8, 10, 12, 14, 6, /**/ 2, 4, 8, 10, 12, 14, 0, 6, //
+ 0, 4, 8, 10, 12, 14, 2, 6, /**/ 4, 8, 10, 12, 14, 0, 2, 6, //
+ 0, 2, 8, 10, 12, 14, 4, 6, /**/ 2, 8, 10, 12, 14, 0, 4, 6, //
+ 0, 8, 10, 12, 14, 2, 4, 6, /**/ 8, 10, 12, 14, 0, 2, 4, 6, //
+ 0, 2, 4, 6, 10, 12, 14, 8, /**/ 2, 4, 6, 10, 12, 14, 0, 8, //
+ 0, 4, 6, 10, 12, 14, 2, 8, /**/ 4, 6, 10, 12, 14, 0, 2, 8, //
+ 0, 2, 6, 10, 12, 14, 4, 8, /**/ 2, 6, 10, 12, 14, 0, 4, 8, //
+ 0, 6, 10, 12, 14, 2, 4, 8, /**/ 6, 10, 12, 14, 0, 2, 4, 8, //
+ 0, 2, 4, 10, 12, 14, 6, 8, /**/ 2, 4, 10, 12, 14, 0, 6, 8, //
+ 0, 4, 10, 12, 14, 2, 6, 8, /**/ 4, 10, 12, 14, 0, 2, 6, 8, //
+ 0, 2, 10, 12, 14, 4, 6, 8, /**/ 2, 10, 12, 14, 0, 4, 6, 8, //
+ 0, 10, 12, 14, 2, 4, 6, 8, /**/ 10, 12, 14, 0, 2, 4, 6, 8, //
+ 0, 2, 4, 6, 8, 12, 14, 10, /**/ 2, 4, 6, 8, 12, 14, 0, 10, //
+ 0, 4, 6, 8, 12, 14, 2, 10, /**/ 4, 6, 8, 12, 14, 0, 2, 10, //
+ 0, 2, 6, 8, 12, 14, 4, 10, /**/ 2, 6, 8, 12, 14, 0, 4, 10, //
+ 0, 6, 8, 12, 14, 2, 4, 10, /**/ 6, 8, 12, 14, 0, 2, 4, 10, //
+ 0, 2, 4, 8, 12, 14, 6, 10, /**/ 2, 4, 8, 12, 14, 0, 6, 10, //
+ 0, 4, 8, 12, 14, 2, 6, 10, /**/ 4, 8, 12, 14, 0, 2, 6, 10, //
+ 0, 2, 8, 12, 14, 4, 6, 10, /**/ 2, 8, 12, 14, 0, 4, 6, 10, //
+ 0, 8, 12, 14, 2, 4, 6, 10, /**/ 8, 12, 14, 0, 2, 4, 6, 10, //
+ 0, 2, 4, 6, 12, 14, 8, 10, /**/ 2, 4, 6, 12, 14, 0, 8, 10, //
+ 0, 4, 6, 12, 14, 2, 8, 10, /**/ 4, 6, 12, 14, 0, 2, 8, 10, //
+ 0, 2, 6, 12, 14, 4, 8, 10, /**/ 2, 6, 12, 14, 0, 4, 8, 10, //
+ 0, 6, 12, 14, 2, 4, 8, 10, /**/ 6, 12, 14, 0, 2, 4, 8, 10, //
+ 0, 2, 4, 12, 14, 6, 8, 10, /**/ 2, 4, 12, 14, 0, 6, 8, 10, //
+ 0, 4, 12, 14, 2, 6, 8, 10, /**/ 4, 12, 14, 0, 2, 6, 8, 10, //
+ 0, 2, 12, 14, 4, 6, 8, 10, /**/ 2, 12, 14, 0, 4, 6, 8, 10, //
+ 0, 12, 14, 2, 4, 6, 8, 10, /**/ 12, 14, 0, 2, 4, 6, 8, 10, //
+ 0, 2, 4, 6, 8, 10, 14, 12, /**/ 2, 4, 6, 8, 10, 14, 0, 12, //
+ 0, 4, 6, 8, 10, 14, 2, 12, /**/ 4, 6, 8, 10, 14, 0, 2, 12, //
+ 0, 2, 6, 8, 10, 14, 4, 12, /**/ 2, 6, 8, 10, 14, 0, 4, 12, //
+ 0, 6, 8, 10, 14, 2, 4, 12, /**/ 6, 8, 10, 14, 0, 2, 4, 12, //
+ 0, 2, 4, 8, 10, 14, 6, 12, /**/ 2, 4, 8, 10, 14, 0, 6, 12, //
+ 0, 4, 8, 10, 14, 2, 6, 12, /**/ 4, 8, 10, 14, 0, 2, 6, 12, //
+ 0, 2, 8, 10, 14, 4, 6, 12, /**/ 2, 8, 10, 14, 0, 4, 6, 12, //
+ 0, 8, 10, 14, 2, 4, 6, 12, /**/ 8, 10, 14, 0, 2, 4, 6, 12, //
+ 0, 2, 4, 6, 10, 14, 8, 12, /**/ 2, 4, 6, 10, 14, 0, 8, 12, //
+ 0, 4, 6, 10, 14, 2, 8, 12, /**/ 4, 6, 10, 14, 0, 2, 8, 12, //
+ 0, 2, 6, 10, 14, 4, 8, 12, /**/ 2, 6, 10, 14, 0, 4, 8, 12, //
+ 0, 6, 10, 14, 2, 4, 8, 12, /**/ 6, 10, 14, 0, 2, 4, 8, 12, //
+ 0, 2, 4, 10, 14, 6, 8, 12, /**/ 2, 4, 10, 14, 0, 6, 8, 12, //
+ 0, 4, 10, 14, 2, 6, 8, 12, /**/ 4, 10, 14, 0, 2, 6, 8, 12, //
+ 0, 2, 10, 14, 4, 6, 8, 12, /**/ 2, 10, 14, 0, 4, 6, 8, 12, //
+ 0, 10, 14, 2, 4, 6, 8, 12, /**/ 10, 14, 0, 2, 4, 6, 8, 12, //
+ 0, 2, 4, 6, 8, 14, 10, 12, /**/ 2, 4, 6, 8, 14, 0, 10, 12, //
+ 0, 4, 6, 8, 14, 2, 10, 12, /**/ 4, 6, 8, 14, 0, 2, 10, 12, //
+ 0, 2, 6, 8, 14, 4, 10, 12, /**/ 2, 6, 8, 14, 0, 4, 10, 12, //
+ 0, 6, 8, 14, 2, 4, 10, 12, /**/ 6, 8, 14, 0, 2, 4, 10, 12, //
+ 0, 2, 4, 8, 14, 6, 10, 12, /**/ 2, 4, 8, 14, 0, 6, 10, 12, //
+ 0, 4, 8, 14, 2, 6, 10, 12, /**/ 4, 8, 14, 0, 2, 6, 10, 12, //
+ 0, 2, 8, 14, 4, 6, 10, 12, /**/ 2, 8, 14, 0, 4, 6, 10, 12, //
+ 0, 8, 14, 2, 4, 6, 10, 12, /**/ 8, 14, 0, 2, 4, 6, 10, 12, //
+ 0, 2, 4, 6, 14, 8, 10, 12, /**/ 2, 4, 6, 14, 0, 8, 10, 12, //
+ 0, 4, 6, 14, 2, 8, 10, 12, /**/ 4, 6, 14, 0, 2, 8, 10, 12, //
+ 0, 2, 6, 14, 4, 8, 10, 12, /**/ 2, 6, 14, 0, 4, 8, 10, 12, //
+ 0, 6, 14, 2, 4, 8, 10, 12, /**/ 6, 14, 0, 2, 4, 8, 10, 12, //
+ 0, 2, 4, 14, 6, 8, 10, 12, /**/ 2, 4, 14, 0, 6, 8, 10, 12, //
+ 0, 4, 14, 2, 6, 8, 10, 12, /**/ 4, 14, 0, 2, 6, 8, 10, 12, //
+ 0, 2, 14, 4, 6, 8, 10, 12, /**/ 2, 14, 0, 4, 6, 8, 10, 12, //
+ 0, 14, 2, 4, 6, 8, 10, 12, /**/ 14, 0, 2, 4, 6, 8, 10, 12, //
+ 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 12, 0, 14, //
+ 0, 4, 6, 8, 10, 12, 2, 14, /**/ 4, 6, 8, 10, 12, 0, 2, 14, //
+ 0, 2, 6, 8, 10, 12, 4, 14, /**/ 2, 6, 8, 10, 12, 0, 4, 14, //
+ 0, 6, 8, 10, 12, 2, 4, 14, /**/ 6, 8, 10, 12, 0, 2, 4, 14, //
+ 0, 2, 4, 8, 10, 12, 6, 14, /**/ 2, 4, 8, 10, 12, 0, 6, 14, //
+ 0, 4, 8, 10, 12, 2, 6, 14, /**/ 4, 8, 10, 12, 0, 2, 6, 14, //
+ 0, 2, 8, 10, 12, 4, 6, 14, /**/ 2, 8, 10, 12, 0, 4, 6, 14, //
+ 0, 8, 10, 12, 2, 4, 6, 14, /**/ 8, 10, 12, 0, 2, 4, 6, 14, //
+ 0, 2, 4, 6, 10, 12, 8, 14, /**/ 2, 4, 6, 10, 12, 0, 8, 14, //
+ 0, 4, 6, 10, 12, 2, 8, 14, /**/ 4, 6, 10, 12, 0, 2, 8, 14, //
+ 0, 2, 6, 10, 12, 4, 8, 14, /**/ 2, 6, 10, 12, 0, 4, 8, 14, //
+ 0, 6, 10, 12, 2, 4, 8, 14, /**/ 6, 10, 12, 0, 2, 4, 8, 14, //
+ 0, 2, 4, 10, 12, 6, 8, 14, /**/ 2, 4, 10, 12, 0, 6, 8, 14, //
+ 0, 4, 10, 12, 2, 6, 8, 14, /**/ 4, 10, 12, 0, 2, 6, 8, 14, //
+ 0, 2, 10, 12, 4, 6, 8, 14, /**/ 2, 10, 12, 0, 4, 6, 8, 14, //
+ 0, 10, 12, 2, 4, 6, 8, 14, /**/ 10, 12, 0, 2, 4, 6, 8, 14, //
+ 0, 2, 4, 6, 8, 12, 10, 14, /**/ 2, 4, 6, 8, 12, 0, 10, 14, //
+ 0, 4, 6, 8, 12, 2, 10, 14, /**/ 4, 6, 8, 12, 0, 2, 10, 14, //
+ 0, 2, 6, 8, 12, 4, 10, 14, /**/ 2, 6, 8, 12, 0, 4, 10, 14, //
+ 0, 6, 8, 12, 2, 4, 10, 14, /**/ 6, 8, 12, 0, 2, 4, 10, 14, //
+ 0, 2, 4, 8, 12, 6, 10, 14, /**/ 2, 4, 8, 12, 0, 6, 10, 14, //
+ 0, 4, 8, 12, 2, 6, 10, 14, /**/ 4, 8, 12, 0, 2, 6, 10, 14, //
+ 0, 2, 8, 12, 4, 6, 10, 14, /**/ 2, 8, 12, 0, 4, 6, 10, 14, //
+ 0, 8, 12, 2, 4, 6, 10, 14, /**/ 8, 12, 0, 2, 4, 6, 10, 14, //
+ 0, 2, 4, 6, 12, 8, 10, 14, /**/ 2, 4, 6, 12, 0, 8, 10, 14, //
+ 0, 4, 6, 12, 2, 8, 10, 14, /**/ 4, 6, 12, 0, 2, 8, 10, 14, //
+ 0, 2, 6, 12, 4, 8, 10, 14, /**/ 2, 6, 12, 0, 4, 8, 10, 14, //
+ 0, 6, 12, 2, 4, 8, 10, 14, /**/ 6, 12, 0, 2, 4, 8, 10, 14, //
+ 0, 2, 4, 12, 6, 8, 10, 14, /**/ 2, 4, 12, 0, 6, 8, 10, 14, //
+ 0, 4, 12, 2, 6, 8, 10, 14, /**/ 4, 12, 0, 2, 6, 8, 10, 14, //
+ 0, 2, 12, 4, 6, 8, 10, 14, /**/ 2, 12, 0, 4, 6, 8, 10, 14, //
+ 0, 12, 2, 4, 6, 8, 10, 14, /**/ 12, 0, 2, 4, 6, 8, 10, 14, //
+ 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 10, 0, 12, 14, //
+ 0, 4, 6, 8, 10, 2, 12, 14, /**/ 4, 6, 8, 10, 0, 2, 12, 14, //
+ 0, 2, 6, 8, 10, 4, 12, 14, /**/ 2, 6, 8, 10, 0, 4, 12, 14, //
+ 0, 6, 8, 10, 2, 4, 12, 14, /**/ 6, 8, 10, 0, 2, 4, 12, 14, //
+ 0, 2, 4, 8, 10, 6, 12, 14, /**/ 2, 4, 8, 10, 0, 6, 12, 14, //
+ 0, 4, 8, 10, 2, 6, 12, 14, /**/ 4, 8, 10, 0, 2, 6, 12, 14, //
+ 0, 2, 8, 10, 4, 6, 12, 14, /**/ 2, 8, 10, 0, 4, 6, 12, 14, //
+ 0, 8, 10, 2, 4, 6, 12, 14, /**/ 8, 10, 0, 2, 4, 6, 12, 14, //
+ 0, 2, 4, 6, 10, 8, 12, 14, /**/ 2, 4, 6, 10, 0, 8, 12, 14, //
+ 0, 4, 6, 10, 2, 8, 12, 14, /**/ 4, 6, 10, 0, 2, 8, 12, 14, //
+ 0, 2, 6, 10, 4, 8, 12, 14, /**/ 2, 6, 10, 0, 4, 8, 12, 14, //
+ 0, 6, 10, 2, 4, 8, 12, 14, /**/ 6, 10, 0, 2, 4, 8, 12, 14, //
+ 0, 2, 4, 10, 6, 8, 12, 14, /**/ 2, 4, 10, 0, 6, 8, 12, 14, //
+ 0, 4, 10, 2, 6, 8, 12, 14, /**/ 4, 10, 0, 2, 6, 8, 12, 14, //
+ 0, 2, 10, 4, 6, 8, 12, 14, /**/ 2, 10, 0, 4, 6, 8, 12, 14, //
+ 0, 10, 2, 4, 6, 8, 12, 14, /**/ 10, 0, 2, 4, 6, 8, 12, 14, //
+ 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 8, 0, 10, 12, 14, //
+ 0, 4, 6, 8, 2, 10, 12, 14, /**/ 4, 6, 8, 0, 2, 10, 12, 14, //
+ 0, 2, 6, 8, 4, 10, 12, 14, /**/ 2, 6, 8, 0, 4, 10, 12, 14, //
+ 0, 6, 8, 2, 4, 10, 12, 14, /**/ 6, 8, 0, 2, 4, 10, 12, 14, //
+ 0, 2, 4, 8, 6, 10, 12, 14, /**/ 2, 4, 8, 0, 6, 10, 12, 14, //
+ 0, 4, 8, 2, 6, 10, 12, 14, /**/ 4, 8, 0, 2, 6, 10, 12, 14, //
+ 0, 2, 8, 4, 6, 10, 12, 14, /**/ 2, 8, 0, 4, 6, 10, 12, 14, //
+ 0, 8, 2, 4, 6, 10, 12, 14, /**/ 8, 0, 2, 4, 6, 10, 12, 14, //
+ 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 6, 0, 8, 10, 12, 14, //
+ 0, 4, 6, 2, 8, 10, 12, 14, /**/ 4, 6, 0, 2, 8, 10, 12, 14, //
+ 0, 2, 6, 4, 8, 10, 12, 14, /**/ 2, 6, 0, 4, 8, 10, 12, 14, //
+ 0, 6, 2, 4, 8, 10, 12, 14, /**/ 6, 0, 2, 4, 8, 10, 12, 14, //
+ 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 4, 0, 6, 8, 10, 12, 14, //
+ 0, 4, 2, 6, 8, 10, 12, 14, /**/ 4, 0, 2, 6, 8, 10, 12, 14, //
+ 0, 2, 4, 6, 8, 10, 12, 14, /**/ 2, 0, 4, 6, 8, 10, 12, 14, //
+ 0, 2, 4, 6, 8, 10, 12, 14, /**/ 0, 2, 4, 6, 8, 10, 12, 14};
+
+ const VFromD<decltype(d8t)> byte_idx{Load(d8, table + mask_bits * 8).raw};
+ const VFromD<decltype(du)> pairs = ZipLower(byte_idx, byte_idx);
+ return BitCast(d, pairs + Set(du, 0x0100));
+}
+
+template <class D, HWY_IF_T_SIZE_D(D, 4)>
+HWY_INLINE VFromD<D> IndicesFromBits128(D d, uint64_t mask_bits) {
+ HWY_DASSERT(mask_bits < 16);
+
+ // There are only 4 lanes, so we can afford to load the index vector directly.
+ alignas(16) static constexpr uint8_t u8_indices[256] = {
+ // PrintCompress32x4Tables
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, //
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, //
+ 4, 5, 6, 7, 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, //
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, //
+ 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, //
+ 0, 1, 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, //
+ 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, //
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, //
+ 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, //
+ 0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, //
+ 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 8, 9, 10, 11, //
+ 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 8, 9, 10, 11, //
+ 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, //
+ 0, 1, 2, 3, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, //
+ 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, //
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
+
+ const Repartition<uint8_t, decltype(d)> d8;
+ return BitCast(d, Load(d8, u8_indices + 16 * mask_bits));
+}
+
+template <class D, HWY_IF_T_SIZE_D(D, 4)>
+HWY_INLINE VFromD<D> IndicesFromNotBits128(D d, uint64_t mask_bits) {
+ HWY_DASSERT(mask_bits < 16);
+
+ // There are only 4 lanes, so we can afford to load the index vector directly.
+ alignas(16) static constexpr uint8_t u8_indices[256] = {
+ // PrintCompressNot32x4Tables
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5,
+ 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 0, 1, 2, 3,
+ 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,
+ 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7,
+ 12, 13, 14, 15, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1,
+ 2, 3, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, 4, 5, 6, 7,
+ 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
+ 10, 11, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+ 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 12, 13, 14, 15, 0, 1,
+ 2, 3, 8, 9, 10, 11, 4, 5, 6, 7, 12, 13, 14, 15, 8, 9, 10, 11,
+ 0, 1, 2, 3, 4, 5, 6, 7, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5,
+ 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 4, 5, 6, 7, 0, 1, 2, 3,
+ 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9,
+ 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11,
+ 12, 13, 14, 15};
+
+ const Repartition<uint8_t, decltype(d)> d8;
+ return BitCast(d, Load(d8, u8_indices + 16 * mask_bits));
+}
+
+template <class D, HWY_IF_T_SIZE_D(D, 8)>
+HWY_INLINE VFromD<D> IndicesFromBits128(D d, uint64_t mask_bits) {
+ HWY_DASSERT(mask_bits < 4);
+
+ // There are only 2 lanes, so we can afford to load the index vector directly.
+ alignas(16) static constexpr uint8_t u8_indices[64] = {
+ // PrintCompress64x2Tables
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+ 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7,
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
+
+ const Repartition<uint8_t, decltype(d)> d8;
+ return BitCast(d, Load(d8, u8_indices + 16 * mask_bits));
+}
+
+template <class D, HWY_IF_T_SIZE_D(D, 8)>
+HWY_INLINE VFromD<D> IndicesFromNotBits128(D d, uint64_t mask_bits) {
+ HWY_DASSERT(mask_bits < 4);
+
+ // There are only 2 lanes, so we can afford to load the index vector directly.
+ alignas(16) static constexpr uint8_t u8_indices[64] = {
+ // PrintCompressNot64x2Tables
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+ 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7,
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15,
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15};
+
+ const Repartition<uint8_t, decltype(d)> d8;
+ return BitCast(d, Load(d8, u8_indices + 16 * mask_bits));
+}
+
+template <typename T, size_t N, HWY_IF_NOT_T_SIZE(T, 1)>
+HWY_API Vec128<T, N> CompressBits(Vec128<T, N> v, uint64_t mask_bits) {
+ const DFromV<decltype(v)> d;
+ const RebindToUnsigned<decltype(d)> du;
+
+ HWY_DASSERT(mask_bits < (1ull << N));
+ const auto indices = BitCast(du, detail::IndicesFromBits128(d, mask_bits));
+ return BitCast(d, TableLookupBytes(BitCast(du, v), indices));
+}
+
+template <typename T, size_t N, HWY_IF_NOT_T_SIZE(T, 1)>
+HWY_API Vec128<T, N> CompressNotBits(Vec128<T, N> v, uint64_t mask_bits) {
+ const DFromV<decltype(v)> d;
+ const RebindToUnsigned<decltype(d)> du;
+
+ HWY_DASSERT(mask_bits < (1ull << N));
+ const auto indices = BitCast(du, detail::IndicesFromNotBits128(d, mask_bits));
+ return BitCast(d, TableLookupBytes(BitCast(du, v), indices));
+}
+
+} // namespace detail
+
+// Single lane: no-op
+template <typename T>
+HWY_API Vec128<T, 1> Compress(Vec128<T, 1> v, Mask128<T, 1> /*m*/) {
+ return v;
+}
+
+// Two lanes: conditional swap
+template <typename T, HWY_IF_T_SIZE(T, 8)>
+HWY_API Vec128<T> Compress(Vec128<T> v, Mask128<T> mask) {
+ // If mask[1] = 1 and mask[0] = 0, then swap both halves, else keep.
+ const DFromV<decltype(v)> d;
+ const Vec128<T> m = VecFromMask(d, mask);
+ const Vec128<T> maskL = DupEven(m);
+ const Vec128<T> maskH = DupOdd(m);
+ const Vec128<T> swap = AndNot(maskL, maskH);
+ return IfVecThenElse(swap, Shuffle01(v), v);
+}
+
+// General case, 2 or 4 bytes
+template <typename T, size_t N, HWY_IF_T_SIZE_ONE_OF(T, (1 << 2) | (1 << 4))>
+HWY_API Vec128<T, N> Compress(Vec128<T, N> v, Mask128<T, N> mask) {
+ const DFromV<decltype(v)> d;
+ return detail::CompressBits(v, BitsFromMask(d, mask));
+}
+
+// ------------------------------ CompressNot
+
+// Single lane: no-op
+template <typename T>
+HWY_API Vec128<T, 1> CompressNot(Vec128<T, 1> v, Mask128<T, 1> /*m*/) {
+ return v;
+}
+
+// Two lanes: conditional swap
+template <typename T, HWY_IF_T_SIZE(T, 8)>
+HWY_API Vec128<T> CompressNot(Vec128<T> v, Mask128<T> mask) {
+ // If mask[1] = 0 and mask[0] = 1, then swap both halves, else keep.
+ const DFromV<decltype(v)> d;
+ const Vec128<T> m = VecFromMask(d, mask);
+ const Vec128<T> maskL = DupEven(m);
+ const Vec128<T> maskH = DupOdd(m);
+ const Vec128<T> swap = AndNot(maskH, maskL);
+ return IfVecThenElse(swap, Shuffle01(v), v);
+}
+
+template <typename T, size_t N, HWY_IF_T_SIZE_ONE_OF(T, (1 << 2) | (1 << 4))>
+HWY_API Vec128<T, N> CompressNot(Vec128<T, N> v, Mask128<T, N> mask) {
+ const DFromV<decltype(v)> d;
+ // For partial vectors, we cannot pull the Not() into the table because
+ // BitsFromMask clears the upper bits.
+ if (N < 16 / sizeof(T)) {
+ return detail::CompressBits(v, BitsFromMask(d, Not(mask)));
+ }
+ return detail::CompressNotBits(v, BitsFromMask(d, mask));
+}
+
+// ------------------------------ CompressBlocksNot
+HWY_API Vec128<uint64_t> CompressBlocksNot(Vec128<uint64_t> v,
+ Mask128<uint64_t> /* m */) {
+ return v;
+}
+
+template <typename T, size_t N, HWY_IF_NOT_T_SIZE(T, 1)>
+HWY_API Vec128<T, N> CompressBits(Vec128<T, N> v,
+ const uint8_t* HWY_RESTRICT bits) {
+ uint64_t mask_bits = 0;
+ constexpr size_t kNumBytes = (N + 7) / 8;
+ CopyBytes<kNumBytes>(bits, &mask_bits);
+ if (N < 8) {
+ mask_bits &= (1ull << N) - 1;
+ }
+
+ return detail::CompressBits(v, mask_bits);
+}
+
+// ------------------------------ CompressStore, CompressBitsStore
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_NOT_T_SIZE_D(D, 1)>
+HWY_API size_t CompressStore(VFromD<D> v, MFromD<D> m, D d,
+ TFromD<D>* HWY_RESTRICT unaligned) {
+ const RebindToUnsigned<decltype(d)> du;
+
+ const uint64_t mask_bits = BitsFromMask(d, m);
+ HWY_DASSERT(mask_bits < (1ull << MaxLanes(d)));
+ const size_t count = PopCount(mask_bits);
+
+ const auto indices = BitCast(du, detail::IndicesFromBits128(d, mask_bits));
+ const auto compressed = BitCast(d, TableLookupBytes(BitCast(du, v), indices));
+ StoreU(compressed, d, unaligned);
+ detail::MaybeUnpoison(unaligned, count);
+ return count;
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_NOT_T_SIZE_D(D, 1)>
+HWY_API size_t CompressBlendedStore(VFromD<D> v, MFromD<D> m, D d,
+ TFromD<D>* HWY_RESTRICT unaligned) {
+ const RebindToUnsigned<decltype(d)> du;
+
+ const uint64_t mask_bits = BitsFromMask(d, m);
+ HWY_DASSERT(mask_bits < (1ull << MaxLanes(d)));
+ const size_t count = PopCount(mask_bits);
+
+ const auto indices = BitCast(du, detail::IndicesFromBits128(d, mask_bits));
+ const auto compressed = BitCast(d, TableLookupBytes(BitCast(du, v), indices));
+ BlendedStore(compressed, FirstN(d, count), d, unaligned);
+ detail::MaybeUnpoison(unaligned, count);
+ return count;
+}
+
+template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_NOT_T_SIZE_D(D, 1)>
+HWY_API size_t CompressBitsStore(VFromD<D> v, const uint8_t* HWY_RESTRICT bits,
+ D d, TFromD<D>* HWY_RESTRICT unaligned) {
+ const RebindToUnsigned<decltype(d)> du;
+
+ uint64_t mask_bits = 0;
+ constexpr size_t kN = MaxLanes(d);
+ constexpr size_t kNumBytes = (kN + 7) / 8;
+ CopyBytes<kNumBytes>(bits, &mask_bits);
+ if (kN < 8) {
+ mask_bits &= (1ull << kN) - 1;
+ }
+ const size_t count = PopCount(mask_bits);
+
+ const auto indices = BitCast(du, detail::IndicesFromBits128(d, mask_bits));
+ const auto compressed = BitCast(d, TableLookupBytes(BitCast(du, v), indices));
+ StoreU(compressed, d, unaligned);
+
+ detail::MaybeUnpoison(unaligned, count);
+ return count;
+}
+
+// ------------------------------ StoreInterleaved2/3/4
+
+// HWY_NATIVE_LOAD_STORE_INTERLEAVED not set, hence defined in
+// generic_ops-inl.h.
+
+// ------------------------------ Additional mask logical operations
+
+template <class T>
+HWY_API Mask128<T, 1> SetAtOrAfterFirst(Mask128<T, 1> mask) {
+ return mask;
+}
+template <class T>
+HWY_API Mask128<T, 2> SetAtOrAfterFirst(Mask128<T, 2> mask) {
+ const FixedTag<T, 2> d;
+ const auto vmask = VecFromMask(d, mask);
+ return MaskFromVec(Or(vmask, InterleaveLower(vmask, vmask)));
+}
+template <class T, size_t N, HWY_IF_LANES_GT(N, 2), HWY_IF_V_SIZE_LE(T, N, 8)>
+HWY_API Mask128<T, N> SetAtOrAfterFirst(Mask128<T, N> mask) {
+ const Simd<T, N, 0> d;
+ const auto vmask = VecFromMask(d, mask);
+ const auto neg_vmask =
+ ResizeBitCast(d, Neg(ResizeBitCast(Full64<int64_t>(), vmask)));
+ return MaskFromVec(Or(vmask, neg_vmask));
+}
+template <class T, HWY_IF_NOT_T_SIZE(T, 8)>
+HWY_API Mask128<T> SetAtOrAfterFirst(Mask128<T> mask) {
+ const Full128<T> d;
+ const Repartition<int64_t, decltype(d)> di64;
+
+ auto vmask = BitCast(di64, VecFromMask(d, mask));
+ VFromD<decltype(di64)> neg_vmask{__lsx_vsub_q(Zero(di64).raw, vmask.raw)};
+
+ return MaskFromVec(BitCast(d, Or(vmask, neg_vmask)));
+}
+
+template <class T, size_t N>
+HWY_API Mask128<T, N> SetBeforeFirst(Mask128<T, N> mask) {
+ return Not(SetAtOrAfterFirst(mask));
+}
+
+template <class T>
+HWY_API Mask128<T, 1> SetOnlyFirst(Mask128<T, 1> mask) {
+ return mask;
+}
+template <class T>
+HWY_API Mask128<T, 2> SetOnlyFirst(Mask128<T, 2> mask) {
+ const FixedTag<T, 2> d;
+ const RebindToSigned<decltype(d)> di;
+
+ const auto vmask = BitCast(di, VecFromMask(d, mask));
+ const auto zero = Zero(di);
+ const auto vmask2 = VecFromMask(di, InterleaveLower(zero, vmask) == zero);
+ return MaskFromVec(BitCast(d, And(vmask, vmask2)));
+}
+template <class T, size_t N, HWY_IF_LANES_GT(N, 2), HWY_IF_V_SIZE_LE(T, N, 8)>
+HWY_API Mask128<T, N> SetOnlyFirst(Mask128<T, N> mask) {
+ const Simd<T, N, 0> d;
+ const RebindToSigned<decltype(d)> di;
+
+ const auto vmask = ResizeBitCast(Full64<int64_t>(), VecFromMask(d, mask));
+ const auto only_first_vmask =
+ BitCast(d, Neg(ResizeBitCast(di, And(vmask, Neg(vmask)))));
+ return MaskFromVec(only_first_vmask);
+}
+template <class T, HWY_IF_NOT_T_SIZE(T, 8)>
+HWY_API Mask128<T> SetOnlyFirst(Mask128<T> mask) {
+ const Full128<T> d;
+ const RebindToSigned<decltype(d)> di;
+
+ auto vmask = BitCast(di, VecFromMask(d, mask));
+ VFromD<decltype(di)> neg_vmask{__lsx_vsub_q(Zero(di).raw, vmask.raw)};
+
+ return MaskFromVec(BitCast(d, Neg(And(vmask, neg_vmask))));
+}
+
+template <class T>
+HWY_API Mask128<T, 1> SetAtOrBeforeFirst(Mask128<T, 1> /*mask*/) {
+ const FixedTag<T, 1> d;
+ const RebindToSigned<decltype(d)> di;
+ using TI = MakeSigned<T>;
+
+ return RebindMask(d, MaskFromVec(Set(di, TI(-1))));
+}
+template <class T, size_t N, HWY_IF_LANES_GT(N, 1)>
+HWY_API Mask128<T, N> SetAtOrBeforeFirst(Mask128<T, N> mask) {
+ const Simd<T, N, 0> d;
+ return SetBeforeFirst(MaskFromVec(ShiftLeftLanes<1>(VecFromMask(d, mask))));
+}
+
+// ------------------------------ Reductions
+#undef HWY_IF_SUM_OF_LANES_D
+#define HWY_IF_SUM_OF_LANES_D(D) \
+ HWY_IF_LANES_GT_D(D, 1), \
+ hwy::EnableIf<!hwy::IsSame<TFromD<D>, uint8_t>() || \
+ (HWY_V_SIZE_D(D) != 8 && HWY_V_SIZE_D(D) != 16)>* = \
+ nullptr
+// ------------------------------ SumOfLanes
+
+template <class D, HWY_IF_U8_D(D), HWY_IF_LANES_D(D, 8)>
+HWY_API VFromD<D> SumOfLanes(D d, VFromD<D> v) {
+ return Set(d, static_cast<uint8_t>(GetLane(SumsOf8(v)) & 0xFF));
+}
+template <class D, HWY_IF_U8_D(D), HWY_IF_LANES_D(D, 16)>
+HWY_API VFromD<D> SumOfLanes(D d, VFromD<D> v) {
+ const Repartition<uint64_t, decltype(d)> d64;
+ VFromD<decltype(d64)> sums = SumsOf8(v);
+ sums = SumOfLanes(d64, sums);
+ return Broadcast<0>(BitCast(d, sums));
+}
+
+// ------------------------------ Lt128
+
+namespace detail {
+
+// Returns vector-mask for Lt128. Generic for all vector lengths.
+template <class D, HWY_IF_U64_D(D)>
+HWY_INLINE VFromD<D> Lt128Vec(const D d, VFromD<D> a, VFromD<D> b) {
+ // Truth table of Eq and Lt for Hi and Lo u64.
+ // (removed lines with (=H && cH) or (=L && cL) - cannot both be true)
+ // =H =L cH cL | out = cH | (=H & cL)
+ // 0 0 0 0 | 0
+ // 0 0 0 1 | 0
+ // 0 0 1 0 | 1
+ // 0 0 1 1 | 1
+ // 0 1 0 0 | 0
+ // 0 1 0 1 | 0
+ // 0 1 1 0 | 1
+ // 1 0 0 0 | 0
+ // 1 0 0 1 | 1
+ // 1 1 0 0 | 0
+ const auto eqHL = Eq(a, b);
+ const VFromD<D> ltHL = VecFromMask(d, Lt(a, b));
+ const VFromD<D> ltLX = ShiftLeftLanes<1>(ltHL);
+ const VFromD<D> vecHx = IfThenElse(eqHL, ltLX, ltHL);
+ return InterleaveUpper(d, vecHx, vecHx);
+}
+
+// Returns vector-mask for Eq128. Generic for all vector lengths.
+template <class D, HWY_IF_U64_D(D)>
+HWY_INLINE VFromD<D> Eq128Vec(D d, VFromD<D> a, VFromD<D> b) {
+ const auto eqHL = VecFromMask(d, Eq(a, b));
+ const auto eqLH = Reverse2(d, eqHL);
+ return And(eqHL, eqLH);
+}
+
+template <class D, HWY_IF_U64_D(D)>
+HWY_INLINE VFromD<D> Ne128Vec(D d, VFromD<D> a, VFromD<D> b) {
+ const auto neHL = VecFromMask(d, Ne(a, b));
+ const auto neLH = Reverse2(d, neHL);
+ return Or(neHL, neLH);
+}
+
+template <class D, HWY_IF_U64_D(D)>
+HWY_INLINE VFromD<D> Lt128UpperVec(D d, VFromD<D> a, VFromD<D> b) {
+ const VFromD<D> ltHL = VecFromMask(d, Lt(a, b));
+ return InterleaveUpper(d, ltHL, ltHL);
+}
+
+template <class D, HWY_IF_U64_D(D)>
+HWY_INLINE VFromD<D> Eq128UpperVec(D d, VFromD<D> a, VFromD<D> b) {
+ const VFromD<D> eqHL = VecFromMask(d, Eq(a, b));
+ return InterleaveUpper(d, eqHL, eqHL);
+}
+
+template <class D, HWY_IF_U64_D(D)>
+HWY_INLINE VFromD<D> Ne128UpperVec(D d, VFromD<D> a, VFromD<D> b) {
+ const VFromD<D> neHL = VecFromMask(d, Ne(a, b));
+ return InterleaveUpper(d, neHL, neHL);
+}
+
+} // namespace detail
+
+template <class D, HWY_IF_U64_D(D)>
+HWY_API MFromD<D> Lt128(D d, VFromD<D> a, VFromD<D> b) {
+ return MaskFromVec(detail::Lt128Vec(d, a, b));
+}
+
+template <class D, HWY_IF_U64_D(D)>
+HWY_API MFromD<D> Eq128(D d, VFromD<D> a, VFromD<D> b) {
+ return MaskFromVec(detail::Eq128Vec(d, a, b));
+}
+
+template <class D, HWY_IF_U64_D(D)>
+HWY_API MFromD<D> Ne128(D d, VFromD<D> a, VFromD<D> b) {
+ return MaskFromVec(detail::Ne128Vec(d, a, b));
+}
+
+template <class D, HWY_IF_U64_D(D)>
+HWY_API MFromD<D> Lt128Upper(D d, VFromD<D> a, VFromD<D> b) {
+ return MaskFromVec(detail::Lt128UpperVec(d, a, b));
+}
+
+template <class D, HWY_IF_U64_D(D)>
+HWY_API MFromD<D> Eq128Upper(D d, VFromD<D> a, VFromD<D> b) {
+ return MaskFromVec(detail::Eq128UpperVec(d, a, b));
+}
+
+template <class D, HWY_IF_U64_D(D)>
+HWY_API MFromD<D> Ne128Upper(D d, VFromD<D> a, VFromD<D> b) {
+ return MaskFromVec(detail::Ne128UpperVec(d, a, b));
+}
+
+// ------------------------------ Min128, Max128 (Lt128)
+
+// Avoids the extra MaskFromVec in Lt128.
+template <class D, HWY_IF_U64_D(D)>
+HWY_API VFromD<D> Min128(D d, VFromD<D> a, VFromD<D> b) {
+ return IfVecThenElse(detail::Lt128Vec(d, a, b), a, b);
+}
+
+template <class D, HWY_IF_U64_D(D)>
+HWY_API VFromD<D> Max128(D d, VFromD<D> a, VFromD<D> b) {
+ return IfVecThenElse(detail::Lt128Vec(d, b, a), a, b);
+}
+
+template <class D, HWY_IF_U64_D(D)>
+HWY_API VFromD<D> Min128Upper(D d, VFromD<D> a, VFromD<D> b) {
+ return IfVecThenElse(detail::Lt128UpperVec(d, a, b), a, b);
+}
+
+template <class D, HWY_IF_U64_D(D)>
+HWY_API VFromD<D> Max128Upper(D d, VFromD<D> a, VFromD<D> b) {
+ return IfVecThenElse(detail::Lt128UpperVec(d, b, a), a, b);
+}
+
+// -------------------- LeadingZeroCount, TrailingZeroCount,
+// HighestSetBitIndex
+
+#ifdef HWY_NATIVE_LEADING_ZERO_COUNT
+#undef HWY_NATIVE_LEADING_ZERO_COUNT
+#else
+#define HWY_NATIVE_LEADING_ZERO_COUNT
+#endif
+
+template <class V, HWY_IF_UI8_D(DFromV<V>), HWY_IF_V_SIZE_LE_D(DFromV<V>, 16)>
+HWY_API V LeadingZeroCount(V v) {
+ return V{__lsx_vclz_b(v.raw)};
+}
+
+template <class V, HWY_IF_UI16_D(DFromV<V>), HWY_IF_V_SIZE_LE_D(DFromV<V>, 16)>
+HWY_API V LeadingZeroCount(V v) {
+ return V{__lsx_vclz_h(v.raw)};
+}
+
+template <class V, HWY_IF_UI32_D(DFromV<V>), HWY_IF_V_SIZE_LE_D(DFromV<V>, 16)>
+HWY_API V LeadingZeroCount(V v) {
+ return V{__lsx_vclz_w(v.raw)};
+}
+
+template <class V, HWY_IF_UI64_D(DFromV<V>), HWY_IF_V_SIZE_LE_D(DFromV<V>, 16)>
+HWY_API V LeadingZeroCount(V v) {
+ return V{__lsx_vclz_d(v.raw)};
+}
+
+template <class V, HWY_IF_V_SIZE_LE_V(V, 16), HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V)>
+HWY_API V HighestSetBitIndex(V v) {
+ const DFromV<decltype(v)> d;
+ using T = TFromD<decltype(d)>;
+ return BitCast(d, Set(d, T{sizeof(T) * 8 - 1}) - LeadingZeroCount(v));
+}
+
+template <class V, HWY_IF_NOT_FLOAT_NOR_SPECIAL_V(V)>
+HWY_API V TrailingZeroCount(V v) {
+ const DFromV<decltype(v)> d;
+ const RebindToSigned<decltype(d)> di;
+ using T = TFromD<decltype(d)>;
+
+ const auto lsb = And(v, BitCast(d, Neg(BitCast(di, v))));
+ return IfThenElse(Eq(v, Zero(d)), Set(d, T{sizeof(T) * 8}),
+ HighestSetBitIndex(lsb));
+}
+
+} // namespace HWY_NAMESPACE
+} // namespace hwy
+
+HWY_AFTER_NAMESPACE();
+
+#undef HWY_LSX_IF_EMULATED_D
diff --git a/third_party/highway/hwy/ops/ppc_vsx-inl.h b/third_party/highway/hwy/ops/ppc_vsx-inl.h
index 02de0175a0..9953dabf21 100644
--- a/third_party/highway/hwy/ops/ppc_vsx-inl.h
+++ b/third_party/highway/hwy/ops/ppc_vsx-inl.h
@@ -428,11 +428,17 @@ static HWY_INLINE bool IsConstantRawAltivecVect(RawV v) {
// ------------------------------ TernaryLogic
#if HWY_PPC_HAVE_10
+
+#ifdef HWY_NATIVE_TERNLOG
+#undef HWY_NATIVE_TERNLOG
+#else
+#define HWY_NATIVE_TERNLOG
+#endif
+
namespace detail {
// NOTE: the kTernLogOp bits of the PPC10 TernaryLogic operation are in reverse
-// order of the kTernLogOp bits of AVX3
-// _mm_ternarylogic_epi64(a, b, c, kTernLogOp)
+// order of the kTernLogOp bits of AVX3's _mm_ternarylogic_epi64
template <uint8_t kTernLogOp, class V>
HWY_INLINE V TernaryLogic(V a, V b, V c) {
const DFromV<decltype(a)> d;
@@ -459,12 +465,17 @@ HWY_INLINE V TernaryLogic(V a, V b, V c) {
}
} // namespace detail
-#endif // HWY_PPC_HAVE_10
// ------------------------------ Xor3
+
+#ifdef HWY_NATIVE_XOR3
+#undef HWY_NATIVE_XOR3
+#else
+#define HWY_NATIVE_XOR3
+#endif
+
template <typename T, size_t N>
HWY_API Vec128<T, N> Xor3(Vec128<T, N> x1, Vec128<T, N> x2, Vec128<T, N> x3) {
-#if HWY_PPC_HAVE_10
#if defined(__OPTIMIZE__)
if (static_cast<int>(detail::IsConstantRawAltivecVect(x1.raw)) +
static_cast<int>(detail::IsConstantRawAltivecVect(x2.raw)) +
@@ -476,11 +487,55 @@ HWY_API Vec128<T, N> Xor3(Vec128<T, N> x1, Vec128<T, N> x2, Vec128<T, N> x3) {
{
return detail::TernaryLogic<0x69>(x1, x2, x3);
}
+}
+
+// ------------------------------ XorAndNot
+
+#ifdef HWY_NATIVE_BCAX
+#undef HWY_NATIVE_BCAX
#else
- return Xor(x1, Xor(x2, x3));
+#define HWY_NATIVE_BCAX
+#endif
+
+template <typename T, size_t N>
+HWY_API Vec128<T, N> XorAndNot(Vec128<T, N> x, Vec128<T, N> a1,
+ Vec128<T, N> a2) {
+#if defined(__OPTIMIZE__)
+ if (static_cast<int>(detail::IsConstantRawAltivecVect(x.raw)) +
+ static_cast<int>(detail::IsConstantRawAltivecVect(a1.raw)) +
+ static_cast<int>(detail::IsConstantRawAltivecVect(a2.raw)) >=
+ 2) {
+ return Xor(x, AndNot(a1, a2));
+ } else // NOLINT
#endif
+ {
+ return detail::TernaryLogic<0x4B>(x, a1, a2);
+ }
}
+// ------------------------------ XorAndNot
+
+// (HWY_NATIVE_TERNLOG already flipped above)
+
+template <typename T, size_t N>
+HWY_API Vec128<T, N> AndXor(Vec128<T, N> a, Vec128<T, N> x1,
+ Vec128<T, N> x2) {
+#if defined(__OPTIMIZE__)
+ if (static_cast<int>(detail::IsConstantRawAltivecVect(a.raw)) +
+ static_cast<int>(detail::IsConstantRawAltivecVect(x1.raw)) +
+ static_cast<int>(detail::IsConstantRawAltivecVect(x2.raw)) >=
+ 2) {
+ return And(a, Xor(x1, x2));
+ } else // NOLINT
+#endif
+ {
+ return detail::TernaryLogic<0x06>(x, a1, a2);
+ }
+}
+
+
+#endif // HWY_PPC_HAVE_10
+
// ------------------------------ Or3
template <typename T, size_t N>
HWY_API Vec128<T, N> Or3(Vec128<T, N> o1, Vec128<T, N> o2, Vec128<T, N> o3) {
@@ -560,6 +615,19 @@ HWY_API Vec128<T, N> operator^(Vec128<T, N> a, Vec128<T, N> b) {
return Xor(a, b);
}
+// ------------------------------ PopulationCount
+
+#ifdef HWY_NATIVE_POPCNT
+#undef HWY_NATIVE_POPCNT
+#else
+#define HWY_NATIVE_POPCNT
+#endif
+
+template <typename T, size_t N, HWY_IF_UNSIGNED(T)>
+HWY_API Vec128<T, N> PopulationCount(Vec128<T, N> v) {
+ return Vec128<T, N>{vec_popcnt(v.raw)};
+}
+
// ================================================== SIGN
// ------------------------------ Neg
@@ -3372,6 +3440,17 @@ HWY_API V InterleaveOddBlocks(D, V a, V /*b*/) {
return a;
}
+// ------------------------------ InterleaveLowerBlocks
+template <class D, class V = VFromD<D>>
+HWY_API V InterleaveLowerBlocks(D, V a, V /*b*/) {
+ return a;
+}
+// ------------------------------ InterleaveUpperBlocks
+template <class D, class V = VFromD<D>>
+HWY_API V InterleaveUpperBlocks(D, V a, V /*b*/) {
+ return a;
+}
+
// ------------------------------ MulFixedPoint15 (OddEven)
#if HWY_S390X_HAVE_Z14
@@ -3552,23 +3631,11 @@ HWY_API VFromD<D32> ReorderWidenMulAccumulate(D32 /*d32*/, V16 a, V16 b,
}
// ------------------------------ RearrangeToOddPlusEven
-template <size_t N>
-HWY_API Vec128<int32_t, N> RearrangeToOddPlusEven(Vec128<int32_t, N> sum0,
- Vec128<int32_t, N> /*sum1*/) {
+template <class VW, HWY_IF_NOT_FLOAT_V(VW)>
+HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW) {
return sum0; // invariant already holds
}
-template <size_t N>
-HWY_API Vec128<uint32_t, N> RearrangeToOddPlusEven(
- Vec128<uint32_t, N> sum0, Vec128<uint32_t, N> /*sum1*/) {
- return sum0; // invariant already holds
-}
-
-template <class VW>
-HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW sum1) {
- return Add(sum0, sum1);
-}
-
// ------------------------------ SatWidenMulPairwiseAccumulate
#if !HWY_S390X_HAVE_Z14
@@ -5369,9 +5436,21 @@ HWY_INLINE uint64_t ExtractSignBits(Vec128<uint8_t, N> sign_bits,
// vec_vbpermq: unsigned or signed, so cast to avoid a warning.
using VU64 = detail::Raw128<uint64_t>::type;
#if HWY_S390X_HAVE_Z14
+
+#if HWY_COMPILER_GCC_ACTUAL >= 1500 || HWY_COMPILER_CLANG >= 2100
+ // GCC 15 and Clang 20 have added the vec_bperm intrinsic
+
+ // Need to use vec_bperm instead of vec_bperm_u128 with GCC 15 and later to
+ // avoid compiler warning
+ using VU128 = __vector unsigned __int128;
+ const Vec128<uint64_t> extracted{reinterpret_cast<VU64>(
+ vec_bperm(reinterpret_cast<VU128>(sign_bits.raw), bit_shuffle))};
+#else // !(HWY_COMPILER_GCC_ACTUAL >= 1500 || HWY_COMPILER_CLANG >= 2100)
const Vec128<uint64_t> extracted{
reinterpret_cast<VU64>(vec_bperm_u128(sign_bits.raw, bit_shuffle))};
-#else
+#endif // HWY_COMPILER_GCC_ACTUAL >= 1500 || HWY_COMPILER_CLANG >= 2100
+
+#else // !HWY_S390X_HAVE_Z14
const Vec128<uint64_t> extracted{
reinterpret_cast<VU64>(vec_vbpermq(sign_bits.raw, bit_shuffle))};
#endif
@@ -6380,8 +6459,13 @@ HWY_INLINE V Per128BitBlkRevLanesOnBe(V v) {
template <class V>
HWY_INLINE V I128Subtract(V a, V b) {
#if HWY_S390X_HAVE_Z14
-#if HWY_COMPILER_CLANG
+#if HWY_COMPILER_CLANG || HWY_COMPILER_GCC_ACTUAL >= 1500
// Workaround for bug in vec_sub_u128 in Clang vecintrin.h
+
+ // The vec_sub_u128 intrinsic is also now deprecated in GCC 15 and later.
+ // The built-in U128x1 vector subtraction operator should be used instead of
+ // vec_sub_u128 with GCC 15 and later to avoid compiler warnings.
+
typedef __uint128_t VU128 __attribute__((__vector_size__(16)));
const V diff_i128{reinterpret_cast<typename detail::Raw128<TFromV<V>>::type>(
reinterpret_cast<VU128>(a.raw) - reinterpret_cast<VU128>(b.raw))};
@@ -6909,8 +6993,18 @@ template <class T, HWY_IF_NOT_FLOAT_NOR_SPECIAL(T),
HWY_INLINE Vec128<T> SumOfU32OrU64LanesAsU128(Vec128<T> v) {
const DFromV<decltype(v)> d;
const RebindToUnsigned<decltype(d)> du;
+#if HWY_COMPILER_GCC_ACTUAL >= 1500 || HWY_COMPILER_CLANG >= 2100
+ // GCC 15 and Clang 20 have new vec_sum intrinsics that replaced the
+ // vec_sum_u128 intrinsic
+
+ // vec_sum needs to be used instead of vec_sum_u128 with GCC 15 or later to
+ // avoid compiler warnings
+ return Vec128<T>{reinterpret_cast<typename detail::Raw128<T>::type>(
+ vec_sum(BitCast(du, v).raw, Zero(du).raw))};
+#else
return BitCast(
d, Vec128<uint8_t>{vec_sum_u128(BitCast(du, v).raw, Zero(du).raw)});
+#endif
}
#endif
@@ -7153,8 +7247,21 @@ HWY_API V BitShuffle(V v, VI idx) {
#endif
#if HWY_S390X_HAVE_Z14
+
+#if HWY_COMPILER_GCC_ACTUAL >= 1500 || HWY_COMPILER_CLANG >= 2100
+ // GCC 15 and Clang 20 have added the vec_bperm intrinsic
+
+ // Need to use vec_bperm instead of vec_bperm_u128 with GCC 15 and later to
+ // avoid compiler warning
+ using RawVU128 = __vector unsigned __int128;
+
+ const VFromD<decltype(d_full_u64)> bit_shuf_result{reinterpret_cast<RawVU64>(
+ vec_bperm(reinterpret_cast<RawVU128>(v.raw), bit_idx.raw))};
+#else
const VFromD<decltype(d_full_u64)> bit_shuf_result{reinterpret_cast<RawVU64>(
vec_bperm_u128(BitCast(du8, v).raw, bit_idx.raw))};
+#endif // !(HWY_COMPILER_GCC_ACTUAL >= 1500 || HWY_COMPILER_CLANG >= 2100)
+
#elif defined(__SIZEOF_INT128__)
using RawVU128 = __vector unsigned __int128;
const VFromD<decltype(d_full_u64)> bit_shuf_result{reinterpret_cast<RawVU64>(
diff --git a/third_party/highway/hwy/ops/rvv-inl.h b/third_party/highway/hwy/ops/rvv-inl.h
index 752c87de6e..3c01d83b26 100644
--- a/third_party/highway/hwy/ops/rvv-inl.h
+++ b/third_party/highway/hwy/ops/rvv-inl.h
@@ -16,8 +16,21 @@
// RISC-V V vectors (length not known at compile time).
// External include guard in highway.h - see comment there.
+#pragma push_macro("__riscv_v_elen")
+
+// Workaround that ensures that all of the __riscv_vsetvl_* and
+// __riscv_vsetvlmax_* macros in riscv_vector.h are defined when compiling with
+// Clang 20 with dynamic dispatch and a baseline target of SCALAR or EMU128
+#if HWY_COMPILER_CLANG >= 2000 && HWY_COMPILER_CLANG < 2100 && \
+ (!defined(__riscv_v_elen) || __riscv_v_elen < 64)
+#undef __riscv_v_elen
+#define __riscv_v_elen 64
+#endif
+
#include <riscv_vector.h>
+#pragma pop_macro("__riscv_v_elen")
+
#include "third_party/highway/hwy/ops/shared-inl.h"
HWY_BEFORE_NAMESPACE();
@@ -1045,12 +1058,6 @@ HWY_API V AndNot(const V not_a, const V b) {
return And(Not(not_a), b);
}
-// ------------------------------ Xor3
-template <class V>
-HWY_API V Xor3(V x1, V x2, V x3) {
- return Xor(x1, Xor(x2, x3));
-}
-
// ------------------------------ Or3
template <class V>
HWY_API V Or3(V o1, V o2, V o3) {
@@ -1188,9 +1195,9 @@ HWY_RVV_FOREACH_I(HWY_RVV_RETV_ARGVV, SaturatedSub, ssub, _ALL)
#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1400
#define HWY_RVV_AVOID_VXRM
// Clang 16 with __riscv_v_intrinsic == 11000 may either require VXRM or avoid.
-// Assume earlier versions avoid.
+// Assume that Clang 16 and earlier avoid VXRM.
#elif HWY_COMPILER_CLANG && \
- (HWY_COMPILER_CLANG < 1600 || __riscv_v_intrinsic < 11000)
+ (HWY_COMPILER_CLANG < 1700 || __riscv_v_intrinsic < 11000)
#define HWY_RVV_AVOID_VXRM
#endif
@@ -1677,10 +1684,17 @@ HWY_API MFromD<D> MaskedIsNaN(const M m, const V v) {
#undef HWY_RVV_RETM_ARGMVV
#undef HWY_RVV_RETM_ARGVV
-#undef HWY_RVV_RETM_ARGVS
// ------------------------------ Gt/Ge (Lt, Le)
+namespace detail {
+HWY_RVV_FOREACH_I(HWY_RVV_RETM_ARGVS, GtS, msgt_vx, _ALL)
+HWY_RVV_FOREACH_U(HWY_RVV_RETM_ARGVS, GtS, msgtu_vx, _ALL)
+HWY_RVV_FOREACH_F(HWY_RVV_RETM_ARGVS, GtS, mfgt_vf, _ALL)
+} // namespace detail
+
+#undef HWY_RVV_RETM_ARGVS
+
// Swap args to reverse comparisons:
template <class V>
HWY_API auto Gt(const V a, const V b) -> decltype(Lt(a, b)) {
@@ -2422,6 +2436,8 @@ HWY_API VFromD<D> PromoteTo(D d, VFromD<Rebind<hwy::bfloat16_t, D>> v) {
// ------------------------------ DemoteTo U
+HWY_INLINE_VAR constexpr size_t kClipShift = 0;
+
// SEW is for the source so we can use _DEMOTE_VIRT.
#define HWY_RVV_DEMOTE(BASE, CHAR, SEW, SEWD, SEWH, LMUL, LMULD, LMULH, SHIFT, \
MLEN, NAME, OP) \
@@ -2429,7 +2445,7 @@ HWY_API VFromD<D> PromoteTo(D d, VFromD<Rebind<hwy::bfloat16_t, D>> v) {
HWY_API HWY_RVV_V(BASE, SEWH, LMULH) NAME( \
HWY_RVV_D(BASE, SEWH, N, SHIFT - 1) d, HWY_RVV_V(BASE, SEW, LMUL) v) { \
return __riscv_v##OP##CHAR##SEWH##LMULH( \
- v, 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, Lanes(d))); \
+ v, kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, Lanes(d))); \
}
// Unsigned -> unsigned
@@ -2455,62 +2471,62 @@ HWY_RVV_FOREACH_I16(HWY_RVV_DEMOTE_I_TO_U, DemoteTo, _, _DEMOTE_VIRT)
template <size_t N>
HWY_API vuint8mf8_t DemoteTo(Simd<uint8_t, N, -3> d, const vint32mf2_t v) {
return __riscv_vnclipu_wx_u8mf8(
- DemoteTo(Simd<uint16_t, N, -2>(), v), 0,
+ DemoteTo(Simd<uint16_t, N, -2>(), v), kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, Lanes(d)));
}
template <size_t N>
HWY_API vuint8mf4_t DemoteTo(Simd<uint8_t, N, -2> d, const vint32m1_t v) {
return __riscv_vnclipu_wx_u8mf4(
- DemoteTo(Simd<uint16_t, N, -1>(), v), 0,
+ DemoteTo(Simd<uint16_t, N, -1>(), v), kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, Lanes(d)));
}
template <size_t N>
HWY_API vuint8mf2_t DemoteTo(Simd<uint8_t, N, -1> d, const vint32m2_t v) {
return __riscv_vnclipu_wx_u8mf2(
- DemoteTo(Simd<uint16_t, N, 0>(), v), 0,
+ DemoteTo(Simd<uint16_t, N, 0>(), v), kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, Lanes(d)));
}
template <size_t N>
HWY_API vuint8m1_t DemoteTo(Simd<uint8_t, N, 0> d, const vint32m4_t v) {
return __riscv_vnclipu_wx_u8m1(
- DemoteTo(Simd<uint16_t, N, 1>(), v), 0,
+ DemoteTo(Simd<uint16_t, N, 1>(), v), kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, Lanes(d)));
}
template <size_t N>
HWY_API vuint8m2_t DemoteTo(Simd<uint8_t, N, 1> d, const vint32m8_t v) {
return __riscv_vnclipu_wx_u8m2(
- DemoteTo(Simd<uint16_t, N, 2>(), v), 0,
+ DemoteTo(Simd<uint16_t, N, 2>(), v), kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, Lanes(d)));
}
template <size_t N>
HWY_API vuint8mf8_t DemoteTo(Simd<uint8_t, N, -3> d, const vuint32mf2_t v) {
return __riscv_vnclipu_wx_u8mf8(
- DemoteTo(Simd<uint16_t, N, -2>(), v), 0,
+ DemoteTo(Simd<uint16_t, N, -2>(), v), kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, Lanes(d)));
}
template <size_t N>
HWY_API vuint8mf4_t DemoteTo(Simd<uint8_t, N, -2> d, const vuint32m1_t v) {
return __riscv_vnclipu_wx_u8mf4(
- DemoteTo(Simd<uint16_t, N, -1>(), v), 0,
+ DemoteTo(Simd<uint16_t, N, -1>(), v), kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, Lanes(d)));
}
template <size_t N>
HWY_API vuint8mf2_t DemoteTo(Simd<uint8_t, N, -1> d, const vuint32m2_t v) {
return __riscv_vnclipu_wx_u8mf2(
- DemoteTo(Simd<uint16_t, N, 0>(), v), 0,
+ DemoteTo(Simd<uint16_t, N, 0>(), v), kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, Lanes(d)));
}
template <size_t N>
HWY_API vuint8m1_t DemoteTo(Simd<uint8_t, N, 0> d, const vuint32m4_t v) {
return __riscv_vnclipu_wx_u8m1(
- DemoteTo(Simd<uint16_t, N, 1>(), v), 0,
+ DemoteTo(Simd<uint16_t, N, 1>(), v), kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, Lanes(d)));
}
template <size_t N>
HWY_API vuint8m2_t DemoteTo(Simd<uint8_t, N, 1> d, const vuint32m8_t v) {
return __riscv_vnclipu_wx_u8m2(
- DemoteTo(Simd<uint16_t, N, 2>(), v), 0,
+ DemoteTo(Simd<uint16_t, N, 2>(), v), kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, Lanes(d)));
}
@@ -2537,37 +2553,37 @@ HWY_API VFromD<D> DemoteTo(D d, VFromD<Rebind<uint64_t, D>> v) {
HWY_API vuint8mf8_t U8FromU32(const vuint32mf2_t v) {
const size_t avl = Lanes(ScalableTag<uint8_t, -3>());
return __riscv_vnclipu_wx_u8mf8(
- __riscv_vnclipu_wx_u16mf4(v, 0,
+ __riscv_vnclipu_wx_u16mf4(v, kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)),
- 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
+ kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
}
HWY_API vuint8mf4_t U8FromU32(const vuint32m1_t v) {
const size_t avl = Lanes(ScalableTag<uint8_t, -2>());
return __riscv_vnclipu_wx_u8mf4(
- __riscv_vnclipu_wx_u16mf2(v, 0,
+ __riscv_vnclipu_wx_u16mf2(v, kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)),
- 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
+ kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
}
HWY_API vuint8mf2_t U8FromU32(const vuint32m2_t v) {
const size_t avl = Lanes(ScalableTag<uint8_t, -1>());
return __riscv_vnclipu_wx_u8mf2(
- __riscv_vnclipu_wx_u16m1(v, 0,
+ __riscv_vnclipu_wx_u16m1(v, kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)),
- 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
+ kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
}
HWY_API vuint8m1_t U8FromU32(const vuint32m4_t v) {
const size_t avl = Lanes(ScalableTag<uint8_t, 0>());
return __riscv_vnclipu_wx_u8m1(
- __riscv_vnclipu_wx_u16m2(v, 0,
+ __riscv_vnclipu_wx_u16m2(v, kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)),
- 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
+ kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
}
HWY_API vuint8m2_t U8FromU32(const vuint32m8_t v) {
const size_t avl = Lanes(ScalableTag<uint8_t, 1>());
return __riscv_vnclipu_wx_u8m2(
- __riscv_vnclipu_wx_u16m4(v, 0,
+ __riscv_vnclipu_wx_u16m4(v, kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl)),
- 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
+ kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
}
// ------------------------------ Truncations
@@ -2578,10 +2594,10 @@ HWY_API vuint8mf8_t TruncateTo(Simd<uint8_t, N, -3> d,
const size_t avl = Lanes(d);
const vuint64m1_t v1 = __riscv_vand(v, 0xFF, avl);
const vuint32mf2_t v2 = __riscv_vnclipu_wx_u32mf2(
- v1, 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
+ v1, kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
const vuint16mf4_t v3 = __riscv_vnclipu_wx_u16mf4(
- v2, 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
- return __riscv_vnclipu_wx_u8mf8(v3, 0,
+ v2, kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
+ return __riscv_vnclipu_wx_u8mf8(v3, kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
}
@@ -2591,10 +2607,10 @@ HWY_API vuint8mf4_t TruncateTo(Simd<uint8_t, N, -2> d,
const size_t avl = Lanes(d);
const vuint64m2_t v1 = __riscv_vand(v, 0xFF, avl);
const vuint32m1_t v2 = __riscv_vnclipu_wx_u32m1(
- v1, 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
+ v1, kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
const vuint16mf2_t v3 = __riscv_vnclipu_wx_u16mf2(
- v2, 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
- return __riscv_vnclipu_wx_u8mf4(v3, 0,
+ v2, kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
+ return __riscv_vnclipu_wx_u8mf4(v3, kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
}
@@ -2604,10 +2620,10 @@ HWY_API vuint8mf2_t TruncateTo(Simd<uint8_t, N, -1> d,
const size_t avl = Lanes(d);
const vuint64m4_t v1 = __riscv_vand(v, 0xFF, avl);
const vuint32m2_t v2 = __riscv_vnclipu_wx_u32m2(
- v1, 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
+ v1, kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
const vuint16m1_t v3 = __riscv_vnclipu_wx_u16m1(
- v2, 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
- return __riscv_vnclipu_wx_u8mf2(v3, 0,
+ v2, kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
+ return __riscv_vnclipu_wx_u8mf2(v3, kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
}
@@ -2617,10 +2633,10 @@ HWY_API vuint8m1_t TruncateTo(Simd<uint8_t, N, 0> d,
const size_t avl = Lanes(d);
const vuint64m8_t v1 = __riscv_vand(v, 0xFF, avl);
const vuint32m4_t v2 = __riscv_vnclipu_wx_u32m4(
- v1, 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
+ v1, kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
const vuint16m2_t v3 = __riscv_vnclipu_wx_u16m2(
- v2, 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
- return __riscv_vnclipu_wx_u8m1(v3, 0,
+ v2, kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
+ return __riscv_vnclipu_wx_u8m1(v3, kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
}
@@ -2630,8 +2646,8 @@ HWY_API vuint16mf4_t TruncateTo(Simd<uint16_t, N, -3> d,
const size_t avl = Lanes(d);
const vuint64m1_t v1 = __riscv_vand(v, 0xFFFF, avl);
const vuint32mf2_t v2 = __riscv_vnclipu_wx_u32mf2(
- v1, 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
- return __riscv_vnclipu_wx_u16mf4(v2, 0,
+ v1, kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
+ return __riscv_vnclipu_wx_u16mf4(v2, kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
}
@@ -2641,8 +2657,8 @@ HWY_API vuint16mf4_t TruncateTo(Simd<uint16_t, N, -2> d,
const size_t avl = Lanes(d);
const vuint64m1_t v1 = __riscv_vand(v, 0xFFFF, avl);
const vuint32mf2_t v2 = __riscv_vnclipu_wx_u32mf2(
- v1, 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
- return __riscv_vnclipu_wx_u16mf4(v2, 0,
+ v1, kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
+ return __riscv_vnclipu_wx_u16mf4(v2, kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
}
@@ -2652,8 +2668,8 @@ HWY_API vuint16mf2_t TruncateTo(Simd<uint16_t, N, -1> d,
const size_t avl = Lanes(d);
const vuint64m2_t v1 = __riscv_vand(v, 0xFFFF, avl);
const vuint32m1_t v2 = __riscv_vnclipu_wx_u32m1(
- v1, 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
- return __riscv_vnclipu_wx_u16mf2(v2, 0,
+ v1, kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
+ return __riscv_vnclipu_wx_u16mf2(v2, kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
}
@@ -2663,8 +2679,8 @@ HWY_API vuint16m1_t TruncateTo(Simd<uint16_t, N, 0> d,
const size_t avl = Lanes(d);
const vuint64m4_t v1 = __riscv_vand(v, 0xFFFF, avl);
const vuint32m2_t v2 = __riscv_vnclipu_wx_u32m2(
- v1, 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
- return __riscv_vnclipu_wx_u16m1(v2, 0,
+ v1, kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
+ return __riscv_vnclipu_wx_u16m1(v2, kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
}
@@ -2674,8 +2690,8 @@ HWY_API vuint16m2_t TruncateTo(Simd<uint16_t, N, 1> d,
const size_t avl = Lanes(d);
const vuint64m8_t v1 = __riscv_vand(v, 0xFFFF, avl);
const vuint32m4_t v2 = __riscv_vnclipu_wx_u32m4(
- v1, 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
- return __riscv_vnclipu_wx_u16m2(v2, 0,
+ v1, kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
+ return __riscv_vnclipu_wx_u16m2(v2, kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
}
@@ -2684,7 +2700,7 @@ HWY_API vuint32mf2_t TruncateTo(Simd<uint32_t, N, -2> d,
const VFromD<Simd<uint64_t, N, -1>> v) {
const size_t avl = Lanes(d);
const vuint64m1_t v1 = __riscv_vand(v, 0xFFFFFFFFu, avl);
- return __riscv_vnclipu_wx_u32mf2(v1, 0,
+ return __riscv_vnclipu_wx_u32mf2(v1, kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
}
@@ -2693,7 +2709,7 @@ HWY_API vuint32mf2_t TruncateTo(Simd<uint32_t, N, -1> d,
const VFromD<Simd<uint64_t, N, 0>> v) {
const size_t avl = Lanes(d);
const vuint64m1_t v1 = __riscv_vand(v, 0xFFFFFFFFu, avl);
- return __riscv_vnclipu_wx_u32mf2(v1, 0,
+ return __riscv_vnclipu_wx_u32mf2(v1, kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
}
@@ -2702,7 +2718,7 @@ HWY_API vuint32m1_t TruncateTo(Simd<uint32_t, N, 0> d,
const VFromD<Simd<uint64_t, N, 1>> v) {
const size_t avl = Lanes(d);
const vuint64m2_t v1 = __riscv_vand(v, 0xFFFFFFFFu, avl);
- return __riscv_vnclipu_wx_u32m1(v1, 0,
+ return __riscv_vnclipu_wx_u32m1(v1, kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
}
@@ -2711,7 +2727,7 @@ HWY_API vuint32m2_t TruncateTo(Simd<uint32_t, N, 1> d,
const VFromD<Simd<uint64_t, N, 2>> v) {
const size_t avl = Lanes(d);
const vuint64m4_t v1 = __riscv_vand(v, 0xFFFFFFFFu, avl);
- return __riscv_vnclipu_wx_u32m2(v1, 0,
+ return __riscv_vnclipu_wx_u32m2(v1, kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
}
@@ -2720,7 +2736,7 @@ HWY_API vuint32m4_t TruncateTo(Simd<uint32_t, N, 2> d,
const VFromD<Simd<uint64_t, N, 3>> v) {
const size_t avl = Lanes(d);
const vuint64m8_t v1 = __riscv_vand(v, 0xFFFFFFFFu, avl);
- return __riscv_vnclipu_wx_u32m4(v1, 0,
+ return __riscv_vnclipu_wx_u32m4(v1, kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
}
@@ -2730,8 +2746,8 @@ HWY_API vuint8mf8_t TruncateTo(Simd<uint8_t, N, -3> d,
const size_t avl = Lanes(d);
const vuint32mf2_t v1 = __riscv_vand(v, 0xFF, avl);
const vuint16mf4_t v2 = __riscv_vnclipu_wx_u16mf4(
- v1, 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
- return __riscv_vnclipu_wx_u8mf8(v2, 0,
+ v1, kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
+ return __riscv_vnclipu_wx_u8mf8(v2, kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
}
@@ -2741,8 +2757,8 @@ HWY_API vuint8mf4_t TruncateTo(Simd<uint8_t, N, -2> d,
const size_t avl = Lanes(d);
const vuint32m1_t v1 = __riscv_vand(v, 0xFF, avl);
const vuint16mf2_t v2 = __riscv_vnclipu_wx_u16mf2(
- v1, 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
- return __riscv_vnclipu_wx_u8mf4(v2, 0,
+ v1, kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
+ return __riscv_vnclipu_wx_u8mf4(v2, kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
}
@@ -2752,8 +2768,8 @@ HWY_API vuint8mf2_t TruncateTo(Simd<uint8_t, N, -1> d,
const size_t avl = Lanes(d);
const vuint32m2_t v1 = __riscv_vand(v, 0xFF, avl);
const vuint16m1_t v2 = __riscv_vnclipu_wx_u16m1(
- v1, 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
- return __riscv_vnclipu_wx_u8mf2(v2, 0,
+ v1, kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
+ return __riscv_vnclipu_wx_u8mf2(v2, kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
}
@@ -2763,8 +2779,8 @@ HWY_API vuint8m1_t TruncateTo(Simd<uint8_t, N, 0> d,
const size_t avl = Lanes(d);
const vuint32m4_t v1 = __riscv_vand(v, 0xFF, avl);
const vuint16m2_t v2 = __riscv_vnclipu_wx_u16m2(
- v1, 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
- return __riscv_vnclipu_wx_u8m1(v2, 0,
+ v1, kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
+ return __riscv_vnclipu_wx_u8m1(v2, kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
}
@@ -2774,8 +2790,8 @@ HWY_API vuint8m2_t TruncateTo(Simd<uint8_t, N, 1> d,
const size_t avl = Lanes(d);
const vuint32m8_t v1 = __riscv_vand(v, 0xFF, avl);
const vuint16m4_t v2 = __riscv_vnclipu_wx_u16m4(
- v1, 0, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
- return __riscv_vnclipu_wx_u8m2(v2, 0,
+ v1, kClipShift, HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
+ return __riscv_vnclipu_wx_u8m2(v2, kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
}
@@ -2784,7 +2800,7 @@ HWY_API vuint16mf4_t TruncateTo(Simd<uint16_t, N, -3> d,
const VFromD<Simd<uint32_t, N, -2>> v) {
const size_t avl = Lanes(d);
const vuint32mf2_t v1 = __riscv_vand(v, 0xFFFF, avl);
- return __riscv_vnclipu_wx_u16mf4(v1, 0,
+ return __riscv_vnclipu_wx_u16mf4(v1, kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
}
@@ -2793,7 +2809,7 @@ HWY_API vuint16mf4_t TruncateTo(Simd<uint16_t, N, -2> d,
const VFromD<Simd<uint32_t, N, -1>> v) {
const size_t avl = Lanes(d);
const vuint32mf2_t v1 = __riscv_vand(v, 0xFFFF, avl);
- return __riscv_vnclipu_wx_u16mf4(v1, 0,
+ return __riscv_vnclipu_wx_u16mf4(v1, kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
}
@@ -2802,7 +2818,7 @@ HWY_API vuint16mf2_t TruncateTo(Simd<uint16_t, N, -1> d,
const VFromD<Simd<uint32_t, N, 0>> v) {
const size_t avl = Lanes(d);
const vuint32m1_t v1 = __riscv_vand(v, 0xFFFF, avl);
- return __riscv_vnclipu_wx_u16mf2(v1, 0,
+ return __riscv_vnclipu_wx_u16mf2(v1, kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
}
@@ -2811,7 +2827,7 @@ HWY_API vuint16m1_t TruncateTo(Simd<uint16_t, N, 0> d,
const VFromD<Simd<uint32_t, N, 1>> v) {
const size_t avl = Lanes(d);
const vuint32m2_t v1 = __riscv_vand(v, 0xFFFF, avl);
- return __riscv_vnclipu_wx_u16m1(v1, 0,
+ return __riscv_vnclipu_wx_u16m1(v1, kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
}
@@ -2820,7 +2836,7 @@ HWY_API vuint16m2_t TruncateTo(Simd<uint16_t, N, 1> d,
const VFromD<Simd<uint32_t, N, 2>> v) {
const size_t avl = Lanes(d);
const vuint32m4_t v1 = __riscv_vand(v, 0xFFFF, avl);
- return __riscv_vnclipu_wx_u16m2(v1, 0,
+ return __riscv_vnclipu_wx_u16m2(v1, kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
}
@@ -2829,7 +2845,7 @@ HWY_API vuint16m4_t TruncateTo(Simd<uint16_t, N, 2> d,
const VFromD<Simd<uint32_t, N, 3>> v) {
const size_t avl = Lanes(d);
const vuint32m8_t v1 = __riscv_vand(v, 0xFFFF, avl);
- return __riscv_vnclipu_wx_u16m4(v1, 0,
+ return __riscv_vnclipu_wx_u16m4(v1, kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
}
@@ -2838,7 +2854,7 @@ HWY_API vuint8mf8_t TruncateTo(Simd<uint8_t, N, -3> d,
const VFromD<Simd<uint16_t, N, -2>> v) {
const size_t avl = Lanes(d);
const vuint16mf4_t v1 = __riscv_vand(v, 0xFF, avl);
- return __riscv_vnclipu_wx_u8mf8(v1, 0,
+ return __riscv_vnclipu_wx_u8mf8(v1, kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
}
@@ -2847,7 +2863,7 @@ HWY_API vuint8mf4_t TruncateTo(Simd<uint8_t, N, -2> d,
const VFromD<Simd<uint16_t, N, -1>> v) {
const size_t avl = Lanes(d);
const vuint16mf2_t v1 = __riscv_vand(v, 0xFF, avl);
- return __riscv_vnclipu_wx_u8mf4(v1, 0,
+ return __riscv_vnclipu_wx_u8mf4(v1, kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
}
@@ -2856,7 +2872,7 @@ HWY_API vuint8mf2_t TruncateTo(Simd<uint8_t, N, -1> d,
const VFromD<Simd<uint16_t, N, 0>> v) {
const size_t avl = Lanes(d);
const vuint16m1_t v1 = __riscv_vand(v, 0xFF, avl);
- return __riscv_vnclipu_wx_u8mf2(v1, 0,
+ return __riscv_vnclipu_wx_u8mf2(v1, kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
}
@@ -2865,7 +2881,7 @@ HWY_API vuint8m1_t TruncateTo(Simd<uint8_t, N, 0> d,
const VFromD<Simd<uint16_t, N, 1>> v) {
const size_t avl = Lanes(d);
const vuint16m2_t v1 = __riscv_vand(v, 0xFF, avl);
- return __riscv_vnclipu_wx_u8m1(v1, 0,
+ return __riscv_vnclipu_wx_u8m1(v1, kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
}
@@ -2874,7 +2890,7 @@ HWY_API vuint8m2_t TruncateTo(Simd<uint8_t, N, 1> d,
const VFromD<Simd<uint16_t, N, 2>> v) {
const size_t avl = Lanes(d);
const vuint16m4_t v1 = __riscv_vand(v, 0xFF, avl);
- return __riscv_vnclipu_wx_u8m2(v1, 0,
+ return __riscv_vnclipu_wx_u8m2(v1, kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
}
@@ -2883,7 +2899,7 @@ HWY_API vuint8m4_t TruncateTo(Simd<uint8_t, N, 2> d,
const VFromD<Simd<uint16_t, N, 3>> v) {
const size_t avl = Lanes(d);
const vuint16m8_t v1 = __riscv_vand(v, 0xFF, avl);
- return __riscv_vnclipu_wx_u8m4(v1, 0,
+ return __riscv_vnclipu_wx_u8m4(v1, kClipShift,
HWY_RVV_INSERT_VXRM(__RISCV_VXRM_RDN, avl));
}
@@ -3173,7 +3189,7 @@ namespace detail {
// For x86-compatible behaviour mandated by Highway API: TableLookupBytes
// offsets are implicitly relative to the start of their 128-bit block.
template <typename T, size_t N, int kPow2>
-HWY_INLINE size_t LanesPerBlock(Simd<T, N, kPow2> d) {
+HWY_LANES_CONSTEXPR HWY_INLINE size_t LanesPerBlock(Simd<T, N, kPow2> d) {
// kMinVecBytes is the minimum size of VFromD<decltype(d)> in bytes
constexpr size_t kMinVecBytes =
ScaleByPower(16, HWY_MAX(HWY_MIN(kPow2, 3), -3));
@@ -3660,7 +3676,7 @@ HWY_API V SwapAdjacentBlocks(const V v) {
template <class D, class V = VFromD<D>>
HWY_API V InterleaveEvenBlocks(D d, V a, V b) {
- const size_t lpb = detail::LanesPerBlock(d);
+ HWY_LANES_CONSTEXPR size_t lpb = detail::LanesPerBlock(d);
return OddEvenBlocks(SlideUpLanes(d, b, lpb), a);
}
@@ -3669,7 +3685,7 @@ HWY_API V InterleaveEvenBlocks(D d, V a, V b) {
template <class D, class V = VFromD<D>>
HWY_API V InterleaveOddBlocks(D d, V a, V b) {
- const size_t lpb = detail::LanesPerBlock(d);
+ HWY_LANES_CONSTEXPR size_t lpb = detail::LanesPerBlock(d);
return OddEvenBlocks(b, SlideDownLanes(d, a, lpb));
}
@@ -3751,6 +3767,83 @@ HWY_RVV_FOREACH_UI08(HWY_RVV_MASKED_TABLE16, MaskedTableLookupLanes16,
} // namespace detail
+// ------------------------------ InterleaveLowerBlocks
+// (ConcatLowerLower, OddEvenBlocks, TableLookupLanes)
+
+namespace detail {
+
+// Given a concatenated vector of (either upper or lower) half of `b` and `a`,
+// permutes the vector to return interleaved blocks from both.
+template <class D, class V = VFromD<D>, HWY_IF_NOT_T_SIZE_D(D, 1)>
+HWY_INLINE V InterleaveBlocks(D d, const V ba) {
+ const RebindToUnsigned<decltype(d)> du;
+ using VU = VFromD<decltype(du)>;
+ using TU = TFromD<decltype(du)>;
+ HWY_LANES_CONSTEXPR size_t lpb = detail::LanesPerBlock(d);
+ const VU iota = detail::Iota0(du);
+ // Divide the block index by 2, without affecting the within-block lane.
+ const VU idx_blocks =
+ detail::AndS(ShiftRight<1>(iota), static_cast<TU>(~(lpb - 1)));
+ VU idx = Or(idx_blocks, detail::AndS(iota, static_cast<TU>(lpb - 1)));
+ // Odd blocks from from `b`, i.e. the upper half of `ba`.
+ idx = OddEvenBlocks(Add(idx, Set(du, static_cast<TU>(Lanes(d) / 2))), idx);
+
+ return TableLookupLanes(ba, IndicesFromVec(d, idx));
+}
+
+// As above, but uses 16-bit indices because 8-bit lanes might overflow.
+// Identical except for for `du` and the `TableLookupLanes16` call.
+template <class D, class V = VFromD<D>, HWY_IF_T_SIZE_D(D, 1)>
+HWY_INLINE V InterleaveBlocks(D d, const V ba) {
+ const Rebind<uint16_t, decltype(d)> du; // Increases LMUL
+ using VU = VFromD<decltype(du)>;
+ using TU = TFromD<decltype(du)>;
+ HWY_LANES_CONSTEXPR size_t lpb = detail::LanesPerBlock(d);
+ const VU iota = detail::Iota0(du);
+ // Divide the block index by 2, without affecting the within-block lane.
+ const VU idx_blocks =
+ detail::AndS(ShiftRight<1>(iota), static_cast<TU>(~(lpb - 1)));
+ VU idx = Or(idx_blocks, detail::AndS(iota, static_cast<TU>(lpb - 1)));
+ // Odd blocks from from `b`, i.e. the upper half of `ba`.
+ idx = OddEvenBlocks(Add(idx, Set(du, static_cast<TU>(Lanes(d) / 2))), idx);
+
+ return detail::TableLookupLanes16(ba, idx);
+}
+
+} // namespace detail
+
+template <class D, class V = VFromD<D>, HWY_IF_POW2_GT_D(DFromV<V>, -3)>
+HWY_INLINE V InterleaveLowerBlocks(D d, const V a, const V b) {
+ // Blocks are independent of type, hence cast to a type where indices will not
+ // overflow.
+ const Repartition<uint16_t, decltype(d)> du;
+ return BitCast(
+ d, detail::InterleaveBlocks(du, BitCast(du, ConcatLowerLower(d, b, a))));
+}
+
+template <class D, class V = VFromD<D>, HWY_IF_POW2_LE_D(DFromV<V>, -3)>
+HWY_INLINE V InterleaveLowerBlocks(D d, const V a, const V b) {
+ // Might be bytes; if so, the second overload will call TableLookupLanes16.
+ return detail::InterleaveBlocks(d, ConcatLowerLower(d, b, a));
+}
+
+// ------------------------------ InterleaveUpperBlocks
+
+template <class D, class V = VFromD<D>, HWY_IF_POW2_GT_D(DFromV<V>, -3)>
+HWY_INLINE V InterleaveUpperBlocks(D d, const V a, const V b) {
+ // Blocks are independent of type, hence cast to a type where indices will not
+ // overflow.
+ const Repartition<uint16_t, decltype(d)> du;
+ return BitCast(
+ d, detail::InterleaveBlocks(du, BitCast(du, ConcatUpperUpper(d, b, a))));
+}
+
+template <class D, class V = VFromD<D>, HWY_IF_POW2_LE_D(DFromV<V>, -3)>
+HWY_INLINE V InterleaveUpperBlocks(D d, const V a, const V b) {
+ // Might be bytes; if so, the second overload will call TableLookupLanes16.
+ return detail::InterleaveBlocks(d, ConcatUpperUpper(d, b, a));
+}
+
// ------------------------------ Reverse (TableLookupLanes)
template <class D, HWY_IF_T_SIZE_D(D, 1), HWY_IF_POW2_LE_D(D, 2)>
HWY_API VFromD<D> Reverse(D d, VFromD<D> v) {
@@ -4406,11 +4499,14 @@ HWY_API V BroadcastBlock(V v) {
return BitCast(d, detail::TableLookupLanes16(BitCast(du8, v), idx));
}
-template <int kBlockIdx, class V, HWY_IF_POW2_GT_D(DFromV<V>, -3)>
+namespace detail {
+
+// Called for at least 16-bit lanes to ensure indices do not overflow.
+template <int kBlockIdx, class V>
HWY_API V BroadcastBlock(V v) {
const DFromV<decltype(v)> d;
- using TU = If<sizeof(TFromV<V>) == 1, uint16_t, MakeUnsigned<TFromV<V>>>;
- const Repartition<TU, decltype(d)> du;
+ const RebindToUnsigned<decltype(d)> du;
+ using TU = TFromD<decltype(du)>;
static_assert(0 <= kBlockIdx && kBlockIdx < d.MaxBlocks(),
"Invalid block index");
@@ -4422,6 +4518,17 @@ HWY_API V BroadcastBlock(V v) {
return BitCast(d, TableLookupLanes(BitCast(du, v), idx));
}
+} // namespace detail
+
+template <int kBlockIdx, class V, HWY_IF_POW2_GT_D(DFromV<V>, -3)>
+HWY_API V BroadcastBlock(V v) {
+ // Because we are broadcasting 128-bit blocks, the type does not matter.
+ // We can cast to uint16_t to ensure indices do not overflow.
+ const DFromV<decltype(v)> d;
+ const Repartition<uint16_t, decltype(d)> du16;
+ return BitCast(d, detail::BroadcastBlock<kBlockIdx>(BitCast(du16, v)));
+}
+
// ------------------------------ ExtractBlock
template <int kBlockIdx, class V>
HWY_API VFromD<BlockDFromD<DFromV<V>>> ExtractBlock(V v) {
@@ -4485,7 +4592,7 @@ HWY_API V ShiftRightLanes(const Simd<T, N, kPow2> d, V v) {
const auto shifted = detail::SlideDown(v, kLanes);
// Match x86 semantics by zeroing upper lanes in 128-bit blocks
- const size_t lpb = detail::LanesPerBlock(di);
+ HWY_LANES_CONSTEXPR size_t lpb = detail::LanesPerBlock(di);
const auto idx_mod =
detail::AndS(BitCast(di, detail::Iota0(du)), static_cast<TI>(lpb - 1));
const auto keep = detail::LtS(idx_mod, static_cast<TI>(lpb - kLanes));
@@ -4500,10 +4607,10 @@ HWY_API V ShiftRightBytes(const D d, const V v) {
}
// ------------------------------ InterleaveWholeLower
-#ifdef HWY_NATIVE_INTERLEAVE_WHOLE
-#undef HWY_NATIVE_INTERLEAVE_WHOLE
+#ifdef HWY_TOGGLE_INTERLEAVE_WHOLE
+#undef HWY_TOGGLE_INTERLEAVE_WHOLE
#else
-#define HWY_NATIVE_INTERLEAVE_WHOLE
+#define HWY_TOGGLE_INTERLEAVE_WHOLE
#endif
namespace detail {
@@ -4594,7 +4701,7 @@ namespace detail {
// Definitely at least 128 bit: match x86 semantics (independent blocks). Using
// InterleaveWhole and 64-bit Compress avoids 8-bit overflow.
template <class D, class V, HWY_IF_POW2_LE_D(D, 2)>
-HWY_INLINE V InterleaveLowerBlocks(D d, const V a, const V b) {
+HWY_INLINE V InterleaveLowerImpl(D d, const V a, const V b) {
static_assert(IsSame<TFromD<D>, TFromV<V>>(), "D/V mismatch");
const Twice<D> dt;
const RebindToUnsigned<decltype(dt)> dt_u;
@@ -4609,18 +4716,18 @@ HWY_INLINE V InterleaveLowerBlocks(D d, const V a, const V b) {
return BitCast(d, LowerHalf(Compress(BitCast(dt_u, interleaved), is_even)));
}
template <class D, class V, HWY_IF_POW2_GT_D(D, 2)>
-HWY_INLINE V InterleaveLowerBlocks(D d, const V a, const V b) {
+HWY_INLINE V InterleaveLowerImpl(D d, const V a, const V b) {
const Half<D> dh;
const VFromD<decltype(dh)> i0 =
- InterleaveLowerBlocks(dh, LowerHalf(dh, a), LowerHalf(dh, b));
+ InterleaveLowerImpl(dh, LowerHalf(dh, a), LowerHalf(dh, b));
const VFromD<decltype(dh)> i1 =
- InterleaveLowerBlocks(dh, UpperHalf(dh, a), UpperHalf(dh, b));
+ InterleaveLowerImpl(dh, UpperHalf(dh, a), UpperHalf(dh, b));
return Combine(d, i1, i0);
}
// As above, for the upper half of blocks.
template <class D, class V, HWY_IF_POW2_LE_D(D, 2)>
-HWY_INLINE V InterleaveUpperBlocks(D d, const V a, const V b) {
+HWY_INLINE V InterleaveUpperImpl(D d, const V a, const V b) {
static_assert(IsSame<TFromD<D>, TFromV<V>>(), "D/V mismatch");
const Twice<D> dt;
const RebindToUnsigned<decltype(dt)> dt_u;
@@ -4635,12 +4742,12 @@ HWY_INLINE V InterleaveUpperBlocks(D d, const V a, const V b) {
return BitCast(d, LowerHalf(Compress(BitCast(dt_u, interleaved), is_odd)));
}
template <class D, class V, HWY_IF_POW2_GT_D(D, 2)>
-HWY_INLINE V InterleaveUpperBlocks(D d, const V a, const V b) {
+HWY_INLINE V InterleaveUpperImpl(D d, const V a, const V b) {
const Half<D> dh;
const VFromD<decltype(dh)> i0 =
- InterleaveUpperBlocks(dh, LowerHalf(dh, a), LowerHalf(dh, b));
+ InterleaveUpperImpl(dh, LowerHalf(dh, a), LowerHalf(dh, b));
const VFromD<decltype(dh)> i1 =
- InterleaveUpperBlocks(dh, UpperHalf(dh, a), UpperHalf(dh, b));
+ InterleaveUpperImpl(dh, UpperHalf(dh, a), UpperHalf(dh, b));
return Combine(d, i1, i0);
}
@@ -4667,7 +4774,7 @@ constexpr bool IsLT128(Simd<T, N, kPow2> /* d */) {
template <class D, class V, HWY_RVV_IF_GE128_D(D)>
HWY_API V InterleaveLower(D d, const V a, const V b) {
- return detail::InterleaveLowerBlocks(d, a, b);
+ return detail::InterleaveLowerImpl(d, a, b);
}
// Single block: interleave without extra Compress.
@@ -4685,8 +4792,8 @@ HWY_API V InterleaveLower(D d, const V a, const V b) {
}
// Fractional LMUL: use LMUL=1 to ensure we can cast to u64.
const ScalableTag<TFromD<D>, HWY_MAX(d.Pow2(), 0)> d1;
- return ResizeBitCast(d, detail::InterleaveLowerBlocks(
- d1, ResizeBitCast(d1, a), ResizeBitCast(d1, b)));
+ return ResizeBitCast(d, detail::InterleaveLowerImpl(d1, ResizeBitCast(d1, a),
+ ResizeBitCast(d1, b)));
}
template <class V>
@@ -4698,7 +4805,7 @@ HWY_API V InterleaveLower(const V a, const V b) {
template <class D, class V, HWY_RVV_IF_GE128_D(D)>
HWY_API V InterleaveUpper(D d, const V a, const V b) {
- return detail::InterleaveUpperBlocks(d, a, b);
+ return detail::InterleaveUpperImpl(d, a, b);
}
// Single block: interleave without extra Compress.
@@ -4716,8 +4823,8 @@ HWY_API V InterleaveUpper(D d, const V a, const V b) {
}
// Fractional LMUL: use LMUL=1 to ensure we can cast to u64.
const ScalableTag<TFromD<D>, HWY_MAX(d.Pow2(), 0)> d1;
- return ResizeBitCast(d, detail::InterleaveUpperBlocks(
- d1, ResizeBitCast(d1, a), ResizeBitCast(d1, b)));
+ return ResizeBitCast(d, detail::InterleaveUpperImpl(d1, ResizeBitCast(d1, a),
+ ResizeBitCast(d1, b)));
}
// ------------------------------ ZipLower
@@ -5253,18 +5360,6 @@ HWY_API VFromD<D> Dup128VecFromValues(D d, TFromD<D> t0, TFromD<D> t1,
detail::Vec64ValsWrapper<TFromD<D>>{{t2, t3}})));
}
-// ------------------------------ PopulationCount (ShiftRight)
-
-// Handles LMUL < 2 or capped vectors, which generic_ops-inl cannot.
-template <typename V, class D = DFromV<V>, HWY_IF_U8_D(D),
- hwy::EnableIf<D().Pow2() < 1 || D().MaxLanes() < 16>* = nullptr>
-HWY_API V PopulationCount(V v) {
- // See https://arxiv.org/pdf/1611.07612.pdf, Figure 3
- v = Sub(v, detail::AndS(ShiftRight<1>(v), 0x55));
- v = Add(detail::AndS(ShiftRight<2>(v), 0x33), detail::AndS(v, 0x33));
- return detail::AndS(Add(v, ShiftRight<4>(v)), 0x0F);
-}
-
// ------------------------------ LoadDup128
template <class D>
@@ -5291,7 +5386,7 @@ HWY_API VFromD<D> LoadDup128(D d, const TFromD<D>* const HWY_RESTRICT p) {
// idx must be unsigned for TableLookupLanes.
using TU = TFromD<decltype(du)>;
- const TU mask = static_cast<TU>(detail::LanesPerBlock(d) - 1);
+ HWY_LANES_CONSTEXPR TU mask = static_cast<TU>(detail::LanesPerBlock(d) - 1);
// Broadcast the first block.
const VFromD<RebindToUnsigned<D>> idx = detail::AndS(detail::Iota0(du), mask);
// Safe even for 8-bit lanes because indices never exceed 15.
@@ -5552,6 +5647,15 @@ constexpr int SufficientPow2ForMask() {
return HWY_MAX(
D().Pow2() - 3 - static_cast<int>(FloorLog2(sizeof(TFromD<D>))), -3);
}
+
+template <class M>
+static HWY_INLINE HWY_MAYBE_UNUSED M RvvVmmv(M mask) {
+ // The below And operation is equivalent to the RVV vmmv instruction and
+ // ensures that mask is not in the same register as a vector operand when used
+ // in RVV instructions that take both a vector operand and a mask operand.
+ return And(mask, mask);
+}
+
} // namespace detail
template <class D, HWY_IF_T_SIZE_D(D, 1), HWY_IF_LANES_LE_D(D, 8)>
@@ -5560,8 +5664,10 @@ HWY_API MFromD<D> Dup128MaskFromMaskBits(D d, unsigned mask_bits) {
if (kN < 8) mask_bits &= detail::MaxMaskBits<kN>();
#if HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400
- return detail::U8MaskBitsVecToMask(
- d, Set(ScalableTag<uint8_t>(), static_cast<uint8_t>(mask_bits)));
+ const ScalableTag<uint8_t, detail::SufficientPow2ForMask<D>()> du8;
+ return detail::RvvVmmv(detail::U8MaskBitsVecToMask(
+ d, detail::ChangeLMUL(ScalableTag<uint8_t>(),
+ Set(du8, static_cast<uint8_t>(mask_bits)))));
#else
const RebindToUnsigned<decltype(d)> du8;
const detail::AdjustSimdTagToMinVecPow2<Repartition<uint64_t, decltype(du8)>>
@@ -5581,10 +5687,10 @@ HWY_API MFromD<D> Dup128MaskFromMaskBits(D d, unsigned mask_bits) {
const ScalableTag<uint8_t, detail::SufficientPow2ForMask<D>()> du8;
const ScalableTag<uint16_t, detail::SufficientPow2ForMask<D>()> du16;
// There are exactly 16 mask bits for 128 vector bits of 8-bit lanes.
- return detail::U8MaskBitsVecToMask(
+ return detail::RvvVmmv(detail::U8MaskBitsVecToMask(
d, detail::ChangeLMUL(
ScalableTag<uint8_t>(),
- BitCast(du8, Set(du16, static_cast<uint16_t>(mask_bits)))));
+ BitCast(du8, Set(du16, static_cast<uint16_t>(mask_bits))))));
#else
// Slow fallback for completeness; the above bits to mask cast is preferred.
const RebindToUnsigned<decltype(d)> du8;
@@ -5613,9 +5719,9 @@ HWY_API MFromD<D> Dup128MaskFromMaskBits(D d, unsigned mask_bits) {
#if HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400
const ScalableTag<uint8_t, detail::SufficientPow2ForMask<D>()> du8;
// There are exactly 8 mask bits for 128 vector bits of 16-bit lanes.
- return detail::U8MaskBitsVecToMask(
+ return detail::RvvVmmv(detail::U8MaskBitsVecToMask(
d, detail::ChangeLMUL(ScalableTag<uint8_t>(),
- Set(du8, static_cast<uint8_t>(mask_bits))));
+ Set(du8, static_cast<uint8_t>(mask_bits)))));
#else
// Slow fallback for completeness; the above bits to mask cast is preferred.
const RebindToUnsigned<D> du;
@@ -5632,9 +5738,9 @@ HWY_API MFromD<D> Dup128MaskFromMaskBits(D d, unsigned mask_bits) {
#if HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400
const ScalableTag<uint8_t, detail::SufficientPow2ForMask<D>()> du8;
- return detail::U8MaskBitsVecToMask(
+ return detail::RvvVmmv(detail::U8MaskBitsVecToMask(
d, detail::ChangeLMUL(ScalableTag<uint8_t>(),
- Set(du8, static_cast<uint8_t>(mask_bits * 0x11))));
+ Set(du8, static_cast<uint8_t>(mask_bits * 0x11)))));
#else
// Slow fallback for completeness; the above bits to mask cast is preferred.
const RebindToUnsigned<D> du;
@@ -5650,9 +5756,9 @@ HWY_API MFromD<D> Dup128MaskFromMaskBits(D d, unsigned mask_bits) {
#if HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400
const ScalableTag<uint8_t, detail::SufficientPow2ForMask<D>()> du8;
- return detail::U8MaskBitsVecToMask(
+ return detail::RvvVmmv(detail::U8MaskBitsVecToMask(
d, detail::ChangeLMUL(ScalableTag<uint8_t>(),
- Set(du8, static_cast<uint8_t>(mask_bits * 0x55))));
+ Set(du8, static_cast<uint8_t>(mask_bits * 0x55)))));
#else
// Slow fallback for completeness; the above bits to mask cast is preferred.
const RebindToUnsigned<D> du;
@@ -5661,6 +5767,27 @@ HWY_API MFromD<D> Dup128MaskFromMaskBits(D d, unsigned mask_bits) {
#endif
}
+// ------------------------------ SetMask
+
+#ifdef HWY_NATIVE_SET_MASK
+#undef HWY_NATIVE_SET_MASK
+#else
+#define HWY_NATIVE_SET_MASK
+#endif
+
+template <class D>
+HWY_API MFromD<D> SetMask(D d, bool val) {
+ const uint8_t u8_mask_val = static_cast<uint8_t>(-static_cast<int>(val));
+#if HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400
+ const ScalableTag<uint8_t, detail::SufficientPow2ForMask<D>()> du8;
+ return detail::RvvVmmv(detail::U8MaskBitsVecToMask(
+ d, detail::ChangeLMUL(ScalableTag<uint8_t>(), Set(du8, u8_mask_val))));
+#else
+ const Rebind<uint8_t, DFromV<VFromD<decltype(d)>>> du8;
+ return MaskFromVec(Set(du8, u8_mask_val));
+#endif
+}
+
// ------------------------------ Abs (Max, Neg)
template <class V, HWY_IF_SIGNED_V(V)>
@@ -5934,9 +6061,13 @@ HWY_API V64 BitShuffle(V64 values, VI idx) {
template <class V, HWY_IF_T_SIZE_ONE_OF_V(V, (1 << 1) | (1 << 2) | (1 << 4)),
class D = DFromV<V>, class DW = RepartitionToWide<D>>
HWY_API VFromD<DW> MulEven(const V a, const V b) {
- const auto lo = Mul(a, b);
- const auto hi = MulHigh(a, b);
- return BitCast(DW(), OddEven(detail::Slide1Up(hi), lo));
+ constexpr int maskVal = sizeof(TFromD<D>) == 4 ? 5
+ : sizeof(TFromD<D>) == 2 ? 0x55
+ : 0x5555;
+ const auto mask = Dup128MaskFromMaskBits(D(), maskVal);
+ const auto hi = Slide1Up(D(), MulHigh(a, b));
+ const auto res = MaskedMulOr(hi, mask, a, b);
+ return BitCast(DW(), res);
}
template <class V, HWY_IF_T_SIZE_ONE_OF_V(V, (1 << 1) | (1 << 2) | (1 << 4)),
@@ -5950,9 +6081,9 @@ HWY_API VFromD<DW> MulOdd(const V a, const V b) {
// There is no 64x64 vwmul.
template <class V, HWY_IF_T_SIZE_V(V, 8)>
HWY_INLINE V MulEven(const V a, const V b) {
- const auto lo = Mul(a, b);
- const auto hi = MulHigh(a, b);
- return OddEven(detail::Slide1Up(hi), lo);
+ const auto mask = Dup128MaskFromMaskBits(DFromV<V>(), 1);
+ const auto hi = Slide1Up(DFromV<V>(), MulHigh(a, b));
+ return MaskedMulOr(hi, mask, a, b);
}
template <class V, HWY_IF_T_SIZE_V(V, 8)>
@@ -6228,11 +6359,6 @@ HWY_API vuint32m8_t RearrangeToOddPlusEven(vuint32m8_t sum0, vuint32m8_t sum1) {
return Combine(d, hi, lo);
}
-template <class VW, HWY_IF_FLOAT_V(VW)> // vfloat*
-HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW sum1) {
- return Add(sum0, sum1); // invariant already holds
-}
-
// ------------------------------ Lt128
#if HWY_COMPILER_CLANG >= 1700 || HWY_COMPILER_GCC_ACTUAL >= 1400
diff --git a/third_party/highway/hwy/ops/scalar-inl.h b/third_party/highway/hwy/ops/scalar-inl.h
index ca59e80169..1ec2662f65 100644
--- a/third_party/highway/hwy/ops/scalar-inl.h
+++ b/third_party/highway/hwy/ops/scalar-inl.h
@@ -253,13 +253,6 @@ HWY_API Vec1<T> operator^(const Vec1<T> a, const Vec1<T> b) {
return Xor(a, b);
}
-// ------------------------------ Xor3
-
-template <typename T>
-HWY_API Vec1<T> Xor3(Vec1<T> x1, Vec1<T> x2, Vec1<T> x3) {
- return Xor(x1, Xor(x2, x3));
-}
-
// ------------------------------ Or3
template <typename T>
@@ -310,6 +303,17 @@ HWY_API Mask1<T> FirstN(D /*tag*/, size_t n) {
return Mask1<T>::FromBool(n != 0);
}
+#ifdef HWY_NATIVE_SET_MASK
+#undef HWY_NATIVE_SET_MASK
+#else
+#define HWY_NATIVE_SET_MASK
+#endif
+
+template <class D>
+HWY_API MFromD<D> SetMask(D /*d*/, bool val) {
+ return MFromD<D>::FromBool(val);
+}
+
// ------------------------------ IfVecThenElse
template <typename T>
HWY_API Vec1<T> IfVecThenElse(Vec1<T> mask, Vec1<T> yes, Vec1<T> no) {
@@ -1673,6 +1677,17 @@ HWY_API V InterleaveOddBlocks(D, V a, V /*b*/) {
return a;
}
+// ------------------------------ InterleaveLowerBlocks
+template <class D, class V = VFromD<D>>
+HWY_API V InterleaveLowerBlocks(D, V a, V /*b*/) {
+ return a;
+}
+// ------------------------------ InterleaveUpperBlocks
+template <class D, class V = VFromD<D>>
+HWY_API V InterleaveUpperBlocks(D, V a, V /*b*/) {
+ return a;
+}
+
// ------------------------------ TableLookupLanes
// Returned by SetTableIndices for use by TableLookupLanes.
diff --git a/third_party/highway/hwy/ops/set_macros-inl.h b/third_party/highway/hwy/ops/set_macros-inl.h
index 2cadeb8ad7..a090549bc5 100644
--- a/third_party/highway/hwy/ops/set_macros-inl.h
+++ b/third_party/highway/hwy/ops/set_macros-inl.h
@@ -1,5 +1,6 @@
// Copyright 2020 Google LLC
-// Copyright 2024 Arm Limited and/or its affiliates <open-source-office@arm.com>
+// Copyright 2024-2025 Arm Limited and/or its affiliates
+// <open-source-office@arm.com>
// SPDX-License-Identifier: Apache-2.0
// SPDX-License-Identifier: BSD-3-Clause
//
@@ -34,18 +35,39 @@
#undef HWY_NAMESPACE
#undef HWY_ALIGN
#undef HWY_MAX_BYTES
-#undef HWY_LANES
+#undef HWY_MIN_BYTES
#undef HWY_HAVE_SCALABLE
#undef HWY_HAVE_TUPLE
+#undef HWY_REGISTERS
#undef HWY_HAVE_INTEGER64
#undef HWY_HAVE_FLOAT16
#undef HWY_HAVE_FLOAT64
#undef HWY_MEM_OPS_MIGHT_FAULT
#undef HWY_NATIVE_FMA
#undef HWY_NATIVE_DOT_BF16
-#undef HWY_CAP_GE256
-#undef HWY_CAP_GE512
+#undef HWY_NATIVE_MASK
+#undef HWY_NATIVE_INTERLEAVE_WHOLE
+
+#ifndef HWY_CAP_GE256
+#define HWY_CAP_GE256 (HWY_MIN_BYTES >= 32)
+#endif
+#ifndef HWY_CAP_GE512
+#define HWY_CAP_GE512 (HWY_MIN_BYTES >= 64)
+#endif
+
+// Almost all targets (except RVV and SCALAR) use this definition.
+#undef HWY_LANES
+#define HWY_LANES(T) (HWY_MAX_BYTES / sizeof(T))
+
+// If 1, both __bf16 and a limited set of *_bf16 SVE intrinsics are available:
+// create/get/set/dup, ld/st, sel, rev, trn, uzp, zip.
+// Consulted below, hence define here rather than in arm_sve-inl.h.
+#if HWY_ARM_HAVE_SCALAR_BF16_TYPE && defined(__ARM_FEATURE_SVE_BF16)
+#define HWY_SVE_HAVE_BF16_FEATURE 1
+#else
+#define HWY_SVE_HAVE_BF16_FEATURE 0
+#endif
#undef HWY_TARGET_IS_SVE
#if HWY_TARGET & HWY_ALL_SVE
@@ -69,7 +91,7 @@
#endif
#undef HWY_TARGET_IS_AVX10_2
-#if HWY_TARGET == HWY_AVX10_2 || HWY_TARGET == HWY_AVX10_2_512
+#if HWY_TARGET == HWY_AVX10_2
#define HWY_TARGET_IS_AVX10_2 1
#else
#define HWY_TARGET_IS_AVX10_2 0
@@ -84,6 +106,13 @@
#define HWY_HAVE_TUPLE 1
#endif
+// Target-specific number of architectural vector registers available.
+#if !HWY_ARCH_X86 || (HWY_TARGET <= HWY_AVX3)
+#define HWY_REGISTERS 32
+#else
+#define HWY_REGISTERS 16
+#endif
+
// For internal use (clamping/validating N for Simd<>)
#undef HWY_MAX_N
#if HWY_TARGET == HWY_SCALAR
@@ -141,56 +170,67 @@
#define HWY_TARGET_STR_AVX2 \
HWY_TARGET_STR_SSE4 ",avx,avx2" HWY_TARGET_STR_BMI2_FMA HWY_TARGET_STR_F16C
-#if HWY_COMPILER_GCC_ACTUAL >= 1400 || HWY_COMPILER_CLANG >= 1800
+#ifndef HWY_HAVE_EVEX512 // allow override
+// evex512 has been removed from clang 22, see
+// https://github.com/llvm/llvm-project/pull/157034
+#if (1400 <= HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1600) || \
+ (1800 <= HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 2200)
+#define HWY_HAVE_EVEX512 1
+#else
+#define HWY_HAVE_EVEX512 0
+#endif
+#endif
+
+#if (HWY_HAVE_EVEX512 == 1)
#define HWY_TARGET_STR_AVX3_VL512 ",evex512"
#else
#define HWY_TARGET_STR_AVX3_VL512
#endif
-#define HWY_TARGET_STR_AVX3_256 \
- HWY_TARGET_STR_AVX2 \
- ",avx512f,avx512cd,avx512vl,avx512dq,avx512bw" HWY_TARGET_STR_AVX3_VL512
-
-#define HWY_TARGET_STR_AVX3 HWY_TARGET_STR_AVX3_256 HWY_TARGET_STR_AVX3_VL512
+#define HWY_TARGET_STR_AVX3 \
+ HWY_TARGET_STR_AVX2 \
+ ",avx512f,avx512cd,avx512vl,avx512dq,avx512bw" HWY_TARGET_STR_AVX3_VL512
-#define HWY_TARGET_STR_AVX3_DL_256 \
- HWY_TARGET_STR_AVX3_256 \
+#define HWY_TARGET_STR_AVX3_DL \
+ HWY_TARGET_STR_AVX3 \
",vpclmulqdq,avx512vbmi,avx512vbmi2,vaes,avx512vnni,avx512bitalg," \
"avx512vpopcntdq,gfni"
-#define HWY_TARGET_STR_AVX3_DL \
- HWY_TARGET_STR_AVX3_DL_256 HWY_TARGET_STR_AVX3_VL512
-
-// Force-disable for compilers that do not properly support avx512bf16.
-#if !defined(HWY_AVX3_DISABLE_AVX512BF16) && \
+// Opt-out for compilers that do not properly support avx512bf16.
+#ifndef HWY_AVX3_ENABLE_AVX512BF16 // allow override
+// Default is to disable if the DISABLE macro is defined, or if old compiler.
+// clang-cl 21.1.4 reportedly works; feel free to define this to 1 there.
+#if defined(HWY_AVX3_DISABLE_AVX512BF16) || \
(HWY_COMPILER_CLANGCL || \
(HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1000) || \
(HWY_COMPILER_CLANG && HWY_COMPILER_CLANG < 900))
-#define HWY_AVX3_DISABLE_AVX512BF16
+#define HWY_AVX3_ENABLE_AVX512BF16 0
+#else
+#define HWY_AVX3_ENABLE_AVX512BF16 1
#endif
+#endif // HWY_AVX3_ENABLE_AVX512BF16
-#if !defined(HWY_AVX3_DISABLE_AVX512BF16)
-#define HWY_TARGET_STR_AVX3_ZEN4_256 HWY_TARGET_STR_AVX3_DL ",avx512bf16"
+#if HWY_AVX3_ENABLE_AVX512BF16
+#define HWY_TARGET_STR_AVX3_ZEN4 HWY_TARGET_STR_AVX3_DL ",avx512bf16"
#else
-#define HWY_TARGET_STR_AVX3_ZEN4_256 HWY_TARGET_STR_AVX3_DL
+#define HWY_TARGET_STR_AVX3_ZEN4 HWY_TARGET_STR_AVX3_DL
#endif
-#define HWY_TARGET_STR_AVX3_ZEN4 \
- HWY_TARGET_STR_AVX3_ZEN4_256 HWY_TARGET_STR_AVX3_VL512
-
-#define HWY_TARGET_STR_AVX3_SPR_256 HWY_TARGET_STR_AVX3_ZEN4 ",avx512fp16"
-
-#define HWY_TARGET_STR_AVX3_SPR \
- HWY_TARGET_STR_AVX3_SPR_256 HWY_TARGET_STR_AVX3_VL512
+#if HWY_COMPILER_GCC_ACTUAL >= 1200 || HWY_COMPILER_CLANG >= 1400
+#define HWY_TARGET_STR_AVX3_SPR HWY_TARGET_STR_AVX3_ZEN4 ",avx512fp16"
+#else
+#define HWY_TARGET_STR_AVX3_SPR HWY_TARGET_STR_AVX3_ZEN4
+#endif
-#if HWY_COMPILER_GCC_ACTUAL >= 1500 || HWY_COMPILER_CLANG >= 2000
-#define HWY_TARGET_STR_AVX10_2 \
- HWY_TARGET_STR_AVX3_SPR_256 ",no-evex512,avx10.2-256"
-#define HWY_TARGET_STR_AVX10_2_512 \
- HWY_TARGET_STR_AVX3_SPR ",avx10.2-256,avx10.2-512"
+// Support for avx10.2-512 was removed between clang 22 and 23 without a
+// feature test macro.
+#if HWY_COMPILER_CLANG >= 2200 && HWY_HAVE_EVEX512
+#define HWY_TARGET_STR_AVX10_2 HWY_TARGET_STR_AVX3_SPR ",avx10.2-512"
+// Recent compilers drop the -512 suffix because 512 bits are always available.
+#elif HWY_COMPILER_GCC_ACTUAL >= 1500 || HWY_COMPILER_CLANG >= 2200
+#define HWY_TARGET_STR_AVX10_2 HWY_TARGET_STR_AVX3_SPR ",avx10.2"
#else
-#define HWY_TARGET_STR_AVX10_2 HWY_TARGET_STR_AVX3_SPR_256 ",no-evex512"
-#define HWY_TARGET_STR_AVX10_2_512 HWY_TARGET_STR_AVX3_SPR
+#define HWY_TARGET_STR_AVX10_2 HWY_TARGET_STR_AVX3_SPR
#endif
#if defined(HWY_DISABLE_PPC8_CRYPTO)
@@ -231,7 +271,7 @@
#define HWY_NAMESPACE N_SSE2
#define HWY_ALIGN alignas(16)
#define HWY_MAX_BYTES 16
-#define HWY_LANES(T) (16 / sizeof(T))
+#define HWY_MIN_BYTES 16
#define HWY_HAVE_SCALABLE 0
#define HWY_HAVE_INTEGER64 1
@@ -240,8 +280,7 @@
#define HWY_MEM_OPS_MIGHT_FAULT 1
#define HWY_NATIVE_FMA 0
#define HWY_NATIVE_DOT_BF16 0
-#define HWY_CAP_GE256 0
-#define HWY_CAP_GE512 0
+#define HWY_NATIVE_MASK 0 // a few actually are
#define HWY_TARGET_STR HWY_TARGET_STR_SSE2
//-----------------------------------------------------------------------------
@@ -251,7 +290,7 @@
#define HWY_NAMESPACE N_SSSE3
#define HWY_ALIGN alignas(16)
#define HWY_MAX_BYTES 16
-#define HWY_LANES(T) (16 / sizeof(T))
+#define HWY_MIN_BYTES 16
#define HWY_HAVE_SCALABLE 0
#define HWY_HAVE_INTEGER64 1
@@ -260,8 +299,7 @@
#define HWY_MEM_OPS_MIGHT_FAULT 1
#define HWY_NATIVE_FMA 0
#define HWY_NATIVE_DOT_BF16 0
-#define HWY_CAP_GE256 0
-#define HWY_CAP_GE512 0
+#define HWY_NATIVE_MASK 0 // a few actually are
#define HWY_TARGET_STR HWY_TARGET_STR_SSSE3
@@ -272,7 +310,7 @@
#define HWY_NAMESPACE N_SSE4
#define HWY_ALIGN alignas(16)
#define HWY_MAX_BYTES 16
-#define HWY_LANES(T) (16 / sizeof(T))
+#define HWY_MIN_BYTES 16
#define HWY_HAVE_SCALABLE 0
#define HWY_HAVE_INTEGER64 1
@@ -281,8 +319,7 @@
#define HWY_MEM_OPS_MIGHT_FAULT 1
#define HWY_NATIVE_FMA 0
#define HWY_NATIVE_DOT_BF16 0
-#define HWY_CAP_GE256 0
-#define HWY_CAP_GE512 0
+#define HWY_NATIVE_MASK 0 // a few actually are
#define HWY_TARGET_STR HWY_TARGET_STR_SSE4
@@ -293,7 +330,7 @@
#define HWY_NAMESPACE N_AVX2
#define HWY_ALIGN alignas(32)
#define HWY_MAX_BYTES 32
-#define HWY_LANES(T) (32 / sizeof(T))
+#define HWY_MIN_BYTES 32
#define HWY_HAVE_SCALABLE 0
#define HWY_HAVE_INTEGER64 1
@@ -307,32 +344,22 @@
#define HWY_NATIVE_FMA 1
#endif
#define HWY_NATIVE_DOT_BF16 0
-
-#define HWY_CAP_GE256 1
-#define HWY_CAP_GE512 0
+#define HWY_NATIVE_MASK 0 // a few actually are
#define HWY_TARGET_STR HWY_TARGET_STR_AVX2
//-----------------------------------------------------------------------------
-// AVX3[_DL]/AVX10
-#elif HWY_TARGET == HWY_AVX3 || HWY_TARGET == HWY_AVX3_DL || \
- HWY_TARGET == HWY_AVX3_ZEN4 || HWY_TARGET == HWY_AVX3_SPR || \
- HWY_TARGET == HWY_AVX10_2 || HWY_TARGET == HWY_AVX10_2_512
+// AVX3[_DL/ZEN4/SPR]/AVX10
+#elif HWY_TARGET <= HWY_AVX3
-#if HWY_TARGET == HWY_AVX10_2
-#define HWY_ALIGN alignas(32)
-#define HWY_MAX_BYTES 32
-#define HWY_LANES(T) (32 / sizeof(T))
-#else
#define HWY_ALIGN alignas(64)
#define HWY_MAX_BYTES 64
-#define HWY_LANES(T) (64 / sizeof(T))
-#endif
+#define HWY_MIN_BYTES 64
#define HWY_HAVE_SCALABLE 0
#define HWY_HAVE_INTEGER64 1
-#if HWY_TARGET <= HWY_AVX10_2 && \
- (HWY_COMPILER_GCC_ACTUAL || HWY_COMPILER_CLANG >= 1901) && \
+#if HWY_TARGET <= HWY_AVX3_SPR && \
+ (HWY_COMPILER_GCC_ACTUAL || HWY_COMPILER_CLANG >= 2200) && \
HWY_HAVE_SCALAR_F16_TYPE
#define HWY_HAVE_FLOAT16 1
#else
@@ -341,18 +368,12 @@
#define HWY_HAVE_FLOAT64 1
#define HWY_MEM_OPS_MIGHT_FAULT 0
#define HWY_NATIVE_FMA 1
-#if (HWY_TARGET <= HWY_AVX3_ZEN4) && !defined(HWY_AVX3_DISABLE_AVX512BF16)
+#if (HWY_TARGET <= HWY_AVX3_ZEN4) && HWY_AVX3_ENABLE_AVX512BF16
#define HWY_NATIVE_DOT_BF16 1
#else
#define HWY_NATIVE_DOT_BF16 0
#endif
-#define HWY_CAP_GE256 1
-
-#if HWY_MAX_BYTES >= 64
-#define HWY_CAP_GE512 1
-#else
-#define HWY_CAP_GE512 0
-#endif
+#define HWY_NATIVE_MASK 1
#if HWY_TARGET == HWY_AVX3
@@ -379,11 +400,6 @@
#define HWY_NAMESPACE N_AVX10_2
#define HWY_TARGET_STR HWY_TARGET_STR_AVX10_2
-#elif HWY_TARGET == HWY_AVX10_2_512
-
-#define HWY_NAMESPACE N_AVX10_2_512
-#define HWY_TARGET_STR HWY_TARGET_STR_AVX10_2_512
-
#else
#error "Logic error"
#endif // HWY_TARGET
@@ -394,7 +410,7 @@
#define HWY_ALIGN alignas(16)
#define HWY_MAX_BYTES 16
-#define HWY_LANES(T) (16 / sizeof(T))
+#define HWY_MIN_BYTES 16
#define HWY_HAVE_SCALABLE 0
#define HWY_HAVE_INTEGER64 1
@@ -403,8 +419,7 @@
#define HWY_MEM_OPS_MIGHT_FAULT 1
#define HWY_NATIVE_FMA 1
#define HWY_NATIVE_DOT_BF16 0
-#define HWY_CAP_GE256 0
-#define HWY_CAP_GE512 0
+#define HWY_NATIVE_MASK 0
#if HWY_TARGET == HWY_PPC8
@@ -431,7 +446,7 @@
#define HWY_ALIGN alignas(16)
#define HWY_MAX_BYTES 16
-#define HWY_LANES(T) (16 / sizeof(T))
+#define HWY_MIN_BYTES 16
#define HWY_HAVE_SCALABLE 0
#define HWY_HAVE_INTEGER64 1
@@ -440,8 +455,7 @@
#define HWY_MEM_OPS_MIGHT_FAULT 1
#define HWY_NATIVE_FMA 1
#define HWY_NATIVE_DOT_BF16 0
-#define HWY_CAP_GE256 0
-#define HWY_CAP_GE512 0
+#define HWY_NATIVE_MASK 0
#if HWY_TARGET == HWY_Z14
@@ -461,9 +475,32 @@
// NEON
#elif HWY_TARGET_IS_NEON
+// Clang 17 crashes with bf16, see github.com/llvm/llvm-project/issues/64179.
+#undef HWY_NEON_HAVE_BFLOAT16
+#if HWY_HAVE_SCALAR_BF16_TYPE && \
+ ((HWY_TARGET == HWY_NEON_BF16 && \
+ (!HWY_COMPILER_CLANG || HWY_COMPILER_CLANG >= 1800)) || \
+ defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC))
+#define HWY_NEON_HAVE_BFLOAT16 1
+#else
+#define HWY_NEON_HAVE_BFLOAT16 0
+#endif
+
+// HWY_NEON_HAVE_F32_TO_BF16C is defined if NEON vcvt_bf16_f32 and
+// vbfdot_f32 are available, even if the __bf16 type is disabled due to
+// GCC/Clang bugs.
+#undef HWY_NEON_HAVE_F32_TO_BF16C
+#if HWY_NEON_HAVE_BFLOAT16 || HWY_TARGET == HWY_NEON_BF16 || \
+ (defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) && \
+ (HWY_COMPILER_GCC_ACTUAL >= 1000 || HWY_COMPILER_CLANG >= 1100))
+#define HWY_NEON_HAVE_F32_TO_BF16C 1
+#else
+#define HWY_NEON_HAVE_F32_TO_BF16C 0
+#endif
+
#define HWY_ALIGN alignas(16)
#define HWY_MAX_BYTES 16
-#define HWY_LANES(T) (16 / sizeof(T))
+#define HWY_MIN_BYTES 16
#define HWY_HAVE_SCALABLE 0
#define HWY_HAVE_INTEGER64 1
@@ -486,14 +523,14 @@
#else
#define HWY_NATIVE_FMA 0
#endif
-#if HWY_NEON_HAVE_F32_TO_BF16C || HWY_TARGET == HWY_NEON_BF16
+
+#if HWY_NEON_HAVE_F32_TO_BF16C
#define HWY_NATIVE_DOT_BF16 1
#else
#define HWY_NATIVE_DOT_BF16 0
#endif
-#define HWY_CAP_GE256 0
-#define HWY_CAP_GE512 0
+#define HWY_NATIVE_MASK 0
#if HWY_TARGET == HWY_NEON_WITHOUT_AES
#define HWY_NAMESPACE N_NEON_WITHOUT_AES
@@ -537,12 +574,27 @@
#define HWY_TARGET_STR_FP16 "+fp16"
#endif
+#if HWY_OS_APPLE
+// Enable i8mm for the NEON_BF16 target if compiling for macOS, iOS, or iPadOS
+// as all Apple Silicon CPU's that support BF16 have support for I8MM.
+#define HWY_TARGET_STR_NEON_BF16_EXTRA "+i8mm"
+#else
+#define HWY_TARGET_STR_NEON_BF16_EXTRA ""
+#endif
+
#if HWY_TARGET == HWY_NEON_WITHOUT_AES
+#if HWY_COMPILER_GCC_ACTUAL && HWY_COMPILER_GCC_ACTUAL < 1400
+// Prevents inadvertent use of SVE by GCC 13.4 and earlier, see #2689.
+#define HWY_TARGET_STR "+nosve"
+#else
// Do not define HWY_TARGET_STR (no pragma).
+#endif // HWY_COMPILER_GCC_ACTUAL
#elif HWY_TARGET == HWY_NEON
#define HWY_TARGET_STR HWY_TARGET_STR_NEON
#elif HWY_TARGET == HWY_NEON_BF16
-#define HWY_TARGET_STR HWY_TARGET_STR_FP16 "+bf16+dotprod" HWY_TARGET_STR_NEON
+#define HWY_TARGET_STR \
+ HWY_TARGET_STR_FP16 \
+ "+bf16+dotprod" HWY_TARGET_STR_NEON_BF16_EXTRA HWY_TARGET_STR_NEON
#else
#error "Logic error, missing case"
#endif // HWY_TARGET
@@ -559,51 +611,65 @@
// SVE only requires lane alignment, not natural alignment of the entire vector.
#define HWY_ALIGN alignas(8)
-// Value ensures MaxLanes() is the tightest possible upper bound to reduce
-// overallocation.
-#define HWY_LANES(T) ((HWY_MAX_BYTES) / sizeof(T))
-
#define HWY_HAVE_INTEGER64 1
#define HWY_HAVE_FLOAT16 1
#define HWY_HAVE_FLOAT64 1
#define HWY_MEM_OPS_MIGHT_FAULT 0
#define HWY_NATIVE_FMA 1
-#if HWY_SVE_HAVE_BF16_FEATURE
+#if HWY_SVE_HAVE_BF16_FEATURE || HWY_TARGET == HWY_SVE2_128
#define HWY_NATIVE_DOT_BF16 1
#else
#define HWY_NATIVE_DOT_BF16 0
#endif
-#define HWY_CAP_GE256 0
-#define HWY_CAP_GE512 0
+#define HWY_NATIVE_MASK 1
#if HWY_TARGET == HWY_SVE2
#define HWY_NAMESPACE N_SVE2
#define HWY_MAX_BYTES 256
+#define HWY_MIN_BYTES 16
#define HWY_HAVE_SCALABLE 1
#elif HWY_TARGET == HWY_SVE_256
#define HWY_NAMESPACE N_SVE_256
#define HWY_MAX_BYTES 32
+#define HWY_MIN_BYTES 32
#define HWY_HAVE_SCALABLE 0
#elif HWY_TARGET == HWY_SVE2_128
#define HWY_NAMESPACE N_SVE2_128
#define HWY_MAX_BYTES 16
+#define HWY_MIN_BYTES 16
#define HWY_HAVE_SCALABLE 0
#else
#define HWY_NAMESPACE N_SVE
#define HWY_MAX_BYTES 256
+#define HWY_MIN_BYTES 16
#define HWY_HAVE_SCALABLE 1
#endif
-// Can use pragmas instead of -march compiler flag
-#if HWY_HAVE_RUNTIME_DISPATCH
-#if HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128
+// Note: -march strings are delimited by + and GCC actually requires + before
+// each pragma target, which are also comma-separated.
+
+#undef HWY_TARGET_STR_SVE2_AES
// Static dispatch with -march=armv8-a+sve2+aes, or no baseline, hence dynamic
// dispatch, which checks for AES support at runtime.
#if defined(__ARM_FEATURE_SVE2_AES) || (HWY_BASELINE_SVE2 == 0)
-#define HWY_TARGET_STR "+sve2+sve2-aes,+sve"
+#define HWY_TARGET_STR_SVE2_AES ",+sve2-aes"
#else // SVE2 without AES
-#define HWY_TARGET_STR "+sve2,+sve"
+#define HWY_TARGET_STR_SVE2_AES ""
+#endif
+
+#undef HWY_TARGET_STR_SVE2_128
+// SVE2_128 implies/requires I8MM and BF16, see #2973.
+#if HWY_TARGET == HWY_SVE2_128
+#define HWY_TARGET_STR_SVE2_128 ",+i8mm,+bf16"
+#else
+#define HWY_TARGET_STR_SVE2_128 ""
#endif
+
+// Can use pragmas instead of -march compiler flag
+#if HWY_HAVE_RUNTIME_DISPATCH
+#if HWY_TARGET == HWY_SVE2 || HWY_TARGET == HWY_SVE2_128
+#define HWY_TARGET_STR \
+ "+sve,+sve2" HWY_TARGET_STR_SVE2_AES HWY_TARGET_STR_SVE2_128
#else // not SVE2 target
#define HWY_TARGET_STR "+sve"
#endif
@@ -617,7 +683,7 @@
#define HWY_ALIGN alignas(16)
#define HWY_MAX_BYTES 16
-#define HWY_LANES(T) (16 / sizeof(T))
+#define HWY_MIN_BYTES 16
#define HWY_HAVE_SCALABLE 0
#define HWY_HAVE_INTEGER64 1
@@ -626,8 +692,7 @@
#define HWY_MEM_OPS_MIGHT_FAULT 1
#define HWY_NATIVE_FMA 0
#define HWY_NATIVE_DOT_BF16 0
-#define HWY_CAP_GE256 0
-#define HWY_CAP_GE512 0
+#define HWY_NATIVE_MASK 0
#define HWY_NAMESPACE N_WASM
@@ -639,7 +704,7 @@
#define HWY_ALIGN alignas(32)
#define HWY_MAX_BYTES 32
-#define HWY_LANES(T) (32 / sizeof(T))
+#define HWY_MIN_BYTES 32
#define HWY_HAVE_SCALABLE 0
#define HWY_HAVE_INTEGER64 1
@@ -648,8 +713,7 @@
#define HWY_MEM_OPS_MIGHT_FAULT 1
#define HWY_NATIVE_FMA 0
#define HWY_NATIVE_DOT_BF16 0
-#define HWY_CAP_GE256 1
-#define HWY_CAP_GE512 0
+#define HWY_NATIVE_MASK 0
#define HWY_NAMESPACE N_WASM_EMU256
@@ -665,9 +729,11 @@
// The spec requires VLEN <= 2^16 bits, so the limit is 2^16 bytes (LMUL=8).
#define HWY_MAX_BYTES 65536
+#define HWY_MIN_BYTES 16
// = HWY_MAX_BYTES divided by max LMUL=8 because MaxLanes includes the actual
// LMUL. This is the tightest possible upper bound.
+#undef HWY_LANES
#define HWY_LANES(T) (8192 / sizeof(T))
#define HWY_HAVE_SCALABLE 1
@@ -676,8 +742,7 @@
#define HWY_MEM_OPS_MIGHT_FAULT 0
#define HWY_NATIVE_FMA 1
#define HWY_NATIVE_DOT_BF16 0
-#define HWY_CAP_GE256 0
-#define HWY_CAP_GE512 0
+#define HWY_NATIVE_MASK 1
#if HWY_RVV_HAVE_F16_VEC
#define HWY_HAVE_FLOAT16 1
@@ -689,7 +754,7 @@
#if HWY_COMPILER_CLANG >= 1900
// https://github.com/riscv/riscv-v-spec/blob/master/v-spec.adoc#181-zvl-minimum-vector-length-standard-extensions
-#define HWY_TARGET_STR "Zvl128b,Zve64d"
+#define HWY_TARGET_STR "arch=+v"
#else
// HWY_TARGET_STR remains undefined so HWY_ATTR is a no-op.
#endif
@@ -701,23 +766,27 @@
#if HWY_TARGET == HWY_LSX
#define HWY_ALIGN alignas(16)
#define HWY_MAX_BYTES 16
+#define HWY_MIN_BYTES 16
+#ifndef __loongarch_sx
+#define HWY_TARGET_STR "lsx"
+#endif
#else
#define HWY_ALIGN alignas(32)
#define HWY_MAX_BYTES 32
+#define HWY_MIN_BYTES 32
+#ifndef __loongarch_asx
+#define HWY_TARGET_STR "lsx,lasx"
+#endif
#endif
-#define HWY_LANES(T) (HWY_MAX_BYTES / sizeof(T))
-
-// TODO: check flag values
#define HWY_HAVE_SCALABLE 0
#define HWY_HAVE_INTEGER64 1
-#define HWY_HAVE_FLOAT16 1
+#define HWY_HAVE_FLOAT16 0
#define HWY_HAVE_FLOAT64 1
-#define HWY_MEM_OPS_MIGHT_FAULT 0
+#define HWY_MEM_OPS_MIGHT_FAULT 1
#define HWY_NATIVE_FMA 1
#define HWY_NATIVE_DOT_BF16 0
-#define HWY_CAP_GE256 0
-#define HWY_CAP_GE512 0
+#define HWY_NATIVE_MASK 0
#if HWY_TARGET == HWY_LSX
#define HWY_NAMESPACE N_LSX
@@ -733,7 +802,7 @@
#define HWY_ALIGN alignas(16)
#define HWY_MAX_BYTES 16
-#define HWY_LANES(T) (16 / sizeof(T))
+#define HWY_MIN_BYTES 16
#define HWY_HAVE_SCALABLE 0
#define HWY_HAVE_INTEGER64 1
@@ -742,8 +811,7 @@
#define HWY_MEM_OPS_MIGHT_FAULT 1
#define HWY_NATIVE_FMA 0
#define HWY_NATIVE_DOT_BF16 0
-#define HWY_CAP_GE256 0
-#define HWY_CAP_GE512 0
+#define HWY_NATIVE_MASK 0
#define HWY_NAMESPACE N_EMU128
@@ -755,6 +823,8 @@
#define HWY_ALIGN
#define HWY_MAX_BYTES 8
+#define HWY_MIN_BYTES 8
+#undef HWY_LANES
#define HWY_LANES(T) 1
#define HWY_HAVE_SCALABLE 0
@@ -764,8 +834,7 @@
#define HWY_MEM_OPS_MIGHT_FAULT 0
#define HWY_NATIVE_FMA 0
#define HWY_NATIVE_DOT_BF16 0
-#define HWY_CAP_GE256 0
-#define HWY_CAP_GE512 0
+#define HWY_NATIVE_MASK 0
#define HWY_NAMESPACE N_SCALAR
@@ -791,7 +860,7 @@
// Clang <9 requires this be invoked at file scope, before any namespace.
#undef HWY_BEFORE_NAMESPACE
-#if defined(HWY_TARGET_STR)
+#if defined(HWY_TARGET_STR) && !defined(HWY_DISABLE_ATTR)
#define HWY_BEFORE_NAMESPACE() \
HWY_PUSH_ATTRIBUTES(HWY_TARGET_STR) \
static_assert(true, "For requiring trailing semicolon")
@@ -803,7 +872,7 @@
// Clang <9 requires any namespaces be closed before this macro.
#undef HWY_AFTER_NAMESPACE
-#if defined(HWY_TARGET_STR)
+#if defined(HWY_TARGET_STR) && !defined(HWY_DISABLE_ATTR)
#define HWY_AFTER_NAMESPACE() \
HWY_POP_ATTRIBUTES \
static_assert(true, "For requiring trailing semicolon")
@@ -814,8 +883,16 @@
#endif
#undef HWY_ATTR
-#if defined(HWY_TARGET_STR) && HWY_HAS_ATTRIBUTE(target)
+#if defined(HWY_TARGET_STR) && HWY_HAS_ATTRIBUTE(target) && \
+ !defined(HWY_DISABLE_ATTR)
#define HWY_ATTR __attribute__((target(HWY_TARGET_STR)))
#else
#define HWY_ATTR
#endif
+
+#if (HWY_MAX_BYTES <= 16) || HWY_TARGET_IS_SVE || (HWY_TARGET == HWY_RVV) || \
+ (HWY_TARGET == HWY_WASM_EMU256)
+#define HWY_NATIVE_INTERLEAVE_WHOLE 1
+#else
+#define HWY_NATIVE_INTERLEAVE_WHOLE 0
+#endif
diff --git a/third_party/highway/hwy/ops/shared-inl.h b/third_party/highway/hwy/ops/shared-inl.h
index 95e3399954..ec97a78564 100644
--- a/third_party/highway/hwy/ops/shared-inl.h
+++ b/third_party/highway/hwy/ops/shared-inl.h
@@ -162,9 +162,10 @@ HWY_INLINE void MaybePoison(T* HWY_RESTRICT unaligned, size_t count) {
#endif
}
+// This can be useful for working around MSAN limitations. For example, prior
+// to Clang 16, it did not understand AVX-512 CompressStore.
template <typename T>
HWY_INLINE void MaybeUnpoison(T* HWY_RESTRICT unaligned, size_t count) {
- // Workaround for MSAN not marking compressstore as initialized (b/233326619)
#if HWY_IS_MSAN
__msan_unpoison(unaligned, count * sizeof(T));
#else
@@ -559,6 +560,29 @@ HWY_API bool IsAligned(D d, T* ptr) {
return reinterpret_cast<uintptr_t>(ptr) % (N * sizeof(T)) == 0;
}
+// Returns whether `Lookup8` can definitely be used for vectors created from
+// tag `d`. May return a false negative for large scalable vectors.
+template <class D, typename T = TFromD<D>>
+HWY_API constexpr bool CanLookup8(D d) {
+ // `Lookup8` can use two-register tables, so it is sufficient to ensure
+ // vectors have at least four lanes (8/2). For fixed-length vectors: check
+ // `MaxLanes` directly. For scalable vectors, first require full
+ // (non-partial) vectors, which implies they are at least 128 bits. Then also
+ // require 16 or 32-bit elements, which implies at least 128/{16,32} =
+ // {8,4} lanes per vector. For 8-bit T, `TableLookupBytes` is more efficient.
+ return (!HWY_HAVE_SCALABLE && MaxLanes(d) >= 4) ||
+ (HWY_HAVE_SCALABLE && detail::IsFull(d) &&
+ (sizeof(T) == 2 || sizeof(T) == 4));
+}
+
+// Returns whether `Lookup16` can definitely be used for vectors created from
+// tag `d`. May return a false negative for large scalable vectors.
+template <class D, typename T = TFromD<D>>
+HWY_API constexpr bool CanLookup16(D d) {
+ return (!HWY_HAVE_SCALABLE && MaxLanes(d) >= 8) ||
+ (HWY_HAVE_SCALABLE && detail::IsFull(d) && (sizeof(T) == 2));
+}
+
// ------------------------------ Choosing overloads (SFINAE)
// Same as base.h macros but with a Simd<T, N, kPow2> argument instead of T.
@@ -703,7 +727,7 @@ HWY_API bool IsAligned(D d, T* ptr) {
// HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V is used to disable the default
// implementation of unsigned to signed DemoteTo/ReorderDemote2To in
// generic_ops-inl.h for at least some of the unsigned to signed demotions on
-// SCALAR/EMU128/SSE2/SSSE3/SSE4/AVX2/SVE/SVE2
+// SCALAR/EMU128/SSE2/SSSE3/SSE4/AVX2/SVE/SVE2/LSX/LASX
#undef HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V
#define HWY_IF_U2I_DEMOTE_FROM_LANE_SIZE_V(V) void* = nullptr
diff --git a/third_party/highway/hwy/ops/wasm_128-inl.h b/third_party/highway/hwy/ops/wasm_128-inl.h
index 207b57c994..01bc2b014a 100644
--- a/third_party/highway/hwy/ops/wasm_128-inl.h
+++ b/third_party/highway/hwy/ops/wasm_128-inl.h
@@ -903,6 +903,24 @@ HWY_API Vec128<double, N> Max(Vec128<double, N> a, Vec128<double, N> b) {
return Vec128<double, N>{wasm_f64x2_pmax(b.raw, a.raw)};
}
+// ------------------------------ MinNumber and MaxNumber
+
+#ifdef HWY_NATIVE_FLOAT_MIN_MAX_NUMBER
+#undef HWY_NATIVE_FLOAT_MIN_MAX_NUMBER
+#else
+#define HWY_NATIVE_FLOAT_MIN_MAX_NUMBER
+#endif
+
+template <class V, HWY_IF_FLOAT_OR_SPECIAL_V(V)>
+HWY_API V MinNumber(V a, V b) {
+ return Min(a, IfThenElse(IsNaN(b), a, b));
+}
+
+template <class V, HWY_IF_FLOAT_OR_SPECIAL_V(V)>
+HWY_API V MaxNumber(V a, V b) {
+ return Max(a, IfThenElse(IsNaN(b), a, b));
+}
+
// ------------------------------ Integer multiplication
// Unsigned
@@ -1555,13 +1573,6 @@ HWY_API Vec128<T, N> Xor(Vec128<T, N> a, Vec128<T, N> b) {
return Vec128<T, N>{wasm_v128_xor(a.raw, b.raw)};
}
-// ------------------------------ Xor3
-
-template <typename T, size_t N>
-HWY_API Vec128<T, N> Xor3(Vec128<T, N> x1, Vec128<T, N> x2, Vec128<T, N> x3) {
- return Xor(x1, Xor(x2, x3));
-}
-
// ------------------------------ Or3
template <typename T, size_t N>
@@ -3925,6 +3936,17 @@ HWY_API V InterleaveOddBlocks(D, V a, V /*b*/) {
return a;
}
+// ------------------------------ InterleaveLowerBlocks
+template <class D, class V = VFromD<D>, HWY_IF_V_SIZE_LE_D(D, 16)>
+HWY_API V InterleaveLowerBlocks(D, V a, V /*b*/) {
+ return a;
+}
+// ------------------------------ InterleaveUpperBlocks
+template <class D, class V = VFromD<D>, HWY_IF_V_SIZE_LE_D(D, 16)>
+HWY_API V InterleaveUpperBlocks(D, V a, V /*b*/) {
+ return a;
+}
+
// ------------------------------ ReverseBlocks
template <class D>
HWY_API VFromD<D> ReverseBlocks(D /* tag */, VFromD<D> v) {
@@ -5870,24 +5892,11 @@ HWY_API VFromD<D32> ReorderWidenMulAccumulate(D32 d32, V16 a, V16 b,
}
// ------------------------------ RearrangeToOddPlusEven
-template <size_t N>
-HWY_API Vec128<int32_t, N> RearrangeToOddPlusEven(
- const Vec128<int32_t, N> sum0, const Vec128<int32_t, N> /*sum1*/) {
+template <class VW, HWY_IF_NOT_FLOAT_V(VW)>
+HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW) {
return sum0; // invariant already holds
}
-template <size_t N>
-HWY_API Vec128<uint32_t, N> RearrangeToOddPlusEven(
- const Vec128<uint32_t, N> sum0, const Vec128<uint32_t, N> /*sum1*/) {
- return sum0; // invariant already holds
-}
-
-template <size_t N>
-HWY_API Vec128<float, N> RearrangeToOddPlusEven(const Vec128<float, N> sum0,
- const Vec128<float, N> sum1) {
- return Add(sum0, sum1);
-}
-
// ------------------------------ Reductions
// Nothing native, generic_ops-inl defines SumOfLanes and ReduceSum.
diff --git a/third_party/highway/hwy/ops/wasm_256-inl.h b/third_party/highway/hwy/ops/wasm_256-inl.h
index e81f33f3ab..e65c8388d3 100644
--- a/third_party/highway/hwy/ops/wasm_256-inl.h
+++ b/third_party/highway/hwy/ops/wasm_256-inl.h
@@ -587,11 +587,6 @@ HWY_API Vec256<T> Xor(Vec256<T> a, Vec256<T> b) {
return a;
}
-template <typename T>
-HWY_API Vec256<T> Xor3(Vec256<T> x1, Vec256<T> x2, Vec256<T> x3) {
- return Xor(x1, Xor(x2, x3));
-}
-
template <typename T>
HWY_API Vec256<T> Or3(Vec256<T> o1, Vec256<T> o2, Vec256<T> o3) {
return Or(o1, Or(o2, o3));
@@ -662,7 +657,7 @@ HWY_API Vec256<T> VecFromMask(D d, Mask256<T> m) {
template <class D, HWY_IF_V_SIZE_D(D, 32)>
HWY_API uint64_t BitsFromMask(D d, MFromD<D> m) {
- const Half<D> dh;
+ const Half<decltype(d)> dh;
const uint64_t lo = BitsFromMask(dh, m.m0);
const uint64_t hi = BitsFromMask(dh, m.m1);
return (hi << Lanes(dh)) | lo;
@@ -1424,6 +1419,17 @@ HWY_API V InterleaveOddBlocks(D, V a, V b) {
return ret;
}
+// ------------------------------ InterleaveLowerBlocks
+template <class D, class V = VFromD<D>, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API V InterleaveLowerBlocks(D d, V a, V b) {
+ return InterleaveEvenBlocks(d, a, b);
+}
+// ------------------------------ InterleaveUpperBlocks
+template <class D, class V = VFromD<D>, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API V InterleaveUpperBlocks(D d, V a, V b) {
+ return InterleaveOddBlocks(d, a, b);
+}
+
// ------------------------------ ReverseBlocks
template <class D, typename T = TFromD<D>>
HWY_API Vec256<T> ReverseBlocks(D /* tag */, const Vec256<T> v) {
@@ -2411,14 +2417,6 @@ HWY_API Vec256<T32> ReorderWidenMulAccumulate(D32 d32, Vec256<T16> a,
return sum0;
}
-// ------------------------------ RearrangeToOddPlusEven
-template <typename TW>
-HWY_API Vec256<TW> RearrangeToOddPlusEven(Vec256<TW> sum0, Vec256<TW> sum1) {
- sum0.v0 = RearrangeToOddPlusEven(sum0.v0, sum1.v0);
- sum0.v1 = RearrangeToOddPlusEven(sum0.v1, sum1.v1);
- return sum0;
-}
-
// ------------------------------ Reductions in generic_ops
// ------------------------------ Lt128
diff --git a/third_party/highway/hwy/ops/x86_128-inl.h b/third_party/highway/hwy/ops/x86_128-inl.h
index db38e79cec..3c525ad159 100644
--- a/third_party/highway/hwy/ops/x86_128-inl.h
+++ b/third_party/highway/hwy/ops/x86_128-inl.h
@@ -57,7 +57,7 @@ namespace detail {
#undef HWY_AVX3_HAVE_F32_TO_BF16C
#if HWY_TARGET <= HWY_AVX3_ZEN4 && !HWY_COMPILER_CLANGCL && \
(HWY_COMPILER_GCC_ACTUAL >= 1000 || HWY_COMPILER_CLANG >= 900) && \
- !defined(HWY_AVX3_DISABLE_AVX512BF16)
+ HWY_AVX3_ENABLE_AVX512BF16
#define HWY_AVX3_HAVE_F32_TO_BF16C 1
#else
#define HWY_AVX3_HAVE_F32_TO_BF16C 0
@@ -70,6 +70,15 @@ namespace detail {
#define HWY_X86_GCC_INLINE_ASM_VEC_CONSTRAINT "x"
#endif
+#undef HWY_X86_HAVE_AVX10_2_OPS
+#if HWY_TARGET_IS_AVX10_2 && \
+ (HWY_COMPILER_GCC_ACTUAL >= 1501 || \
+ (HWY_COMPILER3_CLANG >= 200103 && HWY_COMPILER_CLANG != 2100))
+#define HWY_X86_HAVE_AVX10_2_OPS 1
+#else
+#define HWY_X86_HAVE_AVX10_2_OPS 0
+#endif
+
template <typename T>
struct Raw128 {
using type = __m128i;
@@ -707,83 +716,130 @@ HWY_API Vec128<double, N> Xor(Vec128<double, N> a, Vec128<double, N> b) {
return Vec128<double, N>{_mm_xor_pd(a.raw, b.raw)};
}
-// ------------------------------ Not
-template <typename T, size_t N>
-HWY_API Vec128<T, N> Not(const Vec128<T, N> v) {
- const DFromV<decltype(v)> d;
- const RebindToUnsigned<decltype(d)> du;
- using VU = VFromD<decltype(du)>;
+// ------------------------------ TernaryLogic
+
+#undef HWY_X86_HAVE_TERNARY_LOGIC
#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN
- const __m128i vu = BitCast(du, v).raw;
- return BitCast(d, VU{_mm_ternarylogic_epi32(vu, vu, vu, 0x55)});
+#define HWY_X86_HAVE_TERNARY_LOGIC 1
#else
- return Xor(v, BitCast(d, VU{_mm_set1_epi32(-1)}));
+#define HWY_X86_HAVE_TERNARY_LOGIC 0
+#endif
+
+#if HWY_X86_HAVE_TERNARY_LOGIC
+namespace detail {
+
+// Forward-declare the per-target implementations.
+template <uint8_t kTernLogOp, size_t kVectorBytes>
+struct TernaryLogicImpl;
+
+// Interface called from all targets. Without this, the compiler would only
+// examine one of the overloads, because each is templated on kTernLogOp.
+template <uint8_t kTernLogOp, class V>
+HWY_INLINE V TernaryLogic(V a, V b, V c) {
+ return TernaryLogicImpl<kTernLogOp, sizeof(V)>()(a, b, c);
+}
+
+// Per-target partial specialization.
+template <uint8_t kTernLogOp>
+struct TernaryLogicImpl<kTernLogOp, 16> {
+ template <class V>
+ HWY_INLINE V operator()(V a, V b, V c) const {
+ const DFromV<decltype(a)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ using VU = VFromD<decltype(du)>;
+ const __m128i ret = _mm_ternarylogic_epi64(
+ BitCast(du, a).raw, BitCast(du, b).raw, BitCast(du, c).raw, kTernLogOp);
+ return BitCast(d, VU{ret});
+ }
+};
+
+} // namespace detail
+#endif // HWY_X86_HAVE_TERNARY_LOGIC
+
+// ------------------------------ Not
+template <class V> // generic for all vector lengths
+HWY_API V Not(const V v) {
+#if HWY_X86_HAVE_TERNARY_LOGIC
+ return detail::TernaryLogic<0x55>(v, v, v);
+#else
+ const DFromV<decltype(v)> d;
+ const RebindToSigned<decltype(d)> di;
+ return Xor(v, BitCast(d, Set(di, -1)));
#endif
}
+#if HWY_X86_HAVE_TERNARY_LOGIC
+
// ------------------------------ Xor3
-template <typename T, size_t N>
-HWY_API Vec128<T, N> Xor3(Vec128<T, N> x1, Vec128<T, N> x2, Vec128<T, N> x3) {
-#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN
- const DFromV<decltype(x1)> d;
- const RebindToUnsigned<decltype(d)> du;
- using VU = VFromD<decltype(du)>;
- const __m128i ret = _mm_ternarylogic_epi64(
- BitCast(du, x1).raw, BitCast(du, x2).raw, BitCast(du, x3).raw, 0x96);
- return BitCast(d, VU{ret});
+
+#ifdef HWY_NATIVE_XOR3
+#undef HWY_NATIVE_XOR3
+#else
+#define HWY_NATIVE_XOR3
+#endif
+
+template <class V> // generic for all vector lengths
+HWY_API V Xor3(V x1, V x2, V x3) {
+ return detail::TernaryLogic<0x96>(x1, x2, x3);
+}
+
+// ------------------------------ XorAndNot
+
+#ifdef HWY_NATIVE_BCAX
+#undef HWY_NATIVE_BCAX
+#else
+#define HWY_NATIVE_BCAX
+#endif
+#ifdef HWY_NATIVE_TERNLOG
+#undef HWY_NATIVE_TERNLOG
#else
- return Xor(x1, Xor(x2, x3));
+#define HWY_NATIVE_TERNLOG
#endif
+
+template <class V> // generic for all vector lengths
+HWY_API V XorAndNot(V x, V a1, V a2) {
+ return detail::TernaryLogic<0xD2>(x, a1, a2);
+}
+
+template <class V> // generic for all vector lengths
+HWY_API V AndXor(V a, V x1, V x2) {
+ return detail::TernaryLogic<0x60>(a, x1, x2);
}
+#endif // HWY_X86_HAVE_TERNARY_LOGIC
+
// ------------------------------ Or3
-template <typename T, size_t N>
-HWY_API Vec128<T, N> Or3(Vec128<T, N> o1, Vec128<T, N> o2, Vec128<T, N> o3) {
-#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN
- const DFromV<decltype(o1)> d;
- const RebindToUnsigned<decltype(d)> du;
- using VU = VFromD<decltype(du)>;
- const __m128i ret = _mm_ternarylogic_epi64(
- BitCast(du, o1).raw, BitCast(du, o2).raw, BitCast(du, o3).raw, 0xFE);
- return BitCast(d, VU{ret});
+template <class V> // generic for all vector lengths
+HWY_API V Or3(V o1, V o2, V o3) {
+#if HWY_X86_HAVE_TERNARY_LOGIC
+ return detail::TernaryLogic<0xFE>(o1, o2, o3);
#else
return Or(o1, Or(o2, o3));
#endif
}
// ------------------------------ OrAnd
-template <typename T, size_t N>
-HWY_API Vec128<T, N> OrAnd(Vec128<T, N> o, Vec128<T, N> a1, Vec128<T, N> a2) {
-#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN
- const DFromV<decltype(o)> d;
- const RebindToUnsigned<decltype(d)> du;
- using VU = VFromD<decltype(du)>;
- const __m128i ret = _mm_ternarylogic_epi64(
- BitCast(du, o).raw, BitCast(du, a1).raw, BitCast(du, a2).raw, 0xF8);
- return BitCast(d, VU{ret});
+template <class V> // generic for all vector lengths
+HWY_API V OrAnd(V o, V a1, V a2) {
+#if HWY_X86_HAVE_TERNARY_LOGIC
+ return detail::TernaryLogic<0xF8>(o, a1, a2);
#else
return Or(o, And(a1, a2));
#endif
}
// ------------------------------ IfVecThenElse
-template <typename T, size_t N>
-HWY_API Vec128<T, N> IfVecThenElse(Vec128<T, N> mask, Vec128<T, N> yes,
- Vec128<T, N> no) {
-#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN
- const DFromV<decltype(no)> d;
- const RebindToUnsigned<decltype(d)> du;
- using VU = VFromD<decltype(du)>;
- return BitCast(
- d, VU{_mm_ternarylogic_epi64(BitCast(du, mask).raw, BitCast(du, yes).raw,
- BitCast(du, no).raw, 0xCA)});
+template <class V> // generic for all vector lengths
+HWY_API V IfVecThenElse(V mask, V yes, V no) {
+#if HWY_X86_HAVE_TERNARY_LOGIC
+ return detail::TernaryLogic<0xCA>(mask, yes, no);
#else
return IfThenElse(MaskFromVec(mask), yes, no);
#endif
}
// ------------------------------ BitwiseIfThenElse
-#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN
+#if HWY_X86_HAVE_TERNARY_LOGIC
#ifdef HWY_NATIVE_BITWISE_IF_THEN_ELSE
#undef HWY_NATIVE_BITWISE_IF_THEN_ELSE
@@ -791,12 +847,12 @@ HWY_API Vec128<T, N> IfVecThenElse(Vec128<T, N> mask, Vec128<T, N> yes,
#define HWY_NATIVE_BITWISE_IF_THEN_ELSE
#endif
-template <class V>
+template <class V> // generic for all vector lengths
HWY_API V BitwiseIfThenElse(V mask, V yes, V no) {
return IfVecThenElse(mask, yes, no);
}
-#endif
+#endif // HWY_X86_HAVE_TERNARY_LOGIC
// ------------------------------ Operator overloads (internal-only if float)
@@ -996,6 +1052,23 @@ HWY_API MFromD<D> MaskFalse(D /*d*/) {
return MFromD<D>{static_cast<decltype(MFromD<D>().raw)>(0)};
}
+// ------------------------------ SetMask
+#ifdef HWY_NATIVE_SET_MASK
+#undef HWY_NATIVE_SET_MASK
+#else
+#define HWY_NATIVE_SET_MASK
+#endif
+
+template <class D>
+HWY_API MFromD<D> SetMask(D /*d*/, bool val) {
+ constexpr uint64_t kMask = (HWY_MAX_LANES_D(D) < 64)
+ ? ((1ULL << (HWY_MAX_LANES_D(D) & 63)) - 1ULL)
+ : LimitsMax<uint64_t>();
+
+ return MFromD<D>{static_cast<decltype(MFromD<D>().raw)>(
+ static_cast<uint64_t>(-static_cast<int64_t>(val)) & kMask)};
+}
+
// ------------------------------ IsNegative (MFromD)
#ifdef HWY_NATIVE_IS_NEGATIVE
#undef HWY_NATIVE_IS_NEGATIVE
@@ -1931,6 +2004,45 @@ HWY_API Mask128<T, N> ExclusiveNeither(const Mask128<T, N> a, Mask128<T, N> b) {
#endif // HWY_TARGET <= HWY_AVX3
+// MaskedTernaryLogic depends on MFromD.
+#if HWY_X86_HAVE_TERNARY_LOGIC
+namespace detail {
+
+// Forward-declare implementation.
+template <uint8_t kTernLogOp, size_t kVectorBytes>
+struct MaskedTernaryLogicImpl;
+
+// Same as TernaryLogic, but with writemask. If !mask, returns a.
+template <uint8_t kTernLogOp, class V>
+HWY_INLINE V MaskedTernaryLogic(MFromD<DFromV<V>> mask, V a, V b, V c) {
+ return MaskedTernaryLogicImpl<kTernLogOp, sizeof(V)>()(mask, a, b, c);
+}
+
+template <uint8_t kTernLogOp>
+struct MaskedTernaryLogicImpl<kTernLogOp, 16> {
+ template <class V, class D = DFromV<V>, HWY_IF_T_SIZE_D(D, 4)>
+ HWY_INLINE V operator()(MFromD<D> mask, V a, V b, V c) const {
+ const D d;
+ const RebindToUnsigned<decltype(d)> du;
+ using VU = VFromD<decltype(du)>;
+ const __m128i ret =
+ _mm_mask_ternarylogic_epi32(a.raw, mask.raw, b.raw, c.raw, kTernLogOp);
+ return BitCast(d, VU{ret});
+ }
+ template <class V, class D = DFromV<V>, HWY_IF_T_SIZE_D(D, 8)>
+ HWY_INLINE V operator()(MFromD<D> mask, V a, V b, V c) const {
+ const D d;
+ const RebindToUnsigned<decltype(d)> du;
+ using VU = VFromD<decltype(du)>;
+ const __m128i ret =
+ _mm_mask_ternarylogic_epi64(a.raw, mask.raw, b.raw, c.raw, kTernLogOp);
+ return BitCast(d, VU{ret});
+ }
+};
+
+} // namespace detail
+#endif // HWY_X86_HAVE_TERNARY_LOGIC
+
// ------------------------------ ShiftLeft
template <int kBits, size_t N>
@@ -2037,13 +2149,13 @@ HWY_API Vec128<int8_t, N> ShiftRight(const Vec128<int8_t, N> v) {
// Clang static analysis claims the memory immediately after a partial vector
// store is uninitialized, and also flags the input to partial loads (at least
-// for loadl_pd) as "garbage". This is a false alarm because msan does not
-// raise errors. We work around this by using CopyBytes instead of intrinsics,
-// but only for the analyzer to avoid potentially bad code generation.
+// for loadl_pd) as "garbage". Since 2025-07, MSAN began raising errors. We
+// work around this by using CopyBytes instead of intrinsics, but only for MSAN
+// and static analyzer builds to avoid potentially bad code generation.
// Unfortunately __clang_analyzer__ was not defined for clang-tidy prior to v7.
#ifndef HWY_SAFE_PARTIAL_LOAD_STORE
-#if defined(__clang_analyzer__) || \
- (HWY_COMPILER_CLANG != 0 && HWY_COMPILER_CLANG < 700)
+#if HWY_IS_MSAN || (defined(__clang_analyzer__) || \
+ (HWY_COMPILER_CLANG != 0 && HWY_COMPILER_CLANG < 700))
#define HWY_SAFE_PARTIAL_LOAD_STORE 1
#else
#define HWY_SAFE_PARTIAL_LOAD_STORE 0
@@ -4160,7 +4272,7 @@ HWY_API Vec128<int16_t, N> SaturatedAdd(const Vec128<int16_t, N> a,
return Vec128<int16_t, N>{_mm_adds_epi16(a.raw, b.raw)};
}
-#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN
+#if HWY_X86_HAVE_TERNARY_LOGIC
#ifdef HWY_NATIVE_I32_SATURATED_ADDSUB
#undef HWY_NATIVE_I32_SATURATED_ADDSUB
#else
@@ -4173,32 +4285,21 @@ HWY_API Vec128<int16_t, N> SaturatedAdd(const Vec128<int16_t, N> a,
#define HWY_NATIVE_I64_SATURATED_ADDSUB
#endif
-template <size_t N>
-HWY_API Vec128<int32_t, N> SaturatedAdd(Vec128<int32_t, N> a,
- Vec128<int32_t, N> b) {
- const DFromV<decltype(a)> d;
- const auto sum = a + b;
- const auto overflow_mask = MaskFromVec(
- Vec128<int32_t, N>{_mm_ternarylogic_epi32(a.raw, b.raw, sum.raw, 0x42)});
- const auto i32_max = Set(d, LimitsMax<int32_t>());
- const Vec128<int32_t, N> overflow_result{_mm_mask_ternarylogic_epi32(
- i32_max.raw, MaskFromVec(a).raw, i32_max.raw, i32_max.raw, 0x55)};
+// Generic for all vector lengths.
+template <class V, class D = DFromV<V>, HWY_IF_SIGNED_D(D),
+ HWY_IF_T_SIZE_ONE_OF_D(D, (1 << 4) | (1 << 8))>
+HWY_API V SaturatedAdd(V a, V b) {
+ const D d;
+ const V sum = a + b;
+ const MFromD<D> overflow_mask =
+ MaskFromVec(detail::TernaryLogic<0x42>(a, b, sum));
+ const V max = Set(d, LimitsMax<TFromD<D>>());
+ const V overflow_result =
+ detail::MaskedTernaryLogic<0x55>(MaskFromVec(a), max, max, max);
return IfThenElse(overflow_mask, overflow_result, sum);
}
-template <size_t N>
-HWY_API Vec128<int64_t, N> SaturatedAdd(Vec128<int64_t, N> a,
- Vec128<int64_t, N> b) {
- const DFromV<decltype(a)> d;
- const auto sum = a + b;
- const auto overflow_mask = MaskFromVec(
- Vec128<int64_t, N>{_mm_ternarylogic_epi64(a.raw, b.raw, sum.raw, 0x42)});
- const auto i64_max = Set(d, LimitsMax<int64_t>());
- const Vec128<int64_t, N> overflow_result{_mm_mask_ternarylogic_epi64(
- i64_max.raw, MaskFromVec(a).raw, i64_max.raw, i64_max.raw, 0x55)};
- return IfThenElse(overflow_mask, overflow_result, sum);
-}
-#endif // HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN
+#endif // HWY_X86_HAVE_TERNARY_LOGIC
// ------------------------------ SaturatedSub
@@ -4228,33 +4329,22 @@ HWY_API Vec128<int16_t, N> SaturatedSub(const Vec128<int16_t, N> a,
return Vec128<int16_t, N>{_mm_subs_epi16(a.raw, b.raw)};
}
-#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN
-template <size_t N>
-HWY_API Vec128<int32_t, N> SaturatedSub(Vec128<int32_t, N> a,
- Vec128<int32_t, N> b) {
- const DFromV<decltype(a)> d;
- const auto diff = a - b;
- const auto overflow_mask = MaskFromVec(
- Vec128<int32_t, N>{_mm_ternarylogic_epi32(a.raw, b.raw, diff.raw, 0x18)});
- const auto i32_max = Set(d, LimitsMax<int32_t>());
- const Vec128<int32_t, N> overflow_result{_mm_mask_ternarylogic_epi32(
- i32_max.raw, MaskFromVec(a).raw, i32_max.raw, i32_max.raw, 0x55)};
+#if HWY_X86_HAVE_TERNARY_LOGIC
+// Generic for all vector lengths.
+template <class V, class D = DFromV<V>, HWY_IF_SIGNED_D(D),
+ HWY_IF_T_SIZE_ONE_OF_D(D, (1 << 4) | (1 << 8))>
+HWY_API V SaturatedSub(V a, V b) {
+ const D d;
+ const V diff = a - b;
+ const MFromD<D> overflow_mask =
+ MaskFromVec(detail::TernaryLogic<0x18>(a, b, diff));
+ const V max = Set(d, LimitsMax<TFromD<D>>());
+ const V overflow_result =
+ detail::MaskedTernaryLogic<0x55>(MaskFromVec(a), max, max, max);
return IfThenElse(overflow_mask, overflow_result, diff);
}
-template <size_t N>
-HWY_API Vec128<int64_t, N> SaturatedSub(Vec128<int64_t, N> a,
- Vec128<int64_t, N> b) {
- const DFromV<decltype(a)> d;
- const auto diff = a - b;
- const auto overflow_mask = MaskFromVec(
- Vec128<int64_t, N>{_mm_ternarylogic_epi64(a.raw, b.raw, diff.raw, 0x18)});
- const auto i64_max = Set(d, LimitsMax<int64_t>());
- const Vec128<int64_t, N> overflow_result{_mm_mask_ternarylogic_epi64(
- i64_max.raw, MaskFromVec(a).raw, i64_max.raw, i64_max.raw, 0x55)};
- return IfThenElse(overflow_mask, overflow_result, diff);
-}
-#endif // HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN
+#endif // HWY_X86_HAVE_TERNARY_LOGIC
// ------------------------------ AverageRound
@@ -4476,7 +4566,7 @@ HWY_API Vec128<int64_t, N> operator*(Vec128<int64_t, N> a,
// ------------------------------ RotateRight (ShiftRight, Or)
-// U8 RotateRight implementation on AVX3_DL is now in x86_512-inl.h as U8
+// U8 RotateRight implementation on AVX3_DL is now in x86_avx3-inl.h as U8
// RotateRight uses detail::GaloisAffine on AVX3_DL
#if HWY_TARGET > HWY_AVX3_DL
@@ -6010,6 +6100,110 @@ HWY_API Vec128<double, N> Max(Vec128<double, N> a, Vec128<double, N> b) {
return Vec128<double, N>{_mm_max_pd(a.raw, b.raw)};
}
+// ------------------------------ MinNumber and MaxNumber
+
+#ifdef HWY_NATIVE_FLOAT_MIN_MAX_NUMBER
+#undef HWY_NATIVE_FLOAT_MIN_MAX_NUMBER
+#else
+#define HWY_NATIVE_FLOAT_MIN_MAX_NUMBER
+#endif
+
+#if HWY_X86_HAVE_AVX10_2_OPS
+
+#if HWY_HAVE_FLOAT16
+template <size_t N>
+HWY_API Vec128<float16_t, N> MinNumber(Vec128<float16_t, N> a,
+ Vec128<float16_t, N> b) {
+ return Vec128<float16_t, N>{_mm_minmax_ph(a.raw, b.raw, 0x14)};
+}
+#endif
+template <size_t N>
+HWY_API Vec128<float, N> MinNumber(Vec128<float, N> a, Vec128<float, N> b) {
+ return Vec128<float, N>{_mm_minmax_ps(a.raw, b.raw, 0x14)};
+}
+template <size_t N>
+HWY_API Vec128<double, N> MinNumber(Vec128<double, N> a, Vec128<double, N> b) {
+ return Vec128<double, N>{_mm_minmax_pd(a.raw, b.raw, 0x14)};
+}
+
+#if HWY_HAVE_FLOAT16
+template <size_t N>
+HWY_API Vec128<float16_t, N> MaxNumber(Vec128<float16_t, N> a,
+ Vec128<float16_t, N> b) {
+ return Vec128<float16_t, N>{_mm_minmax_ph(a.raw, b.raw, 0x15)};
+}
+#endif
+template <size_t N>
+HWY_API Vec128<float, N> MaxNumber(Vec128<float, N> a, Vec128<float, N> b) {
+ return Vec128<float, N>{_mm_minmax_ps(a.raw, b.raw, 0x15)};
+}
+template <size_t N>
+HWY_API Vec128<double, N> MaxNumber(Vec128<double, N> a, Vec128<double, N> b) {
+ return Vec128<double, N>{_mm_minmax_pd(a.raw, b.raw, 0x15)};
+}
+
+#else
+
+// MinNumber/MaxNumber are generic for all vector lengths on targets other
+// than AVX10.2
+template <class V, HWY_IF_FLOAT_OR_SPECIAL_V(V)>
+HWY_API V MinNumber(V a, V b) {
+ return Min(a, IfThenElse(IsNaN(b), a, b));
+}
+
+template <class V, HWY_IF_FLOAT_OR_SPECIAL_V(V)>
+HWY_API V MaxNumber(V a, V b) {
+ return Max(a, IfThenElse(IsNaN(b), a, b));
+}
+
+#endif
+
+// ------------------------------ MinMagnitude and MaxMagnitude
+
+#if HWY_X86_HAVE_AVX10_2_OPS
+
+#ifdef HWY_NATIVE_FLOAT_MIN_MAX_MAGNITUDE
+#undef HWY_NATIVE_FLOAT_MIN_MAX_MAGNITUDE
+#else
+#define HWY_NATIVE_FLOAT_MIN_MAX_MAGNITUDE
+#endif
+
+#if HWY_HAVE_FLOAT16
+template <size_t N>
+HWY_API Vec128<float16_t, N> MinMagnitude(Vec128<float16_t, N> a,
+ Vec128<float16_t, N> b) {
+ return Vec128<float16_t, N>{_mm_minmax_ph(a.raw, b.raw, 0x16)};
+}
+#endif
+template <size_t N>
+HWY_API Vec128<float, N> MinMagnitude(Vec128<float, N> a, Vec128<float, N> b) {
+ return Vec128<float, N>{_mm_minmax_ps(a.raw, b.raw, 0x16)};
+}
+template <size_t N>
+HWY_API Vec128<double, N> MinMagnitude(Vec128<double, N> a,
+ Vec128<double, N> b) {
+ return Vec128<double, N>{_mm_minmax_pd(a.raw, b.raw, 0x16)};
+}
+
+#if HWY_HAVE_FLOAT16
+template <size_t N>
+HWY_API Vec128<float16_t, N> MaxMagnitude(Vec128<float16_t, N> a,
+ Vec128<float16_t, N> b) {
+ return Vec128<float16_t, N>{_mm_minmax_ph(a.raw, b.raw, 0x17)};
+}
+#endif
+template <size_t N>
+HWY_API Vec128<float, N> MaxMagnitude(Vec128<float, N> a, Vec128<float, N> b) {
+ return Vec128<float, N>{_mm_minmax_ps(a.raw, b.raw, 0x17)};
+}
+template <size_t N>
+HWY_API Vec128<double, N> MaxMagnitude(Vec128<double, N> a,
+ Vec128<double, N> b) {
+ return Vec128<double, N>{_mm_minmax_pd(a.raw, b.raw, 0x17)};
+}
+
+#endif
+
// ================================================== MEMORY (3)
// ------------------------------ Non-temporal stores
@@ -6850,7 +7044,13 @@ HWY_API Vec128<T, N> Broadcast(const Vec128<T, N> v) {
template <int kLane, typename T, size_t N, HWY_IF_UI32(T)>
HWY_API Vec128<T, N> Broadcast(const Vec128<T, N> v) {
static_assert(0 <= kLane && kLane < N, "Invalid lane");
- return Vec128<T, N>{_mm_shuffle_epi32(v.raw, 0x55 * kLane)};
+ HWY_IF_CONSTEXPR(N == 1) {
+ // Workaround for MSVC compiler bug on single lane integer broadcast
+ return Vec128<T, N>{v};
+ }
+ HWY_IF_CONSTEXPR(N != 1) {
+ return Vec128<T, N>{_mm_shuffle_epi32(v.raw, 0x55 * kLane)};
+ }
}
template <int kLane, typename T, size_t N, HWY_IF_UI64(T)>
@@ -7027,52 +7227,48 @@ HWY_API Vec128<float16_t, N> TableLookupLanes(Vec128<float16_t, N> v,
template <typename T, size_t N, HWY_IF_T_SIZE(T, 4)>
HWY_API Vec128<T, N> TableLookupLanes(Vec128<T, N> v, Indices128<T, N> idx) {
-#if HWY_TARGET <= HWY_AVX2
const DFromV<decltype(v)> d;
- const RebindToFloat<decltype(d)> df;
- const Vec128<float, N> perm{_mm_permutevar_ps(BitCast(df, v).raw, idx.raw)};
- return BitCast(d, perm);
+ const Full128<T> d_full;
+ const Vec128<T> v_full = ZeroExtendResizeBitCast(d_full, d, v);
+
+ const RebindToSigned<decltype(d)> di;
+ const Full128<MakeSigned<T>> di_full;
+ const VFromD<decltype(di_full)> vidx =
+ ZeroExtendResizeBitCast(di_full, di, VFromD<decltype(di)>{idx.raw});
+
+#if HWY_TARGET <= HWY_AVX2
+ // There is no permutevar for non-float; _mm256_permutevar8x32_epi32 is for
+ // 256-bit vectors, hence cast to float.
+ const Full128<float> df_full;
+ // Workaround for MSAN false positive.
+ HWY_IF_CONSTEXPR(HWY_IS_MSAN) PreventElision(GetLane(vidx));
+ const Vec128<float> perm{
+ _mm_permutevar_ps(BitCast(df_full, v_full).raw, vidx.raw)};
+ return ResizeBitCast(d, perm);
#elif HWY_TARGET == HWY_SSE2
#if HWY_COMPILER_GCC_ACTUAL && HWY_HAS_BUILTIN(__builtin_shuffle)
typedef uint32_t GccU32RawVectType __attribute__((__vector_size__(16)));
return Vec128<T, N>{reinterpret_cast<typename detail::Raw128<T>::type>(
- __builtin_shuffle(reinterpret_cast<GccU32RawVectType>(v.raw),
- reinterpret_cast<GccU32RawVectType>(idx.raw)))};
+ __builtin_shuffle(reinterpret_cast<GccU32RawVectType>(v_full.raw),
+ reinterpret_cast<GccU32RawVectType>(vidx.raw)))};
#else
- const Full128<T> d_full;
alignas(16) T src_lanes[4];
- alignas(16) uint32_t indices[4];
+ alignas(16) int32_t indices[4];
alignas(16) T result_lanes[4];
- Store(Vec128<T>{v.raw}, d_full, src_lanes);
- _mm_store_si128(reinterpret_cast<__m128i*>(indices), idx.raw);
+ Store(v_full, d_full, src_lanes);
+ Store(vidx, di_full, indices);
- for (int i = 0; i < 4; i++) {
- result_lanes[i] = src_lanes[indices[i] & 3u];
+ for (size_t i = 0; i < N; i++) {
+ result_lanes[i] = src_lanes[static_cast<size_t>(indices[i] & 3)];
}
-
- return Vec128<T, N>{Load(d_full, result_lanes).raw};
+ return Load(d, result_lanes);
#endif // HWY_COMPILER_GCC_ACTUAL && HWY_HAS_BUILTIN(__builtin_shuffle)
#else // SSSE3 or SSE4
- return TableLookupBytes(v, Vec128<T, N>{idx.raw});
+ return ResizeBitCast(d, TableLookupBytes(BitCast(di_full, v_full), vidx));
#endif
}
-#if HWY_TARGET <= HWY_SSSE3
-template <size_t N, HWY_IF_V_SIZE_GT(float, N, 4)>
-HWY_API Vec128<float, N> TableLookupLanes(Vec128<float, N> v,
- Indices128<float, N> idx) {
-#if HWY_TARGET <= HWY_AVX2
- return Vec128<float, N>{_mm_permutevar_ps(v.raw, idx.raw)};
-#else // SSSE3 or SSE4
- const DFromV<decltype(v)> df;
- const RebindToSigned<decltype(df)> di;
- return BitCast(df,
- TableLookupBytes(BitCast(di, v), Vec128<int32_t, N>{idx.raw}));
-#endif // HWY_TARGET <= HWY_AVX2
-}
-#endif // HWY_TARGET <= HWY_SSSE3
-
// Single lane: no change
template <typename T>
HWY_API Vec128<T, 1> TableLookupLanes(Vec128<T, 1> v,
@@ -7080,11 +7276,15 @@ HWY_API Vec128<T, 1> TableLookupLanes(Vec128<T, 1> v,
return v;
}
-template <typename T, HWY_IF_UI64(T)>
+template <typename T, HWY_IF_T_SIZE(T, 8)>
HWY_API Vec128<T> TableLookupLanes(Vec128<T> v, Indices128<T> idx) {
const DFromV<decltype(v)> d;
+ // No need for ZeroExtendResizeBitCast, we have full vectors.
Vec128<int64_t> vidx{idx.raw};
-#if HWY_TARGET <= HWY_AVX2
+
+ // Disable in MSAN builds due to false positive. Note that this affects
+ // CompressNot, which assumes upper index bits will be ignored.
+#if HWY_TARGET <= HWY_AVX2 && !HWY_IS_MSAN
// There is no _mm_permute[x]var_epi64.
vidx += vidx; // bit1 is the decider (unusual)
const RebindToFloat<decltype(d)> df;
@@ -7096,26 +7296,8 @@ HWY_API Vec128<T> TableLookupLanes(Vec128<T> v, Indices128<T> idx) {
// to obtain an all-zero or all-one mask.
const RebindToSigned<decltype(d)> di;
const Vec128<int64_t> same = (vidx ^ Iota(di, 0)) - Set(di, 1);
- const Mask128<T> mask_same = RebindMask(d, MaskFromVec(same));
- return IfThenElse(mask_same, v, Shuffle01(v));
-#endif
-}
-
-HWY_API Vec128<double> TableLookupLanes(Vec128<double> v,
- Indices128<double> idx) {
- Vec128<int64_t> vidx{idx.raw};
-#if HWY_TARGET <= HWY_AVX2
- vidx += vidx; // bit1 is the decider (unusual)
- return Vec128<double>{_mm_permutevar_pd(v.raw, vidx.raw)};
-#else
- // Only 2 lanes: can swap+blend. Choose v if vidx == iota. To avoid a 64-bit
- // comparison (expensive on SSSE3), just invert the upper lane and subtract 1
- // to obtain an all-zero or all-one mask.
- const DFromV<decltype(v)> d;
- const RebindToSigned<decltype(d)> di;
- const Vec128<int64_t> same = (vidx ^ Iota(di, 0)) - Set(di, 1);
- const Mask128<double> mask_same = RebindMask(d, MaskFromVec(same));
- return IfThenElse(mask_same, v, Shuffle01(v));
+ return BitCast(
+ d, IfVecThenElse(same, BitCast(di, v), Shuffle01(BitCast(di, v))));
#endif
}
@@ -9021,6 +9203,17 @@ HWY_API V InterleaveOddBlocks(D, V a, V /*b*/) {
return a;
}
+// ------------------------------ InterleaveLowerBlocks
+template <class D, class V = VFromD<D>, HWY_IF_V_SIZE_LE_D(D, 16)>
+HWY_API V InterleaveLowerBlocks(D, V a, V /*b*/) {
+ return a;
+}
+// ------------------------------ InterleaveUpperBlocks
+template <class D, class V = VFromD<D>, HWY_IF_V_SIZE_LE_D(D, 16)>
+HWY_API V InterleaveUpperBlocks(D, V a, V /*b*/) {
+ return a;
+}
+
// ------------------------------ Shl (ZipLower, Mul)
// Use AVX2/3 variable shifts where available, otherwise multiply by powers of
@@ -9845,6 +10038,12 @@ HWY_API VFromD<DF> ReorderWidenMulAccumulate(DF /*df*/, VBF a, VBF b,
reinterpret_cast<__m128bh>(b.raw))};
}
+template <class VW, HWY_IF_FLOAT_V(VW)>
+HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW) {
+ // Sum1 is unused and the invariant already holds.
+ return sum0;
+}
+
#endif // HWY_NATIVE_DOT_BF16
// Even if N=1, the input is always at least 2 lanes, hence madd_epi16 is safe.
@@ -9871,21 +10070,10 @@ HWY_API VFromD<DU32> ReorderWidenMulAccumulate(DU32 d, VU16 a, VU16 b,
}
// ------------------------------ RearrangeToOddPlusEven
-template <size_t N>
-HWY_API Vec128<int32_t, N> RearrangeToOddPlusEven(const Vec128<int32_t, N> sum0,
- Vec128<int32_t, N> /*sum1*/) {
- return sum0; // invariant already holds
-}
-
-template <size_t N>
-HWY_API Vec128<uint32_t, N> RearrangeToOddPlusEven(
- const Vec128<uint32_t, N> sum0, Vec128<uint32_t, N> /*sum1*/) {
- return sum0; // invariant already holds
-}
-
-template <class VW>
-HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW sum1) {
- return Add(sum0, sum1);
+template <class VW, HWY_IF_NOT_FLOAT_V(VW)>
+HWY_API VW RearrangeToOddPlusEven(const VW sum0, const VW) {
+ // For integer types, sum1 is unused and the invariant already holds.
+ return sum0;
}
// ------------------------------ SumOfMulQuadAccumulate
@@ -9909,12 +10097,21 @@ HWY_API VFromD<DI32> SumOfMulQuadAccumulate(
#else
#define HWY_NATIVE_I8_I8_SUMOFMULQUADACCUMULATE
#endif
+
+#if HWY_X86_HAVE_AVX10_2_OPS
+template <class DI32, HWY_IF_I32_D(DI32), HWY_IF_V_SIZE_LE_D(DI32, 16)>
+HWY_API VFromD<DI32> SumOfMulQuadAccumulate(DI32 /*di32*/,
+ VFromD<Repartition<int8_t, DI32>> a,
+ VFromD<Repartition<int8_t, DI32>> b,
+ VFromD<DI32> sum) {
+ return VFromD<DI32>{_mm_dpbssd_epi32(sum.raw, a.raw, b.raw)};
+}
+#else // !HWY_X86_HAVE_AVX10_2_OPS
template <class DI32, HWY_IF_I32_D(DI32)>
HWY_API VFromD<DI32> SumOfMulQuadAccumulate(DI32 di32,
VFromD<Repartition<int8_t, DI32>> a,
VFromD<Repartition<int8_t, DI32>> b,
VFromD<DI32> sum) {
- // TODO(janwas): AVX-VNNI-INT8 has dpbssd.
const Repartition<uint8_t, decltype(di32)> du8;
const auto a_u = BitCast(du8, a);
@@ -9923,17 +10120,26 @@ HWY_API VFromD<DI32> SumOfMulQuadAccumulate(DI32 di32,
SumOfMulQuadAccumulate(di32, ShiftRight<7>(a_u), b, Zero(di32)));
return result_sum_0 - result_sum_1;
}
+#endif // HWY_X86_HAVE_AVX10_2_OPS
#ifdef HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE
#undef HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE
#else
#define HWY_NATIVE_U8_U8_SUMOFMULQUADACCUMULATE
#endif
+
+#if HWY_X86_HAVE_AVX10_2_OPS
+template <class DU32, HWY_IF_U32_D(DU32), HWY_IF_V_SIZE_LE_D(DU32, 16)>
+HWY_API VFromD<DU32> SumOfMulQuadAccumulate(
+ DU32 /*du32*/, VFromD<Repartition<uint8_t, DU32>> a,
+ VFromD<Repartition<uint8_t, DU32>> b, VFromD<DU32> sum) {
+ return VFromD<DU32>{_mm_dpbuud_epi32(sum.raw, a.raw, b.raw)};
+}
+#else // !HWY_X86_HAVE_AVX10_2_OPS
template <class DU32, HWY_IF_U32_D(DU32)>
HWY_API VFromD<DU32> SumOfMulQuadAccumulate(
DU32 du32, VFromD<Repartition<uint8_t, DU32>> a,
VFromD<Repartition<uint8_t, DU32>> b, VFromD<DU32> sum) {
- // TODO(janwas): AVX-VNNI-INT8 has dpbuud.
const Repartition<uint8_t, decltype(du32)> du8;
const RebindToSigned<decltype(du8)> di8;
const RebindToSigned<decltype(du32)> di32;
@@ -9946,6 +10152,7 @@ HWY_API VFromD<DU32> SumOfMulQuadAccumulate(
return BitCast(du32, result_sum_0 - result_sum_1);
}
+#endif // HWY_X86_HAVE_AVX10_2_OPS
#endif // HWY_TARGET <= HWY_AVX3_DL
@@ -10443,6 +10650,7 @@ X86ConvertScalarFromFloat(TF from_val) {
return X86ConvertScalarFromFloat<TTo>(hwy::TypeTag<RemoveCvRef<TTo>>(),
from_val);
}
+
#endif // HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD
} // namespace detail
@@ -10455,7 +10663,9 @@ X86ConvertScalarFromFloat(TF from_val) {
template <class D, HWY_IF_V_SIZE_LE_D(D, 8), HWY_IF_I32_D(D)>
HWY_API VFromD<D> DemoteInRangeTo(D /* tag */, VFromD<Rebind<double, D>> v) {
-#if HWY_COMPILER_GCC_ACTUAL
+#if HWY_X86_HAVE_AVX10_2_OPS
+ return VFromD<D>{_mm_cvtts_pd_epi32(v.raw)};
+#elif HWY_COMPILER_GCC_ACTUAL
// Workaround for undefined behavior in _mm_cvttpd_epi32 with GCC if any
// values of v[i] are not within the range of an int32_t
@@ -10492,7 +10702,9 @@ HWY_API VFromD<D> DemoteTo(D di32, VFromD<Rebind<double, D>> v) {
#if HWY_TARGET <= HWY_AVX3
template <class D, HWY_IF_V_SIZE_LE_D(D, 8), HWY_IF_U32_D(D)>
HWY_API VFromD<D> DemoteInRangeTo(D /* tag */, VFromD<Rebind<double, D>> v) {
-#if HWY_COMPILER_GCC_ACTUAL
+#if HWY_X86_HAVE_AVX10_2_OPS
+ return VFromD<D>{_mm_cvtts_pd_epu32(v.raw)};
+#elif HWY_COMPILER_GCC_ACTUAL
// Workaround for undefined behavior in _mm_cvttpd_epu32 with GCC if any
// values of v[i] are not within the range of an uint32_t
@@ -10520,8 +10732,12 @@ HWY_API VFromD<D> DemoteInRangeTo(D /* tag */, VFromD<Rebind<double, D>> v) {
// F64->U32 DemoteTo is generic for all vector lengths
template <class D, HWY_IF_U32_D(D)>
-HWY_API VFromD<D> DemoteTo(D /* tag */, VFromD<Rebind<double, D>> v) {
- return DemoteInRangeTo(D(), ZeroIfNegative(v));
+HWY_API VFromD<D> DemoteTo(D du32, VFromD<Rebind<double, D>> v) {
+#if HWY_X86_HAVE_AVX10_2_OPS
+ return DemoteInRangeTo(du32, v);
+#else
+ return DemoteInRangeTo(du32, ZeroIfNegative(v));
+#endif
}
#else // HWY_TARGET > HWY_AVX3
@@ -10649,7 +10865,9 @@ HWY_API Vec128<uint8_t, N> U8FromU32(const Vec128<uint32_t, N> v) {
#if HWY_TARGET <= HWY_AVX3
template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_I64_D(D)>
HWY_API VFromD<D> PromoteInRangeTo(D /*di64*/, VFromD<Rebind<float, D>> v) {
-#if HWY_COMPILER_GCC_ACTUAL
+#if HWY_X86_HAVE_AVX10_2_OPS
+ return VFromD<D>{_mm_cvtts_ps_epi64(v.raw)};
+#elif HWY_COMPILER_GCC_ACTUAL
// Workaround for undefined behavior with GCC if any values of v[i] are not
// within the range of an int64_t
@@ -10677,6 +10895,9 @@ HWY_API VFromD<D> PromoteInRangeTo(D /*di64*/, VFromD<Rebind<float, D>> v) {
// Generic for all vector lengths.
template <class D, HWY_IF_I64_D(D)>
HWY_API VFromD<D> PromoteTo(D di64, VFromD<Rebind<float, D>> v) {
+#if HWY_X86_HAVE_AVX10_2_OPS
+ return PromoteInRangeTo(di64, v);
+#else
const Rebind<float, decltype(di64)> df32;
const RebindToFloat<decltype(di64)> df64;
// We now avoid GCC UB in PromoteInRangeTo via assembly, see #2189 and
@@ -10689,14 +10910,21 @@ HWY_API VFromD<D> PromoteTo(D di64, VFromD<Rebind<float, D>> v) {
di64, PromoteMaskTo(df64, df32, Ge(v, Set(df32, 9.223372e18f))));
return IfThenElse(overflow, Set(di64, LimitsMax<int64_t>()),
PromoteInRangeTo(di64, v));
+#endif
}
template <class D, HWY_IF_U64_D(D)>
-HWY_API VFromD<D> PromoteTo(D /* tag */, VFromD<Rebind<float, D>> v) {
- return PromoteInRangeTo(D(), ZeroIfNegative(v));
+HWY_API VFromD<D> PromoteTo(D du64, VFromD<Rebind<float, D>> v) {
+#if HWY_X86_HAVE_AVX10_2_OPS
+ return PromoteInRangeTo(du64, v);
+#else
+ return PromoteInRangeTo(du64, ZeroIfNegative(v));
+#endif
}
template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_U64_D(D)>
HWY_API VFromD<D> PromoteInRangeTo(D /* tag */, VFromD<Rebind<float, D>> v) {
-#if HWY_COMPILER_GCC_ACTUAL
+#if HWY_X86_HAVE_AVX10_2_OPS
+ return VFromD<D>{_mm_cvtts_ps_epu64(v.raw)};
+#elif HWY_COMPILER_GCC_ACTUAL
// Workaround for undefined behavior with GCC if any values of v[i] are not
// within the range of an uint64_t
@@ -11375,7 +11603,9 @@ HWY_API VFromD<D> ConvertTo(D /* tag */, VFromD<RebindToFloat<D>> v) {
template <class D, HWY_IF_V_SIZE_LE_D(D, 16), HWY_IF_I32_D(D)>
HWY_API VFromD<D> ConvertInRangeTo(D /*di*/, VFromD<RebindToFloat<D>> v) {
-#if HWY_COMPILER_GCC_ACTUAL
+#if HWY_X86_HAVE_AVX10_2_OPS
+ return VFromD<D>{_mm_cvtts_ps_epi32(v.raw)};
+#elif HWY_COMPILER_GCC_ACTUAL
// Workaround for undefined behavior in _mm_cvttps_epi32 with GCC if any
// values of v[i] are not within the range of an int32_t
@@ -11405,17 +11635,23 @@ HWY_API VFromD<D> ConvertInRangeTo(D /*di*/, VFromD<RebindToFloat<D>> v) {
// F32 to I32 ConvertTo is generic for all vector lengths
template <class D, HWY_IF_I32_D(D)>
HWY_API VFromD<D> ConvertTo(D di, VFromD<RebindToFloat<D>> v) {
+#if HWY_X86_HAVE_AVX10_2_OPS
+ return ConvertInRangeTo(di, v);
+#else
const RebindToFloat<decltype(di)> df;
// See comment at the first occurrence of "IfThenElse(overflow,".
const MFromD<D> overflow = RebindMask(di, Ge(v, Set(df, 2147483648.0f)));
return IfThenElse(overflow, Set(di, LimitsMax<int32_t>()),
ConvertInRangeTo(di, v));
+#endif
}
#if HWY_TARGET <= HWY_AVX3
template <class DI, HWY_IF_V_SIZE_LE_D(DI, 16), HWY_IF_I64_D(DI)>
HWY_API VFromD<DI> ConvertInRangeTo(DI /*di*/, VFromD<RebindToFloat<DI>> v) {
-#if HWY_COMPILER_GCC_ACTUAL
+#if HWY_X86_HAVE_AVX10_2_OPS
+ return VFromD<DI>{_mm_cvtts_pd_epi64(v.raw)};
+#elif HWY_COMPILER_GCC_ACTUAL
// Workaround for undefined behavior in _mm_cvttpd_epi64 with GCC if any
// values of v[i] are not within the range of an int64_t
@@ -11443,17 +11679,23 @@ HWY_API VFromD<DI> ConvertInRangeTo(DI /*di*/, VFromD<RebindToFloat<DI>> v) {
// F64 to I64 ConvertTo is generic for all vector lengths on AVX3
template <class DI, HWY_IF_I64_D(DI)>
HWY_API VFromD<DI> ConvertTo(DI di, VFromD<RebindToFloat<DI>> v) {
+#if HWY_X86_HAVE_AVX10_2_OPS
+ return ConvertInRangeTo(di, v);
+#else
const RebindToFloat<decltype(di)> df;
// See comment at the first occurrence of "IfThenElse(overflow,".
const MFromD<DI> overflow =
RebindMask(di, Ge(v, Set(df, 9.223372036854776e18)));
return IfThenElse(overflow, Set(di, LimitsMax<int64_t>()),
ConvertInRangeTo(di, v));
+#endif
}
template <class DU, HWY_IF_V_SIZE_LE_D(DU, 16), HWY_IF_U32_D(DU)>
HWY_API VFromD<DU> ConvertInRangeTo(DU /*du*/, VFromD<RebindToFloat<DU>> v) {
-#if HWY_COMPILER_GCC_ACTUAL
+#if HWY_X86_HAVE_AVX10_2_OPS
+ return VFromD<DU>{_mm_cvtts_ps_epu32(v.raw)};
+#elif HWY_COMPILER_GCC_ACTUAL
// Workaround for undefined behavior in _mm_cvttps_epu32 with GCC if any
// values of v[i] are not within the range of an uint32_t
@@ -11482,13 +11724,19 @@ HWY_API VFromD<DU> ConvertInRangeTo(DU /*du*/, VFromD<RebindToFloat<DU>> v) {
// F32->U32 ConvertTo is generic for all vector lengths
template <class DU, HWY_IF_U32_D(DU)>
-HWY_API VFromD<DU> ConvertTo(DU /*du*/, VFromD<RebindToFloat<DU>> v) {
- return ConvertInRangeTo(DU(), ZeroIfNegative(v));
+HWY_API VFromD<DU> ConvertTo(DU du32, VFromD<RebindToFloat<DU>> v) {
+#if HWY_X86_HAVE_AVX10_2_OPS
+ return ConvertInRangeTo(du32, v);
+#else
+ return ConvertInRangeTo(du32, ZeroIfNegative(v));
+#endif
}
template <class DU, HWY_IF_V_SIZE_LE_D(DU, 16), HWY_IF_U64_D(DU)>
HWY_API VFromD<DU> ConvertInRangeTo(DU /*du*/, VFromD<RebindToFloat<DU>> v) {
-#if HWY_COMPILER_GCC_ACTUAL
+#if HWY_X86_HAVE_AVX10_2_OPS
+ return VFromD<DU>{_mm_cvtts_pd_epu64(v.raw)};
+#elif HWY_COMPILER_GCC_ACTUAL
// Workaround for undefined behavior in _mm_cvttpd_epu64 with GCC if any
// values of v[i] are not within the range of an uint64_t
@@ -11515,8 +11763,12 @@ HWY_API VFromD<DU> ConvertInRangeTo(DU /*du*/, VFromD<RebindToFloat<DU>> v) {
// F64->U64 ConvertTo is generic for all vector lengths
template <class DU, HWY_IF_U64_D(DU)>
-HWY_API VFromD<DU> ConvertTo(DU /*du*/, VFromD<RebindToFloat<DU>> v) {
- return ConvertInRangeTo(DU(), ZeroIfNegative(v));
+HWY_API VFromD<DU> ConvertTo(DU du64, VFromD<RebindToFloat<DU>> v) {
+#if HWY_X86_HAVE_AVX10_2_OPS
+ return ConvertInRangeTo(du64, v);
+#else
+ return ConvertInRangeTo(du64, ZeroIfNegative(v));
+#endif
}
#else // AVX2 or below
@@ -12669,14 +12921,16 @@ HWY_API Vec128<T> CompressNot(Vec128<T> v, Mask128<T> mask) {
alignas(16) static constexpr uint64_t packed_array[16] = {
0x00000010, 0x00000001, 0x00000010, 0x00000010};
- // For lane i, shift the i-th 4-bit index down to bits [0, 2) -
- // _mm_permutexvar_epi64 will ignore the upper bits.
+ // For lane i, shift the i-th 4-bit index down to bits [0, 2).
const DFromV<decltype(v)> d;
const RebindToUnsigned<decltype(d)> du64;
const auto packed = Set(du64, packed_array[mask.raw]);
- alignas(16) static constexpr uint64_t shifts[2] = {0, 4};
- const auto indices = Indices128<T>{(packed >> Load(du64, shifts)).raw};
- return TableLookupLanes(v, indices);
+ alignas(16) static constexpr uint64_t kShifts[2] = {0, 4};
+ Vec128<uint64_t> indices = packed >> Load(du64, kShifts);
+ // _mm_permutevar_pd will ignore the upper bits, but TableLookupLanes uses
+ // a fallback in MSAN builds, so mask there.
+ HWY_IF_CONSTEXPR(HWY_IS_MSAN) indices &= Set(du64, 1);
+ return TableLookupLanes(v, Indices128<T>{indices.raw});
}
// ------------------------------ CompressBlocksNot
diff --git a/third_party/highway/hwy/ops/x86_256-inl.h b/third_party/highway/hwy/ops/x86_256-inl.h
index 32df08497e..9361150f18 100644
--- a/third_party/highway/hwy/ops/x86_256-inl.h
+++ b/third_party/highway/hwy/ops/x86_256-inl.h
@@ -204,7 +204,6 @@ struct Mask256 {
template <typename T>
using Full256 = Simd<T, 32 / sizeof(T), 0>;
-
// ------------------------------ Zero
// Cannot use VFromD here because it is defined in terms of Zero.
@@ -556,78 +555,48 @@ HWY_API Vec256<double> Xor(Vec256<double> a, Vec256<double> b) {
return Vec256<double>{_mm256_xor_pd(a.raw, b.raw)};
}
-// ------------------------------ Not
-template <typename T>
-HWY_API Vec256<T> Not(const Vec256<T> v) {
- const DFromV<decltype(v)> d;
- using TU = MakeUnsigned<T>;
-#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN
- const __m256i vu = BitCast(RebindToUnsigned<decltype(d)>(), v).raw;
- return BitCast(d, Vec256<TU>{_mm256_ternarylogic_epi32(vu, vu, vu, 0x55)});
-#else
- return Xor(v, BitCast(d, Vec256<TU>{_mm256_set1_epi32(-1)}));
-#endif
-}
-
-// ------------------------------ Xor3
-template <typename T>
-HWY_API Vec256<T> Xor3(Vec256<T> x1, Vec256<T> x2, Vec256<T> x3) {
-#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN
- const DFromV<decltype(x1)> d;
- const RebindToUnsigned<decltype(d)> du;
- using VU = VFromD<decltype(du)>;
- const __m256i ret = _mm256_ternarylogic_epi64(
- BitCast(du, x1).raw, BitCast(du, x2).raw, BitCast(du, x3).raw, 0x96);
- return BitCast(d, VU{ret});
-#else
- return Xor(x1, Xor(x2, x3));
-#endif
-}
+#if HWY_X86_HAVE_TERNARY_LOGIC
+namespace detail {
-// ------------------------------ Or3
-template <typename T>
-HWY_API Vec256<T> Or3(Vec256<T> o1, Vec256<T> o2, Vec256<T> o3) {
-#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN
- const DFromV<decltype(o1)> d;
- const RebindToUnsigned<decltype(d)> du;
- using VU = VFromD<decltype(du)>;
- const __m256i ret = _mm256_ternarylogic_epi64(
- BitCast(du, o1).raw, BitCast(du, o2).raw, BitCast(du, o3).raw, 0xFE);
- return BitCast(d, VU{ret});
-#else
- return Or(o1, Or(o2, o3));
-#endif
-}
+// Per-target partial specialization.
+template <uint8_t kTernLogOp>
+struct TernaryLogicImpl<kTernLogOp, 32> {
+ template <class V>
+ HWY_INLINE V operator()(V a, V b, V c) const {
+ const DFromV<decltype(a)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ using VU = VFromD<decltype(du)>;
+ const __m256i ret = _mm256_ternarylogic_epi64(
+ BitCast(du, a).raw, BitCast(du, b).raw, BitCast(du, c).raw, kTernLogOp);
+ return BitCast(d, VU{ret});
+ }
+};
-// ------------------------------ OrAnd
-template <typename T>
-HWY_API Vec256<T> OrAnd(Vec256<T> o, Vec256<T> a1, Vec256<T> a2) {
-#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN
- const DFromV<decltype(o)> d;
- const RebindToUnsigned<decltype(d)> du;
- using VU = VFromD<decltype(du)>;
- const __m256i ret = _mm256_ternarylogic_epi64(
- BitCast(du, o).raw, BitCast(du, a1).raw, BitCast(du, a2).raw, 0xF8);
- return BitCast(d, VU{ret});
-#else
- return Or(o, And(a1, a2));
-#endif
-}
+// Same, but with writemask. If !mask, returns a.
+template <uint8_t kTernLogOp>
+struct MaskedTernaryLogicImpl<kTernLogOp, 32> {
+ template <class V, class D = DFromV<V>, HWY_IF_T_SIZE_D(D, 4)>
+ HWY_INLINE V operator()(MFromD<D> mask, V a, V b, V c) const {
+ const D d;
+ const RebindToUnsigned<decltype(d)> du;
+ using VU = VFromD<decltype(du)>;
+ const __m256i ret = _mm256_mask_ternarylogic_epi32(a.raw, mask.raw, b.raw,
+ c.raw, kTernLogOp);
+ return BitCast(d, VU{ret});
+ }
+ template <class V, class D = DFromV<V>, HWY_IF_T_SIZE_D(D, 8)>
+ HWY_INLINE V operator()(MFromD<D> mask, V a, V b, V c) const {
+ const D d;
+ const RebindToUnsigned<decltype(d)> du;
+ using VU = VFromD<decltype(du)>;
+ const __m256i ret = _mm256_mask_ternarylogic_epi64(a.raw, mask.raw, b.raw,
+ c.raw, kTernLogOp);
+ return BitCast(d, VU{ret});
+ }
+};
-// ------------------------------ IfVecThenElse
-template <typename T>
-HWY_API Vec256<T> IfVecThenElse(Vec256<T> mask, Vec256<T> yes, Vec256<T> no) {
-#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN
- const DFromV<decltype(yes)> d;
- const RebindToUnsigned<decltype(d)> du;
- using VU = VFromD<decltype(du)>;
- return BitCast(d, VU{_mm256_ternarylogic_epi64(BitCast(du, mask).raw,
- BitCast(du, yes).raw,
- BitCast(du, no).raw, 0xCA)});
-#else
- return IfThenElse(MaskFromVec(mask), yes, no);
-#endif
-}
+} // namespace detail
+#endif // HWY_X86_HAVE_TERNARY_LOGIC
// ------------------------------ Operator overloads (internal-only if float)
@@ -1787,6 +1756,68 @@ HWY_API Vec256<double> Max(const Vec256<double> a, const Vec256<double> b) {
return Vec256<double>{_mm256_max_pd(a.raw, b.raw)};
}
+// ------------------------------ MinNumber and MaxNumber
+
+#if HWY_X86_HAVE_AVX10_2_OPS
+
+#if HWY_HAVE_FLOAT16
+HWY_API Vec256<float16_t> MinNumber(Vec256<float16_t> a, Vec256<float16_t> b) {
+ return Vec256<float16_t>{_mm256_minmax_ph(a.raw, b.raw, 0x14)};
+}
+#endif
+HWY_API Vec256<float> MinNumber(Vec256<float> a, Vec256<float> b) {
+ return Vec256<float>{_mm256_minmax_ps(a.raw, b.raw, 0x14)};
+}
+HWY_API Vec256<double> MinNumber(Vec256<double> a, Vec256<double> b) {
+ return Vec256<double>{_mm256_minmax_pd(a.raw, b.raw, 0x14)};
+}
+
+#if HWY_HAVE_FLOAT16
+HWY_API Vec256<float16_t> MaxNumber(Vec256<float16_t> a, Vec256<float16_t> b) {
+ return Vec256<float16_t>{_mm256_minmax_ph(a.raw, b.raw, 0x15)};
+}
+#endif
+HWY_API Vec256<float> MaxNumber(Vec256<float> a, Vec256<float> b) {
+ return Vec256<float>{_mm256_minmax_ps(a.raw, b.raw, 0x15)};
+}
+HWY_API Vec256<double> MaxNumber(Vec256<double> a, Vec256<double> b) {
+ return Vec256<double>{_mm256_minmax_pd(a.raw, b.raw, 0x15)};
+}
+
+#endif
+
+// ------------------------------ MinMagnitude and MaxMagnitude
+
+#if HWY_X86_HAVE_AVX10_2_OPS
+
+#if HWY_HAVE_FLOAT16
+HWY_API Vec256<float16_t> MinMagnitude(Vec256<float16_t> a,
+ Vec256<float16_t> b) {
+ return Vec256<float16_t>{_mm256_minmax_ph(a.raw, b.raw, 0x16)};
+}
+#endif
+HWY_API Vec256<float> MinMagnitude(Vec256<float> a, Vec256<float> b) {
+ return Vec256<float>{_mm256_minmax_ps(a.raw, b.raw, 0x16)};
+}
+HWY_API Vec256<double> MinMagnitude(Vec256<double> a, Vec256<double> b) {
+ return Vec256<double>{_mm256_minmax_pd(a.raw, b.raw, 0x16)};
+}
+
+#if HWY_HAVE_FLOAT16
+HWY_API Vec256<float16_t> MaxMagnitude(Vec256<float16_t> a,
+ Vec256<float16_t> b) {
+ return Vec256<float16_t>{_mm256_minmax_ph(a.raw, b.raw, 0x17)};
+}
+#endif
+HWY_API Vec256<float> MaxMagnitude(Vec256<float> a, Vec256<float> b) {
+ return Vec256<float>{_mm256_minmax_ps(a.raw, b.raw, 0x17)};
+}
+HWY_API Vec256<double> MaxMagnitude(Vec256<double> a, Vec256<double> b) {
+ return Vec256<double>{_mm256_minmax_pd(a.raw, b.raw, 0x17)};
+}
+
+#endif
+
// ------------------------------ Iota
namespace detail {
@@ -2052,8 +2083,8 @@ HWY_INLINE Vec256<uint32_t> SumsOf4(hwy::UnsignedTag /*type_tag*/,
// ------------------------------ SumsOfAdjQuadAbsDiff
template <int kAOffset, int kBOffset>
-static Vec256<uint16_t> SumsOfAdjQuadAbsDiff(Vec256<uint8_t> a,
- Vec256<uint8_t> b) {
+HWY_API Vec256<uint16_t> SumsOfAdjQuadAbsDiff(Vec256<uint8_t> a,
+ Vec256<uint8_t> b) {
static_assert(0 <= kAOffset && kAOffset <= 1,
"kAOffset must be between 0 and 1");
static_assert(0 <= kBOffset && kBOffset <= 3,
@@ -2098,30 +2129,6 @@ HWY_API Vec256<int16_t> SaturatedAdd(Vec256<int16_t> a, Vec256<int16_t> b) {
return Vec256<int16_t>{_mm256_adds_epi16(a.raw, b.raw)};
}
-#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN
-HWY_API Vec256<int32_t> SaturatedAdd(Vec256<int32_t> a, Vec256<int32_t> b) {
- const DFromV<decltype(a)> d;
- const auto sum = a + b;
- const auto overflow_mask = MaskFromVec(
- Vec256<int32_t>{_mm256_ternarylogic_epi32(a.raw, b.raw, sum.raw, 0x42)});
- const auto i32_max = Set(d, LimitsMax<int32_t>());
- const Vec256<int32_t> overflow_result{_mm256_mask_ternarylogic_epi32(
- i32_max.raw, MaskFromVec(a).raw, i32_max.raw, i32_max.raw, 0x55)};
- return IfThenElse(overflow_mask, overflow_result, sum);
-}
-
-HWY_API Vec256<int64_t> SaturatedAdd(Vec256<int64_t> a, Vec256<int64_t> b) {
- const DFromV<decltype(a)> d;
- const auto sum = a + b;
- const auto overflow_mask = MaskFromVec(
- Vec256<int64_t>{_mm256_ternarylogic_epi64(a.raw, b.raw, sum.raw, 0x42)});
- const auto i64_max = Set(d, LimitsMax<int64_t>());
- const Vec256<int64_t> overflow_result{_mm256_mask_ternarylogic_epi64(
- i64_max.raw, MaskFromVec(a).raw, i64_max.raw, i64_max.raw, 0x55)};
- return IfThenElse(overflow_mask, overflow_result, sum);
-}
-#endif // HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN
-
// ------------------------------ SaturatedSub
// Returns a - b clamped to the destination range.
@@ -2142,35 +2149,8 @@ HWY_API Vec256<int16_t> SaturatedSub(Vec256<int16_t> a, Vec256<int16_t> b) {
return Vec256<int16_t>{_mm256_subs_epi16(a.raw, b.raw)};
}
-#if HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN
-HWY_API Vec256<int32_t> SaturatedSub(Vec256<int32_t> a, Vec256<int32_t> b) {
- const DFromV<decltype(a)> d;
- const auto diff = a - b;
- const auto overflow_mask = MaskFromVec(
- Vec256<int32_t>{_mm256_ternarylogic_epi32(a.raw, b.raw, diff.raw, 0x18)});
- const auto i32_max = Set(d, LimitsMax<int32_t>());
- const Vec256<int32_t> overflow_result{_mm256_mask_ternarylogic_epi32(
- i32_max.raw, MaskFromVec(a).raw, i32_max.raw, i32_max.raw, 0x55)};
- return IfThenElse(overflow_mask, overflow_result, diff);
-}
-
-HWY_API Vec256<int64_t> SaturatedSub(Vec256<int64_t> a, Vec256<int64_t> b) {
- const DFromV<decltype(a)> d;
- const auto diff = a - b;
- const auto overflow_mask = MaskFromVec(
- Vec256<int64_t>{_mm256_ternarylogic_epi64(a.raw, b.raw, diff.raw, 0x18)});
- const auto i64_max = Set(d, LimitsMax<int64_t>());
- const Vec256<int64_t> overflow_result{_mm256_mask_ternarylogic_epi64(
- i64_max.raw, MaskFromVec(a).raw, i64_max.raw, i64_max.raw, 0x55)};
- return IfThenElse(overflow_mask, overflow_result, diff);
-}
-#endif // HWY_TARGET <= HWY_AVX3 && !HWY_IS_MSAN
-
// ------------------------------ Average
-// Returns (a + b + 1) / 2
-
-// Unsigned
HWY_API Vec256<uint8_t> AverageRound(Vec256<uint8_t> a, Vec256<uint8_t> b) {
return Vec256<uint8_t>{_mm256_avg_epu8(a.raw, b.raw)};
}
@@ -4876,6 +4856,17 @@ HWY_API V InterleaveOddBlocks(D d, V a, V b) {
return ConcatUpperUpper(d, b, a);
}
+// ------------------------------ InterleaveLowerBlocks
+template <class D, class V = VFromD<D>, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API V InterleaveLowerBlocks(D d, V a, V b) {
+ return InterleaveEvenBlocks(d, a, b);
+}
+// ------------------------------ InterleaveUpperBlocks
+template <class D, class V = VFromD<D>, HWY_IF_V_SIZE_D(D, 32)>
+HWY_API V InterleaveUpperBlocks(D d, V a, V b) {
+ return InterleaveOddBlocks(d, a, b);
+}
+
// ------------------------------ Reverse (RotateRight)
template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_T_SIZE_D(D, 4)>
@@ -6341,17 +6332,6 @@ HWY_API VFromD<D> ReorderWidenMulAccumulate(D d, Vec256<int16_t> a,
#endif
}
-// ------------------------------ RearrangeToOddPlusEven
-HWY_API Vec256<int32_t> RearrangeToOddPlusEven(const Vec256<int32_t> sum0,
- Vec256<int32_t> /*sum1*/) {
- return sum0; // invariant already holds
-}
-
-HWY_API Vec256<uint32_t> RearrangeToOddPlusEven(const Vec256<uint32_t> sum0,
- Vec256<uint32_t> /*sum1*/) {
- return sum0; // invariant already holds
-}
-
// ------------------------------ SumOfMulQuadAccumulate
#if HWY_TARGET <= HWY_AVX3_DL
@@ -6363,7 +6343,24 @@ HWY_API VFromD<DI32> SumOfMulQuadAccumulate(
return VFromD<DI32>{_mm256_dpbusd_epi32(sum.raw, a_u.raw, b_i.raw)};
}
-#endif
+#if HWY_X86_HAVE_AVX10_2_OPS
+template <class DI32, HWY_IF_I32_D(DI32), HWY_IF_V_SIZE_D(DI32, 32)>
+HWY_API VFromD<DI32> SumOfMulQuadAccumulate(DI32 /*di32*/,
+ VFromD<Repartition<int8_t, DI32>> a,
+ VFromD<Repartition<int8_t, DI32>> b,
+ VFromD<DI32> sum) {
+ return VFromD<DI32>{_mm256_dpbssd_epi32(sum.raw, a.raw, b.raw)};
+}
+
+template <class DU32, HWY_IF_U32_D(DU32), HWY_IF_V_SIZE_D(DU32, 32)>
+HWY_API VFromD<DU32> SumOfMulQuadAccumulate(
+ DU32 /*du32*/, VFromD<Repartition<uint8_t, DU32>> a,
+ VFromD<Repartition<uint8_t, DU32>> b, VFromD<DU32> sum) {
+ return VFromD<DU32>{_mm256_dpbuud_epi32(sum.raw, a.raw, b.raw)};
+}
+#endif // HWY_X86_HAVE_AVX10_2_OPS
+
+#endif // HWY_TARGET <= HWY_AVX3_DL
// ================================================== CONVERT
@@ -6446,7 +6443,9 @@ HWY_API VFromD<D> PromoteTo(D /* tag */, Vec32<int8_t> v) {
#if HWY_TARGET <= HWY_AVX3
template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I64_D(D)>
HWY_API VFromD<D> PromoteInRangeTo(D /*di64*/, VFromD<Rebind<float, D>> v) {
-#if HWY_COMPILER_GCC_ACTUAL
+#if HWY_X86_HAVE_AVX10_2_OPS
+ return VFromD<D>{_mm256_cvtts_ps_epi64(v.raw)};
+#elif HWY_COMPILER_GCC_ACTUAL
// Workaround for undefined behavior with GCC if any values of v[i] are not
// within the range of an int64_t
@@ -6474,7 +6473,9 @@ HWY_API VFromD<D> PromoteInRangeTo(D /*di64*/, VFromD<Rebind<float, D>> v) {
}
template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U64_D(D)>
HWY_API VFromD<D> PromoteInRangeTo(D /* tag */, VFromD<Rebind<float, D>> v) {
-#if HWY_COMPILER_GCC_ACTUAL
+#if HWY_X86_HAVE_AVX10_2_OPS
+ return VFromD<D>{_mm256_cvtts_ps_epu64(v.raw)};
+#elif HWY_COMPILER_GCC_ACTUAL
// Workaround for undefined behavior with GCC if any values of v[i] are not
// within the range of an uint64_t
#if HWY_COMPILER_GCC_ACTUAL >= 700 && !HWY_IS_DEBUG_BUILD
@@ -6853,7 +6854,9 @@ HWY_API VFromD<D> DemoteTo(D /* tag */, Vec256<double> v) {
template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_I32_D(D)>
HWY_API VFromD<D> DemoteInRangeTo(D /* tag */, Vec256<double> v) {
-#if HWY_COMPILER_GCC_ACTUAL
+#if HWY_X86_HAVE_AVX10_2_OPS
+ return VFromD<D>{_mm256_cvtts_pd_epi32(v.raw)};
+#elif HWY_COMPILER_GCC_ACTUAL
// Workaround for undefined behavior in _mm256_cvttpd_epi32 with GCC if any
// values of v[i] are not within the range of an int32_t
@@ -6883,7 +6886,9 @@ HWY_API VFromD<D> DemoteInRangeTo(D /* tag */, Vec256<double> v) {
#if HWY_TARGET <= HWY_AVX3
template <class D, HWY_IF_V_SIZE_D(D, 16), HWY_IF_U32_D(D)>
HWY_API VFromD<D> DemoteInRangeTo(D /* tag */, Vec256<double> v) {
-#if HWY_COMPILER_GCC_ACTUAL
+#if HWY_X86_HAVE_AVX10_2_OPS
+ return VFromD<D>{_mm256_cvtts_pd_epu32(v.raw)};
+#elif HWY_COMPILER_GCC_ACTUAL
// Workaround for undefined behavior in _mm256_cvttpd_epu32 with GCC if any
// values of v[i] are not within the range of an uint32_t
@@ -7178,7 +7183,9 @@ HWY_API VFromD<D> ConvertInRangeTo(D /* tag */, VFromD<RebindToFloat<D>> v) {
template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I32_D(D)>
HWY_API VFromD<D> ConvertInRangeTo(D /*d*/, Vec256<float> v) {
-#if HWY_COMPILER_GCC_ACTUAL
+#if HWY_X86_HAVE_AVX10_2_OPS
+ return VFromD<D>{_mm256_cvtts_ps_epi32(v.raw)};
+#elif HWY_COMPILER_GCC_ACTUAL
// Workaround for undefined behavior in _mm256_cvttps_epi32 with GCC if any
// values of v[i] are not within the range of an int32_t
@@ -7212,7 +7219,9 @@ HWY_API VFromD<D> ConvertInRangeTo(D /*d*/, Vec256<float> v) {
#if HWY_TARGET <= HWY_AVX3
template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I64_D(D)>
HWY_API VFromD<D> ConvertInRangeTo(D /*di*/, Vec256<double> v) {
-#if HWY_COMPILER_GCC_ACTUAL
+#if HWY_X86_HAVE_AVX10_2_OPS
+ return VFromD<D>{_mm256_cvtts_pd_epi64(v.raw)};
+#elif HWY_COMPILER_GCC_ACTUAL
// Workaround for undefined behavior in _mm256_cvttpd_epi64 with GCC if any
// values of v[i] are not within the range of an int64_t
@@ -7240,7 +7249,9 @@ HWY_API VFromD<D> ConvertInRangeTo(D /*di*/, Vec256<double> v) {
}
template <class DU, HWY_IF_V_SIZE_D(DU, 32), HWY_IF_U32_D(DU)>
HWY_API VFromD<DU> ConvertInRangeTo(DU /*du*/, VFromD<RebindToFloat<DU>> v) {
-#if HWY_COMPILER_GCC_ACTUAL
+#if HWY_X86_HAVE_AVX10_2_OPS
+ return VFromD<DU>{_mm256_cvtts_ps_epu32(v.raw)};
+#elif HWY_COMPILER_GCC_ACTUAL
// Workaround for undefined behavior in _mm256_cvttps_epu32 with GCC if any
// values of v[i] are not within the range of an uint32_t
@@ -7280,7 +7291,9 @@ HWY_API VFromD<DU> ConvertInRangeTo(DU /*du*/, VFromD<RebindToFloat<DU>> v) {
}
template <class DU, HWY_IF_V_SIZE_D(DU, 32), HWY_IF_U64_D(DU)>
HWY_API VFromD<DU> ConvertInRangeTo(DU /*du*/, VFromD<RebindToFloat<DU>> v) {
-#if HWY_COMPILER_GCC_ACTUAL
+#if HWY_X86_HAVE_AVX10_2_OPS
+ return VFromD<DU>{_mm256_cvtts_pd_epu64(v.raw)};
+#elif HWY_COMPILER_GCC_ACTUAL
// Workaround for undefined behavior in _mm256_cvttpd_epu64 with GCC if any
// values of v[i] are not within the range of an uint64_t
diff --git a/third_party/highway/hwy/ops/x86_512-inl.h b/third_party/highway/hwy/ops/x86_512-inl.h
index 9fc52d23f0..b821a1c2c1 100644
--- a/third_party/highway/hwy/ops/x86_512-inl.h
+++ b/third_party/highway/hwy/ops/x86_512-inl.h
@@ -637,16 +637,48 @@ HWY_API VFromD<D> Iota(D d, const T2 first) {
// ================================================== LOGICAL
-// ------------------------------ Not
+#if HWY_X86_HAVE_TERNARY_LOGIC
+namespace detail {
-template <typename T>
-HWY_API Vec512<T> Not(const Vec512<T> v) {
- const DFromV<decltype(v)> d;
- const RebindToUnsigned<decltype(d)> du;
- using VU = VFromD<decltype(du)>;
- const __m512i vu = BitCast(du, v).raw;
- return BitCast(d, VU{_mm512_ternarylogic_epi32(vu, vu, vu, 0x55)});
-}
+// Per-target partial specialization.
+template <uint8_t kTernLogOp>
+struct TernaryLogicImpl<kTernLogOp, 64> {
+ template <class V>
+ HWY_INLINE V operator()(V a, V b, V c) const {
+ const DFromV<decltype(a)> d;
+ const RebindToUnsigned<decltype(d)> du;
+ using VU = VFromD<decltype(du)>;
+ const __m512i ret = _mm512_ternarylogic_epi64(
+ BitCast(du, a).raw, BitCast(du, b).raw, BitCast(du, c).raw, kTernLogOp);
+ return BitCast(d, VU{ret});
+ }
+};
+
+// Same, but with writemask. If !mask, returns a.
+template <uint8_t kTernLogOp>
+struct MaskedTernaryLogicImpl<kTernLogOp, 64> {
+ template <class V, class D = DFromV<V>, HWY_IF_T_SIZE_D(D, 4)>
+ HWY_INLINE V operator()(MFromD<D> mask, V a, V b, V c) const {
+ const D d;
+ const RebindToUnsigned<decltype(d)> du;
+ using VU = VFromD<decltype(du)>;
+ const __m512i ret = _mm512_mask_ternarylogic_epi32(a.raw, mask.raw, b.raw,
+ c.raw, kTernLogOp);
+ return BitCast(d, VU{ret});
+ }
+ template <class V, class D = DFromV<V>, HWY_IF_T_SIZE_D(D, 8)>
+ HWY_INLINE V operator()(MFromD<D> mask, V a, V b, V c) const {
+ const D d;
+ const RebindToUnsigned<decltype(d)> du;
+ using VU = VFromD<decltype(du)>;
+ const __m512i ret = _mm512_mask_ternarylogic_epi64(a.raw, mask.raw, b.raw,
+ c.raw, kTernLogOp);
+ return BitCast(d, VU{ret});
+ }
+};
+
+} // namespace detail
+#endif // HWY_X86_HAVE_TERNARY_LOGIC
// ------------------------------ And
@@ -718,66 +750,6 @@ HWY_API Vec512<double> Xor(const Vec512<double> a, const Vec512<double> b) {
return Vec512<double>{_mm512_xor_pd(a.raw, b.raw)};
}
-// ------------------------------ Xor3
-template <typename T>
-HWY_API Vec512<T> Xor3(Vec512<T> x1, Vec512<T> x2, Vec512<T> x3) {
-#if !HWY_IS_MSAN
- const DFromV<decltype(x1)> d;
- const RebindToUnsigned<decltype(d)> du;
- using VU = VFromD<decltype(du)>;
- const __m512i ret = _mm512_ternarylogic_epi64(
- BitCast(du, x1).raw, BitCast(du, x2).raw, BitCast(du, x3).raw, 0x96);
- return BitCast(d, VU{ret});
-#else
- return Xor(x1, Xor(x2, x3));
-#endif
-}
-
-// ------------------------------ Or3
-template <typename T>
-HWY_API Vec512<T> Or3(Vec512<T> o1, Vec512<T> o2, Vec512<T> o3) {
-#if !HWY_IS_MSAN
- const DFromV<decltype(o1)> d;
- const RebindToUnsigned<decltype(d)> du;
- using VU = VFromD<decltype(du)>;
- const __m512i ret = _mm512_ternarylogic_epi64(
- BitCast(du, o1).raw, BitCast(du, o2).raw, BitCast(du, o3).raw, 0xFE);
- return BitCast(d, VU{ret});
-#else
- return Or(o1, Or(o2, o3));
-#endif
-}
-
-// ------------------------------ OrAnd
-template <typename T>
-HWY_API Vec512<T> OrAnd(Vec512<T> o, Vec512<T> a1, Vec512<T> a2) {
-#if !HWY_IS_MSAN
- const DFromV<decltype(o)> d;
- const RebindToUnsigned<decltype(d)> du;
- using VU = VFromD<decltype(du)>;
- const __m512i ret = _mm512_ternarylogic_epi64(
- BitCast(du, o).raw, BitCast(du, a1).raw, BitCast(du, a2).raw, 0xF8);
- return BitCast(d, VU{ret});
-#else
- return Or(o, And(a1, a2));
-#endif
-}
-
-// ------------------------------ IfVecThenElse
-template <typename T>
-HWY_API Vec512<T> IfVecThenElse(Vec512<T> mask, Vec512<T> yes, Vec512<T> no) {
-#if !HWY_IS_MSAN
- const DFromV<decltype(yes)> d;
- const RebindToUnsigned<decltype(d)> du;
- using VU = VFromD<decltype(du)>;
- return BitCast(d, VU{_mm512_ternarylogic_epi64(BitCast(du, mask).raw,
- BitCast(du, yes).raw,
- BitCast(du, no).raw, 0xCA)});
-#else
- return IfThenElse(MaskFromVec(mask), yes, no);
-#endif
-}
-
// ------------------------------ Operator overloads (internal-only if float)
template <typename T>
@@ -1734,6 +1706,68 @@ HWY_API Vec512<double> Max(Vec512<double> a, Vec512<double> b) {
return Vec512<double>{_mm512_max_pd(a.raw, b.raw)};
}
+// ------------------------------ MinNumber and MaxNumber
+
+#if HWY_X86_HAVE_AVX10_2_OPS
+
+#if HWY_HAVE_FLOAT16
+HWY_API Vec512<float16_t> MinNumber(Vec512<float16_t> a, Vec512<float16_t> b) {
+ return Vec512<float16_t>{_mm512_minmax_ph(a.raw, b.raw, 0x14)};
+}
+#endif
+HWY_API Vec512<float> MinNumber(Vec512<float> a, Vec512<float> b) {
+ return Vec512<float>{_mm512_minmax_ps(a.raw, b.raw, 0x14)};
+}
+HWY_API Vec512<double> MinNumber(Vec512<double> a, Vec512<double> b) {
+ return Vec512<double>{_mm512_minmax_pd(a.raw, b.raw, 0x14)};
+}
+
+#if HWY_HAVE_FLOAT16
+HWY_API Vec512<float16_t> MaxNumber(Vec512<float16_t> a, Vec512<float16_t> b) {
+ return Vec512<float16_t>{_mm512_minmax_ph(a.raw, b.raw, 0x15)};
+}
+#endif
+HWY_API Vec512<float> MaxNumber(Vec512<float> a, Vec512<float> b) {
+ return Vec512<float>{_mm512_minmax_ps(a.raw, b.raw, 0x15)};
+}
+HWY_API Vec512<double> MaxNumber(Vec512<double> a, Vec512<double> b) {
+ return Vec512<double>{_mm512_minmax_pd(a.raw, b.raw, 0x15)};
+}
+
+#endif
+
+// ------------------------------ MinMagnitude and MaxMagnitude
+
+#if HWY_X86_HAVE_AVX10_2_OPS
+
+#if HWY_HAVE_FLOAT16
+HWY_API Vec512<float16_t> MinMagnitude(Vec512<float16_t> a,
+ Vec512<float16_t> b) {
+ return Vec512<float16_t>{_mm512_minmax_ph(a.raw, b.raw, 0x16)};
+}
+#endif
+HWY_API Vec512<float> MinMagnitude(Vec512<float> a, Vec512<float> b) {
+ return Vec512<float>{_mm512_minmax_ps(a.raw, b.raw, 0x16)};
+}
+HWY_API Vec512<double> MinMagnitude(Vec512<double> a, Vec512<double> b) {
+ return Vec512<double>{_mm512_minmax_pd(a.raw, b.raw, 0x16)};
+}
+
+#if HWY_HAVE_FLOAT16
+HWY_API Vec512<float16_t> MaxMagnitude(Vec512<float16_t> a,
+ Vec512<float16_t> b) {
+ return Vec512<float16_t>{_mm512_minmax_ph(a.raw, b.raw, 0x17)};
+}
+#endif
+HWY_API Vec512<float> MaxMagnitude(Vec512<float> a, Vec512<float> b) {
+ return Vec512<float>{_mm512_minmax_ps(a.raw, b.raw, 0x17)};
+}
+HWY_API Vec512<double> MaxMagnitude(Vec512<double> a, Vec512<double> b) {
+ return Vec512<double>{_mm512_minmax_pd(a.raw, b.raw, 0x17)};
+}
+
+#endif
+
// ------------------------------ Integer multiplication
// Unsigned
@@ -1858,11 +1892,19 @@ HWY_API V GetExponent(V v) {
#endif
template <class V, HWY_IF_F32(TFromV<V>), HWY_IF_V_SIZE_V(V, 64)>
HWY_API V GetExponent(V v) {
+ // Work around warnings in the intrinsic definitions (passing -1 as a mask).
+ HWY_DIAGNOSTICS(push)
+ HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion")
return V{_mm512_getexp_ps(v.raw)};
+ HWY_DIAGNOSTICS(pop)
}
template <class V, HWY_IF_F64(TFromV<V>), HWY_IF_V_SIZE_V(V, 64)>
HWY_API V GetExponent(V v) {
+ // Work around warnings in the intrinsic definitions (passing -1 as a mask).
+ HWY_DIAGNOSTICS(push)
+ HWY_DIAGNOSTICS_OFF(disable : 4245 4365, ignored "-Wsign-conversion")
return V{_mm512_getexp_pd(v.raw)};
+ HWY_DIAGNOSTICS(pop)
}
// ------------------------------ MaskedMinOr
@@ -4786,6 +4828,65 @@ HWY_API Vec512<T> InterleaveOddBlocks(Full512<T> d, Vec512<T> a, Vec512<T> b) {
return OddEvenBlocks(b, SlideDownBlocks<1>(d, a));
}
+// ------------------------------ InterleaveLowerBlocks (TwoTablesLookupLanes)
+
+// Note that _mm512_shuffle_f32x4 etc. can only use `a` to populate the lower
+// half of the result, so we would require at least two instructions. We instead
+// use table lookups.
+
+template <typename T>
+HWY_API Vec512<T> InterleaveLowerBlocks(Full512<T> d, Vec512<T> a,
+ Vec512<T> b) {
+ const Repartition<uint64_t, decltype(d)> du64;
+ HWY_ALIGN static constexpr int64_t kIdx[8] = {0, 1, 8, 9, 2, 3, 10, 11};
+ const auto idx = SetTableIndices(du64, kIdx);
+ return BitCast(d,
+ TwoTablesLookupLanes(BitCast(du64, a), BitCast(du64, b), idx));
+}
+
+HWY_API Vec512<float> InterleaveLowerBlocks(Full512<float> d, Vec512<float> a,
+ Vec512<float> b) {
+ HWY_ALIGN static constexpr int32_t kIdx[16] = {0, 1, 2, 3, 16, 17, 18, 19,
+ 4, 5, 6, 7, 20, 21, 22, 23};
+ const auto idx = SetTableIndices(d, kIdx);
+ return TwoTablesLookupLanes(a, b, idx);
+}
+
+HWY_API Vec512<double> InterleaveLowerBlocks(Full512<double> d,
+ Vec512<double> a,
+ Vec512<double> b) {
+ HWY_ALIGN static constexpr int64_t kIdx[8] = {0, 1, 8, 9, 2, 3, 10, 11};
+ const auto idx = SetTableIndices(d, kIdx);
+ return TwoTablesLookupLanes(a, b, idx);
+}
+
+// ------------------------------ InterleaveUpperBlocks (TwoTablesLookupLanes)
+template <typename T>
+HWY_API Vec512<T> InterleaveUpperBlocks(Full512<T> d, Vec512<T> a,
+ Vec512<T> b) {
+ const Repartition<uint64_t, decltype(d)> du64;
+ HWY_ALIGN static constexpr int64_t kIdx[8] = {4, 5, 12, 13, 6, 7, 14, 15};
+ const auto idx = SetTableIndices(du64, kIdx);
+ return BitCast(
+ d, TwoTablesLookupLanes(du64, BitCast(du64, a), BitCast(du64, b), idx));
+}
+
+HWY_API Vec512<float> InterleaveUpperBlocks(Full512<float> d, Vec512<float> a,
+ Vec512<float> b) {
+ HWY_ALIGN static constexpr int32_t kIdx[16] = {
+ 8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31};
+ const auto idx = SetTableIndices(d, kIdx);
+ return TwoTablesLookupLanes(a, b, idx);
+}
+
+HWY_API Vec512<double> InterleaveUpperBlocks(Full512<double> d,
+ Vec512<double> a,
+ Vec512<double> b) {
+ HWY_ALIGN static constexpr int64_t kIdx[8] = {4, 5, 12, 13, 6, 7, 14, 15};
+ const auto idx = SetTableIndices(d, kIdx);
+ return TwoTablesLookupLanes(a, b, idx);
+}
+
// ------------------------------ ReverseBlocks
template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_NOT_FLOAT3264_D(D)>
@@ -5503,7 +5604,9 @@ HWY_API VFromD<D> PromoteTo(D /* tag */, Vec256<uint32_t> v) {
template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_I64_D(D)>
HWY_API VFromD<D> PromoteInRangeTo(D /*di64*/, VFromD<Rebind<float, D>> v) {
-#if HWY_COMPILER_GCC_ACTUAL
+#if HWY_X86_HAVE_AVX10_2_OPS
+ return VFromD<D>{_mm512_cvtts_ps_epi64(v.raw)};
+#elif HWY_COMPILER_GCC_ACTUAL
// Workaround for undefined behavior with GCC if any values of v[i] are not
// within the range of an int64_t
@@ -5535,7 +5638,9 @@ HWY_API VFromD<D> PromoteInRangeTo(D /*di64*/, VFromD<Rebind<float, D>> v) {
}
template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_U64_D(D)>
HWY_API VFromD<D> PromoteInRangeTo(D /* tag */, VFromD<Rebind<float, D>> v) {
-#if HWY_COMPILER_GCC_ACTUAL
+#if HWY_X86_HAVE_AVX10_2_OPS
+ return VFromD<D>{_mm512_cvtts_ps_epu64(v.raw)};
+#elif HWY_COMPILER_GCC_ACTUAL
// Workaround for undefined behavior with GCC if any values of v[i] are not
// within the range of an uint64_t
@@ -5823,7 +5928,9 @@ HWY_API VFromD<D> DemoteTo(D /* tag */, Vec512<double> v) {
template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_I32_D(D)>
HWY_API VFromD<D> DemoteInRangeTo(D /* tag */, Vec512<double> v) {
-#if HWY_COMPILER_GCC_ACTUAL
+#if HWY_X86_HAVE_AVX10_2_OPS
+ return VFromD<D>{_mm512_cvtts_pd_epi32(v.raw)};
+#elif HWY_COMPILER_GCC_ACTUAL
// Workaround for undefined behavior in _mm512_cvttpd_epi32 with GCC if any
// values of v[i] are not within the range of an int32_t
@@ -5831,7 +5938,8 @@ HWY_API VFromD<D> DemoteInRangeTo(D /* tag */, Vec512<double> v) {
if (detail::IsConstantX86VecForF2IConv<int32_t>(v)) {
typedef double GccF64RawVectType __attribute__((__vector_size__(64)));
const auto raw_v = reinterpret_cast<GccF64RawVectType>(v.raw);
- return VFromD<D>{_mm256_setr_epi32(
+ return VFromD<D>{
+ _mm256_setr_epi32(
detail::X86ConvertScalarFromFloat<int32_t>(raw_v[0]),
detail::X86ConvertScalarFromFloat<int32_t>(raw_v[1]),
detail::X86ConvertScalarFromFloat<int32_t>(raw_v[2]),
@@ -5839,7 +5947,8 @@ HWY_API VFromD<D> DemoteInRangeTo(D /* tag */, Vec512<double> v) {
detail::X86ConvertScalarFromFloat<int32_t>(raw_v[4]),
detail::X86ConvertScalarFromFloat<int32_t>(raw_v[5]),
detail::X86ConvertScalarFromFloat<int32_t>(raw_v[6]),
- detail::X86ConvertScalarFromFloat<int32_t>(raw_v[7]))};
+ detail::X86ConvertScalarFromFloat<int32_t>(raw_v[7]))
+ };
}
#endif
@@ -5856,7 +5965,9 @@ HWY_API VFromD<D> DemoteInRangeTo(D /* tag */, Vec512<double> v) {
template <class D, HWY_IF_V_SIZE_D(D, 32), HWY_IF_U32_D(D)>
HWY_API VFromD<D> DemoteInRangeTo(D /* tag */, Vec512<double> v) {
-#if HWY_COMPILER_GCC_ACTUAL
+#if HWY_X86_HAVE_AVX10_2_OPS
+ return VFromD<D>{_mm512_cvtts_pd_epu32(v.raw)};
+#elif HWY_COMPILER_GCC_ACTUAL
// Workaround for undefined behavior in _mm512_cvttpd_epu32 with GCC if any
// values of v[i] are not within the range of an uint32_t
@@ -6202,7 +6313,9 @@ HWY_API VFromD<D> ConvertInRangeTo(D /* tag */, VFromD<RebindToFloat<D>> v) {
#endif // HWY_HAVE_FLOAT16
template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_I32_D(D)>
HWY_API VFromD<D> ConvertInRangeTo(D /*d*/, Vec512<float> v) {
-#if HWY_COMPILER_GCC_ACTUAL
+#if HWY_X86_HAVE_AVX10_2_OPS
+ return VFromD<D>{_mm512_cvtts_ps_epi32(v.raw)};
+#elif HWY_COMPILER_GCC_ACTUAL
// Workaround for undefined behavior in _mm512_cvttps_epi32 with GCC if any
// values of v[i] are not within the range of an int32_t
@@ -6242,7 +6355,9 @@ HWY_API VFromD<D> ConvertInRangeTo(D /*d*/, Vec512<float> v) {
}
template <class D, HWY_IF_V_SIZE_D(D, 64), HWY_IF_I64_D(D)>
HWY_API VFromD<D> ConvertInRangeTo(D /*di*/, Vec512<double> v) {
-#if HWY_COMPILER_GCC_ACTUAL
+#if HWY_X86_HAVE_AVX10_2_OPS
+ return VFromD<D>{_mm512_cvtts_pd_epi64(v.raw)};
+#elif HWY_COMPILER_GCC_ACTUAL
// Workaround for undefined behavior in _mm512_cvttpd_epi64 with GCC if any
// values of v[i] are not within the range of an int64_t
@@ -6274,7 +6389,9 @@ HWY_API VFromD<D> ConvertInRangeTo(D /*di*/, Vec512<double> v) {
}
template <class DU, HWY_IF_V_SIZE_D(DU, 64), HWY_IF_U32_D(DU)>
HWY_API VFromD<DU> ConvertInRangeTo(DU /*du*/, VFromD<RebindToFloat<DU>> v) {
-#if HWY_COMPILER_GCC_ACTUAL
+#if HWY_X86_HAVE_AVX10_2_OPS
+ return VFromD<DU>{_mm512_cvtts_ps_epu32(v.raw)};
+#elif HWY_COMPILER_GCC_ACTUAL
// Workaround for undefined behavior in _mm512_cvttps_epu32 with GCC if any
// values of v[i] are not within the range of an uint32_t
@@ -6330,7 +6447,9 @@ HWY_API VFromD<DU> ConvertInRangeTo(DU /*du*/, VFromD<RebindToFloat<DU>> v) {
}
template <class DU, HWY_IF_V_SIZE_D(DU, 64), HWY_IF_U64_D(DU)>
HWY_API VFromD<DU> ConvertInRangeTo(DU /*du*/, VFromD<RebindToFloat<DU>> v) {
-#if HWY_COMPILER_GCC_ACTUAL
+#if HWY_X86_HAVE_AVX10_2_OPS
+ return VFromD<DU>{_mm512_cvtts_pd_epu64(v.raw)};
+#elif HWY_COMPILER_GCC_ACTUAL
// Workaround for undefined behavior in _mm512_cvttpd_epu64 with GCC if any
// values of v[i] are not within the range of an uint64_t
@@ -6656,74 +6775,32 @@ HWY_API Vec512<uint64_t> CLMulUpper(Vec512<uint64_t> va, Vec512<uint64_t> vb) {
// SumsOfAdjShufQuadAbsDiff)
template <int kAOffset, int kBOffset>
-static Vec512<uint16_t> SumsOfAdjQuadAbsDiff(Vec512<uint8_t> a,
- Vec512<uint8_t> b) {
+HWY_API Vec512<uint16_t> SumsOfAdjQuadAbsDiff(Vec512<uint8_t> a,
+ Vec512<uint8_t> b) {
static_assert(0 <= kAOffset && kAOffset <= 1,
"kAOffset must be between 0 and 1");
static_assert(0 <= kBOffset && kBOffset <= 3,
"kBOffset must be between 0 and 3");
+#if HWY_X86_HAVE_AVX10_2_OPS
+ // AVX10.2 now has the _mm512_mpsadbw_epu8 intrinsic available
+ return Vec512<uint16_t>{_mm512_mpsadbw_epu8(
+ a.raw, b.raw,
+ (kAOffset << 5) | (kBOffset << 3) | (kAOffset << 2) | kBOffset)};
+#else
const DFromV<decltype(a)> d;
const RepartitionToWideX2<decltype(d)> du32;
- // While AVX3 does not have a _mm512_mpsadbw_epu8 intrinsic, the
- // SumsOfAdjQuadAbsDiff operation is implementable for 512-bit vectors on
- // AVX3 using SumsOfShuffledQuadAbsDiff and U32 Broadcast.
+ // The _mm512_mpsadbw_epu8 intrinsic is not available prior to AVX10.2.
+ // The SumsOfAdjQuadAbsDiff operation is implementable for 512-bit vectors on
+ // pre-AVX10.2 targets that support AVX3 using SumsOfShuffledQuadAbsDiff and
+ // U32 Broadcast.
return SumsOfShuffledQuadAbsDiff<kAOffset + 2, kAOffset + 1, kAOffset + 1,
kAOffset>(
a, BitCast(d, Broadcast<kBOffset>(BitCast(du32, b))));
+#endif
}
-#if !HWY_IS_MSAN
-// ------------------------------ I32/I64 SaturatedAdd (MaskFromVec)
-
-HWY_API Vec512<int32_t> SaturatedAdd(Vec512<int32_t> a, Vec512<int32_t> b) {
- const DFromV<decltype(a)> d;
- const auto sum = a + b;
- const auto overflow_mask = MaskFromVec(
- Vec512<int32_t>{_mm512_ternarylogic_epi32(a.raw, b.raw, sum.raw, 0x42)});
- const auto i32_max = Set(d, LimitsMax<int32_t>());
- const Vec512<int32_t> overflow_result{_mm512_mask_ternarylogic_epi32(
- i32_max.raw, MaskFromVec(a).raw, i32_max.raw, i32_max.raw, 0x55)};
- return IfThenElse(overflow_mask, overflow_result, sum);
-}
-
-HWY_API Vec512<int64_t> SaturatedAdd(Vec512<int64_t> a, Vec512<int64_t> b) {
- const DFromV<decltype(a)> d;
- const auto sum = a + b;
- const auto overflow_mask = MaskFromVec(
- Vec512<int64_t>{_mm512_ternarylogic_epi64(a.raw, b.raw, sum.raw, 0x42)});
- const auto i64_max = Set(d, LimitsMax<int64_t>());
- const Vec512<int64_t> overflow_result{_mm512_mask_ternarylogic_epi64(
- i64_max.raw, MaskFromVec(a).raw, i64_max.raw, i64_max.raw, 0x55)};
- return IfThenElse(overflow_mask, overflow_result, sum);
-}
-
-// ------------------------------ I32/I64 SaturatedSub (MaskFromVec)
-
-HWY_API Vec512<int32_t> SaturatedSub(Vec512<int32_t> a, Vec512<int32_t> b) {
- const DFromV<decltype(a)> d;
- const auto diff = a - b;
- const auto overflow_mask = MaskFromVec(
- Vec512<int32_t>{_mm512_ternarylogic_epi32(a.raw, b.raw, diff.raw, 0x18)});
- const auto i32_max = Set(d, LimitsMax<int32_t>());
- const Vec512<int32_t> overflow_result{_mm512_mask_ternarylogic_epi32(
- i32_max.raw, MaskFromVec(a).raw, i32_max.raw, i32_max.raw, 0x55)};
- return IfThenElse(overflow_mask, overflow_result, diff);
-}
-
-HWY_API Vec512<int64_t> SaturatedSub(Vec512<int64_t> a, Vec512<int64_t> b) {
- const DFromV<decltype(a)> d;
- const auto diff = a - b;
- const auto overflow_mask = MaskFromVec(
- Vec512<int64_t>{_mm512_ternarylogic_epi64(a.raw, b.raw, diff.raw, 0x18)});
- const auto i64_max = Set(d, LimitsMax<int64_t>());
- const Vec512<int64_t> overflow_result{_mm512_mask_ternarylogic_epi64(
- i64_max.raw, MaskFromVec(a).raw, i64_max.raw, i64_max.raw, 0x55)};
- return IfThenElse(overflow_mask, overflow_result, diff);
-}
-#endif // !HWY_IS_MSAN
-
// ------------------------------ Mask testing
// Beware: the suffix indicates the number of mask bits, not lane size!
@@ -7527,16 +7604,6 @@ HWY_API VFromD<D> ReorderWidenMulAccumulate(D d, Vec512<int16_t> a,
#endif
}
-HWY_API Vec512<int32_t> RearrangeToOddPlusEven(const Vec512<int32_t> sum0,
- Vec512<int32_t> /*sum1*/) {
- return sum0; // invariant already holds
-}
-
-HWY_API Vec512<uint32_t> RearrangeToOddPlusEven(const Vec512<uint32_t> sum0,
- Vec512<uint32_t> /*sum1*/) {
- return sum0; // invariant already holds
-}
-
// ------------------------------ SumOfMulQuadAccumulate
#if HWY_TARGET <= HWY_AVX3_DL
@@ -7548,6 +7615,23 @@ HWY_API VFromD<DI32> SumOfMulQuadAccumulate(
return VFromD<DI32>{_mm512_dpbusd_epi32(sum.raw, a_u.raw, b_i.raw)};
}
+#if HWY_X86_HAVE_AVX10_2_OPS
+template <class DI32, HWY_IF_I32_D(DI32), HWY_IF_V_SIZE_D(DI32, 64)>
+HWY_API VFromD<DI32> SumOfMulQuadAccumulate(DI32 /*di32*/,
+ VFromD<Repartition<int8_t, DI32>> a,
+ VFromD<Repartition<int8_t, DI32>> b,
+ VFromD<DI32> sum) {
+ return VFromD<DI32>{_mm512_dpbssd_epi32(sum.raw, a.raw, b.raw)};
+}
+
+template <class DU32, HWY_IF_U32_D(DU32), HWY_IF_V_SIZE_D(DU32, 64)>
+HWY_API VFromD<DU32> SumOfMulQuadAccumulate(
+ DU32 /*du32*/, VFromD<Repartition<uint8_t, DU32>> a,
+ VFromD<Repartition<uint8_t, DU32>> b, VFromD<DU32> sum) {
+ return VFromD<DU32>{_mm512_dpbuud_epi32(sum.raw, a.raw, b.raw)};
+}
+#endif // HWY_X86_HAVE_AVX10_2_OPS
+
#endif
// ------------------------------ Reductions
diff --git a/third_party/highway/hwy/ops/x86_avx3-inl.h b/third_party/highway/hwy/ops/x86_avx3-inl.h
index 80f9488c6e..7cd1d1774a 100644
--- a/third_party/highway/hwy/ops/x86_avx3-inl.h
+++ b/third_party/highway/hwy/ops/x86_avx3-inl.h
@@ -15,15 +15,9 @@
// External include guard in highway.h - see comment there.
-#if HWY_TARGET == HWY_AVX10_2
-// For AVX10 targets that only support 256-bit or smaller vectors. Already
-// includes base.h and shared-inl.h.
-#include "third_party/highway/hwy/ops/x86_256-inl.h"
-#else
// For AVX3/AVX10 targets that support 512-byte vectors. Already includes base.h
// and shared-inl.h.
#include "third_party/highway/hwy/ops/x86_512-inl.h"
-#endif
// AVX3/AVX10 ops that have dependencies on ops defined in x86_512-inl.h if
// HWY_MAX_BYTES >= 64 is true are defined below
diff --git a/third_party/highway/hwy/per_target.cc b/third_party/highway/hwy/per_target.cc
new file mode 100644
index 0000000000..da8867cb9f
--- /dev/null
+++ b/third_party/highway/hwy/per_target.cc
@@ -0,0 +1,77 @@
+// Copyright 2022 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Enable all targets so that calling Have* does not call into a null pointer.
+#ifndef HWY_COMPILE_ALL_ATTAINABLE
+#define HWY_COMPILE_ALL_ATTAINABLE
+#endif
+#include "third_party/highway/hwy/per_target.h"
+
+#include <stddef.h>
+#include <stdint.h>
+
+#undef HWY_TARGET_INCLUDE
+#define HWY_TARGET_INCLUDE "hwy/per_target.cc"
+#include "third_party/highway/hwy/foreach_target.h" // IWYU pragma: keep
+#include "third_party/highway/hwy/highway.h"
+
+HWY_BEFORE_NAMESPACE();
+namespace hwy {
+namespace HWY_NAMESPACE {
+namespace {
+int64_t GetTarget() { return HWY_TARGET; }
+size_t GetVectorBytes() { return Lanes(ScalableTag<uint8_t>()); }
+bool GetHaveInteger64() { return HWY_HAVE_INTEGER64 != 0; }
+bool GetHaveFloat16() { return HWY_HAVE_FLOAT16 != 0; }
+bool GetHaveFloat64() { return HWY_HAVE_FLOAT64 != 0; }
+} // namespace
+// NOLINTNEXTLINE(google-readability-namespace-comments)
+} // namespace HWY_NAMESPACE
+
+} // namespace hwy
+HWY_AFTER_NAMESPACE();
+
+#if HWY_ONCE
+namespace hwy {
+namespace {
+HWY_EXPORT(GetTarget);
+HWY_EXPORT(GetVectorBytes);
+HWY_EXPORT(GetHaveInteger64);
+HWY_EXPORT(GetHaveFloat16);
+HWY_EXPORT(GetHaveFloat64);
+} // namespace
+
+HWY_DLLEXPORT int64_t DispatchedTarget() {
+ return HWY_DYNAMIC_DISPATCH(GetTarget)();
+}
+
+HWY_DLLEXPORT size_t VectorBytes() {
+ return HWY_DYNAMIC_DISPATCH(GetVectorBytes)();
+}
+
+HWY_DLLEXPORT bool HaveInteger64() {
+ return HWY_DYNAMIC_DISPATCH(GetHaveInteger64)();
+}
+
+HWY_DLLEXPORT bool HaveFloat16() {
+ return HWY_DYNAMIC_DISPATCH(GetHaveFloat16)();
+}
+
+HWY_DLLEXPORT bool HaveFloat64() {
+ return HWY_DYNAMIC_DISPATCH(GetHaveFloat64)();
+}
+
+} // namespace hwy
+#endif // HWY_ONCE
diff --git a/third_party/highway/hwy/perf_counters.cc b/third_party/highway/hwy/perf_counters.cc
new file mode 100644
index 0000000000..376f3ac3cd
--- /dev/null
+++ b/third_party/highway/hwy/perf_counters.cc
@@ -0,0 +1,377 @@
+// Copyright 2024 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "third_party/highway/hwy/perf_counters.h"
+
+#include "third_party/highway/hwy/detect_compiler_arch.h" // HWY_OS_LINUX
+
+#if HWY_OS_LINUX || HWY_IDE
+#include <errno.h>
+#include <fcntl.h> // open
+#include <linux/perf_event.h>
+#include <stddef.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <string.h> // strcmp
+#include <sys/ioctl.h>
+#include <sys/prctl.h>
+#include <sys/stat.h> // O_RDONLY
+#include <sys/syscall.h>
+#include <sys/utsname.h>
+#include <unistd.h>
+
+#include <string>
+#include <vector>
+
+#include "third_party/highway/hwy/base.h" // HWY_ASSERT
+#include "third_party/highway/hwy/bit_set.h"
+#include "third_party/highway/hwy/timer.h"
+
+#endif // HWY_OS_LINUX || HWY_IDE
+
+namespace hwy {
+namespace platform {
+
+#if HWY_OS_LINUX || HWY_IDE
+
+namespace {
+
+bool PerfCountersSupported() {
+ // This is the documented way.
+ struct stat s;
+ return stat("/proc/sys/kernel/perf_event_paranoid", &s) == 0;
+}
+
+// If we detect Linux < 6.9 and AMD EPYC, use cycles instead of ref-cycles
+// because the latter is not supported and returns 0, see
+// https://lwn.net/Articles/967791/.
+uint64_t RefCyclesOrCycles() {
+ const uint32_t ref_cycles = PERF_COUNT_HW_REF_CPU_CYCLES;
+
+ utsname buf;
+ if (uname(&buf) != 0) return ref_cycles;
+ if (std::string(buf.sysname) != "Linux") return ref_cycles;
+ int major, minor;
+ if (sscanf(buf.release, "%d.%d", &major, &minor) != 2) return ref_cycles;
+ if (major > 6 || (major == 6 && minor >= 9)) return ref_cycles;
+
+ // AMD Zen4 CPU
+ char cpu100[100];
+ if (!GetCpuString(cpu100)) return ref_cycles;
+ if (std::string(cpu100).rfind("AMD EPYC", 0) != 0) return ref_cycles;
+
+ return PERF_COUNT_HW_CPU_CYCLES;
+}
+
+struct CounterConfig { // for perf_event_open
+ uint64_t config;
+ uint32_t type;
+ PerfCounters::Counter c;
+};
+
+std::vector<CounterConfig> AllCounterConfigs() {
+ constexpr uint32_t kHW = PERF_TYPE_HARDWARE;
+ constexpr uint32_t kSW = PERF_TYPE_SOFTWARE;
+ constexpr uint32_t kC = PERF_TYPE_HW_CACHE;
+ constexpr uint64_t kL3 = PERF_COUNT_HW_CACHE_LL;
+ constexpr uint64_t kLoad = uint64_t{PERF_COUNT_HW_CACHE_OP_READ} << 8;
+ constexpr uint64_t kStore = uint64_t{PERF_COUNT_HW_CACHE_OP_WRITE} << 8;
+ constexpr uint64_t kAcc = uint64_t{PERF_COUNT_HW_CACHE_RESULT_ACCESS} << 16;
+
+ // Order is important for bin-packing event groups. x86 can only handle two
+ // LLC-related events per group, so spread them out and arrange SW events
+ // such that do not start a new group. This list of counters may change.
+ return {{RefCyclesOrCycles(), kHW, PerfCounters::kRefCycles},
+ {PERF_COUNT_HW_INSTRUCTIONS, kHW, PerfCounters::kInstructions},
+ {PERF_COUNT_SW_PAGE_FAULTS, kSW, PerfCounters::kPageFaults},
+ {kL3 | kLoad | kAcc, kC, PerfCounters::kL3Loads},
+ {kL3 | kStore | kAcc, kC, PerfCounters::kL3Stores},
+ {PERF_COUNT_HW_BRANCH_INSTRUCTIONS, kHW, PerfCounters::kBranches},
+ {PERF_COUNT_HW_BRANCH_MISSES, kHW, PerfCounters::kBranchMispredicts},
+ // Second group:
+ {PERF_COUNT_HW_BUS_CYCLES, kHW, PerfCounters::kBusCycles},
+ {PERF_COUNT_SW_CPU_MIGRATIONS, kSW, PerfCounters::kMigrations},
+ {PERF_COUNT_HW_CACHE_REFERENCES, kHW, PerfCounters::kCacheRefs},
+ {PERF_COUNT_HW_CACHE_MISSES, kHW, PerfCounters::kCacheMisses}};
+}
+
+size_t& PackedIdx(PerfCounters::Counter c) {
+ static size_t packed_idx[64];
+ return packed_idx[static_cast<size_t>(c)];
+}
+
+class PMU {
+ static perf_event_attr MakeAttr(const CounterConfig& cc) {
+ perf_event_attr attr = {};
+ attr.type = cc.type;
+ attr.size = sizeof(attr);
+ attr.config = cc.config;
+ // We request more counters than the HW may support. If so, they are
+ // multiplexed and only active for a fraction of the runtime. Recording the
+ // times lets us extrapolate. GROUP enables a single syscall to reduce the
+ // cost of reading.
+ attr.read_format = PERF_FORMAT_TOTAL_TIME_ENABLED |
+ PERF_FORMAT_TOTAL_TIME_RUNNING | PERF_FORMAT_GROUP;
+ // Do not set inherit=1 because that conflicts with PERF_FORMAT_GROUP.
+ // Do not set disable=1, so that perf_event_open verifies all events in the
+ // group can be scheduled together.
+ attr.exclude_kernel = 1; // required if perf_event_paranoid == 1
+ attr.exclude_hv = 1; // = hypervisor
+ return attr;
+ }
+
+ static int SysPerfEventOpen(const CounterConfig& cc, int leader_fd) {
+ perf_event_attr attr = MakeAttr(cc);
+ const int pid = 0; // current process (cannot also be -1)
+ const int cpu = -1; // any CPU
+ // Retry if interrupted by signals; this actually happens (b/64774091).
+ for (int retry = 0; retry < 10; ++retry) {
+ const int flags = 0;
+ const int fd = static_cast<int>(
+ syscall(__NR_perf_event_open, &attr, pid, cpu, leader_fd, flags));
+ if (!(fd == -1 && errno == EINTR)) return fd;
+ }
+ HWY_WARN("perf_event_open retries were insufficient.");
+ return -1;
+ }
+
+ // Reads from `fd`; recovers from interruptions before/during the read.
+ static bool ReadBytes(int fd, ssize_t size, void* to) {
+ uint8_t* bytes = reinterpret_cast<uint8_t*>(to);
+ ssize_t pos = 0;
+ for (int retry = 0; retry < 10; ++retry) {
+ const ssize_t bytes_read =
+ read(fd, bytes + pos, static_cast<size_t>(size - pos));
+ if (HWY_UNLIKELY(bytes_read <= 0)) {
+ if (errno == EINTR) continue;
+ HWY_WARN("perf read() failed, errno %d.", errno);
+ return false;
+ }
+ pos += bytes_read;
+ HWY_ASSERT(pos <= size);
+ if (HWY_LIKELY(pos == size)) return true; // success
+ }
+ HWY_WARN("perf read() wanted %d bytes, got %d.", static_cast<int>(size),
+ static_cast<int>(pos));
+ return false;
+ }
+
+ // Array size in Buf; this is another upper bound on group size. It should be
+ // loose because it only wastes a bit of stack space, whereas an unnecessary
+ // extra group decreases coverage. Most HW supports 4-8 counters per group.
+ static constexpr size_t kMaxEventsPerGroup = PerfCounters::kCapacity;
+
+#pragma pack(push, 1)
+ struct Buf {
+ uint64_t num_events;
+ uint64_t time_enabled;
+ uint64_t time_running;
+ uint64_t values[kMaxEventsPerGroup];
+ };
+#pragma pack(pop)
+
+ // Returns false on error, otherwise sets `extrapolate` and `values`.
+ static bool ReadAndExtrapolate(int fd, size_t num_events, double& extrapolate,
+ double* HWY_RESTRICT values) {
+ Buf buf;
+ const ssize_t want_bytes = // size of var-len `Buf`
+ static_cast<ssize_t>(24 + num_events * sizeof(uint64_t));
+ if (HWY_UNLIKELY(!ReadBytes(fd, want_bytes, &buf))) return false;
+
+ HWY_DASSERT(num_events == buf.num_events);
+ HWY_DASSERT(buf.time_running <= buf.time_enabled);
+ // If the group was not yet scheduled, we must avoid division by zero.
+ // In case counters were previously running and not reset, their current
+ // values may be nonzero. Returning zero could be interpreted as counters
+ // running backwards, so we instead treat this as a failure and mark the
+ // counters as invalid.
+ if (HWY_UNLIKELY(buf.time_running == 0)) return false;
+
+ // Extrapolate each value.
+ extrapolate = static_cast<double>(buf.time_enabled) /
+ static_cast<double>(buf.time_running);
+ for (size_t i = 0; i < buf.num_events; ++i) {
+ values[i] = static_cast<double>(buf.values[i]) * extrapolate;
+ }
+ return true;
+ }
+
+ public:
+ bool Init() {
+ // Allow callers who do not know about each other to each call `Init`.
+ // If this already succeeded, we're done; if not, we will try again.
+ if (HWY_UNLIKELY(!fds_.empty())) return true;
+ if (HWY_UNLIKELY(!PerfCountersSupported())) {
+ HWY_WARN(
+ "This Linux does not support perf counters. The program will"
+ "continue, but counters will return zero.");
+ return false;
+ }
+
+ groups_.push_back(Group());
+ fds_.reserve(PerfCounters::kCapacity);
+
+ for (const CounterConfig& config : AllCounterConfigs()) {
+ // If the group is limited by our buffer size, add a new one.
+ if (HWY_UNLIKELY(groups_.back().num_events == kMaxEventsPerGroup)) {
+ groups_.push_back(Group());
+ }
+
+ int fd = SysPerfEventOpen(config, groups_.back().leader_fd);
+ // Retry in case the group is limited by HW capacity. Do not check
+ // errno because it is too inconsistent (ENOSPC, EINVAL, others?).
+ if (HWY_UNLIKELY(fd < 0)) {
+ fd = SysPerfEventOpen(config, /*leader_fd=*/-1);
+ if (fd >= 0 && groups_.back().num_events != 0) {
+ groups_.push_back(Group());
+ }
+ }
+
+ if (HWY_UNLIKELY(fd < 0)) {
+ HWY_WARN("perf_event_open %d errno %d for counter %s.", fd, errno,
+ PerfCounters::Name(config.c));
+ } else {
+ // Add to group and set as leader if empty.
+ if (groups_.back().leader_fd == -1) {
+ groups_.back().leader_fd = fd;
+
+ // Ensure the leader is not a SW event, because adding an HW
+ // event to a group with only SW events is slow, and starting
+ // with SW may trigger a bug, see
+ // https://lore.kernel.org/lkml/tip-a1150c202207cc8501bebc45b63c264f91959260@git.kernel.org/
+ if (HWY_UNLIKELY(config.type == PERF_TYPE_SOFTWARE)) {
+ HWY_WARN("SW event %s should not be leader.",
+ PerfCounters::Name(config.c));
+ }
+ }
+
+ PackedIdx(config.c) = fds_.size();
+ groups_.back().num_events += 1;
+ valid_.Set(static_cast<size_t>(config.c));
+ fds_.push_back(fd);
+ }
+ }
+
+ // If no counters are available, remove the empty group.
+ if (HWY_UNLIKELY(fds_.empty())) {
+ HWY_ASSERT(groups_.size() == 1);
+ HWY_ASSERT(groups_.back().num_events == 0);
+ HWY_ASSERT(groups_.back().leader_fd == -1);
+ groups_.clear();
+ }
+
+ size_t num_valid = 0;
+ for (const Group& group : groups_) {
+ num_valid += group.num_events;
+ // All groups have a leader and are not empty.
+ HWY_ASSERT(group.leader_fd >= 0);
+ HWY_ASSERT(0 != group.num_events &&
+ group.num_events <= kMaxEventsPerGroup);
+ }
+ // Total `num_events` matches `fds_` and `Valid()`.
+ HWY_ASSERT(num_valid == fds_.size());
+ HWY_ASSERT(num_valid == valid_.Count());
+ HWY_ASSERT(num_valid <= PerfCounters::kCapacity);
+
+ if (num_valid) {
+ StopAllAndReset();
+ return true;
+ } else {
+ HWY_WARN("No valid counters found.");
+ return true;
+ }
+ }
+
+ bool StartAll() {
+ if (HWY_UNLIKELY(fds_.empty())) return false;
+ HWY_ASSERT(prctl(PR_TASK_PERF_EVENTS_ENABLE) == 0);
+ return true;
+ }
+
+ void StopAllAndReset() {
+ HWY_ASSERT(prctl(PR_TASK_PERF_EVENTS_DISABLE) == 0);
+ for (int fd : fds_) {
+ HWY_ASSERT(ioctl(fd, PERF_EVENT_IOC_RESET, 0) == 0);
+ }
+ }
+
+ // Returns false on error, otherwise sets `valid`, `max_extrapolate`, and
+ // `values`.
+ bool Read(BitSet64& valid, double& max_extrapolate, double* values) {
+ if (HWY_UNLIKELY(!valid_.Any())) return false;
+
+ // Read all counters into buffer in the order in which they were opened.
+ max_extrapolate = 1.0;
+ double* pos = values;
+ for (const Group& group : groups_) {
+ double extrapolate;
+ if (HWY_UNLIKELY(!ReadAndExtrapolate(group.leader_fd, group.num_events,
+ extrapolate, pos))) {
+ return false;
+ }
+ max_extrapolate = HWY_MAX(max_extrapolate, extrapolate);
+ pos += group.num_events;
+ }
+
+ valid = valid_;
+ HWY_DASSERT(pos == values + valid.Count());
+ return true;
+ }
+
+ private:
+ std::vector<int> fds_; // one per valid_
+ BitSet64 valid_;
+
+ struct Group {
+ size_t num_events = 0;
+ int leader_fd = -1;
+ };
+ std::vector<Group> groups_;
+};
+
+// Monostate, see header.
+PMU& GetPMU() {
+ static PMU& pmu = *new PMU(); // avoids exit-dtor warning (no dtor required)
+ return pmu;
+}
+
+} // namespace
+
+HWY_DLLEXPORT bool PerfCounters::Init() { return GetPMU().Init(); }
+HWY_DLLEXPORT bool PerfCounters::StartAll() { return GetPMU().StartAll(); }
+HWY_DLLEXPORT void PerfCounters::StopAllAndReset() {
+ GetPMU().StopAllAndReset();
+}
+HWY_DLLEXPORT PerfCounters::PerfCounters() {
+ if (HWY_UNLIKELY(!GetPMU().Read(valid_, max_extrapolate_, values_))) {
+ valid_ = BitSet64();
+ max_extrapolate_ = 0.0;
+ hwy::ZeroBytes(values_, sizeof(values_));
+ }
+}
+HWY_DLLEXPORT size_t PerfCounters::IndexForCounter(Counter c) {
+ return PackedIdx(c);
+}
+#else
+HWY_DLLEXPORT bool PerfCounters::Init() { return false; }
+HWY_DLLEXPORT bool PerfCounters::StartAll() { return false; }
+HWY_DLLEXPORT void PerfCounters::StopAllAndReset() {}
+HWY_DLLEXPORT PerfCounters::PerfCounters()
+ : max_extrapolate_(1.0), values_{0.0} {}
+HWY_DLLEXPORT size_t PerfCounters::IndexForCounter(Counter) { return 0; }
+#endif // HWY_OS_LINUX || HWY_IDE
+
+} // namespace platform
+} // namespace hwy
diff --git a/third_party/highway/hwy/perf_counters.h b/third_party/highway/hwy/perf_counters.h
index 2764e985bd..32efd3fb3a 100644
--- a/third_party/highway/hwy/perf_counters.h
+++ b/third_party/highway/hwy/perf_counters.h
@@ -82,7 +82,7 @@ class PerfCounters {
case kMigrations:
return "migration";
default:
- HWY_ABORT("Bug: unknown counter %d", c);
+ HWY_UNREACHABLE;
}
}
diff --git a/third_party/highway/hwy/print-inl.h b/third_party/highway/hwy/print-inl.h
index 16cfa141eb..5ecb04948d 100644
--- a/third_party/highway/hwy/print-inl.h
+++ b/third_party/highway/hwy/print-inl.h
@@ -15,6 +15,8 @@
// Print() function
+#include <stddef.h>
+
#include "third_party/highway/hwy/highway.h"
#include "third_party/highway/hwy/print.h"
diff --git a/third_party/highway/hwy/print.cc b/third_party/highway/hwy/print.cc
new file mode 100644
index 0000000000..723950e363
--- /dev/null
+++ b/third_party/highway/hwy/print.cc
@@ -0,0 +1,138 @@
+// Copyright 2022 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "third_party/highway/hwy/print.h"
+
+#include <stdio.h>
+
+#include "third_party/highway/hwy/base.h"
+#include "third_party/highway/hwy/detect_compiler_arch.h"
+
+namespace hwy {
+namespace detail {
+
+HWY_DLLEXPORT void TypeName(const TypeInfo& info, size_t N, char* string100) {
+ const char prefix = info.is_float ? 'f' : (info.is_signed ? 'i' : 'u');
+ // Omit the xN suffix for scalars.
+ if (N == 1) {
+ // NOLINTNEXTLINE
+ snprintf(string100, 64, "%c%d", prefix,
+ static_cast<int>(info.sizeof_t * 8));
+ } else {
+ // NOLINTNEXTLINE
+ snprintf(string100, 64, "%c%dx%d", prefix,
+ static_cast<int>(info.sizeof_t * 8), static_cast<int>(N));
+ }
+}
+
+// The NOLINT are to suppress the warning about passing 100 instead of
+// `sizeof(string100)`, which is a pointer.
+HWY_DLLEXPORT void ToString(const TypeInfo& info, const void* ptr,
+ char* string100) {
+ if (info.sizeof_t == 1) {
+ if (info.is_signed) {
+ int8_t byte;
+ CopyBytes<1>(ptr, &byte); // endian-safe: we ensured sizeof(T)=1.
+ snprintf(string100, 100, "%d", byte); // NOLINT
+ } else {
+ uint8_t byte;
+ CopyBytes<1>(ptr, &byte); // endian-safe: we ensured sizeof(T)=1.
+ snprintf(string100, 100, "0x%02X", byte); // NOLINT
+ }
+ } else if (info.sizeof_t == 2) {
+ if (info.is_bf16) {
+ const double value = static_cast<double>(F32FromBF16Mem(ptr));
+ // NOLINTNEXTLINE
+ snprintf(string100, 100, hwy::ScalarAbs(value) < 1E-3 ? "%.3E" : "%.3f",
+ value);
+ } else if (info.is_float) {
+ const double value = static_cast<double>(F32FromF16Mem(ptr));
+ // NOLINTNEXTLINE
+ snprintf(string100, 100, hwy::ScalarAbs(value) < 1E-4 ? "%.4E" : "%.4f",
+ value);
+ } else {
+ uint16_t bits;
+ CopyBytes<2>(ptr, &bits);
+ snprintf(string100, 100, "0x%04X", bits); // NOLINT
+ }
+ } else if (info.sizeof_t == 4) {
+ if (info.is_float) {
+ float value;
+ CopyBytes<4>(ptr, &value);
+ // NOLINTNEXTLINE
+ snprintf(string100, 100, hwy::ScalarAbs(value) < 1E-6f ? "%.9E" : "%.9f",
+ static_cast<double>(value));
+ } else if (info.is_signed) {
+ int32_t value;
+ CopyBytes<4>(ptr, &value);
+ snprintf(string100, 100, "%d", value); // NOLINT
+ } else {
+ uint32_t value;
+ CopyBytes<4>(ptr, &value);
+ snprintf(string100, 100, "%u", value); // NOLINT
+ }
+ } else if (info.sizeof_t == 8) {
+ if (info.is_float) {
+ double value;
+ CopyBytes<8>(ptr, &value);
+ // NOLINTNEXTLINE
+ snprintf(string100, 100, hwy::ScalarAbs(value) < 1E-9 ? "%.18E" : "%.18f",
+ value);
+ } else {
+ const uint8_t* ptr8 = reinterpret_cast<const uint8_t*>(ptr);
+ uint32_t lo, hi;
+ CopyBytes<4>(ptr8 + (HWY_IS_LITTLE_ENDIAN ? 0 : 4), &lo);
+ CopyBytes<4>(ptr8 + (HWY_IS_LITTLE_ENDIAN ? 4 : 0), &hi);
+ snprintf(string100, 100, "0x%08x%08x", hi, lo); // NOLINT
+ }
+ } else if (info.sizeof_t == 16) {
+ HWY_ASSERT(!info.is_float && !info.is_signed && !info.is_bf16);
+ const uint8_t* ptr8 = reinterpret_cast<const uint8_t*>(ptr);
+ uint32_t words[4];
+ CopyBytes<4>(ptr8 + (HWY_IS_LITTLE_ENDIAN ? 0 : 12), &words[0]);
+ CopyBytes<4>(ptr8 + (HWY_IS_LITTLE_ENDIAN ? 4 : 8), &words[1]);
+ CopyBytes<4>(ptr8 + (HWY_IS_LITTLE_ENDIAN ? 8 : 4), &words[2]);
+ CopyBytes<4>(ptr8 + (HWY_IS_LITTLE_ENDIAN ? 12 : 0), &words[3]);
+ // NOLINTNEXTLINE
+ snprintf(string100, 100, "0x%08x%08x_%08x%08x", words[3], words[2],
+ words[1], words[0]);
+ }
+}
+
+HWY_DLLEXPORT void PrintArray(const TypeInfo& info, const char* caption,
+ const void* array_void, size_t N, size_t lane_u,
+ size_t max_lanes) {
+ const uint8_t* array_bytes = reinterpret_cast<const uint8_t*>(array_void);
+
+ char type_name[100];
+ TypeName(info, N, type_name);
+
+ const intptr_t lane = intptr_t(lane_u);
+ const size_t begin = static_cast<size_t>(HWY_MAX(0, lane - 2));
+ const size_t end = HWY_MIN(begin + max_lanes, N);
+ fprintf(stderr, "%s %s [%d+ ->]:\n ", type_name, caption,
+ static_cast<int>(begin));
+ for (size_t i = begin; i < end; ++i) {
+ const void* ptr = array_bytes + i * info.sizeof_t;
+ char str[100];
+ ToString(info, ptr, str);
+ fprintf(stderr, "%s,", str);
+ }
+ if (begin >= end) fprintf(stderr, "(out of bounds)");
+ fprintf(stderr, "\n");
+}
+
+} // namespace detail
+} // namespace hwy
diff --git a/third_party/highway/hwy/profiler.cc b/third_party/highway/hwy/profiler.cc
new file mode 100644
index 0000000000..b17a32ae1f
--- /dev/null
+++ b/third_party/highway/hwy/profiler.cc
@@ -0,0 +1,146 @@
+// Copyright 2025 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "third_party/highway/hwy/profiler.h"
+
+#include "third_party/highway/hwy/highway_export.h" // HWY_DLLEXPORT
+
+#if PROFILER_ENABLED
+
+#include <stddef.h>
+#include <stdint.h>
+#include <stdio.h>
+
+#include "third_party/highway/hwy/base.h"
+#include "third_party/highway/hwy/robust_statistics.h"
+#include "third_party/highway/hwy/timer.h"
+
+#endif // PROFILER_ENABLED
+
+namespace hwy {
+
+#if PROFILER_ENABLED
+
+static constexpr bool kPrintOverhead = true;
+
+// Must zero-init because `ThreadFunc` calls `SetGlobalIdx()` potentially after
+// this is first used in the `pool::Worker` ctor.
+/*static*/ thread_local size_t Profiler::s_global_idx = 0;
+
+// Detects duration of a zero-length zone: timer plus packet overhead.
+static uint64_t DetectSelfOverhead(Profiler& profiler, size_t global_idx) {
+ static const profiler::ZoneHandle zone = profiler.AddZone("DetectSelf");
+ profiler::Results results;
+ const size_t kNumSamples = 25;
+ uint32_t samples[kNumSamples];
+ for (size_t idx_sample = 0; idx_sample < kNumSamples; ++idx_sample) {
+ // Enough for stable measurements, but only about 50 ms startup cost.
+ const size_t kNumDurations = 700;
+ uint32_t durations[kNumDurations];
+ for (size_t idx_duration = 0; idx_duration < kNumDurations;
+ ++idx_duration) {
+ {
+ PROFILER_ZONE3(profiler, global_idx, zone);
+ }
+ durations[idx_duration] =
+ static_cast<uint32_t>(profiler.GetFirstDurationAndReset(global_idx));
+ }
+ samples[idx_sample] = robust_statistics::Mode(durations, kNumDurations);
+ }
+ return robust_statistics::Mode(samples, kNumSamples);
+}
+
+// Detects average duration of a zero-length zone, after deducting self
+// overhead. This accounts for the delay before/after capturing start/end
+// timestamps, for example due to fence instructions in timer::Start/Stop.
+static uint64_t DetectChildOverhead(Profiler& profiler, size_t global_idx,
+ uint64_t self_overhead) {
+ static const profiler::ZoneHandle zone = profiler.AddZone("DetectChild");
+ // Enough for stable measurements, but only about 50 ms startup cost.
+ const size_t kMaxSamples = 30;
+ uint32_t samples[kMaxSamples];
+ size_t num_samples = 0;
+ // Upper bound because timer resolution might be too coarse to get nonzero.
+ for (size_t s = 0; s < 2 * kMaxSamples && num_samples < kMaxSamples; ++s) {
+ const size_t kNumDurations = 50;
+ uint32_t durations[kNumDurations];
+ for (size_t d = 0; d < kNumDurations; ++d) {
+ constexpr size_t kReps = 500;
+ HWY_FENCE;
+ const uint64_t t0 = timer::Start();
+ for (size_t r = 0; r < kReps; ++r) {
+ PROFILER_ZONE3(profiler, global_idx, zone);
+ }
+ const uint64_t t1 = timer::Stop();
+ HWY_FENCE;
+ // We are measuring the total, not individual zone durations, to include
+ // cross-zone overhead.
+ (void)profiler.GetFirstDurationAndReset(global_idx);
+
+ const uint64_t avg_duration = (t1 - t0 + kReps / 2) / kReps;
+ durations[d] = static_cast<uint32_t>(
+ profiler::PerWorker::ClampedSubtract(avg_duration, self_overhead));
+ }
+ samples[num_samples] = robust_statistics::Mode(durations, kNumDurations);
+ // Overhead is nonzero, but we often measure zero; skip them to prevent
+ // getting a zero result.
+ num_samples += (samples[num_samples] != 0);
+ }
+ return num_samples == 0 ? 0 : robust_statistics::Mode(samples, num_samples);
+}
+
+Profiler::Profiler() {
+ const uint64_t t0 = timer::Start();
+
+ char cpu[100];
+ if (HWY_UNLIKELY(!platform::HaveTimerStop(cpu))) {
+ HWY_ABORT("CPU %s is too old for PROFILER_ENABLED=1, exiting", cpu);
+ }
+
+ // `ThreadPool` calls `Profiler::Get()` before it creates threads, hence this
+ // is guaranteed to be running on the main thread.
+ constexpr size_t kMain = 0;
+ // Must be called before any use of `PROFILER_ZONE*/PROFILER_FUNC*`. This runs
+ // only once because `Profiler` is a singleton.
+ ReserveWorker(kMain);
+ SetGlobalIdx(kMain);
+
+ profiler::Overheads overheads;
+ // WARNING: must pass in `*this` and use `PROFILER_ZONE3` to avoid calling
+ // `Profiler::Get()`, because that would re-enter the magic static init.
+ overheads.self = DetectSelfOverhead(*this, kMain);
+ overheads.child = DetectChildOverhead(*this, kMain, overheads.self);
+ for (size_t worker = 0; worker < profiler::kMaxWorkers; ++worker) {
+ workers_[worker].SetOverheads(overheads);
+ }
+
+ HWY_IF_CONSTEXPR(kPrintOverhead) {
+ printf("Self overhead: %.0f; child: %.0f; elapsed %.1f ms\n",
+ static_cast<double>(overheads.self),
+ static_cast<double>(overheads.child),
+ static_cast<double>(timer::Stop() - t0) /
+ platform::InvariantTicksPerSecond() * 1E3);
+ }
+}
+
+#endif // PROFILER_ENABLED
+
+// Even if disabled, we want to export the symbol.
+HWY_DLLEXPORT Profiler& Profiler::Get() {
+ static Profiler* profiler = new Profiler();
+ return *profiler;
+}
+
+} // namespace hwy
diff --git a/third_party/highway/hwy/profiler.h b/third_party/highway/hwy/profiler.h
index a9c2813615..bede7d3288 100644
--- a/third_party/highway/hwy/profiler.h
+++ b/third_party/highway/hwy/profiler.h
@@ -15,658 +15,861 @@
#ifndef HIGHWAY_HWY_PROFILER_H_
#define HIGHWAY_HWY_PROFILER_H_
+#include <stddef.h>
+#include <stdint.h>
+#include <string.h> // strcmp, strlen
+
+#include <atomic>
+#include <functional>
+
+#include "third_party/highway/hwy/base.h"
+#include "third_party/highway/hwy/highway_export.h"
+
// High precision, low overhead time measurements. Returns exact call counts and
// total elapsed time for user-defined 'zones' (code regions, i.e. C++ scopes).
//
// Uses RAII to capture begin/end timestamps, with user-specified zone names:
-// { PROFILER_ZONE("name"); /*code*/ } or
-// the name of the current function:
-// void FuncToMeasure() { PROFILER_FUNC; /*code*/ }.
+// `{ PROFILER_ZONE("name"); /*code*/ }` or the name of the current function:
+// `void FuncToMeasure() { PROFILER_FUNC; /*code*/ }`.
+//
+// You can reduce the overhead by passing `global_idx`, which can be taken from
+// the argument to the `ThreadPool::Run` lambda (if the pool was constructed
+// with non-default `PoolWorkerMapping`), or from a saved copy of the
+// thread-local `Profiler::Thread`: `PROFILER_ZONE2(global_idx, name)`.
+//
+// The preferred API allows passing flags, such as requesting inclusive time:
+// `static const auto zone = profiler.AddZone("name", flags);` and then
+// `PROFILER_ZONE3(profiler, global_idx, zone)`.
//
-// After all threads have exited any zones, invoke PROFILER_PRINT_RESULTS() to
+// After all threads exit all zones, call `Profiler::Get().PrintResults()` to
// print call counts and average durations [CPU cycles] to stdout, sorted in
// descending order of total duration.
-//
-// The binary MUST be built with --dynamic_mode=off because we rely on the data
-// segments being nearby; if not, an assertion will likely fail.
-
-#include "third_party/highway/hwy/base.h"
-// Configuration settings:
-
-// If zero, this file has no effect and no measurements will be recorded.
+// If zero, mock `Profiler` and `profiler::Zone` will be defined.
#ifndef PROFILER_ENABLED
#define PROFILER_ENABLED 0
#endif
-// How many mebibytes to allocate (if PROFILER_ENABLED) per thread that
-// enters at least one zone. Once this buffer is full, the thread will analyze
-// and discard packets, thus temporarily adding some observer overhead.
-// Each zone occupies 16 bytes.
-#ifndef PROFILER_THREAD_STORAGE
-#define PROFILER_THREAD_STORAGE 200ULL
-#endif
-
-#if PROFILER_ENABLED || HWY_IDE
-
-#include <stddef.h>
-#include <stdint.h>
+#if PROFILER_ENABLED
#include <stdio.h>
-#include <string.h> // strcmp
-#include <atomic>
+#include <algorithm> // std::sort
+#include <utility>
+#include <vector>
#include "third_party/highway/hwy/aligned_allocator.h"
-#include "third_party/highway/hwy/cache_control.h" // FlushStream
-#include "third_party/highway/hwy/contrib/sort/vqsort.h"
-#include "third_party/highway/hwy/robust_statistics.h"
+#include "third_party/highway/hwy/bit_set.h"
#include "third_party/highway/hwy/timer.h"
-
-#define PROFILER_PRINT_OVERHEAD 0
+#endif // PROFILER_ENABLED
namespace hwy {
-// Upper bounds for fixed-size data structures (guarded via HWY_DASSERT):
+// Flags: we want type-safety (enum class) to catch mistakes such as confusing
+// zone with flags. Base type (`uint32_t`) ensures it is safe to cast. Defined
+// outside the `#if` because callers pass them to `PROFILER_ZONE3`. When adding
+// flags, also update `kNumFlags` and `ChildTotalMask`.
+enum class ProfilerFlags : uint32_t {
+ kDefault = 0,
+ // The zone should report cumulative time, including all child zones. If not
+ // specified, zones report self-time, excluding child zones.
+ kInclusive = 1
+};
-// How many threads can actually enter a zone (those that don't do not count).
-// Memory use is about kMaxThreads * PROFILER_THREAD_STORAGE MiB.
-// WARNING: a fiber library can spawn hundreds of threads.
-static constexpr size_t kMaxThreads = 256;
+// Called during `PrintResults` to print results from other modules.
+using ProfilerFunc = std::function<void(void)>;
-static constexpr size_t kMaxDepth = 64; // Maximum nesting of zones.
+template <size_t kMaxStrings>
+class StringTable {
+ static constexpr std::memory_order kRelaxed = std::memory_order_relaxed;
+ static constexpr std::memory_order kAcq = std::memory_order_acquire;
+ static constexpr std::memory_order kRel = std::memory_order_release;
-static constexpr size_t kMaxZones = 256; // Total number of zones.
+ public:
+ // Returns a copy of the `name` passed to `Add` that returned the
+ // given `idx`.
+ const char* Name(size_t idx) const {
+ // `kAcq` so that the string contents are also visible after the pointer is
+ // published via `kRelease` store.
+ return ptrs_[idx].load(kAcq);
+ }
-#pragma pack(push, 1)
+ // Returns `idx < kMaxStrings`. Can be called concurrently. Calls with the
+ // same `name` return the same `idx`.
+ size_t Add(const char* name) {
+ // Linear search if it already exists. `kAcq` ensures we see prior stores.
+ const size_t num_strings = next_ptr_.load(kAcq);
+ HWY_ASSERT(num_strings < kMaxStrings);
+ for (size_t idx = 1; idx < num_strings; ++idx) {
+ const char* existing = ptrs_[idx].load(kAcq);
+ // `next_ptr_` was published after writing `ptr_`, hence it is non-null.
+ HWY_ASSERT(existing != nullptr);
+ if (HWY_UNLIKELY(!strcmp(existing, name))) {
+ return idx;
+ }
+ }
-// Represents zone entry/exit events. Stores a full-resolution timestamp plus
-// an offset (representing zone name or identifying exit packets). POD.
-class Packet {
- public:
- // If offsets do not fit, UpdateOrAdd will overrun our heap allocation
- // (governed by kMaxZones). We have seen multi-megabyte offsets.
- static constexpr size_t kOffsetBits = 25;
- static constexpr uint64_t kOffsetBias = 1ULL << (kOffsetBits - 1);
+ // Copy `name` into `chars_` before publishing the pointer.
+ const size_t len = strlen(name) + 1;
+ const size_t pos = next_char_.fetch_add(len, kRelaxed);
+ HWY_ASSERT(pos + len <= sizeof(chars_));
+ strcpy(chars_ + pos, name); // NOLINT
+
+ for (;;) {
+ size_t idx = next_ptr_.load(kRelaxed);
+ HWY_ASSERT(idx < kMaxStrings);
+
+ // Attempt to claim the next `idx` via CAS.
+ const char* expected = nullptr;
+ if (HWY_LIKELY(ptrs_[idx].compare_exchange_weak(expected, chars_ + pos,
+ kRel, kRelaxed))) {
+ // Publish the new count and make the `ptrs_` write visible.
+ next_ptr_.store(idx + 1, kRel);
+ HWY_DASSERT(!strcmp(Name(idx), name));
+ return idx;
+ }
+
+ // We lost the race. `expected` has been updated.
+ if (HWY_UNLIKELY(!strcmp(expected, name))) {
+ // Done, another thread added the same name. Note that we waste the
+ // extra space in `chars_`, which is fine because it is rare.
+ HWY_DASSERT(!strcmp(Name(idx), name));
+ return idx;
+ }
+
+ // Other thread added a different name. Retry with the next slot.
+ }
+ }
+
+ private:
+ std::atomic<const char*> ptrs_[kMaxStrings];
+ std::atomic<size_t> next_ptr_{1}; // next idx
+ std::atomic<size_t> next_char_{0};
+ char chars_[kMaxStrings * 55];
+};
+
+#if PROFILER_ENABLED
+
+// Implementation details.
+namespace profiler {
- // We need full-resolution timestamps; at an effective rate of 4 GHz,
- // this permits 1 minute zone durations (for longer durations, split into
- // multiple zones). Wraparound is handled by masking.
- static constexpr size_t kTimestampBits = 64 - kOffsetBits;
- static constexpr uint64_t kTimestampMask = (1ULL << kTimestampBits) - 1;
+HWY_INLINE_VAR constexpr size_t kNumFlags = 1;
- static Packet Make(const size_t biased_offset, const uint64_t timestamp) {
- HWY_DASSERT(biased_offset != 0);
- HWY_DASSERT(biased_offset < (1ULL << kOffsetBits));
+// Upper bounds for fixed-size data structures, guarded via HWY_DASSERT:
- Packet packet;
- packet.bits_ =
- (biased_offset << kTimestampBits) + (timestamp & kTimestampMask);
+// Maximum nesting of zones, chosen such that `PerWorker` is 256 bytes.
+HWY_INLINE_VAR constexpr size_t kMaxDepth = 13;
+// Reports with more than ~50 are anyway difficult to read.
+HWY_INLINE_VAR constexpr size_t kMaxZones = 128;
+// Upper bound on global worker_idx across all pools. Note that fiber libraries
+// can spawn hundreds of threads. Turin has 128-192 cores.
+HWY_INLINE_VAR constexpr size_t kMaxWorkers = 256;
- HWY_DASSERT(packet.BiasedOffset() == biased_offset);
- HWY_DASSERT(packet.Timestamp() == (timestamp & kTimestampMask));
- return packet;
+// Type-safe wrapper for zone index plus flags, returned by `AddZone`.
+class ZoneHandle {
+ public:
+ ZoneHandle() : bits_(0) {} // for Accumulator member initialization
+
+ ZoneHandle(size_t zone_idx, ProfilerFlags flags) {
+ HWY_DASSERT(0 != zone_idx && zone_idx < kMaxZones);
+ const uint32_t flags_u = static_cast<uint32_t>(flags);
+ HWY_DASSERT(flags_u < (1u << kNumFlags));
+ bits_ = (static_cast<uint32_t>(zone_idx) << kNumFlags) | flags_u;
+ HWY_DASSERT(ZoneIdx() == zone_idx);
}
- uint64_t Timestamp() const { return bits_ & kTimestampMask; }
+ ZoneHandle(const ZoneHandle& other) = default;
+ ZoneHandle& operator=(const ZoneHandle& other) = default;
+
+ bool operator==(const ZoneHandle other) const { return bits_ == other.bits_; }
+ bool operator!=(const ZoneHandle other) const { return bits_ != other.bits_; }
- size_t BiasedOffset() const {
- const size_t biased_offset = (bits_ >> kTimestampBits);
- HWY_DASSERT(biased_offset != 0);
- HWY_DASSERT(biased_offset < (1ULL << kOffsetBits));
- return biased_offset;
+ size_t ZoneIdx() const {
+ HWY_DASSERT(bits_ != 0);
+ const size_t zone_idx = bits_ >> kNumFlags;
+ HWY_DASSERT(0 != zone_idx && zone_idx < kMaxZones);
+ return zone_idx;
+ }
+
+ bool IsInclusive() const {
+ HWY_DASSERT(bits_ != 0);
+ return (bits_ & static_cast<uint32_t>(ProfilerFlags::kInclusive)) != 0;
+ }
+
+ // Returns a mask to zero/ignore child totals for inclusive zones.
+ uint64_t ChildTotalMask() const {
+ // With a ternary operator, clang tends to generate a branch.
+ // return IsInclusive() ? 0 : ~uint64_t{0};
+ const uint32_t bit =
+ bits_ & static_cast<uint32_t>(ProfilerFlags::kInclusive);
+ return uint64_t{bit} - 1;
}
private:
- uint64_t bits_;
+ uint32_t bits_;
};
-static_assert(sizeof(Packet) == 8, "Wrong Packet size");
-
-// All translation units must use the same string origin. A static member
-// function ensures this without requiring a separate .cc file.
-struct StringOrigin {
- // Returns the address of a string literal. Assuming zone names are also
- // literals and stored nearby, we can represent them as offsets from this,
- // which is faster to compute than hashes or even a static index.
- static const char* Get() {
- // Chosen such that no zone name is a prefix nor suffix of this string
- // to ensure they aren't merged. Note zone exit packets use
- // `biased_offset == kOffsetBias`.
- static const char* string_origin = "__#__";
- return string_origin - Packet::kOffsetBias;
+
+// Storage for zone names.
+class Zones {
+ public:
+ // Returns a copy of the `name` passed to `AddZone` that returned the
+ // given `zone`.
+ const char* Name(ZoneHandle zone) const {
+ return strings_.Name(zone.ZoneIdx());
+ }
+
+ // Can be called concurrently. Calls with the same `name` return the same
+ // `ZoneHandle.ZoneIdx()`.
+ ZoneHandle AddZone(const char* name, ProfilerFlags flags) {
+ return ZoneHandle(strings_.Add(name), flags);
}
-};
-// Representation of an active zone, stored in a stack. Used to deduct
-// child duration from the parent's self time. POD.
-struct Node {
- Packet packet;
- uint64_t child_total;
+ private:
+ StringTable<kMaxZones> strings_;
};
-static_assert(sizeof(Node) == 16, "Wrong Node size");
-// Holds statistics for all zones with the same name. POD.
-struct Accumulator {
- static constexpr size_t kNumCallBits = 64 - Packet::kOffsetBits;
+// Allows other classes such as `ThreadPool` to register/unregister a function
+// to call during `PrintResults`. This allows us to gather data from the worker
+// threads without having to wait until they exit, and decouples the profiler
+// from other modules. Thread-safe.
+class Funcs {
+ static constexpr auto kAcq = std::memory_order_acquire;
+ static constexpr auto kRel = std::memory_order_release;
+
+ public:
+ // Can be called concurrently with distinct keys.
+ void Add(intptr_t key, ProfilerFunc func) {
+ HWY_ASSERT(key != 0 && key != kPending); // reserved values
+ HWY_ASSERT(func); // not empty
+
+ for (size_t i = 0; i < kMaxFuncs; ++i) {
+ intptr_t expected = 0;
+ // Lost a race with a concurrent `Add`, try the next slot.
+ if (!keys_[i].compare_exchange_strong(expected, kPending, kRel)) {
+ continue;
+ }
+ // We own the slot: move func there.
+ funcs_[i] = std::move(func);
+ keys_[i].store(key, kRel); // publishes the `func` write.
+ return;
+ }
- uint64_t BiasedOffset() const {
- const size_t biased_offset = u128.lo >> kNumCallBits;
- HWY_DASSERT(biased_offset != 0);
- HWY_DASSERT(biased_offset < (1ULL << Packet::kOffsetBits));
- return biased_offset;
+ HWY_ABORT("Funcs::Add: no free slot, increase kMaxFuncs.");
}
- uint64_t NumCalls() const { return u128.lo & ((1ULL << kNumCallBits) - 1); }
- uint64_t Duration() const { return u128.hi; }
- void Set(uint64_t biased_offset, uint64_t num_calls, uint64_t duration) {
- HWY_DASSERT(biased_offset != 0);
- HWY_DASSERT(biased_offset < (1ULL << Packet::kOffsetBits));
- HWY_DASSERT(num_calls < (1ULL << kNumCallBits));
+ // Can be called concurrently with distinct keys. It is an error to call this
+ // without a prior `Add` of the same key.
+ void Remove(intptr_t key) {
+ HWY_ASSERT(key != 0 && key != kPending); // reserved values
+
+ for (size_t i = 0; i < kMaxFuncs; ++i) {
+ intptr_t actual = keys_[i].load(kAcq);
+ if (actual == key) {
+ // In general, concurrent removal is fine, but in this specific context,
+ // owners are expected to remove their key exactly once, from the same
+ // thread that added it. In that case, CAS should not fail.
+ if (!keys_[i].compare_exchange_strong(actual, kPending, kRel)) {
+ HWY_WARN("Funcs: CAS failed, why is there a concurrent Remove?");
+ }
+ funcs_[i] = ProfilerFunc();
+ keys_[i].store(0, kRel); // publishes the `func` write.
+ return;
+ }
+ }
+ HWY_ABORT("Funcs::Remove: failed to find key %p.",
+ reinterpret_cast<void*>(key));
+ }
+
+ void CallAll() const {
+ for (size_t i = 0; i < kMaxFuncs; ++i) {
+ intptr_t key = keys_[i].load(kAcq); // ensures `funcs_` is visible.
+ // Safely handles concurrent Add/Remove.
+ if (key != 0 && key != kPending) {
+ funcs_[i]();
+ }
+ }
+ }
+
+ private:
+ static constexpr size_t kMaxFuncs = 64;
+ static constexpr intptr_t kPending = -1;
+
+ ProfilerFunc funcs_[kMaxFuncs]; // non-atomic
+ std::atomic<intptr_t> keys_[kMaxFuncs] = {};
+};
+
+// Holds total duration and number of calls. Worker index is implicit in the
+// index of this class within the `Accumulators` array.
+struct Accumulator {
+ void Add(ZoneHandle new_zone, uint64_t self_duration) {
+ duration += self_duration;
- u128.hi = duration;
- u128.lo = (biased_offset << kNumCallBits) + num_calls;
+ // Only called for valid zones.
+ HWY_DASSERT(new_zone != ZoneHandle());
+ // Our zone might not have been set yet.
+ HWY_DASSERT(zone == ZoneHandle() || zone == new_zone);
+ zone = new_zone;
- HWY_DASSERT(BiasedOffset() == biased_offset);
- HWY_DASSERT(NumCalls() == num_calls);
- HWY_DASSERT(Duration() == duration);
+ num_calls += 1;
}
- void Add(uint64_t num_calls, uint64_t duration) {
- const uint64_t biased_offset = BiasedOffset();
- (void)biased_offset;
+ void Take(Accumulator& other) {
+ duration += other.duration;
+ other.duration = 0;
- u128.lo += num_calls;
- u128.hi += duration;
+ // `ZoneSet` ensures we only call this for non-empty `other`.
+ HWY_DASSERT(other.zone != ZoneHandle());
+ // Our zone might not have been set yet.
+ HWY_DASSERT(zone == ZoneHandle() || zone == other.zone);
+ zone = other.zone;
- HWY_DASSERT(biased_offset == BiasedOffset());
+ num_calls += other.num_calls;
+ other.num_calls = 0;
}
- // For fast sorting by duration, which must therefore be the hi element.
- // lo holds BiasedOffset and NumCalls.
- uint128_t u128;
+ uint64_t duration = 0;
+ ZoneHandle zone; // flags are used by `Results::Print`
+ uint32_t num_calls = 0;
};
static_assert(sizeof(Accumulator) == 16, "Wrong Accumulator size");
-template <typename T>
-inline T ClampedSubtract(const T minuend, const T subtrahend) {
- if (subtrahend > minuend) {
- return 0;
+using ZoneSet = hwy::BitSet<kMaxZones>;
+using WorkerSet = hwy::BitSet<kMaxWorkers>;
+using AtomicWorkerSet = hwy::AtomicBitSet<kMaxWorkers>;
+
+// Durations are per-CPU, but end to end performance is defined by wall time.
+// Assuming fork-join parallelism, zones are entered by multiple threads
+// concurrently, which means the total number of unique threads is also the
+// degree of concurrency, so we can estimate wall time as CPU time divided by
+// the number of unique threads seen. This is facilitated by unique `global_idx`
+// passed in by callers, or taken from thread-local `GlobalIdx()`.
+//
+// We also want to support varying thread counts per call site, because the same
+// function/zone may be called from multiple pools. `EndRootRun` calls
+// `CountWorkersAndReset` after each top-level `ThreadPool::Run`, which
+// generates one data point summarized via descriptive statistics. Here we
+// implement a simpler version of `Stats` because we do not require
+// geomean/variance/kurtosis/skewness. Because concurrency is a small integer,
+// we can simply compute sums rather than online moments. There is also only one
+// instance across all threads, hence we do not require a `Take`.
+//
+// Note that subsequently discovered prior work estimates the number of active
+// and idle processors by updating atomic counters whenever they start/finish a
+// task: https://homes.cs.washington.edu/~tom/pubs/quartz.pdf and "Effective
+// performance measurement and analysis of multithreaded applications". We
+// instead accumulate zone durations into per-thread storage.
+// `CountWorkersAndReset` then checks how many were nonzero, which avoids
+// expensive atomic updates and ensures accurate counts per-zone, rather than
+// estimates of current activity at each sample.
+// D. Vyukov's https://github.com/dvyukov/perf-load, also integrated into Linux
+// perf, also corrects for parallelism without using atomic counters by tracing
+// context switches. Note that we often pin threads, which avoids migrations,
+// but reduces the number of context switch events to mainly preemptions.
+class ConcurrencyStats {
+ public:
+ ConcurrencyStats() { Reset(); }
+
+ void Notify(const size_t x) {
+ sum_ += x;
+ ++n_;
+ min_ = HWY_MIN(min_, x);
+ max_ = HWY_MAX(max_, x);
+ }
+
+ size_t Count() const { return n_; }
+ size_t Min() const { return min_; }
+ size_t Max() const { return max_; }
+ double Mean() const {
+ return static_cast<double>(sum_) / static_cast<double>(n_);
}
- return minuend - subtrahend;
-}
-// Per-thread call graph (stack) and Accumulator for each zone.
+ void Reset() {
+ sum_ = 0;
+ n_ = 0;
+ min_ = hwy::HighestValue<size_t>();
+ max_ = hwy::LowestValue<size_t>();
+ }
+
+ private:
+ uint64_t sum_;
+ size_t n_;
+ size_t min_;
+ size_t max_;
+};
+static_assert(sizeof(ConcurrencyStats) == (8 + 3 * sizeof(size_t)), "");
+
+// Holds the final results across all threads, including `ConcurrencyStats`
+// and `PoolStats`, updated/printed by the main thread.
class Results {
public:
- Results() {
- ZeroBytes(nodes_, sizeof(nodes_));
- ZeroBytes(zones_, sizeof(zones_));
+ void TakeAccumulator(const size_t global_idx, const size_t zone_idx,
+ Accumulator& other) {
+ HWY_DASSERT(global_idx < kMaxWorkers);
+ HWY_DASSERT(zone_idx < kMaxZones);
+ HWY_DASSERT(other.zone.ZoneIdx() == zone_idx);
+
+ visited_zones_.Set(zone_idx);
+ totals_[zone_idx].Take(other);
+ workers_[zone_idx].Set(global_idx);
}
- // Used for computing overhead when this thread encounters its first Zone.
- // This has no observable effect apart from increasing "analyze_elapsed_".
- uint64_t ZoneDuration(const Packet* packets) {
- HWY_DASSERT(depth_ == 0);
- HWY_DASSERT(num_zones_ == 0);
- AnalyzePackets(packets, 2);
- const uint64_t duration = zones_[0].Duration();
- zones_[0].Set(1, 0, 0); // avoids triggering biased_offset = 0 checks
- HWY_DASSERT(depth_ == 0);
- num_zones_ = 0;
- return duration;
+ // Moves the total number of threads seen during the preceding root-level
+ // `ThreadPool::Run` into one data point for `ConcurrencyStats`.
+ void CountWorkersAndReset(const size_t zone_idx) {
+ HWY_DASSERT(zone_idx < kMaxZones);
+ const size_t num_workers = workers_[zone_idx].Count();
+ // Although workers_[zone_idx] at one point was non-empty, it is reset
+ // below, and so can be empty on the second call to this via `PrintResults`,
+ // after one from `EndRootRun`. Do not add a data point if empty.
+ if (num_workers != 0) {
+ concurrency_[zone_idx].Notify(num_workers);
+ }
+ workers_[zone_idx] = WorkerSet();
}
- void SetSelfOverhead(const uint64_t self_overhead) {
- self_overhead_ = self_overhead;
+ void CountWorkersAndReset() {
+ visited_zones_.Foreach(
+ [&](size_t zone_idx) { CountWorkersAndReset(zone_idx); });
}
- void SetChildOverhead(const uint64_t child_overhead) {
- child_overhead_ = child_overhead;
+ void PrintAndReset(const Zones& zones) {
+ const double inv_freq = 1.0 / hwy::platform::InvariantTicksPerSecond();
+
+ // Sort by decreasing total (self) cost. `totals_` are sparse, so sort an
+ // index vector instead.
+ std::vector<uint32_t> indices;
+ indices.reserve(visited_zones_.Count());
+ visited_zones_.Foreach([&](size_t zone_idx) {
+ indices.push_back(static_cast<uint32_t>(zone_idx));
+ // In case the zone exited after `EndRootRun` and was not yet added.
+ CountWorkersAndReset(zone_idx);
+ });
+ std::sort(indices.begin(), indices.end(), [&](uint32_t a, uint32_t b) {
+ return totals_[a].duration > totals_[b].duration;
+ });
+ printf(" %-40s: %10s x %15s / %5s (%5s %3s-%3s) = %9s\n", "Zone", "Calls",
+ "Cycles/Call", "Avg Count", "Count", "Min", "Max", "Wall Time(s)");
+
+ for (uint32_t zone_idx : indices) {
+ Accumulator& total = totals_[zone_idx]; // cleared after printing
+ HWY_ASSERT(total.zone.ZoneIdx() == zone_idx);
+ HWY_ASSERT(total.num_calls != 0); // else visited_zones_ is wrong
+
+ ConcurrencyStats& concurrency = concurrency_[zone_idx];
+ const double duration = static_cast<double>(total.duration);
+ const double per_call =
+ static_cast<double>(total.duration) / total.num_calls;
+ // See comment on `ConcurrencyStats`.
+ const double avg_concurrency = concurrency.Mean();
+ // Avoid division by zero.
+ const double concurrency_divisor = HWY_MAX(1.0, avg_concurrency);
+ printf("%s%-40s: %10.0f x %15.0f / %5.1f (%5zu %3zu-%3zu) = %9.6f\n",
+ total.zone.IsInclusive() ? "(I)" : " ", zones.Name(total.zone),
+ static_cast<double>(total.num_calls), per_call, avg_concurrency,
+ concurrency.Count(), concurrency.Min(), concurrency.Max(),
+ duration * inv_freq / concurrency_divisor);
+
+ total = Accumulator();
+ concurrency.Reset();
+ // `workers_` was already reset by `CountWorkersAndReset`.
+ }
+ visited_zones_ = ZoneSet();
}
- // Draw all required information from the packets, which can be discarded
- // afterwards. Called whenever this thread's storage is full.
- void AnalyzePackets(const Packet* packets, const size_t num_packets) {
- const uint64_t t0 = timer::Start();
+ private:
+ // Indicates which of the array entries are in use.
+ ZoneSet visited_zones_;
+ Accumulator totals_[kMaxZones];
+ WorkerSet workers_[kMaxZones];
+ ConcurrencyStats concurrency_[kMaxZones];
+};
- for (size_t i = 0; i < num_packets; ++i) {
- const Packet p = packets[i];
- // Entering a zone
- if (p.BiasedOffset() != Packet::kOffsetBias) {
- HWY_DASSERT(depth_ < kMaxDepth);
- nodes_[depth_].packet = p;
- HWY_DASSERT(p.BiasedOffset() != 0);
- nodes_[depth_].child_total = 0;
- ++depth_;
- continue;
- }
+// Delay after capturing timestamps before/after the actual zone runs. Even
+// with frequency throttling disabled, this has a multimodal distribution,
+// including 32, 34, 48, 52, 59, 62.
+struct Overheads {
+ uint64_t self = 0;
+ uint64_t child = 0;
+};
+static_assert(sizeof(Overheads) == 16, "Wrong Overheads size");
- HWY_DASSERT(depth_ != 0);
- const Node& node = nodes_[depth_ - 1];
- // Masking correctly handles unsigned wraparound.
- const uint64_t duration =
- (p.Timestamp() - node.packet.Timestamp()) & Packet::kTimestampMask;
- const uint64_t self_duration = ClampedSubtract(
- duration, self_overhead_ + child_overhead_ + node.child_total);
+class Accumulators {
+ // We generally want to group threads together because they are often
+ // accessed together during a zone, but also want to avoid threads sharing a
+ // cache line. Hence interleave 8 zones per worker.
+ static constexpr size_t kPerLine = HWY_ALIGNMENT / sizeof(Accumulator);
- UpdateOrAdd(node.packet.BiasedOffset(), 1, self_duration);
- --depth_;
+ public:
+ Accumulator& Get(const size_t global_idx, const size_t zone_idx) {
+ HWY_DASSERT(global_idx < kMaxWorkers);
+ HWY_DASSERT(zone_idx < kMaxZones);
+ const size_t line = zone_idx / kPerLine;
+ const size_t offset = zone_idx % kPerLine;
+ return zones_[(line * kMaxWorkers + global_idx) * kPerLine + offset];
+ }
- // Deduct this nested node's time from its parent's self_duration.
- if (depth_ != 0) {
- nodes_[depth_ - 1].child_total += duration + child_overhead_;
- }
- }
+ private:
+ Accumulator zones_[kMaxZones * kMaxWorkers];
+};
- const uint64_t t1 = timer::Stop();
- analyze_elapsed_ += t1 - t0;
+// Reacts to zone enter/exit events. Builds a stack of active zones and
+// accumulates self/child duration for each.
+class PerWorker {
+ public:
+ template <typename T>
+ static T ClampedSubtract(const T minuend, const T subtrahend) {
+ static_assert(IsUnsigned<T>(), "");
+ const T difference = minuend - subtrahend;
+ // Clang output for this is verified to CMOV rather than branch.
+ const T no_underflow = (subtrahend > minuend) ? T{0} : ~T{0};
+ return difference & no_underflow;
}
- // Incorporates results from another thread. Call after all threads have
- // exited any zones.
- void Assimilate(Results& other) {
- const uint64_t t0 = timer::Start();
- HWY_DASSERT(depth_ == 0);
- HWY_DASSERT(other.depth_ == 0);
+ void SetOverheads(const Overheads& overheads) { overheads_ = overheads; }
- for (size_t i = 0; i < other.num_zones_; ++i) {
- const Accumulator& zone = other.zones_[i];
- UpdateOrAdd(zone.BiasedOffset(), zone.NumCalls(), zone.Duration());
- }
- other.num_zones_ = 0;
- const uint64_t t1 = timer::Stop();
- analyze_elapsed_ += t1 - t0 + other.analyze_elapsed_;
+ // Entering a zone: push onto stack.
+ void Enter(const uint64_t t_enter) {
+ const size_t depth = depth_;
+ HWY_DASSERT(depth < kMaxDepth);
+ t_enter_[depth] = t_enter;
+ child_total_[1 + depth] = 0;
+ depth_ = 1 + depth;
}
- // Single-threaded.
- void Print() {
- const uint64_t t0 = timer::Start();
- MergeDuplicates();
+ // Exiting the most recently entered zone (top of stack).
+ void Exit(const uint64_t t_exit, const size_t global_idx,
+ const ZoneHandle zone, Accumulators& accumulators) {
+ HWY_DASSERT(depth_ > 0);
+ const size_t depth = depth_ - 1;
+ const size_t zone_idx = zone.ZoneIdx();
+ const uint64_t duration = t_exit - t_enter_[depth];
+ // Clang output for this is verified not to branch. This is 0 if inclusive,
+ // otherwise the child total.
+ const uint64_t child_total =
+ child_total_[1 + depth] & zone.ChildTotalMask();
+
+ const uint64_t self_duration = ClampedSubtract(
+ duration, overheads_.self + overheads_.child + child_total);
+ accumulators.Get(global_idx, zone_idx).Add(zone, self_duration);
+ // For faster TakeAccumulator() - not all zones are encountered.
+ visited_zones_.Set(zone_idx);
+
+ // Adding this nested time to the parent's `child_total` will
+ // cause it to be later subtracted from the parent's `self_duration`.
+ child_total_[1 + depth - 1] += duration + overheads_.child;
+
+ depth_ = depth;
+ }
- // Sort by decreasing total (self) cost.
- VQSort(&zones_[0].u128, num_zones_, SortDescending());
+ // Returns the duration of one enter/exit pair and resets all state. Called
+ // via `DetectSelfOverhead`.
+ uint64_t GetFirstDurationAndReset(size_t global_idx,
+ Accumulators& accumulators) {
+ HWY_DASSERT(depth_ == 0);
- const double inv_freq = 1.0 / platform::InvariantTicksPerSecond();
+ HWY_DASSERT(visited_zones_.Count() == 1);
+ const size_t zone_idx = visited_zones_.First();
+ HWY_DASSERT(zone_idx <= 3);
+ HWY_DASSERT(visited_zones_.Get(zone_idx));
+ visited_zones_.Clear(zone_idx);
- const char* string_origin = StringOrigin::Get();
- for (size_t i = 0; i < num_zones_; ++i) {
- const Accumulator& z = zones_[i];
- const size_t num_calls = z.NumCalls();
- const double duration = static_cast<double>(z.Duration());
- printf("%-40s: %10zu x %15.0f = %9.6f\n",
- string_origin + z.BiasedOffset(), num_calls, duration / num_calls,
- duration * inv_freq);
- }
- num_zones_ = 0;
+ Accumulator& zone = accumulators.Get(global_idx, zone_idx);
+ const uint64_t duration = zone.duration;
+ zone = Accumulator();
+ return duration;
+ }
- const uint64_t t1 = timer::Stop();
- analyze_elapsed_ += t1 - t0;
- printf("Total analysis [s]: %f\n",
- static_cast<double>(analyze_elapsed_) * inv_freq);
+ // Adds all data to `results` and resets it here. Called from the main thread.
+ void MoveTo(const size_t global_idx, Accumulators& accumulators,
+ Results& results) {
+ visited_zones_.Foreach([&](size_t zone_idx) {
+ results.TakeAccumulator(global_idx, zone_idx,
+ accumulators.Get(global_idx, zone_idx));
+ });
+ // OK to reset even if we have active zones, because we set `visited_zones_`
+ // when exiting the zone.
+ visited_zones_ = ZoneSet();
}
private:
- // Updates an existing Accumulator (uniquely identified by biased_offset) or
- // adds one if this is the first time this thread analyzed that zone.
- // Uses a self-organizing list data structure, which avoids dynamic memory
- // allocations and is far faster than unordered_map.
- void UpdateOrAdd(const size_t biased_offset, const uint64_t num_calls,
- const uint64_t duration) {
- HWY_DASSERT(biased_offset != 0);
- HWY_DASSERT(biased_offset < (1ULL << Packet::kOffsetBits));
-
- // Special case for first zone: (maybe) update, without swapping.
- if (num_zones_ != 0 && zones_[0].BiasedOffset() == biased_offset) {
- zones_[0].Add(num_calls, duration);
- return;
- }
+ // 40 bytes:
+ ZoneSet visited_zones_; // Which `zones_` have been active on this worker.
+ uint64_t depth_ = 0; // Current nesting level for active zones.
+ Overheads overheads_;
+
+ uint64_t t_enter_[kMaxDepth];
+ // Used to deduct child duration from parent's self time (unless inclusive).
+ // Shifting by one avoids bounds-checks for depth_ = 0 (root zone).
+ uint64_t child_total_[1 + kMaxDepth] = {0};
+};
+// Enables shift rather than multiplication.
+static_assert(sizeof(PerWorker) == 256, "Wrong size");
- // Look for a zone with the same offset.
- for (size_t i = 1; i < num_zones_; ++i) {
- if (zones_[i].BiasedOffset() == biased_offset) {
- zones_[i].Add(num_calls, duration);
- // Swap with predecessor (more conservative than move to front,
- // but at least as successful).
- const Accumulator prev = zones_[i - 1];
- zones_[i - 1] = zones_[i];
- zones_[i] = prev;
- return;
- }
- }
+} // namespace profiler
- // Not found; create a new Accumulator.
- HWY_DASSERT(num_zones_ < kMaxZones);
- zones_[num_zones_].Set(biased_offset, num_calls, duration);
- ++num_zones_;
- }
-
- // Each instantiation of a function template seems to get its own copy of
- // __func__ and GCC doesn't merge them. An N^2 search for duplicates is
- // acceptable because we only expect a few dozen zones.
- void MergeDuplicates() {
- const char* string_origin = StringOrigin::Get();
- for (size_t i = 0; i < num_zones_; ++i) {
- const size_t biased_offset = zones_[i].BiasedOffset();
- const char* name = string_origin + biased_offset;
- // Separate num_calls from biased_offset so we can add them together.
- uint64_t num_calls = zones_[i].NumCalls();
-
- // Add any subsequent duplicates to num_calls and total_duration.
- for (size_t j = i + 1; j < num_zones_;) {
- if (!strcmp(name, string_origin + zones_[j].BiasedOffset())) {
- num_calls += zones_[j].NumCalls();
- zones_[i].Add(0, zones_[j].Duration());
- // j was the last zone, so we are done.
- if (j == num_zones_ - 1) break;
- // Replace current zone with the last one, and check it next.
- zones_[j] = zones_[--num_zones_];
- } else { // Name differed, try next Accumulator.
- ++j;
- }
- }
+class Profiler {
+ public:
+ static HWY_DLLEXPORT Profiler& Get();
+
+ // Returns `global_idx` from thread-local storage (0 for the main thread).
+ // Used by `PROFILER_ZONE/PROFILER_FUNC`. It is faster to instead pass the
+ // global_idx from `ThreadPool::Run` (if constructed with non-default
+ // `PoolWorkerMapping`) to `PROFILER_ZONE2/PROFILER_ZONE3`.
+ // DEPRECATED: use `GlobalIdx` instead.
+ static size_t Thread() { return s_global_idx; }
+ static size_t GlobalIdx() { return s_global_idx; }
+ // Must be called from all worker threads, and once also on the main thread,
+ // before any use of `PROFILER_ZONE/PROFILER_FUNC`.
+ static void SetGlobalIdx(size_t global_idx) { s_global_idx = global_idx; }
+
+ void ReserveWorker(size_t global_idx) {
+ HWY_ASSERT(!workers_reserved_.Get(global_idx));
+ workers_reserved_.Set(global_idx);
+ }
- // Re-pack regardless of whether any duplicates were found.
- zones_[i].Set(biased_offset, num_calls, zones_[i].Duration());
- }
+ void FreeWorker(size_t global_idx) {
+ HWY_ASSERT(workers_reserved_.Get(global_idx));
+ workers_reserved_.Clear(global_idx);
}
- uint64_t analyze_elapsed_ = 0;
- uint64_t self_overhead_ = 0;
- uint64_t child_overhead_ = 0;
+ // Called by `Zone` from any thread.
+ void Enter(uint64_t t_enter, size_t global_idx) {
+ GetWorker(global_idx).Enter(t_enter);
+ }
- size_t depth_ = 0; // Number of active zones.
- size_t num_zones_ = 0; // Number of retired zones.
+ // Called by `~Zone` from any thread.
+ void Exit(uint64_t t_exit, size_t global_idx, profiler::ZoneHandle zone) {
+ GetWorker(global_idx).Exit(t_exit, global_idx, zone, accumulators_);
+ }
- alignas(HWY_ALIGNMENT) Node nodes_[kMaxDepth]; // Stack
- alignas(HWY_ALIGNMENT) Accumulator zones_[kMaxZones]; // Self-organizing list
-};
+ uint64_t GetFirstDurationAndReset(size_t global_idx) {
+ return GetWorker(global_idx)
+ .GetFirstDurationAndReset(global_idx, accumulators_);
+ }
+
+ const char* Name(profiler::ZoneHandle zone) const {
+ return zones_.Name(zone);
+ }
-// Per-thread packet storage, dynamically allocated.
-class ThreadSpecific {
- static constexpr size_t kBufferCapacity = HWY_ALIGNMENT / sizeof(Packet);
+ // Copies `name` into the string table and returns its unique `zone`. Uses
+ // linear search, which is fine because this is called during static init.
+ // Called via static initializer and the result is passed to the `Zone` ctor.
+ profiler::ZoneHandle AddZone(const char* name,
+ ProfilerFlags flags = ProfilerFlags::kDefault) {
+ return zones_.AddZone(name, flags);
+ }
- public:
- // "name" is used to sanity-check offsets fit in kOffsetBits.
- explicit ThreadSpecific(const char* name)
- : max_packets_((PROFILER_THREAD_STORAGE << 20) / sizeof(Packet)),
- packets_(AllocateAligned<Packet>(max_packets_)),
- num_packets_(0),
- string_origin_(StringOrigin::Get()) {
- // Even in optimized builds, verify that this zone's name offset fits
- // within the allotted space. If not, UpdateOrAdd is likely to overrun
- // zones_[]. Checking here on the cold path (only reached once per thread)
- // is cheap, but it only covers one zone.
- const size_t biased_offset = name - string_origin_;
- HWY_ASSERT(biased_offset < (1ULL << Packet::kOffsetBits));
- }
-
- // Depends on Zone => defined below.
- void ComputeOverhead();
-
- void WriteEntry(const char* name, const uint64_t timestamp) {
- HWY_DASSERT(name >= string_origin_);
- const size_t biased_offset = static_cast<size_t>(name - string_origin_);
- Write(Packet::Make(biased_offset, timestamp));
- }
-
- void WriteExit(const uint64_t timestamp) {
- const size_t biased_offset = Packet::kOffsetBias;
- Write(Packet::Make(biased_offset, timestamp));
- }
-
- void AnalyzeRemainingPackets() {
- // Ensures prior weakly-ordered streaming stores are globally visible.
- FlushStream();
-
- // Storage full => empty it.
- if (num_packets_ + buffer_size_ > max_packets_) {
- results_.AnalyzePackets(packets_.get(), num_packets_);
- num_packets_ = 0;
- }
- CopyBytes(buffer_, packets_.get() + num_packets_,
- buffer_size_ * sizeof(Packet));
- num_packets_ += buffer_size_;
+ void AddFunc(void* owner, ProfilerFunc func) {
+ funcs_.Add(reinterpret_cast<intptr_t>(owner), func);
+ }
+ void RemoveFunc(void* owner) {
+ funcs_.Remove(reinterpret_cast<intptr_t>(owner));
+ }
- results_.AnalyzePackets(packets_.get(), num_packets_);
- num_packets_ = 0;
+ // For reporting average concurrency. Called by `ThreadPool::Run` on the main
+ // thread, returns true if this is the first call since the last `EndRootRun`.
+ //
+ // We want to report the concurrency of each separate 'invocation' of a zone.
+ // A unique per-call identifier (could be approximated with the line number
+ // and return address) is not sufficient because the caller may in turn be
+ // called from differing parallel sections. A per-`ThreadPool::Run` counter
+ // also under-reports concurrency because each pool in nested parallelism
+ // (over packages and CCXes) would be considered separate invocations.
+ //
+ // The alternative of detecting overlapping zones via timestamps is not 100%
+ // reliable because timers may not be synchronized across sockets or perhaps
+ // even cores. "Invariant" x86 TSCs are indeed synchronized across cores, but
+ // not across sockets unless the RESET# signal reaches each at the same time.
+ // Linux seems to make an effort to correct this, and Arm's "generic timer"
+ // broadcasts to "all cores", but there is no universal guarantee.
+ //
+ // Under the assumption that all concurrency is via our `ThreadPool`, we can
+ // record all `global_idx` for each outermost (root) `ThreadPool::Run`. This
+ // collapses all nested pools into one 'invocation'. We then compute per-zone
+ // concurrency as the number of unique `global_idx` seen per invocation.
+ bool IsRootRun() {
+ // We are not the root if a Run was already active.
+ return !run_active_.test_and_set(std::memory_order_acquire);
}
- Results& GetResults() { return results_; }
+ // Must be called if `IsRootRun` returned true. Resets the state so that the
+ // next call to `IsRootRun` will again return true. Called from main thread.
+ // Note that some zones may still be active. Their concurrency will be updated
+ // when `PrintResults` is called.
+ void EndRootRun() {
+ UpdateResults();
+ results_.CountWorkersAndReset();
- private:
- // Overwrites "to" while attempting to bypass the cache (read-for-ownership).
- // Both pointers must be aligned.
- static void StreamCacheLine(const uint64_t* HWY_RESTRICT from,
- uint64_t* HWY_RESTRICT to) {
-#if HWY_COMPILER_CLANG
- for (size_t i = 0; i < HWY_ALIGNMENT / sizeof(uint64_t); ++i) {
- __builtin_nontemporal_store(from[i], to + i);
- }
-#else
- hwy::CopyBytes(from, to, HWY_ALIGNMENT);
-#endif
+ run_active_.clear(std::memory_order_release);
}
- // Write packet to buffer/storage, emptying them as needed.
- void Write(const Packet packet) {
- // Buffer full => copy to storage.
- if (buffer_size_ == kBufferCapacity) {
- // Storage full => empty it.
- if (num_packets_ + kBufferCapacity > max_packets_) {
- results_.AnalyzePackets(packets_.get(), num_packets_);
- num_packets_ = 0;
- }
- // This buffering halves observer overhead and decreases the overall
- // runtime by about 3%. Casting is safe because the first member is u64.
- StreamCacheLine(
- reinterpret_cast<const uint64_t*>(buffer_),
- reinterpret_cast<uint64_t*>(packets_.get() + num_packets_));
- num_packets_ += kBufferCapacity;
- buffer_size_ = 0;
- }
- buffer_[buffer_size_] = packet;
- ++buffer_size_;
- }
-
- // Write-combining buffer to avoid cache pollution. Must be the first
- // non-static member to ensure cache-line alignment.
- Packet buffer_[kBufferCapacity];
- size_t buffer_size_ = 0;
-
- const size_t max_packets_;
- // Contiguous storage for zone enter/exit packets.
- AlignedFreeUniquePtr<Packet[]> packets_;
- size_t num_packets_;
- // Cached here because we already read this cache line on zone entry/exit.
- const char* string_origin_;
- Results results_;
-};
+ // Prints results. Call from main thread after all threads have exited all
+ // zones. Resets all state, can be called again after more zones.
+ void PrintResults() {
+ UpdateResults();
+ // `CountWorkersAndReset` is fused into `Print`, so do not call it here.
-class ThreadList {
- public:
- // Called from any thread.
- ThreadSpecific* Add(const char* name) {
- const size_t index = num_threads_.fetch_add(1, std::memory_order_relaxed);
- HWY_DASSERT(index < kMaxThreads);
+ results_.PrintAndReset(zones_);
- ThreadSpecific* ts = MakeUniqueAligned<ThreadSpecific>(name).release();
- threads_[index].store(ts, std::memory_order_release);
- return ts;
+ funcs_.CallAll();
}
- // Single-threaded.
- void PrintResults() {
- const auto acq = std::memory_order_acquire;
- const size_t num_threads = num_threads_.load(acq);
+ // TODO: remove when no longer called.
+ void SetMaxThreads(size_t) {}
- ThreadSpecific* main = threads_[0].load(acq);
- main->AnalyzeRemainingPackets();
+ private:
+ // Sets main thread index, computes self-overhead, and checks timer support.
+ Profiler();
- for (size_t i = 1; i < num_threads; ++i) {
- ThreadSpecific* ts = threads_[i].load(acq);
- ts->AnalyzeRemainingPackets();
- main->GetResults().Assimilate(ts->GetResults());
- }
+ profiler::PerWorker& GetWorker(size_t global_idx) {
+ HWY_DASSERT(workers_reserved_.Get(global_idx));
+ return workers_[global_idx];
+ }
- if (num_threads != 0) {
- main->GetResults().Print();
- }
+ // Moves accumulators into Results. Called from the main thread.
+ void UpdateResults() {
+ // Ensure we see all writes from before the workers' release fence.
+ std::atomic_thread_fence(std::memory_order_acquire);
+
+ workers_reserved_.Foreach([&](size_t global_idx) {
+ workers_[global_idx].MoveTo(global_idx, accumulators_, results_);
+ });
}
- private:
- // Owning pointers.
- alignas(64) std::atomic<ThreadSpecific*> threads_[kMaxThreads];
- std::atomic<size_t> num_threads_{0};
+ static thread_local size_t s_global_idx;
+
+ // These are atomic because `ThreadFunc` reserves its slot(s) and even
+ // `ThreadPool::ThreadPool` may be called concurrently. Both have bit `i` set
+ // between calls to `Reserve*(i)` and `Free*(i)`. They are consulted in
+ // `UpdateResults` and to validate arguments in debug builds, and only updated
+ // in the pool/thread init/shutdown.
+ profiler::AtomicWorkerSet workers_reserved_;
+
+ std::atomic_flag run_active_ = ATOMIC_FLAG_INIT;
+
+ profiler::Funcs funcs_;
+
+ // To avoid locking, each worker has its own working set. We could access this
+ // through `thread_local` pointers, but that is slow to read on x86. Because
+ // our `ThreadPool` anyway passes a `global_idx` argument, we can instead pass
+ // that through the `PROFILER_ZONE2/PROFILER_ZONE3` macros.
+ profiler::PerWorker workers_[profiler::kMaxWorkers];
+
+ profiler::Accumulators accumulators_;
+
+ profiler::Results results_;
+
+ profiler::Zones zones_;
};
-// RAII zone enter/exit recorder constructed by the ZONE macro; also
-// responsible for initializing ThreadSpecific.
+namespace profiler {
+
+// RAII for zone entry/exit.
class Zone {
public:
- // "name" must be a string literal (see StringOrigin::Get).
- HWY_NOINLINE explicit Zone(const char* name) {
+ // Thread-compatible; must not call concurrently with the same `global_idx`,
+ // which is either:
+ // - passed from `ThreadPool::Run` (if it was constructed with non-default
+ // `PoolWorkerMapping`) to `PROFILER_ZONE2/PROFILER_ZONE3`;
+ // - obtained from `Profiler::GlobalIdx()`; or
+ // - 0 if running on the main thread.
+ Zone(Profiler& profiler, size_t global_idx, ZoneHandle zone)
+ : profiler_(profiler) {
HWY_FENCE;
- ThreadSpecific* HWY_RESTRICT thread_specific = StaticThreadSpecific();
- if (HWY_UNLIKELY(thread_specific == nullptr)) {
- // Ensure the CPU supports our timer.
- char cpu[100];
- if (!platform::HaveTimerStop(cpu)) {
- HWY_ABORT("CPU %s is too old for PROFILER_ENABLED=1, exiting", cpu);
- }
-
- thread_specific = StaticThreadSpecific() = Threads().Add(name);
- // Must happen after setting StaticThreadSpecific, because ComputeOverhead
- // also calls Zone().
- thread_specific->ComputeOverhead();
- }
-
- // (Capture timestamp ASAP, not inside WriteEntry.)
+ const uint64_t t_enter = timer::Start();
+ HWY_FENCE;
+ global_idx_ = static_cast<uint32_t>(global_idx);
+ zone_ = zone;
+ profiler.Enter(t_enter, global_idx);
HWY_FENCE;
- const uint64_t timestamp = timer::Start();
- thread_specific->WriteEntry(name, timestamp);
}
- HWY_NOINLINE ~Zone() {
+ ~Zone() {
HWY_FENCE;
- const uint64_t timestamp = timer::Stop();
- StaticThreadSpecific()->WriteExit(timestamp);
+ const uint64_t t_exit = timer::Stop();
+ profiler_.Exit(t_exit, static_cast<size_t>(global_idx_), zone_);
HWY_FENCE;
}
- // Call exactly once after all threads have exited all zones.
- static void PrintResults() { Threads().PrintResults(); }
-
private:
- // Returns reference to the thread's ThreadSpecific pointer (initially null).
- // Function-local static avoids needing a separate definition.
- static ThreadSpecific*& StaticThreadSpecific() {
- static thread_local ThreadSpecific* thread_specific;
- return thread_specific;
- }
+ Profiler& profiler_;
+ uint32_t global_idx_;
+ ZoneHandle zone_;
+};
- // Returns the singleton ThreadList. Non time-critical.
- static ThreadList& Threads() {
- static ThreadList threads_;
- return threads_;
+} // namespace profiler
+#else // profiler disabled: stub implementation
+
+namespace profiler {
+struct ZoneHandle {};
+} // namespace profiler
+
+struct Profiler {
+ static HWY_DLLEXPORT Profiler& Get();
+
+ // DEPRECATED: use `GlobalIdx` instead.
+ static size_t Thread() { return 0; }
+ static size_t GlobalIdx() { return 0; }
+ static void SetGlobalIdx(size_t) {}
+ void ReserveWorker(size_t) {}
+ void FreeWorker(size_t) {}
+ void Enter(uint64_t, size_t) {}
+ void Exit(uint64_t, size_t, profiler::ZoneHandle) {}
+ uint64_t GetFirstDurationAndReset(size_t) { return 0; }
+
+ const char* Name(profiler::ZoneHandle) const { return nullptr; }
+ profiler::ZoneHandle AddZone(const char*,
+ ProfilerFlags = ProfilerFlags::kDefault) {
+ return profiler::ZoneHandle();
}
-};
-// Creates a zone starting from here until the end of the current scope.
-// Timestamps will be recorded when entering and exiting the zone.
-// "name" must be a string literal, which is ensured by merging with "".
-#define PROFILER_ZONE(name) \
- HWY_FENCE; \
- const hwy::Zone zone("" name); \
- HWY_FENCE
+ void AddFunc(void*, ProfilerFunc) {}
+ void RemoveFunc(void*) {}
-// Creates a zone for an entire function (when placed at its beginning).
-// Shorter/more convenient than ZONE.
-#define PROFILER_FUNC \
- HWY_FENCE; \
- const hwy::Zone zone(__func__); \
- HWY_FENCE
+ bool IsRootRun() { return false; }
+ void EndRootRun() {}
+ void PrintResults() {}
-#define PROFILER_PRINT_RESULTS hwy::Zone::PrintResults
-
-inline void ThreadSpecific::ComputeOverhead() {
- // Delay after capturing timestamps before/after the actual zone runs. Even
- // with frequency throttling disabled, this has a multimodal distribution,
- // including 32, 34, 48, 52, 59, 62.
- uint64_t self_overhead;
- {
- const size_t kNumSamples = 32;
- uint32_t samples[kNumSamples];
- for (size_t idx_sample = 0; idx_sample < kNumSamples; ++idx_sample) {
- const size_t kNumDurations = 1024;
- uint32_t durations[kNumDurations];
-
- for (size_t idx_duration = 0; idx_duration < kNumDurations;
- ++idx_duration) {
- {
- PROFILER_ZONE("Dummy Zone (never shown)");
- }
- const uint64_t duration = results_.ZoneDuration(buffer_);
- buffer_size_ = 0;
- durations[idx_duration] = static_cast<uint32_t>(duration);
- HWY_DASSERT(num_packets_ == 0);
- }
- robust_statistics::CountingSort(durations, kNumDurations);
- samples[idx_sample] = robust_statistics::Mode(durations, kNumDurations);
- }
- // Median.
- robust_statistics::CountingSort(samples, kNumSamples);
- self_overhead = samples[kNumSamples / 2];
- if (PROFILER_PRINT_OVERHEAD) {
- printf("Overhead: %.0f\n", static_cast<double>(self_overhead));
- }
- results_.SetSelfOverhead(self_overhead);
- }
-
- // Delay before capturing start timestamp / after end timestamp.
- const size_t kNumSamples = 32;
- uint32_t samples[kNumSamples];
- for (size_t idx_sample = 0; idx_sample < kNumSamples; ++idx_sample) {
- const size_t kNumDurations = 16;
- uint32_t durations[kNumDurations];
- for (size_t idx_duration = 0; idx_duration < kNumDurations;
- ++idx_duration) {
- const size_t kReps = 10000;
- // Analysis time should not be included => must fit within buffer.
- HWY_DASSERT(kReps * 2 < max_packets_);
- std::atomic_thread_fence(std::memory_order_seq_cst);
- const uint64_t t0 = timer::Start();
- for (size_t i = 0; i < kReps; ++i) {
- PROFILER_ZONE("Dummy");
- }
- FlushStream();
- const uint64_t t1 = timer::Stop();
- HWY_DASSERT(num_packets_ + buffer_size_ == kReps * 2);
- buffer_size_ = 0;
- num_packets_ = 0;
- const uint64_t avg_duration = (t1 - t0 + kReps / 2) / kReps;
- durations[idx_duration] =
- static_cast<uint32_t>(ClampedSubtract(avg_duration, self_overhead));
- }
- robust_statistics::CountingSort(durations, kNumDurations);
- samples[idx_sample] = robust_statistics::Mode(durations, kNumDurations);
- }
- robust_statistics::CountingSort(samples, kNumSamples);
- const uint64_t child_overhead = samples[9 * kNumSamples / 10];
- if (PROFILER_PRINT_OVERHEAD) {
- printf("Child overhead: %.0f\n", static_cast<double>(child_overhead));
- }
- results_.SetChildOverhead(child_overhead);
-}
+ // TODO: remove when no longer called.
+ void SetMaxThreads(size_t) {}
+};
-#pragma pack(pop)
+namespace profiler {
+struct Zone {
+ Zone(Profiler&, size_t, ZoneHandle) {}
+};
+
+} // namespace profiler
+#endif // PROFILER_ENABLED || HWY_IDE
} // namespace hwy
-#endif // PROFILER_ENABLED || HWY_IDE
+// Creates a `Zone` lvalue with a line-dependent name, which records the elapsed
+// time from here until the end of the current scope. `p` is from
+// `Profiler::Get()` or a cached reference. `global_idx < kMaxWorkers`. `zone`
+// is the return value of `AddZone`. Separating its static init from the `Zone`
+// may be more efficient than `PROFILER_ZONE2`.
+#define PROFILER_ZONE3(p, global_idx, zone) \
+ HWY_FENCE; \
+ const hwy::profiler::Zone HWY_CONCAT(Z, __LINE__)(p, global_idx, zone); \
+ HWY_FENCE
-#if !PROFILER_ENABLED && !HWY_IDE
-#define PROFILER_ZONE(name)
-#define PROFILER_FUNC
-#define PROFILER_PRINT_RESULTS()
-#endif
+// For compatibility with old callers that do not pass `p` nor `flags`.
+// Also calls AddZone. Usage: `PROFILER_ZONE2(global_idx, "MyZone");`
+#define PROFILER_ZONE2(global_idx, name) \
+ static const hwy::profiler::ZoneHandle HWY_CONCAT(zone, __LINE__) = \
+ hwy::Profiler::Get().AddZone(name); \
+ PROFILER_ZONE3(hwy::Profiler::Get(), global_idx, HWY_CONCAT(zone, __LINE__))
+#define PROFILER_FUNC2(global_idx) PROFILER_ZONE2(global_idx, __func__)
+
+// OBSOLETE: it is more efficient to pass `global_idx` from `ThreadPool` to
+// `PROFILER_ZONE2/PROFILER_ZONE3`. Here we get it from thread_local storage.
+#define PROFILER_ZONE(name) PROFILER_ZONE2(hwy::Profiler::GlobalIdx(), name)
+#define PROFILER_FUNC PROFILER_FUNC2(hwy::Profiler::GlobalIdx())
+
+// DEPRECATED: Use `hwy::Profiler::Get()` directly instead.
+#define PROFILER_ADD_ZONE(name) hwy::Profiler::Get().AddZone(name)
+#define PROFILER_IS_ROOT_RUN() hwy::Profiler::Get().IsRootRun()
+#define PROFILER_END_ROOT_RUN() hwy::Profiler::Get().EndRootRun()
+#define PROFILER_PRINT_RESULTS() hwy::Profiler::Get().PrintResults()
#endif // HIGHWAY_HWY_PROFILER_H_
diff --git a/third_party/highway/hwy/robust_statistics.h b/third_party/highway/hwy/robust_statistics.h
index 5391cf5951..22689c4817 100644
--- a/third_party/highway/hwy/robust_statistics.h
+++ b/third_party/highway/hwy/robust_statistics.h
@@ -123,8 +123,10 @@ T Median(T* values, const size_t num_values) {
if (num_values % 2) {
return values[half];
}
+ // For integers, round rather than truncate.
+ const T bias = hwy::IsInteger<T>() ? T{1} : T{0};
// Even count: return average of middle two.
- return (values[half] + values[half - 1] + 1) / 2;
+ return (values[half] + values[half - 1] + bias) / 2;
}
// Returns a robust measure of variability.
diff --git a/third_party/highway/hwy/stats.cc b/third_party/highway/hwy/stats.cc
new file mode 100644
index 0000000000..6f660a45f0
--- /dev/null
+++ b/third_party/highway/hwy/stats.cc
@@ -0,0 +1,115 @@
+// Copyright 2024 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "third_party/highway/hwy/stats.h"
+
+#include <stdio.h>
+
+#include <algorithm> // std::min
+#include <string>
+
+#include "third_party/highway/hwy/base.h" // HWY_ASSERT
+
+namespace hwy {
+
+void Stats::Assimilate(const Stats& other) {
+ const int64_t total_n = n_ + other.n_;
+ if (total_n == 0) return; // Nothing to do; prevents div by zero.
+
+ min_ = std::min(min_, other.min_);
+ max_ = std::max(max_, other.max_);
+
+ sum_log_ += other.sum_log_;
+
+ const double product_n = n_ * other.n_;
+ const double n2 = n_ * n_;
+ const double other_n2 = other.n_ * other.n_;
+ const int64_t total_n2 = total_n * total_n;
+ const double total_n3 = static_cast<double>(total_n2) * total_n;
+ // Precompute reciprocal for speed - used at least twice.
+ const double inv_total_n = 1.0 / total_n;
+ const double inv_total_n2 = 1.0 / total_n2;
+
+ const double delta = other.m1_ - m1_;
+ const double delta2 = delta * delta;
+ const double delta3 = delta * delta2;
+ const double delta4 = delta2 * delta2;
+
+ m1_ = (n_ * m1_ + other.n_ * other.m1_) * inv_total_n;
+
+ const double new_m2 = m2_ + other.m2_ + delta2 * product_n * inv_total_n;
+
+ const double new_m3 =
+ m3_ + other.m3_ + delta3 * product_n * (n_ - other.n_) * inv_total_n2 +
+ 3.0 * delta * (n_ * other.m2_ - other.n_ * m2_) * inv_total_n;
+
+ m4_ += other.m4_ +
+ delta4 * product_n * (n2 - product_n + other_n2) / total_n3 +
+ 6.0 * delta2 * (n2 * other.m2_ + other_n2 * m2_) * inv_total_n2 +
+ 4.0 * delta * (n_ * other.m3_ - other.n_ * m3_) * inv_total_n;
+
+ m2_ = new_m2;
+ m3_ = new_m3;
+ n_ = total_n;
+}
+
+std::string Stats::ToString(int exclude) const {
+ if (Count() == 0) return std::string("(none)");
+
+ char buf[300];
+ size_t pos = 0;
+ int ret; // snprintf - bytes written or negative for error.
+
+ if ((exclude & kNoCount) == 0) {
+ ret = snprintf(buf + pos, sizeof(buf) - pos, "Count=%9zu ",
+ static_cast<size_t>(Count()));
+ HWY_ASSERT(ret > 0);
+ pos += ret;
+ }
+
+ if ((exclude & kNoMeanSD) == 0) {
+ const float sd = StandardDeviation();
+ ret = snprintf(buf + pos, sizeof(buf) - pos, "Mean=%10.3e SD=%8.2e ",
+ Mean(), sd);
+ HWY_ASSERT(ret > 0);
+ pos += ret;
+ }
+
+ if ((exclude & kNoMinMax) == 0) {
+ ret = snprintf(buf + pos, sizeof(buf) - pos, "Min=%10.3e Max=%10.3e ",
+ static_cast<double>(Min()), static_cast<double>(Max()));
+ HWY_ASSERT(ret > 0);
+ pos += ret;
+ }
+
+ if ((exclude & kNoSkewKurt) == 0) {
+ ret = snprintf(buf + pos, sizeof(buf) - pos, "Skew=%5.2f Kurt=%7.2f ",
+ Skewness(), Kurtosis());
+ HWY_ASSERT(ret > 0);
+ pos += ret;
+ }
+
+ if ((exclude & kNoGeomean) == 0) {
+ ret = snprintf(buf + pos, sizeof(buf) - pos, "GeoMean=%9.6f ",
+ GeometricMean());
+ HWY_ASSERT(ret > 0);
+ pos += ret;
+ }
+
+ HWY_ASSERT(pos < sizeof(buf));
+ return buf;
+}
+
+} // namespace hwy
diff --git a/third_party/highway/hwy/stats.h b/third_party/highway/hwy/stats.h
index b4b95719fb..92200ddf18 100644
--- a/third_party/highway/hwy/stats.h
+++ b/third_party/highway/hwy/stats.h
@@ -38,23 +38,70 @@ class Bins {
counts_[static_cast<int32_t>(bin)]++;
}
+ uint32_t Bin(size_t bin_idx) const {
+ HWY_DASSERT(bin_idx < N);
+ return counts_[bin_idx];
+ }
+
+ void ResetBin(size_t bin_idx) {
+ HWY_DASSERT(bin_idx < N);
+ counts_[bin_idx] = 0;
+ }
+
void Assimilate(const Bins<N>& other) {
for (size_t i = 0; i < N; ++i) {
counts_[i] += other.counts_[i];
}
}
- void Print(const char* caption) const {
- fprintf(stderr, "\n%s [%zu]\n", caption, N);
- size_t last_nonzero = 0;
+ size_t FirstNonzero() const {
+ for (size_t i = 0; i < N; ++i) {
+ if (counts_[i] != 0) return i;
+ }
+ return N;
+ }
+
+ size_t LastNonzero() const {
for (size_t i = N - 1; i < N; --i) {
- if (counts_[i] != 0) {
- last_nonzero = i;
- break;
+ if (counts_[i] != 0) return i;
+ }
+ return 0;
+ }
+
+ size_t NumNonzero() const {
+ size_t num_nonzero = 0;
+ for (size_t i = 0; i < N; ++i) {
+ num_nonzero += (counts_[i] != 0);
+ }
+ return num_nonzero;
+ }
+
+ size_t ModalBinIdx() const {
+ size_t max = 0;
+ size_t idx_max = 0;
+ for (size_t i = 0; i < N; ++i) {
+ if (counts_[i] > max) {
+ max = counts_[i];
+ idx_max = i;
}
}
- for (size_t i = 0; i <= last_nonzero; ++i) {
- fprintf(stderr, " %zu\n", counts_[i]);
+ return idx_max;
+ }
+
+ void Print(const char* caption, bool skip_zero = false) const {
+ fprintf(stderr, "\n%s [%zu, modal idx %zu]\n", caption, N, ModalBinIdx());
+ const size_t first_nonzero = FirstNonzero();
+ const size_t last_nonzero = LastNonzero();
+ if (skip_zero) {
+ for (size_t i = first_nonzero; i <= last_nonzero; ++i) {
+ if (counts_[i] != 0) {
+ fprintf(stderr, " %3zu: %zu\n", i, counts_[i]);
+ }
+ }
+ } else {
+ for (size_t i = first_nonzero; i <= last_nonzero; ++i) {
+ fprintf(stderr, " %3zu: %zu\n", i, counts_[i]);
+ }
}
}
@@ -65,7 +112,7 @@ class Bins {
}
private:
- size_t counts_[N];
+ uint32_t counts_[N];
};
// Descriptive statistics of a variable (4 moments). Thread-compatible.
@@ -82,13 +129,16 @@ class Stats {
// Logarithmic transform avoids/delays underflow and overflow.
sum_log_ += std::log(static_cast<double>(x));
- // Online moments. Reference: https://goo.gl/9ha694
+ // Online moments. Reference:
+ // https://www.thinkbrg.com/media/publication/720_McCrary_ImplementingAlgorithms_Whitepaper_20151119_WEB.pdf
const double d = x - m1_;
const double d_div_n = d / static_cast<double>(n_);
const double d2n1_div_n = d * (static_cast<double>(n_) - 1) * d_div_n;
const int64_t n_poly = n_ * n_ - 3 * n_ + 3;
m1_ += d_div_n;
- m4_ += d_div_n * (d_div_n * (d2n1_div_n * static_cast<double>(n_poly) + 6.0 * m2_) - 4.0 * m3_);
+ m4_ += d_div_n *
+ (d_div_n * (d2n1_div_n * static_cast<double>(n_poly) + 6.0 * m2_) -
+ 4.0 * m3_);
m3_ += d_div_n * (d2n1_div_n * (static_cast<double>(n_) - 2) - 3.0 * m2_);
m2_ += d2n1_div_n;
}
diff --git a/third_party/highway/hwy/targets.cc b/third_party/highway/hwy/targets.cc
new file mode 100644
index 0000000000..6e2860714c
--- /dev/null
+++ b/third_party/highway/hwy/targets.cc
@@ -0,0 +1,835 @@
+// Copyright 2019 Google LLC
+// Copyright 2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "third_party/highway/hwy/targets.h"
+
+#include <stdint.h>
+#include <stdio.h>
+
+#include "third_party/highway/hwy/base.h"
+#include "third_party/highway/hwy/detect_targets.h"
+#include "third_party/highway/hwy/highway.h"
+#include "third_party/highway/hwy/x86_cpuid.h"
+
+#if HWY_ARCH_X86
+#include <xmmintrin.h>
+
+#elif (HWY_ARCH_ARM || HWY_ARCH_PPC || HWY_ARCH_S390X || HWY_ARCH_RISCV || \
+ HWY_ARCH_LOONGARCH) && \
+ HWY_OS_LINUX
+// sys/auxv.h does not always include asm/hwcap.h, or define HWCAP*, hence we
+// still include this directly. See #1199.
+#if HWY_HAVE_ASM_HWCAP
+#include <asm/hwcap.h>
+#endif
+#if HWY_HAVE_AUXV
+#include <sys/auxv.h>
+#endif
+
+#endif // HWY_ARCH_*
+
+#if HWY_OS_APPLE
+#include <sys/sysctl.h>
+#include <sys/utsname.h>
+#endif // HWY_OS_APPLE
+
+namespace hwy {
+
+#if HWY_OS_APPLE
+static HWY_INLINE HWY_MAYBE_UNUSED bool HasCpuFeature(
+ const char* feature_name) {
+ int result = 0;
+ size_t len = sizeof(int);
+ return (sysctlbyname(feature_name, &result, &len, nullptr, 0) == 0 &&
+ result != 0);
+}
+
+static HWY_INLINE HWY_MAYBE_UNUSED bool ParseU32(const char*& ptr,
+ uint32_t& parsed_val) {
+ uint64_t parsed_u64 = 0;
+
+ const char* start_ptr = ptr;
+ for (char ch; (ch = (*ptr)) != '\0'; ++ptr) {
+ unsigned digit = static_cast<unsigned>(static_cast<unsigned char>(ch)) -
+ static_cast<unsigned>(static_cast<unsigned char>('0'));
+ if (digit > 9u) {
+ break;
+ }
+
+ parsed_u64 = (parsed_u64 * 10u) + digit;
+ if (parsed_u64 > 0xFFFFFFFFu) {
+ return false;
+ }
+ }
+
+ parsed_val = static_cast<uint32_t>(parsed_u64);
+ return (ptr != start_ptr);
+}
+
+static HWY_INLINE HWY_MAYBE_UNUSED bool IsMacOs12_2OrLater() {
+ utsname uname_buf;
+ ZeroBytes(&uname_buf, sizeof(utsname));
+
+ if ((uname(&uname_buf)) != 0) {
+ return false;
+ }
+
+ const char* ptr = uname_buf.release;
+ if (!ptr) {
+ return false;
+ }
+
+ uint32_t major;
+ uint32_t minor;
+ if (!ParseU32(ptr, major)) {
+ return false;
+ }
+
+ if (*ptr != '.') {
+ return false;
+ }
+
+ ++ptr;
+ if (!ParseU32(ptr, minor)) {
+ return false;
+ }
+
+ // We are running on macOS 12.2 or later if the Darwin kernel version is 21.3
+ // or later
+ return (major > 21 || (major == 21 && minor >= 3));
+}
+#endif // HWY_OS_APPLE
+
+#if HWY_ARCH_X86 && HWY_HAVE_RUNTIME_DISPATCH
+namespace x86 {
+
+// Returns the lower 32 bits of extended control register 0.
+// Requires CPU support for "OSXSAVE" (see below).
+static uint32_t ReadXCR0() {
+#if HWY_COMPILER_MSVC
+ return static_cast<uint32_t>(_xgetbv(0));
+#else // HWY_COMPILER_MSVC
+ uint32_t xcr0, xcr0_high;
+ const uint32_t index = 0;
+ asm volatile(".byte 0x0F, 0x01, 0xD0"
+ : "=a"(xcr0), "=d"(xcr0_high)
+ : "c"(index));
+ return xcr0;
+#endif // HWY_COMPILER_MSVC
+}
+
+// Arbitrary bit indices indicating which instruction set extensions are
+// supported. Use enum to ensure values are distinct.
+enum class FeatureIndex : uint32_t {
+ kSSE = 0,
+ kSSE2,
+ kSSE3,
+ kSSSE3,
+
+ kSSE41,
+ kSSE42,
+ kCLMUL,
+ kAES,
+
+ kAVX,
+ kAVX2,
+ kF16C,
+ kFMA,
+ kLZCNT,
+ kBMI,
+ kBMI2,
+
+ kAVX512F,
+ kAVX512VL,
+ kAVX512CD,
+ kAVX512DQ,
+ kAVX512BW,
+ kAVX512FP16,
+ kAVX512BF16,
+
+ kVNNI,
+ kVPCLMULQDQ,
+ kVBMI,
+ kVBMI2,
+ kVAES,
+ kPOPCNTDQ,
+ kBITALG,
+ kGFNI,
+
+ kAVX10,
+ kAPX,
+
+ kSentinel
+};
+static_assert(static_cast<size_t>(FeatureIndex::kSentinel) < 64,
+ "Too many bits for u64");
+
+static HWY_INLINE constexpr uint64_t Bit(FeatureIndex index) {
+ return 1ull << static_cast<size_t>(index);
+}
+
+// Returns bit array of FeatureIndex from CPUID feature flags.
+static uint64_t FlagsFromCPUID() {
+ uint64_t flags = 0; // return value
+ uint32_t abcd[4];
+ Cpuid(0, 0, abcd);
+ const uint32_t max_level = abcd[0];
+
+ // Standard feature flags
+ Cpuid(1, 0, abcd);
+ flags |= IsBitSet(abcd[3], 25) ? Bit(FeatureIndex::kSSE) : 0;
+ flags |= IsBitSet(abcd[3], 26) ? Bit(FeatureIndex::kSSE2) : 0;
+ flags |= IsBitSet(abcd[2], 0) ? Bit(FeatureIndex::kSSE3) : 0;
+ flags |= IsBitSet(abcd[2], 1) ? Bit(FeatureIndex::kCLMUL) : 0;
+ flags |= IsBitSet(abcd[2], 9) ? Bit(FeatureIndex::kSSSE3) : 0;
+ flags |= IsBitSet(abcd[2], 12) ? Bit(FeatureIndex::kFMA) : 0;
+ flags |= IsBitSet(abcd[2], 19) ? Bit(FeatureIndex::kSSE41) : 0;
+ flags |= IsBitSet(abcd[2], 20) ? Bit(FeatureIndex::kSSE42) : 0;
+ flags |= IsBitSet(abcd[2], 25) ? Bit(FeatureIndex::kAES) : 0;
+ flags |= IsBitSet(abcd[2], 28) ? Bit(FeatureIndex::kAVX) : 0;
+ flags |= IsBitSet(abcd[2], 29) ? Bit(FeatureIndex::kF16C) : 0;
+
+ // Extended feature flags
+ Cpuid(0x80000001U, 0, abcd);
+ flags |= IsBitSet(abcd[2], 5) ? Bit(FeatureIndex::kLZCNT) : 0;
+
+ // Extended features
+ if (max_level >= 7) {
+ Cpuid(7, 0, abcd);
+ flags |= IsBitSet(abcd[1], 3) ? Bit(FeatureIndex::kBMI) : 0;
+ flags |= IsBitSet(abcd[1], 5) ? Bit(FeatureIndex::kAVX2) : 0;
+ flags |= IsBitSet(abcd[1], 8) ? Bit(FeatureIndex::kBMI2) : 0;
+
+ flags |= IsBitSet(abcd[1], 16) ? Bit(FeatureIndex::kAVX512F) : 0;
+ flags |= IsBitSet(abcd[1], 17) ? Bit(FeatureIndex::kAVX512DQ) : 0;
+ flags |= IsBitSet(abcd[1], 28) ? Bit(FeatureIndex::kAVX512CD) : 0;
+ flags |= IsBitSet(abcd[1], 30) ? Bit(FeatureIndex::kAVX512BW) : 0;
+ flags |= IsBitSet(abcd[1], 31) ? Bit(FeatureIndex::kAVX512VL) : 0;
+
+ flags |= IsBitSet(abcd[2], 1) ? Bit(FeatureIndex::kVBMI) : 0;
+ flags |= IsBitSet(abcd[2], 6) ? Bit(FeatureIndex::kVBMI2) : 0;
+ flags |= IsBitSet(abcd[2], 8) ? Bit(FeatureIndex::kGFNI) : 0;
+ flags |= IsBitSet(abcd[2], 9) ? Bit(FeatureIndex::kVAES) : 0;
+ flags |= IsBitSet(abcd[2], 10) ? Bit(FeatureIndex::kVPCLMULQDQ) : 0;
+ flags |= IsBitSet(abcd[2], 11) ? Bit(FeatureIndex::kVNNI) : 0;
+ flags |= IsBitSet(abcd[2], 12) ? Bit(FeatureIndex::kBITALG) : 0;
+ flags |= IsBitSet(abcd[2], 14) ? Bit(FeatureIndex::kPOPCNTDQ) : 0;
+
+ flags |= IsBitSet(abcd[3], 23) ? Bit(FeatureIndex::kAVX512FP16) : 0;
+
+ Cpuid(7, 1, abcd);
+ flags |= IsBitSet(abcd[0], 5) ? Bit(FeatureIndex::kAVX512BF16) : 0;
+ flags |= IsBitSet(abcd[3], 19) ? Bit(FeatureIndex::kAVX10) : 0;
+ flags |= IsBitSet(abcd[3], 21) ? Bit(FeatureIndex::kAPX) : 0;
+ }
+
+ return flags;
+}
+
+// Each Highway target requires a 'group' of multiple features/flags.
+static constexpr uint64_t kGroupSSE2 =
+ Bit(FeatureIndex::kSSE) | Bit(FeatureIndex::kSSE2);
+
+static constexpr uint64_t kGroupSSSE3 =
+ Bit(FeatureIndex::kSSE3) | Bit(FeatureIndex::kSSSE3) | kGroupSSE2;
+
+#ifdef HWY_DISABLE_PCLMUL_AES
+static constexpr uint64_t kGroupSSE4 =
+ Bit(FeatureIndex::kSSE41) | Bit(FeatureIndex::kSSE42) | kGroupSSSE3;
+#else
+static constexpr uint64_t kGroupSSE4 =
+ Bit(FeatureIndex::kSSE41) | Bit(FeatureIndex::kSSE42) |
+ Bit(FeatureIndex::kCLMUL) | Bit(FeatureIndex::kAES) | kGroupSSSE3;
+#endif // HWY_DISABLE_PCLMUL_AES
+
+// We normally assume BMI/BMI2/FMA are available if AVX2 is. This allows us to
+// use BZHI and (compiler-generated) MULX. However, VirtualBox lacks them
+// [https://www.virtualbox.org/ticket/15471]. Thus we provide the option of
+// avoiding using and requiring these so AVX2 can still be used.
+#ifdef HWY_DISABLE_BMI2_FMA
+static constexpr uint64_t kGroupBMI2_FMA = 0;
+#else
+static constexpr uint64_t kGroupBMI2_FMA = Bit(FeatureIndex::kBMI) |
+ Bit(FeatureIndex::kBMI2) |
+ Bit(FeatureIndex::kFMA);
+#endif
+
+#ifdef HWY_DISABLE_F16C
+static constexpr uint64_t kGroupF16C = 0;
+#else
+static constexpr uint64_t kGroupF16C = Bit(FeatureIndex::kF16C);
+#endif
+
+static constexpr uint64_t kGroupAVX2 =
+ Bit(FeatureIndex::kAVX) | Bit(FeatureIndex::kAVX2) |
+ Bit(FeatureIndex::kLZCNT) | kGroupBMI2_FMA | kGroupF16C | kGroupSSE4;
+
+static constexpr uint64_t kGroupAVX3 =
+ Bit(FeatureIndex::kAVX512F) | Bit(FeatureIndex::kAVX512VL) |
+ Bit(FeatureIndex::kAVX512DQ) | Bit(FeatureIndex::kAVX512BW) |
+ Bit(FeatureIndex::kAVX512CD) | kGroupAVX2;
+
+static constexpr uint64_t kGroupAVX3_DL =
+ Bit(FeatureIndex::kVNNI) | Bit(FeatureIndex::kVPCLMULQDQ) |
+ Bit(FeatureIndex::kVBMI) | Bit(FeatureIndex::kVBMI2) |
+ Bit(FeatureIndex::kVAES) | Bit(FeatureIndex::kPOPCNTDQ) |
+ Bit(FeatureIndex::kBITALG) | Bit(FeatureIndex::kGFNI) | kGroupAVX3;
+
+static constexpr uint64_t kGroupAVX3_ZEN4 =
+ Bit(FeatureIndex::kAVX512BF16) | kGroupAVX3_DL;
+
+static constexpr uint64_t kGroupAVX3_SPR =
+ Bit(FeatureIndex::kAVX512FP16) | kGroupAVX3_ZEN4;
+
+static constexpr uint64_t kGroupAVX10 =
+ Bit(FeatureIndex::kAVX10) | Bit(FeatureIndex::kAPX) |
+ Bit(FeatureIndex::kVPCLMULQDQ) | Bit(FeatureIndex::kVAES) |
+ Bit(FeatureIndex::kGFNI) | kGroupAVX2;
+
+static int64_t DetectTargets() {
+ int64_t bits = 0; // return value of supported targets.
+ HWY_IF_CONSTEXPR(HWY_ARCH_X86_64) {
+ bits |= HWY_SSE2; // always present in x64
+ }
+
+ const uint64_t flags = FlagsFromCPUID();
+ // Set target bit(s) if all their group's flags are all set.
+ if ((flags & kGroupAVX3_SPR) == kGroupAVX3_SPR) {
+ bits |= HWY_AVX3_SPR;
+ }
+ if ((flags & kGroupAVX3_DL) == kGroupAVX3_DL) {
+ bits |= HWY_AVX3_DL;
+ }
+ if ((flags & kGroupAVX3) == kGroupAVX3) {
+ bits |= HWY_AVX3;
+ }
+ if ((flags & kGroupAVX2) == kGroupAVX2) {
+ bits |= HWY_AVX2;
+ }
+ if ((flags & kGroupSSE4) == kGroupSSE4) {
+ bits |= HWY_SSE4;
+ }
+ if ((flags & kGroupSSSE3) == kGroupSSSE3) {
+ bits |= HWY_SSSE3;
+ }
+ HWY_IF_CONSTEXPR(HWY_ARCH_X86_32) {
+ if ((flags & kGroupSSE2) == kGroupSSE2) {
+ bits |= HWY_SSE2;
+ }
+ }
+
+ uint32_t abcd[4];
+
+ if ((flags & kGroupAVX10) == kGroupAVX10) {
+ Cpuid(0x24, 0, abcd);
+
+ // AVX10 version is in lower 8 bits of abcd[1]
+ const uint32_t avx10_ver = abcd[1] & 0xFFu;
+
+ // 512-bit vectors are supported if avx10_ver >= 1 is true and bit 18 of
+ // abcd[1] is set
+ const bool has_avx10_with_512bit_vectors =
+ (avx10_ver >= 1) && IsBitSet(abcd[1], 18);
+
+ if (has_avx10_with_512bit_vectors) {
+ // AVX10.1 or later with support for 512-bit vectors implies support for
+ // the AVX3/AVX3_DL/AVX3_SPR targets
+ bits |= (HWY_AVX3_SPR | HWY_AVX3_DL | HWY_AVX3);
+
+ if (avx10_ver >= 2) {
+ // AVX10.2 is supported if avx10_ver >= 2 is true
+ bits |= HWY_AVX10_2;
+ }
+ }
+ }
+
+ // Clear AVX2/AVX3 bits if the CPU or OS does not support XSAVE - otherwise,
+ // YMM/ZMM registers are not preserved across context switches.
+
+ // The lower 128 bits of XMM0-XMM15 are guaranteed to be preserved across
+ // context switches on x86_64
+
+ // The following OS's are known to preserve the lower 128 bits of XMM
+ // registers across context switches on x86 CPUs that support SSE (even in
+ // 32-bit mode):
+ // - Windows 2000 or later
+ // - Linux 2.4.0 or later
+ // - Mac OS X 10.4 or later
+ // - FreeBSD 4.4 or later
+ // - NetBSD 1.6 or later
+ // - OpenBSD 3.5 or later
+ // - UnixWare 7 Release 7.1.1 or later
+ // - Solaris 9 4/04 or later
+
+ Cpuid(1, 0, abcd);
+ const bool has_xsave = IsBitSet(abcd[2], 26);
+ const bool has_osxsave = IsBitSet(abcd[2], 27);
+ constexpr int64_t min_avx2 = HWY_AVX2 | (HWY_AVX2 - 1);
+
+ if (has_xsave && has_osxsave) {
+#if HWY_OS_APPLE
+ // On macOS, check for AVX3 XSAVE support by checking that we are running on
+ // macOS 12.2 or later and HasCpuFeature("hw.optional.avx512f") returns true
+
+ // There is a bug in macOS 12.1 or earlier that can cause ZMM16-ZMM31, the
+ // upper 256 bits of the ZMM registers, and K0-K7 (the AVX512 mask
+ // registers) to not be properly preserved across a context switch on
+ // macOS 12.1 or earlier.
+
+ // This bug on macOS 12.1 or earlier on x86_64 CPU's with AVX3 support is
+ // described at
+ // https://community.intel.com/t5/Software-Tuning-Performance/MacOS-Darwin-kernel-bug-clobbers-AVX-512-opmask-register-state/m-p/1327259,
+ // https://github.com/golang/go/issues/49233, and
+ // https://github.com/simdutf/simdutf/pull/236.
+
+ // In addition to the bug that is there on macOS 12.1 or earlier, bits 5, 6,
+ // and 7 can be set to 0 on x86_64 CPUs with AVX3 support on macOS until
+ // the first AVX512 instruction is executed as macOS only preserves
+ // ZMM16-ZMM31, the upper 256 bits of the ZMM registers, and K0-K7 across a
+ // context switch on threads that have executed an AVX512 instruction.
+
+ // Checking for AVX3 XSAVE support on macOS using
+ // HasCpuFeature("hw.optional.avx512f") avoids false negative results
+ // on x86_64 CPU's that have AVX3 support.
+ const bool have_avx3_xsave_support =
+ IsMacOs12_2OrLater() && HasCpuFeature("hw.optional.avx512f");
+#endif
+
+ const uint32_t xcr0 = ReadXCR0();
+ constexpr int64_t min_avx3 = HWY_AVX3 | (HWY_AVX3 - 1);
+ // XMM/YMM
+ if (!IsBitSet(xcr0, 1) || !IsBitSet(xcr0, 2)) {
+ // Clear the AVX2/AVX3 bits if XMM/YMM XSAVE is not enabled
+ bits &= ~min_avx2;
+ }
+
+#if !HWY_OS_APPLE
+ // On OS's other than macOS, check for AVX3 XSAVE support by checking that
+ // bits 5, 6, and 7 of XCR0 are set.
+ const bool have_avx3_xsave_support =
+ IsBitSet(xcr0, 5) && IsBitSet(xcr0, 6) && IsBitSet(xcr0, 7);
+#endif
+
+ // opmask, ZMM lo/hi
+ if (!have_avx3_xsave_support) {
+ bits &= ~min_avx3;
+ }
+ } else { // !has_xsave || !has_osxsave
+ // Clear the AVX2/AVX3 bits if the CPU or OS does not support XSAVE
+ bits &= ~min_avx2;
+ }
+
+ // This is mainly to work around the slow Zen4 CompressStore. It's unclear
+ // whether subsequent AMD models will be affected; assume yes.
+ if ((bits & HWY_AVX3_DL) && (flags & kGroupAVX3_ZEN4) == kGroupAVX3_ZEN4 &&
+ IsAMD()) {
+ bits |= HWY_AVX3_ZEN4;
+ }
+
+ return bits;
+}
+
+} // namespace x86
+#elif HWY_ARCH_ARM && HWY_HAVE_RUNTIME_DISPATCH
+namespace arm {
+
+#if HWY_ARCH_ARM_A64 && !HWY_OS_APPLE && \
+ (HWY_COMPILER_GCC || HWY_COMPILER_CLANG) && \
+ ((HWY_TARGETS & HWY_ALL_SVE) != 0)
+HWY_PUSH_ATTRIBUTES("+sve")
+static int64_t DetectAdditionalSveTargets(int64_t detected_targets) {
+ uint64_t sve_vec_len;
+
+ // Use inline assembly instead of svcntb_pat(SV_ALL) as GCC or Clang might
+ // possibly optimize a svcntb_pat(SV_ALL) call to a constant if the
+ // -msve-vector-bits option is specified
+ asm("cntb %0" : "=r"(sve_vec_len)::);
+
+ return ((sve_vec_len == 32)
+ ? HWY_SVE_256
+ : (((detected_targets & HWY_SVE2) != 0 && sve_vec_len == 16)
+ ? HWY_SVE2_128
+ : 0));
+}
+HWY_POP_ATTRIBUTES
+#endif
+
+static int64_t DetectTargets() {
+ int64_t bits = 0; // return value of supported targets.
+
+ using CapBits = unsigned long; // NOLINT
+#if HWY_OS_APPLE
+ const CapBits hw = 0UL;
+#else
+ // For Android, this has been supported since API 20 (2014).
+ const CapBits hw = getauxval(AT_HWCAP);
+#endif
+ (void)hw;
+
+#if HWY_ARCH_ARM_A64
+ bits |= HWY_NEON_WITHOUT_AES; // aarch64 always has NEON and VFPv4..
+
+#if HWY_OS_APPLE
+ if (HasCpuFeature("hw.optional.arm.FEAT_AES")) {
+ bits |= HWY_NEON;
+
+ // Some macOS versions report AdvSIMD_HPFPCvt under a different key.
+ // Check both known variants for compatibility.
+ if ((HasCpuFeature("hw.optional.AdvSIMD_HPFPCvt") ||
+ HasCpuFeature("hw.optional.arm.AdvSIMD_HPFPCvt")) &&
+ HasCpuFeature("hw.optional.arm.FEAT_DotProd") &&
+ HasCpuFeature("hw.optional.arm.FEAT_BF16") &&
+ HasCpuFeature("hw.optional.arm.FEAT_I8MM")) {
+ bits |= HWY_NEON_BF16;
+ }
+ }
+#else // !HWY_OS_APPLE
+ // .. but not necessarily AES, which is required for HWY_NEON.
+#if defined(HWCAP_AES)
+ if (hw & HWCAP_AES) {
+ bits |= HWY_NEON;
+
+#if defined(HWCAP_ASIMDHP) && defined(HWCAP_ASIMDDP) && defined(HWCAP2_BF16)
+ const CapBits hw2 = getauxval(AT_HWCAP2);
+ constexpr CapBits kGroupF16Dot = HWCAP_ASIMDHP | HWCAP_ASIMDDP;
+ constexpr CapBits kGroupBF16 = HWCAP2_BF16;
+ if ((hw & kGroupF16Dot) == kGroupF16Dot &&
+ (hw2 & kGroupBF16) == kGroupBF16) {
+ bits |= HWY_NEON_BF16;
+ }
+#endif // HWCAP_ASIMDHP && HWCAP_ASIMDDP && HWCAP2_BF16
+ }
+#endif // HWCAP_AES
+
+#if defined(HWCAP_SVE)
+ if (hw & HWCAP_SVE) {
+ bits |= HWY_SVE;
+ }
+#endif
+
+#ifndef HWCAP2_SVE2
+#define HWCAP2_SVE2 (1 << 1)
+#endif
+#ifndef HWCAP2_SVEAES
+#define HWCAP2_SVEAES (1 << 2)
+#endif
+#ifndef HWCAP2_SVEI8MM
+#define HWCAP2_SVEI8MM (1 << 9)
+#endif
+#ifndef HWCAP2_SVEBF16
+#define HWCAP2_SVEBF16 (1 << 12)
+#endif
+
+ constexpr CapBits kGroupSVE2 = HWCAP2_SVE2 | HWCAP2_SVEAES;
+ const CapBits hw2 = getauxval(AT_HWCAP2);
+ if ((hw2 & kGroupSVE2) == kGroupSVE2) {
+ bits |= HWY_SVE2;
+ }
+
+#if (HWY_COMPILER_GCC || HWY_COMPILER_CLANG) && \
+ ((HWY_TARGETS & HWY_ALL_SVE) != 0)
+ if ((bits & HWY_ALL_SVE) != 0) {
+ bits |= DetectAdditionalSveTargets(bits);
+
+ // SVE2_128 implies I8MM and BF16, hence remove it if they are not present.
+ constexpr CapBits kGroupSVE2_128 = HWCAP2_SVEI8MM | HWCAP2_SVEBF16;
+ if ((hw2 & kGroupSVE2_128) != kGroupSVE2_128) {
+ bits &= ~HWY_SVE2_128;
+ }
+ }
+#endif // (HWY_COMPILER_GCC || HWY_COMPILER_CLANG) &&
+ // ((HWY_TARGETS & HWY_ALL_SVE) != 0)
+
+#endif // HWY_OS_APPLE
+
+#else // !HWY_ARCH_ARM_A64
+
+// Some old auxv.h / hwcap.h do not define these. If not, treat as unsupported.
+#if defined(HWCAP_NEON) && defined(HWCAP_VFPv4)
+ if ((hw & HWCAP_NEON) && (hw & HWCAP_VFPv4)) {
+ bits |= HWY_NEON_WITHOUT_AES;
+ }
+#endif
+
+ // aarch32 would check getauxval(AT_HWCAP2) & HWCAP2_AES, but we do not yet
+ // support that platform, and Armv7 lacks AES entirely. Because HWY_NEON
+ // requires native AES instructions, we do not enable that target here.
+
+#endif // HWY_ARCH_ARM_A64
+ return bits;
+}
+} // namespace arm
+#elif HWY_ARCH_PPC && HWY_HAVE_RUNTIME_DISPATCH
+namespace ppc {
+
+#ifndef PPC_FEATURE_HAS_ALTIVEC
+#define PPC_FEATURE_HAS_ALTIVEC 0x10000000
+#endif
+
+#ifndef PPC_FEATURE_HAS_VSX
+#define PPC_FEATURE_HAS_VSX 0x00000080
+#endif
+
+#ifndef PPC_FEATURE2_ARCH_2_07
+#define PPC_FEATURE2_ARCH_2_07 0x80000000
+#endif
+
+#ifndef PPC_FEATURE2_VEC_CRYPTO
+#define PPC_FEATURE2_VEC_CRYPTO 0x02000000
+#endif
+
+#ifndef PPC_FEATURE2_ARCH_3_00
+#define PPC_FEATURE2_ARCH_3_00 0x00800000
+#endif
+
+#ifndef PPC_FEATURE2_ARCH_3_1
+#define PPC_FEATURE2_ARCH_3_1 0x00040000
+#endif
+
+using CapBits = unsigned long; // NOLINT
+
+// For AT_HWCAP, the others are for AT_HWCAP2
+static constexpr CapBits kGroupVSX =
+ PPC_FEATURE_HAS_ALTIVEC | PPC_FEATURE_HAS_VSX;
+
+#if defined(HWY_DISABLE_PPC8_CRYPTO)
+static constexpr CapBits kGroupPPC8 = PPC_FEATURE2_ARCH_2_07;
+#else
+static constexpr CapBits kGroupPPC8 =
+ PPC_FEATURE2_ARCH_2_07 | PPC_FEATURE2_VEC_CRYPTO;
+#endif
+static constexpr CapBits kGroupPPC9 = kGroupPPC8 | PPC_FEATURE2_ARCH_3_00;
+static constexpr CapBits kGroupPPC10 = kGroupPPC9 | PPC_FEATURE2_ARCH_3_1;
+
+static int64_t DetectTargets() {
+ int64_t bits = 0; // return value of supported targets.
+
+#if defined(AT_HWCAP) && defined(AT_HWCAP2)
+ const CapBits hw = getauxval(AT_HWCAP);
+
+ if ((hw & kGroupVSX) == kGroupVSX) {
+ const CapBits hw2 = getauxval(AT_HWCAP2);
+ if ((hw2 & kGroupPPC8) == kGroupPPC8) {
+ bits |= HWY_PPC8;
+ }
+ if ((hw2 & kGroupPPC9) == kGroupPPC9) {
+ bits |= HWY_PPC9;
+ }
+ if ((hw2 & kGroupPPC10) == kGroupPPC10) {
+ bits |= HWY_PPC10;
+ }
+ } // VSX
+#endif // defined(AT_HWCAP) && defined(AT_HWCAP2)
+
+ return bits;
+}
+} // namespace ppc
+#elif HWY_ARCH_S390X && HWY_HAVE_RUNTIME_DISPATCH
+namespace s390x {
+
+#ifndef HWCAP_S390_VX
+#define HWCAP_S390_VX 2048
+#endif
+
+#ifndef HWCAP_S390_VXE
+#define HWCAP_S390_VXE 8192
+#endif
+
+#ifndef HWCAP_S390_VXRS_EXT2
+#define HWCAP_S390_VXRS_EXT2 32768
+#endif
+
+using CapBits = unsigned long; // NOLINT
+
+static constexpr CapBits kGroupZ14 = HWCAP_S390_VX | HWCAP_S390_VXE;
+static constexpr CapBits kGroupZ15 =
+ HWCAP_S390_VX | HWCAP_S390_VXE | HWCAP_S390_VXRS_EXT2;
+
+static int64_t DetectTargets() {
+ int64_t bits = 0;
+
+#if defined(AT_HWCAP)
+ const CapBits hw = getauxval(AT_HWCAP);
+
+ if ((hw & kGroupZ14) == kGroupZ14) {
+ bits |= HWY_Z14;
+ }
+
+ if ((hw & kGroupZ15) == kGroupZ15) {
+ bits |= HWY_Z15;
+ }
+#endif
+
+ return bits;
+}
+} // namespace s390x
+#elif HWY_ARCH_RISCV && HWY_HAVE_RUNTIME_DISPATCH
+namespace rvv {
+
+#ifndef HWCAP_RVV
+#define COMPAT_HWCAP_ISA_V (1 << ('V' - 'A'))
+#endif
+
+using CapBits = unsigned long; // NOLINT
+
+static int64_t DetectTargets() {
+ int64_t bits = 0;
+
+ const CapBits hw = getauxval(AT_HWCAP);
+
+ if ((hw & COMPAT_HWCAP_ISA_V) == COMPAT_HWCAP_ISA_V) {
+ size_t e8m1_vec_len;
+#if HWY_ARCH_RISCV_64
+ int64_t vtype_reg_val;
+#else
+ int32_t vtype_reg_val;
+#endif
+
+ // Check that a vuint8m1_t vector is at least 16 bytes and that tail
+ // agnostic and mask agnostic mode are supported
+ asm volatile(
+ // Avoid compiler error on GCC or Clang if -march=rv64gcv1p0 or
+ // -march=rv32gcv1p0 option is not specified on the command line
+ ".option push\n\t"
+ ".option arch, +v\n\t"
+ "vsetvli %0, zero, e8, m1, ta, ma\n\t"
+ "csrr %1, vtype\n\t"
+ ".option pop"
+ : "=r"(e8m1_vec_len), "=r"(vtype_reg_val));
+
+ // The RVV target is supported if the VILL bit of VTYPE (the MSB bit of
+ // VTYPE) is not set and the length of a vuint8m1_t vector is at least 16
+ // bytes
+ if (vtype_reg_val >= 0 && e8m1_vec_len >= 16) {
+ bits |= HWY_RVV;
+ }
+ }
+
+ return bits;
+}
+} // namespace rvv
+#elif HWY_ARCH_LOONGARCH && HWY_HAVE_RUNTIME_DISPATCH
+
+namespace loongarch {
+
+#ifndef LA_HWCAP_LSX
+#define LA_HWCAP_LSX (1u << 4)
+#endif
+#ifndef LA_HWCAP_LASX
+#define LA_HWCAP_LASX (1u << 5)
+#endif
+
+using CapBits = unsigned long; // NOLINT
+
+static int64_t DetectTargets() {
+ int64_t bits = 0;
+ const CapBits hw = getauxval(AT_HWCAP);
+ if (hw & LA_HWCAP_LSX) bits |= HWY_LSX;
+ if (hw & LA_HWCAP_LASX) bits |= HWY_LASX;
+ return bits;
+}
+} // namespace loongarch
+#endif // HWY_ARCH_*
+
+// Returns targets supported by the CPU, independently of DisableTargets.
+// Factored out of SupportedTargets to make its structure more obvious. Note
+// that x86 CPUID may take several hundred cycles.
+static int64_t DetectTargets() {
+ // Apps will use only one of these (the default is EMU128), but compile flags
+ // for this TU may differ from that of the app, so allow both.
+ int64_t bits = HWY_SCALAR | HWY_EMU128;
+
+#if HWY_ARCH_X86 && HWY_HAVE_RUNTIME_DISPATCH
+ bits |= x86::DetectTargets();
+#elif HWY_ARCH_ARM && HWY_HAVE_RUNTIME_DISPATCH
+ bits |= arm::DetectTargets();
+#elif HWY_ARCH_PPC && HWY_HAVE_RUNTIME_DISPATCH
+ bits |= ppc::DetectTargets();
+#elif HWY_ARCH_S390X && HWY_HAVE_RUNTIME_DISPATCH
+ bits |= s390x::DetectTargets();
+#elif HWY_ARCH_RISCV && HWY_HAVE_RUNTIME_DISPATCH
+ bits |= rvv::DetectTargets();
+#elif HWY_ARCH_LOONGARCH && HWY_HAVE_RUNTIME_DISPATCH
+ bits |= loongarch::DetectTargets();
+
+#else
+ // TODO(janwas): detect support for WASM.
+ // This file is typically compiled without HWY_IS_TEST, but targets_test has
+ // it set, and will expect all of its HWY_TARGETS (= all attainable) to be
+ // supported.
+ bits |= HWY_ENABLED_BASELINE;
+#endif // HWY_ARCH_*
+
+ if ((bits & HWY_ENABLED_BASELINE) != HWY_ENABLED_BASELINE) {
+ const uint64_t bits_u = static_cast<uint64_t>(bits);
+ const uint64_t enabled = static_cast<uint64_t>(HWY_ENABLED_BASELINE);
+ HWY_WARN("CPU supports 0x%08x%08x, software requires 0x%08x%08x\n",
+ static_cast<uint32_t>(bits_u >> 32),
+ static_cast<uint32_t>(bits_u & 0xFFFFFFFF),
+ static_cast<uint32_t>(enabled >> 32),
+ static_cast<uint32_t>(enabled & 0xFFFFFFFF));
+ }
+
+ return bits;
+}
+
+// When running tests, this value can be set to the mocked supported targets
+// mask. Only written to from a single thread before the test starts.
+static int64_t supported_targets_for_test_ = 0;
+
+// Mask of targets disabled at runtime with DisableTargets.
+static int64_t supported_mask_ = LimitsMax<int64_t>();
+
+HWY_DLLEXPORT void DisableTargets(int64_t disabled_targets) {
+ supported_mask_ = static_cast<int64_t>(~disabled_targets);
+ // This will take effect on the next call to SupportedTargets, which is
+ // called right before GetChosenTarget::Update. However, calling Update here
+ // would make it appear that HWY_DYNAMIC_DISPATCH was called, which we want
+ // to check in tests. We instead de-initialize such that the next
+ // HWY_DYNAMIC_DISPATCH calls GetChosenTarget::Update via FunctionCache.
+ GetChosenTarget().DeInit();
+}
+
+HWY_DLLEXPORT void SetSupportedTargetsForTest(int64_t targets) {
+ supported_targets_for_test_ = targets;
+ GetChosenTarget().DeInit(); // see comment above
+}
+
+HWY_DLLEXPORT int64_t SupportedTargets() {
+ int64_t targets = supported_targets_for_test_;
+ if (HWY_LIKELY(targets == 0)) {
+ // Mock not active. Re-detect instead of caching just in case we're on a
+ // heterogeneous ISA (also requires some app support to pin threads). This
+ // is only reached on the first HWY_DYNAMIC_DISPATCH or after each call to
+ // DisableTargets or SetSupportedTargetsForTest.
+ targets = DetectTargets();
+
+ // VectorBytes invokes HWY_DYNAMIC_DISPATCH. To prevent infinite recursion,
+ // first set up ChosenTarget. No need to Update() again afterwards with the
+ // final targets - that will be done by a caller of this function.
+ GetChosenTarget().Update(targets);
+ }
+
+ targets &= supported_mask_;
+ return targets == 0 ? HWY_STATIC_TARGET : targets;
+}
+
+HWY_DLLEXPORT ChosenTarget& GetChosenTarget() {
+ static ChosenTarget chosen_target;
+ return chosen_target;
+}
+
+} // namespace hwy
diff --git a/third_party/highway/hwy/targets.h b/third_party/highway/hwy/targets.h
index 6f34c890fe..a4f8a18faa 100644
--- a/third_party/highway/hwy/targets.h
+++ b/third_party/highway/hwy/targets.h
@@ -82,9 +82,17 @@ HWY_INLINE std::vector<int64_t> SupportedAndGeneratedTargets() {
#endif // HWY_NO_LIBCXX
+// Returns a string that satisfies gtest IsValidParamName(). No longer report
+// targets as "Unknown" if they are for a different architecture, because some
+// users unconditionally disable targets and we want to see which.
static inline HWY_MAYBE_UNUSED const char* TargetName(int64_t target) {
switch (target) {
-#if HWY_ARCH_X86
+ case HWY_EMU128:
+ return "EMU128";
+ case HWY_SCALAR:
+ return "SCALAR";
+
+ // X86
case HWY_SSE2:
return "SSE2";
case HWY_SSSE3:
@@ -99,15 +107,12 @@ static inline HWY_MAYBE_UNUSED const char* TargetName(int64_t target) {
return "AVX3_DL";
case HWY_AVX3_ZEN4:
return "AVX3_ZEN4";
- case HWY_AVX10_2:
- return "AVX10_2";
case HWY_AVX3_SPR:
return "AVX3_SPR";
- case HWY_AVX10_2_512:
- return "AVX10_2_512";
-#endif
+ case HWY_AVX10_2:
+ return "AVX10_2";
-#if HWY_ARCH_ARM
+ // ARM
case HWY_SVE2_128:
return "SVE2_128";
case HWY_SVE_256:
@@ -122,53 +127,71 @@ static inline HWY_MAYBE_UNUSED const char* TargetName(int64_t target) {
return "NEON";
case HWY_NEON_WITHOUT_AES:
return "NEON_WITHOUT_AES";
-#endif
-#if HWY_ARCH_PPC
+ // PPC
case HWY_PPC8:
return "PPC8";
case HWY_PPC9:
return "PPC9";
case HWY_PPC10:
return "PPC10";
-#endif
-#if HWY_ARCH_S390X
+ // S390X
case HWY_Z14:
return "Z14";
case HWY_Z15:
return "Z15";
-#endif
-#if HWY_ARCH_WASM
+ // WASM
case HWY_WASM:
return "WASM";
case HWY_WASM_EMU256:
return "WASM_EMU256";
-#endif
-#if HWY_ARCH_RISCV
+ // RISCV
case HWY_RVV:
return "RVV";
-#endif
-#if HWY_ARCH_LOONGARCH
+ // LOONGARCH
case HWY_LSX:
return "LSX";
case HWY_LASX:
return "LASX";
-#endif
-
- case HWY_EMU128:
- return "EMU128";
- case HWY_SCALAR:
- return "SCALAR";
-
- default:
- return "Unknown"; // must satisfy gtest IsValidParamName()
}
+
+ return "Unknown";
}
+// Invokes VISITOR(TARGET, NAMESPACE) for all enabled targets. Alphabetic order.
+#define HWY_VISIT_TARGETS(VISITOR) \
+ HWY_VISIT_AVX10_2(VISITOR) \
+ HWY_VISIT_AVX2(VISITOR) \
+ HWY_VISIT_AVX3(VISITOR) \
+ HWY_VISIT_AVX3_DL(VISITOR) \
+ HWY_VISIT_AVX3_SPR(VISITOR) \
+ HWY_VISIT_AVX3_ZEN4(VISITOR) \
+ HWY_VISIT_FALLBACK(VISITOR) \
+ HWY_VISIT_LASX(VISITOR) \
+ HWY_VISIT_LSX(VISITOR) \
+ HWY_VISIT_NEON(VISITOR) \
+ HWY_VISIT_NEON_BF16(VISITOR) \
+ HWY_VISIT_NEON_WITHOUT_AES(VISITOR) \
+ HWY_VISIT_PPC10(VISITOR) \
+ HWY_VISIT_PPC8(VISITOR) \
+ HWY_VISIT_PPC9(VISITOR) \
+ HWY_VISIT_RVV(VISITOR) \
+ HWY_VISIT_SSE2(VISITOR) \
+ HWY_VISIT_SSE4(VISITOR) \
+ HWY_VISIT_SSSE3(VISITOR) \
+ HWY_VISIT_SVE(VISITOR) \
+ HWY_VISIT_SVE2(VISITOR) \
+ HWY_VISIT_SVE2_128(VISITOR) \
+ HWY_VISIT_SVE_256(VISITOR) \
+ HWY_VISIT_WASM(VISITOR) \
+ HWY_VISIT_WASM_EMU256(VISITOR) \
+ HWY_VISIT_Z14(VISITOR) \
+ HWY_VISIT_Z15(VISITOR)
+
// The maximum number of dynamic targets on any architecture is defined by
// HWY_MAX_DYNAMIC_TARGETS and depends on the arch.
@@ -212,22 +235,22 @@ static inline HWY_MAYBE_UNUSED const char* TargetName(int64_t target) {
// HWY_MAX_DYNAMIC_TARGETS) bit. This list must contain exactly
// HWY_MAX_DYNAMIC_TARGETS elements and does not include SCALAR. The first entry
// corresponds to the best target. Don't include a "," at the end of the list.
-#define HWY_CHOOSE_TARGET_LIST(func_name) \
- nullptr, /* reserved */ \
- nullptr, /* reserved */ \
- nullptr, /* reserved */ \
- HWY_CHOOSE_AVX10_2_512(func_name), /* AVX10_2_512 */ \
- HWY_CHOOSE_AVX3_SPR(func_name), /* AVX3_SPR */ \
- HWY_CHOOSE_AVX10_2(func_name), /* reserved */ \
- HWY_CHOOSE_AVX3_ZEN4(func_name), /* AVX3_ZEN4 */ \
- HWY_CHOOSE_AVX3_DL(func_name), /* AVX3_DL */ \
- HWY_CHOOSE_AVX3(func_name), /* AVX3 */ \
- HWY_CHOOSE_AVX2(func_name), /* AVX2 */ \
- nullptr, /* AVX */ \
- HWY_CHOOSE_SSE4(func_name), /* SSE4 */ \
- HWY_CHOOSE_SSSE3(func_name), /* SSSE3 */ \
- nullptr, /* reserved - SSE3? */ \
- HWY_CHOOSE_SSE2(func_name) /* SSE2 */
+#define HWY_CHOOSE_TARGET_LIST(func_name) \
+ nullptr, /* reserved */ \
+ nullptr, /* reserved */ \
+ nullptr, /* reserved */ \
+ HWY_CHOOSE_AVX10_2(func_name), /* AVX10_2 */ \
+ HWY_CHOOSE_AVX3_SPR(func_name), /* AVX3_SPR */ \
+ nullptr, /* reserved */ \
+ HWY_CHOOSE_AVX3_ZEN4(func_name), /* AVX3_ZEN4 */ \
+ HWY_CHOOSE_AVX3_DL(func_name), /* AVX3_DL */ \
+ HWY_CHOOSE_AVX3(func_name), /* AVX3 */ \
+ HWY_CHOOSE_AVX2(func_name), /* AVX2 */ \
+ nullptr, /* AVX */ \
+ HWY_CHOOSE_SSE4(func_name), /* SSE4 */ \
+ HWY_CHOOSE_SSSE3(func_name), /* SSSE3 */ \
+ nullptr, /* reserved - SSE3? */ \
+ HWY_CHOOSE_SSE2(func_name) /* SSE2 */
#elif HWY_ARCH_ARM
// See HWY_ARCH_X86 above for details.
diff --git a/third_party/highway/hwy/timer.cc b/third_party/highway/hwy/timer.cc
new file mode 100644
index 0000000000..5fb613de30
--- /dev/null
+++ b/third_party/highway/hwy/timer.cc
@@ -0,0 +1,192 @@
+// Copyright 2019 Google LLC
+// SPDX-License-Identifier: Apache-2.0
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "third_party/highway/hwy/timer.h"
+
+#include <stdlib.h>
+
+#include <chrono> // NOLINT
+#include <ratio> // NOLINT
+
+#include "third_party/highway/hwy/base.h"
+#include "third_party/highway/hwy/robust_statistics.h"
+#include "third_party/highway/hwy/x86_cpuid.h"
+
+namespace hwy {
+
+#if HWY_ARCH_X86
+namespace x86 {
+
+static bool HasRDTSCP() {
+ uint32_t abcd[4];
+ Cpuid(0x80000001U, 0, abcd); // Extended feature flags
+ if ((abcd[3] & (1u << 27)) == 0) return false; // RDTSCP
+
+ Cpuid(0x80000007U, 0, abcd);
+ if ((abcd[3] & (1u << 8)) == 0) {
+ HWY_WARN("TSC not constant/invariant, may vary frequency or jump.");
+ }
+ return true;
+}
+
+} // namespace x86
+#endif // HWY_ARCH_X86
+
+// Measures the actual current frequency of Ticks. We cannot rely on the nominal
+// frequency encoded in x86 GetCpuString because it is misleading on M1 Rosetta,
+// and not reported by AMD. CPUID 0x15 is also not yet widely supported. Also
+// used on RISC-V and aarch64.
+static HWY_MAYBE_UNUSED double MeasureNominalClockRate() {
+ double max_ticks_per_sec = 0.0;
+ // Arbitrary, enough to ignore 2 outliers without excessive init time.
+ for (int rep = 0; rep < 3; ++rep) {
+ auto time0 = std::chrono::steady_clock::now();
+ using Time = decltype(time0);
+ const timer::Ticks ticks0 = timer::Start();
+ const Time time_min = time0 + std::chrono::milliseconds(10);
+
+ Time time1;
+ timer::Ticks ticks1;
+ for (;;) {
+ time1 = std::chrono::steady_clock::now();
+ // Ideally this would be Stop, but that requires RDTSCP on x86. To avoid
+ // another codepath, just use Start instead. now() presumably has its own
+ // fence-like behavior.
+ ticks1 = timer::Start(); // Do not use Stop, see comment above
+ if (time1 >= time_min) break;
+ }
+
+ const double dticks = static_cast<double>(ticks1 - ticks0);
+ std::chrono::duration<double, std::ratio<1>> dtime = time1 - time0;
+ const double ticks_per_sec = dticks / dtime.count();
+ max_ticks_per_sec = HWY_MAX(max_ticks_per_sec, ticks_per_sec);
+ }
+ return max_ticks_per_sec;
+}
+
+#if HWY_ARCH_PPC && defined(__GLIBC__) && defined(__powerpc64__)
+namespace ppc {
+
+static HWY_INLINE double GetTimebaseFreq() {
+ const auto timebase_freq = __ppc_get_timebase_freq();
+ // If timebase_freq is greater than 0, then return timebase_freq.
+
+ // Otherwise, if timebase_freq is less than or equal to 0, fall back to
+ // MeasureNominalClockRate(). This works around issues if running on QEMU on
+ // non-PPC CPU's.
+ return (timebase_freq > 0) ? static_cast<double>(timebase_freq)
+ : MeasureNominalClockRate();
+}
+
+} // namespace ppc
+#endif
+
+namespace platform {
+
+HWY_DLLEXPORT bool GetCpuString(char* cpu100) {
+#if HWY_ARCH_X86
+ uint32_t abcd[4];
+
+ // Check if brand string is supported (it is on all reasonable Intel/AMD)
+ x86::Cpuid(0x80000000U, 0, abcd);
+ if (abcd[0] < 0x80000004U) {
+ cpu100[0] = '\0';
+ return false;
+ }
+
+ for (size_t i = 0; i < 3; ++i) {
+ x86::Cpuid(static_cast<uint32_t>(0x80000002U + i), 0, abcd);
+ CopyBytes<sizeof(abcd)>(&abcd[0], cpu100 + i * 16); // not same size
+ }
+ cpu100[48] = '\0';
+ return true;
+#else
+ cpu100[0] = '?';
+ cpu100[1] = '\0';
+ return false;
+#endif
+}
+
+HWY_DLLEXPORT double Now() {
+ static const double mul = 1.0 / InvariantTicksPerSecond();
+ return static_cast<double>(timer::Start()) * mul;
+}
+
+HWY_DLLEXPORT bool HaveTimerStop(char* cpu100) {
+#if HWY_ARCH_X86
+ if (!x86::HasRDTSCP()) {
+ (void)GetCpuString(cpu100);
+ return false;
+ }
+#endif
+ *cpu100 = '\0';
+ return true;
+}
+
+HWY_DLLEXPORT double InvariantTicksPerSecond() {
+#if HWY_ARCH_PPC && defined(__GLIBC__) && defined(__powerpc64__)
+ static const double freq = ppc::GetTimebaseFreq();
+ return freq;
+#elif HWY_ARCH_X86 || HWY_ARCH_RISCV || (HWY_ARCH_ARM_A64 && !HWY_COMPILER_MSVC)
+ // We assume the x86 TSC is invariant; it is on all recent Intel/AMD CPUs.
+ static const double freq = MeasureNominalClockRate();
+ return freq;
+#elif defined(_WIN32) || defined(_WIN64)
+ LARGE_INTEGER freq;
+ (void)QueryPerformanceFrequency(&freq);
+ return static_cast<double>(freq.QuadPart);
+#elif defined(__APPLE__)
+ // https://developer.apple.com/library/mac/qa/qa1398/_index.html
+ mach_timebase_info_data_t timebase;
+ (void)mach_timebase_info(&timebase);
+ return static_cast<double>(timebase.denom) / timebase.numer * 1E9;
+#else
+ return 1E9; // Haiku and clock_gettime return nanoseconds.
+#endif
+}
+
+HWY_DLLEXPORT uint64_t TimerResolution() {
+ char cpu100[100];
+ bool can_use_stop = HaveTimerStop(cpu100);
+
+ // For measuring timer overhead/resolution. Used in a nested loop =>
+ // quadratic time, acceptable because we know timer overhead is "low".
+ // constexpr because this is used to define array bounds.
+ constexpr size_t kTimerSamples = 256;
+
+ // Nested loop avoids exceeding stack/L1 capacity.
+ timer::Ticks repetitions[kTimerSamples];
+ for (size_t rep = 0; rep < kTimerSamples; ++rep) {
+ timer::Ticks samples[kTimerSamples];
+ if (can_use_stop) {
+ for (size_t i = 0; i < kTimerSamples; ++i) {
+ const timer::Ticks t0 = timer::Start();
+ const timer::Ticks t1 = timer::Stop(); // we checked HasRDTSCP above
+ samples[i] = t1 - t0;
+ }
+ } else {
+ for (size_t i = 0; i < kTimerSamples; ++i) {
+ const timer::Ticks t0 = timer::Start();
+ const timer::Ticks t1 = timer::Start(); // do not use Stop, see above
+ samples[i] = t1 - t0;
+ }
+ }
+ repetitions[rep] = robust_statistics::Mode(samples);
+ }
+ return robust_statistics::Mode(repetitions);
+}
+
+} // namespace platform
+} // namespace hwy
diff --git a/third_party/highway/hwy/timer.h b/third_party/highway/hwy/timer.h
index 6d819c55bb..7063d5adb8 100644
--- a/third_party/highway/hwy/timer.h
+++ b/third_party/highway/hwy/timer.h
@@ -232,6 +232,50 @@ static HWY_INLINE Ticks Stop() {
} // namespace timer
+// Wrapper around Start/Stop that checks whether the CPU supports Stop.
+class Timer {
+ public:
+ Timer() {
+ char cpu100[100];
+ have_timer_stop_ = platform::HaveTimerStop(cpu100);
+ }
+
+ // Before/After have fences to prevent the measured code 'leaking out'.
+ timer::Ticks Before() const { return timer::Start(); }
+ timer::Ticks After() const {
+ return have_timer_stop_ ? timer::Stop() : timer::Start();
+ }
+
+ private:
+ bool have_timer_stop_;
+};
+
+static inline double Seconds(timer::Ticks ticks) {
+ return static_cast<double>(ticks) / platform::InvariantTicksPerSecond();
+}
+
+// Measures elapsed time since construction, with automatic reset.
+class Stopwatch {
+ public:
+ explicit Stopwatch(const Timer& timestamps) : timer_(timestamps) { Reset(); }
+
+ timer::Ticks Origin() const { return t0_; }
+ void Reset() { t0_ = timer_.Before(); }
+
+ // Also resets the start time to the current time to enable reuse without a
+ // second call to the timer.
+ timer::Ticks Elapsed() {
+ const timer::Ticks t1 = timer_.After();
+ const timer::Ticks elapsed = t1 - t0_;
+ t0_ = t1;
+ return elapsed;
+ }
+
+ private:
+ const Timer& timer_;
+ timer::Ticks t0_;
+};
+
} // namespace hwy
#endif // HIGHWAY_HWY_TIMER_H_
diff --git a/third_party/highway/hwy/x86_cpuid.h b/third_party/highway/hwy/x86_cpuid.h
index 2fcdb3c654..9aeca9ac30 100644
--- a/third_party/highway/hwy/x86_cpuid.h
+++ b/third_party/highway/hwy/x86_cpuid.h
@@ -39,9 +39,9 @@ static inline void Cpuid(const uint32_t level, const uint32_t count,
uint32_t* HWY_RESTRICT abcd) {
#if HWY_COMPILER_MSVC || HWY_COMPILER_CLANGCL
int regs[4];
- __cpuidex(regs, level, count);
+ __cpuidex(regs, static_cast<int>(level), static_cast<int>(count));
for (int i = 0; i < 4; ++i) {
- abcd[i] = regs[i];
+ abcd[i] = static_cast<uint32_t>(regs[i]);
}
#else // HWY_COMPILER_MSVC || HWY_COMPILER_CLANGCL
uint32_t a;