"""
Deleting Headshape Points
==============

One cause of bad co-registrations can be due to the presence of  misleading or erroneous headshape points. These can be caused by errors in the recording of the headshape points when the experimenter was using the polhemus system.

In this notebook, we provide a custom function that can be used to delete the polhemus-derived headshape points.

Let's first define the function.

"""


import os.path as op
import sys
from osl_ephys import utils, source_recon
import numpy as np
from osl_ephys.source_recon.rhino.coreg import get_coreg_filenames
import matplotlib.pyplot as plt

fsl_dir = '/Users/matsvanes/fsl'
source_recon.setup_fsl(fsl_dir)

def delete_headshape_points(recon_dir=None, subject=None, polhemus_headshape_file=None):

    '''
    Shows an interactive figure of the polhemus derived headshape
    points in polhemus space.
    Points can be clicked on to delete them.
    The figure should be closed upon completion, at which point
    there is the option to save the deletions.

    Parameters
    ----------
    subjects_dir : string
        Directory containing the subject directories, in the 
        directory structure used by RHINO:
    subject : string
        Subject directory name, in the directory structure used 
        by RHINO:
    polhemus_headshape_file: string
        Full file path to get the polhemus_headshape_file from, 
        and to save any changes to. Note that this is an npy file
        containing the (3 x num_headshapepoints) numpy array of 
        headshape points.
        
    Notes
    -----
    We can call this in two different ways, either:

    1) Specify the subjects_dir AND the subject directory in the 
    directory structure used by RHINO:
    
    delete_headshape_points(recon_dir=recon_dir, subject=subject)
    
    or:
    
    2) Specify the full path to the .npy file containing the (3 x num_headshapepoints) 
    numpy array of headshape points:
    
    delete_headshape_points(polhemus_headshape_file=polhemus_headshape_file)
    '''

    if recon_dir is not None and subject is not None:
        coreg_filenames = get_coreg_filenames(recon_dir, subject)
        polhemus_headshape_file = coreg_filenames["polhemus_headshape_file"]
    elif polhemus_headshape_file is not None:
        polhemus_headshape_file = polhemus_headshape_file
    else:
        ValueError('Invalid inputs. See function\'s documentation.')
      
    polhemus_headshape_polhemus = np.loadtxt(polhemus_headshape_file)

    print("Num headshape points={}".format(polhemus_headshape_polhemus.shape[1]))
    print('Click on points to delete them.')
    print('Press "w" to write changes to the file')
    sys.stdout.flush()

    def scatter_headshapes(ax, x, y, z):
        # Polhemus-derived headshape points
        color, scale, alpha, marker = "red", 8, 0.7, "o"
        ax.scatter(x,y,z,
            color=color,
            marker=marker,
            s=scale,
            alpha=alpha,
            picker=5,
        )
        plt.draw()

    x=list(polhemus_headshape_polhemus[0,:])
    y=list(polhemus_headshape_polhemus[1,:])
    z=list(polhemus_headshape_polhemus[2,:])

    # Create scatter plot
    fig = plt.figure()
    ax = plt.axes(projection="3d")
    scatter_headshapes(ax, x, y, z)

    # Define function to handle click events
    def on_click(event):
        # Get index of clicked point
        ind = event.ind
        # Remove selected points from data arrays
        print('Deleted: {}, {}, {}'.format(x[ind[0]], y[ind[0]], z[ind[0]]))
        sys.stdout.flush()
        
        x.pop(ind[0])
        y.pop(ind[0])
        z.pop(ind[0])
        # Update scatter plot
        ax.cla()
        scatter_headshapes(ax, x, y, z)

    def on_press(event):

        if event.key == 'w':
            polhemus_headshape_polhemus_new = np.array([x, y, z])
            print("Num headshape points remaining={}".format(polhemus_headshape_polhemus_new.shape[1]))
            np.savetxt(coreg_filenames["polhemus_headshape_file"], polhemus_headshape_polhemus_new)
            print('Changes saved to file {}'.format(coreg_filenames["polhemus_headshape_file"]))
                    
    # Connect click event to function
    fig.canvas.mpl_connect('pick_event', on_click)
    fig.canvas.mpl_connect('key_press_event', on_press)

    plt.show()

