#!/usr/bin/env python

"""
Pegasus utility for removing of files during workflow enactment

Usage: pegasus-cleanup [options]

"""

##
#  Copyright 2007-2011 University Of Southern California
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#  http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing,
#  software distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
##

import os
import re
import sys
import errno
import logging
import optparse
import tempfile
import subprocess
import signal
import string
import stat
import time
from collections import deque


__author__ = "Mats Rynge <rynge@isi.edu>"

# --- regular expressions -------------------------------------------------------------

re_parse_comment = re.compile(r'^# +[\w]+ +([\w-]+)')
re_parse_url = re.compile(r'([\w]+)://([\w\.\-:@]*)(/[\S]*)')

# --- classes -------------------------------------------------------------------------

class URL:

    sitename   = ""
    proto      = ""
    host       = ""
    path       = ""

    def set_sitename(self, sitename):
        # the site name is used to match against shell variables, so we have
        # have to replace dashes with underscores (as we do in the planner)
        self.sitename = string.replace(sitename, "-", "_")

    def set_url(self, url):
        self.proto, self.host, self.path = self.parse_url(url)
    
    def parse_url(self, url):

        proto = ""
        host  = ""
        path  = ""

        # default protocol is file://
        if string.find(url, ":") == -1:
            logger.debug("URL without protocol (" + url + ") - assuming file://")
            url = "file://" + url
        
        # translate symlink URLs to file:
        if string.find(url, "symlink:") == 0:
            url = re.sub("^symlink:", "file:", url)

        # file url is a special cases as it can contain relative paths and env vars
        if string.find(url, "file:") == 0:
            proto = "file"
            # file urls can either start with file://[\w]*/ or file: (no //)
            path = re.sub("^file:(//[\w\.\-:@]*)?", "", url)
            path = expand_env_vars(path)
            return proto, host, path

        # other than file urls
        r = re_parse_url.search(url)
        if not r:
            raise RuntimeError("Unable to parse URL: %s" % (url))
        
        # Parse successful
        proto = r.group(1)
        host = r.group(2)
        path = r.group(3)
        
        # no double slashes in urls
        path = re.sub('//+', '/', path)
        
        return proto, host, path

    def url(self):
        return "%s://%s%s" % (self.proto, self.host, self.path)
    
    def url_dirname(self):
        dn = os.path.dirname(self.path)
        return "%s://%s%s" % (self.proto, self.host, dn)


class Alarm(Exception):
    pass


# --- global variables ----------------------------------------------------------------

prog_base = os.path.split(sys.argv[0])[1]   # Name of this program
prog_dir = os.path.normpath(os.path.join(os.path.dirname(sys.argv[0]))) # path to bin 

logger = logging.getLogger("my_logger")

# this is the map of what tool to use for a given protocol pair (src, dest)
tool_map = {}
tool_map['file'    ] = 'rm'
tool_map['ftp'     ] = 'gsiftp'
tool_map['gsiftp'  ] = 'gsiftp'
tool_map['irods'   ] = 'irods'
tool_map['s3'      ] = 's3'
tool_map['s3s'     ] = 's3'
tool_map['scp'     ] = 'scp'
tool_map['srm'     ] = 'srm'
tool_map['symlink' ] = 'rm'

tool_info = {}


# --- functions -----------------------------------------------------------------------


def setup_logger(debug_flag):
    
    # log to the console
    console = logging.StreamHandler()
    
    # default log level - make logger/console match
    logger.setLevel(logging.INFO)
    console.setLevel(logging.INFO)

    # debug - from command line
    if debug_flag:
        logger.setLevel(logging.DEBUG)
        console.setLevel(logging.DEBUG)

    # formatter
    formatter = logging.Formatter("%(asctime)s %(levelname)7s:  %(message)s")
    console.setFormatter(formatter)
    logger.addHandler(console)
    logger.debug("Logger has been configured")


def prog_sigint_handler(signum, frame):
    logger.warn("Exiting due to signal %d" % (signum))
    myexit(1)


