mas_handlers/activity_tracker/
mod.rs

1// Copyright 2024 New Vector Ltd.
2// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
3//
4// SPDX-License-Identifier: AGPL-3.0-only
5// Please see LICENSE in the repository root for full details.
6
7mod bound;
8mod worker;
9
10use std::net::IpAddr;
11
12use chrono::{DateTime, Utc};
13use mas_data_model::{BrowserSession, CompatSession, Session};
14use mas_storage::{BoxRepositoryFactory, Clock};
15use tokio_util::{sync::CancellationToken, task::TaskTracker};
16use ulid::Ulid;
17
18pub use self::bound::Bound;
19use self::worker::Worker;
20
21static MESSAGE_QUEUE_SIZE: usize = 1000;
22
23#[derive(Clone, Copy, Debug, PartialOrd, PartialEq, Eq, Hash)]
24enum SessionKind {
25    OAuth2,
26    Compat,
27    Browser,
28}
29
30impl SessionKind {
31    const fn as_str(self) -> &'static str {
32        match self {
33            SessionKind::OAuth2 => "oauth2",
34            SessionKind::Compat => "compat",
35            SessionKind::Browser => "browser",
36        }
37    }
38}
39
40enum Message {
41    Record {
42        kind: SessionKind,
43        id: Ulid,
44        date_time: DateTime<Utc>,
45        ip: Option<IpAddr>,
46    },
47    Flush(tokio::sync::oneshot::Sender<()>),
48}
49
50#[derive(Clone)]
51pub struct ActivityTracker {
52    channel: tokio::sync::mpsc::Sender<Message>,
53}
54
55impl ActivityTracker {
56    /// Create a new activity tracker
57    ///
58    /// It will spawn the background worker and a loop to flush the tracker on
59    /// the task tracker, and both will shut themselves down, flushing one last
60    /// time, when the cancellation token is cancelled.
61    #[must_use]
62    pub fn new(
63        repository_factory: BoxRepositoryFactory,
64        flush_interval: std::time::Duration,
65        task_tracker: &TaskTracker,
66        cancellation_token: CancellationToken,
67    ) -> Self {
68        let worker = Worker::new(repository_factory);
69        let (sender, receiver) = tokio::sync::mpsc::channel(MESSAGE_QUEUE_SIZE);
70        let tracker = ActivityTracker { channel: sender };
71
72        // Spawn the flush loop and the worker
73        task_tracker.spawn(
74            tracker
75                .clone()
76                .flush_loop(flush_interval, cancellation_token.clone()),
77        );
78        task_tracker.spawn(worker.run(receiver, cancellation_token));
79
80        tracker
81    }
82
83    /// Bind the activity tracker to an IP address.
84    #[must_use]
85    pub fn bind(self, ip: Option<IpAddr>) -> Bound {
86        Bound::new(self, ip)
87    }
88
89    /// Record activity in an OAuth 2.0 session.
90    pub async fn record_oauth2_session(
91        &self,
92        clock: &dyn Clock,
93        session: &Session,
94        ip: Option<IpAddr>,
95    ) {
96        let res = self
97            .channel
98            .send(Message::Record {
99                kind: SessionKind::OAuth2,
100                id: session.id,
101                date_time: clock.now(),
102                ip,
103            })
104            .await;
105
106        if let Err(e) = res {
107            tracing::error!("Failed to record OAuth2 session: {}", e);
108        }
109    }
110
111    /// Record activity in a compat session.
112    pub async fn record_compat_session(
113        &self,
114        clock: &dyn Clock,
115        compat_session: &CompatSession,
116        ip: Option<IpAddr>,
117    ) {
118        let res = self
119            .channel
120            .send(Message::Record {
121                kind: SessionKind::Compat,
122                id: compat_session.id,
123                date_time: clock.now(),
124                ip,
125            })
126            .await;
127
128        if let Err(e) = res {
129            tracing::error!("Failed to record compat session: {}", e);
130        }
131    }
132
133    /// Record activity in a browser session.
134    pub async fn record_browser_session(
135        &self,
136        clock: &dyn Clock,
137        browser_session: &BrowserSession,
138        ip: Option<IpAddr>,
139    ) {
140        let res = self
141            .channel
142            .send(Message::Record {
143                kind: SessionKind::Browser,
144                id: browser_session.id,
145                date_time: clock.now(),
146                ip,
147            })
148            .await;
149
150        if let Err(e) = res {
151            tracing::error!("Failed to record browser session: {}", e);
152        }
153    }
154
155    /// Manually flush the activity tracker.
156    pub async fn flush(&self) {
157        let (tx, rx) = tokio::sync::oneshot::channel();
158        let res = self.channel.send(Message::Flush(tx)).await;
159
160        match res {
161            Ok(()) => {
162                if let Err(e) = rx.await {
163                    tracing::error!(
164                        error = &e as &dyn std::error::Error,
165                        "Failed to flush activity tracker"
166                    );
167                }
168            }
169            Err(e) => {
170                tracing::error!(
171                    error = &e as &dyn std::error::Error,
172                    "Failed to flush activity tracker"
173                );
174            }
175        }
176    }
177
178    /// Regularly flush the activity tracker.
179    async fn flush_loop(
180        self,
181        interval: std::time::Duration,
182        cancellation_token: CancellationToken,
183    ) {
184        // This guard on the shutdown token is to ensure that if this task crashes for
185        // any reason, the server will shut down
186        let _guard = cancellation_token.clone().drop_guard();
187        let mut interval = tokio::time::interval(interval);
188        interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
189
190        loop {
191            tokio::select! {
192                biased;
193
194                () = cancellation_token.cancelled() => {
195                    // The cancellation token was cancelled, so we should exit
196                    return;
197                }
198
199                // First check if the channel is closed, then check if the timer expired
200                () = self.channel.closed() => {
201                    // The channel was closed, so we should exit
202                    return;
203                }
204
205
206                _ = interval.tick() => {
207                    self.flush().await;
208                }
209            }
210        }
211    }
212}