Interval based time series classification in aeon¶
Interval based approaches look at phase dependant intervals of the full series, calculating summary statistics from selected subseries to be used in classification.
Current interval based approaches are implemented in aeon. Time Series Forest (TSF) [1], the Random Interval Spectral Ensemble (RISE) [2], Supervised Time Series Forest (STSF) [3] and Random STSF [4], the Canonical Interval Forest (CIF) [5] the Diverse Representation Canonical Interval Forest (DrCIF) and QUANT [6]. Most have the capability to classify multivariate series.
In this notebook, we will demonstrate how to use these classifiers on the ItalyPowerDemand and BasicMotions datasets.
Set up¶
We can list all classifiers of this category like this
[3]:
import warnings
from sklearn import metrics
from aeon.classification.interval_based import (
RSTSF,
CanonicalIntervalForestClassifier,
DrCIFClassifier,
QUANTClassifier,
RandomIntervalSpectralEnsembleClassifier,
SupervisedTimeSeriesForest,
TimeSeriesForestClassifier,
)
from aeon.datasets import load_basic_motions, load_italy_power_demand
from aeon.utils.discovery import all_estimators
warnings.filterwarnings("ignore")
all_estimators("classifier", tag_filter={"algorithm_type": "interval"})
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
Cell In[3], line 18
15 from aeon.utils.discovery import all_estimators
17 warnings.filterwarnings("ignore")
---> 18 all_estimators("classifier", tag_filter={"algorithm_type": "interval"})
File C:\Code\aeon\aeon\utils\discovery.py:121, in all_estimators(type_filter, exclude_types, tag_filter, exclude_tags, include_sklearn, return_names)
116 if (
117 any(part in modules_to_ignore for part in module_parts)
118 or "._" in module_name
119 ):
120 continue
--> 121 module = import_module(module_name)
123 classes = inspect.getmembers(module, inspect.isclass)
124 # skip private estimators and those not implemented in aeon
File ~\AppData\Local\Programs\Python\Python39\lib\importlib\__init__.py:127, in import_module(name, package)
125 break
126 level += 1
--> 127 return _bootstrap._gcd_import(name[level:], package, level)
File <frozen importlib._bootstrap>:1030, in _gcd_import(name, package, level)
File <frozen importlib._bootstrap>:1007, in _find_and_load(name, import_)
File <frozen importlib._bootstrap>:986, in _find_and_load_unlocked(name, import_)
File <frozen importlib._bootstrap>:680, in _load_unlocked(spec)
File <frozen importlib._bootstrap_external>:850, in exec_module(self, module)
File <frozen importlib._bootstrap>:228, in _call_with_frames_removed(f, *args, **kwds)
File C:\Code\aeon\aeon\classification\convolution_based\__init__.py:12
1 """Convolution-based time series classifiers."""
3 __all__ = [
4 "RocketClassifier",
5 "MiniRocketClassifier",
(...)
9 "MultiRocketHydraClassifier",
10 ]
---> 12 from aeon.classification.convolution_based._arsenal import Arsenal
13 from aeon.classification.convolution_based._hydra import HydraClassifier
14 from aeon.classification.convolution_based._minirocket import MiniRocketClassifier
File C:\Code\aeon\aeon\classification\convolution_based\_arsenal.py:20
18 from aeon.base._base import _clone_estimator
19 from aeon.classification.base import BaseClassifier
---> 20 from aeon.transformations.collection.convolution_based import (
21 MiniRocket,
22 MultiRocket,
23 Rocket,
24 )
27 class Arsenal(BaseClassifier):
28 """
29 Arsenal ensemble.
30
(...)
120 >>> y_pred = clf.predict(X_test)
121 """
File C:\Code\aeon\aeon\transformations\collection\convolution_based\__init__.py:13
11 from ._hydra import HydraTransformer
12 from ._minirocket import MiniRocket
---> 13 from ._minirocket_mv import MiniRocketMultivariateVariable
14 from ._multirocket import MultiRocket
15 from ._rocket import Rocket
File C:\Code\aeon\aeon\transformations\collection\convolution_based\_minirocket_mv.py:303
290 return X_2d_t, lengths
293 # code below from the orignal authors: https://github.com/angus924/minirocket
296 @njit(
297 "float32[:](float32[:,:],int32[:],int32[:],int32[:],int32[:],int32[:],float32[:],"
298 "optional(int32))",
299 fastmath=True,
300 parallel=False,
301 cache=True,
302 )
--> 303 def _fit_biases_multi_var(
304 X,
305 L,
306 num_channels_per_combination,
307 channel_indices,
308 dilations,
309 num_features_per_dilation,
310 quantiles,
311 seed,
312 ):
313 if seed is not None:
314 np.random.seed(seed)
File C:\Code\aeon\venv\lib\site-packages\numba\core\decorators.py:232, in _jit.<locals>.wrapper(func)
230 with typeinfer.register_dispatcher(disp):
231 for sig in sigs:
--> 232 disp.compile(sig)
233 disp.disable_compile()
234 return disp
File C:\Code\aeon\venv\lib\site-packages\numba\core\dispatcher.py:905, in Dispatcher.compile(self, sig)
903 with ev.trigger_event("numba:compile", data=ev_details):
904 try:
--> 905 cres = self._compiler.compile(args, return_type)
906 except errors.ForceLiteralArg as e:
907 def folded(args, kws):
File C:\Code\aeon\venv\lib\site-packages\numba\core\dispatcher.py:80, in _FunctionCompiler.compile(self, args, return_type)
79 def compile(self, args, return_type):
---> 80 status, retval = self._compile_cached(args, return_type)
81 if status:
82 return retval
File C:\Code\aeon\venv\lib\site-packages\numba\core\dispatcher.py:94, in _FunctionCompiler._compile_cached(self, args, return_type)
91 pass
93 try:
---> 94 retval = self._compile_core(args, return_type)
95 except errors.TypingError as e:
96 self._failed_cache[key] = e
File C:\Code\aeon\venv\lib\site-packages\numba\core\dispatcher.py:107, in _FunctionCompiler._compile_core(self, args, return_type)
104 flags = self._customize_flags(flags)
106 impl = self._get_implementation(args, {})
--> 107 cres = compiler.compile_extra(self.targetdescr.typing_context,
108 self.targetdescr.target_context,
109 impl,
110 args=args, return_type=return_type,
111 flags=flags, locals=self.locals,
112 pipeline_class=self.pipeline_class)
113 # Check typing error if object mode is used
114 if cres.typing_error is not None and not flags.enable_pyobject:
File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler.py:744, in compile_extra(typingctx, targetctx, func, args, return_type, flags, locals, library, pipeline_class)
720 """Compiler entry point
721
722 Parameter
(...)
740 compiler pipeline
741 """
742 pipeline = pipeline_class(typingctx, targetctx, library,
743 args, return_type, flags, locals)
--> 744 return pipeline.compile_extra(func)
File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler.py:438, in CompilerBase.compile_extra(self, func)
436 self.state.lifted = ()
437 self.state.lifted_from = None
--> 438 return self._compile_bytecode()
File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler.py:506, in CompilerBase._compile_bytecode(self)
502 """
503 Populate and run pipeline for bytecode input
504 """
505 assert self.state.func_ir is None
--> 506 return self._compile_core()
File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler.py:472, in CompilerBase._compile_core(self)
470 res = None
471 try:
--> 472 pm.run(self.state)
473 if self.state.cr is not None:
474 break
File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler_machinery.py:356, in PassManager.run(self, state)
354 pass_inst = _pass_registry.get(pss).pass_inst
355 if isinstance(pass_inst, CompilerPass):
--> 356 self._runPass(idx, pass_inst, state)
357 else:
358 raise BaseException("Legacy pass in use")
File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler_lock.py:35, in _CompilerLock.__call__.<locals>._acquire_compile_lock(*args, **kwargs)
32 @functools.wraps(func)
33 def _acquire_compile_lock(*args, **kwargs):
34 with self:
---> 35 return func(*args, **kwargs)
File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler_machinery.py:311, in PassManager._runPass(self, index, pss, internal_state)
309 mutated |= check(pss.run_initialization, internal_state)
310 with SimpleTimer() as pass_time:
--> 311 mutated |= check(pss.run_pass, internal_state)
312 with SimpleTimer() as finalize_time:
313 mutated |= check(pss.run_finalizer, internal_state)
File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler_machinery.py:273, in PassManager._runPass.<locals>.check(func, compiler_state)
272 def check(func, compiler_state):
--> 273 mangled = func(compiler_state)
274 if mangled not in (True, False):
275 msg = ("CompilerPass implementations should return True/False. "
276 "CompilerPass with name '%s' did not.")
File C:\Code\aeon\venv\lib\site-packages\numba\core\typed_passes.py:112, in BaseTypeInference.run_pass(self, state)
106 """
107 Type inference and legalization
108 """
109 with fallback_context(state, 'Function "%s" failed type inference'
110 % (state.func_id.func_name,)):
111 # Type inference
--> 112 typemap, return_type, calltypes, errs = type_inference_stage(
113 state.typingctx,
114 state.targetctx,
115 state.func_ir,
116 state.args,
117 state.return_type,
118 state.locals,
119 raise_errors=self._raise_errors)
120 state.typemap = typemap
121 # save errors in case of partial typing
File C:\Code\aeon\venv\lib\site-packages\numba\core\typed_passes.py:93, in type_inference_stage(typingctx, targetctx, interp, args, return_type, locals, raise_errors)
91 infer.build_constraint()
92 # return errors in case of partial typing
---> 93 errs = infer.propagate(raise_errors=raise_errors)
94 typemap, restype, calltypes = infer.unify(raise_errors=raise_errors)
96 return _TypingResults(typemap, restype, calltypes, errs)
File C:\Code\aeon\venv\lib\site-packages\numba\core\typeinfer.py:1083, in TypeInferer.propagate(self, raise_errors)
1080 oldtoken = newtoken
1081 # Errors can appear when the type set is incomplete; only
1082 # raise them when there is no progress anymore.
-> 1083 errors = self.constraints.propagate(self)
1084 newtoken = self.get_state_token()
1085 self.debug.propagate_finished()
File C:\Code\aeon\venv\lib\site-packages\numba\core\typeinfer.py:160, in ConstraintNetwork.propagate(self, typeinfer)
157 with typeinfer.warnings.catch_warnings(filename=loc.filename,
158 lineno=loc.line):
159 try:
--> 160 constraint(typeinfer)
161 except ForceLiteralArg as e:
162 errors.append(e)
File C:\Code\aeon\venv\lib\site-packages\numba\core\typeinfer.py:583, in CallConstraint.__call__(self, typeinfer)
581 fnty = typevars[self.func].getone()
582 with new_error_context("resolving callee type: {0}", fnty):
--> 583 self.resolve(typeinfer, typevars, fnty)
File C:\Code\aeon\venv\lib\site-packages\numba\core\typeinfer.py:606, in CallConstraint.resolve(self, typeinfer, typevars, fnty)
604 fnty = fnty.instance_type
605 try:
--> 606 sig = typeinfer.resolve_call(fnty, pos_args, kw_args)
607 except ForceLiteralArg as e:
608 # Adjust for bound methods
609 folding_args = ((fnty.this,) + tuple(self.args)
610 if isinstance(fnty, types.BoundFunction)
611 else self.args)
File C:\Code\aeon\venv\lib\site-packages\numba\core\typeinfer.py:1577, in TypeInferer.resolve_call(self, fnty, pos_args, kw_args)
1574 return sig
1575 else:
1576 # Normal non-recursive call
-> 1577 return self.context.resolve_function_type(fnty, pos_args, kw_args)
File C:\Code\aeon\venv\lib\site-packages\numba\core\typing\context.py:196, in BaseContext.resolve_function_type(self, func, args, kws)
194 # Prefer user definition first
195 try:
--> 196 res = self._resolve_user_function_type(func, args, kws)
197 except errors.TypingError as e:
198 # Capture any typing error
199 last_exception = e
File C:\Code\aeon\venv\lib\site-packages\numba\core\typing\context.py:248, in BaseContext._resolve_user_function_type(self, func, args, kws, literals)
244 return self.resolve_function_type(func_type, args, kws)
246 if isinstance(func, types.Callable):
247 # XXX fold this into the __call__ attribute logic?
--> 248 return func.get_call_type(self, args, kws)
File C:\Code\aeon\venv\lib\site-packages\numba\core\types\functions.py:308, in BaseFunction.get_call_type(self, context, args, kws)
305 nolitargs = tuple([_unlit_non_poison(a) for a in args])
306 nolitkws = {k: _unlit_non_poison(v)
307 for k, v in kws.items()}
--> 308 sig = temp.apply(nolitargs, nolitkws)
309 except Exception as e:
310 if (utils.use_new_style_errors() and not
311 isinstance(e, errors.NumbaError)):
File C:\Code\aeon\venv\lib\site-packages\numba\core\typing\templates.py:350, in AbstractTemplate.apply(self, args, kws)
348 def apply(self, args, kws):
349 generic = getattr(self, "generic")
--> 350 sig = generic(args, kws)
351 # Enforce that *generic()* must return None or Signature
352 if sig is not None:
File C:\Code\aeon\venv\lib\site-packages\numba\core\typing\templates.py:613, in _OverloadFunctionTemplate.generic(self, args, kws)
607 """
608 Type the overloaded function by compiling the appropriate
609 implementation for the given args.
610 """
611 from numba.core.typed_passes import PreLowerStripPhis
--> 613 disp, new_args = self._get_impl(args, kws)
614 if disp is None:
615 return
File C:\Code\aeon\venv\lib\site-packages\numba\core\typing\templates.py:712, in _OverloadFunctionTemplate._get_impl(self, args, kws)
708 except KeyError:
709 # pass and try outside the scope so as to not have KeyError with a
710 # nested addition error in the case the _build_impl fails
711 pass
--> 712 impl, args = self._build_impl(cache_key, args, kws)
713 return impl, args
File C:\Code\aeon\venv\lib\site-packages\numba\core\typing\templates.py:816, in _OverloadFunctionTemplate._build_impl(self, cache_key, args, kws)
814 # Make sure that the implementation can be fully compiled
815 disp_type = types.Dispatcher(disp)
--> 816 disp_type.get_call_type(self.context, args, kws)
817 if cache_key is not None:
818 self._impl_cache[cache_key] = disp, args
File C:\Code\aeon\venv\lib\site-packages\numba\core\types\functions.py:541, in Dispatcher.get_call_type(self, context, args, kws)
534 def get_call_type(self, context, args, kws):
535 """
536 Resolve a call to this dispatcher using the given argument types.
537 A signature returned and it is ensured that a compiled specialization
538 is available for it.
539 """
540 template, pysig, args, kws = \
--> 541 self.dispatcher.get_call_template(args, kws)
542 sig = template(context).apply(args, kws)
543 if sig:
File C:\Code\aeon\venv\lib\site-packages\numba\core\dispatcher.py:318, in _DispatcherBase.get_call_template(self, args, kws)
316 # Ensure an overload is available
317 if self._can_compile:
--> 318 self.compile(tuple(args))
320 # Create function type for typing
321 func_name = self.py_func.__name__
File C:\Code\aeon\venv\lib\site-packages\numba\core\dispatcher.py:905, in Dispatcher.compile(self, sig)
903 with ev.trigger_event("numba:compile", data=ev_details):
904 try:
--> 905 cres = self._compiler.compile(args, return_type)
906 except errors.ForceLiteralArg as e:
907 def folded(args, kws):
File C:\Code\aeon\venv\lib\site-packages\numba\core\dispatcher.py:80, in _FunctionCompiler.compile(self, args, return_type)
79 def compile(self, args, return_type):
---> 80 status, retval = self._compile_cached(args, return_type)
81 if status:
82 return retval
File C:\Code\aeon\venv\lib\site-packages\numba\core\dispatcher.py:94, in _FunctionCompiler._compile_cached(self, args, return_type)
91 pass
93 try:
---> 94 retval = self._compile_core(args, return_type)
95 except errors.TypingError as e:
96 self._failed_cache[key] = e
File C:\Code\aeon\venv\lib\site-packages\numba\core\dispatcher.py:107, in _FunctionCompiler._compile_core(self, args, return_type)
104 flags = self._customize_flags(flags)
106 impl = self._get_implementation(args, {})
--> 107 cres = compiler.compile_extra(self.targetdescr.typing_context,
108 self.targetdescr.target_context,
109 impl,
110 args=args, return_type=return_type,
111 flags=flags, locals=self.locals,
112 pipeline_class=self.pipeline_class)
113 # Check typing error if object mode is used
114 if cres.typing_error is not None and not flags.enable_pyobject:
File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler.py:744, in compile_extra(typingctx, targetctx, func, args, return_type, flags, locals, library, pipeline_class)
720 """Compiler entry point
721
722 Parameter
(...)
740 compiler pipeline
741 """
742 pipeline = pipeline_class(typingctx, targetctx, library,
743 args, return_type, flags, locals)
--> 744 return pipeline.compile_extra(func)
File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler.py:438, in CompilerBase.compile_extra(self, func)
436 self.state.lifted = ()
437 self.state.lifted_from = None
--> 438 return self._compile_bytecode()
File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler.py:506, in CompilerBase._compile_bytecode(self)
502 """
503 Populate and run pipeline for bytecode input
504 """
505 assert self.state.func_ir is None
--> 506 return self._compile_core()
File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler.py:472, in CompilerBase._compile_core(self)
470 res = None
471 try:
--> 472 pm.run(self.state)
473 if self.state.cr is not None:
474 break
File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler_machinery.py:356, in PassManager.run(self, state)
354 pass_inst = _pass_registry.get(pss).pass_inst
355 if isinstance(pass_inst, CompilerPass):
--> 356 self._runPass(idx, pass_inst, state)
357 else:
358 raise BaseException("Legacy pass in use")
File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler_lock.py:35, in _CompilerLock.__call__.<locals>._acquire_compile_lock(*args, **kwargs)
32 @functools.wraps(func)
33 def _acquire_compile_lock(*args, **kwargs):
34 with self:
---> 35 return func(*args, **kwargs)
File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler_machinery.py:311, in PassManager._runPass(self, index, pss, internal_state)
309 mutated |= check(pss.run_initialization, internal_state)
310 with SimpleTimer() as pass_time:
--> 311 mutated |= check(pss.run_pass, internal_state)
312 with SimpleTimer() as finalize_time:
313 mutated |= check(pss.run_finalizer, internal_state)
File C:\Code\aeon\venv\lib\site-packages\numba\core\compiler_machinery.py:273, in PassManager._runPass.<locals>.check(func, compiler_state)
272 def check(func, compiler_state):
--> 273 mangled = func(compiler_state)
274 if mangled not in (True, False):
275 msg = ("CompilerPass implementations should return True/False. "
276 "CompilerPass with name '%s' did not.")
File C:\Code\aeon\venv\lib\site-packages\numba\core\typed_passes.py:497, in BaseNativeLowering.run_pass(self, state)
491 state['cr'] = _LowerResult(fndesc, call_helper,
492 cfunc=None, env=env)
493 else:
494 # Prepare for execution
495 # Insert native function for use by other jitted-functions.
496 # We also register its library to allow for inlining.
--> 497 cfunc = targetctx.get_executable(library, fndesc, env)
498 targetctx.insert_user_function(cfunc, fndesc, [library])
499 state['cr'] = _LowerResult(fndesc, call_helper,
500 cfunc=cfunc, env=env)
File C:\Code\aeon\venv\lib\site-packages\numba\core\cpu.py:239, in CPUContext.get_executable(self, library, fndesc, env)
226 """
227 Returns
228 -------
(...)
236 an execution environment (from _dynfunc)
237 """
238 # Code generation
--> 239 fnptr = library.get_pointer_to_function(
240 fndesc.llvm_cpython_wrapper_name)
242 # Note: we avoid reusing the original docstring to avoid encoding
243 # issues on Python 2, see issue #1908
244 doc = "compiled wrapper for %r" % (fndesc.qualname,)
File C:\Code\aeon\venv\lib\site-packages\numba\core\codegen.py:989, in JITCodeLibrary.get_pointer_to_function(self, name)
975 def get_pointer_to_function(self, name):
976 """
977 Generate native code for function named *name* and return a pointer
978 to the start of the function (as an integer).
(...)
987 - non-zero if the symbol is defined.
988 """
--> 989 self._ensure_finalized()
990 ee = self._codegen._engine
991 if not ee.is_symbol_defined(name):
File C:\Code\aeon\venv\lib\site-packages\numba\core\codegen.py:567, in CodeLibrary._ensure_finalized(self)
565 def _ensure_finalized(self):
566 if not self._finalized:
--> 567 self.finalize()
File C:\Code\aeon\venv\lib\site-packages\numba\core\codegen.py:762, in CPUCodeLibrary.finalize(self)
756 self._final_module.link_in(
757 library._get_module_for_linking(), preserve=True,
758 )
760 # Optimize the module after all dependences are linked in above,
761 # to allow for inlining.
--> 762 self._optimize_final_module()
764 self._final_module.verify()
765 self._finalize_final_module()
File C:\Code\aeon\venv\lib\site-packages\numba\core\codegen.py:682, in CPUCodeLibrary._optimize_final_module(self)
679 full_name = "Module passes (full optimization)"
680 with self._recorded_timings.record(full_name):
681 # The full optimisation suite is then run on the refop pruned IR
--> 682 self._codegen._mpm_full.run(self._final_module)
File C:\Code\aeon\venv\lib\site-packages\llvmlite\binding\passmanagers.py:698, in ModulePassManager.run(self, module, remarks_file, remarks_format, remarks_filter)
683 """
684 Run optimization passes on the given module.
685
(...)
695 The filter that should be applied to the remarks output.
696 """
697 if remarks_file is None:
--> 698 return ffi.lib.LLVMPY_RunPassManager(self, module)
699 else:
700 r = ffi.lib.LLVMPY_RunPassManagerWithRemarks(
701 self, module, _encode_string(remarks_format),
702 _encode_string(remarks_filter),
703 _encode_string(remarks_file))
File C:\Code\aeon\venv\lib\site-packages\llvmlite\binding\ffi.py:192, in _lib_fn_wrapper.__call__(self, *args, **kwargs)
190 def __call__(self, *args, **kwargs):
191 with self._lock:
--> 192 return self._cfn(*args, **kwargs)
KeyboardInterrupt:
2. Load data¶
[2]:
X_train, y_train = load_italy_power_demand(split="train")
X_test, y_test = load_italy_power_demand(split="test")
X_test = X_test[:50]
y_test = y_test[:50]
print(X_train.shape, y_train.shape, X_test.shape, y_test.shape)
X_train_mv, y_train_mv = load_basic_motions(split="train")
X_test_mv, y_test_mv = load_basic_motions(split="test")
X_train_mv = X_train_mv[:50]
y_train_mv = y_train_mv[:50]
X_test_mv = X_test_mv[:50]
y_test_mv = y_test_mv[:50]
print(X_train_mv.shape, y_train_mv.shape, X_test_mv.shape, y_test_mv.shape)
(67, 1) (67,) (50, 1) (50,)
(40, 6) (40,) (40, 6) (40,)
3. Time Series Forest (TSF)¶
TSF is an ensemble of tree classifiers built on the summary statistics of randomly selected intervals. For each tree sqrt(n_timepoints) intervals are randomly selected. From each of these intervals the mean, standard deviation and slope is extracted from each time series and concatenated into a feature vector. These new features are then used to build a tree, which is added to the ensemble.
[3]:
tsf = TimeSeriesForestClassifier(n_estimators=50, random_state=47)
tsf.fit(X_train, y_train)
tsf_preds = tsf.predict(X_test)
print("TSF Accuracy: " + str(metrics.accuracy_score(y_test, tsf_preds)))
TSF Accuracy: 0.98
4. Random Interval Spectral Ensemble (RISE)¶
RISE is a tree based interval ensemble aimed at classifying audio data. Unlike TSF, it uses a single interval for each tree, and it uses spectral features rather than summary statistics.
[4]:
rise = RandomIntervalSpectralEnsembleClassifier(n_estimators=50, random_state=47)
rise.fit(X_train, y_train)
rise_preds = rise.predict(X_test)
print("RISE Accuracy: " + str(metrics.accuracy_score(y_test, rise_preds)))
RISE Accuracy: 1.0
5. Supervised Time Series Forest (STSF and RSTSF)¶
STSF/RSTSF makes a number of adjustments from the original TSF algorithm. A supervised method of selecting intervals replaces random selection. Features are extracted from intervals generated from additional representations in periodogram and 1st order differences. Median, min, max and interquartile range are included in the summary statistics extracted.
[1]:
stsf = SupervisedTimeSeriesForest(n_estimators=50, random_state=47)
stsf.fit(X_train, y_train)
stsf_preds = stsf.predict(X_test)
print("STSF Accuracy: " + str(metrics.accuracy_score(y_test, stsf_preds)))
rstsf = RSTSF(n_estimators=20)
rstsf.fit(X_train, y_train)
rstsf_preds = rstsf.predict(X_test)
print("RSTSF Accuracy: " + str(metrics.accuracy_score(y_test, rstsf_preds)))
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
Cell In[1], line 1
----> 1 stsf = SupervisedTimeSeriesForest(n_estimators=50, random_state=47)
2 stsf.fit(X_train, y_train)
4 stsf_preds = stsf.predict(X_test)
NameError: name 'SupervisedTimeSeriesForest' is not defined
6. Canonical Interval Forest (CIF)¶
CIF extends from the TSF algorithm. In addition to the 3 summary statistics used by TSF, CIF makes use of the features from the Catch22 [7] transform. To increase the diversity of the ensemble, the number of TSF and catch22 attributes is randomly subsampled per tree.
Univariate¶
[6]:
cif = CanonicalIntervalForestClassifier(
n_estimators=50, att_subsample_size=8, random_state=47
)
cif.fit(X_train, y_train)
cif_preds = cif.predict(X_test)
print("CIF Accuracy: " + str(metrics.accuracy_score(y_test, cif_preds)))
CIF Accuracy: 0.98
Multivariate¶
[7]:
cif_m = CanonicalIntervalForestClassifier(
n_estimators=50, att_subsample_size=8, random_state=47
)
cif_m.fit(X_train_mv, y_train_mv)
cif_m_preds = cif_m.predict(X_test_mv)
print("CIF Accuracy: " + str(metrics.accuracy_score(y_test_mv, cif_m_preds)))
CIF Accuracy: 1.0
6. Diverse Representation Canonical Interval Forest (DrCIF)¶
DrCIF makes use of the periodogram and differences representations used by STSF as well as the addition summary statistics in CIF.
Univariate¶
[8]:
drcif = DrCIFClassifier(n_estimators=5, att_subsample_size=10, random_state=47)
drcif.fit(X_train, y_train)
drcif_preds = drcif.predict(X_test)
print("DrCIF Accuracy: " + str(metrics.accuracy_score(y_test, drcif_preds)))
DrCIF Accuracy: 0.98
Multivariate¶
[9]:
drcif_m = DrCIFClassifier(n_estimators=5, att_subsample_size=10, random_state=47)
drcif_m.fit(X_train_mv, y_train_mv)
drcif_m_preds = drcif_m.predict(X_test_mv)
print("DrCIF Accuracy: " + str(metrics.accuracy_score(y_test_mv, drcif_m_preds)))
DrCIF Accuracy: 1.0
7. QUANT¶
QUANT is a fast interval based classifier based on quantile features
[ ]:
quant = QUANTClassifier(interval_depth=1)
quant.fit(X_train, y_train)
print("QUANT accuracy =", quant.score(X_test, y_test))
Performance on the UCR univariate datasets¶
You can find the interval based classifiers as follows.
[1]:
from aeon.utils.discovery import all_estimators
est = all_estimators("classifier", tag_filter={"algorithm_type": "interval"})
for c in est:
print(c)
('CanonicalIntervalForestClassifier', <class 'aeon.classification.interval_based._cif.CanonicalIntervalForestClassifier'>)
('DrCIFClassifier', <class 'aeon.classification.interval_based._drcif.DrCIFClassifier'>)
('IntervalForestClassifier', <class 'aeon.classification.interval_based._interval_forest.IntervalForestClassifier'>)
('QUANTClassifier', <class 'aeon.classification.interval_based._quant.QUANTClassifier'>)
('RSTSF', <class 'aeon.classification.interval_based._rstsf.RSTSF'>)
('RandomIntervalClassifier', <class 'aeon.classification.interval_based._interval_pipelines.RandomIntervalClassifier'>)
('RandomIntervalSpectralEnsembleClassifier', <class 'aeon.classification.interval_based._rise.RandomIntervalSpectralEnsembleClassifier'>)
('SupervisedIntervalClassifier', <class 'aeon.classification.interval_based._interval_pipelines.SupervisedIntervalClassifier'>)
('SupervisedTimeSeriesForest', <class 'aeon.classification.interval_based._stsf.SupervisedTimeSeriesForest'>)
('TimeSeriesForestClassifier', <class 'aeon.classification.interval_based._tsf.TimeSeriesForestClassifier'>)
[2]:
from aeon.benchmarking.results_loaders import get_estimator_results_as_array
from aeon.datasets.tsc_datasets import univariate
names = [t[0].replace("Classifier", "") for t in est]
names.remove("IntervalForest") # Base class
names.remove("RandomInterval") # Pipeline
names.remove("SupervisedInterval") # Pipeline
results, present_names = get_estimator_results_as_array(
names, univariate, include_missing=False
)
results.shape
[2]:
(112, 7)
[3]:
from aeon.visualisation import plot_boxplot, plot_critical_difference
plot_critical_difference(results, names)
[3]:
(<Figure size 600x260 with 1 Axes>, <Axes: >)
[4]:
plot_boxplot(results, names, relative=True)
[4]:
(<Figure size 1000x600 with 1 Axes>, <Axes: >)
References:¶
[1] Deng, H. et al. (2013). A time series forest for classification and feature extraction. Information Sciences, 239, 142-153.
[2] Flynn, M et al. (2019). The contract random interval spectral ensemble (c-RISE): the effect of contracting a classifier on accuracy. In International Conference on Hybrid Artificial Intelligence Systems (pp. 381-392).
[3] Cabello, N. et al. (2020). Fast and Accurate Time Series Classification Through Supervised Interval Search. In IEEE International Conference on Data Mining.
[4] Cabello, N. et al. (2024). Fast, accurate and interpretable time series classification through randomization. Data Mining and Knowledge Discovery 38: https://link.springer.com/article/10.1007/s10618-023-00978-w
[5] Middlehurst, M. et al. (2020). The Canonical Interval Forest (CIF) Classifier for Time Series Classification. IEEE International Conference on Data Mining https://ieeexplore.ieee .org/document/9378424 arXiv version. https://arxiv.org/abs/2008.09172 [6] Dempster, A. (2024). QUANT: a minimalist interval method for time series classification. Data Mining and Knowledge Discovery 38: https://link.springer.com/article/10.1007/s10618-024-01036-9 [7] Lubba, C. et al. (2019). catch22: CAnonical Time-series CHaracteristics. Data Mining and Knowledge Discovery, 33(6), 1821-1852.
Generated using nbsphinx. The Jupyter notebook can be found here.