def alarm_handler(signum, frame):
    raise Alarm


def expand_env_vars(s):
    re_env_var = re.compile(r'\${?([a-zA-Z0-9_]+)}?')
    s = re.sub(re_env_var, get_env_var, s)
    return s


def get_env_var(match):
    name = match.group(1)
    value = ""
    logger.debug("Looking up " + name)
    if name in os.environ:
        value = os.environ[name]
    return value


def myexec(cmd_line, timeout_secs, should_log):
    """
    executes shell commands with the ability to time out if the command hangs
    """
    global delay_exit_code
    if should_log or logger.isEnabledFor(logging.DEBUG):
        logger.info(cmd_line)
    sys.stdout.flush()

    # set up signal handler for timeout
    signal.signal(signal.SIGALRM, alarm_handler)
    signal.alarm(timeout_secs)

    p = subprocess.Popen(cmd_line + " 2>&1", shell=True)
    try:
        stdoutdata, stderrdata = p.communicate()
    except Alarm:
        if sys.version_info >= (2, 6):
            p.terminate()
        raise RuntimeError("Command '%s' timed out after %s seconds" % (cmd_line, timeout_secs))
    rc = p.returncode
    if rc != 0:
        raise RuntimeError("Command '%s' failed with error code %s" % (cmd_line, rc))


def backticks(cmd_line):
    """
    what would a python program be without some perl love?
    """
    return subprocess.Popen(cmd_line, shell=True, stdout=subprocess.PIPE).communicate()[0]


def check_tool(executable, version_arg, version_regex):
    # initialize the global tool info for this executable
    tool_info[executable] = {}
    tool_info[executable]['full_path'] = None
    tool_info[executable]['version'] = None
    tool_info[executable]['version_major'] = None
    tool_info[executable]['version_minor'] = None
    tool_info[executable]['version_patch'] = None

    # figure out the full path to the executable
    full_path = backticks("which " + executable + " 2>/dev/null") 
    full_path = full_path.rstrip('\n')
    if full_path == "":
        logger.info("Command '%s' not found in the current environment" %(executable))
        return
    tool_info[executable]['full_path'] = full_path

    # version
    if version_regex == None:
        version = "N/A"
    else:
        version = backticks(executable + " " + version_arg + " 2>&1")
        version = version.replace('\n', "")
        re_version = re.compile(version_regex)
        result = re_version.search(version)
        if result:
            version = result.group(1)
        tool_info[executable]['version'] = version

    # if possible, break up version into major, minor, patch
    re_version = re.compile("([0-9]+)\.([0-9]+)(\.([0-9]+)){0,1}")
    result = re_version.search(version)
    if result:
        tool_info[executable]['version_major'] = int(result.group(1))
        tool_info[executable]['version_minor'] = int(result.group(2))
        tool_info[executable]['version_patch'] = result.group(4)
    if tool_info[executable]['version_patch'] == None or tool_info[executable]['version_patch'] == "":
        tool_info[executable]['version_patch'] = None
    else:
        tool_info[executable]['version_patch'] = int(tool_info[executable]['version_patch'])

    logger.info("  %-18s Version: %-7s Path: %s" % (executable, version, full_path))


