binder

Interval based time series classification in aeon

Interval based approaches look at phase dependant intervals of the full series, calculating summary statistics from selected subseries to be used in classification.

Current interval based approaches are implemented in aeon. Time Series Forest (TSF) [1], the Random Interval Spectral Ensemble (RISE) [2], Supervised Time Series Forest (STSF) [3] and Random STSF [4], the Canonical Interval Forest (CIF) [5] the Diverse Representation Canonical Interval Forest (DrCIF) and QUANT [6]. Most have the capability to classify multivariate series.

In this notebook, we will demonstrate how to use these classifiers on the ItalyPowerDemand and BasicMotions datasets.

Set up

We can list all classifiers of this category like this

[3]:
import warnings

from sklearn import metrics

from aeon.classification.interval_based import (
    RSTSF,
    CanonicalIntervalForestClassifier,
    DrCIFClassifier,
    QUANTClassifier,
    RandomIntervalSpectralEnsembleClassifier,
    SupervisedTimeSeriesForest,
    TimeSeriesForestClassifier,
)
from aeon.datasets import load_basic_motions, load_italy_power_demand
from aeon.utils.discovery import all_estimators

warnings.filterwarnings("ignore")
all_estimators("classifier", tag_filter={"algorithm_type": "interval"})
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In[3], line 18
     15 from aeon.utils.discovery import all_estimators
     17 warnings.filterwarnings("ignore")
---> 18 all_estimators("classifier", tag_filter={"algorithm_type": "interval"})

File C:\Code\aeon\aeon\utils\discovery.py:121, in all_estimators(type_filter, exclude_types, tag_filter, exclude_tags, include_sklearn, return_names)
    116 if (
    117     any(part in modules_to_ignore for part in module_parts)
    118     or "._" in module_name
    119 ):
    120     continue
--> 121 module = import_module(module_name)
    123 classes = inspect.getmembers(module, inspect.isclass)
    124 # skip private estimators and those not implemented in aeon

File ~\AppData\Local\Programs\Python\Python39\lib\importlib\__init__.py:127, in import_module(name, package)
    125             break
    126         level += 1
--> 127 return _bootstrap._gcd_import(name[level:], package, level)

File <frozen importlib._bootstrap>:1030, in _gcd_import(name, package, level)

File <frozen importlib._bootstrap>:1007, in _find_and_load(name, import_)

File <frozen importlib._bootstrap>:986, in _find_and_load_unlocked(name, import_)

File <frozen importlib._bootstrap>:680, in _load_unlocked(spec)

File <frozen importlib._bootstrap_external>:850, in exec_module(self, module)

File <frozen importlib._bootstrap>:228, in _call_with_frames_removed(f, *args, **kwds)

File C:\Code\aeon\aeon\classification\convolution_based\__init__.py:12
      1 """Convolution-based time series classifiers."""
      3 __all__ = [
      4     "RocketClassifier",
      5     "MiniRocketClassifier",
   (...)
      9     "MultiRocketHydraClassifier",
     10 ]
---> 12 from aeon.classification.convolution_based._arsenal import Arsenal
     13 from aeon.classification.convolution_based._hydra import HydraClassifier
     14 from aeon.classification.convolution_based._minirocket import MiniRocketClassifier

File C:\Code\aeon\aeon\classification\convolution_based\_arsenal.py:20
     18 from aeon.base._base import _clone_estimator
     19 from aeon.classification.base import BaseClassifier
---> 20 from aeon.transformations.collection.convolution_based import (
     21     MiniRocket,
     22     MultiRocket,
     23     Rocket,
     24 )
     27 class Arsenal(BaseClassifier):
     28     """
     29     Arsenal ensemble.
     30
   (...)
    120     >>> y_pred = clf.predict(X_test)
    121     """

File C:\Code\aeon\aeon\transformations\collection\convolution_based\__init__.py:13
     11 from ._hydra import HydraTransformer
     12 from ._minirocket import MiniRocket
---> 13 from ._minirocket_mv import MiniRocketMultivariateVariable
     14 from ._multirocket import MultiRocket
     15 from ._rocket import Rocket