#%%
# Typically, you would use this function after running ``source_recon.rhino.coreg`` (either directly, or via the batch API),
# and after diagnosing a bad coreg (again, either directly using 
# ``source_recon.rhino.coreg_display`` , or via the html report generated by using the batch API).
# 
# To put ourselves in this situation we will first download the appropriate data and copy the headshape points to the appropriate paths in the assumed RHINO directory structure:
#
# 
# Download files
#``!pip install osfclient``

import os
import os.path as op
from osl_ephys import utils

def get_data(name):
    print('Data will be in directory {}'.format(os.getcwd()))
    """Download a dataset from OSF."""
    if os.path.exists(f"{name}"):
        return f"{name} already downloaded. Skipping.."
    os.system(f"osf -p zxb6c fetch SourceRecon/data/{name}.zip")
    os.system(f"unzip -o {name}.zip")
    os.remove(f"{name}.zip")
    return f"Data downloaded to: {name}"

# Download the dataset
get_data("notts_2subjects")

## Setup file names
data_dir = './notts_2subjects'
recon_dir = './notts_2subjects/recon'

subject = '{subject}'
fif_files_path = op.join(data_dir, subject, subject + '_task-resteyesopen_meg_preproc_raw.fif')    
fif_files = utils.Study(fif_files_path)
subjects = fif_files.fields['subject']
fif_files = fif_files.get()

## Copy polhemus files
import numpy as np

def copy_polhemus_files(recon_dir, subject, preproc_file, smri_file, logger):
    polhemus_headshape = np.loadtxt(op.join(data_dir, subject, 'polhemus/polhemus_headshape.txt'))
    polhemus_nasion = np.loadtxt(op.join(data_dir, subject, 'polhemus/polhemus_nasion.txt'))
    polhemus_rpa = np.loadtxt(op.join(data_dir, subject, 'polhemus/polhemus_rpa.txt'))
    polhemus_lpa = np.loadtxt(op.join(data_dir, subject, 'polhemus/polhemus_lpa.txt'))
    
    #  Get coreg filenames
    filenames = source_recon.rhino.get_coreg_filenames(recon_dir, subject)

    # Save
    np.savetxt(filenames["polhemus_nasion_file"], polhemus_nasion)
    np.savetxt(filenames["polhemus_rpa_file"], polhemus_rpa)
    np.savetxt(filenames["polhemus_lpa_file"], polhemus_lpa)
    np.savetxt(filenames["polhemus_headshape_file"], polhemus_headshape)

copy_polhemus_files(recon_dir, subjects[0], [], [], [])

sub1_polhemus_nasion = op.join(recon_dir, subjects[0], 'rhino/coreg/polhemus_nasion.txt')
print('E.g., the coordinates for the nasion for subject {} in Polhemus space are \n'.format(subjects[0]))
os.system('more {}'.format(sub1_polhemus_nasion))

#%%
# We can now call the *delete_headshape_points* function we have defined above. Note that we can call this in two different ways, either:
# 
# 1) Specify the subjects_dir AND the subject directory, in the directory structure used by RHINO:
# ``delete_headshape_points(recon_dir=recon_dir, subject=subject)``
# 
# 2) Specify the full path to the .npy file containing the (3 x num_headshapepoints) numpy array of headshape points:
# ``delete_headshape_points(polhemus_headshape_file=polhemus_headshape_file)``
# 
# Here, we want to use the first option. Let's now call the function we defined above:


delete_headshape_points(recon_dir, subjects[0])

#%%
# This brings up an interactive figure of the polhemus derived headshape points in polhemus space, as a scatter plot.
# 
# - The plot can be rotated (avoid clicking near an actual headshape point when doing this).
# - Headshape points can be clicked on to delete them.
# - The figure should be closed upon completion, at which point there is the option to save the changes.
# 
# Go ahead and delete some headshape points, after which close the figure, and then choose to save the file.
# Since we have worked on the headshape points file inside the RHINO directory structure, the saved file will be the one used by any subsequent coregistrations, e.g., when we call:
# 

source_recon.rhino.coreg(
    fif_files[0],
    recon_dir,
    subjects[0],
    use_headshape=True,    
    use_nose=True,
)