def check_env_and_tools():
    
    # PATH setup
    path = "/usr/bin:/bin"
    if "PATH" in os.environ:
        path = os.environ['PATH']
    path_entries = path.split(':')
    
    # is /usr/bin in the path?
    if not("/usr/bin" in path_entries):
        path_entries.append("/usr/bin")
        path_entries.append("/bin")
       
    # fink on macos x
    if os.path.exists("/sw/bin") and not("/sw/bin" in path_entries):
        path_entries.append("/sw/bin")
    
    # add our own path (basic way to get to other Pegasus tools)
    my_bin_dir = os.path.normpath(os.path.join(os.path.dirname(sys.argv[0])))
    if not(my_bin_dir in path_entries):
        path_entries.append(my_bin_dir)
          
    # need LD_LIBRARY_PATH for Globus tools
    ld_library_path = ""
    if "LD_LIBRARY_PATH" in os.environ:
        ld_library_path = os.environ['LD_LIBRARY_PATH']
    ld_library_path_entries = ld_library_path.split(':')
    
    # if PEGASUS_HOME is set, prepend it to the PATH (we want it early to override other cruft)
    if "PEGASUS_HOME" in os.environ:
        try:
            path_entries.remove(os.environ['PEGASUS_HOME'] + "/bin")
        except Exception:
            pass
        path_entries.insert(0, os.environ['PEGASUS_HOME'] + "/bin")
    
    # if GLOBUS_LOCATION is set, prepend it to the PATH and LD_LIBRARY_PATH 
    # (we want it early to override other cruft)
    if "GLOBUS_LOCATION" in os.environ:
        try:
            path_entries.remove(os.environ['GLOBUS_LOCATION'] + "/bin")
        except Exception:
            pass
        path_entries.insert(0, os.environ['GLOBUS_LOCATION'] + "/bin")
        try:
            ld_library_path_entries.remove(os.environ['GLOBUS_LOCATION'] + "/lib")
        except Exception:
            pass
        ld_library_path_entries.insert(0, os.environ['GLOBUS_LOCATION'] + "/lib")

    os.environ['PATH'] = ":".join(path_entries)
    os.environ['LD_LIBRARY_PATH'] = ":".join(ld_library_path_entries)
    os.environ['DYLD_LIBRARY_PATH'] = ":".join(ld_library_path_entries)
    logger.info("PATH=" + os.environ['PATH'])
    logger.info("LD_LIBRARY_PATH=" + os.environ['LD_LIBRARY_PATH'])
    
    # irods requires a password hash file
    os.environ['irodsAuthFileName'] = os.getcwd() + "/.irodsA"
    
    # tools we might need later
    check_tool("lcg-del", "--version", "lcg_util-([\.0-9a-zA-Z]+)")
    check_tool("srm-rm", "-version", "srm-copy[ \t]+([\.0-9a-zA-Z]+)")
    check_tool("irm", "-h", "Version[ \t]+([\.0-9a-zA-Z]+)")
    check_tool("pegasus-gridftp", "", None)
    check_tool("pegasus-s3", "help", None)


def rm(urls):
    """
    removes locally using /bin/rm
    """
    for i, url in enumerate(urls):
        cmd = "/bin/rm"
        if options.ignore_failures:
            cmd += " -f"
        if options.recursive:
            cmd += " -r"
        cmd += " \"%s\"" % (url.path)
        try:
            myexec(cmd, 5*60, True)
        except RuntimeError, err:
            logger.error(err)
            raise

def scp(urls):
    """
    removes using ssh+rm
    """
    for i, url in enumerate(urls):
        cmd = "/usr/bin/ssh"
        key = "SSH_PRIVATE_KEY_" + url.sitename
        if key in os.environ:
            cmd += " -i " + os.environ[key]
        elif "SSH_PRIVATE_KEY" in os.environ:
            cmd += " -i " + os.environ['SSH_PRIVATE_KEY']
        cmd += " -o PasswordAuthentication=no"
        cmd += " -o StrictHostKeyChecking=no" + \
               " " + url.host + " " + \
               " \"/bin/rm"
        if options.ignore_failures:
            cmd += " -f"
        if options.recursive:
            cmd += " -r"       
        cmd += " " + url.path + "\""
        try:
            myexec(cmd, 5*60, True)
        except RuntimeError, err:
            logger.error(err)
            raise

