#!/usr/bin/env python3
"""
This application is a stress tester for libjio. It's not a traditional stress
test like fsx (which can be used to test libjio using the preloading library),
but uses fault injection to check how the library behaves under random
failures.
"""
import sys
import os
import time
import select
import random
import fcntl
import traceback
from optparse import OptionParser
import libjio
try:
	import fiu
except ImportError:
	print()
	print("Error: unable to load fiu module. This test needs libfiu")
	print("support. Please install libfiu and recompile libjio with FI=1.")
	print()
	raise
#
# Auxiliary stuff
#
gbcount = 0
def getbytes(n):
	global gbcount
	gbcount = (gbcount + 1) % 10
	return bytes(str(gbcount) * n, 'ascii')
def randfrange(maxend, maxsize):
	start = random.randint(0, maxend - 1)
	size = random.randint(0, (maxend - 1) - start) % maxsize
	return start, start + size
def randfloat(min, max):
	return min + random.random() % (max - min)
class ConsistencyError (Exception):
	pass
def jfsck(fname, cleanup = False):
	flags = 0
	if cleanup:
		flags = libjio.J_CLEANUP
	try:
		r = libjio.jfsck(fname, flags = flags)
		return r
	except IOError as e:
		if e.args[0] == libjio.J_ENOJOURNAL:
			return { 'total': 0 }
		else:
			raise
def comp_cont(bytes):
	"'aaaabbcc' -> [ ('a', 4), ('b', 2), ('c', 2) ]"
	l = []
	prev = bytes[0]
	c = 1
	for b in bytes[1:]:
		if (b == prev):
			c += 1
			continue
		l.append((prev, c))
		prev = b
		c = 1
	l.append((b, c))
	return l
def pread(fd, start, end):
	ppos = fd.tell()
	fd.seek(start, 0)
	r = bytes()
	c = 0
	total = end - start
	while c < total:
		n = fd.read(total - c)
		if (n == ''):
			break
		c += len(n)
		r += n
	fd.seek(ppos, 0)
	assert c == end - start
	return r
#
# Output handler, used to get a nice output when using multiple processes
#
class OutputHandler:
	def __init__(self, every):
		# fds to read from
		self.rs = []
		# we will report every this number of seconds
		self.every = every
		# how many transactions has each child processed; we use the
		# read end of the pipe to identify them
		self.ntrans = {}
		# like self.ntrans but counts only the failed ones
		self.nfailures = {}
		# fd to write to, only relevant in the child
		self.w = None
		# p = parent, c = child
		self.end = 'p'
		# last transaction number print
		self.last_print = 0
		# time of the last print
		self.last_print_time = 0
	def prefork(self):
		r, w = os.pipe()
		self.rs.append(r)
		self.ntrans[r] = 0
		self.nfailures[r] = 0
		self.w = w
	def child(self):
		self.end = 'c'
		os.close(self.rs[-1])
		self.rs = []
	def parent(self):
		os.close(self.w)
		self.w = None
	SUCCESS = bytes('1', encoding = 'ascii')
	FAILURE = bytes('0', encoding = 'ascii')
	def feed(self, success = True):
		if success:
			os.write(self.w, OutputHandler.SUCCESS)
		else:
			os.write(self.w, OutputHandler.FAILURE)
	def output_loop(self):
		while self.rs:
			rr, rw, rx = select.select(self.rs, [], [], 1)
			for r in rr:
				d = os.read(r, 1)
				if not d:
					self.rs.remove(r)
				else:
					self.ntrans[r] += 1
					if d == OutputHandler.FAILURE:
						self.nfailures[r] += 1
			self.cond_print()
		self.print()
		return sum(self.ntrans.values()), sum(self.nfailures.values())
	def cond_print(self):
		if time.time() - self.last_print_time >= self.every:
			self.print()
	def print(self):
		self.last_print_time = time.time()
		for r in sorted(self.ntrans):
			print("%4d" % self.ntrans[r], end = ' ')
		print()
#
# Lock manager, used to lock ranges between multiple processes
#
# We can't lock the real file because that would ruin libjio's locking, so we
# create a new file, remove it, and use fcntl locking. Not very elegant but it
# does the trick.
#
class VoidLockManager:
	def __init__(self):
		pass
	def lock(self, start, end):
		pass
	def unlock(self, start, end):
		pass