File C:\Code\aeon\aeon\transformations\collection\convolution_based\_minirocket_mv.py:303
    290     return X_2d_t, lengths
    293 # code below from the orignal authors: https://github.com/angus924/minirocket
    296 @njit(
    297     "float32[:](float32[:,:],int32[:],int32[:],int32[:],int32[:],int32[:],float32[:],"
    298     "optional(int32))",
    299     fastmath=True,
    300     parallel=False,
    301     cache=True,
    302 )
--> 303 def _fit_biases_multi_var(
    304     X,
    305     L,
    306     num_channels_per_combination,
    307     channel_indices,
    308     dilations,
    309     num_features_per_dilation,
    310     quantiles,
    311     seed,
    312 ):
    313     if seed is not None:
    314         np.random.seed(seed)

File C:\Code\aeon\venv\lib\site-packages\numba\core\decorators.py:232, in _jit.<locals>.wrapper(func)
    230     with typeinfer.register_dispatcher(disp):
    231         for sig in sigs:
--> 232             disp.compile(sig)
    233         disp.disable_compile()
    234 return disp

File C:\Code\aeon\venv\lib\site-packages\numba\core\dispatcher.py:905, in Dispatcher.compile(self, sig)
    903 with ev.trigger_event("numba:compile", data=ev_details):
    904     try:
--> 905         cres = self._compiler.compile(args, return_type)
    906     except errors.ForceLiteralArg as e:
    907         def folded(args, kws):

File C:\Code\aeon\venv\lib\site-packages\numba\core\dispatcher.py:80, in _FunctionCompiler.compile(self, args, return_type)
     79 def compile(self, args, return_type):
---> 80     status, retval = self._compile_cached(args, return_type)
     81     if status:
     82         return retval

File C:\Code\aeon\venv\lib\site-packages\numba\core\dispatcher.py:94, in _FunctionCompiler._compile_cached(self, args, return_type)
     91     pass
     93 try:
---> 94     retval = self._compile_core(args, return_type)
     95 except errors.TypingError as e:
     96     self._failed_cache[key] = e

File C:\Code\aeon\venv\lib\site-packages\numba\core\dispatcher.py:107, in _FunctionCompiler._compile_core(self, args, return_type)
    104 flags = self._customize_flags(flags)
    106 impl = self._get_implementation(args, {})
--> 107 cres = compiler.compile_extra(self.targetdescr.typing_context,
    108                               self.targetdescr.target_context,
    109                               impl,
    110                               args=args, return_type=return_type,
    111                               flags=flags, locals=self.locals,
    112                               pipeline_class=self.pipeline_class)
    113 # Check typing error if object mode is used
    114 if cres.typing_error is not None and not flags.enable_pyobject:

File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler.py:744, in compile_extra(typingctx, targetctx, func, args, return_type, flags, locals, library, pipeline_class)
    720 """Compiler entry point
    721
    722 Parameter
   (...)
    740     compiler pipeline
    741 """
    742 pipeline = pipeline_class(typingctx, targetctx, library,
    743                           args, return_type, flags, locals)
--> 744 return pipeline.compile_extra(func)

File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler.py:438, in CompilerBase.compile_extra(self, func)
    436 self.state.lifted = ()
    437 self.state.lifted_from = None
--> 438 return self._compile_bytecode()

File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler.py:506, in CompilerBase._compile_bytecode(self)
    502 """
    503 Populate and run pipeline for bytecode input
    504 """
    505 assert self.state.func_ir is None
--> 506 return self._compile_core()

File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler.py:472, in CompilerBase._compile_core(self)
    470 res = None
    471 try:
--> 472     pm.run(self.state)
    473     if self.state.cr is not None:
    474         break

File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler_machinery.py:356, in PassManager.run(self, state)
    354 pass_inst = _pass_registry.get(pss).pass_inst
    355 if isinstance(pass_inst, CompilerPass):
--> 356     self._runPass(idx, pass_inst, state)
    357 else:
    358     raise BaseException("Legacy pass in use")

File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler_lock.py:35, in _CompilerLock.__call__.<locals>._acquire_compile_lock(*args, **kwargs)
     32 @functools.wraps(func)
     33 def _acquire_compile_lock(*args, **kwargs):
     34     with self:
