diff --git a/edb/edgeql-parser/src/keywords.rs b/edb/edgeql-parser/src/keywords.rs index 3d55fd69955..a2d760c8749 100644 --- a/edb/edgeql-parser/src/keywords.rs +++ b/edb/edgeql-parser/src/keywords.rs @@ -20,6 +20,7 @@ pub const UNRESERVED_KEYWORDS: &[&str] = &[ "cube", "current", "database", + "namespace", "ddl", "declare", "default", @@ -99,6 +100,8 @@ pub const UNRESERVED_KEYWORDS: &[&str] = &[ "version", "view", "write", + "use", + "show", ]; diff --git a/edb/edgeql/ast.py b/edb/edgeql/ast.py index dff05cf2491..775a1f74d26 100644 --- a/edb/edgeql/ast.py +++ b/edb/edgeql/ast.py @@ -226,6 +226,10 @@ class SessionResetAllAliases(BaseSessionReset): pass +class UseNameSpaceCommand(BaseSessionCommand): + name: str + + class BaseObjectRef(Base): __abstract_node__ = True @@ -836,6 +840,20 @@ class DropDatabase(DropObject, DatabaseCommand): pass +class NameSpaceCommand(ExternalObjectCommand): + __abstract_node__ = True + object_class: qltypes.SchemaObjectClass = ( + qltypes.SchemaObjectClass.NAMESPACE) + + +class CreateNameSpace(CreateObject, NameSpaceCommand): + pass + + +class DropNameSpace(DropObject, NameSpaceCommand): + pass + + class ExtensionPackageCommand(GlobalObjectCommand): __abstract_node__ = True object_class: qltypes.SchemaObjectClass = ( diff --git a/edb/edgeql/codegen.py b/edb/edgeql/codegen.py index 55ddb58a8fc..65589a4e240 100644 --- a/edb/edgeql/codegen.py +++ b/edb/edgeql/codegen.py @@ -1033,6 +1033,12 @@ def visit_AlterDatabase(self, node: qlast.AlterDatabase) -> None: def visit_DropDatabase(self, node: qlast.DropDatabase) -> None: self._visit_DropObject(node, 'DATABASE') + def visit_CreateNameSpace(self, node: qlast.CreateNameSpace) -> None: + self._visit_CreateObject(node, 'NAMESPACE') + + def visit_DropNameSpace(self, node: qlast.DropNameSpace) -> None: + self._visit_DropObject(node, 'NAMESPACE') + def visit_CreateRole(self, node: qlast.CreateRole) -> None: after_name = lambda: self._ddl_visit_bases(node) keywords = [] @@ -2185,6 +2191,13 @@ def visit_SessionResetAliasDecl( self._write_keywords('RESET ALIAS ') self.write(node.alias) + def visit_UseNameSpaceCommand( + self, + node: qlast.UseNameSpaceCommand + ) -> None: + self._write_keywords('USE NAMESPACE ') + self.write(node.name) + def visit_StartTransaction(self, node: qlast.StartTransaction) -> None: self._write_keywords('START TRANSACTION') diff --git a/edb/edgeql/parser/grammar/ddl.py b/edb/edgeql/parser/grammar/ddl.py index a616ce3ac8a..d319f386449 100644 --- a/edb/edgeql/parser/grammar/ddl.py +++ b/edb/edgeql/parser/grammar/ddl.py @@ -56,6 +56,9 @@ class DDLStmt(Nonterm): def reduce_DatabaseStmt(self, *kids): self.val = kids[0].val + def reduce_NameSpaceStmt(self, *kids): + self.val = kids[0].val + def reduce_RoleStmt(self, *kids): self.val = kids[0].val @@ -665,6 +668,38 @@ def reduce_DROP_DATABASE_DatabaseName(self, *kids): self.val = qlast.DropDatabase(name=kids[2].val) +# +# NAMESPACE +# +class NameSpaceStmt(Nonterm): + + def reduce_CreateNameSpaceStmt(self, *kids): + self.val = kids[0].val + + def reduce_DropNameSpaceStmt(self, *kids): + self.val = kids[0].val + + +class CreateNameSpaceStmt(Nonterm): + def reduce_CREATE_NAMESPACE_Identifier(self, *kids): + self.val = qlast.CreateNameSpace( + name=qlast.ObjectRef( + module=None, + name=kids[2].val + ) + ) + + +class DropNameSpaceStmt(Nonterm): + def reduce_DROP_NAMESPACE_Identifier(self, *kids): + self.val = qlast.DropNameSpace( + name=qlast.ObjectRef( + module=None, + name=kids[2].val + ) + ) + + # # EXTENSION PACKAGE # diff --git a/edb/edgeql/parser/grammar/session.py b/edb/edgeql/parser/grammar/session.py index 4d7e932201f..1fc66ea8982 100644 --- a/edb/edgeql/parser/grammar/session.py +++ b/edb/edgeql/parser/grammar/session.py @@ -18,9 +18,7 @@ from __future__ import annotations -from edb.edgeql import ast as qlast - -from .expressions import Nonterm +from edb.pgsql import common as pg_common from .tokens import * # NOQA from .expressions import * # NOQA @@ -32,6 +30,12 @@ def reduce_SetStmt(self, *kids): def reduce_ResetStmt(self, *kids): self.val = kids[0].val + def reduce_UseNameSpaceStmt(self, *kids): + self.val = kids[0].val + + def reduce_ShowNameSpaceStmt(self, *kids): + self.val = kids[0].val + class SetStmt(Nonterm): def reduce_SET_ALIAS_Identifier_AS_MODULE_ModuleName(self, *kids): @@ -54,3 +58,15 @@ def reduce_RESET_MODULE(self, *kids): def reduce_RESET_ALIAS_STAR(self, *kids): self.val = qlast.SessionResetAllAliases() + + +class UseNameSpaceStmt(Nonterm): + def reduce_USE_NAMESPACE_Identifier(self, *kids): + self.val = qlast.UseNameSpaceCommand(name=kids[2].val) + + +class ShowNameSpaceStmt(Nonterm): + def reduce_SHOW_NAMESPACE(self, *kids): + self.val = qlast.SelectQuery( + result=qlast.StringConstant(value=pg_common.NAMESPACE) + ) diff --git a/edb/edgeql/qltypes.py b/edb/edgeql/qltypes.py index 090c8e8a651..42048e5041a 100644 --- a/edb/edgeql/qltypes.py +++ b/edb/edgeql/qltypes.py @@ -237,6 +237,7 @@ class SchemaObjectClass(s_enum.StrEnum): SCALAR_TYPE = 'SCALAR TYPE' TUPLE_TYPE = 'TUPLE TYPE' TYPE = 'TYPE' + NAMESPACE = 'NAMESPACE' class LinkTargetDeleteAction(s_enum.StrEnum): diff --git a/edb/errors/__init__.py b/edb/errors/__init__.py index 68847a18cf6..4fcf1d9ec54 100644 --- a/edb/errors/__init__.py +++ b/edb/errors/__init__.py @@ -213,6 +213,10 @@ class UnknownParameterError(InvalidReferenceError): _code = 0x_04_03_00_06 +class UnknownSchemaError(InvalidReferenceError): + _code = 0x_04_03_00_07 + + class SchemaError(QueryError): _code = 0x_04_04_00_00 @@ -309,6 +313,10 @@ class DuplicateCastDefinitionError(DuplicateDefinitionError): _code = 0x_04_05_02_0A +class DuplicateNameSpaceDefinitionError(DuplicateDefinitionError): + _code = 0x_04_05_02_0B + + class SessionTimeoutError(QueryError): _code = 0x_04_06_00_00 diff --git a/edb/graphql/compiler.py b/edb/graphql/compiler.py index 781a9edc82f..a5fbf16e7ce 100644 --- a/edb/graphql/compiler.py +++ b/edb/graphql/compiler.py @@ -30,7 +30,7 @@ GQLCoreCache: Dict[ - str, + Tuple[str, str], Dict[ (s_schema.FlatSchema, uuid.UUID, s_schema.FlatSchema, str), graphql.GQLCoreSchema @@ -40,19 +40,20 @@ def _get_gqlcore( dbname: str, + namespace: str, std_schema: s_schema.FlatSchema, user_schema: s_schema.FlatSchema, global_schema: s_schema.FlatSchema, module: str = None ) -> graphql.GQLCoreSchema: key = (std_schema, user_schema.version_id, global_schema, module) - if cache := GQLCoreCache.get(dbname): + if cache := GQLCoreCache.get((dbname, namespace)): if key in cache: return cache[key] else: cache.clear() else: - cache = GQLCoreCache.setdefault(dbname, {}) + cache = GQLCoreCache.setdefault((dbname, namespace), {}) core = graphql.GQLCoreSchema( s_schema.ChainedSchema( @@ -68,6 +69,7 @@ def _get_gqlcore( def compile_graphql( dbname: str, + namespace: str, std_schema: s_schema.FlatSchema, user_schema: s_schema.FlatSchema, global_schema: s_schema.FlatSchema, @@ -88,7 +90,7 @@ def compile_graphql( else: ast = graphql.parse_tokens(gql, tokens) - gqlcore = _get_gqlcore(dbname, std_schema, user_schema, global_schema, module) + gqlcore = _get_gqlcore(dbname, namespace, std_schema, user_schema, global_schema, module) return graphql.translate_ast( gqlcore, diff --git a/edb/graphql/extension.pyx b/edb/graphql/extension.pyx index e4d19a15352..15b293ed1b5 100644 --- a/edb/graphql/extension.pyx +++ b/edb/graphql/extension.pyx @@ -97,6 +97,7 @@ async def handle_request( globals = None query = None module = None + namespace = edbdef.DEFAULT_NS limit = 0 try: @@ -111,6 +112,7 @@ async def handle_request( variables = body.get('variables') module = body.get('module') limit = body.get('limit', 0) + namespace = body.get('namespace', edbdef.DEFAULT_NS) globals = body.get('globals') elif request.content_type == 'application/graphql': query = request.body.decode('utf-8') @@ -157,6 +159,12 @@ async def handle_request( else: limit = 0 + namespace = qs.get('namespace') + if namespace is not None: + namespace = namespace[0] + else: + namespace = edbdef.DEFAULT_NS + else: raise TypeError('expected a GET or a POST request') @@ -186,7 +194,7 @@ async def handle_request( response.content_type = b'application/json' try: result = await _execute( - db, server, query, + db, namespace, server, query, operation_name, variables, globals, query_only, module or None, limit ) @@ -216,6 +224,7 @@ async def handle_request( async def compile( db, + ns, server, query: str, tokens: Optional[List[Tuple[int, int, int, str]]], @@ -229,9 +238,10 @@ async def compile( compiler_pool = server.get_compiler_pool() return await compiler_pool.compile_graphql( db.name, - db.user_schema, + ns.name, + ns.user_schema, server.get_global_schema(), - db.reflection_cache, + ns.reflection_cache, db.db_config, server.get_compilation_system_config(), query, @@ -246,9 +256,17 @@ async def compile( async def _execute( - db, server, query, operation_name, variables, + db, namespace, server, query, operation_name, variables, globals, query_only, module, limit ): + + if namespace not in db.ns_map: + raise errors.QueryError( + f'NameSpace: [{namespace}] not in current db [{db.name}](ver:{db.dbver}).' + f'Current NameSpace(s): [{", ".join(db.ns_map.keys())}]' + ) + ns = db.ns_map[namespace] + dbver = db.dbver query_cache = server._http_query_cache @@ -298,7 +316,7 @@ async def _execute( print(f'key_vars: {key_var_names}') print(f'variables: {vars}') - cache_key = ('graphql', prepared_query, key_vars, operation_name, dbver, query_only, module, limit) + cache_key = ('graphql', prepared_query, key_vars, operation_name, dbver, query_only, namespace, module, limit) use_prep_stmt = False entry: CacheEntry = None @@ -307,13 +325,14 @@ async def _execute( if isinstance(entry, CacheRedirect): key_vars2 = tuple(vars[k] for k in entry.key_vars) - cache_key2 = (prepared_query, key_vars2, operation_name, dbver, query_only, module, limit) + cache_key2 = (prepared_query, key_vars2, operation_name, dbver, query_only, namespace, module, limit) entry = query_cache.get(cache_key2, None) if entry is None: if rewritten is not None: qug, gql_op = await compile( db, + ns, server, query, rewritten.tokens(gql_lexer.TokenKind), @@ -327,6 +346,7 @@ async def _execute( else: qug, gql_op = await compile( db, + ns, server, query, None, @@ -349,7 +369,7 @@ async def _execute( query_cache[cache_key] = redir key_vars2 = tuple(vars[k] for k in key_var_names) cache_key2 = ( - 'graphql', prepared_query, key_vars2, operation_name, dbver, query_only, module, limit + 'graphql', prepared_query, key_vars2, operation_name, dbver, query_only, namespace, module, limit ) query_cache[cache_key2] = qug, gql_op if gql_op.is_introspection: @@ -372,7 +392,7 @@ async def _execute( dbv = await server.new_dbview( dbname=db.name, query_cache=False, - protocol_version=edbdef.CURRENT_PROTOCOL, + protocol_version=edbdef.CURRENT_PROTOCOL ) pgcon = await server.acquire_pgcon(db.name) diff --git a/edb/ir/ast.py b/edb/ir/ast.py index 44851564843..b8b05615964 100644 --- a/edb/ir/ast.py +++ b/edb/ir/ast.py @@ -162,6 +162,8 @@ class TypeRef(ImmutableBase): is_opaque_union: bool = False # True, if this describes an sequnce type is_sequence: bool = False + # True, if this contains enums + has_enum: bool = False def __repr__(self) -> str: return f'' diff --git a/edb/ir/typeutils.py b/edb/ir/typeutils.py index 4d5dc81e5a5..b9a54eedbbf 100644 --- a/edb/ir/typeutils.py +++ b/edb/ir/typeutils.py @@ -322,6 +322,11 @@ def type_to_typeref( else: ancestors = None + if isinstance(t, s_scalars.ScalarType): + has_enum = (t.get_enum_values(schema) is not None) + else: + has_enum = False + result = irast.TypeRef( id=t.id, name_hint=name, @@ -338,6 +343,7 @@ def type_to_typeref( is_abstract=t.get_abstract(schema), is_view=t.is_view(schema), is_opaque_union=t.get_is_opaque_union(schema), + has_enum=has_enum, ) elif isinstance(t, s_types.Tuple) and t.is_named(schema): schema, material_type = t.material_type(schema) diff --git a/edb/lib/sys.edgeql b/edb/lib/sys.edgeql index 98738ca4d59..fc371c10562 100644 --- a/edb/lib/sys.edgeql +++ b/edb/lib/sys.edgeql @@ -37,6 +37,8 @@ CREATE TYPE sys::Database EXTENDING sys::SystemObject { }; }; +CREATE TYPE sys::NameSpace EXTENDING sys::SystemObject; + CREATE TYPE sys::ExtensionPackage EXTENDING sys::SystemObject { CREATE REQUIRED PROPERTY script -> str; diff --git a/edb/pgsql/codegen.py b/edb/pgsql/codegen.py index 586aeac7ab9..536ee05da3e 100644 --- a/edb/pgsql/codegen.py +++ b/edb/pgsql/codegen.py @@ -26,6 +26,7 @@ from edb.common.ast import codegen from edb.common import exceptions from edb.common import markup +from edb.schema import defines class SQLSourceGeneratorContext(markup.MarkupExceptionContext): @@ -118,7 +119,7 @@ def visit_Relation(self, node): if node.schemaname is None: self.write(common.qname(node.name)) else: - self.write(common.qname(node.schemaname, node.name)) + self.write(common.qname(common.actual_schemaname(node.schemaname), node.name)) def _visit_values_expr(self, node): self.new_lines = 1 diff --git a/edb/pgsql/common.py b/edb/pgsql/common.py index f63aa93bec5..30ac73d88d9 100644 --- a/edb/pgsql/common.py +++ b/edb/pgsql/common.py @@ -26,7 +26,7 @@ import re from edb.common import uuidgen -from edb.schema import abc as s_abc +from edb.schema import abc as s_abc, defines from edb.schema import casts as s_casts from edb.schema import constraints as s_constr from edb.schema import defines as s_def @@ -45,6 +45,7 @@ RE_LINK_TRIGGER = re.compile(r'(source|target)-del-(def|imm)-(inl|otl)-(f|t)') RE_DUNDER_TYPE_LINK_TRIGGER = re.compile(r'dunder-type-link-[ft]') +NAMESPACE = defines.DEFAULT_NS def quote_e_literal(string): @@ -127,9 +128,23 @@ def quote_type(type_): return first + last -def get_module_backend_name(module: s_name.Name) -> str: +def get_module_backend_name(module: s_name.Name, ignore_ns=False) -> str: # standard modules go into "edgedbstd", user ones into "edgedbpub" - return "edgedbstd" if module in s_schema.STD_MODULES else "edgedbpub" + if ignore_ns: + return "edgedbstd" if module in s_schema.STD_MODULES else "edgedbpub" + return actual_schemaname("edgedbstd") if module in s_schema.STD_MODULES else actual_schemaname("edgedbpub") + + +def actual_schemaname(name: str) -> str: + global NAMESPACE + if name not in defines.EDGEDB_OWNED_DBS: + return name + + if NAMESPACE == defines.DEFAULT_NS: + ns_prefix = '' + else: + ns_prefix = NAMESPACE + '_' + return f"{ns_prefix}{name}" def get_unique_random_name() -> str: @@ -175,8 +190,8 @@ def edgedb_name_to_pg_name(name: str, prefix_length: int = 0) -> str: return _edgedb_name_to_pg_name(name, prefix_length) -def convert_name(name, suffix='', catenate=True): - schema = get_module_backend_name(name.get_module_name()) +def convert_name(name, suffix='', catenate=True, ignore_ns=False): + schema = get_module_backend_name(name.get_module_name(), ignore_ns) if suffix: sname = f'{name.name}_{suffix}' else: @@ -210,7 +225,7 @@ def update_aspect(name, aspect): return (name[0], stripped) -def get_scalar_backend_name(id, module_name, catenate=True, *, aspect=None): +def get_scalar_backend_name(id, module_name, catenate=True, *, aspect=None, ignore_ns=False): if aspect is None: aspect = 'domain' if aspect not in ( @@ -220,7 +235,7 @@ def get_scalar_backend_name(id, module_name, catenate=True, *, aspect=None): raise ValueError( f'unexpected aspect for scalar backend name: {aspect!r}') name = s_name.QualName(module=module_name, name=str(id)) - return convert_name(name, aspect, catenate) + return convert_name(name, aspect, catenate, ignore_ns) def get_aspect_suffix(aspect): @@ -355,7 +370,7 @@ def get_index_backend_name(id, module_name, catenate=True, *, aspect=None): def get_tuple_backend_name(id, catenate=True, *, aspect=None): name = s_name.QualName(module='edgedb', name=f'{id}_t') - return convert_name(name, aspect, catenate) + return convert_name(name, aspect, catenate, ignore_ns=True) def get_backend_name(schema, obj, catenate=True, *, aspect=None): diff --git a/edb/pgsql/compiler/astutils.py b/edb/pgsql/compiler/astutils.py index 15af6760b91..8aa9db5b326 100644 --- a/edb/pgsql/compiler/astutils.py +++ b/edb/pgsql/compiler/astutils.py @@ -26,7 +26,7 @@ from edb.ir import typeutils as irtyputils -from edb.pgsql import ast as pgast +from edb.pgsql import ast as pgast, common from edb.pgsql import types as pg_types if TYPE_CHECKING: @@ -234,7 +234,7 @@ def safe_array_expr( ) if any(el.nullable for el in elements): result = pgast.FuncCall( - name=('edgedb', '_nullif_array_nulls'), + name=(common.actual_schemaname('edgedb'), '_nullif_array_nulls'), args=[result], ser_safe=ser_safe, ) diff --git a/edb/pgsql/compiler/config.py b/edb/pgsql/compiler/config.py index 98a6fe72eb6..15541e25ab0 100644 --- a/edb/pgsql/compiler/config.py +++ b/edb/pgsql/compiler/config.py @@ -412,7 +412,7 @@ def _rewrite_config_insert( overwrite_query = pgast.SelectStmt() id_expr = pgast.FuncCall( - name=('edgedbext', 'uuid_generate_v1mc',), + name=(common.actual_schemaname('edgedbext'), 'uuid_generate_v1mc',), args=[], ) pathctx.put_path_identity_var( diff --git a/edb/pgsql/compiler/dml.py b/edb/pgsql/compiler/dml.py index 981ab348155..58e45743707 100644 --- a/edb/pgsql/compiler/dml.py +++ b/edb/pgsql/compiler/dml.py @@ -1564,7 +1564,7 @@ def check_update_type( # also the (dynamic) type of the argument, so that we can produce # a good error message. check_result = pgast.FuncCall( - name=('edgedb', 'issubclass'), + name=(common.actual_schemaname('edgedb'), 'issubclass'), args=[typ, typeref_val], ) maybe_null = pgast.CaseExpr( @@ -2372,7 +2372,7 @@ def process_link_values( ): if src_prop.out_target.is_sequence: seq_backend_name = pgast.StringConstant( - val=f'"edgedbpub"."{src_prop.out_target.id}_sequence"' + val=f'"{common.actual_schemaname("edgedbpub")}"."{src_prop.out_target.id}_sequence"' ) source_val = pgast.FuncCall( name=('currval', ), diff --git a/edb/pgsql/compiler/expr.py b/edb/pgsql/compiler/expr.py index 3314443ea7c..65c96f22636 100644 --- a/edb/pgsql/compiler/expr.py +++ b/edb/pgsql/compiler/expr.py @@ -560,7 +560,7 @@ def compile_TypeCheckOp( right = dispatch.compile(expr.right, ctx=newctx) result = pgast.FuncCall( - name=('edgedb', 'issubclass'), + name=(common.actual_schemaname('edgedb'), 'issubclass'), args=[left, right]) if negated: diff --git a/edb/pgsql/compiler/output.py b/edb/pgsql/compiler/output.py index 37b5d603df3..e607bc2be8e 100644 --- a/edb/pgsql/compiler/output.py +++ b/edb/pgsql/compiler/output.py @@ -560,7 +560,7 @@ def serialize_expr_to_json( elif irtyputils.is_range(styperef) and not expr.ser_safe: val = pgast.FuncCall( # Use the actual generic helper for converting anyrange to jsonb - name=('edgedb', 'range_to_jsonb'), + name=(common.actual_schemaname('edgedb'), 'range_to_jsonb'), args=[expr], null_safe=True, ser_safe=True) elif irtyputils.is_collection(styperef) and not expr.ser_safe: diff --git a/edb/pgsql/compiler/relctx.py b/edb/pgsql/compiler/relctx.py index b01e3c9b289..552e5325322 100644 --- a/edb/pgsql/compiler/relctx.py +++ b/edb/pgsql/compiler/relctx.py @@ -441,7 +441,7 @@ def new_free_object_rvar( qry = subctx.rel id_expr = pgast.FuncCall( - name=('edgedbext', 'uuid_generate_v4',), args=[] + name=(common.actual_schemaname('edgedbext'), 'uuid_generate_v4',), args=[] ) pathctx.put_path_identity_var(qry, path_id, id_expr, env=ctx.env) @@ -824,7 +824,7 @@ def ensure_transient_identity_for_path( ) -> None: id_expr = pgast.FuncCall( - name=('edgedbext', 'uuid_generate_v4',), + name=(common.actual_schemaname('edgedbext'), 'uuid_generate_v4',), args=[], ) diff --git a/edb/pgsql/dbops/__init__.py b/edb/pgsql/dbops/__init__.py index dc8fdb4a0b4..6e9b1212b7b 100644 --- a/edb/pgsql/dbops/__init__.py +++ b/edb/pgsql/dbops/__init__.py @@ -40,3 +40,4 @@ from .triggers import * # NOQA from .types import * # NOQA from .views import * # NOQA +from .namespace import * # NOQA diff --git a/edb/pgsql/dbops/ddl.py b/edb/pgsql/dbops/ddl.py index 4618b17a2fd..5d092e5c1ec 100644 --- a/edb/pgsql/dbops/ddl.py +++ b/edb/pgsql/dbops/ddl.py @@ -27,6 +27,7 @@ from ..common import quote_ident as qi from ..common import quote_literal as ql +from ..common import actual_schemaname as actual from . import base @@ -115,7 +116,7 @@ def code(self, block: base.PLBlock) -> str: if is_shared: return textwrap.dedent(f'''\ SELECT - edgedb.shobj_metadata( + {actual("edgedb")}.shobj_metadata( {objoid}, {classoid}::regclass::text ) @@ -123,7 +124,7 @@ def code(self, block: base.PLBlock) -> str: elif objsubid: return textwrap.dedent(f'''\ SELECT - edgedb.col_metadata( + {actual("edgedb")}.col_metadata( {objoid}, {objsubid} ) @@ -131,7 +132,7 @@ def code(self, block: base.PLBlock) -> str: else: return textwrap.dedent(f'''\ SELECT - edgedb.obj_metadata( + {actual("edgedb")}.obj_metadata( {objoid}, {classoid}::regclass::text, ) @@ -149,7 +150,7 @@ def code(self, block: base.PLBlock) -> str: SELECT json FROM - edgedbinstdata.instdata + {actual("edgedbinstdata")}.instdata WHERE key = {ql(key)} ''') @@ -211,7 +212,7 @@ def code(self, block: base.PLBlock) -> str: metadata = ql(json.dumps(self.metadata)) return textwrap.dedent(f'''\ UPDATE - edgedbinstdata.instdata + {actual("edgedbinstdata")}.instdata SET json = {metadata} WHERE @@ -260,7 +261,7 @@ def code(self, block: base.PLBlock) -> str: return textwrap.dedent(f'''\ UPDATE - edgedbinstdata.instdata + {actual("edgedbinstdata")}.instdata SET json = {json_v} || {meta_v} WHERE @@ -329,7 +330,7 @@ def code(self, block: base.PLBlock) -> str: json_v, meta_v = self._merge(block) return textwrap.dedent(f'''\ UPDATE - edgedbinstdata.instdata + {actual("edgedbinstdata")}.instdata SET json = {json_v} || {meta_v} WHERE diff --git a/edb/pgsql/dbops/namespace.py b/edb/pgsql/dbops/namespace.py new file mode 100644 index 00000000000..678528886d6 --- /dev/null +++ b/edb/pgsql/dbops/namespace.py @@ -0,0 +1,69 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2008-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from __future__ import annotations + +from typing import Optional, Mapping, Any + +from . import base +from . import ddl +from edb.pgsql.common import quote_ident as qi +from edb.schema.defines import DEFAULT_NS, EDGEDB_OWNED_DBS + + +class NameSpace(base.DBObject): + def __init__( + self, + name: str, + metadata: Optional[Mapping[str, Any]] = None, + ): + super().__init__(metadata=metadata) + self.name = name + + def get_type(self): + return 'SCHEMA' + + def get_id(self): + return qi(f"{self.name}_edgedb") if self.name != DEFAULT_NS else qi("edgedb") + + def is_shared(self) -> bool: + return False + + +class CreateNameSpace(ddl.CreateObject, ddl.NonTransactionalDDLOperation): + def __init__(self, object, **kwargs): + super().__init__(object, **kwargs) + + def code(self, block: base.PLBlock) -> str: + return '' + + +class DropNameSpace( + ddl.SchemaObjectOperation, + ddl.NonTransactionalDDLOperation +): + + def code(self, block: base.PLBlock) -> str: + schemas = ",".join( + [ + qi(f"{self.name}_{schema}") + for schema in EDGEDB_OWNED_DBS + ] + ) + return f'DROP SCHEMA {schemas} CASCADE;' diff --git a/edb/pgsql/dbops/roles.py b/edb/pgsql/dbops/roles.py index a5505869737..555a45a1994 100644 --- a/edb/pgsql/dbops/roles.py +++ b/edb/pgsql/dbops/roles.py @@ -25,6 +25,7 @@ from ..common import quote_ident as qi from ..common import quote_literal as ql +from ..common import actual_schemaname as actual from . import base from . import ddl @@ -153,7 +154,7 @@ def generate_extra(self, block: base.PLBlock) -> None: value = json.dumps(self.object.single_role_metadata) query = base.Query( f''' - UPDATE edgedbinstdata.instdata + UPDATE {actual("edgedbinstdata")}.instdata SET json = {ql(value)}::jsonb WHERE key = 'single_role_metadata' ''' diff --git a/edb/pgsql/delta.py b/edb/pgsql/delta.py index f34d1be3455..aa0b5dd54d3 100644 --- a/edb/pgsql/delta.py +++ b/edb/pgsql/delta.py @@ -52,6 +52,7 @@ from edb.schema import properties as s_props from edb.schema import migrations as s_migrations from edb.schema import modules as s_mod +from edb.schema import namespace as s_ns from edb.schema import name as sn from edb.schema import objects as so from edb.schema import operators as s_opers @@ -361,7 +362,7 @@ def apply( (SELECT version::text FROM - edgedb."_SchemaSchemaVersion" + {common.actual_schemaname('edgedb')}."_SchemaSchemaVersion" FOR UPDATE), {ql(str(expected_ver))} )), @@ -369,7 +370,7 @@ def apply( msg => ( 'Cannot serialize DDL: ' || (SELECT version::text FROM - edgedb."_SchemaSchemaVersion") + {common.actual_schemaname('edgedb')}."_SchemaSchemaVersion") ) ) INTO _dummy_text @@ -1170,7 +1171,7 @@ def compile_edgeql_overloaded_function_body( target AS ancestor, index FROM - edgedb."_SchemaObjectType__ancestors" + {common.actual_schemaname('edgedb')}."_SchemaObjectType__ancestors" WHERE source = {qi(type_param_name)} ) a WHERE ancestor IN ({impl_ids}) @@ -3540,8 +3541,9 @@ def apply( self.apply_scheduled_inhview_updates(schema, context) if is_external: - view_name = ('edgedbpub', str(objtype.id)) - view_name_t = ('edgedbpub', str(objtype.id) + '_t') + schema_name = common.actual_schemaname('edgedbpub') + view_name = (schema_name, str(objtype.id)) + view_name_t = (schema_name, str(objtype.id) + '_t') self.pgops.add( dbops.DropView( name=view_name, @@ -5159,8 +5161,9 @@ def apply( self.apply_scheduled_inhview_updates(schema, context) if has_extern_table: - view_name = ('edgedbpub', str(link.id)) - view_name_t = ('edgedbpub', str(link.id) + '_t') + schema_name = common.actual_schemaname('edgedbpub') + view_name = (schema_name, str(link.id)) + view_name_t = (schema_name, str(link.id) + '_t') self.pgops.add( dbops.DropView( name=view_name, @@ -5808,41 +5811,47 @@ def get_trigger_proc_text(self, target, links, *, target, links, disposition=disposition, schema=schema) def _get_dunder_type_trigger_proc_text(self, target, *, schema): - body = textwrap.dedent('''\ - SELECT - CASE WHEN tp.builtin - THEN 'edgedbstd' - ELSE 'edgedbpub' - END AS sname - INTO schema_name - FROM edgedb."_SchemaType" as tp - WHERE tp.id = OLD.id; - - SELECT EXISTS ( - SELECT FROM pg_tables - WHERE schemaname = "schema_name" - AND tablename = OLD.id::text - ) INTO table_exists; - - IF table_exists THEN - target_sql = format('SELECT EXISTS (SELECT FROM %I.%I LIMIT 1)', "schema_name", OLD.id::text); - EXECUTE target_sql into del_prohibited; - ELSE - del_prohibited = FALSE; - END IF; - - IF del_prohibited THEN - RAISE foreign_key_violation - USING - TABLE = TG_TABLE_NAME, - SCHEMA = TG_TABLE_SCHEMA, - MESSAGE = 'deletion of {tgtname} (' || OLD.id - || ') is prohibited by link target policy', - DETAIL = 'Object is still referenced in link __type__' - || ' of ' || edgedb._get_schema_object_name(OLD.id) || ' (' - || OLD.id || ').'; - END IF; - '''.format(tgtname=target.get_displayname(schema))) + body = textwrap.dedent( + '''SELECT + CASE WHEN tp.builtin + THEN '{std}' + ELSE '{pub}' + END AS sname + INTO schema_name + FROM {edb}."_SchemaType" as tp + WHERE tp.id = OLD.id; + + SELECT EXISTS ( + SELECT FROM pg_tables + WHERE schemaname = "schema_name" + AND tablename = OLD.id::text + ) INTO table_exists; + + IF table_exists THEN + target_sql = format('SELECT EXISTS (SELECT FROM %I.%I LIMIT 1)', "schema_name", OLD.id::text); + EXECUTE target_sql into del_prohibited; + ELSE + del_prohibited = FALSE; + END IF; + + IF del_prohibited THEN + RAISE foreign_key_violation + USING + TABLE = TG_TABLE_NAME, + SCHEMA = TG_TABLE_SCHEMA, + MESSAGE = 'deletion of {tgtname} (' || OLD.id + || ') is prohibited by link target policy', + DETAIL = 'Object is still referenced in link __type__' + || ' of ' || {edb}._get_schema_object_name(OLD.id) || ' (' + || OLD.id || ').'; + END IF; + '''.format( + tgtname=target.get_displayname(schema), + std=common.actual_schemaname('edgedbstd'), + pub=common.actual_schemaname('edgedbpub'), + edb=common.actual_schemaname('edgedb') + ) + ) text = textwrap.dedent('''\ DECLARE @@ -5917,11 +5926,11 @@ def _declare_var(var_prefix, index, var_type): IF FOUND THEN SELECT - edgedb.shortname_from_fullname(link.name), - edgedb._get_schema_object_name(link.{far_endpoint}) + {edb}.shortname_from_fullname(link.name), + {edb}._get_schema_object_name(link.{far_endpoint}) INTO linkname, endname FROM - edgedb."_SchemaLink" AS link + {edb}."_SchemaLink" AS link WHERE link.id = link_type_id; RAISE foreign_key_violation @@ -5942,6 +5951,7 @@ def _declare_var(var_prefix, index, var_type): tgtname=target.get_displayname(schema), near_endpoint=near_endpoint, far_endpoint=far_endpoint, + edb=common.actual_schemaname('edgedb') ) chunks.append(text) @@ -6237,11 +6247,11 @@ def _get_inline_link_trigger_proc_text( IF FOUND THEN SELECT - edgedb.shortname_from_fullname(link.name), - edgedb._get_schema_object_name(link.{far_endpoint}) + {edb}.shortname_from_fullname(link.name), + {edb}._get_schema_object_name(link.{far_endpoint}) INTO linkname, endname FROM - edgedb."_SchemaLink" AS link + {edb}."_SchemaLink" AS link WHERE link.id = link_type_id; RAISE foreign_key_violation @@ -6262,6 +6272,7 @@ def _get_inline_link_trigger_proc_text( tgtname=target.get_displayname(schema), near_endpoint=near_endpoint, far_endpoint=far_endpoint, + edb=common.actual_schemaname('edgedb') ) chunks.append(text) @@ -6762,8 +6773,9 @@ def collect_external_objects( view_def = context.external_view[key] if context.restoring_external: - self.external_views.append(dbops.View(query=view_def, name=('edgedbpub', str(obj.id)))) - self.external_views.append(dbops.View(query=view_def, name=('edgedbpub', str(obj.id) + '_t'))) + schema_name = common.actual_schemaname('edgedbpub') + self.external_views.append(dbops.View(query=view_def, name=(schema_name, str(obj.id)))) + self.external_views.append(dbops.View(query=view_def, name=(schema_name, str(obj.id) + '_t'))) return columns = [] @@ -6784,7 +6796,7 @@ def collect_external_objects( ptrname = ptr.get_shortname(schema).name if ptrname == 'id': - columns.append("edgedbext.uuid_generate_v1mc() AS id") + columns.append(f"{common.actual_schemaname('edgedbext')}.uuid_generate_v1mc() AS id") elif ptrname == '__type__': columns.append(f"'{(str(obj.id))}'::uuid AS __type__") elif has_link_table: @@ -6813,8 +6825,9 @@ def collect_external_objects( if join_link_table is not None: query += f", (SELECT * FROM {join_link_table.relation}) AS INNER_T " \ f"where INNER_T.{join_link_table.columns['source']} = SOURCE_T.{source_identity}" - self.external_views.append(dbops.View(query=query, name=('edgedbpub', str(obj.id)))) - self.external_views.append(dbops.View(query=query, name=('edgedbpub', str(obj.id) + '_t'))) + schema_name = common.actual_schemaname('edgedbpub') + self.external_views.append(dbops.View(query=query, name=(schema_name, str(obj.id)))) + self.external_views.append(dbops.View(query=query, name=(schema_name, str(obj.id) + '_t'))) def apply( @@ -6843,6 +6856,46 @@ class DeleteModule(ModuleMetaCommand, adapts=s_mod.DeleteModule): pass +class NameSpaceMetaCommand(MetaCommand): + pass + + +class CreateNameSpace(NameSpaceMetaCommand, adapts=s_ns.CreateNameSpace): + def apply( + self, + schema: s_schema.Schema, + context: sd.CommandContext, + ) -> s_schema.Schema: + schema = super().apply(schema, context) + self.pgops.add( + dbops.CreateNameSpace( + dbops.NameSpace( + str(self.classname), + metadata=dict( + id=str(self.scls.id), + builtin=self.get_attribute_value('builtin'), + name=str(self.classname), + internal=False + ), + ), + ) + ) + return schema + + +class DeleteNameSpace(NameSpaceMetaCommand, adapts=s_ns.DeleteNameSpace): + def apply( + self, + schema: s_schema.Schema, + context: sd.CommandContext, + ) -> s_schema.Schema: + schema = super().apply(schema, context) + self.pgops.add( + dbops.DropNameSpace(self.classname) + ) + return schema + + class DatabaseMixin: def ensure_has_create_database(self, backend_params): if not backend_params.has_create_database: diff --git a/edb/pgsql/metaschema.py b/edb/pgsql/metaschema.py index 0928373c9f0..90c869c0d4a 100644 --- a/edb/pgsql/metaschema.py +++ b/edb/pgsql/metaschema.py @@ -4008,8 +4008,18 @@ async def bootstrap( config_spec: edbconfig.Spec ) -> None: commands = dbops.CommandGroup() + default_ns = dbops.NameSpace( + defines.DEFAULT_NS, + metadata=dict( + id=str(uuidgen.uuid1mc()), + builtin=False, + name=defines.DEFAULT_NS, + internal=False + ), + ) commands.add_commands([ dbops.CreateSchema(name='edgedb'), + dbops.SetMetadata(default_ns, default_ns.metadata), dbops.CreateSchema(name='edgedbss'), dbops.CreateSchema(name='edgedbpub'), dbops.CreateSchema(name='edgedbstd'), @@ -4803,6 +4813,102 @@ def _generate_schema_ver_views(schema: s_schema.Schema) -> List[dbops.View]: return views +def _generate_namespace_views(schema: s_schema.Schema) -> List[dbops.View]: + NameSpace = schema.get('sys::NameSpace', type=s_objtypes.ObjectType) + annos = NameSpace.getptr( + schema, s_name.UnqualName('annotations'), type=s_links.Link + ) + int_annos = NameSpace.getptr( + schema, s_name.UnqualName('annotations__internal'), type=s_links.Link + ) + + view_query = f''' + SELECT + (ns.description->>'id')::uuid + AS {qi(ptr_col_name(schema, NameSpace, 'id'))}, + (SELECT id FROM edgedb."_SchemaObjectType" + WHERE name = 'sys::NameSpace') + AS {qi(ptr_col_name(schema, NameSpace, '__type__'))}, + (ns.description->>'name') + AS {qi(ptr_col_name(schema, NameSpace, 'name'))}, + (ns.description->>'name') + AS {qi(ptr_col_name(schema, NameSpace, 'name__internal'))}, + ARRAY[]::text[] + AS {qi(ptr_col_name(schema, NameSpace, 'computed_fields'))}, + (ns.description->>'builtin')::bool + AS {qi(ptr_col_name(schema, NameSpace, 'builtin'))}, + (ns.description->>'internal')::bool + AS {qi(ptr_col_name(schema, NameSpace, 'internal'))}, + (ns.description->>'module_name') + AS {qi(ptr_col_name(schema, NameSpace, 'module_name'))}, + ((ns.description)->>'external')::bool + AS {qi(ptr_col_name(schema, NameSpace, 'external'))} + FROM + information_schema.schemata as s + CROSS JOIN LATERAL ( + select edgedb.obj_metadata(s.schema_name::regnamespace, 'pg_namespace') as DESCRIPTION + ) as ns + where ns.description ->> 'id' is not null + ''' + + annos_link_query = f''' + SELECT + (ns.value->>'id')::uuid + AS {qi(ptr_col_name(schema, annos, 'source'))}, + (annotations->>'id')::uuid + AS {qi(ptr_col_name(schema, annos, 'target'))}, + (annotations->>'value')::text + AS {qi(ptr_col_name(schema, annos, 'value'))}, + (annotations->>'is_owned')::bool + AS {qi(ptr_col_name(schema, annos, 'owned'))} + FROM + jsonb_each( + edgedb.get_database_metadata( + current_database() + ) -> 'NameSpace' + ) AS ns + CROSS JOIN LATERAL + ROWS FROM ( + jsonb_array_elements(ns.value->'annotations') + ) AS annotations + ''' + + int_annos_link_query = f''' + SELECT + (ns.value->>'id')::uuid + AS {qi(ptr_col_name(schema, int_annos, 'source'))}, + (annotations->>'id')::uuid + AS {qi(ptr_col_name(schema, int_annos, 'target'))}, + (annotations->>'is_owned')::bool + AS {qi(ptr_col_name(schema, int_annos, 'owned'))} + FROM + jsonb_each( + edgedb.get_database_metadata( + current_database() + ) -> 'NameSpace' + ) AS ns + CROSS JOIN LATERAL + ROWS FROM ( + jsonb_array_elements(ns.value->'annotations__internal') + ) AS annotations + ''' + + objects = { + NameSpace: view_query, + annos: annos_link_query, + int_annos: int_annos_link_query, + } + + views = [] + for obj, query in objects.items(): + tabview = dbops.View(name=tabname(schema, obj), query=query) + inhview = dbops.View(name=inhviewname(schema, obj), query=query) + views.append(tabview) + views.append(inhview) + + return views + + def _make_json_caster( schema: s_schema.Schema, stype: s_types.Type, @@ -4976,6 +5082,9 @@ async def generate_support_views( for verview in _generate_schema_ver_views(schema): commands.add_command(dbops.CreateView(verview, or_replace=True)) + for nsview in _generate_namespace_views(schema): + commands.add_command(dbops.CreateView(nsview, or_replace=True)) + sys_alias_views = _generate_schema_alias_views( schema, s_name.UnqualName('sys')) for alias_view in sys_alias_views: diff --git a/edb/pgsql/types.py b/edb/pgsql/types.py index c9d75d4a573..bc529ca425b 100644 --- a/edb/pgsql/types.py +++ b/edb/pgsql/types.py @@ -307,9 +307,15 @@ def pg_type_from_ir_typeref( else: pg_type = base_type_name_map.get(material.id) if pg_type is None: + builtin_extending_enum = ( + material.has_enum + and str(material.name_hint.module) in s_schema.STD_MODULES_STR + ) # User-defined scalar type pg_type = common.get_scalar_backend_name( - material.id, material.name_hint.module, catenate=False) + material.id, material.name_hint.module, catenate=False, + ignore_ns=builtin_extending_enum + ) return pg_type diff --git a/edb/schema/defines.py b/edb/schema/defines.py index 81d448d0c2d..581f57f9f00 100644 --- a/edb/schema/defines.py +++ b/edb/schema/defines.py @@ -35,3 +35,6 @@ EDGEDB_SYSTEM_DB = '__edgedbsys__' EDGEDB_SPECIAL_DBS = {EDGEDB_TEMPLATE_DB, EDGEDB_SYSTEM_DB} +EDGEDB_OWNED_DBS = ['edgedbext', 'edgedb', 'edgedbss', 'edgedbpub', 'edgedbstd', 'edgedbinstdata'] + +DEFAULT_NS = 'default' diff --git a/edb/schema/namespace.py b/edb/schema/namespace.py new file mode 100644 index 00000000000..7b714c5da18 --- /dev/null +++ b/edb/schema/namespace.py @@ -0,0 +1,95 @@ +# +# This source file is part of the EdgeDB open source project. +# +# Copyright 2008-present MagicStack Inc. and the EdgeDB authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + + +from __future__ import annotations + +from edb import errors +from edb.edgeql import ast as qlast +from edb.edgeql import qltypes +from . import annos as s_anno +from . import delta as sd +from . import objects as so +from . import schema as s_schema +from . import defines + + +class NameSpace( + so.ExternalObject, + s_anno.AnnotationSubject, + qlkind=qltypes.SchemaObjectClass.NAMESPACE, + data_safe=False, +): + pass + + +class NameSpaceCommandContext(sd.ObjectCommandContext[NameSpace]): + pass + + +class NameSpaceCommand( + sd.ExternalObjectCommand[NameSpace], + context_class=NameSpaceCommandContext, +): + def _validate_name( + self, + schema: s_schema.Schema, + context: sd.CommandContext, + ) -> None: + name = self.get_attribute_value('name') + if str(name).startswith('pg_'): + source_context = self.get_attribute_source_context('name') + raise errors.SchemaDefinitionError( + f'NameSpace names can not be started with \'pg_\', ' + f'as such names are reserved for system schemas', + context=source_context, + ) + if str(name) == defines.DEFAULT_NS: + source_context = self.get_attribute_source_context('name') + raise errors.SchemaDefinitionError( + f'\'{defines.DEFAULT_NS}\' is reserved as name for ' + f'default namespace, use others instead.', + context=source_context, + ) + + +class CreateNameSpace(NameSpaceCommand, sd.CreateExternalObject[NameSpace]): + astnode = qlast.CreateNameSpace + + def validate_create( + self, + schema: s_schema.Schema, + context: sd.CommandContext, + ) -> None: + super().validate_create(schema, context) + self._validate_name(schema, context) + + +class DeleteNameSpace(NameSpaceCommand, sd.DeleteExternalObject[NameSpace]): + astnode = qlast.DropNameSpace + + def _validate_legal_command( + self, + schema: s_schema.Schema, + context: sd.CommandContext, + ) -> None: + super()._validate_legal_command(schema, context) + if self.classname.name == defines.DEFAULT_NS: + raise errors.ExecutionError( + f"namespace {self.classname.name!r} cannot be dropped" + ) diff --git a/edb/schema/reflection/structure.py b/edb/schema/reflection/structure.py index f4dce49b924..1eefe954974 100644 --- a/edb/schema/reflection/structure.py +++ b/edb/schema/reflection/structure.py @@ -36,6 +36,7 @@ from edb.schema import inheriting as s_inh from edb.schema import links as s_links from edb.schema import name as sn +from edb.schema import namespace as s_ns from edb.schema import objects as s_obj from edb.schema import objtypes as s_objtypes from edb.schema import schema as s_schema @@ -794,7 +795,8 @@ def generate_structure(schema: s_schema.Schema) -> SchemaReflectionParts: qry += ' FILTER NOT .builtin' if issubclass(py_cls, s_obj.GlobalObject): - global_parts.append(qry) + if not issubclass(py_cls, s_ns.NameSpace): + global_parts.append(qry) else: local_parts.append(qry) diff --git a/edb/schema/schema.py b/edb/schema/schema.py index d5c6571b222..90cc4b30776 100644 --- a/edb/schema/schema.py +++ b/edb/schema/schema.py @@ -66,7 +66,7 @@ BUILTIN_MODULES = STD_MODULES + (sn.UnqualName('builtin'), ) -STD_MODULES_STR = {'sys', 'schema', 'cal', 'math'} +STD_MODULES_STR = {'std', 'sys', 'schema', 'cal', 'math', 'cfg', 'builtin'} # Specifies the order of processing of files and directories in lib/ STD_SOURCES = ( diff --git a/edb/server/bootstrap.py b/edb/server/bootstrap.py index 817a827148f..27d71791ba4 100644 --- a/edb/server/bootstrap.py +++ b/edb/server/bootstrap.py @@ -998,6 +998,81 @@ async def _init_stdlib( return stdlib, config_spec, compiler +async def store_tpl_sql(tpldbdump: bytes, conn: pgcon.PGConnection): + text = f"""\ + INSERT INTO edgedbinstdata.instdata (key, text) + VALUES( + {pg_common.quote_literal('tpl_sql')}, + {pg_common.quote_literal(tpldbdump.decode('utf-8'))}::text + ) + """ + + await _execute(conn, text) + + +async def gen_tpl_dump(cluster: pgcluster.BaseCluster): + tpl_db_name = edbdef.EDGEDB_TEMPLATE_DB + tpl_pg_db_name = cluster.get_db_name(tpl_db_name) + tpldbdump = await cluster.dump_database( + tpl_pg_db_name, + exclude_schemas=['edgedbext'], + dump_object_owners=False, + ) + # exclude create type & domain + tpldbdump = re.sub( + rb'CREATE (?:(TYPE|DOMAIN))[^;]*;', + rb'', + tpldbdump, + flags=re.DOTALL + ) + + commands = [dbops.CreateSchema(name='{ns_prefix}edgedbext')] + for uuid_func in [ + 'uuid_generate_v1', + 'uuid_generate_v1mc', + 'uuid_generate_v4', + 'uuid_nil', + 'uuid_ns_dns', + 'uuid_ns_oid', + 'uuid_ns_url', + 'uuid_ns_x500', + ]: + commands.append( + dbops.CreateOrReplaceFunction( + dbops.Function( + name=('{ns_prefix}edgedbext', uuid_func), + returns=('pg_catalog', 'uuid'), language='plpgsql', + text=f""" + BEGIN + RETURN edgedbext.{uuid_func}(); + END; + """ + ) + ) + ) + + for uuid_func in ['uuid_generate_v3', 'uuid_generate_v5']: + commands.append( + dbops.CreateOrReplaceFunction( + dbops.Function( + name=('{ns_prefix}edgedbext', uuid_func), + returns=('pg_catalog', 'uuid'), language='plpgsql', + args=[('namespace', 'uuid'), ('name', 'text')], + text=f""" + BEGIN + RETURN edgedbext.{uuid_func}(namespace, text); + END; + """ + ) + ) + ) + command_group = dbops.CommandGroup() + command_group.add_commands(commands) + block = dbops.PLTopBlock() + command_group.generate(block) + return block.to_string().encode('utf-8') + tpldbdump + + async def _init_defaults(schema, compiler, conn): script = ''' CREATE MODULE default; @@ -1105,6 +1180,17 @@ async def _compile_sys_queries( queries['listdbs'] = sql + _, sql = compile_bootstrap_script( + compiler, + schema, + f"""SELECT ( + SELECT sys::NameSpace + ).name""", + expected_cardinality_one=False, + ) + + queries['listns'] = sql + role_query = ''' SELECT sys::Role { name, @@ -1364,6 +1450,17 @@ async def _get_instance_data(conn: pgcon.PGConnection) -> Dict[str, Any]: return json.loads(data) +async def get_tpl_sql(conn: pgcon.PGConnection) -> bytes: + data = await conn.sql_fetch_val( + b""" + SELECT text + FROM edgedbinstdata.instdata + WHERE key = 'tpl_sql' + """, + ) + return data + + async def _check_catalog_compatibility( ctx: BootstrapContext, ) -> pgcon.PGConnection: @@ -1509,6 +1606,8 @@ async def _start(ctx: BootstrapContext) -> None: # Initialize global config config.set_settings(config_spec) + if ctx.cluster._pg_bin_dir is None: + await ctx.cluster.lookup_postgres() finally: conn.terminate() diff --git a/edb/server/compiler/compiler.py b/edb/server/compiler/compiler.py index 4e82efd3a13..a9bd1a7174d 100644 --- a/edb/server/compiler/compiler.py +++ b/edb/server/compiler/compiler.py @@ -55,6 +55,7 @@ from edb.ir import ast as irast from edb.schema import database as s_db +from edb.schema import namespace as s_ns from edb.schema import extensions as s_ext from edb.schema import roles as s_roles from edb.schema import ddl as s_ddl @@ -143,6 +144,8 @@ class CompileContext: restoring_external: Optional[bool] = False # If is test mode from http testmode: Optional[bool] = False + # NameSpace for current compile + namespace: str = defines.DEFAULT_NS DEFAULT_MODULE_ALIASES_MAP = immutables.Map( @@ -415,7 +418,12 @@ def _process_delta(self, ctx: CompileContext, delta): for c in pgdelta.get_subcommands() ) - if db_cmd: + ns_cmd = any( + isinstance(c, s_ns.NameSpaceCommand) + for c in pgdelta.get_subcommands() + ) + + if db_cmd or ns_cmd: block = pg_dbops.SQLBlock() new_be_types = new_types = frozenset() else: @@ -439,6 +447,7 @@ def may_has_backend_id(_id): # schema persistence asynchronizable schema_peristence_async = ( not db_cmd + and not ns_cmd and not new_be_types and not any( isinstance(c, (s_ext.ExtensionCommand, @@ -462,23 +471,23 @@ def may_has_backend_id(_id): ctx, pgdelta, subblock, context=context, schema_persist_block=refl_block ) - + instdata_schemaname = pg_common.actual_schemaname("edgedbinstdata") if schema_peristence_async: if debug.flags.keep_schema_persistence_history: invalid_persist_his = f"""\ - UPDATE "edgedbinstdata"."schema_persist_history" + UPDATE "{instdata_schemaname}"."schema_persist_history" SET active = false WHERE version_id = '{str(ver_id)}'::uuid;\ """ else: invalid_persist_his = f"""\ - DELETE FROM "edgedbinstdata"."schema_persist_history" + DELETE FROM "{instdata_schemaname}"."schema_persist_history" WHERE version_id = '{str(ver_id)}'::uuid;\ """ refl_block.add_command(textwrap.dedent(invalid_persist_his)) main_block_sub = block.add_block() main_block_sub.add_command(textwrap.dedent(f"""\ - INSERT INTO "edgedbinstdata"."schema_persist_history" + INSERT INTO "{instdata_schemaname}"."schema_persist_history" ("version_id", "sql") values ( '{str(ver_id)}'::uuid, {pg_common.quote_bytea_literal(refl_block.to_string().encode())} @@ -533,7 +542,7 @@ def _compile_schema_storage_in_delta( with cache.mutate() as cache_mm: for eql, args in meta_blocks: eql_hash = hashlib.sha1(eql.encode()).hexdigest() - fname = ('edgedb', f'__rh_{eql_hash}') + fname = (pg_common.actual_schemaname('edgedb'), f'__rh_{eql_hash}') if eql_hash in cache_mm: argnames = cache_mm[eql_hash] @@ -570,7 +579,7 @@ def _compile_schema_storage_in_delta( or '__script' in args ): # schema version and migration - # update should always goes to main block + # update should always go to main block block.add_command(cmd) else: sp_block.add_command(cmd) @@ -605,6 +614,7 @@ def _compile_schema_storage_stmt( expected_cardinality_one=False, bootstrap_mode=ctx.bootstrap_mode, protocol_version=ctx.protocol_version, + namespace=ctx.namespace ) source = edgeql.Source.from_string(eql) @@ -1096,7 +1106,7 @@ def _compile_and_apply_ddl_stmt( "backend_id" ) FROM - edgedb."_SchemaType" + {pg_common.actual_schemaname('edgedb')}."_SchemaType" WHERE "id" = any(ARRAY[ {', '.join(new_type_ids)} @@ -1135,11 +1145,17 @@ def _compile_and_apply_ddl_stmt( create_db = None drop_db = None create_db_template = None + create_ns = None + drop_ns = None if isinstance(stmt, qlast.DropDatabase): drop_db = stmt.name.name elif isinstance(stmt, qlast.CreateDatabase): create_db = stmt.name.name create_db_template = stmt.template.name if stmt.template else None + elif isinstance(stmt, qlast.CreateNameSpace): + create_ns = stmt.name.name + elif isinstance(stmt, qlast.DropNameSpace): + drop_ns = stmt.name.name if debug.flags.delta_execute: debug.header('Delta Script') @@ -1156,6 +1172,8 @@ def _compile_and_apply_ddl_stmt( ), create_db=create_db, drop_db=drop_db, + create_ns=create_ns, + drop_ns=drop_ns, create_db_template=create_db_template, has_role_ddl=isinstance(stmt, qlast.RoleCommand), ddl_stmt_id=ddl_stmt_id, @@ -1881,7 +1899,11 @@ def _compile_dispatch_ql( self._compile_ql_sess_state(ctx, ql), enums.Capability.SESSION_CONFIG, ) - + elif isinstance(ql, qlast.UseNameSpaceCommand): + return ( + dbstate.NameSpaceSwitchQuery(new_ns=ql.name, sql=()), + enums.Capability.SESSION_CONFIG, + ) elif isinstance(ql, qlast.ConfigOp): if ql.scope is qltypes.ConfigScope.SESSION: capability = enums.Capability.SESSION_CONFIG @@ -1911,6 +1933,7 @@ def _compile( source: edgeql.Source, ) -> dbstate.QueryUnitGroup: current_tx = ctx.state.current_tx() + pg_common.NAMESPACE = ctx.namespace if current_tx.get_migration_state() is not None: original = edgeql.Source.from_string(source.text()) ctx = dataclasses.replace( @@ -1950,6 +1973,17 @@ def _try_compile( default_cardinality = enums.Cardinality.NO_RESULT statements = edgeql.parse_block(source) statements_len = len(statements) + is_script = statements_len > 1 + + if is_script and any(isinstance(stmt, qlast.UseNameSpaceCommand) for stmt in statements): + raise errors.QueryError( + 'USE NAMESPACE statement is not allowed to be used in script.' + ) + + if isinstance(statements[0], qlast.UseNameSpaceCommand) and ctx.in_tx: + raise errors.QueryError( + 'cannot execute USE NAMESPACE in a transaction' + ) if ctx.skip_first: statements = statements[1:] @@ -1964,8 +1998,7 @@ def _try_compile( raise errors.ProtocolError('nothing to compile') rv = dbstate.QueryUnitGroup() - - is_script = statements_len > 1 + rv.namespace = ctx.namespace script_info = None if is_script: if ctx.expect_rollback: @@ -2002,6 +2035,7 @@ def _try_compile( cardinality=default_cardinality, capabilities=capabilities, output_format=stmt_ctx.output_format, + namespace=ctx.namespace ) if not comp.is_transactional: @@ -2067,6 +2101,8 @@ def _try_compile( unit.create_db = comp.create_db unit.drop_db = comp.drop_db unit.create_db_template = comp.create_db_template + unit.create_ns = comp.create_ns + unit.drop_ns = comp.drop_ns unit.has_role_ddl = comp.has_role_ddl unit.ddl_stmt_id = comp.ddl_stmt_id if comp.user_schema is not None: @@ -2190,7 +2226,8 @@ def _try_compile( unit.config_ops.append(comp.config_op) unit.has_set = True - + elif isinstance(comp, dbstate.NameSpaceSwitchQuery): + unit.ns_to_switch = comp.new_ns elif isinstance(comp, dbstate.NullQuery): pass @@ -2424,6 +2461,7 @@ def compile_notebook( def compile( self, + namespace: str, user_schema: s_schema.Schema, global_schema: s_schema.Schema, reflection_cache: Mapping[str, Tuple[str, ...]], @@ -2493,7 +2531,8 @@ def compile( module=module, external_view=external_view, restoring_external=restoring_external, - testmode=testmode + testmode=testmode, + namespace=namespace ) unit_group = self._compile(ctx=ctx, source=source) @@ -2511,6 +2550,7 @@ def compile( def compile_in_tx( self, state: dbstate.CompilerConnectionState, + namespace: str, txid: int, source: edgeql.Source, output_format: enums.OutputFormat, @@ -2558,18 +2598,21 @@ def compile_in_tx( module=module, in_tx=True, external_view=external_view, - restoring_external=restoring_external + restoring_external=restoring_external, + namespace=namespace ) return self._compile(ctx=ctx, source=source), ctx.state def describe_database_dump( self, + namespace: str, user_schema: s_schema.Schema, global_schema: s_schema.Schema, database_config: immutables.Map[str, config.SettingValue], protocol_version: Tuple[int, int], ) -> DumpDescriptor: + pg_common.NAMESPACE = namespace schema = s_schema.ChainedSchema( self._std_schema, user_schema, @@ -2830,6 +2873,7 @@ def _check_dump_layout( def describe_database_restore( self, + namespace, user_schema: s_schema.Schema, global_schema: s_schema.Schema, dump_server_ver_str: Optional[str], @@ -2839,6 +2883,7 @@ def describe_database_restore( protocol_version: Tuple[int, int], external_view: Dict[str, str] ) -> RestoreDescriptor: + pg_common.NAMESPACE = namespace schema_object_ids = { ( s_name.name_from_string(name), @@ -2889,7 +2934,9 @@ def describe_database_restore( log_ddl_as_migrations=False, protocol_version=protocol_version, external_view=external_view, - restoring_external=True + restoring_external=True, + namespace=namespace, + bootstrap_mode=True ) else: ctx = CompileContext( @@ -2900,6 +2947,8 @@ def describe_database_restore( schema_object_ids=schema_object_ids, log_ddl_as_migrations=False, protocol_version=protocol_version, + namespace=namespace, + bootstrap_mode=True ) ctx.state.start_tx() @@ -3193,3 +3242,10 @@ class RestoreBlockDescriptor(NamedTuple): #: this will contain the recursive descriptor on which parts of #: each datum need mending. data_mending_desc: Tuple[Optional[DataMendingDescriptor], ...] + + +class RestoreSchemaInfo(NamedTuple): + schema_ddl: bytes + schema_ids: List[Tuple] + blocks: List + external_views: List[Tuple] diff --git a/edb/server/compiler/dbstate.py b/edb/server/compiler/dbstate.py index 25667df076b..6b7e0491e53 100644 --- a/edb/server/compiler/dbstate.py +++ b/edb/server/compiler/dbstate.py @@ -39,7 +39,7 @@ from edb.schema import objects as s_obj from edb.schema import schema as s_schema -from edb.server import config +from edb.server import config, defines from . import enums from . import sertypes @@ -80,6 +80,13 @@ class NullQuery(BaseQuery): has_dml: bool = False +@dataclasses.dataclass(frozen=True) +class NameSpaceSwitchQuery(BaseQuery): + new_ns: str + is_transactional: bool = False + single_unit: bool = True + + @dataclasses.dataclass(frozen=True) class Query(BaseQuery): @@ -135,7 +142,9 @@ class DDLQuery(BaseQuery): is_transactional: bool = True single_unit: bool = False create_db: Optional[str] = None + create_ns: Optional[str] = None drop_db: Optional[str] = None + drop_ns: Optional[str] = None create_db_template: Optional[str] = None has_role_ddl: bool = False ddl_stmt_id: Optional[str] = None @@ -260,6 +269,14 @@ class QueryUnit: # close all inactive unused pooled connections to the template db. create_db_template: Optional[str] = None + # If non-None, contains a name of the NameSpace that is about to be + # created. + create_ns: Optional[str] = None + + # If non-None, contains a name of the NameSpace that is about to be + # deleted. + drop_ns: Optional[str] = None + # If non-None, the DDL statement will emit data packets marked # with the indicated ID. ddl_stmt_id: Optional[str] = None @@ -306,6 +323,10 @@ class QueryUnit: # schema reflection sqls, only available if this is a ddl stmt. schema_refl_sqls: Tuple[bytes, ...] = None stdview_sqls: Tuple[bytes, ...] = None + # NameSpace to use for current compile + namespace: str = defines.DEFAULT_NS + # NameSpace to switch for connection + ns_to_switch: str = None @property def has_ddl(self) -> bool: @@ -361,6 +382,8 @@ class QueryUnitGroup: ref_ids: Optional[Set[uuid.UUID]] = None # Record affected object ids for cache clear affected_obj_ids: Optional[Set[uuid.UUID]] = None + # NameSpace to use for current compile + namespace: str = defines.DEFAULT_NS def __iter__(self): return iter(self.units) diff --git a/edb/server/compiler/errormech.py b/edb/server/compiler/errormech.py index 6291a0bac88..37e37528d69 100644 --- a/edb/server/compiler/errormech.py +++ b/edb/server/compiler/errormech.py @@ -81,8 +81,10 @@ class ErrorDetails(NamedTuple): pgerrors.ERROR_SERIALIZATION_FAILURE: errors.TransactionSerializationError, pgerrors.ERROR_DEADLOCK_DETECTED: errors.TransactionDeadlockError, pgerrors.ERROR_INVALID_CATALOG_NAME: errors.UnknownDatabaseError, + pgerrors.ERROR_INVALID_SCHEMA_NAME: errors.UnknownSchemaError, pgerrors.ERROR_OBJECT_IN_USE: errors.ExecutionError, pgerrors.ERROR_DUPLICATE_DATABASE: errors.DuplicateDatabaseDefinitionError, + pgerrors.ERROR_DUPLICATE_SCHEMA: errors.DuplicateNameSpaceDefinitionError, pgerrors.ERROR_IDLE_IN_TRANSACTION_TIMEOUT: errors.IdleTransactionTimeoutError, pgerrors.ERROR_QUERY_CANCELLED: errors.QueryTimeoutError, @@ -120,7 +122,7 @@ class ErrorDetails(NamedTuple): pgtype_re = re.compile( '|'.join(fr'\b{key}\b' for key in types.base_type_name_map_r)) enum_re = re.compile( - r'(?P

enum) (?Pedgedb([\w-]+)."(?P[\w-]+)_domain")') + r'(?P

enum) (?P(?:(.*)_)?edgedb([\w-]+)."(?P[\w-]+)_domain")') def translate_pgtype(schema, msg): diff --git a/edb/server/compiler/status.py b/edb/server/compiler/status.py index 5b54b67afa7..50febc28cc2 100644 --- a/edb/server/compiler/status.py +++ b/edb/server/compiler/status.py @@ -155,6 +155,11 @@ def _sess_reset_alias(ql): return b'RESET ALIAS' +@get_status.register(qlast.UseNameSpaceCommand) +def _sess_use_ns(ql): + return f'USE NAMESPACE {ql.name}'.encode() + + @get_status.register(qlast.ConfigOp) def _sess_set_config(ql): if ql.scope == qltypes.ConfigScope.GLOBAL: diff --git a/edb/server/compiler_pool/pool.py b/edb/server/compiler_pool/pool.py index 8262c4788de..6ea2adc623d 100644 --- a/edb/server/compiler_pool/pool.py +++ b/edb/server/compiler_pool/pool.py @@ -91,11 +91,12 @@ def __repr__(self): class MutationHistory: - def __init__(self, dbname: str): + def __init__(self, dbname: str, namespace: str): self._history: List[_SchemaMutation] = [] self._index: Dict[uuid.UUID, int] = {} self._cursor: Dict[uuid.UUID, int] = {} self._db = dbname + self._namespace = namespace @property def latest_ver(self): @@ -109,7 +110,7 @@ def clear(self): self._cursor.clear() def get_pickled_mutation(self, worker: BaseWorker) -> Optional[bytes]: - start = self._cursor.get(worker.get_user_schema_id(self._db)) + start = self._cursor.get(worker.get_user_schema_id(self._db, self._namespace)) if start is None: return @@ -117,13 +118,13 @@ def get_pickled_mutation(self, worker: BaseWorker) -> Optional[bytes]: mut_bytes = self._history[start].bytes if logger.isEnabledFor(logging.DEBUG): logger.debug( - f"::CPOOL:: WOKER<{worker.identifier}> | DB<{self._db}> - " + f"::CPOOL:: WOKER<{worker.identifier}> | DB<{self._db}({self._namespace})> - " f"Using stored {self._history[start]} to update." ) else: mut = s_schema.SchemaMutationLogger.merge([m.obj for m in self._history[start:]]) logger.info( - f"::CPOOL:: WOKER<{worker.identifier}> | DB<{self._db}> - " + f"::CPOOL:: WOKER<{worker.identifier}> | DB<{self._db}({self._namespace})> - " f"Using merged {_trim_uuid(mut.target)}> to update." ) mut_bytes = pickle.dumps(mut) @@ -188,11 +189,11 @@ def __init__( self._last_used = time.monotonic() self._closed = False - def get_user_schema_id(self, dbname: str) -> uuid.UUID: - if dbname not in self._dbs: + def get_user_schema_id(self, dbname: str, namespace: str) -> uuid.UUID: + if self._dbs.get(dbname, {}).get(namespace) is None: return UNKNOW_VER_ID - return self._dbs[dbname].user_schema_version + return self._dbs[dbname][namespace].user_schema_version @functools.cached_property def identifier(self): @@ -292,7 +293,7 @@ def __init__( self._std_schema = std_schema self._refl_schema = refl_schema self._schema_class_layout = schema_class_layout - self._mut_history: Dict[str, MutationHistory] = {} + self._mut_history: Dict[str, Dict[str, MutationHistory]] = {} @functools.lru_cache(maxsize=None) def _get_init_args(self): @@ -303,18 +304,22 @@ def _get_init_args(self): def _get_init_args_uncached(self): dbs: state.DatabasesState = immutables.Map() for db in self._dbindex.iter_dbs(): - db_user_schema = db.user_schema - version_id = UNKNOW_VER_ID if db_user_schema is None else db_user_schema.version_id - dbs = dbs.set( - db.name, - state.DatabaseState( - name=db.name, - user_schema=db_user_schema, - user_schema_version=version_id, - reflection_cache=db.reflection_cache, - database_config=db.db_config, + namespace = immutables.Map() + for ns_name, ns_db in db.ns_map.items(): + db_user_schema = ns_db.user_schema + version_id = UNKNOW_VER_ID if db_user_schema is None else db_user_schema.version_id + namespace.set( + ns_name, + state.DatabaseState( + name=ns_db.name, + namespace=ns_name, + user_schema=db_user_schema, + user_schema_version=version_id, + reflection_cache=ns_db.reflection_cache, + database_config=db.db_config, + ) ) - ) + dbs = dbs.set(db.name, namespace) init_args = ( dbs, @@ -337,7 +342,7 @@ async def start(self): async def stop(self): raise NotImplementedError - def collect_worker_schema_ids(self, dbname) -> List[uuid.UUID]: + def collect_worker_schema_ids(self, dbname, namespace) -> List[uuid.UUID]: raise NotImplementedError def get_template_pid(self): @@ -346,6 +351,7 @@ def get_template_pid(self): async def sync_user_schema( self, dbname, + namespace, user_schema, reflection_cache, global_schema, @@ -359,6 +365,7 @@ async def sync_user_schema( preargs, sync_state = await self._compute_compile_preargs( worker, dbname, + namespace, user_schema, global_schema, reflection_cache, @@ -369,7 +376,7 @@ async def sync_user_schema( if preargs[2] is not None: logger.debug(f"[W::{worker.identifier}] Sync user schema.") else: - if worker.get_user_schema_id(dbname) is not UNKNOW_VER_ID: + if worker.get_user_schema_id(dbname, namespace) is not UNKNOW_VER_ID: logger.warning(f"[W::{worker.identifier}] Attempt to sync user schema failed.") logger.info(f"[W::{worker.identifier}] Initialize user schema.") @@ -382,6 +389,7 @@ async def _compute_compile_preargs( self, worker, dbname, + namespace, user_schema, global_schema, reflection_cache, @@ -393,27 +401,34 @@ def sync_worker_state_cb( *, worker, dbname, + namespace, user_schema=None, global_schema=None, reflection_cache=None, database_config=None, system_config=None, ): - worker_db = worker._dbs.get(dbname) + worker_db = worker._dbs.get(dbname, {}).get(namespace) if worker_db is None: assert user_schema is not None assert reflection_cache is not None assert global_schema is not None assert database_config is not None assert system_config is not None + ns = worker._dbs.get(dbname, immutables.Map()) + ns.set( + namespace, + state.DatabaseState( + name=dbname, + namespace=namespace, + user_schema=user_schema, + user_schema_version=user_schema.version_id, + reflection_cache=reflection_cache, + database_config=database_config, + ) + ) - worker._dbs = worker._dbs.set(dbname, state.DatabaseState( - name=dbname, - user_schema=user_schema, - user_schema_version=user_schema.version_id, - reflection_cache=reflection_cache, - database_config=database_config, - )) + worker._dbs = worker._dbs.set(dbname, ns) worker._global_schema = global_schema worker._system_config = system_config else: @@ -423,24 +438,31 @@ def sync_worker_state_cb( or database_config is not None ): new_user_schema = user_schema or worker_db.user_schema - worker._dbs = worker._dbs.set(dbname, state.DatabaseState( - name=dbname, - user_schema=new_user_schema, - user_schema_version=new_user_schema.version_id, - reflection_cache=( - reflection_cache or worker_db.reflection_cache), - database_config=( - database_config if database_config is not None - else worker_db.database_config), - )) + ns = worker._dbs[dbname] + ns.set( + namespace, + state.DatabaseState( + name=dbname, + namespace=namespace, + user_schema=new_user_schema, + user_schema_version=new_user_schema.version_id, + reflection_cache=( + reflection_cache or worker_db.reflection_cache), + database_config=( + database_config if database_config is not None + else worker_db.database_config), + ) + ) + + worker._dbs = worker._dbs.set(dbname, ns) if global_schema is not None: worker._global_schema = global_schema if system_config is not None: worker._system_config = system_config - worker_db: state.DatabaseState = worker._dbs.get(dbname) - preargs = (dbname,) + worker_db: state.DatabaseState = worker._dbs.get(dbname, {}).get(namespace) + preargs = (dbname, namespace) to_update = {} if worker_db is None: @@ -468,16 +490,16 @@ def sync_worker_state_cb( f"Initialize db <{dbname}> schema version to: [{user_schema.version_id}]" ) else: - if dbname not in self._mut_history: + if self._mut_history.get(dbname, {}).get(namespace) is None: # 当前实例初始化后未执行任何ddl,此时在其他实例发生DDL, # 触发当前实例的introspect_db,导致worker的schema版本失效, # 这种情况下,当前实例_mut_history可能不包含dbname mutation_pickled = None else: - mutation_pickled = self._mut_history[dbname].get_pickled_mutation(worker) + mutation_pickled = self._mut_history[dbname][namespace].get_pickled_mutation(worker) if mutation_pickled is None: logger.warning( - f"::CPOOL:: WOKER<{worker.identifier}> | DB<{dbname}> - " + f"::CPOOL:: WOKER<{worker.identifier}> | DB<{dbname}({namespace})> - " f"No schema mutation available. " f"Schema <{worker_db.user_schema_version}> is outdated, will issue a full update." ) @@ -525,6 +547,7 @@ def sync_worker_state_cb( sync_worker_state_cb, worker=worker, dbname=dbname, + namespace=namespace, **to_update ) else: @@ -541,6 +564,7 @@ def _release_worker(self, worker, *, put_in_front: bool = True): def append_schema_mutation( self, dbname, + namespace, mut_bytes, mutation: s_schema.SchemaMutationLogger, user_schema, @@ -549,10 +573,12 @@ def append_schema_mutation( database_config, system_config, ): - if is_fresh := (dbname not in self._mut_history): - self._mut_history[dbname] = MutationHistory(dbname) + if is_fresh := (self._mut_history.get(dbname, {}).get(namespace) is None): + ns_map = self._mut_history.get(dbname, {}) + ns_map[namespace] = MutationHistory(dbname, namespace) + self._mut_history[dbname] = ns_map - hist = self._mut_history[dbname] + hist = self._mut_history[dbname][namespace] hist.append(_SchemaMutation( base=mutation.id, target=user_schema.version_id, @@ -561,7 +587,7 @@ def append_schema_mutation( )) if not is_fresh: - usids = self.collect_worker_schema_ids(dbname) + usids = self.collect_worker_schema_ids(dbname, namespace) hist.try_trim_history(usids) if ( @@ -570,18 +596,22 @@ def append_schema_mutation( ): logger.debug(f"Schedule {n} tasks to sync worker's user schema.") for _ in range(n): - asyncio.create_task(self.sync_user_schema( - dbname, - user_schema, - reflection_cache, - global_schema, - database_config, - system_config, - )) + asyncio.create_task( + self.sync_user_schema( + dbname, + namespace, + user_schema, + reflection_cache, + global_schema, + database_config, + system_config, + ) + ) async def compile( self, dbname, + namespace, user_schema, global_schema, reflection_cache, @@ -594,6 +624,7 @@ async def compile( preargs, sync_state = await self._compute_compile_preargs( worker, dbname, + namespace, user_schema, global_schema, reflection_cache, @@ -619,6 +650,7 @@ async def compile( async def compile_in_tx( self, dbname, + namespace, txid, pickled_state, state_id, @@ -658,7 +690,7 @@ async def compile_in_tx( pickled_state = state.REUSE_LAST_STATE_MARKER user_schema = None else: - usid = worker.get_user_schema_id(dbname) + usid = worker.get_user_schema_id(dbname, namespace) if state_id == 0: if base_user_schema.version_id != usid: user_schema = _pickle_memoized(base_user_schema) @@ -682,6 +714,7 @@ async def compile_in_tx( 'compile_in_tx', pickled_state, dbname, + namespace, user_schema, txid, *compile_args @@ -700,6 +733,7 @@ async def compile_notebook( self, dbname, user_schema, + namespace, global_schema, reflection_cache, database_config, @@ -711,6 +745,7 @@ async def compile_notebook( preargs, sync_state = await self._compute_compile_preargs( worker, dbname, + namespace, user_schema, global_schema, reflection_cache, @@ -746,6 +781,7 @@ async def try_compile_rollback( async def compile_graphql( self, dbname, + namespace, user_schema, global_schema, reflection_cache, @@ -758,6 +794,7 @@ async def compile_graphql( preargs, sync_state = await self._compute_compile_preargs( worker, dbname, + namespace, user_schema, global_schema, reflection_cache, @@ -778,6 +815,7 @@ async def compile_graphql( async def infer_expr( self, dbname, + namespace, user_schema, global_schema, reflection_cache, @@ -790,6 +828,7 @@ async def infer_expr( preargs, sync_state = await self._compute_compile_preargs( worker, dbname, + namespace, user_schema, global_schema, reflection_cache, @@ -1033,8 +1072,8 @@ def _release_worker(self, worker, *, put_in_front: bool = True): if worker.get_pid() in self._workers: self._workers_queue.release(worker, put_in_front=put_in_front) - def collect_worker_schema_ids(self, dbname) -> Iterable[uuid.UUID]: - return [w.get_user_schema_id(dbname) for w in self._workers.values()] + def collect_worker_schema_ids(self, dbname, namespace) -> Iterable[uuid.UUID]: + return [w.get_user_schema_id(dbname, namespace) for w in self._workers.values()] @srvargs.CompilerPoolMode.Fixed.assign_implementation @@ -1126,8 +1165,8 @@ class DebugWorker: _last_pickled_state = None connected = False - def get_user_schema_id(self, dbname): - return BaseWorker.get_user_schema_id(self, dbname) # noqa + def get_user_schema_id(self, dbname, namespace): + return BaseWorker.get_user_schema_id(self, dbname, namespace) # noqa async def call(self, method_name, *args, sync_state=None): from . import worker @@ -1156,18 +1195,22 @@ def __init__(self, **kwargs): def _get_init_args(self): dbs: state.DatabasesState = immutables.Map() for db in self._dbindex.iter_dbs(): - db_user_schema = db.user_schema - version_id = UNKNOW_VER_ID if db_user_schema is None else db_user_schema.version_id - dbs = dbs.set( - db.name, - state.DatabaseState( - name=db.name, - user_schema=db_user_schema, - user_schema_version=version_id, - reflection_cache=db.reflection_cache, - database_config=db.db_config, + namespace = immutables.Map() + for ns_name, ns_db in db.ns_map.items(): + db_user_schema = ns_db.user_schema + version_id = UNKNOW_VER_ID if db_user_schema is None else db_user_schema.version_id + namespace.set( + ns_name, + state.DatabaseState( + name=db.name, + namespace=ns_name, + user_schema=db_user_schema, + user_schema_version=version_id, + reflection_cache=ns_db.reflection_cache, + database_config=db.db_config, + ) ) - ) + dbs = dbs.set(db.name, namespace) self._worker._dbs = dbs self._worker._backend_runtime_params = self._backend_runtime_params self._worker._std_schema = self._std_schema @@ -1228,8 +1271,8 @@ async def stop(self): if self._worker.connected: self.worker_disconnected(os.getpid()) - def collect_worker_schema_ids(self, dbname) -> Iterable[uuid.UUID]: - return [self._worker.get_user_schema_id(dbname)] + def collect_worker_schema_ids(self, dbname, namespace) -> Iterable[uuid.UUID]: + return [self._worker.get_user_schema_id(dbname, namespace)] @srvargs.CompilerPoolMode.OnDemand.assign_implementation @@ -1520,7 +1563,7 @@ async def _compute_compile_preargs(self, *args): self._sync_lock.release() return preargs, callback - def collect_worker_schema_ids(self, dbname) -> Iterable[uuid.UUID]: + def collect_worker_schema_ids(self, dbname, namespace) -> Iterable[uuid.UUID]: return [] diff --git a/edb/server/compiler_pool/state.py b/edb/server/compiler_pool/state.py index f5444cd6fad..3b7795229bc 100644 --- a/edb/server/compiler_pool/state.py +++ b/edb/server/compiler_pool/state.py @@ -31,13 +31,14 @@ class DatabaseState(typing.NamedTuple): name: str + namespace: str user_schema: typing.Optional[schema.FlatSchema] user_schema_version: typing.Optional[uuid.UUID] reflection_cache: ReflectionCache database_config: immutables.Map[str, config.SettingValue] -DatabasesState = immutables.Map[str, DatabaseState] +DatabasesState = immutables.Map[str, immutables.Map[str, DatabaseState]] class FailedStateSync(Exception): diff --git a/edb/server/compiler_pool/worker.py b/edb/server/compiler_pool/worker.py index add80ad95a7..a26a0e37baa 100644 --- a/edb/server/compiler_pool/worker.py +++ b/edb/server/compiler_pool/worker.py @@ -109,6 +109,7 @@ def __init_worker__( def __sync__( dbname: str, + namespace: str, user_schema: Optional[bytes], user_schema_mutation: Optional[bytes], reflection_cache: Optional[bytes], @@ -122,22 +123,25 @@ def __sync__( global INSTANCE_CONFIG try: - db = DBS.get(dbname) - if db is None: + ns_db = DBS.get(dbname, {}).get(namespace) + if ns_db is None: assert user_schema is not None assert reflection_cache is not None assert database_config is not None user_schema_unpacked = pickle.loads(user_schema) reflection_cache_unpacked = pickle.loads(reflection_cache) database_config_unpacked = pickle.loads(database_config) - db = state.DatabaseState( + ns = DBS.get(dbname, immutables.Map()) + ns_db = state.DatabaseState( dbname, + namespace, user_schema_unpacked, user_schema_unpacked.version_id, reflection_cache_unpacked, database_config_unpacked, ) - DBS = DBS.set(dbname, db) + ns.set(namespace, ns_db) + DBS = DBS.set(dbname, ns) else: updates = {} @@ -156,8 +160,10 @@ def __sync__( updates['database_config'] = pickle.loads(database_config) if updates: - db = db._replace(**updates) - DBS = DBS.set(dbname, db) + ns_db = ns_db._replace(**updates) + ns = DBS[dbname] + ns.set(namespace, ns_db) + DBS = DBS.set(dbname, ns) if global_schema is not None: GLOBAL_SCHEMA = pickle.loads(global_schema) @@ -170,11 +176,12 @@ def __sync__( f'failed to sync worker state: {type(ex).__name__}({ex})') from ex if need_return: - return db + return ns_db def compile( dbname: str, + namespace: str, user_schema: Optional[bytes], user_schema_mutation: Optional[bytes], reflection_cache: Optional[bytes], @@ -187,6 +194,7 @@ def compile( with util.disable_gc(): db = __sync__( dbname, + namespace, user_schema, user_schema_mutation, reflection_cache, @@ -196,6 +204,7 @@ def compile( ) units, cstate = COMPILER.compile( + namespace, db.user_schema, GLOBAL_SCHEMA, db.reflection_cache, @@ -214,7 +223,7 @@ def compile( return units, pickled_state -def compile_in_tx(cstate, dbname, user_schema_pickled, *args, **kwargs): +def compile_in_tx(cstate, dbname, namespace, user_schema_pickled, *args, **kwargs): global LAST_STATE global DBS @@ -227,17 +236,18 @@ def compile_in_tx(cstate, dbname, user_schema_pickled, *args, **kwargs): if user_schema_pickled is not None: user_schema: s_schema.FlatSchema = pickle.loads(user_schema_pickled) else: - user_schema = DBS.get(dbname).user_schema + user_schema = DBS.get(dbname).get(namespace).user_schema cstate = cstate.restore(user_schema) - units, cstate = COMPILER.compile_in_tx(cstate, *args, **kwargs) + units, cstate = COMPILER.compile_in_tx(cstate, namespace, *args, **kwargs) LAST_STATE = cstate return units, pickle.dumps(cstate.compress(), -1), cstate.base_user_schema_id def compile_notebook( dbname: str, + namespace: str, user_schema: Optional[bytes], user_schema_mutation: Optional[bytes], reflection_cache: Optional[bytes], @@ -249,6 +259,7 @@ def compile_notebook( ): db = __sync__( dbname, + namespace, user_schema, user_schema_mutation, reflection_cache, @@ -270,6 +281,7 @@ def compile_notebook( def infer_expr( dbname: str, + namespace: str, user_schema: Optional[bytes], user_schema_mutation: Optional[bytes], reflection_cache: Optional[bytes], @@ -281,6 +293,7 @@ def infer_expr( ): db = __sync__( dbname, + namespace, user_schema, user_schema_mutation, reflection_cache, @@ -321,6 +334,7 @@ def describe_database_restore( def compile_graphql( dbname: str, + namespace: str, user_schema: Optional[bytes], user_schema_mutation: Optional[bytes], reflection_cache: Optional[bytes], @@ -340,6 +354,7 @@ def compile_graphql( db = __sync__( dbname, + namespace, user_schema, user_schema_mutation, reflection_cache, @@ -350,6 +365,7 @@ def compile_graphql( gql_op = graphql.compile_graphql( dbname, + namespace, STD_SCHEMA, db.user_schema, GLOBAL_SCHEMA, @@ -364,6 +380,7 @@ def compile_graphql( ) unit_group, _ = COMPILER.compile( + namespace=namespace, user_schema=db.user_schema, global_schema=GLOBAL_SCHEMA, reflection_cache=db.reflection_cache, diff --git a/edb/server/dbview/dbview.pxd b/edb/server/dbview/dbview.pxd index f9022a6dbef..d863b01bc29 100644 --- a/edb/server/dbview/dbview.pxd +++ b/edb/server/dbview/dbview.pxd @@ -27,9 +27,11 @@ cpdef enum SideEffects: SchemaChanges = 1 << 0 DatabaseConfigChanges = 1 << 1 - InstanceConfigChanges = 1 << 2 - RoleChanges = 1 << 3 - GlobalSchemaChanges = 1 << 4 + DatabaseDrop = 1 << 2 + DatabaseCreate = 1 << 3 + InstanceConfigChanges = 1 << 4 + RoleChanges = 1 << 5 + GlobalSchemaChanges = 1 << 6 @cython.final @@ -46,6 +48,7 @@ cdef class QueryRequestInfo: cdef public bint inline_objectids cdef public uint64_t allow_capabilities cdef public object module + cdef public object namespace cdef public bint read_only cdef public object external_view cdef public bint testmode @@ -71,12 +74,26 @@ cdef class DatabaseIndex: object _factory +cdef class NameSpace: + cdef: + public object _eql_to_compiled + public object _eql_to_compiled_disk + public object _object_id_to_eql + DatabaseIndex _dbindex + object _state_serializers + str _sql_bak_dir + bint _log_cache + + readonly str name + public object user_schema + public object reflection_cache + public object backend_ids + public object extensions + + cdef class Database: cdef: - object _eql_to_compiled - object _eql_to_compiled_disk - object _object_id_to_eql DatabaseIndex _index object _views object _introspection_lock @@ -85,29 +102,27 @@ cdef class Database: bint _log_cache readonly str name + public object ns_map readonly object dbver readonly object db_config - readonly object user_schema - readonly object reflection_cache - readonly object backend_ids - readonly object extensions cdef schedule_config_update(self) - cdef _invalidate_caches(self, drop_ids) + cdef _invalidate_caches(self) cdef _cache_compiled_query(self, key, query_unit) cdef _new_view(self, query_cache, protocol_version) cdef _remove_view(self, view) - cdef _update_backend_ids(self, new_types) + cdef _update_backend_ids(self, namespace, new_types) cdef _set_and_signal_new_user_schema( self, + namespace, new_schema, reflection_cache=?, backend_ids=?, db_config=?, affecting_ids=?, ) - cdef get_state_serializer(self, protocol_version) + cdef get_state_serializer(self, namespace, protocol_version) cdef class DatabaseConnectionView: @@ -137,8 +152,6 @@ cdef class DatabaseConnectionView: tuple _session_state_db_cache tuple _session_state_cache - object _eql_to_compiled - object _txid object _in_tx_db_config object _in_tx_savepoints @@ -166,14 +179,13 @@ cdef class DatabaseConnectionView: object __weakref__ - cdef _invalidate_local_cache(self) cdef _reset_tx_state(self) cdef clear_tx_error(self) cdef rollback_tx_to_savepoint(self, name) - cdef declare_savepoint(self, name, spid) + cdef declare_savepoint(self, namespace, name, spid) cdef recover_aliases_and_config(self, modaliases, config, globals) - cdef abort_tx(self) + cpdef abort_tx(self) cpdef in_tx(self) cpdef in_tx_error(self) @@ -183,13 +195,13 @@ cdef class DatabaseConnectionView: cdef tx_error(self) - cdef start(self, query_unit) - cdef _start_tx(self) + cpdef start(self, query_unit) + cdef _start_tx(self, namespace) cdef _apply_in_tx(self, query_unit) cdef start_implicit(self, query_unit) - cdef on_error(self) + cpdef on_error(self) cdef commit_implicit_tx( - self, user_schema, user_schema_unpacked, + self, namespace, user_schema, user_schema_unpacked, user_schema_mutation, global_schema, cached_reflection, affecting_ids, ) @@ -200,7 +212,7 @@ cdef class DatabaseConnectionView: cpdef get_globals(self) cpdef set_globals(self, new_globals) - cdef get_state_serializer(self) + cdef get_state_serializer(self, namespace) cdef set_state_serializer(self, new_serializer) cdef update_database_config(self) @@ -214,8 +226,8 @@ cdef class DatabaseConnectionView: cpdef get_modaliases(self) cdef bytes serialize_state(self) - cdef bint is_state_desc_changed(self) - cdef describe_state(self) - cdef encode_state(self) - cdef decode_state(self, type_id, data) - cdef inline recode_global(self, serializer, k, v) + cpdef bint is_state_desc_changed(self, namespace) + cdef describe_state(self, namespace) + cpdef encode_state(self) + cpdef decode_state(self, type_id, data, namespace) + cdef inline recode_global(self, serializer, namespace, k, v) diff --git a/edb/server/dbview/dbview.pyx b/edb/server/dbview/dbview.pyx index 6f20e96a5fe..6c071ef9a45 100644 --- a/edb/server/dbview/dbview.pyx +++ b/edb/server/dbview/dbview.pyx @@ -88,6 +88,7 @@ cdef class QueryRequestInfo: read_only: bint = False, testmode: bint = False, external_view: object = immutables.Map(), + namespace: str = defines.DEFAULT_NS, ): self.source = source self.protocol_version = protocol_version @@ -104,6 +105,7 @@ cdef class QueryRequestInfo: self.read_only = read_only self.testmode = testmode self.external_view = external_view + self.namespace = namespace self.cached_hash = hash(( self.source.cache_key(), @@ -118,7 +120,8 @@ cdef class QueryRequestInfo: self.inline_objectids, self.module, self.read_only, - self.testmode + self.testmode, + self.namespace )) def __hash__(self): @@ -138,7 +141,8 @@ cdef class QueryRequestInfo: self.inline_objectids == other.inline_objectids and self.module == other.module and self.read_only == other.read_only and - self.testmode == other.testmode + self.testmode == other.testmode and + self.namespace == other.namespace ) @@ -263,8 +267,7 @@ cdef format_eqls(raw_eqls): return "\n".join(msg) -cdef class Database: - +cdef class NameSpace: # Global LRU cache of compiled anonymous queries _eql_to_compiled: typing.Mapping[ typing.Tuple[QueryRequestInfo, @@ -289,33 +292,29 @@ cdef class Database: typing.Optional[immutables.Map] ] ] - + # Dict for object id to eql + _object_id_to_eql: EqlDict[ + uuid.UUID, + typing.Tuple[QueryRequestInfo, + typing.Optional[immutables.Map], + typing.Optional[immutables.Map] + ] + ] def __init__( self, - DatabaseIndex index, str name, + DatabaseIndex dbindex, *, object user_schema, - object db_config, object reflection_cache, object backend_ids, object extensions, ): self.name = name - - self.dbver = next_dbver() - - self._index = index - self._views = weakref.WeakSet() self._state_serializers = {} - - self._introspection_lock = asyncio.Lock() - self._eql_to_compiled = lru.LRUMapping(maxsize=defines._MAX_QUERIES_CACHE) self._eql_to_compiled_disk = RankedDiskCache() self._object_id_to_eql = EqlDict() - - self.db_config = db_config self.user_schema = user_schema self.reflection_cache = reflection_cache self.backend_ids = backend_ids @@ -326,58 +325,12 @@ cdef class Database: } else: self.extensions = extensions - - self._sql_bak_dir = os.path.join(self.server._runstate_dir, 'sql_bak') + self._dbindex = dbindex + self._sql_bak_dir = os.path.join(dbindex._server._runstate_dir, 'sql_bak', name) self._log_cache = logger.isEnabledFor(logging.DEBUG) and debug.flags.show_cache_info - @property - def server(self): - return self._index._server - - cdef schedule_config_update(self): - self._index._server._on_local_database_config_change(self.name) - - cdef _set_and_signal_new_user_schema( - self, - new_schema, - reflection_cache=None, - backend_ids=None, - db_config=None, - affecting_ids: typing.Set[uuid.UUID]=None - ): - if new_schema is None: - raise AssertionError('new_schema is not supposed to be None') - - self.dbver = next_dbver() - - self.user_schema = new_schema - - self.extensions = { - ext.get_name(new_schema).name - for ext in new_schema.get_objects(type=s_ext.Extension) - } - - if backend_ids is not None: - self.backend_ids = backend_ids - if reflection_cache is not None: - self.reflection_cache = reflection_cache - if db_config is not None: - self.db_config = db_config - - drop_ids = {DROP_IN_SCHEMA_DELTA} - - if affecting_ids: - drop_ids.update(affecting_ids.intersection(self._object_id_to_eql.keys())) - - self._invalidate_caches(drop_ids) - - cdef _update_backend_ids(self, new_types): - self.backend_ids.update(new_types) - - cdef _invalidate_caches(self, drop_ids: typing.Set[uuid.UUID]): + def invalidate_caches(self, drop_ids: typing.Set[uuid.UUID]): self._state_serializers.clear() - self._clear_http_cache() - if self._log_cache: logger.debug(f'Ids to drop: {drop_ids}.') @@ -388,13 +341,17 @@ cdef class Database: for eql in list(self._object_id_to_eql[obj_id]): if eql in self._eql_to_compiled: if self._log_cache: - logger.debug(f"Eql with sql:{format_eqls((eql,))} " - f"will be dropped for change of object with id <{obj_id}> in LRU cache.") + logger.debug( + f"Eql with sql:{format_eqls((eql,))} " + f"will be dropped for change of object with id <{obj_id}> in LRU cache." + ) del self._eql_to_compiled[eql] if eql in self._eql_to_compiled_disk: if self._log_cache: - logger.debug(f"Eql with sql:{format_eqls((eql,))} " - f"will be dropped for change of object with id <{obj_id}> in Disk cache.") + logger.debug( + f"Eql with sql:{format_eqls((eql,))} " + f"will be dropped for change of object with id <{obj_id}> in Disk cache." + ) del self._eql_to_compiled_disk[eql] del self._object_id_to_eql[obj_id] @@ -402,34 +359,21 @@ cdef class Database: if self._log_cache: logger.debug('After invalidate, LRU Cache: \n' + format_eqls(self._eql_to_compiled._dict.keys())) logger.debug('Disk Cache: \n' + format_eqls(self._eql_to_compiled_disk.keys())) - logger.debug(f'Http CACHE: \n{self.server._http_query_cache._dict.keys()}') logger.debug(f'Obj id to Eql: \n{self._object_id_to_eql}') - def _clear_http_cache(self): - query_cache = self.server._http_query_cache - for cache_key in self.server.remove_on_ddl: - if cache_key in query_cache: - del query_cache[cache_key] - self.server.remove_on_ddl.clear() - def clear_caches(self): self._eql_to_compiled.clear() self._eql_to_compiled_disk.clear() self._object_id_to_eql.clear() - query_cache = self.server._http_query_cache - for cache_key in dict(query_cache._dict): - del query_cache[cache_key] - self.server.remove_on_ddl.clear() def view_caches(self): return '\n\n'.join([ f'LRU CACHE: \n{format_eqls(self._eql_to_compiled._dict.keys())}', f'Disk CACHE: \n{format_eqls(self._eql_to_compiled_disk.keys())}', - f'Http CACHE: \n{self.server._http_query_cache._dict.keys()}', f'Obj id to Eql: \n{self._object_id_to_eql}', ]) - cdef _cache_compiled_query(self, key, compiled: dbstate.QueryUnitGroup): + def _cache_compiled_query(self, key, compiled: dbstate.QueryUnitGroup): assert compiled.cacheable existing = self._eql_to_compiled.get(key) @@ -475,6 +419,150 @@ cdef class Database: for obj_id in compiled.ref_ids: self._object_id_to_eql.add(obj_id, key) + def get_state_serializer(self, protocol_version): + if protocol_version not in self._state_serializers: + self._state_serializers[protocol_version] = self._dbindex._factory.make( + self.user_schema, + self._dbindex._global_schema, + protocol_version, + ) + return self._state_serializers[protocol_version] + + def get_query_cache_size(self): + return len(self._eql_to_compiled) + + + +cdef class Database: + + def __init__( + self, + DatabaseIndex index, + str name, + str namespace, + *, + object user_schema, + object db_config, + object reflection_cache, + object backend_ids, + object extensions, + ): + self.name = name + self.dbver = next_dbver() + + self._index = index + self._views = weakref.WeakSet() + self._state_serializers = {} + + self._introspection_lock = asyncio.Lock() + + self.ns_map: typing.Dict[str, NameSpace] = { + namespace: NameSpace( + name=namespace, + dbindex=index, + user_schema=user_schema, + reflection_cache=reflection_cache, + backend_ids=backend_ids, + extensions=extensions + ) + } + self.db_config = db_config + self._log_cache = logger.isEnabledFor(logging.DEBUG) and debug.flags.show_cache_info + + @property + def server(self): + return self._index._server + + cdef schedule_config_update(self): + self._index._server._on_local_database_config_change(self.name) + + cdef _set_and_signal_new_user_schema( + self, + namespace, + new_schema, + reflection_cache=None, + backend_ids=None, + db_config=None, + affecting_ids: typing.Set[uuid.UUID]=None + ): + if new_schema is None: + raise AssertionError('new_schema is not supposed to be None') + + self.dbver = next_dbver() + if db_config is not None: + self.db_config = db_config + + extensions = { + ext.get_name(new_schema).name + for ext in new_schema.get_objects(type=s_ext.Extension) + } + if namespace not in self.ns_map: + ns = self.ns_map[namespace] = NameSpace( + name=namespace, + dbindex=self._index, + user_schema=new_schema, + reflection_cache=reflection_cache, + backend_ids=backend_ids, + extensions=extensions + ) + else: + ns = self.ns_map[namespace] + ns.user_schema = new_schema + ns.extensions = extensions + if reflection_cache is not None: + ns.reflection_cache = reflection_cache + if backend_ids is not None: + ns.backend_ids = backend_ids + drop_ids = {DROP_IN_SCHEMA_DELTA} + + if affecting_ids: + drop_ids.update(affecting_ids.intersection(ns._object_id_to_eql.keys())) + + ns.invalidate_caches(drop_ids) + + self._invalidate_caches() + + cdef _update_backend_ids(self, namespace, new_types): + self.ns_map[namespace].backend_ids.update(new_types) + + cdef _invalidate_caches(self): + self._clear_http_cache() + if self._log_cache: + logger.debug(f'Http CACHE: \n{self.server._http_query_cache._dict.keys()}') + + def _clear_http_cache(self): + query_cache = self.server._http_query_cache + for cache_key in self.server.remove_on_ddl: + if cache_key in query_cache: + del query_cache[cache_key] + self.server.remove_on_ddl.clear() + + def clear_caches(self): + for ns in self.ns_map.values(): + ns.clear_caches() + query_cache = self.server._http_query_cache + for cache_key in dict(query_cache._dict): + del query_cache[cache_key] + self.server.remove_on_ddl.clear() + + def view_caches(self): + return f'Http CACHE: \n{self.server._http_query_cache._dict.keys()}'\ + + '\n\n'.join( + [ + f"NameSpace({ns.name}): \n{ns.view_caches()}" + for ns in self.ns_map.values() + ] + ) + + cdef _cache_compiled_query(self, key, compiled: dbstate.QueryUnitGroup): + assert compiled.cacheable + + if compiled.namespace not in self.ns_map: + return + + ns = self.ns_map[compiled.namespace] + return ns._cache_compiled_query(key, compiled) + cdef _new_view(self, query_cache, protocol_version): view = DatabaseConnectionView( self, query_cache=query_cache, protocol_version=protocol_version @@ -485,33 +573,23 @@ cdef class Database: cdef _remove_view(self, view): self._views.remove(view) - cdef get_state_serializer(self, protocol_version): - if protocol_version not in self._state_serializers: - self._state_serializers[protocol_version] = self._index._factory.make( - self.user_schema, - self._index._global_schema, - protocol_version, - ) - return self._state_serializers[protocol_version] + cdef get_state_serializer(self, namespace, protocol_version): + return self.ns_map[namespace].get_state_serializer(protocol_version) def iter_views(self): yield from self._views - def get_query_cache_size(self): - return len(self._eql_to_compiled) - async def introspection(self): - if self.user_schema is None: + if any(ns.user_schema is None for ns in self.ns_map.values()): async with self._introspection_lock: - if self.user_schema is None: - await self._index._server.introspect_db(self.name) + await self._index._server.introspect(self.name) - async def persist_schema(self): + async def persist_schema(self, namespace): async with self._introspection_lock: - await self._index._server.persist_user_schema(self.name) + await self._index._server.persist_user_schema(self.name, namespace) - def schedule_schema_persistence(self): - asyncio.create_task(self.persist_schema()) + def schedule_schema_persistence(self, namespace): + asyncio.create_task(self.persist_schema(namespace)) def schedule_stdobj_inhview_update(self, sql): asyncio.create_task( @@ -520,13 +598,6 @@ cdef class Database: ) cdef class DatabaseConnectionView: - - _eql_to_compiled: typing.Mapping[ - typing.Tuple[QueryRequestInfo, - typing.Optional[immutables.Map], - typing.Optional[immutables.Map]], - dbstate.QueryUnitGroup] - def __init__(self, db: Database, *, query_cache, protocol_version): self._db = db @@ -555,15 +626,8 @@ cdef class DatabaseConnectionView: self._last_comp_state = None self._last_comp_state_id = None - # Whenever we are in a transaction that had executed a - # DDL command, we use this cache for compiled queries. - self._eql_to_compiled = lru.LRUMapping(maxsize=defines._MAX_QUERIES_CACHE) - self._reset_tx_state() - cdef _invalidate_local_cache(self): - self._eql_to_compiled.clear() - cdef _reset_tx_state(self): self._txid = None self._in_tx = False @@ -588,7 +652,6 @@ cdef class DatabaseConnectionView: self._in_tx_dbver = 0 self._in_tx_stdview_sqls = None self._in_tx_sp_sqls = [] - self._invalidate_local_cache() cdef clear_tx_error(self): self._tx_error = False @@ -613,14 +676,13 @@ cdef class DatabaseConnectionView: self.set_session_config(config) self.set_globals(globals) self.set_state_serializer(state_serializer) - self._invalidate_local_cache() - cdef declare_savepoint(self, name, spid): + cdef declare_savepoint(self, namespace, name, spid): state = ( self.get_modaliases(), self.get_session_config(), self.get_globals(), - self.get_state_serializer(), + self.get_state_serializer(namespace), ) self._in_tx_savepoints.append((name, spid, state)) @@ -630,7 +692,7 @@ cdef class DatabaseConnectionView: self.set_session_config(config) self.set_globals(globals) - cdef abort_tx(self): + cpdef abort_tx(self): if not self.in_tx(): raise errors.InternalServerError('abort_tx(): not in transaction') self._reset_tx_state() @@ -647,20 +709,22 @@ cdef class DatabaseConnectionView: else: return self._globals - cdef get_state_serializer(self): + cdef get_state_serializer(self, namespace): if self._in_tx: if self._in_tx_state_serializer is None: # DDL in transaction, recalculate the state descriptor self._in_tx_state_serializer = self._db._index._factory.make( - self.get_user_schema(), + self.get_user_schema(namespace), self.get_global_schema(), self._protocol_version, ) return self._in_tx_state_serializer else: if self._state_serializer is None: + self.valid_namespace(namespace) # Executed a DDL, recalculate the state descriptor self._state_serializer = self._db.get_state_serializer( + namespace, self._protocol_version ) return self._state_serializer @@ -741,7 +805,7 @@ cdef class DatabaseConnectionView: else: return self._modaliases - def get_user_schema(self): + def get_user_schema(self, namespace: str): if self._in_tx: if self._in_tx_user_schema_mut_pickled: mutation = pickle.loads(self._in_tx_user_schema_mut_pickled) @@ -749,7 +813,19 @@ cdef class DatabaseConnectionView: self._in_tx_user_schema_mut_pickled = None return self._in_tx_user_schema else: - return self._db.user_schema + self.valid_namespace(namespace) + return self._db.ns_map[namespace].user_schema + + def get_reflection_cache(self, namespace: str): + self.valid_namespace(namespace) + return self._db.ns_map[namespace].reflection_cache + + def valid_namespace(self, namespace: str): + if namespace not in self._db.ns_map: + raise errors.QueryError( + f'NameSpace: [{namespace}] not in current db [{self._db.name}](ver:{self._db.dbver}).' + f'Current NameSpace(s): [{", ".join(self._db.ns_map.keys())}]' + ) def get_global_schema(self): if self._in_tx: @@ -762,15 +838,15 @@ cdef class DatabaseConnectionView: else: return self._db._index._global_schema - def get_schema(self): - user_schema = self.get_user_schema() + def get_schema(self, namespace): + user_schema = self.get_user_schema(namespace) return s_schema.ChainedSchema( self._db._index._std_schema, user_schema, self._db._index._global_schema, ) - def resolve_backend_type_id(self, type_id): + def resolve_backend_type_id(self, type_id, namespace): type_id = str(type_id) if self._in_tx: @@ -779,7 +855,8 @@ cdef class DatabaseConnectionView: except KeyError: pass - tid = self._db.backend_ids.get(type_id) + self.valid_namespace(namespace) + tid = self._db.ns_map[namespace].backend_ids.get(type_id) if tid is None: raise RuntimeError( f'cannot resolve backend OID for type {type_id}') @@ -810,8 +887,8 @@ cdef class DatabaseConnectionView: self._session_state_db_cache = (self._config, spec) return spec - cdef bint is_state_desc_changed(self): - serializer = self.get_state_serializer() + cpdef bint is_state_desc_changed(self, namespace): + serializer = self.get_state_serializer(namespace) if not self._in_tx: # We may have executed a query, or COMMIT/ROLLBACK - just use # the serializer we preserved before. NOTE: the schema might @@ -838,10 +915,10 @@ cdef class DatabaseConnectionView: return True - cdef describe_state(self): - return self.get_state_serializer().describe() + cdef describe_state(self, namespace): + return self.get_state_serializer(namespace).describe() - cdef encode_state(self): + cpdef encode_state(self): modaliases = self.get_modaliases() session_config = self.get_session_config() globals_ = self.get_globals() @@ -881,11 +958,11 @@ cdef class DatabaseConnectionView: state['globals'] = {k: v.value for k, v in globals_.items()} return serializer.type_id, serializer.encode(state) - cdef decode_state(self, type_id, data): + cpdef decode_state(self, type_id, data, namespace): if not self._in_tx: # make sure we start clean self._state_serializer = None - serializer = self.get_state_serializer() + serializer = self.get_state_serializer(namespace) self._command_state_serializer = serializer if type_id == sertypes.NULL_TYPE_ID.bytes: @@ -921,7 +998,7 @@ cdef class DatabaseConnectionView: globals_ = immutables.Map({ k: config.SettingValue( name=k, - value=self.recode_global(serializer, k, v), + value=self.recode_global(serializer, namespace, k, v), source='global', scope=qltypes.ConfigScope.GLOBAL, ) for k, v in state.get('globals', {}).items() @@ -933,13 +1010,13 @@ cdef class DatabaseConnectionView: aliases, session_config, globals_, type_id, data ) - cdef inline recode_global(self, serializer, k, v): + cdef inline recode_global(self, serializer, namespace, k, v): if v and v[:4] == b'\x00\x00\x00\x01': array_type_id = serializer.get_global_array_type_id(k) if array_type_id: va = bytearray(v) va[8:12] = INT32_PACKER( - self.resolve_backend_type_id(array_type_id) + self.resolve_backend_type_id(array_type_id, namespace) ) v = bytes(va) return v @@ -952,10 +1029,6 @@ cdef class DatabaseConnectionView: def __get__(self): return self._db.name - property reflection_cache: - def __get__(self): - return self._db.reflection_cache - property dbver: def __get__(self): if self._in_tx and self._in_tx_dbver: @@ -966,6 +1039,9 @@ cdef class DatabaseConnectionView: def server(self): return self._db._index._server + def iter_ns_name(self): + return iter(self._db.ns_map.keys()) + cpdef in_tx(self): return self._in_tx @@ -977,9 +1053,7 @@ cdef class DatabaseConnectionView: key = (key, self.get_modaliases(), self.get_session_config()) - if self._in_tx_with_ddl: - self._eql_to_compiled[key] = query_unit_group - else: + if not self._in_tx_with_ddl: self._db._cache_compiled_query(key, query_unit_group) cdef lookup_compiled_query(self, object key): @@ -988,12 +1062,15 @@ cdef class DatabaseConnectionView: self._in_tx_with_ddl): return None + self.valid_namespace(key.namespace) + ns = self._db.ns_map[key.namespace] + key = (key, self.get_modaliases(), self.get_session_config()) - query_unit_group = self._db._eql_to_compiled.get(key) + query_unit_group = ns._eql_to_compiled.get(key) if query_unit_group is None: - disk_filepath = self._db._eql_to_compiled_disk.get(key) + disk_filepath = ns._eql_to_compiled_disk.get(key) if disk_filepath is None: return None @@ -1002,9 +1079,9 @@ cdef class DatabaseConnectionView: if logger.isEnabledFor(logging.DEBUG): logger.debug(f'Find dumped sql bytes in disk deleted, ' f'drop Eql for Sql: {key[0].source.text()}.') - self._db._eql_to_compiled_disk.delete_with_cb( + ns._eql_to_compiled_disk.delete_with_cb( key, - self._db._object_id_to_eql.maybe_drop_with_eqls + ns._object_id_to_eql.maybe_drop_with_eqls ) return None @@ -1013,7 +1090,7 @@ cdef class DatabaseConnectionView: query_unit_group = pickle.load(disk_file) metrics.edgeql_cache_pickle_load_duration.observe(time.monotonic() - started_at) - self._db._eql_to_compiled[key] = query_unit_group + ns._eql_to_compiled[key] = query_unit_group return query_unit_group @@ -1021,13 +1098,13 @@ cdef class DatabaseConnectionView: if self._in_tx: self._tx_error = True - cdef start(self, query_unit): + cpdef start(self, query_unit): if self._tx_error: self.raise_in_tx_error() if query_unit.tx_id is not None: self._txid = query_unit.tx_id - self._start_tx() + self._start_tx(query_unit.namespace) if self._in_tx and not self._txid: raise errors.InternalServerError('unset txid in transaction') @@ -1035,22 +1112,17 @@ cdef class DatabaseConnectionView: if self._in_tx: self._apply_in_tx(query_unit) - cdef _start_tx(self): + cdef _start_tx(self, namespace): self._in_tx = True self._in_tx_config = self._config self._in_tx_globals = self._globals self._in_tx_db_config = self._db.db_config self._in_tx_modaliases = self._modaliases - self._in_tx_base_user_schema = self._db.user_schema - self._in_tx_user_schema = self._db.user_schema + self._in_tx_base_user_schema = self._db.ns_map[namespace].user_schema + self._in_tx_user_schema = self._db.ns_map[namespace].user_schema self._in_tx_global_schema = self._db._index._global_schema self._in_tx_state_serializer = self._state_serializer - def sync_tx_base_schema(self): - if self._db.user_schema is self._in_tx_base_user_schema: - return - self._in_tx_base_user_schema = self._db.user_schema - cdef _apply_in_tx(self, query_unit): if query_unit.has_ddl: self._in_tx_with_ddl = True @@ -1079,11 +1151,11 @@ cdef class DatabaseConnectionView: self.raise_in_tx_error() if not self._in_tx: - self._start_tx() + self._start_tx(query_unit.namespace) self._apply_in_tx(query_unit) - cdef on_error(self): + cpdef on_error(self): self.tx_error() async def in_tx_persist_schema(self, be_conn): @@ -1093,14 +1165,15 @@ cdef class DatabaseConnectionView: await be_conn.sql_execute(sqls) self._in_tx_sp_sqls.clear() - def save_schema_mutation(self, mut, mut_bytes): + def save_schema_mutation(self, namespace, mut, mut_bytes): self._db._index._server.get_compiler_pool().append_schema_mutation( self.dbname, + namespace, mut_bytes, mut, - self.get_user_schema(), + self.get_user_schema(namespace), self.get_global_schema(), - self.reflection_cache, + self.get_reflection_cache(namespace), self.get_database_config(), self.get_compilation_system_config(), ) @@ -1113,8 +1186,10 @@ cdef class DatabaseConnectionView: not self._in_tx and side_effects and (side_effects & SideEffects.SchemaChanges) + and not (query_unit.create_ns or query_unit.drop_ns) ): self.save_schema_mutation( + query_unit.namespace, query_unit.user_schema_mutation_obj, query_unit.user_schema_mutation, ) @@ -1123,19 +1198,15 @@ cdef class DatabaseConnectionView: def _on_success(self, query_unit, new_types): side_effects = 0 - if query_unit.tx_savepoint_rollback: - # Need to invalidate the cache in case there were - # SET ALIAS or CONFIGURE or DDL commands. - self._invalidate_local_cache() - if not self._in_tx: if new_types: - self._db._update_backend_ids(new_types) + self._db._update_backend_ids(query_unit.namespace, new_types) if query_unit.user_schema_mutation is not None: self._in_tx_dbver = next_dbver() self._state_serializer = None self._db._set_and_signal_new_user_schema( - query_unit.update_user_schema(self._db.user_schema), + query_unit.namespace, + query_unit.update_user_schema(self.get_user_schema(query_unit.namespace)), pickle.loads(query_unit.cached_reflection) if query_unit.cached_reflection is not None else None, @@ -1143,10 +1214,19 @@ cdef class DatabaseConnectionView: None, query_unit.affected_obj_ids ) - self._db.schedule_schema_persistence() + self._db.schedule_schema_persistence(query_unit.namespace) if query_unit.stdview_sqls: self._db.schedule_stdobj_inhview_update(query_unit.stdview_sqls) side_effects |= SideEffects.SchemaChanges + + if query_unit.create_ns or query_unit.drop_ns: + self._db.dbver = next_dbver() + side_effects |= SideEffects.SchemaChanges + if query_unit.create_db: + side_effects |= SideEffects.DatabaseCreate + if query_unit.drop_db: + side_effects |= SideEffects.DatabaseDrop + if query_unit.system_config: side_effects |= SideEffects.InstanceConfigChanges if query_unit.database_config: @@ -1177,10 +1257,11 @@ cdef class DatabaseConnectionView: self._globals = self._in_tx_globals if self._in_tx_new_types: - self._db._update_backend_ids(self._in_tx_new_types) + self._db._update_backend_ids(query_unit.namespace, self._in_tx_new_types) if query_unit.user_schema_mutation is not None: self._state_serializer = None self._db._set_and_signal_new_user_schema( + query_unit.namespace, query_unit.update_user_schema(self._in_tx_base_user_schema), pickle.loads(query_unit.cached_reflection) if query_unit.cached_reflection is not None @@ -1189,7 +1270,7 @@ cdef class DatabaseConnectionView: None, query_unit.affected_obj_ids ) - self._db.schedule_schema_persistence() + self._db.schedule_schema_persistence(query_unit.namespace) if self._in_tx_stdview_sqls: self._db.schedule_stdobj_inhview_update(self._in_tx_stdview_sqls) side_effects |= SideEffects.SchemaChanges @@ -1218,7 +1299,7 @@ cdef class DatabaseConnectionView: return side_effects cdef commit_implicit_tx( - self, user_schema, user_schema_unpacked, + self, namespace, user_schema, user_schema_unpacked, user_schema_mutation, global_schema, cached_reflection, affecting_ids ): @@ -1230,7 +1311,7 @@ cdef class DatabaseConnectionView: self._globals = self._in_tx_globals if self._in_tx_new_types: - self._db._update_backend_ids(self._in_tx_new_types) + self._db._update_backend_ids(namespace, self._in_tx_new_types) if ( user_schema is not None @@ -1240,13 +1321,14 @@ cdef class DatabaseConnectionView: if user_schema_unpacked is not None: user_schema = user_schema_unpacked elif user_schema_mutation is not None: - base_user_schema = self._db.user_schema + base_user_schema = self.get_user_schema(namespace) user_schema = user_schema_mutation.apply(base_user_schema) else: user_schema = pickle.loads(user_schema) self._state_serializer = None self._db._set_and_signal_new_user_schema( + namespace, user_schema, pickle.loads(cached_reflection) if cached_reflection is not None @@ -1373,6 +1455,7 @@ cdef class DatabaseConnectionView: if self.in_tx(): result = await compiler_pool.compile_in_tx( self.dbname, + query_req.namespace, self.txid, self._last_comp_state, self._last_comp_state_id, @@ -1395,9 +1478,10 @@ cdef class DatabaseConnectionView: else: result = await compiler_pool.compile( self.dbname, - self.get_user_schema(), + query_req.namespace, + self.get_user_schema(query_req.namespace), self.get_global_schema(), - self.reflection_cache, + self.get_reflection_cache(query_req.namespace), self.get_database_config(), self.get_compilation_system_config(), query_req.source, @@ -1480,8 +1564,7 @@ cdef class DatabaseIndex: try: return self._dbs[dbname] except KeyError: - raise errors.UnknownDatabaseError( - f'database {dbname!r} does not exist') + raise errors.UnknownDatabaseError(f'database {dbname!r} does not exist') def maybe_get_db(self, dbname): return self._dbs.get(dbname) @@ -1492,9 +1575,10 @@ cdef class DatabaseIndex: def update_global_schema(self, global_schema): self._global_schema = global_schema - def register_db( + def register_ns( self, dbname, + namespace, *, user_schema, db_config, @@ -1505,19 +1589,35 @@ cdef class DatabaseIndex: cdef Database db db = self._dbs.get(dbname) if db is not None: - db._set_and_signal_new_user_schema( - user_schema, reflection_cache, backend_ids, db_config) + if namespace not in db.ns_map: + db.ns_map[namespace] = NameSpace( + name=namespace, + dbindex=db._index, + user_schema=user_schema, + reflection_cache=reflection_cache, + backend_ids=backend_ids, + extensions=extensions + ) + else: + db._set_and_signal_new_user_schema( + namespace, user_schema, reflection_cache, backend_ids, db_config + ) else: - db = Database( + self._dbs[dbname] = Database( self, dbname, + namespace=namespace, user_schema=user_schema, db_config=db_config, reflection_cache=reflection_cache, backend_ids=backend_ids, extensions=extensions, ) - self._dbs[dbname] = db + + def unregister_ns(self, dbname, namespace): + if dbname not in self._dbs: + return + self._dbs[dbname].ns_map.pop(namespace, None) def unregister_db(self, dbname): self._dbs.pop(dbname) diff --git a/edb/server/defines.py b/edb/server/defines.py index 9877a5fa4cb..96f4cfbbc90 100644 --- a/edb/server/defines.py +++ b/edb/server/defines.py @@ -28,6 +28,7 @@ EDGEDB_SUPERGROUP = 'edgedb_supergroup' EDGEDB_SUPERUSER = s_def.EDGEDB_SUPERUSER EDGEDB_TEMPLATE_DB = s_def.EDGEDB_TEMPLATE_DB +DEFAULT_NS = s_def.DEFAULT_NS EDGEDB_SUPERUSER_DB = 'edgedb' EDGEDB_SYSTEM_DB = s_def.EDGEDB_SYSTEM_DB EDGEDB_ENCODING = 'utf-8' diff --git a/edb/server/pgcluster.py b/edb/server/pgcluster.py index 306ac1146d6..6c65180f085 100644 --- a/edb/server/pgcluster.py +++ b/edb/server/pgcluster.py @@ -1174,7 +1174,7 @@ async def _start_logged_subprocess( asyncio.subprocess.PIPE if log_stderr or capture_stderr else asyncio.subprocess.DEVNULL ), - limit=2 ** 20, # 1 MiB + limit=2 ** 25, # 32 MiB **kwargs, ) diff --git a/edb/server/pgcon/errors.py b/edb/server/pgcon/errors.py index c1dadd885c0..c3922b826f3 100644 --- a/edb/server/pgcon/errors.py +++ b/edb/server/pgcon/errors.py @@ -53,6 +53,7 @@ ERROR_INVALID_PASSWORD = '28P01' ERROR_INVALID_CATALOG_NAME = '3D000' +ERROR_INVALID_SCHEMA_NAME = '3F000' ERROR_SERIALIZATION_FAILURE = '40001' ERROR_DEADLOCK_DETECTED = '40P01' @@ -60,6 +61,7 @@ ERROR_WRONG_OBJECT_TYPE = '42809' ERROR_INSUFFICIENT_PRIVILEGE = '42501' ERROR_DUPLICATE_DATABASE = '42P04' +ERROR_DUPLICATE_SCHEMA = '42P06' ERROR_PROGRAM_LIMIT_EXCEEDED = '54000' diff --git a/edb/server/pgcon/pgcon.pyx b/edb/server/pgcon/pgcon.pyx index bf1187da24a..93ade27859d 100644 --- a/edb/server/pgcon/pgcon.pyx +++ b/edb/server/pgcon/pgcon.pyx @@ -1939,7 +1939,15 @@ cdef class PGConnection: event_payload = event_data.get('args') if event == 'schema-changes': dbname = event_payload['dbname'] - self.server._on_remote_ddl(dbname) + namespace = event_payload['namespace'] + drop_ns = event_payload['drop_ns'] + self.server._on_remote_ddl(dbname, namespace, drop_ns) + elif event == 'database-create': + dbname = event_payload['dbname'] + self.server._on_remote_ddl(dbname, namespace=None) + elif event == 'database-drop': + dbname = event_payload['dbname'] + self.server._on_after_drop_db(dbname) elif event == 'database-config-changes': dbname = event_payload['dbname'] self.server._on_remote_database_config_change(dbname) diff --git a/edb/server/protocol/args_ser.pyx b/edb/server/protocol/args_ser.pyx index b55d904d5b0..0505cc2f9ed 100644 --- a/edb/server/protocol/args_ser.pyx +++ b/edb/server/protocol/args_ser.pyx @@ -171,7 +171,7 @@ cdef WriteBuffer recode_bind_args( if param.array_type_id is not None: # ndimensions + flags array_tid = dbv.resolve_backend_type_id( - param.array_type_id) + param.array_type_id, qug.namespace) out_buf.write_cstr(data, 8) out_buf.write_int32(array_tid) out_buf.write_cstr(&data[12], in_len - 12) diff --git a/edb/server/protocol/binary.pxd b/edb/server/protocol/binary.pxd index cf2c267642e..5a255d22d9c 100644 --- a/edb/server/protocol/binary.pxd +++ b/edb/server/protocol/binary.pxd @@ -64,6 +64,7 @@ cdef class EdgeConnection(frontend.FrontendConnection): object loop readonly dbview.DatabaseConnectionView _dbview str dbname + str namespace ReadBuffer buffer @@ -132,7 +133,7 @@ cdef class EdgeConnection(frontend.FrontendConnection): cdef WriteBuffer make_command_data_description_msg( self, dbview.CompiledQuery query ) - cdef WriteBuffer make_state_data_description_msg(self) + cdef WriteBuffer make_state_data_description_msg(self, namespace=?) cdef WriteBuffer make_command_complete_msg(self, capabilities, status) cdef inline reject_headers(self) diff --git a/edb/server/protocol/binary.pyx b/edb/server/protocol/binary.pyx index e1bcd6462d2..a424d2f306b 100644 --- a/edb/server/protocol/binary.pyx +++ b/edb/server/protocol/binary.pyx @@ -67,6 +67,7 @@ from edb.server import defines as edbdef from edb.server.compiler import errormech from edb.server.compiler import enums from edb.server.compiler import sertypes +from edb.server.compiler.compiler import RestoreSchemaInfo from edb.server.protocol import execute from edb.server.protocol cimport frontend from edb.server.pgcon cimport pgcon @@ -117,6 +118,7 @@ DEF QUERY_HEADER_ALLOW_CAPABILITIES = 0xFF04 DEF QUERY_HEADER_EXPLICIT_OBJECTIDS = 0xFF05 DEF QUERY_HEADER_EXPLICIT_MODULE = 0xFF06 DEF QUERY_HEADER_PROHIBIT_MUTATION = 0xFF07 +DEF QUERY_HEADER_EXPLICIT_NS = 0xFF08 DEF SERVER_HEADER_CAPABILITIES = 0x1001 @@ -165,6 +167,7 @@ cdef class EdgeConnection(frontend.FrontendConnection): self.loop = server.get_loop() self._dbview = None self.dbname = None + self.namespace = None self._transport = None self.buffer = ReadBuffer() @@ -475,7 +478,9 @@ cdef class EdgeConnection(frontend.FrontendConnection): f'accept connections' ) - await self._start_connection(database) + namespace = params.get('namespace', edbdef.DEFAULT_NS) + + await self._start_connection(database, namespace) # The user has already been authenticated by other means # (such as the ability to write to a protected socket). @@ -586,7 +591,7 @@ cdef class EdgeConnection(frontend.FrontendConnection): return params - async def _start_connection(self, database: str) -> None: + async def _start_connection(self, database: str, namespace: str) -> None: dbv = await self.server.new_dbview( dbname=database, query_cache=self.query_cache_enabled, @@ -595,6 +600,8 @@ cdef class EdgeConnection(frontend.FrontendConnection): assert type(dbv) is dbview.DatabaseConnectionView self._dbview = dbv self.dbname = database + dbv.valid_namespace(namespace) + self.namespace = namespace self._con_status = EDGECON_STARTED @@ -1004,10 +1011,10 @@ cdef class EdgeConnection(frontend.FrontendConnection): msg.end_message() return msg - cdef WriteBuffer make_state_data_description_msg(self): + cdef WriteBuffer make_state_data_description_msg(self, namespace=None): cdef WriteBuffer msg - type_id, type_data = self.get_dbview().describe_state() + type_id, type_data = self.get_dbview().describe_state(namespace or self.namespace) msg = WriteBuffer.new_message(b's') msg.write_bytes(type_id.bytes) @@ -1091,7 +1098,8 @@ cdef class EdgeConnection(frontend.FrontendConnection): EdgeSeverity.EDGE_SEVERITY_NOTICE, errors.LogMessage.get_code(), 'server restart is required for the configuration ' - 'change to take effect') + 'change to take effect' + ) cdef dbview.QueryRequestInfo parse_execute_request(self): cdef: @@ -1139,7 +1147,7 @@ cdef class EdgeConnection(frontend.FrontendConnection): state_tid = self.buffer.read_bytes(16) state_data = self.buffer.read_len_prefixed_bytes() try: - self.get_dbview().decode_state(state_tid, state_data) + self.get_dbview().decode_state(state_tid, state_data, self.namespace) except errors.StateMismatchError: self.write(self.make_state_data_description_msg()) raise @@ -1154,6 +1162,7 @@ cdef class EdgeConnection(frontend.FrontendConnection): inline_typenames=inline_typenames, inline_objectids=inline_objectids, allow_capabilities=allow_capabilities, + namespace=self.namespace ) async def parse(self): @@ -1264,16 +1273,21 @@ cdef class EdgeConnection(frontend.FrontendConnection): elif len(query_unit_group) > 1: await self._execute_script(compiled, args) else: - use_prep = ( - len(query_unit_group) == 1 - and bool(query_unit_group[0].sql_hash) - ) - await self._execute(compiled, args, use_prep) + if len(query_unit_group) == 1 and query_unit_group[0].ns_to_switch is not None: + new_ns = query_unit_group[0].ns_to_switch + self.get_dbview().valid_namespace(new_ns) + self.namespace = query_unit_group[0].ns_to_switch + else: + use_prep = ( + len(query_unit_group) == 1 + and bool(query_unit_group[0].sql_hash) + ) + await self._execute(compiled, args, use_prep) if self._cancelled: raise ConnectionAbortedError - if _dbview.is_state_desc_changed(): + if _dbview.is_state_desc_changed(self.namespace): self.write(self.make_state_data_description_msg()) self.write( self.make_command_complete_msg( @@ -1438,6 +1452,11 @@ cdef class EdgeConnection(frontend.FrontendConnection): await self.recover_from_error() + try: + self.get_dbview().valid_namespace(self.namespace) + except Exception: + self.namespace = edbdef.DEFAULT_NS + else: self.buffer.finish_message() @@ -1575,7 +1594,7 @@ cdef class EdgeConnection(frontend.FrontendConnection): # only use the backend if schema is required if static_exc is errormech.SchemaRequired: exc = errormech.interpret_backend_error( - self.get_dbview().get_schema(), + self.get_dbview().get_schema(self.namespace), exc.fields ) elif isinstance(static_exc, ( @@ -1584,6 +1603,11 @@ cdef class EdgeConnection(frontend.FrontendConnection): tenant_id = self.server.get_tenant_id() message = static_exc.args[0].replace(f'{tenant_id}_', '') exc = type(static_exc)(message) + elif isinstance(static_exc, + (errors.DuplicateNameSpaceDefinitionError, errors.UnknownSchemaError) + ): + message = static_exc.args[0].replace('schema', 'namespace').replace('_edgedbext', '') + exc = type(static_exc)(message) else: exc = static_exc @@ -1793,6 +1817,18 @@ cdef class EdgeConnection(frontend.FrontendConnection): self._write_waiter.set_result(True) async def dump(self): + await self._dump() + + msg_buf = WriteBuffer.new_message(b'C') + msg_buf.write_int16(0) # no headers + msg_buf.write_int64(0) # capabilities + msg_buf.write_len_prefixed_bytes(b'DUMP') + msg_buf.write_bytes(sertypes.NULL_TYPE_ID.bytes) + msg_buf.write_len_prefixed_bytes(b'') + self.write(msg_buf.end_message()) + self.flush() + + async def _dump(self): cdef: WriteBuffer msg_buf dbview.DatabaseConnectionView _dbview @@ -1825,7 +1861,11 @@ cdef class EdgeConnection(frontend.FrontendConnection): # # This guarantees that every pg connection and the compiler work # with the same DB state. - user_schema = await server.introspect_user_schema(dbname, pgcon) + + global_schema = await server.introspect_global_schema(pgcon) + db_config = await server.introspect_db_config(pgcon) + dump_protocol = self.max_protocol + await pgcon.sql_execute( b'''START TRANSACTION @@ -1840,27 +1880,10 @@ cdef class EdgeConnection(frontend.FrontendConnection): SET statement_timeout = 0; ''', ) - global_schema = await server.introspect_global_schema(pgcon) - db_config = await server.introspect_db_config(pgcon) - dump_protocol = self.max_protocol - - schema_ddl, schema_dynamic_ddl, schema_ids, blocks, external_ids = ( - await compiler_pool.describe_database_dump( - user_schema, - global_schema, - db_config, - dump_protocol, - ) - ) - if schema_dynamic_ddl: - for query in schema_dynamic_ddl: - result = await pgcon.sql_fetch_val(query.encode('utf-8')) - if result: - schema_ddl += '\n' + result.decode('utf-8') + namespaces = list(_dbview.iter_ns_name()) msg_buf = WriteBuffer.new_message(b'@') - msg_buf.write_int16(4) # number of headers msg_buf.write_int16(DUMP_HEADER_BLOCK_TYPE) msg_buf.write_len_prefixed_bytes(DUMP_HEADER_BLOCK_TYPE_INFO) @@ -1868,48 +1891,73 @@ cdef class EdgeConnection(frontend.FrontendConnection): msg_buf.write_len_prefixed_utf8(str(buildmeta.get_version())) msg_buf.write_int16(DUMP_HEADER_SERVER_TIME) msg_buf.write_len_prefixed_utf8(str(int(time.time()))) - - # adding external ddl & external ids - msg_buf.write_int16(DUMP_EXTERNAL_VIEW) - external_views = await self.external_views(external_ids, pgcon) - msg_buf.write_int32(len(external_views)) - for name, view_sql in external_views: - if isinstance(name, tuple): - msg_buf.write_int16(DUMP_EXTERNAL_KEY_LINK) - msg_buf.write_len_prefixed_utf8(name[0]) - msg_buf.write_len_prefixed_utf8(name[1]) - else: - msg_buf.write_int16(DUMP_EXTERNAL_KEY_OBJ) - msg_buf.write_len_prefixed_utf8(name) - msg_buf.write_len_prefixed_utf8(view_sql) + msg_buf.write_int16(DUMP_NAMESPACE_COUNT) + msg_buf.write_int32(len(namespaces)) msg_buf.write_int16(dump_protocol[0]) msg_buf.write_int16(dump_protocol[1]) - msg_buf.write_len_prefixed_utf8(schema_ddl) + all_blocks = [] - msg_buf.write_int32(len(schema_ids)) - for (tn, td, tid) in schema_ids: - msg_buf.write_len_prefixed_utf8(tn) - msg_buf.write_len_prefixed_utf8(td) - assert len(tid) == 16 - msg_buf.write_bytes(tid) # uuid + for ns in namespaces: + user_schema = await server.introspect_user_schema(dbname, ns, pgcon) - msg_buf.write_int32(len(blocks)) - for block in blocks: - assert len(block.schema_object_id.bytes) == 16 - msg_buf.write_bytes(block.schema_object_id.bytes) # uuid - msg_buf.write_len_prefixed_bytes(block.type_desc) + schema_ddl, schema_dynamic_ddl, schema_ids, blocks, external_ids = ( + await compiler_pool.describe_database_dump( + ns, + user_schema, + global_schema, + db_config, + dump_protocol, + ) + ) + + if schema_dynamic_ddl: + for query in schema_dynamic_ddl: + result = await pgcon.sql_fetch_val(query.encode('utf-8')) + if result: + schema_ddl += '\n' + result.decode('utf-8') + + all_blocks.extend(blocks) - msg_buf.write_int16(len(block.schema_deps)) - for depid in block.schema_deps: - assert len(depid.bytes) == 16 - msg_buf.write_bytes(depid.bytes) # uuid + msg_buf.write_len_prefixed_utf8(ns) + + external_views = await self.external_views(external_ids, pgcon) + msg_buf.write_int32(len(external_views)) + for name, view_sql in external_views: + if isinstance(name, tuple): + msg_buf.write_int16(DUMP_EXTERNAL_KEY_LINK) + msg_buf.write_len_prefixed_utf8(name[0]) + msg_buf.write_len_prefixed_utf8(name[1]) + else: + msg_buf.write_int16(DUMP_EXTERNAL_KEY_OBJ) + msg_buf.write_len_prefixed_utf8(name) + msg_buf.write_len_prefixed_utf8(view_sql) + + msg_buf.write_len_prefixed_utf8(schema_ddl) + + msg_buf.write_int32(len(schema_ids)) + for (tn, td, tid) in schema_ids: + msg_buf.write_len_prefixed_utf8(tn) + msg_buf.write_len_prefixed_utf8(td) + assert len(tid) == 16 + msg_buf.write_bytes(tid) # uuid + + msg_buf.write_int32(len(blocks)) + for block in blocks: + assert len(block.schema_object_id.bytes) == 16 + msg_buf.write_bytes(block.schema_object_id.bytes) # uuid + msg_buf.write_len_prefixed_bytes(block.type_desc) + + msg_buf.write_int16(len(block.schema_deps)) + for depid in block.schema_deps: + assert len(depid.bytes) == 16 + msg_buf.write_bytes(depid.bytes) # uuid self._transport.write(memoryview(msg_buf.end_message())) self.flush() - blocks_queue = collections.deque(blocks) + blocks_queue = collections.deque(all_blocks) output_queue = asyncio.Queue(maxsize=2) async with taskgroup.TaskGroup() as g: @@ -1958,15 +2006,6 @@ cdef class EdgeConnection(frontend.FrontendConnection): self._in_dump_restore = False server.release_pgcon(dbname, pgcon) - msg_buf = WriteBuffer.new_message(b'C') - msg_buf.write_int16(0) # no headers - msg_buf.write_int64(0) # capabilities - msg_buf.write_len_prefixed_bytes(b'DUMP') - msg_buf.write_bytes(sertypes.NULL_TYPE_ID.bytes) - msg_buf.write_len_prefixed_bytes(b'') - self.write(msg_buf.end_message()) - self.flush() - async def external_views(self, external_ids: List[Tuple[str, str]], pgcon): views = [] for ext_name, ext_id in external_ids: @@ -2007,7 +2046,75 @@ cdef class EdgeConnection(frontend.FrontendConnection): else: _dbview.on_success(query_unit, {}) + def restore_external_views(self): + external_views = [] + external_view_num = self.buffer.read_int32() + for _ in range(external_view_num): + key_flag = self.buffer.read_int16() + if key_flag == DUMP_EXTERNAL_KEY_LINK: + obj_name = self.buffer.read_len_prefixed_utf8() + link_name = self.buffer.read_len_prefixed_utf8() + sql = self.buffer.read_len_prefixed_utf8() + external_views.append(((obj_name, link_name), sql)) + else: + name = self.buffer.read_len_prefixed_utf8() + sql = self.buffer.read_len_prefixed_utf8() + external_views.append((name, sql)) + return external_views + + def restore_schema_info(self, external_views=None): + if external_views is not None: + external_views = external_views + else: + external_views = self.restore_external_views() + + schema_ddl = self.buffer.read_len_prefixed_bytes() + + ids_num = self.buffer.read_int32() + schema_ids = [] + for _ in range(ids_num): + schema_ids.append( + ( + self.buffer.read_len_prefixed_utf8(), + self.buffer.read_len_prefixed_utf8(), + self.buffer.read_bytes(16), + ) + ) + + block_num = self.buffer.read_int32() + blocks = [] + for _ in range(block_num): + blocks.append( + ( + self.buffer.read_bytes(16), + self.buffer.read_len_prefixed_bytes(), + ) + ) + + # Ignore deps info + for _ in range(self.buffer.read_int16()): + self.buffer.read_bytes(16) + + return RestoreSchemaInfo( + schema_ddl=schema_ddl, schema_ids=schema_ids, blocks=blocks, external_views=external_views + ) + + async def restore(self): + await self._restore() + + state_tid, state_data = self.get_dbview().encode_state() + + msg = WriteBuffer.new_message(b'C') + msg.write_int16(0) # no headers + msg.write_int64(0) # capabilities + msg.write_len_prefixed_bytes(b'RESTORE') + msg.write_bytes(state_tid.bytes) + msg.write_len_prefixed_bytes(state_data) + self.write(msg.end_message()) + self.flush() + + async def _restore(self): cdef: WriteBuffer msg_buf char mtype @@ -2026,33 +2133,26 @@ cdef class EdgeConnection(frontend.FrontendConnection): server = self.server compiler_pool = server.get_compiler_pool() - global_schema = _dbview.get_global_schema() - user_schema = _dbview.get_user_schema() dump_server_ver_str = None headers_num = self.buffer.read_int16() - external_views = [] + ns_count = 0 + schema_info_by_ns: Dict[str, RestoreSchemaInfo] = {} + external_views=[] + default_ns = edbdef.DEFAULT_NS + for _ in range(headers_num): hdrname = self.buffer.read_int16() - if hdrname != DUMP_EXTERNAL_VIEW: + + if hdrname not in [DUMP_EXTERNAL_VIEW, DUMP_NAMESPACE_COUNT]: hdrval = self.buffer.read_len_prefixed_bytes() if hdrname == DUMP_HEADER_SERVER_VER: dump_server_ver_str = hdrval.decode('utf-8') - # getting external ddl & external ids - if hdrname == DUMP_EXTERNAL_VIEW: - external_view_num = self.buffer.read_int32() - for _ in range(external_view_num): - key_flag = self.buffer.read_int16() - if key_flag == DUMP_EXTERNAL_KEY_LINK: - obj_name = self.buffer.read_len_prefixed_utf8() - link_name = self.buffer.read_len_prefixed_utf8() - sql = self.buffer.read_len_prefixed_utf8() - external_views.append(((obj_name, link_name), sql)) - else: - name = self.buffer.read_len_prefixed_utf8() - sql = self.buffer.read_len_prefixed_utf8() - external_views.append((name, sql)) + elif hdrname == DUMP_EXTERNAL_VIEW: + external_views = self.restore_external_views() + elif hdrname == DUMP_NAMESPACE_COUNT: + ns_count = self.buffer.read_int32() proto_major = self.buffer.read_int16() proto_minor = self.buffer.read_int16() @@ -2061,36 +2161,30 @@ cdef class EdgeConnection(frontend.FrontendConnection): raise errors.ProtocolError( f'unsupported dump version {proto_major}.{proto_minor}') - schema_ddl = self.buffer.read_len_prefixed_bytes() - - ids_num = self.buffer.read_int32() - schema_ids = [] - for _ in range(ids_num): - schema_ids.append(( - self.buffer.read_len_prefixed_utf8(), - self.buffer.read_len_prefixed_utf8(), - self.buffer.read_bytes(16), - )) - - block_num = self.buffer.read_int32() - blocks = [] - for _ in range(block_num): - blocks.append(( - self.buffer.read_bytes(16), - self.buffer.read_len_prefixed_bytes(), - )) - - # Ignore deps info - for _ in range(self.buffer.read_int16()): - self.buffer.read_bytes(16) + if ns_count > 0: + for _ in range(ns_count): + ns = self.buffer.read_len_prefixed_utf8() + schema_info_by_ns[ns] = self.restore_schema_info() + else: + schema_info_by_ns[default_ns] = self.restore_schema_info(external_views=external_views) self.buffer.finish_message() dbname = _dbview.dbname pgcon = await server.acquire_pgcon(dbname) self._in_dump_restore = True + + for ns in schema_info_by_ns: + if ns == edbdef.DEFAULT_NS: + continue + await server.create_namespace(pgcon, ns) + await self._execute_utility_stmt(f'CREATE NAMESPACE {ns}', pgcon) + await server.introspect(dbname, ns) + try: - _dbview.decode_state(sertypes.NULL_TYPE_ID.bytes, b'') + all_restore_blocks = [] + all_tables = [] + await self._execute_utility_stmt( 'START TRANSACTION ISOLATION SERIALIZABLE', pgcon, @@ -2105,56 +2199,63 @@ cdef class EdgeConnection(frontend.FrontendConnection): SET statement_timeout = 0; ''', ) - - schema_sql_units, restore_blocks, tables = \ - await compiler_pool.describe_database_restore( - user_schema, - global_schema, - dump_server_ver_str, - schema_ddl, - schema_ids, - blocks, - proto, - dict(external_views) - ) - - for query_unit in schema_sql_units: - new_types = None - _dbview.start(query_unit) - - try: - if query_unit.config_ops: - for op in query_unit.config_ops: - if op.scope is config.ConfigScope.INSTANCE: - raise errors.ProtocolError( - 'CONFIGURE INSTANCE cannot be executed' - ' in dump restore' - ) - - if query_unit.sql: - if query_unit.ddl_stmt_id: - ddl_ret = await pgcon.run_ddl(query_unit) - if ddl_ret and ddl_ret['new_types']: - new_types = ddl_ret['new_types'] - if query_unit.schema_refl_sqls: - # no performance optimization - await pgcon.sql_execute(query_unit.schema_refl_sqls) - else: - await pgcon.sql_execute(query_unit.sql) - except Exception: - _dbview.on_error() - raise - else: - _dbview.on_success(query_unit, new_types) + for ns, (schema_ddl, schema_ids, blocks, external_views) in schema_info_by_ns.items(): + logger.info(f"Restoring namespace: {ns}...") + user_schema = _dbview.get_user_schema(ns) + _dbview.decode_state(sertypes.NULL_TYPE_ID.bytes, b'', ns) + + schema_sql_units, restore_blocks, tables = \ + await compiler_pool.describe_database_restore( + ns, + user_schema, + global_schema, + dump_server_ver_str, + schema_ddl, + schema_ids, + blocks, + proto, + dict(external_views) + ) + all_restore_blocks.extend(restore_blocks) + all_tables.extend(tables) + + for query_unit in schema_sql_units: + new_types = None + _dbview.start(query_unit) + + try: + if query_unit.config_ops: + for op in query_unit.config_ops: + if op.scope is config.ConfigScope.INSTANCE: + raise errors.ProtocolError( + 'CONFIGURE INSTANCE cannot be executed' + ' in dump restore' + ) + + if query_unit.sql: + if query_unit.ddl_stmt_id: + ddl_ret = await pgcon.run_ddl(query_unit) + if ddl_ret and ddl_ret['new_types']: + new_types = ddl_ret['new_types'] + if query_unit.schema_refl_sqls: + # no performance optimization + await pgcon.sql_execute(query_unit.schema_refl_sqls) + else: + await pgcon.sql_execute(query_unit.sql) + except Exception: + _dbview.on_error() + raise + else: + _dbview.on_success(query_unit, new_types) restore_blocks = { b.schema_object_id: b - for b in restore_blocks + for b in all_restore_blocks } disable_trigger_q = '' enable_trigger_q = '' - for table in tables: + for table in all_tables: disable_trigger_q += ( f'ALTER TABLE {table} DISABLE TRIGGER ALL;' ) @@ -2230,21 +2331,11 @@ cdef class EdgeConnection(frontend.FrontendConnection): self._in_dump_restore = False server.release_pgcon(dbname, pgcon) - await server.introspect_db(dbname) + await server.introspect(dbname) - if _dbview.is_state_desc_changed(): - self.write(self.make_state_data_description_msg()) - - state_tid, state_data = _dbview.encode_state() - - msg = WriteBuffer.new_message(b'C') - msg.write_int16(0) # no headers - msg.write_int64(0) # capabilities - msg.write_len_prefixed_bytes(b'RESTORE') - msg.write_bytes(state_tid.bytes) - msg.write_len_prefixed_bytes(state_data) - self.write(msg.end_message()) - self.flush() + for ns in schema_info_by_ns: + if _dbview.is_state_desc_changed(ns): + self.write(self.make_state_data_description_msg()) def _build_type_id_map_for_restore_mending(self, restore_block): type_map = {} @@ -2261,6 +2352,7 @@ cdef class EdgeConnection(frontend.FrontendConnection): type_map[desc.schema_type_id] = ( self.get_dbview().resolve_backend_type_id( desc.schema_type_id, + self.namespace ) ) @@ -2368,7 +2460,7 @@ async def run_script( EdgeConnection conn dbview.CompiledQuery compiled conn = new_edge_connection(server) - await conn._start_connection(database) + await conn._start_connection(database, edbdef.DEFAULT_NS) try: compiled = await conn.get_dbview().parse( dbview.QueryRequestInfo( diff --git a/edb/server/protocol/binary_v0.pyx b/edb/server/protocol/binary_v0.pyx index 43f6c0d1d88..65d5e89d17b 100644 --- a/edb/server/protocol/binary_v0.pyx +++ b/edb/server/protocol/binary_v0.pyx @@ -20,7 +20,7 @@ import asyncio cdef tuple MIN_LEGACY_PROTOCOL = edbdef.MIN_LEGACY_PROTOCOL -from edb.server import args as srvargs +from edb.server import args as srvargs, defines from edb.server.protocol cimport args_ser from edb.server.protocol import execute @@ -197,7 +197,9 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): f'accept connections' ) - await self._start_connection(database) + namespace = params.get('namespace', edbdef.DEFAULT_NS) + + await self._start_connection(database, namespace) # The user has already been authenticated by other means # (such as the ability to write to a protected socket). @@ -265,170 +267,7 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): self.flush() async def legacy_dump(self): - cdef: - WriteBuffer msg_buf - dbview.DatabaseConnectionView _dbview - - self.reject_headers() - self.buffer.finish_message() - - _dbview = self.get_dbview() - if _dbview.txid: - raise errors.ProtocolError( - 'DUMP must not be executed while in transaction' - ) - - server = self.server - compiler_pool = server.get_compiler_pool() - - dbname = _dbview.dbname - pgcon = await server.acquire_pgcon(dbname) - self._in_dump_restore = True - try: - # To avoid having races, we want to: - # - # 1. start a transaction; - # - # 2. in the compiler process we connect to that transaction - # and re-introspect the schema in it. - # - # 3. all dump worker pg connection would work on the same - # connection. - # - # This guarantees that every pg connection and the compiler work - # with the same DB state. - - await pgcon.sql_execute( - b'''START TRANSACTION - ISOLATION LEVEL SERIALIZABLE - READ ONLY - DEFERRABLE; - - -- Disable transaction or query execution timeout - -- limits. Both clients and the server can be slow - -- during the dump/restore process. - SET idle_in_transaction_session_timeout = 0; - SET statement_timeout = 0; - ''', - ) - - user_schema = await server.introspect_user_schema(dbname, pgcon) - global_schema = await server.introspect_global_schema(pgcon) - db_config = await server.introspect_db_config(pgcon) - dump_protocol = self.max_protocol - - schema_ddl, schema_dynamic_ddl, schema_ids, blocks, external_ids = ( - await compiler_pool.describe_database_dump( - user_schema, - global_schema, - db_config, - dump_protocol, - ) - ) - - if schema_dynamic_ddl: - for query in schema_dynamic_ddl: - result = await pgcon.sql_fetch_val(query.encode('utf-8')) - if result: - schema_ddl += '\n' + result.decode('utf-8') - - msg_buf = WriteBuffer.new_message(b'@') - - msg_buf.write_int16(4) # number of headers - msg_buf.write_int16(DUMP_HEADER_BLOCK_TYPE) - msg_buf.write_len_prefixed_bytes(DUMP_HEADER_BLOCK_TYPE_INFO) - msg_buf.write_int16(DUMP_HEADER_SERVER_VER) - msg_buf.write_len_prefixed_utf8(str(buildmeta.get_version())) - msg_buf.write_int16(DUMP_HEADER_SERVER_TIME) - msg_buf.write_len_prefixed_utf8(str(int(time.time()))) - - # adding external ddl & external ids - msg_buf.write_int16(DUMP_EXTERNAL_VIEW) - external_views = await self.external_views(external_ids, pgcon) - msg_buf.write_int32(len(external_views)) - for name, view_sql in external_views: - if isinstance(name, tuple): - msg_buf.write_int16(DUMP_EXTERNAL_KEY_LINK) - msg_buf.write_len_prefixed_utf8(name[0]) - msg_buf.write_len_prefixed_utf8(name[1]) - else: - msg_buf.write_int16(DUMP_EXTERNAL_KEY_OBJ) - msg_buf.write_len_prefixed_utf8(name) - msg_buf.write_len_prefixed_utf8(view_sql) - - msg_buf.write_int16(dump_protocol[0]) - msg_buf.write_int16(dump_protocol[1]) - msg_buf.write_len_prefixed_utf8(schema_ddl) - - msg_buf.write_int32(len(schema_ids)) - for (tn, td, tid) in schema_ids: - msg_buf.write_len_prefixed_utf8(tn) - msg_buf.write_len_prefixed_utf8(td) - assert len(tid) == 16 - msg_buf.write_bytes(tid) # uuid - - msg_buf.write_int32(len(blocks)) - for block in blocks: - assert len(block.schema_object_id.bytes) == 16 - msg_buf.write_bytes(block.schema_object_id.bytes) # uuid - msg_buf.write_len_prefixed_bytes(block.type_desc) - - msg_buf.write_int16(len(block.schema_deps)) - for depid in block.schema_deps: - assert len(depid.bytes) == 16 - msg_buf.write_bytes(depid.bytes) # uuid - - self._transport.write(memoryview(msg_buf.end_message())) - self.flush() - - blocks_queue = collections.deque(blocks) - output_queue = asyncio.Queue(maxsize=2) - - async with taskgroup.TaskGroup() as g: - g.create_task(pgcon.dump( - blocks_queue, - output_queue, - DUMP_BLOCK_SIZE, - )) - - nstops = 0 - while True: - if self._cancelled: - raise ConnectionAbortedError - - out = await output_queue.get() - if out is None: - nstops += 1 - if nstops == 1: - # we only have one worker right now - break - else: - block, block_num, data = out - - msg_buf = WriteBuffer.new_message(b'=') - msg_buf.write_int16(4) # number of headers - - msg_buf.write_int16(DUMP_HEADER_BLOCK_TYPE) - msg_buf.write_len_prefixed_bytes( - DUMP_HEADER_BLOCK_TYPE_DATA) - msg_buf.write_int16(DUMP_HEADER_BLOCK_ID) - msg_buf.write_len_prefixed_bytes( - block.schema_object_id.bytes) - msg_buf.write_int16(DUMP_HEADER_BLOCK_NUM) - msg_buf.write_len_prefixed_bytes( - str(block_num).encode()) - msg_buf.write_int16(DUMP_HEADER_BLOCK_DATA) - msg_buf.write_len_prefixed_buffer(data) - - self._transport.write(memoryview(msg_buf.end_message())) - if self._write_waiter: - await self._write_waiter - - await pgcon.sql_execute(b"ROLLBACK;") - - finally: - self._in_dump_restore = False - server.release_pgcon(dbname, pgcon) + await self._dump() msg_buf = WriteBuffer.new_message(b'C') msg_buf.write_int16(0) # no headers @@ -437,229 +276,7 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): self.flush() async def legacy_restore(self): - cdef: - WriteBuffer msg_buf - char mtype - dbview.DatabaseConnectionView _dbview - - _dbview = self.get_dbview() - if _dbview.txid: - raise errors.ProtocolError( - 'RESTORE must not be executed while in transaction' - ) - - self.reject_headers() - self.buffer.read_int16() # discard -j level - - # Now parse the embedded dump header message: - - server = self.server - compiler_pool = server.get_compiler_pool() - - global_schema = _dbview.get_global_schema() - user_schema = _dbview.get_user_schema() - - dump_server_ver_str = None - headers_num = self.buffer.read_int16() - external_views = [] - for _ in range(headers_num): - hdrname = self.buffer.read_int16() - if hdrname != DUMP_EXTERNAL_VIEW: - hdrval = self.buffer.read_len_prefixed_bytes() - if hdrname == DUMP_HEADER_SERVER_VER: - dump_server_ver_str = hdrval.decode('utf-8') - # getting external ddl & external ids - if hdrname == DUMP_EXTERNAL_VIEW: - external_view_num = self.buffer.read_int32() - for _ in range(external_view_num): - key_flag = self.buffer.read_int16() - if key_flag == DUMP_EXTERNAL_KEY_LINK: - obj_name = self.buffer.read_len_prefixed_utf8() - link_name = self.buffer.read_len_prefixed_utf8() - sql = self.buffer.read_len_prefixed_utf8() - external_views.append(((obj_name, link_name), sql)) - else: - name = self.buffer.read_len_prefixed_utf8() - sql = self.buffer.read_len_prefixed_utf8() - external_views.append((name, sql)) - - proto_major = self.buffer.read_int16() - proto_minor = self.buffer.read_int16() - proto = (proto_major, proto_minor) - if proto > DUMP_VER_MAX or proto < DUMP_VER_MIN: - raise errors.ProtocolError( - f'unsupported dump version {proto_major}.{proto_minor}') - - schema_ddl = self.buffer.read_len_prefixed_bytes() - - ids_num = self.buffer.read_int32() - schema_ids = [] - for _ in range(ids_num): - schema_ids.append(( - self.buffer.read_len_prefixed_utf8(), - self.buffer.read_len_prefixed_utf8(), - self.buffer.read_bytes(16), - )) - - block_num = self.buffer.read_int32() - blocks = [] - for _ in range(block_num): - blocks.append(( - self.buffer.read_bytes(16), - self.buffer.read_len_prefixed_bytes(), - )) - - # Ignore deps info - for _ in range(self.buffer.read_int16()): - self.buffer.read_bytes(16) - - self.buffer.finish_message() - dbname = _dbview.dbname - pgcon = await server.acquire_pgcon(dbname) - - self._in_dump_restore = True - try: - _dbview.decode_state(sertypes.NULL_TYPE_ID.bytes, b'') - await self._execute_utility_stmt( - 'START TRANSACTION ISOLATION SERIALIZABLE', - pgcon, - ) - - await pgcon.sql_execute( - b''' - -- Disable transaction or query execution timeout - -- limits. Both clients and the server can be slow - -- during the dump/restore process. - SET idle_in_transaction_session_timeout = 0; - SET statement_timeout = 0; - ''', - ) - - schema_sql_units, restore_blocks, tables = \ - await compiler_pool.describe_database_restore( - user_schema, - global_schema, - dump_server_ver_str, - schema_ddl, - schema_ids, - blocks, - proto, - dict(external_views) - ) - - for query_unit in schema_sql_units: - new_types = None - _dbview.start(query_unit) - - try: - if query_unit.config_ops: - for op in query_unit.config_ops: - if op.scope is config.ConfigScope.INSTANCE: - raise errors.ProtocolError( - 'CONFIGURE INSTANCE cannot be executed' - ' in dump restore' - ) - - if query_unit.sql: - if query_unit.ddl_stmt_id: - ddl_ret = await pgcon.run_ddl(query_unit) - if ddl_ret and ddl_ret['new_types']: - new_types = ddl_ret['new_types'] - if query_unit.schema_refl_sqls: - # no performance optimization - await pgcon.sql_execute(query_unit.schema_refl_sqls) - else: - await pgcon.sql_execute(query_unit.sql) - except Exception: - _dbview.on_error() - raise - else: - _dbview.on_success(query_unit, new_types) - - restore_blocks = { - b.schema_object_id: b - for b in restore_blocks - } - - disable_trigger_q = '' - enable_trigger_q = '' - for table in tables: - disable_trigger_q += ( - f'ALTER TABLE {table} DISABLE TRIGGER ALL;' - ) - enable_trigger_q += ( - f'ALTER TABLE {table} ENABLE TRIGGER ALL;' - ) - - await pgcon.sql_execute(disable_trigger_q.encode()) - - # Send "RestoreReadyMessage" - msg = WriteBuffer.new_message(b'+') - msg.write_int16(0) # no headers - msg.write_int16(1) # -j1 - self.write(msg.end_message()) - self.flush() - - while True: - if not self.buffer.take_message(): - # Don't report idling when restoring a dump. - # This is an edge case and the client might be - # legitimately slow. - await self.wait_for_message(report_idling=False) - mtype = self.buffer.get_message_type() - - if mtype == b'=': - block_type = None - block_id = None - block_num = None - block_data = None - - num_headers = self.buffer.read_int16() - for _ in range(num_headers): - header = self.buffer.read_int16() - if header == DUMP_HEADER_BLOCK_TYPE: - block_type = self.buffer.read_len_prefixed_bytes() - elif header == DUMP_HEADER_BLOCK_ID: - block_id = self.buffer.read_len_prefixed_bytes() - block_id = pg_UUID(block_id) - elif header == DUMP_HEADER_BLOCK_NUM: - block_num = self.buffer.read_len_prefixed_bytes() - elif header == DUMP_HEADER_BLOCK_DATA: - block_data = self.buffer.read_len_prefixed_bytes() - - self.buffer.finish_message() - - if (block_type is None or block_id is None - or block_num is None or block_data is None): - raise errors.ProtocolError('incomplete data block') - - restore_block = restore_blocks[block_id] - type_id_map = self._build_type_id_map_for_restore_mending( - restore_block) - await pgcon.restore(restore_block, block_data, type_id_map) - - elif mtype == b'.': - self.buffer.finish_message() - break - - else: - self.fallthrough() - - await pgcon.sql_execute(enable_trigger_q.encode()) - - except Exception: - await pgcon.sql_execute(b'ROLLBACK') - _dbview.abort_tx() - raise - - else: - await self._execute_utility_stmt('COMMIT', pgcon) - - finally: - self._in_dump_restore = False - server.release_pgcon(dbname, pgcon) - - await server.introspect_db(dbname) + await self._restore() msg = WriteBuffer.new_message(b'C') msg.write_int16(0) # no headers @@ -986,7 +603,8 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): inline_objectids=inline_objectids, allow_capabilities=allow_capabilities, module=module, - read_only=read_only + read_only=read_only, + namespace=self.namespace ) return eql, query_req, stmt_name @@ -1085,7 +703,8 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): protocol_version=self.protocol_version, output_format=FMT_NONE, module=module, - read_only=read_only + read_only=read_only, + namespace=self.namespace, ) return await self.get_dbview()._compile( @@ -1205,7 +824,8 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): pgcon.PGConnection conn unit_group = await self._legacy_compile_script( - eql, skip_first=skip_first, module=module, read_only=read_only) + eql, skip_first=skip_first, module=module, read_only=read_only + ) if self._cancelled: raise ConnectionAbortedError @@ -1238,12 +858,13 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): query_unit.create_db_template, _dbview.dbname ) if query_unit.drop_db: - await self.server._on_before_drop_db( - query_unit.drop_db, _dbview.dbname) - + await self.server._on_before_drop_db(query_unit.drop_db, _dbview.dbname) + if query_unit.create_ns: + await self.server.create_namespace(conn, query_unit.create_ns) + if query_unit.drop_ns: + await self.server._on_before_drop_ns(query_unit.drop_ns) if query_unit.system_config: - await execute.execute_system_config( - conn, _dbview, query_unit) + await execute.execute_system_config(conn, _dbview, query_unit) else: if query_unit.sql: if query_unit.ddl_stmt_id: @@ -1251,10 +872,7 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): if ddl_ret and ddl_ret['new_types']: new_types = ddl_ret['new_types'] elif query_unit.is_transactional: - await conn.sql_execute( - query_unit.sql, - state=state, - ) + await conn.sql_execute(query_unit.sql, state=state) else: i = 0 for sql in query_unit.sql: @@ -1270,18 +888,19 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): orig_state = None if query_unit.create_db: - await self.server.introspect_db( - query_unit.create_db - ) + await self.server.introspect(query_unit.create_db) + + if query_unit.create_ns: + await self.server.introspect(_dbview.dbname, query_unit.create_ns) + + if query_unit.drop_db: + self.server._on_after_drop_db(query_unit.drop_db) if query_unit.drop_db: - self.server._on_after_drop_db( - query_unit.drop_db) + self.server._on_after_drop_ns(_dbview.dbname, query_unit.drop_ns) if query_unit.config_ops: - await _dbview.apply_config_ops( - conn, - query_unit.config_ops) + await _dbview.apply_config_ops(conn, query_unit.config_ops) except Exception: _dbview.on_error() if not conn.in_tx() and _dbview.in_tx(): @@ -1293,7 +912,7 @@ cdef class EdgeConnectionBackwardsCompatible(EdgeConnection): else: side_effects = _dbview.on_success(query_unit, new_types) if side_effects: - execute.signal_side_effects(_dbview, side_effects) + execute.signal_side_effects(_dbview, query_unit, side_effects) if not _dbview.in_tx(): state = _dbview.serialize_state() if state is not orig_state: diff --git a/edb/server/protocol/consts.pxi b/edb/server/protocol/consts.pxi index 09cc41e1eee..b02acf02800 100644 --- a/edb/server/protocol/consts.pxi +++ b/edb/server/protocol/consts.pxi @@ -27,6 +27,8 @@ DEF DUMP_HEADER_SERVER_TIME = 102 DEF DUMP_HEADER_SERVER_VER = 103 DEF DUMP_HEADER_BLOCKS_INFO = 104 DEF DUMP_EXTERNAL_VIEW = 105 +DEF DUMP_NAMESPACE_COUNT = 106 +DEF DUMP_NAMESPACE_NAME = 107 DEF DUMP_HEADER_BLOCK_ID = 110 DEF DUMP_HEADER_BLOCK_NUM = 111 diff --git a/edb/server/protocol/edgeql_ext.pyx b/edb/server/protocol/edgeql_ext.pyx index 58881a2060e..5fac18076fd 100644 --- a/edb/server/protocol/edgeql_ext.pyx +++ b/edb/server/protocol/edgeql_ext.pyx @@ -64,6 +64,7 @@ async def handle_request( query = None module = None limit = 0 + namespace = edbdef.DEFAULT_NS try: if request.method == b'POST': @@ -76,6 +77,7 @@ async def handle_request( variables = body.get('variables') globals_ = body.get('globals') module = body.get('module') + namespace = body.get('namespace', edbdef.DEFAULT_NS) limit = body.get('limit', 0) else: raise TypeError( @@ -110,6 +112,12 @@ async def handle_request( if module is not None: module = module[0] + namespace = qs.get('namespace') + if namespace is not None: + namespace = namespace[0] + else: + namespace = edbdef.DEFAULT_NS + limit = qs.get('limit') if limit is not None: limit = int(limit[0]) @@ -130,6 +138,9 @@ async def handle_request( if module is not None and not isinstance(module, str): raise TypeError('"module" must be a str object') + if namespace is not None and not isinstance(namespace, str): + raise TypeError('"namespace" must be a str object') + if limit is not None and not isinstance(limit, int): raise TypeError('"limit" must be an integer object') @@ -147,6 +158,7 @@ async def handle_request( try: result = await execute.parse_execute_json( db, + namespace, query, variables=variables or {}, globals_=globals_ or {}, diff --git a/edb/server/protocol/execute.pyx b/edb/server/protocol/execute.pyx index da9a255fcf8..ea81c08f7da 100644 --- a/edb/server/protocol/execute.pyx +++ b/edb/server/protocol/execute.pyx @@ -82,6 +82,10 @@ async def execute( ) if query_unit.drop_db: await server._on_before_drop_db(query_unit.drop_db, dbv.dbname) + if query_unit.create_ns: + await server.create_namespace(be_conn, query_unit.create_ns) + if query_unit.drop_ns: + await server._on_before_drop_ns(query_unit.drop_ns, query_unit.namespace) if query_unit.system_config: await execute_system_config(be_conn, dbv, query_unit) else: @@ -127,14 +131,21 @@ async def execute( if query_unit.tx_savepoint_declare: dbv.declare_savepoint( - query_unit.sp_name, query_unit.sp_id) + query_unit.namespace, query_unit.sp_name, query_unit.sp_id + ) if query_unit.create_db: - await server.introspect_db(query_unit.create_db) + await server.introspect(query_unit.create_db) + + if query_unit.create_ns: + await server.introspect(dbv.dbname, query_unit.create_ns) if query_unit.drop_db: server._on_after_drop_db(query_unit.drop_db) + if query_unit.drop_ns: + server._on_after_drop_ns(dbv.dbname, query_unit.drop_ns) + if config_ops: await dbv.apply_config_ops(be_conn, config_ops) @@ -150,7 +161,7 @@ async def execute( else: side_effects = dbv.on_success(query_unit, new_types) if side_effects: - signal_side_effects(dbv, side_effects) + signal_side_effects(dbv, query_unit, side_effects) if not dbv.in_tx(): state = dbv.serialize_state() if state is not orig_state: @@ -277,7 +288,7 @@ async def execute_script( and query_unit.user_schema_mutation ): if user_schema_unpacked is None: - base_user_schema = user_schema or dbv.get_user_schema() + base_user_schema = user_schema or dbv.get_user_schema(query_unit.namespace) else: base_user_schema = user_schema_unpacked @@ -312,16 +323,16 @@ async def execute_script( gmut_unpickled = pickle.loads(group_mutation) side_effects = dbv.commit_implicit_tx( - user_schema, user_schema_unpacked, gmut_unpickled, + unit_group.namespace, user_schema, user_schema_unpacked, gmut_unpickled, global_schema, cached_reflection, unit_group.affected_obj_ids ) if side_effects: - signal_side_effects(dbv, side_effects) + signal_side_effects(dbv, query_unit, side_effects) if ( side_effects & dbview.SideEffects.SchemaChanges and group_mutation is not None ): - dbv.save_schema_mutation(gmut_unpickled, group_mutation) + dbv.save_schema_mutation(query_unit.namespace, gmut_unpickled, group_mutation) state = dbv.serialize_state() if state is not orig_state: @@ -368,16 +379,38 @@ async def execute_system_config( await conn.sql_execute(b'SELECT pg_reload_conf()') -def signal_side_effects(dbv, side_effects): +def signal_side_effects(dbv, query_unit, side_effects): server = dbv.server if not server._accept_new_tasks: return if side_effects & dbview.SideEffects.SchemaChanges: + if query_unit.create_ns: + namespace = query_unit.create_ns + else: + namespace = query_unit.namespace server.create_task( server._signal_sysevent( 'schema-changes', dbname=dbv.dbname, + namespace=namespace, + drop_ns=query_unit.drop_ns + ), + interruptable=False, + ) + if side_effects & dbview.SideEffects.DatabaseCreate: + server.create_task( + server._signal_sysevent( + 'database-create', + dbname=query_unit.create_db, + ), + interruptable=False, + ) + if side_effects & dbview.SideEffects.DatabaseDrop: + server.create_task( + server._signal_sysevent( + 'database-drop', + dbname=query_unit.drop_db, ), interruptable=False, ) @@ -410,6 +443,7 @@ def signal_side_effects(dbv, side_effects): async def parse_execute( db: dbview.Database, + namespace: str, query: str, *, external_view: Mapping = immutables.Map(), @@ -429,7 +463,8 @@ async def parse_execute( output_format=compiler.OutputFormat.NONE, allow_capabilities=compiler.Capability.MODIFICATIONS | compiler.Capability.DDL, external_view=external_view, - testmode=testmode + testmode=testmode, + namespace=namespace ) compiled = await dbv.parse(query_req) @@ -448,6 +483,7 @@ async def parse_execute( async def parse_execute_json( db: dbview.Database, + namespace: str, query: str, *, variables: Mapping[str, Any] = immutables.Map(), @@ -479,6 +515,7 @@ async def parse_execute_json( allow_capabilities=allow_cap, read_only=read_only, module=module, + namespace=namespace, force_limit=limit ) diff --git a/edb/server/protocol/extern_obj.py b/edb/server/protocol/extern_obj.py index 972a1fb6b81..38fa540202c 100644 --- a/edb/server/protocol/extern_obj.py +++ b/edb/server/protocol/extern_obj.py @@ -25,6 +25,7 @@ from edb import errors +from edb.server import defines from edb.server.protocol import execute from edb.pgsql.types import base_type_name_map_r @@ -305,6 +306,7 @@ def _unknown_path(): try: if request.content_type and b'json' in request.content_type: body = json.loads(request.body) + namespace = body.pop('namespace', defines.DEFAULT_NS) if not isinstance(body, dict): raise TypeError( 'the body of the request must be a JSON object') @@ -334,6 +336,7 @@ def _unknown_path(): try: await execute.parse_execute( db, + namespace, req.to_ddl(), external_view=req.resolve_view(), testmode=bool(request.testmode) diff --git a/edb/server/protocol/infer_expr.py b/edb/server/protocol/infer_expr.py index c990b4509a1..403a5ab4dd7 100644 --- a/edb/server/protocol/infer_expr.py +++ b/edb/server/protocol/infer_expr.py @@ -23,6 +23,7 @@ from edb import errors from edb.common import debug from edb.common import markup +from edb.server import defines async def handle_request( @@ -52,6 +53,7 @@ async def handle_request( 'the body of the request must be a JSON object' ) module = body.get('module') + namespace = body.get('namespace', defines.DEFAULT_NS) objname = body.get('object') expr = body.get('expression') else: @@ -68,6 +70,8 @@ async def handle_request( if not isinstance(module, str): raise TypeError("Field 'module' must be a string.") + if not isinstance(namespace, str): + raise TypeError("Field 'namespace' must be a string.") if not isinstance(objname, str): raise TypeError("Field 'object' must be a string.") if not isinstance(expr, str): @@ -88,7 +92,7 @@ async def handle_request( await db.introspection() try: - result = await execute(db, server, module, objname, expr) + result = await execute(db, server, namespace, module, objname, expr) except Exception as ex: if debug.flags.server: markup.dump(ex) @@ -108,13 +112,18 @@ async def handle_request( response.body = json.dumps(result).encode() -async def execute(db, server, module: str, objname: str, expression: str): +async def execute(db, server, namespace: str, module: str, objname: str, expression: str): + if namespace not in db.ns_map: + raise errors.InternalServerError( + f'NameSpace: [{namespace}] not in current db [{db.name}](ver:{db.dbver})' + ) + ns = db.ns_map[namespace] dbver = db.dbver query_cache = server._http_query_cache name_str = f"{module}::{objname}" - cache_key = ('infer_expr', name_str, expression, dbver, module) + cache_key = ('infer_expr', name_str, expression, dbver, module, namespace) entry = query_cache.get(cache_key, None) @@ -124,9 +133,10 @@ async def execute(db, server, module: str, objname: str, expression: str): compiler_pool = server.get_compiler_pool() result = await compiler_pool.infer_expr( db.name, - db.user_schema, + namespace, + ns.user_schema, server.get_global_schema(), - db.reflection_cache, + ns.reflection_cache, db.db_config, server.get_compilation_system_config(), name_str, diff --git a/edb/server/protocol/notebook_ext.pyx b/edb/server/protocol/notebook_ext.pyx index 259cb42edc8..fc1baa034dd 100644 --- a/edb/server/protocol/notebook_ext.pyx +++ b/edb/server/protocol/notebook_ext.pyx @@ -175,7 +175,7 @@ async def execute(db, server, queries: list): dbv = await server.new_dbview( dbname=db.name, query_cache=False, - protocol_version=edbdef.CURRENT_PROTOCOL, + protocol_version=edbdef.CURRENT_PROTOCOL ) bind_data = None diff --git a/edb/server/protocol/schema_info.py b/edb/server/protocol/schema_info.py index 7f3acdce8cd..dfcc0572a38 100644 --- a/edb/server/protocol/schema_info.py +++ b/edb/server/protocol/schema_info.py @@ -25,6 +25,7 @@ from edb import errors from edb.common import debug from edb.common import markup +from edb.server import defines async def handle_request( @@ -41,6 +42,7 @@ async def handle_request( return query_uuid = None + namespace = defines.DEFAULT_NS try: if request.method == b'POST': @@ -50,6 +52,7 @@ async def handle_request( raise TypeError( 'the body of the request must be a JSON object') query_uuid = body.get('uuid') + namespace = body.get('namespace', defines.DEFAULT_NS) else: raise TypeError( 'unable to interpret SchemaInfo POST request') @@ -61,6 +64,11 @@ async def handle_request( query_uuid = qs.get('uuid') if query_uuid is not None: query_uuid = query_uuid[0] + namespace = qs.get('namespace') + if namespace is not None: + namespace = namespace[0] + else: + namespace = defines.DEFAULT_NS else: raise TypeError('expected a GET or a POST request') @@ -80,7 +88,7 @@ async def handle_request( response.content_type = b'application/json' await db.introspection() try: - result = await execute(db, server, query_uuid) + result = await execute(db, server, namespace, query_uuid) except Exception as ex: if debug.flags.server: markup.dump(ex) @@ -101,8 +109,12 @@ async def handle_request( response.body = b'{"data":' + result + b'}' -async def execute(db, server, query_uuid: str): - user_schema = db.user_schema +async def execute(db, server, namespace: str, query_uuid: str): + if namespace not in db.ns_map: + raise errors.InternalServerError( + f'NameSpace: [{namespace}] not in current db [{db.name}](ver:{db.dbver})' + ) + user_schema = db.ns_map[namespace].user_schema global_schema = server.get_global_schema() obj_id = uuid.UUID(query_uuid) diff --git a/edb/server/protocol/system_api.py b/edb/server/protocol/system_api.py index 4a590b73894..e674a1f048a 100644 --- a/edb/server/protocol/system_api.py +++ b/edb/server/protocol/system_api.py @@ -25,7 +25,7 @@ from edb.common import debug from edb.common import markup -from edb.server import compiler +from edb.server import compiler, defines from edb.server import defines as edbdef from . import execute # type: ignore @@ -90,6 +90,7 @@ async def handle_status_request( db = server.get_db(dbname=edbdef.EDGEDB_SYSTEM_DB) result = await execute.parse_execute_json( db, + defines.DEFAULT_NS, query="SELECT 'OK'", output_format=compiler.OutputFormat.JSON_ELEMENTS, # Disable query cache because we need to ensure that the compiled diff --git a/edb/server/server.py b/edb/server/server.py index 067ea23cdb7..3328f9e36ab 100644 --- a/edb/server/server.py +++ b/edb/server/server.py @@ -22,6 +22,7 @@ import contextlib import functools import hashlib +import re from typing import * import asyncio @@ -66,7 +67,7 @@ from edb.server import metrics from edb.server import pgcon from edb.server.pgcon import errors as pgcon_errors - +from edb.server.bootstrap import get_tpl_sql, gen_tpl_dump, store_tpl_sql from . import dbview if TYPE_CHECKING: @@ -77,6 +78,29 @@ ADMIN_PLACEHOLDER = "" logger = logging.getLogger('edb.server') log_metrics = logging.getLogger('edb.server.metrics') +_RE_STR_REPL_NS = re.compile( + r'(current_setting\([\']+)?' + r'(edgedb)(\.|instdata|pub\.|pub;|pub\'|ss|std\.|std\'|std;|;)([\"a-z0-9_\-]+)?', +) + + +def repl_ignore_setting(match_obj): + maybe_setting, schema_name, tailing, maybe_domain_name = match_obj.groups() + # skip changing pg_catalog.current_setting('edgedb.xxx') + if maybe_setting: + return maybe_setting + schema_name + tailing + (maybe_domain_name or '') + if maybe_domain_name: + # skip create type/domain in builtin: + # Type ends with '_t' in edgedb + # Type ends with '_t' in edgedbpub + # Domain ends with '_domain' in edgedbstd + if ( + (tailing == '.' and maybe_domain_name.strip('"').endswith('_t')) + or (tailing == 'pub.' and maybe_domain_name.strip('"').endswith('_t')) + or (tailing == 'std.' and maybe_domain_name.strip('"').endswith('_domain')) + ): + return schema_name + tailing + maybe_domain_name + return "{ns_prefix}" + schema_name + tailing + (maybe_domain_name or '') class RoleDescriptor(TypedDict): @@ -96,10 +120,11 @@ class Server(ha_base.ClusterProtocol): _roles: Mapping[str, RoleDescriptor] _instance_data: Mapping[str, str] _sys_queries: Mapping[str, str] - _local_intro_query: bytes + _local_intro_query: str _global_intro_query: bytes _report_config_typedesc: bytes _report_config_data: bytes + _ns_tpl_sql: Optional[str] _std_schema: s_schema.Schema _refl_schema: s_schema.Schema @@ -271,13 +296,14 @@ def __init__( self._session_idle_timeout = None self._admin_ui = admin_ui + self._ns_tpl_sql = None @contextlib.asynccontextmanager - async def aquire_distributed_lock(self, dbname, conn): + async def aquire_distributed_lock(self, dbname, namespace, conn): try: - logger.debug(f'Aquiring advisory lock for <{dbname}>') + logger.debug(f'Aquiring advisory lock for <{dbname}({namespace})>') await conn.sql_execute('select pg_advisory_lock(202304241756)'.encode()) - logger.debug(f'Advisory lock for <{dbname}> aquired') + logger.debug(f'Advisory lock for <{dbname}({namespace})> aquired') yield finally: await conn.sql_execute('select pg_advisory_unlock(202304241756)'.encode()) @@ -666,10 +692,14 @@ async def _reintrospect_global_schema(self): self._dbindex.update_global_schema(new_global_schema) self._fetch_roles() - async def introspect_user_schema(self, dbname, conn): - await self._persist_user_schema(dbname, conn) - - json_data = await conn.sql_fetch_val(self._local_intro_query) + async def introspect_user_schema(self, dbname, namespace, conn): + await self._persist_user_schema(dbname, namespace, conn) + if namespace == defines.DEFAULT_NS: + ns_prefix = '' + else: + ns_prefix = namespace + '_' + ns_intro_query = self._local_intro_query.format(ns_prefix=ns_prefix).encode('utf-8') + json_data = await conn.sql_fetch_val(ns_intro_query) base_schema = s_schema.ChainedSchema( self._std_schema, @@ -702,8 +732,8 @@ async def _acquire_intro_pgcon(self, dbname): raise return conn - async def introspect_db(self, dbname): - """Use this method to (re-)introspect a DB. + async def introspect(self, dbname, namespace: str = None): + """Use this method to (re-)introspect a DB or namespace. If the DB is already registered in self._dbindex, its schema, config, etc. would simply be updated. If it's missing @@ -724,64 +754,81 @@ async def introspect_db(self, dbname): return try: - user_schema = await self.introspect_user_schema(dbname, conn) - - reflection_cache_json = await conn.sql_fetch_val( - b''' - SELECT json_agg(o.c) - FROM ( - SELECT - json_build_object( - 'eql_hash', t.eql_hash, - 'argnames', array_to_json(t.argnames) - ) AS c - FROM - ROWS FROM(edgedb._get_cached_reflection()) - AS t(eql_hash text, argnames text[]) - ) AS o; - ''', - ) - - reflection_cache = immutables.Map({ - r['eql_hash']: tuple(r['argnames']) - for r in json.loads(reflection_cache_json) - }) - - backend_ids_json = await conn.sql_fetch_val( - b''' - SELECT - json_object_agg( - "id"::text, - "backend_id" - )::text - FROM - edgedb."_SchemaType" - ''', - ) - backend_ids = json.loads(backend_ids_json) - - db_config = await self.introspect_db_config(conn) + if namespace is None: + ns_query = self.get_sys_query('listns') + json_data = await conn.sql_fetch_val(ns_query) + ns_list = json.loads(json_data) + else: + ns_list = [namespace] - assert self._dbindex is not None - self._dbindex.register_db( - dbname, - user_schema=user_schema, - db_config=db_config, - reflection_cache=reflection_cache, - backend_ids=backend_ids, - ) + for ns in ns_list: + await self._introspect_ns(conn, dbname, ns) finally: self.release_pgcon(dbname, conn) - async def _persist_user_schema(self, dbname, conn): - async with self.aquire_distributed_lock(dbname, conn): + async def _introspect_ns(self, conn, dbname, namespace): + user_schema = await self.introspect_user_schema(dbname, namespace, conn) + if namespace == defines.DEFAULT_NS: + schema_name = 'edgedb' + else: + schema_name = f"{namespace}_edgedb" + reflection_cache_json = await conn.sql_fetch_val( + f''' + SELECT json_agg(o.c) + FROM ( + SELECT + json_build_object( + 'eql_hash', t.eql_hash, + 'argnames', array_to_json(t.argnames) + ) AS c + FROM + ROWS FROM({schema_name}._get_cached_reflection()) + AS t(eql_hash text, argnames text[]) + ) AS o; + '''.encode('utf-8'), + ) + reflection_cache = immutables.Map( + { + r['eql_hash']: tuple(r['argnames']) + for r in json.loads(reflection_cache_json) + } + ) + backend_ids_json = await conn.sql_fetch_val( + f''' + SELECT + json_object_agg( + "id"::text, + "backend_id" + )::text + FROM + {schema_name}."_SchemaType" + '''.encode('utf-8'), + ) + backend_ids = json.loads(backend_ids_json) + db_config = await self.introspect_db_config(conn) + assert self._dbindex is not None + self._dbindex.register_ns( + dbname, + namespace=namespace, + user_schema=user_schema, + db_config=db_config, + reflection_cache=reflection_cache, + backend_ids=backend_ids, + ) + + async def _persist_user_schema(self, dbname, namespace, conn): + if namespace == defines.DEFAULT_NS: + schema_name = 'edgedbinstdata' + else: + schema_name = f"{namespace}_edgedbinstdata" + async with self.aquire_distributed_lock(dbname, namespace, conn): persist_sqls = await conn.sql_fetch( - b'''\ + f'''\ SELECT "version_id", convert_from("sql", 'utf8') from - edgedbinstdata.schema_persist_history + {schema_name}.schema_persist_history WHERE active ORDER BY "timestamp" - ''' + '''.encode('utf-8') ) if not persist_sqls: logger.debug(f"No schema persistence to do.") @@ -789,15 +836,15 @@ async def _persist_user_schema(self, dbname, conn): for vid, sql in persist_sqls: await conn.sql_execute(sql) - logger.debug(f"Finish schema persistence for <{dbname}: {uuid.UUID(bytes=vid)}>") + logger.debug(f"Finish schema persistence for <{dbname}({namespace}): {uuid.UUID(bytes=vid)}>") - async def persist_user_schema(self, dbname): + async def persist_user_schema(self, dbname, namespace): conn = await self._acquire_intro_pgcon(dbname) if not conn: return try: - await self._persist_user_schema(dbname, conn) + await self._persist_user_schema(dbname, namespace, conn) finally: self.release_pgcon(dbname, conn) @@ -821,7 +868,7 @@ async def introspect_db_config(self, conn): async def _early_introspect_db(self, dbname): """We need to always introspect the extensions for each database. - Otherwise we won't know to accept connections for graphql or + Otherwise, we won't know to accept connections for graphql or http, for example, until a native connection is made. """ logger.info("introspecting extensions for database '%s'", dbname) @@ -831,25 +878,34 @@ async def _early_introspect_db(self, dbname): return try: - extension_names_json = await conn.sql_fetch_val( - b''' - SELECT json_agg(name) FROM edgedb."_SchemaExtension"; - ''', - ) - if extension_names_json: - extensions = set(json.loads(extension_names_json)) - else: - extensions = set() + ns_query = self.get_sys_query('listns') + json_data = await conn.sql_fetch_val(ns_query) + ns_list = json.loads(json_data) + for ns in ns_list: + if ns == defines.DEFAULT_NS: + schema_name = 'edgedb' + else: + schema_name = f"{ns}_edgedb" + extension_names_json = await conn.sql_fetch_val( + f''' + SELECT json_agg(name) FROM {schema_name}."_SchemaExtension"; + '''.encode('utf-8') + ) + if extension_names_json: + extensions = set(json.loads(extension_names_json)) + else: + extensions = set() - assert self._dbindex is not None - self._dbindex.register_db( - dbname, - user_schema=None, - db_config=None, - reflection_cache=None, - backend_ids=None, - extensions=extensions, - ) + assert self._dbindex is not None + self._dbindex.register_ns( + dbname, + namespace=ns, + user_schema=None, + db_config=None, + reflection_cache=None, + backend_ids=None, + extensions=extensions, + ) finally: self.release_pgcon(dbname, conn) @@ -900,11 +956,16 @@ async def _load_instance_data(self): self._sys_queries = immutables.Map( {k: q.encode() for k, q in queries.items()}) - self._local_intro_query = await syscon.sql_fetch_val(b'''\ + local_intro_query = await syscon.sql_fetch_val(b'''\ SELECT text FROM edgedbinstdata.instdata WHERE key = 'local_intro_query'; ''') + self._local_intro_query = _RE_STR_REPL_NS.sub( + repl_ignore_setting, + local_intro_query.decode('utf-8'), + ) + self._global_intro_query = await syscon.sql_fetch_val(b'''\ SELECT text FROM edgedbinstdata.instdata WHERE key = 'global_intro_query'; @@ -945,6 +1006,15 @@ async def _load_instance_data(self): WHERE key = 'report_configs_typedesc'; ''') + if (tpldbdump := await get_tpl_sql(syscon)) is None: + tpldbdump = await gen_tpl_dump(self._cluster) + await store_tpl_sql(tpldbdump, syscon) + ns_tpl_sql = tpldbdump.decode() + else: + ns_tpl_sql = tpldbdump.decode() + + self._ns_tpl_sql = _RE_STR_REPL_NS.sub(repl_ignore_setting, ns_tpl_sql) + finally: self._release_sys_pgcon() @@ -1057,6 +1127,16 @@ async def _on_before_drop_db( await self._ensure_database_not_connected(dbname) + async def _on_before_drop_ns( + self, + namespace: str, + current_namespace: str + ) -> None: + if current_namespace == namespace: + raise errors.ExecutionError( + f'cannot drop the currently open current_namespace {namespace!r}' + ) + async def _on_before_create_db_from_template( self, dbname: str, @@ -1090,6 +1170,10 @@ def _on_after_drop_db(self, dbname: str): metrics.background_errors.inc(1.0, 'on_after_drop_db') raise + def _on_after_drop_ns(self, dbname: str, namespace: str): + assert self._dbindex is not None + self._dbindex.unregister_ns(dbname, namespace) + async def _on_system_config_add(self, setting_name, value): # CONFIGURE INSTANCE INSERT ConfigObject; pass @@ -1239,7 +1323,7 @@ async def _signal_sysevent(self, event, **kwargs): metrics.background_errors.inc(1.0, 'signal_sysevent') raise - def _on_remote_ddl(self, dbname): + def _on_remote_ddl(self, dbname, namespace, drop_ns=None): if not self._accept_new_tasks: return @@ -1247,7 +1331,11 @@ def _on_remote_ddl(self, dbname): # on the __edgedb_sysevent__ channel async def task(): try: - await self.introspect_db(dbname) + if drop_ns: + assert self._dbindex is not None + self._dbindex.unregister_ns(dbname, drop_ns) + else: + await self.introspect(dbname, namespace) except Exception: metrics.background_errors.inc(1.0, 'on_remote_ddl') raise @@ -1262,7 +1350,7 @@ def _on_remote_database_config_change(self, dbname): # on the __edgedb_sysevent__ channel async def task(): try: - await self.introspect_db(dbname) + await self.introspect(dbname) except Exception: metrics.background_errors.inc( 1.0, 'on_remote_database_config_change') @@ -1279,7 +1367,7 @@ def _on_local_database_config_change(self, dbname): # of the DB and update all components of it. async def task(): try: - await self.introspect_db(dbname) + await self.introspect(dbname) except Exception: metrics.background_errors.inc( 1.0, 'on_local_database_config_change') @@ -1893,6 +1981,10 @@ def on_switch_over(self): call_on_switch_over=False ) + async def create_namespace(self, be_conn: pgcon.PGConnection, name: str): + tpl_sql = self._ns_tpl_sql.replace("{ns_prefix}", f"{name}_") + await be_conn.sql_execute(tpl_sql.encode('utf-8')) + def get_active_pgcon_num(self) -> int: return ( self._pg_pool.current_capacity - self._pg_pool.get_pending_conns() @@ -1933,12 +2025,18 @@ def serialize_config(cfg): if db.name in defines.EDGEDB_SPECIAL_DBS: continue + ns = {} + for ns_name, ns_db in db.ns_map.items(): + ns[ns_name] = dict( + namespace=ns_name, + extensions=sorted(ns_db.extensions), + query_cache_size=ns_db.get_query_cache_size() + ) + dbs[db.name] = dict( name=db.name, dbver=db.dbver, config=serialize_config(db.db_config), - extensions=sorted(db.extensions), - query_cache_size=db.get_query_cache_size(), connections=[ dict( in_tx=view.in_tx(), @@ -1948,6 +2046,7 @@ def serialize_config(cfg): ) for view in db.iter_views() ], + namespace=ns ) obj['databases'] = dbs diff --git a/edb/testbase/http.py b/edb/testbase/http.py index 7e0d855826d..78d3b50bd1e 100644 --- a/edb/testbase/http.py +++ b/edb/testbase/http.py @@ -22,6 +22,7 @@ import contextlib import http.client import json +import os import ssl import urllib.parse import urllib.request @@ -34,6 +35,7 @@ from . import server from .server import PGConnMixin +from edb.server import defines class StubbornHttpConnection(http.client.HTTPSConnection): @@ -148,7 +150,8 @@ def edgeql_query( variables=None, globals=None, module=None, limit=None ): req_data = { - 'query': query + 'query': query, + 'namespace': os.environ.get('EDGEDB_TEST_CASES_NAMESPACE', defines.DEFAULT_NS) } if use_http_post: @@ -226,7 +229,8 @@ def graphql_query(self, query, *, operation_name=None, variables=None, globals=None): req_data = { - 'query': query + 'query': query, + 'namespace': self.test_ns } if operation_name is not None: @@ -319,7 +323,8 @@ def infer_expr(self, objname, module, expression): req_data = { 'object': objname, 'module': module, - 'expression': expression + 'expression': expression, + 'namespace': os.environ.get('EDGEDB_TEST_CASES_NAMESPACE', defines.DEFAULT_NS) } req = urllib.request.Request(self.http_addr, method='POST') @@ -373,7 +378,7 @@ def get_api_path(cls): def create_type(self, body): req_data = body.as_dict() - + req_data['namespace'] = self.test_ns req = urllib.request.Request(self.http_addr, method='POST') req.add_header('Content-Type', 'application/json') req.add_header('testmode', '1') diff --git a/edb/testbase/server.py b/edb/testbase/server.py index b2fd415e6dc..276be7c0c98 100644 --- a/edb/testbase/server.py +++ b/edb/testbase/server.py @@ -48,7 +48,7 @@ from edb.edgeql import quote as qlquote from edb.pgsql import common as pgcommon from edb.pgsql import params as pgparams -from edb.server import args as edgedb_args, pgcon, pgconnparams +from edb.server import args as edgedb_args, pgcon, pgconnparams, defines from edb.server import cluster as edgedb_cluster from edb.server import defines as edgedb_defines from edb.server import main as edgedb_main @@ -1089,11 +1089,6 @@ async def create_db(): cls.con = cls.loop.run_until_complete(cls.connect(database=dbname)) - if class_set_up != 'skip': - script = cls.get_setup_script() - if script: - cls.loop.run_until_complete(cls.con.execute(script)) - @classmethod def tearDownClass(cls): script = '' @@ -1259,6 +1254,40 @@ def shape(self): class BaseQueryTestCase(DatabaseTestCase): BASE_TEST_CLASS = True + test_ns: str = None + + @classmethod + def setUpClass(cls): + super().setUpClass() + class_set_up = os.environ.get('EDGEDB_TEST_CASES_SET_UP', 'run') + if cls.test_ns is None: + cls.test_ns = os.environ.get('EDGEDB_TEST_CASES_NAMESPACE', defines.DEFAULT_NS) + + if class_set_up != 'skip': + if cls.test_ns != defines.DEFAULT_NS: + cls.loop.run_until_complete( + cls.con.execute(f'CREATE NAMESPACE {cls.test_ns}') + ) + cls.loop.run_until_complete( + cls.con.execute(f'use namespace {cls.test_ns}') + ) + script = cls.get_setup_script() + if script: + cls.loop.run_until_complete(cls.con.execute(script)) + + def setUp(self): + if self.test_ns != defines.DEFAULT_NS: + self.loop.run_until_complete( + self.con.execute(f'use namespace {self.test_ns}') + ) + self.loop.run_until_complete( + self.assert_query_result( + r'show namespace', + [self.test_ns] + ) + ) + + super().setUp() class DDLTestCase(BaseQueryTestCase): diff --git a/tests/test_edgeql_data_migration.py b/tests/test_edgeql_data_migration.py index e980be0eb0b..b1473bfd57b 100644 --- a/tests/test_edgeql_data_migration.py +++ b/tests/test_edgeql_data_migration.py @@ -11224,6 +11224,7 @@ async def test_edgeql_migration_recovery_in_script(self): async def test_edgeql_migration_recovery_commit_fail(self): con2 = await self.connect(database=self.con.dbname) try: + await con2.execute(f'USE NAMESPACE {self.test_ns};') await con2.execute('START MIGRATION TO {}') await con2.execute('POPULATE MIGRATION') diff --git a/tests/test_http_create_type.py b/tests/test_http_create_type.py index 64648503352..da9b251bae6 100644 --- a/tests/test_http_create_type.py +++ b/tests/test_http_create_type.py @@ -4,6 +4,7 @@ import edgedb +from edb.server import defines from edb.testbase import http as http_tb from edb.testbase import server as server_tb @@ -996,6 +997,8 @@ async def test_dml_reject(self): class TestHttpCreateTypeDumpRestore(TestHttpCreateType, server_tb.StableDumpTestCase): + test_ns = defines.DEFAULT_NS + async def prepare(self): await self.prepare_external_db(dbname=f"{self.get_database_name()}_restored") diff --git a/tests/test_http_graphql_query.py b/tests/test_http_graphql_query.py index 7602bb18b2f..18cdb8e22c2 100644 --- a/tests/test_http_graphql_query.py +++ b/tests/test_http_graphql_query.py @@ -50,7 +50,8 @@ def test_graphql_http_keepalive_01(self): value } } - ''' + ''', + 'namespace': self.test_ns } data, headers, status = self.http_con_request(con, req1_data) self.assertEqual(status, 200) diff --git a/tests/test_namespace.py b/tests/test_namespace.py new file mode 100644 index 00000000000..479b71ea664 --- /dev/null +++ b/tests/test_namespace.py @@ -0,0 +1,313 @@ +import json + +import edgedb + +from edb.schema import defines as s_def +from edb.testbase import server as tb + + +class TestNameSpace(tb.DatabaseTestCase): + TRANSACTION_ISOLATION = False + + async def assert_query_in_conn( + self, conn, query, exp_result + ): + res = await conn.query_json(query) + self.assertEqual(json.loads(res), exp_result) + + async def test_create_drop_namespace(self): + await self.con.execute("create namespace ns1;") + await self.assert_query_result( + r"select sys::NameSpace{name} order by .name", + [{'name': s_def.DEFAULT_NS}, {'name': 'ns1'}] + ) + await self.con.execute("drop namespace ns1;") + await self.assert_query_result( + r"select sys::NameSpace{name} order by .name", + [{'name': s_def.DEFAULT_NS}] + ) + + async def test_create_drop_namespace_invalid(self): + await self.con.execute("START TRANSACTION") + + with self.assertRaisesRegex( + edgedb.QueryError, + 'cannot execute CREATE NAMESPACE in a transaction', + ): + await self.con.execute("create namespace ns;") + + await self.con.execute("ROLLBACK") + + await self.con.execute("START TRANSACTION") + + with self.assertRaisesRegex( + edgedb.QueryError, + 'cannot execute DROP NAMESPACE in a transaction', + ): + await self.con.execute("drop namespace ns;") + + await self.con.execute("ROLLBACK") + + async def test_create_namespace_invalid(self): + with self.assertRaisesRegex( + edgedb.SchemaDefinitionError, + f'NameSpace names can not be started with \'pg_\', ' + f'as such names are reserved for system schemas', + ): + await self.con.execute("create namespace pg_ns1;") + + with self.assertRaisesRegex( + edgedb.SchemaDefinitionError, + f'\'{s_def.DEFAULT_NS}\' is reserved as name for ' + f'default namespace, use others instead.' + ): + await self.con.execute(f"create namespace {s_def.DEFAULT_NS};") + + async def test_create_namespace_exists(self): + await self.con.execute("create namespace ns2;") + + with self.assertRaisesRegex( + edgedb.EdgeDBError, + 'namespace "ns2" already exists', + ): + await self.con.execute("create namespace ns2;") + + await self.con.execute("drop namespace ns2;") + + async def test_drop_namespace_invalid(self): + with self.assertRaisesRegex( + edgedb.EdgeDBError, + 'namespace "ns3" does not exist', + ): + await self.con.execute("drop namespace ns3;") + + with self.assertRaisesRegex( + edgedb.ExecutionError, + f"namespace '{s_def.DEFAULT_NS}' cannot be dropped", + ): + await self.con.execute(f"drop namespace {s_def.DEFAULT_NS};") + + await self.con.execute("create namespace n1;") + await self.con.execute("use namespace n1;") + with self.assertRaisesRegex( + edgedb.ExecutionError, + f"cannot drop the currently open current_namespace 'n1'", + ): + await self.con.execute(f"drop namespace n1;") + + async def test_use_show_namespace(self): + await self.con.execute("create namespace temp1;") + # check default + conn1 = await self.connect(database=self.get_database_name()) + conn2 = await self.connect(database=self.get_database_name()) + try: + await self.assert_query_in_conn(conn1, 'show namespace;', [s_def.DEFAULT_NS]) + await self.assert_query_in_conn(conn2, 'show namespace;', [s_def.DEFAULT_NS]) + + # check seperated between connection + await conn1.execute('use namespace temp1;') + await self.assert_query_in_conn(conn1, 'show namespace;', ['temp1']) + await self.assert_query_in_conn(conn2, 'show namespace;', [s_def.DEFAULT_NS]) + + # check use + await conn1.execute( + 'CONFIGURE SESSION SET __internal_testmode := true;' + 'create type A;' + 'CONFIGURE SESSION SET __internal_testmode := false;' + ) + + await self.assert_query_in_conn( + conn1, + 'select count((select schema::ObjectType filter .name="default::A"))', + [1] + ) + + await self.assert_query_in_conn( + conn2, + 'select count((select schema::ObjectType filter .name="default::A"))', + [0] + ) + + await conn2.execute('drop namespace temp1;') + + with self.assertRaises(edgedb.QueryError): + await conn1.query("select 1") + + await self.assert_query_in_conn(conn1, 'show namespace;', [s_def.DEFAULT_NS]) + finally: + await conn1.aclose() + await conn2.aclose() + + async def test_use_namespace_invalid(self): + await self.con.execute("create namespace ns4;") + try: + with self.assertRaises(edgedb.QueryError): + await self.con.execute("use namespace ns5;") + + with self.assertRaisesRegex( + edgedb.QueryError, + 'USE NAMESPACE statement is not allowed to be used in script.', + ): + await self.con.execute("use namespace ns4;select 1;") + + await self.con.execute("START TRANSACTION") + + with self.assertRaisesRegex( + edgedb.QueryError, + 'cannot execute USE NAMESPACE in a transaction', + ): + await self.con.execute("use namespace ns4;") + + await self.con.execute("ROLLBACK") + + finally: + await self.con.execute("drop namespace ns4;") + + async def test_concurrent_schema_version_change_between_ns(self): + await self.con.execute("create namespace temp1;") + await self.con.execute("create namespace temp2;") + + conn1 = await self.connect(database=self.get_database_name()) + conn2 = await self.connect(database=self.get_database_name()) + + try: + await conn1.execute('use namespace temp1;') + await conn2.execute('use namespace temp2;') + + await conn1.execute( + ''' + START MIGRATION TO { + module default { + type A5; + type Object5 { + required link a -> default::A5; + }; + }; + }; + ''' + ) + await conn1.execute('POPULATE MIGRATION') + async with conn2.transaction(): + await conn2.execute( + ''' + START MIGRATION TO { + module default { + type A6; + type Object6 { + required link a -> default::A6; + }; + }; + }; + POPULATE MIGRATION; + COMMIT MIGRATION; + ''' + ) + + await conn1.execute("COMMIT MIGRATION") + + await self.assert_query_in_conn( + conn1, + r""" + SELECT schema::ObjectType { + name, + links: { + target: {name} + } + FILTER .name = 'a' + ORDER BY .name + } + FILTER .name in {'default::Object5', 'default::Object6'}; + """, + [ + { + "name": "default::Object5", + "links": [ + { + "target": { + "name": "default::A5" + } + } + ] + } + ], + ) + + await self.assert_query_in_conn( + conn2, + r""" + SELECT schema::ObjectType { + name, + links: { + target: {name} + } + FILTER .name = 'a' + ORDER BY .name + } + FILTER .name in {'default::Object5', 'default::Object6'}; + """, + [ + { + "name": "default::Object6", + "links": [ + { + "target": { + "name": "default::A6" + } + } + ] + } + ], + ) + + finally: + await conn1.aclose() + await conn2.aclose() + await self.con.execute('drop namespace temp1;') + await self.con.execute('drop namespace temp2;') + + async def test_concurrent_schema_version_change_in_one_ns(self): + await self.con.execute("create namespace temp1;") + + conn1 = await self.connect(database=self.get_database_name()) + conn2 = await self.connect(database=self.get_database_name()) + + try: + await conn1.execute('use namespace temp1;') + await conn2.execute('use namespace temp1;') + + await conn1.execute( + ''' + START MIGRATION TO { + module default { + type A5; + type Object5 { + required link a -> default::A5; + }; + }; + }; + ''' + ) + await conn1.execute('POPULATE MIGRATION') + async with conn2.transaction(): + await conn2.execute( + ''' + START MIGRATION TO { + module default { + type A6; + type Object6 { + required link a -> default::A6; + }; + }; + }; + POPULATE MIGRATION; + COMMIT MIGRATION; + ''' + ) + + with self.assertRaises(edgedb.TransactionError): + await conn1.execute("COMMIT MIGRATION") + + finally: + await conn1.aclose() + await conn2.aclose() + await self.con.execute('drop namespace temp1;') diff --git a/tests/test_server_config.py b/tests/test_server_config.py index 8793a0693e9..1823e48f308 100644 --- a/tests/test_server_config.py +++ b/tests/test_server_config.py @@ -709,57 +709,60 @@ async def test_server_proto_configure_03(self): ) async def test_server_proto_configure_04(self): - with self.assertRaisesRegex( - edgedb.UnsupportedFeatureError, - 'CONFIGURE SESSION INSERT is not supported'): - await self.con.query(''' - CONFIGURE SESSION INSERT TestSessionConfig {name := 'test_04'} - ''') - - with self.assertRaisesRegex( - edgedb.ConfigurationError, - "unrecognized configuration object 'Unrecognized'"): - await self.con.query(''' - CONFIGURE INSTANCE INSERT Unrecognized {name := 'test_04'} - ''') - - with self.assertRaisesRegex( - edgedb.QueryError, - "must not have a FILTER clause"): - await self.con.query(''' - CONFIGURE INSTANCE RESET __internal_testvalue FILTER 1 = 1; - ''') + try: + with self.assertRaisesRegex( + edgedb.UnsupportedFeatureError, + 'CONFIGURE SESSION INSERT is not supported'): + await self.con.query(''' + CONFIGURE SESSION INSERT TestSessionConfig {name := 'test_04'} + ''') - with self.assertRaisesRegex( - edgedb.QueryError, - "non-constant expression"): - await self.con.query(''' - CONFIGURE SESSION SET __internal_testmode := (random() = 0); - ''') + with self.assertRaisesRegex( + edgedb.ConfigurationError, + "unrecognized configuration object 'Unrecognized'"): + await self.con.query(''' + CONFIGURE INSTANCE INSERT Unrecognized {name := 'test_04'} + ''') - with self.assertRaisesRegex( - edgedb.ConfigurationError, - "'Subclass1' cannot be configured directly"): - await self.con.query(''' - CONFIGURE INSTANCE INSERT Subclass1 { - name := 'foo' - }; - ''') + with self.assertRaisesRegex( + edgedb.QueryError, + "must not have a FILTER clause"): + await self.con.query(''' + CONFIGURE INSTANCE RESET __internal_testvalue FILTER 1 = 1; + ''') - await self.con.query(''' - CONFIGURE INSTANCE INSERT TestInstanceConfig { - name := 'test_04', - } - ''') + with self.assertRaisesRegex( + edgedb.QueryError, + "non-constant expression"): + await self.con.query(''' + CONFIGURE SESSION SET __internal_testmode := (random() = 0); + ''') - with self.assertRaisesRegex( - edgedb.ConstraintViolationError, - "TestInstanceConfig.name violates exclusivity constraint"): + with self.assertRaisesRegex( + edgedb.ConfigurationError, + "'Subclass1' cannot be configured directly"): + await self.con.query(''' + CONFIGURE INSTANCE INSERT Subclass1 { + name := 'foo' + }; + ''') await self.con.query(''' CONFIGURE INSTANCE INSERT TestInstanceConfig { name := 'test_04', } ''') + with self.assertRaisesRegex( + edgedb.ConstraintViolationError, + "TestInstanceConfig.name violates exclusivity constraint"): + await self.con.query(''' + CONFIGURE INSTANCE INSERT TestInstanceConfig { + name := 'test_04', + } + ''') + finally: + await self.con.execute(''' + CONFIGURE INSTANCE RESET TestInstanceConfig; + ''') async def test_server_proto_configure_05(self): await self.con.execute(''' diff --git a/tests/test_server_proto.py b/tests/test_server_proto.py index 99998d9eddc..570a314607e 100644 --- a/tests/test_server_proto.py +++ b/tests/test_server_proto.py @@ -859,6 +859,7 @@ async def test_server_proto_wait_cancel_01(self): lock_key = tb.gen_lock_key() con2 = await self.connect(database=self.con.dbname) + await con2.execute(f"use namespace {self.test_ns}") await self.con.query('START TRANSACTION') await self.con.query( @@ -1416,6 +1417,7 @@ async def test_server_proto_tx_02(self): # to make sure that Opportunistic Execute isn't used. con2 = await self.connect(database=self.con.dbname) + await con2.execute(f"use namespace {self.test_ns}") try: with self.assertRaises(edgedb.DivisionByZeroError): @@ -1448,6 +1450,7 @@ async def test_server_proto_tx_03(self): # to make sure that "ROLLBACK" is cached. con2 = await self.connect(database=self.con.dbname) + await con2.execute(f"use namespace {self.test_ns}") try: for _ in range(5): @@ -1521,6 +1524,7 @@ async def test_server_proto_tx_06(self): query = 'SELECT 1' con2 = await self.connect(database=self.con.dbname) + await con2.execute(f"use namespace {self.test_ns}") try: for _ in range(5): self.assertEqual( @@ -1867,6 +1871,7 @@ async def test_server_proto_tx_16(self): async def test_server_proto_tx_17(self): con1 = self.con con2 = await self.connect(database=con1.dbname) + await con2.execute(f"use namespace {self.test_ns}") tx1 = con1.transaction() tx2 = con2.transaction() @@ -2183,6 +2188,10 @@ class TestServerProtoDDL(tb.DDLTestCase): TRANSACTION_ISOLATION = False + SETUP = ''' + CONFIGURE SESSION SET __internal_testmode := true; + ''' + async def test_server_proto_create_db_01(self): if not self.has_create_database: self.skipTest('create database is not supported by the backend') @@ -2224,6 +2233,8 @@ async def test_server_proto_query_cache_invalidate_01(self): con1 = self.con con2 = await self.connect(database=con1.dbname) try: + await con2.execute(f"use namespace {self.test_ns}") + await con2.execute("CONFIGURE SESSION SET __internal_testmode := true;") await con2.execute(f''' CREATE TYPE {typename} {{ CREATE REQUIRED PROPERTY prop1 -> std::str; @@ -2271,6 +2282,8 @@ async def test_server_proto_query_cache_invalidate_02(self): con1 = self.con con2 = await self.connect(database=con1.dbname) try: + await con2.execute(f"use namespace {self.test_ns}") + await con2.execute("CONFIGURE SESSION SET __internal_testmode := true;") await con2.query(f''' CREATE TYPE {typename} {{ CREATE REQUIRED PROPERTY prop1 -> std::str; @@ -2326,6 +2339,8 @@ async def test_server_proto_query_cache_invalidate_03(self): con1 = self.con con2 = await self.connect(database=con1.dbname) try: + await con2.execute(f"use namespace {self.test_ns}") + await con2.execute("CONFIGURE SESSION SET __internal_testmode := true;") await con2.execute(f''' CREATE TYPE {typename} {{ CREATE REQUIRED PROPERTY prop1 -> array; @@ -2364,6 +2379,8 @@ async def test_server_proto_query_cache_invalidate_03(self): await con1.query(query), edgedb.Set([[1, 23]])) + await con2.execute("CONFIGURE SESSION SET __internal_testmode := false;") + finally: await con2.aclose() @@ -2373,6 +2390,8 @@ async def test_server_proto_query_cache_invalidate_04(self): con1 = self.con con2 = await self.connect(database=con1.dbname) try: + await con2.execute(f"use namespace {self.test_ns}") + await con2.execute("CONFIGURE SESSION SET __internal_testmode := true;") await con2.execute(f''' CREATE TYPE {typename} {{ CREATE REQUIRED PROPERTY prop1 -> std::str; @@ -2420,6 +2439,8 @@ async def test_server_proto_query_cache_invalidate_05(self): con1 = self.con con2 = await self.connect(database=con1.dbname) try: + await con2.execute(f"use namespace {self.test_ns}") + await con2.execute("CONFIGURE SESSION SET __internal_testmode := true;") await con2.execute(f''' CREATE TYPE {typename} {{ CREATE REQUIRED PROPERTY prop1 -> std::str; @@ -2477,6 +2498,8 @@ async def test_server_proto_query_cache_invalidate_06(self): con1 = self.con con2 = await self.connect(database=con1.dbname) try: + await con2.execute(f"use namespace {self.test_ns}") + await con2.execute("CONFIGURE SESSION SET __internal_testmode := true;") await con2.execute(f''' CREATE TYPE Foo{typename}; @@ -2534,6 +2557,8 @@ async def test_server_proto_query_cache_invalidate_07(self): con1 = self.con con2 = await self.connect(database=con1.dbname) try: + await con2.execute(f"use namespace {self.test_ns}") + await con2.execute("CONFIGURE SESSION SET __internal_testmode := true;") await con2.execute(f''' CREATE TYPE Foo{typename};