class LockManager:
	def __init__(self):
		fname = "/tmp/js-lock-tmp." + str(os.getpid())
		self.fd = open(fname, 'w+')
		os.unlink(fname)
	def lock(self, start, end):
		#print(os.getpid(), '\tlock:', start, end)
		#sys.stdout.flush()
		fcntl.lockf(self.fd, fcntl.LOCK_EX, end - start, start)
	def unlock(self, start, end):
		#print(os.getpid(), '\tunlock:', start, end)
		#sys.stdout.flush()
		fcntl.lockf(self.fd, fcntl.LOCK_UN, end - start, start)
#
# A range of bytes inside a file, used inside the transactions
#
# Note it can't "remember" the fd as it may change between prepare() and
# verify().
#
class Range:
	def __init__(self, fsize, maxlen, lockmgr):
		# public
		self.start, self.end = randfrange(fsize, maxlen)
		self.new_data = None
		self.type = 'r'
		# private
		self.prev_data = None
		self.new_data_ctx = None
		self.read_buf = None
		self.lockmgr = lockmgr
		self.locked = False
		# read an extended range so we can check we
		# only wrote what we were supposed to
		self.ext_start = max(0, self.start - 32)
		self.ext_end = min(fsize, self.end + 32)
	def __lt__(self, other):
		return self.ext_start < other.ext_start
	def __del__(self):
		if self.locked:
			self.lockmgr.unlock(self.ext_start, self.ext_end)
	def overlaps(self, other):
		if (other.ext_start <= self.ext_start <= other.ext_end) or \
		   (other.ext_start <= self.ext_end <= other.ext_end) or \
		   (self.ext_start <= other.ext_start <= self.ext_end) or \
		   (self.ext_start <= other.ext_end <= self.ext_end):
			return True
		return False
	def prepare_r(self):
		self.type = 'r'
		self.read_buf = bytearray(self.end - self.start)
		self.lockmgr.lock(self.ext_start, self.ext_end)
		self.locked = True
	def verify_r(self, fd):
		real_data = pread(fd, self.start, self.end)
		if real_data != self.read_buf:
			print('Corruption detected')
			self.show(fd)
			raise ConsistencyError
	def prepare_w(self, fd):
		self.type = 'w'
		self.lockmgr.lock(self.ext_start, self.ext_end)
		self.locked = True
		self.prev_data = pread(fd, self.ext_start, self.ext_end)
		self.new_data = getbytes(self.end - self.start)
		self.new_data_ctx = \
			self.prev_data[:self.start - self.ext_start] \
			+ self.new_data \
			+ self.prev_data[- (self.ext_end - self.end):]
		return self.new_data, self.start
	def verify_w(self, fd):
		# NOTE: fd must be a real file
		real_data = pread(fd, self.ext_start, self.ext_end)
		if real_data not in (self.prev_data, self.new_data_ctx):
			print('Corruption detected')
			self.show(fd)
			raise ConsistencyError
	def verify(self, fd):
		if self.type == 'r':
			self.verify_r(fd)
		else:
			self.verify_w(fd)
	def show(self, fd):
		real_data = pread(fd, self.start, self.end)
		print('Range:', self.ext_start, self.ext_end)
		print('Real:', comp_cont(real_data))
		if self.type == 'w':
			print('Prev:', comp_cont(self.prev_data))
			print('New: ', comp_cont(self.new_data_ctx))
		else:
			print('Buf:', comp_cont(self.read_buf))
		print()
#
# Transactions
#
class T_base:
	"Interface for the transaction types"
	def __init__(self, f, jf, fsize, lockmgr, do_verify):
		pass
	def prepare(self):
		pass
	def apply(self):
		pass
	def verify(self, write_only = False):
		pass
class T_jwrite (T_base):
	def __init__(self, f, jf, fsize, lockmgr, do_verify):
		self.f = f
		self.jf = jf
		self.fsize = fsize
		self.do_verify = do_verify
		self.maxoplen = min(int(fsize / 256), 2 * 1024 * 1024)
		self.range = Range(self.fsize, self.maxoplen, lockmgr)
	def prepare(self):
		self.range.prepare_w(self.f)
	def apply(self):
		self.jf.pwrite(self.range.new_data, self.range.start)
	def verify(self, write_only = False):
		if not self.do_verify:
			return
		self.range.verify(self.f)
