From 4180e3b4274ea921fde8f83453eaf15b4abbd01a Mon Sep 17 00:00:00 2001 From: redxef Date: Wed, 9 Nov 2022 11:54:59 +0100 Subject: [PATCH] Add branching operator, explicitly load values from input dict. --- README.md | 22 ++++++------ i3toolwait | 100 ++++++++++++++++++++++++++++++++++++++++++----------- 2 files changed, 91 insertions(+), 31 deletions(-) diff --git a/README.md b/README.md index 0dbb3ef..0e926ca 100644 --- a/README.md +++ b/README.md @@ -51,19 +51,19 @@ It is then possible to construct a filter for any program. Available Operators: -- and: `&` -- or: `|` -- eq: `=` -- neq: `!=` -- gt: `>` -- lt: `<` +- and: `&`: logical and, ungreedy +- or: `|`: logical or, ungreedy +- if: `?`: branch, if the first argument evaluates to `True` return the second, otherwise the third +- eq: `=`: equality +- neq: `!=`: inequality +- gt: `>`: greater than +- lt: `<`: less than +- load: `load`: load a key from the provided input `(load ".container.app_id")` +- has-key: `has-key`: check if a key is in the input: `(has-key ".container.app_id")` -The filter usually operates on the dictionary, and thus the *first* argument to every normal filter -is the dictionary element, in `.` notation, as might be customary in `jq`. +For example: `(> (load ".container.geometry.width") 300)` would match the first window where the width is greater than 300. -For example: `(> ".container.geometry.width" 300)` would match the first window where the width is greater than 300. - -Multiple filters are combined via nesting: `(& (> ".container.geometry.width" 300) (= ".container.window_properties.class" "discord"))`. +Multiple filters are combined via nesting: `(& (> (load ".container.geometry.width") 300) (= (load ".container.window_properties.class") "discord"))`. ## Starting tray programs in a specific order diff --git a/i3toolwait b/i3toolwait index 7a0166a..d67e1bf 100755 --- a/i3toolwait +++ b/i3toolwait @@ -25,12 +25,19 @@ class Expression: def __init__(self): pass def reduce(self, ipc_data): + if self.should_call: + return self.call(ipc_data) return functools.reduce(self.reduce_function(ipc_data), self.children) @property + def should_call(self): + return False + @property def children(self): raise NotImplemented('TODO: implement in subclass') def reduce_function(self, ipc_data): raise NotImplemented('TODO: implement in subclass') + def call(self, ipc_data): + raise NotImplementedError('TODO: implement in subclass') class LiteralExpression(Expression): def __init__(self, value): @@ -78,6 +85,26 @@ class OrExpression(Expression): def reduce_function(self, ipc_data): return lambda a, b: a.reduce(ipc_data) or b.reduce(ipc_data) +class IfExpression(Expression): + def __init__(self, children, *args, **kwargs): + self._children = children + super().__init__(*args, **kwargs) + def __repr__(self) -> str: + cs = ' '.join([repr(c) for c in self.children]) + return f'(? {cs})' + @property + def should_call(self): + return True + @property + def children(self): + return self._children + def call(self, ipc_data): + if self._children[0].reduce(ipc_data): + i = 1 + else: + i = 2 + return self._children[i].reduce(ipc_data) + class EqExpression(Expression): def __init__(self, children, *args, **kwargs): self._children = children @@ -89,11 +116,9 @@ class EqExpression(Expression): def children(self): return self._children def reduce_function(self, ipc_data): - def reduce(key, value): - ipc_value = ipc_data - for k in key.reduce(ipc_data).strip('.').split('.'): - ipc_value = ipc_value[k] - return ipc_value == value.reduce(ipc_data) + def reduce(v0, v1): + print(f'reducing: {repr(self)}') + return v0.reduce(ipc_data) == v1.reduce(ipc_data) return reduce class NeqExpression(Expression): @@ -107,11 +132,8 @@ class NeqExpression(Expression): def children(self): return self._children def reduce_function(self, ipc_data): - def reduce(key, value): - ipc_value = ipc_data - for k in key.reduce(ipc_data).strip('.').split('.'): - ipc_value = ipc_value[k] - return ipc_value != value.reduce(ipc_data) + def reduce(v0, v1): + return v0.reduce(ipc_data) != v1.reduce(ipc_data) return reduce class GtExpression(Expression): @@ -125,11 +147,8 @@ class GtExpression(Expression): def children(self): return self._children def reduce_function(self, ipc_data): - def reduce(key, value): - ipc_value = ipc_data - for k in key.reduce(ipc_data).strip('.').split('.'): - ipc_value = ipc_value[k] - return ipc_value > value.reduce(ipc_data) + def reduce(v0, v1): + return v0.reduce(ipc_data) > v1.reduce(ipc_data) return reduce class LtExpression(Expression): @@ -143,20 +162,60 @@ class LtExpression(Expression): def children(self): return self._children def reduce_function(self, ipc_data): - def reduce(key, value): - ipc_value = ipc_data - for k in key.reduce(ipc_data).strip('.').split('.'): - ipc_value = ipc_value[k] - return ipc_value < value.reduce(ipc_data) + def reduce(v0, v1): + return v0.reduce(ipc_data) < v1.reduce(ipc_data) return reduce +class LoadExpression(Expression): + def __init__(self, value, *args, **kwargs): + self._value = value + super().__init__(*args, **kwargs) + def __repr__(self) -> str: + return f'(load {self._value})' + @property + def should_call(self): + return True + @property + def children(self): + return [self._value] + def call(self, ipc_data): + ipc_value = ipc_data + for k in self._value[0].children[0].strip('.').split('.'): + ipc_value = ipc_value[k] + return ipc_value + +class HasKeyExpression(Expression): + def __init__(self, value, *args, **kwargs): + self._value = value + super().__init__(*args, **kwargs) + def __repr__(self) -> str: + return f'(has-key {self._value})' + @property + def should_call(self): + return True + @property + def children(self): + return [self._value] + def call(self, ipc_data): + ipc_value = ipc_data + for k in self._value[0].children[0].strip('.').split('.'): + try: + ipc_value = ipc_value[k] + except KeyError: + return False + return True + + expression_mapping = { '&': AndExpression, '|': OrExpression, + '?': IfExpression, '=': EqExpression, '!=': NeqExpression, '>': GtExpression, '<': LtExpression, + 'load': LoadExpression, + 'has-key': HasKeyExpression, } def group_tokens(tokens: list[str]) -> list[list[str]]: @@ -508,3 +567,4 @@ def config(ctx, config): if __name__ == '__main__': main() +