---> 35         return func(*args, **kwargs)

File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler_machinery.py:311, in PassManager._runPass(self, index, pss, internal_state)
    309     mutated |= check(pss.run_initialization, internal_state)
    310 with SimpleTimer() as pass_time:
--> 311     mutated |= check(pss.run_pass, internal_state)
    312 with SimpleTimer() as finalize_time:
    313     mutated |= check(pss.run_finalizer, internal_state)

File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler_machinery.py:273, in PassManager._runPass.<locals>.check(func, compiler_state)
    272 def check(func, compiler_state):
--> 273     mangled = func(compiler_state)
    274     if mangled not in (True, False):
    275         msg = ("CompilerPass implementations should return True/False. "
    276                "CompilerPass with name '%s' did not.")

File C:\Code\aeon\venv\lib\site-packages\numba\core\typed_passes.py:112, in BaseTypeInference.run_pass(self, state)
    106 """
    107 Type inference and legalization
    108 """
    109 with fallback_context(state, 'Function "%s" failed type inference'
    110                       % (state.func_id.func_name,)):
    111     # Type inference
--> 112     typemap, return_type, calltypes, errs = type_inference_stage(
    113         state.typingctx,
    114         state.targetctx,
    115         state.func_ir,
    116         state.args,
    117         state.return_type,
    118         state.locals,
    119         raise_errors=self._raise_errors)
    120     state.typemap = typemap
    121     # save errors in case of partial typing

File C:\Code\aeon\venv\lib\site-packages\numba\core\typed_passes.py:93, in type_inference_stage(typingctx, targetctx, interp, args, return_type, locals, raise_errors)
     91     infer.build_constraint()
     92     # return errors in case of partial typing
---> 93     errs = infer.propagate(raise_errors=raise_errors)
     94     typemap, restype, calltypes = infer.unify(raise_errors=raise_errors)
     96 return _TypingResults(typemap, restype, calltypes, errs)

File C:\Code\aeon\venv\lib\site-packages\numba\core\typeinfer.py:1083, in TypeInferer.propagate(self, raise_errors)
   1080 oldtoken = newtoken
   1081 # Errors can appear when the type set is incomplete; only
   1082 # raise them when there is no progress anymore.
-> 1083 errors = self.constraints.propagate(self)
   1084 newtoken = self.get_state_token()
   1085 self.debug.propagate_finished()

File C:\Code\aeon\venv\lib\site-packages\numba\core\typeinfer.py:160, in ConstraintNetwork.propagate(self, typeinfer)
    157 with typeinfer.warnings.catch_warnings(filename=loc.filename,
    158                                        lineno=loc.line):
    159     try:
--> 160         constraint(typeinfer)
    161     except ForceLiteralArg as e:
    162         errors.append(e)

File C:\Code\aeon\venv\lib\site-packages\numba\core\typeinfer.py:583, in CallConstraint.__call__(self, typeinfer)
    581     fnty = typevars[self.func].getone()
    582 with new_error_context("resolving callee type: {0}", fnty):
--> 583     self.resolve(typeinfer, typevars, fnty)

File C:\Code\aeon\venv\lib\site-packages\numba\core\typeinfer.py:606, in CallConstraint.resolve(self, typeinfer, typevars, fnty)
    604     fnty = fnty.instance_type
    605 try:
--> 606     sig = typeinfer.resolve_call(fnty, pos_args, kw_args)
    607 except ForceLiteralArg as e:
    608     # Adjust for bound methods
    609     folding_args = ((fnty.this,) + tuple(self.args)
    610                     if isinstance(fnty, types.BoundFunction)
    611                     else self.args)

File C:\Code\aeon\venv\lib\site-packages\numba\core\typeinfer.py:1577, in TypeInferer.resolve_call(self, fnty, pos_args, kw_args)
   1574     return sig
   1575 else:
   1576     # Normal non-recursive call
-> 1577     return self.context.resolve_function_type(fnty, pos_args, kw_args)