def gsiftp(urls):
    """
    remove files on gridftp servers - delegate run to pegasus-gridftp
    """
 
    try:
        tmp_fd, tmp_name = tempfile.mkstemp(prefix="pegasus-cleanup-", suffix=".lst", dir="/tmp")
        tmp_file = os.fdopen(tmp_fd, "w+b")
    except:
        raise RuntimeError("Unable to create tmp file for pegasus-gridftp cleanup")
    for i, url in enumerate(urls):
        tmp_file.write("%s\n" %(url.url()))

    tmp_file.close()
    
    key = "X509_USER_PROXY_" + urls[0].sitename
    if key in os.environ:
        os.environ["X509_USER_PROXY"] = os.environ[key]
    
    # use pegasus-transfer
    cmd = prog_dir + "/pegasus-gridftp rm"
    if options.ignore_failures:
        cmd += " -f"
    if options.recursive:
        cmd += " -r"
            
    # make output match our current log level
    if logger.isEnabledFor(logging.DEBUG):
        cmd += " -v"

    cmd += " -i " + tmp_name

    try:
        myexec(cmd, 1*60*60, True)
    except RuntimeError, err:
        logger.error(err)
        os.unlink(tmp_name)
        raise

    os.unlink(tmp_name)


def irods_login(sitename):
    """
    log in to irods by using the iinit command - if the file already exists,
    we are already logged in
    """
    key = "irodsEnvFile_" + sitename
    if key in os.environ:
        os.environ["irodsEnvFile"] = os.environ[key]
    
    f = os.environ['irodsAuthFileName']
    if os.path.exists(f):
        return
    
    # read password from env file
    if not "irodsEnvFile" in os.environ:
        raise RuntimeError("Missing irodsEnvFile - unable to do irods transfers")
    password = None
    h = open(os.environ['irodsEnvFile'], 'r')
    for line in h:
        items = line.split(" ", 2)
        if items[0].lower() == "irodspassword":
            password = items[1].strip(" \t'\"\r\n")
    h.close()
    if password == None:
        raise RuntimeError("No irodsPassword specified in irods env file")
    
    h = open(".irodsAc", "w")
    h.write(password + "\n")
    h.close()
    
    cmd = "cat .irodsAc | iinit"
    myexec(cmd, 60*60, True)
        
    os.unlink(".irodsAc")


def irods(urls):
    """
    irods - use the icommands to interact with irods
    """

    if tool_info['irm']['full_path'] == None:
        raise RuntimeError("Unable to do irods transfers becuase iget could" +
                           "not be found in the current path")

    # log in to irods
    try:
        irods_login(urls[0].sitename)
    except Exception, loginErr:
        logger.error(loginErr)
        raise RuntimeError("Unable to log into irods")

    for i, url in enumerate(urls):
        cmd = "irm -f " + url.path
        if options.recursive:
            cmd += " -r"
        try:
            myexec(cmd, 5*60, True)
        except Exception, err:
            logger.error(err)
            raise

def srm(urls):
    """
    srm - use lcg-del or srm-rm
    """

    key = "X509_USER_PROXY_" + urls[0].sitename
    if key in os.environ:
        os.environ["X509_USER_PROXY"] = os.environ[key]

    base_cmd = ""
    if tool_info['lcg-del']['full_path'] != None:
        base_cmd = "lcg-del -b -l -D srmv2 --force"
    elif tool_info['srm-rm']['full_path'] != None:
        base_cmd = "srm-rm"
    else:
        raise RuntimeError("Unable to do srm remove becuase lcg-del/srm-rm could not be found")

    for i, url in enumerate(urls):
        cmd = base_cmd + " " + url.url()
        
        try:
            myexec(cmd, 5*60, True)
        except Exception, err:
            logger.error(err)
            raise

def s3(urls):
    """
    s3 - uses pegasus-s3 to interact with Amazon S3 
    """

    if tool_info['pegasus-s3']['full_path'] == None:
        raise RuntimeError("Unable to do S3 transfers becuase pegasus-s3 could not be found")

    for i, url in enumerate(urls):
        
        key = "S3CFG_" + url.sitename
        if key in os.environ:
            os.environ["S3CFG"] = os.environ[key]
        
        cmd = "pegasus-s3 rm"
        if options.ignore_failures:
            cmd += " -f"
        cmd += " " + url.url()

        try:
            myexec(cmd, 60, True)
        except Exception, err:
            logger.error(err)
            raise


