import omero
import omero.scripts as scripts
import omero.constants
import omero.util.script_utils as scriptUtil
from omero.rtypes import *
import omero.gateway
import omero_api_Gateway_ice
import omero.util.imageUtil as imgUtil
import subprocess
import numpy
from numpy import *
from scipy.ndimage import gaussian_filter

#-------------------------------------------------
#keep these global saves parsing them all the time	
#-------------------------------------------------
gateway = None
session = None
rawPixelStore= None
pixelsService= None
renderingEngine = None
COLOURS = scriptUtil.COLOURS

#updateService= None
#rawFileStore= None
#queryService= None

		
#-------------------------------------------------
#helper	function to log messages/debug
#-------------------------------------------------
logLines = []
def log(text):
	""" Adds lines of text to the logLines list"""
	print text
	logLines.append(text)
    
def doSomething(images, ds, sigm):
	
	log("---")
	for image in images:		
		
		#-------------------------------------------------
		#get the pixel id for the current image
		#-------------------------------------------------
		pixelId = image.getPrimaryPixels().getId().getValue()
		pixels = pixelsService.retrievePixDescription(pixelId)
		
		#-------------------------------------------------
		#get our image dimensions
		#-------------------------------------------------
		xdim = pixels.getSizeX().getValue()
		ydim = pixels.getSizeY().getValue()
		zdim = pixels.getSizeZ().getValue()
		cdim = pixels.getSizeC().getValue()
		tdim = pixels.getSizeT().getValue()
		log("image dimensions: x:%d y:%d z:%d c:%d t:%d" % (xdim, ydim, zdim, cdim, tdim))
		
		pixtype = pixels.getPixelsType().getValue().getValue()
		maxpix = -1
		numpydtype = -1
		if pixtype == "uint8":
			maxpix = 255
			numpydtype = numpy.uint8;
		elif  pixtype == "uint16":
			maxpix = 65535
			numpydtype = numpy.uint16;
		else:
			log("Pixel type is unsupported!")
			return

		#-------------------------------------------------
		#Create a new image to store results
		#-------------------------------------------------
		newimage_pixtype = pixels.getPixelsType()
		newimage_name = image.getName().getValue()
		newimage_desc = "3D Gaussian blur with Sigma: %f\n on parent Image:\n  Name: %s\n  Image ID: %d" % (sigm, image.getName().getValue(), image.getId().getValue())	
		newimage_channelList = range(cdim);
		newimage_iId = pixelsService.createImage(xdim, ydim, zdim, tdim, newimage_channelList, newimage_pixtype, newimage_name, newimage_desc)
		newimage = gateway.getImage(newimage_iId.getValue())
		newimage_pixelsId = newimage.getPrimaryPixels().getId().getValue()
		newimage_pixels = pixelsService.retrievePixDescription(newimage_pixelsId)

		image3d = numpy.ndarray(shape=(zdim,ydim,xdim), dtype=numpydtype, order='F')

		newimage_pixels.setPhysicalSizeX(rdouble(pixels.getPhysicalSizeX().getValue()))
		newimage_pixels.setPhysicalSizeY(rdouble(pixels.getPhysicalSizeY().getValue()))
		newimage_pixels.setPhysicalSizeZ(rdouble(pixels.getPhysicalSizeZ().getValue()))
		gateway.saveObject(newimage_pixels)

		#-------------------------------------------------
		#iterate over our image dimensions
		#-------------------------------------------------
		for c in range(cdim):
			for t in range(tdim):
				for z in range(zdim):
				
					#-------------------------------------------------
					# get the plane
					#-------------------------------------------------
					pixelsId = pixels.getId().getValue()
					bypassOriginalFile = True
					rawPixelStore.setPixelsId(pixelsId, bypassOriginalFile)
					plane2D = scriptUtil.downloadPlane(rawPixelStore, pixels, z, c, t)
					image3d[z,:,:] = plane2D
					
				#-------------------------------------------------
				# Perform a 3D gaussian
				#-------------------------------------------------			
				image3d = gaussian_filter(image3d, sigm)
				
				for z in range(zdim):
					#-------------------------------------------------
					# save the plane
					#-------------------------------------------------
					plane2D = image3d[z,:,:]
					newimage_pixelsId = newimage_pixels.getId().getValue()
					bypassOriginalFile = True
					rawPixelStore.setPixelsId(newimage_pixelsId, bypassOriginalFile)
					scriptUtil.uploadPlane(rawPixelStore, plane2D, z, c, t)
					
			scriptUtil.resetRenderingSettings(renderingEngine, newimage_pixelsId, c, 0, maxpix,  COLOURS["White"])
	
		#-------------------------------------------------
		# put the image in dataset, if specified. 
		#-------------------------------------------------
		if ds:
			link = omero.model.DatasetImageLinkI()
			link.parent = omero.model.DatasetI(ds.getId().getValue(), False)
			link.child = omero.model.ImageI(newimage.getId().getValue(), False)
			gateway.saveAndReturnObject(link)
			
	log("---")