File C:\Code\aeon\venv\lib\site-packages\numba\core\typing\context.py:196, in BaseContext.resolve_function_type(self, func, args, kws)
    194 # Prefer user definition first
    195 try:
--> 196     res = self._resolve_user_function_type(func, args, kws)
    197 except errors.TypingError as e:
    198     # Capture any typing error
    199     last_exception = e

File C:\Code\aeon\venv\lib\site-packages\numba\core\typing\context.py:248, in BaseContext._resolve_user_function_type(self, func, args, kws, literals)
    244         return self.resolve_function_type(func_type, args, kws)
    246 if isinstance(func, types.Callable):
    247     # XXX fold this into the __call__ attribute logic?
--> 248     return func.get_call_type(self, args, kws)

File C:\Code\aeon\venv\lib\site-packages\numba\core\types\functions.py:308, in BaseFunction.get_call_type(self, context, args, kws)
    305         nolitargs = tuple([_unlit_non_poison(a) for a in args])
    306         nolitkws = {k: _unlit_non_poison(v)
    307                     for k, v in kws.items()}
--> 308         sig = temp.apply(nolitargs, nolitkws)
    309 except Exception as e:
    310     if (utils.use_new_style_errors() and not
    311             isinstance(e, errors.NumbaError)):

File C:\Code\aeon\venv\lib\site-packages\numba\core\typing\templates.py:350, in AbstractTemplate.apply(self, args, kws)
    348 def apply(self, args, kws):
    349     generic = getattr(self, "generic")
--> 350     sig = generic(args, kws)
    351     # Enforce that *generic()* must return None or Signature
    352     if sig is not None:

File C:\Code\aeon\venv\lib\site-packages\numba\core\typing\templates.py:613, in _OverloadFunctionTemplate.generic(self, args, kws)
    607 """
    608 Type the overloaded function by compiling the appropriate
    609 implementation for the given args.
    610 """
    611 from numba.core.typed_passes import PreLowerStripPhis
--> 613 disp, new_args = self._get_impl(args, kws)
    614 if disp is None:
    615     return

File C:\Code\aeon\venv\lib\site-packages\numba\core\typing\templates.py:712, in _OverloadFunctionTemplate._get_impl(self, args, kws)
    708 except KeyError:
    709     # pass and try outside the scope so as to not have KeyError with a
    710     # nested addition error in the case the _build_impl fails
    711     pass
--> 712 impl, args = self._build_impl(cache_key, args, kws)
    713 return impl, args

File C:\Code\aeon\venv\lib\site-packages\numba\core\typing\templates.py:816, in _OverloadFunctionTemplate._build_impl(self, cache_key, args, kws)
    814 # Make sure that the implementation can be fully compiled
    815 disp_type = types.Dispatcher(disp)
--> 816 disp_type.get_call_type(self.context, args, kws)
    817 if cache_key is not None:
    818     self._impl_cache[cache_key] = disp, args

File C:\Code\aeon\venv\lib\site-packages\numba\core\types\functions.py:541, in Dispatcher.get_call_type(self, context, args, kws)
    534 def get_call_type(self, context, args, kws):
    535     """
    536     Resolve a call to this dispatcher using the given argument types.
    537     A signature returned and it is ensured that a compiled specialization
    538     is available for it.
    539     """
    540     template, pysig, args, kws = \
--> 541         self.dispatcher.get_call_template(args, kws)
    542     sig = template(context).apply(args, kws)
    543     if sig:

File C:\Code\aeon\venv\lib\site-packages\numba\core\dispatcher.py:318, in _DispatcherBase.get_call_template(self, args, kws)
    316 # Ensure an overload is available
    317 if self._can_compile:
--> 318     self.compile(tuple(args))
    320 # Create function type for typing
    321 func_name = self.py_func.__name__

File C:\Code\aeon\venv\lib\site-packages\numba\core\dispatcher.py:905, in Dispatcher.compile(self, sig)
    903 with ev.trigger_event("numba:compile", data=ev_details):
    904     try:
--> 905         cres = self._compiler.compile(args, return_type)
    906     except errors.ForceLiteralArg as e:
    907         def folded(args, kws):