class T_writeonly (T_base):
	def __init__(self, f, jf, fsize, lockmgr, do_verify):
		self.f = f
		self.jf = jf
		self.fsize = fsize
		self.do_verify = do_verify
		# favour many small ops
		self.maxoplen = 512 * 1024
		self.nops = random.randint(1, 26)
		self.ranges = []
		c = 0
		while len(self.ranges) < self.nops and c < self.nops * 1.25:
			candidate = Range(self.fsize, self.maxoplen, lockmgr)
			safe = True
			for r in self.ranges:
				if candidate.overlaps(r):
					safe = False
					break
			if safe:
				self.ranges.append(candidate)
			c += 1
		# sort the transactions so there's no risk of internal
		# deadlocks via the lock manager
		self.ranges.sort()
	def prepare(self):
		for r in self.ranges:
			r.prepare_w(self.f)
	def apply(self):
		t = self.jf.new_trans()
		for r in self.ranges:
			t.add_w(r.new_data, r.start)
		t.commit()
	def verify(self, write_only = False):
		if not self.do_verify:
			return
		try:
			for r in self.ranges:
				r.verify(self.f)
		except ConsistencyError:
			# show context on errors
			print("-" * 50)
			for r in self.ranges:
				r.show(self.f)
			print("-" * 50)
			raise
class T_readwrite (T_writeonly):
	def __init__(self, f, jf, fsize, lockmgr, do_verify):
		T_writeonly.__init__(self, f, jf, fsize, lockmgr, do_verify)
		self.read_ranges = []
	def prepare(self):
		for r in self.ranges:
			if random.choice((True, False)):
				r.prepare_w(self.f)
			else:
				r.prepare_r()
	def apply(self):
		t = self.jf.new_trans()
		for r in self.ranges:
			if r.type == 'r':
				t.add_r(r.read_buf, r.start)
			else:
				t.add_w(r.new_data, r.start)
		t.commit()
	def verify(self, write_only = False):
		if not self.do_verify:
			return
		try:
			for r in self.ranges:
				if write_only and r.type == 'r':
					continue
				r.verify(self.f)
		except ConsistencyError:
			# show context on errors
			print("-" * 50)
			for r in self.ranges:
				r.show(self.f)
			print("-" * 50)
			raise
t_list = [T_jwrite, T_writeonly, T_readwrite]
#
# The test itself
#
class Stresser:
	def __init__(self, fname, fsize, nops, use_fi, use_as, output,
			lockmgr, do_verify):
		self.fname = fname
		self.fsize = fsize
		self.nops = nops
		self.use_fi = use_fi
		self.use_as = use_as
		self.output = output
		self.lockmgr = lockmgr
		self.do_verify = do_verify
		jflags = 0
		if use_as:
			jflags = libjio.J_LINGER
		self.jf = libjio.open(fname, libjio.O_RDWR | libjio.O_CREAT,
				0o600, jflags)
		self.f = open(fname, mode = 'rb')
		self.jf.truncate(fsize)
		if use_as:
			self.jf.autosync_start(5, 10 * 1024 * 1024)
	def apply(self, trans):
		trans.prepare()
		trans.apply()
		trans.verify()
		return True
	def apply_fork(self, trans):
		# do the prep before the fork so we can verify() afterwards
		trans.prepare()
		sys.stdout.flush()
		pid = os.fork()
		if pid == 0:
			# child
			try:
				self.fiu_enable()
				trans.apply()
				self.fiu_disable()
			except (IOError, MemoryError):
				try:
					self.reopen(trans)
				except (IOError, MemoryError):
					pass
				except:
					self.fiu_disable()
					traceback.print_exc()
				self.fiu_disable()
				sys.exit(1)
			except MemoryError:
				self.fiu_disable()
				sys.exit(1)
			except:
				self.fiu_disable()
				traceback.print_exc()
				sys.exit(1)
			trans.verify()
			sys.exit(0)
		else:
			# parent
			id, status = os.waitpid(pid, 0)
			if not os.WIFEXITED(status):
				i = (status,
					os.WIFSIGNALED(status),
					os.WTERMSIG(status))
				raise RuntimeError(i)
			if os.WEXITSTATUS(status) != 0:
				return False
			return True
	def reopen(self, trans):
		self.jf = None
		r = jfsck(self.fname)
		trans.verify(write_only = True)
		self.jf = libjio.open(self.fname,
			libjio.O_RDWR | libjio.O_CREAT, 0o600)
		return r
	def fiu_enable(self):
		if not self.use_fi:
			return
		# To improve code coverage, we randomize the probability each
		# time we enable failure points
		fiu.enable_random('jio/*',
				probability = randfloat(0.0005, 0.005))
		fiu.enable_random('linux/*',
				probability = randfloat(0.005, 0.03))
		fiu.enable_random('posix/*',
			probability = randfloat(0.005, 0.03))
		fiu.enable_random('libc/mm/*',
			probability = randfloat(0.003, 0.07))
		fiu.enable_random('libc/str/*',
			probability = randfloat(0.005, 0.07))
	def fiu_disable(self):
		if self.use_fi:
			fiu.disable('libc/mm/*')
			fiu.disable('posix/*')
			fiu.disable('jio/*')
			fiu.disable('linux/*')
	def run(self):
		nfailures = 0
		for i in range(1, self.nops + 1):
			trans = random.choice(t_list)(self.f, self.jf,
					self.fsize, self.lockmgr,
					self.do_verify)
			if self.use_fi:
				r = self.apply_fork(trans)
			else:
				r = self.apply(trans)
			if r:
				self.output.feed(success = True)
			else:
				self.output.feed(success = False)
				nfailures += 1
				r = self.reopen(trans)
				trans.verify(write_only = True)
		return nfailures
