risingwave_frontend/
test_utils.rs

1// Copyright 2022 RisingWave Labs
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use std::collections::{BTreeMap, HashMap, HashSet};
16use std::io::Write;
17use std::net::{IpAddr, Ipv4Addr, SocketAddr};
18use std::sync::Arc;
19use std::sync::atomic::{AtomicU32, Ordering};
20
21use futures_async_stream::for_await;
22use parking_lot::RwLock;
23use pgwire::net::{Address, AddressRef};
24use pgwire::pg_response::StatementType;
25use pgwire::pg_server::{SessionId, SessionManager, UserAuthenticator};
26use pgwire::types::Row;
27use risingwave_common::catalog::{
28    AlterDatabaseParam, DEFAULT_DATABASE_NAME, DEFAULT_SCHEMA_NAME, DEFAULT_SUPER_USER,
29    DEFAULT_SUPER_USER_FOR_ADMIN, DEFAULT_SUPER_USER_FOR_ADMIN_ID, DEFAULT_SUPER_USER_ID,
30    FunctionId, IndexId, NON_RESERVED_USER_ID, ObjectId, PG_CATALOG_SCHEMA_NAME,
31    RW_CATALOG_SCHEMA_NAME, TableId,
32};
33use risingwave_common::hash::{VirtualNode, VnodeCount, VnodeCountCompat};
34use risingwave_common::id::{ConnectionId, JobId, SourceId, SubscriptionId, ViewId, WorkerId};
35use risingwave_common::session_config::SessionConfig;
36use risingwave_common::system_param::reader::SystemParamsReader;
37use risingwave_common::util::cluster_limit::ClusterLimit;
38use risingwave_common::util::worker_util::DEFAULT_RESOURCE_GROUP;
39use risingwave_hummock_sdk::version::{HummockVersion, HummockVersionDelta};
40use risingwave_hummock_sdk::{CompactionGroupId, HummockVersionId, INVALID_VERSION_ID};
41use risingwave_pb::backup_service::MetaSnapshotMetadata;
42use risingwave_pb::catalog::{
43    PbComment, PbDatabase, PbFunction, PbIndex, PbSchema, PbSink, PbSource, PbStreamJobStatus,
44    PbSubscription, PbTable, PbView, Table,
45};
46use risingwave_pb::common::{PbObjectType, WorkerNode};
47use risingwave_pb::ddl_service::alter_owner_request::Object;
48use risingwave_pb::ddl_service::create_iceberg_table_request::{PbSinkJobInfo, PbTableJobInfo};
49use risingwave_pb::ddl_service::{
50    DdlProgress, PbTableJobType, TableJobType, alter_name_request, alter_set_schema_request,
51    alter_swap_rename_request, create_connection_request, streaming_job_resource_type,
52};
53use risingwave_pb::hummock::write_limits::WriteLimit;
54use risingwave_pb::hummock::{
55    BranchedObject, CompactTaskAssignment, CompactTaskProgress, CompactionGroupInfo,
56};
57use risingwave_pb::id::ActorId;
58use risingwave_pb::meta::cancel_creating_jobs_request::PbJobs;
59use risingwave_pb::meta::list_actor_splits_response::ActorSplit;
60use risingwave_pb::meta::list_actor_states_response::ActorState;
61use risingwave_pb::meta::list_cdc_progress_response::PbCdcProgress;
62use risingwave_pb::meta::list_iceberg_tables_response::IcebergTable;
63use risingwave_pb::meta::list_rate_limits_response::RateLimitInfo;
64use risingwave_pb::meta::list_refresh_table_states_response::RefreshTableState;
65use risingwave_pb::meta::list_streaming_job_states_response::StreamingJobState;
66use risingwave_pb::meta::list_table_fragments_response::TableFragmentInfo;
67use risingwave_pb::meta::{
68    EventLog, FragmentDistribution, ObjectDependency as PbObjectDependency, PbTableParallelism,
69    PbThrottleTarget, RecoveryStatus, RefreshRequest, RefreshResponse, SystemParams,
70    list_sink_log_store_tables_response,
71};
72use risingwave_pb::secret::PbSecretRef;
73use risingwave_pb::stream_plan::StreamFragmentGraph;
74use risingwave_pb::user::alter_default_privilege_request::Operation as AlterDefaultPrivilegeOperation;
75use risingwave_pb::user::update_user_request::UpdateField;
76use risingwave_pb::user::{GrantPrivilege, UserInfo};
77use risingwave_rpc_client::error::Result as RpcResult;
78use tempfile::{Builder, NamedTempFile};
79
80use crate::FrontendOpts;
81use crate::catalog::catalog_service::CatalogWriter;
82use crate::catalog::root_catalog::Catalog;
83use crate::catalog::{DatabaseId, FragmentId, SchemaId, SecretId, SinkId};
84use crate::error::{ErrorCode, Result, RwError};
85use crate::handler::RwPgResponse;
86use crate::meta_client::FrontendMetaClient;
87use crate::scheduler::HummockSnapshotManagerRef;
88use crate::session::{AuthContext, FrontendEnv, SessionImpl};
89use crate::user::UserId;
90use crate::user::user_manager::UserInfoManager;
91use crate::user::user_service::UserInfoWriter;
92
93/// An embedded frontend without starting meta and without starting frontend as a tcp server.
94pub struct LocalFrontend {
95    pub opts: FrontendOpts,
96    env: FrontendEnv,
97}
98
99impl SessionManager for LocalFrontend {
100    type Error = RwError;
101    type Session = SessionImpl;
102
103    fn create_dummy_session(
104        &self,
105        _database_id: DatabaseId,
106    ) -> std::result::Result<Arc<Self::Session>, Self::Error> {
107        unreachable!()
108    }
109
110    fn connect(
111        &self,
112        _database: &str,
113        _user_name: &str,
114        _peer_addr: AddressRef,
115    ) -> std::result::Result<Arc<Self::Session>, Self::Error> {
116        Ok(self.session_ref())
117    }
118
119    fn cancel_queries_in_session(&self, _session_id: SessionId) {
120        unreachable!()
121    }
122
123    fn cancel_creating_jobs_in_session(&self, _session_id: SessionId) {
124        unreachable!()
125    }
126
127    fn end_session(&self, _session: &Self::Session) {
128        unreachable!()
129    }
130}
131
132impl LocalFrontend {
133    #[expect(clippy::unused_async)]
134    pub async fn new(opts: FrontendOpts) -> Self {
135        let env = FrontendEnv::mock();
136        Self { opts, env }
137    }
138
139    pub async fn run_sql(
140        &self,
141        sql: impl Into<String>,
142    ) -> std::result::Result<RwPgResponse, Box<dyn std::error::Error + Send + Sync>> {
143        let sql: Arc<str> = Arc::from(sql.into());
144        self.session_ref().run_statement(sql, vec![]).await
145    }
146
147    pub async fn run_sql_with_session(
148        &self,
149        session_ref: Arc<SessionImpl>,
150        sql: impl Into<String>,
151    ) -> std::result::Result<RwPgResponse, Box<dyn std::error::Error + Send + Sync>> {
152        let sql: Arc<str> = Arc::from(sql.into());
153        session_ref.run_statement(sql, vec![]).await
154    }
155
156    pub async fn run_user_sql(
157        &self,
158        sql: impl Into<String>,
159        database: String,
160        user_name: String,
161        user_id: UserId,
162    ) -> std::result::Result<RwPgResponse, Box<dyn std::error::Error + Send + Sync>> {
163        let sql: Arc<str> = Arc::from(sql.into());
164        self.session_user_ref(database, user_name, user_id)
165            .run_statement(sql, vec![])
166            .await
167    }
168
169    pub async fn query_formatted_result(&self, sql: impl Into<String>) -> Vec<String> {
170        let mut rsp = self.run_sql(sql).await.unwrap();
171        let mut res = vec![];
172        #[for_await]
173        for row_set in rsp.values_stream() {
174            for row in row_set.unwrap() {
175                res.push(format!("{:?}", row));
176            }
177        }
178        res
179    }
180
181    pub async fn get_explain_output(&self, sql: impl Into<String>) -> String {
182        let mut rsp = self.run_sql(sql).await.unwrap();
183        assert_eq!(rsp.stmt_type(), StatementType::EXPLAIN);
184        let mut res = String::new();
185        #[for_await]
186        for row_set in rsp.values_stream() {
187            for row in row_set.unwrap() {
188                let row: Row = row;
189                let row = row.values()[0].as_ref().unwrap();
190                res += std::str::from_utf8(row).unwrap();
191                res += "\n";
192            }
193        }
194        res
195    }
196
197    /// Creates a new session
198    pub fn session_ref(&self) -> Arc<SessionImpl> {
199        self.session_user_ref(
200            DEFAULT_DATABASE_NAME.to_owned(),
201            DEFAULT_SUPER_USER.to_owned(),
202            DEFAULT_SUPER_USER_ID,
203        )
204    }
205
206    pub fn session_user_ref(
207        &self,
208        database: String,
209        user_name: String,
210        user_id: UserId,
211    ) -> Arc<SessionImpl> {
212        Arc::new(SessionImpl::new(
213            self.env.clone(),
214            AuthContext::new(database, user_name, user_id),
215            UserAuthenticator::None,
216            // Local Frontend use a non-sense id.
217            (0, 0),
218            Address::Tcp(SocketAddr::new(
219                IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
220                6666,
221            ))
222            .into(),
223            Default::default(),
224        ))
225    }
226}
227
228pub async fn get_explain_output(mut rsp: RwPgResponse) -> String {
229    if rsp.stmt_type() != StatementType::EXPLAIN {
230        panic!("RESPONSE INVALID: {rsp:?}");
231    }
232    let mut res = String::new();
233    #[for_await]
234    for row_set in rsp.values_stream() {
235        for row in row_set.unwrap() {
236            let row: Row = row;
237            let row = row.values()[0].as_ref().unwrap();
238            res += std::str::from_utf8(row).unwrap();
239            res += "\n";
240        }
241    }
242    res
243}
244
245pub struct MockCatalogWriter {
246    catalog: Arc<RwLock<Catalog>>,
247    id: AtomicU32,
248    table_id_to_schema_id: RwLock<HashMap<u32, SchemaId>>,
249    schema_id_to_database_id: RwLock<HashMap<SchemaId, DatabaseId>>,
250    hummock_snapshot_manager: HummockSnapshotManagerRef,
251}
252
253#[async_trait::async_trait]
254impl CatalogWriter for MockCatalogWriter {
255    async fn create_database(
256        &self,
257        db_name: &str,
258        owner: UserId,
259        resource_group: &str,
260        barrier_interval_ms: Option<u32>,
261        checkpoint_frequency: Option<u64>,
262    ) -> Result<()> {
263        let database_id = DatabaseId::new(self.gen_id());
264        self.catalog.write().create_database(&PbDatabase {
265            name: db_name.to_owned(),
266            id: database_id,
267            owner,
268            resource_group: resource_group.to_owned(),
269            barrier_interval_ms,
270            checkpoint_frequency,
271        });
272        self.create_schema(database_id, DEFAULT_SCHEMA_NAME, owner)
273            .await?;
274        self.create_schema(database_id, PG_CATALOG_SCHEMA_NAME, owner)
275            .await?;
276        self.create_schema(database_id, RW_CATALOG_SCHEMA_NAME, owner)
277            .await?;
278        Ok(())
279    }
280
281    async fn create_schema(
282        &self,
283        db_id: DatabaseId,
284        schema_name: &str,
285        owner: UserId,
286    ) -> Result<()> {
287        let id = self.gen_id();
288        self.catalog.write().create_schema(&PbSchema {
289            id,
290            name: schema_name.to_owned(),
291            database_id: db_id,
292            owner,
293        });
294        self.add_schema_id(id, db_id);
295        Ok(())
296    }
297
298    async fn create_materialized_view(
299        &self,
300        mut table: PbTable,
301        _graph: StreamFragmentGraph,
302        dependencies: HashSet<ObjectId>,
303        _resource_type: streaming_job_resource_type::ResourceType,
304        _if_not_exists: bool,
305    ) -> Result<()> {
306        table.id = self.gen_id();
307        table.stream_job_status = PbStreamJobStatus::Created as _;
308        table.maybe_vnode_count = VnodeCount::for_test().to_protobuf();
309        self.catalog.write().create_table(&table);
310        self.add_table_or_source_id(table.id.as_raw_id(), table.schema_id, table.database_id);
311        self.insert_object_dependencies(table.id.as_object_id(), dependencies);
312        self.hummock_snapshot_manager.add_table_for_test(table.id);
313        Ok(())
314    }
315
316    async fn replace_materialized_view(
317        &self,
318        mut table: PbTable,
319        _graph: StreamFragmentGraph,
320    ) -> Result<()> {
321        table.stream_job_status = PbStreamJobStatus::Created as _;
322        assert_eq!(table.vnode_count(), VirtualNode::COUNT_FOR_TEST);
323        self.catalog.write().update_table(&table);
324        Ok(())
325    }
326
327    async fn create_view(&self, mut view: PbView, dependencies: HashSet<ObjectId>) -> Result<()> {
328        view.id = self.gen_id();
329        self.catalog.write().create_view(&view);
330        self.add_table_or_source_id(view.id.as_raw_id(), view.schema_id, view.database_id);
331        self.insert_object_dependencies(view.id.as_object_id(), dependencies);
332        Ok(())
333    }
334
335    async fn create_table(
336        &self,
337        source: Option<PbSource>,
338        mut table: PbTable,
339        graph: StreamFragmentGraph,
340        _job_type: PbTableJobType,
341        if_not_exists: bool,
342        dependencies: HashSet<ObjectId>,
343    ) -> Result<()> {
344        if let Some(source) = source {
345            let source_id = self.create_source_inner(source)?;
346            table.optional_associated_source_id = Some(source_id.into());
347        }
348        self.create_materialized_view(
349            table,
350            graph,
351            dependencies,
352            streaming_job_resource_type::ResourceType::Regular(true),
353            if_not_exists,
354        )
355        .await?;
356        Ok(())
357    }
358
359    async fn replace_table(
360        &self,
361        _source: Option<PbSource>,
362        mut table: PbTable,
363        _graph: StreamFragmentGraph,
364        _job_type: TableJobType,
365    ) -> Result<()> {
366        table.stream_job_status = PbStreamJobStatus::Created as _;
367        assert_eq!(table.vnode_count(), VirtualNode::COUNT_FOR_TEST);
368        self.catalog.write().update_table(&table);
369        Ok(())
370    }
371
372    async fn replace_source(&self, source: PbSource, _graph: StreamFragmentGraph) -> Result<()> {
373        self.catalog.write().update_source(&source);
374        Ok(())
375    }
376
377    async fn create_source(
378        &self,
379        source: PbSource,
380        _graph: Option<StreamFragmentGraph>,
381        _if_not_exists: bool,
382    ) -> Result<()> {
383        self.create_source_inner(source).map(|_| ())
384    }
385
386    async fn create_sink(
387        &self,
388        sink: PbSink,
389        graph: StreamFragmentGraph,
390        dependencies: HashSet<ObjectId>,
391        _if_not_exists: bool,
392    ) -> Result<()> {
393        let sink_id = self.create_sink_inner(sink, graph)?;
394        self.insert_object_dependencies(sink_id.as_object_id(), dependencies);
395        Ok(())
396    }
397
398    async fn create_subscription(&self, subscription: PbSubscription) -> Result<()> {
399        self.create_subscription_inner(subscription)
400    }
401
402    async fn create_index(
403        &self,
404        mut index: PbIndex,
405        mut index_table: PbTable,
406        _graph: StreamFragmentGraph,
407        _if_not_exists: bool,
408    ) -> Result<()> {
409        index_table.id = self.gen_id();
410        index_table.stream_job_status = PbStreamJobStatus::Created as _;
411        index_table.maybe_vnode_count = VnodeCount::for_test().to_protobuf();
412        self.catalog.write().create_table(&index_table);
413        self.add_table_or_index_id(
414            index_table.id.as_raw_id(),
415            index_table.schema_id,
416            index_table.database_id,
417        );
418
419        index.id = index_table.id.as_raw_id().into();
420        index.index_table_id = index_table.id;
421        self.catalog.write().create_index(&index);
422        Ok(())
423    }
424
425    async fn create_function(&self, _function: PbFunction) -> Result<()> {
426        unreachable!()
427    }
428
429    async fn create_connection(
430        &self,
431        _connection_name: String,
432        _database_id: DatabaseId,
433        _schema_id: SchemaId,
434        _owner_id: UserId,
435        _connection: create_connection_request::Payload,
436    ) -> Result<()> {
437        unreachable!()
438    }
439
440    async fn create_secret(
441        &self,
442        _secret_name: String,
443        _database_id: DatabaseId,
444        _schema_id: SchemaId,
445        _owner_id: UserId,
446        _payload: Vec<u8>,
447    ) -> Result<()> {
448        unreachable!()
449    }
450
451    async fn comment_on(&self, _comment: PbComment) -> Result<()> {
452        unreachable!()
453    }
454
455    async fn drop_table(
456        &self,
457        source_id: Option<SourceId>,
458        table_id: TableId,
459        cascade: bool,
460    ) -> Result<()> {
461        if cascade {
462            return Err(ErrorCode::NotSupported(
463                "drop cascade in MockCatalogWriter is unsupported".to_owned(),
464                "use drop instead".to_owned(),
465            )
466            .into());
467        }
468        if let Some(source_id) = source_id {
469            self.drop_table_or_source_id(source_id.as_raw_id());
470        }
471        let (database_id, schema_id) = self.drop_table_or_source_id(table_id.as_raw_id());
472        let indexes =
473            self.catalog
474                .read()
475                .get_all_indexes_related_to_object(database_id, schema_id, table_id);
476        for index in indexes {
477            self.drop_index(index.id, cascade).await?;
478        }
479        self.catalog
480            .write()
481            .drop_table(database_id, schema_id, table_id);
482        if let Some(source_id) = source_id {
483            self.catalog
484                .write()
485                .drop_source(database_id, schema_id, source_id);
486        }
487        Ok(())
488    }
489
490    async fn drop_view(&self, _view_id: ViewId, _cascade: bool) -> Result<()> {
491        unreachable!()
492    }
493
494    async fn drop_materialized_view(&self, table_id: TableId, cascade: bool) -> Result<()> {
495        if cascade {
496            return Err(ErrorCode::NotSupported(
497                "drop cascade in MockCatalogWriter is unsupported".to_owned(),
498                "use drop instead".to_owned(),
499            )
500            .into());
501        }
502        let (database_id, schema_id) = self.drop_table_or_source_id(table_id.as_raw_id());
503        let indexes =
504            self.catalog
505                .read()
506                .get_all_indexes_related_to_object(database_id, schema_id, table_id);
507        for index in indexes {
508            self.drop_index(index.id, cascade).await?;
509        }
510        self.catalog
511            .write()
512            .drop_table(database_id, schema_id, table_id);
513        Ok(())
514    }
515
516    async fn drop_source(&self, source_id: SourceId, cascade: bool) -> Result<()> {
517        if cascade {
518            return Err(ErrorCode::NotSupported(
519                "drop cascade in MockCatalogWriter is unsupported".to_owned(),
520                "use drop instead".to_owned(),
521            )
522            .into());
523        }
524        let (database_id, schema_id) = self.drop_table_or_source_id(source_id.as_raw_id());
525        self.catalog
526            .write()
527            .drop_source(database_id, schema_id, source_id);
528        Ok(())
529    }
530
531    async fn reset_source(&self, _source_id: SourceId) -> Result<()> {
532        Ok(())
533    }
534
535    async fn drop_sink(&self, sink_id: SinkId, cascade: bool) -> Result<()> {
536        if cascade {
537            return Err(ErrorCode::NotSupported(
538                "drop cascade in MockCatalogWriter is unsupported".to_owned(),
539                "use drop instead".to_owned(),
540            )
541            .into());
542        }
543        let (database_id, schema_id) = self.drop_table_or_sink_id(sink_id.as_raw_id());
544        self.catalog
545            .write()
546            .drop_sink(database_id, schema_id, sink_id);
547        Ok(())
548    }
549
550    async fn drop_subscription(
551        &self,
552        subscription_id: SubscriptionId,
553        cascade: bool,
554    ) -> Result<()> {
555        if cascade {
556            return Err(ErrorCode::NotSupported(
557                "drop cascade in MockCatalogWriter is unsupported".to_owned(),
558                "use drop instead".to_owned(),
559            )
560            .into());
561        }
562        let (database_id, schema_id) =
563            self.drop_table_or_subscription_id(subscription_id.as_raw_id());
564        self.catalog
565            .write()
566            .drop_subscription(database_id, schema_id, subscription_id);
567        Ok(())
568    }
569
570    async fn drop_index(&self, index_id: IndexId, cascade: bool) -> Result<()> {
571        if cascade {
572            return Err(ErrorCode::NotSupported(
573                "drop cascade in MockCatalogWriter is unsupported".to_owned(),
574                "use drop instead".to_owned(),
575            )
576            .into());
577        }
578        let &schema_id = self
579            .table_id_to_schema_id
580            .read()
581            .get(&index_id.as_raw_id())
582            .unwrap();
583        let database_id = self.get_database_id_by_schema(schema_id);
584
585        let index = {
586            let catalog_reader = self.catalog.read();
587            let schema_catalog = catalog_reader
588                .get_schema_by_id(database_id, schema_id)
589                .unwrap();
590            schema_catalog.get_index_by_id(index_id).unwrap().clone()
591        };
592
593        let index_table_id = index.index_table().id;
594        let (database_id, schema_id) = self.drop_table_or_index_id(index_id.as_raw_id());
595        self.catalog
596            .write()
597            .drop_index(database_id, schema_id, index_id);
598        self.catalog
599            .write()
600            .drop_table(database_id, schema_id, index_table_id);
601        Ok(())
602    }
603
604    async fn drop_function(&self, _function_id: FunctionId, _cascade: bool) -> Result<()> {
605        unreachable!()
606    }
607
608    async fn drop_connection(&self, _connection_id: ConnectionId, _cascade: bool) -> Result<()> {
609        unreachable!()
610    }
611
612    async fn drop_secret(&self, _secret_id: SecretId, _cascade: bool) -> Result<()> {
613        unreachable!()
614    }
615
616    async fn drop_database(&self, database_id: DatabaseId) -> Result<()> {
617        self.catalog.write().drop_database(database_id);
618        Ok(())
619    }
620
621    async fn drop_schema(&self, schema_id: SchemaId, _cascade: bool) -> Result<()> {
622        let database_id = self.drop_schema_id(schema_id);
623        self.catalog.write().drop_schema(database_id, schema_id);
624        Ok(())
625    }
626
627    async fn alter_name(
628        &self,
629        object_id: alter_name_request::Object,
630        object_name: &str,
631    ) -> Result<()> {
632        match object_id {
633            alter_name_request::Object::TableId(table_id) => {
634                self.catalog
635                    .write()
636                    .alter_table_name_by_id(table_id, object_name);
637                Ok(())
638            }
639            _ => {
640                unimplemented!()
641            }
642        }
643    }
644
645    async fn alter_source(&self, source: PbSource) -> Result<()> {
646        self.catalog.write().update_source(&source);
647        Ok(())
648    }
649
650    async fn alter_owner(&self, object: Object, owner_id: UserId) -> Result<()> {
651        for database in self.catalog.read().iter_databases() {
652            for schema in database.iter_schemas() {
653                match object {
654                    Object::TableId(table_id) => {
655                        if let Some(table) = schema.get_created_table_by_id(TableId::from(table_id))
656                        {
657                            let mut pb_table = table.to_prost();
658                            pb_table.owner = owner_id;
659                            self.catalog.write().update_table(&pb_table);
660                            return Ok(());
661                        }
662                    }
663                    _ => unreachable!(),
664                }
665            }
666        }
667
668        Err(ErrorCode::ItemNotFound(format!("object not found: {:?}", object)).into())
669    }
670
671    async fn alter_set_schema(
672        &self,
673        object: alter_set_schema_request::Object,
674        new_schema_id: SchemaId,
675    ) -> Result<()> {
676        match object {
677            alter_set_schema_request::Object::TableId(table_id) => {
678                let mut pb_table = {
679                    let reader = self.catalog.read();
680                    let table = reader.get_any_table_by_id(table_id)?.to_owned();
681                    table.to_prost()
682                };
683                pb_table.schema_id = new_schema_id;
684                self.catalog.write().update_table(&pb_table);
685                self.table_id_to_schema_id
686                    .write()
687                    .insert(table_id.as_raw_id(), new_schema_id);
688                Ok(())
689            }
690            _ => unreachable!(),
691        }
692    }
693
694    async fn alter_parallelism(
695        &self,
696        _job_id: JobId,
697        _parallelism: PbTableParallelism,
698        _deferred: bool,
699    ) -> Result<()> {
700        todo!()
701    }
702
703    async fn alter_backfill_parallelism(
704        &self,
705        _job_id: JobId,
706        _parallelism: Option<PbTableParallelism>,
707        _deferred: bool,
708    ) -> Result<()> {
709        todo!()
710    }
711
712    async fn alter_config(
713        &self,
714        _job_id: JobId,
715        _entries_to_add: HashMap<String, String>,
716        _keys_to_remove: Vec<String>,
717    ) -> Result<()> {
718        todo!()
719    }
720
721    async fn alter_swap_rename(&self, _object: alter_swap_rename_request::Object) -> Result<()> {
722        todo!()
723    }
724
725    async fn alter_secret(
726        &self,
727        _secret_id: SecretId,
728        _secret_name: String,
729        _database_id: DatabaseId,
730        _schema_id: SchemaId,
731        _owner_id: UserId,
732        _payload: Vec<u8>,
733    ) -> Result<()> {
734        unreachable!()
735    }
736
737    async fn alter_resource_group(
738        &self,
739        _table_id: TableId,
740        _resource_group: Option<String>,
741        _deferred: bool,
742    ) -> Result<()> {
743        todo!()
744    }
745
746    async fn alter_database_param(
747        &self,
748        database_id: DatabaseId,
749        param: AlterDatabaseParam,
750    ) -> Result<()> {
751        let mut pb_database = {
752            let reader = self.catalog.read();
753            let database = reader.get_database_by_id(database_id)?.to_owned();
754            database.to_prost()
755        };
756        match param {
757            AlterDatabaseParam::BarrierIntervalMs(interval) => {
758                pb_database.barrier_interval_ms = interval;
759            }
760            AlterDatabaseParam::CheckpointFrequency(frequency) => {
761                pb_database.checkpoint_frequency = frequency;
762            }
763        }
764        self.catalog.write().update_database(&pb_database);
765        Ok(())
766    }
767
768    async fn create_iceberg_table(
769        &self,
770        _table_job_info: PbTableJobInfo,
771        _sink_job_info: PbSinkJobInfo,
772        _iceberg_source: PbSource,
773        _if_not_exists: bool,
774    ) -> Result<()> {
775        todo!()
776    }
777
778    async fn wait(&self) -> Result<()> {
779        Ok(())
780    }
781}
782
783impl MockCatalogWriter {
784    pub fn new(
785        catalog: Arc<RwLock<Catalog>>,
786        hummock_snapshot_manager: HummockSnapshotManagerRef,
787    ) -> Self {
788        catalog.write().create_database(&PbDatabase {
789            id: 0.into(),
790            name: DEFAULT_DATABASE_NAME.to_owned(),
791            owner: DEFAULT_SUPER_USER_ID,
792            resource_group: DEFAULT_RESOURCE_GROUP.to_owned(),
793            barrier_interval_ms: None,
794            checkpoint_frequency: None,
795        });
796        catalog.write().create_schema(&PbSchema {
797            id: 1.into(),
798            name: DEFAULT_SCHEMA_NAME.to_owned(),
799            database_id: 0.into(),
800            owner: DEFAULT_SUPER_USER_ID,
801        });
802        catalog.write().create_schema(&PbSchema {
803            id: 2.into(),
804            name: PG_CATALOG_SCHEMA_NAME.to_owned(),
805            database_id: 0.into(),
806            owner: DEFAULT_SUPER_USER_ID,
807        });
808        catalog.write().create_schema(&PbSchema {
809            id: 3.into(),
810            name: RW_CATALOG_SCHEMA_NAME.to_owned(),
811            database_id: 0.into(),
812            owner: DEFAULT_SUPER_USER_ID,
813        });
814        let mut map: HashMap<SchemaId, DatabaseId> = HashMap::new();
815        map.insert(1_u32.into(), 0_u32.into());
816        map.insert(2_u32.into(), 0_u32.into());
817        map.insert(3_u32.into(), 0_u32.into());
818        Self {
819            catalog,
820            id: AtomicU32::new(3),
821            table_id_to_schema_id: Default::default(),
822            schema_id_to_database_id: RwLock::new(map),
823            hummock_snapshot_manager,
824        }
825    }
826
827    fn gen_id<T: From<u32>>(&self) -> T {
828        // Since the 0 value is `dev` schema and database, so jump out the 0 value.
829        (self.id.fetch_add(1, Ordering::SeqCst) + 1).into()
830    }
831
832    fn add_table_or_source_id(&self, table_id: u32, schema_id: SchemaId, _database_id: DatabaseId) {
833        self.table_id_to_schema_id
834            .write()
835            .insert(table_id, schema_id);
836    }
837
838    fn drop_table_or_source_id(&self, table_id: u32) -> (DatabaseId, SchemaId) {
839        let schema_id = self
840            .table_id_to_schema_id
841            .write()
842            .remove(&table_id)
843            .unwrap();
844        (self.get_database_id_by_schema(schema_id), schema_id)
845    }
846
847    fn add_table_or_sink_id(&self, table_id: u32, schema_id: SchemaId, _database_id: DatabaseId) {
848        self.table_id_to_schema_id
849            .write()
850            .insert(table_id, schema_id);
851    }
852
853    fn add_table_or_subscription_id(
854        &self,
855        table_id: u32,
856        schema_id: SchemaId,
857        _database_id: DatabaseId,
858    ) {
859        self.table_id_to_schema_id
860            .write()
861            .insert(table_id, schema_id);
862    }
863
864    fn add_table_or_index_id(&self, table_id: u32, schema_id: SchemaId, _database_id: DatabaseId) {
865        self.table_id_to_schema_id
866            .write()
867            .insert(table_id, schema_id);
868    }
869
870    fn drop_table_or_sink_id(&self, table_id: u32) -> (DatabaseId, SchemaId) {
871        let schema_id = self
872            .table_id_to_schema_id
873            .write()
874            .remove(&table_id)
875            .unwrap();
876        (self.get_database_id_by_schema(schema_id), schema_id)
877    }
878
879    fn drop_table_or_subscription_id(&self, table_id: u32) -> (DatabaseId, SchemaId) {
880        let schema_id = self
881            .table_id_to_schema_id
882            .write()
883            .remove(&table_id)
884            .unwrap();
885        (self.get_database_id_by_schema(schema_id), schema_id)
886    }
887
888    fn drop_table_or_index_id(&self, table_id: u32) -> (DatabaseId, SchemaId) {
889        let schema_id = self
890            .table_id_to_schema_id
891            .write()
892            .remove(&table_id)
893            .unwrap();
894        (self.get_database_id_by_schema(schema_id), schema_id)
895    }
896
897    fn add_schema_id(&self, schema_id: SchemaId, database_id: DatabaseId) {
898        self.schema_id_to_database_id
899            .write()
900            .insert(schema_id, database_id);
901    }
902
903    fn drop_schema_id(&self, schema_id: SchemaId) -> DatabaseId {
904        self.schema_id_to_database_id
905            .write()
906            .remove(&schema_id)
907            .unwrap()
908    }
909
910    fn create_source_inner(&self, mut source: PbSource) -> Result<SourceId> {
911        source.id = self.gen_id();
912        self.catalog.write().create_source(&source);
913        self.add_table_or_source_id(source.id.as_raw_id(), source.schema_id, source.database_id);
914        Ok(source.id)
915    }
916
917    fn create_sink_inner(&self, mut sink: PbSink, _graph: StreamFragmentGraph) -> Result<SinkId> {
918        sink.id = self.gen_id();
919        sink.stream_job_status = PbStreamJobStatus::Created as _;
920        self.catalog.write().create_sink(&sink);
921        self.add_table_or_sink_id(sink.id.as_raw_id(), sink.schema_id, sink.database_id);
922        Ok(sink.id)
923    }
924
925    fn create_subscription_inner(&self, mut subscription: PbSubscription) -> Result<()> {
926        subscription.id = self.gen_id();
927        self.catalog.write().create_subscription(&subscription);
928        self.add_table_or_subscription_id(
929            subscription.id.as_raw_id(),
930            subscription.schema_id,
931            subscription.database_id,
932        );
933        Ok(())
934    }
935
936    fn get_database_id_by_schema(&self, schema_id: SchemaId) -> DatabaseId {
937        *self
938            .schema_id_to_database_id
939            .read()
940            .get(&schema_id)
941            .unwrap()
942    }
943
944    fn get_object_type(&self, object_id: ObjectId) -> PbObjectType {
945        let catalog = self.catalog.read();
946        for database in catalog.iter_databases() {
947            for schema in database.iter_schemas() {
948                if let Some(table) = schema.get_created_table_by_id(object_id.as_table_id()) {
949                    return if table.is_mview() {
950                        PbObjectType::Mview
951                    } else {
952                        PbObjectType::Table
953                    };
954                }
955                if schema.get_source_by_id(object_id.as_source_id()).is_some() {
956                    return PbObjectType::Source;
957                }
958                if schema.get_view_by_id(object_id.as_view_id()).is_some() {
959                    return PbObjectType::View;
960                }
961                if schema.get_index_by_id(object_id.as_index_id()).is_some() {
962                    return PbObjectType::Index;
963                }
964            }
965        }
966        PbObjectType::Unspecified
967    }
968
969    fn insert_object_dependencies(&self, object_id: ObjectId, dependencies: HashSet<ObjectId>) {
970        if dependencies.is_empty() {
971            return;
972        }
973        let dependencies = dependencies
974            .into_iter()
975            .map(|referenced_object_id| PbObjectDependency {
976                object_id,
977                referenced_object_id,
978                referenced_object_type: self.get_object_type(referenced_object_id) as i32,
979            })
980            .collect();
981        self.catalog
982            .write()
983            .insert_object_dependencies(dependencies);
984    }
985}
986
987pub struct MockUserInfoWriter {
988    id: AtomicU32,
989    user_info: Arc<RwLock<UserInfoManager>>,
990}
991
992#[async_trait::async_trait]
993impl UserInfoWriter for MockUserInfoWriter {
994    async fn create_user(&self, user: UserInfo) -> Result<()> {
995        let mut user = user;
996        user.id = self.gen_id().into();
997        self.user_info.write().create_user(user);
998        Ok(())
999    }
1000
1001    async fn drop_user(&self, id: UserId) -> Result<()> {
1002        self.user_info.write().drop_user(id);
1003        Ok(())
1004    }
1005
1006    async fn update_user(
1007        &self,
1008        update_user: UserInfo,
1009        update_fields: Vec<UpdateField>,
1010    ) -> Result<()> {
1011        let mut lock = self.user_info.write();
1012        let id = update_user.get_id();
1013        let Some(old_name) = lock.get_user_name_by_id(id) else {
1014            return Ok(());
1015        };
1016        let mut user_info = lock.get_user_by_name(&old_name).unwrap().to_prost();
1017        update_fields.into_iter().for_each(|field| match field {
1018            UpdateField::Super => user_info.is_super = update_user.is_super,
1019            UpdateField::Login => user_info.can_login = update_user.can_login,
1020            UpdateField::CreateDb => user_info.can_create_db = update_user.can_create_db,
1021            UpdateField::CreateUser => user_info.can_create_user = update_user.can_create_user,
1022            UpdateField::AuthInfo => user_info.auth_info.clone_from(&update_user.auth_info),
1023            UpdateField::Rename => user_info.name.clone_from(&update_user.name),
1024            UpdateField::Admin => user_info.is_admin = update_user.is_admin,
1025            UpdateField::Unspecified => unreachable!(),
1026        });
1027        lock.update_user(update_user);
1028        Ok(())
1029    }
1030
1031    /// In `MockUserInfoWriter`, we don't support expand privilege with `GrantAllTables` and
1032    /// `GrantAllSources` when grant privilege to user.
1033    async fn grant_privilege(
1034        &self,
1035        users: Vec<UserId>,
1036        privileges: Vec<GrantPrivilege>,
1037        with_grant_option: bool,
1038        _grantor: UserId,
1039    ) -> Result<()> {
1040        let privileges = privileges
1041            .into_iter()
1042            .map(|mut p| {
1043                p.action_with_opts
1044                    .iter_mut()
1045                    .for_each(|ao| ao.with_grant_option = with_grant_option);
1046                p
1047            })
1048            .collect::<Vec<_>>();
1049        for user_id in users {
1050            if let Some(u) = self.user_info.write().get_user_mut(user_id) {
1051                u.extend_privileges(privileges.clone());
1052            }
1053        }
1054        Ok(())
1055    }
1056
1057    /// In `MockUserInfoWriter`, we don't support expand privilege with `RevokeAllTables` and
1058    /// `RevokeAllSources` when revoke privilege from user.
1059    async fn revoke_privilege(
1060        &self,
1061        users: Vec<UserId>,
1062        privileges: Vec<GrantPrivilege>,
1063        _granted_by: UserId,
1064        _revoke_by: UserId,
1065        revoke_grant_option: bool,
1066        _cascade: bool,
1067    ) -> Result<()> {
1068        for user_id in users {
1069            if let Some(u) = self.user_info.write().get_user_mut(user_id) {
1070                u.revoke_privileges(privileges.clone(), revoke_grant_option);
1071            }
1072        }
1073        Ok(())
1074    }
1075
1076    async fn alter_default_privilege(
1077        &self,
1078        _users: Vec<UserId>,
1079        _database_id: DatabaseId,
1080        _schemas: Vec<SchemaId>,
1081        _operation: AlterDefaultPrivilegeOperation,
1082        _operated_by: UserId,
1083    ) -> Result<()> {
1084        todo!()
1085    }
1086}
1087
1088impl MockUserInfoWriter {
1089    pub fn new(user_info: Arc<RwLock<UserInfoManager>>) -> Self {
1090        user_info.write().create_user(UserInfo {
1091            id: DEFAULT_SUPER_USER_ID,
1092            name: DEFAULT_SUPER_USER.to_owned(),
1093            is_super: true,
1094            can_create_db: true,
1095            can_create_user: true,
1096            can_login: true,
1097            ..Default::default()
1098        });
1099        user_info.write().create_user(UserInfo {
1100            id: DEFAULT_SUPER_USER_FOR_ADMIN_ID,
1101            name: DEFAULT_SUPER_USER_FOR_ADMIN.to_owned(),
1102            is_super: true,
1103            can_create_db: true,
1104            can_create_user: true,
1105            can_login: true,
1106            is_admin: true,
1107            ..Default::default()
1108        });
1109        Self {
1110            user_info,
1111            id: AtomicU32::new(NON_RESERVED_USER_ID.as_raw_id()),
1112        }
1113    }
1114
1115    fn gen_id(&self) -> u32 {
1116        self.id.fetch_add(1, Ordering::SeqCst)
1117    }
1118}
1119
1120pub struct MockFrontendMetaClient {}
1121
1122#[async_trait::async_trait]
1123impl FrontendMetaClient for MockFrontendMetaClient {
1124    async fn try_unregister(&self) {}
1125
1126    async fn flush(&self, _database_id: DatabaseId) -> RpcResult<HummockVersionId> {
1127        Ok(INVALID_VERSION_ID)
1128    }
1129
1130    async fn cancel_creating_jobs(&self, _infos: PbJobs) -> RpcResult<Vec<u32>> {
1131        Ok(vec![])
1132    }
1133
1134    async fn list_table_fragments(
1135        &self,
1136        _table_ids: &[JobId],
1137    ) -> RpcResult<HashMap<JobId, TableFragmentInfo>> {
1138        Ok(HashMap::default())
1139    }
1140
1141    async fn list_streaming_job_states(&self) -> RpcResult<Vec<StreamingJobState>> {
1142        Ok(vec![])
1143    }
1144
1145    async fn list_fragment_distribution(
1146        &self,
1147        _include_node: bool,
1148    ) -> RpcResult<Vec<FragmentDistribution>> {
1149        Ok(vec![])
1150    }
1151
1152    async fn list_creating_fragment_distribution(&self) -> RpcResult<Vec<FragmentDistribution>> {
1153        Ok(vec![])
1154    }
1155
1156    async fn list_actor_states(&self) -> RpcResult<Vec<ActorState>> {
1157        Ok(vec![])
1158    }
1159
1160    async fn list_actor_splits(&self) -> RpcResult<Vec<ActorSplit>> {
1161        Ok(vec![])
1162    }
1163
1164    async fn list_meta_snapshots(&self) -> RpcResult<Vec<MetaSnapshotMetadata>> {
1165        Ok(vec![])
1166    }
1167
1168    async fn list_sink_log_store_tables(
1169        &self,
1170    ) -> RpcResult<Vec<list_sink_log_store_tables_response::SinkLogStoreTable>> {
1171        Ok(vec![])
1172    }
1173
1174    async fn set_system_param(
1175        &self,
1176        _param: String,
1177        _value: Option<String>,
1178    ) -> RpcResult<Option<SystemParamsReader>> {
1179        Ok(Some(SystemParams::default().into()))
1180    }
1181
1182    async fn get_session_params(&self) -> RpcResult<SessionConfig> {
1183        Ok(Default::default())
1184    }
1185
1186    async fn set_session_param(&self, _param: String, _value: Option<String>) -> RpcResult<String> {
1187        Ok("".to_owned())
1188    }
1189
1190    async fn get_ddl_progress(&self) -> RpcResult<Vec<DdlProgress>> {
1191        Ok(vec![])
1192    }
1193
1194    async fn get_tables(
1195        &self,
1196        _table_ids: Vec<crate::catalog::TableId>,
1197        _include_dropped_tables: bool,
1198    ) -> RpcResult<HashMap<crate::catalog::TableId, Table>> {
1199        Ok(HashMap::new())
1200    }
1201
1202    async fn list_hummock_pinned_versions(&self) -> RpcResult<Vec<(WorkerId, HummockVersionId)>> {
1203        unimplemented!()
1204    }
1205
1206    async fn list_refresh_table_states(&self) -> RpcResult<Vec<RefreshTableState>> {
1207        unimplemented!()
1208    }
1209
1210    async fn get_hummock_current_version(&self) -> RpcResult<HummockVersion> {
1211        Ok(HummockVersion::default())
1212    }
1213
1214    async fn get_hummock_checkpoint_version(&self) -> RpcResult<HummockVersion> {
1215        unimplemented!()
1216    }
1217
1218    async fn list_version_deltas(&self) -> RpcResult<Vec<HummockVersionDelta>> {
1219        unimplemented!()
1220    }
1221
1222    async fn list_branched_objects(&self) -> RpcResult<Vec<BranchedObject>> {
1223        unimplemented!()
1224    }
1225
1226    async fn list_hummock_compaction_group_configs(&self) -> RpcResult<Vec<CompactionGroupInfo>> {
1227        unimplemented!()
1228    }
1229
1230    async fn list_hummock_active_write_limits(
1231        &self,
1232    ) -> RpcResult<HashMap<CompactionGroupId, WriteLimit>> {
1233        unimplemented!()
1234    }
1235
1236    async fn list_hummock_meta_configs(&self) -> RpcResult<HashMap<String, String>> {
1237        unimplemented!()
1238    }
1239
1240    async fn list_event_log(&self) -> RpcResult<Vec<EventLog>> {
1241        Ok(vec![])
1242    }
1243
1244    async fn list_compact_task_assignment(&self) -> RpcResult<Vec<CompactTaskAssignment>> {
1245        unimplemented!()
1246    }
1247
1248    async fn list_all_nodes(&self) -> RpcResult<Vec<WorkerNode>> {
1249        Ok(vec![])
1250    }
1251
1252    async fn list_compact_task_progress(&self) -> RpcResult<Vec<CompactTaskProgress>> {
1253        unimplemented!()
1254    }
1255
1256    async fn recover(&self) -> RpcResult<()> {
1257        unimplemented!()
1258    }
1259
1260    async fn apply_throttle(
1261        &self,
1262        _throttle_target: PbThrottleTarget,
1263        _throttle_type: risingwave_pb::common::PbThrottleType,
1264        _id: u32,
1265        _rate_limit: Option<u32>,
1266    ) -> RpcResult<()> {
1267        unimplemented!()
1268    }
1269
1270    async fn alter_fragment_parallelism(
1271        &self,
1272        _fragment_ids: Vec<FragmentId>,
1273        _parallelism: Option<PbTableParallelism>,
1274    ) -> RpcResult<()> {
1275        unimplemented!()
1276    }
1277
1278    async fn get_cluster_recovery_status(&self) -> RpcResult<RecoveryStatus> {
1279        Ok(RecoveryStatus::StatusRunning)
1280    }
1281
1282    async fn get_cluster_limits(&self) -> RpcResult<Vec<ClusterLimit>> {
1283        Ok(vec![])
1284    }
1285
1286    async fn list_rate_limits(&self) -> RpcResult<Vec<RateLimitInfo>> {
1287        Ok(vec![])
1288    }
1289
1290    async fn list_cdc_progress(&self) -> RpcResult<HashMap<JobId, PbCdcProgress>> {
1291        Ok(HashMap::default())
1292    }
1293
1294    async fn get_meta_store_endpoint(&self) -> RpcResult<String> {
1295        unimplemented!()
1296    }
1297
1298    async fn alter_sink_props(
1299        &self,
1300        _sink_id: SinkId,
1301        _changed_props: BTreeMap<String, String>,
1302        _changed_secret_refs: BTreeMap<String, PbSecretRef>,
1303        _connector_conn_ref: Option<ConnectionId>,
1304    ) -> RpcResult<()> {
1305        unimplemented!()
1306    }
1307
1308    async fn alter_iceberg_table_props(
1309        &self,
1310        _table_id: TableId,
1311        _sink_id: SinkId,
1312        _source_id: SourceId,
1313        _changed_props: BTreeMap<String, String>,
1314        _changed_secret_refs: BTreeMap<String, PbSecretRef>,
1315        _connector_conn_ref: Option<ConnectionId>,
1316    ) -> RpcResult<()> {
1317        unimplemented!()
1318    }
1319
1320    async fn alter_source_connector_props(
1321        &self,
1322        _source_id: SourceId,
1323        _changed_props: BTreeMap<String, String>,
1324        _changed_secret_refs: BTreeMap<String, PbSecretRef>,
1325        _connector_conn_ref: Option<ConnectionId>,
1326    ) -> RpcResult<()> {
1327        unimplemented!()
1328    }
1329
1330    async fn alter_connection_connector_props(
1331        &self,
1332        _connection_id: u32,
1333        _changed_props: BTreeMap<String, String>,
1334        _changed_secret_refs: BTreeMap<String, PbSecretRef>,
1335    ) -> RpcResult<()> {
1336        Ok(())
1337    }
1338
1339    async fn list_hosted_iceberg_tables(&self) -> RpcResult<Vec<IcebergTable>> {
1340        unimplemented!()
1341    }
1342
1343    async fn get_fragment_by_id(
1344        &self,
1345        _fragment_id: FragmentId,
1346    ) -> RpcResult<Option<FragmentDistribution>> {
1347        unimplemented!()
1348    }
1349
1350    async fn get_fragment_vnodes(
1351        &self,
1352        _fragment_id: FragmentId,
1353    ) -> RpcResult<Vec<(ActorId, Vec<u32>)>> {
1354        unimplemented!()
1355    }
1356
1357    async fn get_actor_vnodes(&self, _actor_id: ActorId) -> RpcResult<Vec<u32>> {
1358        unimplemented!()
1359    }
1360
1361    fn worker_id(&self) -> WorkerId {
1362        0.into()
1363    }
1364
1365    async fn set_sync_log_store_aligned(&self, _job_id: JobId, _aligned: bool) -> RpcResult<()> {
1366        Ok(())
1367    }
1368
1369    async fn compact_iceberg_table(&self, _sink_id: SinkId) -> RpcResult<u64> {
1370        Ok(1)
1371    }
1372
1373    async fn expire_iceberg_table_snapshots(&self, _sink_id: SinkId) -> RpcResult<()> {
1374        Ok(())
1375    }
1376
1377    async fn refresh(&self, _request: RefreshRequest) -> RpcResult<RefreshResponse> {
1378        Ok(RefreshResponse { status: None })
1379    }
1380
1381    fn cluster_id(&self) -> &str {
1382        "test-cluster-uuid"
1383    }
1384
1385    async fn list_unmigrated_tables(&self) -> RpcResult<HashMap<crate::catalog::TableId, String>> {
1386        unimplemented!()
1387    }
1388}
1389
1390#[cfg(test)]
1391pub static PROTO_FILE_DATA: &str = r#"
1392    syntax = "proto3";
1393    package test;
1394    message TestRecord {
1395      int32 id = 1;
1396      Country country = 3;
1397      int64 zipcode = 4;
1398      float rate = 5;
1399    }
1400    message TestRecordAlterType {
1401        string id = 1;
1402        Country country = 3;
1403        int32 zipcode = 4;
1404        float rate = 5;
1405      }
1406    message TestRecordExt {
1407      int32 id = 1;
1408      Country country = 3;
1409      int64 zipcode = 4;
1410      float rate = 5;
1411      string name = 6;
1412    }
1413    message Country {
1414      string address = 1;
1415      City city = 2;
1416      string zipcode = 3;
1417    }
1418    message City {
1419      string address = 1;
1420      string zipcode = 2;
1421    }"#;
1422
1423/// Returns the file.
1424/// (`NamedTempFile` will automatically delete the file when it goes out of scope.)
1425pub fn create_proto_file(proto_data: &str) -> NamedTempFile {
1426    let in_file = Builder::new()
1427        .prefix("temp")
1428        .suffix(".proto")
1429        .rand_bytes(8)
1430        .tempfile()
1431        .unwrap();
1432
1433    let out_file = Builder::new()
1434        .prefix("temp")
1435        .suffix(".pb")
1436        .rand_bytes(8)
1437        .tempfile()
1438        .unwrap();
1439
1440    let mut file = in_file.as_file();
1441    file.write_all(proto_data.as_ref())
1442        .expect("writing binary to test file");
1443    file.flush().expect("flush temp file failed");
1444    let include_path = in_file
1445        .path()
1446        .parent()
1447        .unwrap()
1448        .to_string_lossy()
1449        .into_owned();
1450    let out_path = out_file.path().to_string_lossy().into_owned();
1451    let in_path = in_file.path().to_string_lossy().into_owned();
1452    let mut compile = std::process::Command::new("protoc");
1453
1454    let out = compile
1455        .arg("--include_imports")
1456        .arg("-I")
1457        .arg(include_path)
1458        .arg(format!("--descriptor_set_out={}", out_path))
1459        .arg(in_path)
1460        .output()
1461        .expect("failed to compile proto");
1462    if !out.status.success() {
1463        panic!("compile proto failed \n output: {:?}", out);
1464    }
1465    out_file
1466}