File C:\Code\aeon\venv\lib\site-packages\numba\core\dispatcher.py:80, in _FunctionCompiler.compile(self, args, return_type)
     79 def compile(self, args, return_type):
---> 80     status, retval = self._compile_cached(args, return_type)
     81     if status:
     82         return retval

File C:\Code\aeon\venv\lib\site-packages\numba\core\dispatcher.py:94, in _FunctionCompiler._compile_cached(self, args, return_type)
     91     pass
     93 try:
---> 94     retval = self._compile_core(args, return_type)
     95 except errors.TypingError as e:
     96     self._failed_cache[key] = e

File C:\Code\aeon\venv\lib\site-packages\numba\core\dispatcher.py:107, in _FunctionCompiler._compile_core(self, args, return_type)
    104 flags = self._customize_flags(flags)
    106 impl = self._get_implementation(args, {})
--> 107 cres = compiler.compile_extra(self.targetdescr.typing_context,
    108                               self.targetdescr.target_context,
    109                               impl,
    110                               args=args, return_type=return_type,
    111                               flags=flags, locals=self.locals,
    112                               pipeline_class=self.pipeline_class)
    113 # Check typing error if object mode is used
    114 if cres.typing_error is not None and not flags.enable_pyobject:

File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler.py:744, in compile_extra(typingctx, targetctx, func, args, return_type, flags, locals, library, pipeline_class)
    720 """Compiler entry point
    721
    722 Parameter
   (...)
    740     compiler pipeline
    741 """
    742 pipeline = pipeline_class(typingctx, targetctx, library,
    743                           args, return_type, flags, locals)
--> 744 return pipeline.compile_extra(func)

File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler.py:438, in CompilerBase.compile_extra(self, func)
    436 self.state.lifted = ()
    437 self.state.lifted_from = None
--> 438 return self._compile_bytecode()

File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler.py:506, in CompilerBase._compile_bytecode(self)
    502 """
    503 Populate and run pipeline for bytecode input
    504 """
    505 assert self.state.func_ir is None
--> 506 return self._compile_core()

File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler.py:472, in CompilerBase._compile_core(self)
    470 res = None
    471 try:
--> 472     pm.run(self.state)
    473     if self.state.cr is not None:
    474         break

File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler_machinery.py:356, in PassManager.run(self, state)
    354 pass_inst = _pass_registry.get(pss).pass_inst
    355 if isinstance(pass_inst, CompilerPass):
--> 356     self._runPass(idx, pass_inst, state)
    357 else:
    358     raise BaseException("Legacy pass in use")

File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler_lock.py:35, in _CompilerLock.__call__.<locals>._acquire_compile_lock(*args, **kwargs)
     32 @functools.wraps(func)
     33 def _acquire_compile_lock(*args, **kwargs):
     34     with self:
---> 35         return func(*args, **kwargs)

File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler_machinery.py:311, in PassManager._runPass(self, index, pss, internal_state)
    309     mutated |= check(pss.run_initialization, internal_state)
    310 with SimpleTimer() as pass_time:
--> 311     mutated |= check(pss.run_pass, internal_state)
    312 with SimpleTimer() as finalize_time:
    313     mutated |= check(pss.run_finalizer, internal_state)

File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler_machinery.py:273, in PassManager._runPass.<locals>.check(func, compiler_state)
    272 def check(func, compiler_state):
--> 273     mangled = func(compiler_state)
    274     if mangled not in (True, False):
    275         msg = ("CompilerPass implementations should return True/False. "
    276                "CompilerPass with name '%s' did not.")

File C:\Code\aeon\venv\lib\site-packages\numba\core\typed_passes.py:497, in BaseNativeLowering.run_pass(self, state)
    491     state['cr'] = _LowerResult(fndesc, call_helper,
    492                                cfunc=None, env=env)
    493 else:
    494     # Prepare for execution
    495     # Insert native function for use by other jitted-functions.
    496     # We also register its library to allow for inlining.
--> 497     cfunc = targetctx.get_executable(library, fndesc, env)
    498     targetctx.insert_user_function(cfunc, fndesc, [library])
    499     state['cr'] = _LowerResult(fndesc, call_helper,
    500                                cfunc=cfunc, env=env)

