Alexandria 2.25.0
SDC-CH common library for the Euclid project
SOMTrainer.h
Go to the documentation of this file.
1/*
2 * Copyright (C) 2012-2022 Euclid Science Ground Segment
3 *
4 * This library is free software; you can redistribute it and/or modify it under
5 * the terms of the GNU Lesser General Public License as published by the Free
6 * Software Foundation; either version 3.0 of the License, or (at your option)
7 * any later version.
8 *
9 * This library is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
11 * FOR A PARTICULAR PURPOSE. See the GNU Lesser General Public License for more
12 * details.
13 *
14 * You should have received a copy of the GNU Lesser General Public License
15 * along with this library; if not, write to the Free Software Foundation, Inc.,
16 * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
17 */
18
19/*
20 * @file SOMTrainer.h
21 * @author nikoapos
22 */
23
24#ifndef SOM_SOMTRAINER_H
25#define SOM_SOMTRAINER_H
26
29#include "SOM/SOM.h"
30#include "SOM/SamplingPolicy.h"
31
32namespace Euclid {
33namespace SOM {
34
36
37public:
39 : m_neighborhood_func(neighborhood_func), m_learning_restraint_func(learning_restraint_func) {}
40
41 template <typename DistFunc, typename InputIter, typename InputToWeightFunc>
42 void train(SOM<DistFunc>& som, std::size_t iter_no, InputIter begin, InputIter end, InputToWeightFunc weight_func,
44
45 // We repeat the training for iter_no iterations
46 for (std::size_t i = 0; i < iter_no; ++i) {
47
48 // Compute the factor of the current iteration
49 auto learn_factor = m_learning_restraint_func(i, iter_no);
50 if (learn_factor == 0) {
51 continue;
52 }
53
54 // Go through the training sample of the iteration
55 for (auto it = sampling_policy.start(begin, end); it != end; it = sampling_policy.next(it)) {
56
57 // Get the weights of the input object
58 auto input_weights = weight_func(*it);
59
60 // Find the coordinates of the BMU for the input
61 std::size_t bmu_x, bmu_y;
62 double nd_distance;
63 std::tie(bmu_x, bmu_y, nd_distance) = som.findBMU(*it, weight_func);
64
65 // Now go through all the cells and update their values according their coordinates
66 std::size_t size_x, size_y;
67 std::tie(size_x, size_y) = som.getSize();
68
69 for (std::size_t cell_y = 0; cell_y < size_y; ++cell_y) {
70 for (std::size_t cell_x = 0; cell_x < size_x; ++cell_x) {
71 auto cell = som(cell_x, cell_y);
72
73 // Compute the factor based on the distance of the BMU and the cell
74 auto neighborhood_factor = m_neighborhood_func({bmu_x, bmu_y}, {cell_x, cell_y}, i, iter_no);
75
76 // Get the weights of the cell and update them
77 if (neighborhood_factor != 0) {
78 for (std::size_t wi = 0; wi < som.getDimensions(); ++wi) {
79 cell[wi] = cell[wi] + neighborhood_factor * learn_factor * (input_weights[wi] - cell[wi]);
80 }
81 }
82 }
83 }
84 }
85 }
86 }
87
88private:
91};
92
93} // namespace SOM
94} // namespace Euclid
95
96#endif /* SOM_SOMTRAINER_H */
LearningRestraintFunc::Signature m_learning_restraint_func
Definition: SOMTrainer.h:90
SOMTrainer(NeighborhoodFunc::Signature neighborhood_func, LearningRestraintFunc::Signature learning_restraint_func)
Definition: SOMTrainer.h:38
NeighborhoodFunc::Signature m_neighborhood_func
Definition: SOMTrainer.h:89
void train(SOM< DistFunc > &som, std::size_t iter_no, InputIter begin, InputIter end, InputToWeightFunc weight_func, const SamplingPolicy::Interface< InputIter > &sampling_policy=SamplingPolicy::FullSet< InputIter >{})
Definition: SOMTrainer.h:42
T tie(T... args)