1use std::net::IpAddr;
7
8use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use mas_data_model::{
11 Clock, User,
12 personal::{
13 PersonalAccessToken,
14 session::{PersonalSession, PersonalSessionOwner, SessionState},
15 },
16};
17use mas_storage::{
18 Page, Pagination,
19 pagination::Node,
20 personal::{PersonalSessionFilter, PersonalSessionRepository, PersonalSessionState},
21};
22use oauth2_types::scope::Scope;
23use opentelemetry_semantic_conventions::trace::DB_QUERY_TEXT;
24use rand::RngCore;
25use sea_query::{
26 Cond, Condition, Expr, PgFunc, PostgresQueryBuilder, Query, SimpleExpr, enum_def,
27 extension::postgres::PgExpr as _,
28};
29use sea_query_binder::SqlxBinder as _;
30use sqlx::PgConnection;
31use tracing::{Instrument as _, info_span};
32use ulid::Ulid;
33use uuid::Uuid;
34
35use crate::{
36 DatabaseError,
37 errors::DatabaseInconsistencyError,
38 filter::{Filter, StatementExt as _},
39 iden::{PersonalAccessTokens, PersonalSessions},
40 pagination::QueryBuilderExt as _,
41 tracing::ExecuteExt as _,
42};
43
44pub struct PgPersonalSessionRepository<'c> {
47 conn: &'c mut PgConnection,
48}
49
50impl<'c> PgPersonalSessionRepository<'c> {
51 pub fn new(conn: &'c mut PgConnection) -> Self {
54 Self { conn }
55 }
56}
57
58#[derive(sqlx::FromRow)]
59#[enum_def]
60struct PersonalSessionLookup {
61 personal_session_id: Uuid,
62 owner_user_id: Option<Uuid>,
63 owner_oauth2_client_id: Option<Uuid>,
64 actor_user_id: Uuid,
65 human_name: String,
66 scope_list: Vec<String>,
67 created_at: DateTime<Utc>,
68 revoked_at: Option<DateTime<Utc>>,
69 last_active_at: Option<DateTime<Utc>>,
70 last_active_ip: Option<IpAddr>,
71}
72
73impl Node<Ulid> for PersonalSessionLookup {
74 fn cursor(&self) -> Ulid {
75 self.personal_session_id.into()
76 }
77}
78
79impl TryFrom<PersonalSessionLookup> for PersonalSession {
80 type Error = DatabaseInconsistencyError;
81
82 fn try_from(value: PersonalSessionLookup) -> Result<Self, Self::Error> {
83 let id = Ulid::from(value.personal_session_id);
84 let scope: Result<Scope, _> = value.scope_list.iter().map(|s| s.parse()).collect();
85 let scope = scope.map_err(|e| {
86 DatabaseInconsistencyError::on("personal_sessions")
87 .column("scope")
88 .row(id)
89 .source(e)
90 })?;
91
92 let state = match value.revoked_at {
93 None => SessionState::Valid,
94 Some(revoked_at) => SessionState::Revoked { revoked_at },
95 };
96
97 let owner = match (value.owner_user_id, value.owner_oauth2_client_id) {
98 (Some(owner_user_id), None) => PersonalSessionOwner::User(Ulid::from(owner_user_id)),
99 (None, Some(owner_oauth2_client_id)) => {
100 PersonalSessionOwner::OAuth2Client(Ulid::from(owner_oauth2_client_id))
101 }
102 _ => {
103 return Err(DatabaseInconsistencyError::on("personal_sessions")
105 .column("owner_user_id, owner_oauth2_client_id")
106 .row(id));
107 }
108 };
109
110 Ok(PersonalSession {
111 id,
112 state,
113 owner,
114 actor_user_id: Ulid::from(value.actor_user_id),
115 human_name: value.human_name,
116 scope,
117 created_at: value.created_at,
118 last_active_at: value.last_active_at,
119 last_active_ip: value.last_active_ip,
120 })
121 }
122}
123
124#[derive(sqlx::FromRow)]
125#[enum_def]
126struct PersonalSessionAndAccessTokenLookup {
127 personal_session_id: Uuid,
128 owner_user_id: Option<Uuid>,
129 owner_oauth2_client_id: Option<Uuid>,
130 actor_user_id: Uuid,
131 human_name: String,
132 scope_list: Vec<String>,
133 created_at: DateTime<Utc>,
134 revoked_at: Option<DateTime<Utc>>,
135 last_active_at: Option<DateTime<Utc>>,
136 last_active_ip: Option<IpAddr>,
137
138 personal_access_token_id: Option<Uuid>,
140 token_created_at: Option<DateTime<Utc>>,
141 token_expires_at: Option<DateTime<Utc>>,
142}
143
144impl Node<Ulid> for PersonalSessionAndAccessTokenLookup {
145 fn cursor(&self) -> Ulid {
146 self.personal_session_id.into()
147 }
148}
149
150impl TryFrom<PersonalSessionAndAccessTokenLookup>
151 for (PersonalSession, Option<PersonalAccessToken>)
152{
153 type Error = DatabaseInconsistencyError;
154
155 fn try_from(value: PersonalSessionAndAccessTokenLookup) -> Result<Self, Self::Error> {
156 let session = PersonalSession::try_from(PersonalSessionLookup {
157 personal_session_id: value.personal_session_id,
158 owner_user_id: value.owner_user_id,
159 owner_oauth2_client_id: value.owner_oauth2_client_id,
160 actor_user_id: value.actor_user_id,
161 human_name: value.human_name,
162 scope_list: value.scope_list,
163 created_at: value.created_at,
164 revoked_at: value.revoked_at,
165 last_active_at: value.last_active_at,
166 last_active_ip: value.last_active_ip,
167 })?;
168
169 let token_opt = if let Some(id) = value.personal_access_token_id {
170 let id = Ulid::from(id);
171 Some(PersonalAccessToken {
172 id,
173 session_id: session.id,
174 created_at: value.token_created_at.ok_or(
176 DatabaseInconsistencyError::on("personal_sessions")
177 .column("created_at")
178 .row(id),
179 )?,
180 expires_at: value.token_expires_at,
181 revoked_at: None,
182 })
183 } else {
184 None
185 };
186
187 Ok((session, token_opt))
188 }
189}
190
191#[async_trait]
192impl PersonalSessionRepository for PgPersonalSessionRepository<'_> {
193 type Error = DatabaseError;
194
195 #[tracing::instrument(
196 name = "db.personal_session.lookup",
197 skip_all,
198 fields(
199 db.query.text,
200 session.id = %id,
201 ),
202 err,
203 )]
204 async fn lookup(&mut self, id: Ulid) -> Result<Option<PersonalSession>, Self::Error> {
205 let res = sqlx::query_as!(
206 PersonalSessionLookup,
207 r#"
208 SELECT personal_session_id
209 , owner_user_id
210 , owner_oauth2_client_id
211 , actor_user_id
212 , scope_list
213 , created_at
214 , revoked_at
215 , human_name
216 , last_active_at
217 , last_active_ip as "last_active_ip: IpAddr"
218 FROM personal_sessions
219
220 WHERE personal_session_id = $1
221 "#,
222 Uuid::from(id),
223 )
224 .traced()
225 .fetch_optional(&mut *self.conn)
226 .await?;
227
228 let Some(session) = res else { return Ok(None) };
229
230 Ok(Some(session.try_into()?))
231 }
232
233 #[tracing::instrument(
234 name = "db.personal_session.add",
235 skip_all,
236 fields(
237 db.query.text,
238 session.id,
239 session.scope = %scope,
240 ),
241 err,
242 )]
243 async fn add(
244 &mut self,
245 rng: &mut (dyn RngCore + Send),
246 clock: &dyn Clock,
247 owner: PersonalSessionOwner,
248 actor_user: &User,
249 human_name: String,
250 scope: Scope,
251 ) -> Result<PersonalSession, Self::Error> {
252 let created_at = clock.now();
253 let id = Ulid::from_datetime_with_source(created_at.into(), rng);
254 tracing::Span::current().record("session.id", tracing::field::display(id));
255
256 let scope_list: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
257
258 let (owner_user_id, owner_oauth2_client_id) = match owner {
259 PersonalSessionOwner::User(ulid) => (Some(Uuid::from(ulid)), None),
260 PersonalSessionOwner::OAuth2Client(ulid) => (None, Some(Uuid::from(ulid))),
261 };
262
263 sqlx::query!(
264 r#"
265 INSERT INTO personal_sessions
266 ( personal_session_id
267 , owner_user_id
268 , owner_oauth2_client_id
269 , actor_user_id
270 , human_name
271 , scope_list
272 , created_at
273 )
274 VALUES ($1, $2, $3, $4, $5, $6, $7)
275 "#,
276 Uuid::from(id),
277 owner_user_id,
278 owner_oauth2_client_id,
279 Uuid::from(actor_user.id),
280 &human_name,
281 &scope_list,
282 created_at,
283 )
284 .traced()
285 .execute(&mut *self.conn)
286 .await?;
287
288 Ok(PersonalSession {
289 id,
290 state: SessionState::Valid,
291 owner,
292 actor_user_id: actor_user.id,
293 human_name,
294 scope,
295 created_at,
296 last_active_at: None,
297 last_active_ip: None,
298 })
299 }
300
301 #[tracing::instrument(
302 name = "db.personal_session.revoke",
303 skip_all,
304 fields(
305 db.query.text,
306 %session.id,
307 %session.scope,
308 ),
309 err,
310 )]
311 async fn revoke(
312 &mut self,
313 clock: &dyn Clock,
314 session: PersonalSession,
315 ) -> Result<PersonalSession, Self::Error> {
316 let revoked_at = clock.now();
317
318 {
319 let span = info_span!(
321 "db.personal_session.revoke.tokens",
322 { DB_QUERY_TEXT } = tracing::field::Empty,
323 );
324
325 sqlx::query!(
326 r#"
327 UPDATE personal_access_tokens
328 SET revoked_at = $2
329 WHERE personal_session_id = $1 AND revoked_at IS NULL
330 "#,
331 Uuid::from(session.id),
332 revoked_at,
333 )
334 .record(&span)
335 .execute(&mut *self.conn)
336 .instrument(span)
337 .await?;
338 }
339
340 let res = sqlx::query!(
341 r#"
342 UPDATE personal_sessions
343 SET revoked_at = $2
344 WHERE personal_session_id = $1
345 "#,
346 Uuid::from(session.id),
347 revoked_at,
348 )
349 .traced()
350 .execute(&mut *self.conn)
351 .await?;
352
353 DatabaseError::ensure_affected_rows(&res, 1)?;
354
355 session
356 .finish(revoked_at)
357 .map_err(DatabaseError::to_invalid_operation)
358 }
359
360 #[tracing::instrument(
361 name = "db.personal_session.list",
362 skip_all,
363 fields(
364 db.query.text,
365 ),
366 err,
367 )]
368 async fn list(
369 &mut self,
370 filter: PersonalSessionFilter<'_>,
371 pagination: Pagination,
372 ) -> Result<Page<(PersonalSession, Option<PersonalAccessToken>)>, Self::Error> {
373 let (sql, arguments) = Query::select()
374 .expr_as(
375 Expr::col((PersonalSessions::Table, PersonalSessions::PersonalSessionId)),
376 PersonalSessionAndAccessTokenLookupIden::PersonalSessionId,
377 )
378 .expr_as(
379 Expr::col((PersonalSessions::Table, PersonalSessions::OwnerUserId)),
380 PersonalSessionAndAccessTokenLookupIden::OwnerUserId,
381 )
382 .expr_as(
383 Expr::col((
384 PersonalSessions::Table,
385 PersonalSessions::OwnerOAuth2ClientId,
386 )),
387 PersonalSessionAndAccessTokenLookupIden::OwnerOauth2ClientId,
388 )
389 .expr_as(
390 Expr::col((PersonalSessions::Table, PersonalSessions::ActorUserId)),
391 PersonalSessionAndAccessTokenLookupIden::ActorUserId,
392 )
393 .expr_as(
394 Expr::col((PersonalSessions::Table, PersonalSessions::HumanName)),
395 PersonalSessionAndAccessTokenLookupIden::HumanName,
396 )
397 .expr_as(
398 Expr::col((PersonalSessions::Table, PersonalSessions::ScopeList)),
399 PersonalSessionAndAccessTokenLookupIden::ScopeList,
400 )
401 .expr_as(
402 Expr::col((PersonalSessions::Table, PersonalSessions::CreatedAt)),
403 PersonalSessionAndAccessTokenLookupIden::CreatedAt,
404 )
405 .expr_as(
406 Expr::col((PersonalSessions::Table, PersonalSessions::RevokedAt)),
407 PersonalSessionAndAccessTokenLookupIden::RevokedAt,
408 )
409 .expr_as(
410 Expr::col((PersonalSessions::Table, PersonalSessions::LastActiveAt)),
411 PersonalSessionAndAccessTokenLookupIden::LastActiveAt,
412 )
413 .expr_as(
414 Expr::col((PersonalSessions::Table, PersonalSessions::LastActiveIp)),
415 PersonalSessionAndAccessTokenLookupIden::LastActiveIp,
416 )
417 .expr_as(
418 Expr::col((
419 PersonalAccessTokens::Table,
420 PersonalAccessTokens::PersonalAccessTokenId,
421 )),
422 PersonalSessionAndAccessTokenLookupIden::PersonalAccessTokenId,
423 )
424 .expr_as(
425 Expr::col((PersonalAccessTokens::Table, PersonalAccessTokens::CreatedAt)),
426 PersonalSessionAndAccessTokenLookupIden::TokenCreatedAt,
427 )
428 .expr_as(
429 Expr::col((PersonalAccessTokens::Table, PersonalAccessTokens::ExpiresAt)),
430 PersonalSessionAndAccessTokenLookupIden::TokenExpiresAt,
431 )
432 .from(PersonalSessions::Table)
433 .left_join(
434 PersonalAccessTokens::Table,
435 Cond::all()
436 .add(
437 Expr::col((PersonalSessions::Table, PersonalSessions::PersonalSessionId))
438 .eq(Expr::col((
439 PersonalAccessTokens::Table,
440 PersonalAccessTokens::PersonalSessionId,
441 ))),
442 )
443 .add(
444 Expr::col((PersonalAccessTokens::Table, PersonalAccessTokens::RevokedAt))
445 .is_null(),
446 ),
447 )
448 .apply_filter(filter)
449 .generate_pagination(
450 (PersonalSessions::Table, PersonalSessions::PersonalSessionId),
451 pagination,
452 )
453 .build_sqlx(PostgresQueryBuilder);
454
455 let edges: Vec<PersonalSessionAndAccessTokenLookup> = sqlx::query_as_with(&sql, arguments)
456 .traced()
457 .fetch_all(&mut *self.conn)
458 .await?;
459
460 let page = pagination.process(edges).try_map(TryFrom::try_from)?;
461
462 Ok(page)
463 }
464
465 #[tracing::instrument(
466 name = "db.personal_session.count",
467 skip_all,
468 fields(
469 db.query.text,
470 ),
471 err,
472 )]
473 async fn count(&mut self, filter: PersonalSessionFilter<'_>) -> Result<usize, Self::Error> {
474 let (sql, arguments) = Query::select()
475 .expr(Expr::col((PersonalSessions::Table, PersonalSessions::PersonalSessionId)).count())
476 .from(PersonalSessions::Table)
477 .left_join(
478 PersonalAccessTokens::Table,
479 Cond::all()
480 .add(
481 Expr::col((PersonalSessions::Table, PersonalSessions::PersonalSessionId))
482 .eq(Expr::col((
483 PersonalAccessTokens::Table,
484 PersonalAccessTokens::PersonalSessionId,
485 ))),
486 )
487 .add(
488 Expr::col((PersonalAccessTokens::Table, PersonalAccessTokens::RevokedAt))
489 .is_null(),
490 ),
491 )
492 .apply_filter(filter)
493 .build_sqlx(PostgresQueryBuilder);
494
495 let count: i64 = sqlx::query_scalar_with(&sql, arguments)
496 .traced()
497 .fetch_one(&mut *self.conn)
498 .await?;
499
500 count
501 .try_into()
502 .map_err(DatabaseError::to_invalid_operation)
503 }
504}
505
506impl Filter for PersonalSessionFilter<'_> {
507 fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
508 sea_query::Condition::all()
509 .add_option(self.owner_user().map(|user| {
510 Expr::col((PersonalSessions::Table, PersonalSessions::OwnerUserId))
511 .eq(Uuid::from(user.id))
512 }))
513 .add_option(self.owner_oauth2_client().map(|client| {
514 Expr::col((
515 PersonalSessions::Table,
516 PersonalSessions::OwnerOAuth2ClientId,
517 ))
518 .eq(Uuid::from(client.id))
519 }))
520 .add_option(self.actor_user().map(|user| {
521 Expr::col((PersonalSessions::Table, PersonalSessions::ActorUserId))
522 .eq(Uuid::from(user.id))
523 }))
524 .add_option(self.device().map(|device| -> SimpleExpr {
525 if let Ok([stable_scope_token, unstable_scope_token]) = device.to_scope_token() {
526 Condition::any()
527 .add(
528 Expr::val(stable_scope_token.to_string()).eq(PgFunc::any(Expr::col((
529 PersonalSessions::Table,
530 PersonalSessions::ScopeList,
531 )))),
532 )
533 .add(Expr::val(unstable_scope_token.to_string()).eq(PgFunc::any(
534 Expr::col((PersonalSessions::Table, PersonalSessions::ScopeList)),
535 )))
536 .into()
537 } else {
538 Expr::val(false).into()
540 }
541 }))
542 .add_option(self.state().map(|state| match state {
543 PersonalSessionState::Active => {
544 Expr::col((PersonalSessions::Table, PersonalSessions::RevokedAt)).is_null()
545 }
546 PersonalSessionState::Revoked => {
547 Expr::col((PersonalSessions::Table, PersonalSessions::RevokedAt)).is_not_null()
548 }
549 }))
550 .add_option(self.scope().map(|scope| {
551 let scope: Vec<String> = scope.iter().map(|s| s.as_str().to_owned()).collect();
552 Expr::col((PersonalSessions::Table, PersonalSessions::ScopeList)).contains(scope)
553 }))
554 .add_option(self.last_active_before().map(|last_active_before| {
555 Expr::col((PersonalSessions::Table, PersonalSessions::LastActiveAt))
556 .lt(last_active_before)
557 }))
558 .add_option(self.last_active_after().map(|last_active_after| {
559 Expr::col((PersonalSessions::Table, PersonalSessions::LastActiveAt))
560 .gt(last_active_after)
561 }))
562 .add_option(self.expires_before().map(|expires_before| {
563 Expr::col((PersonalAccessTokens::Table, PersonalAccessTokens::ExpiresAt))
564 .lt(expires_before)
565 }))
566 .add_option(self.expires_after().map(|expires_after| {
567 Expr::col((PersonalAccessTokens::Table, PersonalAccessTokens::ExpiresAt))
568 .gt(expires_after)
569 }))
570 }
571}