diff --git a/bindings/py/cpp_src/bindings/math/py_Random.cpp b/bindings/py/cpp_src/bindings/math/py_Random.cpp index 06eb033ac4..7fb01b0470 100644 --- a/bindings/py/cpp_src/bindings/math/py_Random.cpp +++ b/bindings/py/cpp_src/bindings/math/py_Random.cpp @@ -42,14 +42,12 @@ namespace htm_ext { py::class_ Random(m, "Random"); Random.def(py::init(), py::arg("seed") = 0) - .def("getUInt32", &Random_t::getUInt32, py::arg("max") = (htm::UInt32)-1l) - .def("getReal64", &Random_t::getReal64) - .def("getSeed", &Random_t::getSeed) - .def("max", &Random_t::max) - .def("min", &Random_t::min) - .def("__eq__", [](Random_t const & self, Random_t const & other) {//wrapping operator== - return self == other; - }, py::is_operator()); + .def("getUInt32", &Random_t::getUInt32, py::arg("max") = (htm::UInt32)-1l) + .def("getReal64", &Random_t::getReal64) + .def("getSeed", &Random_t::getSeed) + .def("max", &Random_t::max) + .def("min", &Random_t::min) + .def("__eq__", [](Random_t const & self, Random_t const & other) { return self == other; }, py::is_operator()); //operator== Random.def_property_readonly_static("MAX32", [](py::object) { return Random_t::MAX32; @@ -144,25 +142,37 @@ namespace htm_ext { }, "load from a File, using BINARY, PORTABLE, JSON, or XML format.", py::arg("name"), py::arg("fmt") = 0); + Random.def(py::pickle( - [](const Random_t& r) + [](const Random_t& r) //__getstate__ { std::stringstream ss; - ss << r; - return ss.str(); - }, - [](const std::string& str) - { - if (str.empty()) - { - throw std::runtime_error("Empty state"); - } + r.save(ss); //save r's state to archive (stream) with cereal - std::stringstream ss(str); - Random_t r; - ss >> r; + /* The values in stringstream are binary so pickle will get confused + * trying to treat it as utf8 if you just return ss.str(). + * So we must treat it as py::bytes. Some characters could be null values. + */ + return py::bytes( ss.str() ); + }, - return r; + [](py::bytes &s) // __setstate__ + { + /* pybind11 will pass in the bytes array without conversion. + * so we should be able to just create a string to initalize the stringstream. + */ + std::stringstream ss( s.cast() ); + std::unique_ptr r(new htm::Random()); + r->load(ss); + + /* + * The __setstate__ part of the py::pickle() is actually a py::init() with some options. + * So the return value can be the object returned by value, by pointer, + * or by container (meaning a unique_ptr). SP has a problem with the copy constructor + * and pointers have problems knowing who the owner is so lets use unique_ptr. + * See: https://pybind11.readthedocs.io/en/stable/advanced/classes.html#custom-constructors + */ + return r; } )); diff --git a/src/htm/utils/Random.cpp b/src/htm/utils/Random.cpp index 83d3933c13..e0bb45d0c2 100644 --- a/src/htm/utils/Random.cpp +++ b/src/htm/utils/Random.cpp @@ -18,9 +18,6 @@ /** @file Random Number Generator implementation */ -#include // for istream, ostream -#include // for random seeds - #include #include @@ -32,64 +29,23 @@ bool Random::operator==(const Random &o) const { gen == o.gen; } -bool static_gen_seeded = false; //used only for seeding seed if 0/auto is passed for seed -std::mt19937 static_gen; +std::random_device rd; //HW RNG, undeterministic, platform dependant. Use only for seeding rng if random seed wanted (seed=0) -Random::Random(UInt64 seed) { +Random::Random(const UInt64 seed) { if (seed == 0) { - if( !static_gen_seeded ) { - #if NDEBUG - unsigned int static_seed = (unsigned int)std::chrono::system_clock::now().time_since_epoch().count(); - #else - unsigned int static_seed = DEBUG_RANDOM_SEED; - #endif - static_gen.seed( static_seed ); - static_gen_seeded = true; - NTA_INFO << "Random seed: " << static_seed; - } + std::mt19937 static_gen(rd()); seed_ = static_gen(); //generate random value from HW RNG } else { seed_ = seed; } - // if seed is zero at this point, there is a logic error. - NTA_CHECK(seed_ != 0); + NTA_CHECK(seed_ != 0) << "Random: if seed is zero at this point, there is a logic error"; gen.seed(static_cast(seed_)); //seed the generator steps_ = 0; } - namespace htm { -std::ostream &operator<<(std::ostream &outStream, const Random &r) { - outStream << "random-v2" << " "; - outStream << r.seed_ << " "; - outStream << r.steps_ << " "; - outStream << "endrandom-v2" << " "; - return outStream; -} - - -std::istream &operator>>(std::istream &inStream, Random &r) { - std::string version; - - inStream >> version; - NTA_CHECK(version == "random-v2") << "Random() deserializer -- found unexpected version string '" - << version << "'"; - inStream >> r.seed_; - r.gen.seed(static_cast(r.seed_)); //reseed - inStream >> r.steps_; - r.gen.discard(r.steps_); //advance n steps - //FIXME we could de/serialize directly RNG gen, it should be multi-platform according to standard, - //but on OSX CI it wasn't (25/11/2018). So "hacking" the above instead. - std::string endtag; - inStream >> endtag; - NTA_CHECK(endtag == "endrandom-v2") << "Random() deserializer -- found unexpected end tag '" << endtag << "'"; - inStream.ignore(1); - - return inStream; -} - // helper function for seeding RNGs across the plugin barrier -UInt32 GetRandomSeed() { - return htm::Random().getUInt32(); +UInt32 GetRandomSeed(const UInt seed) { + return htm::Random(seed).getUInt32(); } } // namespace htm diff --git a/src/htm/utils/Random.hpp b/src/htm/utils/Random.hpp index a5c0bb8fa9..9cc7d57f65 100644 --- a/src/htm/utils/Random.hpp +++ b/src/htm/utils/Random.hpp @@ -71,18 +71,23 @@ namespace htm { */ class Random : public Serializable { public: - Random(UInt64 seed = 0); + Random(const UInt64 seed = 0); + // Serialization CerealAdapter; template void save_ar(Archive & ar) const { - ar( cereal::make_nvp("seed", seed_), cereal::make_nvp("steps", steps_)); + ar( CEREAL_NVP(seed_), + CEREAL_NVP(steps_) + ); } template void load_ar(Archive & ar) { - ar( seed_, steps_); - gen.seed(static_cast(seed_)); //reseed + ar( CEREAL_NVP(seed_), + CEREAL_NVP(steps_) + ); + gen.seed(static_cast(seed_)); //reseed gen.discard(steps_); //advance n steps } @@ -160,9 +165,7 @@ class Random : public Serializable { protected: friend class RandomTest; - friend std::ostream &operator<<(std::ostream &, const Random &); - friend std::istream &operator>>(std::istream &, Random &); - friend UInt32 GetRandomSeed(); + friend UInt32 GetRandomSeed(const UInt seed); private: UInt64 seed_; UInt64 steps_ = 0; //step counter, used in serialization. It is important that steps_ is in sync with number of @@ -180,23 +183,18 @@ class Random : public Serializable { typename std::iterator_traits::difference_type i, n; n = last - first; for (i = n-1; i > 0; --i) { - using std::swap; - swap(first[i], first[this->getUInt32(static_cast(i+1))]); + std::swap(first[i], first[this->getUInt32(static_cast(i+1))]); } } }; -// serialization/deserialization -std::ostream &operator<<(std::ostream &, const Random &); -std::istream &operator>>(std::istream &, Random &); - // This function returns seeds from the Random singleton in our // "universe" (application, plugin, python module). If, when the // Random constructor is called, seeder_ is NULL, then seeder_ is // set to this function. The plugin framework can override this // behavior by explicitly setting the seeder to the RandomSeeder // function provided by the application. -UInt32 GetRandomSeed(); +UInt32 GetRandomSeed(const UInt seed=0); } // namespace htm diff --git a/src/test/unit/utils/RandomTest.cpp b/src/test/unit/utils/RandomTest.cpp index 01e7972d06..7c32ec6eb6 100644 --- a/src/test/unit/utils/RandomTest.cpp +++ b/src/test/unit/utils/RandomTest.cpp @@ -37,8 +37,10 @@ using namespace std; TEST(RandomTest, Seeding) { { Random r; + ASSERT_TRUE(r.getSeed() != 0) << "Should initialize with randomized seed"; + auto x = r.getUInt32(); - ASSERT_TRUE(x != 0); + ASSERT_NE(x, 0u); } // test getSeed @@ -64,6 +66,10 @@ TEST(RandomTest, Seeding) { ASSERT_EQ(r(), 419326371u); } + for(int i=0; i< 10; i++) { + ASSERT_NE(Random(0).getSeed(), Random(0).getSeed()) << "Randomly seeded generators should not be identical!"; + } + } @@ -104,87 +110,15 @@ TEST(RandomTest, OperatorEquals) { } -TEST(RandomTest, SerializationDeserialization) { - // test serialization/deserialization - Random r1(862973); - for (int i = 0; i < 100; i++) - r1.getUInt32(); +TEST(RandomTest, testSerialization) { // test serialization/deserialization using Cereal + const UInt SEED = 862973u; + Random r1(SEED); + ASSERT_EQ(r1.getSeed(), SEED) << "RNG seed not set properly"; + //burn-in + for (int i = 0; i < 100; i++) r1.getUInt32(); EXPECT_EQ(r1.getUInt32(), 2276275187u) << "Before serialization must be same"; - // serialize - std::stringstream ostream; - ostream << r1; - - // print out serialization for debugging - std::string x(ostream.str()); -// NTA_INFO << "random serialize string: '" << x << "'"; - // Serialization should be deterministic and platform independent - const std::string expectedString = "random-v2 862973 101 endrandom-v2 "; - EXPECT_EQ(expectedString, x) << "De/serialization"; - - // deserialize into r2 - std::string s(ostream.str()); - std::stringstream ss(s); - Random r2; - ss >> r2; - - // r1 and r2 should be identical - EXPECT_EQ(r1, r2) << "load from serialization"; - EXPECT_EQ(r2.getUInt32(), 3537119063u) << "Deserialized is not deterministic"; - r1.getUInt32(); //move the same number of steps - - UInt32 v1, v2; - for (int i = 0; i < 100; i++) { - v1 = r1.getUInt32(); - v2 = r2.getUInt32(); - EXPECT_EQ(v1, v2) << "serialization"; - } -} - - -TEST(RandomTest, testSerialization2) { - const UInt n=1000; - Random r1(7); - Random r2; - - htm::Timer testTimer; - testTimer.start(); - for (UInt i = 0; i < n; ++i) { - r1.getUInt32(); - // Serialize - ofstream os("random3.stream", ofstream::binary); - os << r1; - os.flush(); - os.close(); - - // Deserialize - ifstream is("random3.stream", ifstream::binary); - is >> r2; - is.close(); - - // Test - ASSERT_EQ(r1.getUInt32(), r2.getUInt32()); - ASSERT_EQ(r1.getUInt32(), r2.getUInt32()); - ASSERT_EQ(r1.getUInt32(), r2.getUInt32()); - ASSERT_EQ(r1.getUInt32(), r2.getUInt32()); - ASSERT_EQ(r1.getUInt32(), r2.getUInt32()); - } - testTimer.stop(); - - remove("random3.stream"); - - cout << "Random serialization: " << testTimer.getElapsed() << endl; -} - - -TEST(RandomTest, testSerialization_ar) { - // test serialization/deserialization - Random r1(862973); - for (int i = 0; i < 100; i++) - r1.getUInt32(); - - EXPECT_EQ(r1.getUInt32(), 2276275187u) << "Before serialization must be same"; // serialize std::stringstream ss; r1.save(ss); @@ -196,14 +130,14 @@ TEST(RandomTest, testSerialization_ar) { // r1 and r2 should be identical EXPECT_EQ(r1, r2) << "load from serialization"; EXPECT_EQ(r2.getUInt32(), 3537119063u) << "Deserialized is not deterministic"; - r1.getUInt32(); //move the same number of steps + EXPECT_EQ(r1.getUInt32(), 3537119063u); //move the same number of steps + EXPECT_EQ(r1.getSeed(), r2.getSeed()); + EXPECT_EQ(r2.getSeed(), SEED) << "Rng deserialized seed is not the same!"; - UInt32 v1, v2; for (int i = 0; i < 100; i++) { - v1 = r1.getUInt32(); - v2 = r2.getUInt32(); - EXPECT_EQ(v1, v2) << "serialization"; + EXPECT_EQ(r1.getUInt32(), r2.getUInt32()) << "serialization"; } + } @@ -223,28 +157,6 @@ TEST(RandomTest, ReturnInCorrectRange) { } } -/* -TEST(RandomTest, getUInt64) { - // tests for getUInt64 - Random r1(1); - ASSERT_EQ(2469588189546311528u, r1.getUInt64()) - << "check getUInt64, seed 1, first call"; - ASSERT_EQ(2516265689700432462u, r1.getUInt64()) - << "check getUInt64, seed 1, second call"; - - Random r2(2); - ASSERT_EQ(16668552215174154828u, r2.getUInt64()) - << "check getUInt64, seed 2, first call"; - EXPECT_EQ(15684088468973760345u, r2.getUInt64()) - << "check getUInt64, seed 2, second call"; - - Random r3(7464235991977222558); - EXPECT_EQ(8035066300482877360u, r3.getUInt64()) - << "check getUInt64, big seed, first call"; - EXPECT_EQ(623784303608610892u, r3.getUInt64()) - << "check getUInt64, big seed, second call"; -} -*/ TEST(RandomTest, getUInt32) { // tests for getUInt32