Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 32 additions & 22 deletions bindings/py/cpp_src/bindings/math/py_Random.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,12 @@ namespace htm_ext {
py::class_<Random_t> Random(m, "Random");

Random.def(py::init<htm::UInt64>(), py::arg("seed") = 0)
.def("getUInt32", &Random_t::getUInt32, py::arg("max") = (htm::UInt32)-1l)
.def("getReal64", &Random_t::getReal64)
.def("getSeed", &Random_t::getSeed)
.def("max", &Random_t::max)
.def("min", &Random_t::min)
.def("__eq__", [](Random_t const & self, Random_t const & other) {//wrapping operator==
return self == other;
}, py::is_operator());
.def("getUInt32", &Random_t::getUInt32, py::arg("max") = (htm::UInt32)-1l)
.def("getReal64", &Random_t::getReal64)
.def("getSeed", &Random_t::getSeed)
.def("max", &Random_t::max)
.def("min", &Random_t::min)
.def("__eq__", [](Random_t const & self, Random_t const & other) { return self == other; }, py::is_operator()); //operator==

Random.def_property_readonly_static("MAX32", [](py::object) {
return Random_t::MAX32;
Expand Down Expand Up @@ -144,25 +142,37 @@ namespace htm_ext {
}, "load from a File, using BINARY, PORTABLE, JSON, or XML format.",
py::arg("name"), py::arg("fmt") = 0);


Random.def(py::pickle(
[](const Random_t& r)
[](const Random_t& r) //__getstate__
{
std::stringstream ss;
ss << r;
return ss.str();
},
[](const std::string& str)
{
if (str.empty())
{
throw std::runtime_error("Empty state");
}
r.save(ss); //save r's state to archive (stream) with cereal

std::stringstream ss(str);
Random_t r;
ss >> r;
/* The values in stringstream are binary so pickle will get confused
* trying to treat it as utf8 if you just return ss.str().
* So we must treat it as py::bytes. Some characters could be null values.
*/
return py::bytes( ss.str() );
},

return r;
[](py::bytes &s) // __setstate__
{
/* pybind11 will pass in the bytes array without conversion.
* so we should be able to just create a string to initalize the stringstream.
*/
std::stringstream ss( s.cast<std::string>() );
std::unique_ptr<htm::Random> r(new htm::Random());
r->load(ss);

/*
* The __setstate__ part of the py::pickle() is actually a py::init() with some options.
* So the return value can be the object returned by value, by pointer,
* or by container (meaning a unique_ptr). SP has a problem with the copy constructor
* and pointers have problems knowing who the owner is so lets use unique_ptr.
* See: https://pybind11.readthedocs.io/en/stable/advanced/classes.html#custom-constructors
*/
return r;
}
));

Expand Down
56 changes: 6 additions & 50 deletions src/htm/utils/Random.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@
/** @file
Random Number Generator implementation
*/
#include <iostream> // for istream, ostream
#include <chrono> // for random seeds

#include <htm/utils/Log.hpp>
#include <htm/utils/Random.hpp>

Expand All @@ -32,64 +29,23 @@ bool Random::operator==(const Random &o) const {
gen == o.gen;
}

bool static_gen_seeded = false; //used only for seeding seed if 0/auto is passed for seed
std::mt19937 static_gen;
std::random_device rd; //HW RNG, undeterministic, platform dependant. Use only for seeding rng if random seed wanted (seed=0)

Random::Random(UInt64 seed) {
Random::Random(const UInt64 seed) {
if (seed == 0) {
if( !static_gen_seeded ) {
#if NDEBUG
unsigned int static_seed = (unsigned int)std::chrono::system_clock::now().time_since_epoch().count();
#else
unsigned int static_seed = DEBUG_RANDOM_SEED;
#endif
static_gen.seed( static_seed );
static_gen_seeded = true;
NTA_INFO << "Random seed: " << static_seed;
}
std::mt19937 static_gen(rd());
seed_ = static_gen(); //generate random value from HW RNG
} else {
seed_ = seed;
}
// if seed is zero at this point, there is a logic error.
NTA_CHECK(seed_ != 0);
NTA_CHECK(seed_ != 0) << "Random: if seed is zero at this point, there is a logic error";
gen.seed(static_cast<unsigned int>(seed_)); //seed the generator
steps_ = 0;
}


namespace htm {
std::ostream &operator<<(std::ostream &outStream, const Random &r) {
outStream << "random-v2" << " ";
outStream << r.seed_ << " ";
outStream << r.steps_ << " ";
outStream << "endrandom-v2" << " ";
return outStream;
}


std::istream &operator>>(std::istream &inStream, Random &r) {
std::string version;

inStream >> version;
NTA_CHECK(version == "random-v2") << "Random() deserializer -- found unexpected version string '"
<< version << "'";
inStream >> r.seed_;
r.gen.seed(static_cast<unsigned int>(r.seed_)); //reseed
inStream >> r.steps_;
r.gen.discard(r.steps_); //advance n steps
//FIXME we could de/serialize directly RNG gen, it should be multi-platform according to standard,
//but on OSX CI it wasn't (25/11/2018). So "hacking" the above instead.
std::string endtag;
inStream >> endtag;
NTA_CHECK(endtag == "endrandom-v2") << "Random() deserializer -- found unexpected end tag '" << endtag << "'";
inStream.ignore(1);

return inStream;
}

// helper function for seeding RNGs across the plugin barrier
UInt32 GetRandomSeed() {
return htm::Random().getUInt32();
UInt32 GetRandomSeed(const UInt seed) {
return htm::Random(seed).getUInt32();
}
} // namespace htm
26 changes: 12 additions & 14 deletions src/htm/utils/Random.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,18 +71,23 @@ namespace htm {
*/
class Random : public Serializable {
public:
Random(UInt64 seed = 0);
Random(const UInt64 seed = 0);


// Serialization
CerealAdapter;
template<class Archive>
void save_ar(Archive & ar) const {
ar( cereal::make_nvp("seed", seed_), cereal::make_nvp("steps", steps_));
ar( CEREAL_NVP(seed_),
CEREAL_NVP(steps_)
);
}
template<class Archive>
void load_ar(Archive & ar) {
ar( seed_, steps_);
gen.seed(static_cast<unsigned int>(seed_)); //reseed
ar( CEREAL_NVP(seed_),
CEREAL_NVP(steps_)
);
gen.seed(static_cast<UInt64>(seed_)); //reseed
gen.discard(steps_); //advance n steps
}

Expand Down Expand Up @@ -160,9 +165,7 @@ class Random : public Serializable {

protected:
friend class RandomTest;
friend std::ostream &operator<<(std::ostream &, const Random &);
friend std::istream &operator>>(std::istream &, Random &);
friend UInt32 GetRandomSeed();
friend UInt32 GetRandomSeed(const UInt seed);
private:
UInt64 seed_;
UInt64 steps_ = 0; //step counter, used in serialization. It is important that steps_ is in sync with number of
Expand All @@ -180,23 +183,18 @@ class Random : public Serializable {
typename std::iterator_traits<RandomIt>::difference_type i, n;
n = last - first;
for (i = n-1; i > 0; --i) {
using std::swap;
swap(first[i], first[this->getUInt32(static_cast<UInt32>(i+1))]);
std::swap(first[i], first[this->getUInt32(static_cast<UInt32>(i+1))]);
}
}
};

// serialization/deserialization
std::ostream &operator<<(std::ostream &, const Random &);
std::istream &operator>>(std::istream &, Random &);

// This function returns seeds from the Random singleton in our
// "universe" (application, plugin, python module). If, when the
// Random constructor is called, seeder_ is NULL, then seeder_ is
// set to this function. The plugin framework can override this
// behavior by explicitly setting the seeder to the RandomSeeder
// function provided by the application.
UInt32 GetRandomSeed();
UInt32 GetRandomSeed(const UInt seed=0);

} // namespace htm

Expand Down
124 changes: 18 additions & 106 deletions src/test/unit/utils/RandomTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ using namespace std;
TEST(RandomTest, Seeding) {
{
Random r;
ASSERT_TRUE(r.getSeed() != 0) << "Should initialize with randomized seed";

auto x = r.getUInt32();
ASSERT_TRUE(x != 0);
ASSERT_NE(x, 0u);
}

// test getSeed
Expand All @@ -64,6 +66,10 @@ TEST(RandomTest, Seeding) {
ASSERT_EQ(r(), 419326371u);
}

for(int i=0; i< 10; i++) {
ASSERT_NE(Random(0).getSeed(), Random(0).getSeed()) << "Randomly seeded generators should not be identical!";
}

}


Expand Down Expand Up @@ -104,87 +110,15 @@ TEST(RandomTest, OperatorEquals) {
}


TEST(RandomTest, SerializationDeserialization) {
// test serialization/deserialization
Random r1(862973);
for (int i = 0; i < 100; i++)
r1.getUInt32();
TEST(RandomTest, testSerialization) { // test serialization/deserialization using Cereal
const UInt SEED = 862973u;
Random r1(SEED);
ASSERT_EQ(r1.getSeed(), SEED) << "RNG seed not set properly";

//burn-in
for (int i = 0; i < 100; i++) r1.getUInt32();
EXPECT_EQ(r1.getUInt32(), 2276275187u) << "Before serialization must be same";
// serialize
std::stringstream ostream;
ostream << r1;

// print out serialization for debugging
std::string x(ostream.str());
// NTA_INFO << "random serialize string: '" << x << "'";
// Serialization should be deterministic and platform independent
const std::string expectedString = "random-v2 862973 101 endrandom-v2 ";
EXPECT_EQ(expectedString, x) << "De/serialization";

// deserialize into r2
std::string s(ostream.str());
std::stringstream ss(s);
Random r2;
ss >> r2;

// r1 and r2 should be identical
EXPECT_EQ(r1, r2) << "load from serialization";
EXPECT_EQ(r2.getUInt32(), 3537119063u) << "Deserialized is not deterministic";
r1.getUInt32(); //move the same number of steps

UInt32 v1, v2;
for (int i = 0; i < 100; i++) {
v1 = r1.getUInt32();
v2 = r2.getUInt32();
EXPECT_EQ(v1, v2) << "serialization";
}
}


TEST(RandomTest, testSerialization2) {
const UInt n=1000;
Random r1(7);
Random r2;

htm::Timer testTimer;
testTimer.start();
for (UInt i = 0; i < n; ++i) {
r1.getUInt32();

// Serialize
ofstream os("random3.stream", ofstream::binary);
os << r1;
os.flush();
os.close();

// Deserialize
ifstream is("random3.stream", ifstream::binary);
is >> r2;
is.close();

// Test
ASSERT_EQ(r1.getUInt32(), r2.getUInt32());
ASSERT_EQ(r1.getUInt32(), r2.getUInt32());
ASSERT_EQ(r1.getUInt32(), r2.getUInt32());
ASSERT_EQ(r1.getUInt32(), r2.getUInt32());
ASSERT_EQ(r1.getUInt32(), r2.getUInt32());
}
testTimer.stop();

remove("random3.stream");

cout << "Random serialization: " << testTimer.getElapsed() << endl;
}


TEST(RandomTest, testSerialization_ar) {
// test serialization/deserialization
Random r1(862973);
for (int i = 0; i < 100; i++)
r1.getUInt32();

EXPECT_EQ(r1.getUInt32(), 2276275187u) << "Before serialization must be same";
// serialize
std::stringstream ss;
r1.save(ss);
Expand All @@ -196,14 +130,14 @@ TEST(RandomTest, testSerialization_ar) {
// r1 and r2 should be identical
EXPECT_EQ(r1, r2) << "load from serialization";
EXPECT_EQ(r2.getUInt32(), 3537119063u) << "Deserialized is not deterministic";
r1.getUInt32(); //move the same number of steps
EXPECT_EQ(r1.getUInt32(), 3537119063u); //move the same number of steps
EXPECT_EQ(r1.getSeed(), r2.getSeed());
EXPECT_EQ(r2.getSeed(), SEED) << "Rng deserialized seed is not the same!";

UInt32 v1, v2;
for (int i = 0; i < 100; i++) {
v1 = r1.getUInt32();
v2 = r2.getUInt32();
EXPECT_EQ(v1, v2) << "serialization";
EXPECT_EQ(r1.getUInt32(), r2.getUInt32()) << "serialization";
}

}


Expand All @@ -223,28 +157,6 @@ TEST(RandomTest, ReturnInCorrectRange) {
}
}

/*
TEST(RandomTest, getUInt64) {
// tests for getUInt64
Random r1(1);
ASSERT_EQ(2469588189546311528u, r1.getUInt64())
<< "check getUInt64, seed 1, first call";
ASSERT_EQ(2516265689700432462u, r1.getUInt64())
<< "check getUInt64, seed 1, second call";

Random r2(2);
ASSERT_EQ(16668552215174154828u, r2.getUInt64())
<< "check getUInt64, seed 2, first call";
EXPECT_EQ(15684088468973760345u, r2.getUInt64())
<< "check getUInt64, seed 2, second call";

Random r3(7464235991977222558);
EXPECT_EQ(8035066300482877360u, r3.getUInt64())
<< "check getUInt64, big seed, first call";
EXPECT_EQ(623784303608610892u, r3.getUInt64())
<< "check getUInt64, big seed, second call";
}
*/

TEST(RandomTest, getUInt32) {
// tests for getUInt32
Expand Down