manen

Fancy Lua REPL
Log | Files | Refs | README | LICENSE

completion.rs (14310B)


      1 use std::sync::Arc;
      2 
      3 use emmylua_parser::{
      4     LuaAst, LuaAstNode, LuaAstToken, LuaBlock, LuaExpr, LuaIndexExpr, LuaNameExpr, LuaParser,
      5     LuaSyntaxTree, LuaTokenKind,
      6 };
      7 use mlua::prelude::*;
      8 use reedline::{Completer, Span, Suggestion};
      9 use rowan::{TextRange, TextSize};
     10 
     11 use crate::{lua::LuaExecutor, parse};
     12 
     13 #[derive(Debug)]
     14 struct Variable {
     15     range: TextRange,
     16     name: String,
     17 }
     18 
     19 #[derive(Debug)]
     20 struct Scope {
     21     range: TextRange,
     22     variables: Vec<Variable>,
     23 }
     24 
     25 pub struct LuaCompleter {
     26     lua_executor: Arc<dyn LuaExecutor>,
     27     tree: LuaSyntaxTree,
     28 
     29     scopes: Vec<Scope>,
     30     text: String,
     31 }
     32 
     33 impl LuaCompleter {
     34     pub fn new(lua_executor: Arc<dyn LuaExecutor>) -> Self {
     35         Self {
     36             lua_executor,
     37             tree: LuaParser::parse("", parse::config()),
     38             scopes: Vec::new(),
     39             text: String::new(),
     40         }
     41     }
     42 
     43     fn refresh_tree(&mut self, text: &str) {
     44         self.tree = LuaParser::parse(text, parse::config());
     45         self.text = text.to_string();
     46         self.scopes = self.resolve_scopes();
     47     }
     48 
     49     fn globals(&self) -> Vec<String> {
     50         if let Ok(globals) = self.lua_executor.globals() {
     51             globals
     52                 .pairs()
     53                 .flatten()
     54                 .map(|(k, _): (String, LuaValue)| k)
     55                 .collect()
     56         } else {
     57             Vec::new()
     58         }
     59     }
     60 
     61     fn resolve_scopes(&self) -> Vec<Scope> {
     62         let mut scopes = Vec::new();
     63 
     64         let chunk = self.tree.get_chunk_node();
     65 
     66         for scope in chunk.descendants::<LuaBlock>() {
     67             let mut variables = Vec::new();
     68 
     69             match scope.get_parent() {
     70                 Some(LuaAst::LuaClosureExpr(closure)) => {
     71                     if let Some(params) = closure.get_params_list() {
     72                         for param in params.get_params() {
     73                             if let Some(token) = param.get_name_token() {
     74                                 variables.push(Variable {
     75                                     range: param.get_range(),
     76                                     name: token.get_name_text().to_string(),
     77                                 });
     78                             }
     79                         }
     80                     }
     81                 }
     82                 Some(LuaAst::LuaForRangeStat(range)) => {
     83                     for token in range.get_var_name_list() {
     84                         variables.push(Variable {
     85                             range: token.get_range(),
     86                             name: token.get_name_text().to_string(),
     87                         })
     88                     }
     89                 }
     90                 Some(LuaAst::LuaForStat(stat)) => {
     91                     if let Some(token) = stat.get_var_name() {
     92                         variables.push(Variable {
     93                             range: token.get_range(),
     94                             name: token.get_name_text().to_string(),
     95                         });
     96                     }
     97                 }
     98                 _ => {}
     99             }
    100 
    101             for node in scope.children::<LuaAst>() {
    102                 match node {
    103                     LuaAst::LuaLocalFuncStat(stat) => {
    104                         if let Some(name) = stat.get_local_name() {
    105                             if let Some(token) = name.get_name_token() {
    106                                 variables.push(Variable {
    107                                     range: token.get_range(),
    108                                     name: token.get_name_text().to_string(),
    109                                 });
    110                             }
    111                         }
    112                     }
    113                     LuaAst::LuaLocalStat(stat) => {
    114                         for name in stat.get_local_name_list() {
    115                             if let Some(token) = name.get_name_token() {
    116                                 variables.push(Variable {
    117                                     range: stat.get_range(),
    118                                     name: token.get_name_text().to_string(),
    119                                 });
    120                             }
    121                         }
    122                     }
    123                     _ => {}
    124                 }
    125             }
    126 
    127             scopes.push(Scope {
    128                 range: scope.get_range(),
    129                 variables,
    130             });
    131         }
    132 
    133         scopes
    134     }
    135 
    136     fn locals(&self, position: u32) -> Vec<String> {
    137         let mut variables = Vec::new();
    138 
    139         for scope in self.scopes.iter() {
    140             if position >= scope.range.start().into() && position <= scope.range.end().into() {
    141                 for var in scope.variables.iter() {
    142                     if position >= var.range.end().into() {
    143                         variables.push(var.name.clone());
    144                     }
    145                 }
    146             }
    147         }
    148 
    149         variables
    150     }
    151 
    152     // okay not the correct terminology
    153     //
    154     // there are 3 kinds of variable
    155     // - local (current scope)
    156     // - global (_G/_ENV)
    157     // - upvalue (local of parent scope(s))
    158     //
    159     // well in 5.2+ its only local and upvalue since you upvalue _ENV
    160     // then you get the individual global variable
    161     //
    162     // in the code
    163     //
    164     // ```lua
    165     // local a = 1
    166     // b = 2
    167     //
    168     // local function _()
    169     //    local c = 3
    170     //    print(a, b, c)
    171     // end
    172     // ```
    173     //
    174     // the bytecode for the function is
    175     //
    176     // 1       [5]     LOADI           0 3
    177     // 2       [6]     GETTABUP        1 0 0   ; _ENV "print"
    178     // 3       [6]     GETUPVAL        2 1     ; a
    179     // 4       [6]     GETTABUP        3 0 1   ; _ENV "b"
    180     //
    181     // the local can be loaded with LOADI (load integer) while a and b
    182     // both have to be upvalued
    183     //
    184     // this is different in 5.1
    185     //
    186     // 1       [5]     LOADK           0 -1    ; 3
    187     // 2       [6]     GETGLOBAL       1 -2    ; print
    188     // 3       [6]     GETUPVAL        2 0     ; a
    189     // 4       [6]     GETGLOBAL       3 -3    ; b
    190     //
    191     // in 5.1, globals are treated uniquely and given their own opcode
    192     //
    193     // to summarize, this function is not properly named
    194     //
    195     // globals either exist or are an extension of _ENV
    196     fn autocomplete_upvalue(&self, query: &str, position: u32) -> Vec<String> {
    197         let mut upvalues = self.locals(position);
    198         upvalues.extend(self.globals());
    199         upvalues.sort();
    200 
    201         upvalues
    202             .into_iter()
    203             .filter(|s| s.starts_with(query))
    204             .collect()
    205     }
    206 
    207     fn table_index(&self, position: u32) -> Option<(TextRange, Vec<String>)> {
    208         let chunk = self.tree.get_chunk_node();
    209 
    210         for index in chunk.descendants::<LuaIndexExpr>() {
    211             let (range, name, is_dot) = index
    212                 .get_index_key()
    213                 .map(|k| k.get_range().map(|r| (r, k.get_path_part(), false)))
    214                 .unwrap_or_else(|| {
    215                     index.token_by_kind(LuaTokenKind::TkDot).map(|t| {
    216                         let range = t.get_range();
    217                         (
    218                             TextRange::new(range.start(), range.start() + TextSize::new(1)),
    219                             String::new(),
    220                             true,
    221                         )
    222                     })
    223                 })?;
    224 
    225             if position >= range.start().into() && position < range.end().into() {
    226                 let mut children: Vec<String> = Vec::new();
    227 
    228                 for parent_index in index.descendants::<LuaIndexExpr>() {
    229                     if let Some(token) = parent_index.get_name_token() {
    230                         children.push(token.get_name_text().to_string());
    231                     }
    232 
    233                     if let Some(LuaExpr::NameExpr(token)) = parent_index.get_prefix_expr() {
    234                         children.push(token.get_name_text()?);
    235                     }
    236                 }
    237 
    238                 if children.len() > 1 {
    239                     children.reverse();
    240                     children.pop();
    241                 }
    242 
    243                 let fields = if let Ok(globals) = self.lua_executor.globals() {
    244                     let mut var: LuaResult<LuaValue> = Ok(LuaValue::Table(globals));
    245 
    246                     for index in children.iter().rev() {
    247                         if let Ok(LuaValue::Table(tbl)) = var {
    248                             var = tbl.raw_get(index.as_str())
    249                         }
    250                     }
    251 
    252                     if let Ok(LuaValue::Table(tbl)) = var {
    253                         tbl.pairs()
    254                             .flatten()
    255                             .map(|(k, _): (String, LuaValue)| k)
    256                             .filter(|s| s.starts_with(&name))
    257                             .collect::<Vec<_>>()
    258                     } else {
    259                         Vec::new()
    260                     }
    261                 } else {
    262                     Vec::new()
    263                 };
    264 
    265                 if is_dot {
    266                     return Some((
    267                         TextRange::new(range.start() + TextSize::new(1), range.end()),
    268                         fields,
    269                     ));
    270                 } else {
    271                     return Some((range, fields));
    272                 }
    273             }
    274         }
    275 
    276         None
    277     }
    278 
    279     fn current_identifier(&self, position: u32) -> Option<(TextRange, String)> {
    280         let chunk = self.tree.get_chunk_node();
    281 
    282         for identifier in chunk.descendants::<LuaNameExpr>() {
    283             let range = identifier.get_range();
    284 
    285             if position >= range.start().into() && position < range.end().into() {
    286                 if let Some(name) = identifier.get_name_text() {
    287                     return Some((range, name));
    288                 } else {
    289                     return None;
    290                 }
    291             }
    292         }
    293 
    294         None
    295     }
    296 }
    297 
    298 impl Completer for LuaCompleter {
    299     fn complete(&mut self, line: &str, pos: usize) -> Vec<Suggestion> {
    300         let pos = pos as u32;
    301         self.refresh_tree(line);
    302 
    303         if let Some((range, current)) = self.current_identifier(pos.saturating_sub(1)) {
    304             return self
    305                 .autocomplete_upvalue(&current, pos)
    306                 .into_iter()
    307                 .map(|s| Suggestion {
    308                     value: s,
    309                     span: Span::new(range.start().into(), range.end().into()),
    310                     ..Default::default()
    311                 })
    312                 .collect();
    313         }
    314 
    315         if let Some((range, fields)) = self.table_index(pos.saturating_sub(1)) {
    316             return fields
    317                 .into_iter()
    318                 .map(|s| Suggestion {
    319                     value: s,
    320                     span: Span::new(range.start().into(), range.end().into()),
    321                     ..Default::default()
    322                 })
    323                 .collect();
    324         }
    325 
    326         Vec::new()
    327     }
    328 }
    329 
    330 #[cfg(test)]
    331 mod tests {
    332     use std::collections::HashMap;
    333 
    334     use crate::lua::MluaExecutor;
    335 
    336     use super::*;
    337 
    338     fn lua_executor() -> Arc<dyn LuaExecutor> {
    339         Arc::new(MluaExecutor::new())
    340     }
    341 
    342     fn line_to_position(line: usize, text: &str) -> u32 {
    343         let split = text.split("\n").collect::<Vec<_>>();
    344         split[0..line].join("\n").len() as u32
    345     }
    346 
    347     #[test]
    348     fn locals() {
    349         let mut completer = LuaCompleter::new(lua_executor());
    350 
    351         let text = r#"
    352         local function foo(a, b)
    353            -- 2: foo, a, b
    354            print(a, b)
    355         end
    356 
    357         -- 6: foo
    358 
    359         local function bar(c)
    360            -- 9: foo, bar, c
    361            print(c)
    362         end
    363 
    364         -- 13: foo, bar
    365 
    366         for i = 1, 10 do
    367            -- 16: foo, bar, i
    368            print(i)
    369         end
    370 
    371         -- 20: foo, bar
    372 
    373         for i, v in pairs(_G) do
    374            -- 23: foo, bar, i, v
    375            print(i, v)
    376         end
    377 
    378         -- 27: foo, bar
    379         "#;
    380 
    381         completer.refresh_tree(text);
    382 
    383         assert_eq!(
    384             &["foo", "a", "b"].as_slice(),
    385             &completer.locals(line_to_position(2, text)),
    386         );
    387 
    388         assert_eq!(
    389             &["foo"].as_slice(),
    390             &completer.locals(line_to_position(6, text)),
    391         );
    392 
    393         assert_eq!(
    394             &["foo", "bar", "c"].as_slice(),
    395             &completer.locals(line_to_position(9, text)),
    396         );
    397 
    398         assert_eq!(
    399             &["foo", "bar"].as_slice(),
    400             &completer.locals(line_to_position(13, text)),
    401         );
    402 
    403         assert_eq!(
    404             &["foo", "bar", "i"].as_slice(),
    405             &completer.locals(line_to_position(16, text)),
    406         );
    407 
    408         assert_eq!(
    409             &["foo", "bar"].as_slice(),
    410             &completer.locals(line_to_position(20, text)),
    411         );
    412 
    413         assert_eq!(
    414             &["foo", "bar", "i", "v"].as_slice(),
    415             &completer.locals(line_to_position(23, text)),
    416         );
    417 
    418         assert_eq!(
    419             &["foo", "bar"].as_slice(),
    420             &completer.locals(line_to_position(27, text)),
    421         );
    422     }
    423 
    424     #[test]
    425     fn upvalues() {
    426         let lua = lua_executor();
    427         lua.globals().unwrap().set("foobar", "").unwrap();
    428 
    429         let mut completer = LuaCompleter::new(lua);
    430 
    431         let text = r#"
    432         local function foo(a, fooing)
    433             local foobaz = 3
    434             -- 3: foo, foobar, fooing, foobaz
    435         end
    436         "#;
    437 
    438         completer.refresh_tree(text);
    439 
    440         assert_eq!(
    441             &["foo", "foobar", "foobaz", "fooing"]
    442                 .map(|s| s.to_string())
    443                 .as_slice(),
    444             &completer.autocomplete_upvalue("foo", line_to_position(3, text))
    445         );
    446     }
    447 
    448     #[test]
    449     fn table_index_query() {
    450         let lua = lua_executor();
    451 
    452         let mut completer = LuaCompleter::new(lua);
    453 
    454         completer.refresh_tree("print(table.ins");
    455 
    456         assert_eq!(
    457             &["insert"].map(|s| s.to_string()).as_slice(),
    458             &completer.table_index(14).map(|t| t.1).unwrap()
    459         );
    460     }
    461 
    462     #[test]
    463     fn table_index_all() {
    464         let lua = lua_executor();
    465 
    466         lua.globals()
    467             .unwrap()
    468             .set("foo", HashMap::from([("bar", 1), ("baz", 2), ("ipsum", 3)]))
    469             .unwrap();
    470 
    471         let mut completer = LuaCompleter::new(lua);
    472 
    473         completer.refresh_tree("print(foo.");
    474 
    475         let mut fields = completer.table_index(9).map(|t| t.1).unwrap();
    476         fields.sort();
    477 
    478         assert_eq!(
    479             &["bar", "baz", "ipsum"].map(|s| s.to_string()).as_slice(),
    480             &fields
    481         );
    482     }
    483 }