diff --git a/python/otbtf.py b/python/otbtf.py index a23d5237b701b9ca245fa9b850162ca2ec52e064..ce84cbc784205e43280128e5c148d45d87fdba9c 100644 --- a/python/otbtf.py +++ b/python/otbtf.py @@ -21,17 +21,24 @@ Contains stuff to help working with TensorFlow and geospatial data in the OTBTF framework. """ +import glob +import json +import os import threading import multiprocessing import time import logging from abc import ABC, abstractmethod +from functools import partial +from tqdm import tqdm + import numpy as np import tensorflow as tf import gdal # ----------------------------------------------------- Helpers -------------------------------------------------------- +import system def gdal_open(filename): @@ -167,13 +174,18 @@ class PatchesImagesReader(PatchesReaderBase): :see PatchesReaderBase """ - def __init__(self, filenames_dict: dict, use_streaming=False): + def __init__(self, filenames_dict, scalar_dict=None, use_streaming=False): """ :param filenames_dict: A dict() structured as follow: {src_name1: [src1_patches_image_1.tif, ..., src1_patches_image_N.tif], src_name2: [src2_patches_image_1.tif, ..., src2_patches_image_N.tif], ... src_nameM: [srcM_patches_image_1.tif, ..., srcM_patches_image_N.tif]} + :param scalar_dict: (optional) a dict containing list of scalars (int, float, str) as follow: + {scalar_name1: ["value_1", ..., "value_N"], + scalar_name2: [value_1, ..., value_N], + ... + scalar_nameN: [value1, ..., value_N]} :param use_streaming: if True, the patches are read on the fly from the disc, nothing is kept in memory. """ @@ -182,8 +194,13 @@ class PatchesImagesReader(PatchesReaderBase): # gdal_ds dict self.gdal_ds = {key: [gdal_open(src_fn) for src_fn in src_fns] for key, src_fns in filenames_dict.items()} + # Scalar parameters (e.g. metadatas) + self.scalar_dict = scalar_dict + if scalar_dict is None: + self.scalar_dict = {} + # check number of patches in each sources - if len({len(ds_list) for ds_list in self.gdal_ds.values()}) != 1: + if len({len(ds_list) for ds_list in list(self.gdal_ds.values()) + list(self.scalar_dict.values())}) != 1: raise Exception("Each source must have the same number of patches images") # streaming on/off @@ -209,6 +226,13 @@ class PatchesImagesReader(PatchesReaderBase): self.ds_sizes = [self._get_nb_of_patches(ds) for ds in self.gdal_ds[src_key_0]] self.size = sum(self.ds_sizes) + # Create another scalars dict so that one scalar <-> one patch + self.scalar_buffer = {} + for src_key, scalars in self.scalar_dict.items(): + self.scalar_buffer[src_key] = [] + for scalar, ds_size in zip(scalars, self.ds_sizes): + self.scalar_buffer[src_key].extend([scalar] * ds_size) + # if use_streaming is False, we store in memory all patches images if not self.use_streaming: patches_list = {src_key: [read_as_np_arr(ds) for ds in self.gdal_ds[src_key]] for src_key in self.gdal_ds} @@ -257,6 +281,7 @@ class PatchesImagesReader(PatchesReaderBase): else: i, offset = self._get_ds_and_offset_from_index(index) res = {src_key: self._read_extract_as_np_arr(self.gdal_ds[src_key][i], offset) for src_key in self.gdal_ds} + res.update({key: np.asarray(scalars[index]) for key, scalars in self.scalar_buffer.items()}) return res @@ -362,7 +387,7 @@ class Dataset: """ def __init__(self, patches_reader: PatchesReaderBase, buffer_length: int = 128, - Iterator: IteratorBase = RandomIterator): + Iterator=RandomIterator): """ :param patches_reader: The patches reader instance :param buffer_length: The number of samples that are stored in the buffer @@ -380,6 +405,7 @@ class Dataset: self.output_types = dict() self.output_shapes = dict() one_sample = self.patches_reader.get_sample(index=0) + print(one_sample) for src_key, np_arr in one_sample.items(): self.output_shapes[src_key] = np_arr.shape self.output_types[src_key] = tf.dtypes.as_dtype(np_arr.dtype) @@ -404,6 +430,14 @@ class Dataset: output_types=self.output_types, output_shapes=self.output_shapes).repeat(1) + def to_tfrecords(self, output_dir, n_samples_per_shard=100, drop_remainder=True): + """ + + """ + tfrecord = TFRecords(output_dir) + tfrecord.ds2tfrecord(self, n_samples_per_shard=n_samples_per_shard, drop_remainder=drop_remainder) + + def get_stats(self) -> dict: """ :return: the dataset statistics, computed by the patches reader @@ -518,3 +552,213 @@ class DatasetFromPatchesImages(Dataset): patches_reader = PatchesImagesReader(filenames_dict=filenames_dict, use_streaming=use_streaming) super().__init__(patches_reader=patches_reader, buffer_length=buffer_length, Iterator=Iterator) + + +class TFRecords: + """ + This class allows to convert Dataset objects to TFRecords and to load them in dataset tensorflows format. + """ + + def __init__(self, path): + """ + :param path: Can be a directory where TFRecords must be save/loaded or a single TFRecord path + """ + if system.is_dir(path) or not os.path.exists(path): + self.dirpath = path + system.mkdir(self.dirpath) + self.tfrecords_pattern_path = "{}*.records".format(system.pathify(self.dirpath)) + else: + self.dirpath = system.dirname(path) + self.tfrecords_pattern_path = path + self.output_types_file = "{}output_types.json".format(system.pathify(self.dirpath)) + self.output_shape_file = "{}output_shape.json".format(system.pathify(self.dirpath)) + self.output_shape = self.load(self.output_shape_file) if os.path.exists(self.output_shape_file) else None + self.output_types = self.load(self.output_types_file) if os.path.exists(self.output_types_file) else None + + def _bytes_feature(self, value): + """ + Used to convert a value to a type compatible with tf.train.Example. + :param value: value + :return a bytes_list from a string / byte. + """ + if isinstance(value, type(tf.constant(0))): + value = value.numpy() # BytesList won't unpack a string from an EagerTensor. + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) + + def ds2tfrecord(self, dataset, n_samples_per_shard=100, drop_remainder=True): + """ + Convert and save samples from dataset object to tfrecord files. + :param dataset: Dataset object to convert into a set of tfrecords + :param n_samples_per_shard: Number of samples per shard + :param drop_remainder: Whether additional samples should be dropped. Advisable if using multiworkers training. + If True, all TFRecords will have `n_samples_per_shard` samples + """ + logging.info("%s samples", dataset.size) + + nb_shards = (dataset.size // n_samples_per_shard) + if not drop_remainder and dataset.size % n_samples_per_shard > 0: + nb_shards += 1 + + self.convert_dataset_output_shapes(dataset) + + def _convert_data(data): + """ + Convert data + """ + data_converted = {} + + for k, d in data.items(): + data_converted[k] = d.name + + return data_converted + + self.save(_convert_data(dataset.output_types), self.output_types_file) + + for i in tqdm(range(nb_shards)): + + if (i + 1) * n_samples_per_shard <= dataset.size: + nb_sample = n_samples_per_shard + else: + nb_sample = dataset.size - i * n_samples_per_shard + + filepath = "{}{}.records".format(system.pathify(self.dirpath), i) + + # Geographic info of all samples of the record + #geojson_path = "{}{}.geojson".format(system.pathify(self.dirpath), i) + #geojson_dic = {"type": "FeatureCollection", + # "name": "{}_geoinfo".format(i), + # "features": []} + + with tf.io.TFRecordWriter(filepath) as writer: + for s in range(nb_sample): + sample = dataset.read_one_sample() + serialized_sample = {name: tf.io.serialize_tensor(fea) for name, fea in sample.items()} + features = {name: self._bytes_feature(serialized_tensor) for name, serialized_tensor in + serialized_sample.items()} + tf_features = tf.train.Features(feature=features) + example = tf.train.Example(features=tf_features) + writer.write(example.SerializeToString()) + + # write the geographic info of the sample inside the geojson dic + #UL_lon, UL_lat, LR_lon, LR_lat = sample['geoinfo'] + #geojson_dic['features'].append({"type": "Feature", "properties": {"sample_id": s}, + # "geometry": {"type": "Polygon", "coordinates": [[[UL_lon, UL_lat], + # [LR_lon, UL_lat], + # [LR_lon, LR_lat], + # [UL_lon, LR_lat], + # [UL_lon, + # UL_lat]]]}}) + # TODO (or not) + #with open(geojson_path, 'w') as f: + # json.dump(geojson_dic, f, indent=4) + + @staticmethod + def save(data, filepath): + """ + Save data to pickle format. + :param data: Data to save json format + :param filepath: Output file name + """ + + with open(filepath, 'w') as f: + json.dump(data, f, indent=4) + + @staticmethod + def load(filepath): + """ + Return data from pickle format. + :param filepath: Input file name + """ + with open(filepath, 'r') as f: + return json.load(f) + + def convert_dataset_output_shapes(self, dataset): + """ + Convert and save numpy shape to tensorflow shape. + :param dataset: Dataset object containing output shapes + """ + output_shapes = {} + + for key in dataset.output_shapes.keys(): + output_shapes[key] = (None,) + dataset.output_shapes[key] + + self.save(output_shapes, self.output_shape_file) + + @staticmethod + def parse_tfrecord(example, features_types, target_keys): + """ + Parse example object to sample dict. + :param example: Example object to parse + :param features_types: List of types for each feature + :param target_keys: list of keys of the targets + """ + read_features = {key: tf.io.FixedLenFeature([], dtype=tf.string) for key in features_types} + example_parsed = tf.io.parse_single_example(example, read_features) + + for key in read_features.keys(): + example_parsed[key] = tf.io.parse_tensor(example_parsed[key], out_type=features_types[key]) + + # Differentiating inputs and outputs + input_parsed = {key: value for (key, value) in example_parsed.items() if key not in target_keys} + target_parsed = {key: value for (key, value) in example_parsed.items() if key in target_keys} + + return input_parsed, target_parsed + + + def read(self, batch_size, target_keys, n_workers=1, drop_remainder=True, shuffle_buffer_size=None): + """ + Read all tfrecord files matching with pattern and convert data to tensorflow dataset. + :param batch_size: Size of tensorflow batch + :param target_key: Key of the target, e.g. 's2_out' + :param n_workers: number of workers, e.g. 4 if using 4 GPUs + e.g. 12 if using 3 nodes of 4 GPUs + :param drop_remainder: whether the last batch should be dropped in the case it has fewer than + `batch_size` elements. True is advisable when training on multiworkers. + False is advisable when evaluating metrics so that all samples are used + :param shuffle_buffer_size: is None, shuffle is not used. Else, blocks of shuffle_buffer_size + elements are shuffled using uniform random. + """ + options = tf.data.Options() + if shuffle_buffer_size: + options.experimental_deterministic = False # disable order, increase speed + options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.AUTO # for multiworker + parse = partial(self.parse_tfrecord, features_types=self.output_types, target_keys=target_keys) + + # TODO: to be investigated : + # 1/ num_parallel_reads useful ? I/O bottleneck of not ? + # 2/ num_parallel_calls=tf.data.experimental.AUTOTUNE useful ? + # 3/ shuffle or not shuffle ? + matching_files = glob.glob(self.tfrecords_pattern_path) + logging.info('Searching TFRecords in %s...', self.tfrecords_pattern_path) + logging.info('Number of matching TFRecords: %s', len(matching_files)) + matching_files = matching_files[:n_workers * (len(matching_files) // n_workers)] # files multiple of workers + nb_matching_files = len(matching_files) + if nb_matching_files == 0: + raise Exception("At least one worker has no TFRecord file in {}. Please ensure that the number of TFRecord " + "files is greater or equal than the number of workers!".format(self.tfrecords_pattern_path)) + logging.info('Reducing number of records to : %s', nb_matching_files) + dataset = tf.data.TFRecordDataset(matching_files) # , num_parallel_reads=2) # interleaves reads from xxx files + dataset = dataset.with_options(options) # uses data as soon as it streams in, rather than in its original order + dataset = dataset.map(parse, num_parallel_calls=tf.data.experimental.AUTOTUNE) + if shuffle_buffer_size: + dataset = dataset.shuffle(buffer_size=shuffle_buffer_size) + dataset = dataset.batch(batch_size, drop_remainder=drop_remainder) + dataset = dataset.prefetch(buffer_size=tf.data.experimental.AUTOTUNE) + # TODO voir si on met le prefetch avant le batch cf https://keras.io/examples/keras_recipes/tfrecord/ + + return dataset + + def read_one_sample(self, target_keys): + """ + Read one tfrecord file matching with pattern and convert data to tensorflow dataset. + :param target_key: Key of the target, e.g. 's2_out' + """ + matching_files = glob.glob(self.tfrecords_pattern_path) + one_file = matching_files[0] + parse = partial(self.parse_tfrecord, features_types=self.output_types, target_keys=target_keys) + dataset = tf.data.TFRecordDataset(one_file) + dataset = dataset.map(parse) + dataset = dataset.batch(1) + + sample = iter(dataset).get_next() + return sample diff --git a/python/system.py b/python/system.py new file mode 100644 index 0000000000000000000000000000000000000000..b4810135a43bc0e960fb46704991e8a61fa77046 --- /dev/null +++ b/python/system.py @@ -0,0 +1,205 @@ +# -*- coding: utf-8 -*- +""" +Copyright (c) 2020-2022 INRAE + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL +THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" +"""Various system operations""" +import logging +import zipfile +import pathlib +import os +import sys +import git + +# --------------------------------------------------- Constants -------------------------------------------------------- + + +COMPLETE_SUFFIX = ".complete" + + +# ---------------------------------------------------- Helpers --------------------------------------------------------- + + +def get_commit_hash(): + """ Return the git hash of the repository """ + repo = git.Repo(os.path.dirname(os.path.realpath(__file__)), search_parent_directories=True) + + try: + commit_hash = repo.active_branch.name + "_" + repo.head.object.hexsha[0:5] + except TypeError: + commit_hash = 'DETACHED_' + repo.head.object.hexsha[0:5] + + return commit_hash + + +def get_directories(root): + """ + List all directories in the root directory + :param root: root directory + :return: list of directories + """ + return [pathify(root) + item for item in os.listdir(root)] + + +def get_files(directory, ext=None): + """ List the files in directory, and sort + :param directory: directory of the image + :param ext: optional, end of filename to be matched + :return: list of the filepaths + """ + ret = [] + for root, _, files in os.walk(directory, topdown=False): + for name in files: + filename = os.path.join(root, name) + if ext: + if filename.lower().endswith(ext.lower()): + ret.append(filename) + else: + ret.append(filename) + return ret + + +def new_bname(filename, suffix): + """ return a new basename (without path, without extension, + suffix) """ + filename = filename[filename.rfind("/"):] + filename = filename[:filename.rfind(".")] + return filename + "_" + suffix + + +def pathify(pth): + """ Adds posix separator if needed """ + if not pth.endswith("/"): + pth += "/" + return pth + + +def mkdir(pth): + """ Create a directory """ + path = pathlib.Path(pth) + path.mkdir(parents=True, exist_ok=True) + + +def dirname(filename): + """ Returns the parent directory of the file """ + return str(pathlib.Path(filename).parent) + + +def basename(pth): + """ Returns the basename. Works with files and paths""" + return str(pathlib.Path(pth).name) + + +def join(*pthslist): + """ Returns the join of all paths""" + return str(pathlib.PurePath(*pthslist)) + + +def list_files_in_zip(filename, endswith=None): + """ List files in zip archive + :param filename: path of the zip + :param endswith: optional, end of filename to be matched + :return: list of the filepaths + """ + with zipfile.ZipFile(filename) as zip_file: + filelist = zip_file.namelist() + if endswith: + filelist = [f for f in filelist if f.endswith(endswith)] + + return filelist + + +def to_vsizip(zipfn, relpth): + """ Create path from zip file """ + return "/vsizip/{}/{}".format(zipfn, relpth) + + +def remove_ext_filename(filename): + """ Remove OTB extended filenames (keep only the part before the "?") """ + if "?" in filename: + return filename[:filename.rfind("?")] + return filename + + +def declare_complete(filename): + """ Declare that a file has been completed, creating a small file """ + filename = remove_ext_filename(filename) + filename += COMPLETE_SUFFIX + with open(filename, "w") as text_file: + text_file.write("ok") + + +def file_exists(filename): + """ Check if file exists """ + my_file = pathlib.Path(filename) + return my_file.is_file() + + +def is_complete(filename): + """ Returns True if a file has been completed """ + filename = remove_ext_filename(filename) + filename += COMPLETE_SUFFIX + return file_exists(filename) + + +def set_env_var(var, value): + """ Set an environment variable """ + os.environ[var] = value + + +def get_env_var(var): + """ Return an environment variable """ + value = os.environ[var] + if value is None: + logging.warning("Environment variable %s is not set. Returning value None.", var) + return value + + +def basic_logging_init(): + """ basic logging initialization """ + logging.basicConfig( + format='%(asctime)s %(levelname)-8s %(message)s', + level=logging.INFO, + datefmt='%Y-%m-%d %H:%M:%S') + + +def logging_info(msg, verbose=True): + """ + Prints log info only if required by `verbose` + :param msg: message to log + :param verbose: boolean. Whether to log msg or not. Default True + :return: + """ + if verbose: + logging.info(msg) + + +def is_dir(filename): + """ return True if filename is the path to a directory """ + return os.path.isdir(filename) + + +def terminate(): + """ Ends the running program """ + sys.exit() + + +def run_and_terminate(main): + """Run the main function then ends the running program""" + sys.exit(main(args=sys.argv[1:]))