scheduler.py
author Fabien Ninoles <fabien@tzone.org>
Sun, 21 Mar 2010 21:40:40 -0400
changeset 4 76ba9b3a9e1c
parent 3 00b6708d1852
child 5 eb1133af54ed
permissions -rwxr-xr-x
Add GrabResult and SetupEvent, for synchronicity.

#!/usr/bin/env python2.6
import sys
import logging
import threading
import Queue
from threadpool import ThreadPool

##
# A task, representing a series of linked pairs of callback and error
# handler Concept is similar to twisted, but you need to put all your
# callback on the task before giving it to the Scheduler.  For this
# reason, you shouldnt call the callback/errorback method yourself (or
# at least, don't put it back in the scheduler after that).

class _Callback(object):
    def __init__(self, callback, errorback, threaded = False):
        self.callback = callback
        self.errorback = errorback
        self.next = None
        self.threaded = threaded
    def Chain(self, next):
        # Can only be called once
        assert(self.next is None)
        self.next = next
        return next
    def Next(self):
        return self.next

class Task(object):

    @staticmethod
    def DefaultCallback(result):
        return result

    @staticmethod
    def DefaultErrorback(error):
        return error

    def __init__(self, func = None, *args, **kwargs):
        super(Task, self).__init__()
        self.head = None
        self.tail = None
        if func:
            def callback(result):
                return func(*args, **kwargs)
            self.AddCallback(callback)

    def _AddCallback(self, callback):
        if self.head is None:
            self.head = self.tail = callback
        else:
            self.tail = self.tail.Chain(callback)

    def _GetNext(self):
        head = self.head
        if head:
            self.head = head.Next()
        return head

    def AddCallback(self, callback, errorback = None, threaded = False):
        if errorback == None:
            errorback = self.DefaultErrorback
        cb = _Callback(callback, errorback, threaded)
        self._AddCallback(cb)
        # permit chained calls
        return self

    def AddThreadedCallback(self, callback, errorback = None):
        return self.AddCallback(callback, errorback, True)

    def ChainTask(self, task):
        self.tail.Chain(task.head)
        self.tail = task.tail

    def GrabResult(self, data = None):
        if data is None:
            data = {}
        def SetResult(result):
            data["result"] = result
            return result
        def SetError(error):
            data["error"] = error
        self.AddCallback(SetResult, SetError)
        return data

    def SetupEvent(self, event = None, data = None):
        if not event:
            event = threading.Event()
        def SetEvent(dummy):
            event.set()
        self.AddCallback(SetEvent, SetEvent)
        return event

##
# Helper class
class ThreadedTask(Task):
    def __init__(self, func, *args, **kwargs):
        super(ThreadedTask, self).__init__(func, *args, **kwargs)
        self.head.threaded = True

class Scheduler(threading.Thread):

    class _SchedulerStop(Exception):
        pass

    def __init__(self, poolSize):
        threading.Thread.__init__(self, name = "Scheduler", target = self.Run)
        self.pool = ThreadPool(poolSize)
        self.tasks = Queue.Queue()

    def ExecuteOne(self, blocking = True):
        logging.debug("Looking for next task...")
        try:
            task = self.tasks.get(blocking)
        except Queue.Empty:
            logging.debug("No task to run")
            return None
        result = None
        error = None
        traceback = None
        while True:
            cb = task._GetNext()
            if not cb:
                # no more callback
                break
            if cb.threaded:
                # a threaded callback
                self._AddJob(task, cb, result, error, traceback)
                # don't pass Go, don't reclaim $200
                return None
            # Run the callback according to the current state
            try:
                if error:
                    error = cb.errorback(error)
                else:
                    result = cb.callback(result)
            except:
                errtype, error, traceback = sys.exc_info()
        if error:
            raise error, None, traceback
        else:
            return result

    def Run(self):
        logging.info("Scheduler start")
        while True:
            try:
                self.ExecuteOne()
            except self._SchedulerStop:
                break
            except:
                logging.exception("Unhandled task exception")
        logging.info("Scheduler stop")

    def Start(self):
        self.pool.Start()
        return self.start()

    def Stop(self, now = False):
        self.pool.Stop(now)
        if now:
            self.tasks = Queue.Queue()
        # We raise an exception to find if we stop stop the scheduler.
        # We could have use a None task, but this make it easier if we
        # want to add such mechanism public or we want to stop on
        # other exception
        def RaiseSchedulerStop():
            raise self._SchedulerStop
        self.AddTask(Task(RaiseSchedulerStop))
        self.join()

    def AddTask(self, task):
        self.tasks.put(task)

    def _AddJob(self, task, cb, result, error, traceback):

        def DoIt(task, cb, result, error, traceback):
            try:
                if error:
                    error = cb.errorback(error)
                else:
                    result = cb.callback(result)
            except:
                errtype, error, traceback = sys.exc_info()
            if error:
                def RaiseError():
                    raise error, None, traceback
                jobTask = Task(RaiseError)
            else:
                def ReturnResult():
                    return result
                jobTask = Task(ReturnResult)
            jobTask.ChainTask(task)
            self.AddTask(jobTask)

        # This double wrap (Job over DoIt) seems necessary to make
        # error not look like a local of Job...
        def Job():
            return DoIt(task, cb, result, error, traceback)
        self.pool.AddJob(Job)

# The global scheduler
scheduler = None

def StartScheduler(size):
    global scheduler
    if scheduler:
        StopScheduler()
    scheduler = Scheduler(size)
    scheduler.Start()

def StopScheduler(now = False):
    global scheduler
    if scheduler:
        scheduler.Stop(now)
    scheduler = None

if __name__ == '__main__':
    from time import sleep
    logging.getLogger().setLevel(logging.INFO)
    # This function is a sample and shouldn't know about the scheduler
    count = 0
    def AsyncCall(name, seconds):
        global count
        count += 1
        
        # Probably a bad example, since the callback
        # doesn't return the exact same type...
        def Initialize(name, seconds):
            print "Here", name
            return name, seconds
        def Blocking(args):
            name, time = args
            print name, "goes to bed"
            sleep(time)
            print name, ": ZZZ..."
            return name
        def Finalize(name):
            global count
            print name, "wakes up!"
            count -= 1
            return name

        task = Task(Initialize, name, seconds)
        task.AddThreadedCallback(Blocking)
        task.AddCallback(Finalize)
        return task

    logging.info("Starting scheduler with 10 workers")
    StartScheduler(10)
    logging.info("Adding asynccall task")
    for x in xrange(int(sys.argv[1])):
        task = AsyncCall("Toto%d" % (x+1), (x % 10)/10.0)
        scheduler.AddTask(task)
    while count > 0:
        logging.debug("Count = %d", count)
        sleep(1)

    # Check for King Toto sleep
    task = AsyncCall("King Toto", 5)
    data = task.GrabResult()
    event = task.SetupEvent()
    scheduler.AddTask(task)
    try:
        event.wait(10)
        print "data = %r" % (data,)
    except:
        logging.exception("Error occured on wait")
    logging.info("Stopping scheduler")
    StopScheduler()
    logging.info("The End.")