diff --git a/mypy/interpreted_plugin.py b/mypy/interpreted_plugin.py index 207e46f2b660..57fbf0a5c58d 100644 --- a/mypy/interpreted_plugin.py +++ b/mypy/interpreted_plugin.py @@ -67,3 +67,8 @@ def get_customize_class_mro_hook(self, fullname: str ) -> Optional[Callable[['mypy.plugin.ClassDefContext'], None]]: return None + + def get_dynamic_class_hook(self, fullname: str + ) -> Optional[Callable[['mypy.plugin.DynamicClassDefContext'], + None]]: + return None diff --git a/mypy/plugin.py b/mypy/plugin.py index a74d963ba43d..65acf5643136 100644 --- a/mypy/plugin.py +++ b/mypy/plugin.py @@ -9,7 +9,7 @@ from mypy.nodes import ( Expression, StrExpr, IntExpr, UnaryExpr, Context, DictExpr, ClassDef, - TypeInfo, SymbolTableNode, MypyFile + TypeInfo, SymbolTableNode, MypyFile, CallExpr ) from mypy.tvar_scope import TypeVarScope from mypy.types import ( @@ -69,6 +69,7 @@ class SemanticAnalyzerPluginInterface: modules = None # type: Dict[str, MypyFile] options = None # type: Options + cur_mod_id = None # type: str msg = None # type: MessageBuilder @abstractmethod @@ -117,6 +118,15 @@ def lookup_qualified(self, name: str, ctx: Context, def add_plugin_dependency(self, trigger: str, target: Optional[str] = None) -> None: raise NotImplementedError + @abstractmethod + def add_symbol_table_node(self, name: str, stnode: SymbolTableNode) -> None: + """Add node to global symbol table (or to nearest class if there is one).""" + raise NotImplementedError + + @abstractmethod + def qualified_name(self, n: str) -> str: + raise NotImplementedError + # A context for a function hook that infers the return type of a function with # a special signature. @@ -165,12 +175,21 @@ def add_plugin_dependency(self, trigger: str, target: Optional[str] = None) -> N # A context for a class hook that modifies the class definition. ClassDefContext = NamedTuple( - 'ClassDecoratorContext', [ + 'ClassDefContext', [ ('cls', ClassDef), # The class definition ('reason', Expression), # The expression being applied (decorator, metaclass, base class) ('api', SemanticAnalyzerPluginInterface) ]) +# A context for dynamic class definitions like +# Base = declarative_base() +DynamicClassDefContext = NamedTuple( + 'DynamicClassDefContext', [ + ('call', CallExpr), # The r.h.s. of dynamic class definition + ('name', str), # The name this class is being assigned to + ('api', SemanticAnalyzerPluginInterface) + ]) + class Plugin: """Base class of all type checker plugins. @@ -225,6 +244,10 @@ def get_customize_class_mro_hook(self, fullname: str ) -> Optional[Callable[[ClassDefContext], None]]: return None + def get_dynamic_class_hook(self, fullname: str + ) -> Optional[Callable[[DynamicClassDefContext], None]]: + return None + T = TypeVar('T') @@ -280,6 +303,10 @@ def get_customize_class_mro_hook(self, fullname: str ) -> Optional[Callable[[ClassDefContext], None]]: return self.plugin.get_customize_class_mro_hook(fullname) + def get_dynamic_class_hook(self, fullname: str + ) -> Optional[Callable[[DynamicClassDefContext], None]]: + return self.plugin.get_dynamic_class_hook(fullname) + class ChainedPlugin(Plugin): """A plugin that represents a sequence of chained plugins. @@ -337,6 +364,10 @@ def get_customize_class_mro_hook(self, fullname: str ) -> Optional[Callable[[ClassDefContext], None]]: return self._find_hook(lambda plugin: plugin.get_customize_class_mro_hook(fullname)) + def get_dynamic_class_hook(self, fullname: str + ) -> Optional[Callable[[DynamicClassDefContext], None]]: + return self._find_hook(lambda plugin: plugin.get_dynamic_class_hook(fullname)) + def _find_hook(self, lookup: Callable[[Plugin], T]) -> Optional[T]: for plugin in self._plugins: hook = lookup(plugin) diff --git a/mypy/semanal.py b/mypy/semanal.py index a4215073bf83..47a797504df1 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -80,7 +80,10 @@ from mypy.sametypes import is_same_type from mypy.options import Options from mypy import experiments -from mypy.plugin import Plugin, ClassDefContext, SemanticAnalyzerPluginInterface +from mypy.plugin import ( + Plugin, ClassDefContext, SemanticAnalyzerPluginInterface, + DynamicClassDefContext +) from mypy.util import get_prefix, correct_relative_import from mypy.semanal_shared import SemanticAnalyzerInterface, set_callable_name from mypy.scope import Scope @@ -1729,6 +1732,7 @@ def final_cb(keep_final: bool) -> None: # Store type into nodes. for lvalue in s.lvalues: self.store_declared_types(lvalue, s.type) + self.apply_dynamic_class_hook(s) self.check_and_set_up_type_alias(s) self.newtype_analyzer.process_newtype_declaration(s) self.process_typevar_declaration(s) @@ -1744,6 +1748,21 @@ def final_cb(keep_final: bool) -> None: isinstance(s.rvalue, (ListExpr, TupleExpr))): self.add_exports(s.rvalue.items) + def apply_dynamic_class_hook(self, s: AssignmentStmt) -> None: + if len(s.lvalues) > 1: + return + lval = s.lvalues[0] + if not isinstance(lval, NameExpr) or not isinstance(s.rvalue, CallExpr): + return + call = s.rvalue + if not isinstance(call.callee, RefExpr): + return + fname = call.callee.fullname + if fname: + hook = self.plugin.get_dynamic_class_hook(fname) + if hook: + hook(DynamicClassDefContext(call, lval.name, self)) + def unwrap_final(self, s: AssignmentStmt) -> None: """Strip Final[...] if present in an assignment. diff --git a/mypy/test/testdiff.py b/mypy/test/testdiff.py index 6e839b228b18..d4617c299b86 100644 --- a/mypy/test/testdiff.py +++ b/mypy/test/testdiff.py @@ -12,7 +12,7 @@ from mypy.server.astdiff import snapshot_symbol_table, compare_symbol_table_snapshots from mypy.test.config import test_temp_dir from mypy.test.data import DataDrivenTestCase, DataSuite -from mypy.test.helpers import assert_string_arrays_equal +from mypy.test.helpers import assert_string_arrays_equal, parse_options class ASTDiffSuite(DataSuite): @@ -22,9 +22,10 @@ def run_case(self, testcase: DataDrivenTestCase) -> None: first_src = '\n'.join(testcase.input) files_dict = dict(testcase.files) second_src = files_dict['tmp/next.py'] + options = parse_options(first_src, testcase, 1) - messages1, files1 = self.build(first_src) - messages2, files2 = self.build(second_src) + messages1, files1 = self.build(first_src, options) + messages2, files2 = self.build(second_src, options) a = [] if messages1: @@ -47,8 +48,8 @@ def run_case(self, testcase: DataDrivenTestCase) -> None: 'Invalid output ({}, line {})'.format(testcase.file, testcase.line)) - def build(self, source: str) -> Tuple[List[str], Optional[Dict[str, MypyFile]]]: - options = Options() + def build(self, source: str, + options: Options) -> Tuple[List[str], Optional[Dict[str, MypyFile]]]: options.use_builtins_fixtures = True options.show_traceback = True options.cache_dir = os.devnull diff --git a/test-data/unit/check-custom-plugin.test b/test-data/unit/check-custom-plugin.test index e01a004c9569..15ea5fb74775 100644 --- a/test-data/unit/check-custom-plugin.test +++ b/test-data/unit/check-custom-plugin.test @@ -280,3 +280,54 @@ reveal_type(FullyQualifiedTestNamedTuple('')._asdict()) # E: Revealed type is 'b [[mypy] plugins=/test-data/unit/plugins/fully_qualified_test_hook.py [builtins fixtures/classmethod.pyi] + +[case testDynamicClassPlugin] +# flags: --config-file tmp/mypy.ini +from mod import declarative_base, Column, Instr + +Base = declarative_base() + +class Model(Base): + x: Column[int] +class Other: + x: Column[int] + +reveal_type(Model().x) # E: Revealed type is 'mod.Instr[builtins.int]' +reveal_type(Other().x) # E: Revealed type is 'mod.Column[builtins.int]' +[file mod.py] +from typing import Generic, TypeVar +def declarative_base(): ... + +T = TypeVar('T') + +class Column(Generic[T]): ... +class Instr(Generic[T]): ... + +[file mypy.ini] +[[mypy] +plugins=/test-data/unit/plugins/dyn_class.py + +[case testDynamicClassPluginNegatives] +# flags: --config-file tmp/mypy.ini +from mod import declarative_base, Column, Instr, non_declarative_base + +Bad1 = non_declarative_base() +Bad2 = Bad3 = declarative_base() + +class C1(Bad1): ... # E: Invalid base class +class C2(Bad2): ... # E: Invalid base class +class C3(Bad3): ... # E: Invalid base class + +[file mod.py] +from typing import Generic, TypeVar +def declarative_base(): ... +def non_declarative_base(): ... + +T = TypeVar('T') + +class Column(Generic[T]): ... +class Instr(Generic[T]): ... + +[file mypy.ini] +[[mypy] +plugins=/test-data/unit/plugins/dyn_class.py diff --git a/test-data/unit/diff.test b/test-data/unit/diff.test index d5531118de65..ef3718522ce3 100644 --- a/test-data/unit/diff.test +++ b/test-data/unit/diff.test @@ -1070,3 +1070,42 @@ class C: pass [out] __main__.C.m + +[case testDynamicBasePluginDiff] +# flags: --config-file tmp/mypy.ini +from mod import declarative_base, Column, Instr + +Base = declarative_base() + +class Model(Base): + x: Column[int] +class Other: + x: Column[int] +class Diff: + x: Column[int] +[file next.py] +from mod import declarative_base, Column, Instr + +Base = declarative_base() + +class Model(Base): + x: Column[int] +class Other: + x: Column[int] +class Diff(Base): + x: Column[int] +[file mod.py] +from typing import Generic, TypeVar +def declarative_base(): ... + +T = TypeVar('T') + +class Column(Generic[T]): ... +class Instr(Generic[T]): ... + +[file mypy.ini] +[[mypy] +plugins=/test-data/unit/plugins/dyn_class.py +[out] +__main__.Diff +__main__.Diff.x diff --git a/test-data/unit/plugins/dyn_class.py b/test-data/unit/plugins/dyn_class.py new file mode 100644 index 000000000000..a1785c65d6c4 --- /dev/null +++ b/test-data/unit/plugins/dyn_class.py @@ -0,0 +1,47 @@ +from mypy.plugin import Plugin +from mypy.nodes import ( + ClassDef, Block, TypeInfo, SymbolTable, SymbolTableNode, GDEF, Var +) +from mypy.types import Instance + +DECL_BASES = set() + +class DynPlugin(Plugin): + def get_dynamic_class_hook(self, fullname): + if fullname == 'mod.declarative_base': + return add_info_hook + return None + + def get_base_class_hook(self, fullname: str): + if fullname in DECL_BASES: + return replace_col_hook + return None + +def add_info_hook(ctx): + class_def = ClassDef(ctx.name, Block([])) + class_def.fullname = ctx.api.qualified_name(ctx.name) + + info = TypeInfo(SymbolTable(), class_def, ctx.api.cur_mod_id) + class_def.info = info + obj = ctx.api.builtin_type('builtins.object') + info.mro = [info, obj.type] + info.bases = [obj] + ctx.api.add_symbol_table_node(ctx.name, SymbolTableNode(GDEF, info)) + DECL_BASES.add(class_def.fullname) + +def replace_col_hook(ctx): + info = ctx.cls.info + for sym in info.names.values(): + node = sym.node + if isinstance(node, Var) and isinstance(node.type, Instance): + if node.type.type.fullname() == 'mod.Column': + new_sym = ctx.api.lookup_fully_qualified_or_none('mod.Instr') + if new_sym: + new_info = new_sym.node + assert isinstance(new_info, TypeInfo) + node.type = Instance(new_info, node.type.args.copy(), + node.type.line, + node.type.column) + +def plugin(version): + return DynPlugin