File C:\Code\aeon\venv\lib\site-packages\numba\core\cpu.py:239, in CPUContext.get_executable(self, library, fndesc, env)
    226 """
    227 Returns
    228 -------
   (...)
    236     an execution environment (from _dynfunc)
    237 """
    238 # Code generation
--> 239 fnptr = library.get_pointer_to_function(
    240     fndesc.llvm_cpython_wrapper_name)
    242 # Note: we avoid reusing the original docstring to avoid encoding
    243 # issues on Python 2, see issue #1908
    244 doc = "compiled wrapper for %r" % (fndesc.qualname,)

File C:\Code\aeon\venv\lib\site-packages\numba\core\codegen.py:989, in JITCodeLibrary.get_pointer_to_function(self, name)
    975 def get_pointer_to_function(self, name):
    976     """
    977     Generate native code for function named *name* and return a pointer
    978     to the start of the function (as an integer).
   (...)
    987         - non-zero if the symbol is defined.
    988     """
--> 989     self._ensure_finalized()
    990     ee = self._codegen._engine
    991     if not ee.is_symbol_defined(name):

File C:\Code\aeon\venv\lib\site-packages\numba\core\codegen.py:567, in CodeLibrary._ensure_finalized(self)
    565 def _ensure_finalized(self):
    566     if not self._finalized:
--> 567         self.finalize()

File C:\Code\aeon\venv\lib\site-packages\numba\core\codegen.py:762, in CPUCodeLibrary.finalize(self)
    756         self._final_module.link_in(
    757             library._get_module_for_linking(), preserve=True,
    758         )
    760 # Optimize the module after all dependences are linked in above,
    761 # to allow for inlining.
--> 762 self._optimize_final_module()
    764 self._final_module.verify()
    765 self._finalize_final_module()

File C:\Code\aeon\venv\lib\site-packages\numba\core\codegen.py:682, in CPUCodeLibrary._optimize_final_module(self)
    679 full_name = "Module passes (full optimization)"
    680 with self._recorded_timings.record(full_name):
    681     # The full optimisation suite is then run on the refop pruned IR
--> 682     self._codegen._mpm_full.run(self._final_module)

File C:\Code\aeon\venv\lib\site-packages\llvmlite\binding\passmanagers.py:698, in ModulePassManager.run(self, module, remarks_file, remarks_format, remarks_filter)
    683 """
    684 Run optimization passes on the given module.
    685
   (...)
    695     The filter that should be applied to the remarks output.
    696 """
    697 if remarks_file is None:
--> 698     return ffi.lib.LLVMPY_RunPassManager(self, module)
    699 else:
    700     r = ffi.lib.LLVMPY_RunPassManagerWithRemarks(
    701         self, module, _encode_string(remarks_format),
    702         _encode_string(remarks_filter),
    703         _encode_string(remarks_file))

File C:\Code\aeon\venv\lib\site-packages\llvmlite\binding\ffi.py:192, in _lib_fn_wrapper.__call__(self, *args, **kwargs)
    190 def __call__(self, *args, **kwargs):
    191     with self._lock:
--> 192         return self._cfn(*args, **kwargs)

KeyboardInterrupt:

2. Load data

[2]:
X_train, y_train = load_italy_power_demand(split="train")
X_test, y_test = load_italy_power_demand(split="test")
X_test = X_test[:50]
y_test = y_test[:50]

print(X_train.shape, y_train.shape, X_test.shape, y_test.shape)

X_train_mv, y_train_mv = load_basic_motions(split="train")
X_test_mv, y_test_mv = load_basic_motions(split="test")

X_train_mv = X_train_mv[:50]
y_train_mv = y_train_mv[:50]
X_test_mv = X_test_mv[:50]
y_test_mv = y_test_mv[:50]

print(X_train_mv.shape, y_train_mv.shape, X_test_mv.shape, y_test_mv.shape)
(67, 1) (67,) (50, 1) (50,)
(40, 6) (40,) (40, 6) (40,)

3. Time Series Forest (TSF)