def urls_groupable(a, b):
    """
    compares two urls, and determins if they are similar enough to be
    grouped together for one tool
    """
    if a.sitename != b.sitename or a.proto != b.proto:
        return False
    return True


def handle_removes(urls):
    """
    removes the file with the given url
    """
    try:
        if tool_map.has_key(urls[0].proto):
            tool = tool_map[urls[0].proto]
            if tool == "rm":
                rm(urls)
            elif tool == "scp":
                scp(urls)
            elif tool == "gsiftp":
                gsiftp(urls)
            elif tool == "irods":
                irods(urls)
            elif tool == "srm":
                srm(urls)
            elif tool == "s3":
                s3(urls)
            else:
                logger.critical("Error: No mapping for the tool '%s'" %(tool))
                myexit(1)
        else:
            logger.critical("Error: This tool does not know how to remove from %s://" % (url.proto))
            myexit(1)

    except RuntimeError, err:
        raise


def myexit(rc):
    """
    system exit without a stack trace - silly python
    """
    try:
        sys.exit(rc)
    except SystemExit:
        sys.exit(rc)


# --- main ----------------------------------------------------------------------------

# dup stderr onto stdout
sys.stderr = sys.stdout

# Configure command line option parser
prog_usage = "usage: %s [options]" % (prog_base)
parser = optparse.OptionParser(usage=prog_usage)
parser.add_option("-f", "--file", action = "store", dest = "file",
                  help = "File containing URLs to be removed. If not given, list is read from stdin.")
parser.add_option("-i", "--ignore-failures", action = "store_true", dest = "ignore_failures",
                  help = "When specified, the tool ignores any failures and always exits with 0.")
parser.add_option("-r", "--recursive", action = "store_true", dest = "recursive",
                  help = "Enable recursive removal for tools that support it.")
parser.add_option("-d", "--debug", action = "store_true", dest = "debug",
                  help = "Enables debugging ouput.")

# Parse command line options
(options, args) = parser.parse_args()
setup_logger(options.debug)

# Die nicely when asked to (Ctrl+C, system shutdown)
signal.signal(signal.SIGINT, prog_sigint_handler)

# stdin or file input?
if options.file == None:
    logger.info("Reading URL pairs from stdin")
    input_file = sys.stdin
else:
    logger.info("Reading URL pairs from %s" % (options.file))
    try:
        input_file = open(options.file, 'r')
    except Exception, err:
        logger.critical('Error reading url pair list: %s' % (err))
        myexit(1)

# check environment and tools
try:
    check_env_and_tools()
except Exception, err:
    logger.critical(err)
    myexit(1)

# list of work
url_q = deque()

# fill the url queue with user provided entries
line_nr = 0
sitename = ""
try:
    for line in input_file.readlines():
        line_nr += 1
        if len(line) > 3:
            line = line.rstrip('\n')
            if line[0] == '#':
                r = re_parse_comment.search(line)
                if r:
                    sitename =  r.group(1)
                else:
                    logger.critical('Unable to parse comment on line %d'
                                    %(line_nr))
                    myexit(1)
            else:
                url = URL()
                url.set_sitename(sitename)
                url.set_url(line)
                url_q.append(url)
                sitename = ""
except Exception, err:
    logger.critical('Error handling url: %s' % (err))
    myexit(1)

# do the removals
exit_code = 0
while url_q:

    u_main = url_q.popleft()

    # create a list of urls to pass to underlying tool
    u_list = []
    u_list.append(u_main)

    try:
        u_next = url_q[0]
    except IndexError, err:
        u_next = False
    while u_next and urls_groupable(u_main, u_next):
        u_list.append(u_next)
        url_q.popleft()
        try:
            u_next = url_q[0]
        except IndexError, err:
            u_next = False

    # magic!
    try:
        handle_removes(u_list)
    except:
        exit_code = 1

# sometimes we just want to ignore the failures
if options.ignore_failures:
    exit_code = 0

myexit(exit_code)


