|
- # Copyright (c) Microsoft Corporation. All rights reserved.
- # Licensed under the MIT License.
-
- import ast
- import json
- import sys
-
-
- class Visitor(ast.NodeVisitor):
- def __init__(self):
- self.symbols = {"classes": [], "methods": [], "functions": []}
-
- def visit_Module(self, node):
- self.visitChildren(node)
-
- def visitChildren(self, node, namespace=""):
- for child in node.body:
- if isinstance(child, ast.FunctionDef):
- self.visitDef(child, namespace)
- if isinstance(child, ast.ClassDef):
- self.visitClassDef(child, namespace)
- try:
- if isinstance(child, ast.AsyncFunctionDef):
- self.visitDef(child, namespace)
- except Exception:
- pass
-
- def visitDef(self, node, namespace=""):
- end_position = self.getEndPosition(node)
- symbol = "functions" if namespace == "" else "methods"
- self.symbols[symbol].append(self.getDataObject(node, namespace))
-
- def visitClassDef(self, node, namespace=""):
- end_position = self.getEndPosition(node)
- self.symbols['classes'].append(self.getDataObject(node, namespace))
-
- if len(namespace) > 0:
- namespace = "{0}::{1}".format(namespace, node.name)
- else:
- namespace = node.name
- self.visitChildren(node, namespace)
-
- def getDataObject(self, node, namespace=""):
- end_position = self.getEndPosition(node)
- return {
- "namespace": namespace,
- "name": node.name,
- "range": {
- "start": {
- "line": node.lineno - 1,
- "character": node.col_offset
- },
- "end": {
- "line": end_position[0],
- "character": end_position[1]
- }
- }
- }
-
- def getEndPosition(self, node):
- if not hasattr(node, 'body') or len(node.body) == 0:
- return (node.lineno - 1, node.col_offset)
- return self.getEndPosition(node.body[-1])
-
-
- def provide_symbols(source):
- """Provides a list of all symbols in provided code.
-
- The list comprises of 3-item tuples that contain the starting line number,
- ending line number and whether the statement is a single line.
-
- """
- tree = ast.parse(source)
- visitor = Visitor()
- visitor.visit(tree)
- sys.stdout.write(json.dumps(visitor.symbols))
- sys.stdout.flush()
-
-
- if __name__ == "__main__":
- if len(sys.argv) == 3:
- contents = sys.argv[2]
- else:
- with open(sys.argv[1], "r") as source:
- contents = source.read()
-
- try:
- default_encoding = sys.getdefaultencoding()
- encoded_contents = contents.encode(default_encoding, 'surrogateescape')
- contents = encoded_contents.decode(default_encoding, 'replace')
- except (UnicodeError, LookupError):
- pass
- if isinstance(contents, bytes):
- contents = contents.decode('utf8')
- provide_symbols(contents)
|