#include "stdafx.h"
#include "ReplaceTasks.h"
#include "Engine.h"
#include "VTable.h"
#include "Function.h"
#include "TypeTransform.h"
#include "Lib/PinnedSet.h"

namespace storm {

	ReplaceTasks::ReplaceTasks() : activeFunctions(null) {
		replaceMap = new RawObjMap(engine().gc);
		vtableMap = new RawObjMap(engine().gc);
		exceptions = new (this) Array<Exception *>();
		functionsToUpdate = new (this) Array<Function *>();
		transforms = new (this) Map<Type *, TypeTransform *>();
	}

	ReplaceTasks::ReplaceTasks(PauseThreads &src) : activeFunctions(src.activeFunctions()) {
		activeFunctions->addRef();
		replaceMap = new RawObjMap(engine().gc);
		vtableMap = new RawObjMap(engine().gc);
		exceptions = new (this) Array<Exception *>();
		functionsToUpdate = new (this) Array<Function *>();
		transforms = new (this) Map<Type *, TypeTransform *>();
	}

	ReplaceTasks::ReplaceTasks(PauseThreads &src, Array<Exception *> *exceptions)
		: activeFunctions(src.activeFunctions()), exceptions(exceptions) {
		activeFunctions->addRef();
		replaceMap = new RawObjMap(engine().gc);
		vtableMap = new RawObjMap(engine().gc);
		functionsToUpdate = new (this) Array<Function *>();
		transforms = new (this) Map<Type *, TypeTransform *>();
	}

	ReplaceTasks::~ReplaceTasks() {
		if (activeFunctions)
			activeFunctions->release();
		delete replaceMap;
		delete vtableMap;
	}

	void ReplaceTasks::replace(Named *old, Named *with) {
		replaceMap->put(old, with);
	}

	void ReplaceTasks::replace(const Handle *old, const Handle *with) {
		replaceMap->put((void *)old, (void *)with);
	}

	void ReplaceTasks::replace(const GcType *old, const GcType *with) {
		replaceMap->put((void *)old, (void *)with);
	}

	void ReplaceTasks::replace(VTable *old, VTable *with) {
		void *o = (void *)old->pointer();
		void *n = (void *)with->pointer();
		size_t offset = vtable::allocOffset();
		vtableMap->put((byte *)o - offset, (byte *)n - offset);
	}

	void ReplaceTasks::transform(TypeTransform *transform) {
		transforms->put(transform->oldType(), transform);
	}

	Bool ReplaceTasks::hasTransformFor(Type *type) {
		return transforms->has(type);
	}

	void ReplaceTasks::replaceActive(Function *newFn) {
		functionsToUpdate->push(newFn);
	}

	class ReplaceWalker : public PtrWalker {
	public:
		ReplaceWalker(RawObjMap *replace, RawObjMap *vtables) : replace(replace), vtables(vtables) {
			flags = fObjects | fExactRoots | fClearWatch;
		}

		RawObjMap *replace;
		RawObjMap *vtables;
		bool foundHeader;

		virtual void prepare() {
			// Now, objects don't move anymore and we can sort the array for good lookup performance!
			replace->sort();
			vtables->sort();
		}

		virtual bool checkRoot(GcRoot *root) {
			// Don't modify the maps we're working with now!
			return !replace->hasRoot(root)
				&& !vtables->hasRoot(root);
		}

		virtual void object(RootObject *obj) {
			// Check the vtable.
			void *vt = (void *)vtable::from(obj);
			size_t offset = vtable::allocOffset();
			if (void *r = vtables->find((byte *)vt - offset))
				vtable::set((byte *)r + offset, obj);

			foundHeader = false;

			// Check all pointers.
			PtrWalker::object(obj);
		}

		virtual void fixed(void *obj) {
			foundHeader = false;
			PtrWalker::fixed(obj);
		}

		virtual void array(void *obj) {
			foundHeader = false;
			PtrWalker::fixed(obj);
		}

		virtual void header(GcType **ptr) {
			// Only scan the first header of each object (i.e. don't scan the 'myGcType' inside Type objects).
			if (foundHeader)
				return;
			foundHeader = true;

			if (void *r = replace->find(*ptr))
				*ptr = (GcType *)r;
		}

		virtual void exactPointer(void **ptr) {
			if (void *r = replace->find(*ptr))
				*ptr = r;
		}

		virtual void ambiguousPointer(void **ptr) {
			// TODO. We need more information on the objects to replace in this case.
		}
	};

	// Extension of the class above to also handle replacing objects on the heap.
	class ReplaceTfmWalker : public ReplaceWalker {
	public:
		ReplaceTfmWalker(RawObjMap *replace, RawObjMap *vtables, Map<Type *, TypeTransform *> *tfmMap)
			: ReplaceWalker(replace, vtables), objects(tfmMap->engine().gc), tfmMap(tfmMap) {

			flags |= fAmbiguousRoots;
		}

		// Transforms.
		Map<Type *, TypeTransform *> *tfmMap;

		// Objects to replace.
		RawObjMap objects;

		virtual void prepare() {
			ReplaceWalker::prepare();
			objects.sort();

			typeSummary.clear();
			for (Map<Type *, TypeTransform *>::Iter i = tfmMap->begin(), end = tfmMap->end(); i != end; ++i) {
				typeSummary.insert(std::make_pair(i.k(), i.v()->summary()));
			}

			pinnedSetType = StormInfo<PinnedSet>::type(tfmMap->engine());
		}

		virtual void object(RootObject *inspect) {
			ReplaceWalker::object(inspect);

			// See if 'inspect' is a PinnedSet. We need to invalidate them since we may change pinned objects.
			const GcType *type = storm::Gc::typeOf(inspect);
			if (type->type == pinnedSetType) {
				((PinnedSet *)inspect)->invalidate();
			}
		}

