mas_storage_pg/personal/
session.rs

1// Copyright 2025 New Vector Ltd.
2//
3// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
4// Please see LICENSE files in the repository root for full details.
5
6use std::net::IpAddr;
7
8use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use mas_data_model::{
11    Clock, User,
12    personal::{
13        PersonalAccessToken,
14        session::{PersonalSession, PersonalSessionOwner, SessionState},
15    },
16};
17use mas_storage::{
18    Page, Pagination,
19    pagination::Node,
20    personal::{PersonalSessionFilter, PersonalSessionRepository, PersonalSessionState},
21};
22use oauth2_types::scope::Scope;
23use opentelemetry_semantic_conventions::trace::DB_QUERY_TEXT;
24use rand::RngCore;
25use sea_query::{
26    Cond, Condition, Expr, PgFunc, PostgresQueryBuilder, Query, SimpleExpr, enum_def,
27    extension::postgres::PgExpr as _,
28};
29use sea_query_binder::SqlxBinder as _;
30use sqlx::PgConnection;
31use tracing::{Instrument as _, info_span};
32use ulid::Ulid;
33use uuid::Uuid;
34
35use crate::{
36    DatabaseError,
37    errors::DatabaseInconsistencyError,
38    filter::{Filter, StatementExt as _},
39    iden::{PersonalAccessTokens, PersonalSessions},
40    pagination::QueryBuilderExt as _,
41    tracing::ExecuteExt as _,
42};
43
44/// An implementation of [`PersonalSessionRepository`] for a PostgreSQL
45/// connection
46pub struct PgPersonalSessionRepository<'c> {
47    conn: &'c mut PgConnection,
48}
49
50impl<'c> PgPersonalSessionRepository<'c> {
51    /// Create a new [`PgPersonalSessionRepository`] from an active PostgreSQL
52    /// connection
53    pub fn new(conn: &'c mut PgConnection) -> Self {
54        Self { conn }
55    }
56}
57
58#[derive(sqlx::FromRow)]
59#[enum_def]
60struct PersonalSessionLookup {
61    personal_session_id: Uuid,
62    owner_user_id: Option<Uuid>,
63    owner_oauth2_client_id: Option<Uuid>,
64    actor_user_id: Uuid,
65    human_name: String,
66    scope_list: Vec<String>,
67    created_at: DateTime<Utc>,
68    revoked_at: Option<DateTime<Utc>>,
69    last_active_at: Option<DateTime<Utc>>,
70    last_active_ip: Option<IpAddr>,
71}
72
73impl Node<Ulid> for PersonalSessionLookup {
74    fn cursor(&self) -> Ulid {
75        self.personal_session_id.into()
76    }
77}
78
79impl TryFrom<PersonalSessionLookup> for PersonalSession {
80    type Error = DatabaseInconsistencyError;
81
82    fn try_from(value: PersonalSessionLookup) -> Result<Self, Self::Error> {
83        let id = Ulid::from(value.personal_session_id);
84        let scope: Result<Scope, _> = value.scope_list.iter().map(|s| s.parse()).collect();
85        let scope = scope.map_err(|e| {
86            DatabaseInconsistencyError::on("personal_sessions")
87                .column("scope")
88                .row(id)
89                .source(e)
90        })?;
91
92        let state = match value.revoked_at {
93            None => SessionState::Valid,
94            Some(revoked_at) => SessionState::Revoked { revoked_at },
95        };
96
97        let owner = match (value.owner_user_id, value.owner_oauth2_client_id) {
98            (Some(owner_user_id), None) => PersonalSessionOwner::User(Ulid::from(owner_user_id)),
99            (None, Some(owner_oauth2_client_id)) => {
100                PersonalSessionOwner::OAuth2Client(Ulid::from(owner_oauth2_client_id))
101            }
102            _ => {
103                // should be impossible (CHECK constraint in Postgres prevents it)
104                return Err(DatabaseInconsistencyError::on("personal_sessions")
105                    .column("owner_user_id, owner_oauth2_client_id")
106                    .row(id));
107            }
108        };
109
110        Ok(PersonalSession {
111            id,
112            state,
113            owner,
114            actor_user_id: Ulid::from(value.actor_user_id),
115            human_name: value.human_name,
116            scope,
117            created_at: value.created_at,
118            last_active_at: value.last_active_at,
119            last_active_ip: value.last_active_ip,
120        })
121    }
122}
123
124#[derive(sqlx::FromRow)]
125#[enum_def]
126struct PersonalSessionAndAccessTokenLookup {
127    personal_session_id: Uuid,
128    owner_user_id: Option<Uuid>,
129    owner_oauth2_client_id: Option<Uuid>,
130    actor_user_id: Uuid,
131    human_name: String,
132    scope_list: Vec<String>,
133    created_at: DateTime<Utc>,
134    revoked_at: Option<DateTime<Utc>>,
135    last_active_at: Option<DateTime<Utc>>,
136    last_active_ip: Option<IpAddr>,
137
138    // tokens
139    personal_access_token_id: Option<Uuid>,
140    token_created_at: Option<DateTime<Utc>>,
141    token_expires_at: Option<DateTime<Utc>>,
142}
143
144impl Node<Ulid> for PersonalSessionAndAccessTokenLookup {
145    fn cursor(&self) -> Ulid {
146        self.personal_session_id.into()
147    }
148}
149
150impl TryFrom<PersonalSessionAndAccessTokenLookup>
151    for (PersonalSession, Option<PersonalAccessToken>)
152{
153    type Error = DatabaseInconsistencyError;
154
155    fn try_from(value: PersonalSessionAndAccessTokenLookup) -> Result<Self, Self::Error> {
156        let session = PersonalSession::try_from(PersonalSessionLookup {
157            personal_session_id: value.personal_session_id,
158            owner_user_id: value.owner_user_id,
159            owner_oauth2_client_id: value.owner_oauth2_client_id,
160            actor_user_id: value.actor_user_id,
161            human_name: value.human_name,
162            scope_list: value.scope_list,
163            created_at: value.created_at,
164            revoked_at: value.revoked_at,
165            last_active_at: value.last_active_at,
166            last_active_ip: value.last_active_ip,
167        })?;
168
169        let token_opt = if let Some(id) = value.personal_access_token_id {
170            let id = Ulid::from(id);
171            Some(PersonalAccessToken {
172                id,
173                session_id: session.id,
174                // should not be possible
175                created_at: value.token_created_at.ok_or(
176                    DatabaseInconsistencyError::on("personal_sessions")
177                        .column("created_at")
178                        .row(id),
179                )?,
180                expires_at: value.token_expires_at,
181                revoked_at: None,
182            })
183        } else {
184            None
185        };
186
187        Ok((session, token_opt))
188    }
189}
190
191#[async_trait]
192impl PersonalSessionRepository for PgPersonalSessionRepository<'_> {
193    type Error = DatabaseError;
194
195    #[tracing::instrument(
196        name = "db.personal_session.lookup",
197        skip_all,
198        fields(
199            db.query.text,
200            session.id = %id,
201        ),
202        err,
203    )]
204    async fn lookup(&mut self, id: Ulid) -> Result<Option<PersonalSession>, Self::Error> {
205        let res = sqlx::query_as!(
206            PersonalSessionLookup,
207            r#"
208                SELECT personal_session_id
209                     , owner_user_id
210                     , owner_oauth2_client_id
211                     , actor_user_id
212                     , scope_list
213                     , created_at
214                     , revoked_at
215                     , human_name
216                     , last_active_at
217                     , last_active_ip as "last_active_ip: IpAddr"
218                FROM personal_sessions
219
220                WHERE personal_session_id = $1
221            "#,
222            Uuid::from(id),
223        )
224        .traced()
225        .fetch_optional(&mut *self.conn)
226        .await?;
227
228        let Some(session) = res else { return Ok(None) };
229
230        Ok(Some(session.try_into()?))
231    }
232
233    #[tracing::instrument(
234        name = "db.personal_session.add",
235        skip_all,
236        fields(
237            db.query.text,
238            session.id,
239            session.scope = %scope,
240        ),
241        err,
242    )]
243    async fn add(
244        &mut self,
245        rng: &mut (dyn RngCore + Send),
246        clock: &dyn Clock,
247        owner: PersonalSessionOwner,
248        actor_user: &User,
249        human_name: String,
250        scope: Scope,
251    ) -> Result<PersonalSession, Self::Error> {
252        let created_at = clock.now();
253        let id = Ulid::from_datetime_with_source(created_at.into(), rng);
254        tracing::Span::current().record("session.id", tracing::field::display(id));
255
256        let scope_list: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
257
258        let (owner_user_id, owner_oauth2_client_id) = match owner {
259            PersonalSessionOwner::User(ulid) => (Some(Uuid::from(ulid)), None),
260            PersonalSessionOwner::OAuth2Client(ulid) => (None, Some(Uuid::from(ulid))),
261        };
262
263        sqlx::query!(
264            r#"
265                INSERT INTO personal_sessions
266                    ( personal_session_id
267                    , owner_user_id
268                    , owner_oauth2_client_id
269                    , actor_user_id
270                    , human_name
271                    , scope_list
272                    , created_at
273                    )
274                VALUES ($1, $2, $3, $4, $5, $6, $7)
275            "#,
276            Uuid::from(id),
277            owner_user_id,
278            owner_oauth2_client_id,
279            Uuid::from(actor_user.id),
280            &human_name,
281            &scope_list,
282            created_at,
283        )
284        .traced()
285        .execute(&mut *self.conn)
286        .await?;
287
288        Ok(PersonalSession {
289            id,
290            state: SessionState::Valid,
291            owner,
292            actor_user_id: actor_user.id,
293            human_name,
294            scope,
295            created_at,
296            last_active_at: None,
297            last_active_ip: None,
298        })
299    }
300
301    #[tracing::instrument(
302        name = "db.personal_session.revoke",
303        skip_all,
304        fields(
305            db.query.text,
306            %session.id,
307            %session.scope,
308        ),
309        err,
310    )]
311    async fn revoke(
312        &mut self,
313        clock: &dyn Clock,
314        session: PersonalSession,
315    ) -> Result<PersonalSession, Self::Error> {
316        let revoked_at = clock.now();
317
318        {
319            // Revoke dependent PATs
320            let span = info_span!(
321                "db.personal_session.revoke.tokens",
322                { DB_QUERY_TEXT } = tracing::field::Empty,
323            );
324
325            sqlx::query!(
326                r#"
327                    UPDATE personal_access_tokens
328                    SET revoked_at = $2
329                    WHERE personal_session_id = $1 AND revoked_at IS NULL
330                "#,
331                Uuid::from(session.id),
332                revoked_at,
333            )
334            .record(&span)
335            .execute(&mut *self.conn)
336            .instrument(span)
337            .await?;
338        }
339
340        let res = sqlx::query!(
341            r#"
342                UPDATE personal_sessions
343                SET revoked_at = $2
344                WHERE personal_session_id = $1
345            "#,
346            Uuid::from(session.id),
347            revoked_at,
348        )
349        .traced()
350        .execute(&mut *self.conn)
351        .await?;
352
353        DatabaseError::ensure_affected_rows(&res, 1)?;
354
355        session
356            .finish(revoked_at)
357            .map_err(DatabaseError::to_invalid_operation)
358    }
359
360    #[tracing::instrument(
361        name = "db.personal_session.list",
362        skip_all,
363        fields(
364            db.query.text,
365        ),
366        err,
367    )]
368    async fn list(
369        &mut self,
370        filter: PersonalSessionFilter<'_>,
371        pagination: Pagination,
372    ) -> Result<Page<(PersonalSession, Option<PersonalAccessToken>)>, Self::Error> {
373        let (sql, arguments) = Query::select()
374            .expr_as(
375                Expr::col((PersonalSessions::Table, PersonalSessions::PersonalSessionId)),
376                PersonalSessionAndAccessTokenLookupIden::PersonalSessionId,
377            )
378            .expr_as(
379                Expr::col((PersonalSessions::Table, PersonalSessions::OwnerUserId)),
380                PersonalSessionAndAccessTokenLookupIden::OwnerUserId,
381            )
382            .expr_as(
383                Expr::col((
384                    PersonalSessions::Table,
385                    PersonalSessions::OwnerOAuth2ClientId,
386                )),
387                PersonalSessionAndAccessTokenLookupIden::OwnerOauth2ClientId,
388            )
389            .expr_as(
390                Expr::col((PersonalSessions::Table, PersonalSessions::ActorUserId)),
391                PersonalSessionAndAccessTokenLookupIden::ActorUserId,
392            )
393            .expr_as(
394                Expr::col((PersonalSessions::Table, PersonalSessions::HumanName)),
395                PersonalSessionAndAccessTokenLookupIden::HumanName,
396            )
397            .expr_as(
398                Expr::col((PersonalSessions::Table, PersonalSessions::ScopeList)),
399                PersonalSessionAndAccessTokenLookupIden::ScopeList,
400            )
401            .expr_as(
402                Expr::col((PersonalSessions::Table, PersonalSessions::CreatedAt)),
403                PersonalSessionAndAccessTokenLookupIden::CreatedAt,
404            )
405            .expr_as(
406                Expr::col((PersonalSessions::Table, PersonalSessions::RevokedAt)),
407                PersonalSessionAndAccessTokenLookupIden::RevokedAt,
408            )
409            .expr_as(
410                Expr::col((PersonalSessions::Table, PersonalSessions::LastActiveAt)),
411                PersonalSessionAndAccessTokenLookupIden::LastActiveAt,
412            )
413            .expr_as(
414                Expr::col((PersonalSessions::Table, PersonalSessions::LastActiveIp)),
415                PersonalSessionAndAccessTokenLookupIden::LastActiveIp,
416            )
417            .expr_as(
418                Expr::col((
419                    PersonalAccessTokens::Table,
420                    PersonalAccessTokens::PersonalAccessTokenId,
421                )),
422                PersonalSessionAndAccessTokenLookupIden::PersonalAccessTokenId,
423            )
424            .expr_as(
425                Expr::col((PersonalAccessTokens::Table, PersonalAccessTokens::CreatedAt)),
426                PersonalSessionAndAccessTokenLookupIden::TokenCreatedAt,
427            )
428            .expr_as(
429                Expr::col((PersonalAccessTokens::Table, PersonalAccessTokens::ExpiresAt)),
430                PersonalSessionAndAccessTokenLookupIden::TokenExpiresAt,
431            )
432            .from(PersonalSessions::Table)
433            .left_join(
434                PersonalAccessTokens::Table,
435                Cond::all()
436                    .add(
437                        Expr::col((PersonalSessions::Table, PersonalSessions::PersonalSessionId))
438                            .eq(Expr::col((
439                                PersonalAccessTokens::Table,
440                                PersonalAccessTokens::PersonalSessionId,
441                            ))),
442                    )
443                    .add(
444                        Expr::col((PersonalAccessTokens::Table, PersonalAccessTokens::RevokedAt))
445                            .is_null(),
446                    ),
447            )
448            .apply_filter(filter)
449            .generate_pagination(
450                (PersonalSessions::Table, PersonalSessions::PersonalSessionId),
451                pagination,
452            )
453            .build_sqlx(PostgresQueryBuilder);
454
455        let edges: Vec<PersonalSessionAndAccessTokenLookup> = sqlx::query_as_with(&sql, arguments)
456            .traced()
457            .fetch_all(&mut *self.conn)
458            .await?;
459
460        let page = pagination.process(edges).try_map(TryFrom::try_from)?;
461
462        Ok(page)
463    }
464
465    #[tracing::instrument(
466        name = "db.personal_session.count",
467        skip_all,
468        fields(
469            db.query.text,
470        ),
471        err,
472    )]
473    async fn count(&mut self, filter: PersonalSessionFilter<'_>) -> Result<usize, Self::Error> {
474        let (sql, arguments) = Query::select()
475            .expr(Expr::col((PersonalSessions::Table, PersonalSessions::PersonalSessionId)).count())
476            .from(PersonalSessions::Table)
477            .left_join(
478                PersonalAccessTokens::Table,
479                Cond::all()
480                    .add(
481                        Expr::col((PersonalSessions::Table, PersonalSessions::PersonalSessionId))
482                            .eq(Expr::col((
483                                PersonalAccessTokens::Table,
484                                PersonalAccessTokens::PersonalSessionId,
485                            ))),
486                    )
487                    .add(
488                        Expr::col((PersonalAccessTokens::Table, PersonalAccessTokens::RevokedAt))
489                            .is_null(),
490                    ),
491            )
492            .apply_filter(filter)
493            .build_sqlx(PostgresQueryBuilder);
494
495        let count: i64 = sqlx::query_scalar_with(&sql, arguments)
496            .traced()
497            .fetch_one(&mut *self.conn)
498            .await?;
499
500        count
501            .try_into()
502            .map_err(DatabaseError::to_invalid_operation)
503    }
504}
505
506impl Filter for PersonalSessionFilter<'_> {
507    fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
508        sea_query::Condition::all()
509            .add_option(self.owner_user().map(|user| {
510                Expr::col((PersonalSessions::Table, PersonalSessions::OwnerUserId))
511                    .eq(Uuid::from(user.id))
512            }))
513            .add_option(self.owner_oauth2_client().map(|client| {
514                Expr::col((
515                    PersonalSessions::Table,
516                    PersonalSessions::OwnerOAuth2ClientId,
517                ))
518                .eq(Uuid::from(client.id))
519            }))
520            .add_option(self.actor_user().map(|user| {
521                Expr::col((PersonalSessions::Table, PersonalSessions::ActorUserId))
522                    .eq(Uuid::from(user.id))
523            }))
524            .add_option(self.device().map(|device| -> SimpleExpr {
525                if let Ok([stable_scope_token, unstable_scope_token]) = device.to_scope_token() {
526                    Condition::any()
527                        .add(
528                            Expr::val(stable_scope_token.to_string()).eq(PgFunc::any(Expr::col((
529                                PersonalSessions::Table,
530                                PersonalSessions::ScopeList,
531                            )))),
532                        )
533                        .add(Expr::val(unstable_scope_token.to_string()).eq(PgFunc::any(
534                            Expr::col((PersonalSessions::Table, PersonalSessions::ScopeList)),
535                        )))
536                        .into()
537                } else {
538                    // If the device ID can't be encoded as a scope token, match no rows
539                    Expr::val(false).into()
540                }
541            }))
542            .add_option(self.state().map(|state| match state {
543                PersonalSessionState::Active => {
544                    Expr::col((PersonalSessions::Table, PersonalSessions::RevokedAt)).is_null()
545                }
546                PersonalSessionState::Revoked => {
547                    Expr::col((PersonalSessions::Table, PersonalSessions::RevokedAt)).is_not_null()
548                }
549            }))
550            .add_option(self.scope().map(|scope| {
551                let scope: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
552                Expr::col((PersonalSessions::Table, PersonalSessions::ScopeList)).contains(scope)
553            }))
554            .add_option(self.last_active_before().map(|last_active_before| {
555                Expr::col((PersonalSessions::Table, PersonalSessions::LastActiveAt))
556                    .lt(last_active_before)
557            }))
558            .add_option(self.last_active_after().map(|last_active_after| {
559                Expr::col((PersonalSessions::Table, PersonalSessions::LastActiveAt))
560                    .gt(last_active_after)
561            }))
562            .add_option(self.expires_before().map(|expires_before| {
563                Expr::col((PersonalAccessTokens::Table, PersonalAccessTokens::ExpiresAt))
564                    .lt(expires_before)
565            }))
566            .add_option(self.expires_after().map(|expires_after| {
567                Expr::col((PersonalAccessTokens::Table, PersonalAccessTokens::ExpiresAt))
568                    .gt(expires_after)
569            }))
570    }
571}