Skip to content

Add plugin hook for dynamic class definition #5875

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Nov 9, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions mypy/interpreted_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,8 @@ def get_customize_class_mro_hook(self, fullname: str
) -> Optional[Callable[['mypy.plugin.ClassDefContext'],
None]]:
return None

def get_dynamic_class_hook(self, fullname: str
) -> Optional[Callable[['mypy.plugin.DynamicClassDefContext'],
None]]:
return None
35 changes: 33 additions & 2 deletions mypy/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from mypy.nodes import (
Expression, StrExpr, IntExpr, UnaryExpr, Context, DictExpr, ClassDef,
TypeInfo, SymbolTableNode, MypyFile
TypeInfo, SymbolTableNode, MypyFile, CallExpr
)
from mypy.tvar_scope import TypeVarScope
from mypy.types import (
Expand Down Expand Up @@ -69,6 +69,7 @@ class SemanticAnalyzerPluginInterface:

modules = None # type: Dict[str, MypyFile]
options = None # type: Options
cur_mod_id = None # type: str
msg = None # type: MessageBuilder

@abstractmethod
Expand Down Expand Up @@ -117,6 +118,15 @@ def lookup_qualified(self, name: str, ctx: Context,
def add_plugin_dependency(self, trigger: str, target: Optional[str] = None) -> None:
raise NotImplementedError

@abstractmethod
def add_symbol_table_node(self, name: str, stnode: SymbolTableNode) -> None:
"""Add node to global symbol table (or to nearest class if there is one)."""
raise NotImplementedError

@abstractmethod
def qualified_name(self, n: str) -> str:
raise NotImplementedError


# A context for a function hook that infers the return type of a function with
# a special signature.
Expand Down Expand Up @@ -165,12 +175,21 @@ def add_plugin_dependency(self, trigger: str, target: Optional[str] = None) -> N

# A context for a class hook that modifies the class definition.
ClassDefContext = NamedTuple(
'ClassDecoratorContext', [
'ClassDefContext', [
('cls', ClassDef), # The class definition
('reason', Expression), # The expression being applied (decorator, metaclass, base class)
('api', SemanticAnalyzerPluginInterface)
])

# A context for dynamic class definitions like
# Base = declarative_base()
DynamicClassDefContext = NamedTuple(
'DynamicClassDefContext', [
('call', CallExpr), # The r.h.s. of dynamic class definition
('name', str), # The name this class is being assigned to
('api', SemanticAnalyzerPluginInterface)
])


class Plugin:
"""Base class of all type checker plugins.
Expand Down Expand Up @@ -225,6 +244,10 @@ def get_customize_class_mro_hook(self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
return None

def get_dynamic_class_hook(self, fullname: str
) -> Optional[Callable[[DynamicClassDefContext], None]]:
return None


T = TypeVar('T')

Expand Down Expand Up @@ -280,6 +303,10 @@ def get_customize_class_mro_hook(self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
return self.plugin.get_customize_class_mro_hook(fullname)

def get_dynamic_class_hook(self, fullname: str
) -> Optional[Callable[[DynamicClassDefContext], None]]:
return self.plugin.get_dynamic_class_hook(fullname)


class ChainedPlugin(Plugin):
"""A plugin that represents a sequence of chained plugins.
Expand Down Expand Up @@ -337,6 +364,10 @@ def get_customize_class_mro_hook(self, fullname: str
) -> Optional[Callable[[ClassDefContext], None]]:
return self._find_hook(lambda plugin: plugin.get_customize_class_mro_hook(fullname))

def get_dynamic_class_hook(self, fullname: str
) -> Optional[Callable[[DynamicClassDefContext], None]]:
return self._find_hook(lambda plugin: plugin.get_dynamic_class_hook(fullname))

def _find_hook(self, lookup: Callable[[Plugin], T]) -> Optional[T]:
for plugin in self._plugins:
hook = lookup(plugin)
Expand Down
21 changes: 20 additions & 1 deletion mypy/semanal.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,10 @@
from mypy.sametypes import is_same_type
from mypy.options import Options
from mypy import experiments
from mypy.plugin import Plugin, ClassDefContext, SemanticAnalyzerPluginInterface
from mypy.plugin import (
Plugin, ClassDefContext, SemanticAnalyzerPluginInterface,
DynamicClassDefContext
)
from mypy.util import get_prefix, correct_relative_import
from mypy.semanal_shared import SemanticAnalyzerInterface, set_callable_name
from mypy.scope import Scope
Expand Down Expand Up @@ -1729,6 +1732,7 @@ def final_cb(keep_final: bool) -> None:
# Store type into nodes.
for lvalue in s.lvalues:
self.store_declared_types(lvalue, s.type)
self.apply_dynamic_class_hook(s)
self.check_and_set_up_type_alias(s)
self.newtype_analyzer.process_newtype_declaration(s)
self.process_typevar_declaration(s)
Expand All @@ -1744,6 +1748,21 @@ def final_cb(keep_final: bool) -> None:
isinstance(s.rvalue, (ListExpr, TupleExpr))):
self.add_exports(s.rvalue.items)

def apply_dynamic_class_hook(self, s: AssignmentStmt) -> None:
if len(s.lvalues) > 1:
return
lval = s.lvalues[0]
if not isinstance(lval, NameExpr) or not isinstance(s.rvalue, CallExpr):
return
call = s.rvalue
if not isinstance(call.callee, RefExpr):
return
fname = call.callee.fullname
if fname:
hook = self.plugin.get_dynamic_class_hook(fname)
if hook:
hook(DynamicClassDefContext(call, lval.name, self))

def unwrap_final(self, s: AssignmentStmt) -> None:
"""Strip Final[...] if present in an assignment.

Expand Down
11 changes: 6 additions & 5 deletions mypy/test/testdiff.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from mypy.server.astdiff import snapshot_symbol_table, compare_symbol_table_snapshots
from mypy.test.config import test_temp_dir
from mypy.test.data import DataDrivenTestCase, DataSuite
from mypy.test.helpers import assert_string_arrays_equal
from mypy.test.helpers import assert_string_arrays_equal, parse_options


class ASTDiffSuite(DataSuite):
Expand All @@ -22,9 +22,10 @@ def run_case(self, testcase: DataDrivenTestCase) -> None:
first_src = '\n'.join(testcase.input)
files_dict = dict(testcase.files)
second_src = files_dict['tmp/next.py']
options = parse_options(first_src, testcase, 1)

messages1, files1 = self.build(first_src)
messages2, files2 = self.build(second_src)
messages1, files1 = self.build(first_src, options)
messages2, files2 = self.build(second_src, options)

a = []
if messages1:
Expand All @@ -47,8 +48,8 @@ def run_case(self, testcase: DataDrivenTestCase) -> None:
'Invalid output ({}, line {})'.format(testcase.file,
testcase.line))

def build(self, source: str) -> Tuple[List[str], Optional[Dict[str, MypyFile]]]:
options = Options()
def build(self, source: str,
options: Options) -> Tuple[List[str], Optional[Dict[str, MypyFile]]]:
options.use_builtins_fixtures = True
options.show_traceback = True
options.cache_dir = os.devnull
Expand Down
51 changes: 51 additions & 0 deletions test-data/unit/check-custom-plugin.test
Original file line number Diff line number Diff line change
Expand Up @@ -280,3 +280,54 @@ reveal_type(FullyQualifiedTestNamedTuple('')._asdict()) # E: Revealed type is 'b
[[mypy]
plugins=<ROOT>/test-data/unit/plugins/fully_qualified_test_hook.py
[builtins fixtures/classmethod.pyi]

[case testDynamicClassPlugin]
# flags: --config-file tmp/mypy.ini
from mod import declarative_base, Column, Instr

Base = declarative_base()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also test a few different assignment statements and ensure that the plugin doesn't get run them. Maybe things like A = B = ... and C = <other_function>(). If the plugin doesn't get run, the lvalue shouldn't be a valid base class.


class Model(Base):
x: Column[int]
class Other:
x: Column[int]

reveal_type(Model().x) # E: Revealed type is 'mod.Instr[builtins.int]'
reveal_type(Other().x) # E: Revealed type is 'mod.Column[builtins.int]'
[file mod.py]
from typing import Generic, TypeVar
def declarative_base(): ...

T = TypeVar('T')

class Column(Generic[T]): ...
class Instr(Generic[T]): ...

[file mypy.ini]
[[mypy]
plugins=<ROOT>/test-data/unit/plugins/dyn_class.py

[case testDynamicClassPluginNegatives]
# flags: --config-file tmp/mypy.ini
from mod import declarative_base, Column, Instr, non_declarative_base

Bad1 = non_declarative_base()
Bad2 = Bad3 = declarative_base()

class C1(Bad1): ... # E: Invalid base class
class C2(Bad2): ... # E: Invalid base class
class C3(Bad3): ... # E: Invalid base class

[file mod.py]
from typing import Generic, TypeVar
def declarative_base(): ...
def non_declarative_base(): ...

T = TypeVar('T')

class Column(Generic[T]): ...
class Instr(Generic[T]): ...

[file mypy.ini]
[[mypy]
plugins=<ROOT>/test-data/unit/plugins/dyn_class.py
39 changes: 39 additions & 0 deletions test-data/unit/diff.test
Original file line number Diff line number Diff line change
Expand Up @@ -1070,3 +1070,42 @@ class C:
pass
[out]
__main__.C.m

[case testDynamicBasePluginDiff]
# flags: --config-file tmp/mypy.ini
from mod import declarative_base, Column, Instr

Base = declarative_base()

class Model(Base):
x: Column[int]
class Other:
x: Column[int]
class Diff:
x: Column[int]
[file next.py]
from mod import declarative_base, Column, Instr

Base = declarative_base()

class Model(Base):
x: Column[int]
class Other:
x: Column[int]
class Diff(Base):
x: Column[int]
[file mod.py]
from typing import Generic, TypeVar
def declarative_base(): ...

T = TypeVar('T')

class Column(Generic[T]): ...
class Instr(Generic[T]): ...

[file mypy.ini]
[[mypy]
plugins=<ROOT>/test-data/unit/plugins/dyn_class.py
[out]
__main__.Diff
__main__.Diff.x
47 changes: 47 additions & 0 deletions test-data/unit/plugins/dyn_class.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from mypy.plugin import Plugin
from mypy.nodes import (
ClassDef, Block, TypeInfo, SymbolTable, SymbolTableNode, GDEF, Var
)
from mypy.types import Instance

DECL_BASES = set()

class DynPlugin(Plugin):
def get_dynamic_class_hook(self, fullname):
if fullname == 'mod.declarative_base':
return add_info_hook
return None

def get_base_class_hook(self, fullname: str):
if fullname in DECL_BASES:
return replace_col_hook
return None

def add_info_hook(ctx):
class_def = ClassDef(ctx.name, Block([]))
class_def.fullname = ctx.api.qualified_name(ctx.name)

info = TypeInfo(SymbolTable(), class_def, ctx.api.cur_mod_id)
class_def.info = info
obj = ctx.api.builtin_type('builtins.object')
info.mro = [info, obj.type]
info.bases = [obj]
ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, info))
DECL_BASES.add(class_def.fullname)

def replace_col_hook(ctx):
info = ctx.cls.info
for sym in info.names.values():
node = sym.node
if isinstance(node, Var) and isinstance(node.type, Instance):
if node.type.type.fullname() == 'mod.Column':
new_sym = ctx.api.lookup_fully_qualified_or_none('mod.Instr')
if new_sym:
new_info = new_sym.node
assert isinstance(new_info, TypeInfo)
node.type = Instance(new_info, node.type.args.copy(),
node.type.line,
node.type.column)

def plugin(version):
return DynPlugin