		virtual bool checkRoot(GcRoot *root) {
			return ReplaceWalker::checkRoot(root)
				&& !objects.hasRoot(root);
		}

		virtual void exactPointer(void **ptr) {
			ReplaceWalker::exactPointer(ptr);

			if (void *replace = objects.find(*ptr)) {
				*ptr = replace;
			}
		}

		virtual void ambiguousPointer(void **ptr) {
			ReplaceWalker::ambiguousPointer(ptr);

			// Note that these pointers may point *into* elements!
			RawObjMap::Item found = objects.findBefore(*ptr);
			if (!found.from)
				return;

			// Check the size of the type, so we can determine if the pointer refers to the
			// allocation or not.
			// Note: We assume that code allocations are not added to the 'objects' array.
			const GcType *type = storm::Gc::typeOf(found.from);
			size_t fromSize = type->stride;
			if (type->kind == GcType::tArray || type->kind == GcType::tWeakArray)
				fromSize = sizeof(size_t) * 2 + fromSize * ((GcArray<void *> *)found.from)->count;

			size_t fromOffset = (byte *)*ptr - (byte *)found.from;
			if (fromOffset >= fromSize)
				return;

			// Now that we are sure that it points into the object, we can actually look up its type
			// and see how to replace it!
			SummaryMap::const_iterator ts = typeSummary.find(type->type);
			if (ts == typeSummary.end())
				return;

			const TypeTransform::Summary &summary = ts->second;
			// Note: We already checked the size, this is just an extra defense.
			if (fromOffset >= summary.size)
				return;

			size_t toOffset = summary.translate(fromOffset);
			*ptr = (byte *)found.to + toOffset;
		}

	private:
		// Data for replacing ambiguous pointers:

		// Type of the PinnedSet type, so that we can invalidate them.
		Type *pinnedSetType;

		// Map from object type to transformation information. Note: populated when the GC is
		// paused, so it is fine to use pointer based hashes in this way here.
		typedef hash_map<Type *, TypeTransform::Summary> SummaryMap;
		SummaryMap typeSummary;
	};

	void ReplaceTasks::apply() {
		Engine &e = engine();

		// See if we have any objects to transforms:
		if (transforms->any()) {
			ReplaceTfmWalker walker(replaceMap, vtableMap, transforms);
			applyTransforms(walker);
			e.gc.walk(walker);
		} else if (replaceMap->any() || vtableMap->any()) {
			// Only replace pointers, etc.
			ReplaceWalker walker(replaceMap, vtableMap);
			e.gc.walk(walker);
		}

		// Update active functions afterwards. This may result in compilation of code, and that
		// needs to happen after we have updated all references to replaced types!
		for (Nat i = 0; i < functionsToUpdate->count(); i++)
			functionsToUpdate->at(i)->replaceActive(this);
	}

	class FindTransformsWalker : public Walker {
	public:
		FindTransformsWalker(Map<Type *, TypeTransform *> *transforms)
			: src(transforms), found(transforms->engine().gc) {

			flags = fObjects;
		}

		// Found objects during the walk.
		RootArray<RootObject> found;

		virtual void prepare() {
			for (Map<Type *, TypeTransform *>::Iter i = src->begin(), end = src->end(); i != end; ++i) {
				types.push_back(i.k());
			}

			std::sort(types.begin(), types.end());
		}

		virtual void finalize() {
			// Copy 'tempFound' to 'found' so that we retain the objects properly.
			// Note: This probably does *not* work properly with SMM, due to lock handling there.
			found.resize(Nat(tempFound.size()));
			for (nat i = 0; i < nat(tempFound.size()); i++)
				found[i] = tempFound[i];
		}

		virtual void object(RootObject *inspect) {
			const GcType *type = storm::Gc::typeOf(inspect);
			vector<Type *>::iterator found = std::lower_bound(types.begin(), types.end(), type->type);
			if (found != types.end() && *found == type->type)
				tempFound.push_back(inspect);
		}

	private:
		Map<Type *, TypeTransform *> *src;

		// Quick lookup during the heap walk.
		vector<Type *> types;

		// Temporary storage of all objects we need to update. We can not touch 'found' during the
		// walk itself.
		vector<RootObject *> tempFound;
	};

	void ReplaceTasks::applyTransforms(ReplaceTfmWalker &out) {
		if (!transforms->any())
			return;

		Engine &e = engine();

		FindTransformsWalker walker(transforms);
		e.gc.walk(walker);

		// Look at the objects that were found and create new versions of them.
		for (nat i = 0; i < walker.found.count(); i++) {
			RootObject *old = walker.found[i];

			TypeTransform *tfm = transforms->get(runtime::typeOf(old), null);
			if (!tfm) {
				// This should not happen, if it does, we just ignore it.
				continue;
			}

			RootObject *transformed = tfm->apply(old);
			out.objects.put(old, transformed);
		}
	}

	void ReplaceTasks::error(Exception *ex) {
		exceptions->push(ex);
	}

	void ReplaceTasks::throwErrors() {
		if (exceptions->any()) {
			throw new (this) MultiException(exceptions);
		}
	}

	vector<ActiveOffset> ReplaceTasks::findActive(const void *function) const {
		vector<ActiveOffset> result;
		if (activeFunctions)
			result = activeFunctions->find(function);
		return result;
	}

	size_t ReplaceTasks::replaceActive(const void *function, size_t fOff, const void *replace, size_t rOff) const {
		if (activeFunctions)
			return activeFunctions->replace(function, fOff, replace, rOff);
		else
			return 0;
	}

}