#
# Main
#
def run_stressers(nproc, fname, fsize, nops, use_fi, use_as, output, lockmgr,
		do_verify):
	pids = []
	print("Launching stress test")
	for i in range(nproc):
		# Calculate how many operations will this child perform. The
		# last one will work a little more so we get exactly nops.
		# Note we prefer to work extra in the end rather than having
		# the last process with 0 child_nops, that's why we use int()
		# instead of round() or ceil().
		child_nops = int(nops / nproc)
		if i == nproc - 1:
			child_nops = nops - int(nops / nproc) * i
		output.prefork()
		sys.stdout.flush()
		pid = os.fork()
		if pid == 0:
			# child
			output.child()
			s = Stresser(fname, fsize, child_nops, use_fi, use_as,
					output, lockmgr, do_verify)
			s.run()
			sys.exit(0)
		else:
			output.parent()
			pids.append(pid)
	print("Launched stress tests")
	totalops, nfailures = output.output_loop()
	print("Stress test completed, waiting for children")
	nerrors = 0
	for pid in pids:
		rpid, status = os.waitpid(pid, 0)
		if os.WEXITSTATUS(status) != 0:
			nerrors += 1
	print("  %d operations" % totalops)
	print("  %d simulated failures" % nfailures)
	print("  %d processes ended with errors" % nerrors)
	if nerrors:
		return False
	return True
def main():
	usage = "Use: %prog [options] <file name> <file size in Mb>"
	parser = OptionParser(usage = usage)
	parser.add_option("-n", "--nops", dest = "nops", type = "int",
		default = 100,
		help = "number of operations (defaults to %default)")
	parser.add_option("-p", "--nproc", dest = "nproc", type = "int",
		default = 1,
		help = "number of processes (defaults to %default)")
	parser.add_option("", "--fi", dest = "use_fi",
		action = "store_true", default = False,
		help = "use fault injection (conflicts with --as)")
	parser.add_option("", "--as", dest = "use_as",
		action = "store_true", default = False,
		help = "use J_LINGER + autosync (conflicts with --fi)")
	parser.add_option("", "--no-internal-lock",
		dest = "use_internal_locks", action = "store_false",
		default = True,
		help = "do not lock internally, disables verification")
	parser.add_option("", "--no-verify", dest = "do_verify",
		action = "store_false", default = True,
		help = "do not perform verifications")
	parser.add_option("", "--keep", dest = "keep",
		action = "store_true", default = False,
		help = "keep the file after completing the test")
	options, args = parser.parse_args()
	if len(args) != 2:
		parser.print_help()
		sys.exit(1)
	fname = args[0]
	try:
		fsize = int(args[1]) * 1024 * 1024
	except ValueError:
		print("Error: the size of the file must be numeric")
		sys.exit(1)
	if options.use_fi and options.use_as:
		print("Error: --fi and --as cannot be used together")
		sys.exit(1)
	if not options.use_internal_locks:
		options.do_verify = False
	output = OutputHandler(every = 2)
	if options.use_internal_locks:
		lockmgr = LockManager()
	else:
		lockmgr = VoidLockManager()
	success = run_stressers(options.nproc, fname, fsize, options.nops,
			options.use_fi, options.use_as, output, lockmgr,
			options.do_verify)
	r = jfsck(fname)
	print("Final check completed")
	if success and not options.keep:
		jfsck(fname, cleanup = True)
		os.unlink(fname)
	if not success:
		print("Test failed")
		return 1
	return 0
if __name__ == '__main__':
	sys.exit(main())