TSF is an ensemble of tree classifiers built on the summary statistics of randomly selected intervals. For each tree sqrt(n_timepoints) intervals are randomly selected. From each of these intervals the mean, standard deviation and slope is extracted from each time series and concatenated into a feature vector. These new features are then used to build a tree, which is added to the ensemble.

[3]:
tsf = TimeSeriesForestClassifier(n_estimators=50, random_state=47)
tsf.fit(X_train, y_train)

tsf_preds = tsf.predict(X_test)
print("TSF Accuracy: " + str(metrics.accuracy_score(y_test, tsf_preds)))
TSF Accuracy: 0.98

4. Random Interval Spectral Ensemble (RISE)

RISE is a tree based interval ensemble aimed at classifying audio data. Unlike TSF, it uses a single interval for each tree, and it uses spectral features rather than summary statistics.

[4]:
rise = RandomIntervalSpectralEnsembleClassifier(n_estimators=50, random_state=47)
rise.fit(X_train, y_train)

rise_preds = rise.predict(X_test)
print("RISE Accuracy: " + str(metrics.accuracy_score(y_test, rise_preds)))
RISE Accuracy: 1.0

5. Supervised Time Series Forest (STSF and RSTSF)

STSF/RSTSF makes a number of adjustments from the original TSF algorithm. A supervised method of selecting intervals replaces random selection. Features are extracted from intervals generated from additional representations in periodogram and 1st order differences. Median, min, max and interquartile range are included in the summary statistics extracted.

[1]:
stsf = SupervisedTimeSeriesForest(n_estimators=50, random_state=47)
stsf.fit(X_train, y_train)

stsf_preds = stsf.predict(X_test)
print("STSF Accuracy: " + str(metrics.accuracy_score(y_test, stsf_preds)))

rstsf = RSTSF(n_estimators=20)
rstsf.fit(X_train, y_train)

rstsf_preds = rstsf.predict(X_test)
print("RSTSF Accuracy: " + str(metrics.accuracy_score(y_test, rstsf_preds)))
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
Cell In[1], line 1
----> 1 stsf = SupervisedTimeSeriesForest(n_estimators=50, random_state=47)
      2 stsf.fit(X_train, y_train)
      4 stsf_preds = stsf.predict(X_test)

NameError: name 'SupervisedTimeSeriesForest' is not defined

6. Canonical Interval Forest (CIF)

CIF extends from the TSF algorithm. In addition to the 3 summary statistics used by TSF, CIF makes use of the features from the Catch22 [7] transform. To increase the diversity of the ensemble, the number of TSF and catch22 attributes is randomly subsampled per tree.

Univariate

[6]:
cif = CanonicalIntervalForestClassifier(
    n_estimators=50, att_subsample_size=8, random_state=47
)
cif.fit(X_train, y_train)

cif_preds = cif.predict(X_test)
print("CIF Accuracy: " + str(metrics.accuracy_score(y_test, cif_preds)))
CIF Accuracy: 0.98

Multivariate

[7]:
cif_m = CanonicalIntervalForestClassifier(
    n_estimators=50, att_subsample_size=8, random_state=47
)
cif_m.fit(X_train_mv, y_train_mv)

cif_m_preds = cif_m.predict(X_test_mv)
print("CIF Accuracy: " + str(metrics.accuracy_score(y_test_mv, cif_m_preds)))
CIF Accuracy: 1.0

6. Diverse Representation Canonical Interval Forest (DrCIF)

DrCIF makes use of the periodogram and differences representations used by STSF as well as the addition summary statistics in CIF.

Univariate

[8]:
drcif = DrCIFClassifier(n_estimators=5, att_subsample_size=10, random_state=47)
drcif.fit(X_train, y_train)

drcif_preds = drcif.predict(X_test)
print("DrCIF Accuracy: " + str(metrics.accuracy_score(y_test, drcif_preds)))
DrCIF Accuracy: 0.98

Multivariate

[9]:
drcif_m = DrCIFClassifier(n_estimators=5, att_subsample_size=10, random_state=47)
drcif_m.fit(X_train_mv, y_train_mv)

drcif_m_preds = drcif_m.predict(X_test_mv)
print("DrCIF Accuracy: " + str(metrics.accuracy_score(y_test_mv, drcif_m_preds)))
DrCIF Accuracy: 1.0

