From f947353fa0df9b5f856ed367264b375a8dfa1c1e Mon Sep 17 00:00:00 2001
From: Jerome Humbert <djeedai@gmail.com>
Date: Mon, 14 Feb 2022 21:33:31 +0000
Subject: [PATCH] Add `Entity` to the completed callback

Provide the callback owner with the `Entity` the tween and the animator
are attached to, as a parameter to the callback when invoked.
---
 src/plugin.rs    | 12 ++++++------
 src/tweenable.rs | 48 +++++++++++++++++++++++++++++++-----------------
 2 files changed, 37 insertions(+), 23 deletions(-)

diff --git a/src/plugin.rs b/src/plugin.rs
index 84d4323..9cc510d 100644
--- a/src/plugin.rs
+++ b/src/plugin.rs
@@ -47,12 +47,12 @@ impl Plugin for TweeningPlugin {
 /// and tick the animator to animate the component.
 pub fn component_animator_system<T: Component>(
     time: Res<Time>,
-    mut query: Query<(&mut T, &mut Animator<T>)>,
+    mut query: Query<(Entity, &mut T, &mut Animator<T>)>,
 ) {
-    for (ref mut target, ref mut animator) in query.iter_mut() {
+    for (entity, ref mut target, ref mut animator) in query.iter_mut() {
         if animator.state != AnimatorState::Paused {
             if let Some(tweenable) = animator.tweenable_mut() {
-                tweenable.tick(time.delta(), target);
+                tweenable.tick(time.delta(), target, entity);
             }
         }
     }
@@ -64,13 +64,13 @@ pub fn component_animator_system<T: Component>(
 pub fn asset_animator_system<T: Asset>(
     time: Res<Time>,
     mut assets: ResMut<Assets<T>>,
-    mut query: Query<&mut AssetAnimator<T>>,
+    mut query: Query<(Entity, &mut AssetAnimator<T>)>,
 ) {
-    for ref mut animator in query.iter_mut() {
+    for (entity, ref mut animator) in query.iter_mut() {
         if animator.state != AnimatorState::Paused {
             if let Some(target) = assets.get_mut(animator.handle()) {
                 if let Some(tweenable) = animator.tweenable_mut() {
-                    tweenable.tick(time.delta(), target);
+                    tweenable.tick(time.delta(), target, entity);
                 }
             }
         }
diff --git a/src/tweenable.rs b/src/tweenable.rs
index 13309db..a58dbb3 100644
--- a/src/tweenable.rs
+++ b/src/tweenable.rs
@@ -59,7 +59,7 @@ pub trait Tweenable<T>: Send + Sync {
     ///
     /// [`rewind()`]: Tweenable::rewind
     /// [`set_progress()`]: Tweenable::set_progress
-    fn tick(&mut self, delta: Duration, target: &mut T) -> TweenState;
+    fn tick(&mut self, delta: Duration, target: &mut T, entity: Entity) -> TweenState;
 
     /// Get the number of times this tweenable completed.
     ///
@@ -85,8 +85,8 @@ impl<T> Tweenable<T> for Box<dyn Tweenable<T> + Send + Sync + 'static> {
     fn progress(&self) -> f32 {
         self.as_ref().progress()
     }
-    fn tick(&mut self, delta: Duration, target: &mut T) -> TweenState {
-        self.as_mut().tick(delta, target)
+    fn tick(&mut self, delta: Duration, target: &mut T, entity: Entity) -> TweenState {
+        self.as_mut().tick(delta, target, entity)
     }
     fn times_completed(&self) -> u32 {
         self.as_ref().times_completed()
@@ -116,7 +116,7 @@ pub struct Tween<T> {
     direction: TweeningDirection,
     times_completed: u32,
     lens: Box<dyn Lens<T> + Send + Sync + 'static>,
-    on_completed: Option<Box<dyn Fn(&Tween<T>) + Send + Sync + 'static>>,
+    on_completed: Option<Box<dyn Fn(Entity, &Tween<T>) + Send + Sync + 'static>>,
 }
 
 impl<T: 'static> Tween<T> {
@@ -199,10 +199,13 @@ impl<T> Tween<T> {
 
     /// Set a callback invoked when the animation completed.
     ///
+    /// The callback when invoked receives as parameters the [`Entity`] on which the target and the
+    /// animator are, as well as a reference to the current [`Tween`].
+    ///
     /// Only non-looping tweenables can complete.
     pub fn set_completed<C>(&mut self, callback: C)
     where
-        C: Fn(&Tween<T>) + Send + Sync + 'static,
+        C: Fn(Entity, &Tween<T>) + Send + Sync + 'static,
     {
         self.on_completed = Some(Box::new(callback));
     }
@@ -237,7 +240,7 @@ impl<T> Tweenable<T> for Tween<T> {
         }
     }
 
-    fn tick(&mut self, delta: Duration, target: &mut T) -> TweenState {
+    fn tick(&mut self, delta: Duration, target: &mut T, entity: Entity) -> TweenState {
         if !self.is_looping() && self.timer.finished() {
             return TweenState::Completed;
         }
@@ -266,7 +269,7 @@ impl<T> Tweenable<T> for Tween<T> {
             self.times_completed += self.timer.times_finished();
 
             if let Some(cb) = &self.on_completed {
-                cb(&self);
+                cb(entity, &self);
             }
         }
 
@@ -388,12 +391,12 @@ impl<T> Tweenable<T> for Sequence<T> {
         self.time.as_secs_f32() / self.duration.as_secs_f32()
     }
 
-    fn tick(&mut self, delta: Duration, target: &mut T) -> TweenState {
+    fn tick(&mut self, delta: Duration, target: &mut T, entity: Entity) -> TweenState {
         if self.index < self.tweens.len() {
             let mut state = TweenState::Active;
             self.time = (self.time + delta).min(self.duration);
             let tween = &mut self.tweens[self.index];
-            let tween_state = tween.tick(delta, target);
+            let tween_state = tween.tick(delta, target, entity);
             if tween_state == TweenState::Completed {
                 tween.rewind();
                 self.index += 1;
@@ -468,11 +471,11 @@ impl<T> Tweenable<T> for Tracks<T> {
         self.time.as_secs_f32() / self.duration.as_secs_f32()
     }
 
-    fn tick(&mut self, delta: Duration, target: &mut T) -> TweenState {
+    fn tick(&mut self, delta: Duration, target: &mut T, entity: Entity) -> TweenState {
         self.time = (self.time + delta).min(self.duration);
         let mut any_active = false;
         for tweenable in &mut self.tracks {
-            let state = tweenable.tick(delta, target);
+            let state = tweenable.tick(delta, target, entity);
             any_active = any_active || (state == TweenState::Active);
         }
         if any_active {
@@ -539,7 +542,7 @@ impl<T> Tweenable<T> for Delay {
         self.timer.percent()
     }
 
-    fn tick(&mut self, delta: Duration, _: &mut T) -> TweenState {
+    fn tick(&mut self, delta: Duration, _: &mut T, _entity: Entity) -> TweenState {
         self.timer.tick(delta);
         if self.timer.finished() {
             TweenState::Completed
@@ -600,10 +603,13 @@ mod tests {
                 },
             );
 
+            let dummy_entity = Entity::from_raw(42);
+
             // Register callbacks to count started/ended events
             let callback_monitor = Arc::new(Mutex::new(CallbackMonitor::default()));
             let cb_mon_ptr = Arc::clone(&callback_monitor);
-            tween.set_completed(move |tween| {
+            tween.set_completed(move |entity, tween| {
+                assert_eq!(dummy_entity, entity);
                 let mut cb_mon = cb_mon_ptr.lock().unwrap();
                 cb_mon.invoke_count += 1;
                 cb_mon.last_reported_count = tween.times_completed();
@@ -658,7 +664,7 @@ mod tests {
                 );
 
                 // Tick the tween
-                let actual_state = tween.tick(tick_duration, &mut transform);
+                let actual_state = tween.tick(tick_duration, &mut transform, dummy_entity);
 
                 // Check actual values
                 assert_eq!(tween.direction(), direction);
@@ -683,7 +689,7 @@ mod tests {
             assert_eq!(tween.times_completed(), 0);
 
             // Dummy tick to update target
-            let actual_state = tween.tick(Duration::ZERO, &mut transform);
+            let actual_state = tween.tick(Duration::ZERO, &mut transform, Entity::from_raw(0));
             assert_eq!(actual_state, TweenState::Active);
             assert!(transform.translation.abs_diff_eq(Vec3::ZERO, 1e-5));
             assert!(transform.rotation.abs_diff_eq(Quat::IDENTITY, 1e-5));
@@ -714,7 +720,11 @@ mod tests {
         let mut seq = tween1.then(tween2);
         let mut transform = Transform::default();
         for i in 1..=16 {
-            let state = seq.tick(Duration::from_secs_f32(0.2), &mut transform);
+            let state = seq.tick(
+                Duration::from_secs_f32(0.2),
+                &mut transform,
+                Entity::from_raw(0),
+            );
             if i < 5 {
                 assert_eq!(state, TweenState::Active);
                 let r = i as f32 * 0.2;
@@ -760,7 +770,11 @@ mod tests {
         let mut tracks = Tracks::new([tween1, tween2]);
         let mut transform = Transform::default();
         for i in 1..=6 {
-            let state = tracks.tick(Duration::from_secs_f32(0.2), &mut transform);
+            let state = tracks.tick(
+                Duration::from_secs_f32(0.2),
+                &mut transform,
+                Entity::from_raw(0),
+            );
             if i < 5 {
                 assert_eq!(state, TweenState::Active);
                 let r = i as f32 * 0.2;
-- 
GitLab