## Main codes of our paper
import torch
import torch.nn as nn
import torch.nn.functional as F

from domainbed import networks


ALGORITHMS = ['ERM', 'LFME']

def get_algorithm_class(algorithm_name):
    """Return the algorithm class with the given name."""
    if algorithm_name not in globals():
        raise NotImplementedError("Algorithm not found: {}".format(algorithm_name))
    return globals()[algorithm_name]


class Algorithm(torch.nn.Module):
	"""
	A subclass of Algorithm implements a domain generalization algorithm.
	Subclasses should implement the following:
	- update()
	- predict()
	"""
	def __init__(self, input_shape, num_classes, num_domains, hparams):
		super(Algorithm, self).__init__()
		self.hparams = hparams

	def update(self, minibatches, unlabeled=None):
		"""
		Perform one update step, given a list of (x, y) tuples for all
		environments.

		Admits an optional list of unlabeled minibatches from the test domains,
		when task is domain_adaptation.
		"""
		raise NotImplementedError

	def predict(self, x):
		raise NotImplementedError

class ERM(Algorithm):
	"""
	Empirical Risk Minimization (ERM)
	"""

	def __init__(self, input_shape, num_classes, num_domains, hparams):
		super(ERM, self).__init__(input_shape, num_classes, num_domains, hparams)
		self.featurizer = networks.Featurizer(input_shape, self.hparams)
		self.classifier = networks.Classifier(
			self.featurizer.n_outputs,
			num_classes,
			self.hparams['nonlinear_classifier'])

		self.network = nn.Sequential(self.featurizer, self.classifier)
		self.optimizer = torch.optim.Adam(
			self.network.parameters(),
			lr=self.hparams["lr"],
			weight_decay=self.hparams['weight_decay']
		)

	def update(self, minibatches, unlabeled=None):
		all_x = torch.cat([x for x,y in minibatches])
		all_y = torch.cat([y for x,y in minibatches])
		loss = F.cross_entropy(self.predict(all_x), all_y)

		self.optimizer.zero_grad()
		loss.backward()
		self.optimizer.step()

		return {'loss': loss.item()}

	def predict(self, x):
		z = self.featurizer(x)
		return self.classifier(z)


class LFME(Algorithm):
    """
    Learning from Experts for Domain Generalization
    """

	def __init__(self, input_shape, num_classes, num_domains, hparams):
		super(LFME, self).__init__(input_shape, num_classes, num_domains, hparams)
		self.MSEloss = nn.MSELoss()
		self.expert_number = num_domains + 1
		self.num_classes = num_classes
		self.featurizer = [None] * self.expert_number
		self.classifier = [None] * self.expert_number
		self.network = [None] * self.expert_number
		self.optimizer = [None] * self.expert_number
		device = 'cuda' #or 'cpu'
		for i in range(self.expert_number):
			self.featurizer[i] = networks.Featurizer(input_shape, self.hparams).to(device)
			self.classifier[i] = networks.Classifier(self.featurizer[i].n_outputs,
				num_classes,self.hparams['nonlinear_classifier']).to(device)
			self.network[i] = nn.Sequential(self.featurizer[i], self.classifier[i])
			self.optimizer[i] = torch.optim.Adam(
				self.network[i].parameters(),
				lr=self.hparams["lr"],
				weight_decay=self.hparams['weight_decay']
				)

	def update(self, minibatches, unlabeled=None):
		all_x = torch.cat([x for x, y in minibatches])
		all_y = torch.cat([y for x, y in minibatches])
		expert = torch.zeros(all_y.shape[0], self.num_classes).to('cuda')
		for i in range(self.expert_number-1):
			mmbatch = minibatches[i]
			## update each expert
			part_x, part_y = mmbatch[0], mmbatch[1]
			result_expert = self.network[i](part_x)
			loss = F.cross_entropy(result_expert, part_y)
			self.optimizer[i].zero_grad()
			loss.backward()
			self.optimizer[i].step()
			index, end = (i) * part_y.shape[0], (i + 1) * part_y.shape[0]
			expert[index:end, :] = F.softmax(result_expert, dim=1)


		result_target = self.network[-1](all_x)
		loss_cla = F.cross_entropy(result_target, all_y)
		loss_guid = self.MSEloss(result_target, expert.detach())
		loss = loss_cla + loss_guid * self.hparams['lfe_reg']
		self.optimizer[-1].zero_grad()
		loss.backward()
		self.optimizer[-1].step()

		return {'loss': loss.item()}

	def predict(self, x):
		result = self.network[-1](x)
		return result
