# -*- coding: utf-8 -*-
# Created by mjribeiro at 18-10-2019

import numpy as np

class OrnsteinUhlenbeckActionNoise:

	def __init__(self, action_dim, mu = 0, theta = 0.15, sigma = 0.2):
		self.action_dim = action_dim
		self.mu = mu
		self.theta = theta
		self.sigma = sigma
		self.X = np.ones(self.action_dim) * self.mu

	def reset(self):
		self.X = np.ones(self.action_dim) * self.mu

	def sample(self):
		dx = self.theta * (self.mu - self.X)
		dx = dx + self.sigma * np.random.randn(len(self.X))
		self.X = self.X + dx
		return self.X


# # use this to plot Ornstein Uhlenbeck random motion
# if __name__ == '__main__':
# 	ou = OrnsteinUhlenbeckActionNoise(1)
# 	states = []
# 	for i in range(1000):
# 		states.append(ou.sample())
# 	import matplotlib.pyplot as plt
#
# 	plt.plot(states)
# 	plt.show()