diff --git a/Source/Core/Common/Flag.h b/Source/Core/Common/Flag.h index 9b8fa3e28e..89b61c7805 100644 --- a/Source/Core/Common/Flag.h +++ b/Source/Core/Common/Flag.h @@ -3,10 +3,17 @@ // Refer to the license.txt file included. // Abstraction for a simple flag that can be toggled in a multithreaded way. -// It exposes a very simple API: +// +// Simple API: // * Set(bool = true): sets the Flag // * IsSet(): tests if the flag is set // * Clear(): clears the flag (equivalent to Set(false)). +// +// More advanced features: +// * TestAndSet(bool = true): sets the flag to the given value. If a change was +// needed (the flag did not already have this value) +// the function returns true. Else, false. +// * TestAndClear(): alias for TestAndSet(false). #pragma once @@ -37,6 +44,17 @@ public: return m_val.load(); } + bool TestAndSet(bool val = true) + { + bool expected = !val; + return m_val.compare_exchange_strong(expected, val); + } + + bool TestAndClear() + { + return TestAndSet(false); + } + private: // We are not using std::atomic_bool here because MSVC sucks as of VC++ // 2013 and does not implement the std::atomic_bool(bool) constructor. diff --git a/Source/UnitTests/Common/FlagTest.cpp b/Source/UnitTests/Common/FlagTest.cpp index ef73d27728..a3a0c3d929 100644 --- a/Source/UnitTests/Common/FlagTest.cpp +++ b/Source/UnitTests/Common/FlagTest.cpp @@ -2,6 +2,7 @@ // Licensed under GPLv2 // Refer to the license.txt file included. +#include #include #include @@ -23,6 +24,9 @@ TEST(Flag, Simple) f.Set(false); EXPECT_FALSE(f.IsSet()); + EXPECT_TRUE(f.TestAndSet()); + EXPECT_TRUE(f.TestAndClear()); + Flag f2(true); EXPECT_TRUE(f2.IsSet()); } @@ -58,3 +62,30 @@ TEST(Flag, MultiThreaded) EXPECT_EQ(ITERATIONS_COUNT, count); } + +TEST(Flag, SpinLock) +{ + // Uses a flag to implement basic spinlocking using TestAndSet. + Flag f; + int count = 0; + const int ITERATIONS_COUNT = 5000; + const int THREADS_COUNT = 50; + + auto adder_func = [&]() { + for (int i = 0; i < ITERATIONS_COUNT; ++i) + { + // Acquire the spinlock. + while (!f.TestAndSet()); + count++; + f.Clear(); + } + }; + + std::array threads; + for (auto& th : threads) + th = std::thread(adder_func); + for (auto& th : threads) + th.join(); + + EXPECT_EQ(ITERATIONS_COUNT * THREADS_COUNT, count); +}