def parseCommandArguments(commandArgs):    
	log("---")

	parent = None
	imageIds = []
	datasetIds = []
	projectIds = []

	#-------------------------------------------------
	#sigma value for our gaussian blur 
	#-------------------------------------------------
	sigma = commandArgs["Sigma"]
	log("sigma: %f" % sigma);
	log("---");

	#-------------------------------------------------
	#parse IDs for projects, datasets, images 
	#-------------------------------------------------
	dataType = commandArgs["Data_Type"]
	log("dataType: %s" % dataType);
	log("---");
	if dataType == "Image": 						#IMAGES
		for imageId in commandArgs["IDs"]:
			try:
				iId = long(imageId.getValue())
				imageIds.append(iId)
			except: pass
	elif dataType == "Dataset": 					#DATASETS
		for datasetId in commandArgs["IDs"]:
			try:
				dId = long(datasetId.getValue())
				datasetIds.append(dId)
			except: pass
	else: 											#PROJECTS
		for projectId in commandArgs["IDs"]:
			try:
				pId = long(projectId.getValue())
				projectIds.append(pId)
			except: pass
	log("---")
	
	#-------------------------------------------------
	#check actually we have something
	#-------------------------------------------------
	if len(imageIds) == 0 and len(datasetIds) == 0 and len(projectIds) == 0:
		print "No image IDs, dataset IDs found or proeject IDs found."       
	log("Found: %d projects, %d datasets, %d images" % (len(projectIds), len(datasetIds), len(imageIds)))
	
	ds = None
	
	#-------------------------------------------------
	#now retrieve all the imageIDs we wish to process
	#-------------------------------------------------
	#PROJECTS
	projects = []
	if len(projectIds) != 0:
		projects = gateway.getProjects(projectIds, False)

	for project in projects:
		if project == None:
			print "No project found for ID: %s" % projectId
			continue
		projectName = project.getName().getValue()
		images = gateway.getImages(omero.api.ContainerClass.Project, [project.getId().getValue()])
		log("Project: %s     ID: %d     Images: %d" % (projectName, project.getId().getValue(), len(images)))
		#-------------------------------------------------
		doSomething(images, ds, sigma)
		
	#DATASETS
	for datasetId in datasetIds:
		dataset = gateway.getDataset(datasetId, False)
		if dataset == None: 
			print "No dataset found for ID: %s" % datasetId
			continue
		datasetName = dataset.getName().getValue()
		ds = dataset
		images = gateway.getImages(omero.api.ContainerClass.Dataset, [datasetId])
		log("Dataset: %s     ID: %d     Images: %d" % (datasetName, datasetId, len(images)))
		#-------------------------------------------------
		doSomething(images, ds, sigma)

	#IMAGES
	if len(datasetIds) == 0:
		images = []
		for imageId in imageIds:
			images.append(gateway.getImage(imageId))
		#-------------------------------------------------
		doSomething(images, ds, sigma)
	
if __name__ == '__main__':
	
	#-------------------------------------------------
	#Intialize our script parameters
	#-------------------------------------------------
	dataTypes = [rstring('Project'),rstring('Dataset'),rstring('Image')]

	client = scripts.client('3D_Gaussian_Blur', """Inverts all images selected, in a dataset or in a project""",
	
	scripts.String("Data_Type", optional=False, grouping="1",
		description="The data you want to work with.", values=dataTypes, default="Image"),
	
	scripts.List("IDs", optional=False, grouping="2",
		description="List of Project IDs, Dataset IDs or Image IDs").ofType(rlong(0)),
	
	scripts.Float("Sigma", optional=False, grouping="2",
		description="The sigma value for the 3D Gaussian", default=0.5),
		
	version = "4.2.1",
	authors = ["Jerome Avondo", "JIC"],
	institutions = ["John Innes Centre"],
	contact = "jerome.avondo@bbsrc.ac.uk",
	)
	
	#-------------------------------------------------
	#Run our script
	#-------------------------------------------------
	try:
	
		session = client.getSession()
	
		gateway = session.createGateway()
		rawPixelStore= session.createRawPixelsStore()
		pixelsService= session.getPixelsService()
		renderingEngine = session.createRenderingEngine()

		#gateway = session.createGateway()
		#re = session.createRenderingEngine()
		#queryService = session.getQueryService()
		#pixelsService = session.getPixelsService()
		#rawPixelStore = session.createRawPixelsStore()
		#updateService = session.getUpdateService()
		#rawFileStore = session.createRawFileStore()

		commandArgs = {}
		
		for key in client.getInputKeys():
			if client.getInput(key):
				commandArgs[key] = client.getInput(key).getValue()
		
		log("---");
		log("commands: %s" % commandArgs)
		
		parseCommandArguments(commandArgs)
		
		client.setOutput("Message", rstring("Script was successfull"))
		
	finally:	
		client.closeSession()