7. QUANT

QUANT is a fast interval based classifier based on quantile features

[ ]:
quant = QUANTClassifier(interval_depth=1)
quant.fit(X_train, y_train)
print("QUANT accuracy =", quant.score(X_test, y_test))

Performance on the UCR univariate datasets

You can find the interval based classifiers as follows.

[1]:
from aeon.utils.discovery import all_estimators

est = all_estimators("classifier", tag_filter={"algorithm_type": "interval"})
for c in est:
    print(c)
('CanonicalIntervalForestClassifier', <class 'aeon.classification.interval_based._cif.CanonicalIntervalForestClassifier'>)
('DrCIFClassifier', <class 'aeon.classification.interval_based._drcif.DrCIFClassifier'>)
('IntervalForestClassifier', <class 'aeon.classification.interval_based._interval_forest.IntervalForestClassifier'>)
('QUANTClassifier', <class 'aeon.classification.interval_based._quant.QUANTClassifier'>)
('RSTSF', <class 'aeon.classification.interval_based._rstsf.RSTSF'>)
('RandomIntervalClassifier', <class 'aeon.classification.interval_based._interval_pipelines.RandomIntervalClassifier'>)
('RandomIntervalSpectralEnsembleClassifier', <class 'aeon.classification.interval_based._rise.RandomIntervalSpectralEnsembleClassifier'>)
('SupervisedIntervalClassifier', <class 'aeon.classification.interval_based._interval_pipelines.SupervisedIntervalClassifier'>)
('SupervisedTimeSeriesForest', <class 'aeon.classification.interval_based._stsf.SupervisedTimeSeriesForest'>)
('TimeSeriesForestClassifier', <class 'aeon.classification.interval_based._tsf.TimeSeriesForestClassifier'>)
[2]:
from aeon.benchmarking.results_loaders import get_estimator_results_as_array
from aeon.datasets.tsc_datasets import univariate

names = [t[0].replace("Classifier", "") for t in est]
names.remove("IntervalForest")  # Base class
names.remove("RandomInterval")  # Pipeline
names.remove("SupervisedInterval")  # Pipeline
results, present_names = get_estimator_results_as_array(
    names, univariate, include_missing=False
)
results.shape
[2]:
(112, 7)
[3]:
from aeon.visualisation import plot_boxplot, plot_critical_difference

plot_critical_difference(results, names)
[3]:
(<Figure size 600x260 with 1 Axes>, <Axes: >)
../../_images/examples_classification_interval_based_24_1.png
[4]:
plot_boxplot(results, names, relative=True)
[4]:
(<Figure size 1000x600 with 1 Axes>, <Axes: >)
../../_images/examples_classification_interval_based_25_1.png

References:

[1] Deng, H. et al. (2013). A time series forest for classification and feature extraction. Information Sciences, 239, 142-153.

[2] Flynn, M et al. (2019). The contract random interval spectral ensemble (c-RISE): the effect of contracting a classifier on accuracy. In International Conference on Hybrid Artificial Intelligence Systems (pp. 381-392).

[3] Cabello, N. et al. (2020). Fast and Accurate Time Series Classification Through Supervised Interval Search. In IEEE International Conference on Data Mining.

[4] Cabello, N. et al. (2024). Fast, accurate and interpretable time series classification through randomization. Data Mining and Knowledge Discovery 38: https://link.springer.com/article/10.1007/s10618-023-00978-w

[5] Middlehurst, M. et al. (2020). The Canonical Interval Forest (CIF) Classifier for Time Series Classification. IEEE International Conference on Data Mining https://ieeexplore.ieee .org/document/9378424 arXiv version. https://arxiv.org/abs/2008.09172 [6] Dempster, A. (2024). QUANT: a minimalist interval method for time series classification. Data Mining and Knowledge Discovery 38: https://link.springer.com/article/10.1007/s10618-024-01036-9 [7] Lubba, C. et al. (2019). catch22: CAnonical Time-series CHaracteristics. Data Mining and Knowledge Discovery, 33(6), 1821-1852.


Generated using nbsphinx. The Jupyter notebook can be found here.